Onur Çopur commited on
Commit
3b90d9c
·
1 Parent(s): 8353beb

first commit

Browse files
Files changed (12) hide show
  1. .gitignore +67 -0
  2. CLAUDE.md +88 -0
  3. README.md +280 -7
  4. __init__.py +14 -0
  5. app.py +322 -4
  6. detector.py +155 -0
  7. packages.txt +7 -0
  8. requirements.txt +9 -0
  9. tracker.py +210 -0
  10. utils/__init__.py +32 -0
  11. utils/io_utils.py +287 -0
  12. utils/visualization.py +315 -0
.gitignore ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ env.bak/
28
+ venv.bak/
29
+
30
+ # IDE
31
+ .vscode/
32
+ .idea/
33
+ *.swp
34
+ *.swo
35
+ *~
36
+ .DS_Store
37
+
38
+ # Output files
39
+ output/
40
+ *.mp4
41
+ *.avi
42
+ *.mov
43
+ *.csv
44
+ *.png
45
+ !example_videos/*.mp4
46
+
47
+ # Model weights (downloaded at runtime)
48
+ *.pt
49
+ *.pth
50
+ *.onnx
51
+
52
+ # Jupyter
53
+ .ipynb_checkpoints/
54
+ *.ipynb
55
+
56
+ # Gradio
57
+ gradio_cached_examples/
58
+ flagged/
59
+
60
+ # Logs
61
+ *.log
62
+ logs/
63
+
64
+ # Temporary files
65
+ tmp/
66
+ temp/
67
+ *.tmp
CLAUDE.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Running the Application
6
+
7
+ Start the Gradio web interface:
8
+ ```bash
9
+ python app.py
10
+ ```
11
+
12
+ The app launches on `http://localhost:7860` by default (port 7860, binds to 0.0.0.0).
13
+
14
+ ## Architecture Overview
15
+
16
+ This is a Gradio-based web application for tennis ball tracking using computer vision. The processing pipeline has three main stages:
17
+
18
+ 1. **Detection** ([detector.py](detector.py)): YOLOv8 detects tennis balls (COCO class 32) in each frame
19
+ 2. **Tracking** ([tracker.py](tracker.py)): Kalman filter smooths trajectories and predicts positions during occlusion
20
+ 3. **Visualization** ([utils/visualization.py](utils/visualization.py)): Overlays trajectory trails, bounding boxes, speed labels, and generates plots
21
+
22
+ ### Key Design Patterns
23
+
24
+ **Processing Flow** ([app.py](app.py:30-188)):
25
+ - `process_video()` orchestrates the full pipeline
26
+ - Uses context managers (`VideoReader`, `VideoWriter`) for safe I/O
27
+ - Frame-by-frame processing: detect → update tracker → render overlays → write frame
28
+ - Trajectory data accumulated in tracker, exported to CSV at end
29
+
30
+ **State Management** ([tracker.py](tracker.py:13-211)):
31
+ - `BallTracker` maintains Kalman filter state vector: `[x, y, vx, vy]`
32
+ - Handles initialization on first detection
33
+ - Predicts ball position when detection is lost (up to `max_missing_frames`)
34
+ - Resets tracker if ball missing too long
35
+
36
+ **Detection Selection** ([app.py](app.py:104-115)):
37
+ - When multiple detections occur, uses highest confidence detection
38
+ - Minimum box size filter (5x5 pixels) in [detector.py](detector.py:107)
39
+
40
+ ## Module Dependencies
41
+
42
+ ```
43
+ app.py (main entry point)
44
+ ├── detector.py (BallDetector class)
45
+ ├── tracker.py (BallTracker class)
46
+ └── utils/
47
+ ├── io_utils.py (VideoReader, VideoWriter, CSV export)
48
+ └── visualization.py (drawing functions, plotting)
49
+ ```
50
+
51
+ All utilities are imported via `utils/__init__.py` which re-exports from submodules.
52
+
53
+ ## Configuration Parameters
54
+
55
+ **Detector** ([detector.py](detector.py:24-49)):
56
+ - `model_name`: 'yolov8n' (fastest), 'yolov8s', 'yolov8m' (most accurate)
57
+ - `confidence_threshold`: 0.1-0.9 (lower = more sensitive)
58
+ - `device`: auto-detected ('cuda' if available, else 'cpu')
59
+
60
+ **Tracker** ([tracker.py](tracker.py:28-47)):
61
+ - `dt`: time step, calculated as `1.0 / fps`
62
+ - `max_missing_frames`: typically `int(fps * 0.5)` (half-second tolerance)
63
+ - `process_noise`: 0.1 (Kalman filter Q matrix)
64
+ - `measurement_noise`: 10.0 (Kalman filter R matrix)
65
+
66
+ ## Output Files
67
+
68
+ All outputs saved to `output/` directory:
69
+ - `tracked_video.mp4`: Video with overlays (bounding boxes, trails, speed labels, info panel)
70
+ - `trajectory.csv`: Frame-by-frame data with columns: frame, timestamp, x/y position, velocity, speed
71
+ - `trajectory_plot.png`: 2D plot with color-coded speed gradient (blue=slow, red=fast)
72
+
73
+ ## Speed Estimation
74
+
75
+ Speed calculated from Kalman filter velocity: `speed = sqrt(vx² + vy²) / dt`
76
+
77
+ **Units**: Speed is in pixels/second. The "km/h" label uses rough approximation (`speed * 0.01`) - real-world conversion requires camera calibration.
78
+
79
+ ## Deployment Notes
80
+
81
+ **Hugging Face Spaces**:
82
+ - Use YOLOv8n model for free tier (limited GPU)
83
+ - App automatically creates `output/` directory on startup
84
+ - Gradio interface configured with `share=False` by default
85
+
86
+ **Model Downloads**:
87
+ - YOLOv8 weights downloaded automatically by Ultralytics on first run
88
+ - Models cached in `~/.cache/torch/hub/` directory
README.md CHANGED
@@ -1,13 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Tennisvision
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # 🎾 TennisVision – AI Ball Tracker
2
+
3
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
4
+ [![Gradio](https://img.shields.io/badge/gradio-4.0+-orange.svg)](https://gradio.app/)
5
+ [![YOLOv8](https://img.shields.io/badge/YOLOv8-Ultralytics-00ADD8.svg)](https://github.com/ultralytics/ultralytics)
6
+
7
+ A comprehensive AI-powered tennis ball tracking demo application that detects and tracks tennis balls in video clips using state-of-the-art computer vision models.
8
+
9
+ ## 🌟 Features
10
+
11
+ - **Real-time Ball Detection**: Uses YOLOv8 models for accurate tennis ball detection
12
+ - **Kalman Filter Tracking**: Smooth trajectory tracking with velocity estimation
13
+ - **Speed Visualization**: Real-time speed estimation and overlay
14
+ - **2D Trajectory Plot**: Color-coded trajectory visualization based on ball speed
15
+ - **Data Export**: Download processed videos and trajectory data in CSV format
16
+ - **Interactive UI**: User-friendly Gradio interface with adjustable parameters
17
+
18
+ ## 🎯 Demo Capabilities
19
+
20
+ 1. ✅ Accept short tennis video clips (user-uploaded)
21
+ 2. ✅ Detect and track tennis balls frame-by-frame using YOLOv8
22
+ 3. ✅ Visualize ball trajectory and speed as video overlays
23
+ 4. ✅ Display 2D trajectory plot (X vs Y) color-coded by speed
24
+ 5. ✅ Provide downloadable outputs:
25
+ - Processed video with trajectory overlays
26
+ - CSV file with timestamped coordinates and speed estimates
27
+ - Trajectory plot image
28
+
29
+ ## 🏗️ Project Structure
30
+
31
+ ```
32
+ tennisvision/
33
+ ├── app.py # Gradio application entry point
34
+ ├── detector.py # YOLOv8 detection model wrapper
35
+ ├── tracker.py # Kalman filter-based ball tracker
36
+ ├── utils/
37
+ │ ├── __init__.py # Utility package initialization
38
+ │ ├── visualization.py # Overlay rendering and plotting
39
+ │ └── io_utils.py # Video I/O and CSV export
40
+ ├── example_videos/ # Sample tennis videos (user-provided)
41
+ ├── output/ # Generated output files
42
+ ├── requirements.txt # Python dependencies
43
+ └── README.md # This file
44
+ ```
45
+
46
+ ## 🚀 Quick Start
47
+
48
+ ### Local Installation
49
+
50
+ 1. **Clone or download this repository**
51
+
52
+ ```bash
53
+ git clone <repository-url>
54
+ cd tennisvision
55
+ ```
56
+
57
+ 2. **Install dependencies**
58
+
59
+ ```bash
60
+ pip install -r requirements.txt
61
+ ```
62
+
63
+ 3. **Run the application**
64
+
65
+ ```bash
66
+ python app.py
67
+ ```
68
+
69
+ 4. **Open your browser** to `http://localhost:7860`
70
+
71
+ ### Usage
72
+
73
+ 1. **Upload a tennis video** (MP4, AVI, or MOV format)
74
+ 2. **Select detection model**:
75
+ - `yolov8n` - Fastest (recommended for Hugging Face Spaces)
76
+ - `yolov8s` - Balanced speed/accuracy
77
+ - `yolov8m` - Most accurate
78
+ 3. **Adjust confidence threshold** (0.1 - 0.9):
79
+ - Lower values detect more balls but may include false positives
80
+ - Higher values are more conservative
81
+ 4. **Click "Run Tracking"** and wait for processing
82
+ 5. **View results** in tabs:
83
+ - Processed video with overlays
84
+ - 2D trajectory plot
85
+ - Download CSV and video files
86
+
87
+ ## 🤗 Hugging Face Spaces Deployment
88
+
89
+ ### Method 1: Using the Web Interface
90
+
91
+ 1. Create a new Space on [Hugging Face](https://huggingface.co/spaces)
92
+ 2. Select **Gradio** as the SDK
93
+ 3. Upload all files from the `tennisvision/` directory
94
+ 4. The Space will automatically build and deploy
95
+
96
+ ### Method 2: Using Git
97
+
98
+ 1. Create a new Space and clone it:
99
+
100
+ ```bash
101
+ git clone https://huggingface.co/spaces/<your-username>/<space-name>
102
+ cd <space-name>
103
+ ```
104
+
105
+ 2. Copy all files from `tennisvision/`:
106
+
107
+ ```bash
108
+ cp -r path/to/tennisvision/* .
109
+ ```
110
+
111
+ 3. Commit and push:
112
+
113
+ ```bash
114
+ git add .
115
+ git commit -m "Initial commit: TennisVision ball tracker"
116
+ git push
117
+ ```
118
+
119
+ ### Configuration Files for Hugging Face
120
+
121
+ Create a `README.md` in the Space root with:
122
+
123
+ ```yaml
124
  ---
125
+ title: TennisVision - AI Ball Tracker
126
+ emoji: 🎾
127
+ colorFrom: green
128
+ colorTo: blue
129
  sdk: gradio
130
+ sdk_version: 4.0.0
131
  app_file: app.py
132
  pinned: false
 
133
  ---
134
+ ```
135
+
136
+ ## 📊 Output Files
137
+
138
+ ### 1. Processed Video (`tracked_video.mp4`)
139
+ - Original video with overlays
140
+ - Ball bounding boxes
141
+ - Trajectory trail (last 20 positions)
142
+ - Speed labels
143
+ - Info panel with frame count and timestamp
144
+
145
+ ### 2. Trajectory CSV (`trajectory.csv`)
146
+ Columns:
147
+ - `frame`: Frame number
148
+ - `timestamp_sec`: Time in seconds
149
+ - `x_pixels`, `y_pixels`: Ball center coordinates
150
+ - `velocity_x_px_per_sec`, `velocity_y_px_per_sec`: Velocity components
151
+ - `speed_px_per_sec`: Instantaneous speed
152
+
153
+ ### 3. Trajectory Plot (`trajectory_plot.png`)
154
+ - 2D visualization of ball path
155
+ - Color gradient representing speed (blue = slow, red = fast)
156
+ - Start and end markers
157
+
158
+ ## 🛠️ Technical Details
159
+
160
+ ### Detection Models
161
+
162
+ **YOLOv8** (You Only Look Once v8) from Ultralytics:
163
+ - Pre-trained on COCO dataset
164
+ - Detects sports balls (class 32)
165
+ - Variants: `yolov8n` (nano), `yolov8s` (small), `yolov8m` (medium)
166
+
167
+ ### Tracking Algorithm
168
+
169
+ **Kalman Filter**:
170
+ - State vector: `[x, y, vx, vy]` (position and velocity)
171
+ - Constant velocity motion model
172
+ - Predicts ball position when detection is lost
173
+ - Smooths noisy detections
174
+
175
+ ### Speed Estimation
176
+
177
+ ```
178
+ speed = sqrt(vx² + vy²) / dt
179
+ ```
180
+ where `dt = 1 / fps`
181
+
182
+ *Note: Speed is in pixels/sec. Real-world conversion requires camera calibration.*
183
+
184
+ ## 🎨 Visualization Features
185
+
186
+ - **Bounding Boxes**: Green boxes around detected balls
187
+ - **Trajectory Trail**: Fading trail showing recent positions
188
+ - **Speed Label**: Real-time speed estimate (km/h approximation)
189
+ - **Info Panel**: Frame number, timestamp, detection confidence
190
+ - **2D Plot**: Complete trajectory with color-coded speed
191
+
192
+ ## ⚠️ Limitations
193
+
194
+ - Requires visible ball in video
195
+ - Works best with clear, high-resolution footage
196
+ - Speed estimates are in pixels (not calibrated to real-world units)
197
+ - Processing time scales with video length and model size
198
+ - Free Hugging Face Spaces have limited GPU resources (use YOLOv8n)
199
+
200
+ ## 🔧 Configuration
201
+
202
+ ### Detector Parameters
203
+
204
+ ```python
205
+ detector = BallDetector(
206
+ model_name="yolov8n", # Model variant
207
+ confidence_threshold=0.3, # Min confidence score
208
+ device="cuda" # or "cpu"
209
+ )
210
+ ```
211
+
212
+ ### Tracker Parameters
213
+
214
+ ```python
215
+ tracker = BallTracker(
216
+ dt=1.0/30.0, # Time step (1/fps)
217
+ process_noise=0.1, # Process noise std
218
+ measurement_noise=10.0, # Measurement noise std
219
+ max_missing_frames=10 # Max frames without detection
220
+ )
221
+ ```
222
+
223
+ ## 🧪 Example Videos
224
+
225
+ Place sample videos in `example_videos/`:
226
+ - `serve.mp4` - Tennis serve motion
227
+ - `rally.mp4` - Rally with multiple ball trajectories
228
+
229
+ *Note: Sample videos not included in repository. Use your own tennis footage.*
230
+
231
+ ## 🐛 Troubleshooting
232
+
233
+ **No ball detected:**
234
+ - Lower the confidence threshold
235
+ - Ensure ball is clearly visible
236
+ - Try a different model (yolov8s or yolov8m)
237
+
238
+ **Slow processing:**
239
+ - Use yolov8n model
240
+ - Process shorter clips
241
+ - Use GPU if available
242
+
243
+ **Poor tracking accuracy:**
244
+ - Increase confidence threshold
245
+ - Adjust Kalman filter parameters
246
+ - Use higher resolution video
247
+
248
+ ## 📚 Dependencies
249
+
250
+ - **torch** >= 2.0.0 - Deep learning framework
251
+ - **ultralytics** >= 8.0.0 - YOLOv8 implementation
252
+ - **opencv-python-headless** == 4.8.1.78 - Video processing
253
+ - **gradio** >= 4.0.0 - Web interface
254
+ - **matplotlib** >= 3.7.0 - Plotting
255
+ - **filterpy** >= 1.4.5 - Kalman filter
256
+ - **numpy** >= 1.24.0 - Numerical operations
257
+
258
+ ## 🙏 Acknowledgments
259
+
260
+ - [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics) - Object detection
261
+ - [Gradio](https://gradio.app/) - Web interface framework
262
+ - [FilterPy](https://github.com/rlabbe/filterpy) - Kalman filter implementation
263
+
264
+ ## 📄 License
265
+
266
+ This project is open-source and available for educational and research purposes.
267
+
268
+ ## 🤝 Contributing
269
+
270
+ Contributions are welcome! Areas for improvement:
271
+ - [ ] Add RT-DETR model support
272
+ - [ ] Implement bounce detection
273
+ - [ ] Add FPS benchmark display
274
+ - [ ] Camera calibration for real-world speed
275
+ - [ ] Multi-ball tracking
276
+ - [ ] Player detection and tracking
277
+
278
+ ## 📧 Contact
279
+
280
+ For questions or issues, please open an issue on the repository.
281
+
282
+ ---
283
+
284
+ **Built with ❤️ for the tennis and computer vision community**
285
 
286
+ 🎾 Enjoy tracking! 🚀
__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TennisVision - AI Ball Tracker
3
+
4
+ A comprehensive tennis ball tracking system using YOLOv8 detection
5
+ and Kalman filter-based tracking.
6
+ """
7
+
8
+ __version__ = "1.0.0"
9
+ __author__ = "TennisVision Team"
10
+
11
+ from .detector import BallDetector
12
+ from .tracker import BallTracker
13
+
14
+ __all__ = ['BallDetector', 'BallTracker']
app.py CHANGED
@@ -1,7 +1,325 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
1
+ """
2
+ TennisVision - AI Ball Tracker
3
+ Gradio application for tennis ball detection and tracking.
4
+ """
5
+
6
  import gradio as gr
7
+ import cv2
8
+ import numpy as np
9
+ import os
10
+ import tempfile
11
+ from pathlib import Path
12
+ from typing import Tuple, Optional
13
+
14
+ from detector import BallDetector
15
+ from tracker import BallTracker
16
+ from utils import (
17
+ VideoReader,
18
+ VideoWriter,
19
+ export_trajectory_csv,
20
+ validate_video_file,
21
+ create_output_directory,
22
+ draw_detection,
23
+ draw_trajectory_trail,
24
+ draw_speed_label,
25
+ draw_info_panel,
26
+ create_trajectory_plot
27
+ )
28
+
29
+
30
+ def process_video(
31
+ video_path: str,
32
+ model_name: str,
33
+ confidence_threshold: float,
34
+ progress=gr.Progress()
35
+ ) -> Tuple[Optional[str], Optional[str], Optional[str], str]:
36
+ """
37
+ Process a video to track the tennis ball.
38
+
39
+ Args:
40
+ video_path: Path to input video file
41
+ model_name: Detection model identifier
42
+ confidence_threshold: Minimum detection confidence
43
+ progress: Gradio progress tracker
44
+
45
+ Returns:
46
+ Tuple of (output_video_path, csv_path, plot_path, status_message)
47
+ """
48
+ try:
49
+ # Validate input video
50
+ is_valid, msg = validate_video_file(video_path)
51
+ if not is_valid:
52
+ return None, None, None, f"❌ Error: {msg}"
53
+
54
+ progress(0, desc="Initializing models...")
55
+
56
+ # Initialize detector and tracker
57
+ detector = BallDetector(
58
+ model_name=model_name,
59
+ confidence_threshold=confidence_threshold
60
+ )
61
+
62
+ # Read video properties
63
+ with VideoReader(video_path) as reader:
64
+ video_props = reader.get_properties()
65
+
66
+ fps = video_props['fps']
67
+ frame_count = video_props['frame_count']
68
+ width = video_props['width']
69
+ height = video_props['height']
70
+
71
+ # Initialize tracker
72
+ tracker = BallTracker(dt=1.0 / fps, max_missing_frames=int(fps * 0.5))
73
+
74
+ # Create temporary output files
75
+ output_dir = create_output_directory("output")
76
+ temp_video = tempfile.NamedTemporaryFile(
77
+ delete=False, suffix='.mp4', dir=output_dir
78
+ )
79
+ output_video_path = temp_video.name
80
+ temp_video.close()
81
+
82
+ csv_path = output_dir / "trajectory.csv"
83
+ plot_path = output_dir / "trajectory_plot.png"
84
+
85
+ progress(0.1, desc="Processing frames...")
86
+
87
+ # Process video
88
+ detection_count = 0
89
+ with VideoReader(video_path) as reader, \
90
+ VideoWriter(output_video_path, fps, width, height) as writer:
91
+
92
+ for frame_num, frame in reader.read_frames():
93
+ # Update progress
94
+ progress_pct = 0.1 + 0.7 * (frame_num / frame_count)
95
+ progress(
96
+ progress_pct,
97
+ desc=f"Processing frame {frame_num + 1}/{frame_count}"
98
+ )
99
+
100
+ # Detect ball
101
+ detections = detector.detect(frame)
102
+
103
+ # Update tracker
104
+ if len(detections) > 0:
105
+ # Use highest confidence detection
106
+ best_detection = detections[0]
107
+ cx, cy = detector.get_ball_center(best_detection)
108
+ state = tracker.update((cx, cy))
109
+ detection_count += 1
110
+
111
+ # Draw detection box
112
+ frame = draw_detection(frame, best_detection)
113
+ else:
114
+ # Predict without detection
115
+ state = tracker.update(None)
116
+
117
+ # Draw trajectory and info if tracker is active
118
+ if state is not None and tracker.is_initialized():
119
+ x, y, vx, vy = state
120
+
121
+ # Draw trajectory trail
122
+ positions = tracker.get_last_n_positions(20)
123
+ frame = draw_trajectory_trail(frame, positions)
124
+
125
+ # Calculate and draw speed
126
+ speed = tracker.get_speed(state)
127
+ frame = draw_speed_label(frame, (x, y), speed, fps)
128
+
129
+ # Draw info panel
130
+ conf = detections[0][4] if len(detections) > 0 else None
131
+ frame = draw_info_panel(frame, frame_num + 1, frame_count, fps, conf)
132
+
133
+ # Write frame
134
+ writer.write_frame(frame)
135
+
136
+ # Export trajectory data
137
+ progress(0.8, desc="Exporting trajectory data...")
138
+ trajectory = tracker.get_trajectory()
139
+
140
+ if len(trajectory) == 0:
141
+ return None, None, None, "❌ No ball detected in video. Try lowering the confidence threshold."
142
+
143
+ # Export CSV
144
+ export_success = export_trajectory_csv(trajectory, fps, str(csv_path))
145
+ if not export_success:
146
+ csv_path = None
147
+
148
+ # Create trajectory plot
149
+ progress(0.9, desc="Creating trajectory plot...")
150
+ try:
151
+ create_trajectory_plot(trajectory, fps, str(plot_path))
152
+ except Exception as e:
153
+ print(f"Failed to create plot: {e}")
154
+ plot_path = None
155
+
156
+ progress(1.0, desc="Complete!")
157
+
158
+ # Generate status message
159
+ status = f"""✅ **Processing Complete!**
160
+
161
+ **Video Info:**
162
+ - Total Frames: {frame_count}
163
+ - Frame Rate: {fps:.1f} FPS
164
+ - Resolution: {width}x{height}
165
+
166
+ **Tracking Results:**
167
+ - Ball Detected: {detection_count} frames ({100 * detection_count / frame_count:.1f}%)
168
+ - Trajectory Points: {len(trajectory)}
169
+
170
+ **Outputs:**
171
+ - Processed video with overlays
172
+ - Trajectory CSV with {len(trajectory)} data points
173
+ - 2D trajectory plot color-coded by speed
174
+ """
175
+
176
+ return (
177
+ output_video_path,
178
+ str(csv_path) if csv_path else None,
179
+ str(plot_path) if plot_path else None,
180
+ status
181
+ )
182
+
183
+ except Exception as e:
184
+ error_msg = f"❌ **Error during processing:** {str(e)}"
185
+ print(error_msg)
186
+ import traceback
187
+ traceback.print_exc()
188
+ return None, None, None, error_msg
189
+
190
+
191
+ # Create Gradio interface
192
+ def create_interface():
193
+ """Create and configure the Gradio interface."""
194
+
195
+ with gr.Blocks(
196
+ title="TennisVision - AI Ball Tracker",
197
+ theme=gr.themes.Soft()
198
+ ) as app:
199
+ gr.Markdown(
200
+ """
201
+ # 🎾 TennisVision - AI Ball Tracker
202
+
203
+ Upload a tennis video to automatically detect and track the ball using
204
+ state-of-the-art computer vision models.
205
+
206
+ **Features:**
207
+ - Real-time ball detection with YOLOv8
208
+ - Smooth trajectory tracking with Kalman filter
209
+ - Speed estimation and visualization
210
+ - Downloadable outputs (video, CSV, plot)
211
+ """
212
+ )
213
+
214
+ with gr.Row():
215
+ with gr.Column(scale=1):
216
+ gr.Markdown("### ⚙️ Input & Settings")
217
+
218
+ video_input = gr.Video(
219
+ label="Upload Tennis Video",
220
+ sources=["upload"]
221
+ )
222
+
223
+ model_dropdown = gr.Dropdown(
224
+ choices=["yolov8n", "yolov8s", "yolov8m"],
225
+ value="yolov8n",
226
+ label="Detection Model",
227
+ info="yolov8n is fastest, yolov8m is most accurate"
228
+ )
229
+
230
+ confidence_slider = gr.Slider(
231
+ minimum=0.1,
232
+ maximum=0.9,
233
+ value=0.3,
234
+ step=0.05,
235
+ label="Confidence Threshold",
236
+ info="Lower = more detections (may include false positives)"
237
+ )
238
+
239
+ process_btn = gr.Button(
240
+ "🚀 Run Tracking",
241
+ variant="primary",
242
+ size="lg"
243
+ )
244
+
245
+ gr.Markdown(
246
+ """
247
+ ### 💡 Tips
248
+ - Use short clips (5-15 seconds) for faster processing
249
+ - Ensure the ball is visible and in motion
250
+ - Lower confidence threshold if ball is not detected
251
+ - YOLOv8n provides fastest inference (~30 FPS)
252
+ """
253
+ )
254
+
255
+ with gr.Column(scale=2):
256
+ gr.Markdown("### 📊 Results")
257
+
258
+ status_output = gr.Markdown(
259
+ "Upload a video and click **Run Tracking** to begin."
260
+ )
261
+
262
+ with gr.Tabs():
263
+ with gr.Tab("📹 Processed Video"):
264
+ video_output = gr.Video(
265
+ label="Tracked Video",
266
+ show_label=False
267
+ )
268
+
269
+ with gr.Tab("📈 Trajectory Plot"):
270
+ plot_output = gr.Image(
271
+ label="2D Trajectory",
272
+ show_label=False
273
+ )
274
+
275
+ with gr.Tab("📥 Downloads"):
276
+ gr.Markdown("### Download Files")
277
+ csv_output = gr.File(
278
+ label="Trajectory Data (CSV)"
279
+ )
280
+ video_download = gr.File(
281
+ label="Processed Video (MP4)"
282
+ )
283
+
284
+ # Event handlers
285
+ process_btn.click(
286
+ fn=process_video,
287
+ inputs=[video_input, model_dropdown, confidence_slider],
288
+ outputs=[video_output, csv_output, plot_output, status_output]
289
+ ).then(
290
+ fn=lambda x: x,
291
+ inputs=[video_output],
292
+ outputs=[video_download]
293
+ )
294
+
295
+ gr.Markdown(
296
+ """
297
+ ---
298
+ ### 📚 About
299
+
300
+ **TennisVision** uses YOLOv8 for ball detection and Kalman filtering
301
+ for smooth trajectory tracking. The system estimates ball speed and
302
+ visualizes the complete trajectory with color-coded speed indicators.
303
+
304
+ **Model:** YOLOv8 (Ultralytics)
305
+ **Tracking:** Kalman Filter
306
+ **Framework:** Gradio + OpenCV
307
+
308
+ Built for deployment on Hugging Face Spaces 🤗
309
+ """
310
+ )
311
+
312
+ return app
313
+
314
 
315
+ if __name__ == "__main__":
316
+ # Create output directory
317
+ create_output_directory("output")
318
 
319
+ # Launch app
320
+ app = create_interface()
321
+ app.launch(
322
+ share=False,
323
+ server_name="0.0.0.0",
324
+ server_port=7860
325
+ )
detector.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ball Detection Module using YOLO and RT-DETR models.
3
+
4
+ This module provides a unified interface for detecting tennis balls
5
+ in video frames using state-of-the-art object detection models.
6
+ """
7
+
8
+ import torch
9
+ import numpy as np
10
+ from typing import List, Tuple, Optional
11
+ from ultralytics import YOLO
12
+
13
+
14
+ class BallDetector:
15
+ """
16
+ Wrapper class for ball detection using YOLOv8 or RT-DETR.
17
+
18
+ Attributes:
19
+ model_name (str): Name of the detection model ('yolov8n', 'yolov8s', etc.)
20
+ confidence_threshold (float): Minimum confidence score for detections
21
+ device (str): Device to run inference on ('cuda' or 'cpu')
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model_name: str = "yolov8n",
27
+ confidence_threshold: float = 0.3,
28
+ device: Optional[str] = None
29
+ ):
30
+ """
31
+ Initialize the ball detector.
32
+
33
+ Args:
34
+ model_name: Model identifier (e.g., 'yolov8n', 'yolov8s')
35
+ confidence_threshold: Minimum confidence for valid detections
36
+ device: Compute device ('cuda', 'cpu', or None for auto-detect)
37
+ """
38
+ self.model_name = model_name
39
+ self.confidence_threshold = confidence_threshold
40
+
41
+ # Auto-detect device if not specified
42
+ if device is None:
43
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ else:
45
+ self.device = device
46
+
47
+ # Load model
48
+ self.model = self._load_model()
49
+
50
+ def _load_model(self) -> YOLO:
51
+ """
52
+ Load the specified detection model.
53
+
54
+ Returns:
55
+ Loaded YOLO model instance
56
+
57
+ Raises:
58
+ ValueError: If model name is not supported
59
+ """
60
+ try:
61
+ if self.model_name.startswith('yolov8'):
62
+ # Load YOLOv8 model from Ultralytics
63
+ model = YOLO(f'{self.model_name}.pt')
64
+ model.to(self.device)
65
+ return model
66
+ else:
67
+ raise ValueError(f"Unsupported model: {self.model_name}")
68
+ except Exception as e:
69
+ raise RuntimeError(f"Failed to load model {self.model_name}: {str(e)}")
70
+
71
+ def detect(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
72
+ """
73
+ Detect tennis balls in a single frame.
74
+
75
+ Args:
76
+ frame: Input frame as numpy array (H, W, 3) in BGR format
77
+
78
+ Returns:
79
+ List of detections, each as (x1, y1, x2, y2, confidence)
80
+ where (x1, y1) is top-left and (x2, y2) is bottom-right
81
+ """
82
+ try:
83
+ # Run inference
84
+ results = self.model.predict(
85
+ frame,
86
+ conf=self.confidence_threshold,
87
+ device=self.device,
88
+ verbose=False,
89
+ classes=[32] # Sports ball class in COCO dataset
90
+ )
91
+
92
+ detections = []
93
+
94
+ # Parse results
95
+ if len(results) > 0 and results[0].boxes is not None:
96
+ boxes = results[0].boxes
97
+
98
+ for box in boxes:
99
+ # Extract bounding box coordinates
100
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
101
+ confidence = float(box.conf[0].cpu().numpy())
102
+
103
+ # Filter small detections (likely noise)
104
+ width = x2 - x1
105
+ height = y2 - y1
106
+
107
+ if width > 5 and height > 5: # Minimum size threshold
108
+ detections.append((
109
+ int(x1), int(y1), int(x2), int(y2), confidence
110
+ ))
111
+
112
+ # Sort by confidence (highest first)
113
+ detections.sort(key=lambda x: x[4], reverse=True)
114
+
115
+ return detections
116
+
117
+ except Exception as e:
118
+ print(f"Detection error: {str(e)}")
119
+ return []
120
+
121
+ def get_ball_center(
122
+ self,
123
+ detection: Tuple[int, int, int, int, float]
124
+ ) -> Tuple[float, float]:
125
+ """
126
+ Calculate the center point of a ball detection.
127
+
128
+ Args:
129
+ detection: Bounding box as (x1, y1, x2, y2, confidence)
130
+
131
+ Returns:
132
+ Center coordinates as (cx, cy)
133
+ """
134
+ x1, y1, x2, y2, _ = detection
135
+ cx = (x1 + x2) / 2.0
136
+ cy = (y1 + y2) / 2.0
137
+ return cx, cy
138
+
139
+ def get_ball_size(
140
+ self,
141
+ detection: Tuple[int, int, int, int, float]
142
+ ) -> Tuple[float, float]:
143
+ """
144
+ Calculate the width and height of a ball detection.
145
+
146
+ Args:
147
+ detection: Bounding box as (x1, y1, x2, y2, confidence)
148
+
149
+ Returns:
150
+ Size as (width, height)
151
+ """
152
+ x1, y1, x2, y2, _ = detection
153
+ width = x2 - x1
154
+ height = y2 - y1
155
+ return width, height
packages.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ libgl1
2
+ libglib2.0-0
3
+ libsm6
4
+ libxext6
5
+ libxrender-dev
6
+ libgomp1
7
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ opencv-python-headless==4.8.1.78
4
+ ultralytics>=8.0.0
5
+ numpy>=1.24.0
6
+ gradio>=4.0.0
7
+ matplotlib>=3.7.0
8
+ filterpy>=1.4.5
9
+ Pillow>=10.0.0
tracker.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ball Tracking Module using Kalman Filter.
3
+
4
+ This module implements a Kalman filter-based tracker for smoothing
5
+ and predicting tennis ball positions across video frames.
6
+ """
7
+
8
+ import numpy as np
9
+ from typing import Optional, Tuple, List
10
+ from filterpy.kalman import KalmanFilter
11
+
12
+
13
+ class BallTracker:
14
+ """
15
+ Kalman filter-based tracker for tennis ball position and velocity.
16
+
17
+ The tracker maintains state estimates for:
18
+ - Position (x, y)
19
+ - Velocity (vx, vy)
20
+
21
+ Attributes:
22
+ dt (float): Time step between frames (1/fps)
23
+ process_noise (float): Process noise covariance
24
+ measurement_noise (float): Measurement noise covariance
25
+ max_missing_frames (int): Maximum frames without detection before reset
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ dt: float = 1.0 / 30.0,
31
+ process_noise: float = 0.1,
32
+ measurement_noise: float = 10.0,
33
+ max_missing_frames: int = 10
34
+ ):
35
+ """
36
+ Initialize the ball tracker.
37
+
38
+ Args:
39
+ dt: Time step between frames (seconds)
40
+ process_noise: Process noise standard deviation
41
+ measurement_noise: Measurement noise standard deviation
42
+ max_missing_frames: Max consecutive frames without detection
43
+ """
44
+ self.dt = dt
45
+ self.process_noise = process_noise
46
+ self.measurement_noise = measurement_noise
47
+ self.max_missing_frames = max_missing_frames
48
+
49
+ # Initialize Kalman filter
50
+ self.kf = self._create_kalman_filter()
51
+
52
+ # Tracking state
53
+ self.initialized = False
54
+ self.missing_frames = 0
55
+ self.trajectory = [] # List of (x, y, vx, vy, frame_num)
56
+ self.frame_count = 0
57
+
58
+ def _create_kalman_filter(self) -> KalmanFilter:
59
+ """
60
+ Create and configure a Kalman filter for 2D position tracking.
61
+
62
+ State vector: [x, y, vx, vy]
63
+ Measurement vector: [x, y]
64
+
65
+ Returns:
66
+ Configured KalmanFilter instance
67
+ """
68
+ kf = KalmanFilter(dim_x=4, dim_z=2)
69
+
70
+ # State transition matrix (constant velocity model)
71
+ kf.F = np.array([
72
+ [1, 0, self.dt, 0],
73
+ [0, 1, 0, self.dt],
74
+ [0, 0, 1, 0],
75
+ [0, 0, 0, 1]
76
+ ])
77
+
78
+ # Measurement matrix (observe position only)
79
+ kf.H = np.array([
80
+ [1, 0, 0, 0],
81
+ [0, 1, 0, 0]
82
+ ])
83
+
84
+ # Measurement noise covariance
85
+ kf.R = np.eye(2) * self.measurement_noise
86
+
87
+ # Process noise covariance
88
+ q = self.process_noise
89
+ kf.Q = np.array([
90
+ [q * self.dt**4 / 4, 0, q * self.dt**3 / 2, 0],
91
+ [0, q * self.dt**4 / 4, 0, q * self.dt**3 / 2],
92
+ [q * self.dt**3 / 2, 0, q * self.dt**2, 0],
93
+ [0, q * self.dt**3 / 2, 0, q * self.dt**2]
94
+ ])
95
+
96
+ # Initial state covariance
97
+ kf.P = np.eye(4) * 100
98
+
99
+ return kf
100
+
101
+ def update(
102
+ self,
103
+ measurement: Optional[Tuple[float, float]] = None
104
+ ) -> Optional[Tuple[float, float, float, float]]:
105
+ """
106
+ Update tracker with a new measurement or predict if no detection.
107
+
108
+ Args:
109
+ measurement: Ball center position as (x, y), or None if not detected
110
+
111
+ Returns:
112
+ Estimated state as (x, y, vx, vy) or None if tracker not initialized
113
+ """
114
+ self.frame_count += 1
115
+
116
+ if measurement is not None:
117
+ # Detection available
118
+ if not self.initialized:
119
+ # Initialize tracker with first detection
120
+ self.kf.x = np.array([
121
+ measurement[0],
122
+ measurement[1],
123
+ 0.0,
124
+ 0.0
125
+ ])
126
+ self.initialized = True
127
+ self.missing_frames = 0
128
+ else:
129
+ # Update with measurement
130
+ z = np.array([measurement[0], measurement[1]])
131
+ self.kf.predict()
132
+ self.kf.update(z)
133
+ self.missing_frames = 0
134
+
135
+ # Record trajectory
136
+ x, y, vx, vy = self.kf.x
137
+ self.trajectory.append((
138
+ float(x), float(y), float(vx), float(vy), self.frame_count
139
+ ))
140
+
141
+ return (float(x), float(y), float(vx), float(vy))
142
+
143
+ else:
144
+ # No detection - predict only
145
+ if self.initialized:
146
+ self.kf.predict()
147
+ self.missing_frames += 1
148
+
149
+ # Reset if too many missing frames
150
+ if self.missing_frames > self.max_missing_frames:
151
+ self.reset()
152
+ return None
153
+
154
+ # Return prediction
155
+ x, y, vx, vy = self.kf.x
156
+ self.trajectory.append((
157
+ float(x), float(y), float(vx), float(vy), self.frame_count
158
+ ))
159
+ return (float(x), float(y), float(vx), float(vy))
160
+
161
+ return None
162
+
163
+ def reset(self):
164
+ """Reset tracker to uninitialized state."""
165
+ self.kf = self._create_kalman_filter()
166
+ self.initialized = False
167
+ self.missing_frames = 0
168
+
169
+ def get_trajectory(self) -> List[Tuple[float, float, float, float, int]]:
170
+ """
171
+ Get the complete trajectory history.
172
+
173
+ Returns:
174
+ List of trajectory points as (x, y, vx, vy, frame_num)
175
+ """
176
+ return self.trajectory
177
+
178
+ def get_speed(self, state: Tuple[float, float, float, float]) -> float:
179
+ """
180
+ Calculate speed from velocity components.
181
+
182
+ Args:
183
+ state: Tracker state as (x, y, vx, vy)
184
+
185
+ Returns:
186
+ Speed in pixels per second
187
+ """
188
+ _, _, vx, vy = state
189
+ speed = np.sqrt(vx**2 + vy**2) / self.dt
190
+ return float(speed)
191
+
192
+ def get_last_n_positions(self, n: int = 20) -> List[Tuple[float, float]]:
193
+ """
194
+ Get the last N tracked positions for trail visualization.
195
+
196
+ Args:
197
+ n: Number of recent positions to return
198
+
199
+ Returns:
200
+ List of (x, y) coordinates
201
+ """
202
+ if len(self.trajectory) == 0:
203
+ return []
204
+
205
+ recent = self.trajectory[-n:]
206
+ return [(x, y) for x, y, _, _, _ in recent]
207
+
208
+ def is_initialized(self) -> bool:
209
+ """Check if tracker has been initialized with a detection."""
210
+ return self.initialized
utils/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility modules for TennisVision."""
2
+
3
+ from .visualization import (
4
+ draw_detection,
5
+ draw_trajectory_trail,
6
+ draw_speed_label,
7
+ draw_info_panel,
8
+ create_trajectory_plot
9
+ )
10
+
11
+ from .io_utils import (
12
+ VideoReader,
13
+ VideoWriter,
14
+ export_trajectory_csv,
15
+ get_video_info,
16
+ validate_video_file,
17
+ create_output_directory
18
+ )
19
+
20
+ __all__ = [
21
+ 'draw_detection',
22
+ 'draw_trajectory_trail',
23
+ 'draw_speed_label',
24
+ 'draw_info_panel',
25
+ 'create_trajectory_plot',
26
+ 'VideoReader',
27
+ 'VideoWriter',
28
+ 'export_trajectory_csv',
29
+ 'get_video_info',
30
+ 'validate_video_file',
31
+ 'create_output_directory'
32
+ ]
utils/io_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ I/O utilities for video processing and data export.
3
+
4
+ This module provides functions for reading/writing videos,
5
+ exporting trajectory data to CSV, and handling file operations.
6
+ """
7
+
8
+ import cv2
9
+ import csv
10
+ import numpy as np
11
+ from typing import List, Tuple, Optional, Generator
12
+ from pathlib import Path
13
+
14
+
15
+ class VideoReader:
16
+ """
17
+ Context manager for reading video files frame by frame.
18
+
19
+ Attributes:
20
+ video_path (str): Path to input video file
21
+ cap (cv2.VideoCapture): OpenCV video capture object
22
+ """
23
+
24
+ def __init__(self, video_path: str):
25
+ """
26
+ Initialize video reader.
27
+
28
+ Args:
29
+ video_path: Path to the video file
30
+
31
+ Raises:
32
+ FileNotFoundError: If video file doesn't exist
33
+ RuntimeError: If video cannot be opened
34
+ """
35
+ self.video_path = video_path
36
+
37
+ if not Path(video_path).exists():
38
+ raise FileNotFoundError(f"Video file not found: {video_path}")
39
+
40
+ self.cap = cv2.VideoCapture(video_path)
41
+
42
+ if not self.cap.isOpened():
43
+ raise RuntimeError(f"Failed to open video: {video_path}")
44
+
45
+ def __enter__(self):
46
+ """Context manager entry."""
47
+ return self
48
+
49
+ def __exit__(self, exc_type, exc_val, exc_tb):
50
+ """Context manager exit - release video capture."""
51
+ self.cap.release()
52
+
53
+ def get_properties(self) -> dict:
54
+ """
55
+ Get video properties.
56
+
57
+ Returns:
58
+ Dictionary containing fps, frame_count, width, height
59
+ """
60
+ return {
61
+ 'fps': self.cap.get(cv2.CAP_PROP_FPS),
62
+ 'frame_count': int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)),
63
+ 'width': int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
64
+ 'height': int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
65
+ }
66
+
67
+ def read_frames(self) -> Generator[Tuple[int, np.ndarray], None, None]:
68
+ """
69
+ Generator that yields frames from the video.
70
+
71
+ Yields:
72
+ Tuple of (frame_number, frame_array)
73
+ """
74
+ frame_num = 0
75
+ while True:
76
+ ret, frame = self.cap.read()
77
+ if not ret:
78
+ break
79
+ yield frame_num, frame
80
+ frame_num += 1
81
+
82
+ def read_frame(self) -> Tuple[bool, Optional[np.ndarray]]:
83
+ """
84
+ Read a single frame.
85
+
86
+ Returns:
87
+ Tuple of (success, frame) where success is a boolean
88
+ """
89
+ return self.cap.read()
90
+
91
+
92
+ class VideoWriter:
93
+ """
94
+ Context manager for writing video files.
95
+
96
+ Attributes:
97
+ output_path (str): Path to output video file
98
+ fps (float): Frame rate
99
+ width (int): Frame width
100
+ height (int): Frame height
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ output_path: str,
106
+ fps: float,
107
+ width: int,
108
+ height: int,
109
+ codec: str = 'mp4v'
110
+ ):
111
+ """
112
+ Initialize video writer.
113
+
114
+ Args:
115
+ output_path: Path to save the video
116
+ fps: Frame rate
117
+ width: Frame width in pixels
118
+ height: Frame height in pixels
119
+ codec: Video codec fourcc code
120
+ """
121
+ self.output_path = output_path
122
+ self.fps = fps
123
+ self.width = width
124
+ self.height = height
125
+
126
+ # Create output directory if it doesn't exist
127
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
128
+
129
+ # Initialize video writer
130
+ fourcc = cv2.VideoWriter_fourcc(*codec)
131
+ self.writer = cv2.VideoWriter(
132
+ output_path,
133
+ fourcc,
134
+ fps,
135
+ (width, height)
136
+ )
137
+
138
+ if not self.writer.isOpened():
139
+ raise RuntimeError(f"Failed to create video writer: {output_path}")
140
+
141
+ def __enter__(self):
142
+ """Context manager entry."""
143
+ return self
144
+
145
+ def __exit__(self, exc_type, exc_val, exc_tb):
146
+ """Context manager exit - release video writer."""
147
+ self.writer.release()
148
+
149
+ def write_frame(self, frame: np.ndarray):
150
+ """
151
+ Write a single frame to the video.
152
+
153
+ Args:
154
+ frame: Frame array in BGR format
155
+ """
156
+ # Ensure frame has correct dimensions
157
+ if frame.shape[1] != self.width or frame.shape[0] != self.height:
158
+ frame = cv2.resize(frame, (self.width, self.height))
159
+
160
+ self.writer.write(frame)
161
+
162
+
163
+ def export_trajectory_csv(
164
+ trajectory: List[Tuple[float, float, float, float, int]],
165
+ fps: float,
166
+ output_path: str
167
+ ) -> bool:
168
+ """
169
+ Export trajectory data to CSV file.
170
+
171
+ Args:
172
+ trajectory: List of (x, y, vx, vy, frame_num) tuples
173
+ fps: Video frame rate
174
+ output_path: Path to save CSV file
175
+
176
+ Returns:
177
+ True if successful, False otherwise
178
+ """
179
+ try:
180
+ # Create output directory if needed
181
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
182
+
183
+ with open(output_path, 'w', newline='') as csvfile:
184
+ writer = csv.writer(csvfile)
185
+
186
+ # Write header
187
+ writer.writerow([
188
+ 'frame',
189
+ 'timestamp_sec',
190
+ 'x_pixels',
191
+ 'y_pixels',
192
+ 'velocity_x_px_per_sec',
193
+ 'velocity_y_px_per_sec',
194
+ 'speed_px_per_sec'
195
+ ])
196
+
197
+ # Write data rows
198
+ for x, y, vx, vy, frame_num in trajectory:
199
+ timestamp = frame_num / fps
200
+ speed = np.sqrt(vx**2 + vy**2) / (1.0 / fps)
201
+
202
+ writer.writerow([
203
+ frame_num,
204
+ f"{timestamp:.3f}",
205
+ f"{x:.2f}",
206
+ f"{y:.2f}",
207
+ f"{vx / (1.0 / fps):.2f}",
208
+ f"{vy / (1.0 / fps):.2f}",
209
+ f"{speed:.2f}"
210
+ ])
211
+
212
+ return True
213
+
214
+ except Exception as e:
215
+ print(f"Error exporting CSV: {str(e)}")
216
+ return False
217
+
218
+
219
+ def get_video_info(video_path: str) -> Optional[dict]:
220
+ """
221
+ Get basic information about a video file.
222
+
223
+ Args:
224
+ video_path: Path to video file
225
+
226
+ Returns:
227
+ Dictionary with video properties or None if failed
228
+ """
229
+ try:
230
+ with VideoReader(video_path) as reader:
231
+ return reader.get_properties()
232
+ except Exception as e:
233
+ print(f"Error reading video info: {str(e)}")
234
+ return None
235
+
236
+
237
+ def validate_video_file(video_path: str) -> Tuple[bool, str]:
238
+ """
239
+ Validate that a video file exists and can be opened.
240
+
241
+ Args:
242
+ video_path: Path to video file
243
+
244
+ Returns:
245
+ Tuple of (is_valid, error_message)
246
+ """
247
+ if not video_path:
248
+ return False, "No video path provided"
249
+
250
+ path = Path(video_path)
251
+
252
+ if not path.exists():
253
+ return False, f"Video file not found: {video_path}"
254
+
255
+ if not path.is_file():
256
+ return False, f"Path is not a file: {video_path}"
257
+
258
+ # Try to open the video
259
+ try:
260
+ with VideoReader(video_path) as reader:
261
+ props = reader.get_properties()
262
+
263
+ if props['frame_count'] == 0:
264
+ return False, "Video has no frames"
265
+
266
+ if props['fps'] <= 0:
267
+ return False, "Invalid video frame rate"
268
+
269
+ return True, "Valid video file"
270
+
271
+ except Exception as e:
272
+ return False, f"Failed to open video: {str(e)}"
273
+
274
+
275
+ def create_output_directory(output_dir: str = "output") -> Path:
276
+ """
277
+ Create output directory if it doesn't exist.
278
+
279
+ Args:
280
+ output_dir: Directory name/path
281
+
282
+ Returns:
283
+ Path object for the output directory
284
+ """
285
+ output_path = Path(output_dir)
286
+ output_path.mkdir(parents=True, exist_ok=True)
287
+ return output_path
utils/visualization.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for ball tracking.
3
+
4
+ This module provides functions for rendering bounding boxes, trajectories,
5
+ and creating 2D trajectory plots with speed-based color coding.
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.colors as mcolors
12
+ from typing import List, Tuple, Optional
13
+ from matplotlib.figure import Figure
14
+
15
+
16
+ def draw_detection(
17
+ frame: np.ndarray,
18
+ detection: Tuple[int, int, int, int, float],
19
+ color: Tuple[int, int, int] = (0, 255, 0),
20
+ thickness: int = 2
21
+ ) -> np.ndarray:
22
+ """
23
+ Draw a bounding box for a detection on the frame.
24
+
25
+ Args:
26
+ frame: Input frame (BGR format)
27
+ detection: Bounding box as (x1, y1, x2, y2, confidence)
28
+ color: Box color in BGR format
29
+ thickness: Line thickness
30
+
31
+ Returns:
32
+ Frame with drawn bounding box
33
+ """
34
+ x1, y1, x2, y2, conf = detection
35
+
36
+ # Draw rectangle
37
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
38
+
39
+ # Draw confidence label
40
+ label = f"{conf:.2f}"
41
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
42
+ label_y = max(y1 - 10, label_size[1])
43
+
44
+ cv2.rectangle(
45
+ frame,
46
+ (x1, label_y - label_size[1] - 5),
47
+ (x1 + label_size[0], label_y + 5),
48
+ color,
49
+ -1
50
+ )
51
+ cv2.putText(
52
+ frame,
53
+ label,
54
+ (x1, label_y),
55
+ cv2.FONT_HERSHEY_SIMPLEX,
56
+ 0.5,
57
+ (0, 0, 0),
58
+ 1
59
+ )
60
+
61
+ return frame
62
+
63
+
64
+ def draw_trajectory_trail(
65
+ frame: np.ndarray,
66
+ positions: List[Tuple[float, float]],
67
+ color: Tuple[int, int, int] = (0, 255, 255),
68
+ max_points: int = 20
69
+ ) -> np.ndarray:
70
+ """
71
+ Draw a trail showing recent ball positions.
72
+
73
+ Args:
74
+ frame: Input frame (BGR format)
75
+ positions: List of (x, y) positions (most recent last)
76
+ color: Trail color in BGR format
77
+ max_points: Maximum number of points to show
78
+
79
+ Returns:
80
+ Frame with drawn trajectory trail
81
+ """
82
+ if len(positions) < 2:
83
+ return frame
84
+
85
+ # Use only recent positions
86
+ recent = positions[-max_points:]
87
+
88
+ # Draw lines connecting positions with fading effect
89
+ for i in range(1, len(recent)):
90
+ # Calculate alpha (opacity) based on position in trail
91
+ alpha = i / len(recent)
92
+
93
+ # Blend color with background
94
+ pt1 = (int(recent[i - 1][0]), int(recent[i - 1][1]))
95
+ pt2 = (int(recent[i][0]), int(recent[i][1]))
96
+
97
+ # Draw line with thickness varying by position
98
+ thickness = max(1, int(2 * alpha))
99
+ line_color = tuple(int(c * alpha) for c in color)
100
+
101
+ cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA)
102
+
103
+ # Draw circle at current position
104
+ if len(recent) > 0:
105
+ curr_pos = (int(recent[-1][0]), int(recent[-1][1]))
106
+ cv2.circle(frame, curr_pos, 5, color, -1, cv2.LINE_AA)
107
+
108
+ return frame
109
+
110
+
111
+ def draw_speed_label(
112
+ frame: np.ndarray,
113
+ position: Tuple[float, float],
114
+ speed: float,
115
+ fps: float,
116
+ color: Tuple[int, int, int] = (255, 255, 255)
117
+ ) -> np.ndarray:
118
+ """
119
+ Draw speed information near the ball position.
120
+
121
+ Args:
122
+ frame: Input frame (BGR format)
123
+ position: Ball position as (x, y)
124
+ speed: Speed in pixels per second
125
+ fps: Video frame rate
126
+ color: Text color in BGR format
127
+
128
+ Returns:
129
+ Frame with speed label
130
+ """
131
+ x, y = int(position[0]), int(position[1])
132
+
133
+ # Convert pixel speed to approximate real-world units
134
+ # (This is a rough estimate; proper conversion requires camera calibration)
135
+ speed_kmh = speed * 0.01 # Rough approximation
136
+
137
+ label = f"{speed_kmh:.1f} km/h"
138
+
139
+ # Draw label with background
140
+ font = cv2.FONT_HERSHEY_SIMPLEX
141
+ font_scale = 0.6
142
+ thickness = 2
143
+ label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
144
+
145
+ # Position label above the ball
146
+ label_x = x - label_size[0] // 2
147
+ label_y = y - 20
148
+
149
+ # Ensure label stays within frame
150
+ label_x = max(0, min(label_x, frame.shape[1] - label_size[0]))
151
+ label_y = max(label_size[1] + 5, label_y)
152
+
153
+ # Draw background rectangle
154
+ cv2.rectangle(
155
+ frame,
156
+ (label_x - 5, label_y - label_size[1] - 5),
157
+ (label_x + label_size[0] + 5, label_y + 5),
158
+ (0, 0, 0),
159
+ -1
160
+ )
161
+
162
+ # Draw text
163
+ cv2.putText(
164
+ frame,
165
+ label,
166
+ (label_x, label_y),
167
+ font,
168
+ font_scale,
169
+ color,
170
+ thickness,
171
+ cv2.LINE_AA
172
+ )
173
+
174
+ return frame
175
+
176
+
177
+ def draw_info_panel(
178
+ frame: np.ndarray,
179
+ frame_num: int,
180
+ total_frames: int,
181
+ fps: float,
182
+ detection_conf: Optional[float] = None
183
+ ) -> np.ndarray:
184
+ """
185
+ Draw an information panel at the top of the frame.
186
+
187
+ Args:
188
+ frame: Input frame (BGR format)
189
+ frame_num: Current frame number
190
+ total_frames: Total number of frames
191
+ fps: Video frame rate
192
+ detection_conf: Detection confidence (if available)
193
+
194
+ Returns:
195
+ Frame with info panel
196
+ """
197
+ # Create semi-transparent overlay
198
+ overlay = frame.copy()
199
+ cv2.rectangle(overlay, (0, 0), (frame.shape[1], 60), (0, 0, 0), -1)
200
+ frame = cv2.addWeighted(overlay, 0.6, frame, 0.4, 0)
201
+
202
+ # Draw text information
203
+ font = cv2.FONT_HERSHEY_SIMPLEX
204
+ font_scale = 0.6
205
+ color = (255, 255, 255)
206
+ thickness = 2
207
+
208
+ # Frame counter
209
+ frame_text = f"Frame: {frame_num}/{total_frames}"
210
+ cv2.putText(frame, frame_text, (10, 25), font, font_scale, color, thickness)
211
+
212
+ # Time
213
+ time_text = f"Time: {frame_num / fps:.2f}s"
214
+ cv2.putText(frame, time_text, (10, 50), font, font_scale, color, thickness)
215
+
216
+ # Detection confidence (if available)
217
+ if detection_conf is not None:
218
+ conf_text = f"Confidence: {detection_conf:.2%}"
219
+ cv2.putText(frame, conf_text, (250, 25), font, font_scale, color, thickness)
220
+
221
+ return frame
222
+
223
+
224
+ def create_trajectory_plot(
225
+ trajectory: List[Tuple[float, float, float, float, int]],
226
+ fps: float,
227
+ output_path: Optional[str] = None
228
+ ) -> Figure:
229
+ """
230
+ Create a 2D trajectory plot color-coded by speed.
231
+
232
+ Args:
233
+ trajectory: List of (x, y, vx, vy, frame_num) tuples
234
+ fps: Video frame rate
235
+ output_path: Path to save plot (optional)
236
+
237
+ Returns:
238
+ Matplotlib Figure object
239
+ """
240
+ if len(trajectory) == 0:
241
+ # Create empty plot
242
+ fig, ax = plt.subplots(figsize=(10, 8))
243
+ ax.text(
244
+ 0.5, 0.5, "No trajectory data available",
245
+ ha='center', va='center', fontsize=14
246
+ )
247
+ ax.set_xlim(0, 1)
248
+ ax.set_ylim(0, 1)
249
+ return fig
250
+
251
+ # Extract coordinates and velocities
252
+ x_coords = [p[0] for p in trajectory]
253
+ y_coords = [p[1] for p in trajectory]
254
+ vx = [p[2] for p in trajectory]
255
+ vy = [p[3] for p in trajectory]
256
+
257
+ # Calculate speeds
258
+ speeds = [np.sqrt(vx[i]**2 + vy[i]**2) / (1.0 / fps) for i in range(len(vx))]
259
+
260
+ # Create figure
261
+ fig, ax = plt.subplots(figsize=(12, 10))
262
+
263
+ # Normalize speeds for color mapping
264
+ if max(speeds) > 0:
265
+ norm = mcolors.Normalize(vmin=min(speeds), vmax=max(speeds))
266
+ colormap = plt.cm.jet
267
+ else:
268
+ norm = None
269
+ colormap = None
270
+
271
+ # Plot trajectory with color-coded speeds
272
+ for i in range(1, len(x_coords)):
273
+ if norm is not None:
274
+ color = colormap(norm(speeds[i]))
275
+ else:
276
+ color = 'blue'
277
+
278
+ ax.plot(
279
+ [x_coords[i - 1], x_coords[i]],
280
+ [y_coords[i - 1], y_coords[i]],
281
+ color=color,
282
+ linewidth=2,
283
+ alpha=0.7
284
+ )
285
+
286
+ # Add start and end markers
287
+ ax.scatter(x_coords[0], y_coords[0], c='green', s=100, marker='o',
288
+ label='Start', zorder=5, edgecolors='black', linewidths=2)
289
+ ax.scatter(x_coords[-1], y_coords[-1], c='red', s=100, marker='X',
290
+ label='End', zorder=5, edgecolors='black', linewidths=2)
291
+
292
+ # Formatting
293
+ ax.set_xlabel('X Position (pixels)', fontsize=12, fontweight='bold')
294
+ ax.set_ylabel('Y Position (pixels)', fontsize=12, fontweight='bold')
295
+ ax.set_title('Tennis Ball Trajectory (Color = Speed)', fontsize=14, fontweight='bold')
296
+ ax.legend(loc='best', fontsize=10)
297
+ ax.grid(True, alpha=0.3)
298
+ ax.invert_yaxis() # Invert Y-axis to match image coordinates
299
+
300
+ # Add colorbar
301
+ if norm is not None:
302
+ sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
303
+ sm.set_array([])
304
+ cbar = plt.colorbar(sm, ax=ax, label='Speed (pixels/sec)')
305
+
306
+ plt.tight_layout()
307
+
308
+ # Save if path provided
309
+ if output_path:
310
+ try:
311
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
312
+ except Exception as e:
313
+ print(f"Error saving plot: {str(e)}")
314
+
315
+ return fig