-
Notifications
You must be signed in to change notification settings - Fork 263
Description
Hello @mimoralea,
I am having difficulty in getting SAC algorithm to perform well.
I was trying to run SAC on Pendulum, Hopper & HalfCheetah environments while training SAC, however the algorithm doesn't seem to learn anything during training, with rewards being struck at 30~40 max. After rechecking my code against the grokking github code multiple times & not finding any mistakes, I tried modifying hyper parameters, with no luck.
Then, I tried running the grokking training notebook code, and it just goes on running for multiple hours, without any message or ending. The only difference being, I am using "import gymnasium as gym", although I am not sure if that could be a problem... I have also attached the requirements.txt of my conda env for reference.
I would be obliged if you could please help me on finding the root of the issue.
requirements.txt
Code:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import os, os.path
import time, tempfile, json, subprocess, base64, io, gc, random, glob
import matplotlib.pyplot
from itertools import count
from IPython.display import display, HTML
#import gym
import gymnasium as gym
from gym import wrappers
import torch.nn.functional as F
import pybullet_envs_gymnasium
from IPython.display import display
LEAVE_PRINT_EVERY_N_SECS = 300
ERASE_LINE = '\x1b[2K'
EPS = 1e-6
BEEP = lambda: os.system("printf '\a'")
RESULTS_DIR = os.path.join('..', 'results')
SEEDS = (12, 34, 56, 78, 90)
def get_gif_html(env_videos, title, subtitle_eps=None, max_n_videos=4):
videos = np.array(env_videos)
if len(videos) == 0:
return
n_videos = max(1, min(max_n_videos, len(videos)))
idxs = np.linspace(0, len(videos) - 1, n_videos).astype(int) if n_videos > 1 else [-1,]
videos = videos[idxs,...]
strm = '<h2>{}<h2>'.format(title)
for video_path, meta_path in videos:
basename = os.path.splitext(video_path)[0]
gif_path = basename + '.gif'
if not os.path.exists(gif_path):
ps = subprocess.Popen(
('ffmpeg',
'-i', video_path,
'-r', '7',
'-f', 'image2pipe',
'-vcodec', 'ppm',
'-crf', '20',
'-vf', 'scale=512:-1',
'-'),
stdout=subprocess.PIPE)
output = subprocess.check_output(
('convert',
'-coalesce',
'-delay', '7',
'-loop', '0',
'-fuzz', '2%',
'+dither',
'-deconstruct',
'-layers', 'Optimize',
'-', gif_path),
stdin=ps.stdout)
ps.wait()
gif = io.open(gif_path, 'r+b').read()
encoded = base64.b64encode(gif)
with open(meta_path) as data_file:
meta = json.load(data_file)
html_tag = """
<h3>{0}<h3/>
<img src="data:image/gif;base64,{1}" />"""
prefix = 'Trial ' if subtitle_eps is None else 'Episode '
sufix = str(meta['episode_id'] if subtitle_eps is None \
else subtitle_eps[meta['episode_id']])
strm += html_tag.format(prefix + sufix, encoded.decode('ascii'))
return strm
class RenderUint8(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
def reset(self, **kwargs):
return self.env.reset(**kwargs)
def render(self, mode='rgb_array'):
frame = self.env.render(mode=mode)
return frame.astype(np.uint8)
def get_make_env_fn(**kargs):
def make_env_fn(env_name, seed=None, render=None, record=False,
unwrapped=False, monitor_mode=None,
inner_wrappers=None, outer_wrappers=None):
mdir = tempfile.mkdtemp()
env = None
if render:
try:
env = gym.make(env_name, render=render)
except:
pass
if env is None:
env = gym.make(env_name)
if seed is not None: env.seed(seed)
env = env.unwrapped if unwrapped else env
if inner_wrappers:
for wrapper in inner_wrappers:
env = wrapper(env)
env = wrappers.Monitor(
env, mdir, force=True,
mode=monitor_mode,
video_callable=lambda e_idx: record) if monitor_mode else env
if outer_wrappers:
for wrapper in outer_wrappers:
env = wrapper(env)
return env
return make_env_fn, kargs
class ReplayBuffer():
def __init__(self,
max_size=10000,
batch_size=64):
self.ss_mem = np.empty(shape=(max_size), dtype=np.ndarray)
self.as_mem = np.empty(shape=(max_size), dtype=np.ndarray)
self.rs_mem = np.empty(shape=(max_size), dtype=np.ndarray)
self.ps_mem = np.empty(shape=(max_size), dtype=np.ndarray)
self.ds_mem = np.empty(shape=(max_size), dtype=np.ndarray)
self.max_size = max_size
self.batch_size = batch_size
self._idx = 0
self.size = 0
def store(self, sample):
s, a, r, p, d = sample
self.ss_mem[self._idx] = s
self.as_mem[self._idx] = a
self.rs_mem[self._idx] = r
self.ps_mem[self._idx] = p
self.ds_mem[self._idx] = d
self._idx += 1
self._idx = self._idx % self.max_size
self.size += 1
self.size = min(self.size, self.max_size)
def sample(self, batch_size=None):
if batch_size == None:
batch_size = self.batch_size
idxs = np.random.choice(
self.size, batch_size, replace=False)
experiences = np.vstack(self.ss_mem[idxs]), \
np.vstack(self.as_mem[idxs]), \
np.vstack(self.rs_mem[idxs]), \
np.vstack(self.ps_mem[idxs]), \
np.vstack(self.ds_mem[idxs])
return experiences
def __len__(self):
return self.size
class FCQSA(nn.Module):
def __init__(self,
input_dim,
output_dim,
hidden_dims=(32,32),
activation_fc=F.relu):
super(FCQSA, self).__init__()
self.activation_fc = activation_fc
self.input_layer = nn.Linear(input_dim + output_dim, hidden_dims[0])
self.hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
self.hidden_layers.append(hidden_layer)
self.output_layer = nn.Linear(hidden_dims[-1], 1)
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
self.device = torch.device(device)
self.to(self.device)
def _format(self, state, action):
x, u = state, action
if not isinstance(x, torch.Tensor):
x = torch.tensor(x,
device=self.device,
dtype=torch.float32)
x = x.unsqueeze(0)
if not isinstance(u, torch.Tensor):
u = torch.tensor(u,
device=self.device,
dtype=torch.float32)
u = u.unsqueeze(0)
return x, u
def forward(self, state, action):
x, u = self._format(state, action)
x = self.activation_fc(self.input_layer(torch.cat((x, u), dim=1)))
for i, hidden_layer in enumerate(self.hidden_layers):
x = self.activation_fc(hidden_layer(x))
x = self.output_layer(x)
return x
def load(self, experiences):
states, actions, new_states, rewards, is_terminals = experiences
states = torch.from_numpy(states).float().to(self.device)
actions = torch.from_numpy(actions).float().to(self.device)
new_states = torch.from_numpy(new_states).float().to(self.device)
rewards = torch.from_numpy(rewards).float().to(self.device)
is_terminals = torch.from_numpy(is_terminals).float().to(self.device)
return states, actions, new_states, rewards, is_terminals
class FCGP(nn.Module):
def __init__(self,
input_dim,
action_bounds,
log_std_min=-20,
log_std_max=2,
hidden_dims=(32,32),
activation_fc=F.relu,
entropy_lr=0.001):
super(FCGP, self).__init__()
self.activation_fc = activation_fc
self.env_min, self.env_max = action_bounds
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.input_layer = nn.Linear(input_dim,
hidden_dims[0])
self.hidden_layers = nn.ModuleList()
for i in range(len(hidden_dims)-1):
hidden_layer = nn.Linear(
hidden_dims[i], hidden_dims[i+1])
self.hidden_layers.append(hidden_layer)
self.output_layer_mean = nn.Linear(hidden_dims[-1], len(self.env_max))
self.output_layer_log_std = nn.Linear(hidden_dims[-1], len(self.env_max))
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
self.device = torch.device(device)
self.to(self.device)
self.env_min = torch.tensor(self.env_min,
device=self.device,
dtype=torch.float32)
self.env_max = torch.tensor(self.env_max,
device=self.device,
dtype=torch.float32)
self.nn_min = F.tanh(torch.Tensor([float('-inf')])).to(self.device)
self.nn_max = F.tanh(torch.Tensor([float('inf')])).to(self.device)
self.rescale_fn = lambda x: (x - self.nn_min) * (self.env_max - self.env_min) / \
(self.nn_max - self.nn_min) + self.env_min
self.target_entropy = -np.prod(self.env_max.shape)
self.logalpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha_optimizer = optim.Adam([self.logalpha], lr=entropy_lr)
def _format(self, state):
x = state
if not isinstance(x, torch.Tensor):
x = torch.tensor(x,
device=self.device,
dtype=torch.float32)
x = x.unsqueeze(0)
return x
def forward(self, state):
x = self._format(state)
x = self.activation_fc(self.input_layer(x))
for hidden_layer in self.hidden_layers:
x = self.activation_fc(hidden_layer(x))
x_mean = self.output_layer_mean(x)
x_log_std = self.output_layer_log_std(x)
x_log_std = torch.clamp(x_log_std,
self.log_std_min,
self.log_std_max)
return x_mean, x_log_std
def full_pass(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
pi_s = Normal(mean, log_std.exp())
pre_tanh_action = pi_s.rsample()
tanh_action = torch.tanh(pre_tanh_action)
action = self.rescale_fn(tanh_action)
log_prob = pi_s.log_prob(pre_tanh_action) - torch.log(
(1 - tanh_action.pow(2)).clamp(0, 1) + epsilon)
log_prob = log_prob.sum(dim=1, keepdim=True)
return action, log_prob, self.rescale_fn(torch.tanh(mean)) # used only in optimize_model().
def _update_exploration_ratio(self, greedy_action, action_taken):
env_min, env_max = self.env_min.cpu().numpy(), self.env_max.cpu().numpy()
self.exploration_ratio = np.mean(abs((greedy_action - action_taken)/(env_max - env_min)))
def _get_actions(self, state):
mean, log_std = self.forward(state)
action = self.rescale_fn(torch.tanh(Normal(mean, log_std.exp()).sample()))
greedy_action = self.rescale_fn(torch.tanh(mean))
random_action = np.random.uniform(low=self.env_min.cpu().numpy(),
high=self.env_max.cpu().numpy())
action_shape = self.env_max.cpu().numpy().shape
action = action.detach().cpu().numpy().reshape(action_shape)
greedy_action = greedy_action.detach().cpu().numpy().reshape(action_shape)
random_action = random_action.reshape(action_shape)
return action, greedy_action, random_action
def select_random_action(self, state):
action, greedy_action, random_action = self._get_actions(state)
self._update_exploration_ratio(greedy_action, random_action)
return random_action # use in training-interaction-step, if len(RB) < min samples.
def select_greedy_action(self, state):
action, greedy_action, random_action = self._get_actions(state)
self._update_exploration_ratio(greedy_action, greedy_action)
return greedy_action # used only in evaluate().
def select_action(self, state):
action, greedy_action, random_action = self._get_actions(state)
self._update_exploration_ratio(greedy_action, action)
return action # use in training-interaction-step, if len(RB) > min samples.
class SAC():
def __init__(self,
replay_buffer_fn,
policy_model_fn,
policy_max_grad_norm,
policy_optimizer_fn,
policy_optimizer_lr,
value_model_fn,
value_max_grad_norm,
value_optimizer_fn,
value_optimizer_lr,
n_warmup_batches,
update_target_every_steps,
tau):
self.replay_buffer_fn = replay_buffer_fn
self.policy_model_fn = policy_model_fn
self.policy_max_grad_norm = policy_max_grad_norm
self.policy_optimizer_fn = policy_optimizer_fn
self.policy_optimizer_lr = policy_optimizer_lr
self.value_model_fn = value_model_fn
self.value_max_grad_norm = value_max_grad_norm
self.value_optimizer_fn = value_optimizer_fn
self.value_optimizer_lr = value_optimizer_lr
self.n_warmup_batches = n_warmup_batches
self.update_target_every_steps = update_target_every_steps
self.tau = tau
def optimize_model(self, experiences):
states, actions, rewards, next_states, is_terminals = experiences
batch_size = len(is_terminals)
# policy loss
current_actions, logpi_s, _ = self.policy_model.full_pass(states)
target_alpha = (logpi_s + self.policy_model.target_entropy).detach()
alpha_loss = -(self.policy_model.logalpha * target_alpha).mean()
self.policy_model.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.policy_model.alpha_optimizer.step()
alpha = self.policy_model.logalpha.exp()
current_q_sa_a = self.online_value_model_a(states, current_actions)
current_q_sa_b = self.online_value_model_b(states, current_actions)
current_q_sa = torch.min(current_q_sa_a, current_q_sa_b)
policy_loss = (alpha * logpi_s - current_q_sa).mean()
# Q loss
ap, logpi_sp, _ = self.policy_model.full_pass(next_states)
q_spap_a = self.target_value_model_a(next_states, ap)
q_spap_b = self.target_value_model_b(next_states, ap)
q_spap = torch.min(q_spap_a, q_spap_b) - alpha * logpi_sp
target_q_sa = (rewards + self.gamma * q_spap * (1 - is_terminals)).detach()
q_sa_a = self.online_value_model_a(states, actions) # actions are from RB.
q_sa_b = self.online_value_model_b(states, actions)
qa_loss = (q_sa_a - target_q_sa).pow(2).mul(0.5).mean()
qb_loss = (q_sa_b - target_q_sa).pow(2).mul(0.5).mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(),
self.policy_max_grad_norm)
self.policy_optimizer.step()
self.value_optimizer_a.zero_grad()
qa_loss.backward()
torch.nn.utils.clip_grad_norm_(self.online_value_model_a.parameters(),
self.value_max_grad_norm)
self.value_optimizer_a.step()
self.value_optimizer_b.zero_grad()
qb_loss.backward()
torch.nn.utils.clip_grad_norm_(self.online_value_model_b.parameters(),
self.value_max_grad_norm)
self.value_optimizer_b.step()
"""
self.policy_optimizer.zero_grad()
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy_model.parameters(),
self.policy_max_grad_norm)
self.policy_optimizer.step()
"""
def interaction_step(self, state, env):
min_samples = self.replay_buffer.batch_size * self.n_warmup_batches
if len(self.replay_buffer) < min_samples:
action = self.policy_model.select_random_action(state)
else:
action = self.policy_model.select_action(state)
new_state, reward, is_terminal, is_truncated, info = env.step(action)
#is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']
is_failure = is_terminal and not is_truncated
experience = (state, action, reward, new_state, float(is_failure))
self.replay_buffer.store(experience)
self.episode_reward[-1] += reward
self.episode_timestep[-1] += 1
self.episode_exploration[-1] += self.policy_model.exploration_ratio
return new_state, is_terminal
def update_value_networks(self, tau=None):
tau = self.tau if tau is None else tau
for target, online in zip(self.target_value_model_a.parameters(),
self.online_value_model_a.parameters()):
target_ratio = (1.0 - tau) * target.data
online_ratio = tau * online.data
mixed_weights = target_ratio + online_ratio
target.data.copy_(mixed_weights)
for target, online in zip(self.target_value_model_b.parameters(),
self.online_value_model_b.parameters()):
target_ratio = (1.0 - tau) * target.data
online_ratio = tau * online.data
mixed_weights = target_ratio + online_ratio
target.data.copy_(mixed_weights)
def train(self, make_env_fn, make_env_kargs, seed, gamma,
max_minutes, max_episodes, goal_mean_100_reward):
training_start, last_debug_time = time.time(), float('-inf')
self.checkpoint_dir = tempfile.mkdtemp()
self.make_env_fn = make_env_fn
self.make_env_kargs = make_env_kargs
self.seed = seed
self.gamma = gamma
env = self.make_env_fn(**self.make_env_kargs, seed=self.seed)
torch.manual_seed(self.seed) ; np.random.seed(self.seed) ; random.seed(self.seed)
nS, nA = env.observation_space.shape[0], env.action_space.shape[0]
action_bounds = env.action_space.low, env.action_space.high
self.episode_timestep = []
self.episode_reward = []
self.episode_seconds = []
self.evaluation_scores = []
self.episode_exploration = []
self.target_value_model_a = self.value_model_fn(nS, nA)
self.online_value_model_a = self.value_model_fn(nS, nA)
self.target_value_model_b = self.value_model_fn(nS, nA)
self.online_value_model_b = self.value_model_fn(nS, nA)
self.update_value_networks(tau=1.0)
self.policy_model = self.policy_model_fn(nS, action_bounds)
self.value_optimizer_a = self.value_optimizer_fn(self.online_value_model_a,
self.value_optimizer_lr)
self.value_optimizer_b = self.value_optimizer_fn(self.online_value_model_b,
self.value_optimizer_lr)
self.policy_optimizer = self.policy_optimizer_fn(self.policy_model,
self.policy_optimizer_lr)
self.replay_buffer = self.replay_buffer_fn()
result = np.empty((max_episodes, 5))
result[:] = np.nan
training_time = 0
for episode in range(1, max_episodes + 1):
episode_start = time.time()
state, info = env.reset()
is_terminal = False
self.episode_reward.append(0.0)
self.episode_timestep.append(0.0)
self.episode_exploration.append(0.0)
for step in count():
state, is_terminal = self.interaction_step(state, env)
min_samples = self.replay_buffer.batch_size * self.n_warmup_batches
if len(self.replay_buffer) > min_samples:
experiences = self.replay_buffer.sample()
experiences = self.online_value_model_a.load(experiences)
self.optimize_model(experiences)
if np.sum(self.episode_timestep) % self.update_target_every_steps == 0:
self.update_value_networks()
if is_terminal:
gc.collect()
break
# stats
episode_elapsed = time.time() - episode_start
self.episode_seconds.append(episode_elapsed)
training_time += episode_elapsed
evaluation_score, _ = self.evaluate(self.policy_model, env)
self.save_checkpoint(episode-1, self.policy_model)
total_step = int(np.sum(self.episode_timestep))
self.evaluation_scores.append(evaluation_score)
mean_10_reward = np.mean(self.episode_reward[-10:])
std_10_reward = np.std(self.episode_reward[-10:])
mean_100_reward = np.mean(self.episode_reward[-100:])
std_100_reward = np.std(self.episode_reward[-100:])
mean_100_eval_score = np.mean(self.evaluation_scores[-100:])
std_100_eval_score = np.std(self.evaluation_scores[-100:])
lst_100_exp_rat = np.array(
self.episode_exploration[-100:])/np.array(self.episode_timestep[-100:])
mean_100_exp_rat = np.mean(lst_100_exp_rat)
std_100_exp_rat = np.std(lst_100_exp_rat)
wallclock_elapsed = time.time() - training_start
result[episode-1] = total_step, mean_100_reward, \
mean_100_eval_score, training_time, wallclock_elapsed
reached_debug_time = time.time() - last_debug_time >= LEAVE_PRINT_EVERY_N_SECS
reached_max_minutes = wallclock_elapsed >= max_minutes * 60
reached_max_episodes = episode >= max_episodes
reached_goal_mean_reward = mean_100_eval_score >= goal_mean_100_reward
training_is_over = reached_max_minutes or \
reached_max_episodes or \
reached_goal_mean_reward
elapsed_str = time.strftime("%H:%M:%S", time.gmtime(time.time() - training_start))
debug_message = 'el {}, ep {:04}, ts {:07}, '
debug_message += 'ar 10 {:05.1f}\u00B1{:05.1f}, '
debug_message += '100 {:05.1f}\u00B1{:05.1f}, '
debug_message += 'ex 100 {:02.1f}\u00B1{:02.1f}, '
debug_message += 'ev {:05.1f}\u00B1{:05.1f}'
debug_message = debug_message.format(
elapsed_str, episode-1, total_step, mean_10_reward, std_10_reward,
mean_100_reward, std_100_reward, mean_100_exp_rat, std_100_exp_rat,
mean_100_eval_score, std_100_eval_score)
print(debug_message, end='\r', flush=True)
if reached_debug_time or training_is_over:
print(ERASE_LINE + debug_message, flush=True)
last_debug_time = time.time()
if training_is_over:
if reached_max_minutes: print(u'--> reached_max_minutes \u2715')
if reached_max_episodes: print(u'--> reached_max_episodes \u2715')
if reached_goal_mean_reward: print(u'--> reached_goal_mean_reward \u2713')
break
final_eval_score, score_std = self.evaluate(self.policy_model, env, n_episodes=100)
wallclock_time = time.time() - training_start
print('Training complete.')
print('Final evaluation score {:.2f}\u00B1{:.2f} in {:.2f}s training time,'
' {:.2f}s wall-clock time.\n'.format(
final_eval_score, score_std, training_time, wallclock_time))
env.close() ; del env
self.get_cleaned_checkpoints()
return result, final_eval_score, training_time, wallclock_time
def evaluate(self, eval_policy_model, eval_env, n_episodes=1):
rs = []
for _ in range(n_episodes):
s, info = eval_env.reset()
d = False
rs.append(0)
for _ in count():
a = eval_policy_model.select_greedy_action(s)
s, r, d, _ = eval_env.step(a)
rs[-1] += r
if d: break
return np.mean(rs), np.std(rs)
def get_cleaned_checkpoints(self, n_checkpoints=4):
try:
return self.checkpoint_paths
except AttributeError:
self.checkpoint_paths = {}
paths = glob.glob(os.path.join(self.checkpoint_dir, '*.tar'))
paths_dic = {int(path.split('.')[-2]):path for path in paths}
last_ep = max(paths_dic.keys())
# checkpoint_idxs = np.geomspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
checkpoint_idxs = np.linspace(1, last_ep+1, n_checkpoints, endpoint=True, dtype=np.int)-1
for idx, path in paths_dic.items():
if idx in checkpoint_idxs:
self.checkpoint_paths[idx] = path
else:
os.unlink(path)
return self.checkpoint_paths
def demo_last(self, title='Fully-trained {} Agent', n_episodes=2, max_n_videos=2):
env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
checkpoint_paths = self.get_cleaned_checkpoints()
last_ep = max(checkpoint_paths.keys())
self.policy_model.load_state_dict(torch.load(checkpoint_paths[last_ep]))
self.evaluate(self.policy_model, env, n_episodes=n_episodes)
env.close()
data = get_gif_html(env_videos=env.videos,
title=title.format(self.__class__.__name__),
max_n_videos=max_n_videos)
del env
return HTML(data=data)
def demo_progression(self, title='{} Agent progression', max_n_videos=4):
env = self.make_env_fn(**self.make_env_kargs, monitor_mode='evaluation', render=True, record=True)
checkpoint_paths = self.get_cleaned_checkpoints()
for i in sorted(checkpoint_paths.keys()):
self.policy_model.load_state_dict(torch.load(checkpoint_paths[i]))
self.evaluate(self.policy_model, env, n_episodes=1)
env.close()
data = get_gif_html(env_videos=env.videos,
title=title.format(self.__class__.__name__),
subtitle_eps=sorted(checkpoint_paths.keys()),
max_n_videos=max_n_videos)
del env
return HTML(data=data)
def save_checkpoint(self, episode_idx, model):
torch.save(model.state_dict(),
os.path.join(self.checkpoint_dir, 'model.{}.tar'.format(episode_idx)))
return
sac_results = []
best_agent, best_eval_score = None, float('-inf')
for seed in SEEDS:
environment_settings = {
'env_name': 'HalfCheetahBulletEnv-v0',
'gamma': 0.99,
'max_minutes': 300,
'max_episodes': 10000,
'goal_mean_100_reward': 2000
}
policy_model_fn = lambda nS, bounds: FCGP(nS, bounds, hidden_dims=(256,256))
policy_max_grad_norm = float('inf')
policy_optimizer_fn = lambda net, lr: optim.Adam(net.parameters(), lr=lr)
policy_optimizer_lr = 0.0003
value_model_fn = lambda nS, nA: FCQSA(nS, nA, hidden_dims=(256,256))
value_max_grad_norm = float('inf')
value_optimizer_fn = lambda net, lr: optim.Adam(net.parameters(), lr=lr)
value_optimizer_lr = 0.0005
replay_buffer_fn = lambda: ReplayBuffer(max_size=100000, batch_size=64)
n_warmup_batches = 10
update_target_every_steps = 1
tau = 0.001
env_name, gamma, max_minutes, \
max_episodes, goal_mean_100_reward = environment_settings.values()
agent = SAC(replay_buffer_fn,
policy_model_fn,
policy_max_grad_norm,
policy_optimizer_fn,
policy_optimizer_lr,
value_model_fn,
value_max_grad_norm,
value_optimizer_fn,
value_optimizer_lr,
n_warmup_batches,
update_target_every_steps,
tau)
make_env_fn, make_env_kargs = get_make_env_fn(env_name=env_name, inner_wrappers=[RenderUint8])
result, final_eval_score, training_time, wallclock_time = agent.train(
make_env_fn, make_env_kargs, seed, gamma, max_minutes, max_episodes, goal_mean_100_reward)
sac_results.append(result)
if final_eval_score > best_eval_score:
best_eval_score = final_eval_score
best_agent = agent
sac_results = np.array(sac_results)
_ = BEEP()
Regards