From 33df709732e40164dafd2b8afc7636491284633b Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 6 Sep 2022 23:11:43 +0200 Subject: [PATCH 01/31] Clone file --- cleanrl/td3_droq_continuous_action_jax.py | 322 ++++++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 cleanrl/td3_droq_continuous_action_jax.py diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py new file mode 100644 index 000000000..7b96897cd --- /dev/null +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -0,0 +1,322 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import Sequence + +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pybullet_envs # noqa +from flax.training.train_state import TrainState +from stable_baselines3.common.buffers import ReplayBuffer +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=1000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--buffer-size", type=int, default=int(1e6), + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=0.005, + help="target smoothing coefficient (default: 0.005)") + parser.add_argument("--policy-noise", type=float, default=0.2, + help="the scale of policy noise") + parser.add_argument("--batch-size", type=int, default=256, + help="the batch size of sample from the reply memory") + parser.add_argument("--exploration-noise", type=float, default=0.1, + help="the scale of exploration noise") + parser.add_argument("--learning-starts", type=int, default=25e3, + help="timestep to start learning") + parser.add_argument("--policy-frequency", type=int, default=2, + help="the frequency of training policy (delayed)") + parser.add_argument("--noise-clip", type=float, default=0.5, + help="noise clip parameter of the Target Policy Smoothing Regularization") + args = parser.parse_args() + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video, run_name): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + @nn.compact + def __call__(self, x: jnp.ndarray, a: jnp.ndarray): + x = jnp.concatenate([x, a], -1) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(1)(x) + return x + + +class Actor(nn.Module): + action_dim: Sequence[int] + action_scale: Sequence[int] + action_bias: Sequence[int] + + @nn.compact + def __call__(self, x): + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + x = nn.tanh(x) + x = x * self.action_scale + self.action_bias + return x + + +class TrainState(TrainState): + target_params: flax.core.FrozenDict + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) + + # env setup + envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" + + max_action = float(envs.single_action_space.high[0]) + envs.single_observation_space.dtype = np.float32 + rb = ReplayBuffer( + args.buffer_size, + envs.single_observation_space, + envs.single_action_space, + device="cpu", + handle_timeout_termination=True, + ) + + # TRY NOT TO MODIFY: start the game + obs = envs.reset() + actor = Actor( + action_dim=np.prod(envs.single_action_space.shape), + action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), + action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), + ) + actor_state = TrainState.create( + apply_fn=actor.apply, + params=actor.init(actor_key, obs), + target_params=actor.init(actor_key, obs), + tx=optax.adam(learning_rate=args.learning_rate), + ) + qf = QNetwork() + qf1_state = TrainState.create( + apply_fn=qf.apply, + params=qf.init(qf1_key, obs, envs.action_space.sample()), + target_params=qf.init(qf1_key, obs, envs.action_space.sample()), + tx=optax.adam(learning_rate=args.learning_rate), + ) + qf2_state = TrainState.create( + apply_fn=qf.apply, + params=qf.init(qf2_key, obs, envs.action_space.sample()), + target_params=qf.init(qf2_key, obs, envs.action_space.sample()), + tx=optax.adam(learning_rate=args.learning_rate), + ) + actor.apply = jax.jit(actor.apply) + qf.apply = jax.jit(qf.apply) + + @jax.jit + def update_critic( + actor_state: TrainState, + qf1_state: TrainState, + qf2_state: TrainState, + observations: np.ndarray, + actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + key: jnp.ndarray, + ): + # TODO Maybe pre-generate a lot of random keys + # also check https://jax.readthedocs.io/en/latest/jax.random.html + key, noise_key = jax.random.split(key, 2) + clipped_noise = jnp.clip( + (jax.random.normal(noise_key, actions[0].shape) * args.policy_noise), + -args.noise_clip, + args.noise_clip, + ) + next_state_actions = jnp.clip( + actor.apply(actor_state.target_params, next_observations) + clipped_noise, + envs.single_action_space.low[0], + envs.single_action_space.high[0], + ) + qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) + qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) + min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) + next_q_value = (rewards + (1 - dones) * args.gamma * (min_qf_next_target)).reshape(-1) + + def mse_loss(params): + qf_a_values = qf.apply(params, observations, actions).squeeze() + return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() + + (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) + (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) + qf1_state = qf1_state.apply_gradients(grads=grads1) + qf2_state = qf2_state.apply_gradients(grads=grads2) + + return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key + + @jax.jit + def update_actor( + actor_state: TrainState, + qf1_state: TrainState, + qf2_state: TrainState, + observations: np.ndarray, + ): + def actor_loss(params): + return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() + + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + actor_state = actor_state.replace( + target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) + ) + + qf1_state = qf1_state.replace( + target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) + ) + qf2_state = qf2_state.replace( + target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) + ) + return actor_state, (qf1_state, qf2_state), actor_loss_value + + start_time = time.time() + for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + else: + actions = actor.apply(actor_state.params, obs) + actions = np.array( + [ + ( + jax.device_get(actions)[0] + + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) + ).clip(envs.single_action_space.low, envs.single_action_space.high) + ] + ) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, dones, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + for info in infos: + if "episode" in info.keys(): + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break + + # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` + real_next_obs = next_obs.copy() + for idx, d in enumerate(dones): + if d: + real_next_obs[idx] = infos[idx]["terminal_observation"] + rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + data = rb.sample(args.batch_size) + + (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) + + if global_step % args.policy_frequency == 0: + actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + ) + if global_step % 100 == 0: + writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) + writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) + writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) + writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + + envs.close() + writer.close() From 6ed7655e5c2ebe1943b11263fab85fe6166c3a42 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 6 Sep 2022 23:47:10 +0200 Subject: [PATCH 02/31] Fixes and reformating --- cleanrl/td3_droq_continuous_action_jax.py | 215 +++++++++++++++------- 1 file changed, 145 insertions(+), 70 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 7b96897cd..c00c77a34 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -17,6 +17,7 @@ from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter +from stable_baselines3.common.vec_env import DummyVecEnv def parse_args(): @@ -38,9 +39,9 @@ def parse_args(): # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", help="the id of the environment") - parser.add_argument("--total-timesteps", type=int, default=1000000, + parser.add_argument("-n", "--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") - parser.add_argument("--learning-rate", type=float, default=3e-4, + parser.add_argument("-lr", "--learning-rate", type=float, default=3e-4, help="the learning rate of the optimizer") parser.add_argument("--buffer-size", type=int, default=int(1e6), help="the replay memory buffer size") @@ -54,8 +55,10 @@ def parse_args(): help="the batch size of sample from the reply memory") parser.add_argument("--exploration-noise", type=float, default=0.1, help="the scale of exploration noise") - parser.add_argument("--learning-starts", type=int, default=25e3, + parser.add_argument("--learning-starts", type=int, default=1000, help="timestep to start learning") + parser.add_argument("--gradient-steps", type=int, default=1, + help="Number of gradient steps to perform after each rollout") parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") parser.add_argument("--noise-clip", type=float, default=0.5, @@ -95,8 +98,8 @@ def __call__(self, x: jnp.ndarray, a: jnp.ndarray): class Actor(nn.Module): action_dim: Sequence[int] - action_scale: Sequence[int] - action_bias: Sequence[int] + action_scale: float + action_bias: float @nn.compact def __call__(self, x): @@ -110,11 +113,11 @@ def __call__(self, x): return x -class TrainState(TrainState): +class RLTrainState(TrainState): target_params: flax.core.FrozenDict -if __name__ == "__main__": +def main(): args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: @@ -132,7 +135,8 @@ class TrainState(TrainState): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -142,15 +146,21 @@ class TrainState(TrainState): key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) # env setup - envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" - - max_action = float(envs.single_action_space.high[0]) - envs.single_observation_space.dtype = np.float32 + envs = DummyVecEnv( + [make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] + ) + assert isinstance( + envs.action_space, gym.spaces.Box + ), "only continuous action space is supported" + + # Assume that all dimensions share the same bound + min_action = float(envs.action_space.low[0]) + max_action = float(envs.action_space.high[0]) + envs.observation_space.dtype = np.float32 rb = ReplayBuffer( args.buffer_size, - envs.single_observation_space, - envs.single_action_space, + envs.observation_space, + envs.action_space, device="cpu", handle_timeout_termination=True, ) @@ -158,27 +168,27 @@ class TrainState(TrainState): # TRY NOT TO MODIFY: start the game obs = envs.reset() actor = Actor( - action_dim=np.prod(envs.single_action_space.shape), - action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), - action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), + action_dim=np.prod(envs.action_space.shape), + action_scale=(max_action - min_action) / 2.0, + action_bias=(max_action + min_action) / 2.0, ) - actor_state = TrainState.create( + actor_state = RLTrainState.create( apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) qf = QNetwork() - qf1_state = TrainState.create( + qf1_state = RLTrainState.create( apply_fn=qf.apply, - params=qf.init(qf1_key, obs, envs.action_space.sample()), - target_params=qf.init(qf1_key, obs, envs.action_space.sample()), + params=qf.init(qf1_key, obs, jnp.array([envs.action_space.sample()])), + target_params=qf.init(qf1_key, obs, jnp.array([envs.action_space.sample()])), tx=optax.adam(learning_rate=args.learning_rate), ) - qf2_state = TrainState.create( + qf2_state = RLTrainState.create( apply_fn=qf.apply, - params=qf.init(qf2_key, obs, envs.action_space.sample()), - target_params=qf.init(qf2_key, obs, envs.action_space.sample()), + params=qf.init(qf2_key, obs, jnp.array([envs.action_space.sample()])), + target_params=qf.init(qf2_key, obs, jnp.array([envs.action_space.sample()])), tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) @@ -186,9 +196,9 @@ class TrainState(TrainState): @jax.jit def update_critic( - actor_state: TrainState, - qf1_state: TrainState, - qf2_state: TrainState, + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, observations: np.ndarray, actions: np.ndarray, next_observations: np.ndarray, @@ -206,62 +216,92 @@ def update_critic( ) next_state_actions = jnp.clip( actor.apply(actor_state.target_params, next_observations) + clipped_noise, - envs.single_action_space.low[0], - envs.single_action_space.high[0], + envs.action_space.low[0], + envs.action_space.high[0], ) - qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) - qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) + qf1_next_target = qf.apply( + qf1_state.target_params, next_observations, next_state_actions + ).reshape(-1) + qf2_next_target = qf.apply( + qf2_state.target_params, next_observations, next_state_actions + ).reshape(-1) min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) - next_q_value = (rewards + (1 - dones) * args.gamma * (min_qf_next_target)).reshape(-1) + next_q_value = ( + rewards + (1 - dones) * args.gamma * (min_qf_next_target) + ).reshape(-1) def mse_loss(params): qf_a_values = qf.apply(params, observations, actions).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() - (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) - (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) + (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad( + mse_loss, has_aux=True + )(qf1_state.params) + (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad( + mse_loss, has_aux=True + )(qf2_state.params) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) - return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key + return ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + (qf1_a_values, qf2_a_values), + key, + ) @jax.jit def update_actor( - actor_state: TrainState, - qf1_state: TrainState, - qf2_state: TrainState, + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, observations: np.ndarray, ): def actor_loss(params): - return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() + return -qf.apply( + qf1_state.params, observations, actor.apply(params, observations) + ).mean() actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) actor_state = actor_state.replace( - target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) + target_params=optax.incremental_update( + actor_state.params, actor_state.target_params, args.tau + ) ) qf1_state = qf1_state.replace( - target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) + target_params=optax.incremental_update( + qf1_state.params, qf1_state.target_params, args.tau + ) ) qf2_state = qf2_state.replace( - target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) + target_params=optax.incremental_update( + qf2_state.params, qf2_state.target_params, args.tau + ) ) return actor_state, (qf1_state, qf2_state), actor_loss_value start_time = time.time() + n_updates = 0 for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) + actions = np.array( + [envs.action_space.sample() for _ in range(envs.num_envs)] + ) else: actions = actor.apply(actor_state.params, obs) actions = np.array( [ ( jax.device_get(actions)[0] - + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape[0]) - ).clip(envs.single_action_space.low, envs.single_action_space.high) + + np.random.normal( + 0, + max_action * args.exploration_noise, + size=envs.action_space.shape[0], + ) + ).clip(envs.action_space.low, envs.action_space.high) ] ) @@ -271,16 +311,26 @@ def actor_loss(params): # TRY NOT TO MODIFY: record rewards for plotting purposes for info in infos: if "episode" in info.keys(): - print(f"global_step={global_step}, episodic_return={info['episode']['r']}") - writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) - writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + print( + f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" + ) + writer.add_scalar( + "charts/episodic_return", info["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", info["episode"]["l"], global_step + ) break # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(dones): - if d: + for idx, done in enumerate(dones): + if done: real_next_obs[idx] = infos[idx]["terminal_observation"] + # Timeout handling done inside the replay buffer + # if infos[idx].get("TimeLimit.truncated", False) == True: + # real_dones[idx] = False + rb.add(obs, real_next_obs, actions, rewards, dones, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook @@ -288,35 +338,60 @@ def actor_loss(params): # ALGO LOGIC: training. if global_step > args.learning_starts: - data = rb.sample(args.batch_size) - - (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - key, - ) - - if global_step % args.policy_frequency == 0: - actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + (qf1_a_values, qf2_a_values), + key, + ) = update_critic( actor_state, qf1_state, qf2_state, data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, ) + + if n_updates % args.policy_frequency == 0: + ( + actor_state, + (qf1_state, qf2_state), + actor_loss_value, + ) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + ) if global_step % 100 == 0: writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) - writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) - print("SPS:", int(global_step / (time.time() - start_time))) - writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "losses/actor_loss", actor_loss_value.item(), global_step + ) + print("FPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar( + "charts/SPS", + int(global_step / (time.time() - start_time)), + global_step, + ) envs.close() writer.close() + + +if __name__ == "__main__": + + try: + main() + except KeyboardInterrupt: + pass From 6a289cae5c363461424adf97ff7bfc6cc29d9166 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 7 Sep 2022 00:50:09 +0200 Subject: [PATCH 03/31] Add dropout and layernorm --- cleanrl/td3_droq_continuous_action_jax.py | 83 +++++++++++++++++++---- 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index c00c77a34..deb7a4166 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -4,7 +4,7 @@ import random import time from distutils.util import strtobool -from typing import Sequence +from typing import Sequence, Optional import flax import flax.linen as nn @@ -58,7 +58,11 @@ def parse_args(): parser.add_argument("--learning-starts", type=int, default=1000, help="timestep to start learning") parser.add_argument("--gradient-steps", type=int, default=1, - help="Number of gradient steps to perform after each rollout") + help="Number of gradient steps to perform after each rollout") + # Argument for dropout rate + parser.add_argument("--dropout-rate", type=float, default=0.0) + # Argument for layer normalization + parser.add_argument("--layer-norm", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") parser.add_argument("--noise-clip", type=float, default=0.5, @@ -85,12 +89,23 @@ def thunk(): # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + @nn.compact - def __call__(self, x: jnp.ndarray, a: jnp.ndarray): + def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False): x = jnp.concatenate([x, a], -1) x = nn.Dense(256)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) x = nn.relu(x) x = nn.Dense(256)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) x = nn.relu(x) x = nn.Dense(1)(x) return x @@ -144,6 +159,7 @@ def main(): np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) + key, dropout_key1, dropout_key2 = jax.random.split(key, 3) # env setup envs = DummyVecEnv( @@ -178,21 +194,37 @@ def main(): target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) - qf = QNetwork() + qf = QNetwork(dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm) qf1_state = RLTrainState.create( apply_fn=qf.apply, - params=qf.init(qf1_key, obs, jnp.array([envs.action_space.sample()])), - target_params=qf.init(qf1_key, obs, jnp.array([envs.action_space.sample()])), + params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), tx=optax.adam(learning_rate=args.learning_rate), ) qf2_state = RLTrainState.create( apply_fn=qf.apply, - params=qf.init(qf2_key, obs, jnp.array([envs.action_space.sample()])), - target_params=qf.init(qf2_key, obs, jnp.array([envs.action_space.sample()])), + params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) - qf.apply = jax.jit(qf.apply) + qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) @jax.jit def update_critic( @@ -208,7 +240,9 @@ def update_critic( ): # TODO Maybe pre-generate a lot of random keys # also check https://jax.readthedocs.io/en/latest/jax.random.html - key, noise_key = jax.random.split(key, 2) + key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) + key, dropout_noise_key = jax.random.split(key, 2) + clipped_noise = jnp.clip( (jax.random.normal(noise_key, actions[0].shape) * args.policy_noise), -args.noise_clip, @@ -220,10 +254,18 @@ def update_critic( envs.action_space.high[0], ) qf1_next_target = qf.apply( - qf1_state.target_params, next_observations, next_state_actions + qf1_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_1}, ).reshape(-1) qf2_next_target = qf.apply( - qf2_state.target_params, next_observations, next_state_actions + qf2_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_2}, ).reshape(-1) min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) next_q_value = ( @@ -231,7 +273,9 @@ def update_critic( ).reshape(-1) def mse_loss(params): - qf_a_values = qf.apply(params, observations, actions).squeeze() + qf_a_values = qf.apply( + params, observations, actions, True, rngs={"dropout": dropout_noise_key} + ).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad( @@ -256,10 +300,17 @@ def update_actor( qf1_state: RLTrainState, qf2_state: RLTrainState, observations: np.ndarray, + key: jnp.ndarray, ): + key, dropout_key = jax.random.split(key, 2) + def actor_loss(params): return -qf.apply( - qf1_state.params, observations, actor.apply(params, observations) + qf1_state.params, + observations, + actor.apply(params, observations), + True, + rngs={"dropout": dropout_key}, ).mean() actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) @@ -280,7 +331,7 @@ def actor_loss(params): qf2_state.params, qf2_state.target_params, args.tau ) ) - return actor_state, (qf1_state, qf2_state), actor_loss_value + return actor_state, (qf1_state, qf2_state), actor_loss_value, key start_time = time.time() n_updates = 0 @@ -364,11 +415,13 @@ def actor_loss(params): actor_state, (qf1_state, qf2_state), actor_loss_value, + key, ) = update_actor( actor_state, qf1_state, qf2_state, data.observations.numpy(), + key, ) if global_step % 100 == 0: writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) From d3ef56b44efedaafa4413666f779d172dadc4fe3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 09:09:00 +0200 Subject: [PATCH 04/31] Add evaluation and tqdm progress bar --- cleanrl/td3_droq_continuous_action_jax.py | 50 +++++++++++++++++++---- 1 file changed, 42 insertions(+), 8 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index deb7a4166..6b24517a9 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -4,7 +4,7 @@ import random import time from distutils.util import strtobool -from typing import Sequence, Optional +from typing import Optional, Sequence import flax import flax.linen as nn @@ -16,8 +16,9 @@ import pybullet_envs # noqa from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer -from torch.utils.tensorboard import SummaryWriter from stable_baselines3.common.vec_env import DummyVecEnv +from torch.utils.tensorboard import SummaryWriter +from tqdm.rich import tqdm def parse_args(): @@ -65,6 +66,11 @@ def parse_args(): parser.add_argument("--layer-norm", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") + parser.add_argument("--eval-freq", type=int, default=-1) + parser.add_argument("--n-eval-envs", type=int, default=1) + parser.add_argument("--n-eval-episodes", type=int, default=10) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--noise-clip", type=float, default=0.5, help="noise clip parameter of the Target Policy Smoothing Regularization") args = parser.parse_args() @@ -72,7 +78,7 @@ def parse_args(): return args -def make_env(env_id, seed, idx, capture_video, run_name): +def make_env(env_id, seed, idx, capture_video=False, run_name=""): def thunk(): env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) @@ -132,6 +138,20 @@ class RLTrainState(TrainState): target_params: flax.core.FrozenDict +def evaluate_policy(eval_env, actor, actor_state, n_eval_episodes: int = 10): + eval_episode_rewards = [] + for _ in range(n_eval_episodes): + obs = eval_env.reset() + done = False + episode_reward = 0 + while not done: + action = np.array(actor.apply(actor_state.params, obs)) + obs, reward, done, _ = eval_env.step(action) + episode_reward += reward + eval_episode_rewards.append(episode_reward) + return np.mean(eval_episode_rewards), np.std(eval_episode_rewards) + + def main(): args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" @@ -165,6 +185,8 @@ def main(): envs = DummyVecEnv( [make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] ) + eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) + assert isinstance( envs.action_space, gym.spaces.Box ), "only continuous action space is supported" @@ -335,7 +357,7 @@ def actor_loss(params): start_time = time.time() n_updates = 0 - for global_step in range(args.total_timesteps): + for global_step in tqdm(range(args.total_timesteps)): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array( @@ -362,9 +384,10 @@ def actor_loss(params): # TRY NOT TO MODIFY: record rewards for plotting purposes for info in infos: if "episode" in info.keys(): - print( - f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" - ) + if args.verbose >= 2: + print( + f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" + ) writer.add_scalar( "charts/episodic_return", info["episode"]["r"], global_step ) @@ -423,6 +446,16 @@ def actor_loss(params): data.observations.numpy(), key, ) + + fps = int(global_step / (time.time() - start_time)) + if args.eval_freq > 0 and global_step % args.eval_freq == 0: + mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) + print( + f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps" + ) + writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) + writer.add_scalar("charts/std_eval_reward", std_reward, global_step) + if global_step % 100 == 0: writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) @@ -431,7 +464,8 @@ def actor_loss(params): writer.add_scalar( "losses/actor_loss", actor_loss_value.item(), global_step ) - print("FPS:", int(global_step / (time.time() - start_time))) + if args.verbose >= 2: + print("FPS:", fps) writer.add_scalar( "charts/SPS", int(global_step / (time.time() - start_time)), From 9704f1d17ededa4e65aba918544dbdc23bd9af24 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 09:21:17 +0200 Subject: [PATCH 05/31] Different dropout keys --- cleanrl/td3_droq_continuous_action_jax.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 6b24517a9..d56f14136 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -263,7 +263,7 @@ def update_critic( # TODO Maybe pre-generate a lot of random keys # also check https://jax.readthedocs.io/en/latest/jax.random.html key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) - key, dropout_noise_key = jax.random.split(key, 2) + key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3) clipped_noise = jnp.clip( (jax.random.normal(noise_key, actions[0].shape) * args.policy_noise), @@ -294,18 +294,18 @@ def update_critic( rewards + (1 - dones) * args.gamma * (min_qf_next_target) ).reshape(-1) - def mse_loss(params): + def mse_loss(params, noise_key): qf_a_values = qf.apply( - params, observations, actions, True, rngs={"dropout": dropout_noise_key} + params, observations, actions, True, rngs={"dropout": noise_key} ).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad( mse_loss, has_aux=True - )(qf1_state.params) + )(qf1_state.params, dropout_key_3) (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad( mse_loss, has_aux=True - )(qf2_state.params) + )(qf2_state.params, dropout_key_4) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) From f0cc8ffccc8ad2deeb5ab20399ec90eee035587f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 10:16:39 +0200 Subject: [PATCH 06/31] Separate q network target update --- cleanrl/td3_droq_continuous_action_jax.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index d56f14136..21f028e54 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -343,6 +343,20 @@ def actor_loss(params): ) ) + # qf1_state = qf1_state.replace( + # target_params=optax.incremental_update( + # qf1_state.params, qf1_state.target_params, args.tau + # ) + # ) + # qf2_state = qf2_state.replace( + # target_params=optax.incremental_update( + # qf2_state.params, qf2_state.target_params, args.tau + # ) + # ) + return actor_state, (qf1_state, qf2_state), actor_loss_value, key + + @jax.jit + def update_q_target_networks(qf1_state, qf2_state): qf1_state = qf1_state.replace( target_params=optax.incremental_update( qf1_state.params, qf1_state.target_params, args.tau @@ -353,7 +367,7 @@ def actor_loss(params): qf2_state.params, qf2_state.target_params, args.tau ) ) - return actor_state, (qf1_state, qf2_state), actor_loss_value, key + return qf1_state, qf2_state start_time = time.time() n_updates = 0 @@ -433,6 +447,9 @@ def actor_loss(params): key, ) + # TODO: check if we need to update actor target too + qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) + if n_updates % args.policy_frequency == 0: ( actor_state, From 23e4d3b1fad6ade45045d934d5f05208677000bc Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 10:24:03 +0200 Subject: [PATCH 07/31] Try to jit the for loop --- cleanrl/td3_droq_continuous_action_jax.py | 95 ++++++++++++++--------- 1 file changed, 58 insertions(+), 37 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 21f028e54..7c39e16f2 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -369,6 +369,54 @@ def update_q_target_networks(qf1_state, qf2_state): ) return qf1_state, qf2_state + @jax.jit + def train(qf1_state, qf2_state, actor_state, key, n_updates): + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + (qf1_a_values, qf2_a_values), + key, + ) = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) + + # TODO: check if we need to update actor target too + qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) + + ( + actor_state, + (qf1_state, qf2_state), + actor_loss_value, + key, + ) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + key, + ) + return ( + n_updates, + (qf1_state, qf2_state), + actor_state, + key, + (qf1_loss_value, qf2_loss_value), + actor_loss_value, + (qf1_a_values, qf2_a_values), + ) + start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): @@ -426,45 +474,18 @@ def update_q_target_networks(qf1_state, qf2_state): # ALGO LOGIC: training. if global_step > args.learning_starts: - for _ in range(args.gradient_steps): - n_updates += 1 - data = rb.sample(args.batch_size) - - ( - (qf1_state, qf2_state), - (qf1_loss_value, qf2_loss_value), - (qf1_a_values, qf2_a_values), - key, - ) = update_critic( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - key, - ) - - # TODO: check if we need to update actor target too - qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) - - if n_updates % args.policy_frequency == 0: - ( - actor_state, - (qf1_state, qf2_state), - actor_loss_value, - key, - ) = update_actor( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - key, - ) + ( + n_updates, + (qf1_state, qf2_state), + actor_state, + key, + (qf1_loss_value, qf2_loss_value), + actor_loss_value, + (qf1_a_values, qf2_a_values), + ) = train(qf1_state, qf2_state, actor_state, key, n_updates) fps = int(global_step / (time.time() - start_time)) + if args.eval_freq > 0 and global_step % args.eval_freq == 0: mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) print( From df61ae5f1c9b0251becb54e8a1412e4f2b15e27e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 11:01:59 +0200 Subject: [PATCH 08/31] Add no jit train version --- cleanrl/td3_droq_continuous_action_jax.py | 55 ++++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 7c39e16f2..804f1d44e 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -373,8 +373,8 @@ def update_q_target_networks(qf1_state, qf2_state): def train(qf1_state, qf2_state, actor_state, key, n_updates): for _ in range(args.gradient_steps): n_updates += 1 + # TODO: replace with jitable replay buffer, currently buggy (same samples are returned) data = rb.sample(args.batch_size) - ( (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), @@ -417,6 +417,54 @@ def train(qf1_state, qf2_state, actor_state, key, n_updates): (qf1_a_values, qf2_a_values), ) + def train_no_jit(qf1_state, qf2_state, actor_state, key, n_updates): + actor_loss_value = 0.0 + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + (qf1_a_values, qf2_a_values), + key, + ) = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) + + # TODO: check if we need to update actor target too + qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) + if n_updates % args.policy_frequency == 0: + ( + actor_state, + (qf1_state, qf2_state), + actor_loss_value, + key, + ) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + key, + ) + return ( + n_updates, + (qf1_state, qf2_state), + actor_state, + key, + (qf1_loss_value, qf2_loss_value), + actor_loss_value, + (qf1_a_values, qf2_a_values), + ) + start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): @@ -474,6 +522,9 @@ def train(qf1_state, qf2_state, actor_state, key, n_updates): # ALGO LOGIC: training. if global_step > args.learning_starts: + # TODO: fix when train_freq > 1 + train_fn = train if args.policy_frequency == args.gradient_steps else train_no_jit + ( n_updates, (qf1_state, qf2_state), @@ -482,7 +533,7 @@ def train(qf1_state, qf2_state, actor_state, key, n_updates): (qf1_loss_value, qf2_loss_value), actor_loss_value, (qf1_a_values, qf2_a_values), - ) = train(qf1_state, qf2_state, actor_state, key, n_updates) + ) = train_fn(qf1_state, qf2_state, actor_state, key, n_updates) fps = int(global_step / (time.time() - start_time)) From f7b4e7c40aae77db3fe87065eea072972d1e49be Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 11:02:18 +0200 Subject: [PATCH 09/31] Revert "Add no jit train version" This reverts commit df61ae5f1c9b0251becb54e8a1412e4f2b15e27e. --- cleanrl/td3_droq_continuous_action_jax.py | 55 +---------------------- 1 file changed, 2 insertions(+), 53 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 804f1d44e..7c39e16f2 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -373,8 +373,8 @@ def update_q_target_networks(qf1_state, qf2_state): def train(qf1_state, qf2_state, actor_state, key, n_updates): for _ in range(args.gradient_steps): n_updates += 1 - # TODO: replace with jitable replay buffer, currently buggy (same samples are returned) data = rb.sample(args.batch_size) + ( (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), @@ -417,54 +417,6 @@ def train(qf1_state, qf2_state, actor_state, key, n_updates): (qf1_a_values, qf2_a_values), ) - def train_no_jit(qf1_state, qf2_state, actor_state, key, n_updates): - actor_loss_value = 0.0 - for _ in range(args.gradient_steps): - n_updates += 1 - data = rb.sample(args.batch_size) - - ( - (qf1_state, qf2_state), - (qf1_loss_value, qf2_loss_value), - (qf1_a_values, qf2_a_values), - key, - ) = update_critic( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - key, - ) - - # TODO: check if we need to update actor target too - qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) - if n_updates % args.policy_frequency == 0: - ( - actor_state, - (qf1_state, qf2_state), - actor_loss_value, - key, - ) = update_actor( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - key, - ) - return ( - n_updates, - (qf1_state, qf2_state), - actor_state, - key, - (qf1_loss_value, qf2_loss_value), - actor_loss_value, - (qf1_a_values, qf2_a_values), - ) - start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): @@ -522,9 +474,6 @@ def train_no_jit(qf1_state, qf2_state, actor_state, key, n_updates): # ALGO LOGIC: training. if global_step > args.learning_starts: - # TODO: fix when train_freq > 1 - train_fn = train if args.policy_frequency == args.gradient_steps else train_no_jit - ( n_updates, (qf1_state, qf2_state), @@ -533,7 +482,7 @@ def train_no_jit(qf1_state, qf2_state, actor_state, key, n_updates): (qf1_loss_value, qf2_loss_value), actor_loss_value, (qf1_a_values, qf2_a_values), - ) = train_fn(qf1_state, qf2_state, actor_state, key, n_updates) + ) = train(qf1_state, qf2_state, actor_state, key, n_updates) fps = int(global_step / (time.time() - start_time)) From 373aabb1a0ede55525f8f2a8001134e7477ed13e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 11:02:23 +0200 Subject: [PATCH 10/31] Revert "Try to jit the for loop" This reverts commit 23e4d3b1fad6ade45045d934d5f05208677000bc. --- cleanrl/td3_droq_continuous_action_jax.py | 95 +++++++++-------------- 1 file changed, 37 insertions(+), 58 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 7c39e16f2..21f028e54 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -369,54 +369,6 @@ def update_q_target_networks(qf1_state, qf2_state): ) return qf1_state, qf2_state - @jax.jit - def train(qf1_state, qf2_state, actor_state, key, n_updates): - for _ in range(args.gradient_steps): - n_updates += 1 - data = rb.sample(args.batch_size) - - ( - (qf1_state, qf2_state), - (qf1_loss_value, qf2_loss_value), - (qf1_a_values, qf2_a_values), - key, - ) = update_critic( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - key, - ) - - # TODO: check if we need to update actor target too - qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) - - ( - actor_state, - (qf1_state, qf2_state), - actor_loss_value, - key, - ) = update_actor( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - key, - ) - return ( - n_updates, - (qf1_state, qf2_state), - actor_state, - key, - (qf1_loss_value, qf2_loss_value), - actor_loss_value, - (qf1_a_values, qf2_a_values), - ) - start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): @@ -474,18 +426,45 @@ def train(qf1_state, qf2_state, actor_state, key, n_updates): # ALGO LOGIC: training. if global_step > args.learning_starts: - ( - n_updates, - (qf1_state, qf2_state), - actor_state, - key, - (qf1_loss_value, qf2_loss_value), - actor_loss_value, - (qf1_a_values, qf2_a_values), - ) = train(qf1_state, qf2_state, actor_state, key, n_updates) + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + (qf1_a_values, qf2_a_values), + key, + ) = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) - fps = int(global_step / (time.time() - start_time)) + # TODO: check if we need to update actor target too + qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) + if n_updates % args.policy_frequency == 0: + ( + actor_state, + (qf1_state, qf2_state), + actor_loss_value, + key, + ) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + key, + ) + + fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) print( From 85fa1434db8c47a57235d7121abca8f5a7c94d64 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 11:05:45 +0200 Subject: [PATCH 11/31] Revert "Separate q network target update" This reverts commit f0cc8ffccc8ad2deeb5ab20399ec90eee035587f. --- cleanrl/td3_droq_continuous_action_jax.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index 21f028e54..d56f14136 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -343,20 +343,6 @@ def actor_loss(params): ) ) - # qf1_state = qf1_state.replace( - # target_params=optax.incremental_update( - # qf1_state.params, qf1_state.target_params, args.tau - # ) - # ) - # qf2_state = qf2_state.replace( - # target_params=optax.incremental_update( - # qf2_state.params, qf2_state.target_params, args.tau - # ) - # ) - return actor_state, (qf1_state, qf2_state), actor_loss_value, key - - @jax.jit - def update_q_target_networks(qf1_state, qf2_state): qf1_state = qf1_state.replace( target_params=optax.incremental_update( qf1_state.params, qf1_state.target_params, args.tau @@ -367,7 +353,7 @@ def update_q_target_networks(qf1_state, qf2_state): qf2_state.params, qf2_state.target_params, args.tau ) ) - return qf1_state, qf2_state + return actor_state, (qf1_state, qf2_state), actor_loss_value, key start_time = time.time() n_updates = 0 @@ -447,9 +433,6 @@ def update_q_target_networks(qf1_state, qf2_state): key, ) - # TODO: check if we need to update actor target too - qf1_state, qf2_state = update_q_target_networks(qf1_state, qf2_state) - if n_updates % args.policy_frequency == 0: ( actor_state, From 60f63e1a08cdf7e65e37a138c2fe2af93a6ee502 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 14:39:34 +0200 Subject: [PATCH 12/31] TQC + TD3 + DroQ first attempt --- cleanrl/tqc_td3_jax.py | 525 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 525 insertions(+) create mode 100644 cleanrl/tqc_td3_jax.py diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py new file mode 100644 index 000000000..424cc8924 --- /dev/null +++ b/cleanrl/tqc_td3_jax.py @@ -0,0 +1,525 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import Optional, Sequence + +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pybullet_envs # noqa +from flax.training.train_state import TrainState +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.vec_env import DummyVecEnv +from torch.utils.tensorboard import SummaryWriter +from tqdm.rich import tqdm + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + help="the id of the environment") + parser.add_argument("-n", "--total-timesteps", type=int, default=1000000, + help="total timesteps of the experiments") + parser.add_argument("-lr", "--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--buffer-size", type=int, default=int(1e6), + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=0.005, + help="target smoothing coefficient (default: 0.005)") + parser.add_argument("--policy-noise", type=float, default=0.2, + help="the scale of policy noise") + parser.add_argument("--batch-size", type=int, default=256, + help="the batch size of sample from the reply memory") + parser.add_argument("--exploration-noise", type=float, default=0.1, + help="the scale of exploration noise") + parser.add_argument("--learning-starts", type=int, default=1000, + help="timestep to start learning") + parser.add_argument("--gradient-steps", type=int, default=1, + help="Number of gradient steps to perform after each rollout") + # Argument for dropout rate + parser.add_argument("--dropout-rate", type=float, default=0.0) + # Argument for layer normalization + parser.add_argument("--layer-norm", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) + parser.add_argument("--policy-frequency", type=int, default=2, + help="the frequency of training policy (delayed)") + parser.add_argument("--eval-freq", type=int, default=-1) + parser.add_argument("--n-eval-envs", type=int, default=1) + parser.add_argument("--n-eval-episodes", type=int, default=10) + parser.add_argument("--verbose", type=int, default=1) + + parser.add_argument("--noise-clip", type=float, default=0.5, + help="noise clip parameter of the Target Policy Smoothing Regularization") + args = parser.parse_args() + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video=False, run_name=""): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + n_quantiles: int = 25 + + @nn.compact + def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False): + x = jnp.concatenate([x, a], -1) + x = nn.Dense(256)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = nn.Dense(self.n_quantiles)(x) + return x + + +class Actor(nn.Module): + action_dim: Sequence[int] + action_scale: float + action_bias: float + + @nn.compact + def __call__(self, x): + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(256)(x) + x = nn.relu(x) + x = nn.Dense(self.action_dim)(x) + x = nn.tanh(x) + x = x * self.action_scale + self.action_bias + return x + + +class RLTrainState(TrainState): + target_params: flax.core.FrozenDict + + +def evaluate_policy(eval_env, actor, actor_state, n_eval_episodes: int = 10): + eval_episode_rewards = [] + for _ in range(n_eval_episodes): + obs = eval_env.reset() + done = False + episode_reward = 0 + while not done: + action = np.array(actor.apply(actor_state.params, obs)) + obs, reward, done, _ = eval_env.step(action) + episode_reward += reward + eval_episode_rewards.append(episode_reward) + return np.mean(eval_episode_rewards), np.std(eval_episode_rewards) + + +def main(): + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) + key, dropout_key1, dropout_key2 = jax.random.split(key, 3) + + # env setup + envs = DummyVecEnv( + [make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] + ) + eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) + + assert isinstance( + envs.action_space, gym.spaces.Box + ), "only continuous action space is supported" + + # Assume that all dimensions share the same bound + min_action = float(envs.action_space.low[0]) + max_action = float(envs.action_space.high[0]) + envs.observation_space.dtype = np.float32 + rb = ReplayBuffer( + args.buffer_size, + envs.observation_space, + envs.action_space, + device="cpu", + handle_timeout_termination=True, + ) + + # Sort and drop top k quantiles to control overestimation. + n_quantiles = 25 + n_critics = 2 + quantiles_total = n_quantiles * n_critics + top_quantiles_to_drop_per_net = 2 + n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * n_critics + + # TRY NOT TO MODIFY: start the game + obs = envs.reset() + actor = Actor( + action_dim=np.prod(envs.action_space.shape), + action_scale=(max_action - min_action) / 2.0, + action_bias=(max_action + min_action) / 2.0, + ) + actor_state = RLTrainState.create( + apply_fn=actor.apply, + params=actor.init(actor_key, obs), + target_params=actor.init(actor_key, obs), + tx=optax.adam(learning_rate=args.learning_rate), + ) + qf = QNetwork(dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm) + qf1_state = RLTrainState.create( + apply_fn=qf.apply, + params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), + tx=optax.adam(learning_rate=args.learning_rate), + ) + qf2_state = RLTrainState.create( + apply_fn=qf.apply, + params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), + tx=optax.adam(learning_rate=args.learning_rate), + ) + actor.apply = jax.jit(actor.apply) + qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) + + @jax.jit + def update_critic( + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + observations: np.ndarray, + actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + key: jnp.ndarray, + ): + # TODO Maybe pre-generate a lot of random keys + # also check https://jax.readthedocs.io/en/latest/jax.random.html + key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) + key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3) + + clipped_noise = jnp.clip( + (jax.random.normal(noise_key, actions[0].shape) * args.policy_noise), + -args.noise_clip, + args.noise_clip, + ) + next_state_actions = jnp.clip( + actor.apply(actor_state.target_params, next_observations) + clipped_noise, + envs.action_space.low[0], + envs.action_space.high[0], + ) + qf1_next_quantiles = qf.apply( + qf1_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_1}, + ) + qf2_next_quantiles = qf.apply( + qf2_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_2}, + ) + + # Concatenate quantiles from both critics to get a single tensor + # batch x quantiles + qf_next_quantiles = jnp.concatenate((qf1_next_quantiles, qf2_next_quantiles), axis=1) + + # sort next quantiles with jax + next_quantiles = jnp.sort(qf_next_quantiles) + # Keep only the quantiles we need + next_target_quantiles = next_quantiles[:, :n_target_quantiles] + + + target_quantiles = ( + rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * (next_target_quantiles) + ) + + # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). + target_quantiles = jnp.expand_dims(target_quantiles, axis=1) + + + def huber_quantile_loss(params, noise_key): + # Compute huber quantile loss + current_quantiles = qf.apply( + params, observations, actions, True, rngs={"dropout": noise_key} + ) + # convert to shape: (batch_size, n_quantiles, 1) + current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) + + # Cumulative probabilities to calculate quantiles. + # shape: (n_quantiles,) + cum_prob = (jnp.arange(n_quantiles, dtype=jnp.float32) + 0.5) / n_quantiles + # convert to shape: (1, n_quantiles, 1) + cum_prob = jnp.expand_dims(cum_prob, axis=(0, -1)) + + # TQC + # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles) + # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1) + # pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles) + # Note: in both cases, the loss has the same shape as pairwise_delta + + pairwise_delta = target_quantiles - current_quantiles + abs_pairwise_delta = jnp.abs(pairwise_delta) + huber_loss = jnp.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) + loss = jnp.abs(cum_prob - (pairwise_delta < 0).astype(jnp.float32)) * huber_loss + return loss.mean() + + qf1_loss_value, grads1 = jax.value_and_grad( + huber_quantile_loss, has_aux=False + )(qf1_state.params, dropout_key_3) + qf2_loss_value, grads2 = jax.value_and_grad( + huber_quantile_loss, has_aux=False + )(qf2_state.params, dropout_key_4) + qf1_state = qf1_state.apply_gradients(grads=grads1) + qf2_state = qf2_state.apply_gradients(grads=grads2) + + return ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + key, + ) + + @jax.jit + def update_actor( + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + observations: np.ndarray, + key: jnp.ndarray, + ): + key, dropout_key = jax.random.split(key, 2) + + def actor_loss(params): + return -qf.apply( + qf1_state.params, + observations, + actor.apply(params, observations), + True, + rngs={"dropout": dropout_key}, + ).mean() + + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + actor_state = actor_state.replace( + target_params=optax.incremental_update( + actor_state.params, actor_state.target_params, args.tau + ) + ) + + qf1_state = qf1_state.replace( + target_params=optax.incremental_update( + qf1_state.params, qf1_state.target_params, args.tau + ) + ) + qf2_state = qf2_state.replace( + target_params=optax.incremental_update( + qf2_state.params, qf2_state.target_params, args.tau + ) + ) + return actor_state, (qf1_state, qf2_state), actor_loss_value, key + + start_time = time.time() + n_updates = 0 + for global_step in tqdm(range(args.total_timesteps)): + # for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = np.array( + [envs.action_space.sample() for _ in range(envs.num_envs)] + ) + else: + actions = actor.apply(actor_state.params, obs) + actions = np.array( + [ + ( + jax.device_get(actions)[0] + + np.random.normal( + 0, + max_action * args.exploration_noise, + size=envs.action_space.shape[0], + ) + ).clip(envs.action_space.low, envs.action_space.high) + ] + ) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, dones, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + for info in infos: + if "episode" in info.keys(): + if args.verbose >= 2: + print( + f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" + ) + writer.add_scalar( + "charts/episodic_return", info["episode"]["r"], global_step + ) + writer.add_scalar( + "charts/episodic_length", info["episode"]["l"], global_step + ) + break + + # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` + real_next_obs = next_obs.copy() + for idx, done in enumerate(dones): + if done: + real_next_obs[idx] = infos[idx]["terminal_observation"] + # Timeout handling done inside the replay buffer + # if infos[idx].get("TimeLimit.truncated", False) == True: + # real_dones[idx] = False + + rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ( + (qf1_state, qf2_state), + (qf1_loss_value, qf2_loss_value), + key, + ) = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) + + if n_updates % args.policy_frequency == 0: + ( + actor_state, + (qf1_state, qf2_state), + actor_loss_value, + key, + ) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + key, + ) + + fps = int(global_step / (time.time() - start_time)) + if args.eval_freq > 0 and global_step % args.eval_freq == 0: + mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) + print( + f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps" + ) + writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) + writer.add_scalar("charts/std_eval_reward", std_reward, global_step) + + if global_step % 100 == 0: + writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) + writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) + # writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) + # writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) + writer.add_scalar( + "losses/actor_loss", actor_loss_value.item(), global_step + ) + if args.verbose >= 2: + print("FPS:", fps) + writer.add_scalar( + "charts/SPS", + int(global_step / (time.time() - start_time)), + global_step, + ) + + envs.close() + writer.close() + + +if __name__ == "__main__": + + try: + main() + except KeyboardInterrupt: + pass From 44f3a9b774422e04df259c30374a782a99182e9c Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 16:16:58 +0200 Subject: [PATCH 13/31] Add number of quantiles to drop as param --- cleanrl/tqc_td3_jax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 424cc8924..1f91fe5e1 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -70,6 +70,8 @@ def parse_args(): parser.add_argument("--n-eval-envs", type=int, default=1) parser.add_argument("--n-eval-episodes", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) + # top quantiles to drop per net + parser.add_argument("--top-quantile-drop-per-net", type=int, default=2) parser.add_argument("--noise-clip", type=float, default=0.5, help="noise clip parameter of the Target Policy Smoothing Regularization") @@ -208,7 +210,7 @@ def main(): n_quantiles = 25 n_critics = 2 quantiles_total = n_quantiles * n_critics - top_quantiles_to_drop_per_net = 2 + top_quantiles_to_drop_per_net = args.top_quantiles_to_drop_per_net n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * n_critics # TRY NOT TO MODIFY: start the game From 5156d7818370d13bf001843ad6689999488aa009 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 16:21:20 +0200 Subject: [PATCH 14/31] Fixes and reformat --- cleanrl/td3_droq_continuous_action_jax.py | 77 +++++---------------- cleanrl/tqc_td3_jax.py | 84 ++++++----------------- 2 files changed, 39 insertions(+), 122 deletions(-) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_continuous_action_jax.py index d56f14136..e9f92be5d 100644 --- a/cleanrl/td3_droq_continuous_action_jax.py +++ b/cleanrl/td3_droq_continuous_action_jax.py @@ -170,8 +170,7 @@ def main(): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" - % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -182,14 +181,10 @@ def main(): key, dropout_key1, dropout_key2 = jax.random.split(key, 3) # env setup - envs = DummyVecEnv( - [make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] - ) + envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) - assert isinstance( - envs.action_space, gym.spaces.Box - ), "only continuous action space is supported" + assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" # Assume that all dimensions share the same bound min_action = float(envs.action_space.low[0]) @@ -290,22 +285,14 @@ def update_critic( rngs={"dropout": dropout_key_2}, ).reshape(-1) min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) - next_q_value = ( - rewards + (1 - dones) * args.gamma * (min_qf_next_target) - ).reshape(-1) + next_q_value = (rewards + (1 - dones) * args.gamma * (min_qf_next_target)).reshape(-1) def mse_loss(params, noise_key): - qf_a_values = qf.apply( - params, observations, actions, True, rngs={"dropout": noise_key} - ).squeeze() + qf_a_values = qf.apply(params, observations, actions, True, rngs={"dropout": noise_key}).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() - (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad( - mse_loss, has_aux=True - )(qf1_state.params, dropout_key_3) - (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad( - mse_loss, has_aux=True - )(qf2_state.params, dropout_key_4) + (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params, dropout_key_3) + (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params, dropout_key_4) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) @@ -338,20 +325,14 @@ def actor_loss(params): actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) actor_state = actor_state.replace( - target_params=optax.incremental_update( - actor_state.params, actor_state.target_params, args.tau - ) + target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) ) qf1_state = qf1_state.replace( - target_params=optax.incremental_update( - qf1_state.params, qf1_state.target_params, args.tau - ) + target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( - target_params=optax.incremental_update( - qf2_state.params, qf2_state.target_params, args.tau - ) + target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) return actor_state, (qf1_state, qf2_state), actor_loss_value, key @@ -360,9 +341,7 @@ def actor_loss(params): for global_step in tqdm(range(args.total_timesteps)): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array( - [envs.action_space.sample() for _ in range(envs.num_envs)] - ) + actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)]) else: actions = actor.apply(actor_state.params, obs) actions = np.array( @@ -385,15 +364,9 @@ def actor_loss(params): for info in infos: if "episode" in info.keys(): if args.verbose >= 2: - print( - f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" - ) - writer.add_scalar( - "charts/episodic_return", info["episode"]["r"], global_step - ) - writer.add_scalar( - "charts/episodic_length", info["episode"]["l"], global_step - ) + print(f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` @@ -416,12 +389,7 @@ def actor_loss(params): n_updates += 1 data = rb.sample(args.batch_size) - ( - (qf1_state, qf2_state), - (qf1_loss_value, qf2_loss_value), - (qf1_a_values, qf2_a_values), - key, - ) = update_critic( + ((qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key,) = update_critic( actor_state, qf1_state, qf2_state, @@ -434,12 +402,7 @@ def actor_loss(params): ) if n_updates % args.policy_frequency == 0: - ( - actor_state, - (qf1_state, qf2_state), - actor_loss_value, - key, - ) = update_actor( + (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( actor_state, qf1_state, qf2_state, @@ -450,9 +413,7 @@ def actor_loss(params): fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) - print( - f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps" - ) + print(f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps") writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) writer.add_scalar("charts/std_eval_reward", std_reward, global_step) @@ -461,9 +422,7 @@ def actor_loss(params): writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) - writer.add_scalar( - "losses/actor_loss", actor_loss_value.item(), global_step - ) + writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) if args.verbose >= 2: print("FPS:", fps) writer.add_scalar( diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 1f91fe5e1..68c75a3b5 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -71,7 +71,7 @@ def parse_args(): parser.add_argument("--n-eval-episodes", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) # top quantiles to drop per net - parser.add_argument("--top-quantile-drop-per-net", type=int, default=2) + parser.add_argument("-t", "--top-quantiles-to-drop-per-net", type=int, default=2) parser.add_argument("--noise-clip", type=float, default=0.5, help="noise clip parameter of the Target Policy Smoothing Regularization") @@ -173,8 +173,7 @@ def main(): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" - % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -185,14 +184,10 @@ def main(): key, dropout_key1, dropout_key2 = jax.random.split(key, 3) # env setup - envs = DummyVecEnv( - [make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] - ) + envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) - assert isinstance( - envs.action_space, gym.spaces.Box - ), "only continuous action space is supported" + assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" # Assume that all dimensions share the same bound min_action = float(envs.action_space.low[0]) @@ -308,21 +303,15 @@ def update_critic( next_quantiles = jnp.sort(qf_next_quantiles) # Keep only the quantiles we need next_target_quantiles = next_quantiles[:, :n_target_quantiles] - - target_quantiles = ( - rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * (next_target_quantiles) - ) + target_quantiles = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * (next_target_quantiles) # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). target_quantiles = jnp.expand_dims(target_quantiles, axis=1) - def huber_quantile_loss(params, noise_key): # Compute huber quantile loss - current_quantiles = qf.apply( - params, observations, actions, True, rngs={"dropout": noise_key} - ) + current_quantiles = qf.apply(params, observations, actions, True, rngs={"dropout": noise_key}) # convert to shape: (batch_size, n_quantiles, 1) current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) @@ -331,7 +320,7 @@ def huber_quantile_loss(params, noise_key): cum_prob = (jnp.arange(n_quantiles, dtype=jnp.float32) + 0.5) / n_quantiles # convert to shape: (1, n_quantiles, 1) cum_prob = jnp.expand_dims(cum_prob, axis=(0, -1)) - + # TQC # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles) # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1) @@ -344,12 +333,8 @@ def huber_quantile_loss(params, noise_key): loss = jnp.abs(cum_prob - (pairwise_delta < 0).astype(jnp.float32)) * huber_loss return loss.mean() - qf1_loss_value, grads1 = jax.value_and_grad( - huber_quantile_loss, has_aux=False - )(qf1_state.params, dropout_key_3) - qf2_loss_value, grads2 = jax.value_and_grad( - huber_quantile_loss, has_aux=False - )(qf2_state.params, dropout_key_4) + qf1_loss_value, grads1 = jax.value_and_grad(huber_quantile_loss, has_aux=False)(qf1_state.params, dropout_key_3) + qf2_loss_value, grads2 = jax.value_and_grad(huber_quantile_loss, has_aux=False)(qf2_state.params, dropout_key_4) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) @@ -381,32 +366,24 @@ def actor_loss(params): actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) actor_state = actor_state.replace( - target_params=optax.incremental_update( - actor_state.params, actor_state.target_params, args.tau - ) + target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) ) qf1_state = qf1_state.replace( - target_params=optax.incremental_update( - qf1_state.params, qf1_state.target_params, args.tau - ) + target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( - target_params=optax.incremental_update( - qf2_state.params, qf2_state.target_params, args.tau - ) + target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) return actor_state, (qf1_state, qf2_state), actor_loss_value, key start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): - # for global_step in range(args.total_timesteps): + # for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array( - [envs.action_space.sample() for _ in range(envs.num_envs)] - ) + actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)]) else: actions = actor.apply(actor_state.params, obs) actions = np.array( @@ -429,15 +406,9 @@ def actor_loss(params): for info in infos: if "episode" in info.keys(): if args.verbose >= 2: - print( - f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}" - ) - writer.add_scalar( - "charts/episodic_return", info["episode"]["r"], global_step - ) - writer.add_scalar( - "charts/episodic_length", info["episode"]["l"], global_step - ) + print(f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` @@ -460,11 +431,7 @@ def actor_loss(params): n_updates += 1 data = rb.sample(args.batch_size) - ( - (qf1_state, qf2_state), - (qf1_loss_value, qf2_loss_value), - key, - ) = update_critic( + ((qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), key,) = update_critic( actor_state, qf1_state, qf2_state, @@ -477,12 +444,7 @@ def actor_loss(params): ) if n_updates % args.policy_frequency == 0: - ( - actor_state, - (qf1_state, qf2_state), - actor_loss_value, - key, - ) = update_actor( + (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( actor_state, qf1_state, qf2_state, @@ -493,9 +455,7 @@ def actor_loss(params): fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) - print( - f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps" - ) + print(f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps") writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) writer.add_scalar("charts/std_eval_reward", std_reward, global_step) @@ -504,9 +464,7 @@ def actor_loss(params): writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) # writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) # writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) - writer.add_scalar( - "losses/actor_loss", actor_loss_value.item(), global_step - ) + writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) if args.verbose >= 2: print("FPS:", fps) writer.add_scalar( From 8aaca4f56d14fad86331bd157a5c056faba9d408 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 16 Sep 2022 17:16:38 +0200 Subject: [PATCH 15/31] n_units as param --- cleanrl/tqc_td3_jax.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 68c75a3b5..6d9934fea 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -72,6 +72,7 @@ def parse_args(): parser.add_argument("--verbose", type=int, default=1) # top quantiles to drop per net parser.add_argument("-t", "--top-quantiles-to-drop-per-net", type=int, default=2) + parser.add_argument("--n-units", type=int, default=256) parser.add_argument("--noise-clip", type=float, default=0.5, help="noise clip parameter of the Target Policy Smoothing Regularization") @@ -100,17 +101,18 @@ class QNetwork(nn.Module): use_layer_norm: bool = False dropout_rate: Optional[float] = None n_quantiles: int = 25 + n_units: int = 256 @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False): x = jnp.concatenate([x, a], -1) - x = nn.Dense(256)(x) + x = nn.Dense(self.n_units)(x) if self.dropout_rate is not None and self.dropout_rate > 0: x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) if self.use_layer_norm: x = nn.LayerNorm()(x) x = nn.relu(x) - x = nn.Dense(256)(x) + x = nn.Dense(self.n_units)(x) if self.dropout_rate is not None and self.dropout_rate > 0: x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) if self.use_layer_norm: @@ -124,12 +126,13 @@ class Actor(nn.Module): action_dim: Sequence[int] action_scale: float action_bias: float + n_units: int = 256 @nn.compact def __call__(self, x): - x = nn.Dense(256)(x) + x = nn.Dense(self.n_units)(x) x = nn.relu(x) - x = nn.Dense(256)(x) + x = nn.Dense(self.n_units)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) x = nn.tanh(x) @@ -214,6 +217,7 @@ def main(): action_dim=np.prod(envs.action_space.shape), action_scale=(max_action - min_action) / 2.0, action_bias=(max_action + min_action) / 2.0, + n_units=args.n_units, ) actor_state = RLTrainState.create( apply_fn=actor.apply, @@ -221,7 +225,9 @@ def main(): target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) - qf = QNetwork(dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm) + qf = QNetwork( + dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm, n_units=args.n_units, n_quantiles=n_quantiles + ) qf1_state = RLTrainState.create( apply_fn=qf.apply, params=qf.init( From aabf78968bfe3e49322d8872cd1bcf1a704a71e3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 17 Sep 2022 11:24:26 +0200 Subject: [PATCH 16/31] Add train method --- cleanrl/tqc_td3_jax.py | 87 ++++++++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 38 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 6d9934fea..f66557280 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -4,6 +4,7 @@ import random import time from distutils.util import strtobool +from functools import partial from typing import Optional, Sequence import flax @@ -318,21 +319,16 @@ def update_critic( def huber_quantile_loss(params, noise_key): # Compute huber quantile loss current_quantiles = qf.apply(params, observations, actions, True, rngs={"dropout": noise_key}) - # convert to shape: (batch_size, n_quantiles, 1) + # convert to shape: (batch_size, n_quantiles, 1) for broadcast current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) # Cumulative probabilities to calculate quantiles. # shape: (n_quantiles,) cum_prob = (jnp.arange(n_quantiles, dtype=jnp.float32) + 0.5) / n_quantiles - # convert to shape: (1, n_quantiles, 1) + # convert to shape: (1, n_quantiles, 1) for broadcast cum_prob = jnp.expand_dims(cum_prob, axis=(0, -1)) - # TQC - # target_quantiles: (batch_size, 1, n_target_quantiles) -> (batch_size, 1, 1, n_target_quantiles) - # current_quantiles: (batch_size, n_critics, n_quantiles) -> (batch_size, n_critics, n_quantiles, 1) - # pairwise_delta: (batch_size, n_critics, n_quantiles, n_target_quantiles) - # Note: in both cases, the loss has the same shape as pairwise_delta - + # pairwise_delta: (batch_size, n_quantiles, n_target_quantiles) pairwise_delta = target_quantiles - current_quantiles abs_pairwise_delta = jnp.abs(pairwise_delta) huber_loss = jnp.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) @@ -350,13 +346,14 @@ def huber_quantile_loss(params, noise_key): key, ) - @jax.jit + @partial(jax.jit, static_argnames=["update_actor"]) def update_actor( actor_state: RLTrainState, qf1_state: RLTrainState, qf2_state: RLTrainState, observations: np.ndarray, key: jnp.ndarray, + update_actor: bool, ): key, dropout_key = jax.random.split(key, 2) @@ -369,11 +366,15 @@ def actor_loss(params): rngs={"dropout": dropout_key}, ).mean() - actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) - actor_state = actor_state.apply_gradients(grads=grads) - actor_state = actor_state.replace( - target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) - ) + if update_actor: + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + # TODO: check with and without updating target actor + actor_state = actor_state.replace( + target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) + ) + else: + actor_loss_value = 0.0 qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) @@ -383,6 +384,33 @@ def actor_loss(params): ) return actor_state, (qf1_state, qf2_state), actor_loss_value, key + def train(n_updates: int, qf1_state, qf2_state, actor_state, key): + for _ in range(args.gradient_steps): + n_updates += 1 + data = rb.sample(args.batch_size) + + ((qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), key,) = update_critic( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.rewards.flatten().numpy(), + data.dones.flatten().numpy(), + key, + ) + + (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( + actor_state, + qf1_state, + qf2_state, + data.observations.numpy(), + key, + update_actor=n_updates % args.policy_frequency, + ) + return n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) + start_time = time.time() n_updates = 0 for global_step in tqdm(range(args.total_timesteps)): @@ -433,30 +461,13 @@ def actor_loss(params): # ALGO LOGIC: training. if global_step > args.learning_starts: - for _ in range(args.gradient_steps): - n_updates += 1 - data = rb.sample(args.batch_size) - - ((qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), key,) = update_critic( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), - key, - ) - - if n_updates % args.policy_frequency == 0: - (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( - actor_state, - qf1_state, - qf2_state, - data.observations.numpy(), - key, - ) + n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) = train( + n_updates, + qf1_state, + qf2_state, + actor_state, + key, + ) fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: From cc74d9e973a4b2699f5cdcf0860f23cb9ff56710 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 17 Sep 2022 12:24:48 +0200 Subject: [PATCH 17/31] JIT train loop --- cleanrl/tqc_td3_jax.py | 53 +++++++++++++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index f66557280..cf581a35d 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -5,7 +5,7 @@ import time from distutils.util import strtobool from functools import partial -from typing import Optional, Sequence +from typing import NamedTuple, Optional, Sequence import flax import flax.linen as nn @@ -22,6 +22,14 @@ from tqdm.rich import tqdm +class ReplayBufferSamplesNp(NamedTuple): + observations: np.ndarray + actions: np.ndarray + next_observations: np.ndarray + dones: np.ndarray + rewards: np.ndarray + + def parse_args(): # fmt: off parser = argparse.ArgumentParser() @@ -384,30 +392,41 @@ def actor_loss(params): ) return actor_state, (qf1_state, qf2_state), actor_loss_value, key - def train(n_updates: int, qf1_state, qf2_state, actor_state, key): - for _ in range(args.gradient_steps): + @jax.jit + def train(data: ReplayBufferSamplesNp, n_updates: int, qf1_state, qf2_state, actor_state, key): + for i in range(args.gradient_steps): n_updates += 1 - data = rb.sample(args.batch_size) + + def slice(x): + assert x.shape[0] % args.gradient_steps == 0 + batch_size = args.batch_size + batch_size = x.shape[0] // args.gradient_steps + return x[batch_size * i : batch_size * (i + 1)] ((qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), key,) = update_critic( actor_state, qf1_state, qf2_state, - data.observations.numpy(), - data.actions.numpy(), - data.next_observations.numpy(), - data.rewards.flatten().numpy(), - data.dones.flatten().numpy(), + slice(data.observations), + slice(data.actions), + slice(data.next_observations), + slice(data.rewards), + slice(data.dones), key, ) + # sanity check + # otherwise must use update_actor=n_updates % args.policy_frequency, + # which is not jitable + assert args.policy_frequency <= args.gradient_steps (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( actor_state, qf1_state, qf2_state, - data.observations.numpy(), + slice(data.observations), key, - update_actor=n_updates % args.policy_frequency, + update_actor=((i + 1) % args.policy_frequency) == 0, + # update_actor=(n_updates % args.policy_frequency) == 0, ) return n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) @@ -461,7 +480,19 @@ def train(n_updates: int, qf1_state, qf2_state, actor_state, key): # ALGO LOGIC: training. if global_step > args.learning_starts: + # Sample all at once for efficiency (so we can jit the for loop) + data = rb.sample(args.batch_size * args.gradient_steps) + # Convert to numpy + data = ReplayBufferSamplesNp( + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) = train( + data, n_updates, qf1_state, qf2_state, From 80589795d7b19e87bb733e212d19fb56343df5ee Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 19:00:12 +0200 Subject: [PATCH 18/31] Debug jit --- cleanrl/tqc_td3_jax.py | 62 ++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index cf581a35d..bd616c547 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -354,14 +354,14 @@ def huber_quantile_loss(params, noise_key): key, ) - @partial(jax.jit, static_argnames=["update_actor"]) + # @partial(jax.jit, static_argnames=["update_actor"]) + @jax.jit def update_actor( actor_state: RLTrainState, qf1_state: RLTrainState, qf2_state: RLTrainState, observations: np.ndarray, key: jnp.ndarray, - update_actor: bool, ): key, dropout_key = jax.random.split(key, 2) @@ -374,15 +374,12 @@ def actor_loss(params): rngs={"dropout": dropout_key}, ).mean() - if update_actor: - actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) - actor_state = actor_state.apply_gradients(grads=grads) - # TODO: check with and without updating target actor - actor_state = actor_state.replace( - target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) - ) - else: - actor_loss_value = 0.0 + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + # TODO: check with and without updating target actor + actor_state = actor_state.replace( + target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) + ) qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) @@ -418,17 +415,24 @@ def slice(x): # sanity check # otherwise must use update_actor=n_updates % args.policy_frequency, # which is not jitable - assert args.policy_frequency <= args.gradient_steps - (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( - actor_state, - qf1_state, - qf2_state, - slice(data.observations), - key, - update_actor=((i + 1) % args.policy_frequency) == 0, - # update_actor=(n_updates % args.policy_frequency) == 0, - ) - return n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) + # assert args.policy_frequency <= args.gradient_steps + (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( + actor_state, + qf1_state, + qf2_state, + slice(data.observations), + key, + # update_actor=((i + 1) % args.policy_frequency) == 0, + # update_actor=(n_updates % args.policy_frequency) == 0, + ) + return ( + n_updates, + qf1_state, + qf2_state, + actor_state, + key, + (qf1_loss_value, qf2_loss_value, actor_loss_value, slice(data.rewards)), + ) start_time = time.time() n_updates = 0 @@ -491,7 +495,14 @@ def slice(x): data.rewards.numpy().flatten(), ) - n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value) = train( + ( + n_updates, + qf1_state, + qf2_state, + actor_state, + key, + (qf1_loss_value, qf2_loss_value, actor_loss_value, reward), + ) = train( data, n_updates, qf1_state, @@ -500,6 +511,11 @@ def slice(x): key, ) + # print(global_step) + # print(reward[:5]) + # if global_step > 10005: + # exit() + fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) From 99686c8abf8c9e1c6834541dc1b83f842058a7ec Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 19:20:33 +0200 Subject: [PATCH 19/31] Cleanup + faster eval --- cleanrl/tqc_td3_jax.py | 53 +++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index bd616c547..7832ada7a 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -4,7 +4,6 @@ import random import time from distutils.util import strtobool -from functools import partial from typing import NamedTuple, Optional, Sequence import flax @@ -18,6 +17,8 @@ from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.env_util import make_vec_env from torch.utils.tensorboard import SummaryWriter from tqdm.rich import tqdm @@ -76,7 +77,7 @@ def parse_args(): parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") parser.add_argument("--eval-freq", type=int, default=-1) - parser.add_argument("--n-eval-envs", type=int, default=1) + parser.add_argument("--n-eval-envs", type=int, default=5) parser.add_argument("--n-eval-episodes", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) # top quantiles to drop per net @@ -149,22 +150,18 @@ def __call__(self, x): return x -class RLTrainState(TrainState): - target_params: flax.core.FrozenDict +class Agent: + def __init__(self, actor, actor_state) -> None: + self.actor = actor + self.actor_state = actor_state + def predict(self, obervations: np.ndarray, deterministic=True, state=None, episode_start=None): + actions = np.array(self.actor.apply(self.actor_state.params, obervations)) + return actions, None -def evaluate_policy(eval_env, actor, actor_state, n_eval_episodes: int = 10): - eval_episode_rewards = [] - for _ in range(n_eval_episodes): - obs = eval_env.reset() - done = False - episode_reward = 0 - while not done: - action = np.array(actor.apply(actor_state.params, obs)) - obs, reward, done, _ = eval_env.step(action) - episode_reward += reward - eval_episode_rewards.append(episode_reward) - return np.mean(eval_episode_rewards), np.std(eval_episode_rewards) + +class RLTrainState(TrainState): + target_params: flax.core.FrozenDict def main(): @@ -197,7 +194,7 @@ def main(): # env setup envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" @@ -234,6 +231,9 @@ def main(): target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) + + agent = Agent(actor, actor_state) + qf = QNetwork( dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm, n_units=args.n_units, n_quantiles=n_quantiles ) @@ -431,7 +431,7 @@ def slice(x): qf2_state, actor_state, key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, slice(data.rewards)), + (qf1_loss_value, qf2_loss_value, actor_loss_value), ) start_time = time.time() @@ -495,14 +495,7 @@ def slice(x): data.rewards.numpy().flatten(), ) - ( - n_updates, - qf1_state, - qf2_state, - actor_state, - key, - (qf1_loss_value, qf2_loss_value, actor_loss_value, reward), - ) = train( + (n_updates, qf1_state, qf2_state, actor_state, key, (qf1_loss_value, qf2_loss_value, actor_loss_value),) = train( data, n_updates, qf1_state, @@ -511,14 +504,10 @@ def slice(x): key, ) - # print(global_step) - # print(reward[:5]) - # if global_step > 10005: - # exit() - fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: - mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) + agent.actor, agent.actor_state = actor, actor_state + mean_reward, std_reward = evaluate_policy(agent, eval_envs, n_eval_episodes=args.n_eval_episodes) print(f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps") writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) writer.add_scalar("charts/std_eval_reward", std_reward, global_step) From d5704b301954d60c38075d1d95a7cf15f64f08c3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 19:52:47 +0200 Subject: [PATCH 20/31] Try ADAN --- cleanrl/tqc_td3_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 7832ada7a..afaa01b68 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -229,7 +229,7 @@ def main(): apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) agent = Agent(actor, actor_state) @@ -249,7 +249,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) qf2_state = RLTrainState.create( apply_fn=qf.apply, @@ -263,7 +263,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) From 047c3146f394ee4194b994ad41b10fc99f039bdd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 21:11:34 +0200 Subject: [PATCH 21/31] Revert "Try ADAN" This reverts commit d5704b301954d60c38075d1d95a7cf15f64f08c3. --- cleanrl/tqc_td3_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index afaa01b68..7832ada7a 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -229,7 +229,7 @@ def main(): apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) agent = Agent(actor, actor_state) @@ -249,7 +249,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) qf2_state = RLTrainState.create( apply_fn=qf.apply, @@ -263,7 +263,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) From 443dc712d041f02b85b1f744b19dd8d04bb44305 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 21:37:30 +0200 Subject: [PATCH 22/31] Revert "Revert "Try ADAN"" This reverts commit 047c3146f394ee4194b994ad41b10fc99f039bdd. --- cleanrl/tqc_td3_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 7832ada7a..afaa01b68 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -229,7 +229,7 @@ def main(): apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) agent = Agent(actor, actor_state) @@ -249,7 +249,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) qf2_state = RLTrainState.create( apply_fn=qf.apply, @@ -263,7 +263,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adam(learning_rate=args.learning_rate), + tx=optax.adan(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) From 8f3beece630d0ef16ac9db0c2b8960c2061f8d62 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 21:37:54 +0200 Subject: [PATCH 23/31] Sort important and Try ADAN again This reverts commit 047c3146f394ee4194b994ad41b10fc99f039bdd. --- cleanrl/tqc_td3_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index afaa01b68..6cc57da07 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -16,9 +16,9 @@ import pybullet_envs # noqa from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.vec_env import DummyVecEnv -from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv from torch.utils.tensorboard import SummaryWriter from tqdm.rich import tqdm From 940a4b66d7aa9cbc53498710e26f577ed20f0e16 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sun, 18 Sep 2022 22:29:06 +0200 Subject: [PATCH 24/31] Back to ADAM --- cleanrl/tqc_td3_jax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cleanrl/tqc_td3_jax.py b/cleanrl/tqc_td3_jax.py index 6cc57da07..8193d72ae 100644 --- a/cleanrl/tqc_td3_jax.py +++ b/cleanrl/tqc_td3_jax.py @@ -229,7 +229,7 @@ def main(): apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) agent = Agent(actor, actor_state) @@ -249,7 +249,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) qf2_state = RLTrainState.create( apply_fn=qf.apply, @@ -263,7 +263,7 @@ def main(): obs, jnp.array([envs.action_space.sample()]), ), - tx=optax.adan(learning_rate=args.learning_rate), + tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) From bcfee189c73f7820ecd362237cde773629a647a9 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 19 Sep 2022 19:29:31 +0200 Subject: [PATCH 25/31] Rename file --- cleanrl/{td3_droq_continuous_action_jax.py => td3_droq_jax.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename cleanrl/{td3_droq_continuous_action_jax.py => td3_droq_jax.py} (100%) diff --git a/cleanrl/td3_droq_continuous_action_jax.py b/cleanrl/td3_droq_jax.py similarity index 100% rename from cleanrl/td3_droq_continuous_action_jax.py rename to cleanrl/td3_droq_jax.py From d68b262c30b2859bae5f94c7bf95f4af86eaa826 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 19 Sep 2022 19:34:55 +0200 Subject: [PATCH 26/31] Add fast eval for TD3 + DroQo --- cleanrl/td3_droq_jax.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/cleanrl/td3_droq_jax.py b/cleanrl/td3_droq_jax.py index e9f92be5d..c482fb612 100644 --- a/cleanrl/td3_droq_jax.py +++ b/cleanrl/td3_droq_jax.py @@ -16,6 +16,8 @@ import pybullet_envs # noqa from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv from torch.utils.tensorboard import SummaryWriter from tqdm.rich import tqdm @@ -67,7 +69,7 @@ def parse_args(): parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") parser.add_argument("--eval-freq", type=int, default=-1) - parser.add_argument("--n-eval-envs", type=int, default=1) + parser.add_argument("--n-eval-envs", type=int, default=5) parser.add_argument("--n-eval-episodes", type=int, default=10) parser.add_argument("--verbose", type=int, default=1) @@ -134,22 +136,18 @@ def __call__(self, x): return x -class RLTrainState(TrainState): - target_params: flax.core.FrozenDict +class Agent: + def __init__(self, actor, actor_state) -> None: + self.actor = actor + self.actor_state = actor_state + def predict(self, obervations: np.ndarray, deterministic=True, state=None, episode_start=None): + actions = np.array(self.actor.apply(self.actor_state.params, obervations)) + return actions, None -def evaluate_policy(eval_env, actor, actor_state, n_eval_episodes: int = 10): - eval_episode_rewards = [] - for _ in range(n_eval_episodes): - obs = eval_env.reset() - done = False - episode_reward = 0 - while not done: - action = np.array(actor.apply(actor_state.params, obs)) - obs, reward, done, _ = eval_env.step(action) - episode_reward += reward - eval_episode_rewards.append(episode_reward) - return np.mean(eval_episode_rewards), np.std(eval_episode_rewards) + +class RLTrainState(TrainState): + target_params: flax.core.FrozenDict def main(): @@ -182,7 +180,7 @@ def main(): # env setup envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - eval_envs = DummyVecEnv([make_env(args.env_id, args.seed + 1, 0)]) + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" @@ -211,6 +209,9 @@ def main(): target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) + + agent = Agent(actor, actor_state) + qf = QNetwork(dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm) qf1_state = RLTrainState.create( apply_fn=qf.apply, @@ -412,7 +413,8 @@ def actor_loss(params): fps = int(global_step / (time.time() - start_time)) if args.eval_freq > 0 and global_step % args.eval_freq == 0: - mean_reward, std_reward = evaluate_policy(eval_envs, actor, actor_state) + agent.actor, agent.actor_state = actor, actor_state + mean_reward, std_reward = evaluate_policy(agent, eval_envs, n_eval_episodes=args.n_eval_episodes) print(f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps") writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) writer.add_scalar("charts/std_eval_reward", std_reward, global_step) From 70aa57d0a09da9602ba9c8e3ad9e63c61b60d7c4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Sep 2022 19:16:26 +0200 Subject: [PATCH 27/31] Add buggy sac implementation --- cleanrl/tqc_sac_jax.py | 581 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 581 insertions(+) create mode 100644 cleanrl/tqc_sac_jax.py diff --git a/cleanrl/tqc_sac_jax.py b/cleanrl/tqc_sac_jax.py new file mode 100644 index 000000000..5e06c154f --- /dev/null +++ b/cleanrl/tqc_sac_jax.py @@ -0,0 +1,581 @@ +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import NamedTuple, Optional, Sequence, Any + +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +import pybullet_envs # noqa +from flax.training.train_state import TrainState +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv +from torch.utils.tensorboard import SummaryWriter +from tqdm.rich import tqdm +import tensorflow_probability + +tfp = tensorflow_probability.substrates.jax +tfd = tfp.distributions + + +class ReplayBufferSamplesNp(NamedTuple): + observations: np.ndarray + actions: np.ndarray + next_observations: np.ndarray + dones: np.ndarray + rewards: np.ndarray + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="weather to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="HalfCheetah-v2", + help="the id of the environment") + parser.add_argument("-n", "--total-timesteps", type=int, default=1000000, + help="total timesteps of the experiments") + parser.add_argument("-lr", "--learning-rate", type=float, default=3e-4, + help="the learning rate of the optimizer") + parser.add_argument("--buffer-size", type=int, default=int(1e6), + help="the replay memory buffer size") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--tau", type=float, default=0.005, + help="target smoothing coefficient (default: 0.005)") + parser.add_argument("--batch-size", type=int, default=256, + help="the batch size of sample from the reply memory") + parser.add_argument("--learning-starts", type=int, default=1000, + help="timestep to start learning") + parser.add_argument("--gradient-steps", type=int, default=1, + help="Number of gradient steps to perform after each rollout") + # Argument for dropout rate + parser.add_argument("--dropout-rate", type=float, default=0.0) + # Argument for layer normalization + parser.add_argument("--layer-norm", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) + parser.add_argument("--policy-frequency", type=int, default=2, + help="the frequency of training policy (delayed)") + parser.add_argument("--eval-freq", type=int, default=-1) + parser.add_argument("--n-eval-envs", type=int, default=5) + parser.add_argument("--n-eval-episodes", type=int, default=10) + parser.add_argument("--verbose", type=int, default=1) + # top quantiles to drop per net + parser.add_argument("-t", "--top-quantiles-to-drop-per-net", type=int, default=2) + parser.add_argument("--n-units", type=int, default=256) + args = parser.parse_args() + # fmt: on + return args + + +def make_env(env_id, seed, idx, capture_video=False, run_name=""): + def thunk(): + env = gym.make(env_id) + env = gym.wrappers.RecordEpisodeStatistics(env) + if capture_video: + if idx == 0: + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + env.seed(seed) + env.action_space.seed(seed) + env.observation_space.seed(seed) + return env + + return thunk + + +# from https://github.com/ikostrikov/walk_in_the_park +# otherwise mode is not define for Squashed Gaussian +class TanhTransformedDistribution(tfd.TransformedDistribution): + def __init__(self, distribution: tfd.Distribution, validate_args: bool = False): + super().__init__(distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args) + + def mode(self) -> jnp.ndarray: + return self.bijector.forward(self.distribution.mode()) + + @classmethod + def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): + td_properties = super()._parameter_properties(dtype, num_classes=num_classes) + del td_properties["bijector"] + return td_properties + + +class Temperature(nn.Module): + initial_temperature: float = 1.0 + + @nn.compact + def __call__(self) -> jnp.ndarray: + log_temp = self.param("log_temp", init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature))) + return jnp.exp(log_temp) + + +# ALGO LOGIC: initialize agent here: +class QNetwork(nn.Module): + use_layer_norm: bool = False + dropout_rate: Optional[float] = None + n_quantiles: int = 25 + n_units: int = 256 + + @nn.compact + def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False): + x = jnp.concatenate([x, a], -1) + x = nn.Dense(self.n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = nn.Dense(self.n_units)(x) + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + if self.use_layer_norm: + x = nn.LayerNorm()(x) + x = nn.relu(x) + x = nn.Dense(self.n_quantiles)(x) + return x + + +class Actor(nn.Module): + action_dim: Sequence[int] + n_units: int = 256 + log_std_min: float = -20 + log_std_max: float = 2 + + @nn.compact + def __call__(self, x): + x = nn.Dense(self.n_units)(x) + x = nn.relu(x) + x = nn.Dense(self.n_units)(x) + x = nn.relu(x) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + +class Agent: + def __init__(self, actor, actor_state) -> None: + self.actor = actor + self.actor_state = actor_state + + def predict(self, obervations: np.ndarray, deterministic=True, state=None, episode_start=None): + actions = np.array(self.actor.apply(self.actor_state.params, obervations).mode()) + return actions, None + + +class RLTrainState(TrainState): + target_params: flax.core.FrozenDict + + +def main(): + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) + key, dropout_key1, dropout_key2, ent_key = jax.random.split(key, 4) + + # env setup + envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) + + assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" + + # Assume that all dimensions share the same bound + min_action = float(envs.action_space.low[0]) + max_action = float(envs.action_space.high[0]) + # For now assumed low=-1, high=1 + # TODO: handle any action space boundary + action_scale = ((max_action - min_action) / 2.0,) + action_bias = ((max_action + min_action) / 2.0,) + + envs.observation_space.dtype = np.float32 + rb = ReplayBuffer( + args.buffer_size, + envs.observation_space, + envs.action_space, + device="cpu", + handle_timeout_termination=True, + ) + + # Sort and drop top k quantiles to control overestimation. + n_quantiles = 25 + n_critics = 2 + quantiles_total = n_quantiles * n_critics + top_quantiles_to_drop_per_net = args.top_quantiles_to_drop_per_net + n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * n_critics + + # TRY NOT TO MODIFY: start the game + obs = envs.reset() + actor = Actor( + action_dim=np.prod(envs.action_space.shape), + n_units=args.n_units, + ) + actor_state = TrainState.create( + apply_fn=actor.apply, + params=actor.init(actor_key, obs), + tx=optax.adam(learning_rate=args.learning_rate), + ) + + ent_coef_init = 0.1 + ent_coef = Temperature(ent_coef_init) + ent_coef_state = TrainState.create( + apply_fn=ent_coef.apply, params=ent_coef.init(ent_key)["params"], tx=optax.adam(learning_rate=args.learning_rate) + ) + + agent = Agent(actor, actor_state) + + qf = QNetwork( + dropout_rate=args.dropout_rate, use_layer_norm=args.layer_norm, n_units=args.n_units, n_quantiles=n_quantiles + ) + qf1_state = RLTrainState.create( + apply_fn=qf.apply, + params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf1_key, "dropout": dropout_key1}, + obs, + jnp.array([envs.action_space.sample()]), + ), + tx=optax.adam(learning_rate=args.learning_rate), + ) + qf2_state = RLTrainState.create( + apply_fn=qf.apply, + params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), + target_params=qf.init( + {"params": qf2_key, "dropout": dropout_key2}, + obs, + jnp.array([envs.action_space.sample()]), + ), + tx=optax.adam(learning_rate=args.learning_rate), + ) + actor.apply = jax.jit(actor.apply) + qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) + + @jax.jit + def update_critic( + actor_state: TrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + ent_coef_state: TrainState, + observations: np.ndarray, + actions: np.ndarray, + next_observations: np.ndarray, + rewards: np.ndarray, + dones: np.ndarray, + key: jnp.ndarray, + ): + # TODO Maybe pre-generate a lot of random keys + # also check https://jax.readthedocs.io/en/latest/jax.random.html + key, noise_key, dropout_key_1, dropout_key_2 = jax.random.split(key, 4) + key, dropout_key_3, dropout_key_4 = jax.random.split(key, 3) + # sample action from the actor + dist = actor.apply(actor_state.params, next_observations) + next_state_actions = dist.sample(seed=noise_key) + next_log_prob = dist.log_prob(next_state_actions) + + ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) + + + qf1_next_quantiles = qf.apply( + qf1_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_1}, + ) + qf2_next_quantiles = qf.apply( + qf2_state.target_params, + next_observations, + next_state_actions, + True, + rngs={"dropout": dropout_key_2}, + ) + + # Concatenate quantiles from both critics to get a single tensor + # batch x quantiles + qf_next_quantiles = jnp.concatenate((qf1_next_quantiles, qf2_next_quantiles), axis=1) + + # sort next quantiles with jax + next_quantiles = jnp.sort(qf_next_quantiles) + # Keep only the quantiles we need + next_target_quantiles = next_quantiles[:, :n_target_quantiles] + + # td error + entropy term + next_target_quantiles = next_target_quantiles - ent_coef_value * next_log_prob.reshape(-1, 1) + target_quantiles = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * next_target_quantiles + + # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). + target_quantiles = jnp.expand_dims(target_quantiles, axis=1) + + def huber_quantile_loss(params, noise_key): + # Compute huber quantile loss + current_quantiles = qf.apply(params, observations, actions, True, rngs={"dropout": noise_key}) + # convert to shape: (batch_size, n_quantiles, 1) for broadcast + current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) + + # Cumulative probabilities to calculate quantiles. + # shape: (n_quantiles,) + cum_prob = (jnp.arange(n_quantiles, dtype=jnp.float32) + 0.5) / n_quantiles + # convert to shape: (1, n_quantiles, 1) for broadcast + cum_prob = jnp.expand_dims(cum_prob, axis=(0, -1)) + + # pairwise_delta: (batch_size, n_quantiles, n_target_quantiles) + pairwise_delta = target_quantiles - current_quantiles + abs_pairwise_delta = jnp.abs(pairwise_delta) + huber_loss = jnp.where(abs_pairwise_delta > 1, abs_pairwise_delta - 0.5, pairwise_delta**2 * 0.5) + loss = jnp.abs(cum_prob - (pairwise_delta < 0).astype(jnp.float32)) * huber_loss + return loss.mean() + + qf1_loss_value, grads1 = jax.value_and_grad(huber_quantile_loss, has_aux=False)(qf1_state.params, dropout_key_3) + qf2_loss_value, grads2 = jax.value_and_grad(huber_quantile_loss, has_aux=False)(qf2_state.params, dropout_key_4) + qf1_state = qf1_state.apply_gradients(grads=grads1) + qf2_state = qf2_state.apply_gradients(grads=grads2) + + return ( + (qf1_state, qf2_state, ent_coef_state), + (qf1_loss_value, qf2_loss_value), + key, + ) + + # @partial(jax.jit, static_argnames=["update_actor"]) + @jax.jit + def update_actor( + actor_state: RLTrainState, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + ent_coef_state: TrainState, + observations: np.ndarray, + key: jnp.ndarray, + ): + key, dropout_key, noise_key = jax.random.split(key, 3) + + def actor_loss(params): + + dist = actor.apply(actor_state.params, observations) + actor_actions = dist.sample(seed=noise_key) + log_prob = dist.log_prob(actions).reshape(-1, 1) + + qf_pi = ( + qf.apply( + qf1_state.params, + observations, + actor_actions, + True, + rngs={"dropout": dropout_key}, + ) + # .mean(axis=2) TODO: add second qf + .mean(axis=1, keepdims=True) + ) + ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) + return (ent_coef_value * log_prob - qf_pi).mean() + + actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + actor_state = actor_state.apply_gradients(grads=grads) + + qf1_state = qf1_state.replace( + target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) + ) + qf2_state = qf2_state.replace( + target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) + ) + return actor_state, (qf1_state, qf2_state), actor_loss_value, key + + @jax.jit + def train(data: ReplayBufferSamplesNp, n_updates: int, qf1_state, qf2_state, actor_state, ent_coef_state, key): + for i in range(args.gradient_steps): + n_updates += 1 + + def slice(x): + assert x.shape[0] % args.gradient_steps == 0 + batch_size = args.batch_size + batch_size = x.shape[0] // args.gradient_steps + return x[batch_size * i : batch_size * (i + 1)] + + ((qf1_state, qf2_state, ent_coef_state), (qf1_loss_value, qf2_loss_value), key,) = update_critic( + actor_state, + qf1_state, + qf2_state, + ent_coef_state, + slice(data.observations), + slice(data.actions), + slice(data.next_observations), + slice(data.rewards), + slice(data.dones), + key, + ) + + # sanity check + # otherwise must use update_actor=n_updates % args.policy_frequency, + # which is not jitable + # assert args.policy_frequency <= args.gradient_steps + (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( + actor_state, + qf1_state, + qf2_state, + ent_coef_state, + slice(data.observations), + key, + # update_actor=((i + 1) % args.policy_frequency) == 0, + # update_actor=(n_updates % args.policy_frequency) == 0, + ) + return ( + n_updates, + qf1_state, + qf2_state, + actor_state, + ent_coef_state, + key, + (qf1_loss_value, qf2_loss_value, actor_loss_value), + ) + + start_time = time.time() + n_updates = 0 + for global_step in tqdm(range(args.total_timesteps)): + # for global_step in range(args.total_timesteps): + # ALGO LOGIC: put action logic here + if global_step < args.learning_starts: + actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)]) + else: + # TODO: JIT sampling? + key, exploration_key = jax.random.split(key, 2) + actions = np.array(actor.apply(actor_state.params, obs).sample(seed=exploration_key)) + + # TRY NOT TO MODIFY: execute the game and log data. + next_obs, rewards, dones, infos = envs.step(actions) + + # TRY NOT TO MODIFY: record rewards for plotting purposes + for info in infos: + if "episode" in info.keys(): + if args.verbose >= 2: + print(f"global_step={global_step + 1}, episodic_return={info['episode']['r']:.2f}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) + break + + # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` + real_next_obs = next_obs.copy() + for idx, done in enumerate(dones): + if done: + real_next_obs[idx] = infos[idx]["terminal_observation"] + # Timeout handling done inside the replay buffer + # if infos[idx].get("TimeLimit.truncated", False) == True: + # real_dones[idx] = False + + rb.add(obs, real_next_obs, actions, rewards, dones, infos) + + # TRY NOT TO MODIFY: CRUCIAL step easy to overlook + obs = next_obs + + # ALGO LOGIC: training. + if global_step > args.learning_starts: + # Sample all at once for efficiency (so we can jit the for loop) + data = rb.sample(args.batch_size * args.gradient_steps) + # Convert to numpy + data = ReplayBufferSamplesNp( + data.observations.numpy(), + data.actions.numpy(), + data.next_observations.numpy(), + data.dones.numpy().flatten(), + data.rewards.numpy().flatten(), + ) + + ( + n_updates, + qf1_state, + qf2_state, + actor_state, + ent_coef_state, + key, + (qf1_loss_value, qf2_loss_value, actor_loss_value), + ) = train( + data, + n_updates, + qf1_state, + qf2_state, + actor_state, + ent_coef_state, + key, + ) + + fps = int(global_step / (time.time() - start_time)) + if args.eval_freq > 0 and global_step % args.eval_freq == 0: + agent.actor, agent.actor_state = actor, actor_state + mean_reward, std_reward = evaluate_policy(agent, eval_envs, n_eval_episodes=args.n_eval_episodes) + print(f"global_step={global_step}, mean_eval_reward={mean_reward:.2f} +/- {std_reward:.2f} - {fps} fps") + writer.add_scalar("charts/mean_eval_reward", mean_reward, global_step) + writer.add_scalar("charts/std_eval_reward", std_reward, global_step) + + if global_step % 100 == 0: + writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) + writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) + # writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) + # writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) + writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) + if args.verbose >= 2: + print("FPS:", fps) + writer.add_scalar( + "charts/SPS", + int(global_step / (time.time() - start_time)), + global_step, + ) + + envs.close() + writer.close() + + +if __name__ == "__main__": + + try: + main() + except KeyboardInterrupt: + pass From 21361c3c16cb7d6c6d85c01aa6b83bddeee77d1a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Sep 2022 20:08:20 +0200 Subject: [PATCH 28/31] Bug fixes and faster sampling (still not working) --- cleanrl/tqc_sac_jax.py | 116 +++++++++++++++++++++++++++++++++++------ 1 file changed, 101 insertions(+), 15 deletions(-) diff --git a/cleanrl/tqc_sac_jax.py b/cleanrl/tqc_sac_jax.py index 5e06c154f..a86e1a5d1 100644 --- a/cleanrl/tqc_sac_jax.py +++ b/cleanrl/tqc_sac_jax.py @@ -3,8 +3,9 @@ import os import random import time +from dataclasses import dataclass from distutils.util import strtobool -from typing import NamedTuple, Optional, Sequence, Any +from typing import Any, NamedTuple, Optional, Sequence, Union import flax import flax.linen as nn @@ -14,6 +15,7 @@ import numpy as np import optax import pybullet_envs # noqa +import tensorflow_probability from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.env_util import make_vec_env @@ -21,7 +23,6 @@ from stable_baselines3.common.vec_env import DummyVecEnv from torch.utils.tensorboard import SummaryWriter from tqdm.rich import tqdm -import tensorflow_probability tfp = tensorflow_probability.substrates.jax tfd = tfp.distributions @@ -35,6 +36,57 @@ class ReplayBufferSamplesNp(NamedTuple): rewards: np.ndarray +class RescaleAction(gym.ActionWrapper): + """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. + + The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` + or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. + + """ + + def __init__( + self, + env: gym.Env, + min_action: int = -1, + max_action: int = 1, + ): + """Initializes the :class:`RescaleAction` wrapper. + + Args: + env (Env): The environment to apply the wrapper + min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. + max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. + """ + assert isinstance(env.action_space, gym.spaces.Box), f"expected Box action space, got {type(env.action_space)}" + assert np.less_equal(min_action, max_action).all(), (min_action, max_action) + + super().__init__(env) + self.min_action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action + self.max_action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action + self.action_space = gym.spaces.Box( + low=min_action, + high=max_action, + shape=env.action_space.shape, + dtype=env.action_space.dtype, + ) + + def action(self, action): + """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. + + Args: + action: The action to rescale + + Returns: + The rescaled action + """ + action = np.clip(action, self.min_action, self.max_action) + low = self.env.action_space.low + high = self.env.action_space.high + action = low + (high - low) * ((action - self.min_action) / (self.max_action - self.min_action)) + action = np.clip(action, low, high) + return action + + def parse_args(): # fmt: off parser = argparse.ArgumentParser() @@ -91,6 +143,8 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video=False, run_name=""): def thunk(): env = gym.make(env_id) + if env_id == "Pendulum-v1": + env = RescaleAction(env) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: @@ -175,13 +229,14 @@ def __call__(self, x): return dist +@dataclass class Agent: - def __init__(self, actor, actor_state) -> None: - self.actor = actor - self.actor_state = actor_state + actor: Actor + actor_state: TrainState def predict(self, obervations: np.ndarray, deterministic=True, state=None, episode_start=None): - actions = np.array(self.actor.apply(self.actor_state.params, obervations).mode()) + # actions = np.array(self.actor.apply(self.actor_state.params, obervations).mode()) + actions = np.array(self.select_action(self.actor_state, obervations)) return actions, None @@ -219,7 +274,7 @@ def main(): # env setup envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed, wrapper_class=RescaleAction) assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" @@ -228,6 +283,7 @@ def main(): max_action = float(envs.action_space.high[0]) # For now assumed low=-1, high=1 # TODO: handle any action space boundary + action_scale = ((max_action - min_action) / 2.0,) action_bias = ((max_action + min_action) / 2.0,) @@ -247,6 +303,9 @@ def main(): top_quantiles_to_drop_per_net = args.top_quantiles_to_drop_per_net n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * n_critics + # automatically set target entropy if needed + target_entropy = -np.prod(envs.action_space.shape).astype(np.float32) + # TRY NOT TO MODIFY: start the game obs = envs.reset() actor = Actor( @@ -259,7 +318,7 @@ def main(): tx=optax.adam(learning_rate=args.learning_rate), ) - ent_coef_init = 0.1 + ent_coef_init = 1.0 ent_coef = Temperature(ent_coef_init) ent_coef_state = TrainState.create( apply_fn=ent_coef.apply, params=ent_coef.init(ent_key)["params"], tx=optax.adam(learning_rate=args.learning_rate) @@ -301,6 +360,16 @@ def main(): actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply, static_argnames=("dropout_rate", "use_layer_norm")) + @jax.jit + def sample_action(actor_state, obervations, key): + return actor.apply(actor_state.params, obervations).sample(seed=key) + + @jax.jit + def select_action(actor_state, obervations): + return actor.apply(actor_state.params, obervations).mode() + + agent.select_action = select_action + @jax.jit def update_critic( actor_state: TrainState, @@ -325,7 +394,6 @@ def update_critic( ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) - qf1_next_quantiles = qf.apply( qf1_state.target_params, next_observations, @@ -401,7 +469,7 @@ def update_actor( def actor_loss(params): - dist = actor.apply(actor_state.params, observations) + dist = actor.apply(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actions).reshape(-1, 1) @@ -417,18 +485,33 @@ def actor_loss(params): .mean(axis=1, keepdims=True) ) ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) - return (ent_coef_value * log_prob - qf_pi).mean() + return (ent_coef_value * log_prob - qf_pi).mean(), -log_prob.mean() - actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) + (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) + # TODO: move update to critic update qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) - return actor_state, (qf1_state, qf2_state), actor_loss_value, key + return actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy + + @jax.jit + def update_temperature(ent_coef_state: TrainState, + entropy: float): + + def temperature_loss(temp_params): + ent_coef_value = ent_coef.apply({"params": temp_params}) + temp_loss = ent_coef_value * (entropy - target_entropy).mean() + return temp_loss + + ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) + ent_coef_state = ent_coef_state.apply_gradients(grads=grads) + + return ent_coef_state, ent_coef_loss @jax.jit def train(data: ReplayBufferSamplesNp, n_updates: int, qf1_state, qf2_state, actor_state, ent_coef_state, key): @@ -458,7 +541,7 @@ def slice(x): # otherwise must use update_actor=n_updates % args.policy_frequency, # which is not jitable # assert args.policy_frequency <= args.gradient_steps - (actor_state, (qf1_state, qf2_state), actor_loss_value, key,) = update_actor( + (actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy) = update_actor( actor_state, qf1_state, qf2_state, @@ -468,6 +551,8 @@ def slice(x): # update_actor=((i + 1) % args.policy_frequency) == 0, # update_actor=(n_updates % args.policy_frequency) == 0, ) + ent_coef_state, _ = update_temperature(ent_coef_state, entropy) + return ( n_updates, qf1_state, @@ -488,7 +573,8 @@ def slice(x): else: # TODO: JIT sampling? key, exploration_key = jax.random.split(key, 2) - actions = np.array(actor.apply(actor_state.params, obs).sample(seed=exploration_key)) + # actions = np.array(actor.apply(actor_state.params, obs).sample(seed=exploration_key)) + actions = np.array(sample_action(actor_state, obs, exploration_key)) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, dones, infos = envs.step(actions) From c88338690eab7be9c02b63851d7b222bc8b08fa2 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Sep 2022 21:38:32 +0200 Subject: [PATCH 29/31] Bug fixes, SAC now workingo --- cleanrl/tqc_sac_jax.py | 68 ++++++++++++++++++++---------------------- 1 file changed, 33 insertions(+), 35 deletions(-) diff --git a/cleanrl/tqc_sac_jax.py b/cleanrl/tqc_sac_jax.py index a86e1a5d1..79947319c 100644 --- a/cleanrl/tqc_sac_jax.py +++ b/cleanrl/tqc_sac_jax.py @@ -126,7 +126,7 @@ def parse_args(): parser.add_argument("--dropout-rate", type=float, default=0.0) # Argument for layer normalization parser.add_argument("--layer-norm", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True) - parser.add_argument("--policy-frequency", type=int, default=2, + parser.add_argument("--policy-frequency", type=int, default=1, help="the frequency of training policy (delayed)") parser.add_argument("--eval-freq", type=int, default=-1) parser.add_argument("--n-eval-envs", type=int, default=5) @@ -143,8 +143,7 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video=False, run_name=""): def thunk(): env = gym.make(env_id) - if env_id == "Pendulum-v1": - env = RescaleAction(env) + # env = RescaleAction(env) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: @@ -160,12 +159,17 @@ def thunk(): # from https://github.com/ikostrikov/walk_in_the_park # otherwise mode is not define for Squashed Gaussian class TanhTransformedDistribution(tfd.TransformedDistribution): - def __init__(self, distribution: tfd.Distribution, validate_args: bool = False): + def __init__(self, distribution: tfd.Distribution, validate_args: bool = False, threshold=0.999): + self._threshold = threshold super().__init__(distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args) def mode(self) -> jnp.ndarray: return self.bijector.forward(self.distribution.mode()) + # def log_prob(self, actions: jnp.ndarray) -> jnp.ndarray: + # actions = jnp.clip(actions, -self._threshold, self._threshold) + # return super().log_prob(actions) + @classmethod def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): td_properties = super()._parameter_properties(dtype, num_classes=num_classes) @@ -173,13 +177,13 @@ def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): return td_properties -class Temperature(nn.Module): - initial_temperature: float = 1.0 +class EntropyCoef(nn.Module): + ent_coef_init: float = 1.0 @nn.compact def __call__(self) -> jnp.ndarray: - log_temp = self.param("log_temp", init_fn=lambda key: jnp.full((), jnp.log(self.initial_temperature))) - return jnp.exp(log_temp) + log_ent_coef = self.param("log_ent_coef", init_fn=lambda key: jnp.full((), jnp.log(self.ent_coef_init))) + return jnp.exp(log_ent_coef) # ALGO LOGIC: initialize agent here: @@ -223,6 +227,7 @@ def __call__(self, x): mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + # dist = tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)) dist = TanhTransformedDistribution( tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), ) @@ -274,7 +279,7 @@ def main(): # env setup envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed, wrapper_class=RescaleAction) + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) # wrapper_class=RescaleAction assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" @@ -319,7 +324,7 @@ def main(): ) ent_coef_init = 1.0 - ent_coef = Temperature(ent_coef_init) + ent_coef = EntropyCoef(ent_coef_init) ent_coef_state = TrainState.create( apply_fn=ent_coef.apply, params=ent_coef.init(ent_key)["params"], tx=optax.adam(learning_rate=args.learning_rate) ) @@ -362,7 +367,9 @@ def main(): @jax.jit def sample_action(actor_state, obervations, key): - return actor.apply(actor_state.params, obervations).sample(seed=key) + dist = actor.apply(actor_state.params, obervations) + action = dist.sample(seed=key) + return action @jax.jit def select_action(actor_state, obervations): @@ -419,6 +426,7 @@ def update_critic( next_target_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term + # ent_coef_value = 0.0 next_target_quantiles = next_target_quantiles - ent_coef_value * next_log_prob.reshape(-1, 1) target_quantiles = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * next_target_quantiles @@ -471,20 +479,14 @@ def actor_loss(params): dist = actor.apply(params, observations) actor_actions = dist.sample(seed=noise_key) - log_prob = dist.log_prob(actions).reshape(-1, 1) - - qf_pi = ( - qf.apply( - qf1_state.params, - observations, - actor_actions, - True, - rngs={"dropout": dropout_key}, - ) - # .mean(axis=2) TODO: add second qf - .mean(axis=1, keepdims=True) - ) + log_prob = dist.log_prob(actor_actions).reshape(-1, 1) + + qf_pi = qf.apply(qf1_state.params, observations, actor_actions, True, rngs={"dropout": dropout_key},).mean( + axis=1, keepdims=True + ) # .mean(axis=2) TODO: add second qf + ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) + # ent_coef_value = 0.01 return (ent_coef_value * log_prob - qf_pi).mean(), -log_prob.mean() (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) @@ -500,13 +502,12 @@ def actor_loss(params): return actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy @jax.jit - def update_temperature(ent_coef_state: TrainState, - entropy: float): - + def update_temperature(ent_coef_state: TrainState, entropy: float): def temperature_loss(temp_params): ent_coef_value = ent_coef.apply({"params": temp_params}) - temp_loss = ent_coef_value * (entropy - target_entropy).mean() - return temp_loss + # ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean() + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) ent_coef_state = ent_coef_state.apply_gradients(grads=grads) @@ -537,10 +538,6 @@ def slice(x): key, ) - # sanity check - # otherwise must use update_actor=n_updates % args.policy_frequency, - # which is not jitable - # assert args.policy_frequency <= args.gradient_steps (actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy) = update_actor( actor_state, qf1_state, @@ -548,8 +545,6 @@ def slice(x): ent_coef_state, slice(data.observations), key, - # update_actor=((i + 1) % args.policy_frequency) == 0, - # update_actor=(n_updates % args.policy_frequency) == 0, ) ent_coef_state, _ = update_temperature(ent_coef_state, entropy) @@ -576,6 +571,7 @@ def slice(x): # actions = np.array(actor.apply(actor_state.params, obs).sample(seed=exploration_key)) actions = np.array(sample_action(actor_state, obs, exploration_key)) + # actions = np.clip(actions, -1.0, 1.0) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, dones, infos = envs.step(actions) @@ -642,6 +638,8 @@ def slice(x): writer.add_scalar("charts/std_eval_reward", std_reward, global_step) if global_step % 100 == 0: + ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) + writer.add_scalar("losses/ent_coef_value", ent_coef_value.item(), global_step) writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) # writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) From f455b4e81af39d76e18ab059d22e0ca8f8efa5ec Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 23 Sep 2022 21:40:24 +0200 Subject: [PATCH 30/31] Cleanup --- cleanrl/tqc_sac_jax.py | 73 +++--------------------------------------- 1 file changed, 5 insertions(+), 68 deletions(-) diff --git a/cleanrl/tqc_sac_jax.py b/cleanrl/tqc_sac_jax.py index 79947319c..77d6b1189 100644 --- a/cleanrl/tqc_sac_jax.py +++ b/cleanrl/tqc_sac_jax.py @@ -5,7 +5,7 @@ import time from dataclasses import dataclass from distutils.util import strtobool -from typing import Any, NamedTuple, Optional, Sequence, Union +from typing import Any, NamedTuple, Optional, Sequence import flax import flax.linen as nn @@ -36,57 +36,6 @@ class ReplayBufferSamplesNp(NamedTuple): rewards: np.ndarray -class RescaleAction(gym.ActionWrapper): - """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. - - The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` - or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. - - """ - - def __init__( - self, - env: gym.Env, - min_action: int = -1, - max_action: int = 1, - ): - """Initializes the :class:`RescaleAction` wrapper. - - Args: - env (Env): The environment to apply the wrapper - min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. - max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. - """ - assert isinstance(env.action_space, gym.spaces.Box), f"expected Box action space, got {type(env.action_space)}" - assert np.less_equal(min_action, max_action).all(), (min_action, max_action) - - super().__init__(env) - self.min_action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action - self.max_action = np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action - self.action_space = gym.spaces.Box( - low=min_action, - high=max_action, - shape=env.action_space.shape, - dtype=env.action_space.dtype, - ) - - def action(self, action): - """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. - - Args: - action: The action to rescale - - Returns: - The rescaled action - """ - action = np.clip(action, self.min_action, self.max_action) - low = self.env.action_space.low - high = self.env.action_space.high - action = low + (high - low) * ((action - self.min_action) / (self.max_action - self.min_action)) - action = np.clip(action, low, high) - return action - - def parse_args(): # fmt: off parser = argparse.ArgumentParser() @@ -159,17 +108,12 @@ def thunk(): # from https://github.com/ikostrikov/walk_in_the_park # otherwise mode is not define for Squashed Gaussian class TanhTransformedDistribution(tfd.TransformedDistribution): - def __init__(self, distribution: tfd.Distribution, validate_args: bool = False, threshold=0.999): - self._threshold = threshold + def __init__(self, distribution: tfd.Distribution, validate_args: bool = False): super().__init__(distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args) def mode(self) -> jnp.ndarray: return self.bijector.forward(self.distribution.mode()) - # def log_prob(self, actions: jnp.ndarray) -> jnp.ndarray: - # actions = jnp.clip(actions, -self._threshold, self._threshold) - # return super().log_prob(actions) - @classmethod def _parameter_properties(cls, dtype: Optional[Any], num_classes=None): td_properties = super()._parameter_properties(dtype, num_classes=num_classes) @@ -279,18 +223,14 @@ def main(): # env setup envs = DummyVecEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) - eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) # wrapper_class=RescaleAction + eval_envs = make_vec_env(args.env_id, n_envs=args.n_eval_envs, seed=args.seed) assert isinstance(envs.action_space, gym.spaces.Box), "only continuous action space is supported" # Assume that all dimensions share the same bound - min_action = float(envs.action_space.low[0]) - max_action = float(envs.action_space.high[0]) + # min_action = float(envs.action_space.low[0]) + # max_action = float(envs.action_space.high[0]) # For now assumed low=-1, high=1 - # TODO: handle any action space boundary - - action_scale = ((max_action - min_action) / 2.0,) - action_bias = ((max_action + min_action) / 2.0,) envs.observation_space.dtype = np.float32 rb = ReplayBuffer( @@ -463,7 +403,6 @@ def huber_quantile_loss(params, noise_key): key, ) - # @partial(jax.jit, static_argnames=["update_actor"]) @jax.jit def update_actor( actor_state: RLTrainState, @@ -566,9 +505,7 @@ def slice(x): if global_step < args.learning_starts: actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)]) else: - # TODO: JIT sampling? key, exploration_key = jax.random.split(key, 2) - # actions = np.array(actor.apply(actor_state.params, obs).sample(seed=exploration_key)) actions = np.array(sample_action(actor_state, obs, exploration_key)) # actions = np.clip(actions, -1.0, 1.0) From 7eb2c4fbc59bcbd2694cf863806bdecc44dfc9e5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 24 Sep 2022 18:49:49 +0200 Subject: [PATCH 31/31] Match DroQ implementation --- cleanrl/tqc_sac_jax.py | 48 +++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/cleanrl/tqc_sac_jax.py b/cleanrl/tqc_sac_jax.py index 77d6b1189..ec75a418d 100644 --- a/cleanrl/tqc_sac_jax.py +++ b/cleanrl/tqc_sac_jax.py @@ -366,7 +366,6 @@ def update_critic( next_target_quantiles = next_quantiles[:, :n_target_quantiles] # td error + entropy term - # ent_coef_value = 0.0 next_target_quantiles = next_target_quantiles - ent_coef_value * next_log_prob.reshape(-1, 1) target_quantiles = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * args.gamma * next_target_quantiles @@ -412,7 +411,7 @@ def update_actor( observations: np.ndarray, key: jnp.ndarray, ): - key, dropout_key, noise_key = jax.random.split(key, 3) + key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) def actor_loss(params): @@ -420,25 +419,45 @@ def actor_loss(params): actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) - qf_pi = qf.apply(qf1_state.params, observations, actor_actions, True, rngs={"dropout": dropout_key},).mean( - axis=1, keepdims=True - ) # .mean(axis=2) TODO: add second qf + qf1_pi = qf.apply( + qf1_state.params, + observations, + actor_actions, + True, + rngs={"dropout": dropout_key_1}, + ) + qf2_pi = qf.apply( + qf2_state.params, + observations, + actor_actions, + True, + rngs={"dropout": dropout_key_2}, + ) + qf1_pi = jnp.expand_dims(qf1_pi, axis=-1) + qf2_pi = jnp.expand_dims(qf2_pi, axis=-1) + + # Concatenate quantiles from both critics + # (batch, n_quantiles, n_critics) + qf_pi = jnp.concatenate((qf1_pi, qf2_pi), axis=1) + qf_pi = qf_pi.mean(axis=2).mean(axis=1, keepdims=True) ent_coef_value = ent_coef.apply({"params": ent_coef_state.params}) - # ent_coef_value = 0.01 return (ent_coef_value * log_prob - qf_pi).mean(), -log_prob.mean() (actor_loss_value, entropy), grads = jax.value_and_grad(actor_loss, has_aux=True)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) - # TODO: move update to critic update + return actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy + + @jax.jit + def soft_update(qf1_state: RLTrainState, qf2_state: RLTrainState): qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) - return actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy + return qf1_state, qf2_state @jax.jit def update_temperature(ent_coef_state: TrainState, entropy: float): @@ -454,7 +473,15 @@ def temperature_loss(temp_params): return ent_coef_state, ent_coef_loss @jax.jit - def train(data: ReplayBufferSamplesNp, n_updates: int, qf1_state, qf2_state, actor_state, ent_coef_state, key): + def train( + data: ReplayBufferSamplesNp, + n_updates: int, + qf1_state: RLTrainState, + qf2_state: RLTrainState, + actor_state: TrainState, + ent_coef_state: TrainState, + key, + ): for i in range(args.gradient_steps): n_updates += 1 @@ -476,6 +503,7 @@ def slice(x): slice(data.dones), key, ) + qf1_state, qf2_state = soft_update(qf1_state, qf2_state) (actor_state, (qf1_state, qf2_state), actor_loss_value, key, entropy) = update_actor( actor_state, @@ -499,8 +527,8 @@ def slice(x): start_time = time.time() n_updates = 0 + # for global_step in range(args.total_timesteps): for global_step in tqdm(range(args.total_timesteps)): - # for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)])