diff --git a/a1/bc_agent.py b/a1/bc_agent.py index e75ba2a..b22f9b9 100644 --- a/a1/bc_agent.py +++ b/a1/bc_agent.py @@ -140,5 +140,5 @@ def _compute_actor_loss(self, norm_obs, norm_expert_a): TODO 1.2: Implement code to calculate the loss for training the policy. ''' # placeholder - loss = torch.zeros(1) + loss = torch.zeros(1, device=self._device) return loss diff --git a/a2/cem_agent.py b/a2/cem_agent.py index cb1ba00..8655a2d 100644 --- a/a2/cem_agent.py +++ b/a2/cem_agent.py @@ -150,8 +150,8 @@ def _eval_candidates(self, candidates): n = candidates.shape[0] # placeholder - rets = torch.zeros(n, device=self._device) - ep_lens = torch.zeros(n, device=self._device) + rets = torch.zeros(n, dtype=torch.float64, device=self._device) + ep_lens = torch.zeros(n, dtype=torch.float64, device=self._device) return rets, ep_lens diff --git a/a2/pg_agent.py b/a2/pg_agent.py index 4058eb6..1f52f93 100644 --- a/a2/pg_agent.py +++ b/a2/pg_agent.py @@ -218,7 +218,7 @@ def _calc_critic_loss(self, norm_obs, tar_val): ''' # placeholder - loss = torch.zeros(1) + loss = torch.zeros(1, device=self._device) return loss def _calc_actor_loss(self, norm_obs, norm_a, adv): @@ -229,5 +229,5 @@ def _calc_actor_loss(self, norm_obs, norm_a, adv): ''' # placeholder - loss = torch.zeros(1) + loss = torch.zeros(1, device=self._device) return loss \ No newline at end of file diff --git a/a3/dqn_agent.py b/a3/dqn_agent.py index aa1c6c0..a1b6d60 100644 --- a/a3/dqn_agent.py +++ b/a3/dqn_agent.py @@ -172,7 +172,7 @@ def _compute_q_loss(self, norm_obs, a, tar_vals): ''' # placeholder - loss = torch.zeros(1) + loss = torch.zeros(1, device=self._device) return loss