Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Guidelines for modifications:
* Manuel Schweiger
* Masoud Moghani
* Maurice Rahme
* Michael Groom
* Michael Gussert
* Michael Noseworthy
* Michael Lin
Expand Down
8 changes: 8 additions & 0 deletions source/isaaclab/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Changelog
---------

Unreleased
~~~~~~~~~~

Added
^^^^^
* Added :class:`isaaclab.utils.modifiers.DelayedObservation` for stochastic latency and multi-rate observation modeling (per-env lags, ``hold_prob``, ``update_period``, ``per_env_phase``).
* Added :class:`isaaclab.utils.modifiers.DelayedObservationCfg` configuration class.

0.46.1 (2025-09-10)
~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions source/isaaclab/isaaclab/utils/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from .modifier_cfg import DigitalFilterCfg
from .modifier import Integrator
from .modifier_cfg import IntegratorCfg
from .modifier import DelayedObservation
from .modifier_cfg import DelayedObservationCfg

# isort: on
from .modifier import bias, clip, scale
182 changes: 182 additions & 0 deletions source/isaaclab/isaaclab/utils/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from isaaclab.utils.buffers import DelayBuffer

from .modifier_base import ModifierBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -257,3 +259,183 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor:
self.y_prev[:] = data

return self.integral


class DelayedObservation(ModifierBase):
r"""A modifier used to return a stochastically delayed (stale) version of
an observation term. This can also be used to model multi-rate
observations for non-sensor terms, e.g., pure MDP terms or proprioceptive terms.

This modifier takes an existing observation term/function, pushes each new batched
observation into a DelayBuffer, and returns an older sample according to a
per-environment integer time-lag. Lags are drawn uniformly from
[min_lag, max_lag], with an optional probability to *hold* the previous lag (to
mimic repeated frames). With 'update_period>0' (multi-rate), new lags are applied only
on refresh ticks, which occur every update_period policy steps. Between refreshes the
realised lag can increase at most by +1 (frame hold). This process is
causal: the lag for each environment can only increase by 1 each step,
ensuring that the returned observation is never older than the previous
step's lagged observation.

Shapes are preserved: the returned tensor has the exact shape of the wrapped
term (``[num_envs, *obs_shape]``).

***Configuration:***
min_lag (int): Minimum time-lag (in steps) to sample. Default 0.
max_lag (int): Maximum time-lag (in steps) to sample. Default 3.
per_env (bool): If True, sample a different lag for each environment.
If False, use the same lag for all envs. Default True.
hold_prob (float): Probability in [0, 1] of holding the previous lag
instead of sampling a new one. Default 0.0 (always sample new).
update_period (int): If > 0, apply new lags every `update_period`
policy steps (models a lower sensor cadence). Between updates, the
lag can increase by at most +1 each step (frame hold). If 0 (default),
update every step.
per_env_phase (bool): Only relevant if `update_period > 0`. If True,
each environment has a different random phase offset for lag updates.
If False, all envs update their lag simultaneously. Default True.

***Example:***
.. code-block:: python

# create a height_scan observation using the delayed observation modifier
from isaaclab.utils.modifiers import DelayedObservation

height_scan = ObservationTermCfg(
func=mdp.height_scan,
params={"sensor_cfg": SceneEntityCfg("height_scanner")},
noise=Unoise(n_min=-0.1, n_max=0.1),
clip=(-1.0, 1.0),
modifiers=[
modifiers.DelayedObservationCfg(
min_lag=0,
max_lag=3,
per_env=True,
hold_prob=0.66,
update_period=3,
per_env_phase=True,
)
],
)
"""

def __init__(self, cfg: modifier_cfg.DelayedObservationCfg, data_dim: tuple[int, ...], device: str):
"""Initialize the DelayedObservation modifier.

Args:
cfg: Configuration parameters.
"""
# initialize parent class
super().__init__(cfg, data_dim, device)
if cfg.min_lag < 0 or cfg.max_lag < cfg.min_lag:
raise ValueError("StochasticDelay: require 0 <= min_lag <= max_lag.")
if cfg.hold_prob < 0.0 or cfg.hold_prob > 1.0:
raise ValueError("StochasticDelay: hold_prob must be in [0, 1].")
if cfg.update_period < 0:
raise ValueError("StochasticDelay: update_period must be non-negative.")
if cfg.update_period > 0 and cfg.update_period > cfg.max_lag:
raise ValueError("StochasticDelay: update_period must be <= max_lag.")

self._batch_size = data_dim[0]
self._expand_1d = len(data_dim) == 1 # e.g., input shape (N,)
self._feature_shape = (
(1,) if self._expand_1d else data_dim[1:]
) # ensure at least one feature dim for DelayBuffer resets

# state
self._buf = DelayBuffer(history_length=cfg.max_lag + 1, batch_size=self._batch_size, device=device)
self._prev_realized_lags: torch.Tensor | None = None # [N]
self._phases: torch.Tensor | None = None # [N] if multi-rate
self._step: int = 0

# prefill buffer with zeros so early delays are valid
zeros = torch.zeros((self._batch_size, *self._feature_shape), device=device)
for _ in range(cfg.max_lag + 1):
self._buf.compute(zeros)

def reset(self, env_ids: Sequence[int] | None = None):
"""Resets the delay buffer and internal state. Since the DelayBuffer
does not support partial resets, if env_ids is not None, only the
previous lags for those envs are reset to zero, forcing the
latest observation to be returned on the next call preventing
observations from before the reset being returned.

Args:
env_ids: The environment ids. Defaults to None, in which case
all environments are considered.
"""
if env_ids is None:
self._buf.reset()
self._prev_realized_lags = None
self._phases = None
self._step = 0
# prefill again with zeros
zeros = torch.zeros((self._batch_size, *self._feature_shape), device=self._device)
for _ in range(self._cfg.max_lag + 1):
self._buf.compute(zeros)
else:
if self._prev_realized_lags is not None:
self._prev_realized_lags[env_ids] = 0

def __call__(self, data: torch.Tensor) -> torch.Tensor:
"""Add the current data to the delay buffer and return a stale sample
according to the current lag for each environment.

Args:
data: The data to apply delay to.

Returns:
Delayed data. Shape is the same as data.
"""
cfg = self._cfg
self._step += 1

data_in = data.unsqueeze(-1) if (self._expand_1d and data.dim() == 1) else data

# initialize phases for multi-rate on first use
if cfg.update_period > 0 and self._phases is None:
if cfg.per_env_phase:
self._phases = torch.randint(0, cfg.update_period, (self._data_dim[0],), device=self._device)
else:
self._phases = torch.zeros(self._data_dim[0], dtype=torch.long, device=self._device)

# sample desired lags in [min_lag, max_lag]
if cfg.min_lag == cfg.max_lag:
desired_lags = torch.full((self._data_dim[0],), cfg.max_lag, dtype=torch.long, device=self._device)
else:
desired_lags = torch.randint(cfg.min_lag, cfg.max_lag + 1, (self._data_dim[0],), device=self._device)

if not cfg.per_env:
desired_lags = torch.full_like(desired_lags, desired_lags[0])

# optional: hold previous realized lag
if cfg.hold_prob > 0.0 and self._prev_realized_lags is not None:
hold_mask = torch.rand((self._data_dim[0],), device=self._device) < cfg.hold_prob
desired_lags = torch.where(hold_mask, self._prev_realized_lags, desired_lags)

# multi-rate update behavior
if cfg.update_period > 0:
refresh_mask = ((self._step - self._phases) % cfg.update_period) == 0
if self._prev_realized_lags is None:
realized_lags = desired_lags
else:
# between refreshes, lag can only increase by +1 (clamped)
hold_realized_lags = (self._prev_realized_lags + 1).clamp(max=cfg.max_lag)
realized_lags = torch.where(refresh_mask, desired_lags, hold_realized_lags)
else:
# every step: causal clamp (at most +1 step older)
if self._prev_realized_lags is None:
realized_lags = desired_lags
else:
realized_lags = torch.minimum(desired_lags, self._prev_realized_lags + 1)

realized_lags = realized_lags.clamp(min=cfg.min_lag, max=cfg.max_lag)
self._prev_realized_lags = realized_lags

# return stale sample
self._buf.set_time_lag(realized_lags)
out = self._buf.compute(data_in)

if self._expand_1d and out.dim() == 2 and data.dim() == 1:
out = out.squeeze(-1)
return out
44 changes: 44 additions & 0 deletions source/isaaclab/isaaclab/utils/modifiers/modifier_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,47 @@ class IntegratorCfg(ModifierCfg):

dt: float = MISSING
"""The time step of the integrator."""


@configclass
class DelayedObservationCfg(ModifierCfg):
"""Configuration parameters for a delayed observation modifier.

For more information, please check the :class:`DelayedObservation` class.
"""

func: type[modifier.DelayedObservation] = modifier.DelayedObservation
"""The delayed observation function to be called for applying the delay."""

# Lag parameters
min_lag: int = 0
"""The minimum lag (in number of policy steps) to be applied to the observations. Defaults to 0."""

max_lag: int = 3
"""The maximum lag (in number of policy steps) to be applied to the observations.

This value must be greater than or equal to :attr:`min_lag`.
"""

per_env: bool = True
"""Whether to use a separate lag for each environment."""

hold_prob: float = 0.0
"""The probability of holding the previous lag when updating the lag."""

# multi-rate emulation parameters (optional)
update_period: int = 1
"""The period (in number of policy steps) at which the lag is updated.

If set to 0, the lag is sampled once at the beginning and remains constant throughout the simulation.
If set to a positive integer, the lag is updated every `update_period` policy steps. Defaults to 1.

This value must be less than or equal to :attr:`max_lag` if it is greater than 0.
"""

per_env_phase: bool = True
"""Whether to use a separate phase for each environment when updating the lag.

If set to True, each environment will have its own phase when updating the lag. If set to False, all
environments will share the same phase. Defaults to True.
"""
96 changes: 96 additions & 0 deletions source/isaaclab/test/utils/test_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,99 @@ def test_integral(device):

# check if the modified data is close to the expected result
torch.testing.assert_close(processed_data, test_cfg.result)


def _counter_batch(t: int, shape, device):
return torch.full(shape, float(t), device=device)


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_delayed_observation_fixed_lag(device):
"""Fixed lag (L=2) should return t-2 after warmup; shape preserved."""
if device.startswith("cuda") and not torch.cuda.is_available():
pytest.skip("CUDA not available")

# config: fixed lag 2, single vector obs (3 envs)
cfg = modifiers.DelayedObservationCfg(min_lag=2, max_lag=2, per_env=True, hold_prob=0.0, update_period=0)
init_data = torch.zeros(3, device=device) # shape carried into modifier ctor

# choose iterations past warmup (max_lag+1 pushes) so last output reflects real history
num_iter = cfg.max_lag + 6
expected_final = torch.full_like(init_data, float((num_iter - 1) - 2))

test_cfg = ModifierTestCfg(cfg=cfg, init_data=init_data, result=expected_final, num_iter=num_iter)

# create a modifier instance
modifier_obj = test_cfg.cfg.func(test_cfg.cfg, test_cfg.init_data.shape, device=device)

for _ in range(3): # a few trials with reset
modifier_obj.reset()
for t in range(test_cfg.num_iter):
data = _counter_batch(t, test_cfg.init_data.shape, device)
processed = modifier_obj(data)
assert processed.shape == data.shape, "Modified data shape does not equal original"

torch.testing.assert_close(processed, test_cfg.result)


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_delayed_observation_multi_rate_period_3(device):
"""Multi-rate cadence: refresh every 3 steps with desired lag=3; holds in between."""
if device.startswith("cuda") and not torch.cuda.is_available():
pytest.skip("CUDA not available")

# single env scalar obs; deterministic cadence (per_env_phase=False)
cfg = modifiers.DelayedObservationCfg(
min_lag=3, max_lag=3, per_env=True, hold_prob=0.0, update_period=3, per_env_phase=False
)
init_data = torch.zeros(1, device=device)

num_iter = cfg.max_lag + 10

# compute expected final value: last t minus realized lag under the 3-step cadence
realized = None
for t in range(num_iter):
if realized is None:
realized = 3
elif ((t + 1) % cfg.update_period) == 0: # refresh on every 3rd call
realized = 3
else:
realized = min(realized + 1, cfg.max_lag)
expected_final = torch.tensor([float((num_iter - 1) - realized)], device=device)

test_cfg = ModifierTestCfg(cfg=cfg, init_data=init_data, result=expected_final, num_iter=num_iter)

modifier_obj = test_cfg.cfg.func(test_cfg.cfg, test_cfg.init_data.shape, device=device)

for _ in range(2):
modifier_obj.reset()
for t in range(test_cfg.num_iter):
data = _counter_batch(t, test_cfg.init_data.shape, device)
processed = modifier_obj(data)
assert processed.shape == data.shape
torch.testing.assert_close(processed, test_cfg.result)


@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
def test_delayed_observation_bounds_and_causality(device):
"""Lag stays within [min_lag,max_lag] and obeys causal clamp: lag_t <= lag_{t-1}+1."""
if device.startswith("cuda") and not torch.cuda.is_available():
pytest.skip("CUDA not available")

cfg = modifiers.DelayedObservationCfg(min_lag=0, max_lag=4, per_env=True, hold_prob=0.0, update_period=0)
init_data = torch.zeros(4, device=device)

modifier_obj = cfg.func(cfg, init_data.shape, device=device)

prev_lag = None
num_iter = cfg.max_lag + 20
for t in range(num_iter):
out = modifier_obj(_counter_batch(t, init_data.shape, device))
# infer realized lag from the counter signal: lag = t - out
lag = (t - out).to(torch.long)

if t >= (cfg.max_lag + 1): # after warmup
assert torch.all(lag >= cfg.min_lag) and torch.all(lag <= cfg.max_lag)
if prev_lag is not None:
assert torch.all(lag <= prev_lag + 1)
prev_lag = lag