diff --git a/dopamine/labs/atari_100k/configs/SPR.gin b/dopamine/labs/atari_100k/configs/SPR.gin new file mode 100644 index 00000000..6c8e91b0 --- /dev/null +++ b/dopamine/labs/atari_100k/configs/SPR.gin @@ -0,0 +1,51 @@ +# Data Regularlized-Q (DrQ) form Kostrikov et al. (2020) +import dopamine.jax.agents.dqn.dqn_agent +import dopamine.jax.networks +import dopamine.discrete_domains.gym_lib +import dopamine.discrete_domains.run_experiment +import dopamine.replay_memory.prioritized_replay_buffer +import dopamine.labs.atari_100k.spr_networks +import dopamine.labs.atari_100k.spr_agent + +# Parameters specific to DrQ are higlighted by comments +JaxDQNAgent.gamma = 0.99 +JaxDQNAgent.update_horizon = 10 # DrQ (instead of 3) +JaxDQNAgent.min_replay_history = 2000 # DrQ (instead of 20000) +JaxDQNAgent.update_period = 1 # DrQ (rather than 4) +JaxDQNAgent.target_update_period = 1 # DrQ (rather than 8000) +JaxDQNAgent.epsilon_train = 0.00 +JaxDQNAgent.epsilon_eval = 0.001 +JaxDQNAgent.epsilon_decay_period = 2001 # DrQ +JaxDQNAgent.optimizer = 'adam' + +SPRAgent.noisy = True +SPRAgent.dueling = True +SPRAgent.double_dqn = True +SPRAgent.distributional = True +SPRAgent.num_atoms = 51 +SPRAgent.log_every = 100 +SPRAgent.num_updates_per_train_step = 2 +SPRAgent.spr_weight = 5 +SPRAgent.jumps = 5 +SPRAgent.data_augmentation = True +SPRAgent.replay_scheme = 'prioritized' +SPRAgent.network = @spr_networks.SPRNetwork +SPRAgent.epsilon_fn = @jax.agents.dqn.dqn_agent.linearly_decaying_epsilon + +# Note these parameters are from DER (van Hasselt et al, 2019) +create_optimizer.learning_rate = 0.0001 +create_optimizer.eps = 0.00015 + +atari_lib.create_atari_environment.game_name = 'Pong' +# Atari 100K benchmark doesn't use sticky actions. +atari_lib.create_atari_environment.sticky_actions = False +AtariPreprocessing.terminal_on_life_loss = True +Runner.num_iterations = 1 +Runner.training_steps = 100000 # agent steps +MaxEpisodeEvalRunner.num_eval_episodes = 100 # agent episodes +Runner.max_steps_per_episode = 27000 # agent steps + +DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.replay_capacity = 200000 +DeterministicOutOfGraphPrioritizedTemporalReplayBuffer.batch_size = 32 +DeterministicOutOfGraphTemporalReplayBuffer.replay_capacity = 200000 +DeterministicOutOfGraphTemporalReplayBuffer.batch_size = 32 \ No newline at end of file diff --git a/dopamine/labs/atari_100k/replay_memory/__init__.py b/dopamine/labs/atari_100k/replay_memory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dopamine/labs/atari_100k/replay_memory/parallel_deterministic_sum_tree.py b/dopamine/labs/atari_100k/replay_memory/parallel_deterministic_sum_tree.py new file mode 100644 index 00000000..d17d0417 --- /dev/null +++ b/dopamine/labs/atari_100k/replay_memory/parallel_deterministic_sum_tree.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2021 The Atari 100k Precipice Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A sum tree data structure that uses JAX for controlling randomness.""" + +from dopamine.replay_memory import sum_tree +import jax +from jax import numpy as jnp +import numpy as np +import time +import functools + + +@jax.jit +def step(i, args): + query_value, index, nodes = args + left_child = index * 2 + 1 + left_sum = nodes[left_child] + index = jax.lax.cond(query_value < left_sum, lambda x: x, lambda x: x + 1, + left_child) + query_value = jax.lax.cond(query_value < left_sum, lambda x: x, + lambda x: x - left_sum, query_value) + return query_value, index, nodes + + +@jax.jit +@functools.partial(jax.vmap, in_axes=(None, None, 0, None, None)) +def parallel_stratified_sample(rng, nodes, i, n, depth): + rng = jax.random.fold_in(rng, i) + total_priority = nodes[0] + upper_bound = (i + 1) / n + lower_bound = i / n + query = jax.random.uniform(rng, minval=lower_bound, maxval=upper_bound) + _, index, _ = jax.lax.fori_loop(0, depth, step, + (query * total_priority, 0, nodes)) + return index + + +class DeterministicSumTree(sum_tree.SumTree): + """A sum tree data structure for storing replay priorities. + + In contrast to the original implementation, this uses JAX for handling + randomness, which allows us to reproduce the same results when using the same + seed. + """ + + def __init__(self, capacity): + """Creates the sum tree data structure for the given replay capacity. + Args: + capacity: int, the maximum number of elements that can be stored in this + data structure. + Raises: + ValueError: If requested capacity is not positive. + """ + assert isinstance(capacity, int) + if capacity <= 0: + raise ValueError( + 'Sum tree capacity should be positive. Got: {}'.format(capacity)) + + self.nodes = [] + self.depth = int(np.ceil(np.log2(capacity))) + self.low_idx = (2**self.depth) - 1 # pri_idx + low_idx -> tree_idx + self.high_idx = capacity + self.low_idx + self.nodes = np.zeros(2**(self.depth + 1) - 1) # Double precision. + + self.max_recorded_priority = 1.0 + + def _total_priority(self): + """Returns the sum of all priorities stored in this sum tree. + Returns: + float, sum of priorities stored in this sum tree. + """ + return self.nodes[0] + + def sample(self, rng, query_value=None): + """Samples an element from the sum tree. + This function is designed to be jitted, so it does not have the same + checks as the original. + """ + # Sample a value in range [0, R), where R is the value stored at the root. + nodes = jnp.array(self.nodes) + query_value = ( + jax.random.uniform(rng) if query_value is None else query_value) + query_value *= self._total_priority() + + # Now traverse the sum tree. + _, index, _ = jax.lax.fori_loop(0, self.depth, step, + (query_value, 0, nodes)) + return index - self.low_idx + + def stratified_sample(self, batch_size, rng): + """Performs stratified sampling using the sum tree.""" + if self._total_priority() == 0.0: + raise Exception('Cannot sample from an empty sum tree.') + indices = parallel_stratified_sample(rng, self.nodes, + jnp.arange(batch_size), batch_size, + self.depth) + return indices - self.low_idx + + def get(self, node_index): + """Returns the value of the leaf node corresponding to the index. + Args: + node_index: The index of the leaf node. + Returns: + The value of the leaf node. + """ + return self.nodes[node_index + self.low_idx] + + def set(self, node_index, value): + """Sets the value of a leaf node and updates internal nodes accordingly. + This operation takes O(log(capacity)). + Args: + node_index: int, the index of the leaf node to be updated. + value: float, the value which we assign to the node. This value must be + nonnegative. Setting value = 0 will cause the element to never be + sampled. + Raises: + ValueError: If the given value is negative. + """ + if value < 0.0: + raise ValueError( + 'Sum tree values should be nonnegative. Got {}'.format(value)) + node_index = node_index + self.low_idx + self.max_recorded_priority = max(value, self.max_recorded_priority) + + delta_value = value - self.nodes[node_index] + + # Now traverse back the tree, adjusting all sums along the way. + for _ in reversed(range(self.depth)): + # Note: Adding a delta leads to some tolerable numerical inaccuracies. + self.nodes[node_index] += delta_value + node_index = (node_index - 1) // 2 + + self.nodes[node_index] += delta_value + assert node_index == 0, ('Sum tree traversal failed, final node index ' + 'is not 0.') diff --git a/dopamine/labs/atari_100k/replay_memory/time_batch_replay_buffer.py b/dopamine/labs/atari_100k/replay_memory/time_batch_replay_buffer.py new file mode 100644 index 00000000..64107f95 --- /dev/null +++ b/dopamine/labs/atari_100k/replay_memory/time_batch_replay_buffer.py @@ -0,0 +1,862 @@ +# coding=utf-8 +# Copyright 2018 The Dopamine Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The standard DQN replay memory. +This implementation is an out-of-graph replay memory + in-graph wrapper. It +supports vanilla n-step updates of the form typically found in the literature, +i.e. where rewards are accumulated for n steps and the intermediate trajectory +is not exposed to the agent. This does not allow, for example, performing +off-policy corrections. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gzip +import math +import os +import pickle +import time + +from absl import logging +import numpy as np +import tensorflow as tf + +import gin.tf +import jax +from jax import numpy as jnp +import functools + +from dopamine.labs.atari_100k.replay_memory import parallel_deterministic_sum_tree as sum_tree + +# Defines a type describing part of the tuple returned by the replay +# memory. Each element of the tuple is a tensor of shape [batch, ...] where +# ... is defined the 'shape' field of ReplayElement. The tensor type is +# given by the 'type' field. The 'name' field is for convenience and ease of +# debugging. +ReplayElement = ( + collections.namedtuple('shape_type', ['name', 'shape', 'type'])) + +# A prefix that can not collide with variable names for checkpoint files. +STORE_FILENAME_PREFIX = '$store$_' + +# This constant determines how many iterations a checkpoint is kept for. +CHECKPOINT_DURATION = 4 + + +def modulo_range(start, length, modulo): + for i in range(length): + yield (start + i) % modulo + + +def invalid_range(cursor, replay_capacity, stack_size, update_horizon): + """Returns a array with the indices of cursor-related invalid transitions. + There are update_horizon + stack_size invalid indices: + - The update_horizon indices before the cursor, because we do not have a + valid N-step transition (including the next state). + - The stack_size indices on or immediately after the cursor. + If N = update_horizon, K = stack_size, and the cursor is at c, invalid + indices are: + c - N, c - N + 1, ..., c, c + 1, ..., c + K - 1. + It handles special cases in a circular buffer in the beginning and the end. + Args: + cursor: int, the position of the cursor. + replay_capacity: int, the size of the replay memory. + stack_size: int, the size of the stacks returned by the replay memory. + update_horizon: int, the agent's update horizon. + Returns: + np.array of size stack_size with the invalid indices. + """ + assert cursor < replay_capacity + return np.array([(cursor - update_horizon + i) % replay_capacity + for i in range(stack_size + update_horizon)]) + + +@gin.configurable +class DeterministicOutOfGraphTemporalReplayBuffer(object): + """A simple out-of-graph Replay Buffer. + Stores transitions, state, action, reward, next_state, terminal (and any + extra contents specified) in a circular buffer and provides a uniform + transition sampling function. + When the states consist of stacks of observations storing the states is + inefficient. This class writes observations and constructs the stacked states + at sample time. + Attributes: + add_count: int, counter of how many transitions have been added (including + the blank ones at the beginning of an episode). + invalid_range: np.array, an array with the indices of cursor-related invalid + transitions + """ + + def __init__(self, + observation_shape, + stack_size, + replay_capacity, + batch_size, + jumps, + update_horizon=1, + gamma=0.99, + max_sample_attempts=1000, + extra_storage_types=None, + observation_dtype=np.uint8, + terminal_dtype=np.uint8, + action_shape=(), + action_dtype=np.int32, + reward_shape=(), + reward_dtype=np.float32): + """Initializes OutOfGraphReplayBuffer. + Args: + observation_shape: tuple of ints. + stack_size: int, number of frames to use in state stack. + replay_capacity: int, number of transitions to keep in memory. + batch_size: int. + update_horizon: int, length of update ('n' in n-step update). + gamma: int, the discount factor. + max_sample_attempts: int, the maximum number of attempts allowed to + get a sample. + extra_storage_types: list of ReplayElements defining the type of the extra + contents that will be stored and returned by sample_transition_batch. + observation_dtype: np.dtype, type of the observations. Defaults to + np.uint8 for Atari 2600. + terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for + Atari 2600. + action_shape: tuple of ints, the shape for the action vector. Empty tuple + means the action is a scalar. + action_dtype: np.dtype, type of elements in the action. + reward_shape: tuple of ints, the shape of the reward vector. Empty tuple + means the reward is a scalar. + reward_dtype: np.dtype, type of elements in the reward. + Raises: + ValueError: If replay_capacity is too small to hold at least one + transition. + """ + assert isinstance(observation_shape, tuple) + if replay_capacity < update_horizon + stack_size: + raise ValueError('There is not enough capacity to cover ' + 'update_horizon and stack_size.') + + logging.info('Creating a %s replay memory with the following parameters:', + self.__class__.__name__) + logging.info('\t observation_shape: %s', str(observation_shape)) + logging.info('\t observation_dtype: %s', str(observation_dtype)) + logging.info('\t terminal_dtype: %s', str(terminal_dtype)) + logging.info('\t stack_size: %d', stack_size) + logging.info('\t replay_capacity: %d', replay_capacity) + logging.info('\t batch_size: %d', batch_size) + logging.info('\t update_horizon: %d', update_horizon) + logging.info('\t gamma: %f', gamma) + + self._action_shape = action_shape + self._action_dtype = action_dtype + self._reward_shape = reward_shape + self._reward_dtype = reward_dtype + self._observation_shape = observation_shape + self._stack_size = stack_size + self._state_shape = self._observation_shape + (self._stack_size,) + self._replay_capacity = replay_capacity + self._batch_size = batch_size + self._update_horizon = update_horizon + self._gamma = gamma + self._observation_dtype = observation_dtype + self._terminal_dtype = terminal_dtype + self._max_sample_attempts = max_sample_attempts + self._jumps = jumps + if extra_storage_types: + self._extra_storage_types = extra_storage_types + else: + self._extra_storage_types = [] + self._create_storage() + self.add_count = np.array(0) + self.invalid_range = np.zeros((self._stack_size)) + # When the horizon is > 1, we compute the sum of discounted rewards as a dot + # product using the precomputed vector . + self._cumulative_discount_vector = np.array( + [math.pow(self._gamma, n) for n in range(update_horizon)], + dtype=np.float32) + self._next_experience_is_episode_start = True + self._episode_end_indices = set() + + def _create_storage(self): + """Creates the numpy arrays used to store transitions. + """ + self._store = {} + for storage_element in self.get_storage_signature(): + array_shape = [self._replay_capacity] + list(storage_element.shape) + self._store[storage_element.name] = np.empty( + array_shape, dtype=storage_element.type) + + def get_add_args_signature(self): + """The signature of the add function. + Note - Derived classes may return a different signature. + Returns: + list of ReplayElements defining the type of the argument signature needed + by the add function. + """ + return self.get_storage_signature() + + def get_storage_signature(self): + """Returns a default list of elements to be stored in this replay memory. + Note - Derived classes may return a different signature. + Returns: + list of ReplayElements defining the type of the contents stored. + """ + storage_elements = [ + ReplayElement('observation', self._observation_shape, + self._observation_dtype), + ReplayElement('action', self._action_shape, self._action_dtype), + ReplayElement('reward', self._reward_shape, self._reward_dtype), + ReplayElement('terminal', (), self._terminal_dtype) + ] + + for extra_replay_element in self._extra_storage_types: + storage_elements.append(extra_replay_element) + return storage_elements + + def _add_zero_transition(self): + """Adds a padding transition filled with zeros (Used in episode beginnings). + """ + zero_transition = [] + for element_type in self.get_add_args_signature(): + zero_transition.append( + np.zeros(element_type.shape, dtype=element_type.type)) + self._episode_end_indices.discard(self.cursor()) # If present + self._add(*zero_transition) + + def add(self, + observation, + action, + reward, + terminal, + *args, + priority=None, + episode_end=False): + """Adds a transition to the replay memory. + This function checks the types and handles the padding at the beginning of + an episode. Then it calls the _add function. + Since the next_observation in the transition will be the observation added + next there is no need to pass it. + If the replay memory is at capacity the oldest transition will be discarded. + Args: + observation: np.array with shape observation_shape. + action: int, the action in the transition. + reward: float, the reward received in the transition. + terminal: np.dtype, acts as a boolean indicating whether the transition + was terminal (1) or not (0). + *args: extra contents with shapes and dtypes according to + extra_storage_types. + priority: float, unused in the circular replay buffer, but may be used + in child classes like PrioritizedReplayBuffer. + episode_end: bool, whether this experience is the last experience in + the episode. This is useful for tasks that terminate due to time-out, + but do not end on a terminal state. Overloading 'terminal' may not + be sufficient in this case, since 'terminal' is passed to the agent + for training. 'episode_end' allows the replay buffer to determine + episode boundaries without passing that information to the agent. + """ + if priority is not None: + args = args + (priority,) + + self._check_add_types(observation, action, reward, terminal, *args) + if self._next_experience_is_episode_start: + for _ in range(self._stack_size - 1): + # Child classes can rely on the padding transitions being filled with + # zeros. This is useful when there is a priority argument. + self._add_zero_transition() + self._next_experience_is_episode_start = False + + if episode_end or terminal: + self._episode_end_indices.add(self.cursor()) + self._next_experience_is_episode_start = True + else: + self._episode_end_indices.discard(self.cursor()) # If present + + self._add(observation, action, reward, terminal, *args) + + def _add(self, *args): + """Internal add method to add to the storage arrays. + Args: + *args: All the elements in a transition. + """ + self._check_args_length(*args) + transition = { + e.name: args[idx] for idx, e in enumerate(self.get_add_args_signature()) + } + self._add_transition(transition) + + def _add_transition(self, transition): + """Internal add method to add transition dictionary to storage arrays. + Args: + transition: The dictionary of names and values of the transition + to add to the storage. + """ + cursor = self.cursor() + for arg_name in transition: + self._store[arg_name][cursor] = transition[arg_name] + + self.add_count += 1 + self.invalid_range = invalid_range(self.cursor(), self._replay_capacity, + self._stack_size, self._update_horizon) + + def _check_args_length(self, *args): + """Check if args passed to the add method have the same length as storage. + Args: + *args: Args for elements used in storage. + Raises: + ValueError: If args have wrong length. + """ + if len(args) != len(self.get_add_args_signature()): + raise ValueError('Add expects {} elements, received {}'.format( + len(self.get_add_args_signature()), len(args))) + + def _check_add_types(self, *args): + """Checks if args passed to the add method match those of the storage. + Args: + *args: Args whose types need to be validated. + Raises: + ValueError: If args have wrong shape or dtype. + """ + self._check_args_length(*args) + for i, (arg_element, store_element) in enumerate( + zip(args, self.get_add_args_signature())): + if isinstance(arg_element, np.ndarray): + arg_shape = arg_element.shape + elif isinstance(arg_element, tuple) or isinstance(arg_element, list): + # TODO(b/80536437). This is not efficient when arg_element is a list. + arg_shape = np.array(arg_element).shape + else: + # Assume it is scalar. + arg_shape = tuple() + store_element_shape = tuple(store_element.shape) + if arg_shape != store_element_shape: + raise ValueError('arg {} has shape {}, expected {}'.format( + i, arg_shape, store_element_shape)) + + def is_empty(self): + """Is the Replay Buffer empty?""" + return self.add_count == 0 + + def is_full(self): + """Is the Replay Buffer full?""" + return self.add_count >= self._replay_capacity + + def cursor(self): + """Index to the location where the next transition will be written.""" + return self.add_count % self._replay_capacity + + def parallel_get_stack(self, indices, element_name): + array = self._store[element_name] + result = np.take( + array, + np.arange(-self._stack_size + 1, 1)[:, None] + indices[None, :], + axis=0, + mode="wrap") + result = np.moveaxis(result, 0, -1) + return result + + def get_observation_stack(self, index): + return self._get_element_stack(index, 'observation') + + def get_range(self, array, start_index, end_index): + """Returns the range of array at the index handling wraparound if necessary. + Args: + array: np.array, the array to get the stack from. + start_index: int, index to the start of the range to be returned. Range + will wraparound if start_index is smaller than 0. + end_index: int, exclusive end index. Range will wraparound if end_index + exceeds replay_capacity. + Returns: + np.array, with shape [end_index - start_index, array.shape[1:]]. + """ + assert end_index > start_index, 'end_index must be larger than start_index' + assert end_index >= 0 + assert start_index < self._replay_capacity + if not self.is_full(): + assert end_index <= self.cursor(), ( + 'Index {} has not been added.'.format(start_index)) + + # Fast slice read when there is no wraparound. + if start_index % self._replay_capacity < end_index % self._replay_capacity: + return_array = array[start_index:end_index, ...] + # Slow list read. + else: + indices = [(start_index + i) % self._replay_capacity + for i in range(end_index - start_index)] + return_array = array[indices, ...] + return return_array + + def get_observation_stack(self, index): + return self._get_element_stack(index, 'observation') + + def _get_element_stack(self, index, element_name): + state = self.get_range(self._store[element_name], + index - self._stack_size + 1, index + 1) + # The stacking axis is 0 but the agent expects as the last axis. + return np.moveaxis(state, 0, -1) + + def get_terminal_stack(self, index): + return self.get_range(self._store['terminal'], index - self._stack_size + 1, + index + 1) + + def is_valid_transition(self, index): + """Checks if the index contains a valid transition. + Checks for collisions with the end of episodes and the current position + of the cursor. + Args: + index: int, the index to the state in the transition. + Returns: + Is the index valid: Boolean. + """ + index = int(index) + # Check the index is in the valid range + if index < 0 or index >= self._replay_capacity: + return False + if not self.is_full(): + # The indices and next_indices must be smaller than the cursor. + if index >= self.cursor() - self._update_horizon - self._jumps: + return False + # The first few indices contain the padding states of the first episode. + if index < self._stack_size - 1: + return False + + # Skip transitions that straddle the cursor. + if index in set(self.invalid_range): + return False + + # If there are terminal flags in any other frame other than the last one + # the stack is not valid, so don't sample it. + if self.get_terminal_stack(index)[:-1].any(): + return False + + # If the episode ends before the update horizon, without a terminal signal, + # it is invalid. + for i in modulo_range(index, self._update_horizon, self._replay_capacity): + if i in self._episode_end_indices and not self._store['terminal'][i]: + return False + + return True + + def _create_batch_arrays(self, batch_size): + """Create a tuple of arrays with the type of get_transition_elements. + When using the WrappedReplayBuffer with staging enabled it is important to + create new arrays every sample because StaginArea keeps a pointer to the + returned arrays. + Args: + batch_size: (int) number of transitions returned. If None the default + batch_size will be used. + Returns: + Tuple of np.arrays with the shape and type of get_transition_elements. + """ + transition_elements = self.get_transition_elements(batch_size) + batch_arrays = [] + for element in transition_elements: + batch_arrays.append(np.empty(element.shape, dtype=element.type)) + return tuple(batch_arrays) + + def sample_index_batch(self, batch_size): + """Returns a batch of valid indices sampled uniformly. + + Args: + batch_size: int, number of indices returned. + + Returns: + list of ints, a batch of valid indices sampled uniformly. + + Raises: + RuntimeError: If the batch was not constructed after maximum number of + tries. + """ + self._rng, rng = jax.random.split(self._rng) + if self.is_full(): + # add_count >= self._replay_capacity > self._stack_size + min_id = self.cursor() - self._replay_capacity + self._stack_size - 1 + max_id = self.cursor() - self._update_horizon - self._jumps + else: + # add_count < self._replay_capacity + min_id = self._stack_size - 1 + max_id = self.cursor() - self._update_horizon - self._jumps + if max_id <= min_id: + raise RuntimeError('Cannot sample a batch with fewer than stack size ' + '({}) + update_horizon ({}) transitions.'.format( + self._stack_size, self._update_horizon)) + indices = jax.random.randint(rng, (batch_size,), min_id, + max_id) % self._replay_capacity + allowed_attempts = self._max_sample_attempts + indices = np.array(indices) + for i in range(len(indices)): + if not self.is_valid_transition(indices[i]): + if allowed_attempts == 0: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, i, batch_size)) + index = indices[i] + while not self.is_valid_transition(index) and allowed_attempts > 0: + # If index i is not valid keep sampling others. Note that this + # is not stratified. + self._rng, rng = jax.random.split(self._rng) + index = jax.random.randint(rng, + (), min_id, max_id) % self._replay_capacity + allowed_attempts -= 1 + indices[i] = index + return indices + + def restore_leading_dims(self, batch_size, jumps, tensor): + return tensor.reshape(batch_size, jumps, *tensor.shape[1:]) + + def sample_transition_batch(self, + rng, + batch_size=None, + indices=None, + jumps=None): + """Returns a batch of transitions (including any extra contents). + If get_transition_elements has been overridden and defines elements not + stored in self._store, an empty array will be returned and it will be + left to the child class to fill it. For example, for the child class + OutOfGraphPrioritizedReplayBuffer, the contents of the + sampling_probabilities are stored separately in a sum tree. + When the transition is terminal next_state_batch has undefined contents. + NOTE: This transition contains the indices of the sampled elements. These + are only valid during the call to sample_transition_batch, i.e. they may + be used by subclasses of this replay buffer but may point to different data + as soon as sampling is done. + Args: + batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + indices: None or list of ints, the indices of every transition in the + batch. If None, sample the indices uniformly. + Returns: + transition_batch: tuple of np.arrays with the shape and type as in + get_transition_elements(). + Raises: + ValueError: If an element to be sampled is missing from the replay buffer. + """ + self._rng = rng + if batch_size is None: + batch_size = self._batch_size + if jumps is None: + jumps = self._jumps + if indices is None: + indices = self.sample_index_batch(batch_size) + assert len(indices) == batch_size + transition_elements = self.get_transition_elements(batch_size) + state_indices = indices[:, None] + np.arange(jumps)[None, :] + state_indices = state_indices.reshape(batch_size * jumps) + + # shape: horizon X batch_size*jumps + # Offset by one; a `d + trajectory_indices = (np.arange(-1, self._update_horizon - 1)[:, None] + + state_indices[None, :]) % self._replay_capacity + trajectory_terminals = self._store["terminal"][trajectory_indices] + trajectory_terminals[0, :] = 0 + is_terminal_transition = trajectory_terminals.any(0) + valid_mask = (1 - trajectory_terminals).cumprod(0) + trajectory_discount_vector = valid_mask * self._cumulative_discount_vector[:, + None] + trajectory_rewards = self._store['reward'][(trajectory_indices + 1) % + self._replay_capacity] + returns = np.sum(trajectory_discount_vector * trajectory_rewards, axis=0) + + next_indices = (state_indices + + self._update_horizon) % self._replay_capacity + outputs = [] + + for element in transition_elements: + name = element.name + if name == 'state': + output = self.parallel_get_stack(state_indices, "observation") + output = self.restore_leading_dims(batch_size, jumps, output) + elif name == 'reward': + # compute the discounted sum of rewards in the trajectory. + output = returns + output = self.restore_leading_dims(batch_size, jumps, output) + elif name == 'next_state': + output = self.parallel_get_stack(next_indices, "observation") + output = self.restore_leading_dims(batch_size, jumps, output) + elif name == "same_trajectory": + output = self._store["terminal"][state_indices] + output = self.restore_leading_dims(batch_size, jumps, output) + output[0, :] = 0 + output = (1 - output).cumprod(1) + elif name == 'valid': + output = np.array([self.is_valid_transition(i) for i in state_indices]) + output = self.restore_leading_dims(batch_size, jumps, output) + elif name in ('next_action', 'next_reward'): + output = self._store[name.lstrip('next_')][next_indices] + output = self.restore_leading_dims(batch_size, jumps, output) + elif element.name == 'terminal': + output = is_terminal_transition + output = self.restore_leading_dims(batch_size, jumps, output) + elif name == 'indices': + output = indices + elif name in self._store.keys(): + output = self._store[name][state_indices] + output = self.restore_leading_dims(batch_size, jumps, output) + else: + continue + outputs.append(output) + return outputs + + def get_transition_elements(self, batch_size=None, jumps=None): + """Returns a 'type signature' for sample_transition_batch. + Args: + batch_size: int, number of transitions returned. If None, the default + batch_size will be used. + Returns: + signature: A namedtuple describing the method's return type signature. + """ + jumps = self._jumps if jumps is None else jumps + batch_size = self._batch_size if batch_size is None else batch_size + + transition_elements = [ + ReplayElement('state', (batch_size, jumps) + self._state_shape, + self._observation_dtype), + ReplayElement('action', (batch_size, jumps) + self._action_shape, + self._action_dtype), + ReplayElement('reward', (batch_size, jumps) + self._reward_shape, + self._reward_dtype), + ReplayElement('next_state', (batch_size, jumps) + self._state_shape, + self._observation_dtype), + ReplayElement('next_action', (batch_size, jumps) + self._action_shape, + self._action_dtype), + ReplayElement('next_reward', (batch_size, jumps) + self._reward_shape, + self._reward_dtype), + ReplayElement('terminal', (batch_size, jumps), self._terminal_dtype), + ReplayElement('same_trajectory', (batch_size, jumps), + self._terminal_dtype), + ReplayElement('valid', (batch_size, jumps), self._terminal_dtype), + ReplayElement('indices', (batch_size,), np.int32) + ] + for element in self._extra_storage_types: + transition_elements.append( + ReplayElement(element.name, + (batch_size, jumps) + tuple(element.shape), + element.type)) + return transition_elements + + def _generate_filename(self, checkpoint_dir, name, suffix): + return os.path.join(checkpoint_dir, '{}_ckpt.{}.gz'.format(name, suffix)) + + def _return_checkpointable_elements(self): + """Return the dict of elements of the class for checkpointing. + Returns: + checkpointable_elements: dict containing all non private (starting with + _) members + all the arrays inside self._store. + """ + checkpointable_elements = {} + for member_name, member in self.__dict__.items(): + if member_name == '_store': + for array_name, array in self._store.items(): + checkpointable_elements[STORE_FILENAME_PREFIX + array_name] = array + elif not member_name.startswith('_'): + checkpointable_elements[member_name] = member + return checkpointable_elements + + def save(self, checkpoint_dir, iteration_number): + """Save the OutOfGraphReplayBuffer attributes into a file. + This method will save all the replay buffer's state in a single file. + Args: + checkpoint_dir: str, the directory where numpy checkpoint files should be + saved. + iteration_number: int, iteration_number to use as a suffix in naming + numpy checkpoint files. + """ + if not tf.io.gfile.exists(checkpoint_dir): + return + + checkpointable_elements = self._return_checkpointable_elements() + + for attr in checkpointable_elements: + filename = self._generate_filename(checkpoint_dir, attr, iteration_number) + with tf.io.gfile.GFile(filename, 'wb') as f: + with gzip.GzipFile(fileobj=f) as outfile: + # Checkpoint the np arrays in self._store with np.save instead of + # pickling the dictionary is critical for file size and performance. + # STORE_FILENAME_PREFIX indicates that the variable is contained in + # self._store. + if attr.startswith(STORE_FILENAME_PREFIX): + array_name = attr[len(STORE_FILENAME_PREFIX):] + np.save(outfile, self._store[array_name], allow_pickle=False) + # Some numpy arrays might not be part of storage + elif isinstance(self.__dict__[attr], np.ndarray): + np.save(outfile, self.__dict__[attr], allow_pickle=False) + else: + pickle.dump(self.__dict__[attr], outfile) + + # After writing a checkpoint file, we garbage collect the checkpoint file + # that is four versions old. + stale_iteration_number = iteration_number - CHECKPOINT_DURATION + if stale_iteration_number >= 0: + stale_filename = self._generate_filename(checkpoint_dir, attr, + stale_iteration_number) + try: + tf.io.gfile.remove(stale_filename) + except tf.errors.NotFoundError: + pass + + def load(self, checkpoint_dir, suffix): + """Restores the object from bundle_dictionary and numpy checkpoints. + Args: + checkpoint_dir: str, the directory where to read the numpy checkpointed + files from. + suffix: str, the suffix to use in numpy checkpoint files. + Raises: + NotFoundError: If not all expected files are found in directory. + """ + save_elements = self._return_checkpointable_elements() + # We will first make sure we have all the necessary files available to avoid + # loading a partially-specified (i.e. corrupted) replay buffer. + for attr in save_elements: + filename = self._generate_filename(checkpoint_dir, attr, suffix) + if not tf.io.gfile.exists(filename): + raise tf.errors.NotFoundError(None, None, + 'Missing file: {}'.format(filename)) + # If we've reached this point then we have verified that all expected files + # are available. + for attr in save_elements: + filename = self._generate_filename(checkpoint_dir, attr, suffix) + with tf.io.gfile.GFile(filename, 'rb') as f: + with gzip.GzipFile(fileobj=f) as infile: + if attr.startswith(STORE_FILENAME_PREFIX): + array_name = attr[len(STORE_FILENAME_PREFIX):] + self._store[array_name] = np.load(infile, allow_pickle=False) + elif isinstance(self.__dict__[attr], np.ndarray): + self.__dict__[attr] = np.load(infile, allow_pickle=False) + else: + self.__dict__[attr] = pickle.load(infile) + + +@gin.configurable +class DeterministicOutOfGraphPrioritizedTemporalReplayBuffer( + DeterministicOutOfGraphTemporalReplayBuffer): + """Deterministic version of prioritized replay buffer.""" + + def __init__(self, + observation_shape, + stack_size, + replay_capacity, + batch_size, + update_horizon=1, + jumps=0, + gamma=0.99, + max_sample_attempts=1000, + extra_storage_types=None, + observation_dtype=np.uint8, + terminal_dtype=np.uint8, + action_shape=(), + action_dtype=np.int32, + reward_shape=(), + reward_dtype=np.float32): + super().__init__( + observation_shape=observation_shape, + stack_size=stack_size, + replay_capacity=replay_capacity, + batch_size=batch_size, + update_horizon=update_horizon, + gamma=gamma, + max_sample_attempts=max_sample_attempts, + extra_storage_types=extra_storage_types, + observation_dtype=observation_dtype, + terminal_dtype=terminal_dtype, + jumps=jumps, + action_shape=action_shape, + action_dtype=action_dtype, + reward_shape=reward_shape, + reward_dtype=reward_dtype) + + self.sum_tree = sum_tree.DeterministicSumTree(replay_capacity) + + def get_add_args_signature(self): + """The signature of the add function.""" + parent_add_signature = super().get_add_args_signature() + add_signature = parent_add_signature + [ + ReplayElement('priority', (), np.float32) + ] + return add_signature + + def _add(self, *args): + """Internal add method to add to the underlying memory arrays.""" + self._check_args_length(*args) + + # Use Schaul et al.'s (2015) scheme of setting the priority of new elements + # to the maximum priority so far. + # Picks out 'priority' from arguments and adds it to the sum_tree. + transition = {} + for i, element in enumerate(self.get_add_args_signature()): + if element.name == 'priority': + priority = args[i] + else: + transition[element.name] = args[i] + + self.sum_tree.set(self.cursor(), priority) + super()._add_transition(transition) + + def sample_index_batch(self, batch_size): + """Returns a batch of valid indices sampled as in Schaul et al. (2015).""" + # Sample stratified indices. Some of them might be invalid. + indices = self.sum_tree.stratified_sample(batch_size, self._rng) + allowed_attempts = self._max_sample_attempts + indices = np.array(indices) + for i in range(len(indices)): + if not self.is_valid_transition(indices[i]): + if allowed_attempts == 0: + raise RuntimeError( + 'Max sample attempts: Tried {} times but only sampled {}' + ' valid indices. Batch size is {}'.format( + self._max_sample_attempts, i, batch_size)) + index = indices[i] + while not self.is_valid_transition(index) and allowed_attempts > 0: + # If index i is not valid keep sampling others. + # Note that this is not stratified. + self._rng, rng = jax.random.split(self._rng) + index = int(self.sum_tree.sample(rng=rng)) + allowed_attempts -= 1 + indices[i] = index + return indices + + def sample_transition_batch(self, rng, batch_size=None, indices=None): + """Returns a batch of transitions with extra storage and the priorities.""" + transition = super().sample_transition_batch(rng, batch_size, indices) + + # By convention, the indices are always the last element of the batch + indices = transition[-1] + priority = self.get_priority(indices) + transition.append(priority) + return transition + + def set_priority(self, indices, priorities): + """Sets the priority of the given elements according to Schaul et al.""" + assert indices.dtype == np.int32, ('Indices must be integers, ' + 'given: {}'.format(indices.dtype)) + for index, priority in zip(indices, priorities): + self.sum_tree.set(index, priority) + + def get_priority(self, indices): + """Fetches the priorities correspond to a batch of memory indices.""" + assert indices.shape, 'Indices must be an array.' + assert indices.dtype == np.int32, ('Indices must be int32s, ' + 'given: {}'.format(indices.dtype)) + batch_size = len(indices) + priority_batch = np.empty((batch_size), dtype=np.float32) + for i, memory_index in enumerate(indices): + priority_batch[i] = self.sum_tree.get(memory_index) + return priority_batch + + def get_transition_elements(self, batch_size=None): + """Returns a 'type signature' for sample_transition_batch.""" + parent_transition_type = (super().get_transition_elements(batch_size)) + probablilities_type = [ + ReplayElement('sampling_probabilities', (batch_size,), np.float32) + ] + return parent_transition_type + probablilities_type diff --git a/dopamine/labs/atari_100k/spr_agent.py b/dopamine/labs/atari_100k/spr_agent.py new file mode 100644 index 00000000..51396583 --- /dev/null +++ b/dopamine/labs/atari_100k/spr_agent.py @@ -0,0 +1,531 @@ +# coding=utf-8 +# Copyright 2021 The Atari 100k Precipice Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An implementation of SPR in Jax. + +Includes the features included in the full Rainbow agent. Designed to work with +an optimized replay buffer that returns subsequences rather than individual +transitions. + +Some details differ from the original implementation due to differences in +the underlying Rainbow implementations. In particular: +* Dueling networks in Dopamine separate at the final layer, not the penultimate + layer as in the original. +* Dopamine's prioritized experience replay does not decay its exponent over time. + +We find that these changes do not drastically impact the overall performance of +the algorithm, however. + +Details on Rainbow are available in +"Rainbow: Combining Improvements in Deep Reinforcement Learning" by Hessel et +al. (2018). For details on SPR, see +"Data-Efficient Reinforcement Learning with Self-Predictive Representations" by +Schwarzer et al (2021). +""" + +import collections +import copy +import functools +import time + +from absl import logging +from dopamine.jax import losses +from dopamine.jax.agents.dqn import dqn_agent +from dopamine.jax.agents.rainbow import rainbow_agent as dopamine_rainbow_agent +from dopamine.replay_memory import prioritized_replay_buffer +import gin +import jax +import jax.numpy as jnp +import numpy as onp +import tensorflow as tf +import optax +from dopamine.labs.atari_100k import spr_networks as networks +from dopamine.labs.atari_100k.replay_memory import time_batch_replay_buffer as tdrbs + + +@functools.partial(jax.jit, static_argnums=(0, 4, 5, 6, 7, 8, 10, 11)) +def select_action(network_def, params, state, rng, num_actions, eval_mode, + epsilon_eval, epsilon_train, epsilon_decay_period, + training_steps, min_replay_history, epsilon_fn, support): + """Select an action from the set of available actions.""" + epsilon = jnp.where( + eval_mode, epsilon_eval, + epsilon_fn(epsilon_decay_period, training_steps, min_replay_history, + epsilon_train)) + + rng, rng1, rng2, rng3 = jax.random.split(rng, num=4) + p = jax.random.uniform(rng1) + q_values = network_def.apply( + params, state, key=rng2, eval_mode=eval_mode, support=support).q_values + + best_actions = jnp.argmax(q_values, axis=-1) + return rng, jnp.where(p <= epsilon, + jax.random.randint(rng3, (), 0, num_actions), + best_actions) + + +@functools.partial( + jax.vmap, in_axes=(None, 0, 0, None, None), axis_name="batch") +def get_logits(model, states, actions, do_rollout, rng): + results = model(states, actions=actions, do_rollout=do_rollout, key=rng)[0] + return results.logits, results.latent + + +@functools.partial( + jax.vmap, in_axes=(None, 0, 0, None, None), axis_name="batch") +def get_q_values(model, states, actions, do_rollout, rng): + results = model(states, actions=actions, do_rollout=do_rollout, key=rng)[0] + return results.q_values, results.latent + + +@functools.partial(jax.vmap, in_axes=(None, 0, None), axis_name="batch") +def get_spr_targets(model, states, key): + results = model(states, key) + return results + + +@functools.partial(jax.jit, static_argnums=(0, 3, 13, 14, 15, 17)) +def train(network_def, online_params, target_params, optimizer, optimizer_state, states, actions, next_states, + rewards, terminals, same_traj_mask, loss_weights, support, + cumulative_gamma, double_dqn, distributional, rng, spr_weight): + """Run a training step.""" + + current_state = states[:, 0] + online_params = online_params + # Split the current rng into 2 for updating the rng after this call + rng, rng1, rng2 = jax.random.split(rng, num=3) + use_spr = spr_weight > 0 + + def q_online(state, key, actions=None, do_rollout=False): + return network_def.apply( + online_params, + state, + actions=actions, + do_rollout=do_rollout, + key=key, + support=support, + mutable=["batch_stats"]) + + def q_target(state, key): + return network_def.apply( + target_params, state, key=key, support=support, mutable=["batch_stats"]) + + def encode_project(state, key): + latent, _ = network_def.apply( + target_params, + state, + method=network_def.encode, + mutable=["batch_stats"]) + latent = latent.reshape(-1) + return network_def.apply( + target_params, + latent, + key=key, + eval_mode=True, + method=network_def.project) + + def loss_fn(params, target, spr_targets, loss_multipliers): + """Computes the distributional loss for C51 or huber loss for DQN.""" + + def q_online(state, key, actions=None, do_rollout=False): + return network_def.apply( + params, + state, + actions=actions, + do_rollout=do_rollout, + key=key, + support=support, + mutable=["batch_stats"]) + + if distributional: + (logits, spr_predictions) = get_logits(q_online, current_state, + actions[:, :-1], use_spr, rng) + logits = jnp.squeeze(logits) + # Fetch the logits for its selected action. We use vmap to perform this + # indexing across the batch. + chosen_action_logits = jax.vmap(lambda x, y: x[y])(logits, actions[:, 0]) + dqn_loss = jax.vmap(losses.softmax_cross_entropy_loss_with_logits)( + target, chosen_action_logits) + else: + q_values, spr_predictions = get_q_values(q_online, current_state, + actions[:, :-1], use_spr, rng) + q_values = jnp.squeeze(q_values) + replay_chosen_q = jax.vmap(lambda x, y: x[y])(q_values, actions[:, 0]) + dqn_loss = jax.vmap(losses.huber_loss)(target, replay_chosen_q) + + if use_spr: + spr_predictions = spr_predictions.transpose(1, 0, 2) + spr_predictions = spr_predictions / jnp.linalg.norm( + spr_predictions, 2, -1, keepdims=True) + spr_targets = spr_targets / jnp.linalg.norm( + spr_targets, 2, -1, keepdims=True) + spr_loss = jnp.power(spr_predictions - spr_targets, 2).sum(-1) + spr_loss = (spr_loss * same_traj_mask.transpose(1, 0)).mean(0) + else: + spr_loss = 0 + + loss = dqn_loss + spr_weight * spr_loss + + mean_loss = jnp.mean(loss_multipliers * loss) + return mean_loss, (loss, dqn_loss, spr_loss) + + # Use the weighted mean loss for gradient computation. + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + target = target_output(q_online, q_target, next_states, rewards, terminals, + support, cumulative_gamma, double_dqn, distributional, + rng1) + + if use_spr: + future_states = states[:, 1:] + spr_targets = get_spr_targets( + encode_project, future_states.reshape(-1, *future_states.shape[2:]), + rng1) + spr_targets = spr_targets.reshape(*future_states.shape[:2], + *spr_targets.shape[1:]).transpose( + 1, 0, 2) + else: + spr_targets = None + + # Get the unweighted loss without taking its mean for updating priorities. + (mean_loss, (loss, dqn_loss, + spr_loss)), grad = grad_fn(online_params, target, spr_targets, + loss_weights) + updates, optimizer_state = optimizer.update(grad, optimizer_state) + online_params = optax.apply_updates(online_params, updates) + return optimizer_state, online_params, loss, mean_loss, dqn_loss, spr_loss, rng2 + + +@functools.partial( + jax.vmap, + in_axes=(None, None, 0, 0, 0, None, None, None, None, None), + axis_name="batch") +def target_output(model, target_network, next_states, rewards, terminals, + support, cumulative_gamma, double_dqn, distributional, rng): + """Builds the C51 target distribution or DQN target Q-values.""" + is_terminal_multiplier = 1. - terminals.astype(jnp.float32) + # Incorporate terminal state to discount factor. + gamma_with_terminal = cumulative_gamma * is_terminal_multiplier + + target_network_dist, _ = target_network(next_states, key=rng) + if double_dqn: + # Use the current network for the action selection + next_state_target_outputs, _ = model(next_states, key=rng) + else: + next_state_target_outputs = target_network_dist + # Action selection using Q-values for next-state + q_values = jnp.squeeze(next_state_target_outputs.q_values) + next_qt_argmax = jnp.argmax(q_values) + + if distributional: + # Compute the target Q-value distribution + probabilities = jnp.squeeze(target_network_dist.probabilities) + next_probabilities = probabilities[next_qt_argmax] + target_support = rewards + gamma_with_terminal * support + target = dopamine_rainbow_agent.project_distribution( + target_support, next_probabilities, support) + else: + # Compute the target Q-value + next_q_values = jnp.squeeze(target_network_dist.q_values) + replay_next_qt_max = next_q_values[next_qt_argmax] + target = rewards + gamma_with_terminal * replay_next_qt_max + + return jax.lax.stop_gradient(target) + + +@gin.configurable +class SPRAgent(dqn_agent.JaxDQNAgent): + """A compact implementation of SPR in Jax.""" + + def __init__(self, + num_actions, + noisy=False, + dueling=False, + double_dqn=False, + distributional=True, + data_augmentation=False, + num_updates_per_train_step=2, + network=networks.SPRNetwork, + num_atoms=51, + vmax=10., + vmin=None, + jumps=5, + spr_weight=5, + log_every=1, + epsilon_fn=dqn_agent.linearly_decaying_epsilon, + replay_scheme='prioritized', + replay_type='deterministic', + summary_writer=None, + seed=None): + """Initializes the agent and constructs the necessary components. + + Args: + num_actions: int, number of actions the agent can take at any state. + noisy: bool, Whether to use noisy networks or not. + dueling: bool, Whether to use dueling network architecture or not. + double_dqn: bool, Whether to use Double DQN or not. + distributional: bool, whether to use distributional RL or not. + data_augmentation: bool, Whether to use data augmentation or not. + num_updates_per_train_step: int, Number of gradient updates every training + step. Defaults to 1. + network: flax.linen Module, neural network used by the agent initialized + by shape in _create_network below. See + dopamine.jax.networks.RainbowNetwork as an example. + num_atoms: int, the number of buckets of the value function distribution. + vmax: float, the value distribution support is [vmin, vmax]. + vmin: float, the value distribution support is [vmin, vmax]. If vmin is + None, it is set to -vmax. + epsilon_fn: function expecting 4 parameters: (decay_period, step, + warmup_steps, epsilon). This function should return the epsilon value + used for exploration during training. + replay_scheme: str, 'prioritized' or 'uniform', the sampling scheme of the + replay memory. + replay_type: str, 'deterministic' or 'regular', specifies the type of + replay buffer to create. + summary_writer: SummaryWriter object, for outputting training statistics. + seed: int, a seed for Jax RNG and initialization. + """ + logging.info('Creating %s agent with the following parameters:', + self.__class__.__name__) + logging.info('\t double_dqn: %s', double_dqn) + logging.info('\t noisy_networks: %s', noisy) + logging.info('\t dueling_dqn: %s', dueling) + logging.info('\t distributional: %s', distributional) + logging.info('\t data_augmentation: %s', data_augmentation) + logging.info('\t replay_scheme: %s', replay_scheme) + logging.info('\t num_updates_per_train_step: %d', + num_updates_per_train_step) + # We need this because some tools convert round floats into ints. + vmax = float(vmax) + self._num_atoms = num_atoms + vmin = vmin if vmin else -vmax + self._support = jnp.linspace(vmin, vmax, num_atoms) + self._replay_scheme = replay_scheme + self._replay_type = replay_type + self._double_dqn = double_dqn + self._noisy = noisy + self._dueling = dueling + self._distributional = distributional + self._data_augmentation = data_augmentation + self._num_updates_per_train_step = num_updates_per_train_step + self._jumps = jumps + self.spr_weight = spr_weight + self.log_every = log_every + super().__init__( + num_actions=num_actions, + network=functools.partial( + network, + num_atoms=num_atoms, + noisy=self._noisy, + dueling=self._dueling, + distributional=self._distributional), + epsilon_fn=dqn_agent.identity_epsilon if self._noisy else epsilon_fn, + summary_writer=summary_writer, + seed=seed) + + def _build_networks_and_optimizer(self): + self._rng, rng = jax.random.split(self._rng) + self.online_params = self.network_def.init( + rng, + x=self.state, + actions=jnp.zeros((5,)), + do_rollout=self.spr_weight > 0, + support=self._support) + self.optimizer = dqn_agent.create_optimizer(self._optimizer_name) + self.optimizer_state = self.optimizer.init(self.online_params) + self.target_network_params = copy.deepcopy(self.online_params) + + def _build_replay_buffer(self): + """Creates the replay buffer used by the agent.""" + if self._replay_scheme not in ['uniform', 'prioritized']: + raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) + if self._replay_type not in ['deterministic']: + raise ValueError('Invalid replay type: {}'.format(self._replay_type)) + if self._replay_scheme == "prioritized": + return tdrbs.DeterministicOutOfGraphPrioritizedTemporalReplayBuffer( + observation_shape=self.observation_shape, + stack_size=self.stack_size, + update_horizon=self.update_horizon, + gamma=self.gamma, + jumps=self._jumps + 1, + observation_dtype=self.observation_dtype, + ) + else: + return tdrbs.DeterministicOutOfGraphTemporalReplayBuffer( + observation_shape=self.observation_shape, + stack_size=self.stack_size, + update_horizon=self.update_horizon, + gamma=self.gamma, + jumps=self._jumps + 1, + observation_dtype=self.observation_dtype, + ) + + def _sample_from_replay_buffer(self): + self._rng, rng = jax.random.split(self._rng) + samples = self._replay.sample_transition_batch(rng) + types = self._replay.get_transition_elements() + self.replay_elements = collections.OrderedDict() + for element, element_type in zip(samples, types): + self.replay_elements[element_type.name] = element + + def _training_step_update(self): + """Gradient update during every training step.""" + self._sample_from_replay_buffer() + + # Add code for data augmentation. + self._rng, rng1, rng2 = jax.random.split(self._rng, num=3) + states = networks.process_inputs( + self.replay_elements['state'], + rng=rng1, + data_augmentation=self._data_augmentation) + next_states = networks.process_inputs( + self.replay_elements['next_state'][:, 0], + rng=rng2, + data_augmentation=self._data_augmentation) + + if self._replay_scheme == 'prioritized': + # The original prioritized experience replay uses a linear exponent + # schedule 0.4 -> 1.0. Comparing the schedule to a fixed exponent of + # 0.5 on 5 games (Asterix, Pong, Q*Bert, Seaquest, Space Invaders) + # suggested a fixed exponent actually performs better, except on Pong. + probs = self.replay_elements['sampling_probabilities'] + # Weight the loss by the inverse priorities. + loss_weights = 1.0 / jnp.sqrt(probs + 1e-10) + loss_weights /= jnp.max(loss_weights) + else: + # Uniform weights if not using prioritized replay. + loss_weights = jnp.ones(states.shape[0]) + + self.optimizer_state, self.online_params, loss, mean_loss,\ + dqn_loss, spr_loss,\ + self._rng = train( + self.network_def, self.online_params, self.target_network_params, + self.optimizer, self.optimizer_state, + states, + self.replay_elements['action'], + next_states, + self.replay_elements['reward'][:, 0], + self.replay_elements['terminal'][:, 0], + self.replay_elements['same_trajectory'][:, 1:], loss_weights, + self._support, self.cumulative_gamma, self._double_dqn, + self._distributional, self._rng, self.spr_weight + ) + + if self._replay_scheme == 'prioritized': + # Rainbow and prioritized replay are parametrized by an exponent + # alpha, but in both cases it is set to 0.5 - for simplicity's sake we + # leave it as is here, using the more direct sqrt(). Taking the square + # root "makes sense", as we are dealing with a squared loss. Add a + # small nonzero value to the loss to avoid 0 priority items. While + # technically this may be okay, setting all items to 0 priority will + # cause troubles, and also result in 1.0 / 0.0 = NaN correction terms. + self._replay.set_priority(self.replay_elements['indices'], + jnp.sqrt(dqn_loss + 1e-10)) + + if self.summary_writer is not None: + summary = tf.compat.v1.Summary(value=[ + tf.compat.v1.Summary.Value( + tag='TotalLoss', simple_value=float(mean_loss)), + tf.compat.v1.Summary.Value( + tag='DQNLoss', simple_value=float(dqn_loss.mean())), + tf.compat.v1.Summary.Value( + tag='SPRLoss', simple_value=float(spr_loss.mean())) + ]) + self.summary_writer.add_summary(summary, self.training_steps) + + def _store_transition(self, + last_observation, + action, + reward, + is_terminal, + *args, + priority=None, + episode_end=False): + """Stores a transition when in training mode.""" + is_prioritized = ( + isinstance(self._replay, + prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer) + or isinstance( + self._replay, + tdrbs.DeterministicOutOfGraphPrioritizedTemporalReplayBuffer)) + if is_prioritized and priority is None: + if self._replay_scheme == 'uniform': + priority = 1. + else: + priority = self._replay.sum_tree.max_recorded_priority + + if not self.eval_mode: + self._replay.add( + last_observation, + action, + reward, + is_terminal, + *args, + priority=priority, + episode_end=episode_end) + + def _train_step(self): + """Runs a single training step. + + Runs training if both: + (1) A minimum number of frames have been added to the replay buffer. + (2) `training_steps` is a multiple of `update_period`. + + Also, syncs weights from online_network_params to target_network_params if + training steps is a multiple of target update period. + """ + if self._replay.add_count > self.min_replay_history: + if self.training_steps % self.update_period == 0: + for _ in range(self._num_updates_per_train_step): + self._training_step_update() + + if self.training_steps % self.target_update_period == 0: + self._sync_weights() + + self.training_steps += 1 + + def begin_episode(self, observation): + """Returns the agent's first action for this episode.""" + self._reset_state() + self._record_observation(observation) + + if not self.eval_mode: + self._train_step() + + state = networks.process_inputs(self.state, data_augmentation=False) + self._rng, self.action = select_action( + self.network_def, self.online_params, state, self._rng, + self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train, + self.epsilon_decay_period, self.training_steps, self.min_replay_history, + self.epsilon_fn, self._support) + + self.action = onp.asarray(self.action) + return self.action + + def step(self, reward, observation): + """Records the most recent transition and returns the agent's next action.""" + self._last_observation = self._observation + self._record_observation(observation) + + if not self.eval_mode: + self._store_transition(self._last_observation, self.action, reward, False) + self._train_step() + + state = networks.process_inputs(self.state, data_augmentation=False) + self._rng, self.action = select_action( + self.network_def, self.online_params, state, self._rng, + self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train, + self.epsilon_decay_period, self.training_steps, self.min_replay_history, + self.epsilon_fn, self._support) + self.action = onp.asarray(self.action) + return self.action diff --git a/dopamine/labs/atari_100k/spr_networks.py b/dopamine/labs/atari_100k/spr_networks.py new file mode 100644 index 00000000..a0ee5227 --- /dev/null +++ b/dopamine/labs/atari_100k/spr_networks.py @@ -0,0 +1,462 @@ +# coding=utf-8 +# Copyright 2021 The Atari 100k Precipice Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Various networks for Jax Dopamine agents.""" + +import functools +import time +import collections + +SPROutputType = collections.namedtuple( + 'RL_network', ['q_values', 'logits', 'probabilities', "latent"]) +from flax import linen as nn +import gin +import jax +from jax import random +import numpy as onp + +from typing import (Any, Callable, Optional, Tuple) + +from jax import lax +from jax.nn import initializers +import jax.numpy as jnp + +PRNGKey = Any +Array = Any +Shape = Tuple[int] +Dtype = Any + + +def _absolute_dims(rank, dims): + return tuple([rank + dim if dim < 0 else dim for dim in dims]) + + +from flax.linen.module import Module, compact, merge_param + +# --------------------------- < Data Augmentation >------------------------------ + + +def _random_crop(key, img, cropped_shape): + """Random crop an image.""" + _, width, height = cropped_shape[:-1] + key_x, key_y = random.split(key, 2) + x = random.randint(key_x, shape=(), minval=0, maxval=img.shape[1] - width) + y = random.randint(key_y, shape=(), minval=0, maxval=img.shape[2] - height) + return img[:, x:x + width, y:y + height] + + +# @functools.partial(jax.jit, static_argnums=(3,)) +@functools.partial(jax.vmap, in_axes=(0, 0, 0, None)) +def _crop_with_indices(img, x, y, cropped_shape): + cropped_image = (jax.lax.dynamic_slice(img, [x, y, 0], cropped_shape[1:])) + return cropped_image + + +def _per_image_random_crop(key, img, cropped_shape): + """Random crop an image.""" + batch_size, width, height = cropped_shape[:-1] + key_x, key_y = random.split(key, 2) + x = random.randint( + key_x, shape=(batch_size,), minval=0, maxval=img.shape[1] - width) + y = random.randint( + key_y, shape=(batch_size,), minval=0, maxval=img.shape[2] - height) + return _crop_with_indices(img, x, y, cropped_shape) + + +def _intensity_aug(key, x, scale=0.05): + """Follows the code in Schwarzer et al. (2020) for intensity augmentation.""" + r = random.normal(key, shape=(x.shape[0], 1, 1, 1)) + noise = 1.0 + (scale * jnp.clip(r, -2.0, 2.0)) + return x * noise + + +@jax.jit +def drq_image_aug(key, obs, img_pad=4): + """Padding and cropping for DrQ.""" + flat_obs = obs.reshape(-1, *obs.shape[-3:]) + paddings = [(0, 0), (img_pad, img_pad), (img_pad, img_pad), (0, 0)] + cropped_shape = flat_obs.shape + # The reference uses ReplicationPad2d in pytorch, but it is not available + # in Jax. Use 'edge' instead. + flat_obs = jnp.pad(flat_obs, paddings, 'edge') + key1, key2 = random.split(key, num=2) + cropped_obs = _per_image_random_crop(key2, flat_obs, cropped_shape) + # cropped_obs = _random_crop(key2, flat_obs, cropped_shape) + aug_obs = _intensity_aug(key1, cropped_obs) + return aug_obs.reshape(*obs.shape) + + +# --------------------------- < NoisyNetwork >--------------------------------- + + +@gin.configurable +class NoisyNetwork(nn.Module): + """Noisy Network from Fortunato et al. (2018).""" + features: int = 512 + + @staticmethod + def sample_noise(key, shape): + return random.normal(key, shape) + + @staticmethod + def f(x): + # See (10) and (11) in Fortunato et al. (2018). + return jnp.multiply(jnp.sign(x), jnp.power(jnp.abs(x), 0.5)) + + @nn.compact + def __call__(self, x, rng_key, bias=True, kernel_init=None, eval_mode=False): + """ + Assumes no batch dimension. + :param x: + :param rng_key: + :param bias: + :param kernel_init: + :param eval_mode: + :return: + """ + + def mu_init(key, shape): + # Initialization of mean noise parameters (Section 3.2) + low = -1 / jnp.power(x.shape[-1], 0.5) + high = 1 / jnp.power(x.shape[-1], 0.5) + return random.uniform(key, minval=low, maxval=high, shape=shape) + + def sigma_init(key, shape, dtype=jnp.float32): + # Initialization of sigma noise parameters (Section 3.2) + return jnp.ones(shape, dtype) * (0.5 / onp.sqrt(x.shape[-1])) + + # Factored gaussian noise in (10) and (11) in Fortunato et al. (2018). + p = NoisyNetwork.sample_noise(rng_key, [x.shape[-1], 1]) + q = NoisyNetwork.sample_noise(rng_key, [1, self.features]) + f_p = NoisyNetwork.f(p) + f_q = NoisyNetwork.f(q) + w_epsilon = f_p * f_q + b_epsilon = jnp.squeeze(f_q) + + # See (8) and (9) in Fortunato et al. (2018) for output computation. + w_mu = self.param('kernel', mu_init, (x.shape[-1], self.features)) + w_sigma = self.param('kernell', sigma_init, (x.shape[-1], self.features)) + w_epsilon = jnp.where( + eval_mode, + onp.zeros(shape=(x.shape[-1], self.features), dtype=onp.float32), + w_epsilon) + w = w_mu + jnp.multiply(w_sigma, w_epsilon) + ret = jnp.matmul(x, w) + + b_epsilon = jnp.where(eval_mode, + onp.zeros(shape=(self.features,), dtype=onp.float32), + b_epsilon) + b_mu = self.param('bias', mu_init, (self.features,)) + b_sigma = self.param('biass', sigma_init, (self.features,)) + b = b_mu + jnp.multiply(b_sigma, b_epsilon) + return jnp.where(bias, ret + b, ret) + + +# --------------------------- < RainbowNetwork >--------------------------------- + + +class NoStatsBatchNorm(nn.Module): + """A version of BatchNorm that does not track running statistics, for use + in places where this functionality is not available in Jax. + + Attributes: + axis: the feature or non-batch axis of the input. + epsilon: a small float added to variance to avoid dividing by zero. + dtype: the dtype of the computation (default: float32). + use_bias: if True, bias (beta) is added. + use_scale: if True, multiply by scale (gamma). + When the next layer is linear (also e.g. nn.relu), this can be disabled + since the scaling will be done by the next layer. + bias_init: initializer for bias, by default, zero. + scale_init: initializer for scale, by default, one. + axis_name: the axis name used to combine batch statistics from multiple + devices. See `jax.pmap` for a description of axis names (default: None). + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, `[[0, 1], [2, 3]]` would independently batch-normalize over + the examples on the first two and last two devices. See `jax.lax.psum` + for more details. + """ + use_running_average: Optional[bool] = None + axis: int = -1 + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + use_bias: bool = True + use_scale: bool = True + bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros + scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones + axis_name: Optional[str] = None + axis_index_groups: Any = None + + @compact + def __call__(self, x, use_running_average: Optional[bool] = None): + """Normalizes the input using batch statistics. + + NOTE: + During initialization (when parameters are mutable) the running average + of the batch statistics will not be updated. Therefore, the inputs + fed during initialization don't need to match that of the actual input + distribution and the reduction axis (set with `axis_name`) does not have + to exist. + + Args: + x: the input to be normalized. + use_running_average: if true, the statistics stored in batch_stats + will be used instead of computing the batch statistics on the input. + + Returns: + Normalized inputs (the same shape as inputs). + """ + x = jnp.asarray(x, jnp.float32) + axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) + axis = _absolute_dims(x.ndim, axis) + feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) + reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) + reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) + + # see NOTE above on initialization behavior + initializing = self.is_mutable_collection('params') + + mean = jnp.mean(x, axis=reduction_axis, keepdims=False) + mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) + if self.axis_name is not None and not initializing: + concatenated_mean = jnp.concatenate([mean, mean2]) + mean, mean2 = jnp.split( + lax.pmean( + concatenated_mean, + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups), 2) + var = mean2 - lax.square(mean) + + y = x - mean.reshape(feature_shape) + mul = lax.rsqrt(var + self.epsilon) + if self.use_scale: + scale = self.param('scale', self.scale_init, + reduced_feature_shape).reshape(feature_shape) + mul = mul * scale + y = y * mul + if self.use_bias: + bias = self.param('bias', self.bias_init, + reduced_feature_shape).reshape(feature_shape) + y = y + bias + return jnp.asarray(y, self.dtype) + + +def feature_layer(noisy, features): + """Network feature layer depending on whether noisy_nets are used on or not.""" + if noisy: + net = NoisyNetwork(features=features) + else: + net = nn.Dense(features, kernel_init=nn.initializers.xavier_uniform()) + + def apply(x, key, eval_mode): + if noisy: + return net(x, key, True, None, eval_mode) + else: + return net(x) + + return net, apply + + +def process_inputs(x, data_augmentation=False, rng=None): + """Input normalization and if specified, data augmentation.""" + out = x.astype(jnp.float32) / 255. + if data_augmentation: + if rng is None: + raise ValueError('Pass rng when using data augmentation') + out = drq_image_aug(rng, out) + return out + + +def renormalize(tensor, has_batch=False): + shape = tensor.shape + if not has_batch: + tensor = jnp.expand_dims(tensor, 0) + tensor = tensor.reshape(tensor.shape[0], -1) + max = jnp.max(tensor, axis=-1, keepdims=True) + min = jnp.min(tensor, axis=-1, keepdims=True) + return ((tensor - min) / (max - min + 1e-5)).reshape(*shape) + + +class ConvTMCell(nn.Module): + """ + MuZero-style transition model for SPR + """ + num_actions: int + latent_dim: int + renormalize: bool + + def setup(self): + self.bn = NoStatsBatchNorm(axis=-1, axis_name="batch") + + @nn.compact + def __call__(self, x, action, eval_mode=False, key=None): + sizes = [self.latent_dim, self.latent_dim] + kernel_sizes = [3, 3] + stride_sizes = [1, 1] + + action_onehot = jax.nn.one_hot(action, self.num_actions) + action_onehot = jax.lax.broadcast(action_onehot, (x.shape[-3], x.shape[-2])) + x = jnp.concatenate([x, action_onehot], -1) + for layer in range(1): + x = nn.Conv( + features=sizes[layer], + kernel_size=(kernel_sizes[layer], kernel_sizes[layer]), + strides=(stride_sizes[layer], stride_sizes[layer]), + kernel_init=nn.initializers.xavier_uniform())( + x) + x = nn.relu(x) + # x = self.bn(x, use_running_average=False) + x = nn.Conv( + features=sizes[-1], + kernel_size=(kernel_sizes[-1], kernel_sizes[-1]), + strides=(stride_sizes[-1], stride_sizes[-1]), + kernel_init=nn.initializers.xavier_uniform())( + x) + x = nn.relu(x) + + if self.renormalize: + x = renormalize(x) + + return x, x + + +class RainbowCNN(nn.Module): + padding: Any = "SAME" + + @nn.compact + def __call__(self, x): + # x = x[None, Ellipsis] + hidden_sizes = [32, 64, 64] + kernel_sizes = [8, 4, 3] + stride_sizes = [4, 2, 1] + for layer in range(3): + x = nn.Conv( + features=hidden_sizes[layer], + kernel_size=(kernel_sizes[layer], kernel_sizes[layer]), + strides=(stride_sizes[layer], stride_sizes[layer]), + kernel_init=nn.initializers.xavier_uniform(), + padding=self.padding)( + x) + x = nn.relu(x) # flatten + return x + + +class TransitionModel(nn.Module): + num_actions: int + latent_dim: int + renormalize: bool + + @nn.compact + def __call__(self, x, action): + scan = nn.scan( + ConvTMCell, + in_axes=0, + out_axes=0, + variable_broadcast=['params'], + split_rngs={'params': False + })(latent_dim=self.latent_dim, + num_actions=self.num_actions, + renormalize=self.renormalize) + return scan(x, action) + + +@gin.configurable +class SPRNetwork(nn.Module): + """Jax Rainbow network for Full Rainbow. + + Attributes: + num_actions: int, number of actions the agent can take at any state. + num_atoms: int, the number of buckets of the value function distribution. + noisy: bool, Whether to use noisy networks. + dueling: bool, Whether to use dueling network architecture. + distributional: bool, whether to use distributional RL. + """ + num_actions: int + num_atoms: int + noisy: bool + dueling: bool + distributional: bool + renormalize: bool = True + padding: Any = "SAME" + + def setup(self): + self.transition_model = TransitionModel( + num_actions=self.num_actions, + latent_dim=64, + renormalize=self.renormalize) + self.projection, self.apply_projection = feature_layer(self.noisy, 512) + self.predictor = nn.Dense(512) + self.encoder = RainbowCNN(padding=self.padding) + + def encode(self, x): + latent = self.encoder(x) + if self.renormalize: + latent = renormalize(latent) + return latent + + def project(self, x, key, eval_mode): + return self.apply_projection(x, key=key, eval_mode=eval_mode) + + @functools.partial(jax.vmap, in_axes=(None, 0, None, None)) + def spr_predict(self, x, key, eval_mode): + return self.predictor( + self.apply_projection(x, key=key, eval_mode=eval_mode)) + + def spr_rollout(self, latent, actions, key): + _, pred_latents = self.transition_model(latent, actions) + predictions = self.spr_predict( + pred_latents.reshape(pred_latents.shape[0], -1), key, True) + return predictions + + @nn.compact + def __call__(self, + x, + support, + actions=None, + do_rollout=False, + eval_mode=False, + key=None): + # Generate a random number generation key if not provided + if key is None: + key = random.PRNGKey(int(time.time() * 1e6)) + + latent = self.encode(x) + x = self.apply_projection(latent.reshape(-1), key, + eval_mode) # Single hidden layer of size 512 + x = nn.relu(x) + + if self.dueling: + _, adv_net = feature_layer(self.noisy, self.num_actions * self.num_atoms) + _, val_net = feature_layer(self.noisy, self.num_atoms) + adv = adv_net(x, key, eval_mode) + value = val_net(x, key, eval_mode) + adv = adv.reshape((self.num_actions, self.num_atoms)) + value = value.reshape((1, self.num_atoms)) + logits = value + (adv - (jnp.mean(adv, -2, keepdims=True))) + else: + _, adv_net = feature_layer(self.noisy, self.num_actions * self.num_atoms) + x = adv_net(x, key, eval_mode) + logits = x.reshape((self.num_actions, self.num_atoms)) + + if do_rollout: + latent = self.spr_rollout(latent, actions, key) + + if self.distributional: + probabilities = jnp.squeeze(nn.softmax(logits)) + q_values = jnp.squeeze(jnp.sum(support * probabilities, axis=-1)) + return SPROutputType(q_values, logits, probabilities, latent) + + q_values = jnp.squeeze(logits) + return SPROutputType(q_values, None, None, latent) diff --git a/dopamine/labs/atari_100k/train.py b/dopamine/labs/atari_100k/train.py index b7a29a47..0f4d7a22 100644 --- a/dopamine/labs/atari_100k/train.py +++ b/dopamine/labs/atari_100k/train.py @@ -25,14 +25,13 @@ from absl import logging from dopamine.discrete_domains import run_experiment from dopamine.discrete_domains import train as base_train -from dopamine.labs.atari_100k import atari_100k_rainbow_agent +from dopamine.labs.atari_100k import atari_100k_rainbow_agent, spr_agent from dopamine.labs.atari_100k import eval_run_experiment import numpy as np import tensorflow as tf - FLAGS = flags.FLAGS -AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps'] +AGENTS = ['DER', 'DrQ', 'OTRainbow', 'DrQ_eps', 'SPR'] # flags are defined when importing run_xm_preprocessing flags.DEFINE_enum('agent', 'DER', AGENTS, 'Name of the agent.') @@ -41,11 +40,17 @@ 'Whether to use `MaxEpisodeEvalRunner` or not.') -def create_agent(sess, # pylint: disable=unused-argument - environment, - seed, - summary_writer=None): +def create_agent( + sess, # pylint: disable=unused-argument + environment, + seed, + agent, + summary_writer=None): """Helper function for creating full rainbow-based Atari 100k agent.""" + if agent == "SPR": + return spr_agent.SPRAgent(num_actions=environment.action_space.n, + seed=seed, + summary_writer=summary_writer) return atari_100k_rainbow_agent.Atari100kRainbowAgent( num_actions=environment.action_space.n, seed=seed, @@ -72,7 +77,7 @@ def main(unused_argv): gin_files, gin_bindings = FLAGS.gin_files, FLAGS.gin_bindings run_experiment.load_gin_configs(gin_files, gin_bindings) # Set the Jax agent seed using the run number - create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number) + create_agent_fn = functools.partial(create_agent, seed=FLAGS.run_number, agent=FLAGS.agent) if FLAGS.max_episode_eval: runner_fn = eval_run_experiment.MaxEpisodeEvalRunner logging.info('Using MaxEpisodeEvalRunner for evaluation.')