File size: 16,883 Bytes
e3aec0d
0a59e8a
2941e2c
0a59e8a
 
 
 
 
 
 
2941e2c
0a59e8a
2941e2c
 
 
 
 
 
 
 
 
0a59e8a
 
 
2941e2c
 
 
 
e3aec0d
2941e2c
 
 
0a59e8a
 
 
 
 
5c476e0
0a59e8a
 
 
 
5c476e0
0a59e8a
 
5c476e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a59e8a
 
2941e2c
 
 
 
0a59e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2941e2c
 
e3aec0d
2941e2c
 
 
 
e3aec0d
 
 
 
 
 
 
 
2941e2c
 
e3aec0d
2941e2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a59e8a
 
 
 
 
 
 
 
 
 
 
 
a369eb8
0a59e8a
 
a369eb8
0a59e8a
 
 
 
a369eb8
0a59e8a
a369eb8
 
 
 
 
 
 
 
 
 
 
0a59e8a
 
 
a369eb8
0a59e8a
a369eb8
0a59e8a
 
a369eb8
0a59e8a
 
 
 
 
 
 
 
 
 
 
a369eb8
0a59e8a
 
 
 
 
 
 
 
 
 
 
 
 
a369eb8
 
 
 
 
60f01ba
a369eb8
 
60f01ba
a369eb8
60f01ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a369eb8
 
 
60f01ba
a369eb8
60f01ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a369eb8
 
 
 
 
0a59e8a
 
 
a369eb8
0a59e8a
 
 
 
 
5c476e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a59e8a
60f01ba
0a59e8a
 
 
60f01ba
 
 
 
 
 
 
 
 
 
0a59e8a
60f01ba
0a59e8a
 
 
60f01ba
0a59e8a
60f01ba
0a59e8a
60f01ba
0a59e8a
 
 
 
 
 
 
 
a369eb8
0a59e8a
 
60f01ba
0a59e8a
 
a369eb8
0a59e8a
 
 
60f01ba
 
a369eb8
60f01ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a369eb8
60f01ba
 
a369eb8
60f01ba
 
 
 
 
 
 
 
 
 
0a59e8a
 
2941e2c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""Model loading and caching - optimized for NVIDIA A10G GPUs."""
import os
import torch
from transformers import (
    Wav2Vec2Processor, 
    Wav2Vec2ForCTC,
    AutoProcessor,
    AutoModelForSpeechSeq2Seq,
    WhisperProcessor,
)
from typing import Tuple, Optional
from data.manager import MODEL_OUTPUT_DIR

# Available models for experimentation
AVAILABLE_MODELS = {
    "Wav2Vec2 Base (960h)": "facebook/wav2vec2-base-960h",
    "Wav2Vec2 Large (960h)": "facebook/wav2vec2-large-960h",
    "Wav2Vec2 Base (100h)": "facebook/wav2vec2-base",
    "OWSM v3.1 Small (ESPnet)": "espnet/owsm_v3.1_ebf_small",
}

# Fine-tuned model key (added dynamically if model exists)
FINE_TUNED_MODEL_KEY = "Fine-tuned Whisper (Caribbean Voices)"

# Global variables for model caching
current_model_name = None
current_processor = None
current_model = None
# Optimized for CUDA on A10G GPUs
device = "cuda" if torch.cuda.is_available() else "cpu"


def is_fine_tuned_model_available() -> bool:
    """Check if a fine-tuned model exists in MODEL_OUTPUT_DIR"""
    if not os.path.exists(MODEL_OUTPUT_DIR):
        return False
    
    # Check for model files directly in MODEL_OUTPUT_DIR
    model_files = [
        os.path.join(MODEL_OUTPUT_DIR, "pytorch_model.bin"),
        os.path.join(MODEL_OUTPUT_DIR, "model.safetensors"),
        os.path.join(MODEL_OUTPUT_DIR, "config.json"),
        os.path.join(MODEL_OUTPUT_DIR, "model_index.json"),  # For ESPnet models
    ]
    
    if any(os.path.exists(f) for f in model_files):
        return True
    
    # Check for ESPnet model in subdirectory
    espnet_dir = os.path.join(MODEL_OUTPUT_DIR, "espnet_model")
    if os.path.exists(espnet_dir):
        espnet_files = [
            os.path.join(espnet_dir, "config.json"),
            os.path.join(espnet_dir, "pytorch_model.bin"),
            os.path.join(espnet_dir, "model.safetensors"),
        ]
        if any(os.path.exists(f) for f in espnet_files):
            return True
    
    # Check for wrapper state dict (ESPnet wrapper models)
    wrapper_state = os.path.join(MODEL_OUTPUT_DIR, "wrapper_state_dict.pt")
    if os.path.exists(wrapper_state):
        return True
    
    # If directory exists and has any content, consider it a valid model directory
    # (some models might save differently)
    if os.path.isdir(MODEL_OUTPUT_DIR):
        try:
            contents = os.listdir(MODEL_OUTPUT_DIR)
            # If directory has any files (not just empty), it might be a model
            if contents:
                # Check for any common model file extensions
                model_extensions = ['.bin', '.safetensors', '.pt', '.json', '.txt']
                for item in contents:
                    item_path = os.path.join(MODEL_OUTPUT_DIR, item)
                    if os.path.isfile(item_path):
                        if any(item.endswith(ext) for ext in model_extensions):
                            return True
        except (OSError, PermissionError):
            pass
    
    return False


def load_model(model_key: str) -> Tuple[Optional[object], object]:
    """Load a model and processor, caching them for efficiency"""
    global current_model_name, current_processor, current_model
    
    # Handle fine-tuned model
    if model_key == FINE_TUNED_MODEL_KEY:
        if not is_fine_tuned_model_available():
            raise FileNotFoundError(
                f"Fine-tuned model not found at {MODEL_OUTPUT_DIR}. "
                f"Please train a model first."
            )
        
        if current_model_name != MODEL_OUTPUT_DIR:
            print(f"Loading fine-tuned model from {MODEL_OUTPUT_DIR}...")
            
            # Try to load processor (WhisperProcessor or AutoProcessor)
            try:
                processor = WhisperProcessor.from_pretrained(MODEL_OUTPUT_DIR)
            except:
                processor = AutoProcessor.from_pretrained(MODEL_OUTPUT_DIR)
            
            # Load model - use AutoModelForSpeechSeq2Seq for compatibility
            model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_OUTPUT_DIR)
            
            model.to(device)
            model.eval()
            
            current_processor = processor
            current_model = model
            current_model_name = MODEL_OUTPUT_DIR
            print(f"✓ Fine-tuned model loaded on {device}")
        
        return current_processor, current_model
    
    model_path = AVAILABLE_MODELS[model_key]
    
    # Handle OWSM models differently - optimized for A10G GPUs with Flash Attention
    if "OWSM" in model_key:
        try:
            from espnet2.bin.s2t_inference import Speech2Text
            if current_model_name != model_path:
                # Flash Attention should be available on A10G GPUs
                try:
                    import flash_attn
                    print("Loading OWSM model with Flash Attention optimization (A10G GPU)...")
                except ImportError:
                    print("⚠ Loading OWSM model without Flash Attention (performance may be suboptimal)")
                    print("  Flash Attention should be installed on A10G GPUs - check build logs")
                
                current_model = Speech2Text.from_pretrained(model_path)
                current_model_name = model_path
                print(f"✓ OWSM model loaded on {device}")
            return None, current_model  # No processor for ESPnet
        except ImportError:
            raise ImportError("ESPnet not installed. Install with: pip install espnet espnet_model_zoo")
    
    # Standard HuggingFace models
    if current_model_name != model_path:
        print(f"Loading model: {model_path}")
        current_processor = Wav2Vec2Processor.from_pretrained(model_path)
        current_model = Wav2Vec2ForCTC.from_pretrained(model_path)
        current_model.to(device)
        current_model.eval()
        current_model_name = model_path
        print(f"Model loaded on {device}")
    
    return current_processor, current_model


def get_available_models():
    """Get list of available model keys, including fine-tuned model if available"""
    models = list(AVAILABLE_MODELS.keys())
    
    # Add fine-tuned model if it exists
    if is_fine_tuned_model_available():
        models.append(FINE_TUNED_MODEL_KEY)
    
    return models


def load_checkpoint(checkpoint_path: str) -> Tuple[Optional[object], object]:
    """
    Load a model from a specific checkpoint directory or Hugging Face Hub.
    
    Args:
        checkpoint_path: Path to checkpoint directory or HF Hub repo ID (e.g., "data/owsm_caribbean_finetuned/checkpoint-2000" or "shaun3141/caribbean-voices-owsm-finetuned")
    
    Returns:
        Tuple of (processor, model)
    """
    hf_token = os.getenv("HF_TOKEN")
    
    # Check if this is an HF Hub repo ID (contains / and doesn't exist locally)
    is_hf_hub = "/" in checkpoint_path and not os.path.exists(checkpoint_path)
    
    if is_hf_hub:
        print(f"Loading checkpoint from Hugging Face Hub: {checkpoint_path}...")
        if not hf_token:
            raise ValueError(f"HF_TOKEN required to load model from Hub: {checkpoint_path}")
    else:
        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")
        print(f"Loading checkpoint from {checkpoint_path}...")
    
    # Try to load processor
    try:
        processor = WhisperProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
    except:
        processor = AutoProcessor.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
    
    # Load model
    model = AutoModelForSpeechSeq2Seq.from_pretrained(checkpoint_path, token=hf_token if is_hf_hub else None)
    
    model.to(device)
    model.eval()
    
    print(f"✓ Checkpoint loaded on {device}")
    
    return processor, model


def get_available_checkpoints(include_base_models: bool = True) -> list:
    """
    Get list of available checkpoints in MODEL_OUTPUT_DIR and Hugging Face Hub.
    
    Args:
        include_base_models: If True, include base models from AVAILABLE_MODELS
    
    Returns:
        List of checkpoint names (e.g., ["Final Model", "checkpoint-1000", "checkpoint-2000", "OWSM v3.1 Small (ESPnet)"])
    """
    checkpoints = []
    
    # Add base models first if requested
    if include_base_models:
        checkpoints.extend(list(AVAILABLE_MODELS.keys()))
    
    # Check Hugging Face Hub for uploaded models (persistent storage)
    hf_token = os.getenv("HF_TOKEN")
    if hf_token:
        try:
            from huggingface_hub import HfApi
            from datetime import datetime
            api = HfApi(token=hf_token)
            
            # Search for all versioned OWSM models
            try:
                models = api.list_models(
                    author="shaun3141",
                    search="caribbean-voices-owsm-finetuned",
                    token=hf_token
                )
                for model_info in models:
                    repo_id = model_info.id
                    if "caribbean-voices-owsm-finetuned-" in repo_id:
                        # Extract timestamp from repo name
                        timestamp_str = repo_id.split("-")[-1]
                        try:
                            # Parse timestamp: YYYYMMDD-HHMMSS
                            dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
                            readable_date = dt.strftime("%Y-%m-%d %H:%M")
                            checkpoint_name = f"OWSM Finetuned ({readable_date})"
                            checkpoints.append(checkpoint_name)
                        except ValueError:
                            # Fallback if timestamp parsing fails
                            checkpoints.append(f"OWSM Finetuned ({timestamp_str})")
            except Exception:
                pass
            
            # Search for all versioned Whisper models
            try:
                models = api.list_models(
                    author="shaun3141",
                    search="caribbean-voices-whisper-finetuned",
                    token=hf_token
                )
                for model_info in models:
                    repo_id = model_info.id
                    if "caribbean-voices-whisper-finetuned-" in repo_id:
                        # Extract timestamp from repo name
                        timestamp_str = repo_id.split("-")[-1]
                        try:
                            # Parse timestamp: YYYYMMDD-HHMMSS
                            dt = datetime.strptime(timestamp_str, "%Y%m%d-%H%M%S")
                            readable_date = dt.strftime("%Y-%m-%d %H:%M")
                            checkpoint_name = f"Whisper Finetuned ({readable_date})"
                            checkpoints.append(checkpoint_name)
                        except ValueError:
                            # Fallback if timestamp parsing fails
                            checkpoints.append(f"Whisper Finetuned ({timestamp_str})")
            except Exception:
                pass
        except Exception:
            pass  # Silently fail if HF Hub check fails
    
    if not os.path.exists(MODEL_OUTPUT_DIR):
        return checkpoints
    
    # Check if final model exists locally (improved detection)
    if is_fine_tuned_model_available():
        checkpoints.append("Final Model")
    
    # Find checkpoint directories
    if os.path.isdir(MODEL_OUTPUT_DIR):
        try:
            for item in os.listdir(MODEL_OUTPUT_DIR):
                checkpoint_path = os.path.join(MODEL_OUTPUT_DIR, item)
                if os.path.isdir(checkpoint_path):
                    if item.startswith("checkpoint-"):
                        # Verify it's a valid checkpoint (has config.json or pytorch_model.bin)
                        if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or
                            os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or
                            os.path.exists(os.path.join(checkpoint_path, "model.safetensors"))):
                            checkpoints.append(item)
                    # Also check for other model directories that might not follow checkpoint-* naming
                    elif item not in ["espnet_model"]:  # Skip known subdirectories
                        # Check if this directory contains model files
                        if (os.path.exists(os.path.join(checkpoint_path, "config.json")) or
                            os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")) or
                            os.path.exists(os.path.join(checkpoint_path, "model.safetensors")) or
                            os.path.exists(os.path.join(checkpoint_path, "wrapper_state_dict.pt"))):
                            # Add as a checkpoint option
                            checkpoints.append(item)
        except (OSError, PermissionError) as e:
            print(f"⚠ Warning: Could not list checkpoints in {MODEL_OUTPUT_DIR}: {e}")
    
    # Sort checkpoints: base models first, then versioned HF Hub models (newest first), then local checkpoints, then final model
    def get_sort_key(name):
        if name in AVAILABLE_MODELS:
            return (0, name)  # Base models first
        elif name.startswith("OWSM Finetuned (") or name.startswith("Whisper Finetuned ("):
            # Extract date for sorting (newest first)
            try:
                date_str = name.split("(")[1].rstrip(")")
                from datetime import datetime
                dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
                # Return negative timestamp so newest comes first
                return (1, -dt.timestamp())
            except:
                return (1, 0)
        elif name == "Final Model":
            return (3, float('inf'))  # Final model last
        elif name.startswith("checkpoint-"):
            try:
                step_num = int(name.split("-")[1])
                return (2, step_num)  # Local checkpoints in middle, sorted by step
            except:
                return (2, 0)
        else:
            return (4, name)  # Unknown items last
    
    checkpoints.sort(key=get_sort_key)
    
    return checkpoints


def get_checkpoint_path(checkpoint_name: str) -> str:
    """
    Get full path to a checkpoint directory or HF Hub repo ID.
    
    Args:
        checkpoint_name: Name like "Final Model", "OWSM Finetuned (2024-01-15 14:30)", or "checkpoint-2000"
    
    Returns:
        Full path to checkpoint directory or HF Hub repo ID
    """
    if checkpoint_name == "Final Model":
        return MODEL_OUTPUT_DIR
    elif checkpoint_name.startswith("OWSM Finetuned (") or checkpoint_name.startswith("Whisper Finetuned ("):
        # Extract timestamp from checkpoint name and find matching repo
        hf_token = os.getenv("HF_TOKEN")
        if not hf_token:
            raise ValueError("HF_TOKEN required to load models from Hub")
        
        from huggingface_hub import HfApi
        from datetime import datetime
        api = HfApi(token=hf_token)
        
        # Extract date from checkpoint name: "OWSM Finetuned (2024-01-15 14:30)"
        try:
            date_str = checkpoint_name.split("(")[1].rstrip(")")
            # Parse readable date back to timestamp format
            dt = datetime.strptime(date_str, "%Y-%m-%d %H:%M")
            timestamp_str = dt.strftime("%Y%m%d-%H%M%S")
            
            # Determine model type and search for matching repo
            if checkpoint_name.startswith("OWSM"):
                repo_pattern = f"shaun3141/caribbean-voices-owsm-finetuned-{timestamp_str}"
            else:
                repo_pattern = f"shaun3141/caribbean-voices-whisper-finetuned-{timestamp_str}"
            
            # Verify repo exists
            try:
                api.model_info(repo_pattern, token=hf_token)
                return repo_pattern
            except Exception:
                # Try to find closest match if exact timestamp doesn't match
                search_term = "caribbean-voices-owsm-finetuned" if checkpoint_name.startswith("OWSM") else "caribbean-voices-whisper-finetuned"
                models = api.list_models(author="shaun3141", search=search_term, token=hf_token)
                # Return the most recent matching model
                matching_models = [m.id for m in models if timestamp_str[:8] in m.id]  # Match by date
                if matching_models:
                    return matching_models[0]
                raise FileNotFoundError(f"Model not found: {checkpoint_name}")
        except Exception as e:
            raise ValueError(f"Could not parse checkpoint name: {checkpoint_name}. Error: {e}")
    else:
        return os.path.join(MODEL_OUTPUT_DIR, checkpoint_name)