Spaces:
Sleeping
Sleeping
Onur Çopur
commited on
Commit
·
3b90d9c
1
Parent(s):
8353beb
first commit
Browse files- .gitignore +67 -0
- CLAUDE.md +88 -0
- README.md +280 -7
- __init__.py +14 -0
- app.py +322 -4
- detector.py +155 -0
- packages.txt +7 -0
- requirements.txt +9 -0
- tracker.py +210 -0
- utils/__init__.py +32 -0
- utils/io_utils.py +287 -0
- utils/visualization.py +315 -0
.gitignore
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
env.bak/
|
| 28 |
+
venv.bak/
|
| 29 |
+
|
| 30 |
+
# IDE
|
| 31 |
+
.vscode/
|
| 32 |
+
.idea/
|
| 33 |
+
*.swp
|
| 34 |
+
*.swo
|
| 35 |
+
*~
|
| 36 |
+
.DS_Store
|
| 37 |
+
|
| 38 |
+
# Output files
|
| 39 |
+
output/
|
| 40 |
+
*.mp4
|
| 41 |
+
*.avi
|
| 42 |
+
*.mov
|
| 43 |
+
*.csv
|
| 44 |
+
*.png
|
| 45 |
+
!example_videos/*.mp4
|
| 46 |
+
|
| 47 |
+
# Model weights (downloaded at runtime)
|
| 48 |
+
*.pt
|
| 49 |
+
*.pth
|
| 50 |
+
*.onnx
|
| 51 |
+
|
| 52 |
+
# Jupyter
|
| 53 |
+
.ipynb_checkpoints/
|
| 54 |
+
*.ipynb
|
| 55 |
+
|
| 56 |
+
# Gradio
|
| 57 |
+
gradio_cached_examples/
|
| 58 |
+
flagged/
|
| 59 |
+
|
| 60 |
+
# Logs
|
| 61 |
+
*.log
|
| 62 |
+
logs/
|
| 63 |
+
|
| 64 |
+
# Temporary files
|
| 65 |
+
tmp/
|
| 66 |
+
temp/
|
| 67 |
+
*.tmp
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Running the Application
|
| 6 |
+
|
| 7 |
+
Start the Gradio web interface:
|
| 8 |
+
```bash
|
| 9 |
+
python app.py
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
The app launches on `http://localhost:7860` by default (port 7860, binds to 0.0.0.0).
|
| 13 |
+
|
| 14 |
+
## Architecture Overview
|
| 15 |
+
|
| 16 |
+
This is a Gradio-based web application for tennis ball tracking using computer vision. The processing pipeline has three main stages:
|
| 17 |
+
|
| 18 |
+
1. **Detection** ([detector.py](detector.py)): YOLOv8 detects tennis balls (COCO class 32) in each frame
|
| 19 |
+
2. **Tracking** ([tracker.py](tracker.py)): Kalman filter smooths trajectories and predicts positions during occlusion
|
| 20 |
+
3. **Visualization** ([utils/visualization.py](utils/visualization.py)): Overlays trajectory trails, bounding boxes, speed labels, and generates plots
|
| 21 |
+
|
| 22 |
+
### Key Design Patterns
|
| 23 |
+
|
| 24 |
+
**Processing Flow** ([app.py](app.py:30-188)):
|
| 25 |
+
- `process_video()` orchestrates the full pipeline
|
| 26 |
+
- Uses context managers (`VideoReader`, `VideoWriter`) for safe I/O
|
| 27 |
+
- Frame-by-frame processing: detect → update tracker → render overlays → write frame
|
| 28 |
+
- Trajectory data accumulated in tracker, exported to CSV at end
|
| 29 |
+
|
| 30 |
+
**State Management** ([tracker.py](tracker.py:13-211)):
|
| 31 |
+
- `BallTracker` maintains Kalman filter state vector: `[x, y, vx, vy]`
|
| 32 |
+
- Handles initialization on first detection
|
| 33 |
+
- Predicts ball position when detection is lost (up to `max_missing_frames`)
|
| 34 |
+
- Resets tracker if ball missing too long
|
| 35 |
+
|
| 36 |
+
**Detection Selection** ([app.py](app.py:104-115)):
|
| 37 |
+
- When multiple detections occur, uses highest confidence detection
|
| 38 |
+
- Minimum box size filter (5x5 pixels) in [detector.py](detector.py:107)
|
| 39 |
+
|
| 40 |
+
## Module Dependencies
|
| 41 |
+
|
| 42 |
+
```
|
| 43 |
+
app.py (main entry point)
|
| 44 |
+
├── detector.py (BallDetector class)
|
| 45 |
+
├── tracker.py (BallTracker class)
|
| 46 |
+
└── utils/
|
| 47 |
+
├── io_utils.py (VideoReader, VideoWriter, CSV export)
|
| 48 |
+
└── visualization.py (drawing functions, plotting)
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
All utilities are imported via `utils/__init__.py` which re-exports from submodules.
|
| 52 |
+
|
| 53 |
+
## Configuration Parameters
|
| 54 |
+
|
| 55 |
+
**Detector** ([detector.py](detector.py:24-49)):
|
| 56 |
+
- `model_name`: 'yolov8n' (fastest), 'yolov8s', 'yolov8m' (most accurate)
|
| 57 |
+
- `confidence_threshold`: 0.1-0.9 (lower = more sensitive)
|
| 58 |
+
- `device`: auto-detected ('cuda' if available, else 'cpu')
|
| 59 |
+
|
| 60 |
+
**Tracker** ([tracker.py](tracker.py:28-47)):
|
| 61 |
+
- `dt`: time step, calculated as `1.0 / fps`
|
| 62 |
+
- `max_missing_frames`: typically `int(fps * 0.5)` (half-second tolerance)
|
| 63 |
+
- `process_noise`: 0.1 (Kalman filter Q matrix)
|
| 64 |
+
- `measurement_noise`: 10.0 (Kalman filter R matrix)
|
| 65 |
+
|
| 66 |
+
## Output Files
|
| 67 |
+
|
| 68 |
+
All outputs saved to `output/` directory:
|
| 69 |
+
- `tracked_video.mp4`: Video with overlays (bounding boxes, trails, speed labels, info panel)
|
| 70 |
+
- `trajectory.csv`: Frame-by-frame data with columns: frame, timestamp, x/y position, velocity, speed
|
| 71 |
+
- `trajectory_plot.png`: 2D plot with color-coded speed gradient (blue=slow, red=fast)
|
| 72 |
+
|
| 73 |
+
## Speed Estimation
|
| 74 |
+
|
| 75 |
+
Speed calculated from Kalman filter velocity: `speed = sqrt(vx² + vy²) / dt`
|
| 76 |
+
|
| 77 |
+
**Units**: Speed is in pixels/second. The "km/h" label uses rough approximation (`speed * 0.01`) - real-world conversion requires camera calibration.
|
| 78 |
+
|
| 79 |
+
## Deployment Notes
|
| 80 |
+
|
| 81 |
+
**Hugging Face Spaces**:
|
| 82 |
+
- Use YOLOv8n model for free tier (limited GPU)
|
| 83 |
+
- App automatically creates `output/` directory on startup
|
| 84 |
+
- Gradio interface configured with `share=False` by default
|
| 85 |
+
|
| 86 |
+
**Model Downloads**:
|
| 87 |
+
- YOLOv8 weights downloaded automatically by Ultralytics on first run
|
| 88 |
+
- Models cached in `~/.cache/torch/hub/` directory
|
README.md
CHANGED
|
@@ -1,13 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
| 1 |
+
# 🎾 TennisVision – AI Ball Tracker
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://gradio.app/)
|
| 5 |
+
[](https://github.com/ultralytics/ultralytics)
|
| 6 |
+
|
| 7 |
+
A comprehensive AI-powered tennis ball tracking demo application that detects and tracks tennis balls in video clips using state-of-the-art computer vision models.
|
| 8 |
+
|
| 9 |
+
## 🌟 Features
|
| 10 |
+
|
| 11 |
+
- **Real-time Ball Detection**: Uses YOLOv8 models for accurate tennis ball detection
|
| 12 |
+
- **Kalman Filter Tracking**: Smooth trajectory tracking with velocity estimation
|
| 13 |
+
- **Speed Visualization**: Real-time speed estimation and overlay
|
| 14 |
+
- **2D Trajectory Plot**: Color-coded trajectory visualization based on ball speed
|
| 15 |
+
- **Data Export**: Download processed videos and trajectory data in CSV format
|
| 16 |
+
- **Interactive UI**: User-friendly Gradio interface with adjustable parameters
|
| 17 |
+
|
| 18 |
+
## 🎯 Demo Capabilities
|
| 19 |
+
|
| 20 |
+
1. ✅ Accept short tennis video clips (user-uploaded)
|
| 21 |
+
2. ✅ Detect and track tennis balls frame-by-frame using YOLOv8
|
| 22 |
+
3. ✅ Visualize ball trajectory and speed as video overlays
|
| 23 |
+
4. ✅ Display 2D trajectory plot (X vs Y) color-coded by speed
|
| 24 |
+
5. ✅ Provide downloadable outputs:
|
| 25 |
+
- Processed video with trajectory overlays
|
| 26 |
+
- CSV file with timestamped coordinates and speed estimates
|
| 27 |
+
- Trajectory plot image
|
| 28 |
+
|
| 29 |
+
## 🏗️ Project Structure
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
tennisvision/
|
| 33 |
+
├── app.py # Gradio application entry point
|
| 34 |
+
├── detector.py # YOLOv8 detection model wrapper
|
| 35 |
+
├── tracker.py # Kalman filter-based ball tracker
|
| 36 |
+
├── utils/
|
| 37 |
+
│ ├── __init__.py # Utility package initialization
|
| 38 |
+
│ ├── visualization.py # Overlay rendering and plotting
|
| 39 |
+
│ └── io_utils.py # Video I/O and CSV export
|
| 40 |
+
├── example_videos/ # Sample tennis videos (user-provided)
|
| 41 |
+
├── output/ # Generated output files
|
| 42 |
+
├── requirements.txt # Python dependencies
|
| 43 |
+
└── README.md # This file
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## 🚀 Quick Start
|
| 47 |
+
|
| 48 |
+
### Local Installation
|
| 49 |
+
|
| 50 |
+
1. **Clone or download this repository**
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
git clone <repository-url>
|
| 54 |
+
cd tennisvision
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
2. **Install dependencies**
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
pip install -r requirements.txt
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
3. **Run the application**
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python app.py
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
4. **Open your browser** to `http://localhost:7860`
|
| 70 |
+
|
| 71 |
+
### Usage
|
| 72 |
+
|
| 73 |
+
1. **Upload a tennis video** (MP4, AVI, or MOV format)
|
| 74 |
+
2. **Select detection model**:
|
| 75 |
+
- `yolov8n` - Fastest (recommended for Hugging Face Spaces)
|
| 76 |
+
- `yolov8s` - Balanced speed/accuracy
|
| 77 |
+
- `yolov8m` - Most accurate
|
| 78 |
+
3. **Adjust confidence threshold** (0.1 - 0.9):
|
| 79 |
+
- Lower values detect more balls but may include false positives
|
| 80 |
+
- Higher values are more conservative
|
| 81 |
+
4. **Click "Run Tracking"** and wait for processing
|
| 82 |
+
5. **View results** in tabs:
|
| 83 |
+
- Processed video with overlays
|
| 84 |
+
- 2D trajectory plot
|
| 85 |
+
- Download CSV and video files
|
| 86 |
+
|
| 87 |
+
## 🤗 Hugging Face Spaces Deployment
|
| 88 |
+
|
| 89 |
+
### Method 1: Using the Web Interface
|
| 90 |
+
|
| 91 |
+
1. Create a new Space on [Hugging Face](https://huggingface.co/spaces)
|
| 92 |
+
2. Select **Gradio** as the SDK
|
| 93 |
+
3. Upload all files from the `tennisvision/` directory
|
| 94 |
+
4. The Space will automatically build and deploy
|
| 95 |
+
|
| 96 |
+
### Method 2: Using Git
|
| 97 |
+
|
| 98 |
+
1. Create a new Space and clone it:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
git clone https://huggingface.co/spaces/<your-username>/<space-name>
|
| 102 |
+
cd <space-name>
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
2. Copy all files from `tennisvision/`:
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
cp -r path/to/tennisvision/* .
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
3. Commit and push:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
git add .
|
| 115 |
+
git commit -m "Initial commit: TennisVision ball tracker"
|
| 116 |
+
git push
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Configuration Files for Hugging Face
|
| 120 |
+
|
| 121 |
+
Create a `README.md` in the Space root with:
|
| 122 |
+
|
| 123 |
+
```yaml
|
| 124 |
---
|
| 125 |
+
title: TennisVision - AI Ball Tracker
|
| 126 |
+
emoji: 🎾
|
| 127 |
+
colorFrom: green
|
| 128 |
+
colorTo: blue
|
| 129 |
sdk: gradio
|
| 130 |
+
sdk_version: 4.0.0
|
| 131 |
app_file: app.py
|
| 132 |
pinned: false
|
|
|
|
| 133 |
---
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
## 📊 Output Files
|
| 137 |
+
|
| 138 |
+
### 1. Processed Video (`tracked_video.mp4`)
|
| 139 |
+
- Original video with overlays
|
| 140 |
+
- Ball bounding boxes
|
| 141 |
+
- Trajectory trail (last 20 positions)
|
| 142 |
+
- Speed labels
|
| 143 |
+
- Info panel with frame count and timestamp
|
| 144 |
+
|
| 145 |
+
### 2. Trajectory CSV (`trajectory.csv`)
|
| 146 |
+
Columns:
|
| 147 |
+
- `frame`: Frame number
|
| 148 |
+
- `timestamp_sec`: Time in seconds
|
| 149 |
+
- `x_pixels`, `y_pixels`: Ball center coordinates
|
| 150 |
+
- `velocity_x_px_per_sec`, `velocity_y_px_per_sec`: Velocity components
|
| 151 |
+
- `speed_px_per_sec`: Instantaneous speed
|
| 152 |
+
|
| 153 |
+
### 3. Trajectory Plot (`trajectory_plot.png`)
|
| 154 |
+
- 2D visualization of ball path
|
| 155 |
+
- Color gradient representing speed (blue = slow, red = fast)
|
| 156 |
+
- Start and end markers
|
| 157 |
+
|
| 158 |
+
## 🛠️ Technical Details
|
| 159 |
+
|
| 160 |
+
### Detection Models
|
| 161 |
+
|
| 162 |
+
**YOLOv8** (You Only Look Once v8) from Ultralytics:
|
| 163 |
+
- Pre-trained on COCO dataset
|
| 164 |
+
- Detects sports balls (class 32)
|
| 165 |
+
- Variants: `yolov8n` (nano), `yolov8s` (small), `yolov8m` (medium)
|
| 166 |
+
|
| 167 |
+
### Tracking Algorithm
|
| 168 |
+
|
| 169 |
+
**Kalman Filter**:
|
| 170 |
+
- State vector: `[x, y, vx, vy]` (position and velocity)
|
| 171 |
+
- Constant velocity motion model
|
| 172 |
+
- Predicts ball position when detection is lost
|
| 173 |
+
- Smooths noisy detections
|
| 174 |
+
|
| 175 |
+
### Speed Estimation
|
| 176 |
+
|
| 177 |
+
```
|
| 178 |
+
speed = sqrt(vx² + vy²) / dt
|
| 179 |
+
```
|
| 180 |
+
where `dt = 1 / fps`
|
| 181 |
+
|
| 182 |
+
*Note: Speed is in pixels/sec. Real-world conversion requires camera calibration.*
|
| 183 |
+
|
| 184 |
+
## 🎨 Visualization Features
|
| 185 |
+
|
| 186 |
+
- **Bounding Boxes**: Green boxes around detected balls
|
| 187 |
+
- **Trajectory Trail**: Fading trail showing recent positions
|
| 188 |
+
- **Speed Label**: Real-time speed estimate (km/h approximation)
|
| 189 |
+
- **Info Panel**: Frame number, timestamp, detection confidence
|
| 190 |
+
- **2D Plot**: Complete trajectory with color-coded speed
|
| 191 |
+
|
| 192 |
+
## ⚠️ Limitations
|
| 193 |
+
|
| 194 |
+
- Requires visible ball in video
|
| 195 |
+
- Works best with clear, high-resolution footage
|
| 196 |
+
- Speed estimates are in pixels (not calibrated to real-world units)
|
| 197 |
+
- Processing time scales with video length and model size
|
| 198 |
+
- Free Hugging Face Spaces have limited GPU resources (use YOLOv8n)
|
| 199 |
+
|
| 200 |
+
## 🔧 Configuration
|
| 201 |
+
|
| 202 |
+
### Detector Parameters
|
| 203 |
+
|
| 204 |
+
```python
|
| 205 |
+
detector = BallDetector(
|
| 206 |
+
model_name="yolov8n", # Model variant
|
| 207 |
+
confidence_threshold=0.3, # Min confidence score
|
| 208 |
+
device="cuda" # or "cpu"
|
| 209 |
+
)
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
### Tracker Parameters
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
tracker = BallTracker(
|
| 216 |
+
dt=1.0/30.0, # Time step (1/fps)
|
| 217 |
+
process_noise=0.1, # Process noise std
|
| 218 |
+
measurement_noise=10.0, # Measurement noise std
|
| 219 |
+
max_missing_frames=10 # Max frames without detection
|
| 220 |
+
)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## 🧪 Example Videos
|
| 224 |
+
|
| 225 |
+
Place sample videos in `example_videos/`:
|
| 226 |
+
- `serve.mp4` - Tennis serve motion
|
| 227 |
+
- `rally.mp4` - Rally with multiple ball trajectories
|
| 228 |
+
|
| 229 |
+
*Note: Sample videos not included in repository. Use your own tennis footage.*
|
| 230 |
+
|
| 231 |
+
## 🐛 Troubleshooting
|
| 232 |
+
|
| 233 |
+
**No ball detected:**
|
| 234 |
+
- Lower the confidence threshold
|
| 235 |
+
- Ensure ball is clearly visible
|
| 236 |
+
- Try a different model (yolov8s or yolov8m)
|
| 237 |
+
|
| 238 |
+
**Slow processing:**
|
| 239 |
+
- Use yolov8n model
|
| 240 |
+
- Process shorter clips
|
| 241 |
+
- Use GPU if available
|
| 242 |
+
|
| 243 |
+
**Poor tracking accuracy:**
|
| 244 |
+
- Increase confidence threshold
|
| 245 |
+
- Adjust Kalman filter parameters
|
| 246 |
+
- Use higher resolution video
|
| 247 |
+
|
| 248 |
+
## 📚 Dependencies
|
| 249 |
+
|
| 250 |
+
- **torch** >= 2.0.0 - Deep learning framework
|
| 251 |
+
- **ultralytics** >= 8.0.0 - YOLOv8 implementation
|
| 252 |
+
- **opencv-python-headless** == 4.8.1.78 - Video processing
|
| 253 |
+
- **gradio** >= 4.0.0 - Web interface
|
| 254 |
+
- **matplotlib** >= 3.7.0 - Plotting
|
| 255 |
+
- **filterpy** >= 1.4.5 - Kalman filter
|
| 256 |
+
- **numpy** >= 1.24.0 - Numerical operations
|
| 257 |
+
|
| 258 |
+
## 🙏 Acknowledgments
|
| 259 |
+
|
| 260 |
+
- [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics) - Object detection
|
| 261 |
+
- [Gradio](https://gradio.app/) - Web interface framework
|
| 262 |
+
- [FilterPy](https://github.com/rlabbe/filterpy) - Kalman filter implementation
|
| 263 |
+
|
| 264 |
+
## 📄 License
|
| 265 |
+
|
| 266 |
+
This project is open-source and available for educational and research purposes.
|
| 267 |
+
|
| 268 |
+
## 🤝 Contributing
|
| 269 |
+
|
| 270 |
+
Contributions are welcome! Areas for improvement:
|
| 271 |
+
- [ ] Add RT-DETR model support
|
| 272 |
+
- [ ] Implement bounce detection
|
| 273 |
+
- [ ] Add FPS benchmark display
|
| 274 |
+
- [ ] Camera calibration for real-world speed
|
| 275 |
+
- [ ] Multi-ball tracking
|
| 276 |
+
- [ ] Player detection and tracking
|
| 277 |
+
|
| 278 |
+
## 📧 Contact
|
| 279 |
+
|
| 280 |
+
For questions or issues, please open an issue on the repository.
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
+
**Built with ❤️ for the tennis and computer vision community**
|
| 285 |
|
| 286 |
+
🎾 Enjoy tracking! 🚀
|
__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TennisVision - AI Ball Tracker
|
| 3 |
+
|
| 4 |
+
A comprehensive tennis ball tracking system using YOLOv8 detection
|
| 5 |
+
and Kalman filter-based tracking.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
__version__ = "1.0.0"
|
| 9 |
+
__author__ = "TennisVision Team"
|
| 10 |
+
|
| 11 |
+
from .detector import BallDetector
|
| 12 |
+
from .tracker import BallTracker
|
| 13 |
+
|
| 14 |
+
__all__ = ['BallDetector', 'BallTracker']
|
app.py
CHANGED
|
@@ -1,7 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TennisVision - AI Ball Tracker
|
| 3 |
+
Gradio application for tennis ball detection and tracking.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
import gradio as gr
|
| 7 |
+
import cv2
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import tempfile
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Tuple, Optional
|
| 13 |
+
|
| 14 |
+
from detector import BallDetector
|
| 15 |
+
from tracker import BallTracker
|
| 16 |
+
from utils import (
|
| 17 |
+
VideoReader,
|
| 18 |
+
VideoWriter,
|
| 19 |
+
export_trajectory_csv,
|
| 20 |
+
validate_video_file,
|
| 21 |
+
create_output_directory,
|
| 22 |
+
draw_detection,
|
| 23 |
+
draw_trajectory_trail,
|
| 24 |
+
draw_speed_label,
|
| 25 |
+
draw_info_panel,
|
| 26 |
+
create_trajectory_plot
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def process_video(
|
| 31 |
+
video_path: str,
|
| 32 |
+
model_name: str,
|
| 33 |
+
confidence_threshold: float,
|
| 34 |
+
progress=gr.Progress()
|
| 35 |
+
) -> Tuple[Optional[str], Optional[str], Optional[str], str]:
|
| 36 |
+
"""
|
| 37 |
+
Process a video to track the tennis ball.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
video_path: Path to input video file
|
| 41 |
+
model_name: Detection model identifier
|
| 42 |
+
confidence_threshold: Minimum detection confidence
|
| 43 |
+
progress: Gradio progress tracker
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple of (output_video_path, csv_path, plot_path, status_message)
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
# Validate input video
|
| 50 |
+
is_valid, msg = validate_video_file(video_path)
|
| 51 |
+
if not is_valid:
|
| 52 |
+
return None, None, None, f"❌ Error: {msg}"
|
| 53 |
+
|
| 54 |
+
progress(0, desc="Initializing models...")
|
| 55 |
+
|
| 56 |
+
# Initialize detector and tracker
|
| 57 |
+
detector = BallDetector(
|
| 58 |
+
model_name=model_name,
|
| 59 |
+
confidence_threshold=confidence_threshold
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Read video properties
|
| 63 |
+
with VideoReader(video_path) as reader:
|
| 64 |
+
video_props = reader.get_properties()
|
| 65 |
+
|
| 66 |
+
fps = video_props['fps']
|
| 67 |
+
frame_count = video_props['frame_count']
|
| 68 |
+
width = video_props['width']
|
| 69 |
+
height = video_props['height']
|
| 70 |
+
|
| 71 |
+
# Initialize tracker
|
| 72 |
+
tracker = BallTracker(dt=1.0 / fps, max_missing_frames=int(fps * 0.5))
|
| 73 |
+
|
| 74 |
+
# Create temporary output files
|
| 75 |
+
output_dir = create_output_directory("output")
|
| 76 |
+
temp_video = tempfile.NamedTemporaryFile(
|
| 77 |
+
delete=False, suffix='.mp4', dir=output_dir
|
| 78 |
+
)
|
| 79 |
+
output_video_path = temp_video.name
|
| 80 |
+
temp_video.close()
|
| 81 |
+
|
| 82 |
+
csv_path = output_dir / "trajectory.csv"
|
| 83 |
+
plot_path = output_dir / "trajectory_plot.png"
|
| 84 |
+
|
| 85 |
+
progress(0.1, desc="Processing frames...")
|
| 86 |
+
|
| 87 |
+
# Process video
|
| 88 |
+
detection_count = 0
|
| 89 |
+
with VideoReader(video_path) as reader, \
|
| 90 |
+
VideoWriter(output_video_path, fps, width, height) as writer:
|
| 91 |
+
|
| 92 |
+
for frame_num, frame in reader.read_frames():
|
| 93 |
+
# Update progress
|
| 94 |
+
progress_pct = 0.1 + 0.7 * (frame_num / frame_count)
|
| 95 |
+
progress(
|
| 96 |
+
progress_pct,
|
| 97 |
+
desc=f"Processing frame {frame_num + 1}/{frame_count}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Detect ball
|
| 101 |
+
detections = detector.detect(frame)
|
| 102 |
+
|
| 103 |
+
# Update tracker
|
| 104 |
+
if len(detections) > 0:
|
| 105 |
+
# Use highest confidence detection
|
| 106 |
+
best_detection = detections[0]
|
| 107 |
+
cx, cy = detector.get_ball_center(best_detection)
|
| 108 |
+
state = tracker.update((cx, cy))
|
| 109 |
+
detection_count += 1
|
| 110 |
+
|
| 111 |
+
# Draw detection box
|
| 112 |
+
frame = draw_detection(frame, best_detection)
|
| 113 |
+
else:
|
| 114 |
+
# Predict without detection
|
| 115 |
+
state = tracker.update(None)
|
| 116 |
+
|
| 117 |
+
# Draw trajectory and info if tracker is active
|
| 118 |
+
if state is not None and tracker.is_initialized():
|
| 119 |
+
x, y, vx, vy = state
|
| 120 |
+
|
| 121 |
+
# Draw trajectory trail
|
| 122 |
+
positions = tracker.get_last_n_positions(20)
|
| 123 |
+
frame = draw_trajectory_trail(frame, positions)
|
| 124 |
+
|
| 125 |
+
# Calculate and draw speed
|
| 126 |
+
speed = tracker.get_speed(state)
|
| 127 |
+
frame = draw_speed_label(frame, (x, y), speed, fps)
|
| 128 |
+
|
| 129 |
+
# Draw info panel
|
| 130 |
+
conf = detections[0][4] if len(detections) > 0 else None
|
| 131 |
+
frame = draw_info_panel(frame, frame_num + 1, frame_count, fps, conf)
|
| 132 |
+
|
| 133 |
+
# Write frame
|
| 134 |
+
writer.write_frame(frame)
|
| 135 |
+
|
| 136 |
+
# Export trajectory data
|
| 137 |
+
progress(0.8, desc="Exporting trajectory data...")
|
| 138 |
+
trajectory = tracker.get_trajectory()
|
| 139 |
+
|
| 140 |
+
if len(trajectory) == 0:
|
| 141 |
+
return None, None, None, "❌ No ball detected in video. Try lowering the confidence threshold."
|
| 142 |
+
|
| 143 |
+
# Export CSV
|
| 144 |
+
export_success = export_trajectory_csv(trajectory, fps, str(csv_path))
|
| 145 |
+
if not export_success:
|
| 146 |
+
csv_path = None
|
| 147 |
+
|
| 148 |
+
# Create trajectory plot
|
| 149 |
+
progress(0.9, desc="Creating trajectory plot...")
|
| 150 |
+
try:
|
| 151 |
+
create_trajectory_plot(trajectory, fps, str(plot_path))
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f"Failed to create plot: {e}")
|
| 154 |
+
plot_path = None
|
| 155 |
+
|
| 156 |
+
progress(1.0, desc="Complete!")
|
| 157 |
+
|
| 158 |
+
# Generate status message
|
| 159 |
+
status = f"""✅ **Processing Complete!**
|
| 160 |
+
|
| 161 |
+
**Video Info:**
|
| 162 |
+
- Total Frames: {frame_count}
|
| 163 |
+
- Frame Rate: {fps:.1f} FPS
|
| 164 |
+
- Resolution: {width}x{height}
|
| 165 |
+
|
| 166 |
+
**Tracking Results:**
|
| 167 |
+
- Ball Detected: {detection_count} frames ({100 * detection_count / frame_count:.1f}%)
|
| 168 |
+
- Trajectory Points: {len(trajectory)}
|
| 169 |
+
|
| 170 |
+
**Outputs:**
|
| 171 |
+
- Processed video with overlays
|
| 172 |
+
- Trajectory CSV with {len(trajectory)} data points
|
| 173 |
+
- 2D trajectory plot color-coded by speed
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
return (
|
| 177 |
+
output_video_path,
|
| 178 |
+
str(csv_path) if csv_path else None,
|
| 179 |
+
str(plot_path) if plot_path else None,
|
| 180 |
+
status
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
error_msg = f"❌ **Error during processing:** {str(e)}"
|
| 185 |
+
print(error_msg)
|
| 186 |
+
import traceback
|
| 187 |
+
traceback.print_exc()
|
| 188 |
+
return None, None, None, error_msg
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Create Gradio interface
|
| 192 |
+
def create_interface():
|
| 193 |
+
"""Create and configure the Gradio interface."""
|
| 194 |
+
|
| 195 |
+
with gr.Blocks(
|
| 196 |
+
title="TennisVision - AI Ball Tracker",
|
| 197 |
+
theme=gr.themes.Soft()
|
| 198 |
+
) as app:
|
| 199 |
+
gr.Markdown(
|
| 200 |
+
"""
|
| 201 |
+
# 🎾 TennisVision - AI Ball Tracker
|
| 202 |
+
|
| 203 |
+
Upload a tennis video to automatically detect and track the ball using
|
| 204 |
+
state-of-the-art computer vision models.
|
| 205 |
+
|
| 206 |
+
**Features:**
|
| 207 |
+
- Real-time ball detection with YOLOv8
|
| 208 |
+
- Smooth trajectory tracking with Kalman filter
|
| 209 |
+
- Speed estimation and visualization
|
| 210 |
+
- Downloadable outputs (video, CSV, plot)
|
| 211 |
+
"""
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
with gr.Row():
|
| 215 |
+
with gr.Column(scale=1):
|
| 216 |
+
gr.Markdown("### ⚙️ Input & Settings")
|
| 217 |
+
|
| 218 |
+
video_input = gr.Video(
|
| 219 |
+
label="Upload Tennis Video",
|
| 220 |
+
sources=["upload"]
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
model_dropdown = gr.Dropdown(
|
| 224 |
+
choices=["yolov8n", "yolov8s", "yolov8m"],
|
| 225 |
+
value="yolov8n",
|
| 226 |
+
label="Detection Model",
|
| 227 |
+
info="yolov8n is fastest, yolov8m is most accurate"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
confidence_slider = gr.Slider(
|
| 231 |
+
minimum=0.1,
|
| 232 |
+
maximum=0.9,
|
| 233 |
+
value=0.3,
|
| 234 |
+
step=0.05,
|
| 235 |
+
label="Confidence Threshold",
|
| 236 |
+
info="Lower = more detections (may include false positives)"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
process_btn = gr.Button(
|
| 240 |
+
"🚀 Run Tracking",
|
| 241 |
+
variant="primary",
|
| 242 |
+
size="lg"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
gr.Markdown(
|
| 246 |
+
"""
|
| 247 |
+
### 💡 Tips
|
| 248 |
+
- Use short clips (5-15 seconds) for faster processing
|
| 249 |
+
- Ensure the ball is visible and in motion
|
| 250 |
+
- Lower confidence threshold if ball is not detected
|
| 251 |
+
- YOLOv8n provides fastest inference (~30 FPS)
|
| 252 |
+
"""
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
with gr.Column(scale=2):
|
| 256 |
+
gr.Markdown("### 📊 Results")
|
| 257 |
+
|
| 258 |
+
status_output = gr.Markdown(
|
| 259 |
+
"Upload a video and click **Run Tracking** to begin."
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
with gr.Tabs():
|
| 263 |
+
with gr.Tab("📹 Processed Video"):
|
| 264 |
+
video_output = gr.Video(
|
| 265 |
+
label="Tracked Video",
|
| 266 |
+
show_label=False
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
with gr.Tab("📈 Trajectory Plot"):
|
| 270 |
+
plot_output = gr.Image(
|
| 271 |
+
label="2D Trajectory",
|
| 272 |
+
show_label=False
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
with gr.Tab("📥 Downloads"):
|
| 276 |
+
gr.Markdown("### Download Files")
|
| 277 |
+
csv_output = gr.File(
|
| 278 |
+
label="Trajectory Data (CSV)"
|
| 279 |
+
)
|
| 280 |
+
video_download = gr.File(
|
| 281 |
+
label="Processed Video (MP4)"
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Event handlers
|
| 285 |
+
process_btn.click(
|
| 286 |
+
fn=process_video,
|
| 287 |
+
inputs=[video_input, model_dropdown, confidence_slider],
|
| 288 |
+
outputs=[video_output, csv_output, plot_output, status_output]
|
| 289 |
+
).then(
|
| 290 |
+
fn=lambda x: x,
|
| 291 |
+
inputs=[video_output],
|
| 292 |
+
outputs=[video_download]
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
gr.Markdown(
|
| 296 |
+
"""
|
| 297 |
+
---
|
| 298 |
+
### 📚 About
|
| 299 |
+
|
| 300 |
+
**TennisVision** uses YOLOv8 for ball detection and Kalman filtering
|
| 301 |
+
for smooth trajectory tracking. The system estimates ball speed and
|
| 302 |
+
visualizes the complete trajectory with color-coded speed indicators.
|
| 303 |
+
|
| 304 |
+
**Model:** YOLOv8 (Ultralytics)
|
| 305 |
+
**Tracking:** Kalman Filter
|
| 306 |
+
**Framework:** Gradio + OpenCV
|
| 307 |
+
|
| 308 |
+
Built for deployment on Hugging Face Spaces 🤗
|
| 309 |
+
"""
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return app
|
| 313 |
+
|
| 314 |
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
# Create output directory
|
| 317 |
+
create_output_directory("output")
|
| 318 |
|
| 319 |
+
# Launch app
|
| 320 |
+
app = create_interface()
|
| 321 |
+
app.launch(
|
| 322 |
+
share=False,
|
| 323 |
+
server_name="0.0.0.0",
|
| 324 |
+
server_port=7860
|
| 325 |
+
)
|
detector.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ball Detection Module using YOLO and RT-DETR models.
|
| 3 |
+
|
| 4 |
+
This module provides a unified interface for detecting tennis balls
|
| 5 |
+
in video frames using state-of-the-art object detection models.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import List, Tuple, Optional
|
| 11 |
+
from ultralytics import YOLO
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BallDetector:
|
| 15 |
+
"""
|
| 16 |
+
Wrapper class for ball detection using YOLOv8 or RT-DETR.
|
| 17 |
+
|
| 18 |
+
Attributes:
|
| 19 |
+
model_name (str): Name of the detection model ('yolov8n', 'yolov8s', etc.)
|
| 20 |
+
confidence_threshold (float): Minimum confidence score for detections
|
| 21 |
+
device (str): Device to run inference on ('cuda' or 'cpu')
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
model_name: str = "yolov8n",
|
| 27 |
+
confidence_threshold: float = 0.3,
|
| 28 |
+
device: Optional[str] = None
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Initialize the ball detector.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_name: Model identifier (e.g., 'yolov8n', 'yolov8s')
|
| 35 |
+
confidence_threshold: Minimum confidence for valid detections
|
| 36 |
+
device: Compute device ('cuda', 'cpu', or None for auto-detect)
|
| 37 |
+
"""
|
| 38 |
+
self.model_name = model_name
|
| 39 |
+
self.confidence_threshold = confidence_threshold
|
| 40 |
+
|
| 41 |
+
# Auto-detect device if not specified
|
| 42 |
+
if device is None:
|
| 43 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 44 |
+
else:
|
| 45 |
+
self.device = device
|
| 46 |
+
|
| 47 |
+
# Load model
|
| 48 |
+
self.model = self._load_model()
|
| 49 |
+
|
| 50 |
+
def _load_model(self) -> YOLO:
|
| 51 |
+
"""
|
| 52 |
+
Load the specified detection model.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Loaded YOLO model instance
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
ValueError: If model name is not supported
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
if self.model_name.startswith('yolov8'):
|
| 62 |
+
# Load YOLOv8 model from Ultralytics
|
| 63 |
+
model = YOLO(f'{self.model_name}.pt')
|
| 64 |
+
model.to(self.device)
|
| 65 |
+
return model
|
| 66 |
+
else:
|
| 67 |
+
raise ValueError(f"Unsupported model: {self.model_name}")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
raise RuntimeError(f"Failed to load model {self.model_name}: {str(e)}")
|
| 70 |
+
|
| 71 |
+
def detect(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
|
| 72 |
+
"""
|
| 73 |
+
Detect tennis balls in a single frame.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
frame: Input frame as numpy array (H, W, 3) in BGR format
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
List of detections, each as (x1, y1, x2, y2, confidence)
|
| 80 |
+
where (x1, y1) is top-left and (x2, y2) is bottom-right
|
| 81 |
+
"""
|
| 82 |
+
try:
|
| 83 |
+
# Run inference
|
| 84 |
+
results = self.model.predict(
|
| 85 |
+
frame,
|
| 86 |
+
conf=self.confidence_threshold,
|
| 87 |
+
device=self.device,
|
| 88 |
+
verbose=False,
|
| 89 |
+
classes=[32] # Sports ball class in COCO dataset
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
detections = []
|
| 93 |
+
|
| 94 |
+
# Parse results
|
| 95 |
+
if len(results) > 0 and results[0].boxes is not None:
|
| 96 |
+
boxes = results[0].boxes
|
| 97 |
+
|
| 98 |
+
for box in boxes:
|
| 99 |
+
# Extract bounding box coordinates
|
| 100 |
+
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
| 101 |
+
confidence = float(box.conf[0].cpu().numpy())
|
| 102 |
+
|
| 103 |
+
# Filter small detections (likely noise)
|
| 104 |
+
width = x2 - x1
|
| 105 |
+
height = y2 - y1
|
| 106 |
+
|
| 107 |
+
if width > 5 and height > 5: # Minimum size threshold
|
| 108 |
+
detections.append((
|
| 109 |
+
int(x1), int(y1), int(x2), int(y2), confidence
|
| 110 |
+
))
|
| 111 |
+
|
| 112 |
+
# Sort by confidence (highest first)
|
| 113 |
+
detections.sort(key=lambda x: x[4], reverse=True)
|
| 114 |
+
|
| 115 |
+
return detections
|
| 116 |
+
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f"Detection error: {str(e)}")
|
| 119 |
+
return []
|
| 120 |
+
|
| 121 |
+
def get_ball_center(
|
| 122 |
+
self,
|
| 123 |
+
detection: Tuple[int, int, int, int, float]
|
| 124 |
+
) -> Tuple[float, float]:
|
| 125 |
+
"""
|
| 126 |
+
Calculate the center point of a ball detection.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
detection: Bounding box as (x1, y1, x2, y2, confidence)
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Center coordinates as (cx, cy)
|
| 133 |
+
"""
|
| 134 |
+
x1, y1, x2, y2, _ = detection
|
| 135 |
+
cx = (x1 + x2) / 2.0
|
| 136 |
+
cy = (y1 + y2) / 2.0
|
| 137 |
+
return cx, cy
|
| 138 |
+
|
| 139 |
+
def get_ball_size(
|
| 140 |
+
self,
|
| 141 |
+
detection: Tuple[int, int, int, int, float]
|
| 142 |
+
) -> Tuple[float, float]:
|
| 143 |
+
"""
|
| 144 |
+
Calculate the width and height of a ball detection.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
detection: Bounding box as (x1, y1, x2, y2, confidence)
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Size as (width, height)
|
| 151 |
+
"""
|
| 152 |
+
x1, y1, x2, y2, _ = detection
|
| 153 |
+
width = x2 - x1
|
| 154 |
+
height = y2 - y1
|
| 155 |
+
return width, height
|
packages.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libgl1
|
| 2 |
+
libglib2.0-0
|
| 3 |
+
libsm6
|
| 4 |
+
libxext6
|
| 5 |
+
libxrender-dev
|
| 6 |
+
libgomp1
|
| 7 |
+
ffmpeg
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
torchvision>=0.15.0
|
| 3 |
+
opencv-python-headless==4.8.1.78
|
| 4 |
+
ultralytics>=8.0.0
|
| 5 |
+
numpy>=1.24.0
|
| 6 |
+
gradio>=4.0.0
|
| 7 |
+
matplotlib>=3.7.0
|
| 8 |
+
filterpy>=1.4.5
|
| 9 |
+
Pillow>=10.0.0
|
tracker.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Ball Tracking Module using Kalman Filter.
|
| 3 |
+
|
| 4 |
+
This module implements a Kalman filter-based tracker for smoothing
|
| 5 |
+
and predicting tennis ball positions across video frames.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Optional, Tuple, List
|
| 10 |
+
from filterpy.kalman import KalmanFilter
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BallTracker:
|
| 14 |
+
"""
|
| 15 |
+
Kalman filter-based tracker for tennis ball position and velocity.
|
| 16 |
+
|
| 17 |
+
The tracker maintains state estimates for:
|
| 18 |
+
- Position (x, y)
|
| 19 |
+
- Velocity (vx, vy)
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
dt (float): Time step between frames (1/fps)
|
| 23 |
+
process_noise (float): Process noise covariance
|
| 24 |
+
measurement_noise (float): Measurement noise covariance
|
| 25 |
+
max_missing_frames (int): Maximum frames without detection before reset
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
dt: float = 1.0 / 30.0,
|
| 31 |
+
process_noise: float = 0.1,
|
| 32 |
+
measurement_noise: float = 10.0,
|
| 33 |
+
max_missing_frames: int = 10
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Initialize the ball tracker.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
dt: Time step between frames (seconds)
|
| 40 |
+
process_noise: Process noise standard deviation
|
| 41 |
+
measurement_noise: Measurement noise standard deviation
|
| 42 |
+
max_missing_frames: Max consecutive frames without detection
|
| 43 |
+
"""
|
| 44 |
+
self.dt = dt
|
| 45 |
+
self.process_noise = process_noise
|
| 46 |
+
self.measurement_noise = measurement_noise
|
| 47 |
+
self.max_missing_frames = max_missing_frames
|
| 48 |
+
|
| 49 |
+
# Initialize Kalman filter
|
| 50 |
+
self.kf = self._create_kalman_filter()
|
| 51 |
+
|
| 52 |
+
# Tracking state
|
| 53 |
+
self.initialized = False
|
| 54 |
+
self.missing_frames = 0
|
| 55 |
+
self.trajectory = [] # List of (x, y, vx, vy, frame_num)
|
| 56 |
+
self.frame_count = 0
|
| 57 |
+
|
| 58 |
+
def _create_kalman_filter(self) -> KalmanFilter:
|
| 59 |
+
"""
|
| 60 |
+
Create and configure a Kalman filter for 2D position tracking.
|
| 61 |
+
|
| 62 |
+
State vector: [x, y, vx, vy]
|
| 63 |
+
Measurement vector: [x, y]
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Configured KalmanFilter instance
|
| 67 |
+
"""
|
| 68 |
+
kf = KalmanFilter(dim_x=4, dim_z=2)
|
| 69 |
+
|
| 70 |
+
# State transition matrix (constant velocity model)
|
| 71 |
+
kf.F = np.array([
|
| 72 |
+
[1, 0, self.dt, 0],
|
| 73 |
+
[0, 1, 0, self.dt],
|
| 74 |
+
[0, 0, 1, 0],
|
| 75 |
+
[0, 0, 0, 1]
|
| 76 |
+
])
|
| 77 |
+
|
| 78 |
+
# Measurement matrix (observe position only)
|
| 79 |
+
kf.H = np.array([
|
| 80 |
+
[1, 0, 0, 0],
|
| 81 |
+
[0, 1, 0, 0]
|
| 82 |
+
])
|
| 83 |
+
|
| 84 |
+
# Measurement noise covariance
|
| 85 |
+
kf.R = np.eye(2) * self.measurement_noise
|
| 86 |
+
|
| 87 |
+
# Process noise covariance
|
| 88 |
+
q = self.process_noise
|
| 89 |
+
kf.Q = np.array([
|
| 90 |
+
[q * self.dt**4 / 4, 0, q * self.dt**3 / 2, 0],
|
| 91 |
+
[0, q * self.dt**4 / 4, 0, q * self.dt**3 / 2],
|
| 92 |
+
[q * self.dt**3 / 2, 0, q * self.dt**2, 0],
|
| 93 |
+
[0, q * self.dt**3 / 2, 0, q * self.dt**2]
|
| 94 |
+
])
|
| 95 |
+
|
| 96 |
+
# Initial state covariance
|
| 97 |
+
kf.P = np.eye(4) * 100
|
| 98 |
+
|
| 99 |
+
return kf
|
| 100 |
+
|
| 101 |
+
def update(
|
| 102 |
+
self,
|
| 103 |
+
measurement: Optional[Tuple[float, float]] = None
|
| 104 |
+
) -> Optional[Tuple[float, float, float, float]]:
|
| 105 |
+
"""
|
| 106 |
+
Update tracker with a new measurement or predict if no detection.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
measurement: Ball center position as (x, y), or None if not detected
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Estimated state as (x, y, vx, vy) or None if tracker not initialized
|
| 113 |
+
"""
|
| 114 |
+
self.frame_count += 1
|
| 115 |
+
|
| 116 |
+
if measurement is not None:
|
| 117 |
+
# Detection available
|
| 118 |
+
if not self.initialized:
|
| 119 |
+
# Initialize tracker with first detection
|
| 120 |
+
self.kf.x = np.array([
|
| 121 |
+
measurement[0],
|
| 122 |
+
measurement[1],
|
| 123 |
+
0.0,
|
| 124 |
+
0.0
|
| 125 |
+
])
|
| 126 |
+
self.initialized = True
|
| 127 |
+
self.missing_frames = 0
|
| 128 |
+
else:
|
| 129 |
+
# Update with measurement
|
| 130 |
+
z = np.array([measurement[0], measurement[1]])
|
| 131 |
+
self.kf.predict()
|
| 132 |
+
self.kf.update(z)
|
| 133 |
+
self.missing_frames = 0
|
| 134 |
+
|
| 135 |
+
# Record trajectory
|
| 136 |
+
x, y, vx, vy = self.kf.x
|
| 137 |
+
self.trajectory.append((
|
| 138 |
+
float(x), float(y), float(vx), float(vy), self.frame_count
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
return (float(x), float(y), float(vx), float(vy))
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
# No detection - predict only
|
| 145 |
+
if self.initialized:
|
| 146 |
+
self.kf.predict()
|
| 147 |
+
self.missing_frames += 1
|
| 148 |
+
|
| 149 |
+
# Reset if too many missing frames
|
| 150 |
+
if self.missing_frames > self.max_missing_frames:
|
| 151 |
+
self.reset()
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
# Return prediction
|
| 155 |
+
x, y, vx, vy = self.kf.x
|
| 156 |
+
self.trajectory.append((
|
| 157 |
+
float(x), float(y), float(vx), float(vy), self.frame_count
|
| 158 |
+
))
|
| 159 |
+
return (float(x), float(y), float(vx), float(vy))
|
| 160 |
+
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
def reset(self):
|
| 164 |
+
"""Reset tracker to uninitialized state."""
|
| 165 |
+
self.kf = self._create_kalman_filter()
|
| 166 |
+
self.initialized = False
|
| 167 |
+
self.missing_frames = 0
|
| 168 |
+
|
| 169 |
+
def get_trajectory(self) -> List[Tuple[float, float, float, float, int]]:
|
| 170 |
+
"""
|
| 171 |
+
Get the complete trajectory history.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of trajectory points as (x, y, vx, vy, frame_num)
|
| 175 |
+
"""
|
| 176 |
+
return self.trajectory
|
| 177 |
+
|
| 178 |
+
def get_speed(self, state: Tuple[float, float, float, float]) -> float:
|
| 179 |
+
"""
|
| 180 |
+
Calculate speed from velocity components.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
state: Tracker state as (x, y, vx, vy)
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Speed in pixels per second
|
| 187 |
+
"""
|
| 188 |
+
_, _, vx, vy = state
|
| 189 |
+
speed = np.sqrt(vx**2 + vy**2) / self.dt
|
| 190 |
+
return float(speed)
|
| 191 |
+
|
| 192 |
+
def get_last_n_positions(self, n: int = 20) -> List[Tuple[float, float]]:
|
| 193 |
+
"""
|
| 194 |
+
Get the last N tracked positions for trail visualization.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
n: Number of recent positions to return
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
List of (x, y) coordinates
|
| 201 |
+
"""
|
| 202 |
+
if len(self.trajectory) == 0:
|
| 203 |
+
return []
|
| 204 |
+
|
| 205 |
+
recent = self.trajectory[-n:]
|
| 206 |
+
return [(x, y) for x, y, _, _, _ in recent]
|
| 207 |
+
|
| 208 |
+
def is_initialized(self) -> bool:
|
| 209 |
+
"""Check if tracker has been initialized with a detection."""
|
| 210 |
+
return self.initialized
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility modules for TennisVision."""
|
| 2 |
+
|
| 3 |
+
from .visualization import (
|
| 4 |
+
draw_detection,
|
| 5 |
+
draw_trajectory_trail,
|
| 6 |
+
draw_speed_label,
|
| 7 |
+
draw_info_panel,
|
| 8 |
+
create_trajectory_plot
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from .io_utils import (
|
| 12 |
+
VideoReader,
|
| 13 |
+
VideoWriter,
|
| 14 |
+
export_trajectory_csv,
|
| 15 |
+
get_video_info,
|
| 16 |
+
validate_video_file,
|
| 17 |
+
create_output_directory
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
'draw_detection',
|
| 22 |
+
'draw_trajectory_trail',
|
| 23 |
+
'draw_speed_label',
|
| 24 |
+
'draw_info_panel',
|
| 25 |
+
'create_trajectory_plot',
|
| 26 |
+
'VideoReader',
|
| 27 |
+
'VideoWriter',
|
| 28 |
+
'export_trajectory_csv',
|
| 29 |
+
'get_video_info',
|
| 30 |
+
'validate_video_file',
|
| 31 |
+
'create_output_directory'
|
| 32 |
+
]
|
utils/io_utils.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
I/O utilities for video processing and data export.
|
| 3 |
+
|
| 4 |
+
This module provides functions for reading/writing videos,
|
| 5 |
+
exporting trajectory data to CSV, and handling file operations.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import csv
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import List, Tuple, Optional, Generator
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VideoReader:
|
| 16 |
+
"""
|
| 17 |
+
Context manager for reading video files frame by frame.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
video_path (str): Path to input video file
|
| 21 |
+
cap (cv2.VideoCapture): OpenCV video capture object
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, video_path: str):
|
| 25 |
+
"""
|
| 26 |
+
Initialize video reader.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
video_path: Path to the video file
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
FileNotFoundError: If video file doesn't exist
|
| 33 |
+
RuntimeError: If video cannot be opened
|
| 34 |
+
"""
|
| 35 |
+
self.video_path = video_path
|
| 36 |
+
|
| 37 |
+
if not Path(video_path).exists():
|
| 38 |
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
| 39 |
+
|
| 40 |
+
self.cap = cv2.VideoCapture(video_path)
|
| 41 |
+
|
| 42 |
+
if not self.cap.isOpened():
|
| 43 |
+
raise RuntimeError(f"Failed to open video: {video_path}")
|
| 44 |
+
|
| 45 |
+
def __enter__(self):
|
| 46 |
+
"""Context manager entry."""
|
| 47 |
+
return self
|
| 48 |
+
|
| 49 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 50 |
+
"""Context manager exit - release video capture."""
|
| 51 |
+
self.cap.release()
|
| 52 |
+
|
| 53 |
+
def get_properties(self) -> dict:
|
| 54 |
+
"""
|
| 55 |
+
Get video properties.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Dictionary containing fps, frame_count, width, height
|
| 59 |
+
"""
|
| 60 |
+
return {
|
| 61 |
+
'fps': self.cap.get(cv2.CAP_PROP_FPS),
|
| 62 |
+
'frame_count': int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)),
|
| 63 |
+
'width': int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
| 64 |
+
'height': int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def read_frames(self) -> Generator[Tuple[int, np.ndarray], None, None]:
|
| 68 |
+
"""
|
| 69 |
+
Generator that yields frames from the video.
|
| 70 |
+
|
| 71 |
+
Yields:
|
| 72 |
+
Tuple of (frame_number, frame_array)
|
| 73 |
+
"""
|
| 74 |
+
frame_num = 0
|
| 75 |
+
while True:
|
| 76 |
+
ret, frame = self.cap.read()
|
| 77 |
+
if not ret:
|
| 78 |
+
break
|
| 79 |
+
yield frame_num, frame
|
| 80 |
+
frame_num += 1
|
| 81 |
+
|
| 82 |
+
def read_frame(self) -> Tuple[bool, Optional[np.ndarray]]:
|
| 83 |
+
"""
|
| 84 |
+
Read a single frame.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Tuple of (success, frame) where success is a boolean
|
| 88 |
+
"""
|
| 89 |
+
return self.cap.read()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class VideoWriter:
|
| 93 |
+
"""
|
| 94 |
+
Context manager for writing video files.
|
| 95 |
+
|
| 96 |
+
Attributes:
|
| 97 |
+
output_path (str): Path to output video file
|
| 98 |
+
fps (float): Frame rate
|
| 99 |
+
width (int): Frame width
|
| 100 |
+
height (int): Frame height
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
output_path: str,
|
| 106 |
+
fps: float,
|
| 107 |
+
width: int,
|
| 108 |
+
height: int,
|
| 109 |
+
codec: str = 'mp4v'
|
| 110 |
+
):
|
| 111 |
+
"""
|
| 112 |
+
Initialize video writer.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
output_path: Path to save the video
|
| 116 |
+
fps: Frame rate
|
| 117 |
+
width: Frame width in pixels
|
| 118 |
+
height: Frame height in pixels
|
| 119 |
+
codec: Video codec fourcc code
|
| 120 |
+
"""
|
| 121 |
+
self.output_path = output_path
|
| 122 |
+
self.fps = fps
|
| 123 |
+
self.width = width
|
| 124 |
+
self.height = height
|
| 125 |
+
|
| 126 |
+
# Create output directory if it doesn't exist
|
| 127 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 128 |
+
|
| 129 |
+
# Initialize video writer
|
| 130 |
+
fourcc = cv2.VideoWriter_fourcc(*codec)
|
| 131 |
+
self.writer = cv2.VideoWriter(
|
| 132 |
+
output_path,
|
| 133 |
+
fourcc,
|
| 134 |
+
fps,
|
| 135 |
+
(width, height)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
if not self.writer.isOpened():
|
| 139 |
+
raise RuntimeError(f"Failed to create video writer: {output_path}")
|
| 140 |
+
|
| 141 |
+
def __enter__(self):
|
| 142 |
+
"""Context manager entry."""
|
| 143 |
+
return self
|
| 144 |
+
|
| 145 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 146 |
+
"""Context manager exit - release video writer."""
|
| 147 |
+
self.writer.release()
|
| 148 |
+
|
| 149 |
+
def write_frame(self, frame: np.ndarray):
|
| 150 |
+
"""
|
| 151 |
+
Write a single frame to the video.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
frame: Frame array in BGR format
|
| 155 |
+
"""
|
| 156 |
+
# Ensure frame has correct dimensions
|
| 157 |
+
if frame.shape[1] != self.width or frame.shape[0] != self.height:
|
| 158 |
+
frame = cv2.resize(frame, (self.width, self.height))
|
| 159 |
+
|
| 160 |
+
self.writer.write(frame)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def export_trajectory_csv(
|
| 164 |
+
trajectory: List[Tuple[float, float, float, float, int]],
|
| 165 |
+
fps: float,
|
| 166 |
+
output_path: str
|
| 167 |
+
) -> bool:
|
| 168 |
+
"""
|
| 169 |
+
Export trajectory data to CSV file.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
trajectory: List of (x, y, vx, vy, frame_num) tuples
|
| 173 |
+
fps: Video frame rate
|
| 174 |
+
output_path: Path to save CSV file
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
True if successful, False otherwise
|
| 178 |
+
"""
|
| 179 |
+
try:
|
| 180 |
+
# Create output directory if needed
|
| 181 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 182 |
+
|
| 183 |
+
with open(output_path, 'w', newline='') as csvfile:
|
| 184 |
+
writer = csv.writer(csvfile)
|
| 185 |
+
|
| 186 |
+
# Write header
|
| 187 |
+
writer.writerow([
|
| 188 |
+
'frame',
|
| 189 |
+
'timestamp_sec',
|
| 190 |
+
'x_pixels',
|
| 191 |
+
'y_pixels',
|
| 192 |
+
'velocity_x_px_per_sec',
|
| 193 |
+
'velocity_y_px_per_sec',
|
| 194 |
+
'speed_px_per_sec'
|
| 195 |
+
])
|
| 196 |
+
|
| 197 |
+
# Write data rows
|
| 198 |
+
for x, y, vx, vy, frame_num in trajectory:
|
| 199 |
+
timestamp = frame_num / fps
|
| 200 |
+
speed = np.sqrt(vx**2 + vy**2) / (1.0 / fps)
|
| 201 |
+
|
| 202 |
+
writer.writerow([
|
| 203 |
+
frame_num,
|
| 204 |
+
f"{timestamp:.3f}",
|
| 205 |
+
f"{x:.2f}",
|
| 206 |
+
f"{y:.2f}",
|
| 207 |
+
f"{vx / (1.0 / fps):.2f}",
|
| 208 |
+
f"{vy / (1.0 / fps):.2f}",
|
| 209 |
+
f"{speed:.2f}"
|
| 210 |
+
])
|
| 211 |
+
|
| 212 |
+
return True
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
print(f"Error exporting CSV: {str(e)}")
|
| 216 |
+
return False
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def get_video_info(video_path: str) -> Optional[dict]:
|
| 220 |
+
"""
|
| 221 |
+
Get basic information about a video file.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
video_path: Path to video file
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
Dictionary with video properties or None if failed
|
| 228 |
+
"""
|
| 229 |
+
try:
|
| 230 |
+
with VideoReader(video_path) as reader:
|
| 231 |
+
return reader.get_properties()
|
| 232 |
+
except Exception as e:
|
| 233 |
+
print(f"Error reading video info: {str(e)}")
|
| 234 |
+
return None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def validate_video_file(video_path: str) -> Tuple[bool, str]:
|
| 238 |
+
"""
|
| 239 |
+
Validate that a video file exists and can be opened.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
video_path: Path to video file
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Tuple of (is_valid, error_message)
|
| 246 |
+
"""
|
| 247 |
+
if not video_path:
|
| 248 |
+
return False, "No video path provided"
|
| 249 |
+
|
| 250 |
+
path = Path(video_path)
|
| 251 |
+
|
| 252 |
+
if not path.exists():
|
| 253 |
+
return False, f"Video file not found: {video_path}"
|
| 254 |
+
|
| 255 |
+
if not path.is_file():
|
| 256 |
+
return False, f"Path is not a file: {video_path}"
|
| 257 |
+
|
| 258 |
+
# Try to open the video
|
| 259 |
+
try:
|
| 260 |
+
with VideoReader(video_path) as reader:
|
| 261 |
+
props = reader.get_properties()
|
| 262 |
+
|
| 263 |
+
if props['frame_count'] == 0:
|
| 264 |
+
return False, "Video has no frames"
|
| 265 |
+
|
| 266 |
+
if props['fps'] <= 0:
|
| 267 |
+
return False, "Invalid video frame rate"
|
| 268 |
+
|
| 269 |
+
return True, "Valid video file"
|
| 270 |
+
|
| 271 |
+
except Exception as e:
|
| 272 |
+
return False, f"Failed to open video: {str(e)}"
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def create_output_directory(output_dir: str = "output") -> Path:
|
| 276 |
+
"""
|
| 277 |
+
Create output directory if it doesn't exist.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
output_dir: Directory name/path
|
| 281 |
+
|
| 282 |
+
Returns:
|
| 283 |
+
Path object for the output directory
|
| 284 |
+
"""
|
| 285 |
+
output_path = Path(output_dir)
|
| 286 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
return output_path
|
utils/visualization.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Visualization utilities for ball tracking.
|
| 3 |
+
|
| 4 |
+
This module provides functions for rendering bounding boxes, trajectories,
|
| 5 |
+
and creating 2D trajectory plots with speed-based color coding.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib.colors as mcolors
|
| 12 |
+
from typing import List, Tuple, Optional
|
| 13 |
+
from matplotlib.figure import Figure
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def draw_detection(
|
| 17 |
+
frame: np.ndarray,
|
| 18 |
+
detection: Tuple[int, int, int, int, float],
|
| 19 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 20 |
+
thickness: int = 2
|
| 21 |
+
) -> np.ndarray:
|
| 22 |
+
"""
|
| 23 |
+
Draw a bounding box for a detection on the frame.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
frame: Input frame (BGR format)
|
| 27 |
+
detection: Bounding box as (x1, y1, x2, y2, confidence)
|
| 28 |
+
color: Box color in BGR format
|
| 29 |
+
thickness: Line thickness
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Frame with drawn bounding box
|
| 33 |
+
"""
|
| 34 |
+
x1, y1, x2, y2, conf = detection
|
| 35 |
+
|
| 36 |
+
# Draw rectangle
|
| 37 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
|
| 38 |
+
|
| 39 |
+
# Draw confidence label
|
| 40 |
+
label = f"{conf:.2f}"
|
| 41 |
+
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
| 42 |
+
label_y = max(y1 - 10, label_size[1])
|
| 43 |
+
|
| 44 |
+
cv2.rectangle(
|
| 45 |
+
frame,
|
| 46 |
+
(x1, label_y - label_size[1] - 5),
|
| 47 |
+
(x1 + label_size[0], label_y + 5),
|
| 48 |
+
color,
|
| 49 |
+
-1
|
| 50 |
+
)
|
| 51 |
+
cv2.putText(
|
| 52 |
+
frame,
|
| 53 |
+
label,
|
| 54 |
+
(x1, label_y),
|
| 55 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 56 |
+
0.5,
|
| 57 |
+
(0, 0, 0),
|
| 58 |
+
1
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
return frame
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def draw_trajectory_trail(
|
| 65 |
+
frame: np.ndarray,
|
| 66 |
+
positions: List[Tuple[float, float]],
|
| 67 |
+
color: Tuple[int, int, int] = (0, 255, 255),
|
| 68 |
+
max_points: int = 20
|
| 69 |
+
) -> np.ndarray:
|
| 70 |
+
"""
|
| 71 |
+
Draw a trail showing recent ball positions.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
frame: Input frame (BGR format)
|
| 75 |
+
positions: List of (x, y) positions (most recent last)
|
| 76 |
+
color: Trail color in BGR format
|
| 77 |
+
max_points: Maximum number of points to show
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Frame with drawn trajectory trail
|
| 81 |
+
"""
|
| 82 |
+
if len(positions) < 2:
|
| 83 |
+
return frame
|
| 84 |
+
|
| 85 |
+
# Use only recent positions
|
| 86 |
+
recent = positions[-max_points:]
|
| 87 |
+
|
| 88 |
+
# Draw lines connecting positions with fading effect
|
| 89 |
+
for i in range(1, len(recent)):
|
| 90 |
+
# Calculate alpha (opacity) based on position in trail
|
| 91 |
+
alpha = i / len(recent)
|
| 92 |
+
|
| 93 |
+
# Blend color with background
|
| 94 |
+
pt1 = (int(recent[i - 1][0]), int(recent[i - 1][1]))
|
| 95 |
+
pt2 = (int(recent[i][0]), int(recent[i][1]))
|
| 96 |
+
|
| 97 |
+
# Draw line with thickness varying by position
|
| 98 |
+
thickness = max(1, int(2 * alpha))
|
| 99 |
+
line_color = tuple(int(c * alpha) for c in color)
|
| 100 |
+
|
| 101 |
+
cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA)
|
| 102 |
+
|
| 103 |
+
# Draw circle at current position
|
| 104 |
+
if len(recent) > 0:
|
| 105 |
+
curr_pos = (int(recent[-1][0]), int(recent[-1][1]))
|
| 106 |
+
cv2.circle(frame, curr_pos, 5, color, -1, cv2.LINE_AA)
|
| 107 |
+
|
| 108 |
+
return frame
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def draw_speed_label(
|
| 112 |
+
frame: np.ndarray,
|
| 113 |
+
position: Tuple[float, float],
|
| 114 |
+
speed: float,
|
| 115 |
+
fps: float,
|
| 116 |
+
color: Tuple[int, int, int] = (255, 255, 255)
|
| 117 |
+
) -> np.ndarray:
|
| 118 |
+
"""
|
| 119 |
+
Draw speed information near the ball position.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
frame: Input frame (BGR format)
|
| 123 |
+
position: Ball position as (x, y)
|
| 124 |
+
speed: Speed in pixels per second
|
| 125 |
+
fps: Video frame rate
|
| 126 |
+
color: Text color in BGR format
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Frame with speed label
|
| 130 |
+
"""
|
| 131 |
+
x, y = int(position[0]), int(position[1])
|
| 132 |
+
|
| 133 |
+
# Convert pixel speed to approximate real-world units
|
| 134 |
+
# (This is a rough estimate; proper conversion requires camera calibration)
|
| 135 |
+
speed_kmh = speed * 0.01 # Rough approximation
|
| 136 |
+
|
| 137 |
+
label = f"{speed_kmh:.1f} km/h"
|
| 138 |
+
|
| 139 |
+
# Draw label with background
|
| 140 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 141 |
+
font_scale = 0.6
|
| 142 |
+
thickness = 2
|
| 143 |
+
label_size, _ = cv2.getTextSize(label, font, font_scale, thickness)
|
| 144 |
+
|
| 145 |
+
# Position label above the ball
|
| 146 |
+
label_x = x - label_size[0] // 2
|
| 147 |
+
label_y = y - 20
|
| 148 |
+
|
| 149 |
+
# Ensure label stays within frame
|
| 150 |
+
label_x = max(0, min(label_x, frame.shape[1] - label_size[0]))
|
| 151 |
+
label_y = max(label_size[1] + 5, label_y)
|
| 152 |
+
|
| 153 |
+
# Draw background rectangle
|
| 154 |
+
cv2.rectangle(
|
| 155 |
+
frame,
|
| 156 |
+
(label_x - 5, label_y - label_size[1] - 5),
|
| 157 |
+
(label_x + label_size[0] + 5, label_y + 5),
|
| 158 |
+
(0, 0, 0),
|
| 159 |
+
-1
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# Draw text
|
| 163 |
+
cv2.putText(
|
| 164 |
+
frame,
|
| 165 |
+
label,
|
| 166 |
+
(label_x, label_y),
|
| 167 |
+
font,
|
| 168 |
+
font_scale,
|
| 169 |
+
color,
|
| 170 |
+
thickness,
|
| 171 |
+
cv2.LINE_AA
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return frame
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def draw_info_panel(
|
| 178 |
+
frame: np.ndarray,
|
| 179 |
+
frame_num: int,
|
| 180 |
+
total_frames: int,
|
| 181 |
+
fps: float,
|
| 182 |
+
detection_conf: Optional[float] = None
|
| 183 |
+
) -> np.ndarray:
|
| 184 |
+
"""
|
| 185 |
+
Draw an information panel at the top of the frame.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
frame: Input frame (BGR format)
|
| 189 |
+
frame_num: Current frame number
|
| 190 |
+
total_frames: Total number of frames
|
| 191 |
+
fps: Video frame rate
|
| 192 |
+
detection_conf: Detection confidence (if available)
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
Frame with info panel
|
| 196 |
+
"""
|
| 197 |
+
# Create semi-transparent overlay
|
| 198 |
+
overlay = frame.copy()
|
| 199 |
+
cv2.rectangle(overlay, (0, 0), (frame.shape[1], 60), (0, 0, 0), -1)
|
| 200 |
+
frame = cv2.addWeighted(overlay, 0.6, frame, 0.4, 0)
|
| 201 |
+
|
| 202 |
+
# Draw text information
|
| 203 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 204 |
+
font_scale = 0.6
|
| 205 |
+
color = (255, 255, 255)
|
| 206 |
+
thickness = 2
|
| 207 |
+
|
| 208 |
+
# Frame counter
|
| 209 |
+
frame_text = f"Frame: {frame_num}/{total_frames}"
|
| 210 |
+
cv2.putText(frame, frame_text, (10, 25), font, font_scale, color, thickness)
|
| 211 |
+
|
| 212 |
+
# Time
|
| 213 |
+
time_text = f"Time: {frame_num / fps:.2f}s"
|
| 214 |
+
cv2.putText(frame, time_text, (10, 50), font, font_scale, color, thickness)
|
| 215 |
+
|
| 216 |
+
# Detection confidence (if available)
|
| 217 |
+
if detection_conf is not None:
|
| 218 |
+
conf_text = f"Confidence: {detection_conf:.2%}"
|
| 219 |
+
cv2.putText(frame, conf_text, (250, 25), font, font_scale, color, thickness)
|
| 220 |
+
|
| 221 |
+
return frame
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def create_trajectory_plot(
|
| 225 |
+
trajectory: List[Tuple[float, float, float, float, int]],
|
| 226 |
+
fps: float,
|
| 227 |
+
output_path: Optional[str] = None
|
| 228 |
+
) -> Figure:
|
| 229 |
+
"""
|
| 230 |
+
Create a 2D trajectory plot color-coded by speed.
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
trajectory: List of (x, y, vx, vy, frame_num) tuples
|
| 234 |
+
fps: Video frame rate
|
| 235 |
+
output_path: Path to save plot (optional)
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Matplotlib Figure object
|
| 239 |
+
"""
|
| 240 |
+
if len(trajectory) == 0:
|
| 241 |
+
# Create empty plot
|
| 242 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 243 |
+
ax.text(
|
| 244 |
+
0.5, 0.5, "No trajectory data available",
|
| 245 |
+
ha='center', va='center', fontsize=14
|
| 246 |
+
)
|
| 247 |
+
ax.set_xlim(0, 1)
|
| 248 |
+
ax.set_ylim(0, 1)
|
| 249 |
+
return fig
|
| 250 |
+
|
| 251 |
+
# Extract coordinates and velocities
|
| 252 |
+
x_coords = [p[0] for p in trajectory]
|
| 253 |
+
y_coords = [p[1] for p in trajectory]
|
| 254 |
+
vx = [p[2] for p in trajectory]
|
| 255 |
+
vy = [p[3] for p in trajectory]
|
| 256 |
+
|
| 257 |
+
# Calculate speeds
|
| 258 |
+
speeds = [np.sqrt(vx[i]**2 + vy[i]**2) / (1.0 / fps) for i in range(len(vx))]
|
| 259 |
+
|
| 260 |
+
# Create figure
|
| 261 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 262 |
+
|
| 263 |
+
# Normalize speeds for color mapping
|
| 264 |
+
if max(speeds) > 0:
|
| 265 |
+
norm = mcolors.Normalize(vmin=min(speeds), vmax=max(speeds))
|
| 266 |
+
colormap = plt.cm.jet
|
| 267 |
+
else:
|
| 268 |
+
norm = None
|
| 269 |
+
colormap = None
|
| 270 |
+
|
| 271 |
+
# Plot trajectory with color-coded speeds
|
| 272 |
+
for i in range(1, len(x_coords)):
|
| 273 |
+
if norm is not None:
|
| 274 |
+
color = colormap(norm(speeds[i]))
|
| 275 |
+
else:
|
| 276 |
+
color = 'blue'
|
| 277 |
+
|
| 278 |
+
ax.plot(
|
| 279 |
+
[x_coords[i - 1], x_coords[i]],
|
| 280 |
+
[y_coords[i - 1], y_coords[i]],
|
| 281 |
+
color=color,
|
| 282 |
+
linewidth=2,
|
| 283 |
+
alpha=0.7
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Add start and end markers
|
| 287 |
+
ax.scatter(x_coords[0], y_coords[0], c='green', s=100, marker='o',
|
| 288 |
+
label='Start', zorder=5, edgecolors='black', linewidths=2)
|
| 289 |
+
ax.scatter(x_coords[-1], y_coords[-1], c='red', s=100, marker='X',
|
| 290 |
+
label='End', zorder=5, edgecolors='black', linewidths=2)
|
| 291 |
+
|
| 292 |
+
# Formatting
|
| 293 |
+
ax.set_xlabel('X Position (pixels)', fontsize=12, fontweight='bold')
|
| 294 |
+
ax.set_ylabel('Y Position (pixels)', fontsize=12, fontweight='bold')
|
| 295 |
+
ax.set_title('Tennis Ball Trajectory (Color = Speed)', fontsize=14, fontweight='bold')
|
| 296 |
+
ax.legend(loc='best', fontsize=10)
|
| 297 |
+
ax.grid(True, alpha=0.3)
|
| 298 |
+
ax.invert_yaxis() # Invert Y-axis to match image coordinates
|
| 299 |
+
|
| 300 |
+
# Add colorbar
|
| 301 |
+
if norm is not None:
|
| 302 |
+
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
|
| 303 |
+
sm.set_array([])
|
| 304 |
+
cbar = plt.colorbar(sm, ax=ax, label='Speed (pixels/sec)')
|
| 305 |
+
|
| 306 |
+
plt.tight_layout()
|
| 307 |
+
|
| 308 |
+
# Save if path provided
|
| 309 |
+
if output_path:
|
| 310 |
+
try:
|
| 311 |
+
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f"Error saving plot: {str(e)}")
|
| 314 |
+
|
| 315 |
+
return fig
|