""" TennisVision - AI Ball Tracker Gradio application for tennis ball detection and tracking. """ import gradio as gr import cv2 import numpy as np import os import tempfile from pathlib import Path from typing import Tuple, Optional from detector import BallDetector from tracker import BallTracker from utils import ( VideoReader, VideoWriter, export_trajectory_csv, validate_video_file, create_output_directory, draw_detection, draw_trajectory_trail, draw_speed_label, draw_info_panel, create_trajectory_plot ) def process_video( video_path: str, model_name: str, confidence_threshold: float, progress=gr.Progress() ) -> Tuple[Optional[str], Optional[str], Optional[str], str]: """ Process a video to track the tennis ball. Args: video_path: Path to input video file model_name: Detection model identifier confidence_threshold: Minimum detection confidence progress: Gradio progress tracker Returns: Tuple of (output_video_path, csv_path, plot_path, status_message) """ try: # Validate input video is_valid, msg = validate_video_file(video_path) if not is_valid: return None, None, None, f"❌ Error: {msg}" progress(0, desc="Initializing models...") # Initialize detector and tracker detector = BallDetector( model_name=model_name, confidence_threshold=confidence_threshold ) # Read video properties with VideoReader(video_path) as reader: video_props = reader.get_properties() fps = video_props['fps'] frame_count = video_props['frame_count'] width = video_props['width'] height = video_props['height'] # Initialize tracker tracker = BallTracker(dt=1.0 / fps, max_missing_frames=int(fps * 0.5)) # Create temporary output files output_dir = create_output_directory("output") temp_video = tempfile.NamedTemporaryFile( delete=False, suffix='.mp4', dir=output_dir ) output_video_path = temp_video.name temp_video.close() csv_path = output_dir / "trajectory.csv" plot_path = output_dir / "trajectory_plot.png" progress(0.1, desc="Processing frames...") # Process video detection_count = 0 with VideoReader(video_path) as reader, \ VideoWriter(output_video_path, fps, width, height) as writer: for frame_num, frame in reader.read_frames(): # Update progress progress_pct = 0.1 + 0.7 * (frame_num / frame_count) progress( progress_pct, desc=f"Processing frame {frame_num + 1}/{frame_count}" ) # Detect ball detections = detector.detect(frame) # Update tracker if len(detections) > 0: # Use highest confidence detection best_detection = detections[0] cx, cy = detector.get_ball_center(best_detection) state = tracker.update((cx, cy)) detection_count += 1 # Draw detection box frame = draw_detection(frame, best_detection) else: # Predict without detection state = tracker.update(None) # Draw trajectory and info if tracker is active if state is not None and tracker.is_initialized(): x, y, vx, vy = state # Draw trajectory trail positions = tracker.get_last_n_positions(20) frame = draw_trajectory_trail(frame, positions) # Calculate and draw speed speed = tracker.get_speed(state) frame = draw_speed_label(frame, (x, y), speed, fps) # Draw info panel conf = detections[0][4] if len(detections) > 0 else None frame = draw_info_panel(frame, frame_num + 1, frame_count, fps, conf) # Write frame writer.write_frame(frame) # Export trajectory data progress(0.8, desc="Exporting trajectory data...") trajectory = tracker.get_trajectory() if len(trajectory) == 0: return None, None, None, "❌ No ball detected in video. Try lowering the confidence threshold." # Export CSV export_success = export_trajectory_csv(trajectory, fps, str(csv_path)) if not export_success: csv_path = None # Create trajectory plot progress(0.9, desc="Creating trajectory plot...") try: create_trajectory_plot(trajectory, fps, str(plot_path)) except Exception as e: print(f"Failed to create plot: {e}") plot_path = None progress(1.0, desc="Complete!") # Generate status message status = f"""✅ **Processing Complete!** **Video Info:** - Total Frames: {frame_count} - Frame Rate: {fps:.1f} FPS - Resolution: {width}x{height} **Tracking Results:** - Ball Detected: {detection_count} frames ({100 * detection_count / frame_count:.1f}%) - Trajectory Points: {len(trajectory)} **Outputs:** - Processed video with overlays - Trajectory CSV with {len(trajectory)} data points - 2D trajectory plot color-coded by speed """ return ( output_video_path, str(csv_path) if csv_path else None, str(plot_path) if plot_path else None, status ) except Exception as e: error_msg = f"❌ **Error during processing:** {str(e)}" print(error_msg) import traceback traceback.print_exc() return None, None, None, error_msg # Create Gradio interface def create_interface(): """Create and configure the Gradio interface.""" with gr.Blocks( title="TennisVision - AI Ball Tracker", theme=gr.themes.Soft() ) as app: gr.Markdown( """ # 🎾 TennisVision - AI Ball Tracker Upload a tennis video to automatically detect and track the ball using state-of-the-art computer vision models. **Features:** - Real-time ball detection with YOLOv8 - Smooth trajectory tracking with Kalman filter - Speed estimation and visualization - Downloadable outputs (video, CSV, plot) """ ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### ⚙️ Input & Settings") video_input = gr.Video( label="Upload Tennis Video", sources=["upload"] ) model_dropdown = gr.Dropdown( choices=["yolov8n", "yolov8s", "yolov8m"], value="yolov8n", label="Detection Model", info="yolov8n is fastest, yolov8m is most accurate" ) confidence_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.3, step=0.05, label="Confidence Threshold", info="Lower = more detections (may include false positives)" ) process_btn = gr.Button( "🚀 Run Tracking", variant="primary", size="lg" ) gr.Markdown( """ ### 💡 Tips - Use short clips (5-15 seconds) for faster processing - Ensure the ball is visible and in motion - Lower confidence threshold if ball is not detected - YOLOv8n provides fastest inference (~30 FPS) """ ) with gr.Column(scale=2): gr.Markdown("### 📊 Results") status_output = gr.Markdown( "Upload a video and click **Run Tracking** to begin." ) with gr.Tabs(): with gr.Tab("📹 Processed Video"): video_output = gr.Video( label="Tracked Video", show_label=False ) with gr.Tab("📈 Trajectory Plot"): plot_output = gr.Image( label="2D Trajectory", show_label=False ) with gr.Tab("📥 Downloads"): gr.Markdown("### Download Files") csv_output = gr.File( label="Trajectory Data (CSV)" ) video_download = gr.File( label="Processed Video (MP4)" ) # Event handlers process_btn.click( fn=process_video, inputs=[video_input, model_dropdown, confidence_slider], outputs=[video_output, csv_output, plot_output, status_output] ).then( fn=lambda x: x, inputs=[video_output], outputs=[video_download] ) gr.Markdown( """ --- ### 📚 About **TennisVision** uses YOLOv8 for ball detection and Kalman filtering for smooth trajectory tracking. The system estimates ball speed and visualizes the complete trajectory with color-coded speed indicators. **Model:** YOLOv8 (Ultralytics) **Tracking:** Kalman Filter **Framework:** Gradio + OpenCV Built for deployment on Hugging Face Spaces 🤗 """ ) return app if __name__ == "__main__": # Create output directory create_output_directory("output") # Launch app app = create_interface() app.launch( share=False, server_name="0.0.0.0", server_port=7860 )