tennisvision / detector.py
Onur Çopur
first commit
3b90d9c
raw
history blame
4.73 kB
"""
Ball Detection Module using YOLO and RT-DETR models.
This module provides a unified interface for detecting tennis balls
in video frames using state-of-the-art object detection models.
"""
import torch
import numpy as np
from typing import List, Tuple, Optional
from ultralytics import YOLO
class BallDetector:
"""
Wrapper class for ball detection using YOLOv8 or RT-DETR.
Attributes:
model_name (str): Name of the detection model ('yolov8n', 'yolov8s', etc.)
confidence_threshold (float): Minimum confidence score for detections
device (str): Device to run inference on ('cuda' or 'cpu')
"""
def __init__(
self,
model_name: str = "yolov8n",
confidence_threshold: float = 0.3,
device: Optional[str] = None
):
"""
Initialize the ball detector.
Args:
model_name: Model identifier (e.g., 'yolov8n', 'yolov8s')
confidence_threshold: Minimum confidence for valid detections
device: Compute device ('cuda', 'cpu', or None for auto-detect)
"""
self.model_name = model_name
self.confidence_threshold = confidence_threshold
# Auto-detect device if not specified
if device is None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = device
# Load model
self.model = self._load_model()
def _load_model(self) -> YOLO:
"""
Load the specified detection model.
Returns:
Loaded YOLO model instance
Raises:
ValueError: If model name is not supported
"""
try:
if self.model_name.startswith('yolov8'):
# Load YOLOv8 model from Ultralytics
model = YOLO(f'{self.model_name}.pt')
model.to(self.device)
return model
else:
raise ValueError(f"Unsupported model: {self.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load model {self.model_name}: {str(e)}")
def detect(self, frame: np.ndarray) -> List[Tuple[int, int, int, int, float]]:
"""
Detect tennis balls in a single frame.
Args:
frame: Input frame as numpy array (H, W, 3) in BGR format
Returns:
List of detections, each as (x1, y1, x2, y2, confidence)
where (x1, y1) is top-left and (x2, y2) is bottom-right
"""
try:
# Run inference
results = self.model.predict(
frame,
conf=self.confidence_threshold,
device=self.device,
verbose=False,
classes=[32] # Sports ball class in COCO dataset
)
detections = []
# Parse results
if len(results) > 0 and results[0].boxes is not None:
boxes = results[0].boxes
for box in boxes:
# Extract bounding box coordinates
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
confidence = float(box.conf[0].cpu().numpy())
# Filter small detections (likely noise)
width = x2 - x1
height = y2 - y1
if width > 5 and height > 5: # Minimum size threshold
detections.append((
int(x1), int(y1), int(x2), int(y2), confidence
))
# Sort by confidence (highest first)
detections.sort(key=lambda x: x[4], reverse=True)
return detections
except Exception as e:
print(f"Detection error: {str(e)}")
return []
def get_ball_center(
self,
detection: Tuple[int, int, int, int, float]
) -> Tuple[float, float]:
"""
Calculate the center point of a ball detection.
Args:
detection: Bounding box as (x1, y1, x2, y2, confidence)
Returns:
Center coordinates as (cx, cy)
"""
x1, y1, x2, y2, _ = detection
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
return cx, cy
def get_ball_size(
self,
detection: Tuple[int, int, int, int, float]
) -> Tuple[float, float]:
"""
Calculate the width and height of a ball detection.
Args:
detection: Bounding box as (x1, y1, x2, y2, confidence)
Returns:
Size as (width, height)
"""
x1, y1, x2, y2, _ = detection
width = x2 - x1
height = y2 - y1
return width, height