diff --git a/isaacgymenvs/tasks/amp/humanoid_amp_base.py b/isaacgymenvs/tasks/amp/humanoid_amp_base.py index 9b1bc5f22..e387a6e92 100644 --- a/isaacgymenvs/tasks/amp/humanoid_amp_base.py +++ b/isaacgymenvs/tasks/amp/humanoid_amp_base.py @@ -145,6 +145,7 @@ def create_sim(self): def reset_idx(self, env_ids): self._reset_actors(env_ids) + self._reset_env_tensors(env_ids) self._refresh_sim_tensors() self._compute_observations(env_ids) return @@ -342,17 +343,18 @@ def _compute_humanoid_obs(self, env_ids=None): return obs def _reset_actors(self, env_ids): + self._root_states[env_ids] = self._initial_root_states[env_ids] self._dof_pos[env_ids] = self._initial_dof_pos[env_ids] self._dof_vel[env_ids] = self._initial_dof_vel[env_ids] - env_ids_int32 = env_ids.to(dtype=torch.int32) - self.gym.set_actor_root_state_tensor_indexed(self.sim, - gymtorch.unwrap_tensor(self._initial_root_states), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + return - self.gym.set_dof_state_tensor_indexed(self.sim, - gymtorch.unwrap_tensor(self._dof_state), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + def _reset_env_tensors(self, env_ids): + env_ids_int32 = env_ids.to(dtype=torch.int32) + self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) + self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._dof_state), + gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) self.progress_buf[env_ids] = 0 self.reset_buf[env_ids] = 0 diff --git a/isaacgymenvs/tasks/humanoid_amp.py b/isaacgymenvs/tasks/humanoid_amp.py index a2e5c9b11..aa91b919f 100644 --- a/isaacgymenvs/tasks/humanoid_amp.py +++ b/isaacgymenvs/tasks/humanoid_amp.py @@ -159,23 +159,13 @@ def _reset_actors(self, env_ids): else: assert(False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init)) - self.progress_buf[env_ids] = 0 - self.reset_buf[env_ids] = 0 - self._terminate_buf[env_ids] = 0 - return def _reset_default(self, env_ids): + self._root_states[env_ids] = self._initial_root_states[env_ids] self._dof_pos[env_ids] = self._initial_dof_pos[env_ids] self._dof_vel[env_ids] = self._initial_dof_vel[env_ids] - env_ids_int32 = env_ids.to(dtype=torch.int32) - self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._initial_root_states), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) - - self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._dof_state), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) - self._reset_default_env_ids = env_ids return @@ -264,11 +254,6 @@ def _set_env_state(self, env_ids, root_pos, root_rot, dof_pos, root_vel, root_an self._dof_pos[env_ids] = dof_pos self._dof_vel[env_ids] = dof_vel - env_ids_int32 = env_ids.to(dtype=torch.int32) - self.gym.set_actor_root_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._root_states), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) - self.gym.set_dof_state_tensor_indexed(self.sim, gymtorch.unwrap_tensor(self._dof_state), - gymtorch.unwrap_tensor(env_ids_int32), len(env_ids_int32)) return def _update_hist_amp_obs(self, env_ids=None):