"""Model loading and caching - optimized for NVIDIA A10G GPUs.""" import os import torch from transformers import ( Wav2Vec2Processor, Wav2Vec2ForCTC, AutoProcessor, AutoModelForSpeechSeq2Seq, WhisperProcessor, ) from typing import Tuple, Optional from data.manager import MODEL_OUTPUT_DIR # Available models for experimentation AVAILABLE_MODELS = { "Wav2Vec2 Base (960h)": "facebook/wav2vec2-base-960h", "Wav2Vec2 Large (960h)": "facebook/wav2vec2-large-960h", "Wav2Vec2 Base (100h)": "facebook/wav2vec2-base", "OWSM v3.1 Small (ESPnet)": "espnet/owsm_v3.1_ebf_small", } # Fine-tuned model key (added dynamically if model exists) FINE_TUNED_MODEL_KEY = "Fine-tuned Whisper (Caribbean Voices)" # Global variables for model caching current_model_name = None current_processor = None current_model = None # Optimized for CUDA on A10G GPUs device = "cuda" if torch.cuda.is_available() else "cpu" def is_fine_tuned_model_available() -> bool: """Check if a fine-tuned model exists in MODEL_OUTPUT_DIR""" if not os.path.exists(MODEL_OUTPUT_DIR): return False # Check for model files directly in MODEL_OUTPUT_DIR model_files = [ os.path.join(MODEL_OUTPUT_DIR, "pytorch_model.bin"), os.path.join(MODEL_OUTPUT_DIR, "model.safetensors"), os.path.join(MODEL_OUTPUT_DIR, "config.json"), os.path.join(MODEL_OUTPUT_DIR, "model_index.json"), # For ESPnet models ] if any(os.path.exists(f) for f in model_files): return True # Check for ESPnet model in subdirectory espnet_dir = os.path.join(MODEL_OUTPUT_DIR, "espnet_model") if os.path.exists(espnet_dir): espnet_files = [ os.path.join(espnet_dir, "config.json"), os.path.join(espnet_dir, "pytorch_model.bin"), os.path.join(espnet_dir, "model.safetensors"), ] if any(os.path.exists(f) for f in espnet_files): return True # Check for wrapper state dict (ESPnet wrapper models) wrapper_state = os.path.join(MODEL_OUTPUT_DIR, "wrapper_state_dict.pt") if os.path.exists(wrapper_state): return True # If directory exists and has any content, consider it a valid model directory # (some models might save differently) if os.path.isdir(MODEL_OUTPUT_DIR): try: contents = os.listdir(MODEL_OUTPUT_DIR) # If directory has any files (not just empty), it might be a model if contents: # Check for any common model file extensions model_extensions = ['.bin', '.safetensors', '.pt', '.json', '.txt'] for item in contents: item_path = os.path.join(MODEL_OUTPUT_DIR, item) if os.path.isfile(item_path): if any(item.endswith(ext) for ext in model_extensions): return True except (OSError, PermissionError): pass return False def load_model(model_key: str) -> Tuple[Optional[object], object]: """Load a model and processor, caching them for efficiency""" global current_model_name, current_processor, current_model # Handle fine-tuned model if model_key == FINE_TUNED_MODEL_KEY: if not is_fine_tuned_model_available(): raise FileNotFoundError( f"Fine-tuned model not found at {MODEL_OUTPUT_DIR}. " f"Please train a model first." ) if current_model_name != MODEL_OUTPUT_DIR: print(f"Loading fine-tuned model from {MODEL_OUTPUT_DIR}...") # Try to load processor (WhisperProcessor or AutoProcessor) try: processor = WhisperProcessor.from_pretrained(MODEL_OUTPUT_DIR) except: processor = AutoProcessor.from_pretrained(MODEL_OUTPUT_DIR) # Load model - use AutoModelForSpeechSeq2Seq for compatibility model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_OUTPUT_DIR) model.to(device) model.eval() current_processor = processor current_model = model current_model_name = MODEL_OUTPUT_DIR print(f"✓ Fine-tuned model loaded on {device}") return current_processor, current_model model_path = AVAILABLE_MODELS[model_key] # Handle OWSM models differently - optimized for A10G GPUs with Flash Attention if "OWSM" in model_key: try: from espnet2.bin.s2t_inference import Speech2Text if current_model_name != model_path: # Flash Attention should be available on A10G GPUs try: import flash_attn print("Loading OWSM model with Flash Attention optimization (A10G GPU)...") except ImportError: print("⚠ Loading OWSM model without Flash Attention (performance may be suboptimal)") print(" Flash Attention should be installed on A10G GPUs - check build logs") current_model = Speech2Text.from_pretrained(model_path) current_model_name = model_path print(f"✓ OWSM model loaded on {device}") return None, current_model # No processor for ESPnet except ImportError: raise ImportError("ESPnet not installed. Install with: pip install espnet espnet_model_zoo") # Standard HuggingFace models if current_model_name != model_path: print(f"Loading model: {model_path}") current_processor = Wav2Vec2Processor.from_pretrained(model_path) current_model = Wav2Vec2ForCTC.from_pretrained(model_path) current_model.to(device) current_model.eval() current_model_name = model_path print(f"Model loaded on {device}") return current_processor, current_model def get_available_models(): """Get list of available model keys, including fine-tuned model if available""" models = list(AVAILABLE_MODELS.keys()) # Add fine-tuned model if it exists if is_fine_tuned_model_available(): models.append(FINE_TUNED_MODEL_KEY) return models def load_checkpoint(checkpoint_path: str) -> Tuple[Optional[object], object]: """ Load a model from a specific checkpoint directory or Hugging Face Hub. Args: checkpoint_path: Path to checkpoint directory or HF Hub repo ID (e.g., "data/owsm_caribbean_finetuned/checkpoint-2000" or "shaun3141/caribbean-voices-owsm-finetuned") Returns: Tuple of (processor, model) """ hf_token = os.getenv("HF_TOKEN") # Check if this is an HF Hub repo ID (contains / and doesn't exist locally) is_hf_hub = "/" in checkpoint_path and not os.path.exists(checkpoint_path) if is_hf_hub: print(f"Loading checkpoint from Hugging Face Hub: {checkpoint_path}...") if not hf_token: raise ValueError(f"HF_TOKEN required to load model from Hub: {checkpoint_path}") else: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") print(f"Loading checkpoint from {checkpoint_path}...") # Try to load processor try: processor = WhisperProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) except: processor = AutoProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) # Load model model = AutoModelForSpeechSeq2Seq.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) model.to(device) model.eval() print(f"✓ Checkpoint loaded on {device}") return processor, model def get_available_checkpoints(include_base_models: bool = True) -> list: """ Get list of available checkpoints in MODEL_OUTPUT_DIR and Hugging Face Hub. Args: include_base_models: If True, include base models from AVAILABLE_MODELS Returns: List of checkpoint names (e.g., ["Final Model", "checkpoint-1000", "checkpoint-2000", "OWSM v3.1 Small (ESPnet)"]) """ checkpoints = [] # Add base models first if requested if include_base_models: checkpoints.extend(list(AVAILABLE_MODELS.keys())) # Check Hugging Face Hub for uploaded models (persistent storage) hf_token = os.getenv("HF_TOKEN") if hf_token: try: from huggingface_hub import HfApi from datetime import datetime api = HfApi(token=hf_token) # Search for all versioned OWSM models try: models = api.list_models( author="shaun3141", search="caribbean-voices-owsm-finetuned", token=hf_token ) for model_info in models: repo_id = model_info.id if "caribbean-voices-owsm-finetuned-" in repo_id: # Extract timestamp from repo name timestamp_str = repo_id.split("-")[-1] try: # Parse timestamp: YYYYMMDD-HHMMSS dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S") readable_date = dt.strftime("%Y-%m-%d %H:%M") checkpoint_name = f"OWSM Finetuned ({readable_date})" checkpoints.append(checkpoint_name) except ValueError: # Fallback if timestamp parsing fails checkpoints.append(f"OWSM Finetuned ({timestamp_str})") except Exception: pass # Search for all versioned Whisper models try: models = api.list_models( author="shaun3141", search="caribbean-voices-whisper-finetuned", token=hf_token ) for model_info in models: repo_id = model_info.id if "caribbean-voices-whisper-finetuned-" in repo_id: # Extract timestamp from repo name timestamp_str = repo_id.split("-")[-1] try: # Parse timestamp: YYYYMMDD-HHMMSS dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S") readable_date = dt.strftime("%Y-%m-%d %H:%M") checkpoint_name = f"Whisper Finetuned ({readable_date})" checkpoints.append(checkpoint_name) except ValueError: # Fallback if timestamp parsing fails checkpoints.append(f"Whisper Finetuned ({timestamp_str})") except Exception: pass except Exception: pass # Silently fail if HF Hub check fails if not os.path.exists(MODEL_OUTPUT_DIR): return checkpoints # Check if final model exists locally (improved detection) if is_fine_tuned_model_available(): checkpoints.append("Final Model") # Find checkpoint directories if os.path.isdir(MODEL_OUTPUT_DIR): try: for item in os.listdir(MODEL_OUTPUT_DIR): checkpoint_path = os.path.join(MODEL_OUTPUT_DIR, item) if os.path.isdir(checkpoint_path): if item.startswith("checkpoint-"): # Verify it's a valid checkpoint (has config.json or pytorch_model.bin) if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or os.path.exists(os.path.join(checkpoint_path, "model.safetensors"))): checkpoints.append(item) # Also check for other model directories that might not follow checkpoint-* naming elif item not in ["espnet_model"]: # Skip known subdirectories # Check if this directory contains model files if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or os.path.exists(os.path.join(checkpoint_path, "model.safetensors")) or os.path.exists(os.path.join(checkpoint_path, "wrapper_state_dict.pt"))): # Add as a checkpoint option checkpoints.append(item) except (OSError, PermissionError) as e: print(f"⚠ Warning: Could not list checkpoints in {MODEL_OUTPUT_DIR}: {e}") # Sort checkpoints: base models first, then versioned HF Hub models (newest first), then local checkpoints, then final model def get_sort_key(name): if name in AVAILABLE_MODELS: return (0, name) # Base models first elif name.startswith("OWSM Finetuned (") or name.startswith("Whisper Finetuned ("): # Extract date for sorting (newest first) try: date_str = name.split("(")[1].rstrip(")") from datetime import datetime dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M") # Return negative timestamp so newest comes first return (1, -dt.timestamp()) except: return (1, 0) elif name == "Final Model": return (3, float('inf')) # Final model last elif name.startswith("checkpoint-"): try: step_num = int(name.split("-")[1]) return (2, step_num) # Local checkpoints in middle, sorted by step except: return (2, 0) else: return (4, name) # Unknown items last checkpoints.sort(key=get_sort_key) return checkpoints def get_checkpoint_path(checkpoint_name: str) -> str: """ Get full path to a checkpoint directory or HF Hub repo ID. Args: checkpoint_name: Name like "Final Model", "OWSM Finetuned (2024-01-15 14:30)", or "checkpoint-2000" Returns: Full path to checkpoint directory or HF Hub repo ID """ if checkpoint_name == "Final Model": return MODEL_OUTPUT_DIR elif checkpoint_name.startswith("OWSM Finetuned (") or checkpoint_name.startswith("Whisper Finetuned ("): # Extract timestamp from checkpoint name and find matching repo hf_token = os.getenv("HF_TOKEN") if not hf_token: raise ValueError("HF_TOKEN required to load models from Hub") from huggingface_hub import HfApi from datetime import datetime api = HfApi(token=hf_token) # Extract date from checkpoint name: "OWSM Finetuned (2024-01-15 14:30)" try: date_str = checkpoint_name.split("(")[1].rstrip(")") # Parse readable date back to timestamp format dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M") timestamp_str = dt.strftime("%Y%m%d-%H%M%S") # Determine model type and search for matching repo if checkpoint_name.startswith("OWSM"): repo_pattern = f"shaun3141/caribbean-voices-owsm-finetuned-{timestamp_str}" else: repo_pattern = f"shaun3141/caribbean-voices-whisper-finetuned-{timestamp_str}" # Verify repo exists try: api.model_info(repo_pattern, token=hf_token) return repo_pattern except Exception: # Try to find closest match if exact timestamp doesn't match search_term = "caribbean-voices-owsm-finetuned" if checkpoint_name.startswith("OWSM") else "caribbean-voices-whisper-finetuned" models = api.list_models(author="shaun3141", search=search_term, token=hf_token) # Return the most recent matching model matching_models = [m.id for m in models if timestamp_str[:8] in m.id] # Match by date if matching_models: return matching_models[0] raise FileNotFoundError(f"Model not found: {checkpoint_name}") except Exception as e: raise ValueError(f"Could not parse checkpoint name: {checkpoint_name}. Error: {e}") else: return os.path.join(MODEL_OUTPUT_DIR, checkpoint_name)