Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
d0d0cff
Remove SB3 as a necessary module
pseudo-rnd-thoughts Mar 19, 2025
7865501
Remove additional sb3 checks
pseudo-rnd-thoughts Mar 19, 2025
bbc117f
Update poetry.lock and requirements + add `from __future__ import ann…
pseudo-rnd-thoughts Mar 19, 2025
6587ee6
Add `get_obs_shape` function
pseudo-rnd-thoughts Mar 19, 2025
d3da8ff
Remove `stable-baselines3` from test installs
pseudo-rnd-thoughts Mar 19, 2025
11f70c5
Add copyright notices
pseudo-rnd-thoughts Mar 25, 2025
61739bd
Update pre-commit
pseudo-rnd-thoughts Jul 4, 2025
7356872
Update tests
pseudo-rnd-thoughts Jul 4, 2025
11782c3
Update docs and tests
pseudo-rnd-thoughts Jul 4, 2025
e03aac1
Merge branch 'master' into replace-poetry-with-uv
pseudo-rnd-thoughts Jul 4, 2025
cd90426
Fix pre-commit and workflow
pseudo-rnd-thoughts Jul 4, 2025
4346403
Update workflow action versions
pseudo-rnd-thoughts Jul 4, 2025
a88614c
Add typing_extensions>=4.6.0 for optuna
pseudo-rnd-thoughts Jul 4, 2025
e2680d8
Increase mujoco bound
pseudo-rnd-thoughts Jul 4, 2025
26b95c4
Merge branch 'replace-poetry-with-uv' into remove-sb3
pseudo-rnd-thoughts Jul 4, 2025
7ca4c96
Remove sb3 from requirements
pseudo-rnd-thoughts Jul 4, 2025
cf6554d
Update to Gymnasium v1.0
pseudo-rnd-thoughts Jul 4, 2025
6e5693d
Revert mujoco<=2.3.3
pseudo-rnd-thoughts Jul 4, 2025
bf811e0
Add `export MUJOCO_GL=egl` as environment variable
pseudo-rnd-thoughts Jul 4, 2025
7d84668
Use MUJOCO_GL egl in mujoco tests
pseudo-rnd-thoughts Jul 4, 2025
2d490b9
Split up the `test_mujoco.py` file
pseudo-rnd-thoughts Jul 5, 2025
6aa2982
Specify the `DISPLAY: :0`
pseudo-rnd-thoughts Jul 5, 2025
9026018
register ale-py envs
pseudo-rnd-thoughts Jul 5, 2025
7350f6e
Add xvfb screen
pseudo-rnd-thoughts Jul 5, 2025
c2f7c6d
Merge branch 'replace-poetry-with-uv' into remove-sb3
pseudo-rnd-thoughts Jul 5, 2025
800be9b
Update lock
pseudo-rnd-thoughts Jul 5, 2025
4339766
Update pytorch version to 2.4.1 as 2.5.1 doesn't have py38 wheels
pseudo-rnd-thoughts Jul 5, 2025
ba42e1c
Add optax and chex to requirements
pseudo-rnd-thoughts Jul 5, 2025
b08a2e7
Merge branch 'remove-sb3' into gymnasium-v1.0
pseudo-rnd-thoughts Jul 5, 2025
274f9d8
Update requirements
pseudo-rnd-thoughts Jul 5, 2025
5dca27d
Update to new `RecordEpisodeInfos`
pseudo-rnd-thoughts Jul 5, 2025
a84ea18
Update to new `RecordEpisodeInfos`
pseudo-rnd-thoughts Jul 5, 2025
56941e2
Fix tests
pseudo-rnd-thoughts Jul 5, 2025
016ca33
Fix PettingZoo, mujoco and envpool tests
pseudo-rnd-thoughts Jul 5, 2025
7c2660b
Fix saving episode info for atari implementations
pseudo-rnd-thoughts Jul 5, 2025
40af2d0
Fix saving episode info for atari implementations for eval
pseudo-rnd-thoughts Jul 5, 2025
21e6c7b
Merge branch 'master-upstream' into gymnasium-v1.0
pseudo-rnd-thoughts Jul 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -47,7 +47,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -67,8 +67,6 @@ jobs:
- name: Install jax
if: runner.os == 'Linux' || runner.os == 'macOS'
run: uv pip install ".[pytest, atari, jax]"
- name: Run gymnasium migration dependencies
run: uv run pip install "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1"
- name: Run gymnasium tests
run: uv run pytest tests/test_atari_gymnasium.py
- name: Run gymnasium tests with jax
Expand All @@ -79,7 +77,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -94,16 +92,14 @@ jobs:
# procgen tests
- name: Install core dependencies
run: uv pip install ".[pytest, procgen]"
- name: Downgrade setuptools
run: uv run pip install setuptools==59.5.0
- name: Run procgen tests
run: uv run pytest tests/test_procgen.py

test-mujoco-envs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -136,7 +132,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -158,7 +154,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -180,7 +176,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-22.04]
runs-on: ${{ matrix.os }}
steps:
Expand Down
15 changes: 9 additions & 6 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand Down Expand Up @@ -187,11 +188,13 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
22 changes: 14 additions & 8 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from dataclasses import dataclass

import ale_py
import gymnasium as gym
import numpy as np
import torch
Expand All @@ -21,6 +22,8 @@
)
from cleanrl_utils.buffers import ReplayBuffer

gym.register_envs(ale_py)


@dataclass
class Args:
Expand Down Expand Up @@ -98,8 +101,8 @@ def thunk():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env = gym.wrappers.GrayscaleObservation(env)
env = gym.wrappers.FrameStackObservation(env, 4)

env.action_space.seed(seed)
return env
Expand Down Expand Up @@ -175,7 +178,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand Down Expand Up @@ -210,11 +214,13 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
24 changes: 15 additions & 9 deletions cleanrl/c51_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.7"

import ale_py
import flax
import flax.linen as nn
import gymnasium as gym
Expand All @@ -27,6 +28,8 @@
)
from cleanrl_utils.buffers import ReplayBuffer

gym.register_envs(ale_py)


@dataclass
class Args:
Expand Down Expand Up @@ -100,8 +103,8 @@ def thunk():
env = FireResetEnv(env)
env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env = gym.wrappers.GrayscaleObservation(env)
env = gym.wrappers.FrameStackObservation(env, 4)

env.action_space.seed(seed)
return env
Expand Down Expand Up @@ -173,7 +176,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand Down Expand Up @@ -268,12 +272,14 @@ def get_action(q_state, obs):
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
if "final_info" in infos and "episode" in infos["final_info"]:
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
15 changes: 9 additions & 6 deletions cleanrl/c51_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand Down Expand Up @@ -233,11 +234,13 @@ def loss(q_params, observations, actions, target_pmfs):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
16 changes: 10 additions & 6 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def forward(self, x):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed, 0, args.capture_video, run_name)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

actor = Actor(envs).to(device)
Expand Down Expand Up @@ -184,11 +186,13 @@ def forward(self, x):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
16 changes: 10 additions & 6 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class TrainState(TrainState):
key, actor_key, qf1_key = jax.random.split(key, 3)

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed, 0, args.capture_video, run_name)], autoreset_mode=gym.vector.AutoresetMode.SAME_STEP
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

max_action = float(envs.single_action_space.high[0])
Expand Down Expand Up @@ -238,11 +240,13 @@ def actor_loss(params):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
15 changes: 9 additions & 6 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)],
autoreset_mode=gym.vector.AutoresetMode.SAME_STEP,
)
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

Expand Down Expand Up @@ -174,11 +175,13 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
if info and "episode" in info:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
episodes_over = np.nonzero(infos["final_info"]["_episode"])[0]
episodic_returns = infos["final_info"]["episode"]["r"][episodes_over]
episodic_lengths = infos["final_info"]["episode"]["l"][episodes_over]
for episodic_return, episodic_length in zip(episodic_returns, episodic_lengths):
print(f"global_step={global_step}, episodic_return={episodic_return}")
writer.add_scalar("charts/episodic_return", episodic_return, global_step)
writer.add_scalar("charts/episodic_length", episodic_length, global_step)

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
Expand Down
Loading
Loading