File size: 30,541 Bytes
610152e
 
 
 
 
 
 
 
 
4f0125c
610152e
4f0125c
 
 
 
 
 
 
 
 
610152e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0125c
 
 
610152e
 
 
 
 
4f0125c
 
 
610152e
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0125c
 
 
 
 
 
 
 
 
610152e
 
 
 
 
 
 
 
4f0125c
610152e
 
4f0125c
 
 
 
 
 
610152e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0125c
610152e
 
 
 
 
4f0125c
610152e
 
4f0125c
 
610152e
 
 
 
 
 
 
 
4f0125c
610152e
4f0125c
610152e
4f0125c
 
 
 
610152e
4f0125c
610152e
4f0125c
 
 
 
 
 
610152e
 
4f0125c
 
 
610152e
 
 
 
4f0125c
610152e
 
 
 
 
 
 
 
 
 
4f0125c
610152e
 
 
4f0125c
 
 
 
610152e
4f0125c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610152e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0125c
610152e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0125c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610152e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
# aurora_pipeline.py
# End-to-end pipeline for CAMS data β†’ Aurora model β†’ predictions β†’ NetCDF
import subprocess
import os

def get_freest_cuda_device_id():
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
            stdout=subprocess.PIPE, encoding='utf-8'
        )
        memory_free = [int(x) for x in result.stdout.strip().split('\n')]
        device_id = memory_free.index(max(memory_free))
        return str(device_id)
    except Exception as e:
        print(f"Could not query nvidia-smi, defaulting to 0. Error: {e}")
        return "0"

# Set CUDA_VISIBLE_DEVICES before importing torch
os.environ["CUDA_VISIBLE_DEVICES"] = get_freest_cuda_device_id()


import torch
import xarray as xr
import pickle
from pathlib import Path
import numpy as np
import zipfile
import cdsapi
from huggingface_hub import hf_hub_download
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime, timedelta
from aurora import Batch, Metadata, AuroraAirPollution, rollout


class AuroraPipeline:
    def __init__(self, 
             extracted_dir="downloads/extracted", 
             static_path="static_vars.pkl", 
             model_ckpt="aurora-0.4-air-pollution.ckpt", 
             model_repo="microsoft/aurora",
             device=None,
             cpu_only=False):

        if device is None or device == "cuda":
            # CUDA_VISIBLE_DEVICES is set, so use 'cuda:0'
            device = "cuda:0" if torch.cuda.is_available() and not cpu_only else "cpu"

        self.extracted_dir = Path(extracted_dir)
        self.static_path = Path(static_path)
        self.model_ckpt = model_ckpt
        self.model_repo = model_repo
        self.device = device
        self.cpu_only = cpu_only or (device == "cpu")
        self.static_vars = self._load_static_vars()
        self.model = None

    def _load_static_vars(self):
        """Load static variables from Hugging Face Hub"""
        static_path = hf_hub_download(
            repo_id="microsoft/aurora",
            filename="aurora-0.4-air-pollution-static.pickle",
        )
        if not Path(static_path).exists():
            raise FileNotFoundError(f"Static variables file not found: {static_path}")
        with open(static_path, "rb") as f:
            static_vars = pickle.load(f)
        return static_vars

    def create_batch(self, date_str, Batch, Metadata, time_index=1):
        """Create a batch for Aurora model from CAMS data
        
        Args:
            date_str: Date string (YYYY-MM-DD)
            Batch: Aurora Batch class
            Metadata: Aurora Metadata class
            time_index: 0 for T-1 (first time), 1 for T (second time)
        """
        surface_path = self.extracted_dir / f"{date_str}-cams-surface.nc"
        atmos_path = self.extracted_dir / f"{date_str}-cams-atmospheric.nc"
        if not surface_path.exists() or not atmos_path.exists():
            raise FileNotFoundError(f"Missing CAMS files for {date_str} in {self.extracted_dir}")

        surf_vars_ds = xr.open_dataset(surface_path, engine="netcdf4", decode_timedelta=True)
        atmos_vars_ds = xr.open_dataset(atmos_path, engine="netcdf4", decode_timedelta=True)

        # Select zero-hour forecast but keep both time steps
        surf_vars_ds = surf_vars_ds.isel(forecast_period=0)
        atmos_vars_ds = atmos_vars_ds.isel(forecast_period=0)
        
        # Don't select time index - Aurora needs both T-1 and T as input
        print(f"πŸ• Using both time steps (T-1 and T) as input for Aurora")

        # Get the time for metadata (use the specified time_index for metadata only)
        selected_time = surf_vars_ds.forecast_reference_time.values[time_index].astype("datetime64[s]").tolist()

        batch = Batch(
            surf_vars={
                "2t": torch.from_numpy(surf_vars_ds["t2m"].values[None]),
                "10u": torch.from_numpy(surf_vars_ds["u10"].values[None]),
                "10v": torch.from_numpy(surf_vars_ds["v10"].values[None]),
                "msl": torch.from_numpy(surf_vars_ds["msl"].values[None]),
                "pm1": torch.from_numpy(surf_vars_ds["pm1"].values[None]),
                "pm2p5": torch.from_numpy(surf_vars_ds["pm2p5"].values[None]),
                "pm10": torch.from_numpy(surf_vars_ds["pm10"].values[None]),
                "tcco": torch.from_numpy(surf_vars_ds["tcco"].values[None]),
                "tc_no": torch.from_numpy(surf_vars_ds["tc_no"].values[None]),
                "tcno2": torch.from_numpy(surf_vars_ds["tcno2"].values[None]),
                "gtco3": torch.from_numpy(surf_vars_ds["gtco3"].values[None]),
                "tcso2": torch.from_numpy(surf_vars_ds["tcso2"].values[None]),
            },
            static_vars={k: torch.from_numpy(v) for k, v in self.static_vars.items()},
            atmos_vars={
                "t": torch.from_numpy(atmos_vars_ds["t"].values[None]),
                "u": torch.from_numpy(atmos_vars_ds["u"].values[None]),
                "v": torch.from_numpy(atmos_vars_ds["v"].values[None]),
                "q": torch.from_numpy(atmos_vars_ds["q"].values[None]),
                "z": torch.from_numpy(atmos_vars_ds["z"].values[None]),
                "co": torch.from_numpy(atmos_vars_ds["co"].values[None]),
                "no": torch.from_numpy(atmos_vars_ds["no"].values[None]),
                "no2": torch.from_numpy(atmos_vars_ds["no2"].values[None]),
                "go3": torch.from_numpy(atmos_vars_ds["go3"].values[None]),
                "so2": torch.from_numpy(atmos_vars_ds["so2"].values[None]),
            },
            metadata=Metadata(
                lat=torch.from_numpy(atmos_vars_ds.latitude.values),
                lon=torch.from_numpy(atmos_vars_ds.longitude.values),
                time=(selected_time,),
                atmos_levels=tuple(int(level) for level in atmos_vars_ds.pressure_level.values),
            ),
        )
        return batch
    def load_model(self, AuroraAirPollution):
        """Load Aurora model and move to device"""
        import gc
        
        # Check memory BEFORE loading
        if torch.cuda.is_available():
            print(f"πŸ“Š GPU Memory BEFORE loading model:")
            print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
            print(f"   Reserved:  {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
            print(f"   Free:      {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3:.2f} GB")
        
        # Clear cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
        
        model = AuroraAirPollution()
        
        # Check AFTER initialization but BEFORE loading checkpoint
        if torch.cuda.is_available():
            print(f"οΏ½ GPU Memory AFTER model init:")
            print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        
        model.load_checkpoint(self.model_repo, self.model_ckpt)
        
        # Check AFTER loading checkpoint
        if torch.cuda.is_available():
            print(f"πŸ“Š GPU Memory AFTER checkpoint load:")
            print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
        
        model.eval()
        model = model.to(self.device)
        
        # Check AFTER moving to device
        if torch.cuda.is_available():
            print(f"πŸ“Š GPU Memory AFTER moving to device:")
            print(f"   Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
            print(f"   Reserved:  {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
        
        self.model = model
        print(f"βœ… Model loaded on {self.device}")
        return model

    def predict(self, batch, rollout, steps=4):
        if self.model is None:
            raise RuntimeError("Model not loaded. Call load_model() first.")
        
        # Move batch to device
        batch = batch.to(self.device)
        
        with torch.inference_mode():
            predictions = [pred.to("cpu") for pred in rollout(self.model, batch, steps=steps)]
        
        return predictions

    def save_predictions_to_netcdf(self, predictions, output_dir, date_str):
        """Save each prediction step as separate NetCDF files in CAMS format"""
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        print(f"πŸ’Ύ Saving {len(predictions)} prediction steps as separate files")
        
        generation_date = datetime.now().strftime("%Y%m%d")
        saved_files = []

        for step_idx, pred in enumerate(predictions):
            step_num = step_idx + 1
            
            # Create filename: predictiondate_step_generationdate.nc
            filename = f"{date_str}_step{step_num:02d}_{generation_date}.nc"
            file_path = output_dir / filename
            
            # Extract coordinates from first prediction
            metadata = pred.metadata
            lats = metadata.lat.cpu().numpy() if hasattr(metadata.lat, 'cpu') else metadata.lat.numpy()
            lons = metadata.lon.cpu().numpy() if hasattr(metadata.lon, 'cpu') else metadata.lon.numpy()
            
            # Create CAMS-compatible coordinates and dimensions
            # CAMS format uses: forecast_period, forecast_reference_time, latitude, longitude
            coords = {
                'forecast_period': ('forecast_period', [0]),  # Single forecast period
                'forecast_reference_time': ('forecast_reference_time', [0, 1]),  # Two reference times (T-1, T)
                'latitude': ('latitude', lats),
                'longitude': ('longitude', lons)
            }
            
            # Add valid_time variable (CAMS format)
            data_vars = {
                'valid_time': (['forecast_reference_time', 'forecast_period'], 
                              np.array([[step_num * 12], [step_num * 12]]))  # Same forecast hours for both ref times
            }
            
            # Add surface variables in CAMS format: (forecast_period, forecast_reference_time, latitude, longitude)
            # Map Aurora variable names to CAMS variable names
            aurora_to_cams_surface = {
                '2t': 't2m',      # 2 metre temperature
                '10u': 'u10',     # 10 metre U wind component  
                '10v': 'v10',     # 10 metre V wind component
                'msl': 'msl',     # Mean sea level pressure (same)
                'pm1': 'pm1',     # PM1 (same)
                'pm2p5': 'pm2p5', # PM2.5 (same)
                'pm10': 'pm10',   # PM10 (same)
                'tcco': 'tcco',   # Total column CO (same)
                'tc_no': 'tc_no', # Total column NO (same)
                'tcno2': 'tcno2', # Total column NO2 (same)
                'gtco3': 'gtco3', # Total column O3 (same)
                'tcso2': 'tcso2'  # Total column SO2 (same)
            }
            
            for aurora_var, var_tensor in pred.surf_vars.items():
                cams_var = aurora_to_cams_surface.get(aurora_var, aurora_var)  # Use CAMS name or fallback to Aurora name
                
                var_data = var_tensor.cpu().numpy() if hasattr(var_tensor, 'cpu') else var_tensor.numpy()
                var_data = np.squeeze(var_data)
                
                # Ensure 2D for surface variables
                if var_data.ndim > 2:
                    while var_data.ndim > 2:
                        var_data = var_data[0]
                elif var_data.ndim < 2:
                    raise ValueError(f"Surface variable {aurora_var} has insufficient dimensions: {var_data.shape}")
                
                # Expand to CAMS format: (1, 2, lat, lon) - same data for both forecast reference times
                cams_data = np.broadcast_to(var_data[np.newaxis, np.newaxis, :, :], (1, 2, var_data.shape[0], var_data.shape[1]))
                data_vars[cams_var] = (['forecast_period', 'forecast_reference_time', 'latitude', 'longitude'], cams_data)
            
            # Add atmospheric variables if present
            # CAMS format: (forecast_period, forecast_reference_time, pressure_level, latitude, longitude)
            # Map Aurora atmospheric variable names to CAMS names
            aurora_to_cams_atmos = {
                't': 't',     # Temperature (same)
                'u': 'u',     # U wind component (same)
                'v': 'v',     # V wind component (same)
                'q': 'q',     # Specific humidity (same)
                'z': 'z',     # Geopotential (same)
                'co': 'co',   # Carbon monoxide (same)
                'no': 'no',   # Nitrogen monoxide (same)
                'no2': 'no2', # Nitrogen dioxide (same)
                'go3': 'go3', # Ozone (same)
                'so2': 'so2'  # Sulphur dioxide (same)
            }
            if hasattr(pred, 'atmos_vars') and pred.atmos_vars:
                atmos_levels = list(metadata.atmos_levels) if hasattr(metadata, 'atmos_levels') else None
                if atmos_levels:
                    coords['pressure_level'] = ('pressure_level', atmos_levels)
                    
                    for aurora_var, var_tensor in pred.atmos_vars.items():
                        cams_var = aurora_to_cams_atmos.get(aurora_var, aurora_var)  # Use CAMS name or fallback
                        
                        var_data = var_tensor.cpu().numpy() if hasattr(var_tensor, 'cpu') else var_tensor.numpy()
                        var_data = np.squeeze(var_data)
                        
                        # Ensure 3D for atmospheric variables (pressure, lat, lon)
                        if var_data.ndim > 3:
                            while var_data.ndim > 3:
                                var_data = var_data[0]
                        elif var_data.ndim < 3:
                            raise ValueError(f"Atmospheric variable {aurora_var} has insufficient dimensions: {var_data.shape}")
                        
                        # Expand to CAMS format: (1, 2, pressure, lat, lon) - same data for both forecast reference times
                        cams_data = np.broadcast_to(var_data[np.newaxis, np.newaxis, :, :, :], 
                                                  (1, 2, var_data.shape[0], var_data.shape[1], var_data.shape[2]))
                        data_vars[cams_var] = (['forecast_period', 'forecast_reference_time', 'pressure_level', 'latitude', 'longitude'], cams_data)
            
            # Create dataset for this step
            ds = xr.Dataset(data_vars, coords=coords)
            
            # Add attributes
            ds.attrs.update({
                'title': f'Aurora Air Pollution Prediction - Step {step_num}',
                'source': 'Aurora model by Microsoft Research',
                'prediction_date': date_str,
                'step': step_num,
                'forecast_hours': step_num * 12,
                'generation_date': generation_date,
                'creation_time': datetime.now().isoformat(),
                'spatial_resolution': f"{abs(lons[1] - lons[0]):.3f} degrees"
            })
            
            # Add variable attributes (using CAMS variable names)
            var_attrs = {
                't2m': {'long_name': '2 metre temperature', 'units': 'K'},
                'u10': {'long_name': '10 metre U wind component', 'units': 'm s-1'},
                'v10': {'long_name': '10 metre V wind component', 'units': 'm s-1'},
                'msl': {'long_name': 'Mean sea level pressure', 'units': 'Pa'},
                'pm1': {'long_name': 'Particulate matter d < 1 um', 'units': 'kg m-3'},
                'pm2p5': {'long_name': 'Particulate matter d < 2.5 um', 'units': 'kg m-3'},
                'pm10': {'long_name': 'Particulate matter d < 10 um', 'units': 'kg m-3'},
                'tcco': {'long_name': 'Total column carbon monoxide', 'units': 'kg m-2'},
                'tc_no': {'long_name': 'Total column nitrogen monoxide', 'units': 'kg m-2'},
                'tcno2': {'long_name': 'Total column nitrogen dioxide', 'units': 'kg m-2'},
                'gtco3': {'long_name': 'Total column ozone', 'units': 'kg m-2'},
                'tcso2': {'long_name': 'Total column sulphur dioxide', 'units': 'kg m-2'},
                # Atmospheric variables
                't': {'long_name': 'Temperature', 'units': 'K'},
                'u': {'long_name': 'U component of wind', 'units': 'm s-1'},
                'v': {'long_name': 'V component of wind', 'units': 'm s-1'},
                'q': {'long_name': 'Specific humidity', 'units': 'kg kg-1'},
                'z': {'long_name': 'Geopotential', 'units': 'm2 s-2'},
                'co': {'long_name': 'Carbon monoxide', 'units': 'kg kg-1'},
                'no': {'long_name': 'Nitrogen monoxide', 'units': 'kg kg-1'},
                'no2': {'long_name': 'Nitrogen dioxide', 'units': 'kg kg-1'},
                'go3': {'long_name': 'Ozone', 'units': 'kg kg-1'},
                'so2': {'long_name': 'Sulphur dioxide', 'units': 'kg kg-1'}
            }
            
            for var_name, attrs in var_attrs.items():
                if var_name in ds.data_vars:
                    ds[var_name].attrs.update(attrs)
            
            # Save to NetCDF
            ds.to_netcdf(file_path, format='NETCDF4')
            saved_files.append(str(file_path))
            print(f"   βœ… Step {step_num}: {filename}")
        
        print(f"βœ… Saved {len(saved_files)} prediction files")
        return saved_files

    def _save_predictions_single_file(self, predictions, output_path):
        """Save all prediction steps to a single NetCDF file (new method)"""
        # Get metadata from first prediction
        first_pred = predictions[0]
        metadata = first_pred.metadata

        # Extract coordinates
        lats = metadata.lat.cpu().numpy() if hasattr(metadata.lat, 'cpu') else metadata.lat.numpy()
        lons = metadata.lon.cpu().numpy() if hasattr(metadata.lon, 'cpu') else metadata.lon.numpy()

        # Create step coordinate
        steps = np.arange(len(predictions))

        # Prepare data variables
        data_vars = {}
        coords = {
            'step': ('step', steps),
            'lat': ('lat', lats),
            'lon': ('lon', lons)
        }

        # Add surface variables
        surf_var_names = list(first_pred.surf_vars.keys())
        for var in surf_var_names:
            # Stack predictions along step dimension
            var_data_list = []
            for pred in predictions:
                var_tensor = pred.surf_vars[var]
                # Move to CPU and convert to numpy
                var_data = var_tensor.cpu().numpy() if hasattr(var_tensor, 'cpu') else var_tensor.numpy()

                # Robust dimension handling: squeeze all singleton dimensions and keep only last 2 (lat, lon)
                var_data = np.squeeze(var_data)  # Remove all singleton dimensions

                # Ensure we have exactly 2 dimensions (lat, lon) for surface variables
                if var_data.ndim > 2:
                    # Take the last 2 dimensions as lat, lon
                    var_data = var_data[..., :, :]
                    # If still more than 2D, take the first slice of extra dimensions
                    while var_data.ndim > 2:
                        var_data = var_data[0]
                elif var_data.ndim < 2:
                    raise ValueError(f"Surface variable {var} has insufficient dimensions: {var_data.shape}")

                var_data_list.append(var_data)

            # Stack along step dimension: (steps, lat, lon)
            arr = np.stack(var_data_list, axis=0)
            data_vars[var] = (['step', 'lat', 'lon'], arr)

        # Add atmospheric variables if present
        if hasattr(first_pred, 'atmos_vars') and first_pred.atmos_vars:
            atmos_levels = list(metadata.atmos_levels) if hasattr(metadata, 'atmos_levels') else None
            if atmos_levels:
                coords['pressure_level'] = ('pressure_level', atmos_levels)

                atmos_var_names = list(first_pred.atmos_vars.keys())
                for var in atmos_var_names:
                    var_data_list = []
                    for pred in predictions:
                        var_tensor = pred.atmos_vars[var]
                        # Move to CPU and convert to numpy
                        var_data = var_tensor.cpu().numpy() if hasattr(var_tensor, 'cpu') else var_tensor.numpy()

                        # Robust dimension handling: squeeze singleton dimensions but keep 3D structure
                        var_data = np.squeeze(var_data)  # Remove singleton dimensions

                        # Ensure we have exactly 3 dimensions (levels, lat, lon) for atmospheric variables
                        if var_data.ndim > 3:
                            # Take the last 3 dimensions as levels, lat, lon
                            var_data = var_data[..., :, :, :]
                            # If still more than 3D, take the first slice of extra dimensions
                            while var_data.ndim > 3:
                                var_data = var_data[0]
                        elif var_data.ndim < 3:
                            raise ValueError(f"Atmospheric variable {var} has insufficient dimensions: {var_data.shape}")

                        var_data_list.append(var_data)

                    # Stack along step dimension: (steps, levels, lat, lon)
                    arr = np.stack(var_data_list, axis=0)
                    data_vars[var] = (['step', 'pressure_level', 'lat', 'lon'], arr)

        # Create dataset
        ds = xr.Dataset(data_vars, coords=coords)

        # Add global attributes
        ds.attrs.update({
            'title': 'Aurora Air Pollution Model Predictions',
            'source': 'Aurora model by Microsoft Research',
            'creation_date': datetime.now().isoformat(),
            'forecast_steps': len(predictions),
            'spatial_resolution': f"{abs(lons[1] - lons[0]):.3f} degrees",
            'conventions': 'CF-1.8'
        })

        # Add variable attributes for better visualization
        var_attrs = {
            '2t': {'long_name': '2 metre temperature', 'units': 'K'},
            '10u': {'long_name': '10 metre U wind component', 'units': 'm s-1'},
            '10v': {'long_name': '10 metre V wind component', 'units': 'm s-1'},
            'msl': {'long_name': 'Mean sea level pressure', 'units': 'Pa'},
            'pm1': {'long_name': 'Particulate matter d < 1 um', 'units': 'kg m-3'},
            'pm2p5': {'long_name': 'Particulate matter d < 2.5 um', 'units': 'kg m-3'},
            'pm10': {'long_name': 'Particulate matter d < 10 um', 'units': 'kg m-3'},
            'tcco': {'long_name': 'Total column carbon monoxide', 'units': 'kg m-2'},
            'tc_no': {'long_name': 'Total column nitrogen monoxide', 'units': 'kg m-2'},
            'tcno2': {'long_name': 'Total column nitrogen dioxide', 'units': 'kg m-2'},
            'gtco3': {'long_name': 'Total column ozone', 'units': 'kg m-2'},
            'tcso2': {'long_name': 'Total column sulphur dioxide', 'units': 'kg m-2'}
        }

        for var_name, attrs in var_attrs.items():
            if var_name in ds.data_vars:
                ds[var_name].attrs.update(attrs)

        # Save to NetCDF
        ds.to_netcdf(output_path, format='NETCDF4')
        print(f"βœ… Predictions saved to {output_path}")
        print(f"   Variables: {list(ds.data_vars.keys())}")
        print(f"   Steps: {len(steps)}")
        print(f"   Spatial grid: {len(lats)}x{len(lons)}")

        return output_path

    def _save_predictions_original_method(self, predictions, output_path):
        """Fallback: Save predictions using the original method (separate files per step)"""
        output_dir = Path(output_path)
        output_dir.mkdir(exist_ok=True)

        for step, pred in enumerate(predictions):
            # Create xarray dataset for surface variables
            surf_data = {}
            for var_name, var_data in pred.surf_vars.items():
                surf_data[var_name] = (
                    ["time", "batch", "lat", "lon"],
                    var_data.cpu().numpy() if hasattr(var_data, 'cpu') else var_data.numpy()
                )

            # Create xarray dataset for atmospheric variables
            atmos_data = {}
            for var_name, var_data in pred.atmos_vars.items():
                atmos_data[var_name] = (
                    ["time", "batch", "level", "lat", "lon"],
                    var_data.cpu().numpy() if hasattr(var_data, 'cpu') else var_data.numpy()
                )

            # Create surface dataset
            surf_ds = xr.Dataset(
                surf_data,
                coords={
                    "time": [pred.metadata.time[0]],
                    "batch": [0],
                    "lat": pred.metadata.lat.cpu().numpy() if hasattr(pred.metadata.lat, 'cpu') else pred.metadata.lat.numpy(),
                    "lon": pred.metadata.lon.cpu().numpy() if hasattr(pred.metadata.lon, 'cpu') else pred.metadata.lon.numpy(),
                }
            )

            # Create atmospheric dataset
            atmos_ds = xr.Dataset(
                atmos_data,
                coords={
                    "time": [pred.metadata.time[0]],
                    "batch": [0],
                    "level": list(pred.metadata.atmos_levels),
                    "lat": pred.metadata.lat.cpu().numpy() if hasattr(pred.metadata.lat, 'cpu') else pred.metadata.lat.numpy(),
                    "lon": pred.metadata.lon.cpu().numpy() if hasattr(pred.metadata.lon, 'cpu') else pred.metadata.lon.numpy(),
                }
            )

            # Save to NetCDF
            surf_filename = f"step_{step:02d}_surface.nc"
            atmos_filename = f"step_{step:02d}_atmospheric.nc"

            surf_ds.to_netcdf(output_dir / surf_filename)
            atmos_ds.to_netcdf(output_dir / atmos_filename)

            print(f"Saved step {step} predictions (fallback method)")

        return output_dir

    def run_pipeline(self, date_str, Batch, Metadata, AuroraAirPollution, rollout, steps=4, output_path=None):
        """Full pipeline: batch creation, model loading, prediction, save output"""
        batch = self.create_batch(date_str, Batch, Metadata)
        self.load_model(AuroraAirPollution)
        predictions = self.predict(batch, rollout, steps=steps)
        if output_path:
            self.save_predictions_to_netcdf(predictions, output_path)
        return predictions

    def run_aurora_prediction_pipeline(self, date_str, Batch, Metadata, AuroraAirPollution, rollout, steps=4, base_predictions_dir="predictions"):
        """Enhanced Aurora prediction pipeline with organized storage"""
        print(f"πŸš€ Starting Aurora prediction pipeline for {date_str}")
        print(f"πŸ“Š Forward prediction steps: {steps} (covering {steps * 12} hours)")
        
        # Create organized directory structure
        run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        run_dir = Path(base_predictions_dir) / f"{date_str}_run_{run_timestamp}"
        run_dir.mkdir(parents=True, exist_ok=True)
        
        # Load model once
        print("🧠 Loading Aurora model...")
        self.load_model(AuroraAirPollution)
        
        # Use the latest timestamp (index 1) for prediction
        print("πŸ“₯ Creating input batch for T (second time)...")
        batch = self.create_batch(date_str, Batch, Metadata, time_index=1)
        
        # Run predictions
        print(f"⚑ Running {steps} prediction steps...")
        predictions = self.predict(batch, rollout, steps=steps)
        
        # Save predictions as separate files
        saved_files = self.save_predictions_to_netcdf(predictions, run_dir, date_str)
        
        # Save metadata about the run
        run_metadata = {
            "date": date_str,
            "run_timestamp": run_timestamp,
            "steps": steps,
            "time_coverage_hours": steps * 12,
            "input_times": ["T-1", "T"],
            "prediction_files": saved_files,
            "run_directory": str(run_dir)
        }
        
        metadata_file = run_dir / "run_metadata.json"
        with open(metadata_file, 'w') as f:
            import json
            json.dump(run_metadata, f, indent=2)
        
        print(f"βœ… Aurora prediction pipeline completed")
        print(f"πŸ“ Results saved to: {run_dir}")
        print(f"πŸ“Š Coverage: {steps * 12} hours forward from {date_str}")
        
        return run_metadata

    @staticmethod
    def list_prediction_runs(base_predictions_dir="predictions"):
        """List all available prediction runs with metadata"""
        runs = []
        predictions_path = Path(base_predictions_dir)
        
        if not predictions_path.exists():
            return runs
        
        for run_dir in predictions_path.iterdir():
            if run_dir.is_dir() and "_run_" in run_dir.name:
                metadata_file = run_dir / "run_metadata.json"
                
                if metadata_file.exists():
                    try:
                        import json
                        with open(metadata_file, 'r') as f:
                            metadata = json.load(f)
                        
                        # Check if any prediction files exist (new format with separate step files)
                        nc_files = list(run_dir.glob("*.nc"))
                        has_predictions = len(nc_files) > 0
                        
                        # Add additional info
                        metadata['available'] = has_predictions
                        metadata['run_dir'] = str(run_dir)
                        metadata['relative_path'] = run_dir.name
                        metadata['prediction_files'] = [f.name for f in nc_files]
                        metadata['num_files'] = len(nc_files)
                        
                        runs.append(metadata)
                    except Exception as e:
                        print(f"⚠️  Could not read metadata for {run_dir}: {e}")
        
        # Sort by run timestamp (newest first)
        runs.sort(key=lambda x: x.get('run_timestamp', ''), reverse=True)
        return runs

# Example usage (not run on import)
if __name__ == "__main__":
    pass