shaun3141's picture
Add timestamp-based versioning to HF Hub model repositories - each training run creates unique versioned repo
60f01ba
"""Model loading and caching - optimized for NVIDIA A10G GPUs."""
import os
import torch
from transformers import (
Wav2Vec2Processor,
Wav2Vec2ForCTC,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
WhisperProcessor,
)
from typing import Tuple, Optional
from data.manager import MODEL_OUTPUT_DIR
# Available models for experimentation
AVAILABLE_MODELS = {
"Wav2Vec2 Base (960h)": "facebook/wav2vec2-base-960h",
"Wav2Vec2 Large (960h)": "facebook/wav2vec2-large-960h",
"Wav2Vec2 Base (100h)": "facebook/wav2vec2-base",
"OWSM v3.1 Small (ESPnet)": "espnet/owsm_v3.1_ebf_small",
}
# Fine-tuned model key (added dynamically if model exists)
FINE_TUNED_MODEL_KEY = "Fine-tuned Whisper (Caribbean Voices)"
# Global variables for model caching
current_model_name = None
current_processor = None
current_model = None
# Optimized for CUDA on A10G GPUs
device = "cuda" if torch.cuda.is_available() else "cpu"
def is_fine_tuned_model_available() -> bool:
"""Check if a fine-tuned model exists in MODEL_OUTPUT_DIR"""
if not os.path.exists(MODEL_OUTPUT_DIR):
return False
# Check for model files directly in MODEL_OUTPUT_DIR
model_files = [
os.path.join(MODEL_OUTPUT_DIR, "pytorch_model.bin"),
os.path.join(MODEL_OUTPUT_DIR, "model.safetensors"),
os.path.join(MODEL_OUTPUT_DIR, "config.json"),
os.path.join(MODEL_OUTPUT_DIR, "model_index.json"), # For ESPnet models
]
if any(os.path.exists(f) for f in model_files):
return True
# Check for ESPnet model in subdirectory
espnet_dir = os.path.join(MODEL_OUTPUT_DIR, "espnet_model")
if os.path.exists(espnet_dir):
espnet_files = [
os.path.join(espnet_dir, "config.json"),
os.path.join(espnet_dir, "pytorch_model.bin"),
os.path.join(espnet_dir, "model.safetensors"),
]
if any(os.path.exists(f) for f in espnet_files):
return True
# Check for wrapper state dict (ESPnet wrapper models)
wrapper_state = os.path.join(MODEL_OUTPUT_DIR, "wrapper_state_dict.pt")
if os.path.exists(wrapper_state):
return True
# If directory exists and has any content, consider it a valid model directory
# (some models might save differently)
if os.path.isdir(MODEL_OUTPUT_DIR):
try:
contents = os.listdir(MODEL_OUTPUT_DIR)
# If directory has any files (not just empty), it might be a model
if contents:
# Check for any common model file extensions
model_extensions = ['.bin', '.safetensors', '.pt', '.json', '.txt']
for item in contents:
item_path = os.path.join(MODEL_OUTPUT_DIR, item)
if os.path.isfile(item_path):
if any(item.endswith(ext) for ext in model_extensions):
return True
except (OSError, PermissionError):
pass
return False
def load_model(model_key: str) -> Tuple[Optional[object], object]:
"""Load a model and processor, caching them for efficiency"""
global current_model_name, current_processor, current_model
# Handle fine-tuned model
if model_key == FINE_TUNED_MODEL_KEY:
if not is_fine_tuned_model_available():
raise FileNotFoundError(
f"Fine-tuned model not found at {MODEL_OUTPUT_DIR}. "
f"Please train a model first."
)
if current_model_name != MODEL_OUTPUT_DIR:
print(f"Loading fine-tuned model from {MODEL_OUTPUT_DIR}...")
# Try to load processor (WhisperProcessor or AutoProcessor)
try:
processor = WhisperProcessor.from_pretrained(MODEL_OUTPUT_DIR)
except:
processor = AutoProcessor.from_pretrained(MODEL_OUTPUT_DIR)
# Load model - use AutoModelForSpeechSeq2Seq for compatibility
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_OUTPUT_DIR)
model.to(device)
model.eval()
current_processor = processor
current_model = model
current_model_name = MODEL_OUTPUT_DIR
print(f"✓ Fine-tuned model loaded on {device}")
return current_processor, current_model
model_path = AVAILABLE_MODELS[model_key]
# Handle OWSM models differently - optimized for A10G GPUs with Flash Attention
if "OWSM" in model_key:
try:
from espnet2.bin.s2t_inference import Speech2Text
if current_model_name != model_path:
# Flash Attention should be available on A10G GPUs
try:
import flash_attn
print("Loading OWSM model with Flash Attention optimization (A10G GPU)...")
except ImportError:
print("⚠ Loading OWSM model without Flash Attention (performance may be suboptimal)")
print(" Flash Attention should be installed on A10G GPUs - check build logs")
current_model = Speech2Text.from_pretrained(model_path)
current_model_name = model_path
print(f"✓ OWSM model loaded on {device}")
return None, current_model # No processor for ESPnet
except ImportError:
raise ImportError("ESPnet not installed. Install with: pip install espnet espnet_model_zoo")
# Standard HuggingFace models
if current_model_name != model_path:
print(f"Loading model: {model_path}")
current_processor = Wav2Vec2Processor.from_pretrained(model_path)
current_model = Wav2Vec2ForCTC.from_pretrained(model_path)
current_model.to(device)
current_model.eval()
current_model_name = model_path
print(f"Model loaded on {device}")
return current_processor, current_model
def get_available_models():
"""Get list of available model keys, including fine-tuned model if available"""
models = list(AVAILABLE_MODELS.keys())
# Add fine-tuned model if it exists
if is_fine_tuned_model_available():
models.append(FINE_TUNED_MODEL_KEY)
return models
def load_checkpoint(checkpoint_path: str) -> Tuple[Optional[object], object]:
"""
Load a model from a specific checkpoint directory or Hugging Face Hub.
Args:
checkpoint_path: Path to checkpoint directory or HF Hub repo ID (e.g., "data/owsm_caribbean_finetuned/checkpoint-2000" or "shaun3141/caribbean-voices-owsm-finetuned")
Returns:
Tuple of (processor, model)
"""
hf_token = os.getenv("HF_TOKEN")
# Check if this is an HF Hub repo ID (contains / and doesn't exist locally)
is_hf_hub = "/" in checkpoint_path and not os.path.exists(checkpoint_path)
if is_hf_hub:
print(f"Loading checkpoint from Hugging Face Hub: {checkpoint_path}...")
if not hf_token:
raise ValueError(f"HF_TOKEN required to load model from Hub: {checkpoint_path}")
else:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
print(f"Loading checkpoint from {checkpoint_path}...")
# Try to load processor
try:
processor = WhisperProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
except:
processor = AutoProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
# Load model
model = AutoModelForSpeechSeq2Seq.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
model.to(device)
model.eval()
print(f"✓ Checkpoint loaded on {device}")
return processor, model
def get_available_checkpoints(include_base_models: bool = True) -> list:
"""
Get list of available checkpoints in MODEL_OUTPUT_DIR and Hugging Face Hub.
Args:
include_base_models: If True, include base models from AVAILABLE_MODELS
Returns:
List of checkpoint names (e.g., ["Final Model", "checkpoint-1000", "checkpoint-2000", "OWSM v3.1 Small (ESPnet)"])
"""
checkpoints = []
# Add base models first if requested
if include_base_models:
checkpoints.extend(list(AVAILABLE_MODELS.keys()))
# Check Hugging Face Hub for uploaded models (persistent storage)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
try:
from huggingface_hub import HfApi
from datetime import datetime
api = HfApi(token=hf_token)
# Search for all versioned OWSM models
try:
models = api.list_models(
author="shaun3141",
search="caribbean-voices-owsm-finetuned",
token=hf_token
)
for model_info in models:
repo_id = model_info.id
if "caribbean-voices-owsm-finetuned-" in repo_id:
# Extract timestamp from repo name
timestamp_str = repo_id.split("-")[-1]
try:
# Parse timestamp: YYYYMMDD-HHMMSS
dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
readable_date = dt.strftime("%Y-%m-%d %H:%M")
checkpoint_name = f"OWSM Finetuned ({readable_date})"
checkpoints.append(checkpoint_name)
except ValueError:
# Fallback if timestamp parsing fails
checkpoints.append(f"OWSM Finetuned ({timestamp_str})")
except Exception:
pass
# Search for all versioned Whisper models
try:
models = api.list_models(
author="shaun3141",
search="caribbean-voices-whisper-finetuned",
token=hf_token
)
for model_info in models:
repo_id = model_info.id
if "caribbean-voices-whisper-finetuned-" in repo_id:
# Extract timestamp from repo name
timestamp_str = repo_id.split("-")[-1]
try:
# Parse timestamp: YYYYMMDD-HHMMSS
dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
readable_date = dt.strftime("%Y-%m-%d %H:%M")
checkpoint_name = f"Whisper Finetuned ({readable_date})"
checkpoints.append(checkpoint_name)
except ValueError:
# Fallback if timestamp parsing fails
checkpoints.append(f"Whisper Finetuned ({timestamp_str})")
except Exception:
pass
except Exception:
pass # Silently fail if HF Hub check fails
if not os.path.exists(MODEL_OUTPUT_DIR):
return checkpoints
# Check if final model exists locally (improved detection)
if is_fine_tuned_model_available():
checkpoints.append("Final Model")
# Find checkpoint directories
if os.path.isdir(MODEL_OUTPUT_DIR):
try:
for item in os.listdir(MODEL_OUTPUT_DIR):
checkpoint_path = os.path.join(MODEL_OUTPUT_DIR, item)
if os.path.isdir(checkpoint_path):
if item.startswith("checkpoint-"):
# Verify it's a valid checkpoint (has config.json or pytorch_model.bin)
if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or
os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or
os.path.exists(os.path.join(checkpoint_path, "model.safetensors"))):
checkpoints.append(item)
# Also check for other model directories that might not follow checkpoint-* naming
elif item not in ["espnet_model"]: # Skip known subdirectories
# Check if this directory contains model files
if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or
os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or
os.path.exists(os.path.join(checkpoint_path, "model.safetensors")) or
os.path.exists(os.path.join(checkpoint_path, "wrapper_state_dict.pt"))):
# Add as a checkpoint option
checkpoints.append(item)
except (OSError, PermissionError) as e:
print(f"⚠ Warning: Could not list checkpoints in {MODEL_OUTPUT_DIR}: {e}")
# Sort checkpoints: base models first, then versioned HF Hub models (newest first), then local checkpoints, then final model
def get_sort_key(name):
if name in AVAILABLE_MODELS:
return (0, name) # Base models first
elif name.startswith("OWSM Finetuned (") or name.startswith("Whisper Finetuned ("):
# Extract date for sorting (newest first)
try:
date_str = name.split("(")[1].rstrip(")")
from datetime import datetime
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
# Return negative timestamp so newest comes first
return (1, -dt.timestamp())
except:
return (1, 0)
elif name == "Final Model":
return (3, float('inf')) # Final model last
elif name.startswith("checkpoint-"):
try:
step_num = int(name.split("-")[1])
return (2, step_num) # Local checkpoints in middle, sorted by step
except:
return (2, 0)
else:
return (4, name) # Unknown items last
checkpoints.sort(key=get_sort_key)
return checkpoints
def get_checkpoint_path(checkpoint_name: str) -> str:
"""
Get full path to a checkpoint directory or HF Hub repo ID.
Args:
checkpoint_name: Name like "Final Model", "OWSM Finetuned (2024-01-15 14:30)", or "checkpoint-2000"
Returns:
Full path to checkpoint directory or HF Hub repo ID
"""
if checkpoint_name == "Final Model":
return MODEL_OUTPUT_DIR
elif checkpoint_name.startswith("OWSM Finetuned (") or checkpoint_name.startswith("Whisper Finetuned ("):
# Extract timestamp from checkpoint name and find matching repo
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("HF_TOKEN required to load models from Hub")
from huggingface_hub import HfApi
from datetime import datetime
api = HfApi(token=hf_token)
# Extract date from checkpoint name: "OWSM Finetuned (2024-01-15 14:30)"
try:
date_str = checkpoint_name.split("(")[1].rstrip(")")
# Parse readable date back to timestamp format
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
timestamp_str = dt.strftime("%Y%m%d-%H%M%S")
# Determine model type and search for matching repo
if checkpoint_name.startswith("OWSM"):
repo_pattern = f"shaun3141/caribbean-voices-owsm-finetuned-{timestamp_str}"
else:
repo_pattern = f"shaun3141/caribbean-voices-whisper-finetuned-{timestamp_str}"
# Verify repo exists
try:
api.model_info(repo_pattern, token=hf_token)
return repo_pattern
except Exception:
# Try to find closest match if exact timestamp doesn't match
search_term = "caribbean-voices-owsm-finetuned" if checkpoint_name.startswith("OWSM") else "caribbean-voices-whisper-finetuned"
models = api.list_models(author="shaun3141", search=search_term, token=hf_token)
# Return the most recent matching model
matching_models = [m.id for m in models if timestamp_str[:8] in m.id] # Match by date
if matching_models:
return matching_models[0]
raise FileNotFoundError(f"Model not found: {checkpoint_name}")
except Exception as e:
raise ValueError(f"Could not parse checkpoint name: {checkpoint_name}. Error: {e}")
else:
return os.path.join(MODEL_OUTPUT_DIR, checkpoint_name)