Evgueni Poloukarov Claude commited on
Commit
44b73f4
·
1 Parent(s): af88e60

feat: implement zero-shot inference pipeline for Day 3

Browse files

Created complete inference infrastructure:

**New modules (src/inference/)**:
- data_fetcher.py: DataFetcher class for preparing Chronos 2 input
* Loads unified features from HF Dataset or local parquet
* Identifies 615 future covariates from metadata
* Prepares context windows (configurable length)
* Formats data for predict_df() API
* Handles multivariate forecasting (38 borders)

- chronos_pipeline.py: ChronosForecaster class for inference
* Loads Chronos 2 Large (710M params) with GPU support
* Zero-shot inference via predict_df() API
* Probabilistic forecasts (mean, median, quantiles)
* Performance benchmarking utilities
* Parquet export functionality

**Testing**:
- scripts/test_inference_pipeline.py: Comprehensive smoke test
* Tests data loading, model loading, inference
* Validates output quality and performance
* Estimates 14-day forecast time
* Single border × 7 days test case

**HuggingFace Space**:
- Fixed BUILD_ERROR by adding jupyterlab to requirements
- Space rebuild in progress (commit a7e66e0)

**Status**: Ready for local testing and smoke test execution

Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <[email protected]>

scripts/test_inference_pipeline.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smoke test for zero-shot inference pipeline
3
+
4
+ Tests:
5
+ 1. Data loading and preparation
6
+ 2. Chronos 2 model loading
7
+ 3. Inference on single border (7 days)
8
+ 4. Output validation
9
+ 5. Performance metrics
10
+ """
11
+
12
+ import sys
13
+ from pathlib import Path
14
+
15
+ # Add src to path
16
+ sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
17
+
18
+ from inference.data_fetcher import DataFetcher
19
+ from inference.chronos_pipeline import ChronosForecaster
20
+ from datetime import datetime, timedelta
21
+ import torch
22
+ import pandas as pd
23
+
24
+ def main():
25
+ print("="*60)
26
+ print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test")
27
+ print("="*60)
28
+
29
+ # Step 1: Check environment
30
+ print("\n[1] Checking environment...")
31
+ print(f"PyTorch version: {torch.__version__}")
32
+ print(f"CUDA available: {torch.cuda.is_available()}")
33
+ if torch.cuda.is_available():
34
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
35
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
36
+ else:
37
+ print("Running on CPU (inference will be slower)")
38
+
39
+ # Step 2: Initialize DataFetcher
40
+ print("\n[2] Initializing DataFetcher...")
41
+ fetcher = DataFetcher(
42
+ use_local=True, # Use local files for testing
43
+ context_length=512 # Use 512 hours context
44
+ )
45
+
46
+ # Step 3: Load data
47
+ print("\n[3] Loading unified features...")
48
+ fetcher.load_data()
49
+
50
+ # Get available date range
51
+ min_date, max_date = fetcher.get_available_dates()
52
+ print(f"Available data: {min_date} to {max_date}")
53
+
54
+ # Select forecast date (use last month as test)
55
+ forecast_date = max_date - timedelta(days=30)
56
+ print(f"Test forecast date: {forecast_date}")
57
+
58
+ # Step 4: Prepare inference data (single border, 7 days)
59
+ print("\n[4] Preparing inference data (1 border, 7 days)...")
60
+ test_border = fetcher.target_borders[0] # Use first border
61
+ print(f"Test border: {test_border}")
62
+
63
+ context_df, future_df = fetcher.prepare_inference_data(
64
+ forecast_date=forecast_date,
65
+ prediction_length=168, # 7 days
66
+ borders=[test_border]
67
+ )
68
+
69
+ print(f"Context shape: {context_df.shape}")
70
+ print(f"Future shape: {future_df.shape}")
71
+
72
+ # Validate data
73
+ print("\n[5] Validating prepared data...")
74
+ assert 'timestamp' in context_df.columns, "Missing timestamp column"
75
+ assert 'border' in context_df.columns, "Missing border column"
76
+ assert 'target' in context_df.columns, "Missing target column"
77
+ assert len(context_df) > 0, "Empty context data"
78
+ assert len(future_df) > 0, "Empty future data"
79
+ print("[+] Data validation passed!")
80
+
81
+ # Check for NaN values
82
+ context_nulls = context_df.isnull().sum().sum()
83
+ future_nulls = future_df.isnull().sum().sum()
84
+ print(f"Context NaN count: {context_nulls}")
85
+ print(f"Future NaN count: {future_nulls}")
86
+
87
+ if context_nulls > 0 or future_nulls > 0:
88
+ print("[!] Warning: Data contains NaN values (will be handled by model)")
89
+
90
+ # Step 6: Initialize Chronos 2 forecaster
91
+ print("\n[6] Initializing Chronos 2 forecaster...")
92
+ forecaster = ChronosForecaster(
93
+ model_name="amazon/chronos-2-large",
94
+ device="auto" # Will use GPU if available
95
+ )
96
+
97
+ # Step 7: Load model
98
+ print("\n[7] Loading Chronos 2 Large model...")
99
+ print("(This may take a few minutes on first load)")
100
+ forecaster.load_model()
101
+ print("[+] Model loaded successfully!")
102
+
103
+ # Step 8: Run inference
104
+ print("\n[8] Running zero-shot inference...")
105
+ print(f"Forecasting {test_border} for 7 days (168 hours)")
106
+
107
+ forecasts = forecaster.predict_single_border(
108
+ border=test_border,
109
+ context_df=context_df,
110
+ future_df=future_df,
111
+ prediction_length=168,
112
+ num_samples=100 # 100 samples for probabilistic forecast
113
+ )
114
+
115
+ print(f"[+] Inference complete! Forecast shape: {forecasts.shape}")
116
+
117
+ # Step 9: Validate forecasts
118
+ print("\n[9] Validating forecasts...")
119
+ assert len(forecasts) > 0, "Empty forecasts"
120
+ assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp"
121
+
122
+ # Check for reasonable values
123
+ if 'mean' in forecasts.columns:
124
+ mean_forecast = forecasts['mean']
125
+ print(f"Forecast statistics:")
126
+ print(f" Mean: {mean_forecast.mean():.2f} MW")
127
+ print(f" Min: {mean_forecast.min():.2f} MW")
128
+ print(f" Max: {mean_forecast.max():.2f} MW")
129
+ print(f" Std: {mean_forecast.std():.2f} MW")
130
+
131
+ # Sanity check: values should be reasonable for power capacity
132
+ assert mean_forecast.min() >= 0, "Negative forecasts detected"
133
+ assert mean_forecast.max() < 20000, "Unreasonably high forecasts"
134
+ print("[+] Forecast validation passed!")
135
+
136
+ # Step 10: Benchmark performance
137
+ print("\n[10] Benchmarking inference performance...")
138
+ metrics = forecaster.benchmark_inference(
139
+ context_df=context_df,
140
+ future_df=future_df,
141
+ prediction_length=168
142
+ )
143
+
144
+ print(f"Performance metrics:")
145
+ for key, value in metrics.items():
146
+ print(f" {key}: {value}")
147
+
148
+ # Check if we meet the 5-minute target (for 14 days)
149
+ # Scale to 14-day estimate
150
+ estimated_14d_time = metrics['inference_time_sec'] * (336 / 168)
151
+ print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)")
152
+
153
+ if estimated_14d_time < 300: # 5 minutes
154
+ print("[+] Performance target met! (<5 min for 14 days)")
155
+ else:
156
+ print("[!] Warning: May not meet 5-minute target for 14 days")
157
+
158
+ # Step 11: Save test forecasts
159
+ print("\n[11] Saving test forecasts...")
160
+ output_path = "data/evaluation/smoke_test_forecast.parquet"
161
+ forecaster.save_forecasts(forecasts, output_path)
162
+ print(f"[+] Saved to: {output_path}")
163
+
164
+ # Summary
165
+ print("\n" + "="*60)
166
+ print("SMOKE TEST SUMMARY")
167
+ print("="*60)
168
+ print("[+] All tests passed!")
169
+ print(f"[+] Border: {test_border}")
170
+ print(f"[+] Forecast length: 168 hours (7 days)")
171
+ print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s")
172
+ print(f"[+] Output shape: {forecasts.shape}")
173
+ print("\n[+] Ready for full inference run!")
174
+ print("="*60)
175
+
176
+ if __name__ == "__main__":
177
+ main()
src/inference/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Zero-shot inference pipeline for FBMC flow forecasting"""
2
+
3
+ from .data_fetcher import DataFetcher
4
+ from .chronos_pipeline import ChronosForecaster
5
+
6
+ __all__ = ['DataFetcher', 'ChronosForecaster']
src/inference/chronos_pipeline.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chronos 2 Zero-Shot Inference Pipeline
3
+
4
+ Handles:
5
+ 1. Loading Chronos 2 Large model (710M params)
6
+ 2. Running zero-shot inference using predict_df() API
7
+ 3. GPU/CPU device mapping
8
+ 4. Saving predictions to parquet
9
+ """
10
+
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, List
13
+ import pandas as pd
14
+ import torch
15
+ from datetime import datetime
16
+ import logging
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class ChronosForecaster:
23
+ """
24
+ Zero-shot forecaster using Chronos 2 Large model.
25
+
26
+ Features:
27
+ - Multivariate forecasting (multiple borders simultaneously)
28
+ - Covariate support (615 future covariates)
29
+ - Large context window (up to 8,192 hours)
30
+ - DataFrame API for easy data handling
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ model_name: str = "amazon/chronos-2-large",
36
+ device: str = "auto",
37
+ torch_dtype: str = "float32"
38
+ ):
39
+ """
40
+ Initialize Chronos 2 forecaster.
41
+
42
+ Args:
43
+ model_name: HuggingFace model name (default: chronos-2-large)
44
+ device: Device to run on ('auto', 'cuda', 'cpu')
45
+ torch_dtype: Torch dtype ('float32', 'float16', 'bfloat16')
46
+ """
47
+ self.model_name = model_name
48
+ self.device = self._resolve_device(device)
49
+ self.torch_dtype = self._resolve_dtype(torch_dtype)
50
+ self.pipeline = None
51
+
52
+ logger.info(f"ChronosForecaster initialized:")
53
+ logger.info(f" Model: {model_name}")
54
+ logger.info(f" Device: {self.device}")
55
+ logger.info(f" Dtype: {self.torch_dtype}")
56
+
57
+ def _resolve_device(self, device: str) -> str:
58
+ """Resolve device string to actual device."""
59
+ if device == "auto":
60
+ return "cuda" if torch.cuda.is_available() else "cpu"
61
+ return device
62
+
63
+ def _resolve_dtype(self, dtype_str: str) -> torch.dtype:
64
+ """Resolve dtype string to torch dtype."""
65
+ dtype_map = {
66
+ "float32": torch.float32,
67
+ "float16": torch.float16,
68
+ "bfloat16": torch.bfloat16
69
+ }
70
+ return dtype_map.get(dtype_str, torch.float32)
71
+
72
+ def load_model(self):
73
+ """Load Chronos 2 model from HuggingFace."""
74
+ if self.pipeline is not None:
75
+ logger.info("Model already loaded")
76
+ return
77
+
78
+ logger.info(f"Loading {self.model_name}...")
79
+ logger.info("This may take a few minutes on first load...")
80
+
81
+ try:
82
+ from chronos import Chronos2Pipeline
83
+
84
+ # Load with device_map for GPU support
85
+ self.pipeline = Chronos2Pipeline.from_pretrained(
86
+ self.model_name,
87
+ device_map=self.device if self.device == "cuda" else None,
88
+ torch_dtype=self.torch_dtype
89
+ )
90
+
91
+ # Move to device if not using device_map
92
+ if self.device == "cpu":
93
+ self.pipeline = self.pipeline.to(self.device)
94
+
95
+ logger.info(f"Model loaded successfully on {self.device}")
96
+
97
+ # Print GPU info if available
98
+ if self.device == "cuda":
99
+ gpu_name = torch.cuda.get_device_name(0)
100
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
101
+ logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB VRAM)")
102
+
103
+ except Exception as e:
104
+ logger.error(f"Failed to load model: {e}")
105
+ raise
106
+
107
+ def predict(
108
+ self,
109
+ context_df: pd.DataFrame,
110
+ future_df: pd.DataFrame,
111
+ prediction_length: int = 336,
112
+ id_column: str = "border",
113
+ timestamp_column: str = "timestamp",
114
+ num_samples: int = 100
115
+ ) -> pd.DataFrame:
116
+ """
117
+ Run zero-shot inference using Chronos 2.
118
+
119
+ Args:
120
+ context_df: Historical data (timestamp, border, target, features)
121
+ future_df: Future covariates (timestamp, border, future_covariates)
122
+ prediction_length: Number of hours to forecast
123
+ id_column: Column name for border ID
124
+ timestamp_column: Column name for timestamp
125
+ num_samples: Number of samples for probabilistic forecast
126
+
127
+ Returns:
128
+ forecasts_df: DataFrame with predictions (timestamp, border, mean, median, q10, q90)
129
+ """
130
+ if self.pipeline is None:
131
+ self.load_model()
132
+
133
+ logger.info("Running zero-shot inference...")
134
+ logger.info(f"Context shape: {context_df.shape}")
135
+ logger.info(f"Future shape: {future_df.shape}")
136
+ logger.info(f"Prediction length: {prediction_length} hours")
137
+ logger.info(f"Borders: {context_df[id_column].nunique()}")
138
+
139
+ try:
140
+ # Run inference
141
+ forecasts = self.pipeline.predict_df(
142
+ context_df=context_df,
143
+ future_df=future_df,
144
+ prediction_length=prediction_length,
145
+ id_column=id_column,
146
+ timestamp_column=timestamp_column,
147
+ num_samples=num_samples
148
+ )
149
+
150
+ logger.info(f"Inference complete! Forecast shape: {forecasts.shape}")
151
+
152
+ # Add metadata
153
+ forecasts['forecast_date'] = context_df[timestamp_column].max()
154
+ forecasts['model'] = self.model_name
155
+
156
+ return forecasts
157
+
158
+ except Exception as e:
159
+ logger.error(f"Inference failed: {e}")
160
+ raise
161
+
162
+ def predict_single_border(
163
+ self,
164
+ border: str,
165
+ context_df: pd.DataFrame,
166
+ future_df: pd.DataFrame,
167
+ prediction_length: int = 336,
168
+ num_samples: int = 100
169
+ ) -> pd.DataFrame:
170
+ """
171
+ Run inference for a single border (useful for testing).
172
+
173
+ Args:
174
+ border: Border name (e.g., 'AT_CZ')
175
+ context_df: Historical data
176
+ future_df: Future covariates
177
+ prediction_length: Hours to forecast
178
+ num_samples: Samples for probabilistic forecast
179
+
180
+ Returns:
181
+ forecasts_df: Predictions for single border
182
+ """
183
+ logger.info(f"Running inference for border: {border}")
184
+
185
+ # Filter for single border
186
+ context_border = context_df[context_df['border'] == border].copy()
187
+ future_border = future_df[future_df['border'] == border].copy()
188
+
189
+ # Run prediction
190
+ forecasts = self.predict(
191
+ context_df=context_border,
192
+ future_df=future_border,
193
+ prediction_length=prediction_length,
194
+ num_samples=num_samples
195
+ )
196
+
197
+ return forecasts
198
+
199
+ def save_forecasts(
200
+ self,
201
+ forecasts: pd.DataFrame,
202
+ output_path: str,
203
+ include_metadata: bool = True
204
+ ):
205
+ """
206
+ Save forecasts to parquet file.
207
+
208
+ Args:
209
+ forecasts: Forecast DataFrame
210
+ output_path: Path to save parquet file
211
+ include_metadata: Include model metadata
212
+ """
213
+ logger.info(f"Saving forecasts to: {output_path}")
214
+
215
+ # Create output directory if needed
216
+ output_path = Path(output_path)
217
+ output_path.parent.mkdir(parents=True, exist_ok=True)
218
+
219
+ # Add metadata
220
+ if include_metadata:
221
+ forecasts = forecasts.copy()
222
+ forecasts['saved_at'] = datetime.now()
223
+
224
+ # Save to parquet
225
+ forecasts.to_parquet(output_path, index=False)
226
+
227
+ logger.info(f"Saved {len(forecasts)} rows to {output_path}")
228
+
229
+ def benchmark_inference(
230
+ self,
231
+ context_df: pd.DataFrame,
232
+ future_df: pd.DataFrame,
233
+ prediction_length: int = 336
234
+ ) -> Dict[str, float]:
235
+ """
236
+ Benchmark inference speed and memory usage.
237
+
238
+ Args:
239
+ context_df: Historical data
240
+ future_df: Future covariates
241
+ prediction_length: Hours to forecast
242
+
243
+ Returns:
244
+ metrics: Dict with inference_time_sec, gpu_memory_mb
245
+ """
246
+ import time
247
+
248
+ logger.info("Benchmarking inference performance...")
249
+
250
+ # Record start time and memory
251
+ start_time = time.time()
252
+ if self.device == "cuda":
253
+ torch.cuda.reset_peak_memory_stats()
254
+
255
+ # Run inference
256
+ _ = self.predict(
257
+ context_df=context_df,
258
+ future_df=future_df,
259
+ prediction_length=prediction_length
260
+ )
261
+
262
+ # Record end time and memory
263
+ end_time = time.time()
264
+ inference_time = end_time - start_time
265
+
266
+ metrics = {
267
+ 'inference_time_sec': inference_time,
268
+ 'borders': context_df['border'].nunique(),
269
+ 'prediction_length': prediction_length
270
+ }
271
+
272
+ if self.device == "cuda":
273
+ peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB
274
+ metrics['gpu_memory_mb'] = peak_memory
275
+
276
+ logger.info(f"Inference time: {inference_time:.2f}s")
277
+ if 'gpu_memory_mb' in metrics:
278
+ logger.info(f"Peak GPU memory: {metrics['gpu_memory_mb']:.1f} MB")
279
+
280
+ return metrics
src/inference/data_fetcher.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Fetcher for Zero-Shot Inference
3
+
4
+ Prepares data for Chronos 2 inference by:
5
+ 1. Loading unified features from HuggingFace Dataset
6
+ 2. Identifying future covariates from metadata
7
+ 3. Preparing context window (historical data)
8
+ 4. Preparing future covariates for forecast horizon
9
+ 5. Formatting data for Chronos 2 predict_df() API
10
+ """
11
+
12
+ from pathlib import Path
13
+ from typing import Tuple, List, Optional
14
+ import pandas as pd
15
+ import polars as pl
16
+ from datetime import datetime, timedelta
17
+ from datasets import load_dataset
18
+ import logging
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class DataFetcher:
25
+ """
26
+ Fetches and prepares data for zero-shot Chronos 2 inference.
27
+
28
+ Handles:
29
+ - Loading unified features (2,553 features)
30
+ - Identifying future covariates (615 features)
31
+ - Creating context windows for each border
32
+ - Extending future covariates into forecast horizon
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ dataset_name: str = "evgueni-p/fbmc-features-24month",
38
+ local_features_path: Optional[str] = None,
39
+ local_metadata_path: Optional[str] = None,
40
+ context_length: int = 512,
41
+ use_local: bool = False
42
+ ):
43
+ """
44
+ Initialize DataFetcher.
45
+
46
+ Args:
47
+ dataset_name: HuggingFace dataset name
48
+ local_features_path: Path to local features parquet file
49
+ local_metadata_path: Path to local metadata CSV
50
+ context_length: Number of hours to use as context (default: 512)
51
+ use_local: If True, load from local files instead of HF Dataset
52
+ """
53
+ self.dataset_name = dataset_name
54
+ self.local_features_path = local_features_path or "data/processed/features_unified_24month.parquet"
55
+ self.local_metadata_path = local_metadata_path or "data/processed/features_unified_metadata.csv"
56
+ self.context_length = context_length
57
+ self.use_local = use_local
58
+
59
+ # Will be loaded lazily
60
+ self.features_df: Optional[pl.DataFrame] = None
61
+ self.metadata_df: Optional[pd.DataFrame] = None
62
+ self.future_covariate_cols: Optional[List[str]] = None
63
+ self.target_borders: Optional[List[str]] = None
64
+
65
+ def load_data(self):
66
+ """Load unified features and metadata."""
67
+ logger.info("Loading unified features and metadata...")
68
+
69
+ if self.use_local:
70
+ # Load from local files
71
+ logger.info(f"Loading features from: {self.local_features_path}")
72
+ self.features_df = pl.read_parquet(self.local_features_path)
73
+
74
+ logger.info(f"Loading metadata from: {self.local_metadata_path}")
75
+ self.metadata_df = pd.read_csv(self.local_metadata_path)
76
+ else:
77
+ # Load from HuggingFace Dataset
78
+ logger.info(f"Loading features from HF Dataset: {self.dataset_name}")
79
+ dataset = load_dataset(self.dataset_name, split="train")
80
+ self.features_df = pl.from_pandas(dataset.to_pandas())
81
+
82
+ # Try to load metadata from HF Dataset
83
+ try:
84
+ metadata_dataset = load_dataset(self.dataset_name, data_files="metadata.csv", split="train")
85
+ self.metadata_df = metadata_dataset.to_pandas()
86
+ except:
87
+ logger.warning("Could not load metadata from HF Dataset, falling back to local")
88
+ self.metadata_df = pd.read_csv(self.local_metadata_path)
89
+
90
+ # Ensure timestamp column is datetime
91
+ if 'timestamp' in self.features_df.columns:
92
+ self.features_df = self.features_df.with_columns(
93
+ pl.col('timestamp').str.to_datetime()
94
+ )
95
+
96
+ logger.info(f"Loaded {len(self.features_df)} rows, {len(self.features_df.columns)} columns")
97
+ logger.info(f"Date range: {self.features_df['timestamp'].min()} to {self.features_df['timestamp'].max()}")
98
+
99
+ # Identify future covariates
100
+ self._identify_future_covariates()
101
+
102
+ # Identify target borders
103
+ self._identify_target_borders()
104
+
105
+ def _identify_future_covariates(self):
106
+ """Identify columns that are future covariates from metadata."""
107
+ logger.info("Identifying future covariates from metadata...")
108
+
109
+ # Filter for future covariates
110
+ future_cov_meta = self.metadata_df[
111
+ self.metadata_df['is_future_covariate'] == True
112
+ ]
113
+
114
+ self.future_covariate_cols = future_cov_meta['feature_name'].tolist()
115
+
116
+ logger.info(f"Found {len(self.future_covariate_cols)} future covariates")
117
+ logger.info(f"Categories: {future_cov_meta['category'].value_counts().to_dict()}")
118
+
119
+ def _identify_target_borders(self):
120
+ """Identify target borders from NTC columns."""
121
+ logger.info("Identifying target borders...")
122
+
123
+ # Find all ntc_actual_* columns
124
+ ntc_cols = [col for col in self.features_df.columns if col.startswith('ntc_actual_')]
125
+
126
+ # Extract border names
127
+ self.target_borders = [col.replace('ntc_actual_', '') for col in ntc_cols]
128
+
129
+ logger.info(f"Found {len(self.target_borders)} target borders")
130
+ logger.info(f"Borders: {', '.join(self.target_borders[:5])}...")
131
+
132
+ def prepare_inference_data(
133
+ self,
134
+ forecast_date: datetime,
135
+ prediction_length: int = 336, # 14 days
136
+ borders: Optional[List[str]] = None
137
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
138
+ """
139
+ Prepare context and future data for Chronos 2 inference.
140
+
141
+ Args:
142
+ forecast_date: The date to forecast from (as-of date)
143
+ prediction_length: Number of hours to forecast (default: 336 = 14 days)
144
+ borders: List of borders to forecast (default: all borders)
145
+
146
+ Returns:
147
+ context_df: Historical data (timestamp, border, target, all features)
148
+ future_df: Future covariates (timestamp, border, future covariates only)
149
+ """
150
+ if self.features_df is None:
151
+ self.load_data()
152
+
153
+ borders = borders or self.target_borders
154
+
155
+ logger.info(f"Preparing inference data for {len(borders)} borders")
156
+ logger.info(f"Forecast date: {forecast_date}")
157
+ logger.info(f"Context length: {self.context_length} hours")
158
+ logger.info(f"Prediction length: {prediction_length} hours")
159
+
160
+ # Extract context window (historical data)
161
+ context_start = forecast_date - timedelta(hours=self.context_length)
162
+ context_df = self.features_df.filter(
163
+ (pl.col('timestamp') >= context_start) &
164
+ (pl.col('timestamp') < forecast_date)
165
+ )
166
+
167
+ logger.info(f"Context window: {context_df['timestamp'].min()} to {context_df['timestamp'].max()}")
168
+ logger.info(f"Context rows: {len(context_df)}")
169
+
170
+ # Prepare context data for each border
171
+ context_dfs = []
172
+ for border in borders:
173
+ ntc_col = f'ntc_actual_{border}'
174
+
175
+ if ntc_col not in context_df.columns:
176
+ logger.warning(f"Border {border} not found in features, skipping")
177
+ continue
178
+
179
+ # Select: timestamp, target, all features
180
+ border_context = context_df.select([
181
+ 'timestamp',
182
+ pl.lit(border).alias('border'),
183
+ pl.col(ntc_col).alias('target'),
184
+ *[col for col in context_df.columns if col not in ['timestamp', ntc_col]]
185
+ ])
186
+
187
+ context_dfs.append(border_context)
188
+
189
+ # Combine all borders
190
+ context_combined = pl.concat(context_dfs)
191
+
192
+ logger.info(f"Combined context shape: {context_combined.shape}")
193
+
194
+ # Prepare future covariates
195
+ # For MVP: Use last known values or simple forward-fill
196
+ # TODO: In production, fetch fresh weather forecasts, generate temporal features
197
+ logger.info("Preparing future covariates...")
198
+
199
+ future_dfs = []
200
+ for border in borders:
201
+ # Create future timestamps
202
+ future_timestamps = pd.date_range(
203
+ start=forecast_date,
204
+ periods=prediction_length,
205
+ freq='H'
206
+ )
207
+
208
+ # Get last known values of future covariates
209
+ last_row = context_df.filter(pl.col('timestamp') == context_df['timestamp'].max())
210
+
211
+ # Extract future covariate values
212
+ future_values = last_row.select(self.future_covariate_cols)
213
+
214
+ # Repeat for all future timestamps
215
+ future_border_df = pl.DataFrame({
216
+ 'timestamp': future_timestamps,
217
+ 'border': [border] * len(future_timestamps)
218
+ })
219
+
220
+ # Add future covariate values (forward-fill from last known)
221
+ for col in self.future_covariate_cols:
222
+ if col in future_values.columns:
223
+ value = future_values[col][0]
224
+ future_border_df = future_border_df.with_columns(
225
+ pl.lit(value).alias(col)
226
+ )
227
+
228
+ future_dfs.append(future_border_df)
229
+
230
+ # Combine all borders
231
+ future_combined = pl.concat(future_dfs)
232
+
233
+ logger.info(f"Future covariates shape: {future_combined.shape}")
234
+
235
+ # Convert to pandas for Chronos 2
236
+ context_pd = context_combined.to_pandas()
237
+ future_pd = future_combined.to_pandas()
238
+
239
+ logger.info("Data preparation complete!")
240
+ logger.info(f"Context: {context_pd.shape}, Future: {future_pd.shape}")
241
+
242
+ return context_pd, future_pd
243
+
244
+ def get_available_dates(self) -> Tuple[datetime, datetime]:
245
+ """Get the available date range in the dataset."""
246
+ if self.features_df is None:
247
+ self.load_data()
248
+
249
+ min_date = self.features_df['timestamp'].min()
250
+ max_date = self.features_df['timestamp'].max()
251
+
252
+ return min_date, max_date