""" Ball Detection Module using YOLO and RT-DETR models. This module provides a unified interface for detecting tennis balls in video frames using state-of-the-art object detection models. """ import torch import numpy as np from typing import List, Tuple, Optional from ultralytics import YOLO class BallDetector: """ Wrapper class for ball detection using YOLOv8 or RT-DETR. Attributes: model_name (str): Name of the detection model ('yolov8n', 'yolov8s', etc.) confidence_threshold (float): Minimum confidence score for detections device (str): Device to run inference on ('cuda' or 'cpu') """ def __init__( self, model_name: str = "yolov8n", confidence_threshold: float = 0.3, device: Optional[str] = None ): """ Initialize the ball detector. Args: model_name: Model identifier (e.g., 'yolov8n', 'yolov8s') confidence_threshold: Minimum confidence for valid detections device: Compute device ('cuda', 'cpu', or None for auto-detect) """ self.model_name = model_name self.confidence_threshold = confidence_threshold # Auto-detect device if not specified if device is None: self.device = 'cuda' if torch.cuda.is_available() else 'cpu' else: self.device = device # Load model self.model = self._load_model() def _load_model(self) -> YOLO: """ Load the specified detection model. Returns: Loaded YOLO model instance Raises: ValueError: If model name is not supported """ try: if self.model_name.startswith('yolov8'): # Load YOLOv8 model from Ultralytics model = YOLO(f'{self.model_name}.pt') model.to(self.device) return model else: raise ValueError(f"Unsupported model: {self.model_name}") except Exception as e: raise RuntimeError(f"Failed to load model {self.model_name}: {str(e)}") def detect(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]: """ Detect tennis balls in a single frame. Args: frame: Input frame as numpy array (H, W, 3) in BGR format Returns: List of detections, each as (x1, y1, x2, y2, confidence) where (x1, y1) is top-left and (x2, y2) is bottom-right """ try: # Run inference results = self.model.predict( frame, conf=self.confidence_threshold, device=self.device, verbose=False, classes=[32] # Sports ball class in COCO dataset ) detections = [] # Parse results if len(results) > 0 and results[0].boxes is not None: boxes = results[0].boxes for box in boxes: # Extract bounding box coordinates x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() confidence = float(box.conf[0].cpu().numpy()) # Filter small detections (likely noise) width = x2 - x1 height = y2 - y1 if width > 5 and height > 5: # Minimum size threshold detections.append(( int(x1), int(y1), int(x2), int(y2), confidence )) # Sort by confidence (highest first) detections.sort(key=lambda x: x[4], reverse=True) return detections except Exception as e: print(f"Detection error: {str(e)}") return [] def get_ball_center( self, detection: Tuple[int, int, int, int, float] ) -> Tuple[float, float]: """ Calculate the center point of a ball detection. Args: detection: Bounding box as (x1, y1, x2, y2, confidence) Returns: Center coordinates as (cx, cy) """ x1, y1, x2, y2, _ = detection cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 return cx, cy def get_ball_size( self, detection: Tuple[int, int, int, int, float] ) -> Tuple[float, float]: """ Calculate the width and height of a ball detection. Args: detection: Bounding box as (x1, y1, x2, y2, confidence) Returns: Size as (width, height) """ x1, y1, x2, y2, _ = detection width = x2 - x1 height = y2 - y1 return width, height