tennisvision / tracker.py
Onur Çopur
first commit
3b90d9c
"""
Ball Tracking Module using Kalman Filter.
This module implements a Kalman filter-based tracker for smoothing
and predicting tennis ball positions across video frames.
"""
import numpy as np
from typing import Optional, Tuple, List
from filterpy.kalman import KalmanFilter
class BallTracker:
"""
Kalman filter-based tracker for tennis ball position and velocity.
The tracker maintains state estimates for:
- Position (x, y)
- Velocity (vx, vy)
Attributes:
dt (float): Time step between frames (1/fps)
process_noise (float): Process noise covariance
measurement_noise (float): Measurement noise covariance
max_missing_frames (int): Maximum frames without detection before reset
"""
def __init__(
self,
dt: float = 1.0 / 30.0,
process_noise: float = 0.1,
measurement_noise: float = 10.0,
max_missing_frames: int = 10
):
"""
Initialize the ball tracker.
Args:
dt: Time step between frames (seconds)
process_noise: Process noise standard deviation
measurement_noise: Measurement noise standard deviation
max_missing_frames: Max consecutive frames without detection
"""
self.dt = dt
self.process_noise = process_noise
self.measurement_noise = measurement_noise
self.max_missing_frames = max_missing_frames
# Initialize Kalman filter
self.kf = self._create_kalman_filter()
# Tracking state
self.initialized = False
self.missing_frames = 0
self.trajectory = [] # List of (x, y, vx, vy, frame_num)
self.frame_count = 0
def _create_kalman_filter(self) -> KalmanFilter:
"""
Create and configure a Kalman filter for 2D position tracking.
State vector: [x, y, vx, vy]
Measurement vector: [x, y]
Returns:
Configured KalmanFilter instance
"""
kf = KalmanFilter(dim_x=4, dim_z=2)
# State transition matrix (constant velocity model)
kf.F = np.array([
[1, 0, self.dt, 0],
[0, 1, 0, self.dt],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
# Measurement matrix (observe position only)
kf.H = np.array([
[1, 0, 0, 0],
[0, 1, 0, 0]
])
# Measurement noise covariance
kf.R = np.eye(2) * self.measurement_noise
# Process noise covariance
q = self.process_noise
kf.Q = np.array([
[q * self.dt**4 / 4, 0, q * self.dt**3 / 2, 0],
[0, q * self.dt**4 / 4, 0, q * self.dt**3 / 2],
[q * self.dt**3 / 2, 0, q * self.dt**2, 0],
[0, q * self.dt**3 / 2, 0, q * self.dt**2]
])
# Initial state covariance
kf.P = np.eye(4) * 100
return kf
def update(
self,
measurement: Optional[Tuple[float, float]] = None
) -> Optional[Tuple[float, float, float, float]]:
"""
Update tracker with a new measurement or predict if no detection.
Args:
measurement: Ball center position as (x, y), or None if not detected
Returns:
Estimated state as (x, y, vx, vy) or None if tracker not initialized
"""
self.frame_count += 1
if measurement is not None:
# Detection available
if not self.initialized:
# Initialize tracker with first detection
self.kf.x = np.array([
measurement[0],
measurement[1],
0.0,
0.0
])
self.initialized = True
self.missing_frames = 0
else:
# Update with measurement
z = np.array([measurement[0], measurement[1]])
self.kf.predict()
self.kf.update(z)
self.missing_frames = 0
# Record trajectory
x, y, vx, vy = self.kf.x
self.trajectory.append((
float(x), float(y), float(vx), float(vy), self.frame_count
))
return (float(x), float(y), float(vx), float(vy))
else:
# No detection - predict only
if self.initialized:
self.kf.predict()
self.missing_frames += 1
# Reset if too many missing frames
if self.missing_frames > self.max_missing_frames:
self.reset()
return None
# Return prediction
x, y, vx, vy = self.kf.x
self.trajectory.append((
float(x), float(y), float(vx), float(vy), self.frame_count
))
return (float(x), float(y), float(vx), float(vy))
return None
def reset(self):
"""Reset tracker to uninitialized state."""
self.kf = self._create_kalman_filter()
self.initialized = False
self.missing_frames = 0
def get_trajectory(self) -> List[Tuple[float, float, float, float, int]]:
"""
Get the complete trajectory history.
Returns:
List of trajectory points as (x, y, vx, vy, frame_num)
"""
return self.trajectory
def get_speed(self, state: Tuple[float, float, float, float]) -> float:
"""
Calculate speed from velocity components.
Args:
state: Tracker state as (x, y, vx, vy)
Returns:
Speed in pixels per second
"""
_, _, vx, vy = state
speed = np.sqrt(vx**2 + vy**2) / self.dt
return float(speed)
def get_last_n_positions(self, n: int = 20) -> List[Tuple[float, float]]:
"""
Get the last N tracked positions for trail visualization.
Args:
n: Number of recent positions to return
Returns:
List of (x, y) coordinates
"""
if len(self.trajectory) == 0:
return []
recent = self.trajectory[-n:]
return [(x, y) for x, y, _, _, _ in recent]
def is_initialized(self) -> bool:
"""Check if tracker has been initialized with a detection."""
return self.initialized