This module provides autoregressive wrappers for various spatial models (FNO, Diffusion) to handle temporal evolution in physics simulations, particularly for MHD and plasma physics applications.
This repo implements a flexible framework for autoregressive temporal modeling with different spatial processing backends:
- Fourier Neural Operator (FNO): Efficient spectral method for spatial processing
- Diffusion Models: Generative approach for spatial field prediction
- Modular Design: Easy to extend with new spatial models via
AutoregressiveBase
AutoregressiveBase (Abstract)
├── AutoregressiveFNO (FNO backend)
└── AutoregressiveDiffusion (Diffusion backend)
AutoregressiveBase: Abstract base class handling temporal logicLoss: Combined loss with relative Lp and temporal consistency termsMultivariableFNO: FNO implementation for multi-variable spatial fieldsSpatialDiffusionModel: Diffusion model for spatial field generation
import AutoregressiveFNO, MultivariableFNO, Loss
import torch
# Create base FNO model
fno = MultivariableFNO(
n_vars=1,
n_modes=(8, 8),
hidden_channels=16,
n_layers=4
)
# Wrap with autoregressive functionality
model = AutoregressiveFNO(
fno_model=fno,
t_in=10, # Input timesteps
t_out=5, # Output timesteps to predict
step_size=1, # Single-step prediction
teacher_forcing_ratio=0.5
)
# Create loss function
loss_fn = Loss(
d=2, # Spatial dimension
p=2, # L2 norm
alpha_temporal=0.5, # Temporal consistency weight
learnable=False # Fixed weights
)
# Training loop (simplified)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for batch in train_loader:
x = batch["input_fields"] # [B, t_in, n_vars, H, W]
y = batch["output_fields"] # [B, t_out, n_vars, H, W]
pred = model(x, target=y, grid=grid)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()# Basic training
python train_ar_fno.py \
--data-path data/bout_ml_dataset_256.pt \
--experiment-name my_experiment \
--n-epochs 500 \
--batch-size 16
# With custom hyperparameters
python train_ar_fno.py \
--experiment-name advanced_experiment \
--n-modes 12 12 \
--hidden-channels 32 \
--n-layers 6 \
--t-in 10 \
--t-out 5 \
--alpha-temporal 0.3 \
--lr 1e-3 \
--grad-clip 1.0
# Resume training from checkpoint
python train_ar_fno.py \
--experiment-name my_experiment \
--resume
# With teacher forcing schedule
python train_ar_fno.py \
--experiment-name tf_experiment \
--tf-schedule "10:0.9,30:0.7,60:0.5,120:0.3"--n-modes: Number of Fourier modes (default: [8, 8])--hidden-channels: Hidden channels in FNO (default: 16)--n-layers: Number of FNO layers (default: 4)--t-in: Input timesteps (default: 10)--t-out: Output timesteps (default: 5)--step-size: Autoregressive step size (default: 1)
--n-epochs: Number of epochs (default: 500)--batch-size: Batch size (default: 16)--lr: Learning rate (default: 1e-3)--weight-decay: Weight decay (default: 1e-5)--grad-clip: Gradient clipping (default: 1.0)
--alpha-temporal: Temporal consistency weight (default: 0.5)--learnable-loss: Use learnable loss weights (flag)
Teacher forcing helps stabilize training by occasionally using ground truth instead of predictions during autoregressive rollout.
--tf-initial: Initial teacher forcing ratio (default: 1.0)--tf-schedule: Schedule as "epoch:ratio" pairs (default: "150:1.0,250:0.9,300:0.7,5000:0.5")
The module expects data in the following format:
{
'data': torch.Tensor, # [n_trajectories, n_timesteps, H, W, n_vars]
'constant_scalars': torch.Tensor, # [n_trajectories, n_scalars]
'constant_fields': torch.Tensor # [H, W, n_fields]
}from load_data import load_data, parse_args
args = parse_args()
train_loader, val_loader, full_data = load_data(args)The AutoregressiveBase class provides:
- Temporal encoding via 3D convolutions
- Multi-step prediction support
- Teacher forcing during training
- Pure inference without teacher forcing
The Loss module combines:
- Relative Lp Loss: Normalizes by target magnitude
- Temporal Consistency Loss: Penalizes inconsistent temporal derivatives
- Learnable Weights: Optional automatic loss weight balancing
Automatic checkpoint loading and saving:
- Saves best model based on validation loss
- Stores optimizer and scheduler states
- Preserves normalization statistics
- Tracks training history
- Prediction GIFs for qualitative assessment
- Loss component tracking
- Long-term prediction rollout
- Error evolution plots
Extend AutoregressiveBase to use custom spatial models:
import AutoregressiveBase
class AutoregressiveCustom(AutoregressiveBase):
def __init__(self, custom_model, **kwargs):
super().__init__(n_vars=custom_model.n_vars, **kwargs)
self.custom_model = custom_model
def _spatial_forward(self, x, grid=None):
# Implement your spatial processing
return self.custom_model(x, grid)Enable automatic loss weight learning:
loss_fn = Loss(
d=2,
p=2,
alpha_temporal=0.5,
learnable=True # Enable learnable weights
)
# Include loss parameters in optimizer
optimizer = torch.optim.Adam(
list(model.parameters()) + list(loss_fn.parameters()),
lr=1e-3
)Configure the model for multi-step prediction:
model = AutoregressiveFNO(
fno_model=fno,
t_in=10,
t_out=20, # Predict 20 steps
step_size=5, # Predict 5 steps at once
teacher_forcing_ratio=0.5
)Training creates the following directory structure:
results/
└── experiment_name/
├── checkpoints/
│ └── best_model.pth
├── visualizations/
│ ├── predictions_epoch_*.png
│ └── ...
├── gifs/
│ ├── training_epoch_*.gif
│ └── ...
├── plots/
│ ├── loss_history.png
│ └── loss_weights_history.png
├── long_term_predictions/
│ ├── error_evolution.png
│ └── prediction_*.gif
└── config.txt
- Start with teacher forcing: Use high initial teacher forcing ratio (0.9-1.0)
- Gradually reduce: Schedule teacher forcing reduction over training
- Monitor loss components: Watch relative Lp vs temporal consistency
- Validate frequently: Check validation performance every 10-20 epochs
- Use gradient clipping: Prevent exploding gradients (clip at 1.0)
- Normalize data: Compute and use dataset statistics
- Save checkpoints: Enable automatic checkpoint saving
- Reduce learning rate
- Increase gradient clipping
- Adjust teacher forcing schedule
- Check data normalization
- Increase
t_in(more context) - Reduce
step_size(finer predictions) - Tune
alpha_temporal(temporal consistency) - Use teacher forcing longer
- Reduce
batch_size - Reduce
hidden_channels - Reduce
n_layers - Use gradient checkpointing