Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c4d5e6
First dirty but working implementation of aba. tested with numpy inte…
Giulero Sep 25, 2025
6e58c83
aba working with random base pose and random joint positions
Giulero Sep 29, 2025
a9e949a
aba passing with random inputs apart from external wrenches
Giulero Oct 1, 2025
28699de
aba passing in mixed. no random wrenches
Giulero Oct 1, 2025
47ac7cb
aba passing also for external wrenches (in numpy backend)
Giulero Oct 1, 2025
a296a2d
Move aba function at the end of the class
Giulero Oct 1, 2025
9c0c347
Clean up numpy aba test
Giulero Oct 1, 2025
8d88ba5
Refactor RBDAlgorithms for CasADi compatibility
Giulero Oct 1, 2025
1f5537e
Implement JAX and PyTorch solve methods to handle batched solves corr…
Giulero Oct 1, 2025
e6e2504
Add matrix operations and dimension expansion methods to SpatialMath
Giulero Oct 1, 2025
643ca98
Add wrapping aba to interfaces
Giulero Oct 1, 2025
d6a6ee4
Add ABA tests for various configurations and batch processing
Giulero Oct 1, 2025
6e318d4
Some simplifications in the aba logic
Giulero Oct 1, 2025
9ec7783
Simplify code structure in aba
Giulero Oct 2, 2025
6443788
Remove debug print
Giulero Oct 2, 2025
8309f1c
Apply black
Giulero Oct 2, 2025
113eef8
add test in idyntree conversion and remove useless call in casadi test
Giulero Oct 2, 2025
bf0c8cd
Refactor array access
Giulero Oct 2, 2025
0611196
Remove unnecessary gravity vector documentation. Clean tests
Giulero Oct 2, 2025
8c699b7
Format with black
Giulero Oct 2, 2025
8531a16
Small refactor (remove line)
Giulero Oct 2, 2025
5da9f79
Update some missing doc in RBDAlgorithms class
Giulero Oct 2, 2025
44b89da
Update tests/test_jax.py
Giulero Oct 2, 2025
ad5bf88
Remove leftover import
Giulero Oct 2, 2025
f9ddd62
Put back jax numpy import
Giulero Oct 2, 2025
b07aa6b
Update src/adam/jax/computations.py
Giulero Oct 7, 2025
a955363
Fix return descriptions in aba
Giulero Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,63 @@ def transpose(self, x: CasadiLike, dims: tuple) -> CasadiLike:
# Only 2-D supported; any request means "swap last two"
return CasadiLike(x.array.T)

# --- algebra shortcuts used by algorithms ---
@staticmethod
def expand_dims(x: CasadiLike, axis: int) -> CasadiLike:
"""Expand dimensions of a CasADi array.

Args:
x: Input array (CasadiLike)
axis: Position where new axis is to be inserted

Returns:
CasadiLike: Array with expanded dimensions
"""
# If axis=-1, we're adding a column dimension to make it (n,1)
if axis == -1:
# Reshape to column vector
return CasadiLike(cs.reshape(x.array, (-1, 1)))
else:
# For other axes, just return as is (CasADi is 2D only)
return x

@staticmethod
def inv(x: CasadiLike) -> CasadiLike:
"""Matrix inversion for CasADi.

Args:
x: Matrix to invert (CasadiLike)

Returns:
CasadiLike: Inverse of x
"""
return CasadiLike(cs.inv(x.array))

@staticmethod
def solve(A: CasadiLike, B: CasadiLike) -> CasadiLike:
"""Solve linear system Ax = B for x using CasADi.

Args:
A: Coefficient matrix (CasadiLike)
B: Right-hand side vector or matrix (CasadiLike)

Returns:
CasadiLike: Solution x
"""
return CasadiLike(cs.solve(A.array, B.array))

@staticmethod
def mtimes(A: CasadiLike, B: CasadiLike) -> CasadiLike:
"""Matrix-matrix multiplication for CasADi.

Args:
A: First matrix (CasadiLike)
B: Second matrix (CasadiLike)

Returns:
CasadiLike: Result of A @ B
"""
return CasadiLike(cs.mtimes(A.array, B.array))

@staticmethod
def mxv(m: CasadiLike, v: CasadiLike) -> CasadiLike:
"""Matrix-vector multiplication for CasADi.
Expand Down
77 changes: 77 additions & 0 deletions src/adam/casadi/computations.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,83 @@ def gravity_term(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX:
self.g,
).array

def aba(
self,
base_transform: cs.SX,
joint_positions: cs.SX,
base_velocity: cs.SX,
joint_velocities: cs.SX,
joint_torques: cs.SX,
external_wrenches: dict[str, cs.SX] | None = None,
) -> cs.SX:
"""Featherstone Articulated-Body Algorithm (floating base, O(n)).

Args:
base_transform (cs.SX): The homogenous transform from base to world frame
joint_positions (cs.SX): The joints position
base_velocity (cs.SX): The base velocity
joint_velocities (cs.SX): The joint velocities
joint_torques (cs.SX): The joint torques
external_wrenches (dict[str, cs.SX], optional): External wrenches applied to the robot. Defaults to None.

Returns:
cs.SX: The joint accelerations and the base acceleration
"""
if (
isinstance(base_transform, cs.MX)
and isinstance(joint_positions, cs.MX)
and isinstance(base_velocity, cs.MX)
and isinstance(joint_velocities, cs.MX)
and isinstance(joint_torques, cs.MX)
):
raise ValueError(
"You are using casadi MX. Please use the function KinDynComputations.aba_fun()"
)

return self.rbdalgos.aba(
base_transform,
joint_positions,
base_velocity,
joint_velocities,
joint_torques,
self.g,
external_wrenches,
).array

def aba_fun(self) -> cs.Function:
"""Returns the Articulated Body Algorithm function for forward dynamics

Returns:
qdd (casADi function): The joint accelerations and base acceleration
"""
base_transform = cs.SX.sym("H", 4, 4)
joint_positions = cs.SX.sym("s", self.NDoF)
base_velocity = cs.SX.sym("v_b", 6)
joint_velocities = cs.SX.sym("s_dot", self.NDoF)
joint_torques = cs.SX.sym("tau", self.NDoF)

qdd = self.rbdalgos.aba(
base_transform,
joint_positions,
base_velocity,
joint_velocities,
joint_torques,
self.g,
None, # external_wrenches not supported in symbolic form
)
return cs.Function(
"qdd",
[
base_transform,
joint_positions,
base_velocity,
joint_velocities,
joint_torques,
],
[qdd.array],
self.f_opts,
)

def CoM_position(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX:
"""Returns the CoM position

Expand Down
12 changes: 12 additions & 0 deletions src/adam/core/array_api_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,15 @@ def expand_dims(self, x: ArrayAPILike, axis: int) -> ArrayAPILike:
def transpose(self, x: ArrayAPILike, dims: tuple) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.permute_dims(x.array, dims))

def inv(self, x: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(x.array)
return self.factory.asarray(xp.linalg.inv(x.array))

def mtimes(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(A.array, B.array)
return self.factory.asarray(xp.matmul(A.array, B.array))

def solve(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike:
xp = self._xp(A.array, B.array)
return self.factory.asarray(xp.linalg.solve(A.array, B.array))
Loading
Loading