Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
34413b1
add control_dofs_force to the tracked function for gradient flow
SonSang May 27, 2026
cf9cc32
add scene-level backward api and unit tests
SonSang May 27, 2026
c977d11
refactor forward_dynamics and forward_kinematics to remove unused BW …
SonSang May 27, 2026
a446ea9
add differentiable contact for plane vs. convex shapes
SonSang May 27, 2026
0b3b785
add manual bw kernels for adding inequality constraints
SonSang May 27, 2026
53642b5
use CG in backward pass when n_constraints=0, which is unrelibale for…
SonSang May 27, 2026
e7ad301
implement robust and efficient rigid solver's backward pass
SonSang May 28, 2026
c5eb3fa
minor fix
SonSang May 28, 2026
e7a61d3
add unit tests for differentiability
SonSang May 28, 2026
9919a63
minor fix
SonSang May 28, 2026
4c69c88
fix bug
SonSang May 28, 2026
b3cb60c
minor bug fix
SonSang May 28, 2026
a89beb0
remove redundant round trip through numpy
SonSang May 28, 2026
c0a3940
add host guard for unsupported constraint add functions
SonSang May 28, 2026
094464b
include gpu version diff contact fd test
SonSang Jun 2, 2026
c8ea21d
Merge branch 'main' into 20260527_diff_rigid_demo_prod
SonSang Jun 2, 2026
b9eff6f
ruff
SonSang Jun 2, 2026
a09bf66
add backward pass for frictionloss inequality constraints
SonSang Jun 2, 2026
3ab8d22
implement backward pass for equality joint
SonSang Jun 2, 2026
782b36b
add backward pass for equality constraint for connect
SonSang Jun 2, 2026
e94e3ed
add backward pass for weld equality constraints
SonSang Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions genesis/assets/xml/cartpole.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<mujoco model="cartpole">
<option gravity="0 0 -9.81"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" range="-4 4" damping="0.0"/>
<inertial pos="0 0 0" mass="1.0" diaginertia="1.0 1.0 1.0"/>
<geom name="cart_geom" type="box" size="0.25 0.25 0.1" contype="0" conaffinity="0" rgba="0 0 0.8 1"/>
<body name="pole" pos="0 0 0">
<joint name="hinge" type="hinge" axis="0 1 0" damping="0.0"/>
<inertial pos="0 0 0.5" mass="10.0" diaginertia="1.0 1.0 1.0"/>
<geom name="pole_geom" type="box" pos="0 0 0.5" size="0.025 0.025 0.5" contype="0" conaffinity="0" rgba="1 1 1 1"/>
</body>
</body>
</worldbody>
</mujoco>
46 changes: 46 additions & 0 deletions genesis/assets/xml/grad/all_eq_fric.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<mujoco model="all_eq_fric">
<!-- Integration scene exercising every differentiated constraint group:
* frictionloss : on j1
* equality JOINT: j1 <-> j2 (linear polycoef)
* equality CONNECT: arm3 anchor <-> arm4 anchor
* equality WELD : arm5 <-> arm6
Each group acts on a disjoint pair of bodies so the constraint solver
does not face an over-constrained system within any one pair. -->
<worldbody>
<body name="arm1" pos="0 0 0">
<joint name="j1" type="hinge" axis="0 1 0" frictionloss="0.5"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm2" pos="0 0.2 0">
<joint name="j2" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm3" pos="0 0.4 0">
<joint name="j3" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm4" pos="0 0.6 0">
<joint name="j4" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm5" pos="0 0.8 0">
<joint name="j5" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm6" pos="0 1.0 0">
<joint name="j6" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
<equality>
<joint joint1="j1" joint2="j2" polycoef="0 1 0 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
<connect body1="arm3" body2="arm4" anchor="0.2 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
<weld body1="arm5" body2="arm6" relpose="0 -0.2 0 1 0 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
</equality>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/capsule.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="capsule">
<compiler angle="degree"/>
<worldbody>
<body name="capsule" pos="0 0 0">
<geom type="capsule" size="0.1 0.2"/>
<joint name="capsule_joint" type="free"/>
</body>
</worldbody>
</mujoco>
17 changes: 17 additions & 0 deletions genesis/assets/xml/grad/connect_loop.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<mujoco model="connect_loop">
<worldbody>
<body name="arm1" pos="0 0 0">
<joint name="j1" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm2" pos="0 0.3 0">
<joint name="j2" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
<equality>
<connect body1="arm1" body2="arm2" anchor="0.2 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
</equality>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/free.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="free">
<worldbody>
<body name="chassis" pos="0 0 0">
<freejoint/>
<inertial mass="1.0" pos="0 0 0" diaginertia="0.1 0.1 0.1"/>
<geom type="box" size="0.1 0.1 0.1" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
14 changes: 14 additions & 0 deletions genesis/assets/xml/grad/free_with_revolute.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<mujoco model="free_with_child">
<worldbody>
<body name="chassis" pos="0 0 0">
<freejoint/>
<inertial mass="1.0" pos="0 0 0" diaginertia="0.1 0.1 0.1"/>
<geom type="box" size="0.1 0.1 0.1" contype="0" conaffinity="0"/>
<body name="arm" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
</mujoco>
17 changes: 17 additions & 0 deletions genesis/assets/xml/grad/hinge_pair_joint_eq_linear.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<mujoco model="hinge_pair_joint_eq_linear">
<worldbody>
<body name="arm1" pos="0 0 0">
<joint name="j1" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="arm2" pos="0.2 0 0">
<joint name="j2" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<equality>
<joint joint1="j1" joint2="j2" polycoef="0 1 0 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
</equality>
</mujoco>
17 changes: 17 additions & 0 deletions genesis/assets/xml/grad/hinge_pair_joint_eq_quadratic.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<mujoco model="hinge_pair_joint_eq_quadratic">
<worldbody>
<body name="arm1" pos="0 0 0">
<joint name="j1" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="arm2" pos="0.2 0 0">
<joint name="j2" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</worldbody>
<equality>
<joint joint1="j1" joint2="j2" polycoef="0 1 0.5 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
</equality>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/prismatic.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="prismatic">
<worldbody>
<body name="slider" pos="0 0 0">
<joint type="slide" axis="1 0 0"/>
<inertial mass="0.5" pos="0 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="box" size="0.05 0.05 0.05" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/revolute.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="revolute">
<worldbody>
<body name="arm" pos="0 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
19 changes: 19 additions & 0 deletions genesis/assets/xml/grad/revolute_chain3.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<mujoco model="chain3">
<worldbody>
<body name="l1" pos="0 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="l2" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
<body name="l3" pos="0.2 0 0">
<joint type="hinge" axis="0 1 0"/>
<inertial mass="0.3" pos="0.1 0 0" diaginertia="0.005 0.005 0.005"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</body>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/revolute_frictionloss.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="revolute_frictionloss">
<worldbody>
<body name="arm" pos="0 0 0">
<joint type="hinge" axis="0 1 0" frictionloss="0.5"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
10 changes: 10 additions & 0 deletions genesis/assets/xml/grad/slider_limit.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<mujoco model="slider_limit">
<option gravity="0 0 0"/>
<worldbody>
<body name="cart" pos="0 0 0">
<joint name="slider" type="slide" axis="1 0 0" range="-4 4" damping="0.0"/>
<inertial pos="0 0 0" mass="1.0" diaginertia="1.0 1.0 1.0"/>
<geom type="box" size="0.25 0.25 0.1" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
9 changes: 9 additions & 0 deletions genesis/assets/xml/grad/spherical.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
<mujoco model="spherical">
<worldbody>
<body name="ball" pos="0 0 0">
<joint type="ball"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
</mujoco>
17 changes: 17 additions & 0 deletions genesis/assets/xml/grad/weld_pair.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<mujoco model="weld_pair">
<worldbody>
<body name="arm1" pos="0 0 0">
<joint name="j1" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
<body name="arm2" pos="0 0.3 0">
<joint name="j2" type="hinge" axis="0 1 0"/>
<inertial mass="0.5" pos="0.1 0 0" diaginertia="0.01 0.01 0.01"/>
<geom type="capsule" fromto="0 0 0 0.2 0 0" size="0.02" contype="0" conaffinity="0"/>
</body>
</worldbody>
<equality>
<weld body1="arm1" body2="arm2" relpose="0 -0.3 0 1 0 0 0" solimp="0.95 0.99 0.001" solref="0.005 1"/>
</equality>
</mujoco>
27 changes: 27 additions & 0 deletions genesis/assets/xml/hopper.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
<mujoco model="hopper">
<compiler angle="radian"/>
<default>
<joint limited="true" armature="1" damping="1"/>
<geom condim="3" friction="0.9 0.005 0.0001" rgba="0.8 0.6 0.4 1"/>
</default>
<worldbody>
<body name="torso" pos="0 0 1.25">
<joint name="rootx" pos="0 0 0" axis="1 0 0" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rootz" pos="0 0 0" axis="0 0 1" type="slide" limited="false" armature="0" damping="0"/>
<joint name="rooty" pos="0 0 0" axis="0 1 0" type="hinge" limited="false" armature="0" damping="0"/>
<geom name="torso_geom" type="capsule" size="0.05 0.2"/>
<body name="thigh" pos="0 0 -0.2">
<joint name="thigh_joint" pos="0 0 0" axis="0 -1 0" type="hinge" range="-2.61799 0"/>
<geom name="thigh_geom" type="capsule" size="0.05 0.225" pos="0 0 -0.225"/>
<body name="leg" pos="0 0 -0.7">
<joint name="leg_joint" pos="0 0 0.25" axis="0 -1 0" type="hinge" range="-2.61799 0"/>
<geom name="leg_geom" type="capsule" size="0.04 0.25"/>
<body name="foot" pos="0 0 -0.25">
<joint name="foot_joint" pos="0 0 0" axis="0 -1 0" type="hinge" range="-0.785398 0.785398"/>
<geom name="foot_geom" type="capsule" size="0.06 0.195" pos="0.06 0 0" quat="0.707107 0 -0.707107 0" friction="2 0.005 0.0001"/>
</body>
</body>
</body>
</body>
</worldbody>
</mujoco>
19 changes: 18 additions & 1 deletion genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self._load_model()

# Initialize target variables and checkpoint
self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity")
self._tgt_keys = ("pos", "quat", "qpos", "dofs_velocity", "control_dofs_force")
self._tgt = dict()
self._tgt_buffer = list()
self._ckpt = dict()
Expand Down Expand Up @@ -1156,6 +1156,8 @@ def process_input(self, in_backward=False):
self.set_quat(**data_kwargs)
case "set_dofs_velocity":
self.set_dofs_velocity(**data_kwargs)
case "control_dofs_force":
self.control_dofs_force(**data_kwargs)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

Expand Down Expand Up @@ -1188,6 +1190,15 @@ def process_input_grad(self):
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)

case "control_dofs_force":
force = data_kwargs.pop("force")
if force is not None and force.requires_grad:
force._backward_from_qd(
self.set_dofs_force_grad,
data_kwargs["dofs_idx_local"],
data_kwargs["envs_idx"],
)
case _:
gs.raise_exception(f"Invalid target key: {key} not in {self._tgt_keys}")

Expand Down Expand Up @@ -3559,6 +3570,11 @@ def set_dofs_velocity_grad(self, dofs_idx_local, envs_idx, velocity_grad):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity_grad(dofs_idx, envs_idx, velocity_grad.data)

@gs.assert_built
def set_dofs_force_grad(self, dofs_idx_local, envs_idx, force_grad):
dofs_idx = self._get_global_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_force_grad(dofs_idx, envs_idx, force_grad.data)

# ------------------------------------------------------------------------------------
# ----------------------------- DOF property setters ---------------------------------
# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -3592,6 +3608,7 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
# ------------------------------------------------------------------------------------

@gs.assert_built
@tracked
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None):
"""
Control the entity's dofs' motor force. This is used for force/torque control.
Expand Down
50 changes: 48 additions & 2 deletions genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,13 +977,16 @@ def reset(self, state: SimState | None = None, envs_idx=None):
self._reset(state, envs_idx=envs_idx)
self._recorder_manager.reset(envs_idx)

def _reset(self, state: SimState | None = None, *, envs_idx=None):
def _reset(self, state: SimState | None = None, *, envs_idx=None, keep_init: bool = False):
if self._is_built:
if state is None:
state = self._init_state
else:
assert isinstance(state, SimState), "state must be a SimState object"
self._init_state = state
# `keep_init=True` restores the state without making it the new
# init, so a later bare `reset()` still rewinds to the true init.
if not keep_init:
self._init_state = state
self._sim.reset(state, envs_idx)
else:
self._init_state = self._get_state()
Expand All @@ -1004,6 +1007,49 @@ def _reset(self, state: SimState | None = None, *, envs_idx=None):
def _reset_grad(self):
self._backward_ready = True

@gs.assert_built
def backward(self, loss: torch.Tensor, *args, **kwargs):
"""Differentiate `loss` and restore the terminal physics state.

Wraps the snapshot/backward/restore dance that differentiable rollouts
otherwise have to perform by hand. `scene._backward()` rewinds physics
state to step 0 as a side-effect of unrolling the adstack, so the safe
pattern is to snapshot the terminal state *before* backward and restore
it *after*:

snapshot = scene.get_state() # terminal state
loss.backward() # rewinds physics to step 0
scene.reset(snapshot) # restore + clear grads + re-arm

This method does exactly that, so callers can just write
`scene.backward(loss)`. Afterwards the scene sits at the terminal physics
state with grads cleared and forward/backward re-armed — ready to continue
the rollout or to be reset to a fresh init.

The registered initial state (`reset()` with no args) is left untouched.

Parameters
----------
loss : torch.Tensor
Scalar loss to differentiate. Extra args/kwargs (e.g. `gradient`,
`retain_graph`) are forwarded to `torch.autograd.backward`.
"""
# Snapshot the terminal state before backward rewinds physics to step 0.
snapshot = self.get_state()
# `scene._backward()` re-enters the torch graph from each step's queried
# states (`_backward_from_qd` -> `state.backward(retain_graph=True)`), so
# the graph must survive the initial autograd pass.
kwargs.setdefault("retain_graph", True)
# Functional `torch.autograd.backward` fills torch + queried-state grads
# WITHOUT triggering `gs.Tensor.backward`'s auto `scene._backward()`, so
# we drive the sim unroll explicitly below.
torch.autograd.backward(loss, *args, **kwargs)
self._backward()
# Restore to the terminal snapshot; `keep_init=True` preserves the real
# initial state so a later bare `reset()` still rewinds to it.
self._reset(snapshot, keep_init=True)
return snapshot

def _get_state(self):
return self._sim.get_state()

Expand Down
Loading
Loading