File size: 5,161 Bytes
2dc6653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#!/usr/bin/env python3
"""
Diagnostic script to test inference pipeline components
Run this in the Space environment to identify issues
"""

import sys
import os
from datetime import datetime

print("="*60)
print("FBMC CHRONOS-2 DIAGNOSTIC SCRIPT")
print("="*60)

# Test 1: Python environment
print("\n[1] Python Environment")
print(f"  Python version: {sys.version}")
print(f"  Python path: {sys.executable}")

# Test 2: Import dependencies
print("\n[2] Importing Dependencies")
try:
    import torch
    print(f"  PyTorch: {torch.__version__}")
    print(f"  CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"  CUDA device: {torch.cuda.get_device_name(0)}")
except Exception as e:
    print(f"  PyTorch ERROR: {e}")

try:
    import polars as pl
    print(f"  Polars: {pl.__version__}")
except Exception as e:
    print(f"  Polars ERROR: {e}")

try:
    import numpy as np
    print(f"  NumPy: {np.__version__}")
except Exception as e:
    print(f"  NumPy ERROR: {e}")

try:
    from chronos import ChronosPipeline
    print(f"  Chronos: OK")
except Exception as e:
    print(f"  Chronos ERROR: {e}")

try:
    from datasets import load_dataset
    print(f"  HF Datasets: OK")
except Exception as e:
    print(f"  HF Datasets ERROR: {e}")

# Test 3: Environment variables
print("\n[3] Environment Variables")
print(f"  HF_TOKEN: {'SET' if os.getenv('HF_TOKEN') else 'NOT SET'}")
print(f"  DEVICE: {os.getenv('DEVICE', 'cuda')}")

# Test 4: Load dataset
print("\n[4] Loading Dataset")
try:
    from datasets import load_dataset
    hf_token = os.getenv("HF_TOKEN")
    print(f"  Loading evgueni-p/fbmc-features-24month...")
    dataset = load_dataset(
        "evgueni-p/fbmc-features-24month",
        split="train",
        token=hf_token
    )
    print(f"  Dataset rows: {len(dataset)}")

    # Convert to Polars
    import polars as pl
    df = pl.from_arrow(dataset.data.table)
    print(f"  Polars shape: {df.shape}")

    # Check for target columns
    target_cols = [col for col in df.columns if col.startswith('target_border_')]
    print(f"  Target borders: {len(target_cols)}")
    if target_cols:
        print(f"  First border: {target_cols[0]}")

except Exception as e:
    print(f"  Dataset ERROR: {e}")
    import traceback
    traceback.print_exc()

# Test 5: Load Chronos model
print("\n[5] Loading Chronos Model")
try:
    from chronos import ChronosPipeline
    import torch

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"  Device: {device}")
    print(f"  Loading amazon/chronos-t5-large...")

    pipeline = ChronosPipeline.from_pretrained(
        "amazon/chronos-t5-large",
        device_map=device,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
    )
    print(f"  Model loaded successfully!")

    # Test inference with dummy data
    print(f"\n  Testing inference with dummy data...")
    import numpy as np
    dummy_context = np.random.randn(512).astype(np.float32)

    forecast = pipeline.predict(
        context=dummy_context,
        prediction_length=24,
        num_samples=5
    )

    forecast_np = forecast.numpy()
    print(f"  Forecast shape: {forecast_np.shape}")

    # Test quantile calculation
    median = np.median(forecast_np, axis=0)
    q10 = np.quantile(forecast_np, 0.1, axis=0)
    q90 = np.quantile(forecast_np, 0.9, axis=0)

    print(f"  Quantiles calculated successfully!")
    print(f"    Median shape: {median.shape}")
    print(f"    Q10 shape: {q10.shape}")
    print(f"    Q90 shape: {q90.shape}")

except Exception as e:
    print(f"  Model ERROR: {e}")
    import traceback
    traceback.print_exc()

# Test 6: Test dynamic_forecast import
print("\n[6] Testing Module Imports")
try:
    from src.forecasting.dynamic_forecast import DynamicForecast
    print(f"  DynamicForecast: OK")
except Exception as e:
    print(f"  DynamicForecast ERROR: {e}")
    import traceback
    traceback.print_exc()

try:
    from src.forecasting.feature_availability import FeatureAvailability
    print(f"  FeatureAvailability: OK")
except Exception as e:
    print(f"  FeatureAvailability ERROR: {e}")

# Test 7: Quick inference test
print("\n[7] Full Pipeline Test (Minimal)")
try:
    print(f"  Testing run_inference function...")
    from src.forecasting.chronos_inference import run_inference

    # This will be slow but should work
    print(f"  Running smoke test for 2025-09-30...")
    print(f"  (This may take 60+ seconds...)")

    result_path = run_inference(
        run_date="2025-09-30",
        forecast_type="smoke_test",
        output_dir="/tmp"
    )

    print(f"  Result file: {result_path}")

    # Check if file has data
    import polars as pl
    df = pl.read_parquet(result_path)
    print(f"  Result shape: {df.shape}")
    print(f"  Columns: {df.columns}")

    if len(df.columns) > 1:
        print(f"  [SUCCESS] Forecast has data!")
    else:
        print(f"  [ERROR] Forecast is empty (only timestamps)")

except Exception as e:
    print(f"  Pipeline ERROR: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("DIAGNOSTIC COMPLETE")
print("="*60)