diff --git a/amp_rsl_rl/algorithms/amp_ppo.py b/amp_rsl_rl/algorithms/amp_ppo.py index 1781604..731ca3a 100644 --- a/amp_rsl_rl/algorithms/amp_ppo.py +++ b/amp_rsl_rl/algorithms/amp_ppo.py @@ -8,6 +8,8 @@ from typing import Optional, Tuple, Dict, Any +import itertools + import torch import torch.nn as nn import torch.optim as optim @@ -90,6 +92,8 @@ def __init__( schedule: str = "fixed", desired_kl: float = 0.01, amp_replay_buffer_size: int = 100000, + distillation_cfg: dict | None = None, + teacher: ActorCritic | None = None, device: str = "cpu", ) -> None: # Set device and learning hyperparameters @@ -106,11 +110,44 @@ def __init__( # The discriminator expects concatenated observations, so the replay buffer uses half the dimension. obs_dim: int = self.discriminator.input_dim // 2 self.amp_storage: ReplayBuffer = ReplayBuffer( - obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=device + obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=self.device ) self.amp_data: AMPLoader = amp_data self.amp_normalizer: Optional[Any] = amp_normalizer + # Check if distillation is enabled + if distillation_cfg is not None: + if teacher is None: + raise ValueError( + "The teacher is None, still distillation is enabled. This is not supported." + ) + + self.teacher: ActorCritic = teacher + self.teacher.to(self.device) + self.teacher.eval() + + # remove required_grad from the teacher the teacher is not trained + for param in self.teacher.parameters(): + param.requires_grad_(False) + + self.distillation_storage: Optional[RolloutStorage] = ( + None # Will be initialized later once environment parameters are known + ) + self.distillation_transition: RolloutStorage.Transition = ( + RolloutStorage.Transition() + ) + self.distillation_cfg = distillation_cfg + self.distill_loss_type = self.distillation_cfg.get("loss_type", "mse") + if self.distill_loss_type not in ["mse", "kl"]: + raise ValueError( + "Currently we support only MSE or KL divergence as loss for the distillation." + ) + else: + self.teacher = None + self.distillation_cfg = None + self.distillation_storage = None + self.distillation_transition = None + # Set up the actor-critic (policy) and move it to the device. self.actor_critic = actor_critic self.actor_critic.to(self.device) @@ -182,6 +219,41 @@ def init_storage( device=self.device, ) + def init_storage_distillation( + self, + num_envs: int, + num_transitions_per_env: int, + student_obs_shape: Tuple[int, ...], + teacher_obs_shape: Tuple[int, ...], + action_shape: Tuple[int, ...], + ) -> None: + """ + Initializes the storage for collected transitions during interactions with the environment. + + Parameters + ---------- + num_envs : int + Number of parallel environments. + num_transitions_per_env : int + Number of transitions to store per environment. + student_obs_shape : tuple + Shape of the observations for the student. + teacher_obs_shape : tuple + Shape of the observations for the teacher. + action_shape : tuple + Shape of the actions taken by the policy. + """ + self.distillation_storage = RolloutStorage( + training_type="distillation", + num_envs=num_envs, + num_transitions_per_env=num_transitions_per_env, + obs_shape=student_obs_shape, + privileged_obs_shape=teacher_obs_shape, + actions_shape=action_shape, + rnd_state_shape=None, + device=self.device, + ) + def test_mode(self) -> None: """ Sets the actor-critic model to evaluation mode. @@ -238,6 +310,36 @@ def act_amp(self, amp_obs: torch.Tensor) -> None: """ self.amp_transition.observations = amp_obs + def act_distillation( + self, student_obs: torch.Tensor, teacher_obs: torch.Tensor + ) -> None: + """ + Records the student and teacher observations into the distillation transition storage. + Computes the actions for both the student and teacher networks. + This is used for distillation training. + Parameters + ---------- + student_obs : torch.Tensor + The observation from the student (actor-critic). + teacher_obs : torch.Tensor + The observation from the teacher (expert). + """ + + if self.distillation_transition is None: + return + + # compute the actions + self.distillation_transition.actions = self.actor_critic.act( + student_obs + ).detach() + with torch.no_grad(): # <‑‑ guard the forward pass + self.distillation_transition.privileged_actions = ( + self.teacher.act_inference(teacher_obs).detach() + ) + # record the observations + self.distillation_transition.observations = student_obs + self.distillation_transition.privileged_observations = teacher_obs + def process_env_step( self, rewards: torch.Tensor, dones: torch.Tensor, infos: Dict[str, Any] ) -> None: @@ -285,6 +387,34 @@ def process_amp_step(self, amp_obs: torch.Tensor) -> None: self.amp_storage.insert(self.amp_transition.observations, amp_obs) self.amp_transition.clear() + def process_distillation_step( + self, rewards: torch.Tensor, dones: torch.Tensor, infos: Dict[str, Any] + ) -> None: + """ + Processes the outcome of an environment step by recording rewards and handling bootstrapping + for timeouts. Also resets the actor-critic's hidden state if necessary. + + This function is used in case of distillation + + Parameters + ---------- + rewards : torch.Tensor + Rewards received from the environment. + dones : torch.Tensor + Tensor indicating if an episode is finished. + infos : dict + Additional information from the environment, which may include 'time_outs'. + """ + + if self.distillation_storage is None: + return + + self.distillation_transition.rewards = rewards + self.distillation_transition.dones = dones + + self.distillation_storage.add_transitions(self.distillation_transition) + self.distillation_transition.clear() + def compute_returns(self, last_critic_obs: torch.Tensor) -> None: """ Computes the discounted returns and advantages based on the critic's evaluation of the last observation. @@ -339,7 +469,9 @@ def discriminator_expert_loss( expected = torch.ones_like(discriminator_output).to(self.device) return loss_fn(discriminator_output, expected) - def update(self) -> Tuple[float, float, float, float, float, float, float, float]: + def update( + self, + ) -> Tuple[float, float, float, float, float, float, float, float, float]: """ Performs a single update step for both the actor-critic (PPO) and the AMP discriminator. It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses, @@ -350,12 +482,13 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float tuple A tuple containing mean losses and statistics: (mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss, - mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert) + mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, dist_loss) """ # Initialize mean loss and accuracy statistics. mean_value_loss: float = 0.0 mean_surrogate_loss: float = 0.0 mean_amp_loss: float = 0.0 + mean_dist_loss: float = 0.0 mean_grad_pen_loss: float = 0.0 mean_policy_pred: float = 0.0 mean_expert_pred: float = 0.0 @@ -390,11 +523,14 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float // self.num_mini_batches, ) + if self.distillation_storage is not None: + distill_gen = self.distillation_storage.generator() + else: + # emits a dummy 5‐tuple of Nones forever + distill_gen = itertools.repeat((None, None, None, None, None)) + # Loop over mini-batches from the environment transitions and AMP data. - for sample, sample_amp_policy, sample_amp_expert in zip( - generator, amp_policy_generator, amp_expert_generator - ): - # Unpack the mini-batch sample from the environment. + for ( ( obs_batch, critic_obs_batch, @@ -408,7 +544,16 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float hid_states_batch, masks_batch, rnd_state_batch, - ) = sample + ), + (policy_state, policy_next_state), + (expert_state, expert_next_state), + dist_batch, + ) in zip( + generator, + amp_policy_generator, + amp_expert_generator, + distill_gen, + ): # Forward pass through the actor to get current policy outputs. self.actor_critic.act( @@ -475,10 +620,6 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float - self.entropy_coef * entropy_batch.mean() ) - # Process AMP loss by unpacking policy and expert AMP samples. - policy_state, policy_next_state = sample_amp_policy - expert_state, expert_next_state = sample_amp_expert - # Normalize AMP observations if a normalizer is provided. if self.amp_normalizer is not None: with torch.no_grad(): @@ -512,11 +653,33 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float # Compute gradient penalty to stabilize discriminator training. grad_pen_loss = self.discriminator.compute_grad_pen( - *sample_amp_expert, lambda_=10 + expert_state=expert_state, + expert_next_state=expert_next_state, + lambda_=10.0, ) - # The final loss combines the PPO loss with AMP losses. + # To the distillation if enabled + dist_loss = 0.0 + if dist_batch[0] is not None: + obs_s, obs_t, _, a_teacher, _ = dist_batch + if self.distill_loss_type == "mse": + student_mean = self.actor_critic.act_inference(obs_s) + dist_loss = torch.nn.functional.mse_loss(student_mean, a_teacher) + else: # KL + with torch.no_grad(): + self.teacher.update_distribution(obs_t) + tdist = self.teacher.distribution + self.actor_critic.update_distribution(obs_s) + sdist = self.actor_critic.distribution + dist_loss = torch.distributions.kl_divergence(tdist, sdist).mean() + else: + dist_loss = 0.0 + + # The final loss combines the PPO loss with AMP losses and distillation loss if available loss = ppo_loss + (amp_loss + grad_pen_loss) + if self.distillation_storage: + lambda_dist = self.distillation_cfg["lam"] + loss += lambda_dist * dist_loss # Backpropagation and optimizer step. self.optimizer.zero_grad() @@ -537,6 +700,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_value_loss += value_loss.item() mean_surrogate_loss += surrogate_loss.item() mean_amp_loss += amp_loss.item() + mean_dist_loss += dist_loss.item() if self.distillation_storage else 0.0 mean_grad_pen_loss += grad_pen_loss.item() mean_policy_pred += policy_d_prob.mean().item() mean_expert_pred += expert_d_prob.mean().item() @@ -558,14 +722,17 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_value_loss /= num_updates mean_surrogate_loss /= num_updates mean_amp_loss /= num_updates + mean_dist_loss /= num_updates mean_grad_pen_loss /= num_updates mean_policy_pred /= num_updates mean_expert_pred /= num_updates mean_accuracy_policy /= mean_accuracy_policy_elem mean_accuracy_expert /= mean_accuracy_expert_elem - # Clear the storage for the next update cycle. + # Clear the storages for the next update cycle. self.storage.clear() + if self.distillation_storage is not None: + self.distillation_storage.clear() return ( mean_value_loss, @@ -576,4 +743,5 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, + mean_dist_loss, ) diff --git a/amp_rsl_rl/runners/amp_on_policy_runner.py b/amp_rsl_rl/runners/amp_on_policy_runner.py index 5e074e5..7a78df5 100644 --- a/amp_rsl_rl/runners/amp_on_policy_runner.py +++ b/amp_rsl_rl/runners/amp_on_policy_runner.py @@ -130,6 +130,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): self.alg_cfg = train_cfg["algorithm"] self.policy_cfg = train_cfg["policy"] self.discriminator_cfg = train_cfg["discriminator"] + self.dist_cfg = self.cfg.get("distill", None) + self.teacher_cfg = self.cfg.get("teacher", None) self.device = device self.env = env @@ -163,8 +165,6 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): actuated_joint_names, ) - # self.env.unwrapped.scene["robot"].joint_names) - # amp_data = AMPLoader(num_amp_obs, self.device) self.amp_normalizer = Normalizer(num_amp_obs) self.discriminator = Discriminator( @@ -175,6 +175,43 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): device=self.device, ).to(self.device) + # if you want distillation, add these two keys to train_cfg["algorithm"] + teacher = None + num_teacher_obs = None + + if self.dist_cfg is None and self.teacher_cfg is not None: + raise ValueError( + "You need to provide distillation configuration if you want to use a teacher model." + ) + if self.dist_cfg is not None and self.teacher_cfg is None: + raise ValueError( + "You need to provide teacher configuration if you want to use distillation." + ) + + if self.dist_cfg is not None and self.teacher_cfg is not None: + # Load the distillation configuration + num_teacher_obs = extras["observations"]["teacher"].shape[1] + if "teacher_critic" in extras["observations"]: + num_teacher_critic_obs = extras["observations"]["teacher_critic"].shape[ + 1 + ] + else: + num_teacher_critic_obs = num_teacher_obs + # Initialize the teacher + teacher_class: ActorCritic = eval(self.teacher_cfg.pop("class_name")) + teacher = teacher_class( + num_teacher_obs, + num_teacher_critic_obs, + self.env.num_actions, + **self.teacher_cfg, + ).to(self.device) + teacher.load_state_dict( + torch.load( + self.dist_cfg["teacher_model_path"], map_location=self.device + )["model_state_dict"] + ) + teacher.eval() + # Initialize the PPO algorithm alg_class = eval(self.alg_cfg.pop("class_name")) # AMP_PPO @@ -184,6 +221,8 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): amp_data=amp_data, amp_normalizer=self.amp_normalizer, device=self.device, + distillation_cfg=self.dist_cfg, + teacher=teacher, **self.alg_cfg, ) self.num_steps_per_env = self.cfg["num_steps_per_env"] @@ -208,6 +247,15 @@ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): [self.env.num_actions], ) + if self.dist_cfg is not None: + self.alg.init_storage_distillation( + num_envs=self.env.num_envs, + num_transitions_per_env=self.num_steps_per_env, + student_obs_shape=[num_obs], + teacher_obs_shape=[num_teacher_obs], + action_shape=[self.env.num_actions], + ) + # Log self.log_dir = log_dir self.writer = None @@ -293,6 +341,7 @@ def update_run_name_with_sequence(prefix: str) -> None: obs, extras = self.env.get_observations() amp_obs = extras["observations"]["amp"] critic_obs = extras["observations"].get("critic", obs) + teacher_obs = extras["observations"].get("teacher", None) obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) amp_obs = amp_obs.to(self.device) self.train_mode() # switch to train mode (for dropout for example) @@ -320,9 +369,12 @@ def update_run_name_with_sequence(prefix: str) -> None: for i in range(self.num_steps_per_env): actions = self.alg.act(obs, critic_obs) self.alg.act_amp(amp_obs) + self.alg.act_distillation(obs, teacher_obs) obs, rewards, dones, infos = self.env.step(actions) _, extras = self.env.get_observations() next_amp_obs = extras["observations"]["amp"] + if self.dist_cfg is not None: + next_teacher_obs = extras["observations"]["teacher"] obs = self.obs_normalizer(obs) if "critic" in infos["observations"]: critic_obs = self.critic_obs_normalizer( @@ -338,6 +390,9 @@ def update_run_name_with_sequence(prefix: str) -> None: ) next_amp_obs = next_amp_obs.to(self.device) + if self.dist_cfg is not None: + teacher_obs = next_teacher_obs.to(self.device) + # Process the AMP reward style_rewards = self.discriminator.predict_reward( amp_obs, next_amp_obs, normalizer=self.amp_normalizer @@ -351,6 +406,7 @@ def update_run_name_with_sequence(prefix: str) -> None: self.alg.process_env_step(rewards, dones, infos) self.alg.process_amp_step(next_amp_obs) + self.alg.process_distillation_step(rewards, dones, infos) # The next observation becomes the current observation for the next step amp_obs = torch.clone(next_amp_obs) @@ -395,6 +451,7 @@ def update_run_name_with_sequence(prefix: str) -> None: mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, + mean_dist_loss, ) = self.alg.update() stop = time.time() learn_time = stop - start @@ -471,6 +528,11 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): "Loss/accuracy_expert", locs["mean_accuracy_expert"], locs["it"] ) + if self.dist_cfg is not None: + self.writer.add_scalar( + "Loss/distillation_loss", locs["mean_dist_loss"], locs["it"] + ) + self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"]) self.writer.add_scalar("Perf/total_fps", fps, locs["it"])