""" Visualization utilities for ball tracking. This module provides functions for rendering bounding boxes, trajectories, and creating 2D trajectory plots with speed-based color coding. """ import cv2 import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors from typing import List, Tuple, Optional from matplotlib.figure import Figure def draw_detection( frame: np.ndarray, detection: Tuple[int, int, int, int, float], color: Tuple[int, int, int] = (0, 255, 0), thickness: int = 2 ) -> np.ndarray: """ Draw a bounding box for a detection on the frame. Args: frame: Input frame (BGR format) detection: Bounding box as (x1, y1, x2, y2, confidence) color: Box color in BGR format thickness: Line thickness Returns: Frame with drawn bounding box """ x1, y1, x2, y2, conf = detection # Draw rectangle cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness) # Draw confidence label label = f"{conf:.2f}" label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) label_y = max(y1 - 10, label_size[1]) cv2.rectangle( frame, (x1, label_y - label_size[1] - 5), (x1 + label_size[0], label_y + 5), color, -1 ) cv2.putText( frame, label, (x1, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1 ) return frame def draw_trajectory_trail( frame: np.ndarray, positions: List[Tuple[float, float]], color: Tuple[int, int, int] = (0, 255, 255), max_points: int = 20 ) -> np.ndarray: """ Draw a trail showing recent ball positions. Args: frame: Input frame (BGR format) positions: List of (x, y) positions (most recent last) color: Trail color in BGR format max_points: Maximum number of points to show Returns: Frame with drawn trajectory trail """ if len(positions) < 2: return frame # Use only recent positions recent = positions[-max_points:] # Draw lines connecting positions with fading effect for i in range(1, len(recent)): # Calculate alpha (opacity) based on position in trail alpha = i / len(recent) # Blend color with background pt1 = (int(recent[i - 1][0]), int(recent[i - 1][1])) pt2 = (int(recent[i][0]), int(recent[i][1])) # Draw line with thickness varying by position thickness = max(1, int(2 * alpha)) line_color = tuple(int(c * alpha) for c in color) cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA) # Draw circle at current position if len(recent) > 0: curr_pos = (int(recent[-1][0]), int(recent[-1][1])) cv2.circle(frame, curr_pos, 5, color, -1, cv2.LINE_AA) return frame def draw_speed_label( frame: np.ndarray, position: Tuple[float, float], speed: float, fps: float, color: Tuple[int, int, int] = (255, 255, 255) ) -> np.ndarray: """ Draw speed information near the ball position. Args: frame: Input frame (BGR format) position: Ball position as (x, y) speed: Speed in pixels per second fps: Video frame rate color: Text color in BGR format Returns: Frame with speed label """ x, y = int(position[0]), int(position[1]) # Convert pixel speed to approximate real-world units # (This is a rough estimate; proper conversion requires camera calibration) speed_kmh = speed * 0.01 # Rough approximation label = f"{speed_kmh:.1f} km/h" # Draw label with background font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 thickness = 2 label_size, _ = cv2.getTextSize(label, font, font_scale, thickness) # Position label above the ball label_x = x - label_size[0] // 2 label_y = y - 20 # Ensure label stays within frame label_x = max(0, min(label_x, frame.shape[1] - label_size[0])) label_y = max(label_size[1] + 5, label_y) # Draw background rectangle cv2.rectangle( frame, (label_x - 5, label_y - label_size[1] - 5), (label_x + label_size[0] + 5, label_y + 5), (0, 0, 0), -1 ) # Draw text cv2.putText( frame, label, (label_x, label_y), font, font_scale, color, thickness, cv2.LINE_AA ) return frame def draw_info_panel( frame: np.ndarray, frame_num: int, total_frames: int, fps: float, detection_conf: Optional[float] = None ) -> np.ndarray: """ Draw an information panel at the top of the frame. Args: frame: Input frame (BGR format) frame_num: Current frame number total_frames: Total number of frames fps: Video frame rate detection_conf: Detection confidence (if available) Returns: Frame with info panel """ # Create semi-transparent overlay overlay = frame.copy() cv2.rectangle(overlay, (0, 0), (frame.shape[1], 60), (0, 0, 0), -1) frame = cv2.addWeighted(overlay, 0.6, frame, 0.4, 0) # Draw text information font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.6 color = (255, 255, 255) thickness = 2 # Frame counter frame_text = f"Frame: {frame_num}/{total_frames}" cv2.putText(frame, frame_text, (10, 25), font, font_scale, color, thickness) # Time time_text = f"Time: {frame_num / fps:.2f}s" cv2.putText(frame, time_text, (10, 50), font, font_scale, color, thickness) # Detection confidence (if available) if detection_conf is not None: conf_text = f"Confidence: {detection_conf:.2%}" cv2.putText(frame, conf_text, (250, 25), font, font_scale, color, thickness) return frame def create_trajectory_plot( trajectory: List[Tuple[float, float, float, float, int]], fps: float, output_path: Optional[str] = None ) -> Figure: """ Create a 2D trajectory plot color-coded by speed. Args: trajectory: List of (x, y, vx, vy, frame_num) tuples fps: Video frame rate output_path: Path to save plot (optional) Returns: Matplotlib Figure object """ if len(trajectory) == 0: # Create empty plot fig, ax = plt.subplots(figsize=(10, 8)) ax.text( 0.5, 0.5, "No trajectory data available", ha='center', va='center', fontsize=14 ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) return fig # Extract coordinates and velocities x_coords = [p[0] for p in trajectory] y_coords = [p[1] for p in trajectory] vx = [p[2] for p in trajectory] vy = [p[3] for p in trajectory] # Calculate speeds speeds = [np.sqrt(vx[i]**2 + vy[i]**2) / (1.0 / fps) for i in range(len(vx))] # Create figure fig, ax = plt.subplots(figsize=(12, 10)) # Normalize speeds for color mapping if max(speeds) > 0: norm = mcolors.Normalize(vmin=min(speeds), vmax=max(speeds)) colormap = plt.cm.jet else: norm = None colormap = None # Plot trajectory with color-coded speeds for i in range(1, len(x_coords)): if norm is not None: color = colormap(norm(speeds[i])) else: color = 'blue' ax.plot( [x_coords[i - 1], x_coords[i]], [y_coords[i - 1], y_coords[i]], color=color, linewidth=2, alpha=0.7 ) # Add start and end markers ax.scatter(x_coords[0], y_coords[0], c='green', s=100, marker='o', label='Start', zorder=5, edgecolors='black', linewidths=2) ax.scatter(x_coords[-1], y_coords[-1], c='red', s=100, marker='X', label='End', zorder=5, edgecolors='black', linewidths=2) # Formatting ax.set_xlabel('X Position (pixels)', fontsize=12, fontweight='bold') ax.set_ylabel('Y Position (pixels)', fontsize=12, fontweight='bold') ax.set_title('Tennis Ball Trajectory (Color = Speed)', fontsize=14, fontweight='bold') ax.legend(loc='best', fontsize=10) ax.grid(True, alpha=0.3) ax.invert_yaxis() # Invert Y-axis to match image coordinates # Add colorbar if norm is not None: sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=ax, label='Speed (pixels/sec)') plt.tight_layout() # Save if path provided if output_path: try: plt.savefig(output_path, dpi=150, bbox_inches='tight') except Exception as e: print(f"Error saving plot: {str(e)}") return fig