|
|
"""Model loading and caching - optimized for NVIDIA A10G GPUs.""" |
|
|
import os |
|
|
import torch |
|
|
from transformers import ( |
|
|
Wav2Vec2Processor, |
|
|
Wav2Vec2ForCTC, |
|
|
AutoProcessor, |
|
|
AutoModelForSpeechSeq2Seq, |
|
|
WhisperProcessor, |
|
|
) |
|
|
from typing import Tuple, Optional |
|
|
from data.manager import MODEL_OUTPUT_DIR |
|
|
|
|
|
|
|
|
AVAILABLE_MODELS = { |
|
|
"Wav2Vec2 Base (960h)": "facebook/wav2vec2-base-960h", |
|
|
"Wav2Vec2 Large (960h)": "facebook/wav2vec2-large-960h", |
|
|
"Wav2Vec2 Base (100h)": "facebook/wav2vec2-base", |
|
|
"OWSM v3.1 Small (ESPnet)": "espnet/owsm_v3.1_ebf_small", |
|
|
} |
|
|
|
|
|
|
|
|
FINE_TUNED_MODEL_KEY = "Fine-tuned Whisper (Caribbean Voices)" |
|
|
|
|
|
|
|
|
current_model_name = None |
|
|
current_processor = None |
|
|
current_model = None |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
def is_fine_tuned_model_available() -> bool: |
|
|
"""Check if a fine-tuned model exists in MODEL_OUTPUT_DIR""" |
|
|
if not os.path.exists(MODEL_OUTPUT_DIR): |
|
|
return False |
|
|
|
|
|
|
|
|
model_files = [ |
|
|
os.path.join(MODEL_OUTPUT_DIR, "pytorch_model.bin"), |
|
|
os.path.join(MODEL_OUTPUT_DIR, "model.safetensors"), |
|
|
os.path.join(MODEL_OUTPUT_DIR, "config.json"), |
|
|
os.path.join(MODEL_OUTPUT_DIR, "model_index.json"), |
|
|
] |
|
|
|
|
|
if any(os.path.exists(f) for f in model_files): |
|
|
return True |
|
|
|
|
|
|
|
|
espnet_dir = os.path.join(MODEL_OUTPUT_DIR, "espnet_model") |
|
|
if os.path.exists(espnet_dir): |
|
|
espnet_files = [ |
|
|
os.path.join(espnet_dir, "config.json"), |
|
|
os.path.join(espnet_dir, "pytorch_model.bin"), |
|
|
os.path.join(espnet_dir, "model.safetensors"), |
|
|
] |
|
|
if any(os.path.exists(f) for f in espnet_files): |
|
|
return True |
|
|
|
|
|
|
|
|
wrapper_state = os.path.join(MODEL_OUTPUT_DIR, "wrapper_state_dict.pt") |
|
|
if os.path.exists(wrapper_state): |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
if os.path.isdir(MODEL_OUTPUT_DIR): |
|
|
try: |
|
|
contents = os.listdir(MODEL_OUTPUT_DIR) |
|
|
|
|
|
if contents: |
|
|
|
|
|
model_extensions = ['.bin', '.safetensors', '.pt', '.json', '.txt'] |
|
|
for item in contents: |
|
|
item_path = os.path.join(MODEL_OUTPUT_DIR, item) |
|
|
if os.path.isfile(item_path): |
|
|
if any(item.endswith(ext) for ext in model_extensions): |
|
|
return True |
|
|
except (OSError, PermissionError): |
|
|
pass |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def load_model(model_key: str) -> Tuple[Optional[object], object]: |
|
|
"""Load a model and processor, caching them for efficiency""" |
|
|
global current_model_name, current_processor, current_model |
|
|
|
|
|
|
|
|
if model_key == FINE_TUNED_MODEL_KEY: |
|
|
if not is_fine_tuned_model_available(): |
|
|
raise FileNotFoundError( |
|
|
f"Fine-tuned model not found at {MODEL_OUTPUT_DIR}. " |
|
|
f"Please train a model first." |
|
|
) |
|
|
|
|
|
if current_model_name != MODEL_OUTPUT_DIR: |
|
|
print(f"Loading fine-tuned model from {MODEL_OUTPUT_DIR}...") |
|
|
|
|
|
|
|
|
try: |
|
|
processor = WhisperProcessor.from_pretrained(MODEL_OUTPUT_DIR) |
|
|
except: |
|
|
processor = AutoProcessor.from_pretrained(MODEL_OUTPUT_DIR) |
|
|
|
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_OUTPUT_DIR) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
current_processor = processor |
|
|
current_model = model |
|
|
current_model_name = MODEL_OUTPUT_DIR |
|
|
print(f"✓ Fine-tuned model loaded on {device}") |
|
|
|
|
|
return current_processor, current_model |
|
|
|
|
|
model_path = AVAILABLE_MODELS[model_key] |
|
|
|
|
|
|
|
|
if "OWSM" in model_key: |
|
|
try: |
|
|
from espnet2.bin.s2t_inference import Speech2Text |
|
|
if current_model_name != model_path: |
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
print("Loading OWSM model with Flash Attention optimization (A10G GPU)...") |
|
|
except ImportError: |
|
|
print("⚠ Loading OWSM model without Flash Attention (performance may be suboptimal)") |
|
|
print(" Flash Attention should be installed on A10G GPUs - check build logs") |
|
|
|
|
|
current_model = Speech2Text.from_pretrained(model_path) |
|
|
current_model_name = model_path |
|
|
print(f"✓ OWSM model loaded on {device}") |
|
|
return None, current_model |
|
|
except ImportError: |
|
|
raise ImportError("ESPnet not installed. Install with: pip install espnet espnet_model_zoo") |
|
|
|
|
|
|
|
|
if current_model_name != model_path: |
|
|
print(f"Loading model: {model_path}") |
|
|
current_processor = Wav2Vec2Processor.from_pretrained(model_path) |
|
|
current_model = Wav2Vec2ForCTC.from_pretrained(model_path) |
|
|
current_model.to(device) |
|
|
current_model.eval() |
|
|
current_model_name = model_path |
|
|
print(f"Model loaded on {device}") |
|
|
|
|
|
return current_processor, current_model |
|
|
|
|
|
|
|
|
def get_available_models(): |
|
|
"""Get list of available model keys, including fine-tuned model if available""" |
|
|
models = list(AVAILABLE_MODELS.keys()) |
|
|
|
|
|
|
|
|
if is_fine_tuned_model_available(): |
|
|
models.append(FINE_TUNED_MODEL_KEY) |
|
|
|
|
|
return models |
|
|
|
|
|
|
|
|
def load_checkpoint(checkpoint_path: str) -> Tuple[Optional[object], object]: |
|
|
""" |
|
|
Load a model from a specific checkpoint directory or Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to checkpoint directory or HF Hub repo ID (e.g., "data/owsm_caribbean_finetuned/checkpoint-2000" or "shaun3141/caribbean-voices-owsm-finetuned") |
|
|
|
|
|
Returns: |
|
|
Tuple of (processor, model) |
|
|
""" |
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
|
|
|
is_hf_hub = "/" in checkpoint_path and not os.path.exists(checkpoint_path) |
|
|
|
|
|
if is_hf_hub: |
|
|
print(f"Loading checkpoint from Hugging Face Hub: {checkpoint_path}...") |
|
|
if not hf_token: |
|
|
raise ValueError(f"HF_TOKEN required to load model from Hub: {checkpoint_path}") |
|
|
else: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") |
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
|
|
|
|
|
|
try: |
|
|
processor = WhisperProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) |
|
|
except: |
|
|
processor = AutoProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) |
|
|
|
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None) |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f"✓ Checkpoint loaded on {device}") |
|
|
|
|
|
return processor, model |
|
|
|
|
|
|
|
|
def get_available_checkpoints(include_base_models: bool = True) -> list: |
|
|
""" |
|
|
Get list of available checkpoints in MODEL_OUTPUT_DIR and Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
include_base_models: If True, include base models from AVAILABLE_MODELS |
|
|
|
|
|
Returns: |
|
|
List of checkpoint names (e.g., ["Final Model", "checkpoint-1000", "checkpoint-2000", "OWSM v3.1 Small (ESPnet)"]) |
|
|
""" |
|
|
checkpoints = [] |
|
|
|
|
|
|
|
|
if include_base_models: |
|
|
checkpoints.extend(list(AVAILABLE_MODELS.keys())) |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if hf_token: |
|
|
try: |
|
|
from huggingface_hub import HfApi |
|
|
from datetime import datetime |
|
|
api = HfApi(token=hf_token) |
|
|
|
|
|
|
|
|
try: |
|
|
models = api.list_models( |
|
|
author="shaun3141", |
|
|
search="caribbean-voices-owsm-finetuned", |
|
|
token=hf_token |
|
|
) |
|
|
for model_info in models: |
|
|
repo_id = model_info.id |
|
|
if "caribbean-voices-owsm-finetuned-" in repo_id: |
|
|
|
|
|
timestamp_str = repo_id.split("-")[-1] |
|
|
try: |
|
|
|
|
|
dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S") |
|
|
readable_date = dt.strftime("%Y-%m-%d %H:%M") |
|
|
checkpoint_name = f"OWSM Finetuned ({readable_date})" |
|
|
checkpoints.append(checkpoint_name) |
|
|
except ValueError: |
|
|
|
|
|
checkpoints.append(f"OWSM Finetuned ({timestamp_str})") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
try: |
|
|
models = api.list_models( |
|
|
author="shaun3141", |
|
|
search="caribbean-voices-whisper-finetuned", |
|
|
token=hf_token |
|
|
) |
|
|
for model_info in models: |
|
|
repo_id = model_info.id |
|
|
if "caribbean-voices-whisper-finetuned-" in repo_id: |
|
|
|
|
|
timestamp_str = repo_id.split("-")[-1] |
|
|
try: |
|
|
|
|
|
dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S") |
|
|
readable_date = dt.strftime("%Y-%m-%d %H:%M") |
|
|
checkpoint_name = f"Whisper Finetuned ({readable_date})" |
|
|
checkpoints.append(checkpoint_name) |
|
|
except ValueError: |
|
|
|
|
|
checkpoints.append(f"Whisper Finetuned ({timestamp_str})") |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if not os.path.exists(MODEL_OUTPUT_DIR): |
|
|
return checkpoints |
|
|
|
|
|
|
|
|
if is_fine_tuned_model_available(): |
|
|
checkpoints.append("Final Model") |
|
|
|
|
|
|
|
|
if os.path.isdir(MODEL_OUTPUT_DIR): |
|
|
try: |
|
|
for item in os.listdir(MODEL_OUTPUT_DIR): |
|
|
checkpoint_path = os.path.join(MODEL_OUTPUT_DIR, item) |
|
|
if os.path.isdir(checkpoint_path): |
|
|
if item.startswith("checkpoint-"): |
|
|
|
|
|
if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or |
|
|
os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or |
|
|
os.path.exists(os.path.join(checkpoint_path, "model.safetensors"))): |
|
|
checkpoints.append(item) |
|
|
|
|
|
elif item not in ["espnet_model"]: |
|
|
|
|
|
if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or |
|
|
os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or |
|
|
os.path.exists(os.path.join(checkpoint_path, "model.safetensors")) or |
|
|
os.path.exists(os.path.join(checkpoint_path, "wrapper_state_dict.pt"))): |
|
|
|
|
|
checkpoints.append(item) |
|
|
except (OSError, PermissionError) as e: |
|
|
print(f"⚠ Warning: Could not list checkpoints in {MODEL_OUTPUT_DIR}: {e}") |
|
|
|
|
|
|
|
|
def get_sort_key(name): |
|
|
if name in AVAILABLE_MODELS: |
|
|
return (0, name) |
|
|
elif name.startswith("OWSM Finetuned (") or name.startswith("Whisper Finetuned ("): |
|
|
|
|
|
try: |
|
|
date_str = name.split("(")[1].rstrip(")") |
|
|
from datetime import datetime |
|
|
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M") |
|
|
|
|
|
return (1, -dt.timestamp()) |
|
|
except: |
|
|
return (1, 0) |
|
|
elif name == "Final Model": |
|
|
return (3, float('inf')) |
|
|
elif name.startswith("checkpoint-"): |
|
|
try: |
|
|
step_num = int(name.split("-")[1]) |
|
|
return (2, step_num) |
|
|
except: |
|
|
return (2, 0) |
|
|
else: |
|
|
return (4, name) |
|
|
|
|
|
checkpoints.sort(key=get_sort_key) |
|
|
|
|
|
return checkpoints |
|
|
|
|
|
|
|
|
def get_checkpoint_path(checkpoint_name: str) -> str: |
|
|
""" |
|
|
Get full path to a checkpoint directory or HF Hub repo ID. |
|
|
|
|
|
Args: |
|
|
checkpoint_name: Name like "Final Model", "OWSM Finetuned (2024-01-15 14:30)", or "checkpoint-2000" |
|
|
|
|
|
Returns: |
|
|
Full path to checkpoint directory or HF Hub repo ID |
|
|
""" |
|
|
if checkpoint_name == "Final Model": |
|
|
return MODEL_OUTPUT_DIR |
|
|
elif checkpoint_name.startswith("OWSM Finetuned (") or checkpoint_name.startswith("Whisper Finetuned ("): |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if not hf_token: |
|
|
raise ValueError("HF_TOKEN required to load models from Hub") |
|
|
|
|
|
from huggingface_hub import HfApi |
|
|
from datetime import datetime |
|
|
api = HfApi(token=hf_token) |
|
|
|
|
|
|
|
|
try: |
|
|
date_str = checkpoint_name.split("(")[1].rstrip(")") |
|
|
|
|
|
dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M") |
|
|
timestamp_str = dt.strftime("%Y%m%d-%H%M%S") |
|
|
|
|
|
|
|
|
if checkpoint_name.startswith("OWSM"): |
|
|
repo_pattern = f"shaun3141/caribbean-voices-owsm-finetuned-{timestamp_str}" |
|
|
else: |
|
|
repo_pattern = f"shaun3141/caribbean-voices-whisper-finetuned-{timestamp_str}" |
|
|
|
|
|
|
|
|
try: |
|
|
api.model_info(repo_pattern, token=hf_token) |
|
|
return repo_pattern |
|
|
except Exception: |
|
|
|
|
|
search_term = "caribbean-voices-owsm-finetuned" if checkpoint_name.startswith("OWSM") else "caribbean-voices-whisper-finetuned" |
|
|
models = api.list_models(author="shaun3141", search=search_term, token=hf_token) |
|
|
|
|
|
matching_models = [m.id for m in models if timestamp_str[:8] in m.id] |
|
|
if matching_models: |
|
|
return matching_models[0] |
|
|
raise FileNotFoundError(f"Model not found: {checkpoint_name}") |
|
|
except Exception as e: |
|
|
raise ValueError(f"Could not parse checkpoint name: {checkpoint_name}. Error: {e}") |
|
|
else: |
|
|
return os.path.join(MODEL_OUTPUT_DIR, checkpoint_name) |
|
|
|
|
|
|