diff --git a/src/adam/core/rbd_algorithms.py b/src/adam/core/rbd_algorithms.py index 2b873df2..97cca9b6 100644 --- a/src/adam/core/rbd_algorithms.py +++ b/src/adam/core/rbd_algorithms.py @@ -4,7 +4,7 @@ from adam.core.constants import Representations from adam.core.spatial_math import ArrayLike, SpatialMath -from adam.model import Model, Node +from adam.model import Model, Node, Joint class RBDAlgorithms: @@ -24,6 +24,13 @@ def __init__(self, model: Model, math: SpatialMath) -> None: self.frame_velocity_representation = ( Representations.MIXED_REPRESENTATION ) # default + # Cache root quantities that are reused at every call + self._root_spatial_transform = self.math.spatial_transform( + self.math.factory.eye(3), + self.math.factory.zeros((3, 1)), + ) + self._root_motion_subspace = self.math.factory.eye(6) + self._prepare_tree_cache() def set_frame_velocity_representation(self, representation: Representations): """Sets the frame velocity representation @@ -48,179 +55,159 @@ def crba(self, base_transform: npt.ArrayLike, joint_positions: npt.ArrayLike): base_transform, joint_positions ) - model = self.model - Nnodes = model.N - n = model.NDoF - root_name = self.root_link + math = self.math + n = self.NDoF + node_count = self._node_count + node_indices = self._node_indices + rev_indices = self._rev_node_indices + parent_indices = self._parent_indices + joint_indices = self._joint_indices_per_node + motion_subspaces = self._motion_subspaces + inertias = self._spatial_inertias + joints = self._joints_per_node + root_idx = self._root_index + joint_to_node = self._joint_index_to_node + + if len(joint_positions.shape) >= 2: + batch_shape = joint_positions.shape[:-1] + elif base_transform.ndim > 2: + batch_shape = base_transform.shape[:-2] + else: + batch_shape = () - Ic = [None] * Nnodes # (...,6,6) - X_p = [None] * Nnodes # (...,6,6) - Phi = [None] * Nnodes # (...,6,ri) + if joint_positions.shape[-1] > 0: + zero_q = math.zeros_like(joint_positions[..., 0]) + else: + zero_q = math.zeros_like(base_transform[..., 0, 0]) + + def tile_batch(arr): + if not batch_shape: + return arr + reps = batch_shape + (1,) * len(arr.shape) + return math.tile(arr, reps) + + Xup = [None] * node_count + Phi = [None] * node_count + Ic = [None] * node_count + + for idx in node_indices: + Ic[idx] = tile_batch(inertias[idx]) + if idx == root_idx: + Xup[idx] = tile_batch(self._root_spatial_transform) + Phi[idx] = tile_batch(self._root_motion_subspace) + else: + joint = joints[idx] + joint_idx = joint_indices[idx] + q_i = ( + joint_positions[..., joint_idx] if joint_idx is not None else zero_q + ) + X_current = ( + joint.spatial_transform(q=q_i) + if joint is not None + else self._root_spatial_transform + ) + Xup[idx] = X_current + Phi[idx] = tile_batch(motion_subspaces[idx]) + + Xup_T = [math.swapaxes(Xup[i], -2, -1) for i in range(node_count)] + Ic_comp = Ic[:] + + for idx in rev_indices: + parent = parent_indices[idx] + if parent >= 0: + Ic_comp[parent] = Ic_comp[parent] + math.mtimes( + math.mtimes(Xup_T[idx], Ic_comp[idx]), + Xup[idx], + ) - batch = joint_positions.shape[:-1] if len(joint_positions.shape) >= 2 else () + def block_index(node_idx: int) -> int | None: + if node_idx == root_idx: + return 0 + joint_idx = joint_indices[node_idx] + if joint_idx is None: + return None + return 1 + int(joint_idx) - for i, node in enumerate(model.tree): - link_i, joint_i, link_pi = node.get_elements() + sizes = [6] + [1] * n + blocks: list[list[npt.ArrayLike | None]] = [ + [None for _ in range(n + 1)] for _ in range(n + 1) + ] + + for idx in rev_indices: + Phi_i = Phi[idx] + Phi_T = math.swapaxes(Phi_i, -2, -1) + F = math.mtimes(Ic_comp[idx], Phi_i) + ri = block_index(idx) + + if idx == root_idx: + blocks[0][0] = math.mtimes(Phi_T, F) + continue - # spatial inertia (broadcast as needed) - inertia = link_i.spatial_inertia() - Ic[i] = self.math.tile(inertia, batch + (1, 1)) if batch else inertia + if ri is not None: + blocks[ri][ri] = math.mtimes(Phi_T, F) - if link_i.name == root_name: - # root: X_p = spatial_transform(I, 0), Phi = I_6 - eye3 = self.math.factory.eye(3) - zeros = self.math.factory.zeros((3, 1)) - xp_root = self.math.spatial_transform(eye3, zeros) - phi_root = self.math.factory.eye(6) - X_p[i] = self.math.tile(xp_root, batch + (1, 1)) if batch else xp_root - Phi[i] = self.math.tile(phi_root, batch + (1, 1)) if batch else phi_root - else: - # joint transform and motion subspace - if (joint_i is not None) and (joint_i.idx is not None): - q_i = joint_positions[..., joint_i.idx] - else: - q_i = self.math.zeros_like(joint_positions[..., 0]) - X_p[i] = joint_i.spatial_transform(q=q_i) - - Si = joint_i.motion_subspace() - Phi[i] = self.math.tile(Si, batch + (1, 1)) if batch else Si - - T = lambda x: self.math.swapaxes(x, -2, -1) # transpose last two dims - X_p_T = [T(X_p[k]) for k in range(Nnodes)] - - for i, node in reversed(list(enumerate(model.tree))): - link_i, joint_i, link_pi = node.get_elements() - if link_i.name != root_name: - pi = model.tree.get_idx_from_name(link_pi.name) - Ic[pi] = Ic[pi] + X_p_T[i] @ Ic[i] @ X_p[i] - - # Map nodes to (row/col) block indices in M - # base is block 0 (size 6), actuated joints follow (size 1 each) - def block_index(node): - link_i, joint_i, _ = node.get_elements() - if link_i.name == root_name: - return 0 - if (joint_i is not None) and (joint_i.idx is not None): - return 1 + int(joint_i.idx) - return None # fixed joints do not appear in M/Jcm - - # Prepare a (n+1) x (n+1) grid of blocks (filled later, zeros where missing) - blocks = [[None for _ in range(n + 1)] for _ in range(n + 1)] - - # Assemble M - for i, node in reversed(list(enumerate(model.tree))): - link_i, joint_i, link_pi = node.get_elements() - ri = block_index(node) - - F = Ic[i] @ Phi[i] # (...,6,ri_i) (ri_i is 6 for base, 1 for joint) - - # Diagonal terms - if ( - (link_i.name != root_name) - and (joint_i is not None) - and (joint_i.idx is not None) - ): - # joint diagonal (1x1) - blocks[ri][ri] = T(Phi[i]) @ F - if link_i.name == root_name: - # base diagonal (6x6) - blocks[0][0] = T(Phi[i]) @ F - - # Off-diagonal terms along path to root - j = i - link_j, joint_j, link_pj = model.tree[j].get_elements() - while link_j.name != root_name: - F = X_p_T[j] @ F - j = model.tree.get_idx_from_name(model.tree[j].parent.name) - link_j, joint_j, link_pj = model.tree[j].get_elements() - - rj = block_index(model.tree[j]) + current = idx + F_path = F + while current != root_idx: + parent = parent_indices[current] + F_path = math.mtimes(Xup_T[current], F_path) + rj = block_index(parent) if rj is None: + current = parent continue - - Bij = T(F) @ Phi[j] # shapes adapt: (6,1), (1,6), or (1,1) - - if ( - (link_i.name == root_name) - and (joint_j is not None) - and (joint_j.idx is not None) - ): - # base–joint and its symmetric - blocks[0][rj] = Bij - blocks[rj][0] = T(Bij) - - elif ( - (link_j.name == root_name) - and (joint_i is not None) - and (joint_i.idx is not None) - ): - # joint–base and its symmetric - blocks[ri][0] = Bij - blocks[0][ri] = T(Bij) - - elif ( - (joint_i is not None) - and (joint_i.idx is not None) - and (joint_j is not None) - and (joint_j.idx is not None) - ): - # joint–joint (scalar) and symmetric (same scalar) - blocks[ri][rj] = Bij - blocks[rj][ri] = Bij - - # Replace missing blocks with zeros and concatenate into a full matrix - batch = joint_positions.shape[:-1] if len(joint_positions.shape) >= 2 else () - sizes = [6] + [1] * n + Bij = math.mtimes(math.swapaxes(F_path, -2, -1), Phi[parent]) + if ri is None: + current = parent + continue + blocks[ri][rj] = Bij + blocks[rj][ri] = math.swapaxes(Bij, -2, -1) + current = parent row_tensors = [] for r in range(n + 1): row_blocks = [] for c in range(n + 1): - B = blocks[r][c] - if B is None: - B = self.math.factory.zeros(batch + (sizes[r], sizes[c])) - row_blocks.append(B) - row_tensors.append(self.math.concatenate(row_blocks, axis=-1)) - M = self.math.concatenate(row_tensors, axis=-2) # (..., 6+n, 6+n) - - # Orin's O_X_G (centroidal transform) - A = T(M[..., :3, 3:6]) / M[..., 0, 0][..., None, None] # (...,3,3) - I3 = self.math.factory.eye(batch + (3,)) - Z3 = self.math.factory.zeros(batch + (3, 3)) - top = self.math.concatenate([I3, A], axis=-1) # (...,3,6) - bot = self.math.concatenate([Z3, I3], axis=-1) # (...,3,6) - O_X_G = self.math.concatenate([top, bot], axis=-2) # (...,6,6) - - # Propagate centroidal transform and build Jcm - X_G = [None] * Nnodes - for i, node in enumerate(model.tree): - link_i, joint_i, link_pi = node.get_elements() - if link_i.name == root_name: - X_G[i] = O_X_G + block = blocks[r][c] + if block is None: + block = math.factory.zeros(batch_shape + (sizes[r], sizes[c])) + row_blocks.append(block) + row_tensors.append(math.concatenate(row_blocks, axis=-1)) + M = math.concatenate(row_tensors, axis=-2) + + A = math.swapaxes(M[..., :3, 3:6], -2, -1) / M[..., 0, 0][..., None, None] + I3 = math.factory.eye(batch_shape + (3,)) + Z3 = math.factory.zeros(batch_shape + (3, 3)) + top = math.concatenate([I3, A], axis=-1) + bot = math.concatenate([Z3, I3], axis=-1) + O_X_G = math.concatenate([top, bot], axis=-2) + + X_G = [None] * node_count + for idx in node_indices: + if idx == root_idx: + X_G[idx] = O_X_G else: - pi = model.tree.get_idx_from_name(link_pi.name) - X_G[i] = X_p[i] @ X_G[pi] - - root_idx = model.tree.get_idx_from_name(root_name) - J_base = T(X_G[root_idx]) @ Ic[root_idx] @ Phi[root_idx] # (...,6,6) + parent = parent_indices[idx] + X_G[idx] = math.mtimes(Xup[idx], X_G[parent]) - # collect joint columns in index order - idx2node = {} - for i, node in enumerate(model.tree): - _, joint_i, _ = node.get_elements() - if (joint_i is not None) and (joint_i.idx is not None): - idx2node[int(joint_i.idx)] = i + J_base = math.mtimes( + math.swapaxes(X_G[root_idx], -2, -1), + math.mtimes(Ic_comp[root_idx], Phi[root_idx]), + ) + zero_col = math.factory.zeros(batch_shape + (6, 1)) joint_cols = [] for jidx in range(n): - if jidx in idx2node: - i = idx2node[jidx] - col = T(X_G[i]) @ Ic[i] @ Phi[i] # (...,6,1) + node_idx = joint_to_node.get(jidx) + if node_idx is not None: + col = math.mtimes( + math.swapaxes(X_G[node_idx], -2, -1), + math.mtimes(Ic_comp[node_idx], Phi[node_idx]), + ) else: - col = self.math.factory.zeros(batch + (6, 1)) + col = zero_col joint_cols.append(col) - Jcm = self.math.concatenate([J_base] + joint_cols, axis=-1) # (...,6,6+n) + Jcm = math.concatenate([J_base] + joint_cols, axis=-1) if ( self.frame_velocity_representation @@ -229,17 +216,21 @@ def block_index(node): return M, Jcm if self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: - Xm = self.math.adjoint_mixed_inverse(base_transform) # (...,6,6) - In = self.math.factory.eye(batch + (n,)) - Z6n = self.math.factory.zeros(batch + (6, n)) - Zn6 = self.math.factory.zeros(batch + (n, 6)) + Xm = math.adjoint_mixed_inverse(base_transform) + In = math.factory.eye(batch_shape + (n,)) + Z6n = math.factory.zeros(batch_shape + (6, n)) + Zn6 = math.factory.zeros(batch_shape + (n, 6)) - top = self.math.concatenate([Xm, Z6n], axis=-1) # (...,6,6+n) - bot = self.math.concatenate([Zn6, In], axis=-1) # (...,n,6+n) - X_to_mixed = self.math.concatenate([top, bot], axis=-2) # (...,6+n,6+n) + top = math.concatenate([Xm, Z6n], axis=-1) + bot = math.concatenate([Zn6, In], axis=-1) + X_to_mixed = math.concatenate([top, bot], axis=-2) - M_mixed = T(X_to_mixed) @ M @ X_to_mixed - Jcm_mixed = T(Xm) @ Jcm @ X_to_mixed + M_mixed = math.mtimes( + math.swapaxes(X_to_mixed, -2, -1), math.mtimes(M, X_to_mixed) + ) + Jcm_mixed = math.mtimes( + math.swapaxes(Xm, -2, -1), math.mtimes(Jcm, X_to_mixed) + ) return M_mixed, Jcm_mixed raise ValueError( @@ -619,132 +610,125 @@ def rnea( ) ) - model = self.model - Nnodes = model.N - n = model.NDoF - root_name = self.root_link - - T = lambda X: self.math.swapaxes(X, -2, -1) # transpose last two dims - batch_size = base_transform.shape[:-2] if len(base_transform.shape) > 2 else () + math = self.math + n = self.NDoF + node_count = self._node_count + parent_indices = self._parent_indices + joint_indices = self._joint_indices_per_node + motion_subspaces = self._motion_subspaces + inertias = self._spatial_inertias + joints = self._joints_per_node + root_idx = self._root_index + node_indices = self._node_indices + rev_indices = self._rev_node_indices + + batch_shape = ( + tuple(base_transform.shape[:-2]) if base_transform.ndim > 2 else () + ) - gravity_X = self.math.adjoint_mixed_inverse(base_transform) # (...,6,6) + gravity_X = math.adjoint_mixed_inverse(base_transform) # (...,6,6) if ( self.frame_velocity_representation == Representations.BODY_FIXED_REPRESENTATION ): - B_X_BI = self.math.factory.eye(batch_size + (6,)) # (...,6,6) - transformed_acc = self.math.factory.zeros(batch_size + (6,)) # (...,6) + B_X_BI = math.factory.eye(batch_shape + (6,)) + transformed_acc = math.factory.zeros(batch_shape + (6,)) elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: - B_X_BI = self.math.adjoint_mixed_inverse(base_transform) # (...,6,6) - omega = base_velocity[..., 3:] # (...,3) - vlin = base_velocity[..., :3] # (...,3) - # Use matrix-vector multiplication properly for batched operations - skew_omega_times_vlin = self.math.mxv( - self.math.skew(omega), vlin - ) # (...,3) - top3 = -self.math.mxv(B_X_BI[..., :3, :3], skew_omega_times_vlin) # (...,3) - bot3 = self.math.factory.zeros(batch_size + (3,)) - transformed_acc = self.math.concatenate([top3, bot3], axis=-1) # (...,6) + B_X_BI = math.adjoint_mixed_inverse(base_transform) + omega = base_velocity[..., 3:] + vlin = base_velocity[..., :3] + skew_omega_times_vlin = math.mxv(math.skew(omega), vlin) + top3 = -math.mxv(B_X_BI[..., :3, :3], skew_omega_times_vlin) + bot3 = math.factory.zeros(batch_shape + (3,)) + transformed_acc = math.concatenate([top3, bot3], axis=-1) else: raise NotImplementedError( "Only BODY_FIXED_REPRESENTATION and MIXED_REPRESENTATION are implemented" ) - # base spatial acceleration (bias + gravity) - a0 = -(self.math.mxv(gravity_X, g)) + transformed_acc # (...,6) + a0 = -(math.mxv(gravity_X, g)) + transformed_acc - Ic, X_p, Phi = [None] * Nnodes, [None] * Nnodes, [None] * Nnodes - v, a, f = [None] * Nnodes, [None] * Nnodes, [None] * Nnodes + if n > 0: + zero_q = math.zeros_like(joint_positions[..., 0]) + zero_qd = math.zeros_like(joint_velocities[..., 0]) + else: + zero_q = math.zeros_like(base_velocity[..., 0]) + zero_qd = zero_q + + Ic = [None] * node_count + Xup = [None] * node_count + v = [None] * node_count + a = [None] * node_count + f = [None] * node_count + + for idx in node_indices: + Ic[idx] = inertias[idx] + + if idx == root_idx: + Xup[idx] = self._root_spatial_transform + v[idx] = math.mxv(B_X_BI, base_velocity) + a[idx] = math.mxv(Xup[idx], a0) + continue + + joint = joints[idx] + parent = parent_indices[idx] + joint_idx = joint_indices[idx] + + q = joint_positions[..., joint_idx] if joint_idx is not None else zero_q + qd = joint_velocities[..., joint_idx] if joint_idx is not None else zero_qd - for i, node in enumerate(model.tree): - node: Node - link_i, joint_i, link_pi = node.get_elements() + X = joint.spatial_transform(q=q) + Xup[idx] = X - inertia = link_i.spatial_inertia() - Ic[i] = ( - self.math.tile(inertia, batch_size + (1, 1)) if batch_size else inertia + Phi_i = motion_subspaces[idx] + phi_qd = math.vxs(Phi_i, qd) + v[idx] = math.mxv(X, v[parent]) + phi_qd + a[idx] = math.mxv(X, a[parent]) + math.mxv( + math.spatial_skew(v[idx]), phi_qd ) - if link_i.name == root_name: - eye3 = self.math.factory.eye(3) - zeros = self.math.factory.zeros((3, 1)) - xp_root = self.math.spatial_transform(eye3, zeros) - phi_root = self.math.factory.eye(6) - X_p[i] = ( - self.math.tile(xp_root, batch_size + (1, 1)) - if batch_size - else xp_root - ) - Phi[i] = ( - self.math.tile(phi_root, batch_size + (1, 1)) - if batch_size - else phi_root - ) - v[i] = self.math.mxv(B_X_BI, base_velocity) # (...,6) - a[i] = self.math.mxv(X_p[i], a0) # (...,6) - else: - q = ( - joint_positions[..., joint_i.idx] - if (joint_i is not None) and (joint_i.idx is not None) - else self.math.zeros_like(joint_positions[..., 0]) - ) - qd = ( - joint_velocities[..., joint_i.idx] - if (joint_i is not None) and (joint_i.idx is not None) - else self.math.zeros_like(joint_velocities[..., 0]) - ) + f[idx] = math.mxv(Ic[idx], a[idx]) + math.mxv( + math.spatial_skew_star(v[idx]), math.mxv(Ic[idx], v[idx]) + ) - X_p[i] = joint_i.spatial_transform(q=q) # (...,6,6) - Si = joint_i.motion_subspace() # (6,) - Phi[i] = ( - self.math.tile(Si, batch_size + (1, 1)) if batch_size else Si - ) # (...,6,1) - pi = model.tree.get_idx_from_name(link_pi.name) + # Root wrench contribution (skipped in loop above) + f[root_idx] = math.mxv(Ic[root_idx], a[root_idx]) + math.mxv( + math.spatial_skew_star(v[root_idx]), math.mxv(Ic[root_idx], v[root_idx]) + ) - phi_qd = self.math.vxs(Phi[i], qd) # (...,6) + tau_base = None + tau_joint_cols = [None] * n if n > 0 else [] - v[i] = self.math.mxv(X_p[i], v[pi]) + phi_qd # (...,6) - a[i] = self.math.mxv(X_p[i], a[pi]) + self.math.mxv( - self.math.spatial_skew(v[i]), phi_qd - ) # (...,6) + for idx in rev_indices: + Phi_i = motion_subspaces[idx] + Fi = f[idx] + Phi_T = math.swapaxes(Phi_i, -2, -1) - f[i] = self.math.mxv(Ic[i], a[i]) + self.math.mxv( - self.math.spatial_skew_star(v[i]), self.math.mxv(Ic[i], v[i]) - ) # (...,6) + if idx == root_idx: + tau_base = math.mxv(Phi_T, Fi) + else: + joint_idx = joint_indices[idx] + if joint_idx is not None: + tau_joint_cols[joint_idx] = math.mxv(Phi_T, Fi) + parent = parent_indices[idx] + if parent >= 0: + f[parent] = f[parent] + math.mxv( + math.swapaxes(Xup[idx], -2, -1), Fi + ) - tau_base = None - tau_joint_by_idx = {} - - for i, node in reversed(list(enumerate(model.tree))): - link_i, joint_i, link_pi = node.get_elements() - - if link_i.name == root_name: - tau_base = self.math.mxv(T(Phi[i]), f[i]) # (...,6) - elif (joint_i is not None) and (joint_i.idx is not None): - # (Phi^T f) -> (...,1,6) @ (...,6) -> (...,1) - tau_joint_by_idx[int(joint_i.idx)] = self.math.mxv( - T(Phi[i]), f[i] - ) # (...,1) - - if link_i.name != root_name: - pi = model.tree.get_idx_from_name(link_pi.name) - f[pi] = f[pi] + self.math.mxv(T(X_p[i]), f[i]) # (...,6) - - tau_base = self.math.mxv(T(B_X_BI), tau_base) # (...,6) - - tau_joints = [] - for ii in range(n): - col = ( - tau_joint_by_idx[ii] - if ii in tau_joint_by_idx - else self.math.factory.zeros(batch_size + (1,)) - ) - tau_joints.append(col) + tau_base = math.mxv(math.swapaxes(B_X_BI, -2, -1), tau_base) - tau_joints_vec = self.math.concatenate(tau_joints, axis=-1) + if n > 0: + zero_tau = math.factory.zeros(batch_shape + (1,)) + tau_joints = [ + col if col is not None else zero_tau for col in tau_joint_cols + ] + tau_joints_vec = math.concatenate(tau_joints, axis=-1) + else: + tau_joints_vec = math.factory.zeros(batch_shape + (0,)) - return self.math.concatenate([tau_base, tau_joints_vec], axis=-1) + return math.concatenate([tau_base, tau_joints_vec], axis=-1) def aba( self, @@ -770,12 +754,17 @@ def aba( Returns: accelerations (npt.ArrayLike): The spatial acceleration of the base and joints accelerations """ - model = self.model math = self.math - - Nnodes = model.N - n = model.NDoF - root_name = self.root_link + n = self.NDoF + node_count = self._node_count + parent_indices = self._parent_indices + joint_indices = self._joint_indices_per_node + motion_subspaces = self._motion_subspaces + inertias = self._spatial_inertias + joints = self._joints_per_node + root_idx = self._root_index + node_indices = self._node_indices + rev_indices = self._rev_node_indices ( base_transform, @@ -793,26 +782,26 @@ def aba( g, ) - T = lambda X: math.swapaxes(X, -2, -1) - - batch = base_transform.shape[:-2] if base_transform.ndim > 2 else () + batch_shape = ( + tuple(base_transform.shape[:-2]) if base_transform.ndim > 2 else () + ) if external_wrenches is not None: - generalized_ext = self.math.factory.zeros(batch + (6 + n,)) + generalized_ext = math.factory.zeros(batch_shape + (6 + n,)) for frame, wrench in external_wrenches.items(): wrench_arr = self._convert_to_arraylike(wrench) J = self.jacobian(frame, base_transform, joint_positions) - generalized_ext = generalized_ext + self.math.mxv(T(J), wrench_arr) + generalized_ext = generalized_ext + math.mxv( + math.swapaxes(J, -2, -1), wrench_arr + ) base_ext = generalized_ext[..., :6] joint_ext = ( - generalized_ext[..., 6:] - if n > 0 - else self.math.zeros_like(joint_torques) + generalized_ext[..., 6:] if n > 0 else math.zeros_like(joint_torques) ) else: - base_ext = self.math.factory.zeros(batch + (6,)) - joint_ext = self.math.zeros_like(joint_torques) + base_ext = math.factory.zeros(batch_shape + (6,)) + joint_ext = math.zeros_like(joint_torques) joint_torques_eff = joint_torques + joint_ext @@ -822,7 +811,7 @@ def aba( self.frame_velocity_representation == Representations.BODY_FIXED_REPRESENTATION ): - B_X_BI = math.factory.eye(batch + (6,)) if batch else math.factory.eye(6) + B_X_BI = math.factory.eye(batch_shape + (6,)) else: raise NotImplementedError( "Only BODY_FIXED_REPRESENTATION and MIXED_REPRESENTATION are implemented" @@ -834,10 +823,10 @@ def aba( a0_input = math.mxv(math.adjoint_mixed_inverse(base_transform), g) def zeros6(): - return math.factory.zeros(batch + (6,)) if batch else math.factory.zeros(6) + return math.factory.zeros(batch_shape + (6,)) def eye6(): - return math.factory.eye(batch + (6,)) if batch else math.factory.eye(6) + return math.factory.eye(batch_shape + (6,)) def expand_to_match(vec, reference): expanded = math.expand_dims(vec, axis=-1) @@ -847,77 +836,79 @@ def expand_to_match(vec, reference): expanded = math.expand_dims(expanded, axis=-1) return expanded - Xup = [None] * Nnodes - Scols: list[ArrayLike | None] = [None] * Nnodes - v = [None] * Nnodes - c = [None] * Nnodes - IA = [None] * Nnodes - pA = [None] * Nnodes - g_acc = [None] * Nnodes - - for idx, node in enumerate(model.tree): - link_i, joint_i, link_pi = node.get_elements() - - inertia = link_i.spatial_inertia() - IA[idx] = math.tile(inertia, batch + (1, 1)) if batch else inertia - - if link_i.name == root_name: + if n > 0: + zero_q = math.zeros_like(joint_positions[..., 0]) + else: + zero_q = math.zeros_like(base_velocity[..., 0]) + + def tile_batch(arr): + if not batch_shape: + return arr + reps = batch_shape + (1,) * len(arr.shape) + return math.tile(arr, reps) + + Xup = [None] * node_count + Scols: list[ArrayLike | None] = [None] * node_count + v = [None] * node_count + c = [None] * node_count + IA = [None] * node_count + pA = [None] * node_count + g_acc = [None] * node_count + + for idx in node_indices: + IA[idx] = tile_batch(inertias[idx]) + + if idx == root_idx: Xup[idx] = eye6() v[idx] = base_velocity_body c[idx] = zeros6() g_acc[idx] = a0_input else: - pi = model.tree.get_idx_from_name(link_pi.name) - - if joint_i is not None: - q_i = ( - joint_positions[..., joint_i.idx] - if joint_i.idx is not None - else math.zeros_like(joint_positions[..., 0]) - ) - Xup[idx] = joint_i.spatial_transform(q=q_i) - else: - Xup[idx] = eye6() + parent = parent_indices[idx] + joint = joints[idx] + joint_idx = joint_indices[idx] + q_i = ( + joint_positions[..., joint_idx] if joint_idx is not None else zero_q + ) + X_current = joint.spatial_transform(q=q_i) + Xup[idx] = X_current - g_acc[idx] = math.mxv(Xup[idx], g_acc[pi]) + g_acc[idx] = math.mxv(X_current, g_acc[parent]) - if (joint_i is not None) and (joint_i.idx is not None): - Si = joint_i.motion_subspace() + if joint_idx is not None: + Si = tile_batch(motion_subspaces[idx]) Scols[idx] = Si - qd_i = joint_velocities[..., joint_i.idx] + qd_i = joint_velocities[..., joint_idx] vJ = math.vxs(Si, qd_i) else: Scols[idx] = None vJ = zeros6() - v[idx] = math.mxv(Xup[idx], v[pi]) + vJ + v[idx] = math.mxv(X_current, v[parent]) + vJ c[idx] = math.mxv(math.spatial_skew(v[idx]), vJ) pA[idx] = math.mxv( math.spatial_skew_star(v[idx]), math.mxv(IA[idx], v[idx]) ) - d_list: list[ArrayLike | None] = [None] * Nnodes - u_list: list[ArrayLike | None] = [None] * Nnodes - U_list: list[ArrayLike | None] = [None] * Nnodes + d_list: list[ArrayLike | None] = [None] * node_count + inv_d_list: list[ArrayLike | None] = [None] * node_count + u_list: list[ArrayLike | None] = [None] * node_count + U_list: list[ArrayLike | None] = [None] * node_count - for idx, node in reversed(list(enumerate(model.tree))): - link_i, joint_i, link_pi = node.get_elements() - - if link_i.name == root_name: + for idx in rev_indices: + if idx == root_idx: continue - pi = model.tree.get_idx_from_name(link_pi.name) - - Xpt = T(Xup[idx]) - + parent = parent_indices[idx] + Xpt = math.swapaxes(Xup[idx], -2, -1) + Si = Scols[idx] if Scols[idx] is not None: - S_i = Scols[idx] - U_i = math.mtimes(IA[idx], S_i) - d_i = math.mtimes(T(S_i), U_i) - tau_i = joint_torques_eff[..., joint_i.idx] - tau_vec = tau_i - Si_T_pA = math.mxv(T(S_i), pA[idx])[..., 0] + U_i = math.mtimes(IA[idx], Si) + d_i = math.mtimes(math.swapaxes(Si, -2, -1), U_i) + joint_idx = joint_indices[idx] + tau_vec = joint_torques_eff[..., joint_idx] + Si_T_pA = math.mxv(math.swapaxes(Si, -2, -1), pA[idx])[..., 0] u_i = tau_vec - Si_T_pA d_list[idx] = d_i @@ -925,75 +916,63 @@ def expand_to_match(vec, reference): U_list[idx] = U_i inv_d = math.inv(d_i) - Ia = IA[idx] - math.mtimes(U_i, math.mtimes(inv_d, T(U_i))) - u_i_expanded = expand_to_match(u_i, inv_d) - gain = math.mtimes(inv_d, u_i_expanded) - # Extract column vector + inv_d_list[idx] = inv_d + Ia = IA[idx] - math.mtimes( + U_i, math.mtimes(inv_d, math.swapaxes(U_i, -2, -1)) + ) + gain = math.mtimes(inv_d, expand_to_match(u_i, inv_d)) gain_vec = gain[..., 0] pa = pA[idx] + math.mxv(Ia, c[idx]) + math.mxv(U_i, gain_vec) else: Ia = IA[idx] pa = pA[idx] + math.mxv(Ia, c[idx]) - IA[pi] = IA[pi] + math.mtimes(math.mtimes(Xpt, Ia), Xup[idx]) - pA[pi] = pA[pi] + math.mxv(Xpt, pa) + IA[parent] = IA[parent] + math.mtimes(math.mtimes(Xpt, Ia), Xup[idx]) + pA[parent] = pA[parent] + math.mxv(Xpt, pa) - root_idx = model.tree.get_idx_from_name(root_name) rhs_root = base_ext_body - pA[root_idx] + math.mxv(IA[root_idx], a0_input) a_base = math.solve(IA[root_idx], rhs_root) - a = [None] * Nnodes + a = [None] * node_count a[root_idx] = a_base qdd_entries: list[ArrayLike | None] = [None] * n if n > 0 else [] - for idx, node in enumerate(model.tree): - link_i, joint_i, link_pi = node.get_elements() - - if link_i.name == root_name: + for idx in node_indices: + if idx == root_idx: continue - pi = model.tree.get_idx_from_name(link_pi.name) - a_pre = math.mxv(Xup[idx], a[pi]) + c[idx] + parent = parent_indices[idx] + a_pre = math.mxv(Xup[idx], a[parent]) + c[idx] free_acc = g_acc[idx] rel_acc = a_pre - free_acc if free_acc is not None else a_pre - if ( - Scols[idx] is not None - and (joint_i is not None) - and (joint_i.idx is not None) - ): - S_i = Scols[idx] + Si = Scols[idx] + joint_idx = joint_indices[idx] + + if Si is not None and joint_idx is not None: U_i = U_list[idx] - U_T_rel_acc = math.mxv(T(U_i), rel_acc)[..., 0] + U_T_rel_acc = math.mxv(math.swapaxes(U_i, -2, -1), rel_acc)[..., 0] num = u_list[idx] - U_T_rel_acc - inv_d = math.inv(d_list[idx]) + inv_d = inv_d_list[idx] num_expanded = expand_to_match(num, inv_d) gain_qdd = math.mtimes(inv_d, num_expanded) qdd_col = gain_qdd[..., 0] - if joint_i.idx < n: - qdd_entries[joint_i.idx] = qdd_col - a_correction_vec = math.mxv(S_i, qdd_col) + if joint_idx < n: + qdd_entries[joint_idx] = qdd_col + a_correction_vec = math.mxv(Si, qdd_col) a[idx] = a_pre + a_correction_vec else: a[idx] = a_pre if n > 0: - qdd_cols = [] - for entry in qdd_entries: - if entry is None: - qdd_cols.append( - math.factory.zeros(batch + (1,)) - if batch - else math.factory.zeros((1,)) - ) - else: - qdd_cols.append(entry) + zero_col = math.factory.zeros(batch_shape + (1,)) + qdd_cols = [ + entry if entry is not None else zero_col for entry in qdd_entries + ] joint_qdd = math.concatenate(qdd_cols, axis=-1) else: - joint_qdd = ( - math.factory.zeros(batch + (0,)) if batch else math.factory.zeros((0,)) - ) + joint_qdd = math.factory.zeros(batch_shape + (0,)) if self.frame_velocity_representation == Representations.MIXED_REPRESENTATION: Xm = math.adjoint_mixed(base_transform) @@ -1022,3 +1001,38 @@ def _convert_to_arraylike(self, *args): else: converted.append(self.math.asarray(arg)) return converted[0] if len(converted) == 1 else converted + + def _prepare_tree_cache(self) -> None: + """Pre-compute static tree data so the dynamic algorithms avoid repeated Python work.""" + nodes = list(self.model.tree) + node_count = len(nodes) + self._node_indices = tuple(range(node_count)) + self._rev_node_indices = tuple(reversed(self._node_indices)) + self._parent_indices = [-1] * node_count + self._joint_indices_per_node: list[int | None] = [None] * node_count + self._motion_subspaces: list[npt.ArrayLike] = [None] * node_count + self._spatial_inertias: list[npt.ArrayLike] = [None] * node_count + self._joints_per_node: list[Joint | None] = [None] * node_count + self._joint_index_to_node: dict[int, int] = {} + + for idx, node in enumerate(nodes): + link, joint, parent_link = node.get_elements() + self._joints_per_node[idx] = joint + self._spatial_inertias[idx] = link.spatial_inertia() + parent_idx = ( + self.model.tree.get_idx_from_name(parent_link.name) + if parent_link is not None + else -1 + ) + self._parent_indices[idx] = parent_idx + if joint is None: + self._motion_subspaces[idx] = self._root_motion_subspace + self._joint_indices_per_node[idx] = None + else: + self._motion_subspaces[idx] = joint.motion_subspace() + self._joint_indices_per_node[idx] = joint.idx + if joint.idx is not None: + self._joint_index_to_node[int(joint.idx)] = idx + + self._root_index = self.model.tree.get_idx_from_name(self.root_link) + self._node_count = node_count diff --git a/src/adam/parametric/pytorch/computations_parametric.py b/src/adam/parametric/pytorch/computations_parametric.py index 7bf0ba6b..73e615c3 100644 --- a/src/adam/parametric/pytorch/computations_parametric.py +++ b/src/adam/parametric/pytorch/computations_parametric.py @@ -26,7 +26,7 @@ def __init__( device: torch.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ), - dtypes: torch.dtype = torch.float64, + dtype: torch.dtype = torch.float64, root_link: str = None, gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: @@ -37,9 +37,9 @@ def __init__( links_name_list (list): list of parametric links root_link (str, optional): Deprecated. The root link is automatically chosen as the link with no parent in the URDF. Defaults to None. """ - ref = torch.tensor(0.0, dtype=dtypes, device=device) + ref = torch.tensor(0.0, dtype=dtype, device=device) self.math = SpatialMath(spec=spec_from_reference(ref)) - self.g = gravity.to(dtype=dtypes, device=device) + self.g = gravity.to(dtype=dtype, device=device) self.links_name_list = links_name_list self.joints_name_list = joints_name_list self.urdfstring = urdfstring diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index a902c4ee..7b8c2911 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -27,7 +27,7 @@ def __init__( device: torch.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ), - dtype: torch.dtype = torch.float64, + dtype: torch.dtype = torch.float32, root_link: str = None, gravity: torch.Tensor = torch.as_tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: diff --git a/src/adam/pytorch/computations.py b/src/adam/pytorch/computations.py index 5d4ca488..ddc4b9f8 100644 --- a/src/adam/pytorch/computations.py +++ b/src/adam/pytorch/computations.py @@ -23,7 +23,7 @@ def __init__( device: torch.device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ), - dtypes: torch.dtype = torch.float64, + dtype: torch.dtype = torch.float32, root_link: str = None, gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: @@ -33,14 +33,14 @@ def __init__( joints_name_list (list): list of the actuated joints root_link (str, optional): Deprecated. The root link is automatically chosen as the link with no parent in the URDF. Defaults to None. """ - ref = torch.tensor(0.0, dtype=dtypes, device=device) + ref = torch.tensor(0.0, dtype=dtype, device=device) spec = spec_from_reference(ref) math = SpatialMath(spec=spec) factory = URDFModelFactory(path=urdfstring, math=math) model = Model.build(factory=factory, joints_name_list=joints_name_list) self.rbdalgos = RBDAlgorithms(model=model, math=math) self.NDoF = self.rbdalgos.NDoF - self.g = gravity.to(dtype=dtypes, device=device) + self.g = gravity.to(dtype=dtype, device=device) if root_link is not None: warnings.warn( "The root_link argument is not used. The root link is automatically chosen as the link with no parent in the URDF", diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 629beca8..64a0cb75 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -10,7 +10,10 @@ def setup_test(tests_setup, device) -> KinDynComputations | RobotCfg | State: robot_cfg, state = tests_setup adam_kin_dyn = KinDynComputations( - robot_cfg.model_path, robot_cfg.joints_name_list, device=device + robot_cfg.model_path, + robot_cfg.joints_name_list, + device=device, + dtype=torch.float64, ) adam_kin_dyn.set_frame_velocity_representation(robot_cfg.velocity_representation) # convert state quantities to torch tensors diff --git a/tests/test_pytorch_batch.py b/tests/test_pytorch_batch.py index c24abc61..18e2e0be 100644 --- a/tests/test_pytorch_batch.py +++ b/tests/test_pytorch_batch.py @@ -12,7 +12,10 @@ def setup_test(tests_setup, device) -> KinDynComputationsBatch | RobotCfg | Stat robot_cfg, state = tests_setup adam_kin_dyn = KinDynComputationsBatch( - robot_cfg.model_path, robot_cfg.joints_name_list, device=device + robot_cfg.model_path, + robot_cfg.joints_name_list, + device=device, + dtype=torch.float64, ) adam_kin_dyn.set_frame_velocity_representation(robot_cfg.velocity_representation)