""" Ball Tracking Module using Kalman Filter. This module implements a Kalman filter-based tracker for smoothing and predicting tennis ball positions across video frames. """ import numpy as np from typing import Optional, Tuple, List from filterpy.kalman import KalmanFilter class BallTracker: """ Kalman filter-based tracker for tennis ball position and velocity. The tracker maintains state estimates for: - Position (x, y) - Velocity (vx, vy) Attributes: dt (float): Time step between frames (1/fps) process_noise (float): Process noise covariance measurement_noise (float): Measurement noise covariance max_missing_frames (int): Maximum frames without detection before reset """ def __init__( self, dt: float = 1.0 / 30.0, process_noise: float = 0.1, measurement_noise: float = 10.0, max_missing_frames: int = 10 ): """ Initialize the ball tracker. Args: dt: Time step between frames (seconds) process_noise: Process noise standard deviation measurement_noise: Measurement noise standard deviation max_missing_frames: Max consecutive frames without detection """ self.dt = dt self.process_noise = process_noise self.measurement_noise = measurement_noise self.max_missing_frames = max_missing_frames # Initialize Kalman filter self.kf = self._create_kalman_filter() # Tracking state self.initialized = False self.missing_frames = 0 self.trajectory = [] # List of (x, y, vx, vy, frame_num) self.frame_count = 0 def _create_kalman_filter(self) -> KalmanFilter: """ Create and configure a Kalman filter for 2D position tracking. State vector: [x, y, vx, vy] Measurement vector: [x, y] Returns: Configured KalmanFilter instance """ kf = KalmanFilter(dim_x=4, dim_z=2) # State transition matrix (constant velocity model) kf.F = np.array([ [1, 0, self.dt, 0], [0, 1, 0, self.dt], [0, 0, 1, 0], [0, 0, 0, 1] ]) # Measurement matrix (observe position only) kf.H = np.array([ [1, 0, 0, 0], [0, 1, 0, 0] ]) # Measurement noise covariance kf.R = np.eye(2) * self.measurement_noise # Process noise covariance q = self.process_noise kf.Q = np.array([ [q * self.dt**4 / 4, 0, q * self.dt**3 / 2, 0], [0, q * self.dt**4 / 4, 0, q * self.dt**3 / 2], [q * self.dt**3 / 2, 0, q * self.dt**2, 0], [0, q * self.dt**3 / 2, 0, q * self.dt**2] ]) # Initial state covariance kf.P = np.eye(4) * 100 return kf def update( self, measurement: Optional[Tuple[float, float]] = None ) -> Optional[Tuple[float, float, float, float]]: """ Update tracker with a new measurement or predict if no detection. Args: measurement: Ball center position as (x, y), or None if not detected Returns: Estimated state as (x, y, vx, vy) or None if tracker not initialized """ self.frame_count += 1 if measurement is not None: # Detection available if not self.initialized: # Initialize tracker with first detection self.kf.x = np.array([ measurement[0], measurement[1], 0.0, 0.0 ]) self.initialized = True self.missing_frames = 0 else: # Update with measurement z = np.array([measurement[0], measurement[1]]) self.kf.predict() self.kf.update(z) self.missing_frames = 0 # Record trajectory x, y, vx, vy = self.kf.x self.trajectory.append(( float(x), float(y), float(vx), float(vy), self.frame_count )) return (float(x), float(y), float(vx), float(vy)) else: # No detection - predict only if self.initialized: self.kf.predict() self.missing_frames += 1 # Reset if too many missing frames if self.missing_frames > self.max_missing_frames: self.reset() return None # Return prediction x, y, vx, vy = self.kf.x self.trajectory.append(( float(x), float(y), float(vx), float(vy), self.frame_count )) return (float(x), float(y), float(vx), float(vy)) return None def reset(self): """Reset tracker to uninitialized state.""" self.kf = self._create_kalman_filter() self.initialized = False self.missing_frames = 0 def get_trajectory(self) -> List[Tuple[float, float, float, float, int]]: """ Get the complete trajectory history. Returns: List of trajectory points as (x, y, vx, vy, frame_num) """ return self.trajectory def get_speed(self, state: Tuple[float, float, float, float]) -> float: """ Calculate speed from velocity components. Args: state: Tracker state as (x, y, vx, vy) Returns: Speed in pixels per second """ _, _, vx, vy = state speed = np.sqrt(vx**2 + vy**2) / self.dt return float(speed) def get_last_n_positions(self, n: int = 20) -> List[Tuple[float, float]]: """ Get the last N tracked positions for trail visualization. Args: n: Number of recent positions to return Returns: List of (x, y) coordinates """ if len(self.trajectory) == 0: return [] recent = self.trajectory[-n:] return [(x, y) for x, y, _, _, _ in recent] def is_initialized(self) -> bool: """Check if tracker has been initialized with a detection.""" return self.initialized