Skip to content

Commit 812fd36

Browse files
committed
Workable version for aamas
1 parent 08a0fb1 commit 812fd36

20 files changed

+1626
-0
lines changed

algos/ddpg.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from copy import deepcopy
2+
3+
import numpy as np
4+
import torch
5+
from torch import nn
6+
from torch.optim import Adam
7+
import torch.nn.functional as F
8+
9+
10+
from .nn import Critic, MLP
11+
from .utils import initialize_weight
12+
13+
14+
class DeterministicPolicy(nn.Module):
15+
16+
def __init__(self, state_shape, action_shape, hidden_units=(256, 256),
17+
hidden_activation=nn.ReLU(inplace=True)):
18+
super().__init__()
19+
20+
self.mlp = MLP(
21+
input_dim=state_shape[0],
22+
output_dim=action_shape[0],
23+
hidden_units=hidden_units,
24+
hidden_activation=hidden_activation,
25+
).apply(initialize_weight)
26+
27+
def forward(self, states):
28+
return torch.tanh(self.mlp(states))
29+
30+
31+
class DDPG:
32+
def __init__(self, state_shape, action_shape, max_action=1, discount=0.99, tau=5e-3,
33+
batch_size=256, device="cuda:0", seed=0, logger=None):
34+
np.random.seed(seed)
35+
torch.manual_seed(seed)
36+
37+
self.actor = DeterministicPolicy(
38+
state_shape=state_shape,
39+
action_shape=action_shape,
40+
hidden_units=[256, 256],
41+
hidden_activation=nn.ReLU(inplace=True)
42+
).to(device)
43+
self.actor_target = deepcopy(self.actor)
44+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)
45+
46+
self.critic = Critic(state_shape, action_shape).to(device)
47+
self.critic_target = deepcopy(self.critic)
48+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), weight_decay=1e-2, lr=3e-4)
49+
50+
self.logger = logger
51+
52+
self.expl_noise = 0.1
53+
self.action_shape = action_shape
54+
self.dtype = torch.float
55+
self.discount = discount
56+
self.tau = tau
57+
self.batch_size = batch_size
58+
self.max_action = max_action
59+
self.device = device
60+
61+
def exploit(self, state):
62+
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
63+
return self.actor(state).cpu().data.numpy().flatten()
64+
65+
def explore(self, state):
66+
state = torch.tensor(
67+
state, dtype=self.dtype, device=self.device).unsqueeze_(0)
68+
69+
with torch.no_grad():
70+
noise = (torch.randn(self.action_shape) * self.max_action * self.expl_noise).to(self.device)
71+
action = self.actor(state) + noise
72+
73+
a = action.cpu().numpy()[0]
74+
return np.clip(a, -self.max_action, self.max_action)
75+
76+
def update(self, batch):
77+
# Sample replay buffer
78+
#state, action, next_state, reward, done = batch
79+
state, action, reward, done, next_state = batch
80+
81+
# Compute the target Q value
82+
target_Q = self.critic_target(next_state, self.actor_target(next_state))
83+
target_Q = reward + (1.0 - done) * self.discount * target_Q.detach()
84+
85+
# Get current Q estimate
86+
current_Q = self.critic(state, action)
87+
88+
# Compute critic loss
89+
critic_loss = F.mse_loss(current_Q, target_Q)
90+
91+
# Optimize the critic
92+
self.critic_optimizer.zero_grad()
93+
critic_loss.backward()
94+
self.critic_optimizer.step()
95+
96+
# Compute actor loss
97+
actor_loss = -self.critic(state, self.actor(state)).mean()
98+
99+
# Optimize the actor
100+
self.actor_optimizer.zero_grad()
101+
actor_loss.backward()
102+
self.actor_optimizer.step()
103+
104+
# Update the frozen target models
105+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
106+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
107+
108+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
109+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
110+
111+
def accumulate_action_gradient(self, *args):
112+
pass

algos/mocco.py

+219
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from copy import deepcopy
2+
import math
3+
4+
import numpy as np
5+
import torch
6+
from torch import nn
7+
from torch.optim import Adam
8+
import torch.nn.functional as F
9+
10+
11+
from .nn import DoubleCritic, MLP, MCCritic
12+
from .utils import Clamp, initialize_weight, soft_update, disable_gradient
13+
14+
15+
class DeterministicPolicy(nn.Module):
16+
17+
def __init__(self, state_shape, action_shape, hidden_units=(256, 256),
18+
hidden_activation=nn.ReLU(inplace=True)):
19+
super().__init__()
20+
21+
self.mlp = MLP(
22+
input_dim=state_shape[0],
23+
output_dim=action_shape[0],
24+
hidden_units=hidden_units,
25+
hidden_activation=hidden_activation,
26+
).apply(initialize_weight)
27+
28+
def forward(self, states):
29+
return torch.tanh(self.mlp(states))
30+
31+
32+
class MOCCO:
33+
34+
def __init__(self, state_shape, action_shape, device, seed, batch_size=256, policy_noise=0.2,
35+
expl_noise=0.1, noise_clip=0.5, beta=0.1, policy_freq=2, gamma=0.99, lr_actor=3e-4, lr_critic=3e-4,
36+
max_action=1.0, target_update_coef=5e-3, log_every=5000, logger=None):
37+
np.random.seed(seed)
38+
torch.manual_seed(seed)
39+
40+
self.update_step = 0
41+
self.state_shape = state_shape
42+
self.action_shape = action_shape
43+
self.dtype = torch.uint8 if len(state_shape) == 3 else torch.float
44+
self.device = device
45+
self.batch_size = batch_size
46+
self.gamma = gamma
47+
self.policy_noise = policy_noise
48+
self.expl_noise = expl_noise
49+
self.noise_clip = noise_clip
50+
self.policy_freq = policy_freq
51+
self.beta = beta
52+
self.max_action = max_action
53+
self.discount = gamma
54+
self.log_every = log_every
55+
self.logger = logger
56+
57+
self.actor = DeterministicPolicy(
58+
state_shape=self.state_shape,
59+
action_shape=self.action_shape,
60+
hidden_units=[256, 256],
61+
hidden_activation=nn.ReLU(inplace=True)
62+
).to(self.device)
63+
64+
self.actor_target = deepcopy(self.actor).to(self.device).eval()
65+
66+
self.critic = DoubleCritic(
67+
state_shape=self.state_shape,
68+
action_shape=self.action_shape,
69+
hidden_units=[256, 256],
70+
hidden_activation=nn.ReLU(inplace=True)
71+
).to(self.device)
72+
73+
self.critic_target = deepcopy(self.critic).to(self.device).eval()
74+
disable_gradient(self.critic_target)
75+
76+
self.optim_actor = Adam(self.actor.parameters(), lr=lr_actor)
77+
self.optim_critic = Adam(self.critic.parameters(), lr=lr_critic)
78+
79+
self.target_update_coef = target_update_coef
80+
81+
self.critic_mc = MCCritic(
82+
state_shape=self.state_shape,
83+
action_shape=self.action_shape,
84+
hidden_units=[256, 256],
85+
hidden_activation=nn.ReLU(inplace=True)
86+
).to(self.device)
87+
self.optim_critic_mc = Adam(self.critic_mc.parameters(), lr=lr_critic)
88+
89+
noise_std = 0.58
90+
self.da_std_buf = np.zeros((10, *action_shape))
91+
self.norm_noise = np.sqrt(action_shape[0]) * noise_std
92+
self.da_std_cnt = 0
93+
self.da_std_max = np.zeros(*action_shape)
94+
95+
def get_guided_noise(self, state, a_pi=None, with_info=False):
96+
if a_pi is None:
97+
a_pi = self.actor(state) # [1, ACTION_DIM]
98+
99+
d_a = self.critic_mc.get_action_grad(self.optim_critic_mc, state, a_pi) # [1, ACTION_DIM]
100+
da_std = self.da_std_buf.std(axis=0)
101+
scale = torch.tensor(da_std / self.da_std_max).float().to(self.device)
102+
103+
d_a_norm = torch.linalg.norm(d_a, dim=1, keepdim=True)
104+
d_a_normalized = d_a / d_a_norm * self.norm_noise
105+
noise = d_a_normalized * scale # [1, 6]
106+
107+
if with_info:
108+
return noise, da_std, scale
109+
return noise
110+
111+
def explore(self, state):
112+
self.accumulate_action_gradient(state)
113+
114+
state = torch.tensor(
115+
state, dtype=self.dtype, device=self.device).unsqueeze_(0)
116+
117+
a_pi = self.actor(state)
118+
noise, da_std, scale = self.get_guided_noise(state, a_pi=a_pi, with_info=True)
119+
noise = noise.cpu()
120+
121+
# Logging
122+
if self.update_step % self.log_every == 0:
123+
for i_a in range(noise.shape[1]):
124+
self.logger.log_scalar(f"guided_noise/noise_a{i_a}", noise[0, i_a].item(), self.update_step)
125+
self.logger.log_scalar(f"guided_noise_da/da_run_std_{i_a}", da_std[i_a].item(), self.update_step)
126+
self.logger.log_scalar(f"guided_noise_scale/scale_a{i_a}", scale[i_a].item(), self.update_step)
127+
128+
a_noised = (a_pi.detach().cpu() + noise).numpy()[0]
129+
return np.clip(a_noised, -self.max_action, self.max_action)
130+
131+
def update(self, batch, batch_mc):
132+
self.update_step += 1
133+
self.update_critic_mc(*batch_mc)
134+
self.update_critic(*batch)
135+
136+
if self.update_step % self.policy_freq == 0:
137+
self.update_actor(batch[0])
138+
soft_update(self.critic_target, self.critic, self.target_update_coef)
139+
soft_update(self.actor_target, self.actor, self.target_update_coef)
140+
141+
def update_critic_mc(self, states, actions, qs_mc):
142+
q1, q2, q3 = self.critic_mc(states, actions)
143+
loss_mc = (q1 - qs_mc).pow(2).mean() + (q2 - qs_mc).pow(2).mean() + (q3 - qs_mc).pow(2).mean()
144+
self.optim_critic_mc.zero_grad()
145+
loss_mc.backward()
146+
self.optim_critic_mc.step()
147+
148+
if self.update_step % self.log_every == 0:
149+
self.logger.log_scalar("algo/mc_loss", loss_mc.item(), self.update_step)
150+
151+
def update_critic(self, states, actions, rewards, dones, next_states):
152+
q1, _ = self.critic(states, actions)
153+
154+
with torch.no_grad():
155+
# Select action according to policy and add clipped noise
156+
noise = (
157+
torch.randn_like(actions) * self.policy_noise
158+
).clamp(-self.noise_clip, self.noise_clip)
159+
160+
next_actions = self.actor_target(next_states) + noise
161+
next_actions = next_actions.clamp(-self.max_action, self.max_action)
162+
163+
q_next, _ = self.critic_target(next_states, next_actions)
164+
165+
q_target = rewards + (1.0 - dones) * self.discount * q_next
166+
167+
q_mc_1, q_mc_2, q_mc_3 = self.critic_mc(states, actions)
168+
q_mc_cat = torch.cat((q_mc_1, q_mc_2, q_mc_3), dim=1)
169+
q_mc = torch.mean(q_mc_cat, dim=1, keepdim=True).detach()
170+
mc_error = self.beta * (q1 - q_mc).pow(2).mean()
171+
172+
td_error1 = (q1 - q_target).pow(2).mean()
173+
loss_critic = td_error1 + mc_error
174+
175+
self.optim_critic.zero_grad()
176+
loss_critic.backward()
177+
self.optim_critic.step()
178+
179+
if self.update_step % self.log_every == 0:
180+
self.logger.log_scalar("algo/q1", q1.detach().mean().cpu(), self.update_step)
181+
self.logger.log_scalar("algo/q_target", q_target.mean().cpu(), self.update_step)
182+
self.logger.log_scalar("algo/q_mc", q_mc.mean().cpu(), self.update_step)
183+
self.logger.log_scalar("algo/abs_q_err", (q1 - q_target).detach().mean().cpu(), self.update_step)
184+
self.logger.log_scalar("algo/critic_loss", loss_critic.item(), self.update_step)
185+
self.logger.log_scalar("algo/td_error", td_error1.detach().item(), self.update_step)
186+
self.logger.log_scalar("algo/mc_error", mc_error.detach().item(), self.update_step)
187+
self.logger.log_scalar("algo/q1_grad_norm", self.critic.q1.get_layer_norm(), self.update_step)
188+
self.logger.log_scalar("algo/actor_grad_norm", self.actor.mlp.get_layer_norm(), self.update_step)
189+
# self.logger.log_scalar("algo/bn_mccritic", self.critic_mc.bn.weight.mean().cpu().item(), self.update_step)
190+
191+
def update_actor(self, states):
192+
actions = self.actor(states)
193+
qs1 = self.critic.Q1(states, actions)
194+
loss_actor = -qs1.mean()
195+
196+
self.optim_actor.zero_grad()
197+
loss_actor.backward()
198+
self.optim_actor.step()
199+
200+
if self.update_step % self.log_every == 0:
201+
self.logger.log_scalar("algo/loss_actor", loss_actor.item(), self.update_step)
202+
203+
def exploit(self, state):
204+
state = torch.tensor(
205+
state, dtype=self.dtype, device=self.device).unsqueeze_(0)
206+
with torch.no_grad():
207+
action = self.actor(state)
208+
return action.cpu().numpy()[0]
209+
210+
def accumulate_action_gradient(self, state):
211+
state = torch.tensor(
212+
state, dtype=self.dtype, device=self.device).unsqueeze_(0)
213+
a_pi = self.actor(state)
214+
d_a = self.critic_mc.get_action_grad(self.optim_critic_mc, state, a_pi).detach()
215+
self.da_std_buf[self.da_std_cnt, :] = d_a.cpu().numpy().flatten()
216+
self.da_std_cnt = (self.da_std_cnt + 1) % self.da_std_buf.shape[0]
217+
218+
da_std = self.da_std_buf.std(axis=0)
219+
self.da_std_max = np.maximum(self.da_std_max, da_std)

0 commit comments

Comments
 (0)