diff --git a/src/adam/casadi/casadi_like.py b/src/adam/casadi/casadi_like.py index b5804b23..8d131f9f 100644 --- a/src/adam/casadi/casadi_like.py +++ b/src/adam/casadi/casadi_like.py @@ -306,7 +306,63 @@ def transpose(self, x: CasadiLike, dims: tuple) -> CasadiLike: # Only 2-D supported; any request means "swap last two" return CasadiLike(x.array.T) - # --- algebra shortcuts used by algorithms --- + @staticmethod + def expand_dims(x: CasadiLike, axis: int) -> CasadiLike: + """Expand dimensions of a CasADi array. + + Args: + x: Input array (CasadiLike) + axis: Position where new axis is to be inserted + + Returns: + CasadiLike: Array with expanded dimensions + """ + # If axis=-1, we're adding a column dimension to make it (n,1) + if axis == -1: + # Reshape to column vector + return CasadiLike(cs.reshape(x.array, (-1, 1))) + else: + # For other axes, just return as is (CasADi is 2D only) + return x + + @staticmethod + def inv(x: CasadiLike) -> CasadiLike: + """Matrix inversion for CasADi. + + Args: + x: Matrix to invert (CasadiLike) + + Returns: + CasadiLike: Inverse of x + """ + return CasadiLike(cs.inv(x.array)) + + @staticmethod + def solve(A: CasadiLike, B: CasadiLike) -> CasadiLike: + """Solve linear system Ax = B for x using CasADi. + + Args: + A: Coefficient matrix (CasadiLike) + B: Right-hand side vector or matrix (CasadiLike) + + Returns: + CasadiLike: Solution x + """ + return CasadiLike(cs.solve(A.array, B.array)) + + @staticmethod + def mtimes(A: CasadiLike, B: CasadiLike) -> CasadiLike: + """Matrix-matrix multiplication for CasADi. + + Args: + A: First matrix (CasadiLike) + B: Second matrix (CasadiLike) + + Returns: + CasadiLike: Result of A @ B + """ + return CasadiLike(cs.mtimes(A.array, B.array)) + @staticmethod def mxv(m: CasadiLike, v: CasadiLike) -> CasadiLike: """Matrix-vector multiplication for CasADi. diff --git a/src/adam/casadi/computations.py b/src/adam/casadi/computations.py index 6f5dfe7b..3fba0daf 100644 --- a/src/adam/casadi/computations.py +++ b/src/adam/casadi/computations.py @@ -467,6 +467,83 @@ def gravity_term(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX: self.g, ).array + def aba( + self, + base_transform: cs.SX, + joint_positions: cs.SX, + base_velocity: cs.SX, + joint_velocities: cs.SX, + joint_torques: cs.SX, + external_wrenches: dict[str, cs.SX] | None = None, + ) -> cs.SX: + """Featherstone Articulated-Body Algorithm (floating base, O(n)). + + Args: + base_transform (cs.SX): The homogenous transform from base to world frame + joint_positions (cs.SX): The joints position + base_velocity (cs.SX): The base velocity + joint_velocities (cs.SX): The joint velocities + joint_torques (cs.SX): The joint torques + external_wrenches (dict[str, cs.SX], optional): External wrenches applied to the robot. Defaults to None. + + Returns: + cs.SX: The base acceleration and the joint accelerations + """ + if ( + isinstance(base_transform, cs.MX) + and isinstance(joint_positions, cs.MX) + and isinstance(base_velocity, cs.MX) + and isinstance(joint_velocities, cs.MX) + and isinstance(joint_torques, cs.MX) + ): + raise ValueError( + "You are using casadi MX. Please use the function KinDynComputations.aba_fun()" + ) + + return self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + external_wrenches, + ).array + + def aba_fun(self) -> cs.Function: + """Returns the Articulated Body Algorithm function for forward dynamics + + Returns: + qdd (casADi function): The joint accelerations and base acceleration + """ + base_transform = cs.SX.sym("H", 4, 4) + joint_positions = cs.SX.sym("s", self.NDoF) + base_velocity = cs.SX.sym("v_b", 6) + joint_velocities = cs.SX.sym("s_dot", self.NDoF) + joint_torques = cs.SX.sym("tau", self.NDoF) + + qdd = self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + None, # external_wrenches not supported in symbolic form + ) + return cs.Function( + "qdd", + [ + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + ], + [qdd.array], + self.f_opts, + ) + def CoM_position(self, base_transform: cs.SX, joint_positions: cs.SX) -> cs.SX: """Returns the CoM position diff --git a/src/adam/core/array_api_math.py b/src/adam/core/array_api_math.py index 42ebc4cd..e19379ae 100644 --- a/src/adam/core/array_api_math.py +++ b/src/adam/core/array_api_math.py @@ -238,3 +238,15 @@ def expand_dims(self, x: ArrayAPILike, axis: int) -> ArrayAPILike: def transpose(self, x: ArrayAPILike, dims: tuple) -> ArrayAPILike: xp = self._xp(x.array) return self.factory.asarray(xp.permute_dims(x.array, dims)) + + def inv(self, x: ArrayAPILike) -> ArrayAPILike: + xp = self._xp(x.array) + return self.factory.asarray(xp.linalg.inv(x.array)) + + def mtimes(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike: + xp = self._xp(A.array, B.array) + return self.factory.asarray(xp.matmul(A.array, B.array)) + + def solve(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike: + xp = self._xp(A.array, B.array) + return self.factory.asarray(xp.linalg.solve(A.array, B.array)) diff --git a/src/adam/core/rbd_algorithms.py b/src/adam/core/rbd_algorithms.py index 4750cd39..2b873df2 100644 --- a/src/adam/core/rbd_algorithms.py +++ b/src/adam/core/rbd_algorithms.py @@ -37,9 +37,12 @@ def crba(self, base_transform: npt.ArrayLike, joint_positions: npt.ArrayLike): """ Batched Composite Rigid Body Algorithm (CRBA) + Orin's Centroidal Momentum Matrix. - - Mirrors the reference implementation’s control flow for readability. - - No array/tensor item-assignments (blocks collected then concatenated). - - Supports batched inputs via broadcasting. + Args: + base_transform (npt.ArrayLike): The homogenous transform from base to world frame + joint_positions (npt.ArrayLike): The joints position + + Returns: + M, Jcm (npt.ArrayLike, npt.ArrayLike): The mass matrix and the centroidal momentum matrix """ base_transform, joint_positions = self._convert_to_arraylike( base_transform, joint_positions @@ -180,7 +183,7 @@ def block_index(node): row_tensors.append(self.math.concatenate(row_blocks, axis=-1)) M = self.math.concatenate(row_tensors, axis=-2) # (..., 6+n, 6+n) - # Orin’s O_X_G (centroidal transform) + # Orin's O_X_G (centroidal transform) A = T(M[..., :3, 3:6]) / M[..., 0, 0][..., None, None] # (...,3,3) I3 = self.math.factory.eye(batch + (3,)) Z3 = self.math.factory.zeros(batch + (3, 3)) @@ -246,12 +249,13 @@ def block_index(node): def forward_kinematics( self, frame, base_transform: npt.ArrayLike, joint_positions: npt.ArrayLike ) -> npt.ArrayLike: - """Computes the forward kinematics relative to the specified frame + """Computes the forward kinematics relative to the specified `frame`. Args: frame (str): The frame to which the fk will be computed base_transform (npt.ArrayLike): The homogenous transform from base to world frame joint_positions (npt.ArrayLike): The joints position + Returns: I_H_L (npt.ArrayLike): The fk represented as Homogenous transformation matrix """ @@ -273,7 +277,7 @@ def forward_kinematics( def joints_jacobian( self, frame: str, joint_positions: npt.ArrayLike ) -> npt.ArrayLike: - """Returns the Jacobian relative to the specified frame + """Returns the Jacobian relative to the specified `frame`. Args: frame (str): The frame to which the jacobian will be computed @@ -315,6 +319,17 @@ def joints_jacobian( def jacobian( self, frame: str, base_transform: npt.ArrayLike, joint_positions: npt.ArrayLike ) -> npt.ArrayLike: + """Returns the Jacobian for `frame`. + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (npt.ArrayLike): The homogenous transform from base to world frame + joint_positions (npt.ArrayLike): The joints position + + Returns: + npt.ArrayLike: The Jacobian for the specified frame + """ + base_transform, joint_positions = self._convert_to_arraylike( base_transform, joint_positions ) @@ -330,7 +345,7 @@ def jacobian( == Representations.BODY_FIXED_REPRESENTATION ): return J_tot - elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: + if self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: w_H_L = w_H_B @ B_H_L LI_X_L = self.math.adjoint_mixed(w_H_L) @@ -358,13 +373,14 @@ def jacobian( def relative_jacobian( self, frame: str, joint_positions: npt.ArrayLike ) -> npt.ArrayLike: - """Returns the Jacobian between the root link and a specified frame + """Returns the Jacobian between the root link and a specified `frame`. + Args: frame (str): The tip of the chain joint_positions (npt.ArrayLike): The joints position Returns: - J (npt.ArrayLike): The 6 x NDoF Jacobian between the root and the frame + J (npt.ArrayLike): The 6 x NDoF Jacobian between the root and the `frame` """ joint_positions = self._convert_to_arraylike(joint_positions) if ( @@ -390,7 +406,19 @@ def jacobian_dot( base_velocity: npt.ArrayLike, joint_velocities: npt.ArrayLike, ) -> npt.ArrayLike: - """Returns the Jacobian time derivative for `frame`.""" + """Returns the Jacobian time derivative for `frame`. + + Args: + frame (str): The frame to which the jacobian will be computed + base_transform (npt.ArrayLike): The homogenous transform from base to world frame + joint_positions (npt.ArrayLike): The joints position + base_velocity (npt.ArrayLike): The spatial velocity of the base + joint_velocities (npt.ArrayLike): The joints velocities + + Returns: + J_dot (npt.ArrayLike): The Jacobian derivative relative to the frame + + """ base_transform, joint_positions, base_velocity, joint_velocities = ( self._convert_to_arraylike( @@ -444,7 +472,6 @@ def jacobian_dot( H_j = joint.homogeneous(q=q) B_H_j = B_H_j @ H_j L_H_j = L_H_B @ B_H_j - S = joint.motion_subspace() J_j = self.math.adjoint(L_H_j) @ S @@ -574,8 +601,17 @@ def rnea( g: npt.ArrayLike, ) -> npt.ArrayLike: """ - Batched Recursive Newton–Euler (reduced: no joint/base accelerations, no external forces). - Returns tau with shape (..., 6+n, 1). No item assignment; no squeeze helper. + Batched Recursive Newton-Euler (reduced: no joint/base accelerations, no external forces). + + Args: + base_transform (npt.ArrayLike): The homogenous transform from base to world frame + joint_positions (npt.ArrayLike): The joints position + base_velocity (npt.ArrayLike): The base spatial velocity + joint_velocities (npt.ArrayLike): The joints velocities + g (npt.ArrayLike): The gravity vector + + Returns: + tau (npt.ArrayLike): The vector of generalized forces """ base_transform, joint_positions, base_velocity, joint_velocities, g = ( self._convert_to_arraylike( @@ -710,10 +746,272 @@ def rnea( return self.math.concatenate([tau_base, tau_joints_vec], axis=-1) - def aba(self): - raise NotImplementedError + def aba( + self, + base_transform: npt.ArrayLike, + joint_positions: npt.ArrayLike, + base_velocity: npt.ArrayLike, + joint_velocities: npt.ArrayLike, + joint_torques: npt.ArrayLike, + g: npt.ArrayLike, + external_wrenches: dict[str, npt.ArrayLike] | None = None, + ) -> npt.ArrayLike: + """Featherstone Articulated Body Algorithm for floating-base forward dynamics. + + Args: + base_transform (npt.ArrayLike): The homogenous transform from base to world frame + joint_positions (npt.ArrayLike): The joints position + base_velocity (npt.ArrayLike): The spatial velocity of the base + joint_velocities (npt.ArrayLike): The joints velocities + joint_torques (npt.ArrayLike): The joints torques/forces + g (npt.ArrayLike | None, optional): The gravity vector + external_wrenches (dict[str, npt.ArrayLike] | None, optional): A dictionary of external wrenches applied to specific links. Keys are link names, and values are 6D wrench vectors. Defaults to None. + + Returns: + accelerations (npt.ArrayLike): The spatial acceleration of the base and joints accelerations + """ + model = self.model + math = self.math + + Nnodes = model.N + n = model.NDoF + root_name = self.root_link + + ( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + g, + ) = self._convert_to_arraylike( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + g, + ) + + T = lambda X: math.swapaxes(X, -2, -1) + + batch = base_transform.shape[:-2] if base_transform.ndim > 2 else () + + if external_wrenches is not None: + generalized_ext = self.math.factory.zeros(batch + (6 + n,)) + for frame, wrench in external_wrenches.items(): + wrench_arr = self._convert_to_arraylike(wrench) + J = self.jacobian(frame, base_transform, joint_positions) + generalized_ext = generalized_ext + self.math.mxv(T(J), wrench_arr) + + base_ext = generalized_ext[..., :6] + joint_ext = ( + generalized_ext[..., 6:] + if n > 0 + else self.math.zeros_like(joint_torques) + ) + else: + base_ext = self.math.factory.zeros(batch + (6,)) + joint_ext = self.math.zeros_like(joint_torques) + + joint_torques_eff = joint_torques + joint_ext + + if self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: + B_X_BI = math.adjoint_mixed_inverse(base_transform) + elif ( + self.frame_velocity_representation + == Representations.BODY_FIXED_REPRESENTATION + ): + B_X_BI = math.factory.eye(batch + (6,)) if batch else math.factory.eye(6) + else: + raise NotImplementedError( + "Only BODY_FIXED_REPRESENTATION and MIXED_REPRESENTATION are implemented" + ) + + base_velocity_body = math.mxv(B_X_BI, base_velocity) + base_ext_body = math.mxv(B_X_BI, base_ext) + + a0_input = math.mxv(math.adjoint_mixed_inverse(base_transform), g) + + def zeros6(): + return math.factory.zeros(batch + (6,)) if batch else math.factory.zeros(6) + + def eye6(): + return math.factory.eye(batch + (6,)) if batch else math.factory.eye(6) + + def expand_to_match(vec, reference): + expanded = math.expand_dims(vec, axis=-1) + expanded_ndim = expanded.ndim + reference_ndim = reference.ndim + if expanded_ndim != reference_ndim: + expanded = math.expand_dims(expanded, axis=-1) + return expanded + + Xup = [None] * Nnodes + Scols: list[ArrayLike | None] = [None] * Nnodes + v = [None] * Nnodes + c = [None] * Nnodes + IA = [None] * Nnodes + pA = [None] * Nnodes + g_acc = [None] * Nnodes + + for idx, node in enumerate(model.tree): + link_i, joint_i, link_pi = node.get_elements() + + inertia = link_i.spatial_inertia() + IA[idx] = math.tile(inertia, batch + (1, 1)) if batch else inertia + + if link_i.name == root_name: + Xup[idx] = eye6() + v[idx] = base_velocity_body + c[idx] = zeros6() + g_acc[idx] = a0_input + else: + pi = model.tree.get_idx_from_name(link_pi.name) + + if joint_i is not None: + q_i = ( + joint_positions[..., joint_i.idx] + if joint_i.idx is not None + else math.zeros_like(joint_positions[..., 0]) + ) + Xup[idx] = joint_i.spatial_transform(q=q_i) + else: + Xup[idx] = eye6() + + g_acc[idx] = math.mxv(Xup[idx], g_acc[pi]) + + if (joint_i is not None) and (joint_i.idx is not None): + Si = joint_i.motion_subspace() + Scols[idx] = Si + qd_i = joint_velocities[..., joint_i.idx] + vJ = math.vxs(Si, qd_i) + else: + Scols[idx] = None + vJ = zeros6() + + v[idx] = math.mxv(Xup[idx], v[pi]) + vJ + c[idx] = math.mxv(math.spatial_skew(v[idx]), vJ) + + pA[idx] = math.mxv( + math.spatial_skew_star(v[idx]), math.mxv(IA[idx], v[idx]) + ) + + d_list: list[ArrayLike | None] = [None] * Nnodes + u_list: list[ArrayLike | None] = [None] * Nnodes + U_list: list[ArrayLike | None] = [None] * Nnodes + + for idx, node in reversed(list(enumerate(model.tree))): + link_i, joint_i, link_pi = node.get_elements() + + if link_i.name == root_name: + continue + + pi = model.tree.get_idx_from_name(link_pi.name) + + Xpt = T(Xup[idx]) + + if Scols[idx] is not None: + S_i = Scols[idx] + U_i = math.mtimes(IA[idx], S_i) + d_i = math.mtimes(T(S_i), U_i) + tau_i = joint_torques_eff[..., joint_i.idx] + tau_vec = tau_i + Si_T_pA = math.mxv(T(S_i), pA[idx])[..., 0] + u_i = tau_vec - Si_T_pA + + d_list[idx] = d_i + u_list[idx] = u_i + U_list[idx] = U_i + + inv_d = math.inv(d_i) + Ia = IA[idx] - math.mtimes(U_i, math.mtimes(inv_d, T(U_i))) + u_i_expanded = expand_to_match(u_i, inv_d) + gain = math.mtimes(inv_d, u_i_expanded) + # Extract column vector + gain_vec = gain[..., 0] + pa = pA[idx] + math.mxv(Ia, c[idx]) + math.mxv(U_i, gain_vec) + else: + Ia = IA[idx] + pa = pA[idx] + math.mxv(Ia, c[idx]) + + IA[pi] = IA[pi] + math.mtimes(math.mtimes(Xpt, Ia), Xup[idx]) + pA[pi] = pA[pi] + math.mxv(Xpt, pa) + + root_idx = model.tree.get_idx_from_name(root_name) + rhs_root = base_ext_body - pA[root_idx] + math.mxv(IA[root_idx], a0_input) + a_base = math.solve(IA[root_idx], rhs_root) + + a = [None] * Nnodes + a[root_idx] = a_base + + qdd_entries: list[ArrayLike | None] = [None] * n if n > 0 else [] + + for idx, node in enumerate(model.tree): + link_i, joint_i, link_pi = node.get_elements() + + if link_i.name == root_name: + continue + + pi = model.tree.get_idx_from_name(link_pi.name) + a_pre = math.mxv(Xup[idx], a[pi]) + c[idx] + free_acc = g_acc[idx] + rel_acc = a_pre - free_acc if free_acc is not None else a_pre + + if ( + Scols[idx] is not None + and (joint_i is not None) + and (joint_i.idx is not None) + ): + S_i = Scols[idx] + U_i = U_list[idx] + U_T_rel_acc = math.mxv(T(U_i), rel_acc)[..., 0] + num = u_list[idx] - U_T_rel_acc + inv_d = math.inv(d_list[idx]) + num_expanded = expand_to_match(num, inv_d) + gain_qdd = math.mtimes(inv_d, num_expanded) + qdd_col = gain_qdd[..., 0] + if joint_i.idx < n: + qdd_entries[joint_i.idx] = qdd_col + a_correction_vec = math.mxv(S_i, qdd_col) + a[idx] = a_pre + a_correction_vec + else: + a[idx] = a_pre + + if n > 0: + qdd_cols = [] + for entry in qdd_entries: + if entry is None: + qdd_cols.append( + math.factory.zeros(batch + (1,)) + if batch + else math.factory.zeros((1,)) + ) + else: + qdd_cols.append(entry) + joint_qdd = math.concatenate(qdd_cols, axis=-1) + else: + joint_qdd = ( + math.factory.zeros(batch + (0,)) if batch else math.factory.zeros((0,)) + ) + + if self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: + Xm = math.adjoint_mixed(base_transform) + base_vel_mixed = math.mxv(Xm, base_velocity_body) + Xm_dot = math.adjoint_mixed_derivative(base_transform, base_vel_mixed) + base_acc = math.mxv(Xm, a_base) + math.mxv(Xm_dot, base_velocity_body) + else: + base_acc = a_base + + return math.concatenate([base_acc, joint_qdd], axis=-1) def _convert_to_arraylike(self, *args): + """Convert inputs to ArrayLike if they are not already. + Args: + *args: Input arguments to be converted. + Returns: + Converted arguments as ArrayLike. + """ if not args: raise ValueError("At least one argument is required") @@ -723,5 +1021,4 @@ def _convert_to_arraylike(self, *args): converted.append(arg) else: converted.append(self.math.asarray(arg)) - return converted[0] if len(converted) == 1 else converted diff --git a/src/adam/core/spatial_math.py b/src/adam/core/spatial_math.py index d00d6f84..1d1fdde1 100644 --- a/src/adam/core/spatial_math.py +++ b/src/adam/core/spatial_math.py @@ -226,6 +226,17 @@ def transpose(self, x: npt.ArrayLike, dims: tuple) -> npt.ArrayLike: """ pass + @abc.abstractmethod + def inv(self, x: npt.ArrayLike) -> npt.ArrayLike: + """ + Args: + x (npt.ArrayLike): input array + + Returns: + npt.ArrayLike: inverse of the array + """ + pass + @abc.abstractmethod def mtimes(self, x: npt.ArrayLike, y: npt.ArrayLike) -> npt.ArrayLike: pass @@ -665,7 +676,7 @@ def mxv(self, m: npt.ArrayLike, v: npt.ArrayLike) -> npt.ArrayLike: res = m @ v[..., None] return res[..., 0] # Remove the extra dimension - def vxs(self, v: npt.ArrayLike, c: npt.ArrayLike) -> npt.ArrayLike: + def vxs(self, v: npt.ArrayLike, s: npt.ArrayLike) -> npt.ArrayLike: """ Args: v (npt.ArrayLike): Vector @@ -675,8 +686,8 @@ def vxs(self, v: npt.ArrayLike, c: npt.ArrayLike) -> npt.ArrayLike: """ if v.shape[-1] == 1: v = v[..., 0] - c = c[..., None] # Add extra dimension - return v * c + s = s[..., None] # Add extra dimension + return v * s def adjoint_inverse(self, H: npt.ArrayLike) -> npt.ArrayLike: """ diff --git a/src/adam/jax/computations.py b/src/adam/jax/computations.py index aadb3e7f..fa12a8b8 100644 --- a/src/adam/jax/computations.py +++ b/src/adam/jax/computations.py @@ -252,6 +252,39 @@ def CoM_jacobian( base_transform, joint_positions ).array.squeeze() + def aba( + self, + base_transform: jnp.array, + joint_positions: jnp.array, + base_velocity: jnp.array, + joint_velocities: jnp.array, + joint_torques: jnp.array, + external_wrenches: dict[str, jnp.array] | None = None, + ) -> jnp.array: + """Featherstone Articulated-Body Algorithm (floating base, O(n)). + + Args: + base_transform (jnp.array): The homogenous transform from base to world frame + joint_positions (jnp.array): The joints position + base_velocity (jnp.array): The base velocity + joint_velocities (jnp.array): The joint velocities + joint_torques (jnp.array): The joint torques + external_wrenches (dict[str, jnp.array], optional): External wrenches applied to the robot. Defaults to None. + + Returns: + jnp.array: The base acceleration and the joint accelerations + """ + + return self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + external_wrenches, + ).array.squeeze() + def get_total_mass(self) -> float: """Returns the total mass of the robot diff --git a/src/adam/jax/jax_like.py b/src/adam/jax/jax_like.py index eefcc548..f59871cc 100644 --- a/src/adam/jax/jax_like.py +++ b/src/adam/jax/jax_like.py @@ -35,3 +35,20 @@ def __init__(self, spec: ArraySpec | None = None): class SpatialMath(ArrayAPISpatialMath): def __init__(self, spec: ArraySpec | None = None): super().__init__(JaxLikeFactory(spec=spec)) + + def solve(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike: + """Override solve to handle JAX's batched solve API correctly + + JAX requires b to have shape (..., N, M) for batched solves, not just (..., N). + This follows JAX's recommendation: use solve(a, b[..., None]).squeeze(-1) for 1D solves. + """ + a_arr = A.array + b_arr = B.array + + # If b is 1D per batch (shape like (batch, N)), add extra dimension for JAX + if b_arr.ndim > 1 and a_arr.ndim == b_arr.ndim + 1: + result = jnp.linalg.solve(a_arr, b_arr[..., None]).squeeze(-1) + else: + result = jnp.linalg.solve(a_arr, b_arr) + + return self.factory.asarray(result) diff --git a/src/adam/numpy/computations.py b/src/adam/numpy/computations.py index 0c418bbe..74c6942e 100644 --- a/src/adam/numpy/computations.py +++ b/src/adam/numpy/computations.py @@ -267,3 +267,36 @@ def get_total_mass(self) -> float: mass: The total mass """ return self.rbdalgos.get_total_mass() + + def aba( + self, + base_transform: np.ndarray, + joint_positions: np.ndarray, + base_velocity: np.ndarray, + joint_velocities: np.ndarray, + joint_torques: np.ndarray, + external_wrenches: dict[str, np.ndarray] | None = None, + ) -> np.ndarray: + """Featherstone Articulated-Body Algorithm (floating base, O(n)). + + Args: + base_transform (np.ndarray): The homogenous transform from base to world frame + joint_positions (np.ndarray): The joints position + base_velocity (np.ndarray): The base velocity + joint_velocities (np.ndarray): The joint velocities + joint_torques (np.ndarray): The joint torques + external_wrenches (dict[str, np.ndarray], optional): External wrenches applied to the robot. Defaults to None. + + Returns: + np.ndarray: The base acceleration and the joint accelerations + """ + + return self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + external_wrenches, + ).array.squeeze() diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index c8a385c0..a902c4ee 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -279,3 +279,36 @@ def get_total_mass(self) -> float: mass: The total mass """ return self.rbdalgos.get_total_mass() + + def aba( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + joint_torques: torch.Tensor, + external_wrenches: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + """Featherstone Articulated-Body Algorithm (floating base, O(n)). + + Args: + base_transform (torch.Tensor): The batch of homogenous transforms from base to world frame + joint_positions (torch.Tensor): The batch of joints position + base_velocity (torch.Tensor): The batch of base velocity + joint_velocities (torch.Tensor): The batch of joint velocities + joint_torques (torch.Tensor): The batch of joint torques + external_wrenches (dict[str, torch.Tensor], optional): External wrenches applied to the robot. Defaults to None. + + Returns: + torch.Tensor: The batch of base acceleration and the joint accelerations + """ + + return self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + external_wrenches, + ).array.squeeze() diff --git a/src/adam/pytorch/computations.py b/src/adam/pytorch/computations.py index c076bd1a..5d4ca488 100644 --- a/src/adam/pytorch/computations.py +++ b/src/adam/pytorch/computations.py @@ -271,6 +271,39 @@ def gravity_term( self.g, ).array.squeeze() + def aba( + self, + base_transform: torch.Tensor, + joint_positions: torch.Tensor, + base_velocity: torch.Tensor, + joint_velocities: torch.Tensor, + joint_torques: torch.Tensor, + external_wrenches: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + """Featherstone Articulated-Body Algorithm (floating base, O(n)). + + Args: + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position + base_velocity (torch.Tensor): The base velocity + joint_velocities (torch.Tensor): The joint velocities + joint_torques (torch.Tensor): The joint torques + external_wrenches (dict[str, torch.Tensor], optional): External wrenches applied to the robot. Defaults to None. + + Returns: + torch.Tensor: The base acceleration and the joint accelerations + """ + + return self.rbdalgos.aba( + base_transform, + joint_positions, + base_velocity, + joint_velocities, + joint_torques, + self.g, + external_wrenches, + ).array.squeeze() + def get_total_mass(self) -> float: """Returns the total mass of the robot diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index 13254eb5..da6d619b 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -38,3 +38,7 @@ def __init__(self, spec: ArraySpec | None = None): class SpatialMath(ArrayAPISpatialMath): def __init__(self, spec: ArraySpec | None = None): super().__init__(TorchLikeFactory(spec=spec)) + + def solve(self, A: ArrayAPILike, B: ArrayAPILike) -> ArrayAPILike: + """Override solve to use torch.linalg.solve directly to avoid array_api_compat bug""" + return self.factory.asarray(torch.linalg.solve(A.array, B.array)) diff --git a/tests/test_casadi.py b/tests/test_casadi.py index 109bc479..10a10ae2 100644 --- a/tests/test_casadi.py +++ b/tests/test_casadi.py @@ -173,3 +173,52 @@ def test_gravity_term(setup_test): assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) adam_gravity = cs.DM(adam_kin_dyn.gravity_term_fun()(state.H, state.joints_pos)) assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) + + +def test_aba(setup_test): + adam_kin_dyn, robot_cfg, state = setup_test + torques = np.random.randn(len(state.joints_pos)) * 10 + H = state.H + joints_pos = state.joints_pos + base_vel = state.base_vel + joints_vel = state.joints_vel + + wrenches = { + "l_sole": np.random.randn(6) * 10, + "torso_1": np.random.randn(6) * 10, + "head": np.random.randn(6) * 10, + } + + # Test direct method (SX) - with external wrenches + adam_qdd = cs.DM( + adam_kin_dyn.aba( + base_transform=H, + joint_positions=joints_pos, + base_velocity=base_vel, + joint_velocities=joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + ) + # Verify using the equations of motion: M @ qdd + h = tau + J^T @ wrench + M = cs.DM(adam_kin_dyn.mass_matrix(H, joints_pos)) + h = cs.DM(adam_kin_dyn.bias_force(H, joints_pos, base_vel, joints_vel)) + + generalized_external_wrenches = np.zeros(6 + len(joints_pos)) + for frame, wrench in wrenches.items(): + J = cs.DM(adam_kin_dyn.jacobian(frame, H, joints_pos)) + generalized_external_wrenches += (J.T @ wrench).full().flatten() + + base_wrench = np.zeros(6) + full_tau = np.concatenate([base_wrench, torques]) + residual = ( + M @ adam_qdd + h - full_tau + ).full().flatten() - generalized_external_wrenches + + assert residual == pytest.approx(0.0, abs=1e-4) + # Test function method + adam_qdd_fun = cs.DM( + adam_kin_dyn.aba_fun()(H, joints_pos, base_vel, joints_vel, torques) + ) + residual_fun = (M @ adam_qdd_fun + h - full_tau).full().flatten() + assert residual_fun == pytest.approx(0.0, abs=1e-4) diff --git a/tests/test_idyntree_conversion.py b/tests/test_idyntree_conversion.py index 129b185f..82520b54 100644 --- a/tests/test_idyntree_conversion.py +++ b/tests/test_idyntree_conversion.py @@ -175,3 +175,52 @@ def test_gravity_term(setup_test): assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) adam_gravity = cs.DM(adam_kin_dyn.gravity_term_fun()(state.H, state.joints_pos)) assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) + + +def test_aba(setup_test): + adam_kin_dyn, robot_cfg, state = setup_test + torques = np.random.randn(len(state.joints_pos)) * 10 + H = state.H + joints_pos = state.joints_pos + base_vel = state.base_vel + joints_vel = state.joints_vel + + wrenches = { + "l_sole": np.random.randn(6) * 10, + "torso_1": np.random.randn(6) * 10, + "head": np.random.randn(6) * 10, + } + + # Test direct method (SX) - with external wrenches + adam_qdd = cs.DM( + adam_kin_dyn.aba( + base_transform=H, + joint_positions=joints_pos, + base_velocity=base_vel, + joint_velocities=joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + ) + # Verify using the equations of motion: M @ qdd + h = tau + J^T @ wrench + M = cs.DM(adam_kin_dyn.mass_matrix(H, joints_pos)) + h = cs.DM(adam_kin_dyn.bias_force(H, joints_pos, base_vel, joints_vel)) + + generalized_external_wrenches = np.zeros(6 + len(joints_pos)) + for frame, wrench in wrenches.items(): + J = cs.DM(adam_kin_dyn.jacobian(frame, H, joints_pos)) + generalized_external_wrenches += (J.T @ wrench).full().flatten() + + base_wrench = np.zeros(6) + full_tau = np.concatenate([base_wrench, torques]) + residual = ( + M @ adam_qdd + h - full_tau + ).full().flatten() - generalized_external_wrenches + + assert residual == pytest.approx(0.0, abs=1e-4) + # Test function method + adam_qdd_fun = cs.DM( + adam_kin_dyn.aba_fun()(H, joints_pos, base_vel, joints_vel, torques) + ) + residual_fun = (M @ adam_qdd_fun + h - full_tau).full().flatten() + assert residual_fun == pytest.approx(0.0, abs=1e-4) diff --git a/tests/test_jax.py b/tests/test_jax.py index 7a35681c..652f7ded 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,8 +1,8 @@ import numpy as np +import jax.numpy as jnp import pytest from conftest import RobotCfg, State from jax import config - from adam.jax import KinDynComputations config.update("jax_enable_x64", True) @@ -71,7 +71,7 @@ def test_jacobian_dot(setup_test): idyn_jacobian_dot_nu = robot_cfg.idyn_function_values.jacobian_dot_nu adam_jacobian_dot_nu = adam_kin_dyn.jacobian_dot( "l_sole", state.H, state.joints_pos, state.base_vel, state.joints_vel - ) @ np.concatenate((state.base_vel, state.joints_vel)) + ) @ jnp.concatenate((state.base_vel, state.joints_vel)) assert idyn_jacobian_dot_nu - adam_jacobian_dot_nu == pytest.approx(0.0, abs=1e-5) @@ -119,3 +119,41 @@ def test_gravity_term(setup_test): idyn_gravity = robot_cfg.idyn_function_values.gravity_term adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos) assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) + + +def test_aba(setup_test): + adam_kin_dyn, robot_cfg, state = setup_test + torques = np.random.randn(len(state.joints_pos)) * 10 + H = state.H + joints_pos = state.joints_pos + base_vel = state.base_vel + joints_vel = state.joints_vel + + wrenches = { + "l_sole": np.random.randn(6) * 10, + "torso_1": np.random.randn(6) * 10, + "head": np.random.randn(6) * 10, + } + + adam_qdd = adam_kin_dyn.aba( + base_transform=H, + joint_positions=joints_pos, + base_velocity=base_vel, + joint_velocities=joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + + M = adam_kin_dyn.mass_matrix(H, joints_pos) + h = adam_kin_dyn.bias_force(H, joints_pos, base_vel, joints_vel) + + generalized_external_wrenches = jnp.zeros(6 + len(joints_pos)) + for frame, wrench in wrenches.items(): + J = adam_kin_dyn.jacobian(frame, H, joints_pos) + generalized_external_wrenches += J.T @ wrench + + base_wrench = np.zeros(6) + full_tau = jnp.concatenate([base_wrench, torques]) + residual = M @ adam_qdd + h - full_tau - generalized_external_wrenches + + assert residual == pytest.approx(0.0, abs=1e-4) diff --git a/tests/test_jax_batch.py b/tests/test_jax_batch.py index 72f1e47d..144c3c9a 100644 --- a/tests/test_jax_batch.py +++ b/tests/test_jax_batch.py @@ -616,3 +616,83 @@ def gravity_sum(H, joints_pos): # Verify batch variation assert not jnp.allclose(adam_gravity[0], adam_gravity[1], atol=1e-6) + + +def test_aba(setup_test): + """Test Articulated Body Algorithm with batched inputs""" + adam_kin_dyn, robot_cfg, state, batch_size = setup_test + n_joints = robot_cfg.n_dof + + # Create random torques for the batch + torques = jnp.array(np.random.randn(batch_size, n_joints) * 10) + + # Create random wrenches for multiple frames + wrenches = { + "l_sole": jnp.array(np.random.randn(batch_size, 6) * 10), + "torso_1": jnp.array(np.random.randn(batch_size, 6) * 10), + "head": jnp.array(np.random.randn(batch_size, 6) * 10), + } + + # Compute ABA + adam_qdd = adam_kin_dyn.aba( + base_transform=state.H, + joint_positions=state.joints_pos, + base_velocity=state.base_vel, + joint_velocities=state.joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + + # Check output shape + assert adam_qdd.shape == (batch_size, 6 + n_joints) + + # Test gradient computation + def aba_sum(H, joints_pos, base_vel, joints_vel, torques): + return adam_kin_dyn.aba( + H, joints_pos, base_vel, joints_vel, torques, wrenches + ).sum() + + grad_fn = grad(aba_sum, argnums=(0, 1, 2, 3, 4)) + try: + grad_results = grad_fn( + state.H, state.joints_pos, state.base_vel, state.joints_vel, torques + ) + assert all(g is not None for g in grad_results) + except Exception as e: + raise ValueError(f"Gradient computation failed: {e}") + + # Verify using the equations of motion: M @ qdd + h = tau + J^T @ wrench + M = adam_kin_dyn.mass_matrix(state.H, state.joints_pos) + h = adam_kin_dyn.bias_force( + state.H, state.joints_pos, state.base_vel, state.joints_vel + ) + + # Compute generalized external wrenches + generalized_external_wrenches = jnp.zeros((batch_size, 6 + n_joints)) + for frame, wrench in wrenches.items(): + J = adam_kin_dyn.jacobian(frame, state.H, state.joints_pos) + # J shape: (batch_size, 6, 6+n_joints), wrench shape: (batch_size, 6) + # J^T @ wrench for each batch element + generalized_external_wrenches += jnp.einsum( + "bij,bj->bi", J.transpose((0, 2, 1)), wrench + ) + + # Create full generalized forces (base wrench is zero + joint torques) + base_wrench = jnp.zeros((batch_size, 6)) + full_tau = jnp.concatenate([base_wrench, torques], axis=1) + + # Compute residual: M @ qdd + h - tau - J^T @ wrench = 0 + residual = ( + jnp.einsum("bij,bj->bi", M, adam_qdd) + + h + - full_tau + - generalized_external_wrenches + ) + + # Assert residual is close to zero + assert jnp.allclose( + residual, jnp.zeros_like(residual), atol=1e-4 + ), f"Residual max: {jnp.abs(residual).max()}" + + # Verify batch variation + assert not jnp.allclose(adam_qdd[0], adam_qdd[1], atol=1e-6) diff --git a/tests/test_numpy.py b/tests/test_numpy.py index 2b2aa169..b1ca3388 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -116,3 +116,41 @@ def test_gravity_term(setup_test): idyn_gravity = robot_cfg.idyn_function_values.gravity_term adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos) assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4) + + +def test_aba(setup_test): + adam_kin_dyn, robot_cfg, state = setup_test + torques = np.random.randn(len(state.joints_pos)) * 10 + H = state.H + joints_pos = state.joints_pos + base_vel = state.base_vel + joints_vel = state.joints_vel + + wrenches = { + "l_sole": np.random.randn(6) * 10, + "torso_1": np.random.randn(6) * 10, + "head": np.random.randn(6) * 10, + } + + adam_qdd = adam_kin_dyn.aba( + base_transform=H, + joint_positions=joints_pos, + base_velocity=base_vel, + joint_velocities=joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + + M = adam_kin_dyn.mass_matrix(H, joints_pos) + h = adam_kin_dyn.bias_force(H, joints_pos, base_vel, joints_vel) + + generalized_external_wrenches = np.zeros(6 + len(joints_pos)) + for frame, wrench in wrenches.items(): + J = adam_kin_dyn.jacobian(frame, H, joints_pos) + generalized_external_wrenches += J.T @ wrench + + base_wrench = np.zeros(6) + full_tau = np.concatenate([base_wrench, torques]) + residual = M @ adam_qdd + h - full_tau - generalized_external_wrenches + + assert residual == pytest.approx(0.0, abs=1e-4) diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 5cfe1501..629beca8 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -130,3 +130,59 @@ def test_gravity_term(setup_test): idyn_gravity = robot_cfg.idyn_function_values.gravity_term adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos) assert idyn_gravity - to_numpy(adam_gravity) == pytest.approx(0.0, abs=1e-4) + + +def test_aba(setup_test): + adam_kin_dyn, robot_cfg, state = setup_test + torques = ( + torch.randn( + len(state.joints_pos), + dtype=state.joints_pos.dtype, + device=state.joints_pos.device, + ) + * 10 + ) + H = state.H + joints_pos = state.joints_pos + base_vel = state.base_vel + joints_vel = state.joints_vel + + wrenches = { + "l_sole": torch.randn( + 6, dtype=state.joints_pos.dtype, device=state.joints_pos.device + ) + * 10, + "torso_1": torch.randn( + 6, dtype=state.joints_pos.dtype, device=state.joints_pos.device + ) + * 10, + "head": torch.randn( + 6, dtype=state.joints_pos.dtype, device=state.joints_pos.device + ) + * 10, + } + + adam_qdd = adam_kin_dyn.aba( + base_transform=H, + joint_positions=joints_pos, + base_velocity=base_vel, + joint_velocities=joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + + M = adam_kin_dyn.mass_matrix(H, joints_pos) + h = adam_kin_dyn.bias_force(H, joints_pos, base_vel, joints_vel) + + generalized_external_wrenches = torch.zeros( + 6 + len(joints_pos), dtype=H.dtype, device=H.device + ) + for frame, wrench in wrenches.items(): + J = adam_kin_dyn.jacobian(frame, H, joints_pos) + generalized_external_wrenches += J.T @ wrench + + base_wrench = torch.zeros(6, dtype=H.dtype, device=H.device) + full_tau = torch.concatenate([base_wrench, torques]) + residual = M @ adam_qdd + h - full_tau - generalized_external_wrenches + + assert to_numpy(residual) == pytest.approx(0.0, abs=1e-4) diff --git a/tests/test_pytorch_batch.py b/tests/test_pytorch_batch.py index 5870a980..c24abc61 100644 --- a/tests/test_pytorch_batch.py +++ b/tests/test_pytorch_batch.py @@ -515,3 +515,77 @@ def test_gravity_term(setup_test): # Verify batch variation (random inputs should produce different outputs) assert not torch.allclose(adam_gravity[0], adam_gravity[1], atol=1e-6) + + +def test_aba(setup_test): + """Test Articulated Body Algorithm with batched inputs""" + adam_kin_dyn, robot_cfg, state, batch_size = setup_test + n_joints = robot_cfg.n_dof + + # Get device from state tensors + device = state.H.device + + # Create random torques for the batch + torques = torch.randn(batch_size, n_joints, device=device, dtype=torch.float64) * 10 + # Create random wrenches for multiple frames + wrenches = { + "l_sole": torch.randn(batch_size, 6, device=device, dtype=torch.float64) * 10, + "torso_1": torch.randn(batch_size, 6, device=device, dtype=torch.float64) * 10, + "head": torch.randn(batch_size, 6, device=device, dtype=torch.float64) * 10, + } + + # Compute ABA + adam_qdd = adam_kin_dyn.aba( + base_transform=state.H, + joint_positions=state.joints_pos, + base_velocity=state.base_vel, + joint_velocities=state.joints_vel, + joint_torques=torques, + external_wrenches=wrenches, + ) + # Check output shape + assert adam_qdd.shape == (batch_size, 6 + n_joints) + + # Verify using the equations of motion: M @ qdd + h = tau + J^T @ wrench + M = adam_kin_dyn.mass_matrix(state.H, state.joints_pos) + h = adam_kin_dyn.bias_force( + state.H, state.joints_pos, state.base_vel, state.joints_vel + ) + + # Compute generalized external wrenches + generalized_external_wrenches = torch.zeros( + batch_size, 6 + n_joints, device=device, dtype=torch.float64 + ) + for frame, wrench in wrenches.items(): + J = adam_kin_dyn.jacobian(frame, state.H, state.joints_pos) + # J shape: (batch_size, 6, 6+n_joints), wrench shape: (batch_size, 6) + # J^T @ wrench for each batch element + generalized_external_wrenches += torch.einsum( + "bij,bj->bi", J.transpose(1, 2), wrench + ) + + # Create full generalized forces (base wrench is zero + joint torques) + base_wrench = torch.zeros(batch_size, 6, device=device, dtype=torch.float64) + full_tau = torch.cat([base_wrench, torques], dim=1) + + # Compute residual: M @ qdd + h - tau - J^T @ wrench = 0 + residual = ( + torch.einsum("bij,bj->bi", M, adam_qdd) + + h + - full_tau + - generalized_external_wrenches + ) + + # Assert residual is close to zero + assert torch.allclose( + residual, torch.zeros_like(residual), atol=1e-4 + ), f"Residual max: {residual.abs().max().item()}" + + # Test gradient computation + try: + adam_qdd.sum().backward() + except Exception as e: + raise ValueError(f"Gradient computation failed: {e}") + + # Verify batch variation (random inputs should produce different outputs) + assert not torch.allclose(adam_qdd[0], adam_qdd[1], atol=1e-6)