diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 716f793d8..e46ce55c2 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -9,6 +9,8 @@ import genesis as gs from genesis.engine.materials.base import Material +from genesis.engine.states.entities import RigidEntityState +from genesis.engine.states.cache import QueriedStates from genesis.options.morphs import Morph from genesis.options.surfaces import Surface from genesis.utils import geom as gu @@ -18,7 +20,7 @@ from genesis.utils import mjcf as mju from genesis.utils import terrain as tu from genesis.utils import urdf as uu -from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, tensor_to_array, ti_to_torch +from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, ti_to_torch, to_gs_tensor from ..base_entity import Entity from .rigid_equality import RigidEquality @@ -93,6 +95,11 @@ def __init__( self._load_model() + # For differentiability + self._init_tgt_vars() + self._init_ckpt() + self._queried_states = QueriedStates() + def _load_model(self): self._links = gs.List() self._joints = gs.List() @@ -827,6 +834,22 @@ def _add_equality(self, name, type, objs_name, data, sol_params): self._equalities.append(equality) return equality + def _init_tgt_vars(self): + # Temp variable to store targets for next step + self._tgt = dict() + self._tgt_buffer = dict() + self._init_tgt_keys() + + for key in self._tgt_keys: + self._tgt[key] = None + self._tgt_buffer[key] = list() + + def _init_tgt_keys(self): + self._tgt_keys = ["pos", "quat", "vel", "ang"] + + def _init_ckpt(self): + self._ckpt = dict() + # ------------------------------------------------------------------------------------ # --------------------------------- Jacobian & IK ------------------------------------ # ------------------------------------------------------------------------------------ @@ -1574,6 +1597,81 @@ def plan_path( # ---------------------------------- control & io ------------------------------------ # ------------------------------------------------------------------------------------ + def process_input(self, in_backward=False): + if in_backward: + # Use negative index because buffer length might not be full + index = self._sim.cur_step_local - self._sim._steps_local + for key in self._tgt_keys: + self._tgt[key] = self._tgt_buffer[key][index] + else: + for key in self._tgt_keys: + self._tgt_buffer[key].append(self._tgt[key]) + + # set_pos followed by set_vel, because set_pos resets velocity. + if self._tgt["pos"] is not None: + self._tgt["pos"].assert_contiguous() + self._tgt["pos"].assert_sceneless() + self.set_pos(self._tgt["pos"]) + + if self._tgt["quat"] is not None: + self._tgt["quat"].assert_contiguous() + self._tgt["quat"].assert_sceneless() + self.set_quat(self._tgt["quat"]) + + if self._tgt["vel"] is not None: + self._tgt["vel"].assert_contiguous() + self._tgt["vel"].assert_sceneless() + self.set_vel(self._tgt["vel"]) + + if self._tgt["ang"] is not None: + self._tgt["ang"].assert_contiguous() + self._tgt["ang"].assert_sceneless() + self.set_ang(self._tgt["ang"]) + + for key in self._tgt_keys: + self._tgt[key] = None + + def process_input_grad(self): + _tgt_pos = self._tgt_buffer["pos"].pop() + _tgt_quat = self._tgt_buffer["quat"].pop() + _tgt_vel = self._tgt_buffer["vel"].pop() + _tgt_ang = self._tgt_buffer["ang"].pop() + + if _tgt_pos is not None and _tgt_pos.requires_grad: + _tgt_pos._backward_from_ti(self.set_pos_grad, self._sim.cur_substep_local) + + if _tgt_quat is not None and _tgt_quat.requires_grad: + _tgt_quat._backward_from_ti(self.set_quat_grad, self._sim.cur_substep_local) + + if _tgt_vel is not None and _tgt_vel.requires_grad: + _tgt_vel._backward_from_ti(self.set_vel_grad, self._sim.cur_substep_local) + + if _tgt_ang is not None and _tgt_ang.requires_grad: + _tgt_ang._backward_from_ti(self.set_ang_grad, self._sim.cur_substep_local) + + def save_ckpt(self, ckpt_name): + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = { + "_tgt_buffer": dict(), + } + + for key in self._tgt_keys: + self._ckpt[ckpt_name]["_tgt_buffer"][key] = list(self._tgt_buffer[key]) + self._tgt_buffer[key].clear() + + def load_ckpt(self, ckpt_name): + for key in self._tgt_keys: + self._tgt_buffer[key] = list(self._ckpt[ckpt_name]["_tgt_buffer"][key]) + + @gs.assert_built + def get_state(self): + state = RigidEntityState(self, self._sim.cur_step_global) + self._kernel_get_base_link_state(self._sim.cur_substep_local, state.pos, state.quat) + # Store all queried states to track gradient flow + self._queried_states.append(state) + + return state + def get_joint(self, name=None, uid=None): """ Get a RigidJoint object by name or uid. @@ -1879,74 +1977,6 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal links_idx = self._get_idx(links_idx_local, self.n_links, self._link_start, unsafe=True) return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe) - @gs.assert_built - def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): - """ - Set position of the entity's base link. - - Parameters - ---------- - pos : array_like - The position to set. - relative : bool, optional - Whether the position to set is absolute or relative to the initial (not current!) position. Defaults to - False. - zero_velocity : bool, optional - Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a - sudden change in entity pose. - envs_idx : None | array_like, optional - The indices of the environments. If None, all environments will be considered. Defaults to None. - """ - if not unsafe: - _pos = torch.as_tensor(pos, dtype=gs.tc_float, device=gs.device).contiguous() - if _pos is not pos: - gs.logger.debug(ALLOCATE_TENSOR_WARNING) - pos = _pos - self._solver.set_base_links_pos( - pos.unsqueeze(-2), - self._base_links_idx, - envs_idx, - relative=relative, - unsafe=unsafe, - skip_forward=zero_velocity, - ) - if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) - - @gs.assert_built - def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): - """ - Set quaternion of the entity's base link. - - Parameters - ---------- - quat : array_like - The quaternion to set. - relative : bool, optional - Whether the quaternion to set is absolute or relative to the initial (not current!) quaternion. Defaults to - False. - zero_velocity : bool, optional - Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a - sudden change in entity pose. - envs_idx : None | array_like, optional - The indices of the environments. If None, all environments will be considered. Defaults to None. - """ - if not unsafe: - _quat = torch.as_tensor(quat, dtype=gs.tc_float, device=gs.device).contiguous() - if _quat is not quat: - gs.logger.debug(ALLOCATE_TENSOR_WARNING) - quat = _quat - self._solver.set_base_links_quat( - quat.unsqueeze(-2), - self._base_links_idx, - envs_idx, - relative=relative, - unsafe=unsafe, - skip_forward=zero_velocity, - ) - if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) - @gs.assert_built def get_verts(self): """ @@ -1977,6 +2007,15 @@ def _update_verts_for_geom(self): i_g = i_g_ + self._geom_start self._solver.update_verts_for_geom(i_g) + @ti.kernel + def _kernel_get_base_link_state(self, f: ti.i32, pos: ti.types.ndarray(), quat: ti.types.ndarray()): + for i_b in ti.ndrange(self._solver._B): + l = self._solver.links_state[f, self.base_link_idx, i_b] + for j in ti.static(range(3)): + pos[i_b, j] = l.pos[f, self.base_link_idx, i_b][j] + for j in ti.static(range(4)): + quat[i_b, j] = l.quat[f, self.base_link_idx, i_b][j] + @gs.assert_built def get_AABB(self): """ @@ -2041,6 +2080,144 @@ def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False return idx_global + def set_position(self, pos): + """ + Save the given position (entity's base link) to the target tensor. + """ + pos = to_gs_tensor(pos) + self._tgt["pos"] = pos + + def set_quaternion(self, quat): + """ + Save the given quaternion (entity's base link) to the target tensor. + """ + quat = to_gs_tensor(quat) + self._tgt["quat"] = quat + + def set_velocity(self, vel): + """ + Save the given velocity (entity's base link) to the target tensor. + """ + vel = to_gs_tensor(vel) + self._tgt["vel"] = vel + + def set_angular_velocity(self, ang): + """ + Save the given angular velocity (entity's base link) to the target tensor. + """ + ang = to_gs_tensor(ang) + self._tgt["ang"] = ang + + @gs.assert_built + def set_pos(self, pos): + """ + Set the position of the entity's base link. + + TODO: Compare with the original set_pos in the latest Genesis and fix backward compatibility. + """ + links_idx = self._base_links_idx + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + pos = pos.unsqueeze(0).unsqueeze(0) + pos = pos.expand((len(envs_idx), len(links_idx), -1)) + self.solver.set_base_links_pos(pos, links_idx, envs_idx) + + @gs.assert_built + def set_pos_grad(self, f, pos_grad): + links_idx = self._base_links_idx + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + tmp_pos_grad = pos_grad.clone().unsqueeze(0).unsqueeze(0) + tmp_pos_grad = tmp_pos_grad.expand((len(envs_idx), len(links_idx), -1)) + self.solver._kernel_set_links_pos_grad(f, tmp_pos_grad, links_idx, envs_idx) + pos_grad.data = tmp_pos_grad.sum(dim=0).sum(dim=0) + + @gs.assert_built + def set_quat(self, quat): + """ + Set the quaternion of the entity's base link. + + TODO: Compare with the original set_quat in the latest Genesis and fix backward compatibility. + """ + + links_idx = self._base_links_idx + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + quat = quat.unsqueeze(0).unsqueeze(0) + quat = quat.expand((len(envs_idx), len(links_idx), -1)) + self.solver.set_base_links_quat(quat, links_idx, envs_idx) + + @gs.assert_built + def set_quat_grad(self, f, quat_grad): + + links_idx = self._base_links_idx + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + tmp_quat_grad = quat_grad.clone().unsqueeze(0).unsqueeze(0) + tmp_quat_grad = tmp_quat_grad.expand((len(envs_idx), len(links_idx), -1)) + self.solver._kernel_set_links_quat_grad(f, tmp_quat_grad, links_idx, envs_idx) + quat_grad.data = tmp_quat_grad.sum(dim=0).sum(dim=0) + + @gs.assert_built + def set_vel(self, vel): + """ + Set the velocity of the entity's base link. + + TODO: Compare with the original set_vel in the latest Genesis and fix backward compatibility. + """ + + assert self.n_dofs == 6, "set_vel is only supported for 6 dof entities" + + dofs_idx = [i for i in range(self.dof_start, self.dof_start + 3)] + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + vel = vel.unsqueeze(0) + vel = vel.expand((len(envs_idx), -1)) + self.solver.set_dofs_velocity(vel, dofs_idx, envs_idx) + + @gs.assert_built + def set_vel_grad(self, f, vel_grad): + + assert self.n_dofs == 6, "set_vel is only supported for 6 dof entities" + + dofs_idx = to_gs_tensor([i for i in range(self.dof_start, self.dof_start + 3)]) + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + tmp_vel_grad = vel_grad.clone().unsqueeze(0) + tmp_vel_grad = tmp_vel_grad.expand((len(envs_idx), -1)) + self.solver._kernel_set_dofs_velocity_grad(f, tmp_vel_grad, dofs_idx, envs_idx) + vel_grad.data = tmp_vel_grad.sum(dim=0) + + @gs.assert_built + def set_ang(self, ang): + """ + Set the angular velocity of the entity's base link. + + TODO: Compare with the original set_ang in the latest Genesis and fix backward compatibility. + """ + + assert self.n_dofs == 6, "set_ang is only supported for 6 dof entities" + + dofs_idx = [i for i in range(self.dof_start + 3, self.dof_end)] + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + ang = ang.unsqueeze(0) + ang = ang.expand((len(envs_idx), -1)) + self.solver.set_dofs_velocity(ang, dofs_idx, envs_idx) + + @gs.assert_built + def set_ang_grad(self, f, ang_grad): + + assert self.n_dofs == 6, "set_ang is only supported for 6 dof entities" + + dofs_idx = to_gs_tensor([i for i in range(self.dof_start + 3, self.dof_end)]) + envs_idx = self.solver._scene._sanitize_envs_idx(None, unsafe=False) + + tmp_ang_grad = ang_grad.clone().unsqueeze(0) + tmp_ang_grad = tmp_ang_grad.expand((len(envs_idx), -1)) + self.solver._kernel_set_dofs_velocity_grad(f, tmp_ang_grad, dofs_idx, envs_idx) + ang_grad.data = tmp_ang_grad.sum(dim=0) + @gs.assert_built def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False): """ @@ -2139,12 +2316,14 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False): + def set_dofs_velocity(self, f, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False): """ Set the entity's dofs' velocity. Parameters ---------- + f : int + The substep index. velocity : array_like | None The velocity to set. Zero if not specified. dofs_idx_local : None | array_like, optional @@ -2153,7 +2332,7 @@ def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, * The indices of the environments. If None, all environments will be considered. Defaults to None. """ dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) - self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) + self._solver.set_dofs_velocity(f, velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) @gs.assert_built def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): @@ -2457,17 +2636,19 @@ def get_mass_mat(self, envs_idx=None, decompose=False, *, unsafe=False): return self._solver.get_mass_mat(dofs_idx, envs_idx, decompose, unsafe=unsafe) @gs.assert_built - def zero_all_dofs_velocity(self, envs_idx=None, *, unsafe=False): + def zero_all_dofs_velocity(self, f, envs_idx=None, *, unsafe=False): """ Zero the velocity of all the entity's dofs. Parameters ---------- + f : int + The substep index. envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ dofs_idx_local = torch.arange(self.n_dofs, dtype=gs.tc_int, device=gs.device) - self.set_dofs_velocity(None, dofs_idx_local, envs_idx, unsafe=unsafe) + self.set_dofs_velocity(f, None, dofs_idx_local, envs_idx, unsafe=unsafe) @gs.assert_built def detect_collision(self, env_idx=0): @@ -2717,6 +2898,44 @@ def get_mass(self): mass += link.get_mass() return mass + @gs.assert_built + def reset_grad(self): + for key in self._tgt_keys: + self._tgt_buffer[key].clear() + self._queried_states.clear() + + @gs.assert_built + def collect_output_grads(self): + """ + Collect gradients from external queried states. + """ + if self._sim.cur_step_global in self._queried_states: + # one step could have multiple states + for state in self._queried_states[self._sim.cur_step_global]: + self.add_grad_from_state(state) + + @gs.assert_built + def add_grad_from_state(self, state): + if state.pos.grad is not None: + state.pos.assert_contiguous() + self.set_frame_add_grad_pos(self._sim.cur_substep_local, state.pos.grad) + + if state.quat.grad is not None: + state.quat.assert_contiguous() + self.set_frame_add_grad_quat(self._sim.cur_substep_local, state.quat.grad) + + @ti.kernel + def set_frame_add_grad_pos(self, f: ti.i32, pos_grad: ti.types.ndarray()): + for i_b in ti.ndrange(self._solver._B): + for j in ti.static(range(3)): + self._solver.links_state.pos.grad[f, self.base_link_idx, i_b][j] += pos_grad[i_b, j] + + @ti.kernel + def set_frame_add_grad_quat(self, f: ti.i32, quat_grad: ti.types.ndarray()): + for i_b in ti.ndrange(self._solver._B): + for j in ti.static(range(4)): + self._solver.links_state.quat.grad[f, self.base_link_idx, i_b][j] += quat_grad[i_b, j] + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index ea8871d3f..d066ef67e 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -267,18 +267,13 @@ def f_global_to_s_global(self, f_global): # ------------------------------------------------------------------------------------ def step(self, in_backward=False): - if self._rigid_only: # "Only Advance!" --Thomas Wade :P - for _ in range(self._substeps): - self.rigid_solver.substep() - self._cur_substep_global += 1 - else: - self.process_input(in_backward=in_backward) - for _ in range(self._substeps): - self.substep(self.cur_substep_local) + self.process_input(in_backward=in_backward) + for _ in range(self._substeps): + self.substep(self.cur_substep_local) - self._cur_substep_global += 1 - if self.cur_substep_local == 0 and not in_backward: - self.save_ckpt() + self._cur_substep_global += 1 + if self.cur_substep_local == 0 and not in_backward: + self.save_ckpt() if self.rigid_solver.is_active(): self.rigid_solver.clear_external_force() diff --git a/genesis/engine/solvers/rigid/collider_decomp.py b/genesis/engine/solvers/rigid/collider_decomp.py index a4b1e9448..dcd88f78c 100644 --- a/genesis/engine/solvers/rigid/collider_decomp.py +++ b/genesis/engine/solvers/rigid/collider_decomp.py @@ -307,12 +307,13 @@ def clear(self, envs_idx=None): self._collider_state, ) - def detection(self) -> None: + def detection(self, f) -> None: # from genesis.utils.tools import create_timer self._contacts_info_cache = {} # timer = create_timer(name="69477ab0-5e75-47cb-a4a5-d4eebd9336ca", level=3, ti_sync=True, skip_first_call=True) rigid_solver.kernel_update_geom_aabbs( + f, self._solver.geoms_state, self._solver.geoms_init_AABB, self._solver._static_rigid_sim_config, @@ -320,6 +321,7 @@ def detection(self) -> None: ) # timer.stamp("func_update_aabbs") func_broad_phase( + f, self._solver.links_state, self._solver.links_info, self._solver.geoms_state, @@ -334,6 +336,7 @@ def detection(self) -> None: ) # timer.stamp("func_broad_phase") func_narrow_phase_convex_vs_convex( + f, self._solver.links_state, self._solver.links_info, self._solver.geoms_state, @@ -356,6 +359,7 @@ def detection(self) -> None: self._support_field._support_field_static_config, ) func_narrow_phase_convex_specializations( + f, self._solver.geoms_state, self._solver.geoms_info, self._solver.geoms_init_AABB, @@ -387,6 +391,7 @@ def detection(self) -> None: # timer.stamp("func_narrow_phase_any_vs_terrain") if self._collider_static_config.has_nonconvex_nonterrain: func_narrow_phase_nonconvex_vs_nonterrain( + f, self._solver.links_state, self._solver.links_info, self._solver.geoms_state, @@ -632,17 +637,19 @@ def collider_kernel_get_contacts( @ti.func def func_point_in_geom_aabb( + i_f, i_g, i_b, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, point: ti.types.vector(3), ): - return (point < geoms_state.aabb_max[i_g, i_b]).all() and (point > geoms_state.aabb_min[i_g, i_b]).all() + return (point < geoms_state.aabb_max[i_f, i_g, i_b]).all() and (point > geoms_state.aabb_min[i_f, i_g, i_b]).all() @ti.func def func_is_geom_aabbs_overlap( + f, i_ga, i_gb, i_b, @@ -650,8 +657,8 @@ def func_is_geom_aabbs_overlap( geoms_info: array_class.GeomsInfo, ): return not ( - (geoms_state.aabb_max[i_ga, i_b] <= geoms_state.aabb_min[i_gb, i_b]).any() - or (geoms_state.aabb_min[i_ga, i_b] >= geoms_state.aabb_max[i_gb, i_b]).any() + (geoms_state.aabb_max[f, i_ga, i_b] <= geoms_state.aabb_min[f, i_gb, i_b]).any() + or (geoms_state.aabb_min[f, i_ga, i_b] >= geoms_state.aabb_max[f, i_gb, i_b]).any() ) @@ -701,6 +708,7 @@ def func_contact_sphere_sdf( @ti.func def func_contact_vertex_sdf( + i_f, i_ga, i_gb, i_b, @@ -710,8 +718,8 @@ def func_contact_vertex_sdf( collider_static_config: ti.template(), sdf_info: array_class.SDFInfo, ): - ga_pos = geoms_state.pos[i_ga, i_b] - ga_quat = geoms_state.quat[i_ga, i_b] + ga_pos = geoms_state.pos[i_f, i_ga, i_b] + ga_quat = geoms_state.quat[i_f, i_ga, i_b] is_col = False penetration = gs.ti_float(0.0) @@ -720,8 +728,8 @@ def func_contact_vertex_sdf( for i_v in range(geoms_info.vert_start[i_ga], geoms_info.vert_end[i_ga]): vertex_pos = gu.ti_transform_by_trans_quat(verts_info.init_pos[i_v], ga_pos, ga_quat) - if func_point_in_geom_aabb(i_gb, i_b, geoms_state, geoms_info, vertex_pos): - new_penetration = -sdf.sdf_func_world(geoms_state, geoms_info, sdf_info, vertex_pos, i_gb, i_b) + if func_point_in_geom_aabb(i_f, i_gb, i_b, geoms_state, geoms_info, vertex_pos): + new_penetration = -sdf.sdf_func_world(geoms_state, geoms_info, sdf_info, vertex_pos, i_f, i_gb, i_b) if new_penetration > penetration: is_col = True contact_pos = vertex_pos @@ -730,7 +738,7 @@ def func_contact_vertex_sdf( # Compute contact normal only once, and only in case of contact if is_col: normal = sdf.sdf_func_normal_world( - geoms_state, geoms_info, collider_static_config, sdf_info, contact_pos, i_gb, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, contact_pos, i_f, i_gb, i_b ) # The contact point must be offsetted by half the penetration depth @@ -741,6 +749,7 @@ def func_contact_vertex_sdf( @ti.func def func_contact_edge_sdf( + i_f, i_ga, i_gb, i_b, @@ -766,26 +775,26 @@ def func_contact_edge_sdf( i_v1 = edges_info.v1[i_e] p_0 = gu.ti_transform_by_trans_quat( - verts_info.init_pos[i_v0], geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] + verts_info.init_pos[i_v0], geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] ) p_1 = gu.ti_transform_by_trans_quat( - verts_info.init_pos[i_v1], geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] + verts_info.init_pos[i_v1], geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] ) vec_01 = gu.ti_normalize(p_1 - p_0) sdf_grad_0_b = sdf.sdf_func_grad_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p_0, i_gb, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p_0, i_f, i_gb, i_b ) sdf_grad_1_b = sdf.sdf_func_grad_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p_1, i_gb, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p_1, i_f, i_gb, i_b ) # check if the edge on a is facing towards mesh b (I am not 100% sure about this, subject to removal) sdf_grad_0_a = sdf.sdf_func_grad_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p_0, i_ga, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p_0, i_f, i_ga, i_b ) sdf_grad_1_a = sdf.sdf_func_grad_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p_1, i_ga, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p_1, i_f, i_ga, i_b ) normal_edge_0 = sdf_grad_0_a - sdf_grad_0_a.dot(vec_01) * vec_01 normal_edge_1 = sdf_grad_1_a - sdf_grad_1_a.dot(vec_01) * vec_01 @@ -799,7 +808,7 @@ def func_contact_edge_sdf( p_mid = 0.5 * (p_0 + p_1) if ( sdf.sdf_func_grad_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p_mid, i_gb, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p_mid, i_f, i_gb, i_b ).dot(vec_01) < 0 ): @@ -809,12 +818,12 @@ def func_contact_edge_sdf( cur_length = 0.5 * cur_length p = 0.5 * (p_0 + p_1) - new_penetration = -sdf.sdf_func_world(geoms_state, geoms_info, sdf_info, p, i_gb, i_b) + new_penetration = -sdf.sdf_func_world(geoms_state, geoms_info, sdf_info, p, i_f, i_gb, i_b) if new_penetration > penetration: is_col = True normal = sdf.sdf_func_normal_world( - geoms_state, geoms_info, collider_static_config, sdf_info, p, i_gb, i_b + geoms_state, geoms_info, collider_static_config, sdf_info, p, i_f, i_gb, i_b ) contact_pos = p penetration = new_penetration @@ -1125,6 +1134,7 @@ def func_add_prism_vert( @ti.func def func_check_collision_valid( + f, i_ga, i_gb, i_b, @@ -1155,8 +1165,8 @@ def func_check_collision_valid( I_la = [i_la, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_la I_lb = [i_lb, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_lb - if (links_state.hibernated[i_la, i_b] and links_info.is_fixed[I_lb]) or ( - links_state.hibernated[i_lb, i_b] and links_info.is_fixed[I_la] + if (links_state.hibernated[f, i_la, i_b] and links_info.is_fixed[I_lb]) or ( + links_state.hibernated[f, i_lb, i_b] and links_info.is_fixed[I_la] ): is_valid = False @@ -1166,6 +1176,7 @@ def func_check_collision_valid( @gs.maybe_pure @ti.kernel def func_broad_phase( + f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, geoms_state: array_class.GeomsState, @@ -1195,16 +1206,16 @@ def func_broad_phase( # copy updated geom aabbs to buffer for sorting if collider_state.first_time[i_b]: for i in range(n_geoms): - collider_state.sort_buffer.value[2 * i, i_b] = geoms_state.aabb_min[i, i_b][axis] + collider_state.sort_buffer.value[2 * i, i_b] = geoms_state.aabb_min[f, i, i_b][axis] collider_state.sort_buffer.i_g[2 * i, i_b] = i collider_state.sort_buffer.is_max[2 * i, i_b] = 0 - collider_state.sort_buffer.value[2 * i + 1, i_b] = geoms_state.aabb_max[i, i_b][axis] + collider_state.sort_buffer.value[2 * i + 1, i_b] = geoms_state.aabb_max[f, i, i_b][axis] collider_state.sort_buffer.i_g[2 * i + 1, i_b] = i collider_state.sort_buffer.is_max[2 * i + 1, i_b] = 1 - geoms_state.min_buffer_idx[i, i_b] = 2 * i - geoms_state.max_buffer_idx[i, i_b] = 2 * i + 1 + geoms_state.min_buffer_idx[f, i, i_b] = 2 * i + geoms_state.max_buffer_idx[f, i, i_b] = 2 * i + 1 collider_state.first_time[i_b] = False @@ -1214,11 +1225,11 @@ def func_broad_phase( for i in range(n_geoms * 2): if collider_state.sort_buffer.is_max[i, i_b]: collider_state.sort_buffer.value[i, i_b] = geoms_state.aabb_max[ - collider_state.sort_buffer.i_g[i, i_b], i_b + f, collider_state.sort_buffer.i_g[i, i_b], i_b ][axis] else: collider_state.sort_buffer.value[i, i_b] = geoms_state.aabb_min[ - collider_state.sort_buffer.i_g[i, i_b], i_b + f, collider_state.sort_buffer.i_g[i, i_b], i_b ][axis] # insertion sort, which has complexity near O(n) for nearly sorted array @@ -1235,9 +1246,9 @@ def func_broad_phase( if ti.static(static_rigid_sim_config.use_hibernation): if collider_state.sort_buffer.is_max[j, i_b]: - geoms_state.max_buffer_idx[collider_state.sort_buffer.i_g[j, i_b], i_b] = j + 1 + geoms_state.max_buffer_idx[f, collider_state.sort_buffer.i_g[j, i_b], i_b] = j + 1 else: - geoms_state.min_buffer_idx[collider_state.sort_buffer.i_g[j, i_b], i_b] = j + 1 + geoms_state.min_buffer_idx[f, collider_state.sort_buffer.i_g[j, i_b], i_b] = j + 1 j -= 1 collider_state.sort_buffer.value[j + 1, i_b] = key_value @@ -1246,11 +1257,12 @@ def func_broad_phase( if ti.static(static_rigid_sim_config.use_hibernation): if key_is_max: - geoms_state.max_buffer_idx[key_i_g, i_b] = j + 1 + geoms_state.max_buffer_idx[f, key_i_g, i_b] = j + 1 else: - geoms_state.min_buffer_idx[key_i_g, i_b] = j + 1 + geoms_state.min_buffer_idx[f, key_i_g, i_b] = j + 1 # sweep over the sorted AABBs to find potential collision pairs + # TODO: merge hibernation code collider_state.n_broad_pairs[i_b] = 0 if ti.static(not static_rigid_sim_config.use_hibernation): n_active = 0 @@ -1263,6 +1275,7 @@ def func_broad_phase( i_ga, i_gb = i_gb, i_ga if not func_check_collision_valid( + f, i_ga, i_gb, i_b, @@ -1276,7 +1289,7 @@ def func_broad_phase( ): continue - if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info): + if not func_is_geom_aabbs_overlap(f, i_ga, i_gb, i_b, geoms_state, geoms_info): # Clear collision normal cache if not in contact if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): # self.contact_cache[i_ga, i_gb, i_b].i_va_ws = -1 @@ -1307,7 +1320,7 @@ def func_broad_phase( n_active_awake = 0 n_active_hib = 0 for i in range(2 * n_geoms): - is_incoming_geom_hibernated = geoms_state.hibernated[collider_state.sort_buffer.i_g[i, i_b], i_b] + is_incoming_geom_hibernated = geoms_state.hibernated[f, collider_state.sort_buffer.i_g[i, i_b], i_b] if not collider_state.sort_buffer.is_max[i, i_b]: # both awake and hibernated geom check with active awake geoms @@ -1318,6 +1331,7 @@ def func_broad_phase( i_ga, i_gb = i_gb, i_ga if not func_check_collision_valid( + f, i_ga, i_gb, i_b, @@ -1331,7 +1345,7 @@ def func_broad_phase( ): continue - if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info): + if not func_is_geom_aabbs_overlap(f, i_ga, i_gb, i_b, geoms_state, geoms_info): # Clear collision normal cache if not in contact if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): # self.contact_cache[i_ga, i_gb, i_b].i_va_ws = -1 @@ -1353,6 +1367,7 @@ def func_broad_phase( i_ga, i_gb = i_gb, i_ga if not func_check_collision_valid( + f, i_ga, i_gb, i_b, @@ -1366,7 +1381,7 @@ def func_broad_phase( ): continue - if not func_is_geom_aabbs_overlap(i_ga, i_gb, i_b, geoms_state, geoms_info): + if not func_is_geom_aabbs_overlap(f, i_ga, i_gb, i_b, geoms_state, geoms_info): # Clear collision normal cache if not in contact # self.contact_cache[i_ga, i_gb, i_b].i_va_ws = -1 collider_state.contact_cache.normal[i_ga, i_gb, i_b] = ti.Vector.zero( @@ -1413,6 +1428,7 @@ def func_broad_phase( @gs.maybe_pure @ti.kernel def func_narrow_phase_convex_vs_convex( + i_f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, geoms_state: array_class.GeomsState, @@ -1465,6 +1481,7 @@ def func_narrow_phase_convex_vs_convex( ): if ti.static(sys.platform == "darwin"): func_convex_convex_contact( + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, @@ -1491,6 +1508,7 @@ def func_narrow_phase_convex_vs_convex( else: if not (geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE and geoms_info.type[i_gb] == gs.GEOM_TYPE.BOX): func_convex_convex_contact( + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, @@ -1519,6 +1537,7 @@ def func_narrow_phase_convex_vs_convex( @gs.maybe_pure @ti.kernel def func_narrow_phase_convex_specializations( + i_f: ti.i32, geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, geoms_init_AABB: array_class.GeomsInitAABB, @@ -1543,6 +1562,7 @@ def func_narrow_phase_convex_specializations( if ti.static(sys.platform != "darwin"): if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE and geoms_info.type[i_gb] == gs.GEOM_TYPE.BOX: func_plane_box_contact( + i_f, i_ga, i_gb, i_b, @@ -1560,6 +1580,7 @@ def func_narrow_phase_convex_specializations( if ti.static(static_rigid_sim_config.box_box_detection): if geoms_info.type[i_ga] == gs.GEOM_TYPE.BOX and geoms_info.type[i_gb] == gs.GEOM_TYPE.BOX: func_box_box_contact( + i_f, i_ga, i_gb, i_b, @@ -1627,6 +1648,7 @@ def func_narrow_phase_any_vs_terrain( @gs.maybe_pure @ti.kernel def func_narrow_phase_nonconvex_vs_nonterrain( + i_f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, geoms_state: array_class.GeomsState, @@ -1674,6 +1696,7 @@ def func_narrow_phase_nonconvex_vs_nonterrain( contact_pos_i = ti.Vector.zero(gs.ti_float, 3) if not is_col: is_col_i, normal_i, penetration_i, contact_pos_i = func_contact_vertex_sdf( + i_f, i_ga, i_gb, i_b, @@ -1699,11 +1722,12 @@ def func_narrow_phase_nonconvex_vs_nonterrain( if ti.static(static_rigid_sim_config.enable_multi_contact): if not is_col and is_col_i: - ga_pos, ga_quat = geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] - gb_pos, gb_quat = geoms_state.pos[i_gb, i_b], geoms_state.quat[i_gb, i_b] + ga_pos, ga_quat = geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] + gb_pos, gb_quat = geoms_state.pos[i_f, i_gb, i_b], geoms_state.quat[i_f, i_gb, i_b] # Perturb geom_a around two orthogonal axes to find multiple contacts axis_0, axis_1 = func_contact_orthogonals( + i_f, i_ga, i_gb, normal_i, @@ -1721,12 +1745,13 @@ def func_narrow_phase_nonconvex_vs_nonterrain( axis = (2 * (i_rot % 2) - 1) * axis_0 + (1 - 2 * ((i_rot // 2) % 2)) * axis_1 qrot = gu.ti_rotvec_to_quat(collider_static_config.mc_perturbation * axis) - func_rotate_frame(i_ga, contact_pos_i, qrot, i_b, geoms_state, geoms_info) + func_rotate_frame(i_f, i_ga, contact_pos_i, qrot, i_b, geoms_state, geoms_info) func_rotate_frame( - i_gb, contact_pos_i, gu.ti_inv_quat(qrot), i_b, geoms_state, geoms_info + i_f, i_gb, contact_pos_i, gu.ti_inv_quat(qrot), i_b, geoms_state, geoms_info ) is_col, normal, penetration, contact_pos = func_contact_vertex_sdf( + i_f, i_ga, i_gb, i_b, @@ -1795,11 +1820,12 @@ def func_narrow_phase_nonconvex_vs_nonterrain( ) n_con += 1 - geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] = ga_pos, ga_quat - geoms_state.pos[i_gb, i_b], geoms_state.quat[i_gb, i_b] = gb_pos, gb_quat + geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] = ga_pos, ga_quat + geoms_state.pos[i_f, i_gb, i_b], geoms_state.quat[i_f, i_gb, i_b] = gb_pos, gb_quat if not is_col: # check edge-edge if vertex-face is not detected is_col, normal, penetration, contact_pos = func_contact_edge_sdf( + i_f, i_ga, i_gb, i_b, @@ -1827,6 +1853,7 @@ def func_narrow_phase_nonconvex_vs_nonterrain( @ti.func def func_plane_box_contact( + i_f, i_ga, i_gb, i_b, @@ -1840,8 +1867,8 @@ def func_plane_box_contact( collider_info: array_class.ColliderInfo, collider_static_config: ti.template(), ): - ga_pos, ga_quat = geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] - gb_pos, gb_quat = geoms_state.pos[i_gb, i_b], geoms_state.quat[i_gb, i_b] + ga_pos, ga_quat = geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] + gb_pos, gb_quat = geoms_state.pos[i_f, i_gb, i_b], geoms_state.quat[i_f, i_gb, i_b] plane_dir = ti.Vector( [geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]], dt=gs.ti_float @@ -1849,7 +1876,7 @@ def func_plane_box_contact( plane_dir = gu.ti_transform_by_quat(plane_dir, ga_quat) normal = -plane_dir.normalized() - v1, _ = support_field._func_support_box(geoms_state, geoms_info, normal, i_gb, i_b) + v1, _ = support_field._func_support_box(geoms_state, geoms_info, normal, i_f, i_gb, i_b) penetration = normal.dot(v1 - ga_pos) if penetration > 0.0: @@ -1956,6 +1983,7 @@ def func_compute_tolerance( @ti.func def func_contact_orthogonals( + i_f, i_ga, i_gb, normal: ti.types.vector(3), @@ -1997,7 +2025,7 @@ def func_contact_orthogonals( # Compute orthogonal basis mixing principal inertia axes of geometry with contact normal i_l = geoms_info.link_idx[i_g] - rot = gu.ti_quat_to_R(links_state.i_quat[i_l, i_b]) + rot = gu.ti_quat_to_R(links_state.i_quat[i_f, i_l, i_b]) axis_idx = gs.ti_int(0) axis_angle_max = gs.ti_float(0.0) for i in ti.static(range(3)): @@ -2015,6 +2043,7 @@ def func_contact_orthogonals( @ti.func def func_convex_convex_contact( + i_f, i_ga, i_gb, i_b, @@ -2041,6 +2070,7 @@ def func_convex_convex_contact( if geoms_info.type[i_ga] == gs.GEOM_TYPE.PLANE and geoms_info.type[i_gb] == gs.GEOM_TYPE.BOX: if ti.static(sys.platform == "darwin"): func_plane_box_contact( + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, @@ -2070,8 +2100,8 @@ def func_convex_convex_contact( tolerance = func_compute_tolerance(i_ga, i_gb, i_b, geoms_info, geoms_init_AABB, collider_static_config) # Backup state before local perturbation - ga_pos, ga_quat = geoms_state.pos[i_ga, i_b], geoms_state.quat[i_ga, i_b] - gb_pos, gb_quat = geoms_state.pos[i_gb, i_b], geoms_state.quat[i_gb, i_b] + ga_pos, ga_quat = geoms_state.pos[i_f, i_ga, i_b], geoms_state.quat[i_f, i_ga, i_b] + gb_pos, gb_quat = geoms_state.pos[i_f, i_gb, i_b], geoms_state.quat[i_f, i_gb, i_b] # Pre-allocate some buffers is_col_0 = False @@ -2095,8 +2125,8 @@ def func_convex_convex_contact( # otherwise it would be more sensitive to ill-conditionning. axis = (2 * (i_detection % 2) - 1) * axis_0 + (1 - 2 * ((i_detection // 2) % 2)) * axis_1 qrot = gu.ti_rotvec_to_quat(collider_static_config.mc_perturbation * axis) - func_rotate_frame(i_ga, contact_pos_0, qrot, i_b, geoms_state, geoms_info) - func_rotate_frame(i_gb, contact_pos_0, gu.ti_inv_quat(qrot), i_b, geoms_state, geoms_info) + func_rotate_frame(i_f, i_ga, contact_pos_0, qrot, i_b, geoms_state, geoms_info) + func_rotate_frame(i_f, i_gb, contact_pos_0, gu.ti_inv_quat(qrot), i_b, geoms_state, geoms_info) if (multi_contact and is_col_0) or (i_detection == 0): try_sdf = False @@ -2105,7 +2135,7 @@ def func_convex_convex_contact( [geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]], dt=gs.ti_float, ) - plane_dir = gu.ti_transform_by_quat(plane_dir, geoms_state.quat[i_ga, i_b]) + plane_dir = gu.ti_transform_by_quat(plane_dir, geoms_state.quat[i_f, i_ga, i_b]) normal = -plane_dir.normalized() v1 = mpr.support_driver( @@ -2117,10 +2147,11 @@ def func_convex_convex_contact( support_field_info, support_field_static_config, normal, + i_f, i_gb, i_b, ) - penetration = normal.dot(v1 - geoms_state.pos[i_ga, i_b]) + penetration = normal.dot(v1 - geoms_state.pos[i_f, i_ga, i_b]) contact_pos = v1 - 0.5 * penetration * normal is_col = penetration > 0 else: @@ -2156,6 +2187,7 @@ def func_convex_convex_contact( mpr_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -2192,6 +2224,7 @@ def func_convex_convex_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -2252,6 +2285,7 @@ def func_convex_convex_contact( # ) # self.contact_cache[i_ga, i_gb, i_b].i_va_ws = i_va is_col_i, normal_i, penetration_i, contact_pos_i = func_contact_vertex_sdf( + i_f, i_ga if i_sdf == 0 else i_gb, i_gb if i_sdf == 0 else i_ga, i_b, @@ -2313,6 +2347,7 @@ def func_convex_convex_contact( if multi_contact: # perturb geom_a around two orthogonal axes to find multiple contacts axis_0, axis_1 = func_contact_orthogonals( + i_f, i_ga, i_gb, normal, @@ -2406,14 +2441,15 @@ def func_convex_convex_contact( ) n_con = n_con + 1 - geoms_state.pos[i_ga, i_b] = ga_pos - geoms_state.quat[i_ga, i_b] = ga_quat - geoms_state.pos[i_gb, i_b] = gb_pos - geoms_state.quat[i_gb, i_b] = gb_quat + geoms_state.pos[i_f, i_ga, i_b] = ga_pos + geoms_state.quat[i_f, i_ga, i_b] = ga_quat + geoms_state.pos[i_f, i_gb, i_b] = gb_pos + geoms_state.quat[i_f, i_gb, i_b] = gb_quat @ti.func def func_rotate_frame( + i_f, i_g, contact_pos: ti.types.vector(3), qrot: ti.types.vector(4), @@ -2421,16 +2457,17 @@ def func_rotate_frame( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, ): - geoms_state.quat[i_g, i_b] = gu.ti_transform_quat_by_quat(geoms_state.quat[i_g, i_b], qrot) + geoms_state.quat[i_f, i_g, i_b] = gu.ti_transform_quat_by_quat(geoms_state.quat[i_f, i_g, i_b], qrot) - rel = contact_pos - geoms_state.pos[i_g, i_b] + rel = contact_pos - geoms_state.pos[i_f, i_g, i_b] vec = gu.ti_transform_by_quat(rel, qrot) vec = vec - rel - geoms_state.pos[i_g, i_b] = geoms_state.pos[i_g, i_b] - vec + geoms_state.pos[i_f, i_g, i_b] = geoms_state.pos[i_f, i_g, i_b] - vec @ti.func def func_box_box_contact( + i_f, i_ga, i_gb, i_b, @@ -2459,10 +2496,10 @@ def func_box_box_contact( margin2 = margin * margin rotmore = ti.Matrix.zero(gs.ti_float, 3, 3) - ga_pos = geoms_state.pos[i_ga, i_b] - gb_pos = geoms_state.pos[i_gb, i_b] - ga_quat = geoms_state.quat[i_ga, i_b] - gb_quat = geoms_state.quat[i_gb, i_b] + ga_pos = geoms_state.pos[i_f, i_ga, i_b] + gb_pos = geoms_state.pos[i_f, i_gb, i_b] + ga_quat = geoms_state.quat[i_f, i_ga, i_b] + gb_quat = geoms_state.quat[i_f, i_gb, i_b] size1 = ( ti.Vector([geoms_info.data[i_ga][0], geoms_info.data[i_ga][1], geoms_info.data[i_ga][2]], dt=gs.ti_float) / 2 diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index f200c31f8..04c889bfb 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -188,8 +188,9 @@ def reset(self, envs_idx=None): static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) - def add_equality_constraints(self): + def add_equality_constraints(self, f): add_equality_constraints( + i_f=f, links_info=self._solver.links_info, links_state=self._solver.links_state, dofs_state=self._solver.dofs_state, @@ -203,8 +204,9 @@ def add_equality_constraints(self): static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) - def add_frictionloss_constraints(self): + def add_frictionloss_constraints(self, f): add_frictionloss_constraints( + i_f=f, links_info=self._solver.links_info, joints_info=self._solver.joints_info, dofs_info=self._solver.dofs_info, @@ -215,8 +217,9 @@ def add_frictionloss_constraints(self): static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) - def add_collision_constraints(self): + def add_collision_constraints(self, f): add_collision_constraints( + i_f=f, links_info=self._solver.links_info, links_state=self._solver.links_state, dofs_state=self._solver.dofs_state, @@ -226,8 +229,9 @@ def add_collision_constraints(self): static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) - def add_joint_limit_constraints(self): + def add_joint_limit_constraints(self, f): add_joint_limit_constraints( + i_f=f, links_info=self._solver.links_info, joints_info=self._solver.joints_info, dofs_info=self._solver.dofs_info, @@ -238,7 +242,7 @@ def add_joint_limit_constraints(self): static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) - def resolve(self): + def resolve(self, f): # Early return if there is nothing to solve if not self._solver._enable_collision and not self._solver._enable_joint_limit: has_equality_constraints = np.any(self.constraint_state.ti_n_equalities.to_numpy()) @@ -249,6 +253,7 @@ def resolve(self): # timer = create_timer(name="resolve", level=3, ti_sync=True, skip_first_call=True) func_init_solver( + i_f=f, dofs_state=self._solver.dofs_state, entities_info=self._solver.entities_info, constraint_state=self.constraint_state, @@ -256,8 +261,10 @@ def resolve(self): static_rigid_sim_config=self._solver._static_rigid_sim_config, static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) + print("before solve qacc[2, 0]:", self.constraint_state.qacc.to_numpy()[2, 0]) # timer.stamp("_func_init_solver") func_solve( + i_f=f, entities_info=self._solver.entities_info, dofs_state=self._solver.dofs_state, constraint_state=self.constraint_state, @@ -265,20 +272,24 @@ def resolve(self): static_rigid_sim_config=self._solver._static_rigid_sim_config, static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) + print("after solve qacc[2, 0]:", self.constraint_state.qacc.to_numpy()[2, 0]) # timer.stamp("_func_solve") func_update_qacc( + i_f=f, dofs_state=self._solver.dofs_state, constraint_state=self.constraint_state, static_rigid_sim_config=self._solver._static_rigid_sim_config, static_rigid_sim_cache_key=self._solver._static_rigid_sim_cache_key, ) # timer.stamp("_func_update_qacc") + print("after update dofs acc[2, 0]:", self._solver.dofs_state.acc.to_numpy()[f, 2, 0]) if self._solver._static_rigid_sim_config.noslip_iterations > 0: self.noslip() func_update_contact_force( + i_f=f, links_state=self._solver.links_state, collider_state=self._collider._collider_state, constraint_state=self.constraint_state, @@ -490,6 +501,7 @@ def constraint_solver_kernel_reset( @gs.maybe_pure @ti.kernel def add_collision_constraints( + i_f: ti.i32, links_info: array_class.LinksInfo, links_state: array_class.LinksState, dofs_state: array_class.DofsState, @@ -498,8 +510,8 @@ def add_collision_constraints( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] + _B = dofs_state.ctrl_mode.shape[2] + n_dofs = dofs_state.ctrl_mode.shape[1] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): @@ -553,16 +565,16 @@ def add_collision_constraints( for i_d_ in range(links_info.n_dofs[link_maybe_batch]): i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + cdof_ang = dofs_state.cdof_ang[i_f, i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_f, i_d, i_b] t_quat = gu.ti_identity_quat() - t_pos = contact_data_pos - links_state.root_COM[link, i_b] + t_pos = contact_data_pos - links_state.root_COM[i_f, link, i_b] _, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) diff = sign * vel jac = diff @ n - jac_qvel = jac_qvel + jac * dofs_state.vel[i_d, i_b] + jac_qvel = jac_qvel + jac * dofs_state.vel[i_f, i_d, i_b] constraint_state.jac[n_con, i_d, i_b] = constraint_state.jac[n_con, i_d, i_b] + jac if ti.static(static_rigid_sim_config.sparse_solve): @@ -588,6 +600,7 @@ def add_collision_constraints( @ti.func def func_equality_connect( + i_f, i_b, i_e, links_info: array_class.LinksInfo, @@ -598,7 +611,7 @@ def func_equality_connect( collider_state: array_class.ColliderState, static_rigid_sim_config: ti.template(), ): - n_dofs = dofs_state.ctrl_mode.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[1] link1_idx = equalities_info.eq_obj1id[i_e, i_b] link2_idx = equalities_info.eq_obj2id[i_e, i_b] @@ -623,13 +636,13 @@ def func_equality_connect( # Transform anchor positions to global coordinates global_anchor1 = gu.ti_transform_by_trans_quat( pos=anchor1_pos, - trans=links_state.pos[link1_idx, i_b], - quat=links_state.quat[link1_idx, i_b], + trans=links_state.pos[i_f, link1_idx, i_b], + quat=links_state.quat[i_f, link1_idx, i_b], ) global_anchor2 = gu.ti_transform_by_trans_quat( pos=anchor2_pos, - trans=links_state.pos[link2_idx, i_b], - quat=links_state.quat[link2_idx, i_b], + trans=links_state.pos[i_f, link2_idx, i_b], + quat=links_state.quat[i_f, link2_idx, i_b], ) invweight = links_info.invweight[link_a_maybe_batch][0] + links_info.invweight[link_b_maybe_batch][0] @@ -639,7 +652,7 @@ def func_equality_connect( ti.atomic_add(constraint_state.n_constraints_equality[i_b], 1) if ti.static(static_rigid_sim_config.sparse_solve): - for i_d_ in range(collider_state.jac_n_relevant_dofs[n_con, i_b]): + for i_d_ in range(constraint_state.jac_n_relevant_dofs[n_con, i_b]): i_d = constraint_state.jac_relevant_dofs[n_con, i_d_, i_b] constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) else: @@ -662,16 +675,16 @@ def func_equality_connect( for i_d_ in range(links_info.n_dofs[link_maybe_batch]): i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + cdof_ang = dofs_state.cdof_ang[i_f, i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_f, i_d, i_b] t_quat = gu.ti_identity_quat() - t_pos = pos - links_state.root_COM[link, i_b] + t_pos = pos - links_state.root_COM[i_f, link, i_b] ang, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) diff = sign * vel jac = diff[i_3] - jac_qvel = jac_qvel + jac * dofs_state.vel[i_d, i_b] + jac_qvel = jac_qvel + jac * dofs_state.vel[i_f, i_d, i_b] constraint_state.jac[n_con, i_d, i_b] = constraint_state.jac[n_con, i_d, i_b] + jac if ti.static(static_rigid_sim_config.sparse_solve): @@ -697,6 +710,7 @@ def func_equality_connect( @ti.func def func_equality_joint( + i_f, i_b, i_e, joints_info: array_class.JointsInfo, @@ -739,8 +753,8 @@ def func_equality_joint( for i_d in range(n_dofs): constraint_state.jac[n_con, i_d, i_b] = gs.ti_float(0.0) - pos1 = rigid_global_info.qpos[i_qpos1, i_b] - pos2 = rigid_global_info.qpos[i_qpos2, i_b] + pos1 = rigid_global_info.qpos[i_f, i_qpos1, i_b] + pos2 = rigid_global_info.qpos[i_f, i_qpos2, i_b] ref1 = rigid_global_info.qpos0[i_qpos1, i_b] ref2 = rigid_global_info.qpos0[i_qpos2, i_b] @@ -759,8 +773,8 @@ def func_equality_joint( constraint_state.jac[n_con, i_dof1, i_b] = gs.ti_float(1.0) constraint_state.jac[n_con, i_dof2, i_b] = -deriv jac_qvel = ( - constraint_state.jac[n_con, i_dof1, i_b] * dofs_state.vel[i_dof1, i_b] - + constraint_state.jac[n_con, i_dof2, i_b] * dofs_state.vel[i_dof2, i_b] + constraint_state.jac[n_con, i_dof1, i_b] * dofs_state.vel[i_f, i_dof1, i_b] + + constraint_state.jac[n_con, i_dof2, i_b] * dofs_state.vel[i_f, i_dof2, i_b] ) invweight = dofs_info.invweight[I_dof1] + dofs_info.invweight[I_dof2] @@ -776,6 +790,7 @@ def func_equality_joint( @gs.maybe_pure @ti.kernel def add_equality_constraints( + i_f: ti.i32, links_info: array_class.LinksInfo, links_state: array_class.LinksState, dofs_state: array_class.DofsState, @@ -788,12 +803,13 @@ def add_equality_constraints( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - _B = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_b in range(_B): for i_e in range(constraint_state.ti_n_equalities[i_b]): if equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.CONNECT: func_equality_connect( + i_f, i_b, i_e, links_info=links_info, @@ -807,6 +823,7 @@ def add_equality_constraints( elif equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.WELD: func_equality_weld( + i_f, i_b, i_e, links_info=links_info, @@ -818,6 +835,7 @@ def add_equality_constraints( ) elif equalities_info.eq_type[i_e, i_b] == gs.EQUALITY_TYPE.JOINT: func_equality_joint( + i_f, i_b, i_e, joints_info=joints_info, @@ -832,6 +850,7 @@ def add_equality_constraints( @ti.func def func_equality_weld( + i_f, i_b, i_e, links_info: array_class.LinksInfo, @@ -841,7 +860,7 @@ def func_equality_weld( constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), ): - n_dofs = dofs_state.ctrl_mode.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[1] # TODO: sparse mode # Get equality info for this constraint @@ -883,21 +902,21 @@ def func_equality_weld( # Transform anchor positions to global coordinates global_anchor1 = gu.ti_transform_by_trans_quat( pos=anchor1_pos, - trans=links_state.pos[link1_idx, i_b], - quat=links_state.quat[link1_idx, i_b], + trans=links_state.pos[i_f, link1_idx, i_b], + quat=links_state.quat[i_f, link1_idx, i_b], ) global_anchor2 = gu.ti_transform_by_trans_quat( pos=anchor2_pos, - trans=links_state.pos[link2_idx, i_b], - quat=links_state.quat[link2_idx, i_b], + trans=links_state.pos[i_f, link2_idx, i_b], + quat=links_state.quat[i_f, link2_idx, i_b], ) pos_error = global_anchor1 - global_anchor2 # Compute orientation error. # For weld: compute q = body1_quat * relpose, then error = (inv(body2_quat) * q) - quat_body1 = links_state.quat[link1_idx, i_b] - quat_body2 = links_state.quat[link2_idx, i_b] + quat_body1 = links_state.quat[i_f, link1_idx, i_b] + quat_body2 = links_state.quat[i_f, link2_idx, i_b] q = gu.ti_quat_mul(quat_body1, relpose) inv_quat_body2 = gu.ti_inv_quat(quat_body2) error_quat = gu.ti_quat_mul(inv_quat_body2, q) @@ -937,15 +956,15 @@ def func_equality_weld( for i_d_ in range(links_info.n_dofs[link_maybe_batch]): i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - cdof_ang = dofs_state.cdof_ang[i_d, i_b] - cdot_vel = dofs_state.cdof_vel[i_d, i_b] + cdof_ang = dofs_state.cdof_ang[i_f, i_d, i_b] + cdot_vel = dofs_state.cdof_vel[i_f, i_d, i_b] t_quat = gu.ti_identity_quat() - t_pos = pos_anchor - links_state.root_COM[link, i_b] + t_pos = pos_anchor - links_state.root_COM[i_f, link, i_b] ang, vel = gu.ti_transform_motion_by_trans_quat(cdof_ang, cdot_vel, t_pos, t_quat) diff = sign * vel jac = diff[i] - jac_qvel += jac * dofs_state.vel[i_d, i_b] + jac_qvel += jac * dofs_state.vel[i_f, i_d, i_b] constraint_state.jac[n_con, i_d, i_b] += jac if ti.static(static_rigid_sim_config.sparse_solve): @@ -981,7 +1000,7 @@ def func_equality_weld( for i_d_ in range(links_info.n_dofs[link_maybe_batch]): i_d = links_info.dof_end[link_maybe_batch] - 1 - i_d_ - jac = sign * dofs_state.cdof_ang[i_d, i_b] + jac = sign * dofs_state.cdof_ang[i_f, i_d, i_b] for i_con in range(n_con, n_con + 3): constraint_state.jac[i_con, i_d, i_b] = constraint_state.jac[i_con, i_d, i_b] + jac[i_con - n_con] @@ -1004,7 +1023,7 @@ def func_equality_weld( for i_con in range(n_con, n_con + 3): constraint_state.jac[i_con, i_d, i_b] = 0.5 * quat3[i_con - n_con + 1] * torquescale jac_qvel[i_con - n_con] = ( - jac_qvel[i_con - n_con] + constraint_state.jac[i_con, i_d, i_b] * dofs_state.vel[i_d, i_b] + jac_qvel[i_con - n_con] + constraint_state.jac[i_con, i_d, i_b] * dofs_state.vel[i_f, i_d, i_b] ) for i_con in range(n_con, n_con + 3): @@ -1022,6 +1041,7 @@ def func_equality_weld( @gs.maybe_pure @ti.kernel def add_joint_limit_constraints( + i_f: ti.i32, links_info: array_class.LinksInfo, joints_info: array_class.JointsInfo, dofs_info: array_class.DofsInfo, @@ -1033,7 +1053,7 @@ def add_joint_limit_constraints( ): _B = constraint_state.jac.shape[2] n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[1] # TODO: sparse mode ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) @@ -1048,13 +1068,13 @@ def add_joint_limit_constraints( i_q = joints_info.q_start[I_j] i_d = joints_info.dof_start[I_j] I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - pos_delta_min = rigid_global_info.qpos[i_q, i_b] - dofs_info.limit[I_d][0] - pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_q, i_b] + pos_delta_min = rigid_global_info.qpos[i_f, i_q, i_b] - dofs_info.limit[I_d][0] + pos_delta_max = dofs_info.limit[I_d][1] - rigid_global_info.qpos[i_f, i_q, i_b] pos_delta = min(pos_delta_min, pos_delta_max) if pos_delta < 0: jac = (pos_delta_min < pos_delta_max) * 2 - 1 - jac_qvel = jac * dofs_state.vel[i_d, i_b] + jac_qvel = jac * dofs_state.vel[i_f, i_d, i_b] imp, aref = gu.imp_aref(joints_info.sol_params[I_j], pos_delta, jac_qvel, pos_delta) diag = ti.max(dofs_info.invweight[I_d] * (1 - imp) / imp, gs.EPS) @@ -1081,6 +1101,7 @@ def add_joint_limit_constraints( @gs.maybe_pure @ti.kernel def add_frictionloss_constraints( + i_f: ti.i32, links_info: array_class.LinksInfo, joints_info: array_class.JointsInfo, dofs_info: array_class.DofsInfo, @@ -1092,7 +1113,7 @@ def add_frictionloss_constraints( ): _B = constraint_state.jac.shape[2] n_links = links_info.root_idx.shape[0] - n_dofs = dofs_state.ctrl_mode.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[1] # TODO: sparse mode ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) @@ -1108,7 +1129,7 @@ def add_frictionloss_constraints( if dofs_info.frictionloss[I_d] > gs.EPS: jac = 1.0 - jac_qvel = jac * dofs_state.vel[i_d, i_b] + jac_qvel = jac * dofs_state.vel[i_f, i_d, i_b] imp, aref = gu.imp_aref(joints_info.sol_params[I_j], 0.0, jac_qvel, 0.0) diag = ti.max(dofs_info.invweight[I_d] * (1.0 - imp) / imp, gs.EPS) @@ -1126,6 +1147,7 @@ def add_frictionloss_constraints( @ti.func def func_nt_hessian_incremental( + i_f, i_b, entities_info: array_class.EntitiesInfo, constraint_state: array_class.ConstraintState, @@ -1184,6 +1206,7 @@ def func_nt_hessian_incremental( if rank < n_dofs: func_nt_hessian_direct( + i_f, i_b, entities_info=entities_info, constraint_state=constraint_state, @@ -1226,6 +1249,7 @@ def func_nt_hessian_incremental( if rank < n_dofs: func_nt_hessian_direct( + i_f, i_b, entities_info=entities_info, constraint_state=constraint_state, @@ -1237,6 +1261,7 @@ def func_nt_hessian_incremental( @ti.func def func_nt_hessian_direct( + i_f, i_b, entities_info: array_class.EntitiesInfo, constraint_state: array_class.ConstraintState, @@ -1286,7 +1311,7 @@ def func_nt_hessian_direct( for i_d1 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): constraint_state.nt_H[i_d1, i_d2, i_b] = ( - constraint_state.nt_H[i_d1, i_d2, i_b] + rigid_global_info.mass_mat[i_d1, i_d2, i_b] + constraint_state.nt_H[i_d1, i_d2, i_b] + rigid_global_info.mass_mat[i_f, i_d1, i_d2, i_b] ) # self.nt_ori_H[i_d1, i_d2, i_b] = self.nt_H[i_d1, i_d2, i_b] @@ -1351,18 +1376,19 @@ def func_nt_chol_solve( @gs.maybe_pure @ti.kernel def func_update_contact_force( + i_f: ti.i32, links_state: array_class.LinksState, collider_state: array_class.ColliderState, constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_links = links_state.contact_force.shape[0] - _B = links_state.contact_force.shape[1] + n_links = links_state.contact_force.shape[1] + _B = links_state.contact_force.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(n_links, _B): - links_state.contact_force[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.contact_force[i_f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): @@ -1384,35 +1410,39 @@ def func_update_contact_force( collider_state.contact_data.force[i_c, i_b] = force - links_state.contact_force[contact_data_link_a, i_b] = ( - links_state.contact_force[contact_data_link_a, i_b] - force + links_state.contact_force[i_f, contact_data_link_a, i_b] = ( + links_state.contact_force[i_f, contact_data_link_a, i_b] - force ) - links_state.contact_force[contact_data_link_b, i_b] = ( - links_state.contact_force[contact_data_link_b, i_b] + force + links_state.contact_force[i_f, contact_data_link_b, i_b] = ( + links_state.contact_force[i_f, contact_data_link_b, i_b] + force ) @gs.maybe_pure @ti.kernel def func_update_qacc( + i_f: ti.i32, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_dofs = dofs_state.acc.shape[0] - _B = dofs_state.acc.shape[1] + n_dofs = dofs_state.acc.shape[1] + _B = dofs_state.acc.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.acc[i_d, i_b] = constraint_state.qacc[i_d, i_b] - dofs_state.qf_constraint[i_d, i_b] = constraint_state.qfrc_constraint[i_d, i_b] - dofs_state.force[i_d, i_b] = dofs_state.qf_smooth[i_d, i_b] + constraint_state.qfrc_constraint[i_d, i_b] + dofs_state.acc[i_f, i_d, i_b] = constraint_state.qacc[i_d, i_b] + dofs_state.qf_constraint[i_f, i_d, i_b] = constraint_state.qfrc_constraint[i_d, i_b] + dofs_state.force[i_f, i_d, i_b] = ( + dofs_state.qf_smooth[i_f, i_d, i_b] + constraint_state.qfrc_constraint[i_d, i_b] + ) constraint_state.qacc_ws[i_d, i_b] = constraint_state.qacc[i_d, i_b] @gs.maybe_pure @ti.kernel def func_solve( + i_f: ti.i32, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, constraint_state: array_class.ConstraintState, @@ -1430,6 +1460,7 @@ def func_solve( tol_scaled = (rigid_global_info.meaninertia[i_b] * ti.max(1, n_dofs)) * static_rigid_sim_config.tolerance for it in range(static_rigid_sim_config.iterations): func_solve_body( + i_f, i_b, entities_info=entities_info, dofs_state=dofs_state, @@ -1443,6 +1474,7 @@ def func_solve( @ti.func def func_ls_init( + i_f, i_b, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -1458,7 +1490,7 @@ def func_ls_init( for i_d1 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): mv = gs.ti_float(0.0) for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - mv += rigid_global_info.mass_mat[i_d1, i_d2, i_b] * constraint_state.search[i_d2, i_b] + mv += rigid_global_info.mass_mat[i_f, i_d1, i_d2, i_b] * constraint_state.search[i_d2, i_b] constraint_state.mv[i_d1, i_b] = mv for i_c in range(constraint_state.n_constraints[i_b]): @@ -1478,7 +1510,7 @@ def func_ls_init( for i_d in range(n_dofs): quad_gauss_1 += ( constraint_state.search[i_d, i_b] * constraint_state.Ma[i_d, i_b] - - constraint_state.search[i_d, i_b] * dofs_state.force[i_d, i_b] + - constraint_state.search[i_d, i_b] * dofs_state.force[i_f, i_d, i_b] ) quad_gauss_2 += 0.5 * constraint_state.search[i_d, i_b] * constraint_state.mv[i_d, i_b] for _i0 in range(1): @@ -1553,8 +1585,8 @@ def func_ls_point_fn( @ti.func -def func_no_linesearch(i_b, constraint_state: array_class.ConstraintState): - func_ls_init(i_b) +def func_no_linesearch(i_f, i_b, constraint_state: array_class.ConstraintState): + func_ls_init(i_f, i_b) n_dofs = constraint_state.search.shape[0] constraint_state.improved[i_b] = 1 @@ -1567,6 +1599,7 @@ def func_no_linesearch(i_b, constraint_state: array_class.ConstraintState): @ti.func def func_linesearch( + i_f, i_b, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -1597,6 +1630,7 @@ def func_linesearch( res_alpha = 0.0 else: func_ls_init( + i_f, i_b, entities_info=entities_info, dofs_state=dofs_state, @@ -1802,6 +1836,7 @@ def update_bracket( @ti.func def func_solve_body( + i_f, i_b, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -1811,6 +1846,7 @@ def func_solve_body( ): n_dofs = constraint_state.qacc.shape[0] alpha = func_linesearch( + i_f, i_b, entities_info=entities_info, dofs_state=dofs_state, @@ -1837,6 +1873,7 @@ def func_solve_body( constraint_state.cg_prev_Mgrad[i_d, i_b] = constraint_state.Mgrad[i_d, i_b] func_update_constraint( + i_f, i_b, qacc=constraint_state.qacc, Ma=constraint_state.Ma, @@ -1848,6 +1885,7 @@ def func_solve_body( if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): func_nt_hessian_incremental( + i_f, i_b, entities_info=entities_info, constraint_state=constraint_state, @@ -1856,6 +1894,7 @@ def func_solve_body( ) func_update_gradient( + i_f, i_b, dofs_state=dofs_state, entities_info=entities_info, @@ -1902,6 +1941,7 @@ def func_solve_body( @ti.func def func_update_constraint( + i_f, i_b, qacc: array_class.V_ANNOTATION, Ma: array_class.V_ANNOTATION, @@ -1960,11 +2000,15 @@ def func_update_constraint( constraint_state.qfrc_constraint[i_d, i_b] = qfrc_constraint # (Mx - Mx') * (x - x') for i_d in range(n_dofs): - v = 0.5 * (Ma[i_d, i_b] - dofs_state.force[i_d, i_b]) * (qacc[i_d, i_b] - dofs_state.acc_smooth[i_d, i_b]) + v = ( + 0.5 + * (Ma[i_d, i_b] - dofs_state.force[i_f, i_d, i_b]) + * (qacc[i_d, i_b] - dofs_state.acc_smooth[i_f, i_d, i_b]) + ) constraint_state.gauss[i_b] = constraint_state.gauss[i_b] + v cost[i_b] = cost[i_b] + v - # D * (Jx - aref) ** 2 + # # D * (Jx - aref) ** 2 for i_c in range(constraint_state.n_constraints[i_b]): cost[i_b] = cost[i_b] + 0.5 * ( constraint_state.efc_D[i_c, i_b] @@ -1976,6 +2020,7 @@ def func_update_constraint( @ti.func def func_update_gradient( + i_f, i_b, dofs_state: array_class.DofsState, entities_info: array_class.EntitiesInfo, @@ -1987,13 +2032,14 @@ def func_update_gradient( for i_d in range(n_dofs): constraint_state.grad[i_d, i_b] = ( - constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] + constraint_state.Ma[i_d, i_b] - dofs_state.force[i_f, i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] ) if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG): rigid_solver.func_solve_mass_batched( constraint_state.grad, constraint_state.Mgrad, + i_f, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, @@ -2004,6 +2050,20 @@ def func_update_gradient( func_nt_chol_solve(i_b, constraint_state=constraint_state) +@ti.func +def copy_acc_smooth( + i_f: ti.i32, + dofs_state: array_class.DofsState, + constraint_state: array_class.ConstraintState, + static_rigid_sim_config: ti.template(), +): + n_dofs = constraint_state.qacc_smooth.shape[0] + _B = constraint_state.qacc_smooth.shape[1] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(n_dofs, _B): + constraint_state.qacc_smooth[i_d, i_b] = dofs_state.acc_smooth[i_f, i_d, i_b] + + @ti.func def initialize_Jaref( qacc: array_class.V_ANNOTATION, @@ -2028,6 +2088,7 @@ def initialize_Jaref( @ti.func def initialize_Ma( + i_f: ti.i32, Ma: array_class.V_ANNOTATION, qacc: array_class.V_ANNOTATION, entities_info: array_class.EntitiesInfo, @@ -2035,7 +2096,7 @@ def initialize_Ma( static_rigid_sim_config: ti.template(), ): - _B = rigid_global_info.mass_mat.shape[2] + _B = rigid_global_info.mass_mat.shape[3] n_entities = entities_info.n_links.shape[0] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_e, i_b in ti.ndrange(n_entities, _B): @@ -2043,13 +2104,14 @@ def initialize_Ma( i_d1 = entities_info.dof_start[i_e] + i_d1_ Ma_ = gs.ti_float(0.0) for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - Ma_ += rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b] + Ma_ += rigid_global_info.mass_mat[i_f, i_d1, i_d2, i_b] * qacc[i_d2, i_b] Ma[i_d1, i_b] = Ma_ @gs.maybe_pure @ti.kernel def func_init_solver( + i_f: ti.i32, dofs_state: array_class.DofsState, entities_info: array_class.EntitiesInfo, constraint_state: array_class.ConstraintState, @@ -2057,8 +2119,17 @@ def func_init_solver( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - _B = dofs_state.acc_smooth.shape[1] - n_dofs = dofs_state.acc_smooth.shape[0] + _B = dofs_state.acc_smooth.shape[2] + n_dofs = dofs_state.acc_smooth.shape[1] + + # copy dofs_state.acc_smooth to constraint_state.qacc_smooth + copy_acc_smooth( + i_f=i_f, + dofs_state=dofs_state, + constraint_state=constraint_state, + static_rigid_sim_config=static_rigid_sim_config, + ) + # check if warm start initialize_Jaref( qacc=constraint_state.qacc_ws, @@ -2067,6 +2138,7 @@ def func_init_solver( ) initialize_Ma( + i_f=i_f, Ma=constraint_state.Ma_ws, qacc=constraint_state.qacc_ws, entities_info=entities_info, @@ -2077,6 +2149,7 @@ def func_init_solver( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): func_update_constraint( + i_f, i_b, qacc=constraint_state.qacc_ws, Ma=constraint_state.Ma_ws, @@ -2086,14 +2159,15 @@ def func_init_solver( static_rigid_sim_config=static_rigid_sim_config, ) initialize_Jaref( - qacc=dofs_state.acc_smooth, + qacc=constraint_state.qacc_smooth, constraint_state=constraint_state, static_rigid_sim_config=static_rigid_sim_config, ) initialize_Ma( + i_f=i_f, Ma=constraint_state.Ma, - qacc=dofs_state.acc_smooth, + qacc=constraint_state.qacc_smooth, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, @@ -2101,8 +2175,9 @@ def func_init_solver( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): func_update_constraint( + i_f, i_b, - qacc=dofs_state.acc_smooth, + qacc=constraint_state.qacc_smooth, Ma=constraint_state.Ma, cost=constraint_state.cost, dofs_state=dofs_state, @@ -2115,7 +2190,7 @@ def func_init_solver( constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] constraint_state.Ma[i_d, i_b] = constraint_state.Ma_ws[i_d, i_b] else: - constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_smooth[i_d, i_b] initialize_Jaref( qacc=constraint_state.qacc, constraint_state=constraint_state, @@ -2126,6 +2201,7 @@ def func_init_solver( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): func_update_constraint( + i_f, i_b, qacc=constraint_state.qacc, Ma=constraint_state.Ma, @@ -2136,6 +2212,7 @@ def func_init_solver( ) if ti.static(static_rigid_sim_config.solver_type == gs.constraint_solver.Newton): func_nt_hessian_direct( + i_f, i_b, entities_info=entities_info, rigid_global_info=rigid_global_info, @@ -2144,6 +2221,7 @@ def func_init_solver( ) func_update_gradient( + i_f, i_b, dofs_state=dofs_state, entities_info=entities_info, diff --git a/genesis/engine/solvers/rigid/gjk_decomp.py b/genesis/engine/solvers/rigid/gjk_decomp.py index ca5b1631f..670368968 100644 --- a/genesis/engine/solvers/rigid/gjk_decomp.py +++ b/genesis/engine/solvers/rigid/gjk_decomp.py @@ -168,6 +168,7 @@ def func_gjk_contact( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -210,6 +211,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -275,6 +277,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -299,6 +302,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -306,7 +310,7 @@ def func_gjk_contact( # Run EPA from the polytope if polytope_flag == EPA_POLY_INIT_RETURN_CODE.SUCCESS: - i_f = func_epa( + i_face = func_epa( geoms_state, geoms_info, verts_info, @@ -317,6 +321,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -327,7 +332,7 @@ def func_gjk_contact( # (1) [i_f] should be a valid face index in the polytope (>= 0), # (2) Both of the geometries should be discrete, # (3) [enable_mujoco_multi_contact] should be True. Default to False. - if i_f >= 0 and func_is_discrete_geoms(geoms_info, i_ga, i_gb, i_b): + if i_face >= 0 and func_is_discrete_geoms(geoms_info, i_ga, i_gb, i_b): func_multi_contact( geoms_state, geoms_info, @@ -335,6 +340,7 @@ def func_gjk_contact( faces_info, gjk_state, gjk_static_config, + i_f, i_ga, i_gb, i_b, @@ -353,6 +359,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -379,6 +386,7 @@ def func_gjk_contact( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -425,6 +433,7 @@ def func_gjk( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -481,8 +490,8 @@ def func_gjk( early_stop = False # Set initial guess of support vector using the positions, which should be a non-zero vector. - approx_witness_point_obj1 = geoms_state.pos[i_ga, i_b] - approx_witness_point_obj2 = geoms_state.pos[i_gb, i_b] + approx_witness_point_obj1 = geoms_state.pos[i_f, i_ga, i_b] + approx_witness_point_obj2 = geoms_state.pos[i_f, i_gb, i_b] support_vector = approx_witness_point_obj1 - approx_witness_point_obj2 if support_vector.dot(support_vector) < gjk_static_config.FLOAT_MIN_SQ: support_vector = gs.ti_vec3(1.0, 0.0, 0.0) @@ -522,6 +531,7 @@ def func_gjk( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -567,6 +577,7 @@ def func_gjk( gjk_static_config=gjk_static_config, support_field_info=support_field_info, support_field_static_config=support_field_static_config, + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, @@ -661,6 +672,7 @@ def func_gjk_intersect( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -751,6 +763,7 @@ def func_gjk_intersect( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -1086,6 +1099,7 @@ def func_epa( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_frame, i_ga, i_gb, i_b, @@ -1149,6 +1163,7 @@ def func_epa( gjk_static_config, support_field_info, support_field_static_config, + i_frame, i_ga, i_gb, i_b, @@ -1481,6 +1496,7 @@ def func_epa_init_polytope_2d( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -1551,6 +1567,7 @@ def func_epa_init_polytope_2d( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -1624,6 +1641,7 @@ def func_epa_init_polytope_3d( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -1679,6 +1697,7 @@ def func_epa_init_polytope_3d( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -1843,6 +1862,7 @@ def func_epa_support( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -1876,6 +1896,7 @@ def func_epa_support( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -2148,6 +2169,7 @@ def func_multi_contact( faces_info: array_class.FacesInfo, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_frame, i_ga, i_gb, i_b, @@ -2213,7 +2235,7 @@ def func_multi_contact( nnorms = 0 if geom_type == gs.GEOM_TYPE.BOX: nnorms = func_potential_box_normals( - geoms_state, geoms_info, gjk_state, gjk_static_config, i_g, i_b, nface, v1i, v2i, v3i, t_dir + geoms_state, geoms_info, gjk_state, gjk_static_config, i_frame, i_g, i_b, nface, v1i, v2i, v3i, t_dir ) elif geom_type == gs.GEOM_TYPE.MESH: nnorms = func_potential_mesh_normals( @@ -2223,6 +2245,7 @@ def func_multi_contact( faces_info, gjk_state, gjk_static_config, + i_frame, i_g, i_b, nface, @@ -2265,7 +2288,7 @@ def func_multi_contact( nnorms = 0 if geom_type == gs.GEOM_TYPE.BOX: nnorms = func_potential_box_edge_normals( - geoms_state, geoms_info, gjk_state, gjk_static_config, i_g, i_b, nface, v1, v2, v1i, v2i + geoms_state, geoms_info, gjk_state, gjk_static_config, i_frame, i_g, i_b, nface, v1, v2, v1i, v2i ) elif geom_type == gs.GEOM_TYPE.MESH: nnorms = func_potential_mesh_edge_normals( @@ -2275,6 +2298,7 @@ def func_multi_contact( faces_info, gjk_state, gjk_static_config, + i_frame, i_g, i_b, nface, @@ -2421,6 +2445,7 @@ def func_potential_box_normals( geoms_info: array_class.GeomsInfo, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_f, i_g, i_b, dim, @@ -2439,7 +2464,7 @@ def func_potential_box_normals( We identify related face normals to the simplex by checking the vertex indices of the simplex. """ - g_quat = geoms_state.quat[i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] # Change to local vertex indices v1 -= geoms_info.vert_start[i_g] @@ -2528,7 +2553,7 @@ def func_potential_box_normals( if is_degenerate_simplex: n_normals = ( 1 - if func_box_normal_from_collision_normal(geoms_state, gjk_state, gjk_static_config, i_g, i_b, dir) + if func_box_normal_from_collision_normal(geoms_state, gjk_state, gjk_static_config, i_f, i_g, i_b, dir) == RETURN_CODE.SUCCESS else 0 ) @@ -2581,6 +2606,7 @@ def func_box_normal_from_collision_normal( geoms_state: array_class.GeomsState, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_f, i_g, i_b, dir, @@ -2595,7 +2621,7 @@ def func_box_normal_from_collision_normal( ) # Get local collision normal - g_quat = geoms_state.quat[i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] local_dir = gu.ti_transform_by_quat(dir, gu.ti_inv_quat(g_quat)) local_dir = local_dir.normalized() @@ -2620,6 +2646,7 @@ def func_potential_mesh_normals( faces_info: array_class.FacesInfo, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_f, i_g, i_b, dim, @@ -2639,7 +2666,7 @@ def func_potential_mesh_normals( We identify related face normals to the simplex by checking the vertex indices of the simplex. """ # Get the geometry state and quaternion - g_quat = geoms_state.quat[i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] # Number of potential face normals n_normals = 0 @@ -2722,6 +2749,7 @@ def func_potential_box_edge_normals( geoms_info: array_class.GeomsInfo, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_f, i_g, i_b, dim, @@ -2740,8 +2768,8 @@ def func_potential_box_edge_normals( We identify related edge normals to the simplex by checking the vertex indices of the simplex. """ # Get the geometry state and quaternion - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] g_size_x = geoms_info.data[0] * 0.5 g_size_y = geoms_info.data[1] * 0.5 g_size_z = geoms_info.data[2] * 0.5 @@ -2788,6 +2816,7 @@ def func_potential_mesh_edge_normals( faces_info: array_class.FacesInfo, gjk_state: array_class.GJKState, gjk_static_config: ti.template(), + i_f, i_g, i_b, dim, @@ -2806,8 +2835,8 @@ def func_potential_mesh_edge_normals( We identify related edge normals to the simplex by checking the vertex indices of the simplex. """ # Get the geometry state and quaternion - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] # Number of potential face normals n_normals = 0 @@ -3346,6 +3375,7 @@ def func_support( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -3380,6 +3410,7 @@ def func_support( support_field_info, support_field_static_config, d, + i_f, i_g, i_b, i, @@ -3552,6 +3583,7 @@ def support_mesh( gjk_state: array_class.GJKState, gjk_static_config: ti.template(), direction, + i_f, i_g, i_b, i_o, @@ -3559,8 +3591,8 @@ def support_mesh( """ Find the support point on a mesh in the given direction. """ - g_quat = geoms_state.quat[i_g, i_b] - g_pos = geoms_state.pos[i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] d_mesh = gu.ti_transform_by_quat(direction, gu.ti_inv_quat(g_quat)) # Exhaustively search for the vertex with maximum dot product @@ -3606,6 +3638,7 @@ def support_driver( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), direction, + i_f, i_g, i_b, i_o, @@ -3619,20 +3652,20 @@ def support_driver( geom_type = geoms_info.type[i_g] if geom_type == gs.GEOM_TYPE.SPHERE: - v = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_g, i_b, shrink_sphere) + v = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_f, i_g, i_b, shrink_sphere) elif geom_type == gs.GEOM_TYPE.ELLIPSOID: - v = support_field._func_support_ellipsoid(geoms_state, geoms_info, direction, i_g, i_b) + v = support_field._func_support_ellipsoid(geoms_state, geoms_info, direction, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.CAPSULE: - v = support_field._func_support_capsule(geoms_state, geoms_info, direction, i_g, i_b, shrink_sphere) + v = support_field._func_support_capsule(geoms_state, geoms_info, direction, i_f, i_g, i_b, shrink_sphere) elif geom_type == gs.GEOM_TYPE.BOX: - v, vid = support_field._func_support_box(geoms_state, geoms_info, direction, i_g, i_b) + v, vid = support_field._func_support_box(geoms_state, geoms_info, direction, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.TERRAIN: if ti.static(collider_static_config.has_terrain): - v, vid = support_field._func_support_prism(collider_state, direction, i_g, i_b) + v, vid = support_field._func_support_prism(collider_state, direction, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.MESH and static_rigid_sim_config.enable_mujoco_compatibility: # If mujoco-compatible, do exhaustive search for the vertex v, vid = support_mesh( - geoms_state, geoms_info, verts_info, gjk_state, gjk_static_config, direction, i_g, i_b, i_o + geoms_state, geoms_info, verts_info, gjk_state, gjk_static_config, direction, i_f, i_g, i_b, i_o ) else: v, vid = support_field._func_support_world( @@ -3641,6 +3674,7 @@ def support_driver( support_field_info, support_field_static_config, direction, + i_f, i_g, i_b, ) @@ -3659,6 +3693,7 @@ def func_safe_gjk( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -3707,6 +3742,7 @@ def func_safe_gjk( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -3729,6 +3765,7 @@ def func_safe_gjk( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -3804,6 +3841,7 @@ def func_safe_gjk( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -3979,6 +4017,7 @@ def func_search_valid_simplex_vertex( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -4009,7 +4048,7 @@ def func_search_valid_simplex_vertex( id2 = geoms_info.vert_start[i_gb] + j for p in range(2): obj = func_get_discrete_geom_vertex( - geoms_state, geoms_info, verts_info, i_ga if p == 0 else i_gb, i_b, i if p == 0 else j + geoms_state, geoms_info, verts_info, i_f, i_ga if p == 0 else i_gb, i_b, i if p == 0 else j ) if p == 0: obj1 = obj @@ -4046,6 +4085,7 @@ def func_search_valid_simplex_vertex( gjk_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -4080,6 +4120,7 @@ def func_get_discrete_geom_vertex( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, verts_info: array_class.VertsInfo, + i_f, i_g, i_b, i_v, @@ -4088,8 +4129,8 @@ def func_get_discrete_geom_vertex( Get the discrete vertex of the geometry for the given index [i_v]. """ geom_type = geoms_info.type[i_g] - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] # Get the vertex position in the local frame of the geometry. v = ti.Vector([0.0, 0.0, 0.0], dt=gs.ti_float) @@ -4160,6 +4201,7 @@ def func_safe_gjk_support( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -4199,6 +4241,7 @@ def func_safe_gjk_support( geoms_info, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -4226,6 +4269,7 @@ def func_safe_gjk_support( support_field_info, support_field_static_config, d, + i_f, i_g, i_b, j, @@ -4267,6 +4311,7 @@ def count_support_driver( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), d, + i_f, i_g, i_b, ): @@ -4276,7 +4321,7 @@ def count_support_driver( geom_type = geoms_info.type[i_g] count = 1 if geom_type == gs.GEOM_TYPE.BOX: - count = support_field._func_count_supports_box(geoms_state, geoms_info, d, i_g, i_b) + count = support_field._func_count_supports_box(geoms_state, geoms_info, d, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.MESH: count = support_field._func_count_supports_world( geoms_state, @@ -4284,6 +4329,7 @@ def count_support_driver( support_field_info, support_field_static_config, d, + i_f, i_g, i_b, ) @@ -4296,6 +4342,7 @@ def func_count_support( geoms_info: array_class.GeomsInfo, support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -4312,6 +4359,7 @@ def func_count_support( support_field_info, support_field_static_config, dir if i == 0 else -dir, + i_f, i_ga if i == 0 else i_gb, i_b, ) @@ -4331,6 +4379,7 @@ def func_safe_epa( gjk_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_frame, i_ga, i_gb, i_b, @@ -4393,6 +4442,7 @@ def func_safe_epa( gjk_static_config, support_field_info, support_field_static_config, + i_frame, i_ga, i_gb, i_b, diff --git a/genesis/engine/solvers/rigid/mpr_decomp.py b/genesis/engine/solvers/rigid/mpr_decomp.py index c10dad24f..5bd5ae40b 100644 --- a/genesis/engine/solvers/rigid/mpr_decomp.py +++ b/genesis/engine/solvers/rigid/mpr_decomp.py @@ -190,25 +190,26 @@ def support_driver( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), direction, + i_f, i_g, i_b, ): v = ti.Vector.zero(gs.ti_float, 3) geom_type = geoms_info.type[i_g] if geom_type == gs.GEOM_TYPE.SPHERE: - v = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_g, i_b, False) + v = support_field._func_support_sphere(geoms_state, geoms_info, direction, i_f, i_g, i_b, False) elif geom_type == gs.GEOM_TYPE.ELLIPSOID: - v = support_field._func_support_ellipsoid(geoms_state, geoms_info, direction, i_g, i_b) + v = support_field._func_support_ellipsoid(geoms_state, geoms_info, direction, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.CAPSULE: - v = support_field._func_support_capsule(geoms_state, geoms_info, direction, i_g, i_b, False) + v = support_field._func_support_capsule(geoms_state, geoms_info, direction, i_f, i_g, i_b, False) elif geom_type == gs.GEOM_TYPE.BOX: - v, _ = support_field._func_support_box(geoms_state, geoms_info, direction, i_g, i_b) + v, _ = support_field._func_support_box(geoms_state, geoms_info, direction, i_f, i_g, i_b) elif geom_type == gs.GEOM_TYPE.TERRAIN: if ti.static(collider_static_config.has_terrain): - v, _ = support_field._func_support_prism(collider_state, direction, i_g, i_b) + v, _ = support_field._func_support_prism(collider_state, direction, i_f, i_g, i_b) else: v, _ = support_field._func_support_world( - geoms_state, geoms_info, support_field_info, support_field_static_config, direction, i_g, i_b + geoms_state, geoms_info, support_field_info, support_field_static_config, direction, i_f, i_g, i_b ) return v @@ -223,6 +224,7 @@ def compute_support( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), direction, + i_f, i_ga, i_gb, i_b, @@ -236,6 +238,7 @@ def compute_support( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_b, ) @@ -248,6 +251,7 @@ def compute_support( support_field_info, support_field_static_config, -direction, + i_f, i_gb, i_b, ) @@ -296,6 +300,7 @@ def mpr_refine_portal( mpr_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -317,6 +322,7 @@ def mpr_refine_portal( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_gb, i_b, @@ -401,6 +407,7 @@ def mpr_find_penetration( collider_static_config: ti.template(), mpr_state: array_class.MPRState, mpr_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -423,6 +430,7 @@ def mpr_find_penetration( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_gb, i_b, @@ -505,6 +513,7 @@ def mpr_discover_portal( collider_static_config: ti.template(), mpr_state: array_class.MPRState, mpr_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -530,6 +539,7 @@ def mpr_discover_portal( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_gb, i_b, @@ -563,6 +573,7 @@ def mpr_discover_portal( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_gb, i_b, @@ -600,6 +611,7 @@ def mpr_discover_portal( support_field_info, support_field_static_config, direction, + i_f, i_ga, i_gb, i_b, @@ -653,6 +665,7 @@ def guess_geoms_center( geoms_init_AABB: array_class.GeomsInitAABB, static_rigid_sim_config: ti.template(), mpr_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -685,10 +698,10 @@ def guess_geoms_center( # respective geometry. If one of the center is off, its offset from the original center is divided by 2 and the # signed distance is computed once again until to find a valid point. This procedure should be cheap. - g_pos_a = geoms_state.pos[i_ga, i_b] - g_pos_b = geoms_state.pos[i_gb, i_b] - g_quat_a = geoms_state.quat[i_ga, i_b] - g_quat_b = geoms_state.quat[i_gb, i_b] + g_pos_a = geoms_state.pos[i_f, i_ga, i_b] + g_pos_b = geoms_state.pos[i_f, i_gb, i_b] + g_quat_a = geoms_state.quat[i_f, i_ga, i_b] + g_quat_b = geoms_state.quat[i_f, i_gb, i_b] center_a = gu.ti_transform_by_trans_quat(geoms_info.center[i_ga], g_pos_a, g_quat_a) center_b = gu.ti_transform_by_trans_quat(geoms_info.center[i_gb], g_pos_b, g_quat_b) @@ -742,6 +755,7 @@ def func_mpr_contact_from_centers( mpr_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -758,6 +772,7 @@ def func_mpr_contact_from_centers( collider_static_config=collider_static_config, mpr_state=mpr_state, mpr_static_config=mpr_static_config, + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, @@ -785,6 +800,7 @@ def func_mpr_contact_from_centers( mpr_static_config, support_field_info, support_field_static_config, + i_f, i_ga, i_gb, i_b, @@ -801,6 +817,7 @@ def func_mpr_contact_from_centers( collider_static_config, mpr_state, mpr_static_config, + i_f, i_ga, i_gb, i_b, @@ -821,6 +838,7 @@ def func_mpr_contact( mpr_static_config: ti.template(), support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), + i_f, i_ga, i_gb, i_b, @@ -832,6 +850,7 @@ def func_mpr_contact( geoms_init_AABB, static_rigid_sim_config, mpr_static_config, + i_f, i_ga, i_gb, i_b, @@ -848,6 +867,7 @@ def func_mpr_contact( mpr_static_config=mpr_static_config, support_field_info=support_field_info, support_field_static_config=support_field_static_config, + i_f=i_f, i_ga=i_ga, i_gb=i_gb, i_b=i_b, diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 79c28648e..cf64a3bbc 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -203,6 +203,12 @@ def build(self): self._n_entities = self.n_entities self._n_equalities = self.n_equalities + self._max_n_links_per_entity = self.max_n_links_per_entity + self._max_n_joints_per_link = self.max_n_joints_per_link + self._max_n_dofs_per_joint = self.max_n_dofs_per_joint + self._max_n_dofs_per_entity = self.max_n_dofs_per_entity + self._max_n_dofs_per_link = self.max_n_dofs_per_link + self._geoms = self.geoms self._vgeoms = self.vgeoms self._links = self.links @@ -303,6 +309,8 @@ def build(self): self._init_collider() self._init_constraint_solver() + self.rigid_adjoint_cache = self.data_manager.rigid_adjoint_cache + self._init_invweight_and_meaninertia(force_update=False) def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, unsafe=False): @@ -322,8 +330,9 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u qpos = ti_to_torch(self.qpos0, envs_idx, transpose=True, unsafe=True) if self.n_envs == 0: qpos = qpos.squeeze(0) - self.set_qpos(qpos, envs_idx=envs_idx if self.n_envs > 0 else None) + self.set_qpos(0, qpos, envs_idx=envs_idx if self.n_envs > 0 else None) kernel_forward_kinematics_links_geoms( + 0, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -342,6 +351,7 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u # Compute mass matrix without any implicit damping terms # TODO: This kernel could be optimized to take `envs_idx` as input if performance is critical. kernel_compute_mass_matrix( + f=0, links_state=self.links_state, links_info=self.links_info, dofs_state=self.dofs_state, @@ -354,11 +364,17 @@ def _init_invweight_and_meaninertia(self, envs_idx=None, *, force_update=True, u ) # Define some proxies for convenience - mass_mat_D_inv = self._rigid_global_info.mass_mat_D_inv.to_numpy() - mass_mat_L = self._rigid_global_info.mass_mat_L.to_numpy() - offsets = self.links_state.i_pos.to_numpy() - cdof_ang = self.dofs_state.cdof_ang.to_numpy() - cdof_vel = self.dofs_state.cdof_vel.to_numpy() + # mass_mat_D_inv = self._rigid_global_info.mass_mat_D_inv.to_numpy()[0, :, 0] + # mass_mat_L = self._rigid_global_info.mass_mat_L.to_numpy()[0, :, :, 0] + # offsets = self.links_state.i_pos.to_numpy()[0, :, 0] + # cdof_ang = self.dofs_state.cdof_ang.to_numpy()[0, :, 0] + # cdof_vel = self.dofs_state.cdof_vel.to_numpy()[0, :, 0] + mass_mat_D_inv = self._rigid_global_info.mass_mat_D_inv.to_numpy()[0] + mass_mat_L = self._rigid_global_info.mass_mat_L.to_numpy()[0] + offsets = self.links_state.i_pos.to_numpy()[0] + cdof_ang = self.dofs_state.cdof_ang.to_numpy()[0] + cdof_vel = self.dofs_state.cdof_vel.to_numpy()[0] + links_joint_start = self.links_info.joint_start.to_numpy() links_joint_end = self.links_info.joint_end.to_numpy() links_dof_end = self.links_info.dof_end.to_numpy() @@ -687,7 +703,19 @@ def _init_link_fields(self): is_init_qpos_out_of_bounds |= (joint.dofs_limit[0, 0] > init_qpos[joint.q_start]).any() is_init_qpos_out_of_bounds |= (init_qpos[joint.q_start] > joint.dofs_limit[0, 1]).any() # init_qpos[joint.q_start] = np.clip(init_qpos[joint.q_start], *joint.dofs_limit[0]) - self.qpos.from_numpy(init_qpos) + + init_qpos, qs_idx, envs_idx = self._sanitize_1D_io_variables( + init_qpos.T if self.n_envs > 0 else self.init_qpos, + inputs_idx=None, + input_size=self.n_qs, + envs_idx=None, + idx_name="qs_idx", + skip_allocation=True, + unsafe=False, + ) + if self.n_envs == 0: + init_qpos = init_qpos.unsqueeze(0) + kernel_set_qpos(0, init_qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config) if is_init_qpos_out_of_bounds: gs.logger.warning( "Reference robot position exceeds joint limits." @@ -908,7 +936,7 @@ def _init_equality_fields(self): static_rigid_sim_cache_key=self._static_rigid_sim_cache_key, ) if self._use_contact_island: - gs.logger.warn("contact island is not supported for equality constraints yet") + gs.logger.warning("contact island is not supported for equality constraints yet") def _init_envs_offset(self): self.envs_offset = self._rigid_global_info.envs_offset @@ -951,12 +979,13 @@ def _init_constraint_solver(self): else: self.constraint_solver = ConstraintSolver(self) - def substep(self): - # from genesis.utils.tools import create_timer + def substep(self, f): + from genesis.engine.couplers import SAPCoupler # timer = create_timer("rigid", level=1, ti_sync=True, skip_first_call=True) kernel_step_1( + f=f, links_state=self.links_state, links_info=self.links_info, joints_state=self.joints_state, @@ -977,9 +1006,10 @@ def substep(self): if isinstance(self.sim.coupler, SAPCoupler): self.update_qvel() else: - self._func_constraint_force() + self._func_constraint_force(f) # timer.stamp("constraint_force") kernel_step_2( + f=f, dofs_state=self.dofs_state, dofs_info=self.dofs_info, links_info=self.links_info, @@ -998,6 +1028,8 @@ def substep(self): ) # timer.stamp("kernel_step_2") + pass + def _kernel_detect_collision(self): self.collider.clear() self.collider.detection() @@ -1011,18 +1043,19 @@ def detect_collision(self, env_idx=0): collision_pairs[:, 1] = self.collider._collider_state.contact_data.geom_b.to_numpy()[:n_collision, env_idx] return collision_pairs - def _func_constraint_force(self): + def _func_constraint_force(self, f): # from genesis.utils.tools import create_timer # timer = create_timer(name="constraint_force", level=2, ti_sync=True, skip_first_call=True) self._func_constraint_clear() # timer.stamp("constraint_solver.clear") if not self._disable_constraint and not self._use_contact_island: - self.constraint_solver.add_equality_constraints() + self.constraint_solver.add_equality_constraints(f) # timer.stamp("constraint_solver.add_equality_constraints") if self._enable_collision: - self.collider.detection() + self.collider.detection(f) + print("n contacts: ", self.collider._collider_state.n_contacts.to_numpy()) # timer.stamp("detection") if not self._disable_constraint: @@ -1030,14 +1063,14 @@ def _func_constraint_force(self): self.constraint_solver.add_constraints() # timer.stamp("constraint_solver.add_constraints") else: - self.constraint_solver.add_frictionloss_constraints() + self.constraint_solver.add_frictionloss_constraints(f) if self._enable_collision: - self.constraint_solver.add_collision_constraints() + self.constraint_solver.add_collision_constraints(f) if self._enable_joint_limit: - self.constraint_solver.add_joint_limit_constraints() + self.constraint_solver.add_joint_limit_constraints(f) # timer.stamp("constraint_solver.add_other_constraints") - self.constraint_solver.resolve() + self.constraint_solver.resolve(f) # timer.stamp("constraint_solver.resolve") def _func_constraint_clear(self): @@ -1046,8 +1079,9 @@ def _func_constraint_clear(self): self.constraint_solver.constraint_state.n_constraints_frictionloss.fill(0) self.collider._collider_state.n_contacts.fill(0) - def _func_forward_dynamics(self): + def _func_forward_dynamics(self, f): kernel_forward_dynamics( + f=f, links_state=self.links_state, links_info=self.links_info, dofs_state=self.dofs_state, @@ -1062,8 +1096,9 @@ def _func_forward_dynamics(self): contact_island_state=self.constraint_solver.contact_island.contact_island_state, ) - def _func_update_acc(self): + def _func_update_acc(self, f): kernel_update_acc( + f=f, dofs_state=self.dofs_state, links_info=self.links_info, links_state=self.links_state, @@ -1287,22 +1322,17 @@ def apply_links_external_torque( ) @ti.kernel - def update_qvel(self): - if ti.static(self._use_hibernation): - ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) - for i_b in range(self._B): - for i_d_ in range(self.n_awake_dofs[i_b]): - i_d = self.awake_dofs[i_d_, i_b] - self.dofs_state.vel_prev[i_d, i_b] = self.dofs_state.vel[i_d, i_b] - self.dofs_state.vel[i_d, i_b] = ( - self.dofs_state.vel[i_d, i_b] + self.dofs_state.acc[i_d, i_b] * self._substep_dt - ) - else: - ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(self.n_dofs, self._B): - self.dofs_state.vel_prev[i_d, i_b] = self.dofs_state.vel[i_d, i_b] - self.dofs_state.vel[i_d, i_b] = ( - self.dofs_state.vel[i_d, i_b] + self.dofs_state.acc[i_d, i_b] * self._substep_dt + def update_qvel(self, f: ti.i32): + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, self._B) if ti.static(self._use_hibernation) else ti.ndrange(self.n_dofs, self._B) + ): + for i_1 in range(self.n_awake_dofs[i_b]) if ti.static(self._use_hibernation) else range(1): + i_d = self.awake_dofs[i_1, i_b] if ti.static(self._use_hibernation) else i_0 + + self.dofs_state.vel_prev[f, i_d, i_b] = self.dofs_state.vel[f, i_d, i_b] + self.dofs_state.vel[f, i_d, i_b] = ( + self.dofs_state.vel[f, i_d, i_b] + self.dofs_state.acc[f, i_d, i_b] * self._substep_dt ) @ti.kernel @@ -1326,7 +1356,7 @@ def update_qacc_from_qvel_delta(self): def substep_pre_coupling(self, f): if self.is_active(): - self.substep() + self.substep(f) def substep_pre_coupling_grad(self, f): pass @@ -1371,6 +1401,7 @@ def reset_grad(self): def update_geoms_render_T(self): kernel_update_geoms_render_T( + self.sim.cur_substep_local, self._geoms_render_T, geoms_state=self.geoms_state, rigid_global_info=self._rigid_global_info, @@ -1407,6 +1438,7 @@ def get_state(self, f): # static_rigid_sim_config: ti.template(), kernel_get_state( + f=f, qpos=state.qpos, vel=state.dofs_vel, links_pos=state.links_pos, @@ -1429,6 +1461,7 @@ def set_state(self, f, state, envs_idx=None): if self.is_active(): envs_idx = self._scene._sanitize_envs_idx(envs_idx) kernel_set_state( + f, qpos=state.qpos, dofs_vel=state.dofs_vel, links_pos=state.links_pos, @@ -1444,6 +1477,7 @@ def set_state(self, f, state, envs_idx=None): static_rigid_sim_config=self._static_rigid_sim_config, ) kernel_forward_kinematics_links_geoms( + f, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -1466,16 +1500,118 @@ def set_state(self, f, state, envs_idx=None): self._cur_step = -1 def process_input(self, in_backward=False): - pass + for entity in self._entities: + entity.process_input(in_backward=in_backward) def process_input_grad(self): - pass + for entity in self._entities: + entity.process_input_grad() def save_ckpt(self, ckpt_name): - pass + if self._sim.requires_grad: + if ckpt_name not in self._ckpt: + self._ckpt[ckpt_name] = dict() + self._ckpt[ckpt_name]["qpos"] = torch.zeros((self._B, self.n_qs), dtype=gs.tc_float) + self._ckpt[ckpt_name]["vel"] = torch.zeros((self._B, self.n_dofs), dtype=gs.tc_float) + self._ckpt[ckpt_name]["links_pos"] = torch.zeros((self._B, self.n_links, 3), dtype=gs.tc_float) + self._ckpt[ckpt_name]["links_quat"] = torch.zeros((self._B, self.n_links, 4), dtype=gs.tc_float) + self._ckpt[ckpt_name]["i_pos_shift"] = torch.zeros((self._B, self.n_links, 3), dtype=gs.tc_float) + self._ckpt[ckpt_name]["mass_shift"] = torch.zeros((self._B, self.n_links), dtype=gs.tc_float) + self._ckpt[ckpt_name]["friction_ratio"] = torch.zeros((self._B, self.n_geoms), dtype=gs.tc_float) + + self._kernel_get_state( + 0, + self._ckpt[ckpt_name]["qpos"], + self._ckpt[ckpt_name]["vel"], + self._ckpt[ckpt_name]["links_pos"], + self._ckpt[ckpt_name]["links_quat"], + self._ckpt[ckpt_name]["i_pos_shift"], + self._ckpt[ckpt_name]["mass_shift"], + self._ckpt[ckpt_name]["friction_ratio"], + ) + + for entity in self._entities: + entity.save_ckpt(ckpt_name) + + # Restart from frame 0 in memory + kernel_copy_frame( + self._sim.substeps_local, + 0, + self.links_state, + self.joints_state, + self.dofs_state, + self.geoms_state, + self._rigid_global_info, + self._static_rigid_sim_config, + ) def load_ckpt(self, ckpt_name): - pass + self.copy_frame(0, self._sim.substeps_local) + self.copy_grad(0, self._sim.substeps_local) + + if self._sim.requires_grad: + self.reset_grad_till_frame(self._sim.substeps_local) + + envs_idx = self._sanitize_envs_idx(None) + self._kernel_set_state( + 0, + self._ckpt[ckpt_name]["qpos"], + self._ckpt[ckpt_name]["vel"], + self._ckpt[ckpt_name]["links_pos"], + self._ckpt[ckpt_name]["links_quat"], + self._ckpt[ckpt_name]["i_pos_shift"], + self._ckpt[ckpt_name]["mass_shift"], + self._ckpt[ckpt_name]["friction_ratio"], + envs_idx, + ) + + for entity in self._entities: + entity.load_ckpt(ckpt_name=ckpt_name) + + @ti.kernel + def copy_grad(self, source: ti.i32, target: ti.i32): + + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(self.n_qs, self._B): + self.qpos.grad[target, i_q, i_b] = self.qpos.grad[source, i_q, i_b] + + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(self.n_dofs, self._B): + self.dofs_state.grad[target, i_d, i_b].vel = self.dofs_state.grad[source, i_d, i_b].vel + + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(self.n_links, self._B): + for i in ti.static(range(3)): + self.links_state.grad[target, i_l, i_b].pos[i] = self.links_state.grad[source, i_l, i_b].pos[i] + self.links_state.grad[target, i_l, i_b].i_pos_shift[i] = self.links_state.grad[ + source, i_l, i_b + ].i_pos_shift[i] + for i in ti.static(range(4)): + self.links_state.grad[target, i_l, i_b].quat[i] = self.links_state.grad[source, i_l, i_b].quat[i] + self.links_state.grad[target, i_l, i_b].mass_shift = self.links_state.grad[source, i_l, i_b].mass_shift + + ti.loop_config(serialize=self._para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(self.n_geoms, self._B): + self.geoms_state.grad[target, i_l, i_b].friction_ratio = self.geoms_state.grad[ + source, i_l, i_b + ].friction_ratio + + @ti.kernel + def reset_grad_till_frame(self, f: ti.i32): + for i_f, i_n, i_b in ti.ndrange(f, self.n_qs, self._B): + self.qpos.grad[i_f, i_n, i_b] = 0 + + for i_f, i_n, i_b in ti.ndrange(f, self.n_dofs, self._B): + self.dofs_state.grad[i_f, i_n, i_b].vel = 0 + + for i_f, i_n, i_b in ti.ndrange(f, self.n_links, self._B): + self.links_state.grad[i_f, i_n, i_b].pos = 0 + self.links_state.grad[i_f, i_n, i_b].i_pos_shift = 0 + self.links_state.grad[i_f, i_n, i_b].quat = 0 + self.links_state.grad[i_f, i_n, i_b].mass_shift = 0 + + for i_f, i_n, i_b in ti.ndrange(f, self.n_geoms, self._B): + self.geoms_state.grad[i_f, i_n, i_b].friction_ratio = 0 def is_active(self): return self.n_links > 0 @@ -1674,6 +1810,7 @@ def set_base_links_pos( if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): gs.raise_exception("`links_idx` contains at least one link that is not a base link.") kernel_set_links_pos( + 0, relative, pos, links_idx, @@ -1685,6 +1822,7 @@ def set_base_links_pos( ) if not skip_forward: kernel_forward_kinematics_links_geoms( + 0, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -1716,6 +1854,7 @@ def set_base_links_quat( if not unsafe and not torch.isin(links_idx, self._base_links_idx).all(): gs.raise_exception("`links_idx` contains at least one link that is not a base link.") kernel_set_links_quat( + 0, relative, quat, links_idx, @@ -1727,6 +1866,7 @@ def set_base_links_quat( ) if not skip_forward: kernel_forward_kinematics_links_geoms( + 0, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -1789,13 +1929,13 @@ def set_geoms_friction_ratio(self, friction_ratio, geoms_idx=None, envs_idx=None friction_ratio, geoms_idx, envs_idx, self.geoms_state, self._static_rigid_sim_config ) - def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): + def set_qpos(self, f, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): qpos, qs_idx, envs_idx = self._sanitize_1D_io_variables( qpos, qs_idx, self.n_qs, envs_idx, idx_name="qs_idx", skip_allocation=True, unsafe=unsafe ) if self.n_envs == 0: qpos = qpos.unsqueeze(0) - kernel_set_qpos(qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config) + kernel_set_qpos(f, qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config) self.collider.reset(envs_idx) self.collider.clear(envs_idx) if self.constraint_solver is not None: @@ -1803,6 +1943,7 @@ def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsa self.constraint_solver.clear(envs_idx) if not skip_forward: kernel_forward_kinematics_links_geoms( + f, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -1983,14 +2124,15 @@ def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forw ) if velocity is None: - kernel_set_dofs_zero_velocity(dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + kernel_set_dofs_zero_velocity(0, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) else: if self.n_envs == 0: velocity = velocity.unsqueeze(0) - kernel_set_dofs_velocity(velocity, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + kernel_set_dofs_velocity(0, velocity, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) if not skip_forward: kernel_forward_kinematics_links_geoms( + 0, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -2013,6 +2155,7 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw if self.n_envs == 0: position = position.unsqueeze(0) kernel_set_dofs_position( + 0, position, dofs_idx, envs_idx, @@ -2030,6 +2173,7 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw self.constraint_solver.clear(envs_idx) if not skip_forward: kernel_forward_kinematics_links_geoms( + 0, envs_idx, links_state=self.links_state, links_info=self.links_info, @@ -2228,15 +2372,15 @@ def get_dofs_control_force(self, dofs_idx=None, envs_idx=None, *, unsafe=False): return _tensor def get_dofs_force(self, dofs_idx=None, envs_idx=None, *, unsafe=False): - tensor = ti_to_torch(self.dofs_state.force, envs_idx, dofs_idx, transpose=True, unsafe=unsafe) + tensor = ti_to_torch(self.dofs_state.force, envs_idx, dofs_idx, transpose=True, unsafe=unsafe)[:, 0] return tensor.squeeze(0) if self.n_envs == 0 else tensor def get_dofs_velocity(self, dofs_idx=None, envs_idx=None, *, unsafe=False): - tensor = ti_to_torch(self.dofs_state.vel, envs_idx, dofs_idx, transpose=True, unsafe=unsafe) + tensor = ti_to_torch(self.dofs_state.vel, envs_idx, dofs_idx, transpose=True, unsafe=unsafe)[:, 0] return tensor.squeeze(0) if self.n_envs == 0 else tensor def get_dofs_position(self, dofs_idx=None, envs_idx=None, *, unsafe=False): - tensor = ti_to_torch(self.dofs_state.pos, envs_idx, dofs_idx, transpose=True, unsafe=unsafe) + tensor = ti_to_torch(self.dofs_state.pos, envs_idx, dofs_idx, transpose=True, unsafe=unsafe)[:, 0] return tensor.squeeze(0) if self.n_envs == 0 else tensor def get_dofs_kp(self, dofs_idx=None, envs_idx=None, *, unsafe=False): @@ -2385,6 +2529,7 @@ def clear_external_force(self): def update_vgeoms(self): kernel_update_vgeoms( + self.sim.cur_substep_local, self.vgeoms_info, self.vgeoms_state, self.links_state, @@ -2527,12 +2672,24 @@ def n_links(self): return self._n_links return len(self.links) + @property + def max_n_links_per_entity(self): + if self.is_built: + return self._max_n_links_per_entity + return max([len(entity.links) for entity in self._entities]) if len(self._entities) > 0 else 0 + @property def n_joints(self): if self.is_built: return self._n_joints return len(self.joints) + @property + def max_n_joints_per_link(self): + if self.is_built: + return self._max_n_joints_per_link + return max([len(link.joints) for link in self.links]) if len(self.links) > 0 else 0 + @property def n_geoms(self): if self.is_built: @@ -2605,6 +2762,24 @@ def n_dofs(self): return self._n_dofs return sum(entity.n_dofs for entity in self._entities) + @property + def max_n_dofs_per_entity(self): + if self.is_built: + return self._max_n_dofs_per_entity + return max([entity.n_dofs for entity in self._entities]) if len(self._entities) > 0 else 0 + + @property + def max_n_dofs_per_link(self): + if self.is_built: + return self._max_n_dofs_per_link + return max([link.n_dofs for link in self.links]) if len(self.links) > 0 else 0 + + @property + def max_n_dofs_per_joint(self): + if self.is_built: + return self._max_n_dofs_per_joint + return max([joint.n_dofs for joint in self.joints]) if len(self.joints) > 0 else 0 + @property def init_qpos(self): if len(self._entities) == 0: @@ -2631,6 +2806,7 @@ def equalities(self): @gs.maybe_pure @ti.kernel def kernel_compute_mass_matrix( + f: ti.i32, # taichi variables links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -2643,6 +2819,7 @@ def kernel_compute_mass_matrix( decompose: ti.template(), ): func_compute_mass_matrix( + f=f, implicit_damping=False, links_state=links_state, links_info=links_info, @@ -2654,6 +2831,7 @@ def kernel_compute_mass_matrix( ) if decompose: func_factor_mass( + f=f, implicit_damping=False, entities_info=entities_info, dofs_state=dofs_state, @@ -2711,7 +2889,7 @@ def kernel_init_meaninertia( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_dofs = rigid_global_info.mass_mat.shape[0] + n_dofs = rigid_global_info.mass_mat.shape[1] n_entities = entities_info.n_links.shape[0] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_b_ in range(envs_idx.shape[0]): @@ -2721,7 +2899,7 @@ def kernel_init_meaninertia( for i_e in range(n_entities): for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.meaninertia[i_b] += rigid_global_info.mass_mat[i_d, i_d, i_b] + rigid_global_info.meaninertia[i_b] += rigid_global_info.mass_mat[0, i_d, i_d, i_b] rigid_global_info.meaninertia[i_b] = rigid_global_info.meaninertia[i_b] / n_dofs else: rigid_global_info.meaninertia[i_b] = 1.0 @@ -2750,8 +2928,9 @@ def kernel_init_dof_fields( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_dofs = dofs_state.ctrl_mode.shape[0] - _B = dofs_state.ctrl_mode.shape[1] + n_steps = dofs_state.ctrl_mode.shape[0] + n_dofs = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] for I in ti.grouped(dofs_info.invweight): i = I[0] # batching (if any) will be the second dim @@ -2772,13 +2951,13 @@ def kernel_init_dof_fields( dofs_info.kv[I] = dofs_kv[i] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i, b in ti.ndrange(n_dofs, _B): - dofs_state.ctrl_mode[i, b] = gs.CTRL_MODE.FORCE + for f, i, b in ti.ndrange(n_steps, n_dofs, _B): + dofs_state.ctrl_mode[f, i, b] = gs.CTRL_MODE.FORCE if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i, b in ti.ndrange(n_dofs, _B): - dofs_state.hibernated[i, b] = False + for f, i, b in ti.ndrange(n_steps, n_dofs, _B): + dofs_state.hibernated[f, i, b] = False rigid_global_info.awake_dofs[i, b] = i ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) @@ -2814,7 +2993,7 @@ def kernel_init_link_fields( static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): n_links = links_parent_idx.shape[0] - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] for I in ti.grouped(links_info.invweight): i = I[0] @@ -2852,10 +3031,10 @@ def kernel_init_link_fields( # Update state for root fixed link. Their state will not be updated in forward kinematics later but can be manually changed by user. if links_info.parent_idx[I] == -1 and links_info.is_fixed[I]: for j in ti.static(range(4)): - links_state.quat[i, b][j] = links_quat[i, j] + links_state.quat[0, i, b][j] = links_quat[i, j] for j in ti.static(range(3)): - links_state.pos[i, b][j] = links_pos[i, j] + links_state.pos[0, i, b][j] = links_pos[i, j] for j in ti.static(range(3)): links_state.i_pos_shift[i, b][j] = 0.0 @@ -2864,7 +3043,7 @@ def kernel_init_link_fields( if ti.static(static_rigid_sim_config.use_hibernation): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i, b in ti.ndrange(n_links, _B): - links_state.hibernated[i, b] = False + links_state.hibernated[0, i, b] = False rigid_global_info.awake_links[i, b] = i ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) @@ -3017,8 +3196,9 @@ def kernel_init_geom_fields( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): + n_steps = geoms_state.pos.shape[0] n_geoms = geoms_pos.shape[0] - _B = geoms_state.friction_ratio.shape[1] + _B = geoms_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i in range(n_geoms): for j in ti.static(range(3)): @@ -3236,6 +3416,7 @@ def kernel_init_equality_fields( @gs.maybe_pure @ti.kernel def kernel_forward_dynamics( + f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, dofs_state: array_class.DofsState, @@ -3250,6 +3431,7 @@ def kernel_forward_dynamics( contact_island_state: array_class.ContactIslandState, ): func_forward_dynamics( + f=f, links_state=links_state, links_info=links_info, dofs_state=dofs_state, @@ -3267,6 +3449,7 @@ def kernel_forward_dynamics( @gs.maybe_pure @ti.kernel def kernel_update_acc( + f: ti.i32, dofs_state: array_class.DofsState, links_info: array_class.LinksInfo, links_state: array_class.LinksState, @@ -3276,6 +3459,7 @@ def kernel_update_acc( static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): func_update_acc( + f=f, update_cacc=True, dofs_state=dofs_state, links_info=links_info, @@ -3298,6 +3482,7 @@ def func_vel_at_point(pos_world, link_idx, i_b, links_state: array_class.LinksSt @ti.func def func_compute_mass_matrix( + f: ti.i32, implicit_damping: ti.template(), # taichi variables links_state: array_class.LinksState, @@ -3308,171 +3493,141 @@ def func_compute_mass_matrix( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_state.pos.shape[0] + _B = links_state.pos.shape[2] + n_links = links_state.pos.shape[1] n_entities = entities_info.n_links.shape[0] - n_dofs = dofs_state.f_ang.shape[0] - - if ti.static(static_rigid_sim_config.use_hibernation): - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] - links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] - links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] - links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] - - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] - ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] - - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] - - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], - ) - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = ( - dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) - + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) - ) * rigid_global_info.mass_parent_mask[i_d, j_d] - - # FIXME: Updating the lower-part of the mass matrix is irrelevant - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - for j_d in range(i_d + 1, entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] - - # Take into account motor armature - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) + n_dofs = dofs_state.f_ang.shape[1] - # Take into account first-order correction terms for implicit integration scheme right away - if ti.static(implicit_damping): - for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * static_rigid_sim_config.substep_dt - ) - if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel - rigid_global_info.mass_mat[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * static_rigid_sim_config.substep_dt - ) - else: - # crb initialize - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): - links_state.crb_inertial[i_l, i_b] = links_state.cinr_inertial[i_l, i_b] - links_state.crb_pos[i_l, i_b] = links_state.cinr_pos[i_l, i_b] - links_state.crb_quat[i_l, i_b] = links_state.cinr_quat[i_l, i_b] - links_state.crb_mass[i_l, i_b] = links_state.cinr_mass[i_l, i_b] + # crb initialize + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 + + links_state.crb_inertial[f, i_l, i_b] = links_state.cinr_inertial[f, i_l, i_b] + links_state.crb_pos[f, i_l, i_b] = links_state.cinr_pos[f, i_l, i_b] + links_state.crb_quat[f, i_l, i_b] = links_state.cinr_quat[f, i_l, i_b] + links_state.crb_mass[f, i_l, i_b] = links_state.cinr_mass[f, i_l, i_b] + + # crb + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - # crb - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): for i in range(entities_info.n_links[i_e]): i_l = entities_info.link_end[i_e] - 1 - i I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] if i_p != -1: - links_state.crb_inertial[i_p, i_b] = ( - links_state.crb_inertial[i_p, i_b] + links_state.crb_inertial[i_l, i_b] + links_state.crb_inertial[f, i_p, i_b] = ( + links_state.crb_inertial[f, i_p, i_b] + links_state.crb_inertial[f, i_l, i_b] + ) + links_state.crb_mass[f, i_p, i_b] = ( + links_state.crb_mass[f, i_p, i_b] + links_state.crb_mass[f, i_l, i_b] ) - links_state.crb_mass[i_p, i_b] = links_state.crb_mass[i_p, i_b] + links_state.crb_mass[i_l, i_b] - links_state.crb_pos[i_p, i_b] = links_state.crb_pos[i_p, i_b] + links_state.crb_pos[i_l, i_b] - links_state.crb_quat[i_p, i_b] = links_state.crb_quat[i_p, i_b] + links_state.crb_quat[i_l, i_b] + links_state.crb_pos[f, i_p, i_b] = ( + links_state.crb_pos[f, i_p, i_b] + links_state.crb_pos[f, i_l, i_b] + ) + links_state.crb_quat[f, i_p, i_b] = ( + links_state.crb_quat[f, i_p, i_b] + links_state.crb_quat[f, i_l, i_b] + ) - # mass_mat - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): + # mass_mat + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.f_ang[i_d, i_b], dofs_state.f_vel[i_d, i_b] = gu.inertial_mul( - links_state.crb_pos[i_l, i_b], - links_state.crb_inertial[i_l, i_b], - links_state.crb_mass[i_l, i_b], - dofs_state.cdof_vel[i_d, i_b], - dofs_state.cdof_ang[i_d, i_b], + dofs_state.f_ang[f, i_d, i_b], dofs_state.f_vel[f, i_d, i_b] = gu.inertial_mul( + links_state.crb_pos[f, i_l, i_b], + links_state.crb_inertial[f, i_l, i_b], + links_state.crb_mass[f, i_l, i_b], + dofs_state.cdof_vel[f, i_d, i_b], + dofs_state.cdof_ang[f, i_d, i_b], ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e, i_b in ti.ndrange(n_entities, _B): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + for i_d, j_d in ti.ndrange( (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), (entities_info.dof_start[i_e], entities_info.dof_end[i_e]), ): - rigid_global_info.mass_mat[i_d, j_d, i_b] = ( - dofs_state.f_ang[i_d, i_b].dot(dofs_state.cdof_ang[j_d, i_b]) - + dofs_state.f_vel[i_d, i_b].dot(dofs_state.cdof_vel[j_d, i_b]) + rigid_global_info.mass_mat[f, i_d, j_d, i_b] = ( + dofs_state.f_ang[f, i_d, i_b].dot(dofs_state.cdof_ang[f, j_d, i_b]) + + dofs_state.f_vel[f, i_d, i_b].dot(dofs_state.cdof_vel[f, j_d, i_b]) ) * rigid_global_info.mass_parent_mask[i_d, j_d] # FIXME: Updating the lower-part of the mass matrix is irrelevant for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): for j_d in range(i_d + 1, entities_info.dof_end[i_e]): - rigid_global_info.mass_mat[i_d, j_d, i_b] = rigid_global_info.mass_mat[j_d, i_d, i_b] - - # Take into account motor armature - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] = ( - rigid_global_info.mass_mat[i_d, i_d, i_b] + dofs_info.armature[I_d] - ) + rigid_global_info.mass_mat[f, i_d, j_d, i_b] = rigid_global_info.mass_mat[f, j_d, i_d, i_b] - # Take into account first-order correction terms for implicit integration scheme right away - if ti.static(implicit_damping): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): + # Take into account motor armature + for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.damping[I_d] * static_rigid_sim_config.substep_dt - if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - # qM += d qfrc_actuator / d qvel - rigid_global_info.mass_mat[i_d, i_d, i_b] += dofs_info.kv[I_d] * static_rigid_sim_config.substep_dt + rigid_global_info.mass_mat[f, i_d, i_d, i_b] = ( + rigid_global_info.mass_mat[f, i_d, i_d, i_b] + dofs_info.armature[I_d] + ) + + # Take into account first-order correction terms for implicit integration scheme right away + if ti.static(implicit_damping): + for i_d in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d + rigid_global_info.mass_mat[f, i_d, i_d, i_b] += ( + dofs_info.damping[I_d] * static_rigid_sim_config.substep_dt + ) + if (dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.POSITION) or ( + dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.VELOCITY + ): + # qM += d qfrc_actuator / d qvel + rigid_global_info.mass_mat[f, i_d, i_d, i_b] += ( + dofs_info.kv[I_d] * static_rigid_sim_config.substep_dt + ) @ti.func def func_factor_mass( + f: ti.i32, implicit_damping: ti.template(), entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -3483,98 +3638,71 @@ def func_factor_mass( """ Compute Cholesky decomposition (L^T @ D @ L) of mass matrix. """ - _B = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - - if rigid_global_info._mass_mat_mask[i_e, i_b] == 1: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] - - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] - - if ti.static(implicit_damping): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.damping[I_d] * static_rigid_sim_config.substep_dt - ) - if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY - ): - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( - dofs_info.kv[I_d] * static_rigid_sim_config.substep_dt - ) - - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] - - for j_d_ in range(i_d - entity_dof_start): - j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] - for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] - ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e, i_b in ti.ndrange(n_entities, _B): - if rigid_global_info._mass_mat_mask[i_e, i_b] == 1: + if rigid_global_info._mass_mat_mask[f, i_e, i_b] == 1: entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] for i_d in range(entity_dof_start, entity_dof_end): for j_d in range(entity_dof_start, i_d + 1): - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = rigid_global_info.mass_mat[i_d, j_d, i_b] + rigid_global_info.mass_mat_L[f, i_d, j_d, i_b] = rigid_global_info.mass_mat[f, i_d, j_d, i_b] if ti.static(implicit_damping): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( + rigid_global_info.mass_mat_L[f, i_d, i_d, i_b] += ( dofs_info.damping[I_d] * static_rigid_sim_config.substep_dt ) if ti.static(static_rigid_sim_config.integrator == gs.integrator.implicitfast): - if (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) or ( - dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY + if (dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.POSITION) or ( + dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.VELOCITY ): - rigid_global_info.mass_mat_L[i_d, i_d, i_b] += ( + rigid_global_info.mass_mat_L[f, i_d, i_d, i_b] += ( dofs_info.kv[I_d] * static_rigid_sim_config.substep_dt ) for i_d_ in range(n_dofs): i_d = entity_dof_end - i_d_ - 1 - rigid_global_info.mass_mat_D_inv[i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[i_d, i_d, i_b] + rigid_global_info.mass_mat_D_inv[f, i_d, i_b] = 1.0 / rigid_global_info.mass_mat_L[f, i_d, i_d, i_b] for j_d_ in range(i_d - entity_dof_start): j_d = i_d - j_d_ - 1 - a = rigid_global_info.mass_mat_L[i_d, j_d, i_b] * rigid_global_info.mass_mat_D_inv[i_d, i_b] + a = ( + rigid_global_info.mass_mat_L[f, i_d, j_d, i_b] + * rigid_global_info.mass_mat_D_inv[f, i_d, i_b] + ) for k_d in range(entity_dof_start, j_d + 1): - rigid_global_info.mass_mat_L[j_d, k_d, i_b] -= ( - a * rigid_global_info.mass_mat_L[i_d, k_d, i_b] + rigid_global_info.mass_mat_L[f, j_d, k_d, i_b] -= ( + a * rigid_global_info.mass_mat_L[f, i_d, k_d, i_b] ) - rigid_global_info.mass_mat_L[i_d, j_d, i_b] = a + rigid_global_info.mass_mat_L[f, i_d, j_d, i_b] = a # FIXME: Diagonal coeffs of L are ignored in computations, so no need to update them. - rigid_global_info.mass_mat_L[i_d, i_d, i_b] = 1.0 + rigid_global_info.mass_mat_L[f, i_d, i_d, i_b] = 1.0 @ti.func def func_solve_mass_batched( vec: array_class.V_ANNOTATION, out: array_class.V_ANNOTATION, + i_f: ti.int32, i_b: ti.int32, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, @@ -3582,35 +3710,64 @@ def func_solve_mass_batched( ): n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - if rigid_global_info._mass_mat_mask[i_e, i_b] == 1: - entity_dof_start = entities_info.dof_start[i_e] - entity_dof_end = entities_info.dof_end[i_e] - n_dofs = entities_info.n_dofs[i_e] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_0 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ): + i_e = rigid_global_info.awake_entities[i_0, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - # Step 1: Solve w st. L^T @ w = y - for i_d_ in range(n_dofs): - i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] - for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] + if rigid_global_info._mass_mat_mask[i_f, i_e, i_b] == 1: + entity_dof_start = entities_info.dof_start[i_e] + entity_dof_end = entities_info.dof_end[i_e] + n_dofs = entities_info.n_dofs[i_e] - # Step 2: z = D^{-1} w - for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + # Step 1: Solve w st. L^T @ w = y + for i_d_ in range(n_dofs): + i_d = entity_dof_end - i_d_ - 1 + out[i_d, i_b] = vec[i_d, i_b] + for j_d in range(i_d + 1, entity_dof_end): + out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_f, j_d, i_d, i_b] * out[j_d, i_b] - # Step 3: Solve x st. L @ x = z - for i_d in range(entity_dof_start, entity_dof_end): - for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e in range(n_entities): - if rigid_global_info._mass_mat_mask[i_e, i_b] == 1: + # Step 2: z = D^{-1} w + for i_d in range(entity_dof_start, entity_dof_end): + out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_f, i_d, i_b] + + # Step 3: Solve x st. L @ x = z + for i_d in range(entity_dof_start, entity_dof_end): + for j_d in range(entity_dof_start, i_d): + out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_f, i_d, j_d, i_b] * out[j_d, i_b] + + +@ti.func +def func_solve_mass( + f: ti.i32, + vec: array_class.V_ANNOTATION, + out: array_class.V_ANNOTATION, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + _B = dofs_state.acc.shape[2] + n_entities = entities_info.n_links.shape[0] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_b in range(_B): + for i_0 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ): + i_e = ( + rigid_global_info.awake_entities[i_0, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + + if rigid_global_info._mass_mat_mask[f, i_e, i_b] == 1: entity_dof_start = entities_info.dof_start[i_e] entity_dof_end = entities_info.dof_end[i_e] n_dofs = entities_info.n_dofs[i_e] @@ -3618,39 +3775,18 @@ def func_solve_mass_batched( # Step 1: Solve w st. L^T @ w = y for i_d_ in range(n_dofs): i_d = entity_dof_end - i_d_ - 1 - out[i_d, i_b] = vec[i_d, i_b] + out[f, i_d, i_b] = vec[f, i_d, i_b] for j_d in range(i_d + 1, entity_dof_end): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[j_d, i_d, i_b] * out[j_d, i_b] + out[f, i_d, i_b] -= rigid_global_info.mass_mat_L[f, j_d, i_d, i_b] * out[f, j_d, i_b] # Step 2: z = D^{-1} w for i_d in range(entity_dof_start, entity_dof_end): - out[i_d, i_b] *= rigid_global_info.mass_mat_D_inv[i_d, i_b] + out[f, i_d, i_b] *= rigid_global_info.mass_mat_D_inv[f, i_d, i_b] # Step 3: Solve x st. L @ x = z for i_d in range(entity_dof_start, entity_dof_end): for j_d in range(entity_dof_start, i_d): - out[i_d, i_b] -= rigid_global_info.mass_mat_L[i_d, j_d, i_b] * out[j_d, i_b] - - -@ti.func -def func_solve_mass( - vec: array_class.V_ANNOTATION, - out: array_class.V_ANNOTATION, - entities_info: array_class.EntitiesInfo, - rigid_global_info: array_class.RigidGlobalInfo, - static_rigid_sim_config: ti.template(), -): - _B = out.shape[1] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in range(_B): - func_solve_mass_batched( - vec, - out, - i_b, - entities_info=entities_info, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) + out[f, i_d, i_b] -= rigid_global_info.mass_mat_L[f, i_d, j_d, i_b] * out[f, j_d, i_b] @gs.maybe_pure @@ -3970,6 +4106,7 @@ def kernel_rigid_entity_inverse_kinematics( # decomposed kernels should happen in the block below. This block will be handled by composer and composed into a single kernel @ti.func def func_forward_dynamics( + f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, dofs_state: array_class.DofsState, @@ -3983,6 +4120,7 @@ def func_forward_dynamics( contact_island_state: array_class.ContactIslandState, ): func_compute_mass_matrix( + f=f, implicit_damping=ti.static(static_rigid_sim_config.integrator == gs.integrator.approximate_implicitfast), links_state=links_state, links_info=links_info, @@ -3993,6 +4131,7 @@ def func_forward_dynamics( static_rigid_sim_config=static_rigid_sim_config, ) func_factor_mass( + f=f, implicit_damping=False, entities_info=entities_info, dofs_state=dofs_state, @@ -4001,6 +4140,7 @@ def func_forward_dynamics( static_rigid_sim_config=static_rigid_sim_config, ) func_torque_and_passive_force( + f=f, entities_state=entities_state, entities_info=entities_info, dofs_state=dofs_state, @@ -4014,6 +4154,7 @@ def func_forward_dynamics( contact_island_state=contact_island_state, ) func_update_acc( + f=f, update_cacc=False, dofs_state=dofs_state, links_info=links_info, @@ -4023,6 +4164,7 @@ def func_forward_dynamics( static_rigid_sim_config=static_rigid_sim_config, ) func_update_force( + f=f, links_state=links_state, links_info=links_info, entities_info=entities_info, @@ -4031,6 +4173,7 @@ def func_forward_dynamics( ) # self._func_actuation() func_bias_force( + f=f, dofs_state=dofs_state, links_state=links_state, links_info=links_info, @@ -4038,6 +4181,7 @@ def func_forward_dynamics( static_rigid_sim_config=static_rigid_sim_config, ) func_compute_qacc( + f=f, dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, @@ -4062,6 +4206,7 @@ def kernel_clear_external_force( @ti.func def func_update_cartesian_space( + f, i_b, links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -4075,7 +4220,9 @@ def func_update_cartesian_space( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): + # TODO: Non-differentiable function, should we make it differentiable? func_forward_kinematics( + f, i_b, links_state=links_state, links_info=links_info, @@ -4088,6 +4235,7 @@ def func_update_cartesian_space( static_rigid_sim_config=static_rigid_sim_config, ) func_COM_links( + f, i_b, links_state=links_state, links_info=links_info, @@ -4100,6 +4248,7 @@ def func_update_cartesian_space( static_rigid_sim_config=static_rigid_sim_config, ) func_forward_velocity( + f, i_b, entities_info=entities_info, links_info=links_info, @@ -4111,6 +4260,7 @@ def func_update_cartesian_space( ) func_update_geoms( + f, i_b=i_b, entities_info=entities_info, geoms_info=geoms_info, @@ -4124,6 +4274,7 @@ def func_update_cartesian_space( @gs.maybe_pure @ti.kernel def kernel_step_1( + f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, joints_state: array_class.JointsState, @@ -4140,10 +4291,11 @@ def kernel_step_1( static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): if ti.static(static_rigid_sim_config.enable_mujoco_compatibility): - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): func_update_cartesian_space( + f=f, i_b=i_b, links_state=links_state, links_info=links_info, @@ -4159,6 +4311,7 @@ def kernel_step_1( ) func_forward_dynamics( + f=f, links_state=links_state, links_info=links_info, dofs_state=dofs_state, @@ -4175,6 +4328,7 @@ def kernel_step_1( @ti.func def func_implicit_damping( + f: ti.i32, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, entities_info: array_class.EntitiesInfo, @@ -4183,7 +4337,7 @@ def func_implicit_damping( ): n_entities = entities_info.dof_start.shape[0] - _B = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] # Determine whether the mass matrix must be re-computed to take into account first-order correction terms. # Note that avoiding inverting the mass matrix twice would not only speed up simulation but also improving # numerical stability as computing post-damping accelerations from forces is not necessary anymore. @@ -4192,7 +4346,7 @@ def func_implicit_damping( or static_rigid_sim_config.integrator == gs.integrator.Euler ): for i_e, i_b in ti.ndrange(n_entities, _B): - rigid_global_info._mass_mat_mask[i_e, i_b] = 0 + rigid_global_info._mass_mat_mask[f, i_e, i_b] = 0 ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_e, i_b in ti.ndrange(n_entities, _B): @@ -4201,15 +4355,16 @@ def func_implicit_damping( for i_d in range(entity_dof_start, entity_dof_end): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d if dofs_info.damping[I_d] > gs.EPS: - rigid_global_info._mass_mat_mask[i_e, i_b] = 1 + rigid_global_info._mass_mat_mask[f, i_e, i_b] = 1 if ti.static(static_rigid_sim_config.integrator != gs.integrator.Euler): if ( - (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION) - or (dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY) + (dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.POSITION) + or (dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.VELOCITY) ) and dofs_info.kv[I_d] > gs.EPS: - rigid_global_info._mass_mat_mask[i_e, i_b] = 1 + rigid_global_info._mass_mat_mask[f, i_e, i_b] = 1 func_factor_mass( + f=f, implicit_damping=True, entities_info=entities_info, dofs_state=dofs_state, @@ -4218,8 +4373,10 @@ def func_implicit_damping( static_rigid_sim_config=static_rigid_sim_config, ) func_solve_mass( + f=f, vec=dofs_state.force, out=dofs_state.acc, + dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, @@ -4231,12 +4388,13 @@ def func_implicit_damping( or static_rigid_sim_config.integrator == gs.integrator.Euler ): for i_e, i_b in ti.ndrange(n_entities, _B): - rigid_global_info._mass_mat_mask[i_e, i_b] = 1 + rigid_global_info._mass_mat_mask[f, i_e, i_b] = 1 @gs.maybe_pure @ti.kernel def kernel_step_2( + f: ti.i32, dofs_state: array_class.DofsState, dofs_info: array_class.DofsInfo, links_info: array_class.LinksInfo, @@ -4259,6 +4417,7 @@ def kernel_step_2( # before and after integration under the effect of external forces and constraints. This means that # acceleration data will be shifted one timestep in the past, but there isn't really any way around. func_update_acc( + f=f, update_cacc=True, dofs_state=dofs_state, links_info=links_info, @@ -4270,6 +4429,7 @@ def kernel_step_2( if ti.static(static_rigid_sim_config.integrator != gs.integrator.approximate_implicitfast): func_implicit_damping( + f=f, dofs_state=dofs_state, dofs_info=dofs_info, entities_info=entities_info, @@ -4278,6 +4438,7 @@ def kernel_step_2( ) func_integrate( + f=f, dofs_state=dofs_state, links_info=links_info, joints_info=joints_info, @@ -4305,10 +4466,11 @@ def kernel_step_2( ) if ti.static(not static_rigid_sim_config.enable_mujoco_compatibility): - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b in range(_B): func_update_cartesian_space( + f=f + 1, i_b=i_b, links_state=links_state, links_info=links_info, @@ -4327,6 +4489,7 @@ def kernel_step_2( @gs.maybe_pure @ti.kernel def kernel_forward_kinematics_links_geoms( + f: ti.i32, envs_idx: ti.types.ndarray(), links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -4341,10 +4504,12 @@ def kernel_forward_kinematics_links_geoms( static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): + # TODO: Non-differentiable function, should we make it differentiable? for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] func_update_cartesian_space( + f=f, i_b=i_b, links_state=links_state, links_info=links_info, @@ -4362,6 +4527,7 @@ def kernel_forward_kinematics_links_geoms( @ti.func def func_COM_links( + f, i_b, links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -4375,311 +4541,185 @@ def func_COM_links( ): n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - links_state.root_COM[i_l, i_b].fill(0.0) - links_state.mass_sum[i_l, i_b] = 0.0 + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + links_state.root_COM[f, i_l, i_b].fill(0.0) + links_state.mass_sum[f, i_l, i_b] = 0.0 - mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.i_pos[i_l, i_b], - links_state.i_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], - links_info.inertial_quat[I_l], - links_state.pos[i_l, i_b], - links_state.quat[i_l, i_b], - ) + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_r = links_info.root_idx[I_l] - links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] + mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.i_pos[f, i_l, i_b], + links_state.i_quat[f, i_l, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], + links_info.inertial_quat[I_l], + links_state.pos[f, i_l, i_b], + links_state.quat[f, i_l, i_b], + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_r = links_info.root_idx[I_l] + links_state.mass_sum[f, i_r, i_b] += mass + links_state.root_COM[f, i_r, i_b] += mass * links_state.i_pos[f, i_l, i_b] - i_r = links_info.root_idx[I_l] - if i_l == i_r: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_r = links_info.root_idx[I_l] + if i_l == i_r and links_state.mass_sum[f, i_l, i_b] > 0.0: + links_state.root_COM[f, i_l, i_b] = links_state.root_COM[f, i_l, i_b] / links_state.mass_sum[f, i_l, i_b] - i_r = links_info.root_idx[I_l] - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + i_r = links_info.root_idx[I_l] + links_state.root_COM[f, i_l, i_b] = links_state.root_COM[f, i_r, i_b] - i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_inertial = links_info.inertial_i[I_l] - i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_pos[i_l, i_b], - links_state.cinr_quat[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b] - ) + i_r = links_info.root_idx[I_l] + links_state.i_pos[f, i_l, i_b] = links_state.i_pos[f, i_l, i_b] - links_state.root_COM[f, i_l, i_b] + + i_inertial = links_info.inertial_i[I_l] + i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] + ( + links_state.cinr_inertial[f, i_l, i_b], + links_state.cinr_pos[f, i_l, i_b], + links_state.cinr_quat[f, i_l, i_b], + links_state.cinr_mass[f, i_l, i_b], + ) = gu.ti_transform_inertia_by_trans_quat( + i_inertial, i_mass, links_state.i_pos[f, i_l, i_b], links_state.i_quat[f, i_l, i_b] + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] + if links_info.n_dofs[I_l] == 0: + continue - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] + i_p = links_info.parent_idx[I_l] - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] + _i_j = links_info.joint_start[I_l] + _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j + joint_type = joints_info.type[_I_j] - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + p_pos = ti.Vector.zero(gs.ti_float, 3) + p_quat = gu.ti_identity_quat() + if i_p != -1: + p_pos = links_state.pos[f, i_p, i_b] + p_quat = links_state.quat[f, i_p, i_b] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): + links_state.j_pos[f, i_l, i_b] = links_state.pos[f, i_l, i_b] + links_state.j_quat[f, i_l, i_b] = links_state.quat[f, i_l, i_b] + else: + ( + links_state.j_pos[f, i_l, i_b], + links_state.j_quat[f, i_l, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) + for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - if joint_type == gs.JOINT_TYPE.FREE: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.cdof_vel[i_d, i_b] = dofs_info.motion_vel[I_d] - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat( - dofs_info.motion_ang[I_d], links_state.j_quat[i_l, i_b] - ) - - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) - - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - else: - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - motion_vel = dofs_info.motion_vel[I_d] - motion_ang = dofs_info.motion_ang[I_d] - - dofs_state.cdof_ang[i_d, i_b] = gu.ti_transform_by_quat(motion_ang, links_state.j_quat[i_l, i_b]) - dofs_state.cdof_vel[i_d, i_b] = gu.ti_transform_by_quat(motion_vel, links_state.j_quat[i_l, i_b]) - - offset_pos = links_state.root_COM[i_l, i_b] - links_state.j_pos[i_l, i_b] - ( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - ) = gu.ti_transform_motion_by_trans_quat( - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], - offset_pos, - gu.ti_identity_quat(), - ) - - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - links_state.root_COM[i_l, i_b].fill(0.0) - links_state.mass_sum[i_l, i_b] = 0.0 - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.i_pos[i_l, i_b], - links_state.i_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - links_info.inertial_pos[I_l] + links_state.i_pos_shift[i_l, i_b], - links_info.inertial_quat[I_l], - links_state.pos[i_l, i_b], - links_state.quat[i_l, i_b], - ) - - i_r = links_info.root_idx[I_l] - links_state.mass_sum[i_r, i_b] += mass - links_state.root_COM[i_r, i_b] += mass * links_state.i_pos[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - if i_l == i_r: - if links_state.mass_sum[i_l, i_b] > 0.0: - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_l, i_b] / links_state.mass_sum[i_l, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - links_state.root_COM[i_l, i_b] = links_state.root_COM[i_r, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - i_r = links_info.root_idx[I_l] - links_state.i_pos[i_l, i_b] = links_state.i_pos[i_l, i_b] - links_state.root_COM[i_l, i_b] - - i_inertial = links_info.inertial_i[I_l] - i_mass = links_info.inertial_mass[I_l] + links_state.mass_shift[i_l, i_b] - ( - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_pos[i_l, i_b], - links_state.cinr_quat[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - ) = gu.ti_transform_inertia_by_trans_quat( - i_inertial, i_mass, links_state.i_pos[i_l, i_b], links_state.i_quat[i_l, i_b] - ) - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue - - i_p = links_info.parent_idx[I_l] - - _i_j = links_info.joint_start[I_l] - _I_j = [_i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else _i_j - joint_type = joints_info.type[_I_j] - - p_pos = ti.Vector.zero(gs.ti_float, 3) - p_quat = gu.ti_identity_quat() - if i_p != -1: - p_pos = links_state.pos[i_p, i_b] - p_quat = links_state.quat[i_p, i_b] - - if joint_type == gs.JOINT_TYPE.FREE or (links_info.is_fixed[I_l] and i_p == -1): - links_state.j_pos[i_l, i_b] = links_state.pos[i_l, i_b] - links_state.j_quat[i_l, i_b] = links_state.quat[i_l, i_b] - else: ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat(links_info.pos[I_l], links_info.quat[I_l], p_pos, p_quat) + links_state.j_pos[f, i_l, i_b], + links_state.j_quat[f, i_l, i_b], + ) = gu.ti_transform_pos_quat_by_trans_quat( + joints_info.pos[I_j], + gu.ti_identity_quat(), + links_state.j_pos[f, i_l, i_b], + links_state.j_quat[f, i_l, i_b], + ) - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l_ in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_links) + ): + i_l = rigid_global_info.awake_links[i_l_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_l_ + I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - ( - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - joints_info.pos[I_j], - gu.ti_identity_quat(), - links_state.j_pos[i_l, i_b], - links_state.j_quat[i_l, i_b], - ) + if links_info.n_dofs[I_l] == 0: + continue - # cdof_fn - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l in range(n_links): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + offset_pos = links_state.root_COM[f, i_l, i_b] - joints_state.xanchor[f, i_j, i_b] + I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + joint_type = joints_info.type[I_j] - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - offset_pos = links_state.root_COM[i_l, i_b] - joints_state.xanchor[i_j, i_b] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] + dof_start = joints_info.dof_start[I_j] - dof_start = joints_info.dof_start[I_j] + if joint_type == gs.JOINT_TYPE.REVOLUTE: + dofs_state.cdof_ang[f, dof_start, i_b] = joints_state.xaxis[f, i_j, i_b] + dofs_state.cdof_vel[f, dof_start, i_b] = joints_state.xaxis[f, i_j, i_b].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.PRISMATIC: + dofs_state.cdof_ang[f, dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[f, dof_start, i_b] = joints_state.xaxis[f, i_j, i_b] + elif joint_type == gs.JOINT_TYPE.SPHERICAL: + xmat_T = gu.ti_quat_to_R(links_state.quat[f, i_l, i_b]).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[f, i + dof_start, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[f, i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) + elif joint_type == gs.JOINT_TYPE.FREE: + for i in ti.static(range(3)): + dofs_state.cdof_ang[f, i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[f, i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) + dofs_state.cdof_vel[f, i + dof_start, i_b][i] = 1.0 - if joint_type == gs.JOINT_TYPE.REVOLUTE: - dofs_state.cdof_ang[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.cdof_ang[dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[dof_start, i_b] = joints_state.xaxis[i_j, i_b] - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b]).transpose() - for i in ti.static(range(3)): - dofs_state.cdof_ang[i + dof_start, i_b] = xmat_T[i, :] - dofs_state.cdof_vel[i + dof_start, i_b] = xmat_T[i, :].cross(offset_pos) - elif joint_type == gs.JOINT_TYPE.FREE: - for i in ti.static(range(3)): - dofs_state.cdof_ang[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[i + dof_start, i_b] = ti.Vector.zero(gs.ti_float, 3) - dofs_state.cdof_vel[i + dof_start, i_b][i] = 1.0 - - xmat_T = gu.ti_quat_to_R(links_state.quat[i_l, i_b]).transpose() - for i in ti.static(range(3)): - dofs_state.cdof_ang[i + dof_start + 3, i_b] = xmat_T[i, :] - dofs_state.cdof_vel[i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) + xmat_T = gu.ti_quat_to_R(links_state.quat[f, i_l, i_b]).transpose() + for i in ti.static(range(3)): + dofs_state.cdof_ang[f, i + dof_start + 3, i_b] = xmat_T[i, :] + dofs_state.cdof_vel[f, i + dof_start + 3, i_b] = xmat_T[i, :].cross(offset_pos) - for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofvel_ang[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - dofs_state.cdofvel_vel[i_d, i_b] = dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] + for i_d in range(dof_start, joints_info.dof_end[I_j]): + dofs_state.cdofvel_ang[f, i_d, i_b] = dofs_state.cdof_ang[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] + dofs_state.cdofvel_vel[f, i_d, i_b] = dofs_state.cdof_vel[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] @ti.func def func_forward_kinematics( + f, i_b, links_state: array_class.LinksState, links_info: array_class.LinksInfo, @@ -4691,44 +4731,36 @@ def func_forward_kinematics( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): + # TODO: Non-differentiable function, should we make it differentiable? n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_kinematics_entity( - i_e, - i_b, - links_state, - links_info, - joints_state, - joints_info, - dofs_state, - dofs_info, - entities_info, - rigid_global_info, - static_rigid_sim_config, - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): - func_forward_kinematics_entity( - i_e, - i_b, - links_state, - links_info, - joints_state, - joints_info, - dofs_state, - dofs_info, - entities_info, - rigid_global_info, - static_rigid_sim_config, - ) + + for i_e_ in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_e_ + ) + func_forward_kinematics_entity( + f, + i_e, + i_b, + links_state, + links_info, + joints_state, + joints_info, + dofs_state, + dofs_info, + entities_info, + rigid_global_info, + static_rigid_sim_config, + ) @ti.func def func_forward_velocity( + f, i_b, entities_info: array_class.EntitiesInfo, links_info: array_class.LinksInfo, @@ -4739,35 +4771,26 @@ def func_forward_velocity( static_rigid_sim_config: ti.template(), ): n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e in range(n_entities): - func_forward_velocity_entity( - i_e=i_e, - i_b=i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) + for i_e_ in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_entities) + ): + i_e = ( + rigid_global_info.awake_entities[i_e_, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_e_ + ) + func_forward_velocity_entity( + f, + i_e=i_e, + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @gs.maybe_pure @@ -4806,6 +4829,7 @@ def kernel_forward_kinematics_entity( @ti.func def func_forward_kinematics_entity( + f, i_e, i_b, links_state: array_class.LinksState, @@ -4824,8 +4848,8 @@ def func_forward_kinematics_entity( pos = links_info.pos[I_l] quat = links_info.quat[I_l] if links_info.parent_idx[I_l] != -1: - parent_pos = links_state.pos[links_info.parent_idx[I_l], i_b] - parent_quat = links_state.quat[links_info.parent_idx[I_l], i_b] + parent_pos = links_state.pos[f, links_info.parent_idx[I_l], i_b] + parent_quat = links_state.quat[f, links_info.parent_idx[I_l], i_b] pos = parent_pos + gu.ti_transform_by_quat(pos, parent_quat) quat = gu.ti_transform_quat_by_quat(quat, parent_quat) @@ -4838,14 +4862,14 @@ def func_forward_kinematics_entity( # compute axis and anchor if joint_type == gs.JOINT_TYPE.FREE: - joints_state.xanchor[i_j, i_b] = ti.Vector( + joints_state.xanchor[f, i_j, i_b] = ti.Vector( [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[f, q_start, i_b], + rigid_global_info.qpos[f, q_start + 1, i_b], + rigid_global_info.qpos[f, q_start + 2, i_b], ] ) - joints_state.xaxis[i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) + joints_state.xaxis[f, i_j, i_b] = ti.Vector([0.0, 0.0, 1.0]) elif joint_type == gs.JOINT_TYPE.FIXED: pass else: @@ -4855,70 +4879,74 @@ def func_forward_kinematics_entity( elif joint_type == gs.JOINT_TYPE.PRISMATIC: axis = dofs_info.motion_vel[I_d] - joints_state.xanchor[i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos - joints_state.xaxis[i_j, i_b] = gu.ti_transform_by_quat(axis, quat) + joints_state.xanchor[f, i_j, i_b] = gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos + joints_state.xaxis[f, i_j, i_b] = gu.ti_transform_by_quat(axis, quat) if joint_type == gs.JOINT_TYPE.FREE: pos = ti.Vector( [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[f, q_start, i_b], + rigid_global_info.qpos[f, q_start + 1, i_b], + rigid_global_info.qpos[f, q_start + 2, i_b], ], dt=gs.ti_float, ) quat = ti.Vector( [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], + rigid_global_info.qpos[f, q_start + 3, i_b], + rigid_global_info.qpos[f, q_start + 4, i_b], + rigid_global_info.qpos[f, q_start + 5, i_b], + rigid_global_info.qpos[f, q_start + 6, i_b], ], dt=gs.ti_float, ) xyz = gu.ti_quat_to_xyz(quat) for i in ti.static(range(3)): - dofs_state.pos[dof_start + i, i_b] = pos[i] - dofs_state.pos[dof_start + 3 + i, i_b] = xyz[i] + dofs_state.pos[f, dof_start + i, i_b] = pos[i] + dofs_state.pos[f, dof_start + 3 + i, i_b] = xyz[i] elif joint_type == gs.JOINT_TYPE.FIXED: pass elif joint_type == gs.JOINT_TYPE.SPHERICAL: qloc = ti.Vector( [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], + rigid_global_info.qpos[f, q_start, i_b], + rigid_global_info.qpos[f, q_start + 1, i_b], + rigid_global_info.qpos[f, q_start + 2, i_b], + rigid_global_info.qpos[f, q_start + 3, i_b], ], dt=gs.ti_float, ) xyz = gu.ti_quat_to_xyz(qloc) for i in ti.static(range(3)): - dofs_state.pos[dof_start + i, i_b] = xyz[i] + dofs_state.pos[f, dof_start + i, i_b] = xyz[i] quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = joints_state.xanchor[f, i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) elif joint_type == gs.JOINT_TYPE.REVOLUTE: axis = dofs_info.motion_ang[I_d] - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + dofs_state.pos[f, dof_start, i_b] = ( + rigid_global_info.qpos[f, q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] ) - qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[dof_start, i_b]) + qloc = gu.ti_rotvec_to_quat(axis * dofs_state.pos[f, dof_start, i_b]) quat = gu.ti_transform_quat_by_quat(qloc, quat) - pos = joints_state.xanchor[i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) + pos = joints_state.xanchor[f, i_j, i_b] - gu.ti_transform_by_quat(joints_info.pos[I_j], quat) else: # joint_type == gs.JOINT_TYPE.PRISMATIC: - dofs_state.pos[dof_start, i_b] = ( - rigid_global_info.qpos[q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] + dofs_state.pos[f, dof_start, i_b] = ( + rigid_global_info.qpos[f, q_start, i_b] - rigid_global_info.qpos0[q_start, i_b] ) - pos = pos + joints_state.xaxis[i_j, i_b] * dofs_state.pos[dof_start, i_b] + pos = pos + joints_state.xaxis[f, i_j, i_b] * dofs_state.pos[f, dof_start, i_b] # Skip link pose update for fixed root links to let users manually overwrite them - if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): - links_state.pos[i_l, i_b] = pos - links_state.quat[i_l, i_b] = quat + # if not (links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]): + links_state.pos[f, i_l, i_b] = pos + links_state.quat[f, i_l, i_b] = quat + # print(f"setting {f}, {i_l}, {i_b} pos: {links_state.pos[f, i_l, i_b]}, quat: {links_state.quat[f, i_l, i_b]}") + # else: + # print(f"skipping link pose update for fixed root link {f}, {i_l}, {i_b}") @ti.func def func_forward_velocity_entity( + f, i_e, i_b, entities_info: array_class.EntitiesInfo, @@ -4935,8 +4963,8 @@ def func_forward_velocity_entity( cvel_vel = ti.Vector.zero(gs.ti_float, 3) cvel_ang = ti.Vector.zero(gs.ti_float, 3) if links_info.parent_idx[I_l] != -1: - cvel_vel = links_state.cd_vel[links_info.parent_idx[I_l], i_b] - cvel_ang = links_state.cd_ang[links_info.parent_idx[I_l], i_b] + cvel_vel = links_state.cd_vel[f, links_info.parent_idx[I_l], i_b] + cvel_ang = links_state.cd_ang[f, links_info.parent_idx[I_l], i_b] for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j @@ -4947,52 +4975,54 @@ def func_forward_velocity_entity( if joint_type == gs.JOINT_TYPE.FREE: for i_3 in ti.static(range(3)): cvel_vel = ( - cvel_vel + dofs_state.cdof_vel[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + cvel_vel + + dofs_state.cdof_vel[f, dof_start + i_3, i_b] * dofs_state.vel[f, dof_start + i_3, i_b] ) cvel_ang = ( - cvel_ang + dofs_state.cdof_ang[dof_start + i_3, i_b] * dofs_state.vel[dof_start + i_3, i_b] + cvel_ang + + dofs_state.cdof_ang[f, dof_start + i_3, i_b] * dofs_state.vel[f, dof_start + i_3, i_b] ) for i_3 in ti.static(range(3)): ( - dofs_state.cdofd_ang[dof_start + i_3, i_b], - dofs_state.cdofd_vel[dof_start + i_3, i_b], + dofs_state.cdofd_ang[f, dof_start + i_3, i_b], + dofs_state.cdofd_vel[f, dof_start + i_3, i_b], ) = ti.Vector.zero(gs.ti_float, 3), ti.Vector.zero(gs.ti_float, 3) ( - dofs_state.cdofd_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdofd_vel[dof_start + i_3 + 3, i_b], + dofs_state.cdofd_ang[f, dof_start + i_3 + 3, i_b], + dofs_state.cdofd_vel[f, dof_start + i_3 + 3, i_b], ) = gu.motion_cross_motion( cvel_ang, cvel_vel, - dofs_state.cdof_ang[dof_start + i_3 + 3, i_b], - dofs_state.cdof_vel[dof_start + i_3 + 3, i_b], + dofs_state.cdof_ang[f, dof_start + i_3 + 3, i_b], + dofs_state.cdof_vel[f, dof_start + i_3 + 3, i_b], ) for i_3 in ti.static(range(3)): cvel_vel = ( cvel_vel - + dofs_state.cdof_vel[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + + dofs_state.cdof_vel[f, dof_start + i_3 + 3, i_b] * dofs_state.vel[f, dof_start + i_3 + 3, i_b] ) cvel_ang = ( cvel_ang - + dofs_state.cdof_ang[dof_start + i_3 + 3, i_b] * dofs_state.vel[dof_start + i_3 + 3, i_b] + + dofs_state.cdof_ang[f, dof_start + i_3 + 3, i_b] * dofs_state.vel[f, dof_start + i_3 + 3, i_b] ) else: for i_d in range(dof_start, joints_info.dof_end[I_j]): - dofs_state.cdofd_ang[i_d, i_b], dofs_state.cdofd_vel[i_d, i_b] = gu.motion_cross_motion( + dofs_state.cdofd_ang[f, i_d, i_b], dofs_state.cdofd_vel[f, i_d, i_b] = gu.motion_cross_motion( cvel_ang, cvel_vel, - dofs_state.cdof_ang[i_d, i_b], - dofs_state.cdof_vel[i_d, i_b], + dofs_state.cdof_ang[f, i_d, i_b], + dofs_state.cdof_vel[f, i_d, i_b], ) for i_d in range(dof_start, joints_info.dof_end[I_j]): - cvel_vel = cvel_vel + dofs_state.cdof_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - cvel_ang = cvel_ang + dofs_state.cdof_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] + cvel_vel = cvel_vel + dofs_state.cdof_vel[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] + cvel_ang = cvel_ang + dofs_state.cdof_ang[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] - links_state.cd_vel[i_l, i_b] = cvel_vel - links_state.cd_ang[i_l, i_b] = cvel_ang + links_state.cd_vel[f, i_l, i_b] = cvel_vel + links_state.cd_ang[f, i_l, i_b] = cvel_ang @gs.maybe_pure @@ -5023,6 +5053,7 @@ def kernel_update_geoms( @ti.func def func_update_geoms( + f, i_b, entities_info: array_class.EntitiesInfo, geoms_info: array_class.GeomsInfo, @@ -5035,37 +5066,31 @@ def func_update_geoms( NOTE: this only update geom pose, not its verts and else. """ n_geoms = geoms_info.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_g in range(entities_info.geom_start[i_e], entities_info.geom_end[i_e]): - ( - geoms_state.pos[i_g, i_b], - geoms_state.quat[i_g, i_b], - ) = gu.ti_transform_pos_quat_by_trans_quat( - geoms_info.pos[i_g], - geoms_info.quat[i_g], - links_state.pos[geoms_info.link_idx[i_g], i_b], - links_state.quat[geoms_info.link_idx[i_g], i_b], - ) + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_0 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(n_geoms) + ): + i_e = rigid_global_info.awake_entities[i_0, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else 0 + + for i_1 in ( + range(entities_info.geom_start[i_e], entities_info.geom_end[i_e]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_g = i_1 if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - geoms_state.verts_updated[i_g, i_b] = 0 - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_g in range(n_geoms): ( - geoms_state.pos[i_g, i_b], - geoms_state.quat[i_g, i_b], + geoms_state.pos[f, i_g, i_b], + geoms_state.quat[f, i_g, i_b], ) = gu.ti_transform_pos_quat_by_trans_quat( geoms_info.pos[i_g], geoms_info.quat[i_g], - links_state.pos[geoms_info.link_idx[i_g], i_b], - links_state.quat[geoms_info.link_idx[i_g], i_b], + links_state.pos[f, geoms_info.link_idx[i_g], i_b], + links_state.quat[f, geoms_info.link_idx[i_g], i_b], ) - geoms_state.verts_updated[i_g, i_b] = 0 - @gs.maybe_pure @ti.kernel @@ -5077,7 +5102,7 @@ def kernel_update_verts_for_geom( free_verts_state: array_class.FreeVertsState, fixed_verts_state: array_class.FixedVertsState, ): - _B = geoms_state.verts_updated.shape[1] + _B = geoms_state.verts_updated.shape[2] for i_b in range(_B): func_update_verts_for_geom(i_g, i_b, geoms_state, geoms_info, verts_info, free_verts_state, fixed_verts_state) @@ -5129,17 +5154,18 @@ def func_update_all_verts(self): @gs.maybe_pure @ti.kernel def kernel_update_geom_aabbs( + f: ti.i32, geoms_state: array_class.GeomsState, geoms_init_AABB: array_class.GeomsInitAABB, static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_geoms = geoms_state.pos.shape[0] - _B = geoms_state.pos.shape[1] + n_geoms = geoms_state.pos.shape[1] + _B = geoms_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_g, i_b in ti.ndrange(n_geoms, _B): - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[f, i_g, i_b] + g_quat = geoms_state.quat[f, i_g, i_b] lower = gu.ti_vec3(ti.math.inf) upper = gu.ti_vec3(-ti.math.inf) @@ -5148,13 +5174,14 @@ def kernel_update_geom_aabbs( lower = ti.min(lower, corner_pos) upper = ti.max(upper, corner_pos) - geoms_state.aabb_min[i_g, i_b] = lower - geoms_state.aabb_max[i_g, i_b] = upper + geoms_state.aabb_min[f, i_g, i_b] = lower + geoms_state.aabb_max[f, i_g, i_b] = upper @gs.maybe_pure @ti.kernel def kernel_update_vgeoms( + f: ti.int32, vgeoms_info: array_class.VGeomsInfo, vgeoms_state: array_class.VGeomsState, links_state: array_class.LinksState, @@ -5165,14 +5192,14 @@ def kernel_update_vgeoms( Vgeoms are only for visualization purposes. """ n_vgeoms = vgeoms_info.link_idx.shape[0] - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_g, i_b in ti.ndrange(n_vgeoms, _B): vgeoms_state.pos[i_g, i_b], vgeoms_state.quat[i_g, i_b] = gu.ti_transform_pos_quat_by_trans_quat( vgeoms_info.pos[i_g], vgeoms_info.quat[i_g], - links_state.pos[vgeoms_info.link_idx[i_g], i_b], - links_state.quat[vgeoms_info.link_idx[i_g], i_b], + links_state.pos[f, vgeoms_info.link_idx[i_g], i_b], + links_state.quat[f, vgeoms_info.link_idx[i_g], i_b], ) @@ -5408,25 +5435,30 @@ def func_clear_external_force( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = links_state.pos.shape[1] - n_links = links_state.pos.shape[0] + _B = links_state.pos.shape[2] + n_links = links_state.pos.shape[1] + n_frames = links_state.pos.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_l, i_b in ti.ndrange(n_links, _B): - links_state.cfrc_applied_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cfrc_applied_vel[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_f, i_0, i_b in ( + ti.ndrange(n_frames, 1, _B) + if ti.static(static_rigid_sim_config.use_hibernation) + else ti.ndrange(n_frames, n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 + + links_state.cfrc_applied_ang[i_f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cfrc_applied_vel[i_f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) @ti.func def func_torque_and_passive_force( + f: ti.i32, entities_state: array_class.EntitiesState, entities_info: array_class.EntitiesInfo, dofs_state: array_class.DofsState, @@ -5440,8 +5472,8 @@ def func_torque_and_passive_force( contact_island_state: array_class.ContactIslandState, ): n_entities = entities_info.n_links.shape[0] - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] + _B = dofs_state.ctrl_mode.shape[2] + n_dofs = dofs_state.ctrl_mode.shape[1] n_links = links_info.root_idx.shape[0] # compute force based on each dof's ctrl mode @@ -5460,19 +5492,19 @@ def func_torque_and_passive_force( for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d force = gs.ti_float(0.0) - if dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.FORCE: - force = dofs_state.ctrl_force[i_d, i_b] - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.VELOCITY: - force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[i_d, i_b] - dofs_state.vel[i_d, i_b]) - elif dofs_state.ctrl_mode[i_d, i_b] == gs.CTRL_MODE.POSITION and not ( + if dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.FORCE: + force = dofs_state.ctrl_force[f, i_d, i_b] + elif dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.VELOCITY: + force = dofs_info.kv[I_d] * (dofs_state.ctrl_vel[f, i_d, i_b] - dofs_state.vel[f, i_d, i_b]) + elif dofs_state.ctrl_mode[f, i_d, i_b] == gs.CTRL_MODE.POSITION and not ( joint_type == gs.JOINT_TYPE.FREE and i_d >= links_info.dof_start[I_l] + 3 ): force = ( - dofs_info.kp[I_d] * (dofs_state.ctrl_pos[i_d, i_b] - dofs_state.pos[i_d, i_b]) - - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] + dofs_info.kp[I_d] * (dofs_state.ctrl_pos[f, i_d, i_b] - dofs_state.pos[f, i_d, i_b]) + - dofs_info.kv[I_d] * dofs_state.vel[f, i_d, i_b] ) - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + dofs_state.qf_applied[f, i_d, i_b] = ti.math.clamp( force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1], @@ -5483,24 +5515,24 @@ def func_torque_and_passive_force( dof_start = links_info.dof_start[I_l] if joint_type == gs.JOINT_TYPE.FREE and ( - dofs_state.ctrl_mode[dof_start + 3, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 4, i_b] == gs.CTRL_MODE.POSITION - or dofs_state.ctrl_mode[dof_start + 5, i_b] == gs.CTRL_MODE.POSITION + dofs_state.ctrl_mode[f, dof_start + 3, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[f, dof_start + 4, i_b] == gs.CTRL_MODE.POSITION + or dofs_state.ctrl_mode[f, dof_start + 5, i_b] == gs.CTRL_MODE.POSITION ): xyz = ti.Vector( [ - dofs_state.pos[0 + 3 + dof_start, i_b], - dofs_state.pos[1 + 3 + dof_start, i_b], - dofs_state.pos[2 + 3 + dof_start, i_b], + dofs_state.pos[f, 0 + 3 + dof_start, i_b], + dofs_state.pos[f, 1 + 3 + dof_start, i_b], + dofs_state.pos[f, 2 + 3 + dof_start, i_b], ], dt=gs.ti_float, ) ctrl_xyz = ti.Vector( [ - dofs_state.ctrl_pos[0 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[1 + 3 + dof_start, i_b], - dofs_state.ctrl_pos[2 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[f, 0 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[f, 1 + 3 + dof_start, i_b], + dofs_state.ctrl_pos[f, 2 + 3 + dof_start, i_b], ], dt=gs.ti_float, ) @@ -5514,72 +5546,53 @@ def func_torque_and_passive_force( for j in ti.static(range(3)): i_d = dof_start + 3 + j I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[i_d, i_b] + force = dofs_info.kp[I_d] * rotvec[j] - dofs_info.kv[I_d] * dofs_state.vel[f, i_d, i_b] - dofs_state.qf_applied[i_d, i_b] = ti.math.clamp( + dofs_state.qf_applied[f, i_d, i_b] = ti.math.clamp( force, dofs_info.force_range[I_d][0], dofs_info.force_range[I_d][1] ) if ti.abs(force) > gs.EPS: wakeup = True - if ti.static(static_rigid_sim_config.use_hibernation) and entities_state.hibernated[i_e, i_b] and wakeup: - func_wakeup_entity_and_its_temp_island( - i_e, - i_b, - entities_state, - entities_info, - dofs_state, - links_state, - geoms_state, - rigid_global_info, - contact_island_state, - ) - - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] - I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - if links_info.n_dofs[I_l] == 0: - continue + if ti.static(static_rigid_sim_config.use_hibernation): + if entities_state.hibernated[i_e, i_b] and wakeup: + func_wakeup_entity_and_its_temp_island( + i_e, + i_b, + entities_state, + entities_info, + dofs_state, + links_state, + geoms_state, + rigid_global_info, + contact_island_state, + ) - i_j = links_info.joint_start[I_l] - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - joint_type = joints_info.type[I_j] - - if joint_type != gs.JOINT_TYPE.FREE and joint_type != gs.JOINT_TYPE.FIXED: - q_start = links_info.q_start[I_l] - dof_start = links_info.dof_start[I_l] - dof_end = links_info.dof_end[I_l] - - for j_d in range(dof_end - dof_start): - I_d = ( - [dof_start + j_d, i_b] - if ti.static(static_rigid_sim_config.batch_dofs_info) - else dof_start + j_d - ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_d = rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 I_d = [i_d, i_b] if ti.static(static_rigid_sim_config.batch_dofs_info) else i_d - dofs_state.qf_passive[i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[i_d, i_b] - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): + dofs_state.qf_passive[f, i_d, i_b] = -dofs_info.damping[I_d] * dofs_state.vel[f, i_d, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] == 0: continue @@ -5598,13 +5611,14 @@ def func_torque_and_passive_force( if ti.static(static_rigid_sim_config.batch_dofs_info) else dof_start + j_d ) - dofs_state.qf_passive[dof_start + j_d, i_b] += ( - -rigid_global_info.qpos[q_start + j_d, i_b] * dofs_info.stiffness[I_d] + dofs_state.qf_passive[f, dof_start + j_d, i_b] += ( + -rigid_global_info.qpos[f, q_start + j_d, i_b] * dofs_info.stiffness[I_d] ) @ti.func def func_update_acc( + f: ti.i32, update_cacc: ti.template(), dofs_state: array_class.DofsState, links_info: array_class.LinksInfo, @@ -5613,173 +5627,133 @@ def func_update_acc( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_links = links_info.root_idx.shape[0] + _B = dofs_state.ctrl_mode.shape[2] n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - - if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( - 1 - entities_info.gravity_compensation[i_e] - ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang - if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] - + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] - + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] - ) - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + for i_l in range(entities_info.link_start[i_e], entities_info.link_end[i_e]): I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] if i_p == -1: - links_state.cdd_vel[i_l, i_b] = -rigid_global_info.gravity[i_b] * ( + links_state.cdd_vel[f, i_l, i_b] = -rigid_global_info.gravity[i_b] * ( 1 - entities_info.gravity_compensation[i_e] ) - links_state.cdd_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cdd_ang[f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) - links_state.cacc_ang[i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cacc_lin[f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) + links_state.cacc_ang[f, i_l, i_b] = ti.Vector.zero(gs.ti_float, 3) else: - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_p, i_b] - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_p, i_b] + links_state.cdd_vel[f, i_l, i_b] = links_state.cdd_vel[f, i_p, i_b] + links_state.cdd_ang[f, i_l, i_b] = links_state.cdd_ang[f, i_p, i_b] if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = links_state.cacc_lin[i_p, i_b] - links_state.cacc_ang[i_l, i_b] = links_state.cacc_ang[i_p, i_b] + links_state.cacc_lin[f, i_l, i_b] = links_state.cacc_lin[f, i_p, i_b] + links_state.cacc_ang[f, i_l, i_b] = links_state.cacc_ang[f, i_p, i_b] for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): # cacc = cacc_parent + cdofdot * qvel + cdof * qacc - local_cdd_vel = dofs_state.cdofd_vel[i_d, i_b] * dofs_state.vel[i_d, i_b] - local_cdd_ang = dofs_state.cdofd_ang[i_d, i_b] * dofs_state.vel[i_d, i_b] - links_state.cdd_vel[i_l, i_b] = links_state.cdd_vel[i_l, i_b] + local_cdd_vel - links_state.cdd_ang[i_l, i_b] = links_state.cdd_ang[i_l, i_b] + local_cdd_ang + local_cdd_vel = dofs_state.cdofd_vel[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] + local_cdd_ang = dofs_state.cdofd_ang[f, i_d, i_b] * dofs_state.vel[f, i_d, i_b] + links_state.cdd_vel[f, i_l, i_b] = links_state.cdd_vel[f, i_l, i_b] + local_cdd_vel + links_state.cdd_ang[f, i_l, i_b] = links_state.cdd_ang[f, i_l, i_b] + local_cdd_ang if ti.static(update_cacc): - links_state.cacc_lin[i_l, i_b] = ( - links_state.cacc_lin[i_l, i_b] + links_state.cacc_lin[f, i_l, i_b] = ( + links_state.cacc_lin[f, i_l, i_b] + local_cdd_vel - + dofs_state.cdof_vel[i_d, i_b] * dofs_state.acc[i_d, i_b] + + dofs_state.cdof_vel[f, i_d, i_b] * dofs_state.acc[f, i_d, i_b] ) - links_state.cacc_ang[i_l, i_b] = ( - links_state.cacc_ang[i_l, i_b] + links_state.cacc_ang[f, i_l, i_b] = ( + links_state.cacc_ang[f, i_l, i_b] + local_cdd_ang - + dofs_state.cdof_ang[i_d, i_b] * dofs_state.acc[i_d, i_b] + + dofs_state.cdof_ang[f, i_d, i_b] * dofs_state.acc[f, i_d, i_b] ) @ti.func def func_update_force( + f: ti.i32, links_state: array_class.LinksState, links_info: array_class.LinksInfo, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] n_links = links_info.root_idx.shape[0] n_entities = entities_info.n_links.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - - f1_ang, f1_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cdd_vel[i_l, i_b], - links_state.cdd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cd_vel[i_l, i_b], - links_state.cd_ang[i_l, i_b], - ) - f2_ang, f2_vel = gu.motion_cross_force( - links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel - ) - - links_state.cfrc_vel[i_l, i_b] = f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] - links_state.cfrc_ang[i_l, i_b] = f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i in range(entities_info.n_links[i_e]): - i_l = entities_info.link_end[i_e] - 1 - i - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - i_p = links_info.parent_idx[I_l] - if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): f1_ang, f1_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cdd_vel[i_l, i_b], - links_state.cdd_ang[i_l, i_b], + links_state.cinr_pos[f, i_l, i_b], + links_state.cinr_inertial[f, i_l, i_b], + links_state.cinr_mass[f, i_l, i_b], + links_state.cdd_vel[f, i_l, i_b], + links_state.cdd_ang[f, i_l, i_b], ) f2_ang, f2_vel = gu.inertial_mul( - links_state.cinr_pos[i_l, i_b], - links_state.cinr_inertial[i_l, i_b], - links_state.cinr_mass[i_l, i_b], - links_state.cd_vel[i_l, i_b], - links_state.cd_ang[i_l, i_b], + links_state.cinr_pos[f, i_l, i_b], + links_state.cinr_inertial[f, i_l, i_b], + links_state.cinr_mass[f, i_l, i_b], + links_state.cd_vel[f, i_l, i_b], + links_state.cd_ang[f, i_l, i_b], ) f2_ang, f2_vel = gu.motion_cross_force( - links_state.cd_ang[i_l, i_b], links_state.cd_vel[i_l, i_b], f2_ang, f2_vel + links_state.cd_ang[f, i_l, i_b], links_state.cd_vel[f, i_l, i_b], f2_ang, f2_vel ) - links_state.cfrc_vel[i_l, i_b] = f1_vel + f2_vel + links_state.cfrc_applied_vel[i_l, i_b] - links_state.cfrc_ang[i_l, i_b] = f1_ang + f2_ang + links_state.cfrc_applied_ang[i_l, i_b] + links_state.cfrc_vel[f, i_l, i_b] = f1_vel + f2_vel + links_state.cfrc_applied_vel[f, i_l, i_b] + links_state.cfrc_ang[f, i_l, i_b] = f1_ang + f2_ang + links_state.cfrc_applied_ang[f, i_l, i_b] + + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_e, i_b in ti.ndrange(n_entities, _B): for i in range(entities_info.n_links[i_e]): i_l = entities_info.link_end[i_e] - 1 - i I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l i_p = links_info.parent_idx[I_l] if i_p != -1: - links_state.cfrc_vel[i_p, i_b] = links_state.cfrc_vel[i_p, i_b] + links_state.cfrc_vel[i_l, i_b] - links_state.cfrc_ang[i_p, i_b] = links_state.cfrc_ang[i_p, i_b] + links_state.cfrc_ang[i_l, i_b] + links_state.cfrc_vel[f, i_p, i_b] = ( + links_state.cfrc_vel[f, i_p, i_b] + links_state.cfrc_vel[f, i_l, i_b] + ) + links_state.cfrc_ang[f, i_p, i_b] = ( + links_state.cfrc_ang[f, i_p, i_b] + links_state.cfrc_ang[f, i_l, i_b] + ) @ti.func @@ -5808,92 +5782,86 @@ def func_actuation(self): @ti.func def func_bias_force( + f: ti.i32, dofs_state: array_class.DofsState, links_state: array_class.LinksState, links_info: array_class.LinksInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) - - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] - # + self.dofs_state.qf_actuator[i_d, i_b] - ) - - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] - - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l for i_d in range(links_info.dof_start[I_l], links_info.dof_end[I_l]): - dofs_state.qf_bias[i_d, i_b] = dofs_state.cdof_ang[i_d, i_b].dot( - links_state.cfrc_ang[i_l, i_b] - ) + dofs_state.cdof_vel[i_d, i_b].dot(links_state.cfrc_vel[i_l, i_b]) - - dofs_state.force[i_d, i_b] = ( - dofs_state.qf_passive[i_d, i_b] - - dofs_state.qf_bias[i_d, i_b] - + dofs_state.qf_applied[i_d, i_b] + dofs_state.qf_bias[f, i_d, i_b] = dofs_state.cdof_ang[f, i_d, i_b].dot( + links_state.cfrc_ang[f, i_l, i_b] + ) + dofs_state.cdof_vel[f, i_d, i_b].dot(links_state.cfrc_vel[f, i_l, i_b]) + + dofs_state.force[f, i_d, i_b] = ( + dofs_state.qf_passive[f, i_d, i_b] + - dofs_state.qf_bias[f, i_d, i_b] + + dofs_state.qf_applied[f, i_d, i_b] # + self.dofs_state.qf_actuator[i_d, i_b] ) - dofs_state.qf_smooth[i_d, i_b] = dofs_state.force[i_d, i_b] + dofs_state.qf_smooth[f, i_d, i_b] = dofs_state.force[f, i_d, i_b] @ti.func def func_compute_qacc( + f: ti.i32, dofs_state: array_class.DofsState, entities_info: array_class.EntitiesInfo, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] + _B = dofs_state.ctrl_mode.shape[2] n_entities = entities_info.n_links.shape[0] func_solve_mass( + f=f, vec=dofs_state.force, out=dofs_state.acc_smooth, + dofs_state=dofs_state, entities_info=entities_info, rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_b in range(_B): - for i_e_ in range(rigid_global_info.n_awake_entities[i_b]): - i_e = rigid_global_info.awake_entities[i_e_, i_b] - for i_d1_ in range(entities_info.n_dofs[i_e]): - i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) - for i_e, i_b in ti.ndrange(n_entities, _B): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_entities, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_entities[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_e = ( + rigid_global_info.awake_entities[i_1, i_b] + if ti.static(static_rigid_sim_config.use_hibernation) + else i_0 + ) + for i_d1_ in range(entities_info.n_dofs[i_e]): i_d1 = entities_info.dof_start[i_e] + i_d1_ - dofs_state.acc[i_d1, i_b] = dofs_state.acc_smooth[i_d1, i_b] + dofs_state.acc[f, i_d1, i_b] = dofs_state.acc_smooth[f, i_d1, i_b] @ti.func def func_integrate( + f: ti.i32, dofs_state: array_class.DofsState, links_info: array_class.LinksInfo, joints_info: array_class.JointsInfo, @@ -5901,115 +5869,35 @@ def func_integrate( static_rigid_sim_config: ti.template(), ): - _B = dofs_state.ctrl_mode.shape[1] - n_dofs = dofs_state.ctrl_mode.shape[0] + _B = dofs_state.ctrl_mode.shape[2] + n_dofs = dofs_state.ctrl_mode.shape[1] n_links = links_info.root_idx.shape[0] - if ti.static(static_rigid_sim_config.use_hibernation): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): - i_d = rigid_global_info.awake_dofs[i_d_, i_b] - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * static_rigid_sim_config.substep_dt - ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b in range(_B): - for i_l_ in range(rigid_global_info.n_awake_links[i_b]): - i_l = rigid_global_info.awake_links[i_l_, i_b] - I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l - - for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): - I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j - dof_start = joints_info.dof_start[I_j] - q_start = joints_info.q_start[I_j] - q_end = joints_info.q_end[I_j] - - joint_type = joints_info.type[I_j] - - if joint_type == gs.JOINT_TYPE.FREE: - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + 3, i_b], - rigid_global_info.qpos[q_start + 4, i_b], - rigid_global_info.qpos[q_start + 5, i_b], - rigid_global_info.qpos[q_start + 6, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], - ] - ) - * static_rigid_sim_config.substep_dt - ) - qrot = gu.ti_rotvec_to_quat(ang) - rot = gu.ti_transform_quat_by_quat(qrot, rot) - pos = ti.Vector( - [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - ] - ) - vel = ti.Vector( - [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], - ] - ) - pos = pos + vel * static_rigid_sim_config.substep_dt - for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j] - elif joint_type == gs.JOINT_TYPE.FIXED: - pass - elif joint_type == gs.JOINT_TYPE.SPHERICAL: - rot = ti.Vector( - [ - rigid_global_info.qpos[q_start + 0, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], - rigid_global_info.qpos[q_start + 3, i_b], - ] - ) - ang = ( - ti.Vector( - [ - dofs_state.vel[dof_start + 3, i_b], - dofs_state.vel[dof_start + 4, i_b], - dofs_state.vel[dof_start + 5, i_b], - ] - ) - * static_rigid_sim_config.substep_dt - ) - qrot = gu.ti_rotvec_to_quat(ang) - rot = gu.ti_transform_quat_by_quat(qrot, rot) - for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j, i_b] = rot[j] - - else: - for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * static_rigid_sim_config.substep_dt - ) - - else: - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in ti.ndrange(n_dofs, _B): - dofs_state.vel[i_d, i_b] = ( - dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * static_rigid_sim_config.substep_dt + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_dofs, _B): + for i_1 in ( + range(rigid_global_info.n_awake_dofs[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_d = rigid_global_info.awake_dofs[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 + + dofs_state.vel[f + 1, i_d, i_b] = ( + dofs_state.vel[f, i_d, i_b] + dofs_state.acc[f, i_d, i_b] * static_rigid_sim_config.substep_dt ) - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_l, i_b in ti.ndrange(n_links, _B): + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_0, i_b in ( + ti.ndrange(1, _B) if ti.static(static_rigid_sim_config.use_hibernation) else ti.ndrange(n_links, _B) + ): + for i_1 in ( + range(rigid_global_info.n_awake_links[i_b]) + if ti.static(static_rigid_sim_config.use_hibernation) + else range(1) + ): + i_l = rigid_global_info.awake_links[i_1, i_b] if ti.static(static_rigid_sim_config.use_hibernation) else i_0 I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + if links_info.n_dofs[I_l] == 0: continue @@ -6024,52 +5912,220 @@ def func_integrate( if joint_type == gs.JOINT_TYPE.FREE: pos = ti.Vector( [ - rigid_global_info.qpos[q_start, i_b], - rigid_global_info.qpos[q_start + 1, i_b], - rigid_global_info.qpos[q_start + 2, i_b], + rigid_global_info.qpos[f, q_start, i_b], + rigid_global_info.qpos[f, q_start + 1, i_b], + rigid_global_info.qpos[f, q_start + 2, i_b], ] ) vel = ti.Vector( [ - dofs_state.vel[dof_start, i_b], - dofs_state.vel[dof_start + 1, i_b], - dofs_state.vel[dof_start + 2, i_b], + dofs_state.vel[f + 1, dof_start, i_b], + dofs_state.vel[f + 1, dof_start + 1, i_b], + dofs_state.vel[f + 1, dof_start + 2, i_b], ] ) - pos = pos + vel * static_rigid_sim_config.substep_dt + pos_ = pos + vel * static_rigid_sim_config.substep_dt for j in ti.static(range(3)): - rigid_global_info.qpos[q_start + j, i_b] = pos[j] + rigid_global_info.qpos[f + 1, q_start + j, i_b] = pos_[j] if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 rot = ti.Vector( [ - rigid_global_info.qpos[q_start + rot_offset + 0, i_b], - rigid_global_info.qpos[q_start + rot_offset + 1, i_b], - rigid_global_info.qpos[q_start + rot_offset + 2, i_b], - rigid_global_info.qpos[q_start + rot_offset + 3, i_b], + rigid_global_info.qpos[f, q_start + rot_offset + 0, i_b], + rigid_global_info.qpos[f, q_start + rot_offset + 1, i_b], + rigid_global_info.qpos[f, q_start + rot_offset + 2, i_b], + rigid_global_info.qpos[f, q_start + rot_offset + 3, i_b], ] ) ang = ( ti.Vector( [ - dofs_state.vel[dof_start + rot_offset + 0, i_b], - dofs_state.vel[dof_start + rot_offset + 1, i_b], - dofs_state.vel[dof_start + rot_offset + 2, i_b], + dofs_state.vel[f + 1, dof_start + rot_offset + 0, i_b], + dofs_state.vel[f + 1, dof_start + rot_offset + 1, i_b], + dofs_state.vel[f + 1, dof_start + rot_offset + 2, i_b], ] ) * static_rigid_sim_config.substep_dt ) qrot = gu.ti_rotvec_to_quat(ang) - rot = gu.ti_transform_quat_by_quat(qrot, rot) + rot_ = gu.ti_transform_quat_by_quat(qrot, rot) for j in ti.static(range(4)): - rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j] + rigid_global_info.qpos[f + 1, q_start + j + rot_offset, i_b] = rot_[j] else: for j in range(q_end - q_start): - rigid_global_info.qpos[q_start + j, i_b] = ( - rigid_global_info.qpos[q_start + j, i_b] - + dofs_state.vel[dof_start + j, i_b] * static_rigid_sim_config.substep_dt + rigid_global_info.qpos[f + 1, q_start + j, i_b] = ( + rigid_global_info.qpos[f, q_start + j, i_b] + + dofs_state.vel[f + 1, dof_start + j, i_b] * static_rigid_sim_config.substep_dt ) + # FIXME: hibernation implementation is different from the non-hibernation one, check if the logic is same + + # if ti.static(static_rigid_sim_config.use_hibernation): + # ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + # for i_b in range(_B): + # for i_d_ in range(rigid_global_info.n_awake_dofs[i_b]): + # i_d = rigid_global_info.awake_dofs[i_d_, i_b] + # dofs_state.vel[i_d, i_b] = ( + # dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * static_rigid_sim_config.substep_dt + # ) + + # ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + # for i_b in range(_B): + # for i_l_ in range(rigid_global_info.n_awake_links[i_b]): + # i_l = rigid_global_info.awake_links[i_l_, i_b] + # I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + + # for i_j in range(links_info.joint_start[I_l], links_info.joint_end[I_l]): + # I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + # dof_start = joints_info.dof_start[I_j] + # q_start = joints_info.q_start[I_j] + # q_end = joints_info.q_end[I_j] + + # joint_type = joints_info.type[I_j] + + # if joint_type == gs.JOINT_TYPE.FREE: + # rot = ti.Vector( + # [ + # rigid_global_info.qpos[q_start + 3, i_b], + # rigid_global_info.qpos[q_start + 4, i_b], + # rigid_global_info.qpos[q_start + 5, i_b], + # rigid_global_info.qpos[q_start + 6, i_b], + # ] + # ) + # ang = ( + # ti.Vector( + # [ + # dofs_state.vel[dof_start + 3, i_b], + # dofs_state.vel[dof_start + 4, i_b], + # dofs_state.vel[dof_start + 5, i_b], + # ] + # ) + # * static_rigid_sim_config.substep_dt + # ) + # qrot = gu.ti_rotvec_to_quat(ang) + # rot = gu.ti_transform_quat_by_quat(qrot, rot) + # pos = ti.Vector( + # [ + # rigid_global_info.qpos[q_start, i_b], + # rigid_global_info.qpos[q_start + 1, i_b], + # rigid_global_info.qpos[q_start + 2, i_b], + # ] + # ) + # vel = ti.Vector( + # [ + # dofs_state.vel[dof_start, i_b], + # dofs_state.vel[dof_start + 1, i_b], + # dofs_state.vel[dof_start + 2, i_b], + # ] + # ) + # pos = pos + vel * static_rigid_sim_config.substep_dt + # for j in ti.static(range(3)): + # rigid_global_info.qpos[q_start + j, i_b] = pos[j] + # for j in ti.static(range(4)): + # rigid_global_info.qpos[q_start + j + 3, i_b] = rot[j] + # elif joint_type == gs.JOINT_TYPE.FIXED: + # pass + # elif joint_type == gs.JOINT_TYPE.SPHERICAL: + # rot = ti.Vector( + # [ + # rigid_global_info.qpos[q_start + 0, i_b], + # rigid_global_info.qpos[q_start + 1, i_b], + # rigid_global_info.qpos[q_start + 2, i_b], + # rigid_global_info.qpos[q_start + 3, i_b], + # ] + # ) + # ang = ( + # ti.Vector( + # [ + # dofs_state.vel[dof_start + 3, i_b], + # dofs_state.vel[dof_start + 4, i_b], + # dofs_state.vel[dof_start + 5, i_b], + # ] + # ) + # * static_rigid_sim_config.substep_dt + # ) + # qrot = gu.ti_rotvec_to_quat(ang) + # rot = gu.ti_transform_quat_by_quat(qrot, rot) + # for j in ti.static(range(4)): + # rigid_global_info.qpos[q_start + j, i_b] = rot[j] + + # else: + # for j in range(q_end - q_start): + # rigid_global_info.qpos[q_start + j, i_b] = ( + # rigid_global_info.qpos[q_start + j, i_b] + # + dofs_state.vel[dof_start + j, i_b] * static_rigid_sim_config.substep_dt + # ) + + # else: + # ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + # for i_d, i_b in ti.ndrange(n_dofs, _B): + # dofs_state.vel[i_d, i_b] = ( + # dofs_state.vel[i_d, i_b] + dofs_state.acc[i_d, i_b] * static_rigid_sim_config.substep_dt + # ) + + # ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + # for i_l, i_b in ti.ndrange(n_links, _B): + # I_l = [i_l, i_b] if ti.static(static_rigid_sim_config.batch_links_info) else i_l + # if links_info.n_dofs[I_l] == 0: + # continue + + # dof_start = links_info.dof_start[I_l] + # q_start = links_info.q_start[I_l] + # q_end = links_info.q_end[I_l] + + # i_j = links_info.joint_start[I_l] + # I_j = [i_j, i_b] if ti.static(static_rigid_sim_config.batch_joints_info) else i_j + # joint_type = joints_info.type[I_j] + + # if joint_type == gs.JOINT_TYPE.FREE: + # pos = ti.Vector( + # [ + # rigid_global_info.qpos[q_start, i_b], + # rigid_global_info.qpos[q_start + 1, i_b], + # rigid_global_info.qpos[q_start + 2, i_b], + # ] + # ) + # vel = ti.Vector( + # [ + # dofs_state.vel[dof_start, i_b], + # dofs_state.vel[dof_start + 1, i_b], + # dofs_state.vel[dof_start + 2, i_b], + # ] + # ) + # pos = pos + vel * static_rigid_sim_config.substep_dt + # for j in ti.static(range(3)): + # rigid_global_info.qpos[q_start + j, i_b] = pos[j] + # if joint_type == gs.JOINT_TYPE.SPHERICAL or joint_type == gs.JOINT_TYPE.FREE: + # rot_offset = 3 if joint_type == gs.JOINT_TYPE.FREE else 0 + # rot = ti.Vector( + # [ + # rigid_global_info.qpos[q_start + rot_offset + 0, i_b], + # rigid_global_info.qpos[q_start + rot_offset + 1, i_b], + # rigid_global_info.qpos[q_start + rot_offset + 2, i_b], + # rigid_global_info.qpos[q_start + rot_offset + 3, i_b], + # ] + # ) + # ang = ( + # ti.Vector( + # [ + # dofs_state.vel[dof_start + rot_offset + 0, i_b], + # dofs_state.vel[dof_start + rot_offset + 1, i_b], + # dofs_state.vel[dof_start + rot_offset + 2, i_b], + # ] + # ) + # * static_rigid_sim_config.substep_dt + # ) + # qrot = gu.ti_rotvec_to_quat(ang) + # rot = gu.ti_transform_quat_by_quat(qrot, rot) + # for j in ti.static(range(4)): + # rigid_global_info.qpos[q_start + j + rot_offset, i_b] = rot[j] + # else: + # for j in range(q_end - q_start): + # rigid_global_info.qpos[q_start + j, i_b] = ( + # rigid_global_info.qpos[q_start + j, i_b] + # + dofs_state.vel[dof_start + j, i_b] * static_rigid_sim_config.substep_dt + # ) + @ti.func def func_integrate_dq_entity( @@ -6154,19 +6210,20 @@ def func_integrate_dq_entity( @gs.maybe_pure @ti.kernel def kernel_update_geoms_render_T( + f: ti.int32, geoms_render_T: ti.types.ndarray(), geoms_state: array_class.GeomsState, rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): - n_geoms = geoms_state.pos.shape[0] - _B = geoms_state.pos.shape[1] + n_geoms = geoms_state.pos.shape[1] + _B = geoms_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_g, i_b in ti.ndrange(n_geoms, _B): geom_T = gu.ti_trans_quat_to_T( - geoms_state.pos[i_g, i_b] + rigid_global_info.envs_offset[i_b], - geoms_state.quat[i_g, i_b], + geoms_state.pos[f, i_g, i_b] + rigid_global_info.envs_offset[i_b], + geoms_state.quat[f, i_g, i_b], ) for i, j in ti.static(ti.ndrange(4, 4)): geoms_render_T[i_g, i_b, i, j] = ti.cast(geom_T[i, j], ti.float32) @@ -6184,7 +6241,7 @@ def kernel_update_vgeoms_render_T( static_rigid_sim_cache_key: array_class.StaticRigidSimCacheKey, ): n_vgeoms = vgeoms_info.link_idx.shape[0] - _B = links_state.pos.shape[1] + _B = links_state.pos.shape[2] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_g, i_b in ti.ndrange(n_vgeoms, _B): geom_T = gu.ti_trans_quat_to_T( @@ -6198,6 +6255,7 @@ def kernel_update_vgeoms_render_T( @gs.maybe_pure @ti.kernel def kernel_get_state( + f: ti.i32, qpos: ti.types.ndarray(), vel: ti.types.ndarray(), links_pos: ti.types.ndarray(), @@ -6221,19 +6279,19 @@ def kernel_get_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b in ti.ndrange(n_qs, _B): - qpos[i_b, i_q] = rigid_global_info.qpos[i_q, i_b] + qpos[i_b, i_q] = rigid_global_info.qpos[f, i_q, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b in ti.ndrange(n_dofs, _B): - vel[i_b, i_d] = dofs_state.vel[i_d, i_b] + vel[i_b, i_d] = dofs_state.vel[f, i_d, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l, i_b in ti.ndrange(n_links, _B): for i in ti.static(range(3)): - links_pos[i_b, i_l, i] = links_state.pos[i_l, i_b][i] + links_pos[i_b, i_l, i] = links_state.pos[f, i_l, i_b][i] i_pos_shift[i_b, i_l, i] = links_state.i_pos_shift[i_l, i_b][i] for i in ti.static(range(4)): - links_quat[i_b, i_l, i] = links_state.quat[i_l, i_b][i] + links_quat[i_b, i_l, i] = links_state.quat[f, i_l, i_b][i] mass_shift[i_b, i_l] = links_state.mass_shift[i_l, i_b] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -6244,6 +6302,7 @@ def kernel_get_state( @gs.maybe_pure @ti.kernel def kernel_set_state( + f: ti.int32, qpos: ti.types.ndarray(), dofs_vel: ti.types.ndarray(), links_pos: ti.types.ndarray(), @@ -6266,19 +6325,19 @@ def kernel_set_state( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q, i_b_ in ti.ndrange(n_qs, envs_idx.shape[0]): - rigid_global_info.qpos[i_q, envs_idx[i_b_]] = qpos[envs_idx[i_b_], i_q] + rigid_global_info.qpos[f, i_q, envs_idx[i_b_]] = qpos[envs_idx[i_b_], i_q] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_d, i_b_ in ti.ndrange(n_dofs, envs_idx.shape[0]): - dofs_state.vel[i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d] + dofs_state.vel[f, i_d, envs_idx[i_b_]] = dofs_vel[envs_idx[i_b_], i_d] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l, i_b_ in ti.ndrange(n_links, envs_idx.shape[0]): for i in ti.static(range(3)): - links_state.pos[i_l, envs_idx[i_b_]][i] = links_pos[envs_idx[i_b_], i_l, i] + links_state.pos[f, i_l, envs_idx[i_b_]][i] = links_pos[envs_idx[i_b_], i_l, i] links_state.i_pos_shift[i_l, envs_idx[i_b_]][i] = i_pos_shift[envs_idx[i_b_], i_l, i] for i in ti.static(range(4)): - links_state.quat[i_l, envs_idx[i_b_]][i] = links_quat[envs_idx[i_b_], i_l, i] + links_state.quat[f, i_l, envs_idx[i_b_]][i] = links_quat[envs_idx[i_b_], i_l, i] links_state.mass_shift[i_l, envs_idx[i_b_]] = mass_shift[envs_idx[i_b_], i_l] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) @@ -6289,6 +6348,7 @@ def kernel_set_state( @gs.maybe_pure @ti.kernel def kernel_set_links_pos( + f: ti.i32, relative: ti.i32, pos: ti.types.ndarray(), links_idx: ti.types.ndarray(), @@ -6307,24 +6367,25 @@ def kernel_set_links_pos( if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: for i in ti.static(range(3)): - links_state.pos[i_l, i_b][i] = pos[i_b_, i_l_, i] + links_state.pos[f, i_l, i_b][i] = pos[i_b_, i_l_, i] if relative: for i in ti.static(range(3)): - links_state.pos[i_l, i_b][i] = links_state.pos[i_l, i_b][i] + links_info.pos[I_l][i] + links_state.pos[f, i_l, i_b][i] = links_state.pos[f, i_l, i_b][i] + links_info.pos[I_l][i] else: q_start = links_info.q_start[I_l] for i in ti.static(range(3)): - rigid_global_info.qpos[q_start + i, i_b] = pos[i_b_, i_l_, i] + rigid_global_info.qpos[f, q_start + i, i_b] = pos[i_b_, i_l_, i] if relative: for i in ti.static(range(3)): - rigid_global_info.qpos[q_start + i, i_b] = ( - rigid_global_info.qpos[q_start + i, i_b] + rigid_global_info.qpos0[q_start + i, i_b] + rigid_global_info.qpos[f, q_start + i, i_b] = ( + rigid_global_info.qpos[f, q_start + i, i_b] + rigid_global_info.qpos0[q_start + i, i_b] ) @gs.maybe_pure @ti.kernel def kernel_set_links_quat( + f: ti.i32, relative: ti.i32, quat: ti.types.ndarray(), links_idx: ti.types.ndarray(), @@ -6352,7 +6413,7 @@ def kernel_set_links_quat( dt=gs.ti_float, ) if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: - links_state.quat[i_l, i_b] = gu.ti_transform_quat_by_quat(links_info.quat[I_l], quat_) + links_state.quat[f, i_l, i_b] = gu.ti_transform_quat_by_quat(links_info.quat[I_l], quat_) else: q_start = links_info.q_start[I_l] quat0 = ti.Vector( @@ -6366,15 +6427,15 @@ def kernel_set_links_quat( ) quat_ = gu.ti_transform_quat_by_quat(quat0, quat_) for i in ti.static(range(4)): - rigid_global_info.qpos[q_start + i + 3, i_b] = quat_[i] + rigid_global_info.qpos[f, q_start + i + 3, i_b] = quat_[i] else: if links_info.parent_idx[I_l] == -1 and links_info.is_fixed[I_l]: for i in ti.static(range(4)): - links_state.quat[i_l, i_b][i] = quat[i_b_, i_l_, i] + links_state.quat[f, i_l, i_b][i] = quat[i_b_, i_l_, i] else: q_start = links_info.q_start[I_l] for i in ti.static(range(4)): - rigid_global_info.qpos[q_start + i + 3, i_b] = quat[i_b_, i_l_, i] + rigid_global_info.qpos[f, q_start + i + 3, i_b] = quat[i_b_, i_l_, i] @gs.maybe_pure @@ -6441,6 +6502,7 @@ def kernel_set_geoms_friction_ratio( @gs.maybe_pure @ti.kernel def kernel_set_qpos( + f: ti.i32, qpos: ti.types.ndarray(), qs_idx: ti.types.ndarray(), envs_idx: ti.types.ndarray(), @@ -6449,7 +6511,7 @@ def kernel_set_qpos( ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_q_, i_b_ in ti.ndrange(qs_idx.shape[0], envs_idx.shape[0]): - rigid_global_info.qpos[qs_idx[i_q_], envs_idx[i_b_]] = qpos[i_b_, i_q_] + rigid_global_info.qpos[f, qs_idx[i_q_], envs_idx[i_b_]] = qpos[i_b_, i_q_] @gs.maybe_pure @@ -6670,6 +6732,7 @@ def kernel_set_dofs_limit( @gs.maybe_pure @ti.kernel def kernel_set_dofs_velocity( + f: ti.int32, velocity: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), envs_idx: ti.types.ndarray(), @@ -6678,12 +6741,13 @@ def kernel_set_dofs_velocity( ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): - dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = velocity[i_b_, i_d_] + dofs_state.vel[f, dofs_idx[i_d_], envs_idx[i_b_]] = velocity[i_b_, i_d_] @gs.maybe_pure @ti.kernel def kernel_set_dofs_zero_velocity( + f: ti.int32, dofs_idx: ti.types.ndarray(), envs_idx: ti.types.ndarray(), dofs_state: array_class.DofsState, @@ -6691,12 +6755,13 @@ def kernel_set_dofs_zero_velocity( ): ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): - dofs_state.vel[dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 + dofs_state.vel[f, dofs_idx[i_d_], envs_idx[i_b_]] = 0.0 @gs.maybe_pure @ti.kernel def kernel_set_dofs_position( + i_f: ti.i32, position: ti.types.ndarray(), dofs_idx: ti.types.ndarray(), envs_idx: ti.types.ndarray(), @@ -6711,7 +6776,7 @@ def kernel_set_dofs_position( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_d_, i_b_ in ti.ndrange(dofs_idx.shape[0], envs_idx.shape[0]): - dofs_state.pos[dofs_idx[i_d_], envs_idx[i_b_]] = position[i_b_, i_d_] + dofs_state.pos[i_f, dofs_idx[i_d_], envs_idx[i_b_]] = position[i_b_, i_d_] # Note that qpos must be updated, as dofs_state.pos is not used for actual IK. # TODO: Make this more efficient by only taking care of releavant qs/dofs. @@ -6735,37 +6800,39 @@ def kernel_set_dofs_position( elif joint_type == gs.JOINT_TYPE.FREE: xyz = ti.Vector( [ - dofs_state.pos[0 + 3 + dof_start, i_b], - dofs_state.pos[1 + 3 + dof_start, i_b], - dofs_state.pos[2 + 3 + dof_start, i_b], + dofs_state.pos[i_f, 0 + 3 + dof_start, i_b], + dofs_state.pos[i_f, 1 + 3 + dof_start, i_b], + dofs_state.pos[i_f, 2 + 3 + dof_start, i_b], ], dt=gs.ti_float, ) quat = gu.ti_xyz_to_quat(xyz) for i in ti.static(range(3)): - rigid_global_info.qpos[i + q_start, i_b] = dofs_state.pos[i + dof_start, i_b] + rigid_global_info.qpos[i_f, i + q_start, i_b] = dofs_state.pos[i_f, i + dof_start, i_b] for i in ti.static(range(4)): - rigid_global_info.qpos[i + 3 + q_start, i_b] = quat[i] + rigid_global_info.qpos[i_f, i + 3 + q_start, i_b] = quat[i] elif joint_type == gs.JOINT_TYPE.SPHERICAL: xyz = ti.Vector( [ - dofs_state.pos[0 + dof_start, i_b], - dofs_state.pos[1 + dof_start, i_b], - dofs_state.pos[2 + dof_start, i_b], + dofs_state.pos[i_f, 0 + dof_start, i_b], + dofs_state.pos[i_f, 1 + dof_start, i_b], + dofs_state.pos[i_f, 2 + dof_start, i_b], ], dt=gs.ti_float, ) quat = gu.ti_xyz_to_quat(xyz) for i_q_ in ti.static(range(4)): i_q = q_start + i_q_ - rigid_global_info.qpos[i_q, i_b] = quat[i_q_] + rigid_global_info.qpos[i_f, i_q, i_b] = quat[i_q_] else: # (gs.JOINT_TYPE.REVOLUTE, gs.JOINT_TYPE.PRISMATIC) for i_d_ in range(links_info.dof_end[I_l] - dof_start): i_q = q_start + i_d_ i_d = dof_start + i_d_ - rigid_global_info.qpos[i_q, i_b] = rigid_global_info.qpos0[i_q, i_b] + dofs_state.pos[i_d, i_b] + rigid_global_info.qpos[i_f, i_q, i_b] = ( + rigid_global_info.qpos0[i_q, i_b] + dofs_state.pos[i_f, i_d, i_b] + ) @gs.maybe_pure @@ -6993,3 +7060,67 @@ def kernel_set_geoms_friction( ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL) for i_g_ in ti.ndrange(geoms_idx.shape[0]): geoms_info.friction[geoms_idx[i_g_]] = friction[i_g_] + + +@gs.maybe_pure +@ti.kernel +def kernel_copy_frame( + source: ti.i32, + target: ti.i32, + links_state: array_class.LinksState, + joints_state: array_class.JointsState, + dofs_state: array_class.DofsState, + geoms_state: array_class.GeomsState, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + """ + Copy state needed to proceed to the next frame from source to target frame. + """ + # rigid_global_info + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_q, i_b in ti.ndrange(rigid_global_info.qpos.shape[1], rigid_global_info.qpos.shape[2]): + rigid_global_info.qpos[target, i_q, i_b] = rigid_global_info.qpos[source, i_q, i_b] + + # links_state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_l, i_b in ti.ndrange(links_state.pos.shape[1], links_state.pos.shape[2]): + links_state.pos[target, i_l, i_b] = links_state.pos[source, i_l, i_b] + links_state.quat[target, i_l, i_b] = links_state.quat[source, i_l, i_b] + links_state.root_COM[target, i_l, i_b] = links_state.root_COM[source, i_l, i_b] + links_state.mass_sum[target, i_l, i_b] = links_state.mass_sum[source, i_l, i_b] + links_state.i_pos[target, i_l, i_b] = links_state.i_pos[source, i_l, i_b] + links_state.i_quat[target, i_l, i_b] = links_state.i_quat[source, i_l, i_b] + links_state.cinr_inertial[target, i_l, i_b] = links_state.cinr_inertial[source, i_l, i_b] + links_state.cinr_pos[target, i_l, i_b] = links_state.cinr_pos[source, i_l, i_b] + links_state.cinr_quat[target, i_l, i_b] = links_state.cinr_quat[source, i_l, i_b] + links_state.cinr_mass[target, i_l, i_b] = links_state.cinr_mass[source, i_l, i_b] + links_state.j_pos[target, i_l, i_b] = links_state.j_pos[source, i_l, i_b] + links_state.j_quat[target, i_l, i_b] = links_state.j_quat[source, i_l, i_b] + links_state.cd_vel[target, i_l, i_b] = links_state.cd_vel[source, i_l, i_b] + links_state.cd_ang[target, i_l, i_b] = links_state.cd_ang[source, i_l, i_b] + + # joints_state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_j, i_b in ti.ndrange(joints_state.xanchor.shape[1], joints_state.xanchor.shape[2]): + joints_state.xanchor[target, i_j, i_b] = joints_state.xanchor[source, i_j, i_b] + joints_state.xaxis[target, i_j, i_b] = joints_state.xaxis[source, i_j, i_b] + + # dofs_state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in ti.ndrange(dofs_state.pos.shape[1], dofs_state.pos.shape[2]): + dofs_state.pos[target, i_d, i_b] = dofs_state.pos[source, i_d, i_b] + dofs_state.vel[target, i_d, i_b] = dofs_state.vel[source, i_d, i_b] + dofs_state.cdof_ang[target, i_d, i_b] = dofs_state.cdof_ang[source, i_d, i_b] + dofs_state.cdof_vel[target, i_d, i_b] = dofs_state.cdof_vel[source, i_d, i_b] + dofs_state.cdofvel_ang[target, i_d, i_b] = dofs_state.cdofvel_ang[source, i_d, i_b] + dofs_state.cdofvel_vel[target, i_d, i_b] = dofs_state.cdofvel_vel[source, i_d, i_b] + dofs_state.cdofd_ang[target, i_d, i_b] = dofs_state.cdofd_ang[source, i_d, i_b] + dofs_state.cdofd_vel[target, i_d, i_b] = dofs_state.cdofd_vel[source, i_d, i_b] + + # geoms_state + ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_g, i_b in ti.ndrange(geoms_state.pos.shape[1], geoms_state.pos.shape[2]): + geoms_state.pos[target, i_g, i_b] = geoms_state.pos[source, i_g, i_b] + geoms_state.quat[target, i_g, i_b] = geoms_state.quat[source, i_g, i_b] + geoms_state.verts_updated[target, i_g, i_b] = geoms_state.verts_updated[source, i_g, i_b] diff --git a/genesis/engine/solvers/rigid/support_field_decomp.py b/genesis/engine/solvers/rigid/support_field_decomp.py index 878727592..b101d0189 100644 --- a/genesis/engine/solvers/rigid/support_field_decomp.py +++ b/genesis/engine/solvers/rigid/support_field_decomp.py @@ -129,6 +129,7 @@ def _func_support_world( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), d, + i_f, i_g, i_b, ): @@ -136,8 +137,8 @@ def _func_support_world( support position for a world direction """ - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] d_mesh = gu.ti_transform_by_quat(d, gu.ti_inv_quat(g_quat)) v, vid = _func_support_mesh(support_field_info, support_field_static_config, d_mesh, i_g) v_ = gu.ti_transform_by_trans_quat(v, g_pos, g_quat) @@ -199,11 +200,12 @@ def _func_support_sphere( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, d, + i_f, i_g, i_b, shrink, ): - sphere_center = geoms_state.pos[i_g, i_b] + sphere_center = geoms_state.pos[i_f, i_g, i_b] sphere_radius = geoms_info.data[i_g][0] # Shrink the sphere to a point @@ -218,10 +220,11 @@ def _func_support_ellipsoid( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, d, + i_f, i_g, i_b, ): - ellipsoid_center = geoms_state.pos[i_g, i_b] + ellipsoid_center = geoms_state.pos[i_f, i_g, i_b] ellipsoid_scaled_axis = ti.Vector( [ geoms_info.data[i_g][0] ** 2, @@ -230,7 +233,7 @@ def _func_support_ellipsoid( ], dt=gs.ti_float, ) - ellipsoid_scaled_axis = gu.ti_transform_by_quat(ellipsoid_scaled_axis, geoms_state.quat[i_g, i_b]) + ellipsoid_scaled_axis = gu.ti_transform_by_quat(ellipsoid_scaled_axis, geoms_state.quat[i_f, i_g, i_b]) dist = ellipsoid_scaled_axis / ti.sqrt(d.dot(1.0 / ellipsoid_scaled_axis)) return ellipsoid_center + d * dist @@ -240,13 +243,14 @@ def _func_support_capsule( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, d, + i_f, i_g, i_b, shrink, ): res = gs.ti_vec3(0, 0, 0) - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] capsule_center = g_pos capsule_radius = geoms_info.data[i_g][0] capsule_halflength = 0.5 * geoms_info.data[i_g][1] @@ -267,6 +271,7 @@ def _func_support_capsule( def _func_support_prism( collider_state: array_class.ColliderState, d, + i_f, i_g, i_b, ): @@ -290,11 +295,12 @@ def _func_support_box( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, d, + i_f, i_g, i_b, ): - g_pos = geoms_state.pos[i_g, i_b] - g_quat = geoms_state.quat[i_g, i_b] + g_pos = geoms_state.pos[i_f, i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] d_box = gu.ti_inv_transform_by_quat(d, g_quat) v_ = ti.Vector( @@ -318,13 +324,14 @@ def _func_count_supports_world( support_field_info: array_class.SupportFieldInfo, support_field_static_config: ti.template(), d, + i_f, i_g, i_b, ): """ Count the number of valid support points for the given world direction. """ - d_mesh = gu.ti_transform_by_quat(d, gu.ti_inv_quat(geoms_state.quat[i_g, i_b])) + d_mesh = gu.ti_transform_by_quat(d, gu.ti_inv_quat(geoms_state.quat[i_f, i_g, i_b])) return _func_count_supports_mesh( geoms_state, geoms_info, support_field_info, support_field_static_config, d_mesh, i_g ) @@ -386,6 +393,7 @@ def _func_count_supports_box( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, d, + i_f, i_g, i_b, ): @@ -395,7 +403,7 @@ def _func_count_supports_box( If the direction has 1 zero component, there are 2 possible support points. If the direction has 2 zero components, there are 4 possible support points. """ - g_quat = geoms_state.quat[i_g, i_b] + g_quat = geoms_state.quat[i_f, i_g, i_b] d_box = gu.ti_inv_transform_by_quat(d, g_quat) return 2 ** (d_box == 0.0).cast(gs.ti_int).sum() diff --git a/genesis/engine/states/entities.py b/genesis/engine/states/entities.py index 4c6bbad5a..ce15fe061 100644 --- a/genesis/engine/states/entities.py +++ b/genesis/engine/states/entities.py @@ -188,3 +188,45 @@ def vel(self): @property def active(self): return self._active + + +class RigidEntityState(RBC): + """ + Dynamic state queried from a genesis RigidEntity. + """ + + def __init__(self, entity, s_global): + self._entity = entity + self._s_global = s_global + + args = { + "dtype": gs.tc_float, + "requires_grad": self._entity.scene.requires_grad, + "scene": self._entity.scene, + } + + num_batch = self._entity._solver._B + self._pos = gs.zeros((num_batch, 3), **args) + self._quat = gs.zeros((num_batch, 4), **args) + + def serializable(self): + self._entity = None + + self._pos = self._pos.detach() + self._quat = self._quat.detach() + + @property + def entity(self): + return self._entity + + @property + def s_global(self): + return self._s_global + + @property + def pos(self): + return self._pos + + @property + def quat(self): + return self._quat diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 5ffe91ad7..a858845f5 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -44,6 +44,7 @@ class StructRigidGlobalInfo: def get_rigid_global_info(solver): f_batch = solver._batch_shape + n_frame = solver._sim.substeps_local + 1 # Basic fields kwargs = { @@ -54,14 +55,22 @@ def get_rigid_global_info(solver): "n_awake_links": V(dtype=gs.ti_int, shape=f_batch()), "awake_links": V(dtype=gs.ti_int, shape=f_batch(solver.n_links)), "qpos0": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_qs_)), - "qpos": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_qs_)), + "qpos": V(dtype=gs.ti_float, shape=solver._batch_shape((n_frame, solver.n_qs_))), "links_T": V_MAT(n=4, m=4, dtype=gs.ti_float, shape=solver.n_links), "envs_offset": V_VEC(3, dtype=gs.ti_float, shape=f_batch()), "geoms_init_AABB": V_VEC(3, dtype=gs.ti_float, shape=(solver.n_geoms_, 8)), - "mass_mat": V(dtype=gs.ti_float, shape=solver._batch_shape((solver.n_dofs_, solver.n_dofs_))), - "mass_mat_L": V(dtype=gs.ti_float, shape=solver._batch_shape((solver.n_dofs_, solver.n_dofs_))), - "mass_mat_D_inv": V(dtype=gs.ti_float, shape=solver._batch_shape((solver.n_dofs_,))), - "_mass_mat_mask": V(dtype=gs.ti_int, shape=solver._batch_shape(solver.n_entities_)), + "mass_mat": V(dtype=gs.ti_float, shape=solver._batch_shape((n_frame, solver.n_dofs_, solver.n_dofs_))), + "mass_mat_L": V(dtype=gs.ti_float, shape=solver._batch_shape((n_frame, solver.n_dofs_, solver.n_dofs_))), + "mass_mat_D_inv": V( + dtype=gs.ti_float, + shape=solver._batch_shape( + ( + n_frame, + solver.n_dofs_, + ) + ), + ), + "_mass_mat_mask": V(dtype=gs.ti_int, shape=solver._batch_shape((n_frame, solver.n_entities_))), "meaninertia": V(dtype=gs.ti_float, shape=solver._batch_shape()), "mass_parent_mask": V(dtype=gs.ti_float, shape=(solver.n_dofs_, solver.n_dofs_)), "gravity": V_VEC(3, dtype=gs.ti_float, shape=f_batch()), @@ -136,6 +145,7 @@ class StructConstraintState: qfrc_constraint: V_ANNOTATION qacc: V_ANNOTATION qacc_ws: V_ANNOTATION + qacc_smooth: V_ANNOTATION qacc_prev: V_ANNOTATION cost_ws: V_ANNOTATION gauss: V_ANNOTATION @@ -210,6 +220,7 @@ def get_constraint_state(constraint_solver, solver): "qfrc_constraint": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_dofs_)), "qacc": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_dofs_)), "qacc_ws": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_dofs_)), + "qacc_smooth": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_dofs_)), "qacc_prev": V(dtype=gs.ti_float, shape=solver._batch_shape(solver.n_dofs_)), "cost_ws": V(gs.ti_float, shape=solver._batch_shape()), "gauss": V(gs.ti_float, shape=solver._batch_shape()), @@ -1252,7 +1263,7 @@ class StructDofsState: def get_dofs_state(solver): - shape = solver._batch_shape(solver.n_dofs_) + shape = solver._batch_shape((solver._sim.substeps_local + 1, solver.n_dofs_)) kwargs = { "force": V(dtype=gs.ti_float, shape=shape), "qf_bias": V(dtype=gs.ti_float, shape=shape), @@ -1335,7 +1346,8 @@ class StructLinksState: def get_links_state(solver): - shape = solver._batch_shape(solver.n_links_) + shape = solver._batch_shape((solver._sim.substeps_local + 1, solver.n_links_)) + shape_ = solver._batch_shape((solver.n_links_,)) kwargs = { "cinr_inertial": V(dtype=gs.ti_mat3, shape=shape), "cinr_pos": V(dtype=gs.ti_vec3, shape=shape), @@ -1359,8 +1371,6 @@ def get_links_state(solver): "cd_vel": V(dtype=gs.ti_vec3, shape=shape), "mass_sum": V(dtype=gs.ti_float, shape=shape), "root_COM": V(dtype=gs.ti_vec3, shape=shape), - "mass_shift": V(dtype=gs.ti_float, shape=shape), - "i_pos_shift": V(dtype=gs.ti_vec3, shape=shape), "cacc_ang": V(dtype=gs.ti_vec3, shape=shape), "cacc_lin": V(dtype=gs.ti_vec3, shape=shape), "cfrc_ang": V(dtype=gs.ti_vec3, shape=shape), @@ -1369,6 +1379,9 @@ def get_links_state(solver): "cfrc_applied_vel": V(dtype=gs.ti_vec3, shape=shape), "contact_force": V(dtype=gs.ti_vec3, shape=shape), "hibernated": V(dtype=gs.ti_int, shape=shape), + # These are only updated by user input + "mass_shift": V(dtype=gs.ti_float, shape=shape_), + "i_pos_shift": V(dtype=gs.ti_vec3, shape=shape_), } if use_ndarray: @@ -1490,7 +1503,7 @@ class StructJointsState: def get_joints_state(solver): - shape = solver._batch_shape(solver.n_joints_) + shape = solver._batch_shape((solver._sim.substeps_local + 1, solver.n_joints_)) kwargs = { "xanchor": V(dtype=gs.ti_vec3, shape=shape), "xaxis": V(dtype=gs.ti_vec3, shape=shape), @@ -1600,11 +1613,13 @@ class StructGeomsState: min_buffer_idx: V_ANNOTATION max_buffer_idx: V_ANNOTATION hibernated: V_ANNOTATION + # These are only updated by user input friction_ratio: V_ANNOTATION def get_geoms_state(solver): - shape = solver._batch_shape(solver.n_geoms_) + shape = solver._batch_shape((solver._sim.substeps_local + 1, solver.n_geoms_)) + shape_ = solver._batch_shape((solver.n_geoms_,)) kwargs = { "pos": V(dtype=gs.ti_vec3, shape=shape), "quat": V(dtype=gs.ti_vec4, shape=shape), @@ -1614,7 +1629,7 @@ def get_geoms_state(solver): "min_buffer_idx": V(dtype=gs.ti_int, shape=shape), "max_buffer_idx": V(dtype=gs.ti_int, shape=shape), "hibernated": V(dtype=gs.ti_int, shape=shape), - "friction_ratio": V(dtype=gs.ti_float, shape=shape), + "friction_ratio": V(dtype=gs.ti_float, shape=shape_), } if use_ndarray: @@ -2065,6 +2080,52 @@ def get_static_rigid_sim_cache_key(solver): return StaticRigidSimCacheKey(**kwargs) +# =========================================== RigidAdjointCache =========================================== + + +@dataclasses.dataclass +class StructRigidAdjointCache: + forward_kinematics_joint_pos: V_ANNOTATION + forward_kinematics_joint_quat: V_ANNOTATION + i_pos: V_ANNOTATION + j_pos: V_ANNOTATION + j_quat: V_ANNOTATION + root_COM: V_ANNOTATION + mass_mat_L0: V_ANNOTATION + mass_mat_L1: V_ANNOTATION + + +def get_rigid_adjoint_cache(solver): + n_frame = solver._sim.substeps_local + 1 + + kwargs = { + "forward_kinematics_joint_pos_in": V(dtype=gs.ti_vec3, shape=solver._batch_shape((n_frame, solver.n_joints_))), + "forward_kinematics_joint_pos_out": V(dtype=gs.ti_vec3, shape=solver._batch_shape((n_frame, solver.n_joints_))), + "forward_kinematics_joint_quat_in": V(dtype=gs.ti_vec4, shape=solver._batch_shape((n_frame, solver.n_joints_))), + "forward_kinematics_joint_quat_out": V( + dtype=gs.ti_vec4, shape=solver._batch_shape((n_frame, solver.n_joints_)) + ), + "i_pos": V(dtype=gs.ti_vec3, shape=solver._batch_shape((n_frame, solver.n_links_))), + "j_pos": V(dtype=gs.ti_vec3, shape=solver._batch_shape((n_frame, solver.n_links_))), + "j_quat": V(dtype=gs.ti_vec4, shape=solver._batch_shape((n_frame, solver.n_links_))), + "root_COM": V(dtype=gs.ti_vec3, shape=solver._batch_shape((n_frame, solver.n_links_))), + "mass_mat_L0": V(dtype=gs.ti_float, shape=solver._batch_shape((n_frame, solver.n_dofs_, solver.n_dofs_))), + "mass_mat_L1": V(dtype=gs.ti_float, shape=solver._batch_shape((n_frame, solver.n_dofs_, solver.n_dofs_))), + } + + if use_ndarray: + return StructRigidAdjointCache(**kwargs) + else: + + @ti.data_oriented + class ClassRigidAdjointCache: + def __init__(self): + for k, v in kwargs.items(): + setattr(self, k, v) + + return ClassRigidAdjointCache() + + # =========================================== DataManager =========================================== @@ -2100,6 +2161,8 @@ def __init__(self, solver): self.entities_info = get_entities_info(solver) self.entities_state = get_entities_state(solver) + self.rigid_adjoint_cache = get_rigid_adjoint_cache(solver) + # we will use struct for DofsState and DofsInfo after Hugh adds array_struct feature to gstaichi DofsState = ti.template() if not use_ndarray else StructDofsState @@ -2132,3 +2195,4 @@ def __init__(self, solver): GJKState = ti.template() if not use_ndarray else StructGJKState SDFInfo = ti.template() if not use_ndarray else StructSDFInfo ContactIslandState = ti.template() if not use_ndarray else StructContactIslandState +RigidAdjointCache = ti.template() if not use_ndarray else StructRigidAdjointCache diff --git a/genesis/utils/sdf_decomp.py b/genesis/utils/sdf_decomp.py index 8f8499e54..dce491391 100644 --- a/genesis/utils/sdf_decomp.py +++ b/genesis/utils/sdf_decomp.py @@ -73,6 +73,7 @@ def sdf_func_world( geoms_info: array_class.GeomsInfo, sdf_info: array_class.SDFInfo, pos_world, + frame_idx, geom_idx, batch_idx, ): @@ -80,8 +81,8 @@ def sdf_func_world( sdf value from world coordinate """ - g_pos = geoms_state.pos[geom_idx, batch_idx] - g_quat = geoms_state.quat[geom_idx, batch_idx] + g_pos = geoms_state.pos[frame_idx, geom_idx, batch_idx] + g_quat = geoms_state.quat[frame_idx, geom_idx, batch_idx] sd = gs.ti_float(0.0) if geoms_info.type[geom_idx] == gs.GEOM_TYPE.SPHERE: @@ -168,11 +169,12 @@ def sdf_func_grad_world( collider_static_config: ti.template(), sdf_info: array_class.SDFInfo, pos_world, + frame_idx, geom_idx, batch_idx, ): - g_pos = geoms_state.pos[geom_idx, batch_idx] - g_quat = geoms_state.quat[geom_idx, batch_idx] + g_pos = geoms_state.pos[frame_idx, geom_idx, batch_idx] + g_quat = geoms_state.quat[frame_idx, geom_idx, batch_idx] grad_world = ti.Vector.zero(gs.ti_float, 3) if geoms_info.type[geom_idx] == gs.GEOM_TYPE.SPHERE: @@ -272,11 +274,14 @@ def sdf_func_normal_world( collider_static_config: ti.template(), sdf_info: array_class.SDFInfo, pos_world, + frame_idx, geom_idx, batch_idx, ): return gu.ti_normalize( - sdf_func_grad_world(geoms_state, geoms_info, collider_static_config, sdf_info, pos_world, geom_idx, batch_idx) + sdf_func_grad_world( + geoms_state, geoms_info, collider_static_config, sdf_info, pos_world, frame_idx, geom_idx, batch_idx + ) ) diff --git a/tests/utils.py b/tests/utils.py index 0d63b93a5..736217117 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -286,17 +286,17 @@ def init_simulators(gs_sim, mj_sim=None, qpos=None, qvel=None): gs_robot.set_dofs_velocity(qvel) # TODO: This should be moved in `set_state`, `set_qpos`, `set_dofs_position`, `set_dofs_velocity` gs_sim.rigid_solver.dofs_state.qf_constraint.fill(0.0) - gs_sim.rigid_solver._func_forward_dynamics() - gs_sim.rigid_solver._func_constraint_force() - gs_sim.rigid_solver._func_update_acc() + gs_sim.rigid_solver._func_forward_dynamics(0) + gs_sim.rigid_solver._func_constraint_force(0) + gs_sim.rigid_solver._func_update_acc(0) if gs_sim.scene.visualizer: gs_sim.scene.visualizer.update() if mj_sim is not None: mujoco.mj_resetData(mj_sim.model, mj_sim.data) - mj_sim.data.qpos[mj_qs_idx] = gs_sim.rigid_solver.qpos.to_numpy()[:, 0] - mj_sim.data.qvel[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[:, 0] + mj_sim.data.qpos[mj_qs_idx] = gs_sim.rigid_solver.qpos.to_numpy()[0, :, 0] + mj_sim.data.qvel[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[0, :, 0] mujoco.mj_forward(mj_sim.model, mj_sim.data) @@ -771,19 +771,19 @@ def check_mujoco_data_consistency( (mj_bodies_idx, _, mj_qs_idx, mj_dofs_idx, _, _) = mj_maps # crb - gs_crb_inertial = gs_sim.rigid_solver.links_state.crb_inertial.to_numpy()[:, 0].reshape([-1, 9])[ + gs_crb_inertial = gs_sim.rigid_solver.links_state.crb_inertial.to_numpy()[0, :, 0].reshape([-1, 9])[ :, [0, 4, 8, 1, 2, 5] ] mj_crb_inertial = mj_sim.data.crb[:, :6] # upper-triangular part assert_allclose(gs_crb_inertial[gs_bodies_idx], mj_crb_inertial[mj_bodies_idx], tol=tol) - gs_crb_pos = gs_sim.rigid_solver.links_state.crb_pos.to_numpy()[:, 0] + gs_crb_pos = gs_sim.rigid_solver.links_state.crb_pos.to_numpy()[0, :, 0] mj_crb_pos = mj_sim.data.crb[:, 6:9] assert_allclose(gs_crb_pos[gs_bodies_idx], mj_crb_pos[mj_bodies_idx], tol=tol) - gs_crb_mass = gs_sim.rigid_solver.links_state.crb_mass.to_numpy()[:, 0] + gs_crb_mass = gs_sim.rigid_solver.links_state.crb_mass.to_numpy()[0, :, 0] mj_crb_mass = mj_sim.data.crb[:, 9] assert_allclose(gs_crb_mass[gs_bodies_idx], mj_crb_mass[mj_bodies_idx], tol=tol) - gs_mass_mat = gs_sim.rigid_solver.mass_mat.to_numpy()[:, :, 0] + gs_mass_mat = gs_sim.rigid_solver.mass_mat.to_numpy()[0, :, :, 0] mj_mass_mat = np.zeros((mj_sim.model.nv, mj_sim.model.nv)) mujoco.mj_fullM(mj_sim.model, mj_mass_mat, mj_sim.data.qM) assert_allclose(gs_mass_mat[gs_dofs_idx][:, gs_dofs_idx], mj_mass_mat[mj_dofs_idx][:, mj_dofs_idx], tol=tol) @@ -793,13 +793,13 @@ def check_mujoco_data_consistency( assert_allclose(gs_meaninertia, mj_meaninertia, tol=tol) # Pre-constraint so-called bias forces in configuration space - gs_qfrc_bias = gs_sim.rigid_solver.dofs_state.qf_bias.to_numpy()[:, 0] + gs_qfrc_bias = gs_sim.rigid_solver.dofs_state.qf_bias.to_numpy()[0, :, 0] mj_qfrc_bias = mj_sim.data.qfrc_bias assert_allclose(gs_qfrc_bias, mj_qfrc_bias[mj_dofs_idx], tol=tol) - gs_qfrc_passive = gs_sim.rigid_solver.dofs_state.qf_passive.to_numpy()[:, 0] + gs_qfrc_passive = gs_sim.rigid_solver.dofs_state.qf_passive.to_numpy()[0, :, 0] mj_qfrc_passive = mj_sim.data.qfrc_passive assert_allclose(gs_qfrc_passive, mj_qfrc_passive[mj_dofs_idx], tol=tol) - gs_qfrc_actuator = gs_sim.rigid_solver.dofs_state.qf_applied.to_numpy()[:, 0] + gs_qfrc_actuator = gs_sim.rigid_solver.dofs_state.qf_applied.to_numpy()[0, :, 0] mj_qfrc_actuator = mj_sim.data.qfrc_actuator assert_allclose(gs_qfrc_actuator, mj_qfrc_actuator[mj_dofs_idx], tol=tol) @@ -895,19 +895,20 @@ def check_mujoco_data_consistency( mj_efc_vel = mj_sim.data.efc_vel assert_allclose(gs_efc_vel[gs_sidx], mj_efc_vel[mj_sidx], tol=tol) - gs_qfrc_constraint = gs_sim.rigid_solver.dofs_state.qf_constraint.to_numpy()[:, 0] + gs_qfrc_constraint = gs_sim.rigid_solver.dofs_state.qf_constraint.to_numpy()[0, :, 0] mj_qfrc_constraint = mj_sim.data.qfrc_constraint assert_allclose(gs_qfrc_constraint[gs_dofs_idx], mj_qfrc_constraint[mj_dofs_idx], tol=tol) - gs_qfrc_all = gs_sim.rigid_solver.dofs_state.force.to_numpy()[:, 0] + gs_qfrc_all = gs_sim.rigid_solver.dofs_state.force.to_numpy()[0, :, 0] mj_qfrc_all = mj_sim.data.qfrc_smooth + mj_sim.data.qfrc_constraint assert_allclose(gs_qfrc_all[gs_dofs_idx], mj_qfrc_all[mj_dofs_idx], tol=tol) - gs_qfrc_smooth = gs_sim.rigid_solver.dofs_state.qf_smooth.to_numpy()[:, 0] + gs_qfrc_smooth = gs_sim.rigid_solver.dofs_state.qf_smooth.to_numpy()[0, :, 0] mj_qfrc_smooth = mj_sim.data.qfrc_smooth assert_allclose(gs_qfrc_smooth[gs_dofs_idx], mj_qfrc_smooth[mj_dofs_idx], tol=tol) - gs_qacc_smooth = gs_sim.rigid_solver.dofs_state.acc_smooth.to_numpy()[:, 0] + gs_qacc_smooth = gs_sim.rigid_solver.dofs_state.acc_smooth.to_numpy()[0, :, 0] + print(gs_sim.rigid_solver.dofs_state.acc_smooth.to_numpy()) mj_qacc_smooth = mj_sim.data.qacc_smooth assert_allclose(gs_qacc_smooth[gs_dofs_idx], mj_qacc_smooth[mj_dofs_idx], tol=tol) @@ -920,66 +921,66 @@ def check_mujoco_data_consistency( mj_qacc_pre = mj_sim.data.qacc assert_allclose(gs_qacc_pre[gs_dofs_idx], mj_qacc_pre[mj_dofs_idx], tol=tol) - gs_qvel = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[:, 0] + gs_qvel = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[0, :, 0] mj_qvel = mj_sim.data.qvel assert_allclose(gs_qvel[gs_dofs_idx], mj_qvel[mj_dofs_idx], tol=tol) - gs_qpos = gs_sim.rigid_solver.qpos.to_numpy()[:, 0] + gs_qpos = gs_sim.rigid_solver.qpos.to_numpy()[0, :, 0] mj_qpos = mj_sim.data.qpos assert_allclose(gs_qpos[gs_q_idx], mj_qpos[mj_qs_idx], tol=tol) # ------------------------------------------------------------------------ - gs_com = gs_sim.rigid_solver.links_state.root_COM.to_numpy()[:, 0] + gs_com = gs_sim.rigid_solver.links_state.root_COM.to_numpy()[0, :, 0] gs_root_idx = np.unique(gs_sim.rigid_solver.links_info.root_idx.to_numpy()[gs_bodies_idx]) mj_com = mj_sim.data.subtree_com mj_root_idx = np.unique(mj_sim.model.body_rootid[mj_bodies_idx]) assert_allclose(gs_com[gs_root_idx], mj_com[mj_root_idx], tol=tol) - gs_xipos = gs_sim.rigid_solver.links_state.i_pos.to_numpy()[:, 0] + gs_xipos = gs_sim.rigid_solver.links_state.i_pos.to_numpy()[0, :, 0] mj_xipos = mj_sim.data.xipos - mj_sim.data.subtree_com[mj_sim.model.body_rootid] assert_allclose(gs_xipos[gs_bodies_idx], mj_xipos[mj_bodies_idx], tol=tol) - gs_xpos = gs_sim.rigid_solver.links_state.pos.to_numpy()[:, 0] + gs_xpos = gs_sim.rigid_solver.links_state.pos.to_numpy()[0, :, 0] mj_xpos = mj_sim.data.xpos assert_allclose(gs_xpos[gs_bodies_idx], mj_xpos[mj_bodies_idx], tol=tol) - gs_xquat = gs_sim.rigid_solver.links_state.quat.to_numpy()[:, 0] + gs_xquat = gs_sim.rigid_solver.links_state.quat.to_numpy()[0, :, 0] gs_xmat = gu.quat_to_R(gs_xquat).reshape([-1, 9]) mj_xmat = mj_sim.data.xmat assert_allclose(gs_xmat[gs_bodies_idx], mj_xmat[mj_bodies_idx], tol=tol) - gs_cd_vel = gs_sim.rigid_solver.links_state.cd_vel.to_numpy()[:, 0] + gs_cd_vel = gs_sim.rigid_solver.links_state.cd_vel.to_numpy()[0, :, 0] mj_cd_vel = mj_sim.data.cvel[:, 3:] assert_allclose(gs_cd_vel[gs_bodies_idx], mj_cd_vel[mj_bodies_idx], tol=tol) - gs_cd_ang = gs_sim.rigid_solver.links_state.cd_ang.to_numpy()[:, 0] + gs_cd_ang = gs_sim.rigid_solver.links_state.cd_ang.to_numpy()[0, :, 0] mj_cd_ang = mj_sim.data.cvel[:, :3] assert_allclose(gs_cd_ang[gs_bodies_idx], mj_cd_ang[mj_bodies_idx], tol=tol) - gs_cdof_vel = gs_sim.rigid_solver.dofs_state.cdof_vel.to_numpy()[:, 0] + gs_cdof_vel = gs_sim.rigid_solver.dofs_state.cdof_vel.to_numpy()[0, :, 0] mj_cdof_vel = mj_sim.data.cdof[:, 3:] assert_allclose(gs_cdof_vel[gs_dofs_idx], mj_cdof_vel[mj_dofs_idx], tol=tol) - gs_cdof_ang = gs_sim.rigid_solver.dofs_state.cdof_ang.to_numpy()[:, 0] + gs_cdof_ang = gs_sim.rigid_solver.dofs_state.cdof_ang.to_numpy()[0, :, 0] mj_cdof_ang = mj_sim.data.cdof[:, :3] assert_allclose(gs_cdof_ang[gs_dofs_idx], mj_cdof_ang[mj_dofs_idx], tol=tol) mj_cdof_dot_ang = mj_sim.data.cdof_dot[:, :3] - gs_cdof_dot_ang = gs_sim.rigid_solver.dofs_state.cdofd_ang.to_numpy()[:, 0] + gs_cdof_dot_ang = gs_sim.rigid_solver.dofs_state.cdofd_ang.to_numpy()[0, :, 0] assert_allclose(gs_cdof_dot_ang[gs_dofs_idx], mj_cdof_dot_ang[mj_dofs_idx], tol=tol) mj_cdof_dot_vel = mj_sim.data.cdof_dot[:, 3:] - gs_cdof_dot_vel = gs_sim.rigid_solver.dofs_state.cdofd_vel.to_numpy()[:, 0] + gs_cdof_dot_vel = gs_sim.rigid_solver.dofs_state.cdofd_vel.to_numpy()[0, :, 0] assert_allclose(gs_cdof_dot_vel[gs_dofs_idx], mj_cdof_dot_vel[mj_dofs_idx], tol=tol) # cinr - gs_cinr_inertial = gs_sim.rigid_solver.links_state.cinr_inertial.to_numpy()[:, 0].reshape([-1, 9])[ + gs_cinr_inertial = gs_sim.rigid_solver.links_state.cinr_inertial.to_numpy()[0, :, 0].reshape([-1, 9])[ :, [0, 4, 8, 1, 2, 5] ] mj_cinr_inertial = mj_sim.data.cinert[:, :6] # upper-triangular part assert_allclose(gs_cinr_inertial[gs_bodies_idx], mj_cinr_inertial[mj_bodies_idx], tol=tol) - gs_cinr_pos = gs_sim.rigid_solver.links_state.cinr_pos.to_numpy()[:, 0] + gs_cinr_pos = gs_sim.rigid_solver.links_state.cinr_pos.to_numpy()[0, :, 0] mj_cinr_pos = mj_sim.data.cinert[:, 6:9] assert_allclose(gs_cinr_pos[gs_bodies_idx], mj_cinr_pos[mj_bodies_idx], tol=tol) - gs_cinr_mass = gs_sim.rigid_solver.links_state.cinr_mass.to_numpy()[:, 0] + gs_cinr_mass = gs_sim.rigid_solver.links_state.cinr_mass.to_numpy()[0, :, 0] mj_cinr_mass = mj_sim.data.cinert[:, 9] assert_allclose(gs_cinr_mass[gs_bodies_idx], mj_cinr_mass[mj_bodies_idx], tol=tol) @@ -1002,13 +1003,13 @@ def simulate_and_check_mujoco_consistency(gs_sim, mj_sim, qpos=None, qvel=None, check_mujoco_data_consistency(gs_sim, mj_sim, qvel_prev=qvel_prev, tol=tol) # Keep Mujoco and Genesis simulation in sync to avoid drift over time - mj_sim.data.qpos[mj_qs_idx] = gs_sim.rigid_solver.qpos.to_numpy()[:, 0] - mj_sim.data.qvel[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[:, 0] + mj_sim.data.qpos[mj_qs_idx] = gs_sim.rigid_solver.qpos.to_numpy()[0, :, 0] + mj_sim.data.qvel[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[0, :, 0] mj_sim.data.qacc_warmstart[mj_dofs_idx] = gs_sim.rigid_solver.constraint_solver.qacc_ws.to_numpy()[:, 0] - mj_sim.data.qacc_smooth[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.acc_smooth.to_numpy()[:, 0] + mj_sim.data.qacc_smooth[mj_dofs_idx] = gs_sim.rigid_solver.dofs_state.acc_smooth.to_numpy()[0, :, 0] # Backup current velocity - qvel_prev = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[:, 0] + qvel_prev = gs_sim.rigid_solver.dofs_state.vel.to_numpy()[0, :, 0] # Do a single simulation step (eventually with substeps for Genesis) mujoco.mj_step(mj_sim.model, mj_sim.data)