Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 183 additions & 15 deletions amp_rsl_rl/algorithms/amp_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from typing import Optional, Tuple, Dict, Any

import itertools

import torch
import torch.nn as nn
import torch.optim as optim
Expand Down Expand Up @@ -90,6 +92,8 @@ def __init__(
schedule: str = "fixed",
desired_kl: float = 0.01,
amp_replay_buffer_size: int = 100000,
distillation_cfg: dict | None = None,
teacher: ActorCritic | None = None,
device: str = "cpu",
) -> None:
# Set device and learning hyperparameters
Expand All @@ -106,11 +110,44 @@ def __init__(
# The discriminator expects concatenated observations, so the replay buffer uses half the dimension.
obs_dim: int = self.discriminator.input_dim // 2
self.amp_storage: ReplayBuffer = ReplayBuffer(
obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=device
obs_dim=obs_dim, buffer_size=amp_replay_buffer_size, device=self.device
)
self.amp_data: AMPLoader = amp_data
self.amp_normalizer: Optional[Any] = amp_normalizer

# Check if distillation is enabled
if distillation_cfg is not None:
if teacher is None:
raise ValueError(
"The teacher is None, still distillation is enabled. This is not supported."
)

self.teacher: ActorCritic = teacher
self.teacher.to(self.device)
self.teacher.eval()

# remove required_grad from the teacher the teacher is not trained
for param in self.teacher.parameters():
param.requires_grad_(False)

self.distillation_storage: Optional[RolloutStorage] = (
None # Will be initialized later once environment parameters are known
)
self.distillation_transition: RolloutStorage.Transition = (
RolloutStorage.Transition()
)
self.distillation_cfg = distillation_cfg
self.distill_loss_type = self.distillation_cfg.get("loss_type", "mse")
if self.distill_loss_type not in ["mse", "kl"]:
raise ValueError(
"Currently we support only MSE or KL divergence as loss for the distillation."
)
else:
self.teacher = None
self.distillation_cfg = None
self.distillation_storage = None
self.distillation_transition = None

# Set up the actor-critic (policy) and move it to the device.
self.actor_critic = actor_critic
self.actor_critic.to(self.device)
Expand Down Expand Up @@ -182,6 +219,41 @@ def init_storage(
device=self.device,
)

def init_storage_distillation(
self,
num_envs: int,
num_transitions_per_env: int,
student_obs_shape: Tuple[int, ...],
teacher_obs_shape: Tuple[int, ...],
action_shape: Tuple[int, ...],
) -> None:
"""
Initializes the storage for collected transitions during interactions with the environment.

Parameters
----------
num_envs : int
Number of parallel environments.
num_transitions_per_env : int
Number of transitions to store per environment.
student_obs_shape : tuple
Shape of the observations for the student.
teacher_obs_shape : tuple
Shape of the observations for the teacher.
action_shape : tuple
Shape of the actions taken by the policy.
"""
self.distillation_storage = RolloutStorage(
training_type="distillation",
num_envs=num_envs,
num_transitions_per_env=num_transitions_per_env,
obs_shape=student_obs_shape,
privileged_obs_shape=teacher_obs_shape,
actions_shape=action_shape,
rnd_state_shape=None,
device=self.device,
)

def test_mode(self) -> None:
"""
Sets the actor-critic model to evaluation mode.
Expand Down Expand Up @@ -238,6 +310,36 @@ def act_amp(self, amp_obs: torch.Tensor) -> None:
"""
self.amp_transition.observations = amp_obs

def act_distillation(
self, student_obs: torch.Tensor, teacher_obs: torch.Tensor
) -> None:
"""
Records the student and teacher observations into the distillation transition storage.
Computes the actions for both the student and teacher networks.
This is used for distillation training.
Parameters
----------
student_obs : torch.Tensor
The observation from the student (actor-critic).
teacher_obs : torch.Tensor
The observation from the teacher (expert).
"""

if self.distillation_transition is None:
return

# compute the actions
self.distillation_transition.actions = self.actor_critic.act(
student_obs
).detach()
with torch.no_grad(): # <‑‑ guard the forward pass
self.distillation_transition.privileged_actions = (
self.teacher.act_inference(teacher_obs).detach()
)
# record the observations
self.distillation_transition.observations = student_obs
self.distillation_transition.privileged_observations = teacher_obs

def process_env_step(
self, rewards: torch.Tensor, dones: torch.Tensor, infos: Dict[str, Any]
) -> None:
Expand Down Expand Up @@ -285,6 +387,34 @@ def process_amp_step(self, amp_obs: torch.Tensor) -> None:
self.amp_storage.insert(self.amp_transition.observations, amp_obs)
self.amp_transition.clear()

def process_distillation_step(
self, rewards: torch.Tensor, dones: torch.Tensor, infos: Dict[str, Any]
) -> None:
"""
Processes the outcome of an environment step by recording rewards and handling bootstrapping
for timeouts. Also resets the actor-critic's hidden state if necessary.

This function is used in case of distillation

Parameters
----------
rewards : torch.Tensor
Rewards received from the environment.
dones : torch.Tensor
Tensor indicating if an episode is finished.
infos : dict
Additional information from the environment, which may include 'time_outs'.
"""

if self.distillation_storage is None:
return

self.distillation_transition.rewards = rewards
self.distillation_transition.dones = dones

self.distillation_storage.add_transitions(self.distillation_transition)
self.distillation_transition.clear()

def compute_returns(self, last_critic_obs: torch.Tensor) -> None:
"""
Computes the discounted returns and advantages based on the critic's evaluation of the last observation.
Expand Down Expand Up @@ -339,7 +469,9 @@ def discriminator_expert_loss(
expected = torch.ones_like(discriminator_output).to(self.device)
return loss_fn(discriminator_output, expected)

def update(self) -> Tuple[float, float, float, float, float, float, float, float]:
def update(
self,
) -> Tuple[float, float, float, float, float, float, float, float, float]:
"""
Performs a single update step for both the actor-critic (PPO) and the AMP discriminator.
It iterates over mini-batches of data, computes surrogate, value, AMP and gradient penalty losses,
Expand All @@ -350,12 +482,13 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
tuple
A tuple containing mean losses and statistics:
(mean_value_loss, mean_surrogate_loss, mean_amp_loss, mean_grad_pen_loss,
mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert)
mean_policy_pred, mean_expert_pred, mean_accuracy_policy, mean_accuracy_expert, dist_loss)
"""
# Initialize mean loss and accuracy statistics.
mean_value_loss: float = 0.0
mean_surrogate_loss: float = 0.0
mean_amp_loss: float = 0.0
mean_dist_loss: float = 0.0
mean_grad_pen_loss: float = 0.0
mean_policy_pred: float = 0.0
mean_expert_pred: float = 0.0
Expand Down Expand Up @@ -390,11 +523,14 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
// self.num_mini_batches,
)

if self.distillation_storage is not None:
distill_gen = self.distillation_storage.generator()
else:
# emits a dummy 5‐tuple of Nones forever
distill_gen = itertools.repeat((None, None, None, None, None))

# Loop over mini-batches from the environment transitions and AMP data.
for sample, sample_amp_policy, sample_amp_expert in zip(
generator, amp_policy_generator, amp_expert_generator
):
# Unpack the mini-batch sample from the environment.
for (
(
obs_batch,
critic_obs_batch,
Expand All @@ -408,7 +544,16 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
hid_states_batch,
masks_batch,
rnd_state_batch,
) = sample
),
(policy_state, policy_next_state),
(expert_state, expert_next_state),
dist_batch,
) in zip(
generator,
amp_policy_generator,
amp_expert_generator,
distill_gen,
):

# Forward pass through the actor to get current policy outputs.
self.actor_critic.act(
Expand Down Expand Up @@ -475,10 +620,6 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
- self.entropy_coef * entropy_batch.mean()
)

# Process AMP loss by unpacking policy and expert AMP samples.
policy_state, policy_next_state = sample_amp_policy
expert_state, expert_next_state = sample_amp_expert

# Normalize AMP observations if a normalizer is provided.
if self.amp_normalizer is not None:
with torch.no_grad():
Expand Down Expand Up @@ -512,11 +653,33 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float

# Compute gradient penalty to stabilize discriminator training.
grad_pen_loss = self.discriminator.compute_grad_pen(
*sample_amp_expert, lambda_=10
expert_state=expert_state,
expert_next_state=expert_next_state,
lambda_=10.0,
)

# The final loss combines the PPO loss with AMP losses.
# To the distillation if enabled
dist_loss = 0.0
if dist_batch[0] is not None:
obs_s, obs_t, _, a_teacher, _ = dist_batch
if self.distill_loss_type == "mse":
student_mean = self.actor_critic.act_inference(obs_s)
dist_loss = torch.nn.functional.mse_loss(student_mean, a_teacher)
else: # KL
with torch.no_grad():
self.teacher.update_distribution(obs_t)
tdist = self.teacher.distribution
self.actor_critic.update_distribution(obs_s)
sdist = self.actor_critic.distribution
dist_loss = torch.distributions.kl_divergence(tdist, sdist).mean()
else:
dist_loss = 0.0

# The final loss combines the PPO loss with AMP losses and distillation loss if available
loss = ppo_loss + (amp_loss + grad_pen_loss)
if self.distillation_storage:
lambda_dist = self.distillation_cfg["lam"]
loss += lambda_dist * dist_loss

# Backpropagation and optimizer step.
self.optimizer.zero_grad()
Expand All @@ -537,6 +700,7 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
mean_amp_loss += amp_loss.item()
mean_dist_loss += dist_loss.item() if self.distillation_storage else 0.0
mean_grad_pen_loss += grad_pen_loss.item()
mean_policy_pred += policy_d_prob.mean().item()
mean_expert_pred += expert_d_prob.mean().item()
Expand All @@ -558,14 +722,17 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
mean_amp_loss /= num_updates
mean_dist_loss /= num_updates
mean_grad_pen_loss /= num_updates
mean_policy_pred /= num_updates
mean_expert_pred /= num_updates
mean_accuracy_policy /= mean_accuracy_policy_elem
mean_accuracy_expert /= mean_accuracy_expert_elem

# Clear the storage for the next update cycle.
# Clear the storages for the next update cycle.
self.storage.clear()
if self.distillation_storage is not None:
self.distillation_storage.clear()

return (
mean_value_loss,
Expand All @@ -576,4 +743,5 @@ def update(self) -> Tuple[float, float, float, float, float, float, float, float
mean_expert_pred,
mean_accuracy_policy,
mean_accuracy_expert,
mean_dist_loss,
)
Loading