Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,15 @@ agent = PixelBasedAgent(env, config, training_config, pixel_config)

```bash
# Train with state observations
python examples/train_state_mujoco.py --env HalfCheetah-v4
python examples/train_mujoco.py --env HalfCheetah-v4

# Train with pixel observations
python examples/train_pixel_mujoco.py --env HalfCheetah-v4

# Resume from checkpoint
python examples/train_state_mujoco.py --env HalfCheetah-v4 --resume
python examples/train_mujoco.py --env HalfCheetah-v4 --pixels

# Use custom config
python examples/train_pixel_mujoco.py --env Hopper-v4 --config examples/configs/hopper_pixel.yaml
python examples/train_mujoco.py --env Hopper-v4 --pixels --config examples/configs/hopper_pixel.yaml


```

### Configuration Files
Expand Down
36 changes: 3 additions & 33 deletions active_inference_diffusion/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self.total_steps = 0
self.episode_count = 0
self.exploration_noise = training_config.exploration_noise
self.reward_normalizer = RunningMeanStd(shape=())


@abstractmethod
def _setup_dimensions(self):
Expand All @@ -106,38 +106,8 @@ def _create_replay_buffer(self) -> ReplayBuffer:

def _setup_optimizers(self):
"""Setup optimizers"""
# Score network optimizer
self.score_optimizer = torch.optim.Adam(
self.active_inference.score_network.parameters(),
lr=self.config.learning_rate
)

# Policy optimizer
self.policy_optimizer = torch.optim.Adam(
self.active_inference.policy_network.parameters(),
lr=self.config.learning_rate
)

# Value optimizer
self.value_optimizer = torch.optim.Adam(
self.active_inference.value_network.parameters(),
lr=self.config.learning_rate
)

# Dynamics optimizer
self.dynamics_optimizer = torch.optim.Adam(
list(self.active_inference.dynamics_model.parameters()) +
list(self.active_inference.reward_predictor.parameters()),
lr=self.config.learning_rate
)
#Add epistemic optimizer
self.epistemic_optimizer = torch.optim.Adam(
self.active_inference.epistemic_estimator.parameters(),
lr=self.config.learning_rate*0.1,
weight_decay=1e-5
)
self.active_inference.epistemic_optimizer = self.epistemic_optimizer

pass

def act(
self,
observation: np.ndarray,
Expand Down
Loading