Skip to content
Draft
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2567fe5
added demo notebooks
threewisemonkeys-as Aug 27, 2020
1aebe3c
Merge branch 'master' of https://github.com/SforAiDl/genrl
threewisemonkeys-as Aug 31, 2020
e3b8a8a
Merge branch 'master' of https://github.com/SforAiDl/genrl
threewisemonkeys-as Sep 3, 2020
4cba727
initial structure
threewisemonkeys-as Sep 3, 2020
3d233c4
added mp
threewisemonkeys-as Sep 6, 2020
1c504cc
add files
threewisemonkeys-as Sep 20, 2020
2c3298a
added new structure on rpc
threewisemonkeys-as Oct 1, 2020
73586d5
working structure
threewisemonkeys-as Oct 7, 2020
9ef6845
fixed integration bugs
threewisemonkeys-as Oct 7, 2020
072d545
removed unneccary files
threewisemonkeys-as Oct 7, 2020
64db1c1
added support for running from multiple scripts
threewisemonkeys-as Oct 8, 2020
4d57a06
added evaluate to trainer
threewisemonkeys-as Oct 8, 2020
f325429
added proxy getter
threewisemonkeys-as Oct 23, 2020
7ce19ec
added rpc backend option
threewisemonkeys-as Oct 23, 2020
cfba909
added logging to trainer
threewisemonkeys-as Oct 23, 2020
992a3a9
Added more options to trainer
threewisemonkeys-as Oct 23, 2020
bf1a50a
moved load weights to user
threewisemonkeys-as Oct 23, 2020
e2eef66
decreased number of eval its
threewisemonkeys-as Oct 23, 2020
837eb18
removed train wrapper
threewisemonkeys-as Oct 23, 2020
7fcbb23
removed loop to user fn
threewisemonkeys-as Oct 26, 2020
0002fa4
added example for secondary node
threewisemonkeys-as Oct 26, 2020
bebf50f
removed original exmpale
threewisemonkeys-as Oct 26, 2020
29bd1d6
removed fn
threewisemonkeys-as Oct 26, 2020
18536a2
shifted examples
threewisemonkeys-as Oct 29, 2020
8f859d6
shifted logger to base class
threewisemonkeys-as Oct 29, 2020
555e290
added on policy example
threewisemonkeys-as Oct 29, 2020
59e960c
removed temp example
threewisemonkeys-as Oct 29, 2020
8d5a8b6
got on policy distributed example to work
threewisemonkeys-as Oct 29, 2020
8030b2a
formatting
threewisemonkeys-as Oct 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions genrl/trainers/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import copy
import multiprocessing as mp
from typing import Type, Union

import numpy as np
import reverb
import tensorflow as tf
import torch

from genrl.trainers import Trainer


class ReverbReplayBuffer:
def __init__(
self,
size,
batch_size,
obs_shape,
action_shape,
discrete=True,
reward_shape=(1,),
done_shape=(1,),
n_envs=1,
):
self.size = size
self.obs_shape = (n_envs, *obs_shape)
self.action_shape = (n_envs, *action_shape)
self.reward_shape = (n_envs, *reward_shape)
self.done_shape = (n_envs, *done_shape)
self.n_envs = n_envs
self.action_dtype = np.int64 if discrete else np.float32

self._pos = 0
self._table = reverb.Table(
name="replay_buffer",
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
max_size=self.size,
rate_limiter=reverb.rate_limiters.MinSize(2),
)
self._server = reverb.Server(tables=[self._table], port=None)
self._server_address = f"localhost:{self._server.port}"
self._client = reverb.Client(self._server_address)
self._dataset = reverb.ReplayDataset(
server_address=self._server_address,
table="replay_buffer",
max_in_flight_samples_per_worker=2 * batch_size,
dtypes=(np.float32, self.action_dtype, np.float32, np.float32, np.bool),
shapes=(
tf.TensorShape([n_envs, *obs_shape]),
tf.TensorShape([n_envs, *action_shape]),
tf.TensorShape([n_envs, *reward_shape]),
tf.TensorShape([n_envs, *obs_shape]),
tf.TensorShape([n_envs, *done_shape]),
),
)
self._iterator = self._dataset.batch(batch_size).as_numpy_iterator()

def push(self, inp):
i = []
i.append(np.array(inp[0], dtype=np.float32).reshape(self.obs_shape))
i.append(np.array(inp[1], dtype=self.action_dtype).reshape(self.action_shape))
i.append(np.array(inp[2], dtype=np.float32).reshape(self.reward_shape))
i.append(np.array(inp[3], dtype=np.float32).reshape(self.obs_shape))
i.append(np.array(inp[4], dtype=np.bool).reshape(self.done_shape))

self._client.insert(i, priorities={"replay_buffer": 1.0})
if self._pos < self.size:
self._pos += 1

def extend(self, inp):
for sample in inp:
self.push(sample)

def sample(self, *args, **kwargs):
sample = next(self._iterator)
obs, a, r, next_obs, d = [torch.from_numpy(t).float() for t in sample.data]
return obs, a, r.reshape(-1, self.n_envs), next_obs, d.reshape(-1, self.n_envs)

def __len__(self):
return self._pos

def __del__(self):
self._server.stop()


class DistributedOffPolicyTrainer(Trainer):
"""Distributed Off Policy Trainer Class

Trainer class for Distributed Off Policy Agents

"""

def __init__(
self,
*args,
env,
agent,
max_ep_len: int = 500,
max_timesteps: int = 5000,
update_interval: int = 50,
buffer_server_port=None,
param_server_port=None,
**kwargs,
):
super(DistributedOffPolicyTrainer, self).__init__(
*args, off_policy=True, max_timesteps=max_timesteps, **kwargs
)
self.env = env
self.agent = agent
self.max_ep_len = max_ep_len
self.update_interval = update_interval
self.buffer_server_port = buffer_server_port
self.param_server_port = param_server_port

def train(self, n_actors, max_buffer_size, batch_size, max_updates):
buffer_server = reverb.Server(
tables=[
reverb.Table(
name="replay_buffer",
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
max_size=max_buffer_size,
rate_limiter=reverb.rate_limiters.MinSize(2),
)
],
port=self.buffer_server_port,
)
buffer_server_address = f"localhost:{self.buffer_server.port}"

param_server = reverb.Server(
tables=[
reverb.Table(
name="replay_buffer",
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
max_size=1,
)
],
port=self.param_server_port,
)
param_server_address = f"localhost:{self.param_server.port}"

actor_procs = []
for _ in range(n_actors):
p = mp.Process(
target=self._run_actor,
args=(
copy.deepcopy(self.agent),
copy.deepcopy(self.env),
buffer_server_address,
param_server_address,
),
)
p.daemon = True
actor_procs.append(p)

learner_proc = mp.Process(
target=self._run_learner,
args=(
self.agent,
max_updates,
buffer_server_address,
param_server_address,
batch_size,
),
)
learner_proc.daemon = True

def _run_actor(self, agent, env, buffer_server_address, param_server_address):
buffer_client = reverb.Client(buffer_server_address)
param_client = reverb.Client(param_server_address)

state = env.reset()

while True:
params = param_client.sample(table="replay_buffer")
agent.load_weights(params)

action = self.get_action(state)
next_state, reward, done, info = self.env.step(action)

state = next_state.clone()

buffer_client.insert([state, action, reward, done, next_state])

def _run_learner(
self,
agent,
max_updates,
buffer_server_address,
param_server_address,
batch_size,
):
param_client = reverb.Client(param_server_address)
dataset = reverb.ReplayDataset(
server_address=buffer_server_address,
table="replay_buffer",
)
data_iter = dataset.batch(batch_size).as_numpy_iterator()

for _ in range(max_updates):
sample = next(data_iter)
agent.update_params(sample)
param_client.insert(agent.get_weights())