Spaces:
Sleeping
feat: implement zero-shot inference pipeline for Day 3
Browse filesCreated complete inference infrastructure:
**New modules (src/inference/)**:
- data_fetcher.py: DataFetcher class for preparing Chronos 2 input
* Loads unified features from HF Dataset or local parquet
* Identifies 615 future covariates from metadata
* Prepares context windows (configurable length)
* Formats data for predict_df() API
* Handles multivariate forecasting (38 borders)
- chronos_pipeline.py: ChronosForecaster class for inference
* Loads Chronos 2 Large (710M params) with GPU support
* Zero-shot inference via predict_df() API
* Probabilistic forecasts (mean, median, quantiles)
* Performance benchmarking utilities
* Parquet export functionality
**Testing**:
- scripts/test_inference_pipeline.py: Comprehensive smoke test
* Tests data loading, model loading, inference
* Validates output quality and performance
* Estimates 14-day forecast time
* Single border × 7 days test case
**HuggingFace Space**:
- Fixed BUILD_ERROR by adding jupyterlab to requirements
- Space rebuild in progress (commit a7e66e0)
**Status**: Ready for local testing and smoke test execution
Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <[email protected]>
- scripts/test_inference_pipeline.py +177 -0
- src/inference/__init__.py +6 -0
- src/inference/chronos_pipeline.py +280 -0
- src/inference/data_fetcher.py +252 -0
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smoke test for zero-shot inference pipeline
|
| 3 |
+
|
| 4 |
+
Tests:
|
| 5 |
+
1. Data loading and preparation
|
| 6 |
+
2. Chronos 2 model loading
|
| 7 |
+
3. Inference on single border (7 days)
|
| 8 |
+
4. Output validation
|
| 9 |
+
5. Performance metrics
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# Add src to path
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
| 17 |
+
|
| 18 |
+
from inference.data_fetcher import DataFetcher
|
| 19 |
+
from inference.chronos_pipeline import ChronosForecaster
|
| 20 |
+
from datetime import datetime, timedelta
|
| 21 |
+
import torch
|
| 22 |
+
import pandas as pd
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
print("="*60)
|
| 26 |
+
print("FBMC Chronos 2 Zero-Shot Inference - Smoke Test")
|
| 27 |
+
print("="*60)
|
| 28 |
+
|
| 29 |
+
# Step 1: Check environment
|
| 30 |
+
print("\n[1] Checking environment...")
|
| 31 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 32 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 33 |
+
if torch.cuda.is_available():
|
| 34 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 35 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 36 |
+
else:
|
| 37 |
+
print("Running on CPU (inference will be slower)")
|
| 38 |
+
|
| 39 |
+
# Step 2: Initialize DataFetcher
|
| 40 |
+
print("\n[2] Initializing DataFetcher...")
|
| 41 |
+
fetcher = DataFetcher(
|
| 42 |
+
use_local=True, # Use local files for testing
|
| 43 |
+
context_length=512 # Use 512 hours context
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Step 3: Load data
|
| 47 |
+
print("\n[3] Loading unified features...")
|
| 48 |
+
fetcher.load_data()
|
| 49 |
+
|
| 50 |
+
# Get available date range
|
| 51 |
+
min_date, max_date = fetcher.get_available_dates()
|
| 52 |
+
print(f"Available data: {min_date} to {max_date}")
|
| 53 |
+
|
| 54 |
+
# Select forecast date (use last month as test)
|
| 55 |
+
forecast_date = max_date - timedelta(days=30)
|
| 56 |
+
print(f"Test forecast date: {forecast_date}")
|
| 57 |
+
|
| 58 |
+
# Step 4: Prepare inference data (single border, 7 days)
|
| 59 |
+
print("\n[4] Preparing inference data (1 border, 7 days)...")
|
| 60 |
+
test_border = fetcher.target_borders[0] # Use first border
|
| 61 |
+
print(f"Test border: {test_border}")
|
| 62 |
+
|
| 63 |
+
context_df, future_df = fetcher.prepare_inference_data(
|
| 64 |
+
forecast_date=forecast_date,
|
| 65 |
+
prediction_length=168, # 7 days
|
| 66 |
+
borders=[test_border]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
print(f"Context shape: {context_df.shape}")
|
| 70 |
+
print(f"Future shape: {future_df.shape}")
|
| 71 |
+
|
| 72 |
+
# Validate data
|
| 73 |
+
print("\n[5] Validating prepared data...")
|
| 74 |
+
assert 'timestamp' in context_df.columns, "Missing timestamp column"
|
| 75 |
+
assert 'border' in context_df.columns, "Missing border column"
|
| 76 |
+
assert 'target' in context_df.columns, "Missing target column"
|
| 77 |
+
assert len(context_df) > 0, "Empty context data"
|
| 78 |
+
assert len(future_df) > 0, "Empty future data"
|
| 79 |
+
print("[+] Data validation passed!")
|
| 80 |
+
|
| 81 |
+
# Check for NaN values
|
| 82 |
+
context_nulls = context_df.isnull().sum().sum()
|
| 83 |
+
future_nulls = future_df.isnull().sum().sum()
|
| 84 |
+
print(f"Context NaN count: {context_nulls}")
|
| 85 |
+
print(f"Future NaN count: {future_nulls}")
|
| 86 |
+
|
| 87 |
+
if context_nulls > 0 or future_nulls > 0:
|
| 88 |
+
print("[!] Warning: Data contains NaN values (will be handled by model)")
|
| 89 |
+
|
| 90 |
+
# Step 6: Initialize Chronos 2 forecaster
|
| 91 |
+
print("\n[6] Initializing Chronos 2 forecaster...")
|
| 92 |
+
forecaster = ChronosForecaster(
|
| 93 |
+
model_name="amazon/chronos-2-large",
|
| 94 |
+
device="auto" # Will use GPU if available
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Step 7: Load model
|
| 98 |
+
print("\n[7] Loading Chronos 2 Large model...")
|
| 99 |
+
print("(This may take a few minutes on first load)")
|
| 100 |
+
forecaster.load_model()
|
| 101 |
+
print("[+] Model loaded successfully!")
|
| 102 |
+
|
| 103 |
+
# Step 8: Run inference
|
| 104 |
+
print("\n[8] Running zero-shot inference...")
|
| 105 |
+
print(f"Forecasting {test_border} for 7 days (168 hours)")
|
| 106 |
+
|
| 107 |
+
forecasts = forecaster.predict_single_border(
|
| 108 |
+
border=test_border,
|
| 109 |
+
context_df=context_df,
|
| 110 |
+
future_df=future_df,
|
| 111 |
+
prediction_length=168,
|
| 112 |
+
num_samples=100 # 100 samples for probabilistic forecast
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
print(f"[+] Inference complete! Forecast shape: {forecasts.shape}")
|
| 116 |
+
|
| 117 |
+
# Step 9: Validate forecasts
|
| 118 |
+
print("\n[9] Validating forecasts...")
|
| 119 |
+
assert len(forecasts) > 0, "Empty forecasts"
|
| 120 |
+
assert 'timestamp' in forecasts.columns or forecasts.index.name == 'timestamp', "Missing timestamp"
|
| 121 |
+
|
| 122 |
+
# Check for reasonable values
|
| 123 |
+
if 'mean' in forecasts.columns:
|
| 124 |
+
mean_forecast = forecasts['mean']
|
| 125 |
+
print(f"Forecast statistics:")
|
| 126 |
+
print(f" Mean: {mean_forecast.mean():.2f} MW")
|
| 127 |
+
print(f" Min: {mean_forecast.min():.2f} MW")
|
| 128 |
+
print(f" Max: {mean_forecast.max():.2f} MW")
|
| 129 |
+
print(f" Std: {mean_forecast.std():.2f} MW")
|
| 130 |
+
|
| 131 |
+
# Sanity check: values should be reasonable for power capacity
|
| 132 |
+
assert mean_forecast.min() >= 0, "Negative forecasts detected"
|
| 133 |
+
assert mean_forecast.max() < 20000, "Unreasonably high forecasts"
|
| 134 |
+
print("[+] Forecast validation passed!")
|
| 135 |
+
|
| 136 |
+
# Step 10: Benchmark performance
|
| 137 |
+
print("\n[10] Benchmarking inference performance...")
|
| 138 |
+
metrics = forecaster.benchmark_inference(
|
| 139 |
+
context_df=context_df,
|
| 140 |
+
future_df=future_df,
|
| 141 |
+
prediction_length=168
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
print(f"Performance metrics:")
|
| 145 |
+
for key, value in metrics.items():
|
| 146 |
+
print(f" {key}: {value}")
|
| 147 |
+
|
| 148 |
+
# Check if we meet the 5-minute target (for 14 days)
|
| 149 |
+
# Scale to 14-day estimate
|
| 150 |
+
estimated_14d_time = metrics['inference_time_sec'] * (336 / 168)
|
| 151 |
+
print(f"\nEstimated time for 14-day forecast: {estimated_14d_time:.1f}s ({estimated_14d_time/60:.1f} min)")
|
| 152 |
+
|
| 153 |
+
if estimated_14d_time < 300: # 5 minutes
|
| 154 |
+
print("[+] Performance target met! (<5 min for 14 days)")
|
| 155 |
+
else:
|
| 156 |
+
print("[!] Warning: May not meet 5-minute target for 14 days")
|
| 157 |
+
|
| 158 |
+
# Step 11: Save test forecasts
|
| 159 |
+
print("\n[11] Saving test forecasts...")
|
| 160 |
+
output_path = "data/evaluation/smoke_test_forecast.parquet"
|
| 161 |
+
forecaster.save_forecasts(forecasts, output_path)
|
| 162 |
+
print(f"[+] Saved to: {output_path}")
|
| 163 |
+
|
| 164 |
+
# Summary
|
| 165 |
+
print("\n" + "="*60)
|
| 166 |
+
print("SMOKE TEST SUMMARY")
|
| 167 |
+
print("="*60)
|
| 168 |
+
print("[+] All tests passed!")
|
| 169 |
+
print(f"[+] Border: {test_border}")
|
| 170 |
+
print(f"[+] Forecast length: 168 hours (7 days)")
|
| 171 |
+
print(f"[+] Inference time: {metrics['inference_time_sec']:.1f}s")
|
| 172 |
+
print(f"[+] Output shape: {forecasts.shape}")
|
| 173 |
+
print("\n[+] Ready for full inference run!")
|
| 174 |
+
print("="*60)
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
main()
|
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Zero-shot inference pipeline for FBMC flow forecasting"""
|
| 2 |
+
|
| 3 |
+
from .data_fetcher import DataFetcher
|
| 4 |
+
from .chronos_pipeline import ChronosForecaster
|
| 5 |
+
|
| 6 |
+
__all__ = ['DataFetcher', 'ChronosForecaster']
|
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chronos 2 Zero-Shot Inference Pipeline
|
| 3 |
+
|
| 4 |
+
Handles:
|
| 5 |
+
1. Loading Chronos 2 Large model (710M params)
|
| 6 |
+
2. Running zero-shot inference using predict_df() API
|
| 7 |
+
3. GPU/CPU device mapping
|
| 8 |
+
4. Saving predictions to parquet
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Optional, Dict, List
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import torch
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ChronosForecaster:
|
| 23 |
+
"""
|
| 24 |
+
Zero-shot forecaster using Chronos 2 Large model.
|
| 25 |
+
|
| 26 |
+
Features:
|
| 27 |
+
- Multivariate forecasting (multiple borders simultaneously)
|
| 28 |
+
- Covariate support (615 future covariates)
|
| 29 |
+
- Large context window (up to 8,192 hours)
|
| 30 |
+
- DataFrame API for easy data handling
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
model_name: str = "amazon/chronos-2-large",
|
| 36 |
+
device: str = "auto",
|
| 37 |
+
torch_dtype: str = "float32"
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Initialize Chronos 2 forecaster.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model_name: HuggingFace model name (default: chronos-2-large)
|
| 44 |
+
device: Device to run on ('auto', 'cuda', 'cpu')
|
| 45 |
+
torch_dtype: Torch dtype ('float32', 'float16', 'bfloat16')
|
| 46 |
+
"""
|
| 47 |
+
self.model_name = model_name
|
| 48 |
+
self.device = self._resolve_device(device)
|
| 49 |
+
self.torch_dtype = self._resolve_dtype(torch_dtype)
|
| 50 |
+
self.pipeline = None
|
| 51 |
+
|
| 52 |
+
logger.info(f"ChronosForecaster initialized:")
|
| 53 |
+
logger.info(f" Model: {model_name}")
|
| 54 |
+
logger.info(f" Device: {self.device}")
|
| 55 |
+
logger.info(f" Dtype: {self.torch_dtype}")
|
| 56 |
+
|
| 57 |
+
def _resolve_device(self, device: str) -> str:
|
| 58 |
+
"""Resolve device string to actual device."""
|
| 59 |
+
if device == "auto":
|
| 60 |
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
+
return device
|
| 62 |
+
|
| 63 |
+
def _resolve_dtype(self, dtype_str: str) -> torch.dtype:
|
| 64 |
+
"""Resolve dtype string to torch dtype."""
|
| 65 |
+
dtype_map = {
|
| 66 |
+
"float32": torch.float32,
|
| 67 |
+
"float16": torch.float16,
|
| 68 |
+
"bfloat16": torch.bfloat16
|
| 69 |
+
}
|
| 70 |
+
return dtype_map.get(dtype_str, torch.float32)
|
| 71 |
+
|
| 72 |
+
def load_model(self):
|
| 73 |
+
"""Load Chronos 2 model from HuggingFace."""
|
| 74 |
+
if self.pipeline is not None:
|
| 75 |
+
logger.info("Model already loaded")
|
| 76 |
+
return
|
| 77 |
+
|
| 78 |
+
logger.info(f"Loading {self.model_name}...")
|
| 79 |
+
logger.info("This may take a few minutes on first load...")
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
from chronos import Chronos2Pipeline
|
| 83 |
+
|
| 84 |
+
# Load with device_map for GPU support
|
| 85 |
+
self.pipeline = Chronos2Pipeline.from_pretrained(
|
| 86 |
+
self.model_name,
|
| 87 |
+
device_map=self.device if self.device == "cuda" else None,
|
| 88 |
+
torch_dtype=self.torch_dtype
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Move to device if not using device_map
|
| 92 |
+
if self.device == "cpu":
|
| 93 |
+
self.pipeline = self.pipeline.to(self.device)
|
| 94 |
+
|
| 95 |
+
logger.info(f"Model loaded successfully on {self.device}")
|
| 96 |
+
|
| 97 |
+
# Print GPU info if available
|
| 98 |
+
if self.device == "cuda":
|
| 99 |
+
gpu_name = torch.cuda.get_device_name(0)
|
| 100 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 101 |
+
logger.info(f"GPU: {gpu_name} ({gpu_memory:.1f} GB VRAM)")
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"Failed to load model: {e}")
|
| 105 |
+
raise
|
| 106 |
+
|
| 107 |
+
def predict(
|
| 108 |
+
self,
|
| 109 |
+
context_df: pd.DataFrame,
|
| 110 |
+
future_df: pd.DataFrame,
|
| 111 |
+
prediction_length: int = 336,
|
| 112 |
+
id_column: str = "border",
|
| 113 |
+
timestamp_column: str = "timestamp",
|
| 114 |
+
num_samples: int = 100
|
| 115 |
+
) -> pd.DataFrame:
|
| 116 |
+
"""
|
| 117 |
+
Run zero-shot inference using Chronos 2.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
context_df: Historical data (timestamp, border, target, features)
|
| 121 |
+
future_df: Future covariates (timestamp, border, future_covariates)
|
| 122 |
+
prediction_length: Number of hours to forecast
|
| 123 |
+
id_column: Column name for border ID
|
| 124 |
+
timestamp_column: Column name for timestamp
|
| 125 |
+
num_samples: Number of samples for probabilistic forecast
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
forecasts_df: DataFrame with predictions (timestamp, border, mean, median, q10, q90)
|
| 129 |
+
"""
|
| 130 |
+
if self.pipeline is None:
|
| 131 |
+
self.load_model()
|
| 132 |
+
|
| 133 |
+
logger.info("Running zero-shot inference...")
|
| 134 |
+
logger.info(f"Context shape: {context_df.shape}")
|
| 135 |
+
logger.info(f"Future shape: {future_df.shape}")
|
| 136 |
+
logger.info(f"Prediction length: {prediction_length} hours")
|
| 137 |
+
logger.info(f"Borders: {context_df[id_column].nunique()}")
|
| 138 |
+
|
| 139 |
+
try:
|
| 140 |
+
# Run inference
|
| 141 |
+
forecasts = self.pipeline.predict_df(
|
| 142 |
+
context_df=context_df,
|
| 143 |
+
future_df=future_df,
|
| 144 |
+
prediction_length=prediction_length,
|
| 145 |
+
id_column=id_column,
|
| 146 |
+
timestamp_column=timestamp_column,
|
| 147 |
+
num_samples=num_samples
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
logger.info(f"Inference complete! Forecast shape: {forecasts.shape}")
|
| 151 |
+
|
| 152 |
+
# Add metadata
|
| 153 |
+
forecasts['forecast_date'] = context_df[timestamp_column].max()
|
| 154 |
+
forecasts['model'] = self.model_name
|
| 155 |
+
|
| 156 |
+
return forecasts
|
| 157 |
+
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Inference failed: {e}")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
def predict_single_border(
|
| 163 |
+
self,
|
| 164 |
+
border: str,
|
| 165 |
+
context_df: pd.DataFrame,
|
| 166 |
+
future_df: pd.DataFrame,
|
| 167 |
+
prediction_length: int = 336,
|
| 168 |
+
num_samples: int = 100
|
| 169 |
+
) -> pd.DataFrame:
|
| 170 |
+
"""
|
| 171 |
+
Run inference for a single border (useful for testing).
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
border: Border name (e.g., 'AT_CZ')
|
| 175 |
+
context_df: Historical data
|
| 176 |
+
future_df: Future covariates
|
| 177 |
+
prediction_length: Hours to forecast
|
| 178 |
+
num_samples: Samples for probabilistic forecast
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
forecasts_df: Predictions for single border
|
| 182 |
+
"""
|
| 183 |
+
logger.info(f"Running inference for border: {border}")
|
| 184 |
+
|
| 185 |
+
# Filter for single border
|
| 186 |
+
context_border = context_df[context_df['border'] == border].copy()
|
| 187 |
+
future_border = future_df[future_df['border'] == border].copy()
|
| 188 |
+
|
| 189 |
+
# Run prediction
|
| 190 |
+
forecasts = self.predict(
|
| 191 |
+
context_df=context_border,
|
| 192 |
+
future_df=future_border,
|
| 193 |
+
prediction_length=prediction_length,
|
| 194 |
+
num_samples=num_samples
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
return forecasts
|
| 198 |
+
|
| 199 |
+
def save_forecasts(
|
| 200 |
+
self,
|
| 201 |
+
forecasts: pd.DataFrame,
|
| 202 |
+
output_path: str,
|
| 203 |
+
include_metadata: bool = True
|
| 204 |
+
):
|
| 205 |
+
"""
|
| 206 |
+
Save forecasts to parquet file.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
forecasts: Forecast DataFrame
|
| 210 |
+
output_path: Path to save parquet file
|
| 211 |
+
include_metadata: Include model metadata
|
| 212 |
+
"""
|
| 213 |
+
logger.info(f"Saving forecasts to: {output_path}")
|
| 214 |
+
|
| 215 |
+
# Create output directory if needed
|
| 216 |
+
output_path = Path(output_path)
|
| 217 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
|
| 219 |
+
# Add metadata
|
| 220 |
+
if include_metadata:
|
| 221 |
+
forecasts = forecasts.copy()
|
| 222 |
+
forecasts['saved_at'] = datetime.now()
|
| 223 |
+
|
| 224 |
+
# Save to parquet
|
| 225 |
+
forecasts.to_parquet(output_path, index=False)
|
| 226 |
+
|
| 227 |
+
logger.info(f"Saved {len(forecasts)} rows to {output_path}")
|
| 228 |
+
|
| 229 |
+
def benchmark_inference(
|
| 230 |
+
self,
|
| 231 |
+
context_df: pd.DataFrame,
|
| 232 |
+
future_df: pd.DataFrame,
|
| 233 |
+
prediction_length: int = 336
|
| 234 |
+
) -> Dict[str, float]:
|
| 235 |
+
"""
|
| 236 |
+
Benchmark inference speed and memory usage.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
context_df: Historical data
|
| 240 |
+
future_df: Future covariates
|
| 241 |
+
prediction_length: Hours to forecast
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
metrics: Dict with inference_time_sec, gpu_memory_mb
|
| 245 |
+
"""
|
| 246 |
+
import time
|
| 247 |
+
|
| 248 |
+
logger.info("Benchmarking inference performance...")
|
| 249 |
+
|
| 250 |
+
# Record start time and memory
|
| 251 |
+
start_time = time.time()
|
| 252 |
+
if self.device == "cuda":
|
| 253 |
+
torch.cuda.reset_peak_memory_stats()
|
| 254 |
+
|
| 255 |
+
# Run inference
|
| 256 |
+
_ = self.predict(
|
| 257 |
+
context_df=context_df,
|
| 258 |
+
future_df=future_df,
|
| 259 |
+
prediction_length=prediction_length
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Record end time and memory
|
| 263 |
+
end_time = time.time()
|
| 264 |
+
inference_time = end_time - start_time
|
| 265 |
+
|
| 266 |
+
metrics = {
|
| 267 |
+
'inference_time_sec': inference_time,
|
| 268 |
+
'borders': context_df['border'].nunique(),
|
| 269 |
+
'prediction_length': prediction_length
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
if self.device == "cuda":
|
| 273 |
+
peak_memory = torch.cuda.max_memory_allocated() / 1e6 # MB
|
| 274 |
+
metrics['gpu_memory_mb'] = peak_memory
|
| 275 |
+
|
| 276 |
+
logger.info(f"Inference time: {inference_time:.2f}s")
|
| 277 |
+
if 'gpu_memory_mb' in metrics:
|
| 278 |
+
logger.info(f"Peak GPU memory: {metrics['gpu_memory_mb']:.1f} MB")
|
| 279 |
+
|
| 280 |
+
return metrics
|
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Fetcher for Zero-Shot Inference
|
| 3 |
+
|
| 4 |
+
Prepares data for Chronos 2 inference by:
|
| 5 |
+
1. Loading unified features from HuggingFace Dataset
|
| 6 |
+
2. Identifying future covariates from metadata
|
| 7 |
+
3. Preparing context window (historical data)
|
| 8 |
+
4. Preparing future covariates for forecast horizon
|
| 9 |
+
5. Formatting data for Chronos 2 predict_df() API
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Tuple, List, Optional
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import polars as pl
|
| 16 |
+
from datetime import datetime, timedelta
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class DataFetcher:
|
| 25 |
+
"""
|
| 26 |
+
Fetches and prepares data for zero-shot Chronos 2 inference.
|
| 27 |
+
|
| 28 |
+
Handles:
|
| 29 |
+
- Loading unified features (2,553 features)
|
| 30 |
+
- Identifying future covariates (615 features)
|
| 31 |
+
- Creating context windows for each border
|
| 32 |
+
- Extending future covariates into forecast horizon
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
dataset_name: str = "evgueni-p/fbmc-features-24month",
|
| 38 |
+
local_features_path: Optional[str] = None,
|
| 39 |
+
local_metadata_path: Optional[str] = None,
|
| 40 |
+
context_length: int = 512,
|
| 41 |
+
use_local: bool = False
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Initialize DataFetcher.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
dataset_name: HuggingFace dataset name
|
| 48 |
+
local_features_path: Path to local features parquet file
|
| 49 |
+
local_metadata_path: Path to local metadata CSV
|
| 50 |
+
context_length: Number of hours to use as context (default: 512)
|
| 51 |
+
use_local: If True, load from local files instead of HF Dataset
|
| 52 |
+
"""
|
| 53 |
+
self.dataset_name = dataset_name
|
| 54 |
+
self.local_features_path = local_features_path or "data/processed/features_unified_24month.parquet"
|
| 55 |
+
self.local_metadata_path = local_metadata_path or "data/processed/features_unified_metadata.csv"
|
| 56 |
+
self.context_length = context_length
|
| 57 |
+
self.use_local = use_local
|
| 58 |
+
|
| 59 |
+
# Will be loaded lazily
|
| 60 |
+
self.features_df: Optional[pl.DataFrame] = None
|
| 61 |
+
self.metadata_df: Optional[pd.DataFrame] = None
|
| 62 |
+
self.future_covariate_cols: Optional[List[str]] = None
|
| 63 |
+
self.target_borders: Optional[List[str]] = None
|
| 64 |
+
|
| 65 |
+
def load_data(self):
|
| 66 |
+
"""Load unified features and metadata."""
|
| 67 |
+
logger.info("Loading unified features and metadata...")
|
| 68 |
+
|
| 69 |
+
if self.use_local:
|
| 70 |
+
# Load from local files
|
| 71 |
+
logger.info(f"Loading features from: {self.local_features_path}")
|
| 72 |
+
self.features_df = pl.read_parquet(self.local_features_path)
|
| 73 |
+
|
| 74 |
+
logger.info(f"Loading metadata from: {self.local_metadata_path}")
|
| 75 |
+
self.metadata_df = pd.read_csv(self.local_metadata_path)
|
| 76 |
+
else:
|
| 77 |
+
# Load from HuggingFace Dataset
|
| 78 |
+
logger.info(f"Loading features from HF Dataset: {self.dataset_name}")
|
| 79 |
+
dataset = load_dataset(self.dataset_name, split="train")
|
| 80 |
+
self.features_df = pl.from_pandas(dataset.to_pandas())
|
| 81 |
+
|
| 82 |
+
# Try to load metadata from HF Dataset
|
| 83 |
+
try:
|
| 84 |
+
metadata_dataset = load_dataset(self.dataset_name, data_files="metadata.csv", split="train")
|
| 85 |
+
self.metadata_df = metadata_dataset.to_pandas()
|
| 86 |
+
except:
|
| 87 |
+
logger.warning("Could not load metadata from HF Dataset, falling back to local")
|
| 88 |
+
self.metadata_df = pd.read_csv(self.local_metadata_path)
|
| 89 |
+
|
| 90 |
+
# Ensure timestamp column is datetime
|
| 91 |
+
if 'timestamp' in self.features_df.columns:
|
| 92 |
+
self.features_df = self.features_df.with_columns(
|
| 93 |
+
pl.col('timestamp').str.to_datetime()
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
logger.info(f"Loaded {len(self.features_df)} rows, {len(self.features_df.columns)} columns")
|
| 97 |
+
logger.info(f"Date range: {self.features_df['timestamp'].min()} to {self.features_df['timestamp'].max()}")
|
| 98 |
+
|
| 99 |
+
# Identify future covariates
|
| 100 |
+
self._identify_future_covariates()
|
| 101 |
+
|
| 102 |
+
# Identify target borders
|
| 103 |
+
self._identify_target_borders()
|
| 104 |
+
|
| 105 |
+
def _identify_future_covariates(self):
|
| 106 |
+
"""Identify columns that are future covariates from metadata."""
|
| 107 |
+
logger.info("Identifying future covariates from metadata...")
|
| 108 |
+
|
| 109 |
+
# Filter for future covariates
|
| 110 |
+
future_cov_meta = self.metadata_df[
|
| 111 |
+
self.metadata_df['is_future_covariate'] == True
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
self.future_covariate_cols = future_cov_meta['feature_name'].tolist()
|
| 115 |
+
|
| 116 |
+
logger.info(f"Found {len(self.future_covariate_cols)} future covariates")
|
| 117 |
+
logger.info(f"Categories: {future_cov_meta['category'].value_counts().to_dict()}")
|
| 118 |
+
|
| 119 |
+
def _identify_target_borders(self):
|
| 120 |
+
"""Identify target borders from NTC columns."""
|
| 121 |
+
logger.info("Identifying target borders...")
|
| 122 |
+
|
| 123 |
+
# Find all ntc_actual_* columns
|
| 124 |
+
ntc_cols = [col for col in self.features_df.columns if col.startswith('ntc_actual_')]
|
| 125 |
+
|
| 126 |
+
# Extract border names
|
| 127 |
+
self.target_borders = [col.replace('ntc_actual_', '') for col in ntc_cols]
|
| 128 |
+
|
| 129 |
+
logger.info(f"Found {len(self.target_borders)} target borders")
|
| 130 |
+
logger.info(f"Borders: {', '.join(self.target_borders[:5])}...")
|
| 131 |
+
|
| 132 |
+
def prepare_inference_data(
|
| 133 |
+
self,
|
| 134 |
+
forecast_date: datetime,
|
| 135 |
+
prediction_length: int = 336, # 14 days
|
| 136 |
+
borders: Optional[List[str]] = None
|
| 137 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 138 |
+
"""
|
| 139 |
+
Prepare context and future data for Chronos 2 inference.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
forecast_date: The date to forecast from (as-of date)
|
| 143 |
+
prediction_length: Number of hours to forecast (default: 336 = 14 days)
|
| 144 |
+
borders: List of borders to forecast (default: all borders)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
context_df: Historical data (timestamp, border, target, all features)
|
| 148 |
+
future_df: Future covariates (timestamp, border, future covariates only)
|
| 149 |
+
"""
|
| 150 |
+
if self.features_df is None:
|
| 151 |
+
self.load_data()
|
| 152 |
+
|
| 153 |
+
borders = borders or self.target_borders
|
| 154 |
+
|
| 155 |
+
logger.info(f"Preparing inference data for {len(borders)} borders")
|
| 156 |
+
logger.info(f"Forecast date: {forecast_date}")
|
| 157 |
+
logger.info(f"Context length: {self.context_length} hours")
|
| 158 |
+
logger.info(f"Prediction length: {prediction_length} hours")
|
| 159 |
+
|
| 160 |
+
# Extract context window (historical data)
|
| 161 |
+
context_start = forecast_date - timedelta(hours=self.context_length)
|
| 162 |
+
context_df = self.features_df.filter(
|
| 163 |
+
(pl.col('timestamp') >= context_start) &
|
| 164 |
+
(pl.col('timestamp') < forecast_date)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
logger.info(f"Context window: {context_df['timestamp'].min()} to {context_df['timestamp'].max()}")
|
| 168 |
+
logger.info(f"Context rows: {len(context_df)}")
|
| 169 |
+
|
| 170 |
+
# Prepare context data for each border
|
| 171 |
+
context_dfs = []
|
| 172 |
+
for border in borders:
|
| 173 |
+
ntc_col = f'ntc_actual_{border}'
|
| 174 |
+
|
| 175 |
+
if ntc_col not in context_df.columns:
|
| 176 |
+
logger.warning(f"Border {border} not found in features, skipping")
|
| 177 |
+
continue
|
| 178 |
+
|
| 179 |
+
# Select: timestamp, target, all features
|
| 180 |
+
border_context = context_df.select([
|
| 181 |
+
'timestamp',
|
| 182 |
+
pl.lit(border).alias('border'),
|
| 183 |
+
pl.col(ntc_col).alias('target'),
|
| 184 |
+
*[col for col in context_df.columns if col not in ['timestamp', ntc_col]]
|
| 185 |
+
])
|
| 186 |
+
|
| 187 |
+
context_dfs.append(border_context)
|
| 188 |
+
|
| 189 |
+
# Combine all borders
|
| 190 |
+
context_combined = pl.concat(context_dfs)
|
| 191 |
+
|
| 192 |
+
logger.info(f"Combined context shape: {context_combined.shape}")
|
| 193 |
+
|
| 194 |
+
# Prepare future covariates
|
| 195 |
+
# For MVP: Use last known values or simple forward-fill
|
| 196 |
+
# TODO: In production, fetch fresh weather forecasts, generate temporal features
|
| 197 |
+
logger.info("Preparing future covariates...")
|
| 198 |
+
|
| 199 |
+
future_dfs = []
|
| 200 |
+
for border in borders:
|
| 201 |
+
# Create future timestamps
|
| 202 |
+
future_timestamps = pd.date_range(
|
| 203 |
+
start=forecast_date,
|
| 204 |
+
periods=prediction_length,
|
| 205 |
+
freq='H'
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Get last known values of future covariates
|
| 209 |
+
last_row = context_df.filter(pl.col('timestamp') == context_df['timestamp'].max())
|
| 210 |
+
|
| 211 |
+
# Extract future covariate values
|
| 212 |
+
future_values = last_row.select(self.future_covariate_cols)
|
| 213 |
+
|
| 214 |
+
# Repeat for all future timestamps
|
| 215 |
+
future_border_df = pl.DataFrame({
|
| 216 |
+
'timestamp': future_timestamps,
|
| 217 |
+
'border': [border] * len(future_timestamps)
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
# Add future covariate values (forward-fill from last known)
|
| 221 |
+
for col in self.future_covariate_cols:
|
| 222 |
+
if col in future_values.columns:
|
| 223 |
+
value = future_values[col][0]
|
| 224 |
+
future_border_df = future_border_df.with_columns(
|
| 225 |
+
pl.lit(value).alias(col)
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
future_dfs.append(future_border_df)
|
| 229 |
+
|
| 230 |
+
# Combine all borders
|
| 231 |
+
future_combined = pl.concat(future_dfs)
|
| 232 |
+
|
| 233 |
+
logger.info(f"Future covariates shape: {future_combined.shape}")
|
| 234 |
+
|
| 235 |
+
# Convert to pandas for Chronos 2
|
| 236 |
+
context_pd = context_combined.to_pandas()
|
| 237 |
+
future_pd = future_combined.to_pandas()
|
| 238 |
+
|
| 239 |
+
logger.info("Data preparation complete!")
|
| 240 |
+
logger.info(f"Context: {context_pd.shape}, Future: {future_pd.shape}")
|
| 241 |
+
|
| 242 |
+
return context_pd, future_pd
|
| 243 |
+
|
| 244 |
+
def get_available_dates(self) -> Tuple[datetime, datetime]:
|
| 245 |
+
"""Get the available date range in the dataset."""
|
| 246 |
+
if self.features_df is None:
|
| 247 |
+
self.load_data()
|
| 248 |
+
|
| 249 |
+
min_date = self.features_df['timestamp'].min()
|
| 250 |
+
max_date = self.features_df['timestamp'].max()
|
| 251 |
+
|
| 252 |
+
return min_date, max_date
|