Skip to content

Commit

Permalink
fixing placeholders
Browse files Browse the repository at this point in the history
  • Loading branch information
Jason Peng committed Sep 17, 2024
1 parent b3d0ce9 commit c93d87c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion a1/bc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,5 @@ def _compute_actor_loss(self, norm_obs, norm_expert_a):
TODO 1.2: Implement code to calculate the loss for training the policy.
'''
# placeholder
loss = torch.zeros(1)
loss = torch.zeros(1, device=self._device)
return loss
4 changes: 2 additions & 2 deletions a2/cem_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def _eval_candidates(self, candidates):
n = candidates.shape[0]

# placeholder
rets = torch.zeros(n, device=self._device)
ep_lens = torch.zeros(n, device=self._device)
rets = torch.zeros(n, dtype=torch.float64, device=self._device)
ep_lens = torch.zeros(n, dtype=torch.float64, device=self._device)

return rets, ep_lens

Expand Down
4 changes: 2 additions & 2 deletions a2/pg_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _calc_critic_loss(self, norm_obs, tar_val):
'''

# placeholder
loss = torch.zeros(1)
loss = torch.zeros(1, device=self._device)
return loss

def _calc_actor_loss(self, norm_obs, norm_a, adv):
Expand All @@ -229,5 +229,5 @@ def _calc_actor_loss(self, norm_obs, norm_a, adv):
'''

# placeholder
loss = torch.zeros(1)
loss = torch.zeros(1, device=self._device)
return loss
2 changes: 1 addition & 1 deletion a3/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _compute_q_loss(self, norm_obs, a, tar_vals):
'''

# placeholder
loss = torch.zeros(1)
loss = torch.zeros(1, device=self._device)

return loss

Expand Down

0 comments on commit c93d87c

Please sign in to comment.