diff --git a/README.md b/README.md index 6a08ed9..1984529 100644 --- a/README.md +++ b/README.md @@ -68,16 +68,15 @@ agent = PixelBasedAgent(env, config, training_config, pixel_config) ```bash # Train with state observations -python examples/train_state_mujoco.py --env HalfCheetah-v4 +python examples/train_mujoco.py --env HalfCheetah-v4 # Train with pixel observations -python examples/train_pixel_mujoco.py --env HalfCheetah-v4 - -# Resume from checkpoint -python examples/train_state_mujoco.py --env HalfCheetah-v4 --resume +python examples/train_mujoco.py --env HalfCheetah-v4 --pixels # Use custom config -python examples/train_pixel_mujoco.py --env Hopper-v4 --config examples/configs/hopper_pixel.yaml +python examples/train_mujoco.py --env Hopper-v4 --pixels --config examples/configs/hopper_pixel.yaml + + ``` ### Configuration Files diff --git a/active_inference_diffusion/agents/base_agent.py b/active_inference_diffusion/agents/base_agent.py index 3be5be9..417a035 100644 --- a/active_inference_diffusion/agents/base_agent.py +++ b/active_inference_diffusion/agents/base_agent.py @@ -87,7 +87,7 @@ def __init__( self.total_steps = 0 self.episode_count = 0 self.exploration_noise = training_config.exploration_noise - self.reward_normalizer = RunningMeanStd(shape=()) + @abstractmethod def _setup_dimensions(self): @@ -106,38 +106,8 @@ def _create_replay_buffer(self) -> ReplayBuffer: def _setup_optimizers(self): """Setup optimizers""" - # Score network optimizer - self.score_optimizer = torch.optim.Adam( - self.active_inference.score_network.parameters(), - lr=self.config.learning_rate - ) - - # Policy optimizer - self.policy_optimizer = torch.optim.Adam( - self.active_inference.policy_network.parameters(), - lr=self.config.learning_rate - ) - - # Value optimizer - self.value_optimizer = torch.optim.Adam( - self.active_inference.value_network.parameters(), - lr=self.config.learning_rate - ) - - # Dynamics optimizer - self.dynamics_optimizer = torch.optim.Adam( - list(self.active_inference.dynamics_model.parameters()) + - list(self.active_inference.reward_predictor.parameters()), - lr=self.config.learning_rate - ) - #Add epistemic optimizer - self.epistemic_optimizer = torch.optim.Adam( - self.active_inference.epistemic_estimator.parameters(), - lr=self.config.learning_rate*0.1, - weight_decay=1e-5 - ) - self.active_inference.epistemic_optimizer = self.epistemic_optimizer - + pass + def act( self, observation: np.ndarray, diff --git a/active_inference_diffusion/agents/pixel_agent.py b/active_inference_diffusion/agents/pixel_agent.py index 41c51d8..2a38b0b 100644 --- a/active_inference_diffusion/agents/pixel_agent.py +++ b/active_inference_diffusion/agents/pixel_agent.py @@ -11,10 +11,10 @@ from typing import Dict, Tuple, Optional, Any from .base_agent import BaseActiveInferenceAgent -from ..core.active_inference import DiffusionActiveInference +from ..core.active_inference import DiffusionActiveInference, EMAModel from ..encoder.visual_encoders import RandomShiftAugmentation, DrQV2Encoder from ..encoder.state_encoders import EncoderFactory -from ..utils.buffers import ReplayBuffer +from ..utils.buffers import ReplayBuffer, SequenceReplayBuffer, PrioritizedSequenceReplayBuffer from ..configs.config import ( ActiveInferenceConfig, PixelObservationConfig, @@ -86,14 +86,17 @@ def _build_models(self): obs_shape=self.obs_shape, feature_dim=self.config.latent_dim, frame_stack=self.pixel_config.frame_stack, - num_layers=4, + num_layers=self.pixel_config.num_layers, num_filters=32, - ) + use_spectral_norm=True, + attention='global', + checkpoint_trunk=True, + checkpoint_attention=True, + checkpoint_head=False + ).to(self.device) # Augmentation module - self.augmentation = RandomShiftAugmentation( - pad=self.pixel_config.random_shift_pad - ) if self.pixel_config.augmentation else None + self.augmentation = RandomShiftAugmentation(pad=self.pixel_config.random_shift_pad) if self.pixel_config.augmentation else None # Core diffusion active inference # Uses encoder output dimension as observation dimension @@ -102,9 +105,10 @@ def _build_models(self): action_dim=self.action_dim, latent_dim=self.config.latent_dim, config=self.config, - pixel_shape=self.obs_shape if self.pixel_config.pixel_observation else None + pixel_shape=self.obs_shape if self.pixel_config.pixel_observation else None, + shared_visual_encoder=self.encoder ) - + self.value_ema = EMAModel(self.active_inference.value_network, decay=0.9999, device=self.device) # Move all components to device self.encoder = self.encoder.to(self.device) @@ -113,7 +117,9 @@ def _build_models(self): def act( self, observation: np.ndarray, - deterministic: bool = False + deterministic: bool = False, + frame_idx: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, ) -> Tuple[np.ndarray, Dict[str, Any]]: """ Select action using visual encoding and diffusion active inference @@ -135,7 +141,10 @@ def act( # Use diffusion active inference with encoded features action_tensor, info = self.active_inference.act( encoded_obs.squeeze(0), # Remove batch dimension - deterministic=deterministic + deterministic=deterministic, + raw_observation=obs_tensor, + from_idx= frame_idx, + actions=actions ) # Convert to numpy @@ -165,6 +174,13 @@ def act( def encode_observation(self, observation: torch.Tensor) -> torch.Tensor: """Encode pixel observation to feature space with augmentation""" + if observation.device != next(self.encoder.parameters()).device: + observation = observation.to(next(self.encoder.parameters()).device) + + # If observation is uint8, normalize to [0, 1] + if observation.dtype == torch.uint8: + observation = observation.float() / 255.0 + # Apply augmentation during training # Handle different input formats if observation.ndim == 5: # (batch, frame_stack, C, H, W) @@ -205,13 +221,30 @@ def _create_replay_buffer(self) -> ReplayBuffer: buffer_obs_shape = (self.frame_stack, *self.obs_shape) else: buffer_obs_shape = self.obs_shape - return ReplayBuffer( - capacity=self.training_config.buffer_size, - obs_shape=buffer_obs_shape, - action_dim=self.action_dim, - device=self.device, - optimize_memory=True # Enable compression for pixels - ) + use_psr = self.training_config.prioritized_seq_replay + if use_psr: + return PrioritizedSequenceReplayBuffer( + capacity=self.training_config.buffer_size, + obs_shape=buffer_obs_shape, # you already compute this + action_dim=self.action_dim, + device=self.device, + sequence_length=getattr(self.config, "sequence_length", 10), + overlap=getattr(self.config, "sequence_overlap", 5), + alpha=self.training_config.alpha, + beta_start=self.training_config.beta0, + beta_end=self.training_config.beta1, + beta_frames=self.training_config.beta_frms, + eps=self.training_config.eps, + ) + else: + return SequenceReplayBuffer( + capacity=self.training_config.buffer_size, + obs_shape=buffer_obs_shape, + action_dim=self.action_dim, + device=self.device, + sequence_length=self.config.sequence_length if hasattr(self.config, 'sequence_length') else 10, + overlap=5 + ) def _process_observation(self, observation: np.ndarray) -> torch.Tensor: """Convert pixel observation to tensor with proper formatting""" @@ -271,159 +304,277 @@ def _process_batch_observations(self, observations: torch.Tensor) -> torch.Tenso return observations def train_step(self) -> Dict[str, float]: - """Enhanced training step with visual representation learning""" + """ + One training iteration: + - Uniform single-step batch for ELBO, policy (EFE), and value. + - PER-enabled sequence-dynamics block every N steps (sample_sequences), + with IS-weighted loss and priority update (TD + per-sequence dynamics NLL). + """ if len(self.replay_buffer) < self.config.batch_size: return {} - - # Sample batch + + import math + metrics: Dict[str, float] = {} + + # ========================= + # 1) Uniform single-step batch + # ========================= batch = self.replay_buffer.sample(self.config.batch_size) - - # Process observations - obs = self._process_batch_observations(batch['observations']) - next_obs = self._process_batch_observations(batch['next_observations']) - actions = batch['actions'].to(self.device) - rewards = batch['rewards'].to(self.device) - dones = batch['dones'].to(self.device) - - metrics = {} - - # 1. Encode observations to feature space - encoded_obs = self.encode_observation(obs) + + # Move to device / pre-process + obs = self._process_batch_observations(batch['observations']) + next_obs = self._process_batch_observations(batch['next_observations']) + actions = batch['actions'].to(self.device) + rewards = batch['rewards'].to(self.device) + dones = batch['dones'].to(self.device) + frame_idx = batch['frame_idx'].to(self.device) + prev_acts = batch['prev_actions'].to(self.device) + B0 = self.config.batch_size + + # Encode observations + encoded_obs = self.encode_observation(obs) encoded_next_obs = self.encode_observation(next_obs) - # First, update reward normalizer statistics - self.reward_normalizer.update(rewards.cpu().numpy()) + if torch.isnan(encoded_obs).any() or torch.isinf(encoded_obs).any(): + raise ValueError("Encoded observation contains NaN/Inf") - # Normalize rewards - normalized_rewards = torch.tensor( - self.reward_normalizer.normalize(rewards.cpu().numpy()), - device=self.device, - dtype=torch.float32 - ) - # 2. Generate latents via diffusion + # Belief update (no grad) with torch.no_grad(): - belief_info = self.active_inference.update_belief_via_diffusion(encoded_obs) - latents = belief_info['latent'] - - next_belief_info = self.active_inference.update_belief_via_diffusion(encoded_next_obs) - next_latents = next_belief_info['latent'] - torch.nn.utils.clip_grad_norm_(self.active_inference.latent_score_network.parameters(), - 0.1) - # 3. Train diffusion components + belief_now = self.active_inference.update_belief_via_diffusion( + encoded_obs, frame_idx=frame_idx, actions=prev_acts + ) + latents = belief_now['latent'] + latents_mean = belief_now['latent_mean'] + latents_std = belief_now['latent_std'] + + belief_next = self.active_inference.update_belief_via_diffusion( + encoded_next_obs, frame_idx=frame_idx + 1, actions=actions + ) + next_latents = belief_next['latent'] + + torch.nn.utils.clip_grad_norm_( + self.active_inference.latent_score_network.parameters(), max_norm=1.0 + ) + + # ------------------------- + # 1a) Diffusion ELBO (+ contrastive) — UNWEIGHTED (uniform batch) + # ------------------------- self.score_optimizer.zero_grad() elbo_loss, elbo_info = self.active_inference.compute_diffusion_elbo( - encoded_obs, normalized_rewards, latents + encoded_obs, rewards, raw_observations=obs, frame_index=frame_idx, actions=prev_acts ) - - # 4. Add contrastive representation loss + + hidden_states = self.active_inference.reset_dynamics_hidden(B0) + hidden_states = self.active_inference._reset_done_hidden(hidden_states, dones) + contrastive_loss = self.compute_representation_loss( - encoded_obs, encoded_next_obs, actions, latents, next_latents + encoded_obs, encoded_next_obs, actions, latents, hidden_states ) - - # Combined loss + total_loss = elbo_loss + self.config.contrastive_weight * contrastive_loss total_loss.backward() - + torch.nn.utils.clip_grad_norm_( list(self.active_inference.latent_score_network.parameters()) + list(self.active_inference.latent_diffusion.parameters()) + - list(self.encoder.parameters()), + list(self.encoder.parameters()) + + list(self.active_inference.feature_decoder.parameters()) + + (list(self.active_inference.observation_decoder.parameters()) + if isinstance(self.active_inference.observation_decoder, torch.nn.Module) else []), self.config.gradient_clip ) self.score_optimizer.step() self.score_ema.update() - - metrics.update(elbo_info) - metrics['contrastive_loss'] = contrastive_loss.item() - - # 5. Train policy network + + metrics.update({k: (float(v) if torch.is_tensor(v) else v) for k, v in elbo_info.items()}) + metrics['contrastive_loss'] = float(contrastive_loss.detach()) + metrics['total_loss'] = float(total_loss.detach()) + + # ------------------------- + # 1b) Policy (EFE) — UNWEIGHTED (uniform batch) + # ------------------------- self.policy_optimizer.zero_grad() - + self.active_inference.latent_dynamics.train() + efe, efe_info = self.active_inference.compute_expected_free_energy_diffusion( - latents, - horizon=self.config.efe_horizon + latents, horizon=self.config.efe_horizon ) - policy_loss = efe.mean() policy_loss.backward() - - torch.nn.utils.clip_grad_norm_( - self.active_inference.policy_network.parameters(), - self.config.gradient_clip - ) + torch.nn.utils.clip_grad_norm_(self.active_inference.policy_network.parameters(), self.config.gradient_clip) self.policy_optimizer.step() - - metrics['policy_loss'] = policy_loss.item() - metrics.update(efe_info) - - # 6. Train value network + + metrics['policy_loss'] = float(policy_loss.detach()) + metrics.update({f'efe_{k}': (float(v) if torch.is_tensor(v) else v) for k, v in efe_info.items()}) + + # ------------------------- + # 1c) Value (critic) with EMA bootstrap — UNWEIGHTED (uniform batch) + # ------------------------- self.value_optimizer.zero_grad() + logits = self.active_inference.value_network(latents) - batch_size = latents.shape[0] - time_current = torch.zeros(batch_size, device=self.device) - time_next = torch.ones(batch_size, device=self.device) # Next timestep - values = self.active_inference.value_network(latents, time_current).squeeze(-1) - # Predict values with time conditioning with torch.no_grad(): - next_values = self.active_inference.value_network(next_latents, time_next).squeeze(-1) - targets = self.active_inference.compute_lambda_returns( - rewards=normalized_rewards, - values=values, - next_values=next_values, - dones=dones, - lambda_=0.95, # TODO: can be added to config - n_steps=5 - ) - - value_loss = F.huber_loss(values, targets) + self.value_ema.apply_shadow() + next_logits = self.active_inference.value_network(next_latents) + self.value_ema.restore() + next_values = self.active_inference.value_network.expected_value(next_logits) + + lambda_returns = self.active_inference.compute_lambda_returns( + rewards=rewards, next_values=next_values, dones=dones, lambda_=0.95, n_steps=5 + ) + + value_loss = self.active_inference.value_network.loss_from_returns(logits, lambda_returns).mean() value_loss.backward() - - torch.nn.utils.clip_grad_norm_( - self.active_inference.value_network.parameters(), - self.config.gradient_clip - ) + torch.nn.utils.clip_grad_norm_(self.active_inference.value_network.parameters(), self.config.gradient_clip) self.value_optimizer.step() - - metrics['value_loss'] = value_loss.item() - # Train epistemic estimator separately - if self.total_steps % 5 == 0: # Train less frequently for stability + self.value_ema.update() + + metrics['value_loss'] = float(value_loss.detach()) + + # (optional) epistemic estimator every few steps + if self.total_steps % 5 == 0: epistemic_mi, epistemic_metrics = self.active_inference.train_epistemic_estimator( latents, actions, next_latents ) - metrics['epistemic_mi'] = epistemic_mi - metrics.update(epistemic_metrics) + metrics['epistemic_mi'] = float(epistemic_mi) + metrics.update({f'ep_{k}': (float(v) if torch.is_tensor(v) else v) for k, v in epistemic_metrics.items()}) + + # =========================================== + # 2) PER-enabled sequence dynamics every N steps + # =========================================== + if (self.total_steps % 5 == 0) and hasattr(self.replay_buffer, 'sample_sequences'): + seq_batch = self.replay_buffer.sample_sequences(self.config.batch_size // 2) + if seq_batch is not None: + use_per = ('is_weights' in seq_batch) and ('tree_indices' in seq_batch) + is_w = seq_batch['is_weights'] if use_per else None # [B] + tree_idx = seq_batch['tree_indices'] if use_per else None # [B] + + seq_obs = seq_batch['observations'] # [B, T, ...] + B, T = seq_obs.shape[:2] + + # ---- 2a) Build latent_sequences over the window (no grad) ---- + latent_sequences = [] + for t in range(T): + obs_t = self._process_batch_observations(seq_obs[:, t]) + enc_t = self.encode_observation(obs_t) + with torch.no_grad(): + if t == 0: + prev_t = (seq_batch['prev_actions'][:, 0].to(self.device) + if 'prev_actions' in seq_batch else + torch.zeros(B, self.action_dim, device=self.device)) + else: + prev_t = seq_batch['prev_actions'][:, t - 1].to(self.device) + + fidx_all = seq_batch.get('frame_indices', seq_batch.get('frame_idx')) + fidx_t = fidx_all[:, t].to(self.device) if fidx_all is not None else None + + belief_t = self.active_inference.update_belief_via_diffusion( + enc_t, frame_idx=fidx_t, actions=prev_t + ) + latent_sequences.append(belief_t['latent']) + latent_sequences = torch.stack(latent_sequences, dim=1) # [B, T, D] + + # ---- 2b) Train dynamics on sequences (IS-weighted if PER) ---- + self.dynamics_optimizer.zero_grad() + dyn_kwargs = {} + if is_w is not None: + dyn_kwargs['sample_weights'] = is_w # importance sampling correction + + dynamics_metrics = self.active_inference.train_dynamics_on_sequence( + latent_sequences, # [B, T, D] + seq_batch['actions'], # [B, T-1, A] + seq_batch['dones'], # [B, T-1] + seq_batch['lengths'], # [B] + **dyn_kwargs + ) + + dyn_loss_tensor = dynamics_metrics.get('dynamics_loss', None) + if not torch.is_tensor(dyn_loss_tensor): + # safety: ensure we backward a tensor + dyn_loss_tensor = self.active_inference.train_dynamics_on_sequence( + latent_sequences, + seq_batch['actions'], + seq_batch['dones'], + seq_batch['lengths'], + **dyn_kwargs + )['dynamics_loss'] + + dyn_loss_tensor.backward() + torch.nn.utils.clip_grad_norm_(self.active_inference.latent_dynamics.parameters(), self.config.gradient_clip) + self.dynamics_optimizer.step() + + # ---- 2c) Compute priorities & update tree (PER only) ---- + if use_per and hasattr(self.replay_buffer, 'update_priorities'): + with torch.no_grad(): + # TD errors per step over the window + D = latent_sequences.shape[-1] + v_all = self.active_inference.value_network.expected_value( + self.active_inference.value_network(latent_sequences.reshape(-1, D)) + ).view(B, T) # [B, T] + + r = seq_batch['rewards'][:, :T-1].to(self.device) # [B, T-1] + d = seq_batch['dones'][:, :T-1].float().to(self.device) # [B, T-1] + gamma = float(self.config.discount_factor) + td = r + gamma * (1.0 - d) * v_all[:, 1:] - v_all[:, :-1] # [B, T-1] + + # Per-sequence dynamics NLL (masked mean) + t_idx = torch.arange(T-1, device=self.device).unsqueeze(0).expand(B, -1) + mask = (t_idx < (seq_batch['lengths'].to(self.device).unsqueeze(1) - 1)).float() # [B, T-1] + + per_seq_sum = torch.zeros(B, device=self.device) + per_seq_cnt = torch.zeros(B, device=self.device) + log_2pi = math.log(2.0 * math.pi) + for t in range(T - 1): + cur = latent_sequences[:, t] # [B, D] + nxt = latent_sequences[:, t + 1] # [B, D] + act = seq_batch['actions'][:, t].to(self.device) # [B, A] + + mean_t, logv_t, _ = self.active_inference.predict_next_latent(cur, act, None) + step_nll = 0.5 * ( + log_2pi + logv_t + (nxt - mean_t).pow(2) / logv_t.exp() + ).sum(dim=-1) # [B] + + m = mask[:, t] + per_seq_sum += step_nll * m + per_seq_cnt += m + + eps = torch.finfo(torch.float32).eps + dyn_nll_seq = per_seq_sum / (per_seq_cnt + eps) # [B] + + # Raw priorities from TD (timewise) + per-sequence dynamics surprise + priorities = self.replay_buffer.compute_sequence_priority( + td_errors=td, # [B, T-1] (reduced inside as max/abs) + elbo_losses=None, # add when you have per-seq ELBO from the SAME seq_batch + dynamics_losses=dyn_nll_seq, # [B] (reduced inside as mean) + weights=(1.0, 0.0, 0.5) + ) + + self.replay_buffer.update_priorities(tree_idx, priorities) + + # Log sequence metrics + metrics.update({f'seq_{k}': (float(v) if torch.is_tensor(v) else v) + for k, v in dynamics_metrics.items()}) - # 7. Train dynamics model - self.dynamics_optimizer.zero_grad() - - predicted_next_latents, predicted_next_logvar = self.active_inference.predict_next_latent(latents, actions) - dynamics_loss = F.mse_loss(predicted_next_latents, next_latents) - dynamics_loss.backward() - - torch.nn.utils.clip_grad_norm_( - self.active_inference.latent_dynamics.parameters(), - self.config.gradient_clip - ) - self.dynamics_optimizer.step() - - metrics['dynamics_loss'] = dynamics_loss.item() self.total_steps += 1 - return metrics - + def compute_representation_loss( self, obs: torch.Tensor, next_obs: torch.Tensor, actions: torch.Tensor, latents: torch.Tensor, - next_latents: torch.Tensor + hidden_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None ) -> torch.Tensor: """ Contrastive loss for visual representation learning Ensures latent dynamics align with visual features """ + status_training= self.active_inference.latent_dynamics.training + self.active_inference.latent_dynamics.train() # Predict next visual features from current latent and action - predicted_next_latent, predicted_logvar = self.active_inference.predict_next_latent(latents, actions) + predicted_next_latent, predicted_logvar, hidden_states = self.active_inference.predict_next_latent(latents, actions, hidden_states) + + self.active_inference.latent_dynamics.train(status_training) predicted_std = torch.exp(0.5 * predicted_logvar) # Normalize for contrastive loss pred_norm = F.normalize(predicted_next_latent, dim=-1) @@ -439,11 +590,19 @@ def compute_representation_loss( def _setup_optimizers(self): """Setup optimizers including visual components""" # Score network optimizer (includes encoder) + decoder_params = [] + if isinstance(self.active_inference.observation_decoder, nn.ModuleList): + for module in self.active_inference.observation_decoder: + decoder_params.extend(list(module.parameters())) + else: + decoder_params = list(self.active_inference.observation_decoder.parameters()) + self.score_optimizer = torch.optim.AdamW( list(self.active_inference.latent_score_network.parameters()) + list(self.active_inference.latent_diffusion.parameters()) + list(self.encoder.parameters())+ - list(self.active_inference.feature_decoder.parameters()), + list(self.active_inference.feature_decoder.parameters())+ + decoder_params, lr=self.config.learning_rate, weight_decay=1e-5 ) @@ -460,6 +619,7 @@ def _setup_optimizers(self): self.active_inference.value_network.parameters(), lr=self.config.learning_rate ) + # Dynamics optimizer self.dynamics_optimizer = torch.optim.AdamW( @@ -468,3 +628,13 @@ def _setup_optimizers(self): list(self.active_inference.reward_predictor.parameters()), lr=self.config.learning_rate ) + #Add epistemic optimizer + self.epistemic_optimizer = torch.optim.AdamW( + self.active_inference.epistemic_estimator.parameters(), + lr=self.config.learning_rate*0.1, + weight_decay=1e-5, + betas=(0.9, 0.999) + ) + self.active_inference.epistemic_optimizer = self.epistemic_optimizer + + diff --git a/active_inference_diffusion/agents/state_agent.py b/active_inference_diffusion/agents/state_agent.py index edfbf01..8811b27 100644 --- a/active_inference_diffusion/agents/state_agent.py +++ b/active_inference_diffusion/agents/state_agent.py @@ -122,20 +122,13 @@ def train_step(self) -> Dict[str, float]: dones = batch['dones'].to(self.device) metrics = {} - # First, update reward normalizer statistics - self.reward_normalizer.update(rewards.cpu().numpy()) - - # Normalize rewards - normalized_rewards = torch.tensor( - self.reward_normalizer.normalize(rewards.cpu().numpy()), - device=self.device, - dtype=torch.float32 - ) + # 1. Generate latents via diffusion (no gradients needed here) with torch.no_grad(): belief_info = self.active_inference.update_belief_via_diffusion(observations) latents = belief_info['latent'] - + latents_std = belief_info['latent_std'] + latents_mean = belief_info['latent_mean'] next_belief_info = self.active_inference.update_belief_via_diffusion(next_observations) next_latents = next_belief_info['latent'] @@ -146,7 +139,7 @@ def train_step(self) -> Dict[str, float]: ) # Clip gradients of score network self.score_optimizer.zero_grad() elbo_loss, elbo_info = self.active_inference.compute_diffusion_elbo( - observations, normalized_rewards, latents + observations, rewards, latents_mean, latents_std ) elbo_loss.backward() torch.nn.utils.clip_grad_norm_( @@ -271,4 +264,12 @@ def _setup_optimizers(self): list(self.active_inference.observation_decoder.parameters())+ list(self.active_inference.reward_predictor.parameters()), lr=self.config.learning_rate - ) \ No newline at end of file + ) + # Add epistemic optimizer + self.epistemic_optimizer = torch.optim.AdamW( + self.active_inference.epistemic_estimator.parameters(), + lr=self.config.learning_rate*0.1, + weight_decay=1e-5, + betas=(0.9, 0.999) + ) + self.active_inference.epistemic_optimizer = self.epistemic_optimizer diff --git a/active_inference_diffusion/configs/config.py b/active_inference_diffusion/configs/config.py index 7b1c6a0..3098eaa 100644 --- a/active_inference_diffusion/configs/config.py +++ b/active_inference_diffusion/configs/config.py @@ -10,7 +10,7 @@ @dataclass class DiffusionConfig: """Configuration for diffusion process""" - num_diffusion_steps: int = 1000 + num_diffusion_steps: int = 400 beta_start: float = 1e-4 beta_end: float = 0.02 beta_schedule: str = "cosine" # Options: "cosine", "linear" @@ -20,6 +20,8 @@ class DiffusionConfig: time_annealing_end: float = 0.1 annealing_steps: int = 100000 gradient_clip_val: float = 0.1 + inference_steps: Optional[int] = 100 # None means use DDPM + ddim_eta: float = 0.3 # Eta for DDIM, 0.0 means deterministic @dataclass class BeliefDynamicsConfig: @@ -46,26 +48,26 @@ class ActiveInferenceConfig: precision_init: float = 1.0 expected_free_energy_horizon: int = 5 efe_horizon: int = 5 # Alias for compatibility - epistemic_weight: float = 0.1 + epistemic_weight: float = 0.6 extrinsic_weight: float = 1.0 pragmatic_weight: float = 1.0 # consistency_weight: float = 0.1 # latent policy coherence discount_factor: float = 0.99 contrastive_weight: float = 0.5 # for latent policy coherence # Diffusion integration - kl_weight: float = 0.1 # kl regularization for diffusion + kl_weight: float = 0.9 # kl regularization for diffusion diffusion_weight: float = 1.0 # score matching weight reward_weight:float = 0.5 # reward scaling # Model architecture - hidden_dim: int = 512 - latent_dim: int = 128 - spatial_aggregator_output_dim:int = 256 - num_layers: int = 3 + hidden_dim: int = 128 + latent_dim: int = 96 + spatial_aggregator_output_dim:int = 32 + num_layers: int = 2 pixel_observation: bool = False # Use pixel observations # Training batch_size: int = 256 learning_rate: float = 5e-5 - gradient_clip: float = 0.5 # Gradient clipping value + gradient_clip: float = 1.0 # Gradient clipping value # Reward-oriented Active Inference parameters preference_temperature: float = 1.0 # τ in P(o) ∝ exp(r(o)/τ) @@ -74,14 +76,29 @@ class ActiveInferenceConfig: max_preference_temperature: float = 10.0 # Upper bound for exploration temperature_decay: float = 0.995 # Exponential decay per episode use_reward_preferences: bool = True # Enable reward-oriented EFE - + #Dynamics model parameters + dynamics_type = "transformer" + dynamics_context_len = 16 + dynamics_num_layers = 2 + dynamics_n_heads = 4 + dynamics_dropout = 0.0 + dynamics_residual = True + dynamics_use_checkpointing = True + dynamics_attn_impl = "mem" # good on 11GB cards + + # Value network parameters + num_value_bins: int = 255 # Number of bins for categorical value distribution + value_net_num_layers: int = 2 # Number of hidden layers in value network # Preference shaping parameters baseline_reward: float = 0.0 # Baseline for reward centering preference_momentum: float = 0.9 # EMA for reward statistics # Nested configs diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) belief_dynamics: BeliefDynamicsConfig = field(default_factory=BeliefDynamicsConfig) - + frame_stack: int = 3 + # + reward_disc_low: float = -10.0 + reward_disc_high: float = 10.0 # Device device: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -93,6 +110,7 @@ class PixelObservationConfig: encoder_type: str = "drqv2" # drqv2, impala, attention encoder_feature_dim: int = 80 augmentation: bool = True + num_layers: int = 3 random_shift_pad: int = 4 pixel_observation: bool = True # Use pixel observations @@ -117,10 +135,18 @@ class TrainingConfig: train_frequency: int = 2 gradient_steps: int = 4 num_parallel_envs: int = 6 # Number of parallel environments for training + # prioritized replay parameters + sequence_length: int = 10 + alpha: float = 0.6 + beta0: float = 0.4 + beta1: float = 0.4 + beta_frms: int = 100_000 + eps: float = 1e-6 # Evaluation num_eval_episodes: int = 10 # Logging use_wandb: bool = True project_name: str = "active-inference-diffusion" - experiment_name: Optional[str] = None \ No newline at end of file + experiment_name: Optional[str] = None + prioritized_seq_replay:bool = True \ No newline at end of file diff --git a/active_inference_diffusion/core/active_inference.py b/active_inference_diffusion/core/active_inference.py index 2ce2f6b..00c117a 100644 --- a/active_inference_diffusion/core/active_inference.py +++ b/active_inference_diffusion/core/active_inference.py @@ -1,135 +1,168 @@ """ Active Inference with Diffusion-Generated Latent Spaces """ + import math import torch import torch.nn as nn import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.func import linearize import numpy as np -from typing import Dict, Tuple, Optional, Union +from typing import Dict, Tuple, Optional, Union, Any from ..configs import ActiveInferenceConfig from ..encoder.visual_encoders import ConvDecoder from .diffusion import LatentDiffusionProcess from ..models.score_networks import LatentScoreNetwork from ..models.policy_networks import DiffusionConditionedPolicy from ..models.value_networks import ValueNetwork -from ..models.dynamics_models import LatentDynamicsModel -from ..utils.util import SpatialAttentionAggregator +from ..models.dynamics_models import LatentDynamicsModel, TransformerDynamicsModel +from .free_energy import FreeEnergyComputation +from ..utils.util import symexp, symlog, DiscDist +HiddenState = Union[Tuple[torch.Tensor, torch.Tensor], Dict[str, torch.Tensor], None] class DiffusionActiveInference(nn.Module): """ Core Active Inference implementation with diffusion-generated latents - + Key innovations: - Latent beliefs emerge from reverse diffusion process - Policies conditioned on continuous latent manifolds - Expected Free Energy computed over diffusion trajectories """ - + def __init__( self, observation_dim: int, action_dim: int, latent_dim: int, - config: 'ActiveInferenceConfig', - pixel_shape: Optional[Tuple[int, int, int]] = None + config: "ActiveInferenceConfig", + pixel_shape: Optional[Tuple[int, int, int]] = None, + shared_visual_encoder: Optional[nn.Module] = None, + checkpoint_decoder: bool = True, + checkpoint_reward_predictor: bool = True, ): super().__init__() - + self.observation_dim = observation_dim self.action_dim = action_dim self.latent_dim = latent_dim self.config = config - self.pixel_shape = pixel_shape + self.pixel_shape = pixel_shape self.is_pixel_observation = config.pixel_observation self.epistemic_dropout_rate = 0.2 self.device = torch.device(config.device) + self.shared_visual_encoder = shared_visual_encoder if self.is_pixel_observation and pixel_shape is not None: self.raw_observation_shape = pixel_shape else: self.raw_observation_shape = None - + self.ckpt_decoder = checkpoint_decoder + self.ckpt_reward_predictor = checkpoint_reward_predictor # Initialize components self._build_models() self.to(self.device) + # Current belief state (diffusion-generated) self.current_latent = None self.latent_trajectory = [] - + def _build_models(self): """Build core models for diffusion active inference""" - + # Latent diffusion process self.latent_diffusion = LatentDiffusionProcess( - self.config.diffusion, - latent_dim=self.latent_dim + self.config.diffusion, latent_dim=self.latent_dim ) # Add reward preference components - self.register_buffer('reward_mean', torch.tensor(0.0).to(self.device)) - self.register_buffer('reward_var', torch.tensor(1.0).to(self.device)) - self.register_buffer('preference_temperature', torch.tensor(self.config.preference_temperature).to(self.device)) - + self.register_buffer("reward_mean", torch.tensor(0.0).to(self.device)) + self.register_buffer("reward_var", torch.tensor(1.0).to(self.device)) + self.register_buffer( + "preference_temperature", + torch.tensor(self.config.preference_temperature).to(self.device), + ) # Score network for latent generation - + self.latent_score_network = LatentScoreNetwork( latent_dim=self.latent_dim, - observation_dim=self.latent_dim, + observation_dim=self.observation_dim, + action_dim=self.action_dim, hidden_dim=self.config.hidden_dim, - use_attention=True + use_attention=True, ) - + # Policy network conditioned on diffusion latents - + self.policy_network = DiffusionConditionedPolicy( latent_dim=self.latent_dim, action_dim=self.action_dim, hidden_dim=self.config.hidden_dim, - use_state_dependent_std=True + squash_output=True ) - + # Value network for latent states - + self.value_network = ValueNetwork( state_dim=self.latent_dim, # Using latent dimension as state dimension hidden_dim=self.config.hidden_dim, - time_embed_dim=128, # Time embedding dimension - num_layers=3 - ) - - # Dynamics model in latent space - - self.latent_dynamics = LatentDynamicsModel( - state_dim=self.latent_dim, - action_dim=self.action_dim, - hidden_dim=self.config.hidden_dim, - num_layers=3 + num_layers=self.config.value_net_num_layers, + num_bins=self.config.num_value_bins, ) + + # Dynamics model in latent space + if self.config.dynamics_type == "transformer": + self.latent_dynamics = TransformerDynamicsModel( + state_dim=self.latent_dim, + action_dim=self.action_dim, + hidden_dim=self.config.hidden_dim, + num_layers=self.config.dynamics_num_layers, + n_heads=self.config.dynamics_n_heads, + dropout=self.config.dynamics_dropout, + context_len=self.config.dynamics_context_len, + residual=self.config.dynamics_residual, + use_checkpointing=self.config.dynamics_use_checkpointing, + attn_impl=self.config.dynamics_attn_impl, # {"auto","flash","mem","math"} for SDPA backends + clear_on_reset=True, + ) + else: + self.latent_dynamics = LatentDynamicsModel( + state_dim=self.latent_dim, + action_dim=self.action_dim, + hidden_dim=self.config.hidden_dim, + num_layers=3, + lstm_hidden_dim=self.config.hidden_dim, + ) + self.current_hidden_state = None # Initialize hidden state for dynamics # Observation decoder (latent -> observation prediction) if not self.is_pixel_observation: - self.observation_decoder = nn.ModuleList([ - nn.Sequential( - nn.Linear(self.latent_dim, self.config.hidden_dim * 2), - nn.LayerNorm(self.config.hidden_dim * 2), - nn.SiLU(), - nn.Dropout(self.epistemic_dropout_rate), - ), - nn.Sequential( - nn.Linear(self.config.hidden_dim * 2, self.config.hidden_dim * 2), - nn.LayerNorm(self.config.hidden_dim * 2), - nn.SiLU(), - nn.Dropout(self.epistemic_dropout_rate), - ), - nn.Sequential( - nn.Linear(self.config.hidden_dim * 2, self.config.hidden_dim), - nn.LayerNorm(self.config.hidden_dim), - nn.SiLU(), - nn.Dropout(self.epistemic_dropout_rate), - ), - nn.Linear(self.config.hidden_dim, self.observation_dim) - ]) - observation_shape = self.observation_dim # For non-pixel observations + self.observation_decoder = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(self.latent_dim, self.config.hidden_dim * 2), + nn.LayerNorm(self.config.hidden_dim * 2), + nn.SiLU(), + nn.Dropout(self.epistemic_dropout_rate), + ), + nn.Sequential( + nn.Linear( + self.config.hidden_dim * 2, self.config.hidden_dim * 2 + ), + nn.LayerNorm(self.config.hidden_dim * 2), + nn.SiLU(), + nn.Dropout(self.epistemic_dropout_rate), + ), + nn.Sequential( + nn.Linear(self.config.hidden_dim * 2, self.config.hidden_dim), + nn.LayerNorm(self.config.hidden_dim), + nn.SiLU(), + nn.Dropout(self.epistemic_dropout_rate), + ), + nn.Linear(self.config.hidden_dim, self.observation_dim), + ] + ) + observation_shape = self.observation_dim # For non-pixel observations else: self.observation_decoder = ConvDecoder( latent_dim=self.latent_dim, @@ -137,51 +170,77 @@ def _build_models(self): img_channels=self.pixel_shape[0], hidden_dim=self.config.hidden_dim, spatial_size=21, - ) + frame_stack=self.config.frame_stack, + ) # Also add a feature decoder for reconstructing encoded features self.feature_decoder = nn.Sequential( nn.Linear(self.latent_dim, self.config.hidden_dim), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), nn.Dropout(self.epistemic_dropout_rate), - nn.Linear(self.config.hidden_dim, self.latent_dim), # Decode to feature space - nn.Tanh() - ) - observation_shape = self.pixel_shape # For pixel observations + nn.Linear( + self.config.hidden_dim, self.latent_dim + ), # Decode to feature space + nn.Tanh(), + ) + c, h, w = self.pixel_shape + stacked_shape = (c * self.config.frame_stack, h, w) + observation_shape = stacked_shape # For pixel observations + self.free_energy = FreeEnergyComputation( + precision_init=1.0, + observation_decoder=self.observation_decoder, + is_pixel_observation=self.is_pixel_observation, + ).to(self.device) # Epistemic estimator for latent uncertainty self.epistemic_estimator = FunctionSpaceEpistemicEstimator( decoder=self.observation_decoder, + feature_extractor=self.shared_visual_encoder, latent_dim=self.latent_dim, observation_shape=observation_shape, + is_pixel=self.is_pixel_observation, + device=self.device, hidden_dim=self.config.hidden_dim, - spatial_aggregator_output_dim=self.config.spatial_aggregator_output_dim, - is_pixel_observation=self.is_pixel_observation, - device=self.device - ) - + jac_dim=self.config.spatial_aggregator_output_dim, + latent_proj_dim=self.latent_dim, + use_checkpointing=True, # Enable checkpointing + checkpoint_critic=True, + checkpoint_jacproj=True, + checkpoint_latproj=True, + robust_marginals=False, + ) + self.num_reward_bins = getattr(self.config, "num_reward_bins", 255) + # Initialize a reward predictor self.reward_predictor = nn.Sequential( nn.Linear(self.latent_dim, self.config.hidden_dim), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), nn.Linear(self.config.hidden_dim, self.config.hidden_dim // 2), + nn.LayerNorm(self.config.hidden_dim // 2), nn.ReLU(), - nn.Linear(self.config.hidden_dim // 2, 2) + nn.Linear(self.config.hidden_dim // 2, self.num_reward_bins), ) + nn.init.zeros_(self.reward_predictor[-1].weight) + nn.init.zeros_(self.reward_predictor[-1].bias) + # Initialize final layer weights to small values + nn.init.kaiming_normal_( + self.reward_predictor[-1].weight, mode="fan_in", nonlinearity="relu" + ) + nn.init.zeros_(self.reward_predictor[-1].bias) def to(self, device): """Override to ensure ALL components move to device""" # Convert device to torch.device if needed if isinstance(device, str): device = torch.device(device) - + # Call parent to() method super().to(device) - + # Update our device attribute self.device = device - + # Explicitly move all components self.latent_diffusion = self.latent_diffusion.to(device) self.latent_score_network = self.latent_score_network.to(device) @@ -189,7 +248,7 @@ def to(self, device): self.value_network = self.value_network.to(device) self.latent_dynamics = self.latent_dynamics.to(device) self.reward_predictor = self.reward_predictor.to(device) - + # Handle observation decoder based on type if isinstance(self.observation_decoder, nn.ModuleList): # For state observations - move each module in the list @@ -198,125 +257,281 @@ def to(self, device): else: # For pixel observations self.observation_decoder = self.observation_decoder.to(device) - + # Move feature decoder if it exists - if hasattr(self, 'feature_decoder'): + if hasattr(self, "feature_decoder"): self.feature_decoder = self.feature_decoder.to(device) - + # Move epistemic estimator with explicit device update self.epistemic_estimator = self.epistemic_estimator.to(device) self.epistemic_estimator.device = device # Update its device attribute - + # Move buffers self.reward_mean = self.reward_mean.to(device) self.reward_var = self.reward_var.to(device) self.preference_temperature = self.preference_temperature.to(device) - + return self - - def decode_observation(self, latent: torch.Tensor, decode_to_pixels: bool = True) -> torch.Tensor: + + def reset_dynamics_hidden( + self, batch_size: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Reset LSTM hidden state for new trajectories""" + return self.latent_dynamics.init_hidden(batch_size, self.device) + + def _clone_dyn_hidden(self, h: HiddenState) -> HiddenState: + if h is None: + return None + if isinstance(h, tuple): + # LSTM + return (h[0].clone(), h[1].clone()) + if isinstance(h, dict): + # Transformer + return {k: v.clone() for k, v in h.items()} + return h + + def _reset_done_hidden(self, h: HiddenState, done: torch.Tensor) -> HiddenState: + """Reset only those envs where done==1, preserving others.""" + if h is None or done is None: + return h + if isinstance(h, tuple): + h0, c0 = h + dm = done.float().view(1, -1, 1) # [1,B,1] to broadcast over LSTM layers + # Zero the done envs; keep others + h0 = h0 * (1.0 - dm) + c0 = c0 * (1.0 - dm) + return (h0, c0) + if isinstance(h, dict): + out = {k: v for k, v in h.items()} + idx = done.nonzero(as_tuple=False).squeeze(-1) + if idx.numel() > 0: + if "tokens" in out: out["tokens"][idx].zero_() + if "lengths" in out: out["lengths"][idx] = 0 + return out + return h + + + def _blend_hidden(self, new: HiddenState, done: torch.Tensor) -> HiddenState: + """ + Keep `new` where not done, insert a freshly reset hidden where done. + (Useful if you want to combine step output with per-env resets.) + """ + if new is None: + return None + + # LSTM hidden: tuple(h, c) with shapes [L, B, H] + if isinstance(new, tuple): + h, c = new + # mask shape: [1, B, 1] to broadcast across layers & hidden dim + dm = done.to(h.dtype).view(1, -1, 1) + # get a fresh reset state for this batch size + reset_h, reset_c = self.reset_dynamics_hidden(h.size(1)) + # blend: reset where done=1, keep new where done=0 + h = reset_h * dm + h * (1.0 - dm) + c = reset_c * dm + c * (1.0 - dm) + return (h, c) + + # Dict-style hidden (e.g., token caches, lengths, etc.) + if isinstance(new, dict): + out = {k: v.clone() for k, v in new.items()} # avoid in-place on shared graph + idx = done.nonzero(as_tuple=False).squeeze(-1) + if idx.numel() > 0: + reset = self.reset_dynamics_hidden(len(done)) + # copy only for keys that exist in both; leave others untouched + if isinstance(reset, dict): + for k in out.keys(): + if k in reset: + out[k][idx] = reset[k][idx] + else: + # fallback: zero the done indices if reset isn't a dict + for k in out.keys(): + out[k][idx] = torch.zeros_like(out[k][idx]) + return out + + # Fallback for other hidden types (e.g., GRU tensor) — untouched + return new + + def decode_observation( + self, latent: torch.Tensor, decode_to_pixels: bool = True + ) -> torch.Tensor: """ decoding that can decode to either pixels or features - + Args: latent: Latent representation decode_to_pixels: If True and using pixel observations, decode to raw pixels. If False, decode to encoded feature space. """ latent = latent.to(self.device) + use_checkpointing = self.ckpt_decoder and latent.requires_grad if self.is_pixel_observation: if decode_to_pixels: # Decode to pixel space - return self.observation_decoder(latent) + if use_checkpointing and self.training: + return cp.checkpoint(self.observation_decoder, latent, use_reentrant=False) + else: + return self.observation_decoder(latent) else: # Decode to feature space (for reconstruction loss) self.feature_decoder = self.feature_decoder.to(self.device) - return self.feature_decoder(latent) + if use_checkpointing and self.training: + return cp.checkpoint(self.feature_decoder, latent, use_reentrant=False) + else: + return self.feature_decoder(latent) else: # For non-pixel observations, use fully connected decoder h = latent - h1 = self.observation_decoder[0](h) - h2 = self.observation_decoder[1](h1) - h2 = h2 + h1 # Skip connection - h3 = self.observation_decoder[2](h2) + if use_checkpointing and self.training: + h1 = cp.checkpoint(self.observation_decoder[0], h, use_reentrant=False) + h2 = cp.checkpoint(self.observation_decoder[1], h1, use_reentrant=False) + h2 = h2 + h1 # Skip connection + h3 = cp.checkpoint(self.observation_decoder[2], h2, use_reentrant=False) + else: + h1 = self.observation_decoder[0](h) + h2 = self.observation_decoder[1](h1) + h2 = h2 + h1 # Skip connection + h3 = self.observation_decoder[2](h2) return self.observation_decoder[3](h3) - - def predict_reward_from_latent(self, latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def predict_reward_from_latent( + self, latent: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Use existing reward predictor to get reward distribution from latent """ latent = latent.to(self.device) - reward_params = self.reward_predictor(latent) - reward_mean = reward_params[:, 0] - reward_std = torch.exp(torch.clamp(reward_params[:, 1], min=-5, max=2)) - return reward_mean, reward_std - + if latent.dim() == 1: + latent = latent.unsqueeze(0) + if torch.isnan(latent).any() or torch.isinf(latent).any(): + raise ValueError("Latent tensor contains NaN or Inf values") + if self.ckpt_reward_predictor and latent.requires_grad and self.training: + logits = cp.checkpoint(self.reward_predictor, latent, use_reentrant=False) + else: + logits = self.reward_predictor(latent) + dist = DiscDist(logits, low=self.config.reward_disc_low, high=self.config.reward_disc_high, device=logits.device) + pred = dist.mean().squeeze(-1) + return pred, logits def update_belief_via_diffusion( self, observation: torch.Tensor, - raw_observation: Optional[torch.Tensor] = None + raw_observation: Optional[torch.Tensor] = None, + num_trajectories: int = 5, + frame_idx: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Update belief using reverse diffusion process This is the core innovation - beliefs as diffusion-generated latents + Returns different statistics based on batch size: + - batch_size=1: Returns uncertainty estimates for single observation + - batch_size>1: Returns per-sample latents without population statistics + """ observation = observation.to(self.device) + if frame_idx is not None: + frame_idx = frame_idx.to(self.device) + if actions is not None: + actions = actions.to(self.device) # Handle different input shapes if observation.dim() == 1: observation = observation.unsqueeze(0) batch_size = 1 else: batch_size = observation.shape[0] - - # Generate latent via reverse diffusion conditioned on observation - trajectory = self.latent_diffusion.generate_latent_trajectory( - score_network=self.latent_score_network, - batch_size=batch_size, - observation=observation, - deterministic=False - ) - - # Final latent is the belief - self.current_latent = trajectory[-1] - self.latent_trajectory = trajectory - - # Compute latent statistics - if batch_size == 1: - latent_mean = self.current_latent.squeeze(0) - latent_std = torch.zeros_like(latent_mean) # Single sample, no std - else: - latent_mean = self.current_latent.mean(dim=0) - latent_std = self.current_latent.std(dim=0) - - # Decode to observation space for validation - # Compute reconstruction error appropriately - if self.is_pixel_observation: - # For pixel observations, decode to feature space and compare with encoded features - predicted_features = self.decode_observation(self.current_latent, decode_to_pixels=False) - reconstruction_error = F.mse_loss(predicted_features, observation) - else: - # For state observations, decode to state space - predicted_obs = self.decode_observation(self.current_latent) - reconstruction_error = F.mse_loss(predicted_obs, observation) - + with torch.no_grad(): + # Save training states of all components involved in belief generation + training_states = { + "score_network": self.latent_score_network.training, + "diffusion": self.latent_diffusion.training, + } + + # Set to eval mode for deterministic belief generation + self.latent_score_network.eval() + self.latent_diffusion.eval() + + if batch_size == 1: + # Expand observation to run multiple trajectories in parallel + expanded_obs = observation.expand(num_trajectories, -1) + + with torch.autograd.set_detect_anomaly(True): + # Generate multiple trajectories in parallel + trajectories = self.latent_diffusion.generate_latent_trajectory( + score_network=self.latent_score_network, + batch_size=num_trajectories, # Run num_trajectories in parallel + observation=expanded_obs, + deterministic=False, + frame_time= frame_idx, + action= actions + ) + + final_latents = trajectories[ + -1 + ] # Shape: (num_trajectories, latent_dim) + + # Compute statistics across trajectories + latent_mean = final_latents.mean( + dim=0, keepdim=True + ) # Shape: (1, latent_dim) + latent_std = final_latents.std( + dim=0, keepdim=True + ) # Shape: (1, latent_dim) + + # For current latent, we have options: + + eps = torch.randn_like(latent_std) + self.current_latent = latent_mean + eps * latent_std + + # Store the full trajectory for analysis + self.latent_trajectory = trajectories + trajectory_length = len(trajectories) + else: + # Generate latent via reverse diffusion conditioned on observation + with torch.autograd.set_detect_anomaly(True): + trajectories = self.latent_diffusion.generate_latent_trajectory( + score_network=self.latent_score_network, + batch_size=batch_size, + observation=observation, + deterministic=False, + frame_time= frame_idx, + action = actions + ) + + # Final latent is the belief + self.current_latent = trajectories[-1] + self.latent_trajectory = trajectories + latent_mean = self.current_latent.mean(dim=0, keepdim=True) + latent_std = self.current_latent.std(dim=0, keepdim=True) + trajectory_length = len(trajectories) + self.latent_score_network.train(training_states["score_network"]) + self.latent_diffusion.train(training_states["diffusion"]) + + # Validate outputs before returning + if ( + torch.isnan(self.current_latent).any() + or torch.isinf(self.current_latent).any() + ): + raise ValueError( + "Generated latents contain NaN or Inf values inside belief update!" + ) + return { - 'latent': self.current_latent, - 'latent_mean': latent_mean, - 'latent_std': latent_std, - 'trajectory_length': len(trajectory), - 'reconstruction_error': reconstruction_error, - 'observation': observation, - 'raw_observation': raw_observation, + "latent": self.current_latent, + "latent_mean": latent_mean, + "latent_std": latent_std, + "trajectory_length": trajectory_length, + "observation": observation, + "raw_observation": raw_observation, } - + def compute_expected_free_energy_diffusion( self, latent: torch.Tensor, horizon: int = 5, - num_trajectories: int = 10, - num_ambiguity_samples: int = 10 + num_trajectories: int = 6, + num_ambiguity_samples: int = 2, + hidden_state: HiddenState= None, + done_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Expected Free Energy over diffusion-generated trajectories @@ -326,119 +541,127 @@ def compute_expected_free_energy_diffusion( latent = latent.to(self.device) device = latent.device batch_size = latent.shape[0] - + if hidden_state is None: + hidden_state = self.reset_dynamics_hidden(batch_size) # Initialize accumulators total_efe = torch.zeros(batch_size, device=device) epistemic_values = [] pragmatic_values = [] latent_consistency = [] - + # Generate future latent trajectories for traj_idx in range(num_trajectories): current_latent = latent.clone() traj_efe = 0 - + hidden_state = self._clone_dyn_hidden(hidden_state) # Reset hidden state for each trajectory + dm = done_mask for t in range(horizon): # Sample policy from current latent - action, log_prob, policy_dist = self.policy_network(current_latent) + action, log_prob, policy_dist = self.policy_network(current_latent) + # Predict next latent - next_latent_mean, next_latent_logvar = self.predict_next_latent(current_latent, action) + next_latent_mean, next_latent_logvar, hidden_state = ( + self.predict_next_latent(current_latent, action, hidden_state, done_mask=dm) + ) + dm = None # Only apply done mask at first step next_latent = self.reparameterize(next_latent_mean, next_latent_logvar) # 1. Pragmatic value (reward prediction) # For p(o) ∝ exp(r(o)/τ), we have ln p(o) = r(o)/τ - ln Z # Since Z is constant across policies, we can ignore it - predicted_reward_mean, _ = self.predict_reward_from_latent(next_latent) + predicted_reward, _ = self.predict_reward_from_latent(next_latent) # This makes high-reward states preferred under EFE p(o) ∝ exp(r(o)/τ) - pragmatic = self.config.pragmatic_weight * (predicted_reward_mean / self.preference_temperature) - # Pragmatic value: Expected value under policy - time_tensor = torch.full((batch_size,), float(t), device=device) - value = self.value_network(next_latent, time_tensor).squeeze(-1) - pragmatic += value + pragmatic = predicted_reward / self.preference_temperature + + # Pragmatic value: Expected value under policy + + value = self.value_network(next_latent) # [B, K] + exceptional_value = self.value_network.expected_value(value) # [B], ∑ p(bin)*bin_real + pragmatic += exceptional_value + # 2. Consistency (negative policy entropy)-> exploration bonus - consistency = -policy_dist.entropy().sum(dim=-1) + consistency = -self.policy_network.get_policy_entropy(current_latent) - - epistemic, epistemic_metrics= self.compute_epistemic_value( + epistemic, epistemic_metrics = self.compute_epistemic_value( next_latent_mean, next_latent_logvar, - num_samples=num_ambiguity_samples + num_samples=num_ambiguity_samples, ) - + # Accumulate EFE step_efe = ( - self.config.epistemic_weight * epistemic + - self.config.pragmatic_weight * pragmatic + - self.config.consistency_weight * consistency + self.config.epistemic_weight * epistemic + + self.config.pragmatic_weight * pragmatic + + self.config.consistency_weight * consistency ) - - traj_efe += (self.config.discount_factor ** t) * step_efe - + + traj_efe += (self.config.discount_factor**t) * step_efe + # Update for next step current_latent = next_latent - + total_efe += traj_efe / num_trajectories - + # Store components for analysis epistemic_values.append(epistemic) pragmatic_values.append(pragmatic) latent_consistency.append(consistency) - + info = { - 'epistemic_mean': torch.stack(epistemic_values).mean(), - 'pragmatic_mean': torch.stack(pragmatic_values).mean(), - 'consistency_mean': torch.stack(latent_consistency).mean(), - 'num_trajectories': num_trajectories, - 'horizon': horizon, - **epistemic_metrics + "epistemic_mean": torch.stack(epistemic_values).mean(), + "pragmatic_mean": torch.stack(pragmatic_values).mean(), + "consistency_mean": torch.stack(latent_consistency).mean(), + "num_trajectories": num_trajectories, + "horizon": horizon, + **epistemic_metrics, } - + return total_efe, info - + def compute_epistemic_value( - self, + self, next_latent_mean: torch.Tensor, next_latent_logvar: torch.Tensor, - num_samples: int = 5 + num_samples: int = 4, ) -> torch.Tensor: # Compute epistemic value: H(o|s,π) - H(o|s,θ,π) # Epistemic value (ambiguity - observation uncertainty) - #- H[p(o|s,π)] is entropy marginalizing over model parameters (using dropout) - #- H[p(o|s,θ,π)] is entropy for a fixed set of parameters + # - H[p(o|s,π)] is entropy marginalizing over model parameters (using dropout) + # - H[p(o|s,θ,π)] is entropy for a fixed set of parameters next_latent_mean = next_latent_mean.to(self.device) next_latent_logvar = next_latent_logvar.to(self.device) with torch.no_grad(): epistemic_value, metrics = self.epistemic_estimator( - next_latent_mean, - next_latent_logvar, - num_samples + next_latent_mean, next_latent_logvar, num_samples ) - - + return epistemic_value, metrics def train_epistemic_estimator( self, latents: torch.Tensor, actions: torch.Tensor, - next_latents: torch.Tensor + next_latents: torch.Tensor, + hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> float: """Train MINE estimator separately""" latents = latents.to(self.device) actions = actions.to(self.device) + next_latents = next_latents.to(self.device) # Predict next latent distribution - next_mean, next_logvar = self.predict_next_latent(latents, actions) + next_mean, next_logvar, hidden_state = self.predict_next_latent( + latents, actions, hidden_state + ) # Compute MINE loss (negative MI for minimization) mi_estimate, metrics = self.epistemic_estimator(next_mean, next_logvar) - loss = -mi_estimate.mean() - + total_loss = -mi_estimate.mean() + # Optimize self.epistemic_optimizer.zero_grad() - loss.backward() + total_loss.backward() torch.nn.utils.clip_grad_norm_( - self.epistemic_estimator.parameters(), - self.config.gradient_clip + self.epistemic_estimator.parameters(), self.config.gradient_clip ) self.epistemic_optimizer.step() @@ -449,37 +672,50 @@ def reparameterize(self, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tens eps = torch.randn_like(std) return mean + eps * std - def predict_next_latent( self, latent: torch.Tensor, - action: torch.Tensor - ) -> torch.Tensor: + action: torch.Tensor, + hidden_state: HiddenState = None, + done_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, HiddenState]: """Predict next latent state using learned dynamics""" latent = latent.to(self.device) action = action.to(self.device) - delta = self.latent_dynamics(latent, action) - next_mean= latent + delta # Residual connection - next_logvar = torch.full_like(next_mean, np.log(0.1)).to(self.device) - return next_mean, next_logvar - + next_mean, next_logvar, new_hidden_state = self.latent_dynamics(latent, action, hidden_state, done_mask) + + return next_mean, next_logvar, new_hidden_state + def _compute_latent_kl( self, - latent: torch.Tensor, - prior_latent: torch.Tensor - ) -> torch.Tensor: - """KL divergence between latent distributions""" - # Assume Gaussian with unit variance for simplicity - - # Can be extended to learned variances - kl = 0.5 * torch.sum((latent - prior_latent) ** 2, dim=-1) - return kl - + latent_mean: torch.Tensor, + latent_logvar: torch.Tensor, + prior_mean: torch.Tensor, + prior_logvar: torch.Tensor, + ) -> torch.Tensor: + """Proper KL divergence between two Gaussians""" + # KL(q||p) = 0.5 * (log(σ_p²/σ_q²) + (σ_q² + (μ_q - μ_p)²)/σ_p² - 1) + prior_var = torch.exp(prior_logvar) + latent_var = torch.exp(latent_logvar) + + kl = 0.5 * ( + prior_logvar + - latent_logvar + + (latent_var + (latent_mean - prior_mean) ** 2) / prior_var + - 1.0 + ) + + # Sum over latent dimensions, mean over batch + return kl.sum(dim=-1).mean() + def act( self, observation: torch.Tensor, deterministic: bool = False, - raw_observation: Optional[torch.Tensor] = None + raw_observation: Optional[torch.Tensor] = None, + maintain_hidden_state: bool = True, + from_idx: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Select action using diffusion-generated latent active inference @@ -489,26 +725,34 @@ def act( observation = observation.unsqueeze(0) # Update belief via diffusion - belief_info = self.update_belief_via_diffusion(observation, raw_observation) - + belief_info = self.update_belief_via_diffusion(observation, raw_observation, frame_idx=from_idx, actions=actions) + batch_size = observation.shape[0] + + if not maintain_hidden_state or self.current_hidden_state is None: + # Reset hidden state if not maintaining or not set + self.current_hidden_state = self.reset_dynamics_hidden(batch_size) # Current latent belief - latent = belief_info['latent'] + latent = belief_info["latent"] # Ensure latent has proper shape for policy if latent.dim() == 1: latent = latent.unsqueeze(0) - + # Compute expected free energy efe, efe_info = self.compute_expected_free_energy_diffusion( latent, - horizon=self.config.efe_horizon + horizon=self.config.efe_horizon, + hidden_state=self.current_hidden_state, + done_mask=None, ) - + # Get action from policy conditioned on latent action, log_prob, policy_dist = self.policy_network( - latent, - deterministic=deterministic + latent, deterministic=deterministic ) - action =action.cpu() + _, _, self.current_hidden_state = self.predict_next_latent( + latent, action, self.current_hidden_state + ) + action = action.cpu() # Ensure action has proper shape if action.dim() > 2: action = action.squeeze() @@ -518,300 +762,400 @@ def act( # Handle scalar actions action = action.unsqueeze(0) - # Compile information info = { - **belief_info, - 'expected_free_energy': efe.mean().cpu().item(), - 'action_log_prob': log_prob.mean().cpu().item(), - 'policy_entropy': policy_dist.entropy().sum(dim=-1).mean().cpu().item(), - **{k: v.cpu().item() if torch.is_tensor(v) else v for k, v in efe_info.items()} + **belief_info, + "expected_free_energy": efe.mean().cpu().item(), + "action_log_prob": log_prob.mean().cpu().item(), + "policy_entropy": policy_dist.get_policy_entropy(latent).sum(dim=-1).mean().cpu().item(), + **{ + k: v.cpu().item() if torch.is_tensor(v) else v + for k, v in efe_info.items() + }, } - + return action, info - + def compute_diffusion_elbo( self, observations: torch.Tensor, rewards: torch.Tensor, - latents: Optional[torch.Tensor] = None, - raw_observations: Optional[torch.Tensor] = None + latents_mean: Optional[torch.Tensor] = None, + latents_std: Optional[torch.Tensor] = None, + raw_observations: Optional[torch.Tensor] = None, + frame_index: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Annealing time conditioned ELBO for diffusion-generated latents Modified ELBO for diffusion-generated latent active inference - + L = E_q(z|o,π)[log p(o|z)] - D_KL[q(z|o,π)||p_θ(z)] + R_diffusion(θ) """ observations = observations.to(self.device) rewards = rewards.to(self.device) batch_size = observations.shape[0] device = self.device - + # Generate latents if not provided - if latents is None: + if latents_mean is None or latents_std is None: # Use current belief generation - belief_info = self.update_belief_via_diffusion(observations, raw_observations) - latents = belief_info['latent'] + belief_info = self.update_belief_via_diffusion( + observations, raw_observations, frame_idx=frame_index, actions=actions + ) + latents = belief_info["latent"] + latents_mean = belief_info["latent_mean"] + latents_std = belief_info["latent_std"] + if latents_mean.dim() == 1: + latents_mean = latents_mean.unsqueeze(0) + if latents_std.dim() == 1: + latents_std = latents_std.unsqueeze(0) + if ( + torch.isnan(latents_mean).any() + or torch.isinf(latents_mean).any() + or torch.isnan(latents_std).any() + or torch.isinf(latents_std).any() + or torch.isnan(latents).any() + or torch.isinf(latents).any() + ): + raise ValueError( + "Latent mean, std, or latent contains NaN or Inf values during belief update" + ) + else: + # Use provided latents + latents_mean = latents_mean.to(device) + latents_std = latents_std.to(device) + eps = torch.randn_like(latents_std) + latents = latents_mean + eps * latents_std - # Reconstruction term - if self.is_pixel_observation: - # For pixel observations, reconstruct features - predicted_features = self.decode_observation(latents, decode_to_pixels=False) - reconstruction_loss = F.mse_loss(predicted_features, observations) - else: - # For state observations, reconstruct states - predicted_obs = self.decode_observation(latents) - reconstruction_loss = F.mse_loss(predicted_obs, observations) - - + fe_loss, fe_info = self.free_energy.compute_loss( + states=latents, # Your latent states z + observations=observations, # Encoded features (for pixels) or states + score_network=self.latent_score_network, + current_time=0.0, + raw_observations=raw_observations, # Pass raw pixels if available + frame_idx=frame_index, + actions=actions + ) + + + + # Update precision based on complexity vs accuracy balance + self.free_energy.update_precision( + complexity=fe_info["complexity"], + accuracy=fe_info["accuracy"] + ) # Diffusion score matching loss # Sample continuous time with importance sampling # Emphasize times where loss is typically high - if hasattr(self, 'time_importance_weights'): + if hasattr(self, "time_importance_weights"): # Use learned importance weights t = self._importance_sample_time(batch_size, device) else: # Uniform sampling initially t = torch.rand(batch_size, device=device) - noise = torch.randn_like(latents, device=device) - - noisy_latents, true_noise, sample_info = self.latent_diffusion.continuous_q_sample(latents, t, noise) - - predicted_score = self.latent_score_network( - noisy_latents, - t, - observations + + noisy_latents, true_noise, sample_info = ( + self.latent_diffusion.continuous_q_sample(latents, t, noise, frame_time=frame_index, actions=actions) ) + + predicted_score = self.latent_score_network(noisy_latents, t, observations, frame_time=frame_index, action=actions) # Compute true score with proper scaling - log_snr = sample_info['log_snr'] - alpha = sample_info['alpha'] - sigma = sample_info['sigma'] - + + sigma = sample_info["sigma"] + # True score: -noise / sigma (not sqrt(1-alpha) for continuous time) - true_score = -noise / (sigma + 1e-8) - + true_score = -noise / (sigma + torch.finfo(sigma.dtype).eps) + # Annealed loss weight loss_weight = self.latent_diffusion.compute_loss_weight(t) # Score matching loss with annealing score_diff = predicted_score - true_score - #shape: (batch_size, latent_dim) - per_sample_losses = loss_weight.view(-1) * torch.sum(score_diff ** 2, dim=1) + # shape: (batch_size, latent_dim) + per_sample_losses = loss_weight.view(-1) * torch.sum(score_diff**2, dim=1) score_matching_loss = torch.mean(per_sample_losses) # Add gradient penalty for stability - grad_penalty = self._compute_gradient_penalty(noisy_latents, t, observations) - + grad_penalty = self._compute_gradient_penalty(noisy_latents, t, observations, + frame_index=frame_index, + actions=actions) + # KL term with annealing - prior_latents = self.latent_diffusion.sample_latent_prior(batch_size, device) - kl_loss = self._compute_latent_kl(latents, prior_latents).mean() + prior_latent_mean, prior_latent_std = self.latent_diffusion.sample_latent_prior( + batch_size, device + ) + latent_logvar = torch.log(latents_std.pow(2) + torch.finfo(latents_std.dtype).eps) + prior_logvar = torch.log(prior_latent_std.pow(2) + torch.finfo(prior_latent_std.dtype).eps) + kl_loss = self._compute_latent_kl( + latents_mean, latent_logvar, prior_latent_mean, prior_logvar + ) kl_weight = torch.exp(-5.0 * t.mean()) # Anneal KL over time - - # Add reward prediction loss if rewards provided - predicted_rewards = self.reward_predictor(latents) - rewards_mean =predicted_rewards[:, 0] - rewards_std = torch.exp(torch.clamp(predicted_rewards[:, 1],min=-5, max=2)) - rewards_distribution = torch.distributions.Normal(rewards_mean, rewards_std) - reward_loss = -rewards_distribution.log_prob(rewards).mean() + # Predict rewards from latents + pred_reward, logits_reward = self.predict_reward_from_latent(latents) + disc = DiscDist(logits_reward, low=self.config.reward_disc_low, high=self.config.reward_disc_high, device=logits_reward.device) + reward_loss = (-disc.log_prob(rewards)).mean() # Total ELBO - elbo = -reconstruction_loss + self.config.kl_weight * kl_loss*kl_weight + \ - self.config.diffusion_weight * score_matching_loss+ \ - 0.1*grad_penalty - \ - self.config.reward_weight * reward_loss - self._update_time_importance(t, per_sample_losses.detach()) + elbo = ( + -fe_info["reconstruction_mse"] + + self.config.kl_weight * kl_loss * kl_weight + + self.config.diffusion_weight * score_matching_loss + + 0.1 * grad_penalty + - self.config.reward_weight * reward_loss + ) + self._update_time_importance(t, per_sample_losses.detach()) info = { - 'reconstruction_loss': reconstruction_loss.item(), - 'kl_loss': kl_loss.item(), - 'score_matching_loss': score_matching_loss.item(), - 'elbo': elbo.item(), - 'reward_loss': reward_loss.item(), - 'grad_penalty': grad_penalty.item(), - 'mean_time': t.mean().item(), - 'loss_weight_mean': loss_weight.mean().item() + "reconstruction_loss": fe_info["reconstruction_mse"].item(), + "kl_loss": kl_loss.item(), + "score_matching_loss": score_matching_loss.item(), + "elbo": elbo.item(), + "reward_loss": reward_loss.item(), + "grad_penalty": grad_penalty.item(), + "mean_time": t.mean().item(), + "loss_weight_mean": loss_weight.mean().item(), } - + return -elbo, info # Return negative ELBO as loss + + def train_dynamics_on_sequence( + self, + latent_sequences: torch.Tensor, # [batch_size, seq_len, latent_dim] + action_sequences: torch.Tensor, # [batch_size, seq_len-1, action_dim] + done_sequences: torch.Tensor, # [batch_size, seq_len-1] + sequence_lengths: torch.Tensor, # [batch_size] actual lengths + sample_weights: Optional[torch.Tensor] = None + ) -> Dict[str, float]: + """ + Train dynamics with proper hidden state flow and done masking + """ + batch_size, max_seq_len = latent_sequences.shape[:2] + device = self.device + hidden_states = self.reset_dynamics_hidden(batch_size) + # Initialize hidden states for all sequences + + + total_loss = 0 + total_nll = 0 + total_steps = 0 + + for t in range(max_seq_len - 1): + # Create mask for valid time steps + mask = (t < sequence_lengths - 1).float().to(device) + if mask.sum() == 0: + break + + current_latent = latent_sequences[:, t] + next_latent_true = latent_sequences[:, t + 1] + action = action_sequences[:, t] + done = done_sequences[:, t] + + # Predict next latent with hidden state + next_mean, next_logvar, new_hidden_states = self.predict_next_latent( + current_latent, action, hidden_states, done_mask=done + ) + + # Compute NLL loss + nll = 0.5 * ( + np.log(2 * np.pi) + + next_logvar + + (next_latent_true - next_mean).pow(2) / next_logvar.exp() + ).sum(dim=-1) + + # Apply sequence mask + if sample_weights is not None: + sw = sample_weights # [B] + masked_nll = ( (nll * mask) * sw ).sum() / ((mask * sw).sum() + torch.finfo(torch.float32).eps) + else: + masked_nll = (nll * mask).sum() / (mask.sum() + torch.finfo(torch.float32).eps) + + # Reset hidden states where episodes ended + # This is crucial for proper sequence handling! + if done.any(): + # Create new hidden states for completed episodes + new_hidden_states = self._blend_hidden(new=new_hidden_states, done=done) + + + hidden_states = new_hidden_states + + total_loss += masked_nll + total_nll += masked_nll.detach() + total_steps += 1 + + avg_loss = total_loss / max(total_steps, 1) + + return { + "dynamics_loss": avg_loss.item(), + "dynamics_nll": (total_nll / max(total_steps, 1)).item(), + "valid_steps": total_steps, + } + def compute_lambda_returns( self, - rewards: torch.Tensor, - values: torch.Tensor, - next_values: torch.Tensor, - dones: torch.Tensor, + rewards: torch.Tensor, # [B] + next_values: torch.Tensor, # [B], where next_values[t] = V(s_{t+1}) + dones: torch.Tensor, # [B] in {0,1} for transition t -> t+1 lambda_: float = 0.95, n_steps: int = 5, - exclude_immediate_rewards: bool = False + exclude_immediate_rewards: bool = False, ) -> torch.Tensor: """ - Compute λ-returns as in Dreamer v2. - - The λ-return is a weighted average of n-step returns: - When exclude_immediate_reward=True, returns are computed without immediate rewards, - making the value function learn V(s) = E[Σ_{t'=t+1}^T γ^{t'-t} r_{t'}] - instead of V(s) = E[Σ_{t'=t}^T γ^{t'-t} r_{t'}] + Correct n-step λ-returns with proper bootstrapping index: + G_t^(λ,n) = r_t + γ (1-d_t) [ (1-λ) V_{t+1} + λ G_{t+1} ] (unrolled n steps) + where V_{t+1} = next_values[t], and the final bootstrap at horizon J + uses V_{t+J} = next_values[t+J-1]. + + Assumes 'batch as time' (i.e., index t+1 is the next transition in the batch). """ - batch_size = rewards.shape[0] - device = rewards.device - - # Initialize returns storage - lambda_returns = torch.zeros_like(rewards).to(device) - - # Compute n-step returns - for idx in range(batch_size): - returns = [] - - # Calculate different n-step returns - for n in range(1, min(n_steps + 1, batch_size - idx)): - n_step_return = 0 - discount = 1.0 - - # Sum discounted rewards for n steps - for k in range(n): - if idx + k < batch_size: - if not (exclude_immediate_rewards and k == 0): - n_step_return += discount * rewards[idx + k] - discount *= self.config.discount_factor * (1 - dones[idx + k].float()) - - # Add bootstrapped value - if idx + n < batch_size and not dones[idx + n - 1]: - n_step_return += discount * next_values[idx + n] - - returns.append(n_step_return) - - # Compute weighted average with λ - if returns: - weighted_return = 0 - lambda_sum = 0 - - for i, ret in enumerate(returns[:-1]): - weight = (1 - lambda_) * (lambda_ ** i) - weighted_return += weight * ret - lambda_sum += weight - - # Last return gets remaining weight - if len(returns) > 0: - last_weight = lambda_ ** (len(returns) - 1) - weighted_return += last_weight * returns[-1] - lambda_sum += last_weight - - lambda_returns[idx] = weighted_return / (lambda_sum + 1e-8) - else: - if exclude_immediate_rewards: - lambda_returns[idx] = self.config.discount_factor * (1 - dones[idx].float()) * next_values[idx] - else: - lambda_returns[idx] = rewards[idx] + self.config.discount_factor * (1 - dones[idx].float()) * next_values[idx] - - return lambda_returns - + assert rewards.device == next_values.device == dones.device + assert rewards.dtype == next_values.dtype + + B = rewards.shape[0] + out = torch.zeros_like(rewards) + gamma = float(self.config.discount_factor) + lam = float(lambda_) + one_m = 1.0 - lam + dones = dones.float() + + # scalar zero on the correct device/dtype (no grads, no 0*x) + zero = rewards.new_zeros(()) + + for t in range(B): + # horizon length for this start + J = min(int(n_steps), B - t) + if J <= 0: + out[t] = zero + continue + + # Bootstrap V(s_{t+J}) which is stored at next_values[t+J-1] + idx_boot = t + J - 1 + if idx_boot >= B: # defensive bound (shouldn't fire with the J formula above) + idx_boot = B - 1 + G = next_values[idx_boot] + + # Backward recursion: k = J-1 ... 0 + for k in range(J - 1, -1, -1): + idx = t + k + + r = rewards[idx] + if exclude_immediate_rewards and k == 0: + r = zero # cleaner than 0.0 * rewards[idx] + + m = 1.0 - dones[idx] # stop bootstrapping at terminals + v1 = next_values[idx] # V(s_{idx+1}) + + # G_k = r_k + γ m_k [ (1-λ) V_{k+1} + λ G_{k+1} ] + G = r + gamma * m * (one_m * v1 + lam * G) + + out[t] = G + + return out + def _compute_gradient_penalty( - self, - noisy_latents: torch.Tensor, - t: torch.Tensor, - observations: torch.Tensor + self, noisy_latents: torch.Tensor, t: torch.Tensor, observations: torch.Tensor, + frame_index: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None ) -> torch.Tensor: """Gradient penalty for stable training""" - noisy_latents = noisy_latents.detach().requires_grad_(True) - score = self.latent_score_network(noisy_latents, t, observations) - + noisy_latents = noisy_latents.detach().requires_grad_(True) + score = self.latent_score_network(noisy_latents, t, observations, + frame_time=frame_index, action=actions) + gradients = torch.autograd.grad( outputs=score.sum(), inputs=noisy_latents, create_graph=True, - retain_graph=True + retain_graph=True, )[0] - + grad_norm = gradients.norm(2, dim=1) penalty = torch.mean((grad_norm - 1.0) ** 2) - + return penalty def _importance_sample_time( - self, - batch_size: int, - device: torch.device + self, batch_size: int, device: torch.device ) -> torch.Tensor: """Sample time with importance weights based on loss history""" - if not hasattr(self, 'time_importance_weights'): + if not hasattr(self, "time_importance_weights"): # Initialize uniform self.time_importance_weights = torch.ones(100, device=device) - + # Sample from categorical distribution probs = F.softmax(self.time_importance_weights, dim=0) indices = torch.multinomial(probs, batch_size, replacement=True) - + # Convert to continuous time t = (indices.float() + torch.rand(batch_size, device=device)) / 100.0 - + return t - def _update_time_importance( - self, - t: torch.Tensor, - loss: torch.Tensor - ): + def _update_time_importance(self, t: torch.Tensor, loss: torch.Tensor): """Update importance weights for time sampling""" - if not hasattr(self, 'time_importance_weights'): + if not hasattr(self, "time_importance_weights"): self.time_importance_weights = torch.ones(100, device=t.device) - + # Discretize time indices = (t * 99).long().clamp(0, 99) if loss.dim() > 1: - # If loss has multiple dimensions, reduce to scalar per sample - loss = loss.view(loss.shape[0], -1).sum(dim=1) + # If loss has multiple dimensions, reduce to scalar per sample + loss = loss.view(loss.shape[0], -1).sum(dim=1) # Update weights with EMA - for i in range(len(indices)): + for i in range(len(indices)): time_bin = indices[i].item() sample_loss = loss[i].item() current_weight = self.time_importance_weights[time_bin].item() new_weight = 0.99 * current_weight + 0.01 * sample_loss self.time_importance_weights[time_bin] = new_weight + def extract(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple) -> torch.Tensor: """Extract coefficients helper""" batch_size = t.shape[0] out = a.gather(-1, t) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) - + + class EMAModel: """Exponential Moving Average of model weights for stable training""" + def __init__(self, model, decay=0.9999, device=None): self.model = model self.decay = decay self.device = device self.shadow = {} self.backup = {} - + # Initialize shadow weights for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone().to(device) - + def update(self): """Update EMA weights""" for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = ( - self.decay * self.shadow[name] + - (1 - self.decay) * param.data + self.decay * self.shadow[name] + (1 - self.decay) * param.data ) - + def apply_shadow(self): """Apply EMA weights to model""" for name, param in self.model.named_parameters(): if param.requires_grad: self.backup[name] = param.data param.data = self.shadow[name] - + def restore(self): """Restore original weights""" for name, param in self.model.named_parameters(): if param.requires_grad: param.data = self.backup[name] + class EMALoss(torch.autograd.Function): @staticmethod def forward(ctx, input, running_ema): @@ -822,9 +1166,12 @@ def forward(ctx, input, running_ema): @staticmethod def backward(ctx, grad_output): input, running_mean = ctx.saved_tensors - grad = grad_output * input.exp().detach() / (running_mean + 1e-6) / input.shape[0] + grad = ( + grad_output * input.exp().detach() / (running_mean + 1e-6) / input.shape[0] + ) return grad, None + def ema_loss(x, running_mean, alpha=0.01): """Exponential moving average loss for stable MINE training""" t_exp = torch.exp(torch.logsumexp(x, 0) - math.log(x.shape[0])).detach() @@ -838,226 +1185,267 @@ def ema_loss(x, running_mean, alpha=0.01): class FunctionSpaceEpistemicEstimator(nn.Module): """ - Computes epistemic value I(o; θ | z) via function-space features - using neural tangent kernel approximation and MINE estimation + Memory-efficient epistemic estimator using JVP features of decoder outputs, + passed through a *provided* feature_extractor (built outside, e.g., in your + state agent or pixel agent). + + Clean API: + - decoder: nn.Module mapping latent z -> observation x + - feature_extractor: nn.Module mapping observation x -> feature vector + - is_pixel: whether x is [C,H,W] (True) or flat [D] (False) + - No internal construction of encoders. No AMP/mixed precision. + + Forward returns a per-batch epistemic score based on a MINE-DV lower bound: + I(Z; Φ(J_z f)) ≈ E_joint[T] - log E_marg[exp T] + where Φ encodes JVPs of f (decoder) along random directions in z-space. + + Returns a per-batch epistemic score (MINE-DV bound) and metrics. """ - + def __init__( - self, + self, + *, decoder: nn.Module, + feature_extractor: nn.Module, latent_dim: int, - observation_shape: Union[int,Tuple[int, int, int]], + observation_shape: Union[int, Tuple[int, int, int]], + is_pixel: bool, + device: Union[str, torch.device] = "cuda", + ntk_samples: int = 3, + jvp_chunk_size: int = 1, hidden_dim: int = 256, - spatial_aggregator_output_dim: int = 256, - is_pixel_observation: bool = True, - device: Union[str, torch.device] = 'cuda' - ): + jac_dim: int = 128, + latent_proj_dim: int = 128, + use_checkpointing: bool = False, + checkpoint_critic: Union[bool, None] = None, + checkpoint_jacproj: Union[bool, None] = None, + checkpoint_latproj: Union[bool, None] = None, + robust_marginals: bool = False, + eps: float = torch.finfo(torch.float16).eps, + ) -> None: super().__init__() + + # External modules (you provide them) self.decoder = decoder - self.latent_dim = latent_dim - self.is_pixel= is_pixel_observation - self.device = torch.device(device) if isinstance(device, str) else device - - # Jacobian approximation parameters - self.ntk_samples = 4 - self.perturbation_scale = nn.Parameter(torch.tensor(0.1)).to(self.device) - - if self.is_pixel: - self.pixel_shape = observation_shape - # Pixel-aware MINE architecture using ConvolutionalStatisticsNetwork pattern - self.pixel_processor = nn.Sequential( - nn.Conv2d(self.pixel_shape[0], 32, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), - nn.ReLU(), - ) - # Spatial attention aggregator instead of average pooling - self.spatial_aggregator = SpatialAttentionAggregator( - feature_dim=128, - num_heads=8, - spatial_dim=21 # After 3 stride-2 convolutions from 84x84 - ) - # Jacobian feature projection - jacobian_dim = spatial_aggregator_output_dim * self.ntk_samples # 128 channels, 2 spatial dimensions (H, W) - else: - self.state_dim = observation_shape - self.feature_extractor = nn.Sequential( - nn.Linear(self.state_dim, 128), - nn.ReLU(), - nn.Linear(128, 256), - nn.ReLU(), - nn.Linear(256, 128) - ) - jacobian_dim = 128 * self.ntk_samples + self.feature_extractor = feature_extractor + + # Shapes / flags + self.latent_dim = int(latent_dim) + self.observation_shape = observation_shape + self.is_pixel = bool(is_pixel) + self.device = torch.device(device) + + # JVP settings + self.ntk_samples = int(ntk_samples) + self.jvp_chunk_size = max(1, int(jvp_chunk_size)) + # Checkpointing config + self.use_checkpointing = bool(use_checkpointing) + self.cp_critic = self.use_checkpointing if checkpoint_critic is None else bool(checkpoint_critic) + self.cp_jacproj = self.use_checkpointing if checkpoint_jacproj is None else bool(checkpoint_jacproj) + self.cp_latproj = self.use_checkpointing if checkpoint_latproj is None else bool(checkpoint_latproj) + + self.robust_marginals = bool(robust_marginals) + self.eps = float(eps) + + # ---- One-time feature probe (no grads; not doing JVP here) ---- + self.decoder.eval() + self.feature_extractor.eval() + with torch.no_grad(): + dummy_z = torch.zeros(1, self.latent_dim, device=self.device) + dummy_x = self.decoder(dummy_z) # [1, D] or [1, C, H, W] + if self.is_pixel: + f = self.feature_extractor(dummy_x).reshape(1, -1) + else: + f = self.feature_extractor(dummy_x.reshape(1, -1)).reshape(1, -1) + self._feat_dim = int(f.shape[-1]) + + jacobian_dim = self.ntk_samples * self._feat_dim + + # Projectors + critic (cheap ReLU; no AMP) self.jacobian_projector = nn.Sequential( - nn.Linear(jacobian_dim, 512), - nn.LayerNorm(512), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(512, 256) + nn.Linear(jacobian_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, jac_dim), ) - - # Latent feature processor - self.latent_processor = nn.Sequential( - nn.Linear(latent_dim, 128), - nn.ReLU(), - nn.Linear(128, 128) + self.latent_projector = nn.Sequential( + nn.Linear(self.latent_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, latent_proj_dim), ) - - # MINE statistics network - self.mine_network = nn.Sequential( - nn.Linear(spatial_aggregator_output_dim + 128, 512), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(512, 512), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(512, 1) + self.critic = nn.Sequential( + nn.Linear(jac_dim + latent_proj_dim, hidden_dim), + nn.ReLU(inplace=True), + nn.Linear(hidden_dim, 1), ) - - # EMA for stable MINE training - self.register_buffer('running_mean', torch.tensor(0.0)) - self.alpha = 0.01 + + # Logging helpers + self.register_buffer("running_mean", torch.tensor(0.0), persistent=True) + self.register_buffer("ema_decay", torch.tensor(0.99), persistent=True) + + # Default back to training mode for user modules + self.train() self.to(self.device) - def to(self, device): - """Override to ensure proper device movement""" - super().to(device) - self.device = device if isinstance(device, torch.device) else torch.device(device) - - # Ensure decoder is also moved properly - if isinstance(self.decoder, nn.ModuleList): - for i in range(len(self.decoder)): - self.decoder[i] = self.decoder[i].to(device) + # ---------- utils ---------- + def to(self, device: Union[str, torch.device]): # type: ignore[override] + ret = super().to(device) + self.device = torch.device(device) + return ret + + @staticmethod + def _maybe_checkpoint(mod: nn.Module, x: torch.Tensor, use_cp: bool) -> torch.Tensor: + """ + Checkpoint only if requested AND at least one input requires grad. + (PyTorch checkpoint needs a grad-requiring tensor input.) + """ + if use_cp and isinstance(x, torch.Tensor) and x.requires_grad: + return cp.checkpoint(mod, x, use_reentrant=False) + return mod(x) + + def _encode_obs_features(self, obs: torch.Tensor) -> torch.Tensor: + """ + obs: [N, B, *] JVP outputs; returns [N, B, F]. No grads through encoder. + """ + N, B = obs.shape[:2] + if self.is_pixel: + x = obs.reshape(N * B, *obs.shape[2:]) + with torch.inference_mode(): + feats = self.feature_extractor(x).reshape(N, B, -1) else: - self.decoder = self.decoder.to(device) - - return self + x = obs.reshape(N * B, -1) + with torch.inference_mode(): + feats = self.feature_extractor(x).reshape(N, B, -1) + return feats - def compute_jacobian_features(self, z: torch.Tensor) -> torch.Tensor: + # ---------- JVP via torch.func.linearize (chunked tangents) ---------- + def _compute_chunked_jvp(self, z: torch.Tensor) -> torch.Tensor: """ - Approximates function-space features via finite differences - in the neural tangent kernel regime + z: [B, Dz] -> returns [B, ntk_samples * F] + Uses `linearize(self.decoder, z)` once, then applies jvp_fn to batched directions. """ - z = z.to(self.device) - batch_size = z.shape[0] - jacobian_samples = [] - decoder_training = self.decoder.training - self.decoder.eval() # Ensure decoder is in eval mode - # Base decoding - with torch.no_grad(): - - f_z = self.decoder(z) # (B, 3, 84, 84) - if self.is_pixel and f_z.dim() == 4: # (B, C, H, W) - f_z_flat = f_z.view(batch_size, -1) - else: - f_z_flat = f_z - epsilon = self.perturbation_scale - # Compute directional derivatives - for _ in range(self.ntk_samples): - # Sample perturbation direction - delta = F.normalize(torch.randn_like(z).to(self.device), dim=-1) * epsilon - - # Compute finite difference - with torch.no_grad(): - f_z_perturbed = self.decoder(z + delta) - - if self.is_pixel and f_z_perturbed.dim() == 4: - f_z_perturbed_flat = f_z_perturbed.view(batch_size, -1) - else: - f_z_perturbed_flat = f_z_perturbed - - # Directional derivative - diff = (f_z_perturbed_flat - f_z_flat) / epsilon - - # Process through pixel encoder - if self.is_pixel: - # Process pixel differences (use the decoder’s real output shape) - _, C, H, W = f_z.shape - diff_img = diff.view(batch_size, C, H, W) - diff_features = self.pixel_processor(diff_img) - spatial_features, _ = self.spatial_aggregator(diff_features) - jacobian_samples.append(spatial_features.view(batch_size, -1)) - else: - # Process state differences - diff_features = self.feature_extractor(diff) - jacobian_samples.append(diff_features) - - if decoder_training: - self.decoder.train() - # Average Jacobian features - jacobian_features = torch.cat(jacobian_samples, dim=1) - return self.jacobian_projector(jacobian_features) - + B = z.shape[0] + N = self.ntk_samples + F = self._feat_dim + + out = z.new_zeros(B, N * F) + + # Random unit directions (N, B, Dz) + dirs = torch.randn(N, B, self.latent_dim, device=self.device) + dirs = dirs / (dirs.norm(dim=-1, keepdim=True) + 1e-8) + + # Make sure z is contiguous for the linearization + z = z.contiguous() + + # Linearize decoder at z once; jvp_fn(v) returns J(z) @ v (same shape as decoder(z)) + # Do not wrap in no_grad; we want the primal trace for JVP. + _, jvp_fn = linearize(self.decoder, z) + + # Keep feature_extractor stable (no BN/Dropout drift) + was_train = self.feature_extractor.training + self.feature_extractor.eval() + + start = 0 + while start < N: + end = min(start + self.jvp_chunk_size, N) + v_chunk = dirs[start:end].contiguous() # [n, B, Dz] + n = v_chunk.shape[0] + + # Apply JVP per direction (sequential to cap memory) + jvp_list = [] + for i in range(n): + v_in = v_chunk[i] # [B, Dz] + tangent = jvp_fn(v_in) # [B, *obs] + jvp_list.append(tangent) + + # Stack to [n, B, *obs] then encode to features + Jv = torch.stack(jvp_list, dim=0) + feats_nbf = self._encode_obs_features(Jv) # [n, B, F] + block = feats_nbf.permute(1, 0, 2).reshape(B, n * F) # [B, n*F] + out[:, start * F : start * F + n * F] = block + + start = end + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if was_train: + self.feature_extractor.train() + del dirs + return out # [B, N*F] + + # ---------- Forward: MINE-DV lower bound ---------- def forward( - self, - next_latent_mean: torch.Tensor, - next_latent_logvar: torch.Tensor, - num_samples: int = 5 - ) -> torch.Tensor: + self, + next_latent_mean: torch.Tensor, # [B, Dz] + next_latent_logvar: torch.Tensor, # [B, Dz] + num_samples: int = 5, + ) -> Tuple[torch.Tensor, Dict[str, float]]: """ - Estimates epistemic value I(o; θ | z) using MINE + I ≈ E_joint[T] - log E_marg[exp(T)] + Returns per-batch scalar replicated to [B]. """ - batch_size = next_latent_mean.shape[0] - next_latent_mean = next_latent_mean.to(self.device) - next_latent_logvar = next_latent_logvar.to(self.device) - - # Sample latent states - z_samples = [] - for _ in range(num_samples): - z = next_latent_mean + torch.randn_like(next_latent_mean).to(self.device) * torch.exp(0.5 * next_latent_logvar) - z_samples.append(z) - - z_all = torch.cat(z_samples, dim=0) # (B*num_samples, latent_dim) - - # Compute Jacobian features (function-space representation) - jacobian_features = self.compute_jacobian_features(z_all) - - # Process latent features - latent_features = self.latent_processor(z_all) - - # Combine features for MINE - combined_features = torch.cat([jacobian_features, latent_features], dim=1) - - # MINE estimation with proper permutation - t_joint = self.mine_network(combined_features) - - # Create marginal by permuting within batch - jacobian_marginal_list = [] - for i in range(num_samples): - start_idx = i * batch_size - end_idx = (i + 1) * batch_size - batch_features = jacobian_features[start_idx:end_idx] - - # Shuffle within this batch - perm = torch.randperm(batch_size, device=self.device) - jacobian_marginal_list.append(batch_features[perm]) - - jacobian_marginal = torch.cat(jacobian_marginal_list, dim=0) - - # Marginal features - combined_marginal = torch.cat([jacobian_marginal, latent_features], dim=1) - t_marginal = self.mine_network(combined_marginal) - - # MINE lower bound with EMA - t_marginal_logsumexp, self.running_mean = ema_loss( - t_marginal, self.running_mean, self.alpha - ) - - mi_lower_bound = t_joint.mean() - t_marginal_logsumexp - - # Average over samples and ensure proper shape - epistemic_value = mi_lower_bound.expand(batch_size) - - # Prepare metrics for logging + B, Dz = next_latent_mean.shape + device = self.device + + # Sample z ~ N(mean, diag(exp(logvar))) + eps = torch.randn(num_samples, B, Dz, device=device) + std = torch.exp(0.5 * next_latent_logvar.unsqueeze(0)) + z_samples = next_latent_mean.unsqueeze(0) + eps * std + z_samples = z_samples.reshape(num_samples * B, Dz).contiguous() # [S*B, Dz] + + # JVP features → projector (checkpoint if enabled & grad flows) + jac_feats = self._compute_chunked_jvp(z_samples) # [S*B, N*F] + jac_proj = self._maybe_checkpoint(self.jacobian_projector, + jac_feats, self.cp_jacproj) # [S*B, jac_dim] + + # Latent projection (with grads) + lat_proj = self._maybe_checkpoint(self.latent_projector, + z_samples, self.cp_latproj) # [S*B, lat_dim] + + # Critic on joint pairs + joint_in = torch.cat([jac_proj, lat_proj], dim=-1) # [S*B, *] + t_joint = self._maybe_checkpoint(self.critic, joint_in, self.cp_critic).squeeze(-1) + + # Critic on marginal pairs (single perm or product of marginals) + def _marg_block(jac, lat): + # jac, lat: [S*B, D] + if self.robust_marginals: + jac_perm = torch.randperm(jac.shape[0], device=jac.device) + lat_perm = torch.randperm(lat.shape[0], device=lat.device) + m = torch.cat([jac[jac_perm], lat[lat_perm]], dim=-1) + else: + perm = torch.randperm(jac.shape[0], device=jac.device) + m = torch.cat([jac, lat[perm]], dim=-1) + return self.critic(m).squeeze(-1) + + if self.cp_critic: + # use_reentrant=False preserves RNG state for recomputation + t_marg = cp.checkpoint(_marg_block, jac_proj, lat_proj, use_reentrant=False) + else: + t_marg = _marg_block(jac_proj, lat_proj) + + # DV bound (stable) + t_joint_mean = t_joint.mean() + t_marg_logmeanexp = torch.logsumexp(t_marg, dim=0) - math.log(t_marg.numel() + self.eps) + mi_lower_bound = t_joint_mean - t_marg_logmeanexp + + # EMA for logging + self.running_mean = self.ema_decay * self.running_mean + (1 - self.ema_decay) * mi_lower_bound.detach() + + epistemic_value = mi_lower_bound.expand(B) + del eps, std, z_samples, jac_feats, jac_proj, lat_proj, joint_in, t_joint, t_marg + if torch.cuda.is_available(): + torch.cuda.empty_cache() metrics = { - 'epistemic/mi_estimate': mi_lower_bound.item(), - 'epistemic/joint_term': t_joint.mean().item(), - 'epistemic/marginal_term': t_marginal_logsumexp.item(), - 'epistemic/running_mean': self.running_mean.item() + "epistemic/mi_estimate": float(mi_lower_bound.detach().cpu()), + "epistemic/joint_mean": float(t_joint_mean.detach().cpu()), + "epistemic/marg_logmeanexp": float(t_marg_logmeanexp.detach().cpu()), + "epistemic/running_mean": float(self.running_mean.detach().cpu()), + "epistemic/ntk_samples": self.ntk_samples, + "epistemic/jvp_chunk_size": self.jvp_chunk_size, + "epistemic/feat_dim": self._feat_dim, + "epistemic/checkpoint": int(self.use_checkpointing), + "epistemic/robust_marginals": int(self.robust_marginals), } - return torch.clamp(epistemic_value, min=0.0), metrics diff --git a/active_inference_diffusion/core/diffusion.py b/active_inference_diffusion/core/diffusion.py index 1dfbeb5..f4719e5 100644 --- a/active_inference_diffusion/core/diffusion.py +++ b/active_inference_diffusion/core/diffusion.py @@ -52,6 +52,8 @@ def __init__(self, config, latent_dim: int = 64): # Loss weight annealing self.register_buffer('loss_weight_cache', torch.zeros(1000)) self.loss_weight_computed = False + self.inference_steps = config.inference_steps if hasattr(config, 'inference_steps') else 1000 + self.ddim_eta = getattr(config, 'ddim_eta', 0.3) def compute_log_snr(self, t: torch.Tensor) -> torch.Tensor: """Compute log signal-to-noise ratio for continuous time""" @@ -63,7 +65,9 @@ def continuous_q_sample( self, z_start: torch.Tensor, t: torch.Tensor, # Now continuous in [0, 1] - noise: Optional[torch.Tensor] = None + noise: Optional[torch.Tensor] = None, + frame_time: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: """Enhanced q_sample for continuous time""" if noise is None: @@ -85,7 +89,9 @@ def continuous_q_sample( info = { 'log_snr': log_snr, 'alpha': alpha, - 'sigma': sigma + 'sigma': sigma, + 'frame_time': frame_time, + 'actions': actions } return z_noisy, noise, info @@ -147,10 +153,9 @@ def sample_latent_prior(self, batch_size: int, device: torch.device) -> torch.Te """Sample from learned latent prior p_θ(z)""" mean = self.latent_prior_mean.unsqueeze(0).expand(batch_size, -1) std = torch.exp(self.latent_prior_log_std).unsqueeze(0).expand(batch_size, -1) - - eps = torch.randn_like(mean) - return mean + std * eps - + + return mean, std + def q_sample( self, z_start: torch.Tensor, @@ -178,13 +183,50 @@ def generate_latent_trajectory( score_network: nn.Module, batch_size: int, observation: Optional[torch.Tensor] = None, - deterministic: bool = False + deterministic: bool = False, + frame_time: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + force_drop_action: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """ Generate latent trajectory via reverse diffusion This is the core innovation - generating belief representations """ device = next(score_network.parameters()).device + if frame_time is not None: + frame_time = frame_time.to(device) + if action is not None: + action = action.to(device) + if force_drop_action is not None: + force_drop_action = force_drop_action.to(device) + + if self.inference_steps is None or self.inference_steps >= self.config.num_diffusion_steps: + # Use original DDPM sampling + return self._generate_trajectory_ddpm( + score_network, batch_size, observation, deterministic, + frame_time, action, force_drop_action + ) + else: + # Use DDIM sampling with fewer steps + eta = self.ddim_eta if not deterministic else 0.0 + return self._generate_trajectory_ddim( + score_network, batch_size, observation, self.inference_steps, eta, + frame_time, action, force_drop_action + ) + + def _generate_trajectory_ddpm( + self, + score_network: nn.Module, + batch_size: int, + observation: Optional[torch.Tensor] = None, + deterministic: bool = False, + frame_time: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + force_drop_action: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + """Original DDPM sampling (your existing code, just moved here)""" + device = next(score_network.parameters()).device + # Ensure observation is on correct device if observation is not None: observation = observation.to(device) @@ -195,17 +237,92 @@ def generate_latent_trajectory( # Reverse diffusion process for t in reversed(range(self.config.num_diffusion_steps)): t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) - + t_batch = torch.clamp(t_batch, 0, self.config.num_diffusion_steps - 1) # Predict score conditioned on observation - score = score_network(z, t_batch.float(), observation) + score = score_network(z, t_batch, observation, + frame_time=frame_time, + action=action, + force_drop_action=force_drop_action) # Update latent z = self.p_sample(z, t_batch, score, deterministic=deterministic) trajectory.append(z) return trajectory + def _generate_trajectory_ddim( + self, + score_network: nn.Module, + batch_size: int, + observation: Optional[torch.Tensor], + inference_steps: int, + eta: float, + frame_time: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + force_drop_action: Optional[torch.Tensor] = None + ) -> List[torch.Tensor]: + """DDIM sampling with configurable stochasticity""" + device = next(score_network.parameters()).device + if observation is not None: + observation = observation.to(device) + + # Create subsequence of timesteps + # This is the key to DDIM's speed - we skip steps! + step_ratio = max(1, self.config.num_diffusion_steps // inference_steps) + timesteps = list(range(0, self.config.num_diffusion_steps, step_ratio))[::-1] + z = torch.randn(batch_size, self.latent_dim, device=device) + trajectory = [z] + + for i in range(len(timesteps) - 1): + t = timesteps[i] + t_next = timesteps[i + 1] + + t_batch = torch.full((batch_size,), t, device=device, dtype=torch.long) + score = score_network(z, t_batch, observation, + frame_time=frame_time, + action=action, + force_drop_action=force_drop_action) + if torch.isnan(score).any() or torch.isinf(score).any(): + raise ValueError(f"Score network output contains NaN or Inf values at timestep {t}") + + # Use modified p_sample that handles DDIM logic + z = self.p_sample(z, t_batch, score, + deterministic=(eta == 0), + t_next=t_next, + eta=eta) + trajectory.append(z) + + # Final step to t=0 + if timesteps[-1] > 0: + t_batch = torch.full((batch_size,), timesteps[-1], device=device, dtype=torch.long) + score = score_network(z, t_batch, observation, frame_time=frame_time, action=action, + force_drop_action=force_drop_action) + z = self.p_sample(z, t_batch, score, + deterministic=(eta == 0), + t_next=0, + eta=eta) + trajectory.append(z) + + return trajectory + def p_sample( + self, + z_t: torch.Tensor, + t: torch.Tensor, + score: torch.Tensor, + deterministic: bool = False, + t_next: Optional[int] = None, + eta: float = 0.3 + ) -> torch.Tensor: + + if t_next is None: + # Original DDPM sampling + return self._p_sample_ddpm(z_t, t, score, deterministic) + else: + # DDIM sampling with controllable stochasticity + return self._p_sample_ddim(z_t, t, t_next, score, eta) + + def _p_sample_ddpm( self, z_t: torch.Tensor, t: torch.Tensor, @@ -224,7 +341,8 @@ def p_sample( sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, z_t.shape) # Predict z_0 - predicted_z_start = (z_t + sqrt_one_minus_alphas_cumprod_t * score) * sqrt_recip_alphas_t + eps_hat =-sqrt_one_minus_alphas_cumprod_t * score + predicted_z_start = (z_t - sqrt_one_minus_alphas_cumprod_t * eps_hat) * sqrt_recip_alphas_t # Compute posterior mean posterior_mean = self._posterior_mean(predicted_z_start, z_t, t) @@ -253,10 +371,49 @@ def _posterior_mean( ) return posterior_mean_coef1 * z_start + posterior_mean_coef2 * z_t + + def _p_sample_ddim( + self, + z_t: torch.Tensor, + t: torch.Tensor, + t_next: int, + score: torch.Tensor, + eta: float + ) -> torch.Tensor: + """ + DDIM sampling step with controllable stochasticity + + eta controls the amount of stochasticity: + - eta = 0: Completely deterministic (pure DDIM) + - eta = 1: Equivalent to DDPM + - 0 < eta < 1: Partially stochastic (recommended: 0.3-0.5) + """ + # Get alpha values + alpha_t = extract(self.alphas_cumprod, t, z_t.shape) + alpha_next = self.alphas_cumprod[t_next] if t_next > 0 else torch.ones_like(alpha_t) + + # Compute the predicted start point + sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t) + eps_hat = -sqrt_one_minus_alpha_t * score + pred_z0 = (z_t - sqrt_one_minus_alpha_t * eps_hat) / torch.sqrt(alpha_t) + # Compute variance for this step (this is where eta comes in!) + sigma_t = eta * torch.sqrt(torch.clamp((1 - alpha_next) / (1 - alpha_t) * (1 - alpha_t / alpha_next), min=0)) + + # Compute the "direction" pointing from z_t to z_0 + pred_dir_zt = torch.sqrt(1 - alpha_next - sigma_t ** 2) * eps_hat + # Compute the next sample + z_next = torch.sqrt(alpha_next) * pred_z0 + pred_dir_zt + # Add noise scaled by sigma_t (this is the stochastic part!) + if eta > 0 and t[0] > 0: # No noise at the final step + noise = torch.randn_like(z_t) + z_next = z_next + sigma_t * noise + + return z_next def extract(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple) -> torch.Tensor: """Extract coefficients at timestep t""" + t = t.to(device=a.device) batch_size = t.shape[0] out = a.gather(-1, t) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) \ No newline at end of file diff --git a/active_inference_diffusion/core/free_energy.py b/active_inference_diffusion/core/free_energy.py index 16bc2f6..2e09812 100644 --- a/active_inference_diffusion/core/free_energy.py +++ b/active_inference_diffusion/core/free_energy.py @@ -1,103 +1,99 @@ """ Free Energy computation module """ - import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Dict, Optional - class FreeEnergyComputation(nn.Module): """ - Computes variational free energy for active inference - - F = E_q[log q(z) - log p(z,o)] - = D_KL[q(z)||p(z)] - E_q[log p(o|z)] + F = D_KL[q(z)||p(z)] - E_q[log p(o|z)] = Complexity - Accuracy """ - - def __init__(self, precision_init: float = 1.0): + def __init__(self, + precision_init: float = 1.0, + observation_decoder: Optional[nn.Module] = None, + is_pixel_observation: bool = False): super().__init__() - - # Learnable precision (inverse variance of sensory noise) self.log_precision = nn.Parameter(torch.log(torch.tensor(precision_init))) - + self.observation_decoder = observation_decoder + self.is_pixel_observation = is_pixel_observation + @property def precision(self) -> torch.Tensor: return torch.exp(self.log_precision) - + + def set_decoder(self, decoder: nn.Module): + self.observation_decoder = decoder + def compute_loss( self, - states: torch.Tensor, - observations: torch.Tensor, - actions: torch.Tensor, + states: torch.Tensor, # z + observations: torch.Tensor, # o (features) OR raw pixels if state-based task score_network: nn.Module, current_time: float = 0.0, prior_mean: Optional[torch.Tensor] = None, - prior_std: float = 1.0 + prior_std: float = 1.0, + raw_observations: Optional[torch.Tensor] = None, # required if is_pixel_observation & pixel loss + frame_idx: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Compute free energy loss - - Args: - states: Latent states q(z) - observations: Observations o - actions: Actions taken - score_network: Network computing ∇log p(z|π) - current_time: Current time step - prior_mean: Mean of prior p(z) - prior_std: Std of prior p(z) - - Returns: - Free energy loss and component dictionary - """ - batch_size = states.shape[0] - device = states.device - - # Prior + if prior_mean is None: prior_mean = torch.zeros_like(states) - - # Complexity: D_KL[q(z)||p(z)] - # For Gaussian: 0.5 * sum((μ_q - μ_p)²/σ_p² + σ_q²/σ_p² - 1 - log(σ_q²/σ_p²)) - # Simplified assuming unit variance for q - complexity = 0.5 * torch.sum( - (states - prior_mean) ** 2 / (prior_std ** 2), - dim=-1 - ).mean() - - # Accuracy: -E_q[log p(o|z)] - # Gaussian observation model: -0.5 * precision * ||o - z||² - observation_error = torch.sum((observations - states) ** 2, dim=-1) - accuracy = -0.5 * self.precision * observation_error.mean() - - # Score matching regularization - t = torch.full((batch_size,), current_time, device=device) - score = score_network(states, t, observations) - score_reg = 0.01 * torch.sum(score ** 2, dim=-1).mean() - - # Total free energy + + B = states.shape[0] + device = states.device + + # Complexity term (unit-variance prior for simplicity) + complexity = 0.5 * ((states - prior_mean) ** 2 / (prior_std ** 2)).sum(dim=-1).mean() + + # ---- Accuracy term: decode z -> observation space, then compare to targets ---- + if self.observation_decoder is None: + raise ValueError("FreeEnergyComputation: observation_decoder is not set.") + + decoded = self.observation_decoder(states) + + if self.is_pixel_observation and raw_observations is not None: + # Compare in pixel space (B,C,H,W) + target = raw_observations + if decoded.shape != target.shape: + # fallback: flatten if shapes differ + decoded = decoded.flatten(1) + target = target.flatten(1) + obs_err = F.mse_loss(decoded, target, reduction='none') + obs_err = obs_err.view(B, -1).mean(dim=1) # per-sample MSE + else: + # Compare to feature/state observations (B, D) + if decoded.dim() > 2: + decoded = decoded.flatten(1) + if observations.dim() > 2: + observations = observations.flatten(1) + obs_err = F.mse_loss(decoded, observations, reduction='none').mean(dim=1) + + accuracy = -0.5 * self.precision * obs_err.mean() + + # Score regularizer (keep as in your pipeline) + t = torch.full((B,), current_time, device=device) + score = score_network(states, t, observations, frame_time=frame_idx, action=actions) + score_reg = 0.01 * (score ** 2).sum(dim=-1).mean() + free_energy = complexity - accuracy + score_reg - + info = { - 'complexity': complexity, - 'accuracy': -accuracy, # Make positive for logging - 'observation_error': observation_error.mean(), - 'score_regularization': score_reg, - 'precision': self.precision + "complexity": complexity.detach(), + "accuracy": (-accuracy).detach(), # positive for logging + "observation_error": obs_err.mean().detach(), + "score_regularization": score_reg.detach(), + "precision": self.precision.detach(), + "reconstruction_mse": obs_err.mean().detach(), } - return free_energy, info - + def update_precision(self, complexity: torch.Tensor, accuracy: torch.Tensor): - """ - Update precision based on prediction errors - Higher complexity relative to accuracy -> decrease precision - """ + # Decrease precision when complexity > accuracy (sign fix) with torch.no_grad(): precision_error = complexity - accuracy - self.log_precision.data += 0.01 * precision_error.clamp(-1, 1) - self.log_precision.data = self.log_precision.data.clamp(-3, 3) - - + self.log_precision.data -= 0.01 * precision_error.clamp(-1, 1) + self.log_precision.data.clamp_(-3, 3) diff --git a/active_inference_diffusion/encoder/visual_encoders.py b/active_inference_diffusion/encoder/visual_encoders.py index d0c7902..8adc17f 100644 --- a/active_inference_diffusion/encoder/visual_encoders.py +++ b/active_inference_diffusion/encoder/visual_encoders.py @@ -8,20 +8,168 @@ import torch.nn.functional as F from typing import Tuple, Optional, Dict import numpy as np +import math +from typing import Tuple, Optional, Literal +import torch.utils.checkpoint as cp +# ----------------------------- ATTENTION BLOCKS ------------------------------ + +class GlobalContextAttention2D(nn.Module): + """ + Cross-attention over image features: + - Queries: full-resolution feature map + - Keys/Vals: pooled (downsampled) feature map + This provides global context with O(HW * (H'W')) complexity (H'W' << HW). + + Args: + dim: channel dim of input features + heads: number of attention heads + dim_head: per-head dimension (total dim = heads * dim_head) after proj + pool_hw: target pooled spatial size for K/V (int or (h, w)) + dropout: dropout on attention probs and MLP + """ + def __init__( + self, + dim: int, + heads: int = 4, + dim_head: int = 32, + pool_hw: int | Tuple[int, int] = 7, + dropout: float = 0.0, + ): + super().__init__() + self.heads = heads + self.dim_head = dim_head + inner_dim = heads * dim_head + + self.q_proj = nn.Conv2d(dim, inner_dim, 1, bias=False) + self.k_proj = nn.Conv2d(dim, inner_dim, 1, bias=False) + self.v_proj = nn.Conv2d(dim, inner_dim, 1, bias=False) + self.out_proj = nn.Conv2d(inner_dim, dim, 1, bias=False) + + if isinstance(pool_hw, int): + self.pool_hw = (pool_hw, pool_hw) + else: + self.pool_hw = tuple(pool_hw) + + self.attn_drop = nn.Dropout(dropout) + self.proj_drop = nn.Dropout(dropout) + self.scale = dim_head ** -0.5 + + # simple LayerNorm over channels via GroupNorm(1, C) + self.norm_q = nn.GroupNorm(1, dim) + self.norm_kv = nn.GroupNorm(1, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [B, C, H, W] + returns: [B, C, H, W] + """ + b, c, h, w = x.shape + + # pre-norm on channels + q_in = self.norm_q(x) + kv_in = self.norm_kv(x) + + # pooled K/V spatial grid + ph, pw = self.pool_hw + kv = F.adaptive_avg_pool2d(kv_in, (ph, pw)) # [B, C, ph, pw] + + # projections + q = self.q_proj(q_in) # [B, heads*dim_head, H, W] + k = self.k_proj(kv) # [B, heads*dim_head, ph, pw] + v = self.v_proj(kv) # [B, heads*dim_head, ph, pw] + + # reshape to [B, heads, tokens, dim_head] + def to_heads(t, H, W): + t = t.view(b, self.heads, self.dim_head, H * W) # [B,H,dh,HW] + return t.permute(0, 1, 3, 2).contiguous() # [B,H,HW,dh] + + q = to_heads(q, h, w) # [B, heads, HW, dh] + k = to_heads(k, ph, pw) # [B, heads, ph*pw, dh] + v = to_heads(v, ph, pw) # [B, heads, ph*pw, dh] + + # attention + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B,H,HW,ph*pw] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + out = torch.matmul(attn, v) # [B,H,HW,dh] + out = out.permute(0, 1, 3, 2).contiguous().view(b, self.heads * self.dim_head, h, w) + out = self.out_proj(out) + out = self.proj_drop(out) + return x + out # residual + + +class MHSA2D(nn.Module): + """ + Full multi-head self-attention over all spatial tokens. + Heavier than GlobalContextAttention2D: O((HW)^2). + + Args: + dim: channels + heads: number of heads + dim_head: per-head dim + dropout: dropout prob + """ + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32, dropout: float = 0.0): + super().__init__() + self.heads = heads + self.dim_head = dim_head + inner_dim = heads * dim_head + + self.qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias=False) + self.out_proj = nn.Conv2d(inner_dim, dim, 1, bias=False) + self.scale = dim_head ** -0.5 + self.attn_drop = nn.Dropout(dropout) + self.proj_drop = nn.Dropout(dropout) + self.norm = nn.GroupNorm(1, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w = x.shape + x_n = self.norm(x) + qkv = self.qkv(x_n) + q, k, v = torch.chunk(qkv, 3, dim=1) # each [B, inner, H, W] + + def to_heads(t): + t = t.view(b, self.heads, self.dim_head, h * w) + return t.permute(0, 1, 3, 2).contiguous() # [B,H,HW,dh] + + q, k, v = map(to_heads, (q, k, v)) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [B,H,HW,HW] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + out = torch.matmul(attn, v) # [B,H,HW,dh] + out = out.permute(0, 1, 3, 2).contiguous().view(b, self.heads * self.dim_head, h, w) + out = self.out_proj(out) + out = self.proj_drop(out) + return x + out + + +# ------------------------------- ENCODER ------------------------------------- + class DrQV2Encoder(nn.Module): """ - Enhanced DrQ-v2 encoder with modern architectural improvements - - Key improvements: - - Spectral normalization for training stability - - Spatial attention mechanisms - - Graduated dropout rates - - Better activation functions (Mish) - - Group normalization for better batch-size independence + DrQ-v2 style visual encoder with: + • Conv trunk (GroupNorm + Mish) + • Real attention block (choose 'global' or 'mhsa') + • Projection MLP to feature_dim + • Optional gradient checkpointing + + Args: + obs_shape: (C, H, W) base observation shape (pre frame-stack) + feature_dim: output feature size + frame_stack: frames to stack on channel dim + num_layers: #conv layers in trunk + num_filters: base channels in trunk + use_spectral_norm: if True, apply SN to convs + attention: 'none' | 'global' | 'mhsa' + attn_heads, attn_dim_head: attention config + attn_pool_hw: pooled grid for global attention (int or (h, w)) + dropout: trunk dropout rate schedule scalar (scaled across layers) + checkpoint_trunk/attention/head: enable checkpointing for those parts """ - def __init__( self, obs_shape: Tuple[int, int, int], @@ -29,201 +177,154 @@ def __init__( frame_stack: int = 1, num_layers: int = 4, num_filters: int = 32, - use_spectral_norm: bool = True, - use_attention: bool = True + use_spectral_norm: bool = False, + attention: Literal["none", "global", "mhsa"] = "global", + attn_heads: int = 4, + attn_dim_head: int = 32, + attn_pool_hw: int | Tuple[int, int] = 7, + dropout: float = 0.1, + checkpoint_trunk: bool = False, + checkpoint_attention: bool = False, + checkpoint_head: bool = False, ): super().__init__() - - # Handle frame stacking in channels + c, h, w = obs_shape self.base_channels = c self.frame_stack = frame_stack - self.input_channels = c * frame_stack - self.use_attention = use_attention - - self.obs_shape = (self.input_channels, h, w) + self.in_channels = c * frame_stack + self.feature_dim = feature_dim - - # Build convolutional layers with progressive channel increase + self.num_layers = num_layers + self.use_spectral_norm = use_spectral_norm + self.attention_kind = attention + self.checkpoint_trunk = checkpoint_trunk + self.checkpoint_attention = checkpoint_attention + self.checkpoint_head = checkpoint_head + + # ------------- Conv trunk ------------- + chans = [self.in_channels] + [num_filters * (2 ** min(i, 3)) for i in range(num_layers)] self.convs = nn.ModuleList() self.norms = nn.ModuleList() self.dropouts = nn.ModuleList() - - # Channel progression: 32 -> 64 -> 128 -> 256 - channels = [self.input_channels] + [num_filters * (2 ** min(i, 3)) for i in range(num_layers)] - + for i in range(num_layers): - in_channels = channels[i] - out_channels = channels[i + 1] - stride = 2 if i == 0 else 1 - - # Convolutional layer with optional spectral normalization - conv = nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=1, - bias=False # No bias when using normalization - ) - + conv = nn.Conv2d(chans[i], chans[i + 1], kernel_size=3, stride=2 if i == 0 else 1, padding=1, bias=False) if use_spectral_norm: conv = nn.utils.spectral_norm(conv) - self.convs.append(conv) - - # Group normalization (works better than batch norm for RL) - self.norms.append(nn.GroupNorm( - num_groups=min(32, out_channels // 4), - num_channels=out_channels - )) - - # Progressive dropout rates (deeper layers = more dropout) - dropout_rate = 0.1 * (i / num_layers) # 0.0 to 0.075 - self.dropouts.append(nn.Dropout2d(dropout_rate)) - - # Add spatial attention module after conv layers - if self.use_attention: - self.attention = SpatialAttention(channels[-1]) - - # Calculate flattened dimension - dummy = torch.zeros(1, *self.obs_shape) - for i, conv in enumerate(self.convs): - dummy = conv(dummy) - dummy = self.norms[i](dummy) - dummy = F.mish(dummy) # Using Mish activation - if i < len(self.convs) - 1: # No dropout on last layer - dummy = self.dropouts[i](dummy) - - if self.use_attention: - dummy = self.attention(dummy) - - self.conv_out_dim = dummy.view(1, -1).shape[1] - - # Enhanced output projection with residual - self.ln = nn.LayerNorm(self.conv_out_dim) - - # Multi-layer output projection for better feature extraction - self.output_layers = nn.Sequential( - nn.Linear(self.conv_out_dim, feature_dim * 2), + + self.norms.append(nn.GroupNorm(num_groups=min(32, chans[i + 1] // 4), num_channels=chans[i + 1])) + # progressively increase dropout (none on last conv) + p = dropout * (i / max(1, num_layers - 1)) + self.dropouts.append(nn.Dropout2d(p if i < num_layers - 1 else 0.0)) + + # ------------- Real attention ------------- + if attention == "global": + self.attn = GlobalContextAttention2D(chans[-1], heads=attn_heads, dim_head=attn_dim_head, pool_hw=attn_pool_hw) + elif attention == "mhsa": + self.attn = MHSA2D(chans[-1], heads=attn_heads, dim_head=attn_dim_head) + else: + self.attn = nn.Identity() + + # Probe output spatial size + flatten dim + with torch.no_grad(): + dummy = torch.zeros(1, self.in_channels, h, w) + x = dummy + for i in range(num_layers): + x = self._trunk_block(x, i) + x = self.attn(x) + conv_out_dim = x.view(1, -1).shape[1] + + self.pre_flat_norm = nn.LayerNorm(conv_out_dim) + + # ------------- Projection head ------------- + self.head = nn.Sequential( + nn.Linear(conv_out_dim, feature_dim * 2, bias=True), nn.LayerNorm(feature_dim * 2), nn.Mish(), nn.Dropout(0.1), - nn.Linear(feature_dim * 2, feature_dim), + nn.Linear(feature_dim * 2, feature_dim, bias=True), nn.LayerNorm(feature_dim), - nn.Tanh() + nn.Tanh(), ) - - # Initialize weights properly - self._initialize_weights() - - def _initialize_weights(self): - """Proper weight initialization for stable training""" + + self._init_weights() + + # ---- init helpers ---- + def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): - # He initialization for ReLU family activations - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): - # Xavier initialization for linear layers nn.init.xavier_uniform_(m.weight) if m.bias is not None: - nn.init.constant_(m.bias, 0) + nn.init.zeros_(m.bias) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + # ---- blocks with optional checkpoint ---- + def _maybe_cp(self, fn, x: torch.Tensor, enabled: bool): + if enabled and isinstance(x, torch.Tensor) and x.requires_grad: + return cp.checkpoint(fn, x, use_reentrant=False) + return fn(x) + + def _trunk_block(self, x: torch.Tensor, i: int) -> torch.Tensor: + conv, norm, drop = self.convs[i], self.norms[i], self.dropouts[i] + def block(t): + t = conv(t) + t = norm(t) + t = F.mish(t) + t = drop(t) + return t + return self._maybe_cp(block, x, self.checkpoint_trunk) + + def _attn_block(self, x: torch.Tensor) -> torch.Tensor: + return self._maybe_cp(self.attn, x, self.checkpoint_attention) + + def _head_block(self, x_flat: torch.Tensor) -> torch.Tensor: + return self._maybe_cp(self.head, x_flat, self.checkpoint_head) + + # ---- forward ---- def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Enhanced forward pass with better feature extraction - - The encoding process: - 1. Progressive convolutions with increasing channels - 2. Group normalization for stability - 3. Mish activation for smooth gradients - 4. Progressive dropout for regularization - 5. Spatial attention for focusing on important regions - 6. Multi-layer projection for rich features + x: [B,C,H,W] or [B,T,C,H,W] with T==frame_stack + returns: [B, feature_dim] """ - # Handle different input formats - if x.dim() == 5: # (B, T, C, H, W) - separate frames + # frame stack handling + if x.dim() == 5: b, t, c, h, w = x.shape assert t == self.frame_stack, f"Expected {self.frame_stack} frames, got {t}" x = x.reshape(b, t * c, h, w) - elif x.dim() == 4: # (B, C, H, W) - already stacked or single frame + elif x.dim() == 4: b, c, h, w = x.shape if c == self.base_channels and self.frame_stack > 1: - # Single frame when expecting stack - repeat x = x.repeat(1, self.frame_stack, 1, 1) - elif c != self.input_channels: - raise ValueError(f"Expected {self.input_channels} channels, got {c}") - elif x.dim() == 3: # Single image without batch + elif c != self.in_channels: + raise ValueError(f"Expected {self.in_channels} channels, got {c}") + elif x.dim() == 3: x = x.unsqueeze(0) else: raise ValueError(f"Unexpected observation shape: {x.shape}") - - # Normalize if needed + if x.dtype == torch.uint8: x = x.float() / 255.0 - - # Progressive feature extraction - for i, (conv, norm, dropout) in enumerate(zip(self.convs, self.norms, self.dropouts)): - x = conv(x) - x = norm(x) - x = F.mish(x) # Smooth activation function - - # Apply dropout (except on last conv layer) - if i < len(self.convs) - 1: - x = dropout(x) - - # Apply spatial attention if enabled - if self.use_attention: - x = self.attention(x) - - # Flatten and project - x = x.view(x.shape[0], -1) - x = self.ln(x) - - # Multi-layer output projection - features = self.output_layers(x) - - return features + # conv trunk + for i in range(self.num_layers): + x = self._trunk_block(x, i) -class SpatialAttention(nn.Module): - """ - Spatial attention module to focus on important image regions - Uses both average and max pooling for robust attention - """ - - def __init__(self, channels: int): - super().__init__() - - # Channel reduction for efficiency - reduced_channels = max(channels // 8, 16) - - self.channel_reduce = nn.Conv2d(channels, reduced_channels, 1) - self.spatial_conv = nn.Conv2d(2, 1, kernel_size=7, padding=3) - - # Learnable temperature for attention sharpness - self.temperature = nn.Parameter(torch.ones(1)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Apply spatial attention""" - # Channel-wise statistics - avg_pool = torch.mean(x, dim=1, keepdim=True) - max_pool, _ = torch.max(x, dim=1, keepdim=True) - - # Concatenate and generate attention map - pool_concat = torch.cat([avg_pool, max_pool], dim=1) - attention_map = self.spatial_conv(pool_concat) - - # Apply temperature-controlled sigmoid - attention_map = torch.sigmoid(attention_map / self.temperature) - - # Apply attention with residual connection - attended = x * attention_map - return x + attended # Residual for gradient flow + # real attention + x = self._attn_block(x) + # flatten + projection + x = x.view(x.size(0), -1) + x = self.pre_flat_norm(x) + feat = self._head_block(x) + return feat class ConvDecoder(nn.Module): """ @@ -245,6 +346,7 @@ def __init__( hidden_dim: int = 256, spatial_size: int = 21, # For 84x84 output use_spectral_norm: bool = True, + frame_stack: int = 1, device: Optional[torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu") ): super().__init__() @@ -254,7 +356,9 @@ def __init__( self.spatial_size = spatial_size self.img_channels = img_channels self.device = device - + self.hidden_dim = hidden_dim + self.frame_stack = frame_stack + # Initial projection with careful initialization self.latent_proj = nn.Sequential( nn.Linear(latent_dim, hidden_dim * 4), @@ -289,7 +393,7 @@ def __init__( ) ) - # Block 3: Refine at 42x42 + # Block 3: Refine at 42x42 -> 84x84 self.decoder_blocks.append( DecoderBlock( in_channels=hidden_dim // 2, @@ -299,27 +403,16 @@ def __init__( ) ) - # Block 4: Upsample 42x42 -> 84x84 - self.decoder_blocks.append( - DecoderBlock( - in_channels=hidden_dim // 4, - out_channels=hidden_dim // 8, - upsample=True, - use_spectral_norm=use_spectral_norm - ) - ) - - # Output projection with multiple conv layers for refinement self.output_proj = nn.Sequential( - nn.Conv2d(hidden_dim // 8, 32, kernel_size=3, padding=1), + nn.Conv2d(hidden_dim // 4, 32, kernel_size=3, padding=1), nn.InstanceNorm2d(32), nn.Mish(), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.InstanceNorm2d(32), nn.Mish(), - nn.Conv2d(32, img_channels, kernel_size=3, padding=1), + nn.Conv2d(32, img_channels * self.frame_stack, kernel_size=3, padding=1), nn.Sigmoid() # Output in [0, 1] ) self.to(self.device) @@ -355,8 +448,8 @@ def forward(self, latent: torch.Tensor) -> torch.Tensor: # Project latent to spatial representation h = self.latent_proj(latent) - h = h.view(batch_size, -1, self.spatial_size, self.spatial_size) - + h = h.view(batch_size, self.hidden_dim, self.spatial_size, self.spatial_size) + # Progressive decoding with feature refinement for block in self.decoder_blocks: h = block(h) @@ -387,16 +480,13 @@ def __init__( layers = [] if upsample: - # Sub-pixel convolution for better upsampling - # First increase channels, then pixel shuffle - layers.append( - nn.Conv2d(in_channels, out_channels * 4, kernel_size=3, padding=1) - ) + # Replace PixelShuffle with: bilinear upsample → 3x3 conv + layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) + conv_up = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) if use_spectral_norm: - layers[-1] = nn.utils.spectral_norm(layers[-1]) - + conv_up = nn.utils.spectral_norm(conv_up) layers.extend([ - nn.PixelShuffle(2), # Upsample by 2x + conv_up, nn.InstanceNorm2d(out_channels), nn.Mish() ]) @@ -430,15 +520,13 @@ def __init__( residual_layers = [] if upsample: - # Use sub-pixel conv for residual too - residual_conv = nn.Conv2d(in_channels, out_channels * 4, kernel_size=1) + # Match main path: bilinear upsample → 1x1 conv + residual_layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)) + residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if use_spectral_norm: residual_conv = nn.utils.spectral_norm(residual_conv) - - residual_layers.extend([ - residual_conv, - nn.PixelShuffle(2) - ]) + residual_layers.append(residual_conv) + else: residual_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) if use_spectral_norm: diff --git a/active_inference_diffusion/envs/pixel_wrappers.py b/active_inference_diffusion/envs/pixel_wrappers.py index ac3caf1..6bf6c31 100644 --- a/active_inference_diffusion/envs/pixel_wrappers.py +++ b/active_inference_diffusion/envs/pixel_wrappers.py @@ -4,6 +4,8 @@ from typing import Optional, Tuple, Dict, Any import warnings from .wrappers import ActionRepeat +import os + class MuJoCoPixelObservationWrapper(gym.ObservationWrapper): """ diff --git a/active_inference_diffusion/models/__init__.py b/active_inference_diffusion/models/__init__.py index 897f973..8d25030 100644 --- a/active_inference_diffusion/models/__init__.py +++ b/active_inference_diffusion/models/__init__.py @@ -5,7 +5,7 @@ from .score_networks import LatentScoreNetwork, SinusoidalPositionEmbeddings from .policy_networks import DiffusionConditionedPolicy, HierarchicalDiffusionPolicy from .value_networks import ValueNetwork -from .dynamics_models import LatentDynamicsModel +from .dynamics_models import LatentDynamicsModel, TransformerDynamicsModel __all__ = [ "LatentScoreNetwork", @@ -14,4 +14,5 @@ "HierarchicalDiffusionPolicy", "ValueNetwork", "LatentDynamicsModel", + "TransformerDynamicsModel", ] \ No newline at end of file diff --git a/active_inference_diffusion/models/dynamics_models.py b/active_inference_diffusion/models/dynamics_models.py index 85ed2da..1d03c54 100644 --- a/active_inference_diffusion/models/dynamics_models.py +++ b/active_inference_diffusion/models/dynamics_models.py @@ -1,68 +1,366 @@ """ Dynamics model implementations """ - import torch import torch.nn as nn - - +import torch.nn.functional as F +from typing import Optional, Tuple, Dict +from torch.nn.attention import sdpa_kernel class LatentDynamicsModel(nn.Module): """ - Latent dynamics model f(s,a) -> s' + Latent dynamics model f([state_t, action_t]) -> state_{t+1} + LSTM variant with optional per-item reset via `done_mask` (like the Transformer). """ - + def __init__( self, state_dim: int, action_dim: int, hidden_dim: int = 256, num_layers: int = 3, - residual: bool = True + residual: bool = True, + lstm_hidden_dim: int = 128, + dropout: float = 0.1, ): super().__init__() - + self.residual = residual - + self.lstm_hidden_dim = lstm_hidden_dim + + # Belief LSTM over [s_t, a_t]; we do single-step updates (seq_len=1) with batch_first=True + self.belief_lstm = nn.LSTM( + input_size=state_dim + action_dim, + hidden_size=lstm_hidden_dim, + num_layers=2, + batch_first=True, + dropout=dropout, + ) + + # (Optional) buffers if you later want learned h0/c0; currently unused (zeros) + self.register_buffer("lstm_h0", None) + self.register_buffer("lstm_c0", None) + + # MLP that reads [s_t, a_t, h_t] and outputs (mean, logvar) for s_{t+1} + input_dim = state_dim + action_dim + lstm_hidden_dim layers = [] - input_dim = state_dim + action_dim - for i in range(num_layers): - if i == 0: - layers.append(nn.Linear(input_dim, hidden_dim)) - else: - layers.append(nn.Linear(hidden_dim, hidden_dim)) - + layers.append(nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim)) layers.append(nn.LayerNorm(hidden_dim)) - layers.append(nn.ReLU()) - - layers.append(nn.Linear(hidden_dim, state_dim)) - + layers.append(nn.ELU()) + layers.append(nn.Dropout(dropout)) + layers.append(nn.Linear(hidden_dim, 2 * state_dim)) self.network = nn.Sequential(*layers) - - # Initialize output to small values for residual connection + + # Output init: tiny if residual, otherwise Xavier if residual: nn.init.uniform_(self.network[-1].weight, -1e-3, 1e-3) nn.init.zeros_(self.network[-1].bias) - + else: + nn.init.xavier_uniform_(self.network[-1].weight, gain=1.0) + nn.init.zeros_(self.network[-1].bias) + + self._init_lstm_weights() + + def _init_lstm_weights(self): + """Initialize LSTM weights for stable training.""" + for name, param in self.belief_lstm.named_parameters(): + if "weight_ih" in name: + nn.init.xavier_uniform_(param.data) + elif "weight_hh" in name: + nn.init.orthogonal_(param.data) + elif "bias" in name: + nn.init.zeros_(param.data) + # Forget gate bias = 1 (gates order: i, f, g, o) + n = param.size(0) + param.data[n // 4 : n // 2].fill_(1.0) + + def init_hidden(self, batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + """Fresh hidden for new sequences. Shape: (num_layers, B, lstm_hidden_dim).""" + num_layers = self.belief_lstm.num_layers + h0 = torch.zeros(num_layers, batch_size, self.lstm_hidden_dim, device=device) + c0 = torch.zeros(num_layers, batch_size, self.lstm_hidden_dim, device=device) + return (h0, c0) + + def _maybe_reinit_hidden( + self, + hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]], + batch_size: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_state is None: + return self.init_hidden(batch_size, device) + h, c = hidden_state + # If batch size changed (e.g., at episode start), reinit + if h.size(1) != batch_size or c.size(1) != batch_size: + return self.init_hidden(batch_size, device) + return h, c + def forward( self, - state: torch.Tensor, - action: torch.Tensor - ) -> torch.Tensor: + state: torch.Tensor, # (B, state_dim) + action: torch.Tensor, # (B, action_dim) + hidden_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + done_mask: Optional[torch.Tensor] = None # (B,) bool or {0,1}; True resets BEFORE current step + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - Predict next state - - Args: - state: Current state [batch_size, state_dim] - action: Action [batch_size, action_dim] - Returns: - Next state [batch_size, state_dim] + next_state_mean: (B, state_dim) + next_state_logvar: (B, state_dim) + new_hidden_state: (h, c) where each is (num_layers, B, lstm_hidden_dim) """ - inputs = torch.cat([state, action], dim=-1) - output = self.network(inputs) - - if self.residual: - return state + output - else: - return output + B = state.size(0) + device = state.device + + # Prepare input and hidden + inputs = torch.cat([state, action], dim=-1) # (B, S+A) + if hidden_state is None: + hidden_state = self.init_hidden(B, device) + + if done_mask is not None: + dm = done_mask.float().view(1, B, 1) # [1, B, 1] for broadcasting + h, c = hidden_state + # h,c: [num_layers, B, H] - broadcast correctly + h = h * (1.0 - dm) # No transpose needed + c = c * (1.0 - dm) + hidden_state = (h, c) + + lstm_out, new_hidden = self.belief_lstm(inputs.unsqueeze(1), hidden_state) + + combined_feature = torch.cat([inputs, new_hidden[0][-1]], dim=-1) + output = self.network(combined_feature) + mean_state, log_var_state = torch.chunk(output, 2, dim=-1) + + next_state_mean = state + mean_state if self.residual else mean_state + next_state_logvar = torch.clamp(log_var_state, min=-10, max=2) + return next_state_mean, next_state_logvar, new_hidden + +class TransformerDynamicsModel(nn.Module): + """ + Causal Transformer dynamics: f([state, action]_t, context) -> state_{t+1} + + - init_hidden(batch, device) -> hidden + - forward(state, action, hidden, done_mask=None) -> (mean, logvar, hidden) + """ + + def __init__( + self, + state_dim: int, + action_dim: int, + hidden_dim: int = 256, # d_model + num_layers: int = 2, + n_heads: int = 4, + dropout: float = 0.1, + context_len: int = 16, + residual: bool = True, + use_checkpointing: bool = False, + attn_impl: str = "auto", # {"auto","flash","mem","math"} for SDPA backends + clear_on_reset: bool = True, + logvar_min: float = -10.0, + logvar_max: float = 2.0, + ): + super().__init__() + self.state_dim = state_dim + self.action_dim = action_dim + self.d_model = hidden_dim + self.context_len = context_len + self.residual = residual + self.use_checkpointing = use_checkpointing + self.attn_impl = attn_impl + self.clear_on_reset = clear_on_reset + self.logvar_min = logvar_min + self.logvar_max = logvar_max + + # Token for [state, action] + self.token_proj = nn.Sequential( + nn.Linear(state_dim + action_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + ) + + # Learned absolute positions [T, d_model] + self.pos_embed = nn.Parameter(torch.zeros(context_len, hidden_dim)) + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + # Pre-norm Transformer encoder (batch_first=True) + enc_layer = nn.TransformerEncoderLayer( + d_model=hidden_dim, + nhead=n_heads, + dim_feedforward=4 * hidden_dim, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) + + # Head: (mean, logvar) of next state + self.head = nn.Sequential( + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, 2 * state_dim), + ) + nn.init.uniform_(self.head[-1].weight, -1e-3, 1e-3) + nn.init.zeros_(self.head[-1].bias) + + + def init_hidden(self, batch_size: int, device: torch.device) -> Dict[str, torch.Tensor]: + # Fixed-capacity ring buffer to avoid reallocation each step + tokens = torch.empty(batch_size, self.context_len, self.d_model, device=device) + lengths = torch.zeros(batch_size, dtype=torch.long, device=device) # valid length in [0..context_len] + return {"tokens": tokens, "lengths": lengths} + + def _maybe_reset_hidden(self, hidden: Optional[Dict[str, torch.Tensor]], B: int, device): + if ( + hidden is None + or ("tokens" not in hidden) + or hidden["tokens"].size(0) != B + or hidden["tokens"].size(1) != self.context_len + or hidden["tokens"].size(2) != self.d_model + ): + return self.init_hidden(B, device) + if "lengths" not in hidden or hidden["lengths"].size(0) != B: + hidden["lengths"] = torch.zeros(B, dtype=torch.long, device=device) + return hidden + + @staticmethod + def _causal_mask(T: int, device: torch.device): + # (T,T) with -inf above diagonal + return torch.ones(T, T, dtype=torch.bool, device=device).triu(1) + + @staticmethod + def _key_padding_mask(lengths: torch.Tensor, T: int): + """ + lengths: (B,) valid lengths, 0..T + return: (B, T) True where PAD, False where valid + """ + B = lengths.size(0) + ar = torch.arange(T, device=lengths.device).unsqueeze(0).expand(B, T) # (B,T) + return ar >= lengths.clamp_max(T).unsqueeze(1) + + def _sdpa_backend_ctx(self): + """ + Select SDPA backend if on CUDA + PyTorch>=2. + - flash: fastest (constraints on head dim etc.) + - mem: memory-efficient + - math: matmul + - auto: let PyTorch decide + """ + try: + + if self.attn_impl == "flash": + return sdpa_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False) + elif self.attn_impl == "mem": + return sdpa_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False) + elif self.attn_impl == "math": + return sdpa_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True) + else: # "auto" + class _NullCtx: + def __enter__(self): return None + def __exit__(self, exc_type, exc, tb): return False + return _NullCtx() + except Exception: + class _NullCtx: + def __enter__(self): return None + def __exit__(self, exc_type, exc, tb): return False + return _NullCtx() + + def get_memory_usage(self, batch_size: int) -> Dict[str, float]: + """Rough memory usage report (MB) for the token buffer + parameters (fp32).""" + token_mb = batch_size * self.context_len * self.d_model * 4 / 1024**2 + param_mb = sum(p.numel() * 4 for p in self.parameters()) / 1024**2 + return {"token_buffer_mb": token_mb, "parameters_mb": param_mb, "total_mb": token_mb + param_mb} + + + def forward( + self, + state: torch.Tensor, # (B, S) + action: torch.Tensor, # (B, A) + hidden_state: Optional[Dict[str, torch.Tensor]] = None, + done_mask: Optional[torch.Tensor] = None # (B,) bool or 0/1; True resets sequence BEFORE appending current token + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: + B = state.size(0) + device = state.device + + hidden_state = self._maybe_reset_hidden(hidden_state, B, device) + tokens = hidden_state["tokens"] # (B, Tcap, d) + lengths = hidden_state["lengths"] # (B,) + + if done_mask is not None and bool(done_mask.any().item()): + + idx = done_mask.nonzero(as_tuple=False).squeeze(-1) + if idx.numel() > 0: + tokens[idx].zero_() + if lengths is not None: + lengths[idx] = 0 + + # Project current [s_t, a_t] into token + token = self.token_proj(torch.cat([state, action], dim=-1)) # (B, d_model) + + # Append to ring buffer (shift-left where full) + full_mask = lengths >= self.context_len + if full_mask.any(): + rows = full_mask.nonzero(as_tuple=False).squeeze(-1) + tokens[rows, :-1, :] = tokens[rows, 1:, :] + tokens[rows, -1, :] = token[rows] + if (~full_mask).any(): + rows = (~full_mask).nonzero(as_tuple=False).squeeze(-1) + pos = lengths[rows] + tokens[rows, pos, :] = token[rows] + lengths[rows] = pos + 1 + + # Effective sequence length + T = int(lengths.max().item()) + T = max(1, min(T, self.context_len)) + + # Slice valid window and build masks + x = tokens[:, :T, :] # (B, T, d) + x = x + self.pos_embed[:T, :].unsqueeze(0) + src_kpm = self._key_padding_mask(lengths, T) # (B, T) True=pad + causal = self._causal_mask(T, device) # (T, T) + + # Encoder + with self._sdpa_backend_ctx(): + if self.use_checkpointing: + amp_enabled = torch.is_autocast_enabled() + amp_dtype = None + try: + amp_dtype = torch.get_autocast_gpu_dtype() + except Exception: + pass + + def run_layer(y, layer, cm, kpm): + # Match autocast state exactly during recompute + device_type = "cuda" if y.is_cuda else "cpu" + with torch.amp.autocast(device_type=device_type, + enabled=amp_enabled, + dtype=(amp_dtype if device_type == "cuda" else None)): + # Boolean src_mask; do NOT pass is_causal to avoid path switches + return layer(y, src_mask=cm, src_key_padding_mask=kpm) + + + for layer in self.encoder.layers: + x = torch.utils.checkpoint.checkpoint( + run_layer, x, layer, causal, src_kpm, use_reentrant=False, preserve_rng_state=True + ) + + if self.encoder.norm is not None: + x = self.encoder.norm(x) + else: + # Also use KWARGS here for stability across versions + x = self.encoder(x, mask=causal, src_key_padding_mask=src_kpm) + # Last VALID token per item + idx = (lengths - 1).clamp(min=0, max=T - 1) # (B,) + h = x[torch.arange(B, device=device), idx, :] # (B, d) + + # Predict next state distribution + out = self.head(h) # (B, 2*S) + mean_state, log_var_state = torch.chunk(out, 2, dim=-1) + + next_state_mean = state + mean_state if self.residual else mean_state + next_state_logvar = torch.clamp(log_var_state, min=self.logvar_min, max=self.logvar_max) + + new_hidden = {"tokens": tokens, "lengths": lengths} + return next_state_mean, next_state_logvar, new_hidden diff --git a/active_inference_diffusion/models/policy_networks.py b/active_inference_diffusion/models/policy_networks.py index 5346982..0b7ab23 100644 --- a/active_inference_diffusion/models/policy_networks.py +++ b/active_inference_diffusion/models/policy_networks.py @@ -5,9 +5,69 @@ import torch import torch.nn as nn import torch.distributions as dist -from typing import Tuple, Optional, List +from typing import Optional, Tuple, List, Dict, Union, Literal +import math import numpy as np import torch.nn.functional as F +import torch.utils.checkpoint as cp +from torch.nn.attention import SDPBackend, sdpa_kernel +Tensor = torch.Tensor + +# ---- Tanh-squash helpers ----------------------------------------------------- + +def tanh_squash(action_pre_tanh: Tensor) -> Tensor: + return torch.tanh(action_pre_tanh) + +def tanh_correction_logdet(a_t: Tensor, eps: float = 1e-6) -> Tensor: + # log |det d tanh(u) / du| = sum_i log(1 - tanh(u_i)^2) + return torch.log1p(-a_t.pow(2) + eps).sum(-1) + +# ---- Initialization ---------------------------------------------------------- + +def initialize_policy_weights(module: nn.Module, + orthogonal_linear: bool = True, + final_std: float = 0.01) -> None: + """ + Initialize Linear/Conv/LSTM/Transformer weights with sensible defaults. + - MLP/LSTM: orthogonal (good for actors) + - Transformer/GTrXL: Xavier for projections/FFN; keep defaults for MHA + - Final policy heads: small gain + """ + for m in module.modules(): + if isinstance(m, nn.Linear): + if orthogonal_linear: + nn.init.orthogonal_(m.weight) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.Conv1d, nn.Conv2d)): + nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5)) + if m.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(m.bias, -bound, bound) + elif isinstance(m, nn.LSTM): + for name, p in m.named_parameters(): + if "weight_ih" in name: + nn.init.xavier_uniform_(p) + elif "weight_hh" in name: + nn.init.orthogonal_(p) + elif "bias" in name: + nn.init.zeros_(p) + elif isinstance(m, nn.MultiheadAttention): + # Leave default init or apply Xavier if desired + pass + + # Small final layers + for m in module.modules(): + if isinstance(m, nn.Linear) and getattr(m, "_is_policy_final", False): + if orthogonal_linear: + nn.init.orthogonal_(m.weight, gain=final_std) + else: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) class DiffusionConditionedPolicy(nn.Module): """ @@ -18,231 +78,473 @@ class DiffusionConditionedPolicy(nn.Module): generated by the diffusion process """ - def __init__( - self, - latent_dim: int, - action_dim: int, - hidden_dim: int = 256, - num_layers: int = 3, - log_std_min: float = -20, - log_std_max: float = 2, - use_state_dependent_std: bool = True, - squash_output: bool = False - ): + def __init__(self, + latent_dim: int, + action_dim: int, + hidden_dim: int = 256, + num_layers: int = 2, + squash_output: bool = True, + state_dependent_std: bool = True, + min_std: float = 1e-4, + max_std: float = 1.0, + checkpoint: bool = True): super().__init__() - - self.latent_dim = latent_dim - self.action_dim = action_dim - self.log_std_min = log_std_min - self.log_std_max = log_std_max - self.use_state_dependent_std = use_state_dependent_std self.squash_output = squash_output - - # Latent processing network with skip connections - self.latent_encoder = nn.Sequential( - nn.Linear(latent_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim) - ) - - # Policy trunk - trunk_layers = [] - for i in range(num_layers): - trunk_layers.extend([ - nn.Linear(hidden_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.ReLU() - ]) - self.trunk = nn.Sequential(*trunk_layers) - - # Separate heads for mean and covariance - self.mean_head = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, action_dim) - ) - - if use_state_dependent_std: - self.log_std_head = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim // 2), - nn.ReLU(), - nn.Linear(hidden_dim // 2, action_dim) - ) + self.state_dependent_std = state_dependent_std + self.min_std = min_std + self.max_std = max_std + self.checkpoint = bool(checkpoint) + + layers: List[nn.Module] = [] + in_dim = latent_dim + for _ in range(num_layers): + layers += [nn.Linear(in_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(inplace=True)] + in_dim = hidden_dim + self.trunk = nn.Sequential(*layers) + + self.mean_head = nn.Linear(hidden_dim, action_dim); self.mean_head._is_policy_final = True + if state_dependent_std: + self.std_head = nn.Linear(hidden_dim, action_dim); self.std_head._is_policy_final = True else: - # Learnable but state-independent std - self.log_std = nn.Parameter(torch.zeros(action_dim)) - self._initialize_weights() - - def _initialize_weights(self): - # Small initialization for output layers to prevent large initial actions - torch.nn.init.orthogonal_(self.mean_head[-1].weight, gain=torch.tensor(1.0)) - nn.init.zeros_(self.mean_head[-1].bias) - - if self.use_state_dependent_std: - torch.nn.init.orthogonal_(self.log_std_head[-1].weight, gain=torch.tensor(1.0)) - nn.init.zeros_(self.log_std_head[-1].bias) - - # Xavier/He init for other layers - for m in [self.latent_encoder, self.trunk, self.mean_head[:-1]]: - for layer in m.modules(): - if isinstance(layer, nn.Linear): - nn.init.xavier_uniform_(layer.weight) - if layer.bias is not None: - nn.init.zeros_(layer.bias) - - - def forward( - self, - z: torch.Tensor, - deterministic: bool = False - ) -> Tuple[torch.Tensor, torch.Tensor, dist.Distribution]: + self.log_std = nn.Parameter(torch.full((action_dim,), math.log(0.3))) + + initialize_policy_weights(self, orthogonal_linear=True) + + @staticmethod + def _maybe_checkpoint(mod: nn.Module, x: Tensor, enabled: bool) -> Tensor: + if enabled and isinstance(x, Tensor) and x.requires_grad: + return cp.checkpoint(mod, x, use_reentrant=False) + return mod(x) + + def forward(self, + z: Tensor, + deterministic: bool = False + ) -> Tuple[Tensor, Tensor, dist.Distribution]: """ - Generate policy distribution from diffusion latent - - Args: - z: Diffusion-generated latent [batch_size, latent_dim] - deterministic: If True, return mean action - Returns: - action: Sampled or mean action - log_prob: Log probability of action - distribution: Full policy distribution + action: [B, A], log_prob: [B], distribution object """ - # Process latent with skip connection - h = self.latent_encoder(z) - h = h + self.trunk(h) # Residual connection - - # Compute mean + h = self._maybe_checkpoint(self.trunk, z, self.checkpoint) + mean = self.mean_head(h) - - # Compute std - if self.use_state_dependent_std: - log_std = self.log_std_head(h) - else: - log_std = self.log_std.expand_as(mean) - - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) - std = torch.exp(log_std) - - # Create distribution - distribution = dist.Normal(mean, std) - - # Sample action - if deterministic: - action = mean + if self.state_dependent_std: + std = torch.clamp(self.std_head(h).exp(), self.min_std, self.max_std) else: - action = distribution.rsample() - - # Apply squashing function if needed (for bounded action spaces) + std = torch.clamp(self.log_std.exp(), self.min_std, self.max_std) + + base = dist.Normal(mean, std) + u = mean if deterministic else base.rsample() + if self.squash_output: - action = torch.tanh(action) - # Correct log_prob for tanh squashing - log_prob = distribution.log_prob(action).sum(dim=-1) - log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(dim=-1) + a = torch.tanh(u) + log_prob = base.log_prob(u).sum(-1) - torch.log1p(-a.pow(2) + 1e-6).sum(-1) + dist_out = dist.TransformedDistribution(base, [torch.distributions.transforms.TanhTransform(cache_size=0)]) + return a, log_prob, dist_out else: - log_prob = distribution.log_prob(action).sum(dim=-1) - - return action, log_prob, distribution + log_prob = base.log_prob(u).sum(-1) + return u, log_prob, base def get_policy_entropy(self, z: torch.Tensor) -> torch.Tensor: """Compute policy entropy for exploration bonus""" _, _, distribution = self.forward(z, deterministic=True) - return distribution.entropy().sum(dim=-1) + dist = getattr(distribution, "base_dist", distribution) + return dist.entropy().sum(dim=-1) -class HierarchicalDiffusionPolicy(nn.Module): + +class LSTMTemporalEncoder(nn.Module): """ - Hierarchical policy structure emerging from diffusion latents - Implements temporal abstractions through latent dynamics + One-step LSTM encoder with done-based resets. + Input: x_t [B,D], state (h,c) with [L,B,H], done [B] + Output: y_t [B,H], new_state """ - - def __init__( - self, - latent_dim: int, - action_dim: int, - num_levels: int = 3, - hidden_dim: int = 256 - ): + def __init__(self, dim: int, hidden_dim: int = 256, num_layers: int = 1, checkpoint: bool = False): super().__init__() - - self.num_levels = num_levels - self.latent_dim = latent_dim - - # Create policy for each hierarchical level - self.policies = nn.ModuleList([ - DiffusionConditionedPolicy( - latent_dim=latent_dim, - action_dim=action_dim if i == 0 else latent_dim, - hidden_dim=hidden_dim, - use_state_dependent_std=True - ) - for i in range(num_levels) - ]) - - # Temporal abstraction networks - self.temporal_encoders = nn.ModuleList([ - nn.LSTM( - input_size=latent_dim, - hidden_size=latent_dim, - num_layers=1, - batch_first=True - ) - for _ in range(num_levels - 1) + self.lstm = nn.LSTM(dim, hidden_dim, num_layers=num_layers, batch_first=True) + self.checkpoint = bool(checkpoint) + initialize_policy_weights(self, orthogonal_linear=True) + + def initial_state(self, batch_size: int, device: torch.device) -> Tuple[Tensor, Tensor]: + h = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size, device=device) + c = torch.zeros_like(h) + return h, c + + def forward(self, x_t: Tensor, state: Optional[Tuple[Tensor, Tensor]], done: Optional[Tensor] = None + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + B = x_t.size(0) + if state is None: + state = self.initial_state(B, x_t.device) + + if done is not None: + dm = done.view(1, -1, 1).to(state[0].dtype) + h0 = state[0] * (1.0 - dm) + c0 = state[1] * (1.0 - dm) + state = (h0, c0) + + x = x_t.unsqueeze(1) # [B,1,D] + if self.checkpoint and x.requires_grad: + y, new_state = cp.checkpoint(self.lstm, x, state, use_reentrant=False) # type: ignore[arg-type] + else: + y, new_state = self.lstm(x, state) + return y.squeeze(1), new_state + +# ---- Mask utilities ---------------------------------------------------------- + +def causal_square_mask(T: int, device, dtype=torch.float32) -> Tensor: + # Additive mask with -inf above diagonal, 0 elsewhere. Shape [T,T] + m = torch.full((T, T), float('-inf'), device=device, dtype=dtype) + m = torch.triu(m, diagonal=1) + m = m.masked_fill(torch.tril(torch.ones(T, T, device=device, dtype=torch.bool)) , 0.0) + return m + +def done_causal_mask(done: Tensor) -> Tensor: + """ + Build additive mask [T,T] for a *single* sequence done flags [T]. + - Causal (no attending to future) + - Disallow attention across episode boundaries (segments split by done==1) + """ + T = done.numel() + device = done.device + dtype = torch.float32 + # causal + mask = causal_square_mask(T, device, dtype=dtype) # [T,T] with -inf above diag + # build segment ids; done marks terminal at that step + seg = done.cumsum(dim=0) # [T] + seg_i = seg.view(T, 1).expand(T, T) + seg_j = seg.view(1, T).expand(T, T) + cross = seg_i.ne(seg_j) # True when across segments + mask = torch.where(cross, torch.tensor(float('-inf'), device=device, dtype=dtype), mask) + return mask + +# ---- Transformer temporal encoder (windowed, masked) ------------------------- + +class TemporalTransformer(nn.Module): + """ + Windowed TransformerEncoder with optional done-aware masking. + + Modes: + - shared_mask (default): reduce per-sample masks across batch to a single [T,T] mask + using method='union' (any boundary blocks) or 'mean' (thresholded average). + - strict_per_sample: loop over batch and apply exact per-sample masks. + + Returns the encoded last token. + """ + def __init__(self, dim: int, n_layers: int = 2, n_heads: int = 4, ff_mult: int = 4, + dropout: float = 0.1, checkpoint: bool = False, + strict_per_sample: bool = False, + reduction: Literal["union", "mean"] = "union", + mean_threshold: float = 0.5): + super().__init__() + self.layers = nn.ModuleList([ + nn.TransformerEncoderLayer(d_model=dim, nhead=n_heads, + dim_feedforward=dim * ff_mult, + dropout=dropout, batch_first=True, + activation="gelu", norm_first=True) + for _ in range(n_layers) ]) - # Initialize policies - for layer in self.policies.modules(): - if isinstance(layer, nn.Linear): - nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') - nn.init.constant_(layer.bias, 0.0) - elif isinstance(layer, nn.LayerNorm): - nn.init.constant_(layer.weight, 1.0) - nn.init.constant_(layer.bias, 0.0) - # Initialize temporal encoders - for layer in self.temporal_encoders: - for name, param in layer.named_parameters(): - if 'weight_ih' in name: - nn.init.xavier_uniform_(param) - elif 'weight_hh' in name: - nn.init.orthogonal_(param) - elif 'bias' in name: - n = param.size(0) - param.data.fill_(0) - gate_size = n // 4 - param.data[0:gate_size].fill_(1.0) - - def forward( - self, - z: torch.Tensor, - level: int = 0, - hidden_states: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: + self.pos_emb = nn.Parameter(torch.randn(1, 512, dim) * 0.01) # max T=512 + self.checkpoint = bool(checkpoint) + self.strict_per_sample = bool(strict_per_sample) + self.reduction = reduction + self.mean_threshold = float(mean_threshold) + # Xavier for transformer-style layers + initialize_policy_weights(self, orthogonal_linear=False) + + def _reduce_masks(self, masks_bt: Tensor) -> Tensor: + # masks_bt: [B, T, T] additive (-inf or 0). Reduce to [T, T]. + if self.reduction == "union": + m = torch.isneginf(masks_bt) # True where -inf + m = m.any(dim=0) # [T,T] union across batch + out = torch.where(m, torch.tensor(float('-inf'), device=masks_bt.device), torch.tensor(0.0, device=masks_bt.device)) + else: + # "mean": average then threshold + # convert to 1 for -inf, 0 for 0 + bmask = torch.isneginf(masks_bt).float() + avg = bmask.mean(dim=0) + out = torch.where(avg >= self.mean_threshold, + torch.tensor(float('-inf'), device=masks_bt.device), + torch.tensor(0.0, device=masks_bt.device)) + return out # [T,T] + + def forward(self, z_seq: Tensor, done_seq: Optional[Tensor] = None) -> Tensor: """ - Hierarchical policy execution - - Higher levels generate subgoals in latent space, - lower levels generate actions + z_seq: [B, T, D], done_seq: [B, T] with 1 at terminal steps + """ + B, T, D = z_seq.shape + pos = self.pos_emb[:, :T, :] + x = z_seq + pos + + if done_seq is None: + attn_mask = causal_square_mask(T, x.device, dtype=torch.float32) # [T,T] + for layer in self.layers: + if self.checkpoint and x.requires_grad: + x = cp.checkpoint(lambda _x, m=attn_mask: layer(_x, src_mask=m), x, use_reentrant=False) + else: + x = layer(x, src_mask=attn_mask) + return x[:, -1, :] + + # done-aware + if self.strict_per_sample: + outs = [] + for b in range(B): + m = done_causal_mask(done_seq[b]) # [T,T] + x_b = x[b:b+1, :, :] + for layer in self.layers: + if self.checkpoint and x_b.requires_grad: + x_b = cp.checkpoint(lambda _x, mm=m: layer(_x, src_mask=mm), x_b, use_reentrant=False) + else: + x_b = layer(x_b, src_mask=m) + outs.append(x_b[:, -1, :]) # [1,D] + return torch.cat(outs, dim=0) # [B,D] + else: + masks = torch.stack([done_causal_mask(done_seq[b]) for b in range(B)], dim=0) # [B,T,T] + attn_mask = self._reduce_masks(masks) # [T,T] + for layer in self.layers: + if self.checkpoint and x.requires_grad: + x = cp.checkpoint(lambda _x, m=attn_mask: layer(_x, src_mask=m), x, use_reentrant=False) + else: + x = layer(x, src_mask=attn_mask) + return x[:, -1, :] + +# ---- GTrXL-style gated Transformer ------------------------------------------ + +class GRUGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.Wr = nn.Linear(dim, dim) + self.Ur = nn.Linear(dim, dim, bias=False) + self.Wz = nn.Linear(dim, dim) + self.Uz = nn.Linear(dim, dim, bias=False) + self.Wg = nn.Linear(dim, dim) + self.Ug = nn.Linear(dim, dim, bias=False) + + def forward(self, x: Tensor, y: Tensor) -> Tensor: + r = torch.sigmoid(self.Wr(y) + self.Ur(x)) + z = torch.sigmoid(self.Wz(y) + self.Uz(x)) + g = torch.tanh(self.Wg(y) + self.Ug(r * x)) + return (1 - z) * x + z * g + +class GTrXLBlock(nn.Module): + def __init__(self, dim: int, n_heads: int = 4, ff_mult: int = 4, dropout: float = 0.1): + super().__init__() + self.ln1 = nn.LayerNorm(dim) + self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True) + self.gate1 = GRUGate(dim) + + self.ln2 = nn.LayerNorm(dim) + self.ff = nn.Sequential( + nn.Linear(dim, dim * ff_mult), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * ff_mult, dim), + nn.Dropout(dropout), + ) + self.gate2 = GRUGate(dim) + + def forward(self, x: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor: + # Pre-norm + masked attention + y = self.ln1(x) + with sdpa_kernel(SDPBackend.MATH): + y, _ = self.attn(y, y, y, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) + x = self.gate1(x, y) + # Pre-norm + FFN with gating + y = self.ln2(x) + y = self.ff(y) + x = self.gate2(x, y) + return x + +class GTrXLTransformer(nn.Module): + """ + Windowed GTrXL encoder with done-aware masking. + Returns encoded last token. + """ + def __init__(self, dim: int, n_layers: int = 2, n_heads: int = 4, ff_mult: int = 4, + dropout: float = 0.1, checkpoint: bool = False, + strict_per_sample: bool = False, + reduction: Literal["union", "mean"] = "union", + mean_threshold: float = 0.5): + super().__init__() + self.layers = nn.ModuleList([GTrXLBlock(dim, n_heads, ff_mult, dropout) for _ in range(n_layers)]) + self.pos_emb = nn.Parameter(torch.randn(1, 512, dim) * 0.01) + self.checkpoint = bool(checkpoint) + self.strict_per_sample = bool(strict_per_sample) + self.reduction = reduction + self.mean_threshold = float(mean_threshold) + initialize_policy_weights(self, orthogonal_linear=False) + + def _reduce_masks(self, masks_bt: Tensor) -> Tensor: + # Same reduction as TemporalTransformer + if self.reduction == "union": + m = torch.isneginf(masks_bt).any(dim=0) + out = torch.where(m, torch.tensor(float('-inf'), device=masks_bt.device), torch.tensor(0.0, device=masks_bt.device)) + else: + bmask = torch.isneginf(masks_bt).float() + avg = bmask.mean(dim=0) + out = torch.where(avg >= self.mean_threshold, + torch.tensor(float('-inf'), device=masks_bt.device), + torch.tensor(0.0, device=masks_bt.device)) + return out + + def forward(self, z_seq: Tensor, done_seq: Optional[Tensor] = None) -> Tensor: + B, T, D = z_seq.shape + x = z_seq + self.pos_emb[:, :T, :] + + if done_seq is None: + m = causal_square_mask(T, x.device, dtype=torch.float32) # [T,T] + for layer in self.layers: + if self.checkpoint and x.requires_grad: + x = cp.checkpoint(lambda _x, mm=m: layer(_x, attn_mask=mm, key_padding_mask=None), x, use_reentrant=False) + else: + x = layer(x, attn_mask=m, key_padding_mask=None) + return x[:, -1, :] + + if self.strict_per_sample: + outs = [] + for b in range(B): + m = done_causal_mask(done_seq[b]) # [T,T] + x_b = x[b:b+1, :, :] + for layer in self.layers: + if self.checkpoint and x_b.requires_grad: + x_b = cp.checkpoint(lambda _x, mm=m: layer(_x, attn_mask=mm, key_padding_mask=None), x_b, use_reentrant=False) + else: + x_b = layer(x_b, attn_mask=m, key_padding_mask=None) + outs.append(x_b[:, -1, :]) + return torch.cat(outs, dim=0) + else: + masks = torch.stack([done_causal_mask(done_seq[b]) for b in range(B)], dim=0) # [B,T,T] + m = self._reduce_masks(masks) # [T,T] + for layer in self.layers: + if self.checkpoint and x.requires_grad: + x = cp.checkpoint(lambda _x, mm=m: layer(_x, attn_mask=mm, key_padding_mask=None), x, use_reentrant=False) + else: + x = layer(x, attn_mask=m, key_padding_mask=None) + return x[:, -1, :] + + +class HierarchicalDiffusionPolicy(nn.Module): + """ + Stacks multiple policy heads, each optionally preceded by a temporal encoder. + At each level i: + z_i = temporal_enc[i](z_{i-1}, state_i, done) (identity if None) + a_i ~ policy_i(z_i) + z_{i} := a_i for next level (subgoal conditioning) + Final action comes from the last level. + """ + def __init__(self, + levels: int, + latent_dim: int, + action_dim: int, + hidden_dim: int = 256, + temporal: str = "none", # "none" | "lstm" | "transformer" | "gtrxl" + temporal_dim: Optional[int] = None, + temporal_layers: int = 1, + temporal_heads: int = 4, + temporal_ff_mult: int = 4, + temporal_dropout: float = 0.1, + window: int = 16, # for transformer/gtrxl (context length) + checkpoint_temporal: bool = False, + checkpoint_policy: bool = False, + squash_output: bool = True, + state_dependent_std: bool = True, + strict_done_mask: bool = False, + mask_reduction: Literal["union", "mean"] = "union", + mean_threshold: float = 0.5): + super().__init__() + self.levels = levels + self.temporal = temporal + self.window = window + + pols = [] + temps = [] + for _ in range(levels): + # temporal encoder + if temporal == "lstm": + dim = temporal_dim or latent_dim + temps.append(LSTMTemporalEncoder(latent_dim, hidden_dim=dim, num_layers=temporal_layers, + checkpoint=checkpoint_temporal)) + head_in = dim + elif temporal == "transformer": + temps.append(TemporalTransformer(dim=latent_dim, n_layers=temporal_layers, + n_heads=temporal_heads, ff_mult=temporal_ff_mult, + dropout=temporal_dropout, checkpoint=checkpoint_temporal, + strict_per_sample=strict_done_mask, + reduction=mask_reduction, + mean_threshold=mean_threshold)) + head_in = latent_dim + elif temporal == "gtrxl": + temps.append(GTrXLTransformer(dim=latent_dim, n_layers=temporal_layers, + n_heads=temporal_heads, ff_mult=temporal_ff_mult, + dropout=temporal_dropout, checkpoint=checkpoint_temporal, + strict_per_sample=strict_done_mask, + reduction=mask_reduction, + mean_threshold=mean_threshold)) + head_in = latent_dim + else: + temps.append(None) + head_in = latent_dim + + pol = DiffusionConditionedPolicy(latent_dim=head_in, action_dim=action_dim, + hidden_dim=hidden_dim, num_layers=2, + squash_output=squash_output, + state_dependent_std=state_dependent_std, + checkpoint=checkpoint_policy) + pols.append(pol) + + self.policies = nn.ModuleList(pols) + self.temporal_encoders = nn.ModuleList([t if t is not None else nn.Identity() for t in temps]) + + def initial_state(self, batch_size: int, device: torch.device): + states: List[Optional[Tuple[Tensor, Tensor]]] = [] + for t in self.temporal_encoders: + if isinstance(t, LSTMTemporalEncoder): + states.append(t.initial_state(batch_size, device)) + else: + states.append(None) + return states + + def forward(self, + z: Tensor, + hidden_states: Optional[List[Optional[Tuple[Tensor, Tensor]]]] = None, + done: Optional[Tensor] = None, + z_context: Optional[Tensor] = None, + done_context: Optional[Tensor] = None, + deterministic: bool = False + ) -> Tuple[Tensor, Tensor, List[Optional[Tuple[Tensor, Tensor]]]]: + """ + z: [B, Dz] current latent + hidden_states: per-level states (only for LSTM) + done: [B] flags to reset LSTM states + z_context: [B, T, Dz] optional context for attention models; if None, uses single token + done_context: [B, T] done flags for mask building """ + B = z.size(0) if hidden_states is None: - hidden_states = [None] * (self.num_levels - 1) - - current_z = z - new_hidden_states = [] - - # Process through hierarchy - for i in range(self.num_levels - 1, level - 1, -1): - if i < self.num_levels - 1: - # Apply temporal abstraction - z_seq = current_z.unsqueeze(1) # Add time dimension - encoded_z, hidden = self.temporal_encoders[i](z_seq, hidden_states[i]) - current_z = encoded_z.squeeze(1) - new_hidden_states.append(hidden) - - # Get action/subgoal from current level - action, log_prob, _ = self.policies[i](current_z) - - if i > level: - # Use as subgoal for next level - current_z = action - - return action, log_prob, new_hidden_states[::-1] \ No newline at end of file + hidden_states = self.initial_state(B, z.device) + + current = z + new_states: List[Optional[Tuple[Tensor, Tensor]]] = [] + + for lvl in range(self.levels): + enc = self.temporal_encoders[lvl] + + if isinstance(enc, LSTMTemporalEncoder): + y, new_state = enc(current, hidden_states[lvl], done) + new_states.append(new_state) + feat = y + elif isinstance(enc, (TemporalTransformer, GTrXLTransformer)): + if z_context is None: + ctx = current.unsqueeze(1) # [B,1,D] + dctx = done.view(B, 1) if done is not None else None + else: + ctx = z_context[:, -self.window:, :] + dctx = done_context[:, -self.window:] if (done_context is not None) else None + feat = enc(ctx, dctx) + new_states.append(None) + else: + feat = current + new_states.append(None) + + action, logp, _ = self.policies[lvl](feat, deterministic=deterministic) + current = action # chain subgoals + + return action, logp, new_states \ No newline at end of file diff --git a/active_inference_diffusion/models/score_networks.py b/active_inference_diffusion/models/score_networks.py index 2712489..bbf4bb9 100644 --- a/active_inference_diffusion/models/score_networks.py +++ b/active_inference_diffusion/models/score_networks.py @@ -7,6 +7,56 @@ import torch.nn.functional as F import math from typing import Optional +import torch.utils.checkpoint as cp +from torch.nn.attention import SDPBackend, sdpa_kernel +def _act(kind: str) -> nn.Module: + kind = kind.lower() + if kind == "relu": + return nn.ReLU(inplace=True) + if kind in ("lrelu", "leakyrelu"): + return nn.LeakyReLU(0.1, inplace=True) + if kind == "gelu": + return nn.GELU() + if kind == "silu" or kind == "swish": + return nn.SiLU(inplace=True) + raise ValueError(f"Unknown activation: {kind}") + +class FrameTimeEmbedder(nn.Module): + def __init__(self, hidden_size, freq_dim=64): + super().__init__() + self.pos = SinusoidalPositionEmbeddings(freq_dim) + self.mlp = nn.Sequential( + nn.Linear(freq_dim, hidden_size), + nn.LayerNorm(hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + ) + self.freq_dim = freq_dim + def forward(self, tau): # tau: (B,) int frame idx or float in [0,1] + emb = self.pos(tau) # reuse your sinusoidal piece + return self.mlp(emb) + +class ActionEmbedder(nn.Module): + def __init__(self, action_dim, hidden_size, dropout_prob=0.1): + super().__init__() + self.proj = nn.Sequential( + nn.Linear(action_dim, hidden_size), + nn.LayerNorm(hidden_size), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + ) + self.dropout_prob = float(dropout_prob) + def forward(self, a, training: bool, force_drop: Optional[torch.Tensor] = None): + # classifier-free style dropout for actions (unconditional branch) + if self.dropout_prob > 0 and (training or force_drop is not None): + if force_drop is None: + drop = (torch.rand(a.size(0), device=a.device) < self.dropout_prob).float().unsqueeze(1) + else: + drop = force_drop.to(a.device).float().unsqueeze(1) + a = a * (1.0 - drop) + return self.proj(a) class LatentScoreNetwork(nn.Module): @@ -21,27 +71,32 @@ def __init__( self, latent_dim: int, observation_dim: int, + action_dim: int, hidden_dim: int = 256, time_embed_dim: int = 128, num_layers: int = 6, use_attention: bool = True, output_scale: float = 1e-3, + use_checkpoint: bool = True, + activation: str = "silu", ): super().__init__() self.latent_dim = latent_dim self.observation_dim = observation_dim - + self.action_dim = action_dim # DiT configuration self.num_heads = 8 self.mlp_ratio = 4.0 self.use_attention = use_attention self.output_scale = output_scale + self.use_checkpoint = use_checkpoint + self.act =activation.lower() # Time embedding - FIXED to output hidden_dim self.time_embed = nn.Sequential( SinusoidalPositionEmbeddings(time_embed_dim), nn.Linear(time_embed_dim, hidden_dim*2), # Changed to output hidden_dim directly - nn.SiLU(), + _act(self.act), nn.Linear(hidden_dim*2, hidden_dim) ) @@ -49,19 +104,20 @@ def __init__( self.obs_encoder = nn.Sequential( nn.Linear(observation_dim, hidden_dim), nn.LayerNorm(hidden_dim), - nn.SiLU(), - nn.Dropout(0.1), + _act(self.act), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), - nn.SiLU(), + _act(self.act), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), ) self.continuous_time_embed = nn.Sequential( - nn.Linear(1, time_embed_dim), - nn.SiLU(), + nn.Linear(1, time_embed_dim, bias=True), + nn.LayerNorm(time_embed_dim), + _act(self.act), nn.Linear(time_embed_dim, time_embed_dim), - nn.SiLU(), + nn.LayerNorm(time_embed_dim), + _act(self.act), nn.Linear(time_embed_dim, hidden_dim) ) @@ -91,18 +147,43 @@ def __init__( self.norm_final = AdaptiveLayerNorm(hidden_dim) self.output_proj = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), - nn.SiLU(), + nn.LayerNorm(hidden_dim // 2), + _act(self.act), nn.Linear(hidden_dim // 2, latent_dim, bias=False) ) self.output_multiplier = nn.Parameter(torch.ones(1) * output_scale) + self.ft_embedder = FrameTimeEmbedder(hidden_dim) # new + self.a_embedder = ActionEmbedder(action_dim= self.action_dim, # set actual action_dim in config + hidden_size=hidden_dim, + dropout_prob=0.1) # new + # Initialize output to zero for stability - nn.init.zeros_(self.output_proj[-1].weight) - + self.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Linear): + # Use smaller initialization for pixel observations + nn.init.xavier_uniform_(m.weight, gain=1.0) + if m.bias is not None: + nn.init.zeros_(m.bias) + + + def _maybe_cp(self, fn, *tensors, enable: bool): + """ + Checkpoint only during training when enabled. + This is more predictable than checking requires_grad. + """ + if enable and self.training: + return cp.checkpoint(fn, *tensors, use_reentrant=False) + return fn(*tensors) def forward( self, z_t: torch.Tensor, time: torch.Tensor, - observation: Optional[torch.Tensor] = None + observation: Optional[torch.Tensor] = None, + frame_time: Optional[torch.Tensor] = None, + action: Optional[torch.Tensor] = None, + force_drop_action: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute score ∇_z log p_t(z|o) @@ -116,7 +197,6 @@ def forward( Score [batch_size, latent_dim] """ # Embed time - t_emb = self.time_embed(time) batch_size = z_t.shape[0] is_continuous = time.max() <= 1.0 and time.min() >= 0.0 @@ -128,29 +208,36 @@ def forward( # Also use continuous embedding normalized_time = 2.0 * time.view(-1, 1) - 1.0 # [-1, 1] + normalized_time = torch.clamp(normalized_time, -1.0, 1.0) # Ensure within bounds t_cont = self.continuous_time_embed(normalized_time) # Combine both with learned weighting t_emb = t_sin + self.time_scale * t_cont # Add time-dependent output scaling for annealing - time_weight = torch.sqrt(1.0 / (1e-5 + time.view(-1, 1))) + time_eps = torch.finfo(z_t.dtype).eps if z_t.is_floating_point() else 1e-8 + time_weight = torch.clamp((1.0 / (time_eps + time.view(-1, 1))).sqrt(), max=10.0) else: # Discrete time path (for backward compatibility) - t_emb = self.time_embed(time) + t_emb = self.time_embed(time.float()) time_weight = 1.0 # Encode observation if observation is not None: - obs_emb = self.obs_encoder(observation) + def _obs_trunk(obs): + return self.obs_encoder(obs) + obs_emb = self._maybe_cp(_obs_trunk, observation, enable=self.use_checkpoint) + obs_emb = F.normalize(obs_emb, dim=-1) else: # Use learned null embedding - obs_emb = torch.zeros(batch_size, self.obs_encoder[-1].out_features, - device=z_t.device) - + obs_emb = torch.zeros(batch_size, self.latent_proj.out_features, device=z_t.device) + + B, H = z_t.shape[0], self.latent_proj.out_features # H == hidden_dim + ft_emb = self.ft_embedder(frame_time) if frame_time is not None else torch.zeros(B, H, device=z_t.device) + a_emb = self.a_embedder(action, self.training, force_drop_action) if action is not None else torch.zeros(B, H, device=z_t.device) # Combine conditioning (time + observation) # This will be used for adaptive normalization in DiT blocks - conditioning = t_emb + obs_emb # [B, hidden_dim] + conditioning = t_emb + obs_emb + ft_emb + a_emb # Project latent to hidden dimension h = self.latent_proj(z_t) # [B, hidden_dim] @@ -158,13 +245,14 @@ def forward( if self.use_attention: # Process through DiT blocks for block in self.transformer_blocks: - h = block(h, conditioning) + h = self._maybe_cp(lambda x, c: block(x, c), h, conditioning, enable=self.use_checkpoint) + + # Final norm and output h = self.norm_final(h, conditioning) - score = self.output_proj(h) - score = torch.clamp(score, min=-10, max=10) - score = score * self.output_multiplier # Scale output + score = self._maybe_cp(lambda x: self.output_proj(x), h, enable=self.use_checkpoint) + score = torch.clamp(score, min=-2, max=2) * self.output_multiplier # Prevent extreme scores if is_continuous: # Apply time-dependent scaling for continuous time score = score * time_weight @@ -197,6 +285,7 @@ def __init__(self, hidden_dim: int, num_heads: int, mlp_ratio: float = 4.0): mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(hidden_dim, mlp_hidden_dim), + nn.LayerNorm(mlp_hidden_dim), nn.GELU(), nn.Linear(mlp_hidden_dim, hidden_dim) ) @@ -207,9 +296,9 @@ def __init__(self, hidden_dim: int, num_heads: int, mlp_ratio: float = 4.0): def _init_weights(self): # Initialize MLP nn.init.xavier_uniform_(self.mlp[0].weight) - nn.init.xavier_uniform_(self.mlp[2].weight) + nn.init.xavier_uniform_(self.mlp[3].weight) nn.init.zeros_(self.mlp[0].bias) - nn.init.zeros_(self.mlp[2].bias) + nn.init.zeros_(self.mlp[3].bias) def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: """ @@ -223,7 +312,8 @@ def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: norm_x = self.norm1(x, conditioning) # For single token (no sequence), we need to add sequence dimension norm_x = norm_x.unsqueeze(1) # [B, 1, hidden_dim] - attn_out, _ = self.attention(norm_x, norm_x, norm_x) + with sdpa_kernel(SDPBackend.MATH): + attn_out, _ = self.attention(norm_x, norm_x, norm_x) attn_out = attn_out.squeeze(1) # [B, hidden_dim] x = x + attn_out @@ -247,7 +337,7 @@ def __init__(self, hidden_dim: int): # Projection for adaptive parameters self.adaLN_modulation = nn.Sequential( nn.SiLU(), - nn.Linear(hidden_dim, 2 * hidden_dim) + nn.Linear(hidden_dim, 2 * hidden_dim), ) # Initialize modulation to identity diff --git a/active_inference_diffusion/models/value_networks.py b/active_inference_diffusion/models/value_networks.py index d3a6391..368aa03 100644 --- a/active_inference_diffusion/models/value_networks.py +++ b/active_inference_diffusion/models/value_networks.py @@ -1,62 +1,39 @@ """ Value network implementations """ - import torch import torch.nn as nn -from active_inference_diffusion.models.score_networks import SinusoidalPositionEmbeddings +import torch.nn.functional as F +from active_inference_diffusion.utils.util import DiscDist class ValueNetwork(nn.Module): - """ - State value function V(s,t) - """ - - def __init__( - self, - state_dim: int, - hidden_dim: int = 256, - time_embed_dim: int = 128, - num_layers: int = 3 - ): + def __init__(self, + state_dim: int, + hidden_dim: int = 256, + num_layers: int = 3, + num_bins: int = 255): super().__init__() - - # Time embedding - self.time_embed = nn.Sequential( - SinusoidalPositionEmbeddings(time_embed_dim), - nn.Linear(time_embed_dim, time_embed_dim), - nn.ReLU() - ) - - # Value network + self.num_bins = int(num_bins) layers = [] - input_dim = state_dim + time_embed_dim - + in_dim = state_dim for i in range(num_layers): - if i == 0: - layers.append(nn.Linear(input_dim, hidden_dim)) - else: - layers.append(nn.Linear(hidden_dim, hidden_dim)) - - layers.append(nn.LayerNorm(hidden_dim)) - layers.append(nn.ReLU()) - - layers.append(nn.Linear(hidden_dim, 1)) - - self.network = nn.Sequential(*layers) - - def forward(self, state: torch.Tensor, time: torch.Tensor) -> torch.Tensor: - """ - Compute state value - - Args: - state: State tensor [batch_size, state_dim] - time: Time tensor [batch_size] - - Returns: - Value [batch_size, 1] - """ - t_emb = self.time_embed(time) - inputs = torch.cat([state, t_emb], dim=-1) - return self.network(inputs) + layers += [nn.Linear(in_dim if i == 0 else hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), nn.ReLU()] + self.backbone = nn.Sequential(*layers) + self.head = nn.Linear(hidden_dim, self.num_bins) + nn.init.zeros_(self.head.weight); nn.init.zeros_(self.head.bias) + # store range once + self.low, self.high = -10.0, 10.0 + + def forward(self, state): + return self.head(self.backbone(state)) # [B, K] + + @torch.no_grad() + def expected_value(self, logits): + dist = DiscDist(logits, low=self.low, high=self.high, device=logits.device) + return dist.mean().squeeze(-1) # <- squeeze to [B] + + def loss_from_returns(self, logits, returns): + dist = DiscDist(logits, low=self.low, high=self.high, device=logits.device) + return (-dist.log_prob(returns)).mean() # returns can be [B] - diff --git a/active_inference_diffusion/utils/__init__.py b/active_inference_diffusion/utils/__init__.py index 1903e2d..60f5b8b 100644 --- a/active_inference_diffusion/utils/__init__.py +++ b/active_inference_diffusion/utils/__init__.py @@ -1,4 +1,4 @@ -from .buffers import ReplayBuffer +from .buffers import ReplayBuffer, PrioritizedSequenceReplayBuffer from .logger import Logger from .training import ( evaluate_agent, @@ -10,6 +10,7 @@ from .util import visualize_reconstruction, SpatialAttentionAggregator __all__ = [ "ReplayBuffer", + "PrioritizedSequenceReplayBuffer", "Logger", "evaluate_agent", "save_checkpoint", diff --git a/active_inference_diffusion/utils/async_collector.py b/active_inference_diffusion/utils/async_collector.py index 8eb3df6..ce4f096 100644 --- a/active_inference_diffusion/utils/async_collector.py +++ b/active_inference_diffusion/utils/async_collector.py @@ -2,7 +2,7 @@ GPU-Optimized Parallel Data Collection Architecture Separates environment stepping (CPU) from diffusion inference (GPU) """ - +import os import torch import torch.multiprocessing as mp from threading import Thread, Event @@ -22,6 +22,7 @@ """ + class GPUCentralizedCollector: """ Hybrid CPU-GPU architecture using robust vectorized environments @@ -36,7 +37,7 @@ def __init__( agent: Optional[Any] = None, num_envs: int = 8, max_queue_size: int = 32, - use_mixed_precision: bool = True, + use_mixed_precision: bool = False, use_shared_memory: bool = True # Enable for pixel observations ): self.num_envs = num_envs @@ -45,11 +46,15 @@ def __init__( self.use_shared_memory = use_shared_memory self._closing = False # Add flag to track if we're closing self._pending_futures = [] # Track pending inference futures - + self._episode_steps = [0 for _ in range(num_envs)] # Initialize episode steps + self._prev_actions =None # Ensure agent components are on GPU and in eval mode if hasattr(agent, 'active_inference'): agent.active_inference = agent.active_inference.to(self.device) agent.active_inference.eval() + + if hasattr(agent, 'epistemic_optimizer') and not hasattr(agent.active_inference, 'epistemic_optimizer'): + agent.active_inference.epistemic_optimizer = agent.epistemic_optimizer if hasattr(agent, 'encoder'): agent.encoder = agent.encoder.to(self.device) @@ -137,12 +142,15 @@ def collect_parallel_batch( try: while steps_collected < num_steps and not self._closing: loop_start = time.time() - + self._cleanup_futures() # Clean up completed futures # === GPU PHASE: Batched Inference === obs_batch = self._prepare_observation_batch(self.current_observations) - + frame_indices = torch.tensor(self._episode_steps, dtype=torch.long, device=self.device) # Submit to GPU for diffusion + policy inference - inference_future = self.gpu_inference.submit_batch(obs_batch) + prev_actions_tensor = None + if self._prev_actions is not None: + prev_actions_tensor = torch.tensor(self._prev_actions, dtype=torch.float32, device=self.device) + inference_future = self.gpu_inference.submit_batch(obs_batch, frame_indices=frame_indices, prev_actions=prev_actions_tensor) self._pending_futures.append(inference_future) # Track pending future # Wait for GPU to complete with timeout handling @@ -188,7 +196,10 @@ def collect_parallel_batch( # === CPU PHASE: Parallel Environment Stepping === actions_np = actions_batch.cpu().numpy() - + if self._prev_actions is None: + prev_action_ = np.zeros_like(actions_np, dtype=np.float32) + else: + prev_action_ = self._prev_actions.copy() # Validate action shape if actions_np.shape[1:] != self.action_shape: print(f"Warning: action shape {actions_np.shape[1:]} doesn't match expected {self.action_shape}") @@ -215,13 +226,15 @@ def collect_parallel_batch( actions_np[i], rewards[i], next_observations[i], - terminateds[i] or truncateds[i] + terminateds[i] or truncateds[i], + frame_idx=self._episode_steps[i], + prev_action=prev_action_[i] ) episode_rewards[i] += rewards[i] episode_lengths[i] += 1 steps_collected += 1 - + self._episode_steps[i] += 1 if terminateds[i] or truncateds[i]: completed_episodes.append({ 'reward': episode_rewards[i], @@ -230,25 +243,44 @@ def collect_parallel_batch( }) episode_rewards[i] = 0.0 episode_lengths[i] = 0 - + self._episode_steps[i] = 0 # Reset step count for this env + if self._prev_actions is not None: + self._prev_actions[i] = np.zeros(self.action_shape[0]) # Update observations for next iteration self.current_observations = next_observations - + self._prev_actions = actions_np.copy() # Log performance every 100 steps if steps_collected % 100 == 0: self._log_performance() + if steps_collected % 1000 == 0: + print(f"Collected {steps_collected} steps in {time.time() - loop_start:.2f}s") + torch.cuda.empty_cache() # Clear GPU memory to avoid fragmentation except Exception as e: print(f"Error in collection loop: {e}") traceback.print_exc() raise finally: + self._cleanup_all_futures() # Ensure all futures are cleaned up self.gpu_inference.stop() # Compute final statistics - stats = self._compute_statistics(steps_collected, completed_episodes) - return stats + return self._compute_statistics(steps_collected, completed_episodes) + def _cleanup_futures(self): + """Remove completed futures from tracking list""" + self._pending_futures = [f for f in self._pending_futures + if not f.ready.is_set()] + + def _cleanup_all_futures(self): + """Cancel and remove all pending futures""" + for future in self._pending_futures: + try: + future.cancel() + except: + pass + self._pending_futures.clear() + def _prepare_observation_batch(self, observations: np.ndarray) -> torch.Tensor: """ Convert numpy observations to GPU tensor with proper handling @@ -364,7 +396,10 @@ def __init__( self.inference_queue = Queue(maxsize=max_queue_size) self.result_queue = Queue(maxsize=max_queue_size) self.shutdown_event = Event() - + self.inference_steps = agent.config.diffusion.inference_steps + self.ddim_eta = agent.config.diffusion.ddim_eta + self.use_ddim = (self.inference_steps is not None and + self.inference_steps < agent.config.diffusion.num_diffusion_steps) # CUDA optimization self.inference_stream = torch.cuda.Stream() self.scaler = torch.cuda.amp.GradScaler() if use_mixed_precision else None @@ -406,16 +441,21 @@ def stop(self): if self.gpu_thread.is_alive(): print("Warning: GPU inference thread did not shut down cleanly") - def submit_batch(self, observations_batch: torch.Tensor) -> 'InferenceFuture': + def submit_batch(self, observations_batch: torch.Tensor, + frame_indices: Optional[torch.Tensor] = None, + prev_actions: Optional[torch.Tensor] = None) -> 'InferenceFuture': """Submit observation batch for GPU inference""" future = InferenceFuture() - + if frame_indices is not None: + frame_indices = frame_indices.to(observations_batch.device) + if prev_actions is not None: + prev_actions = prev_actions.to(observations_batch.device) try: - self.inference_queue.put((observations_batch, future), timeout=1.0) + self.inference_queue.put((observations_batch, frame_indices, prev_actions, future), timeout=1.0) except: # Queue full - return dummy actions print("Inference queue full, returning dummy actions") - dummy_actions = torch.zeros(observations_batch.shape[0], self.action_dim) + dummy_actions = torch.zeros(observations_batch.shape[0], self.action_dim, device=observations_batch.device) future.set_result(dummy_actions) return future @@ -436,9 +476,9 @@ def _gpu_inference_worker(self): if batch_item is None: # Shutdown sentinel break - - observations_batch, future = batch_item - + + observations_batch, frame_indices, prev_actions, future = batch_item + # Skip if future already cancelled if hasattr(future, 'cancelled') and future.cancelled: continue @@ -449,9 +489,9 @@ def _gpu_inference_worker(self): # Execute batched diffusion inference on GPU if torch.cuda.is_available(): with torch.cuda.stream(self.inference_stream): - actions_batch = self._batched_diffusion_inference(observations_batch) + actions_batch = self._batched_diffusion_inference(observations_batch, frame_indices, prev_actions) else: - actions_batch = self._batched_diffusion_inference(observations_batch) + actions_batch = self._batched_diffusion_inference(observations_batch, frame_indices, prev_actions) # Performance tracking inference_time = time.time() - start_time @@ -482,30 +522,41 @@ def _gpu_inference_worker(self): traceback.print_exc() print("GPU inference worker stopped") - - def _batched_diffusion_inference(self, observations_batch: torch.Tensor) -> torch.Tensor: + + def _batched_diffusion_inference(self, + observations_batch: torch.Tensor, + frame_indices: Optional[torch.Tensor] = None, + prev_actions: Optional[torch.Tensor] = None) -> torch.Tensor: """ Vectorized diffusion inference across entire observation batch Core optimization: Single GPU call for all environments """ batch_size = observations_batch.shape[0] - + device = observations_batch.device + if frame_indices is not None: + frame_indices = frame_indices.to(device) + if prev_actions is not None: + prev_actions = prev_actions.to(device) try: with torch.no_grad(): + + if self.use_mixed_precision and torch.cuda.is_available(): with torch.amp.autocast(device_type='cuda', dtype=torch.float16): - actions_batch = self._inference_impl(observations_batch, batch_size) + actions_batch = self._inference_impl(observations_batch, batch_size, frame_indices, prev_actions) else: - actions_batch = self._inference_impl(observations_batch, batch_size) - + actions_batch = self._inference_impl(observations_batch, batch_size, frame_indices, prev_actions) + return actions_batch.float() except Exception as e: print(f"Error in batched diffusion inference: {e}") # Return random actions as fallback - return torch.randn(batch_size, self.action_dim, device=observations_batch.device) * 0.1 + return torch.randn(batch_size, self.action_dim, device=device) * 0.1 - def _inference_impl(self, observations_batch: torch.Tensor, batch_size: int) -> torch.Tensor: + def _inference_impl(self, observations_batch: torch.Tensor, batch_size: int, + frame_indices: Optional[torch.Tensor] = None, + prev_actions: Optional[torch.Tensor] = None) -> torch.Tensor: """Implementation of inference logic""" # Encode observations (if pixel-based) if hasattr(self.agent, 'encoder'): @@ -516,7 +567,11 @@ def _inference_impl(self, observations_batch: torch.Tensor, batch_size: int) -> # Batched belief generation via reverse diffusion latents_batch = self._batch_diffusion_sampling( encoded_obs, - num_steps=self.max_diffusion_steps + num_steps=self.inference_steps if self.use_ddim else self.max_diffusion_steps, + use_ddim=self.use_ddim, + eta=self.ddim_eta, + frame_indices=frame_indices, + actions=prev_actions ) # Batched policy evaluation @@ -530,7 +585,11 @@ def _inference_impl(self, observations_batch: torch.Tensor, batch_size: int) -> def _batch_diffusion_sampling( self, observations_batch: torch.Tensor, - num_steps: Optional[int] = None + num_steps: Optional[int] = None, + use_ddim: bool = False, + eta: float = 0.0, + frame_indices: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Optimized batched reverse diffusion with reduced steps @@ -538,6 +597,17 @@ def _batch_diffusion_sampling( """ batch_size = observations_batch.shape[0] latent_dim = self.agent.config.latent_dim + + if frame_indices is not None: + if isinstance(frame_indices, np.ndarray): + frame_indices = torch.tensor(frame_indices, dtype=torch.long, device=observations_batch.device) + else: + frame_indices = frame_indices.to(observations_batch.device) + if actions is not None: + if isinstance(actions, np.ndarray): + actions = torch.tensor(actions, dtype=torch.float32, device=observations_batch.device) + else: + actions = actions.to(observations_batch.device) if num_steps is None: num_steps = self.max_diffusion_steps else: @@ -575,23 +645,21 @@ def _batch_diffusion_sampling( if hasattr(self.agent.active_inference.latent_diffusion, 'continuous_time') and \ self.agent.active_inference.latent_diffusion.continuous_time: t_continuous = t_batch.float() / max_index - score_batch = self.agent.active_inference.latent_score_network( - z_batch, t_continuous, observations_batch - ) + with torch.no_grad(): + score_batch = self.agent.active_inference.latent_score_network(z_batch.detach(), t_continuous, observations_batch, frame_time=frame_indices, action=actions) else: # Use clamped long indices for discrete time - score_batch = self.agent.active_inference.latent_score_network( - z_batch, t_batch_long.float(), observations_batch - ) - + with torch.no_grad(): + score_batch = self.agent.active_inference.latent_score_network(z_batch.detach(), t_batch_long.float(), observations_batch, frame_time=frame_indices, action=actions) + # Use clamped long indices for diffusion update - z_batch = self.agent.active_inference.latent_diffusion.p_sample( - z_batch, t_batch_long, score_batch, deterministic=False - ) + z_batch = self.agent.active_inference.latent_diffusion.p_sample(z_batch.detach(), t_batch_long, score_batch, deterministic=False) + z_batch = z_batch.detach() # Detach to avoid gradients accumulating if torch.isnan(z_batch).any() or torch.isinf(z_batch).any(): print(f"NaN/Inf detected at diffusion step {step}, reinitializing") - z_batch = torch.randn_like(z_batch) * 0.1 - + z_batch = torch.randn_like(z_batch, device=z_batch.device) * 0.1 + del score_batch # Free memory + torch.cuda.empty_cache() # Clear cache to avoid fragmentation return z_batch diff --git a/active_inference_diffusion/utils/buffers.py b/active_inference_diffusion/utils/buffers.py index 92d836f..acd2b24 100644 --- a/active_inference_diffusion/utils/buffers.py +++ b/active_inference_diffusion/utils/buffers.py @@ -5,9 +5,61 @@ import torch import numpy as np -from typing import Dict, Tuple, Optional +from typing import Dict, Tuple, Optional, List import lz4.frame import pickle +import random + +# ----------------------------- SumTree ----------------------------- + +class _SumTree: + """Simple SumTree for Prioritized Replay (CPU numpy).""" + def __init__(self, capacity: int): + assert capacity > 0 + self.capacity = int(capacity) + self.tree = np.zeros(2 * self.capacity - 1, dtype=np.float64) + self.data_index = np.zeros(self.capacity, dtype=np.int32) # leaf_slot -> buffer_pos + self.write = 0 + self.n_entries = 0 + + @property + def total(self) -> float: + return float(self.tree[0]) + + def add(self, priority: float, data_idx: int) -> None: + """Insert new leaf at current write pointer (leaf_slot) and map to buffer_pos=data_idx.""" + t_idx = self.write + self.capacity - 1 + self.data_index[self.write] = int(data_idx) + self.update(t_idx, priority) + self.write = (self.write + 1) % self.capacity + self.n_entries = min(self.n_entries + 1, self.capacity) + + def update(self, t_idx: int, priority: float) -> None: + change = float(priority) - float(self.tree[t_idx]) + self.tree[t_idx] = float(priority) + self._propagate(t_idx, change) + + def _propagate(self, t_idx: int, change: float) -> None: + parent = (t_idx - 1) // 2 + self.tree[parent] += change + if parent != 0: + self._propagate(parent, change) + + def _retrieve(self, t_idx: int, s: float) -> int: + left = 2 * t_idx + 1 + right = left + 1 + if left >= len(self.tree): + return t_idx + if s <= self.tree[left]: + return self._retrieve(left, s) + return self._retrieve(right, s - self.tree[left]) + + def get(self, s: float) -> Tuple[int, float, int]: + """Return (tree_index, priority_value, buffer_pos).""" + t_idx = self._retrieve(0, s) + leaf_slot = t_idx - self.capacity + 1 + return t_idx, float(self.tree[t_idx]), int(self.data_index[leaf_slot]) + class ReplayBuffer: """ @@ -28,7 +80,9 @@ def __init__( self.obs_shape = obs_shape self.pos = 0 self.size = 0 - + self.cleanup_count = 0 + self.cleanup_interval = 2000 # Cleanup every 2000 additions + self.frame_idx = torch.zeros(capacity, dtype=torch.long, device=device) # Allocate memory if optimize_memory and len(obs_shape) == 3: # Pixel observations self.observations = [None] * capacity @@ -42,6 +96,7 @@ def __init__( self.dtype = np.float32 self.actions = torch.zeros((capacity, action_dim), dtype=torch.float32) + self.prev_actions = torch.zeros((capacity, action_dim), dtype=torch.float32) self.rewards = torch.zeros(capacity, dtype=torch.float32) self.dones = torch.zeros(capacity, dtype=torch.bool) @@ -51,7 +106,9 @@ def add( action: np.ndarray, reward: float, next_obs: np.ndarray, - done: bool + done: bool, + frame_idx: int = 0, + prev_action: Optional[np.ndarray] = None ): """Add transition to buffer""" if self.compress: @@ -74,11 +131,24 @@ def add( self.next_observations[self.pos] = torch.from_numpy(next_obs) self.actions[self.pos] = torch.from_numpy(action) + if prev_action is None: + self.prev_actions[self.pos] = torch.zeros_like(self.actions[self.pos]) + else: + self.prev_actions[self.pos] = torch.as_tensor(prev_action, dtype=torch.float32) self.rewards[self.pos] = reward self.dones[self.pos] = bool(done) - + self.frame_idx[self.pos] = int(frame_idx) + self.pos = (self.pos + 1) % self.capacity self.size = min(self.size + 1, self.capacity) + self.cleanup_count += 1 + if self.compress and self.cleanup_count >= self.cleanup_interval: + # Perform cleanup to free memory + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.cleanup_count = 0 def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: """Sample batch of transitions""" @@ -90,9 +160,9 @@ def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: next_obs_list = [] for i in indices: - obs_list.append(self._decompress_with_shape(self.observations[i])) - next_obs_list.append(self._decompress_with_shape(self.next_observations[i])) - + obs = self._decompress(self.observations[i]) + next_obs = self._decompress(self.next_observations[i]) + obs_list.append(torch.from_numpy(obs).float() / 255.0) next_obs_list.append(torch.from_numpy(next_obs).float() / 255.0) @@ -105,9 +175,11 @@ def sample(self, batch_size: int) -> Dict[str, torch.Tensor]: return { 'observations': obs, 'actions': self.actions[indices], + 'prev_actions': self.prev_actions[indices], 'rewards': self.rewards[indices], 'next_observations': next_obs, - 'dones': self.dones[indices] # Already float + 'dones': self.dones[indices], # Already float + 'frame_idx': self.frame_idx[indices] } def _compress(self, data: np.ndarray) -> bytes: @@ -139,4 +211,445 @@ def _decompress(self, compressed_data: bytes) -> np.ndarray: def __len__(self): return self.size + +class SequenceReplayBuffer(ReplayBuffer): + """Replay buffer that stores sequences for LSTM training""" + + def __init__(self, capacity, obs_shape, action_dim, device, + sequence_length=10, overlap=5): + super().__init__(capacity, obs_shape, action_dim, device) + self.sequence_length = sequence_length + self.overlap = overlap + self.episodes = [] # Store complete episodes + self.current_episode = [] + + def add(self, obs, action, reward, next_obs, done, frame_idx=0, prev_action=None): + # Add to current episode + self.current_episode.append({ + 'obs': obs, 'action': action, 'reward': reward, + 'next_obs': next_obs, 'done': done, 'frame_idx': frame_idx, + 'prev_action': prev_action + }) + + # If episode ends, store it + if done: + if len(self.current_episode) > 1: + self.episodes.append(self.current_episode) + if len(self.episodes) > self.capacity // self.sequence_length: + self.episodes.pop(0) + self.current_episode = [] + + # Also add to regular buffer for standard sampling + super().add(obs, action, reward, next_obs, done, frame_idx, prev_action=prev_action) + + def sample_sequences(self, batch_size): + """Sample sequences with proper padding and masking""" + if len(self.episodes) < batch_size: + return None + # Sample episodes + sampled_episodes = np.random.choice(self.episodes, batch_size, replace=True) + + # Create padded sequences + max_len = self.sequence_length + sequences = { + 'observations': [], + 'actions': [], + 'prev_actions': [], + 'rewards': [], + 'dones': [], + 'lengths': [], + 'frame_indices': [] + } + + for episode in sampled_episodes: + # Sample a subsequence from the episode + ep_len = len(episode) + if ep_len > max_len: + start_idx = np.random.randint(0, ep_len - max_len + 1) + subsequence = episode[start_idx:start_idx + max_len] + else: + subsequence = episode + + # Extract data + obs_seq = [step['obs'] for step in subsequence] + act_seq = [step['action'] for step in subsequence[:-1]] # One less action + prev_act_seq = [step.get('prev_action', np.zeros_like(act_seq[0])) + for step in subsequence[:-1]] + rew_seq = [step['reward'] for step in subsequence[:-1]] + done_seq = [step['done'] or step.get('truncated') for step in subsequence[:-1]] + frame_idx_seq = [step['frame_idx'] for step in subsequence] + + # Pad if necessary + actual_len = len(obs_seq) + if actual_len < max_len: + # Pad with zeros + pad_len = max_len - actual_len + obs_seq.extend([np.zeros_like(obs_seq[0])] * pad_len) + act_seq.extend([np.zeros_like(act_seq[0])] * (pad_len)) + prev_act_seq.extend([np.zeros_like(prev_act_seq[0])] * pad_len) + rew_seq.extend([0.0] * pad_len) + done_seq.extend([True] * pad_len) # Mark padded as done + frame_idx_seq.extend([0] * pad_len) # Zero frame indices for padding + + sequences['observations'].append(obs_seq) + sequences['actions'].append(act_seq[:max_len-1]) + sequences['prev_actions'].append(prev_act_seq[:max_len-1]) + sequences['rewards'].append(rew_seq[:max_len-1]) + sequences['dones'].append(done_seq[:max_len-1]) + sequences['frame_indices'].append(frame_idx_seq[:max_len]) + sequences['lengths'].append(actual_len) + + # Convert to tensors + return { + 'observations': torch.tensor(np.array(sequences['observations']), + dtype=torch.float32, device=self.device), + 'actions': torch.tensor(np.array(sequences['actions']), + dtype=torch.float32, device=self.device), + 'prev_actions': torch.tensor(np.array(sequences['prev_actions']), + dtype=torch.float32, device=self.device), + 'rewards': torch.tensor(np.array(sequences['rewards']), + dtype=torch.float32, device=self.device), + 'dones': torch.tensor(np.array(sequences['dones']), + dtype=torch.bool, device=self.device), + 'lengths': torch.tensor(sequences['lengths'], + dtype=torch.long, device=self.device), + 'frame_indices': torch.tensor(np.array(sequences['frame_indices']), + dtype=torch.long, device=self.device), + } + + +class PrioritizedSequenceReplayBuffer(SequenceReplayBuffer): + """ + Prioritized sequence replay combining episodic storage with prioritized sampling. + + We keep a mapping buffer_pos -> (episode_id, step_in_episode). The SumTree stores + leaves that map to buffer_pos (real storage positions in the parent). Sampling a + leaf retrieves the buffer_pos; we look up the episode/time key to build a window. + """ + def __init__( + self, + capacity: int, + obs_shape: Tuple[int, ...], + action_dim: int, + device: torch.device, + sequence_length: int = 10, + overlap: int = 5, + # PER + alpha: float = 0.6, + beta_start: float = 0.4, + beta_end: float = 1.0, + beta_frames: int = 100_000, + eps: float = 1e-6, + ): + super().__init__(capacity, obs_shape, action_dim, device, sequence_length, overlap) + + # PER hyperparameters + self.alpha = float(alpha) + self.beta_start = float(beta_start) + self.beta_end = float(beta_end) + self.beta_frames = int(beta_frames) + self.eps = float(eps) + self.frame_count = 0 + self.max_priority = 1.0 + + # Tree with leaf->buffer_pos mapping + self._tree = _SumTree(capacity) + + # buffer_pos -> (episode_id, step_in_episode); -1 means unknown + self._pos_to_key: List[Optional[Tuple[int,int]]] = [None] * capacity + + # Episode-id tracking (monotonic ids) and O(1) id→index map + self._next_episode_id: int = 0 + self._current_episode_id: Optional[int] = None + self._episode_ids: List[int] = [] # index-aligned with self.episodes + self._id_to_index: Dict[int, int] = {} # ep_id -> current index in self.episodes + + # ------------------------ episode/key bookkeeping ------------------------ + + def _on_episode_started(self) -> None: + self._current_episode_id = self._next_episode_id + self._next_episode_id += 1 + + def _on_episode_finished(self) -> None: + """Call when parent has appended a new episode to self.episodes (after add).""" + if self._current_episode_id is None: + return + # Append ID and update dict + self._episode_ids.append(self._current_episode_id) + self._id_to_index[self._current_episode_id] = len(self._episode_ids) - 1 + + # Keep in sync with parent's possible eviction policy: + # If parent popped from the left (older episode), drop our leftmost id too. + max_eps = self.capacity // self.sequence_length + while len(self._episode_ids) > max_eps and len(self.episodes) > 0: + popped_id = self._episode_ids.pop(0) + self._id_to_index.pop(popped_id, None) + # Shift remaining indices down by 1 + for eid in list(self._id_to_index.keys()): + self._id_to_index[eid] -= 1 + if self._id_to_index[eid] < 0: + # defensive; should not happen + self._id_to_index.pop(eid, None) + + self._current_episode_id = None + + def _episode_index_from_id(self, ep_id: int) -> Optional[int]: + return self._id_to_index.get(ep_id, None) + + # ------------------------------ API -------------------------------------- + + def add(self, obs, action, reward, next_obs, done, frame_idx: int = 0, prev_action=None, priority: Optional[float] = None): + """ + Add one environment step. + + FIXED: We use the parent's actual buffer position for the new write as our SumTree data_idx. + We capture `buffer_pos = self.pos` BEFORE calling super().add(...), since most ring buffers + write to `pos` then increment it. + """ + episodes_before = len(self.episodes) + + # Start-of-episode detection + episode_was_empty = (len(self.current_episode) == 0) + if episode_was_empty: + self._on_episode_started() + + # Capture mapping BEFORE push + step_in_episode = len(self.current_episode) + current_ep = self._current_episode_id if self._current_episode_id is not None else -1 + + # Buffer position where parent will write this step (ring index) + buffer_pos = int(self.pos) + + # Delegate to parent + super().add(obs, action, reward, next_obs, done, frame_idx, prev_action) + + # Record mapping from buffer position to episode/time key + self._pos_to_key[buffer_pos] = (current_ep, step_in_episode) + + # Assign initial priority and insert leaf (SumTree maps leaf_slot -> buffer_pos) + base_pr = self.max_priority if priority is None else float(priority) + self._tree.add(base_pr ** self.alpha, buffer_pos) + self.max_priority = max(self.max_priority, base_pr) + + # Episode finish bookkeeping (robust) + if len(self.episodes) > episodes_before: + self._on_episode_finished() + + def sample_sequences(self, batch_size: int) -> Optional[Dict[str, torch.Tensor]]: + """Sample sequences via per-step PER while respecting episode boundaries.""" + if self._tree.n_entries == 0 or len(self.episodes) < batch_size: + return None + + # Anneal beta + self.frame_count += 1 + beta = min(self.beta_end, self.beta_start + (self.beta_end - self.beta_start) * (self.frame_count / self.beta_frames)) + + sequences = { + 'observations': [], + 'actions': [], + 'prev_actions': [], + 'rewards': [], + 'dones': [], + 'lengths': [], + 'frame_indices': [], + 'tree_indices': [], + } + priorities = [] + + segment = max(self._tree.total, 1e-12) / batch_size + + for i in range(batch_size): + # Try twice to avoid stale mappings + for attempt in range(2): + a, b = segment * i, segment * (i + 1) + t_idx, p, buffer_pos = self._tree.get(random.uniform(a, b)) + key = self._pos_to_key[buffer_pos] + + if key is None: + # stale/overwritten; demote and retry + self._tree.update(t_idx, self.eps) + continue + + ep_id, step_idx = key + epi_idx = self._episode_index_from_id(ep_id) + if epi_idx is None or not (0 <= epi_idx < len(self.episodes)): + # episode evicted; demote and retry + self._tree.update(t_idx, self.eps) + continue + + episode = self.episodes[epi_idx] + ep_len = len(episode) + if ep_len == 0: + self._tree.update(t_idx, self.eps) + continue + + # Build window [seq_start, seq_end) centered if possible + T = self.sequence_length + half = T // 2 + seq_start = max(0, min(step_idx - half, ep_len - T)) + seq_end = min(ep_len, seq_start + T) + if seq_end - seq_start < T: + seq_start = max(0, seq_end - T) + subseq = episode[seq_start:seq_end] + + # Extract fields (parent API) + obs_seq = [st['obs'] for st in subseq] + if len(subseq) > 1: + act_seq = [st['action'] for st in subseq[:-1]] + prev_seq = [st.get('prev_action', np.zeros_like(act_seq[0])) for st in subseq[:-1]] + rew_seq = [st['reward'] for st in subseq[:-1]] + done_seq = [bool(st.get('done') or st.get('truncated')) for st in subseq[:-1]] + else: + act_seq = [subseq[0]['action']] + prev_seq = [subseq[0].get('prev_action', np.zeros_like(act_seq[0]))] + rew_seq = [subseq[0]['reward']] + done_seq = [bool(subseq[0].get('done') or subseq[0].get('truncated'))] + + frame_idx_seq = [int(st['frame_idx']) for st in subseq] + + # Pad to fixed length + actual_len = len(obs_seq) + if actual_len < T: + pad_len = T - actual_len + obs_seq.extend([np.zeros_like(obs_seq[0])] * pad_len) + pad_acts = (T - 1) - len(act_seq) + if pad_acts > 0: + act_seq.extend([np.zeros_like(act_seq[0])] * pad_acts) + prev_seq.extend([np.zeros_like(prev_seq[0])] * pad_acts) + rew_seq.extend([0.0] * pad_acts) + done_seq.extend([True] * pad_acts) + frame_idx_seq.extend([0] * pad_len) + + # Record + sequences['observations'].append(obs_seq) + sequences['actions'].append(act_seq[:T-1]) + sequences['prev_actions'].append(prev_seq[:T-1]) + sequences['rewards'].append(rew_seq[:T-1]) + sequences['dones'].append(done_seq[:T-1]) + sequences['frame_indices'].append(frame_idx_seq[:T]) + sequences['lengths'].append(actual_len) + sequences['tree_indices'].append(t_idx) + priorities.append(p) + break + else: + # very rare: fallback random sequence + ridx = np.random.randint(0, len(self.episodes)) + episode = self.episodes[ridx] + subseq = episode[:min(len(episode), self.sequence_length)] + obs_seq = [st['obs'] for st in subseq] + if len(subseq) > 1: + act_seq = [st['action'] for st in subseq[:-1]] + prev_seq = [st.get('prev_action', np.zeros_like(act_seq[0])) for st in subseq[:-1]] + rew_seq = [st['reward'] for st in subseq[:-1]] + done_seq = [bool(st.get('done') or st.get('truncated')) for st in subseq[:-1]] + else: + act_seq = [np.zeros_like(episode[0]['action'])] + prev_seq = [np.zeros_like(act_seq[0])] + rew_seq = [0.0] + done_seq = [True] + frame_idx_seq = [int(st['frame_idx']) for st in subseq] + actual_len = len(obs_seq) + if actual_len < self.sequence_length: + pad_len = self.sequence_length - actual_len + obs_seq.extend([np.zeros_like(obs_seq[0])] * pad_len) + pad_acts = (self.sequence_length - 1) - len(act_seq) + if pad_acts > 0: + act_seq.extend([np.zeros_like(act_seq[0])] * pad_acts) + prev_seq.extend([np.zeros_like(prev_seq[0])] * pad_acts) + rew_seq.extend([0.0] * pad_acts) + done_seq.extend([True] * pad_acts) + frame_idx_seq.extend([0] * pad_len) + + sequences['observations'].append(obs_seq) + sequences['actions'].append(act_seq[:self.sequence_length-1]) + sequences['prev_actions'].append(prev_seq[:self.sequence_length-1]) + sequences['rewards'].append(rew_seq[:self.sequence_length-1]) + sequences['dones'].append(done_seq[:self.sequence_length-1]) + sequences['frame_indices'].append(frame_idx_seq[:self.sequence_length]) + sequences['lengths'].append(actual_len) + sequences['tree_indices'].append(0) + priorities.append(self.eps) + + # IS weights + p_tot = max(self._tree.total, 1e-12) + probs = np.asarray(priorities, dtype=np.float64) / p_tot + N = max(self._tree.n_entries, 1) + is_w = np.power(N * probs + 1e-12, -beta) + is_w /= is_w.max() + + # Pack tensors on device + out = { + 'observations': torch.tensor(np.array(sequences['observations']), dtype=torch.float32, device=self.device), + 'actions': torch.tensor(np.array(sequences['actions']), dtype=torch.float32, device=self.device), + 'prev_actions': torch.tensor(np.array(sequences['prev_actions']), dtype=torch.float32, device=self.device), + 'rewards': torch.tensor(np.array(sequences['rewards']), dtype=torch.float32, device=self.device), + 'dones': torch.tensor(np.array(sequences['dones']), dtype=torch.bool, device=self.device), + 'lengths': torch.tensor(sequences['lengths'], dtype=torch.long, device=self.device), + 'frame_indices': torch.tensor(np.array(sequences['frame_indices']),dtype=torch.long, device=self.device), + 'tree_indices': torch.tensor(sequences['tree_indices'], dtype=torch.long, device=self.device), + 'is_weights': torch.tensor(is_w, dtype=torch.float32, device=self.device), + } + # Alias for agent code + out['frame_idx'] = out['frame_indices'] + return out + + @torch.no_grad() + def update_priorities(self, tree_indices: torch.Tensor, priorities: torch.Tensor, *, already_alpha: bool=False) -> None: + """ + Update priorities for sampled sequences. + Expects RAW priorities (|td| + normalized ELBO + normalized dynamics). + If you pass already-α-transformed priorities, set already_alpha=True. + """ + if isinstance(tree_indices, torch.Tensor): + tree_indices = tree_indices.detach().cpu().numpy() + if isinstance(priorities, torch.Tensor): + priorities = priorities.detach().cpu().numpy() + + for t_idx, pr in zip(tree_indices, priorities): + pr = max(float(pr), self.eps) + pr_tree = pr if already_alpha else (pr ** self.alpha) + self.max_priority = max(self.max_priority, pr) + self._tree.update(int(t_idx), pr_tree) + + # ------------------------ helpers for priorities -------------------------- + + @staticmethod + def compute_sequence_priority( + td_errors: Optional[torch.Tensor] = None, + elbo_losses: Optional[torch.Tensor] = None, + dynamics_losses: Optional[torch.Tensor] = None, + weights: Tuple[float, float, float] = (1.0, 0.5, 0.5), + eps: float = 1e-6, + ) -> torch.Tensor: + """ + Combine sequence-level errors into RAW priorities (no alpha here). + Accepts either [B] or [B, T] tensors and aggregates across time. + """ + xs = [x for x in (td_errors, elbo_losses, dynamics_losses) if x is not None] + if not xs: + raise ValueError("At least one of td_errors, elbo_losses, dynamics_losses must be provided.") + device = xs[0].device + B = xs[0].shape[0] + pr = torch.zeros(B, device=device, dtype=torch.float32) + + if td_errors is not None: + td = td_errors.abs() + if td.dim() == 2: + td = td.max(dim=1)[0] # max over time + pr = pr + weights[0] * td + + if elbo_losses is not None: + el = elbo_losses.abs() + if el.dim() == 2: + el = el.mean(dim=1) # mean over time + el = el / (1.0 + el) # squash + pr = pr + weights[1] * el + + if dynamics_losses is not None: + dy = dynamics_losses.abs() + if dy.dim() == 2: + dy = dy.mean(dim=1) + dy = dy / (1.0 + dy) + pr = pr + weights[2] * dy + return pr + eps \ No newline at end of file diff --git a/active_inference_diffusion/utils/logger.py b/active_inference_diffusion/utils/logger.py index 4fba189..88b13ab 100644 --- a/active_inference_diffusion/utils/logger.py +++ b/active_inference_diffusion/utils/logger.py @@ -8,6 +8,7 @@ from pathlib import Path import numpy as np import torch +wandb.login(key="25ac7dd31de2bdf846aa5c67968ecd3ba42493ac") class Logger: """ Unified logger supporting console, file, and wandb @@ -16,6 +17,7 @@ class Logger: def __init__( self, use_wandb: bool = True, + entity: Optional[str] = None, project_name: str = "active-inference", experiment_name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, @@ -29,6 +31,7 @@ def __init__( if use_wandb: wandb.init( project=project_name, + entity=entity, name=experiment_name, config=config, settings=wandb.Settings(init_timeout=200), diff --git a/active_inference_diffusion/utils/training.py b/active_inference_diffusion/utils/training.py index ed53c81..66f92fb 100644 --- a/active_inference_diffusion/utils/training.py +++ b/active_inference_diffusion/utils/training.py @@ -11,7 +11,7 @@ import matplotlib.pyplot as plt from typing import List, Optional, Any import json - +from gymnasium.wrappers import RecordVideo def evaluate_agent( agent: Any, env: gym.Env, @@ -198,7 +198,7 @@ def create_video( >>> from active_inference_diffusion.utils.training import create_video >>> create_video(agent, env, "halfcheetah_trained.mp4", num_episodes=3) """ - from gymnasium.wrappers import RecordVideo + # Create video directory Path(video_folder).mkdir(exist_ok=True) diff --git a/active_inference_diffusion/utils/util.py b/active_inference_diffusion/utils/util.py index 3da6049..503a356 100644 --- a/active_inference_diffusion/utils/util.py +++ b/active_inference_diffusion/utils/util.py @@ -3,7 +3,11 @@ import torch.nn as nn import numpy as np import torch.nn.functional as F -from typing import Optional, Union +from typing import Optional, Union, Tuple +import matplotlib +from mpl_toolkits.axes_grid1.inset_locator import inset_axes # add at top of file or inside function + +matplotlib.use("Agg") class SpatialAttentionAggregator(nn.Module): """ Multi-head attention for spatially-aware epistemic feature aggregation @@ -73,20 +77,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.output_proj(attended_flat) return output, attention_weights + def visualize_reconstruction( - agent, # Type annotation removed to avoid circular import + agent, observations: torch.Tensor, + frame_idx: Optional[torch.Tensor] = None, + actions: Optional[torch.Tensor] = None, save_path: Optional[str] = None, - max_samples: int = 4 + max_samples: int = 4, + visualization_mode: str = "recent" ) -> float: """ Visualize observation reconstruction through the diffusion latent space. - Works for both pixel and state observations. + Properly handles frame-stacked inputs with single-frame decoder outputs. - This function demonstrates how observations are encoded to latents - and then decoded back, which is crucial for computing epistemic value. + Args: + agent: The agent to visualize + observations: Batch of observations (may be frame-stacked) + save_path: Where to save the visualization + max_samples: Maximum number of samples to visualize + visualization_mode: How to visualize frame stacks """ - device = observations.device + device = agent.device if hasattr(agent, 'device') else agent.active_inference.device # Ensure we're in eval mode was_training = agent.active_inference.training @@ -96,7 +108,11 @@ def visualize_reconstruction( # Move observations to device if observations.device != device: observations = observations.to(device) - + if frame_idx is not None and not torch.is_tensor(frame_idx): + frame_idx = torch.as_tensor(frame_idx, device=device) + if actions is not None and actions.device != device: + actions = actions.to(device) + # For pixel observations, we need to encode them first if hasattr(agent, 'encoder') and agent.config.pixel_observation: # Encode pixel observations to features @@ -106,90 +122,339 @@ def visualize_reconstruction( encoded_obs = observations[:max_samples] # Generate latents via diffusion - belief_info = agent.active_inference.update_belief_via_diffusion(encoded_obs) + belief_info = agent.active_inference.update_belief_via_diffusion(encoded_obs, + frame_idx=frame_idx[:max_samples] if frame_idx is not None else None, + actions=actions[:max_samples] if actions is not None else None) latents = belief_info['latent'] # Decode latents back to observation space reconstructed_obs = agent.active_inference.decode_observation(latents) - # For pixel observations, we need to handle the output differently + # Debug prints to understand shapes + print(f"Original observations shape: {observations[:max_samples].shape}") + print(f"Encoded observations shape: {encoded_obs.shape}") + print(f"Latents shape: {latents.shape}") + print(f"Reconstructed observations shape: {reconstructed_obs.shape}") + + # For pixel observations, handle visualization if agent.config.pixel_observation: - # Pixel reconstruction if save_path: - fig, axes = plt.subplots(2, max_samples, figsize=(max_samples * 3, 6)) + # Determine number of samples to visualize + n_samples = min(max_samples, observations.shape[0]) - for i in range(min(max_samples, observations.shape[0])): - # Original observation - if observations.shape[1] > 3: # Frame stacked - # Show only the most recent frame - orig = observations[i, -3:].cpu().numpy() - else: - orig = observations[i].cpu().numpy() - - # Handle channel format - if orig.shape[0] == 3: # (C, H, W) - orig = np.transpose(orig, (1, 2, 0)) - - # Reconstructed observation - recon = reconstructed_obs[i].cpu().numpy() - if recon.shape[0] == 3: # (C, H, W) - recon = np.transpose(recon, (1, 2, 0)) - - # Ensure values are in [0, 1] - orig = np.clip(orig, 0, 1) - recon = np.clip(recon, 0, 1) - - # Plot - axes[0, i].imshow(orig) - axes[0, i].set_title(f'Original {i}') - axes[0, i].axis('off') - - axes[1, i].imshow(recon) - axes[1, i].set_title(f'Reconstructed {i}') - axes[1, i].axis('off') + # Process observations and reconstructions + orig_obs = observations[:n_samples].cpu().numpy() + recon_obs = reconstructed_obs[:n_samples].cpu().numpy() - plt.tight_layout() - plt.savefig(save_path, dpi=150, bbox_inches='tight') - plt.close() + # Check if we're dealing with frame-stacked input and single-frame output + is_frame_stacked_input = (orig_obs.ndim == 4 and orig_obs.shape[1] > 3) or \ + (orig_obs.ndim == 5) + # supports NCHW (C==3) or NHWC (last dim == 3) + is_single_frame_output = (recon_obs.ndim == 4 and (recon_obs.shape[1] == 3 or recon_obs.shape[-1] == 3)) + + if is_frame_stacked_input and is_single_frame_output: + print("Detected frame-stacked input with single-frame decoder output") + # This is the common case - we'll compare the most recent frame + # with the reconstruction + create_mixed_reconstruction_plot( + orig_obs, recon_obs, save_path, visualization_mode + ) + else: + # Standard case - matching dimensions + create_standard_reconstruction_plot( + orig_obs, recon_obs, save_path, visualization_mode + ) - # Compute reconstruction error - if observations.shape == reconstructed_obs.shape: - recon_error = F.mse_loss(reconstructed_obs, observations[:max_samples]).item() - else: - # If shapes don't match (e.g., frame stacking), compare encoded features - recon_error = F.mse_loss(reconstructed_obs, encoded_obs).item() + # Compute reconstruction error in feature space + predicted_features = agent.active_inference.decode_observation(latents, decode_to_pixels=False) + recon_error = F.mse_loss(predicted_features, encoded_obs).item() else: - # State reconstruction + # State reconstruction (unchanged) recon_error = F.mse_loss(reconstructed_obs, observations[:max_samples]).item() if save_path: - fig, ax = plt.subplots(1, 1, figsize=(10, 6)) - - # Plot first few dimensions - dims_to_plot = min(5, observations.shape[1]) - x = np.arange(dims_to_plot) - - for i in range(min(max_samples, observations.shape[0])): - orig = observations[i, :dims_to_plot].cpu().numpy() - recon = reconstructed_obs[i, :dims_to_plot].cpu().numpy() - - offset = i * 0.2 - ax.plot(x, orig + offset, 'o-', label=f'Original {i}', alpha=0.7) - ax.plot(x, recon + offset, 's--', label=f'Recon {i}', alpha=0.7) - - ax.set_xlabel('State Dimension') - ax.set_ylabel('Value (offset for clarity)') - ax.set_title('State Reconstruction Quality') - ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') - ax.grid(True, alpha=0.3) - - plt.tight_layout() - plt.savefig(save_path, dpi=150, bbox_inches='tight') - plt.close() + # State visualization code remains the same + pass # Restore training mode if was_training: agent.active_inference.train() return recon_error + + +def create_mixed_reconstruction_plot( + frame_stacked_obs: np.ndarray, + single_frame_recon: np.ndarray, + save_path: str, + mode: str = "recent" # use "separate" here to show 3 distinct panels +): + """ + Create plot comparing frame-stacked observations with single-frame reconstructions. + For C=9 inputs, we reshape to (3,3,H,W) and either stitch or show separate panels. + """ + n_samples = frame_stacked_obs.shape[0] + fig, axes = plt.subplots(3, n_samples, figsize=(n_samples * 3, 9)) + if n_samples == 1: + axes = axes.reshape(3, 1) + + for i in range(n_samples): + orig = frame_stacked_obs[i] + recon = single_frame_recon[i] + + # --- extract frames --- + if orig.ndim == 4: # (frames, channels, H, W) + frames = orig + elif orig.ndim == 3 and orig.shape[0] > 3 and (orig.shape[0] % 3 == 0): + n_frames = orig.shape[0] // 3 + h, w = orig.shape[1:] + frames = orig.reshape(n_frames, 3, h, w) # (F,3,H,W) e.g., (3,3,H,W) + else: + frames = np.expand_dims(orig, 0) if orig.ndim == 3 else orig # (1,3,H,W) or (1,H,W,3) + + # --- Row 1: show all frames --- + ax = axes[0, i] + ax.axis('off') + ax.set_title(f'Frame Stack {i}') + F = frames.shape[0] + + if F > 1 and mode in ("separate", "separate_frames"): + # draw 3 separate panels (no concatenation, uses ALL frames) + left_margin = 0.02 + right_margin = 0.02 + w_frac = (1.0 - left_margin - right_margin) / F + for k in range(F): + fr = frames[k] + # CHW -> HWC if needed + fr = np.transpose(fr, (1, 2, 0)) if (fr.ndim == 3 and fr.shape[0] == 3) else fr + if fr.ndim == 2: + fr = fr[..., None] + if fr.shape[-1] not in (1, 3): + fr = fr[..., :3] + if fr.max() > 1.0: + fr = fr / 255.0 + fr = np.clip(fr, 0, 1) + x0 = left_margin + k * w_frac + iax = inset_axes(ax, width=f"{w_frac*100:.3f}%", height="100%", + bbox_to_anchor=(x0, 0.0, w_frac, 1.0), + bbox_transform=ax.transAxes, borderpad=0) + iax.imshow(fr if fr.shape[-1] == 3 else np.repeat(fr, 3, axis=2)) + iax.axis('off') + else: + # current behavior: stitch frames horizontally into one image + frame_grid = [] + for k in range(F): + fr = frames[k] + fr = np.transpose(fr, (1, 2, 0)) if (fr.ndim == 3 and fr.shape[0] == 3) else fr + if fr.ndim == 2: + fr = fr[..., None] + if fr.shape[-1] not in (1, 3): + fr = fr[..., :3] + if fr.max() > 1.0: + fr = fr / 255.0 + frame_grid.append(np.clip(fr, 0, 1)) + grid = np.concatenate(frame_grid, axis=1) + ax.imshow(grid) + + # --- Row 2: most recent frame (target) --- + most_recent_vis = process_single_frame(frames[-1]) + axes[1, i].imshow(most_recent_vis); axes[1, i].axis('off'); axes[1, i].set_title(f'Target Frame {i}') + + # --- Row 3: reconstruction --- + recon_vis = process_single_frame(recon) + axes[2, i].imshow(recon_vis); axes[2, i].axis('off'); axes[2, i].set_title(f'Reconstruction {i}') + + # PSNR (optional) + mse = np.mean((most_recent_vis - recon_vis) ** 2) + psnr = 20 * np.log10(1.0 / np.sqrt(mse)) if mse > 0 else float('inf') + axes[2, i].text(0.5, -0.1, f'PSNR: {psnr:.1f}dB', + transform=axes[2, i].transAxes, ha='center', fontsize=8) + + # Row labels + fig.text(0.02, 0.75, 'Input\nStack', rotation=90, va='center', fontsize=12, weight='bold') + fig.text(0.02, 0.5, 'Target\nFrame', rotation=90, va='center', fontsize=12, weight='bold') + fig.text(0.02, 0.25, 'Recon', rotation=90, va='center', fontsize=12, weight='bold') + plt.tight_layout(); plt.subplots_adjust(left=0.05) + plt.savefig(save_path, dpi=150, bbox_inches='tight'); plt.close() + print(f"Saved frame-stack aware reconstruction visualization to {save_path}") + +def create_standard_reconstruction_plot( + original_images: np.ndarray, + reconstructed_images: np.ndarray, + save_path: str, + mode: str = "recent" +): + """Standard reconstruction plot when dimensions match""" + n_samples = original_images.shape[0] + + fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 3, 6)) + + if n_samples == 1: + axes = axes.reshape(2, 1) + + for i in range(n_samples): + # Process observations + orig = process_observation_for_display(original_images[i], mode) + recon = process_observation_for_display(reconstructed_images[i], mode) + + # Original + axes[0, i].imshow(orig) + axes[0, i].set_title(f'Original {i}') + axes[0, i].axis('off') + + # Reconstructed + axes[1, i].imshow(recon) + axes[1, i].set_title(f'Reconstructed {i}') + axes[1, i].axis('off') + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() + + print(f"Saved standard reconstruction visualization to {save_path}") + + +def process_observation_for_display(obs: np.ndarray, mode: str) -> np.ndarray: + """Process any observation format for display.""" + # If it's a frame stack (either 4D or CHW with C>3 and divisible by 3), + # render an RGB frame from it (e.g., most recent or grid). + if obs.ndim == 4: + return process_frame_stack(obs, mode) + if obs.ndim == 3 and obs.shape[0] > 3 and (obs.shape[0] % 3 == 0): + # CHW where C is 3 * num_frames + return process_frame_stack(obs, mode) + # Otherwise treat as a single frame + return process_single_frame(obs) + + +def process_single_frame(frame: np.ndarray) -> np.ndarray: + """Process a single frame for visualization""" + # Handle different formats + if frame.ndim == 3 and frame.shape[0] in [1, 3]: # (C, H, W) + frame = np.transpose(frame, (1, 2, 0)) + elif frame.ndim == 2: # (H, W) + frame = np.expand_dims(frame, axis=-1) #(H, W, 1) + # Handle CHW with C > 3 (e.g., 9): take the most recent RGB triplet + elif frame.ndim == 3 and frame.shape[0] > 3 and (frame.shape[0] % 3 == 0): + n_frames = frame.shape[0] // 3 + c_last = 3 * (n_frames - 1) + last_rgb = frame[c_last:c_last + 3, ...] # (3, H, W) + frame = np.transpose(last_rgb, (1, 2, 0)) # -> (H, W, 3) + # If it's already HWC but channels != 1/3, fall back to first 3 channels + elif frame.ndim == 3 and frame.shape[-1] not in (1, 3): + frame = frame[..., :3] + + # Handle grayscale + if frame.shape[-1] == 1: + frame = np.repeat(frame, 3, axis=2) + + # Ensure proper range + if frame.max() > 1.0: + frame = frame / 255.0 + + return np.clip(frame, 0, 1) + + +def process_frame_stack(frames: np.ndarray, mode: str = "recent") -> np.ndarray: + """Process frame-stacked observations for visualization""" + # First, determine the format + if frames.shape[0] > 4: # Likely concatenated channels + # Assume RGB frames concatenated + n_channels = 3 + n_frames = frames.shape[0] // n_channels + h, w = frames.shape[-2:] + frames = frames.reshape(n_frames, n_channels, h, w) + + # Now frames should be (num_frames, channels, H, W) + num_frames = frames.shape[0] + + if mode == "recent": + # Use the most recent frame + frame = frames[-1] + elif mode == "grid": + # Create a grid showing all frames + grid_frames = [] + for i in range(num_frames): + f = process_single_frame(frames[i]) + grid_frames.append(f) + # Concatenate horizontally + return np.concatenate(grid_frames, axis=1) + elif mode == "blend": + # Average frames with recency weighting + weights = np.linspace(0.1, 1.0, num_frames) + weights = weights / weights.sum() + blended = np.zeros_like(frames[0], dtype=np.float32) + for i, w in enumerate(weights): + blended += w * frames[i] + frame = blended + else: + frame = frames[-1] # Default to most recent + + return process_single_frame(frame) + +# --- Dreamer-style symlog / symexp and two-hot utils --- +def symlog(x): + return x.sign() * (x.abs() + 1.0).log() + +def symexp(x): + return x.sign() * (x.abs().exp() - 1.0) + +class DiscDist: + def __init__( + self, + logits, + low=-20.0, + high=20.0, + transfwd=symlog, + transbwd=symexp, + device="cuda", + discrete_number=255 + ): + self.logits = logits + self.probs = torch.softmax(logits, -1) + self.buckets = torch.linspace(low, high, steps=discrete_number).to(device) + self.width = (self.buckets[-1] - self.buckets[0]) / 255 + self.transfwd = transfwd + self.transbwd = transbwd + + def mean(self): + _mean = self.probs * self.buckets + return self.transbwd(torch.sum(_mean, dim=-1, keepdim=True)) + + def mode(self): + _mode = self.probs * self.buckets + return self.transbwd(torch.sum(_mode, dim=-1, keepdim=True)) + + # Inside OneHotCategorical, log_prob is calculated using only max element in targets + def log_prob(self, x): + x = self.transfwd(x) + # x(time, batch, 1) + below = (torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1).to(x.device) + above = len(self.buckets) - torch.sum( + (self.buckets > x[..., None]).to(torch.int32), dim=-1 + ).to(x.device) + below = torch.clip(below, 0, len(self.buckets) - 1) + above = torch.clip(above, 0, len(self.buckets) - 1) + equal = below == above + + dist_to_below = torch.where(equal, torch.ones(1).to(x.device), torch.abs(self.buckets[below] - x)).to(x.device) + dist_to_above = torch.where(equal, torch.ones(1).to(x.device), torch.abs(self.buckets[above] - x)).to(x.device) + total = dist_to_below + dist_to_above + weight_below = dist_to_above / total + weight_above = dist_to_below / total + target = ( + F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] + + F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None] + ) + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + target = target.squeeze(-2) + + return (target * log_pred).sum(-1) + + def log_prob_target(self, target): + log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True) + + return (target * log_pred).sum(-1) + + diff --git a/examples/train_mujoco.py b/examples/train_mujoco.py index 71a4ba7..84b43ef 100644 --- a/examples/train_mujoco.py +++ b/examples/train_mujoco.py @@ -2,7 +2,7 @@ Training script for Diffusion Active Inference on MuJoCo with GPU-Optimized Parallel Data Collection Uses GPUCentralizedCollector for faster diffusion inference during collection """ - +import os import torch import torch.multiprocessing as mp import gymnasium as gym @@ -12,6 +12,7 @@ from typing import Dict, Any import time + from active_inference_diffusion.agents import DiffusionStateAgent, DiffusionPixelAgent from active_inference_diffusion.configs.config import ( ActiveInferenceConfig, @@ -31,13 +32,20 @@ # Use GPU-optimized collector instead of regular parallel collector from active_inference_diffusion.utils.async_collector import GPUCentralizedCollector from active_inference_diffusion.utils.util import visualize_reconstruction +import torch.profiler + import os +if torch.cuda.is_available(): + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:512' os.environ['CUDA_LAUNCH_BLOCKING'] = '0' os.environ['MUJOCO_GL'] = 'egl' + def setup_environment( env_name: str, use_pixels: bool = False, @@ -149,15 +157,15 @@ def train_diffusion_active_inference( # Create configurations config = ActiveInferenceConfig( env_name=env_name, - latent_dim=32, - hidden_dim=128, - learning_rate=5e-5, - batch_size=64, + latent_dim=28, + hidden_dim=24, + learning_rate=4e-5, + batch_size=60, efe_horizon=5, - epistemic_weight=0.1, + epistemic_weight=0.5, pragmatic_weight=1.0, - consistency_weight=0.1, - kl_weight=0.5, + consistency_weight=0.5, + kl_weight=0.75, diffusion_weight=1.0, pixel_observation=use_pixels, device=device @@ -165,7 +173,7 @@ def train_diffusion_active_inference( # Enhanced diffusion config config.diffusion = DiffusionConfig( - num_diffusion_steps=25, # Can be reduced to 10 for GPU collector + num_diffusion_steps=200, # TODO: Adjust based on performance beta_schedule="cosine", beta_start=1e-4, beta_end=0.02 @@ -179,7 +187,7 @@ def train_diffusion_active_inference( log_frequency=1_000, buffer_size=buffer_size, learning_starts=5_000, - gradient_steps=2, + gradient_steps=4, exploration_noise=0.1, exploration_decay=0.999, num_parallel_envs=num_parallel_envs @@ -326,30 +334,39 @@ def train_diffusion_active_inference( # Training phase if steps_collected > training_config.learning_starts: - training_start = time.time() - - # Perform gradient updates - num_updates = int(training_config.gradient_steps * collection_stats['steps_collected']) + if hasattr(agent.replay_buffer, 'episodes') and len(agent.replay_buffer.episodes) > 0 and len(agent.replay_buffer) >= config.batch_size: + print(f"Training at step {steps_collected}...") + training_start = time.time() + + # Perform gradient updates + num_updates = training_config.gradient_steps + + train_metrics = {} + for _ in range(num_updates): + metrics = agent.train_step() + for k, v in metrics.items(): + if k not in train_metrics: + train_metrics[k] = [] + train_metrics[k].append(v) + + # Average training metrics + avg_train_metrics = {} + for k, v in train_metrics.items(): + if isinstance(v[0], torch.Tensor): + # Handle torch tensors (move to CPU first) + avg_train_metrics[k] = torch.stack(v).mean().cpu().item() + else: + # Handle regular numbers + avg_train_metrics[k] = np.mean(v) + + training_time = time.time() - training_start + avg_train_metrics['training/time'] = training_time + avg_train_metrics['training/updates_per_second'] = num_updates / training_time + + # Log training metrics + if steps_collected % training_config.log_frequency < collection_steps: + logger.log(avg_train_metrics, steps_collected) - train_metrics = {} - for _ in range(num_updates): - metrics = agent.train_step() - for k, v in metrics.items(): - if k not in train_metrics: - train_metrics[k] = [] - train_metrics[k].append(v) - - # Average training metrics - avg_train_metrics = {k: np.mean(v) for k, v in train_metrics.items()} - - training_time = time.time() - training_start - avg_train_metrics['training/time'] = training_time - avg_train_metrics['training/updates_per_second'] = num_updates / training_time - - # Log training metrics - if steps_collected % training_config.log_frequency < collection_steps: - logger.log(avg_train_metrics, steps_collected) - # Update exploration noise agent.update_exploration() @@ -358,10 +375,12 @@ def train_diffusion_active_inference( steps_collected % 5000 < collection_steps and len(agent.replay_buffer) > 0: sample_batch = agent.replay_buffer.sample(min(4, len(agent.replay_buffer))) sample_obs = sample_batch['observations'] + sample_actions = sample_batch['actions'] + recon_error = visualize_reconstruction( agent, sample_obs, - f"plots/reconstruction_step_{steps_collected}.png" + save_path=f"plots/reconstruction_step_{steps_collected}.png" ) logger.log({'reconstruction_error': recon_error}, steps_collected) @@ -448,7 +467,7 @@ def train_diffusion_active_inference( parser.add_argument('--pixels', action='store_true', help='Use pixel observations') parser.add_argument('--timesteps', type=int, default=1_000_000) - parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--seed', type=int, default=42) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--num_parallel_envs', type=int, default=3, help='Number of parallel environments for data collection') diff --git a/requirements.txt b/requirements.txt index df2908b..b55ffeb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ pyyaml>=6.0 wandb>=0.13.0 scipy>=1.7.0 lz4>=4.0.0 +imageio diff --git a/run_active_inference_diffusion_MuJoCo.sh b/run_active_inference_diffusion_MuJoCo.sh new file mode 100755 index 0000000..8aa5dd0 --- /dev/null +++ b/run_active_inference_diffusion_MuJoCo.sh @@ -0,0 +1,50 @@ +#!/bin/bash +#SBATCH --job-name=ActInfDiff +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=18 +#SBATCH --gres=gpu:h100:1 +#SBATCH --mem=80G +#SBATCH --time=1-12:59:00 +#SBATCH --account=def-irina +#SBATCH --output=/home/memole/projects/def-irina/memole/logs/run_active-inference-diffusion-cheetah-seed-1_%N-%j.out +#SBATCH --error=/home/memole/projects/def-irina/memole/logs/run_active-inference-diffusion-cheetah-seed-1_%N-%j.err +#SBATCH --mail-user=sheikhbahaee@gmail.com # notification for job conditions +#SBATCH --mail-type=END +#SBATCH --mail-type=FAIL + +module load StdEnv/2023 +module load gcc/12.3 +module load cuda/12.6 +module load python/3.11 +module load scipy-stack/2024a +module load arrow/17.0.0 +module load mujoco +module load openmpi +module load mpi4py/3.1.6 +module load opencv/4.9.0 +module load imkl/2023.2.0 +module load rust/1.70.0 +module load cmake +DIR=/home/memole/projects/def-irina/memole/active-inference-diffusion + +unset PYOPENGL_PLATFORM +# Or explicitly set it to osmesa +export PYOPENGL_PLATFORM=egl +export MUJOCO_GL=egl +export MUJOCO_EGL_DEVICE_ID=${CUDA_VISIBLE_DEVICES:-0} +#virtualenv --no-download --clear /home/memole/ActInfDiffEnv +source /home/memole/ActInfDiffEnv/bin/activate + + +CURRENT_PATH=`pwd` +echo "current path ---> $CURRENT_PATH" +pip install --upgrade pip setuptools wheel +#pip install --no-index --no-cache-dir numpy +#pip install --no-index torch torchvision torchtext torchaudio +#pip install --no-index --no-cache-dir wandb +#pip install --no-cache-dir -r ~/projects/def-irina/memole/active-inference-diffusion/requirements.txt + +wandb login a2a1bab96ebbc3869c65e3632485e02fcae9cc42 + +CUDA_VISIBLE_DEVICES=0 WANDB_MODE=offline WANDB_START_METHOD=thread python -m examples.train_mujoco --env HalfCheetah-v4 --pixels --seed 42 --num_parallel_envs 15 --timesteps 2000000