Skip to content

Commit 01836e0

Browse files
authored
Merge pull request #248 from cpnota/release/0.7.1
Release/0.7.1
2 parents 67b27aa + 074d0ca commit 01836e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+707
-159
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ jobs:
3030
pip install torch==1.8.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
3131
make install
3232
AutoROM -v
33+
python -m atari_py.import_roms $(python -c 'import site; print(site.getsitepackages()[0])')/multi_agent_ale_py/ROM
3334
- name: Lint code
3435
run: |
3536
make lint

all/agents/a2c.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _make_buffer(self):
101101
)
102102

103103

104-
class A2CTestAgent(Agent):
104+
class A2CTestAgent(Agent, ParallelAgent):
105105
def __init__(self, features, policy):
106106
self.features = features
107107
self.policy = policy

all/agents/dqn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,8 @@ def _should_train(self):
8181

8282

8383
class DQNTestAgent(Agent):
84-
def __init__(self, q, n_actions, exploration=0.):
85-
self.q = q
86-
self.n_actions = n_actions
87-
self.exploration = 0.001
84+
def __init__(self, policy):
85+
self.policy = policy
8886

8987
def act(self, state):
90-
if np.random.rand() < self.exploration:
91-
return np.random.randint(0, self.n_actions)
92-
return torch.argmax(self.q.eval(state)).item()
88+
return self.policy.eval(state)

all/agents/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _train(self):
9898

9999
# adjust temperature
100100
temperature_grad = (_log_probs + self.entropy_target).mean()
101-
self.temperature += self.lr_temperature * temperature_grad.detach()
101+
self.temperature = max(0, self.temperature + self.lr_temperature * temperature_grad.detach())
102102

103103
# additional debugging info
104104
self.writer.add_loss('entropy', -_log_probs.mean())

all/agents/vqn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ def _train(self, reward, next_state):
5050
self.q.reinforce(loss)
5151

5252

53-
VQNTestAgent = DQNTestAgent
53+
class VQNTestAgent(Agent, ParallelAgent):
54+
def __init__(self, policy):
55+
self.policy = policy
56+
57+
def act(self, state):
58+
return self.policy.eval(state)

all/agents/vsarsa.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from torch.nn.functional import mse_loss
2-
from ._agent import Agent
32
from ._parallel_agent import ParallelAgent
4-
from .dqn import DQNTestAgent
3+
from .vqn import VQNTestAgent
54

65

76
class VSarsa(ParallelAgent):
@@ -47,4 +46,4 @@ def _train(self, reward, next_state, next_action):
4746
self.q.reinforce(loss)
4847

4948

50-
VSarsaTestAgent = DQNTestAgent
49+
VSarsaTestAgent = VQNTestAgent

all/bodies/atari.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
class DeepmindAtariBody(Body):
88
def __init__(self, agent, lazy_frames=False, episodic_lives=True, frame_stack=4, clip_rewards=True):
9-
agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack)
9+
if frame_stack > 1:
10+
agent = FrameStack(agent, lazy=lazy_frames, size=frame_stack)
1011
if clip_rewards:
1112
agent = ClipRewards(agent)
1213
if episodic_lives:
@@ -19,7 +20,7 @@ def process_state(self, state):
1920
if 'life_lost' not in state:
2021
return state
2122

22-
if len(state) == 1:
23+
if len(state.shape) == 0:
2324
if state['life_lost']:
2425
return state.update('mask', 0.)
2526
return state

all/bodies/vision.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,9 @@ def update(self, key, value):
6969
x = {}
7070
for k in self.keys():
7171
if not k == key:
72-
x[k] = super().__getitem__(k)
72+
x[k] = dict.__getitem__(self, k)
7373
x[key] = value
74-
state = LazyState(x, device=self.device)
75-
state.to_cache = self.to_cache
74+
state = LazyState.from_state(x, x['observation'], self.to_cache)
7675
return state
7776

7877
def to(self, device):

all/environments/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from ._environment import Environment
2-
from._multiagent_environment import MultiagentEnvironment
2+
from ._multiagent_environment import MultiagentEnvironment
3+
from ._vector_environment import VectorEnvironment
34
from .gym import GymEnvironment
45
from .atari import AtariEnvironment
56
from .multiagent_atari import MultiagentAtariEnv
67
from .multiagent_pettingzoo import MultiagentPettingZooEnv
8+
from .duplicate_env import DuplicateEnvironment
9+
from .vector_env import GymVectorEnvironment
710
from .pybullet import PybulletEnvironment
811

912
__all__ = [
@@ -13,5 +16,7 @@
1316
"AtariEnvironment",
1417
"MultiagentAtariEnv",
1518
"MultiagentPettingZooEnv",
19+
"GymVectorEnvironment",
20+
"DuplicateEnvironment",
1621
"PybulletEnvironment",
1722
]
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class VectorEnvironment(ABC):
5+
"""
6+
A reinforcement learning vector Environment.
7+
8+
Similar to a regular RL environment except many environments are stacked together
9+
in the observations, rewards, and dones, and the vector environment expects
10+
an action to be given for each environment in step.
11+
12+
Also, since sub-environments are done at different times, you do not need to
13+
manually reset the environments when they are done, rather the vector environment
14+
automatically resets environments when they are complete.
15+
"""
16+
17+
@property
18+
@abstractmethod
19+
def name(self):
20+
"""
21+
The name of the environment.
22+
"""
23+
24+
@abstractmethod
25+
def reset(self):
26+
"""
27+
Reset the environment and return a new initial state.
28+
29+
Returns
30+
-------
31+
State
32+
The initial state for the next episode.
33+
"""
34+
35+
@abstractmethod
36+
def step(self, action):
37+
"""
38+
Apply an action and get the next state.
39+
40+
Parameters
41+
----------
42+
action : Action
43+
The action to apply at the current time step.
44+
45+
Returns
46+
-------
47+
all.environments.State
48+
The State of the environment after the action is applied.
49+
This State object includes both the done flag and any additional "info"
50+
float
51+
The reward achieved by the previous action
52+
"""
53+
54+
@abstractmethod
55+
def close(self):
56+
"""
57+
Clean up any extraneous environment objects.
58+
"""
59+
60+
@property
61+
@abstractmethod
62+
def state_array(self):
63+
"""
64+
A StateArray of the Environments at the current timestep.
65+
"""
66+
67+
@property
68+
@abstractmethod
69+
def state_space(self):
70+
"""
71+
The Space representing the range of observable states for each environment.
72+
73+
Returns
74+
-------
75+
Space
76+
An object of type Space that represents possible states the agent may observe
77+
"""
78+
79+
@property
80+
def observation_space(self):
81+
"""
82+
Alias for Environment.state_space.
83+
84+
Returns
85+
-------
86+
Space
87+
An object of type Space that represents possible states the agent may observe
88+
"""
89+
return self.state_space
90+
91+
@property
92+
@abstractmethod
93+
def action_space(self):
94+
"""
95+
The Space representing the range of possible actions for each environment.
96+
97+
Returns
98+
-------
99+
Space
100+
An object of type Space that represents possible actions the agent may take
101+
"""
102+
103+
@property
104+
@abstractmethod
105+
def device(self):
106+
"""
107+
The torch device the environment lives on.
108+
"""
109+
110+
@property
111+
@abstractmethod
112+
def num_envs(self):
113+
"""
114+
Number of environments in vector. This is the number of actions step() expects as input
115+
and the number of observations, dones, etc returned by the environment.
116+
"""

0 commit comments

Comments
 (0)