diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 449daac00..3927e9718 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -24,7 +24,7 @@ jobs: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_HUB_DOWNLOAD_TIMEOUT: 60 - GENESIS_IMAGE_VER: "1_3" + GENESIS_IMAGE_VER: "1_4" TIMEOUT_MINUTES: 180 FORCE_COLOR: 1 PY_COLORS: 1 diff --git a/examples/sensors/contact_force_go2.py b/examples/sensors/contact_force_go2.py index 9a9178822..7895f926a 100644 --- a/examples/sensors/contact_force_go2.py +++ b/examples/sensors/contact_force_go2.py @@ -1,4 +1,5 @@ import argparse +import os from tqdm import tqdm @@ -8,11 +9,11 @@ def main(): parser = argparse.ArgumentParser() - parser.add_argument("-dt", "--timestep", type=float, default=1e-2, help="Simulation time step") + parser.add_argument("-dt", "--timestep", type=float, default=0.01, help="Simulation time step") parser.add_argument("-v", "--vis", action="store_true", default=True, help="Show visualization GUI") parser.add_argument("-nv", "--no-vis", action="store_false", dest="vis", help="Disable visualization GUI") parser.add_argument("-c", "--cpu", action="store_true", help="Use CPU instead of GPU") - parser.add_argument("-t", "--seconds", type=float, default=2, help="Number of seconds to simulate") + parser.add_argument("-t", "--seconds", type=float, default=2.0, help="Number of seconds to simulate") parser.add_argument("-f", "--force", action="store_true", default=True, help="Use ContactForceSensor (xyz float)") parser.add_argument("-nf", "--no-force", action="store_false", dest="force", help="Use ContactSensor (boolean)") @@ -23,19 +24,25 @@ def main(): ########################## scene setup ########################## scene = gs.Scene( - sim_options=gs.options.SimOptions(dt=args.timestep), + sim_options=gs.options.SimOptions( + dt=args.timestep, + ), rigid_options=gs.options.RigidOptions( - use_gjk_collision=True, constraint_timeconst=max(0.01, 2 * args.timestep), + use_gjk_collision=True, + ), + vis_options=gs.options.VisOptions( + show_world_frame=True, + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, ), - vis_options=gs.options.VisOptions(show_world_frame=True), - profiling_options=gs.options.ProfilingOptions(show_FPS=False), show_viewer=args.vis, ) scene.add_entity(gs.morphs.Plane()) - foot_link_names = ["FR_foot", "FL_foot", "RR_foot", "RL_foot"] + foot_link_names = ("FR_foot", "FL_foot", "RR_foot", "RL_foot") go2 = scene.add_entity( gs.morphs.URDF( file="urdf/go2/urdf/go2.urdf", @@ -79,7 +86,7 @@ def main(): scene.build() try: - steps = int(args.seconds / args.timestep) + steps = int(args.seconds / args.timestep) if "PYTEST_VERSION" not in os.environ else 5 for _ in tqdm(range(steps)): scene.step() except KeyboardInterrupt: diff --git a/examples/sensors/imu_franka.py b/examples/sensors/imu_franka.py index 85df8c6f4..6cca4e250 100644 --- a/examples/sensors/imu_franka.py +++ b/examples/sensors/imu_franka.py @@ -1,4 +1,5 @@ import argparse +import os import numpy as np from tqdm import tqdm @@ -21,14 +22,20 @@ def main(): ########################## create a scene ########################## scene = gs.Scene( - sim_options=gs.options.SimOptions(dt=args.timestep), - vis_options=gs.options.VisOptions(show_world_frame=False), + sim_options=gs.options.SimOptions( + dt=args.timestep, + ), + vis_options=gs.options.VisOptions( + show_world_frame=False, + ), viewer_options=gs.options.ViewerOptions( camera_pos=(3.5, 0.0, 2.5), camera_lookat=(0.0, 0.0, 0.5), camera_fov=40, ), - profiling_options=gs.options.ProfilingOptions(show_FPS=False), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), show_viewer=args.vis, ) @@ -40,6 +47,7 @@ def main(): end_effector = franka.get_link("hand") motors_dof = (0, 1, 2, 3, 4, 5, 6) + ########################## record sensor data ########################## imu = scene.add_sensor( gs.sensors.IMU( entity_idx=franka.idx, @@ -59,25 +67,38 @@ def main(): draw_debug=True, ) ) - labels = {"lin_acc": ("acc_x", "acc_y", "acc_z"), "ang_vel": ("gyro_x", "gyro_y", "gyro_z")} if args.vis: + xyz = ("x", "y", "z") + labels = {"lin_acc": xyz, "true_lin_acc": xyz, "ang_vel": xyz, "true_ang_vel": xyz} + + def data_func(): + data = imu.read() + true_data = imu.read_ground_truth() + return { + "lin_acc": data.lin_acc, + "true_lin_acc": true_data.lin_acc, + "ang_vel": data.ang_vel, + "true_ang_vel": true_data.ang_vel, + } + if IS_PYQTGRAPH_AVAILABLE: - imu.start_recording(gs.recorders.PyQtLinePlot(title="IMU Measured Data", labels=labels)) scene.start_recording( - imu.read_ground_truth, + data_func, gs.recorders.PyQtLinePlot(title="IMU Ground Truth Data", labels=labels), ) elif IS_MATPLOTLIB_AVAILABLE: gs.logger.info("pyqtgraph not found, falling back to matplotlib.") - imu.start_recording(gs.recorders.MPLLinePlot(title="IMU Measured Data", labels=labels)) scene.start_recording( - imu.read_ground_truth, + data_func, gs.recorders.MPLLinePlot(title="IMU Ground Truth Data", labels=labels), ) else: print("matplotlib or pyqtgraph not found, skipping real-time plotting.") - imu.start_recording(gs.recorders.NPZFile(filename="imu_data.npz")) + scene.start_recording( + data_func=lambda: imu.read()._asdict(), + rec_options=gs.recorders.NPZFile(filename="imu_data.npz"), + ) ########################## build ########################## scene.build() @@ -98,7 +119,7 @@ def main(): rate = np.deg2rad(2.0) try: - steps = int(args.seconds / args.timestep) + steps = int(args.seconds / args.timestep) if "PYTEST_VERSION" not in os.environ else 5 for i in tqdm(range(steps)): scene.step() diff --git a/examples/sensors/lidar_teleop.py b/examples/sensors/lidar_teleop.py new file mode 100644 index 000000000..dc71ac908 --- /dev/null +++ b/examples/sensors/lidar_teleop.py @@ -0,0 +1,242 @@ +import argparse +import os +import threading + +import numpy as np + +import genesis as gs +from genesis.sensors.raycaster.patterns import DepthCameraPattern, GridPattern, SphericalPattern +from genesis.utils.geom import euler_to_quat + +IS_PYNPUT_AVAILABLE = False +try: + from pynput import keyboard + + IS_PYNPUT_AVAILABLE = True +except ImportError: + pass + +# Position and angle increments for keyboard teleop control +KEY_DPOS = 0.1 +KEY_DANGLE = 0.1 + +# Movement when no keyboard control is available +MOVE_RADIUS = 1.0 +MOVE_RATE = 1.0 / 100.0 + +# Number of obstacles to create in a ring around the robot +NUM_CYLINDERS = 8 +NUM_BOXES = 6 +CYLINDER_RING_RADIUS = 3.0 +BOX_RING_RADIUS = 5.0 + + +class KeyboardDevice: + def __init__(self): + self.pressed_keys = set() + self.lock = threading.Lock() + self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) + + def start(self): + self.listener.start() + + def stop(self): + try: + self.listener.stop() + except NotImplementedError: + # Dummy backend does not implement stop + pass + self.listener.join() + + def on_press(self, key: "keyboard.Key"): + with self.lock: + self.pressed_keys.add(key) + + def on_release(self, key: "keyboard.Key"): + with self.lock: + self.pressed_keys.discard(key) + + def get_cmd(self): + return self.pressed_keys + + +def main(): + parser = argparse.ArgumentParser(description="Genesis LiDAR/Depth Camera Visualization with Keyboard Teleop") + parser.add_argument("-B", "--n_envs", type=int, default=0, help="Number of environments to replicate") + parser.add_argument("--cpu", action="store_true", help="Run on CPU instead of GPU") + parser.add_argument("--use-box", action="store_true", help="Use Box as robot instead of Go2") + parser.add_argument("-f", "--fixed", action="store_true", help="Load obstacles as fixed.") + parser.add_argument( + "--pattern", + type=str, + default="spherical", + choices=["spherical", "depth", "grid"], + help="Sensor pattern type", + ) + + args = parser.parse_args() + + gs.init(backend=gs.cpu if args.cpu else gs.gpu, precision="32", logging_level="info") + + scene = gs.Scene( + viewer_options=gs.options.ViewerOptions( + camera_pos=(6.0, 6.0, 4.0), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=60, + max_FPS=60, + ), + profiling_options=gs.options.ProfilingOptions(show_FPS=False), + show_viewer=True, + ) + + scene.add_entity(gs.morphs.Plane()) + + # create ring of obstacles to visualize raycaster sensor hits + for i in range(NUM_CYLINDERS): + angle = 2 * np.pi * i / NUM_CYLINDERS + x = CYLINDER_RING_RADIUS * np.cos(angle) + y = CYLINDER_RING_RADIUS * np.sin(angle) + scene.add_entity( + gs.morphs.Cylinder( + height=1.5, + radius=0.3, + pos=(x, y, 0.75), + fixed=args.fixed, + ) + ) + + for i in range(NUM_BOXES): + angle = 2 * np.pi * i / NUM_BOXES + np.pi / 6 + x = BOX_RING_RADIUS * np.cos(angle) + y = BOX_RING_RADIUS * np.sin(angle) + scene.add_entity( + gs.morphs.Box( + size=(0.5, 0.5, 2.0), + pos=(x, y, 1.0), + fixed=args.fixed, + ) + ) + + robot_kwargs = dict( + pos=(0.0, 0.0, 0.35), + quat=(1.0, 0.0, 0.0, 0.0), + fixed=True, + ) + + if args.use_box: + robot = scene.add_entity(gs.morphs.Box(size=(0.1, 0.1, 0.1), **robot_kwargs)) + pos_offset = (0.0, 0.0, 0.2) + else: + robot = scene.add_entity(gs.morphs.URDF(file="urdf/go2/urdf/go2.urdf", **robot_kwargs)) + pos_offset = (0.3, 0.0, 0.1) + + sensor_kwargs = dict( + entity_idx=robot.idx, + pos_offset=pos_offset, + euler_offset=(0.0, 0.0, 0.0), + return_world_frame=True, + draw_debug=True, + ) + + if args.pattern == "depth": + sensor = scene.add_sensor(gs.sensors.DepthCamera(pattern=DepthCameraPattern(), **sensor_kwargs)) + scene.start_recording( + data_func=(lambda: sensor.read_image()[0]) if args.n_envs > 0 else sensor.read_image, + rec_options=gs.recorders.MPLImagePlot(), + ) + else: + if args.pattern == "grid": + pattern_cfg = GridPattern() + else: + if args.pattern != "spherical": + gs.logger.warning(f"Unrecognized raycaster pattern: {args.pattern}. Using 'spherical' instead.") + pattern_cfg = SphericalPattern() + + sensor = scene.add_sensor(gs.sensors.Lidar(pattern=pattern_cfg, **sensor_kwargs)) + + scene.build(n_envs=args.n_envs) + + if IS_PYNPUT_AVAILABLE: + kb = KeyboardDevice() + kb.start() + + print("Keyboard Controls:") + # Avoid using same keys as interactive viewer keyboard controls + print("[↑/↓/←/→]: Move XY") + print("[j/k]: Down/Up") + print("[n/m]: Roll CCW/CW") + print("[,/.]: Pitch Up/Down") + print("[o/p]: Yaw CCW/CW") + print("[\\]: Reset") + print("[esc]: Quit") + else: + print("Keyboard teleop is disabled since pynput is not installed. To install, run `pip install pynput`.") + + init_pos = np.array([0.0, 0.0, 0.35], dtype=np.float32) + init_euler = np.array([0.0, 0.0, 0.0], dtype=np.float32) + + target_pos = init_pos.copy() + target_euler = init_euler.copy() + + def apply_pose_to_all_envs(pos_np: np.ndarray, quat_np: np.ndarray): + if args.n_envs > 0: + pos_np = np.expand_dims(pos_np, axis=0).repeat(args.n_envs, axis=0) + quat_np = np.expand_dims(quat_np, axis=0).repeat(args.n_envs, axis=0) + robot.set_pos(pos_np) + robot.set_quat(quat_np) + + apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) + + try: + while True: + if IS_PYNPUT_AVAILABLE: + pressed = kb.pressed_keys.copy() + if keyboard.Key.esc in pressed: + break + if keyboard.KeyCode.from_char("\\") in pressed: + target_pos[:] = init_pos + target_euler[:] = init_euler + + if keyboard.Key.up in pressed: + target_pos[0] += KEY_DPOS + if keyboard.Key.down in pressed: + target_pos[0] -= KEY_DPOS + if keyboard.Key.right in pressed: + target_pos[1] -= KEY_DPOS + if keyboard.Key.left in pressed: + target_pos[1] += KEY_DPOS + if keyboard.KeyCode.from_char("j") in pressed: + target_pos[2] -= KEY_DPOS + if keyboard.KeyCode.from_char("k") in pressed: + target_pos[2] += KEY_DPOS + + if keyboard.KeyCode.from_char("n") in pressed: + target_euler[0] += KEY_DANGLE # roll CCW around +X + if keyboard.KeyCode.from_char("m") in pressed: + target_euler[0] -= KEY_DANGLE # roll CW around +X + if keyboard.KeyCode.from_char(",") in pressed: + target_euler[1] += KEY_DANGLE # pitch up around +Y + if keyboard.KeyCode.from_char(".") in pressed: + target_euler[1] -= KEY_DANGLE # pitch down around +Y + if keyboard.KeyCode.from_char("o") in pressed: + target_euler[2] += KEY_DANGLE # yaw CCW around +Z + if keyboard.KeyCode.from_char("p") in pressed: + target_euler[2] -= KEY_DANGLE # yaw CW around +Z + else: + # move in a circle if no keyboard control + target_pos[0] = MOVE_RADIUS * np.cos(scene.t * MOVE_RATE) + target_pos[1] = MOVE_RADIUS * np.sin(scene.t * MOVE_RATE) + + apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) + scene.step() + + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") + + +if __name__ == "__main__": + main() diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 17b3a582c..2524d1fc5 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -1,9 +1,9 @@ from copy import copy from itertools import chain -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal -import numpy as np import gstaichi as ti +import numpy as np import torch import trimesh @@ -11,9 +11,8 @@ from genesis.engine.materials.base import Material from genesis.options.morphs import Morph from genesis.options.surfaces import Surface -from genesis.utils import geom as gu from genesis.utils import array_class -from genesis.utils import linalg as lu +from genesis.utils import geom as gu from genesis.utils import mesh as mu from genesis.utils import mjcf as mju from genesis.utils import terrain as tu @@ -27,8 +26,8 @@ from .rigid_link import RigidLink if TYPE_CHECKING: - from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver from genesis.engine.scene import Scene + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver @ti.data_oriented @@ -3004,7 +3003,7 @@ def _kernel_get_free_verts( tensor: ti.types.ndarray(), free_verts_idx_local: ti.types.ndarray(), verts_state_start: ti.i32, - free_verts_state: array_class.FreeVertsState, + free_verts_state: array_class.VertsState, ): n_verts = free_verts_idx_local.shape[0] _B = tensor.shape[0] @@ -3018,7 +3017,7 @@ def _kernel_get_fixed_verts( tensor: ti.types.ndarray(), fixed_verts_idx_local: ti.types.ndarray(), verts_state_start: ti.i32, - fixed_verts_state: array_class.FixedVertsState, + fixed_verts_state: array_class.VertsState, ): n_verts = fixed_verts_idx_local.shape[0] _B = tensor.shape[0] diff --git a/genesis/engine/entities/rigid_entity/rigid_geom.py b/genesis/engine/entities/rigid_entity/rigid_geom.py index ae15913e7..edb375e4c 100644 --- a/genesis/engine/entities/rigid_entity/rigid_geom.py +++ b/genesis/engine/entities/rigid_entity/rigid_geom.py @@ -18,9 +18,9 @@ from genesis.utils.misc import tensor_to_array, DeprecationError if TYPE_CHECKING: - from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver from genesis.engine.materials.rigid import Rigid as RigidMaterial from genesis.engine.mesh import Mesh + from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver from .rigid_entity import RigidEntity from .rigid_link import RigidLink @@ -1091,7 +1091,7 @@ def _kernel_get_vgeoms_quat(tensor: ti.types.ndarray(), vgeom_idx: ti.i32, vgeom @ti.kernel def _kernel_get_free_verts( - tensor: ti.types.ndarray(), verts_state_start: ti.i32, n_verts: ti.i32, free_verts_state: array_class.FreeVertsState + tensor: ti.types.ndarray(), verts_state_start: ti.i32, n_verts: ti.i32, free_verts_state: array_class.VertsState ): _B = free_verts_state.pos.shape[1] for i_v_, i, i_b in ti.ndrange(n_verts, 3, _B): @@ -1104,7 +1104,7 @@ def _kernel_get_fixed_verts( tensor: ti.types.ndarray(), verts_state_start: ti.i32, n_verts: ti.i32, - fixed_verts_state: array_class.FixedVertsState, + fixed_verts_state: array_class.VertsState, ): for i_v_, i in ti.ndrange(n_verts, 3): i_v = i_v_ + verts_state_start diff --git a/genesis/engine/scene.py b/genesis/engine/scene.py index 0cfd38cb2..d33a30f55 100644 --- a/genesis/engine/scene.py +++ b/genesis/engine/scene.py @@ -773,7 +773,7 @@ def build( env_spacing=(0.0, 0.0), n_envs_per_row: int | None = None, center_envs_at_origin=True, - compile_kernels=True, + compile_kernels=None, ): """ Builds the scene once all entities have been added. This operation is required before running the simulation. @@ -781,16 +781,23 @@ def build( Parameters ---------- n_envs : int - Number of parallel environments to create. If `n_envs` is 0, the scene will not have a batching dimension. If `n_envs` is greater than 0, the first dimension of all the input and returned states will be the batch dimension. + Number of parallel environments to create. + If `n_envs` is 0, the scene will not have a batching dimension. When greater than 0, the first dimension of + all the input and returned states will be the batch dimension. env_spacing : tuple of float, shape (2,) - The spacing between adjacent environments in the scene. This is for visualization purposes only and does not change simulation-related poses. + The spacing between adjacent environments in the scene. + This is for visualization purposes only and does not change simulation-related poses. n_envs_per_row : int The number of environments per row for visualization. If None, it will be set to `sqrt(n_envs)`. center_envs_at_origin : bool Whether to put the center of all the environments at the origin (for visualization only). - compile_kernels : bool - Whether to compile the simulation kernels inside `build()`. If False, the kernels will not be compiled (or loaded if found in the cache) until the first call of `scene.step()`. This is useful for cases you don't want to run the actual simulation, but rather just want to visualize the created scene. + compile_kernels : bool, optional + This parameter is deprecated and will be removed in future release. """ + if compile_kernels is not None: + warn_once("`compile_kernels` is deprecated and will be removed in future release.") + compile_kernels = True + with gs.logger.timer(f"Building scene ~~~<{self._uid}>~~~..."): self._parallelize(n_envs, env_spacing, n_envs_per_row, center_envs_at_origin) @@ -803,10 +810,9 @@ def build( self._is_built = True - if compile_kernels: - with gs.logger.timer("Compiling simulation kernels..."): - self._sim.step() - self._reset() + with gs.logger.timer("Compiling simulation kernels..."): + self._sim.step() + self._reset() # visualizer with gs.logger.timer("Building visualizer..."): diff --git a/genesis/engine/simulator.py b/genesis/engine/simulator.py index 5574cc99f..8980347c1 100644 --- a/genesis/engine/simulator.py +++ b/genesis/engine/simulator.py @@ -210,13 +210,13 @@ def build(self): if self.n_envs > 0 and self.sf_solver.is_active(): gs.raise_exception("Batching is not supported for SF solver as of now.") - self._sensor_manager.build() - # hybrid for entity in self._entities: if isinstance(entity, HybridEntity): entity.build() + self._sensor_manager.build() + def reset(self, state: SimState, envs_idx=None): for solver, solver_state in zip(self._solvers, state): if solver.n_entities > 0: diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index 054851ea8..270bbfb26 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import gstaichi as ti import numpy as np @@ -7,28 +6,24 @@ import torch import genesis as gs -import genesis.utils.geom as gu import genesis.utils.array_class as array_class - +import genesis.utils.geom as gu from genesis.engine.entities import AvatarEntity, DroneEntity, RigidEntity from genesis.engine.entities.base_entity import Entity -from genesis.engine.solvers.rigid.contact_island import ContactIsland from genesis.engine.states.solvers import RigidSolverState from genesis.options.solvers import RigidOptions -from genesis.styles import colors, formats from genesis.utils import linalg as lu -from genesis.utils.misc import ti_to_torch, DeprecationError, ALLOCATE_TENSOR_WARNING +from genesis.utils.misc import ALLOCATE_TENSOR_WARNING, DeprecationError, ti_to_torch +from genesis.utils.sdf_decomp import SDF -from ....utils.sdf_decomp import SDF from ..base_solver import Solver +from .collider_decomp import Collider from .constraint_solver_decomp import ConstraintSolver from .constraint_solver_decomp_island import ConstraintSolverIsland -from .collider_decomp import Collider from .rigid_solver_decomp_util import func_wakeup_entity_and_its_temp_island if TYPE_CHECKING: import genesis.engine.solvers.rigid.array_class - from genesis.engine.scene import Scene from genesis.engine.simulator import Simulator @@ -5069,8 +5064,8 @@ def kernel_update_verts_for_geoms( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, verts_info: array_class.VertsInfo, - free_verts_state: array_class.FreeVertsState, - fixed_verts_state: array_class.FixedVertsState, + free_verts_state: array_class.VertsState, + fixed_verts_state: array_class.VertsState, ): n_geoms = geoms_idx.shape[0] _B = geoms_state.verts_updated.shape[1] @@ -5086,8 +5081,8 @@ def func_update_verts_for_geom( geoms_state: array_class.GeomsState, geoms_info: array_class.GeomsInfo, verts_info: array_class.VertsInfo, - free_verts_state: array_class.FreeVertsState, - fixed_verts_state: array_class.FixedVertsState, + free_verts_state: array_class.VertsState, + fixed_verts_state: array_class.VertsState, ): if not geoms_state.verts_updated[i_g, i_b]: if geoms_info.is_fixed[i_g]: @@ -5114,8 +5109,8 @@ def func_update_verts_for_geom( def func_update_all_verts( geoms_state: array_class.GeomsState, verts_info: array_class.VertsInfo, - free_verts_state: array_class.FreeVertsState, - fixed_verts_state: array_class.FixedVertsState, + free_verts_state: array_class.VertsState, + fixed_verts_state: array_class.VertsState, ): n_verts = verts_info.geom_idx.shape[0] _B = geoms_state.pos.shape[1] @@ -5133,6 +5128,29 @@ def func_update_all_verts( @ti.kernel(pure=gs.use_pure) +def kernel_update_all_verts( + geoms_state: array_class.GeomsState, + verts_info: array_class.VertsInfo, + free_verts_state: array_class.VertsState, + fixed_verts_state: array_class.VertsState, +): + n_verts = verts_info.geom_idx.shape[0] + _B = geoms_state.pos.shape[1] + for i_v, i_b in ti.ndrange(n_verts, _B): + g_pos = geoms_state.pos[verts_info.geom_idx[i_v], i_b] + g_quat = geoms_state.quat[verts_info.geom_idx[i_v], i_b] + verts_state_idx = verts_info.verts_state_idx[i_v] + if not verts_info.is_fixed[i_v]: + free_verts_state.pos[verts_state_idx, i_b] = gu.ti_transform_by_trans_quat( + verts_info.init_pos[i_v], g_pos, g_quat + ) + elif i_b == 0: + fixed_verts_state.pos[verts_state_idx] = gu.ti_transform_by_trans_quat( + verts_info.init_pos[i_v], g_pos, g_quat + ) + + +@ti.kernel def kernel_update_geom_aabbs( geoms_state: array_class.GeomsState, geoms_init_AABB: array_class.GeomsInitAABB, diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index 8dc1ac515..979addbc6 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -1225,8 +1225,8 @@ def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defi def start(self, auto_refresh=True): import pyglet # For some reason, this is necessary if 'pyglet.window.xlib' fails to import... try: - import pyglet.window.xlib - xlib_exceptions = (pyglet.window.xlib.XlibException,) + import pyglet.window.xlib, pyglet.display.xlib + xlib_exceptions = (pyglet.window.xlib.XlibException, pyglet.display.xlib.NoSuchDisplayException) except ImportError: xlib_exceptions = () @@ -1235,12 +1235,14 @@ def start(self, auto_refresh=True): confs = [ pyglet.gl.Config( depth_size=24, + alpha_size=8, # This parameter is essential to ensure proper pixel matching across platforms double_buffer=True, # Double buffering to avoid flickering major_version=TARGET_OPEN_GL_MAJOR, minor_version=TARGET_OPEN_GL_MINOR, ), pyglet.gl.Config( depth_size=24, + alpha_size=8, double_buffer=True, major_version=MIN_OPEN_GL_MAJOR, minor_version=MIN_OPEN_GL_MINOR, diff --git a/genesis/options/morphs.py b/genesis/options/morphs.py index a5006eb78..2a463602b 100644 --- a/genesis/options/morphs.py +++ b/genesis/options/morphs.py @@ -11,13 +11,11 @@ import numpy as np import genesis as gs -import genesis.utils.geom as gu import genesis.utils.misc as mu from .misc import CoacdOptions from .options import Options - URDF_FORMAT = ".urdf" MJCF_FORMAT = ".xml" MESH_FORMATS = (".obj", ".ply", ".stl") diff --git a/genesis/recorders/base_recorder.py b/genesis/recorders/base_recorder.py index cb31724b3..c3a34a694 100644 --- a/genesis/recorders/base_recorder.py +++ b/genesis/recorders/base_recorder.py @@ -33,7 +33,7 @@ class RecorderOptions(Options): buffer_size: int = 0 buffer_full_wait_time: float = 0.1 - def validate(self): + def model_post_init(self, context): """Validate the recorder options values before the recorder is added to the scene.""" if self.hz is not None and self.hz < gs.EPS: gs.raise_exception(f"[{type(self).__name__}] recording hz should be greater than 0.") diff --git a/genesis/recorders/file_writers.py b/genesis/recorders/file_writers.py index c1b13cde2..e5c48c656 100644 --- a/genesis/recorders/file_writers.py +++ b/genesis/recorders/file_writers.py @@ -1,7 +1,7 @@ import csv import os -from pathlib import Path from collections import defaultdict +from pathlib import Path import numpy as np import torch @@ -13,7 +13,6 @@ from .base_recorder import Recorder, RecorderOptions from .recorder_manager import register_recording - IS_PYAV_AVAILABLE = False try: import av @@ -107,10 +106,10 @@ class VideoFileWriterOptions(BaseFileWriterOptions): bitrate: float = 1.0 codec_options: dict[str, str] = Field(default_factory=dict) - def validate(self): - super().validate() + def model_post_init(self, context): + super().model_post_init(context) - if not self.codec in av.codecs_available: + if self.codec not in av.codecs_available: gs.raise_exception(f"[{type(self).__name__}] Codec '{self._options.codec}' not supported.") if not self.filename.endswith(".mp4"): @@ -240,8 +239,8 @@ class CSVFileWriterOptions(BaseFileWriterOptions): header: tuple[str, ...] | None = None save_every_write: bool = False - def validate(self): - super().validate() + def model_post_init(self, context): + super().model_post_init(context) if not self.filename.lower().endswith(".csv"): gs.raise_exception(f"[{type(self).__name__}] CSV output must be a .csv file") @@ -322,8 +321,8 @@ class NPZFileWriterOptions(BaseFileWriterOptions): If True, a counter will be added to the filename and incremented on each reset. """ - def validate(self): - super().validate() + def model_post_init(self, context): + super().model_post_init(context) if not self.filename.lower().endswith(".npz"): gs.raise_exception(f"[{type(self).__name__}] NPZ output must be an .npz file") diff --git a/genesis/recorders/recorder_manager.py b/genesis/recorders/recorder_manager.py index bfc37b4b3..6ad6c2d3b 100644 --- a/genesis/recorders/recorder_manager.py +++ b/genesis/recorders/recorder_manager.py @@ -42,7 +42,6 @@ def add_recorder(self, data_func: Callable[[], Any], rec_options: "RecorderOptio recorder : Recorder The created recorder object. """ - rec_options.validate() recorder_cls = RecorderManager.RECORDER_TYPES_MAP[type(rec_options)] recorder = recorder_cls(self, rec_options, data_func) self._recorders.append(recorder) diff --git a/genesis/sensors/__init__.py b/genesis/sensors/__init__.py index 8c43ba5ca..3ac1e4f2c 100644 --- a/genesis/sensors/__init__.py +++ b/genesis/sensors/__init__.py @@ -2,4 +2,7 @@ from .contact_force import ContactForceSensorOptions as ContactForce from .contact_force import ContactSensorOptions as Contact from .imu import IMUOptions as IMU +from .raycaster import DepthCameraOptions as DepthCamera +from .raycaster import RaycasterOptions as Lidar +from .raycaster import RaycasterOptions as Raycaster from .sensor_manager import SensorManager diff --git a/genesis/sensors/base_sensor.py b/genesis/sensors/base_sensor.py index 3c17944de..5d39f2117 100644 --- a/genesis/sensors/base_sensor.py +++ b/genesis/sensors/base_sensor.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Generic, Sequence, TypeVar +from typing import TYPE_CHECKING, Generic, Sequence, Type, TypeVar import gstaichi as ti import numpy as np @@ -16,6 +16,7 @@ if TYPE_CHECKING: from genesis.engine.entities.rigid_entity.rigid_link import RigidLink + from genesis.engine.scene import Scene from genesis.recorders.base_recorder import Recorder, RecorderOptions from genesis.utils.ring_buffer import TensorRingBuffer from genesis.vis.rasterizer_context import RasterizerContext @@ -63,9 +64,11 @@ class SensorOptions(Options): update_ground_truth_only: bool = False draw_debug: bool = False - def validate(self, scene): + def validate(self, scene: "Scene"): """ Validate the sensor options values before the sensor is added to the scene. + + Use pydantic's model_post_init() for validation that does not require scene context. """ delay_hz = self.delay / scene._sim.dt if not np.isclose(delay_hz, round(delay_hz), atol=gs.EPS): @@ -75,6 +78,7 @@ def validate(self, scene): ) +# Note: dataclass is used as opposed to pydantic.BaseModel since torch.Tensors are not supported by default @dataclass class SharedSensorMetadata: """ @@ -103,7 +107,9 @@ class Sensor(RBC, Generic[SharedSensorMetadataT]): the shared cache to return the correct data. """ - def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_manager: "SensorManager"): + def __init__( + self, sensor_options: "SensorOptions", sensor_idx: int, data_cls: Type[tuple], sensor_manager: "SensorManager" + ): self._options: "SensorOptions" = sensor_options self._idx: int = sensor_idx self._manager: "SensorManager" = sensor_manager @@ -114,14 +120,13 @@ def __init__(self, sensor_options: "SensorOptions", sensor_idx: int, sensor_mana self._delay_ts = round(self._options.delay / self._dt) self._cache_slices: list[slice] = [] - self._return_format = self._get_return_format() - is_return_dict = isinstance(self._return_format, dict) - if is_return_dict: - self._return_shapes = self._return_format.values() - self._get_formatted_data = self._get_formatted_data_dict - else: - self._return_shapes = (self._return_format,) - self._get_formatted_data = self._get_formatted_data_tuple + self._return_data_class = data_cls + return_format = self._get_return_format() + assert len(return_format) > 0 + if isinstance(return_format[0], int): + return_format = (return_format,) + self._return_shapes: tuple[tuple[int, ...], ...] = return_format + self._cache_size = 0 for shape in self._return_shapes: data_size = np.prod(shape) @@ -163,17 +168,15 @@ def reset(cls, shared_metadata: SharedSensorMetadataT, envs_idx): """ pass - def _get_return_format(self) -> dict[str, tuple[int, ...]] | tuple[int, ...]: + def _get_return_format(self) -> tuple[int | tuple[int, ...], ...]: """ Get the data format of the read() return value. Returns ------- - return_format : dict | tuple - - If tuple, the final shape of the read() return value. - e.g. (2, 3) means read() will return a tensor of shape (2, 3). - - If dict a dictionary with string keys and tensor values will be returned. - e.g. {"pos": (3,), "quat": (4,)} returns a dict of tensors [0:3] and [3:7] from the cache. + return_format : tuple[tuple[int, ...], ...] + The output shape(s) of the tensor data returned by read(), e.g. (2, 3) means read() will return a single + tensor of shape (2, 3) and ((3,), (3,)) would return two tensors of shape (3,). """ raise NotImplementedError(f"{type(self).__name__} has not implemented `get_return_format()`.") @@ -209,7 +212,7 @@ def _get_cache_dtype(cls) -> torch.dtype: """ raise NotImplementedError(f"{cls.__name__} has not implemented `get_cache_dtype()`.") - def _draw_debug(self, context: "RasterizerContext"): + def _draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): """ Draw debug shapes for the sensor in the scene. """ @@ -306,9 +309,9 @@ def _apply_delay_to_shared_cache( tensor_start += tensor_size - def _get_return_values(self, tensor: torch.Tensor, envs_idx=None) -> list[torch.Tensor]: + def _get_formatted_data(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tensor: """ - Preprares the given tensor into multiple tensors matching `self._return_shapes`. + Returns tensor(s) matching the return format. Note that this method does not clone the data tensor, it should have been cloned by the caller. """ @@ -319,20 +322,13 @@ def _get_return_values(self, tensor: torch.Tensor, envs_idx=None) -> list[torch. for i, shape in enumerate(self._return_shapes): field_data = tensor_chunk[..., self._cache_slices[i]].reshape((len(envs_idx), *shape)) - if self._manager._sim.n_envs == 0: field_data = field_data.squeeze(0) return_values.append(field_data) - return return_values - - def _get_formatted_data_dict(self, tensor: torch.Tensor, envs_idx=None) -> dict[str, torch.Tensor]: - """Returns a dictionary of tensors matching the return format.""" - return dict(zip(self._return_format.keys(), self._get_return_values(tensor, envs_idx))) - - def _get_formatted_data_tuple(self, tensor: torch.Tensor, envs_idx=None) -> torch.Tensor: - """Returns a tensor matching the return format.""" - return self._get_return_values(tensor, envs_idx)[0] + if len(return_values) == 1: + return return_values[0] + return self._return_data_class(*return_values) def _sanitize_envs_idx(self, envs_idx) -> torch.Tensor: return self._manager._sim._scene._sanitize_envs_idx(envs_idx) @@ -359,7 +355,7 @@ class RigidSensorOptionsMixin: pos_offset: Tuple3FType = (0.0, 0.0, 0.0) euler_offset: Tuple3FType = (0.0, 0.0, 0.0) - def validate(self, scene): + def validate(self, scene: "Scene"): super().validate(scene) if self.entity_idx < 0 or self.entity_idx >= len(scene.entities): gs.raise_exception(f"Invalid RigidEntity index {self.entity_idx}.") @@ -436,16 +432,12 @@ class NoisySensorOptionsMixin: The standard deviation of the additive white noise. random_walk : float | tuple[float, ...], optional The standard deviation of the random walk, which acts as accumulated bias drift. - delay : float, optional - The delay in seconds, affecting how outdated the sensor data is when it is read. jitter : float, optional The jitter in seconds modeled as a a random additive delay sampled from a normal distribution. Jitter cannot be greater than delay. `interpolate` should be True when `jitter` is greater than 0. interpolate : bool, optional If True, the sensor data is interpolated between data points for delay + jitter. Otherwise, the sensor data at the closest time step will be used. Default is False. - update_ground_truth_only : bool, optional - If True, the sensor will only update the ground truth data, and not the measured data. """ resolution: float | tuple[float, ...] = 0.0 @@ -455,8 +447,7 @@ class NoisySensorOptionsMixin: jitter: float = 0.0 interpolate: bool = False - def validate(self, scene): - super().validate(scene) + def model_post_init(self, _): if self.jitter > 0 and not self.interpolate: gs.raise_exception(f"{type(self).__name__}: `interpolate` should be True when `jitter` is greater than 0.") if self.jitter > self.delay: diff --git a/genesis/sensors/contact_force.py b/genesis/sensors/contact_force.py index 255819dbb..c1283daaf 100644 --- a/genesis/sensors/contact_force.py +++ b/genesis/sensors/contact_force.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, Type import gstaichi as ti import numpy as np @@ -76,16 +76,6 @@ class ContactSensorOptions(RigidSensorOptionsMixin, SensorOptions): Parameters ---------- - entity_idx : int - The global entity index of the RigidEntity to which this sensor is attached. - link_idx_local : int, optional - The local index of the RigidLink of the RigidEntity to which this sensor is attached. - delay : float - The delay in seconds before the sensor data is read. - update_ground_truth_only : bool - If True, the sensor will only update the ground truth data, and not the measured data. - draw_debug : bool, optional - If True and the interactive viewer is active, a sphere will be drawn at the sensor's position. debug_sphere_radius : float, optional The radius of the debug sphere. Defaults to 0.05. debug_color : float, optional @@ -106,15 +96,22 @@ class ContactSensorMetadata(SharedSensorMetadata): expanded_links_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int) -@register_sensor(ContactSensorOptions, ContactSensorMetadata) +@register_sensor(ContactSensorOptions, ContactSensorMetadata, tuple) @ti.data_oriented class ContactSensor(Sensor): """ Sensor that returns bool based on whether associated RigidLink is in contact. """ - def __init__(self, sensor_options: ContactSensorOptions, sensor_idx: int, sensor_manager: "SensorManager"): - super().__init__(sensor_options, sensor_idx, sensor_manager) + def __init__( + self, + sensor_options: ContactSensorOptions, + sensor_idx: int, + data_cls: Type[tuple], + sensor_manager: "SensorManager", + ): + super().__init__(sensor_options, sensor_idx, data_cls, sensor_manager) + self._link: "RigidLink" | None = None self.debug_object: "Mesh" | None = None @@ -162,16 +159,16 @@ def _update_shared_cache( buffered_data.append(shared_ground_truth_cache) cls._apply_delay_to_shared_cache(shared_metadata, shared_cache, buffered_data) - def _draw_debug(self, context: "RasterizerContext"): + def _draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): """ Draw debug sphere when the sensor detects contact. Only draws for first rendered environment. """ - env_idx = context.rendered_envs_idx[0] + env_idx = context.rendered_envs_idx[0] if self._manager._sim.n_envs > 0 else None - pos = self._link.get_pos(envs_idx=env_idx)[0] - is_contact = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None).item() + pos = self._link.get_pos(envs_idx=env_idx).squeeze(0) + is_contact = self.read(envs_idx=env_idx).item() if is_contact: if self.debug_object is None: @@ -179,7 +176,7 @@ def _draw_debug(self, context: "RasterizerContext"): pos=pos, radius=self._options.debug_sphere_radius, color=self._options.debug_color ) else: - context.update_debug_objects([self.debug_object], trans_to_T(pos).unsqueeze(0)) + buffer_updates.update(context.get_buffer_debug_objects([self.debug_object], [trans_to_T(pos)])) elif self.debug_object is not None: context.clear_debug_object(self.debug_object) self.debug_object = None @@ -194,35 +191,10 @@ class ContactForceSensorOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin Parameters ---------- - entity_idx : int - The global entity index of the RigidEntity to which this sensor is attached. - link_idx_local : int, optional - The local index of the RigidLink of the RigidEntity to which this sensor is attached. min_force : float | tuple[float, float, float], optional The minimum detectable absolute force per each axis. Values below this will be treated as 0. Default is 0. max_force : float | tuple[float, float, float], optional The maximum output absolute force per each axis. Values above this will be clipped. Default is infinity. - resolution : float | tuple[float, float, float], optional - The measurement resolution of each axis of force (smallest increment of change in the sensor reading). - Default is 0.0, which means no quantization is applied. - bias : float | tuple[float, float, float], optional - The constant additive bias of the sensor. - noise : float | tuple[float, float, float], optional - The standard deviation of the additive white noise. - random_walk : float | tuple[float, float, float], optional - The standard deviation of the random walk, which acts as accumulated bias drift. - delay : float, optional - The delay in seconds, affecting how outdated the sensor data is when it is read. - jitter : float, optional - The jitter in seconds modeled as a a random additive delay sampled from a normal distribution. - Jitter cannot be greater than delay. `interpolate` should be True when `jitter` is greater than 0. - interpolate : bool, optional - If True, the sensor data is interpolated between data points for delay + jitter. - Otherwise, the sensor data at the closest time step will be used. Default is False. - update_ground_truth_only : bool, optional - If True, the sensor will only update the ground truth data, and not the measured data. - draw_debug : bool, optional - If True and the interactive viewer is active, an arrow for the contact force will be drawn. debug_color : float, optional The rgba color of the debug arrow. Defaults to (1.0, 0.0, 1.0, 0.5). debug_scale : float, optional @@ -235,8 +207,7 @@ class ContactForceSensorOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin debug_color: tuple[float, float, float, float] = (1.0, 0.0, 1.0, 0.5) debug_scale: float = 0.01 - def validate(self, scene): - super().validate(scene) + def model_post_init(self, _): if not ( isinstance(self.min_force, float) or (isinstance(self.min_force, Sequence) and len(self.min_force) == 3) ): @@ -263,9 +234,10 @@ class ContactForceSensorMetadata(RigidSensorMetadataMixin, NoisySensorMetadataMi min_force: torch.Tensor = make_tensor_field((0, 3)) max_force: torch.Tensor = make_tensor_field((0, 3)) + output_forces: torch.Tensor = make_tensor_field((0, 0)) # FIXME: remove once we have contiguous cache slices -@register_sensor(ContactForceSensorOptions, ContactForceSensorMetadata) +@register_sensor(ContactForceSensorOptions, ContactForceSensorMetadata, tuple) @ti.data_oriented class ContactForceSensor( RigidSensorMixin[ContactForceSensorMetadata], @@ -276,8 +248,15 @@ class ContactForceSensor( Sensor that returns the total contact force being applied to the associated RigidLink in its local frame. """ - def __init__(self, sensor_options: ContactForceSensorOptions, sensor_idx: int, sensor_manager: "SensorManager"): - super().__init__(sensor_options, sensor_idx, sensor_manager) + def __init__( + self, + sensor_options: ContactForceSensorOptions, + sensor_idx: int, + data_cls: Type[tuple], + sensor_manager: "SensorManager", + ): + super().__init__(sensor_options, sensor_idx, data_cls, sensor_manager) + self.debug_object: "Mesh" | None = None def build(self): @@ -298,6 +277,13 @@ def build(self): _to_tuple(self._options.max_force, length_per_value=3), ) + if self._shared_metadata.output_forces.numel() == 0: + self._shared_metadata.output_forces.reshape(self._manager._sim._B, 0) + self._shared_metadata.output_forces = concat_with_tensor( + self._shared_metadata.output_forces, + torch.empty((self._manager._sim._B, 3), dtype=gs.tc_float, device=gs.device), + ) + def _get_return_format(self) -> tuple[int, ...]: return (3,) @@ -313,7 +299,11 @@ def _update_shared_ground_truth_cache( all_contacts = shared_metadata.solver.collider.get_contacts(as_tensor=True, to_torch=True) force, link_a, link_b = all_contacts["force"], all_contacts["link_a"], all_contacts["link_b"] - shared_ground_truth_cache.fill_(0.0) + if not shared_ground_truth_cache.is_contiguous(): + shared_metadata.output_forces.fill_(0.0) + else: + shared_ground_truth_cache.fill_(0.0) + if link_a.shape[-1] == 0: return # no contacts @@ -330,8 +320,10 @@ def _update_shared_ground_truth_cache( link_b.contiguous(), links_quat.contiguous(), shared_metadata.links_idx, - shared_ground_truth_cache, + shared_ground_truth_cache if shared_ground_truth_cache.is_contiguous() else shared_metadata.output_forces, ) + if not shared_ground_truth_cache.is_contiguous(): + shared_ground_truth_cache[:] = shared_metadata.output_forces @classmethod def _update_shared_cache( @@ -358,7 +350,7 @@ def _update_shared_cache( shared_cache_per_sensor[torch.abs(shared_cache_per_sensor) < shared_metadata.min_force] = 0.0 cls._quantize_to_resolution(shared_metadata.resolution, shared_cache) - def _draw_debug(self, context: "RasterizerContext"): + def _draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): """ Draw debug arrow representing the contact force. @@ -366,12 +358,12 @@ def _draw_debug(self, context: "RasterizerContext"): """ env_idx = context.rendered_envs_idx[0] - pos = self._link.get_pos(envs_idx=env_idx) - quat = self._link.get_quat(envs_idx=env_idx) + pos = self._link.get_pos(envs_idx=env_idx).squeeze(0) + quat = self._link.get_quat(envs_idx=env_idx).squeeze(0) force = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None) - vec = tensor_to_array(transform_by_quat(force * self._options.debug_scale, quat)) + vec = tensor_to_array(transform_by_quat(force.squeeze(0) * self._options.debug_scale, quat)) if self.debug_object is not None: context.clear_debug_object(self.debug_object) - self.debug_object = context.draw_debug_arrow(pos=pos[0], vec=vec[0], color=self._options.debug_color) + self.debug_object = context.draw_debug_arrow(pos=pos, vec=vec, color=self._options.debug_color) diff --git a/genesis/sensors/imu.py b/genesis/sensors/imu.py index 9f17cbcc6..52c2793e4 100644 --- a/genesis/sensors/imu.py +++ b/genesis/sensors/imu.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple, Type import gstaichi as ti import numpy as np @@ -25,6 +25,7 @@ from .sensor_manager import register_sensor if TYPE_CHECKING: + from genesis.ext.pyrender.mesh import Mesh from genesis.utils.ring_buffer import TensorRingBuffer from genesis.vis.rasterizer_context import RasterizerContext @@ -81,14 +82,6 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions Parameters ---------- - entity_idx : int - The global entity index of the RigidEntity to which this IMU sensor is attached. - link_idx_local : int, optional - The local index of the RigidLink of the RigidEntity to which this IMU sensor is attached. - pos_offset : tuple[float, float, float], optional - The positional offset of the IMU sensor from the RigidLink. - euler_offset : tuple[float, float, float], optional - The rotational offset of the IMU sensor from the RigidLink in degrees. acc_resolution : float, optional The measurement resolution of the accelerometer (smallest increment of change in the sensor reading). Default is 0.0, which means no quantization is applied. @@ -115,18 +108,6 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions The standard deviation of the white noise for each axis of the gyroscope. gyro_random_walk : tuple[float, float, float] The standard deviation of the bias drift for each axis of the gyroscope. - delay : float, optional - The delay in seconds, affecting how outdated the sensor data is when it is read. - jitter : float, optional - The jitter in seconds modeled as a a random additive delay sampled from a normal distribution. - Jitter cannot be greater than delay. `interpolate` should be True when `jitter` is greater than 0. - interpolate : bool, optional - If True, the sensor data is interpolated between data points for delay + jitter. - Otherwise, the sensor data at the closest time step will be used. Default is False. - update_ground_truth_only : bool, optional - If True, the sensor will only update the ground truth data, and not the measured data. - draw_debug : bool, optional - If True and the interactive viewer is active, an arrow for linear acceleration will be drawn. debug_acc_color : float, optional The rgba color of the debug acceleration arrow. Defaults to (0.0, 1.0, 1.0, 0.5). debug_acc_scale: float, optional @@ -153,8 +134,7 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions debug_gyro_color: tuple[float, float, float, float] = (1.0, 1.0, 0.0, 0.5) debug_gyro_scale: float = 0.01 - def validate(self, scene): - super().validate(scene) + def model_post_init(self, _): self._validate_axes_skew(self.acc_axes_skew) self._validate_axes_skew(self.gyro_axes_skew) @@ -177,13 +157,31 @@ class IMUSharedMetadata(RigidSensorMetadataMixin, NoisySensorMetadataMixin, Shar gyro_indices: torch.Tensor = make_tensor_field((0, 0), dtype_factory=lambda: gs.tc_int) -@register_sensor(IMUOptions, IMUSharedMetadata) +class IMUData(NamedTuple): + lin_acc: torch.Tensor + ang_vel: torch.Tensor + + +@register_sensor(IMUOptions, IMUSharedMetadata, IMUData) @ti.data_oriented class IMUSensor( RigidSensorMixin[IMUSharedMetadata], NoisySensorMixin[IMUSharedMetadata], Sensor[IMUSharedMetadata], ): + def __init__( + self, + options: IMUOptions, + shared_metadata: IMUSharedMetadata, + data_cls: Type[IMUData], + manager: "gs.SensorManager", + ): + super().__init__(options, shared_metadata, data_cls, manager) + + self.debug_objects: list["Mesh | None"] = [None, None] + self.quat_offset: torch.Tensor + self.pos_offset: torch.Tensor + @gs.assert_built def set_acc_axes_skew(self, axes_skew: MaybeMatrix3x3Type, envs_idx=None): envs_idx = self._sanitize_envs_idx(envs_idx) @@ -256,17 +254,12 @@ def build(self): expand=(self._manager._sim._B, 2, 3, 3), dim=1, ) - if self._options.draw_debug: - self.debug_objects = [None, None] self.quat_offset = self._shared_metadata.offsets_quat[0, self._idx] - self.pos_offset = self._shared_metadata.offsets_pos[0, self._idx] + self.pos_offset = self._shared_metadata.offsets_pos[0, self._idx].unsqueeze(0) - def _get_return_format(self) -> dict[str, tuple[int, ...]]: - return { - "lin_acc": (3,), - "ang_vel": (3,), - } + def _get_return_format(self) -> tuple[tuple[int, ...], ...]: + return (3,), (3,) @classmethod def _get_cache_dtype(cls) -> torch.dtype: @@ -328,7 +321,7 @@ def _update_shared_cache( cls._add_noise_drift_bias(shared_metadata, shared_cache) cls._quantize_to_resolution(shared_metadata.resolution, shared_cache) - def _draw_debug(self, context: "RasterizerContext"): + def _draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): """ Draw debug arrow for the IMU acceleration. @@ -339,16 +332,19 @@ def _draw_debug(self, context: "RasterizerContext"): quat = self._link.get_quat(envs_idx=env_idx) pos = self._link.get_pos(envs_idx=env_idx) + transform_by_quat(self.pos_offset, quat) - data = self.read(envs_idx=env_idx) - acc_vec = data["lin_acc"] * self._options.debug_acc_scale - gyro_vec = data["ang_vel"] * self._options.debug_gyro_scale + # cannot specify envs_idx for read() when n_envs=0 + data = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None) + acc_vec = data.lin_acc * self._options.debug_acc_scale + gyro_vec = data.ang_vel * self._options.debug_gyro_scale + # transform from local frame to world frame offset_quat = transform_quat_by_quat(self.quat_offset, quat) - acc_vec = tensor_to_array(transform_by_quat(acc_vec, offset_quat)) - gyro_vec = tensor_to_array(transform_by_quat(gyro_vec, offset_quat)) + acc_vec = tensor_to_array(transform_by_quat(acc_vec, offset_quat)).flatten() + gyro_vec = tensor_to_array(transform_by_quat(gyro_vec, offset_quat)).flatten() for debug_object in self.debug_objects: if debug_object is not None: context.clear_debug_object(debug_object) + self.debug_objects[0] = context.draw_debug_arrow(pos=pos[0], vec=acc_vec, color=self._options.debug_acc_color) self.debug_objects[1] = context.draw_debug_arrow(pos=pos[0], vec=gyro_vec, color=self._options.debug_gyro_color) diff --git a/genesis/sensors/raycaster/__init__.py b/genesis/sensors/raycaster/__init__.py new file mode 100644 index 000000000..fc48bfd79 --- /dev/null +++ b/genesis/sensors/raycaster/__init__.py @@ -0,0 +1,8 @@ +from .depth_camera import DepthCameraOptions +from .patterns import ( + DepthCameraPattern, + GridPattern, + RaycastPattern, + SphericalPattern, +) +from .raycaster import RaycasterOptions diff --git a/genesis/sensors/raycaster/depth_camera.py b/genesis/sensors/raycaster/depth_camera.py new file mode 100644 index 000000000..7cedfdb54 --- /dev/null +++ b/genesis/sensors/raycaster/depth_camera.py @@ -0,0 +1,35 @@ +import torch + +from genesis.sensors.sensor_manager import register_sensor + +from .patterns import DepthCameraPattern +from .raycaster import RaycasterData, RaycasterOptions, RaycasterSensor, RaycasterSharedMetadata + + +class DepthCameraOptions(RaycasterOptions): + """ + Depth camera that uses ray casting to obtain depth images. + + Parameters + ---------- + pattern: DepthCameraPattern + The raycasting pattern configuration for the sensor. + """ + + pattern: DepthCameraPattern + + +@register_sensor(DepthCameraOptions, RaycasterSharedMetadata, RaycasterData) +class DepthCameraSensor(RaycasterSensor): + def read_image(self) -> torch.Tensor: + """ + Read the depth image from the sensor. + + This method uses the hit distances from the underlying RaycasterSensor.read() method and reshapes into image. + + Returns + ------- + torch.Tensor + The depth image with shape (height, width). + """ + return self.read().distances.reshape(self._options.pattern.height, self._options.pattern.width) diff --git a/genesis/sensors/raycaster/patterns.py b/genesis/sensors/raycaster/patterns.py new file mode 100644 index 000000000..a4cf1629f --- /dev/null +++ b/genesis/sensors/raycaster/patterns.py @@ -0,0 +1,300 @@ +import math +from dataclasses import dataclass +from typing import Sequence + +import torch + +import genesis as gs +from genesis.utils.geom import spherical_to_cartesian + + +@dataclass +class RaycastPattern: + """ + Base class for raycast patterns. + """ + + def __init__(self): + self._return_shape: tuple[int, ...] = self._get_return_shape() + self._ray_dirs: torch.Tensor = torch.empty((*self._return_shape, 3), dtype=gs.tc_float, device=gs.device) + self._ray_starts: torch.Tensor = torch.empty((*self._return_shape, 3), dtype=gs.tc_float, device=gs.device) + self.compute_ray_dirs() + self.compute_ray_starts() + + def _get_return_shape(self) -> tuple[int, ...]: + """Get the shape of the ray vectors, e.g. (n_scan_lines, n_points_per_line) or (n_rays,)""" + raise NotImplementedError(f"{type(self).__name__} must implement `get_return_shape()`.") + + def compute_ray_dirs(self): + """ + Update ray_dirs, the local direction vectors of the rays. + """ + raise NotImplementedError(f"{type(self).__name__} must implement `compute_ray_dirs()`.") + + def compute_ray_starts(self): + """ + Update ray_starts, the local start positions of the rays. + + As a default, all rays will start at the local origin. + """ + self._ray_starts.fill_(0.0) + + @property + def return_shape(self) -> tuple[int, ...]: + return self._return_shape + + @property + def ray_dirs(self) -> torch.Tensor: + return self._ray_dirs + + @property + def ray_starts(self) -> torch.Tensor: + return self._ray_starts + + +# ============================== Generic Patterns ============================== + + +class GridPattern(RaycastPattern): + """ + Configuration for grid-based ray casting. + + Defines a 2D grid of rays in the sensor coordinate system. + + Parameters + ---------- + resolution : float + Grid spacing in meters. + size : tuple[float, float] + Grid dimensions (length, width) in meters. + direction : tuple[float, float, float] + Ray direction vector. + """ + + def __init__( + self, + resolution: float = 0.1, + size: tuple[float, float] = (2.0, 2.0), + direction: tuple[float, float, float] = (0.0, 0.0, -1.0), + ): + if resolution < 1e-3: + gs.raise_exception(f"Resolution should be at least 1e-3 (1mm). Got `{resolution}`.") + self.coords = [ + torch.arange(-size / 2, size / 2 + gs.EPS, resolution, dtype=gs.tc_float, device=gs.device) for size in size + ] + self.direction = torch.tensor(direction, dtype=gs.tc_float, device=gs.device) + + super().__init__() + + def _get_return_shape(self) -> tuple[int, ...]: + return (len(self.coords[0]), len(self.coords[1])) + + def compute_ray_dirs(self): + self._ray_dirs[:] = self.direction.expand((*self._return_shape, 3)) + + def compute_ray_starts(self): + grid_x, grid_y = torch.meshgrid(*self.coords, indexing="xy") + self._ray_starts[..., 0] = grid_x + self._ray_starts[..., 1] = grid_y + self._ray_starts[..., 2] = 0.0 + + +def _generate_uniform_angles( + n_points: tuple[int, int], + fov: tuple[float | tuple[float, float] | None, float | tuple[float, float] | None], + res: tuple[float | None, float | None], + angles: tuple[Sequence[float] | None, Sequence[float] | None], +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Helper function to generate uniform angles given various formats (n and fov, res and fov, or angles). + """ + return_angles = [] + + for n_points_i, fov_i, res_i, angles_i in zip(n_points, fov, res, angles): + if angles_i is None: + assert fov_i is not None, "FOV should be provided if angles not given." + + if res_i is not None: + if isinstance(fov_i, Sequence): + f_min, f_max = fov_i + else: + f_max = fov_i / 2.0 + f_min = -f_max + n_points_i = math.ceil((f_max - f_min) / res_i) + 1 + + assert n_points_i is not None + + if isinstance(fov_i, Sequence): + f_min, f_max = fov_i + fov_size = f_max - f_min + else: + f_max = fov_i / 2.0 + f_min = -f_max + fov_size = fov_i + + assert fov_size <= 360.0 + gs.EPS, "FOV should not be larger than a full rotation." + + # Avoid duplicate angle at 0/360 degrees + if fov_size >= 360.0 - gs.EPS: + f_max -= fov_size / (n_points_i - 1) * 0.5 + + angles_i = torch.linspace(f_min, f_max, n_points_i, dtype=gs.tc_float, device=gs.device) + else: + angles_i = torch.tensor(angles_i, dtype=gs.tc_float, device=gs.device) + + return_angles.append(torch.deg2rad(angles_i)) + + return tuple(return_angles) + + +class SphericalPattern(RaycastPattern): + """ + Configuration for spherical ray pattern. + + Either specify: + - (`n_points`, `fov`) for uniform spacing by count. + - (`angular_resolution`, `fov`) for uniform spacing by resolution. + - `angles` for custom angles. + + + Parameters + ---------- + fov: tuple[float | tuple[float, float], float | tuple[float, float]] + Field of view in degrees for horizontal and vertical directions. Defaults to (360.0, 30.0). + If a single float is provided, the FOV is centered around 0 degrees. + If a tuple is provided, it specifies the (min, max) angles. + n_points: tuple[int, int] + Number of horizontal/azimuth and vertical/elevation scan lines. Defaults to (64, 128). + angular_resolution: tuple[float, float], optional + Horizontal and vertical angular resolution in degrees. Overrides n_points if provided. + angles: tuple[Sequence[float], Sequence[float]], optional + Array of horizontal/vertical angles. Overrides the other options if provided. + """ + + def __init__( + self, + fov: tuple[float | tuple[float, float], float | tuple[float, float]] = (360.0, 60.0), + n_points: tuple[int, int] = (128, 64), + angular_resolution: tuple[float | None, float | None] = (None, None), + angles: tuple[Sequence[float] | None, Sequence[float] | None] = (None, None), + ): + for fov_i in fov: + if (isinstance(fov_i, float) and (fov_i < 0 or fov_i > 360.0 + gs.EPS)) or ( + isinstance(fov_i, tuple) and (fov_i[1] - fov_i[0] > 360.0 + gs.EPS) + ): + gs.raise_exception(f"[{type(self).__name__}] FOV should be between 0 and 360. Got: {fov}.") + + self.angles = _generate_uniform_angles(n_points, fov, angular_resolution, angles) + + super().__init__() + + def _get_return_shape(self) -> tuple[int, ...]: + return tuple(len(a) for a in self.angles) + + def compute_ray_dirs(self): + meshgrid = torch.meshgrid(*self.angles, indexing="ij") + self._ray_dirs[:] = spherical_to_cartesian(*meshgrid) + + +# ============================== Camera Patterns ============================== + + +def _compute_focal_lengths( + width: int, height: int, fov_horizontal: float | None, fov_vertical: float | None +) -> tuple[float, float]: + """ + Helper function to compute focal lengths given image dimensions and FOV. + """ + if fov_horizontal is not None and fov_vertical is None: + fh_rad = math.radians(fov_horizontal) + fv_rad = 2.0 * math.atan((height / width) * math.tan(fh_rad / 2.0)) + elif fov_vertical is not None and fov_horizontal is None: + fv_rad = math.radians(fov_vertical) + fh_rad = 2.0 * math.atan((width / height) * math.tan(fv_rad / 2.0)) + else: + fh_rad = math.radians(fov_horizontal) + fv_rad = math.radians(fov_vertical) + + fx = width / (2.0 * math.tan(fh_rad / 2.0)) + fy = height / (2.0 * math.tan(fv_rad / 2.0)) + + return fx, fy + + +class DepthCameraPattern(RaycastPattern): + """Configuration for pinhole depth camera ray casting. + + Parameters + ---------- + width : int + Image width in pixels. + height : int + Image height in pixels. + fx : float | None + Focal length in x direction (pixels). Computed from FOV if None. + fy : float | None + Focal length in y direction (pixels). Computed from FOV if None. + cx : float | None + Principal point x coordinate (pixels). Defaults to image center if None. + cy : float | None + Principal point y coordinate (pixels). Defaults to image center if None. + fov_horizontal : float + Horizontal field of view in degrees. Used to compute fx if fx is None. + fov_vertical : float | None + Vertical field of view in degrees. Used to compute fy if fy is None. + """ + + def __init__( + self, + width: int = 128, + height: int = 96, + fx: float | None = None, + fy: float | None = None, + cx: float | None = None, + cy: float | None = None, + fov_horizontal: float = 90.0, + fov_vertical: float | None = None, + ): + + if width <= 0 or height <= 0: + gs.raise_exception(f"[{type(self).__name__}] Image dimensions must be positive. Got: ({width}, {height})") + + self.width = width + self.height = height + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + self.fov_horizontal = fov_horizontal + self.fov_vertical = fov_vertical + + super().__init__() + + def _get_return_shape(self) -> tuple[int, ...]: + return (self.height, self.width) + + def compute_ray_dirs(self): + W, H = int(self.width), int(self.height) + + fx, fy, cx, cy = self.fx, self.fy, self.cx, self.cy + if fx is None or fy is None: + fx, fy = _compute_focal_lengths(W, H, self.fov_horizontal, self.fov_vertical) + if cx is None: + cx = W * 0.5 + if cy is None: + cy = H * 0.5 + + u = torch.arange(0, W, dtype=gs.tc_float, device=gs.device) + 0.5 + v = torch.arange(0, H, dtype=gs.tc_float, device=gs.device) + 0.5 + uu, vv = torch.meshgrid(u, v, indexing="xy") + + # standard camera frame coordinates + x_c = (uu - cx) / fx + y_c = (vv - cy) / fy + z_c = torch.ones_like(x_c) + + # transform to robotics camera frame + dirs = torch.stack([z_c, -x_c, -y_c], dim=-1) + dirs /= torch.linalg.norm(dirs, dim=-1, keepdim=True) + + self._ray_dirs[:] = dirs diff --git a/genesis/sensors/raycaster/raycaster.py b/genesis/sensors/raycaster/raycaster.py new file mode 100644 index 000000000..b8e6df31b --- /dev/null +++ b/genesis/sensors/raycaster/raycaster.py @@ -0,0 +1,605 @@ +import math +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NamedTuple, Type + +import gstaichi as ti +import numpy as np +import torch +from pydantic import Field + +import genesis as gs +import genesis.engine.solvers.rigid.rigid_solver_decomp as rigid_solver_decomp +import genesis.utils.array_class as array_class +from genesis.engine.bvh import AABB, LBVH +from genesis.sensors.sensor_manager import register_sensor +from genesis.utils.geom import ( + ti_normalize, + ti_transform_by_quat, + ti_transform_by_trans_quat, + trans_to_T, + transform_by_quat, + transform_by_trans_quat, +) +from genesis.utils.misc import concat_with_tensor, make_tensor_field, ti_to_torch +from genesis.vis.rasterizer_context import RasterizerContext + +from ..base_sensor import ( + RigidSensorMetadataMixin, + RigidSensorMixin, + RigidSensorOptionsMixin, + Sensor, + SensorOptions, + SharedSensorMetadata, +) +from .patterns import RaycastPattern + +if TYPE_CHECKING: + from genesis.ext.pyrender.mesh import Mesh + from genesis.utils.ring_buffer import TensorRingBuffer + + +DEBUG_COLORS = ( + (1.0, 0.2, 0.2, 1.0), + (0.2, 1.0, 0.2, 1.0), + (0.2, 0.6, 1.0, 1.0), + (1.0, 1.0, 0.2, 1.0), +) +# A constant stack size should be sufficient for BVH traversal. +# https://madmann91.github.io/2021/01/06/bvhs-part-2.html +# https://forums.developer.nvidia.com/t/thinking-parallel-part-ii-tree-traversal-on-the-gpu/148342 +STACK_SIZE = ti.static(64) + + +@ti.func +def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): + """ + Möller-Trumbore ray-triangle intersection. + + Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise + """ + result = ti.Vector.zero(gs.ti_float, 4) + + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Begin calculating determinant - also used to calculate u parameter + h = ray_dir.cross(edge2) + a = edge1.dot(h) + + # Check all conditions in sequence without early returns + valid = True + + t = gs.ti_float(0.0) + u = gs.ti_float(0.0) + v = gs.ti_float(0.0) + f = gs.ti_float(0.0) + s = ti.Vector.zero(gs.ti_float, 3) + q = ti.Vector.zero(gs.ti_float, 3) + + # If determinant is near zero, ray lies in plane of triangle + if ti.abs(a) < gs.EPS: + valid = False + + if valid: + f = 1.0 / a + s = ray_start - v0 + u = f * s.dot(h) + + if u < 0.0 or u > 1.0: + valid = False + + if valid: + q = s.cross(edge1) + v = f * ray_dir.dot(q) + + if v < 0.0 or u + v > 1.0: + valid = False + + if valid: + # At this stage we can compute t to find out where the intersection point is on the line + t = f * edge2.dot(q) + + # Ray intersection + if t <= gs.EPS: + valid = False + + if valid: + result = ti.math.vec4(t, u, v, 1.0) + + return result + + +@ti.func +def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): + """ + Fast ray-AABB intersection test. + Returns the t value of intersection, or -1.0 if no intersection. + """ + result = -1.0 + + # Use the slab method for ray-AABB intersection + sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) + ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) + inv_dir = 1.0 / ray_dir + + t1 = (aabb_min - ray_start) * inv_dir + t2 = (aabb_max - ray_start) * inv_dir + + tmin = ti.min(t1, t2) + tmax = ti.max(t1, t2) + + t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) + t_far = ti.min(tmax.x, tmax.y, tmax.z) + + # Check if ray intersects AABB + if t_near <= t_far: + result = t_near + + return result + + +@ti.kernel +def kernel_update_aabbs( + map_faces: ti.template(), + free_verts_state: array_class.VertsState, + fixed_verts_state: array_class.VertsState, + verts_info: array_class.VertsInfo, + faces_info: array_class.FacesInfo, + aabb_state: array_class.AABBState, +): + for i_b, i_f_ in ti.ndrange(free_verts_state.pos.shape[1], map_faces.shape[0]): + i_f = map_faces[i_f_] + aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) + aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) + + for i in ti.static(range(3)): + i_v = verts_info.verts_state_idx[faces_info.verts_idx[i_f][i]] + if verts_info.is_fixed[faces_info.verts_idx[i_f][i]]: + pos_v = fixed_verts_state.pos[i_v] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) + else: + pos_v = free_verts_state.pos[i_v, i_b] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) + + +@ti.kernel +def kernel_cast_rays( + map_faces: ti.template(), + fixed_verts_state: array_class.VertsState, + free_verts_state: array_class.VertsState, + verts_info: array_class.VertsInfo, + faces_info: array_class.FacesInfo, + bvh_nodes: ti.template(), + bvh_morton_codes: ti.template(), # maps sorted leaves to original triangle indices + links_pos: ti.types.ndarray(ndim=3), # [n_env, n_sensors, 3] + links_quat: ti.types.ndarray(ndim=3), # [n_env, n_sensors, 4] + ray_starts: ti.types.ndarray(ndim=2), # [n_points, 3] + ray_directions: ti.types.ndarray(ndim=2), # [n_points, 3] + max_ranges: ti.types.ndarray(ndim=1), # [n_sensors] + no_hit_values: ti.types.ndarray(ndim=1), # [n_sensors] + is_world_frame: ti.types.ndarray(ndim=1), # [n_sensors] + points_to_sensor_idx: ti.types.ndarray(ndim=1), # [n_points] + sensor_cache_offsets: ti.types.ndarray(ndim=1), # [n_sensors] - cache start index for each sensor + sensor_point_offsets: ti.types.ndarray(ndim=1), # [n_sensors] - point start index for each sensor + sensor_point_counts: ti.types.ndarray(ndim=1), # [n_sensors] - number of points for each sensor + output_hits: ti.types.ndarray(ndim=2), # [n_env, total_cache_size] +): + """ + Taichi kernel for ray casting, accelerated by a Bounding Volume Hierarchy (BVH). + + The result `output_hits` will be a 2D array of shape (n_env, total_cache_size) where in the second dimension, + each sensor's data is stored as [sensor_points (n_points * 3), sensor_ranges (n_points)]. + """ + + n_triangles = map_faces.shape[0] + n_points = ray_starts.shape[0] + # batch, point + for i_b, i_p in ti.ndrange(output_hits.shape[0], n_points): + i_s = points_to_sensor_idx[i_p] + + # --- 1. Setup Ray --- + link_pos = ti.math.vec3(links_pos[i_b, i_s, 0], links_pos[i_b, i_s, 1], links_pos[i_b, i_s, 2]) + link_quat = ti.math.vec4( + links_quat[i_b, i_s, 0], links_quat[i_b, i_s, 1], links_quat[i_b, i_s, 2], links_quat[i_b, i_s, 3] + ) + + ray_start_local = ti.math.vec3(ray_starts[i_p, 0], ray_starts[i_p, 1], ray_starts[i_p, 2]) + ray_start_world = ti_transform_by_trans_quat(ray_start_local, link_pos, link_quat) + + ray_dir_local = ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2]) + ray_direction_world = ti_normalize(ti_transform_by_quat(ray_dir_local, link_quat)) + + # --- 2. BVH Traversal --- + max_range = max_ranges[i_s] + hit_face = -1 + + # Stack for non-recursive traversal + node_stack = ti.Vector.zero(ti.i32, STACK_SIZE) + node_stack[0] = 0 # Start traversal at the root node (index 0) + stack_idx = 1 + + while stack_idx > 0: + stack_idx -= 1 + node_idx = node_stack[stack_idx] + + node = bvh_nodes[i_b, node_idx] + + # Check if ray hits the node's bounding box + aabb_t = ray_aabb_intersection(ray_start_world, ray_direction_world, node.bound.min, node.bound.max) + + if aabb_t >= 0.0 and aabb_t < max_range: + if node.left == -1: # is leaf node + # A leaf node corresponds to one of the sorted triangles. Find the original triangle index. + sorted_leaf_idx = node_idx - (n_triangles - 1) + original_tri_idx = ti.cast(bvh_morton_codes[0, sorted_leaf_idx][1], ti.i32) + + i_f = map_faces[original_tri_idx] + is_fixed = verts_info.is_fixed[faces_info.verts_idx[i_f][0]] + + v0 = ti.Vector.zero(gs.ti_float, 3) + v1 = ti.Vector.zero(gs.ti_float, 3) + v2 = ti.Vector.zero(gs.ti_float, 3) + + if is_fixed: + v0 = fixed_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][0]]] + v1 = fixed_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][1]]] + v2 = fixed_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][2]]] + + else: + v0 = free_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][0]], i_b] + v1 = free_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][1]], i_b] + v2 = free_verts_state.pos[verts_info.verts_state_idx[faces_info.verts_idx[i_f][2]], i_b] + + # Perform the expensive ray-triangle intersection test + hit_result = ray_triangle_intersection(ray_start_world, ray_direction_world, v0, v1, v2) + + if hit_result.w > 0.0 and hit_result.x < max_range and hit_result.x >= 0.0: + max_range = hit_result.x + hit_face = i_f + # hit_u, hit_v could be stored here if needed + + else: # It's an INTERNAL node + # Push children onto the stack for further traversal + # Make sure stack doesn't overflow + if stack_idx < ti.static(STACK_SIZE - 2): + node_stack[stack_idx] = node.left + node_stack[stack_idx + 1] = node.right + stack_idx += 2 + + # --- 3. Process Hit Result --- + # The format of output_hits is: [sensor1 points][sensor1 ranges][sensor2 points][sensor2 ranges]... + i_p_sensor = i_p - sensor_point_offsets[i_s] + i_p_offset = sensor_cache_offsets[i_s] # cumulative cache offset for this sensor + n_points_in_sensor = sensor_point_counts[i_s] # number of points in this sensor + + i_p_dist = i_p_offset + n_points_in_sensor * 3 + i_p_sensor # index for distance output + + if hit_face >= 0: + dist = max_range + # Store distance at: cache_offset + (num_points_in_sensor * 3) + point_idx_in_sensor + output_hits[i_b, i_p_dist] = dist + + if is_world_frame[i_s]: + hit_point = ray_start_world + dist * ray_direction_world + + # Store points at: cache_offset + point_idx_in_sensor * 3 + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 0] = hit_point.x + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 1] = hit_point.y + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 2] = hit_point.z + else: + # Local frame output along provided local ray direction + hit_point = dist * ti_normalize( + ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2]) + ) + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 0] = hit_point.x + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 1] = hit_point.y + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 2] = hit_point.z + + else: + # No hit + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 0] = 0.0 + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 1] = 0.0 + output_hits[i_b, i_p_offset + i_p_sensor * 3 + 2] = 0.0 + output_hits[i_b, i_p_dist] = no_hit_values[i_s] + + +class RaycasterOptions(RigidSensorOptionsMixin, SensorOptions): + """ + Raycaster sensor that performs ray casting to get distance measurements and point clouds. + + Parameters + ---------- + pattern: RaycastPatternOptions + The raycasting pattern for the sensor. + min_range : float, optional + The minimum sensing range in meters. Defaults to 0.0. + max_range : float, optional + The maximum sensing range in meters. Defaults to 20.0. + no_hit_value : float, optional + The value to return for no hit. Defaults to max_range if not specified. + return_world_frame : bool, optional + Whether to return points in the world frame. Defaults to False (local frame). + debug_sphere_radius: float, optional + The radius of each debug sphere drawn in the scene. Defaults to 0.02. + debug_ray_start_color: float, optional + The color of each debug ray start sphere drawn in the scene. Defaults to (0.5, 0.5, 1.0, 1.0). + debug_ray_hit_color: float, optional + The color of each debug ray hit point sphere drawn in the scene. Defaults to (1.0, 0.5, 0.5, 1.0). + """ + + pattern: RaycastPattern + min_range: float = 0.0 + max_range: float = 20.0 + no_hit_value: float = Field(default_factory=lambda data: data["max_range"]) + return_world_frame: bool = False + + debug_sphere_radius: float = 0.02 + debug_ray_start_color: tuple[float, float, float, float] = (0.5, 0.5, 1.0, 1.0) + debug_ray_hit_color: tuple[float, float, float, float] = (1.0, 0.5, 0.5, 1.0) + + def model_post_init(self, _): + if self.min_range < 0.0: + gs.raise_exception(f"[{type(self).__name__}] min_range should be non-negative. Got: {self.min_range}.") + if self.max_range <= self.min_range: + gs.raise_exception( + f"[{type(self).__name__}] max_range {self.max_range} should be greater than min_range {self.min_range}." + ) + + +@dataclass +class RaycasterSharedMetadata(RigidSensorMetadataMixin, SharedSensorMetadata): + bvh: LBVH | None = None + aabb: AABB | None = None + needs_aabb_update: bool = False + map_faces: Any | None = None + n_faces: int = 0 + + sensors_ray_start_idx: list[int] = field(default_factory=list) + total_n_rays: int = 0 + + min_ranges: torch.Tensor = make_tensor_field((0,)) + max_ranges: torch.Tensor = make_tensor_field((0,)) + no_hit_values: torch.Tensor = make_tensor_field((0,)) + return_world_frame: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_bool) + + patterns: list[RaycastPattern] = field(default_factory=list) + ray_dirs: torch.Tensor = make_tensor_field((0, 3)) + ray_starts: torch.Tensor = make_tensor_field((0, 3)) + ray_starts_world: torch.Tensor = make_tensor_field((0, 3)) + ray_dirs_world: torch.Tensor = make_tensor_field((0, 3)) + + points_to_sensor_idx: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int) + sensor_cache_offsets: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int) + sensor_point_offsets: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int) + sensor_point_counts: torch.Tensor = make_tensor_field((0,), dtype_factory=lambda: gs.tc_int) + output_hits: torch.Tensor = make_tensor_field((0, 0)) # FIXME: remove once we have contiguous cache slices + + +class RaycasterData(NamedTuple): + points: torch.Tensor + distances: torch.Tensor + + +@register_sensor(RaycasterOptions, RaycasterSharedMetadata, RaycasterData) +@ti.data_oriented +class RaycasterSensor(RigidSensorMixin, Sensor): + + def __init__( + self, + options: RaycasterOptions, + shared_metadata: RaycasterSharedMetadata, + data_cls: Type[RaycasterData], + manager: "gs.SensorManager", + ): + super().__init__(options, shared_metadata, data_cls, manager) + self.debug_objects: list["Mesh | None"] = [] + self.ray_starts: torch.Tensor = torch.empty((0, 3), device=gs.device, dtype=gs.tc_float) + + @classmethod + def _build_bvh(cls, shared_metadata: RaycasterSharedMetadata): + n_faces = shared_metadata.solver.faces_info.geom_idx.shape[0] + torch_map_faces = torch.arange(n_faces, dtype=torch.int32, device=gs.device) + + shared_metadata.map_faces = ti.field(ti.i32, (n_faces)) + shared_metadata.map_faces.from_torch(torch_map_faces) + shared_metadata.n_faces = n_faces + + shared_metadata.aabb = AABB(n_batches=shared_metadata.solver.free_verts_state.pos.shape[1], n_aabbs=n_faces) + + rigid_solver_decomp.kernel_update_all_verts( + geoms_state=shared_metadata.solver.geoms_state, + verts_info=shared_metadata.solver.verts_info, + free_verts_state=shared_metadata.solver.free_verts_state, + fixed_verts_state=shared_metadata.solver.fixed_verts_state, + ) + + kernel_update_aabbs( + map_faces=shared_metadata.map_faces, + free_verts_state=shared_metadata.solver.free_verts_state, + fixed_verts_state=shared_metadata.solver.fixed_verts_state, + verts_info=shared_metadata.solver.verts_info, + faces_info=shared_metadata.solver.faces_info, + aabb_state=shared_metadata.aabb, + ) + shared_metadata.bvh = LBVH(shared_metadata.aabb) + shared_metadata.bvh.build() + + def build(self): + super().build() # set shared metadata from RigidSensorMixin + + # first lidar sensor initialization: build aabb and bvh + if self._shared_metadata.bvh is None: + geom_is_fixed = ti_to_torch(self._shared_metadata.solver.geoms_info.is_fixed) + self._shared_metadata.needs_aabb_update = bool((~geom_is_fixed).any().item()) + + self._shared_metadata.output_hits = torch.empty( + (self._manager._sim._B, 0), device=gs.device, dtype=gs.tc_float + ) + self._shared_metadata.sensor_cache_offsets = concat_with_tensor( + self._shared_metadata.sensor_cache_offsets, 0 + ) + self._build_bvh(self._shared_metadata) + + self._shared_metadata.patterns.append(self._options.pattern) + pos_offset = self._shared_metadata.offsets_pos[0, -1, :] # all envs have same offset on build + quat_offset = self._shared_metadata.offsets_quat[0, -1, :] + + ray_starts = self._options.pattern.ray_starts.reshape(-1, 3) + self.ray_starts = transform_by_trans_quat(ray_starts, pos_offset, quat_offset) + self._shared_metadata.ray_starts = torch.cat([self._shared_metadata.ray_starts, self.ray_starts]) + + ray_dirs = self._options.pattern.ray_dirs.reshape(-1, 3) + ray_dirs = transform_by_quat(ray_dirs, quat_offset) + self._shared_metadata.ray_dirs = torch.cat([self._shared_metadata.ray_dirs, ray_dirs]) + + num_rays = math.prod(self._options.pattern.return_shape) + self._shared_metadata.sensors_ray_start_idx.append(self._shared_metadata.total_n_rays) + + # These fields are used to properly index into the big cache tensor in kernel_cast_rays + self._shared_metadata.output_hits = concat_with_tensor( + self._shared_metadata.output_hits, + torch.empty((self._manager._sim._B, self._cache_size), device=gs.device, dtype=gs.tc_float), + dim=-1, + ) + self._shared_metadata.sensor_cache_offsets = concat_with_tensor( + self._shared_metadata.sensor_cache_offsets, self._cache_size + ) + self._shared_metadata.sensor_point_offsets = concat_with_tensor( + self._shared_metadata.sensor_point_offsets, self._shared_metadata.total_n_rays + ) + self._shared_metadata.sensor_point_counts = concat_with_tensor( + self._shared_metadata.sensor_point_counts, num_rays + ) + self._shared_metadata.total_n_rays += num_rays + + self._shared_metadata.points_to_sensor_idx = concat_with_tensor( + self._shared_metadata.points_to_sensor_idx, + [self._idx] * num_rays, + flatten=True, + ) + self._shared_metadata.return_world_frame = concat_with_tensor( + self._shared_metadata.return_world_frame, self._options.return_world_frame + ) + self._shared_metadata.min_ranges = concat_with_tensor(self._shared_metadata.min_ranges, self._options.min_range) + self._shared_metadata.max_ranges = concat_with_tensor(self._shared_metadata.max_ranges, self._options.max_range) + no_hit_value = self._options.no_hit_value if self._options.no_hit_value is not None else self._options.max_range + self._shared_metadata.no_hit_values = concat_with_tensor(self._shared_metadata.no_hit_values, no_hit_value) + + @classmethod + def reset(cls, shared_metadata: RaycasterSharedMetadata, envs_idx): + super().reset(shared_metadata, envs_idx) + cls._build_bvh(shared_metadata) + + def _get_return_format(self) -> tuple[tuple[int, ...], ...]: + shape = self._options.pattern.return_shape + return (*shape, 3), shape + + @classmethod + def _get_cache_dtype(cls) -> torch.dtype: + return gs.tc_float + + @classmethod + def _update_shared_ground_truth_cache( + cls, shared_metadata: RaycasterSharedMetadata, shared_ground_truth_cache: torch.Tensor + ): + if not shared_metadata.needs_aabb_update: + rigid_solver_decomp.kernel_update_all_verts( + geoms_state=shared_metadata.solver.geoms_state, + verts_info=shared_metadata.solver.verts_info, + free_verts_state=shared_metadata.solver.free_verts_state, + fixed_verts_state=shared_metadata.solver.fixed_verts_state, + ) + + kernel_update_aabbs( + map_faces=shared_metadata.map_faces, + free_verts_state=shared_metadata.solver.free_verts_state, + fixed_verts_state=shared_metadata.solver.fixed_verts_state, + verts_info=shared_metadata.solver.verts_info, + faces_info=shared_metadata.solver.faces_info, + aabb_state=shared_metadata.aabb, + ) + + links_pos = shared_metadata.solver.get_links_pos(links_idx=shared_metadata.links_idx) + links_quat = shared_metadata.solver.get_links_quat(links_idx=shared_metadata.links_idx) + if shared_metadata.solver.n_envs == 0: + links_pos = links_pos.unsqueeze(0) + links_quat = links_quat.unsqueeze(0) + + kernel_cast_rays( + map_faces=shared_metadata.map_faces, + fixed_verts_state=shared_metadata.solver.fixed_verts_state, + free_verts_state=shared_metadata.solver.free_verts_state, + verts_info=shared_metadata.solver.verts_info, + faces_info=shared_metadata.solver.faces_info, + bvh_nodes=shared_metadata.bvh.nodes, + bvh_morton_codes=shared_metadata.bvh.morton_codes, + links_pos=links_pos.contiguous(), + links_quat=links_quat.contiguous(), + ray_starts=shared_metadata.ray_starts, + ray_directions=shared_metadata.ray_dirs, + max_ranges=shared_metadata.max_ranges, + no_hit_values=shared_metadata.no_hit_values, + is_world_frame=shared_metadata.return_world_frame, + points_to_sensor_idx=shared_metadata.points_to_sensor_idx, + sensor_cache_offsets=shared_metadata.sensor_cache_offsets, + sensor_point_offsets=shared_metadata.sensor_point_offsets, + sensor_point_counts=shared_metadata.sensor_point_counts, + output_hits=( + shared_ground_truth_cache if shared_ground_truth_cache.is_contiguous() else shared_metadata.output_hits + ), + ) + if not shared_ground_truth_cache.is_contiguous(): + shared_ground_truth_cache[:] = shared_metadata.output_hits + + @classmethod + def _update_shared_cache( + cls, + shared_metadata: RaycasterSharedMetadata, + shared_ground_truth_cache: torch.Tensor, + shared_cache: torch.Tensor, + buffered_data: "TensorRingBuffer", + ): + buffered_data.append(shared_ground_truth_cache) + cls._apply_delay_to_shared_cache(shared_metadata, shared_cache, buffered_data) + + def _draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): + """ + Draw hit points as spheres in the scene. + + Only draws for first rendered environment. + """ + env_idx = context.rendered_envs_idx[0] + + points = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None).points.reshape(-1, 3) + + pos = self._link.get_pos(envs_idx=env_idx) + quat = self._link.get_quat(envs_idx=env_idx) + + ray_starts = transform_by_trans_quat(self.ray_starts, pos, quat) + + if not self._options.return_world_frame: + points = transform_by_trans_quat(points + self.ray_starts, pos, quat) + + if not self.debug_objects: + for ray_start in ray_starts: + self.debug_objects.append( + context.draw_debug_sphere( + ray_start, + radius=self._options.debug_sphere_radius, + color=self._options.debug_ray_start_color, + ) + ) + for point in points: + self.debug_objects.append( + context.draw_debug_sphere( + point, + radius=self._options.debug_sphere_radius, + color=self._options.debug_ray_hit_color, + ) + ) + else: + buffer_updates.update( + context.get_buffer_debug_objects(self.debug_objects, trans_to_T(torch.cat([ray_starts, points], dim=0))) + ) diff --git a/genesis/sensors/sensor_manager.py b/genesis/sensors/sensor_manager.py index 627545599..834f99419 100644 --- a/genesis/sensors/sensor_manager.py +++ b/genesis/sensors/sensor_manager.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Type +import numpy as np import torch from genesis.utils.ring_buffer import TensorRingBuffer @@ -11,7 +12,7 @@ class SensorManager: - SENSOR_TYPES_MAP: dict[Type["SensorOptions"], tuple[Type["Sensor"], Type["SharedSensorMetadata"]]] = {} + SENSOR_TYPES_MAP: dict[Type["SensorOptions"], tuple[Type["Sensor"], Type["SharedSensorMetadata"], Type[tuple]]] = {} def __init__(self, sim): self._sim = sim @@ -27,11 +28,11 @@ def __init__(self, sim): def create_sensor(self, sensor_options: "SensorOptions") -> "Sensor": sensor_options.validate(self._sim.scene) - sensor_cls, metadata_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] + sensor_cls, metadata_cls, data_cls = SensorManager.SENSOR_TYPES_MAP[type(sensor_options)] self._sensors_by_type.setdefault(sensor_cls, []) if sensor_cls not in self._sensors_metadata: self._sensors_metadata[sensor_cls] = metadata_cls() - sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), self) + sensor = sensor_cls(sensor_options, len(self._sensors_by_type[sensor_cls]), data_cls, self) self._sensors_by_type[sensor_cls].append(sensor) return sensor @@ -86,10 +87,10 @@ def step(self): self._buffered_data[dtype][:, cache_slice], ) - def draw_debug(self, context: "RasterizerContext"): + def draw_debug(self, context: "RasterizerContext", buffer_updates: dict[str, np.ndarray]): for sensor in self.sensors: if sensor._options.draw_debug: - sensor._draw_debug(context) + sensor._draw_debug(context, buffer_updates) def reset(self, envs_idx=None): envs_idx = self._sim._scene._sanitize_envs_idx(envs_idx) @@ -119,9 +120,11 @@ def sensors(self): return tuple([sensor for sensor_list in self._sensors_by_type.values() for sensor in sensor_list]) -def register_sensor(options_cls: Type["SensorOptions"], metadata_cls: Type["SharedSensorMetadata"]): +def register_sensor( + options_cls: Type["SensorOptions"], metadata_cls: Type["SharedSensorMetadata"], data_cls: Type[tuple] +): def _impl(sensor_cls: Type["Sensor"]): - SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls, metadata_cls + SensorManager.SENSOR_TYPES_MAP[options_cls] = sensor_cls, metadata_cls, data_cls return sensor_cls return _impl diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index 37cf008a4..4e27faf91 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -1,14 +1,12 @@ import dataclasses -import inspect -import os from functools import partial -from typing import Any, Callable, Type, cast +import os import gstaichi as ti -from gstaichi.lang._fast_caching import FIELD_METADATA_CACHE_VALUE, args_hasher +from gstaichi.lang._fast_caching import FIELD_METADATA_CACHE_VALUE +import numpy as np import genesis as gs -import numpy as np # as a temporary solution, we get is_ndarray from os's environment variable use_ndarray = os.environ.get("GS_USE_NDARRAY", "0") == "1" @@ -1851,14 +1849,21 @@ def __init__(self): return ClassEdgesInfo() -# =========================================== FreeVertsState =========================================== +# =========================================== VertsState =========================================== @dataclasses.dataclass -class StructFreeVertsState: +class StructVertsState: pos: V_ANNOTATION +@ti.data_oriented +class ClassVertsState: + def __init__(self, kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def get_free_verts_state(solver): shape = solver._batch_shape(solver.n_free_verts_) kwargs = { @@ -1866,24 +1871,9 @@ def get_free_verts_state(solver): } if use_ndarray: - return StructFreeVertsState(**kwargs) + return StructVertsState(**kwargs) else: - - @ti.data_oriented - class ClassFreeVertsState: - def __init__(self): - for k, v in kwargs.items(): - setattr(self, k, v) - - return ClassFreeVertsState() - - -# =========================================== FixedVertsState =========================================== - - -@dataclasses.dataclass -class StructFixedVertsState: - pos: V_ANNOTATION + return ClassVertsState(kwargs) def get_fixed_verts_state(solver): @@ -1893,16 +1883,9 @@ def get_fixed_verts_state(solver): } if use_ndarray: - return StructFixedVertsState(**kwargs) + return StructVertsState(**kwargs) else: - - @ti.data_oriented - class ClassFixedVertsState: - def __init__(self): - for k, v in kwargs.items(): - setattr(self, k, v) - - return ClassFixedVertsState() + return ClassVertsState(kwargs) # =========================================== VvertsInfo =========================================== @@ -2251,8 +2234,7 @@ def __init__(self, solver): LinksInfo = ti.template() if not use_ndarray else StructLinksInfo JointsInfo = ti.template() if not use_ndarray else StructJointsInfo JointsState = ti.template() if not use_ndarray else StructJointsState -FreeVertsState = ti.template() if not use_ndarray else StructFreeVertsState -FixedVertsState = ti.template() if not use_ndarray else StructFixedVertsState +VertsState = ti.template() if not use_ndarray else StructVertsState VertsInfo = ti.template() if not use_ndarray else StructVertsInfo EdgesInfo = ti.template() if not use_ndarray else StructEdgesInfo FacesInfo = ti.template() if not use_ndarray else StructFacesInfo @@ -2273,3 +2255,4 @@ def __init__(self, solver): SDFInfo = ti.template() if not use_ndarray else StructSDFInfo ContactIslandState = ti.template() if not use_ndarray else StructContactIslandState DiffContactInput = ti.template() if not use_ndarray else StructDiffContactInput +AABBState = ti.template() diff --git a/genesis/utils/geom.py b/genesis/utils/geom.py index 3c46b2c30..d651c968c 100644 --- a/genesis/utils/geom.py +++ b/genesis/utils/geom.py @@ -1584,6 +1584,31 @@ def transform_inertia_by_T(inertia_tensor, T, mass): return R @ inertia_tensor @ R.T + translation_inertia +def spherical_to_cartesian(theta: torch.Tensor, phi: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Convert spherical coordinates to Cartesian coordinates. + + Parameters + ---------- + theta : torch.Tensor + Horizontal angles in radians. + phi : torch.Tensor + Vertical angles in radians. + + Returns + ------- + vectors : torch.Tensor + Vectors in cartesian coordinates as tensor of shape (..., 3). + """ + cos_phi = torch.cos(phi) + + x = torch.cos(theta) * cos_phi # forward + y = torch.sin(theta) * cos_phi # left + z = torch.sin(phi) # up + + return torch.stack([x, y, z], dim=-1) + + def slerp(q0, q1, t): """ Perform spherical linear interpolation between two quaternions. diff --git a/genesis/utils/misc.py b/genesis/utils/misc.py index 7a4909a3c..928004c1d 100644 --- a/genesis/utils/misc.py +++ b/genesis/utils/misc.py @@ -3,29 +3,28 @@ import functools import logging import math +import os import platform import random -import types -import shutil import sys -import os +import types import weakref from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Callable, NoReturn, Optional, Type -import numpy as np import cpuinfo +import gstaichi as ti +import numpy as np import psutil import pyglet import torch -import gstaichi as ti from gstaichi.lang.util import is_ti_template, to_pytorch_type, to_numpy_type from gstaichi._kernels import tensor_to_ext_arr, matrix_to_ext_arr, ndarray_to_ext_arr, ndarray_matrix_to_ext_arr from gstaichi.lang import impl -from gstaichi.types import primitive_types from gstaichi.lang.exception import handle_exception_from_cpp +from gstaichi.types import primitive_types import genesis as gs from genesis.constants import backend as gs_backend @@ -312,7 +311,9 @@ def is_approx_multiple(a, b, tol=1e-7): return abs(a % b) < tol or abs(b - (a % b)) < tol -def concat_with_tensor(tensor: torch.Tensor, value, expand: tuple[int, ...] | None = None, dim: int = 0): +def concat_with_tensor( + tensor: torch.Tensor, value, expand: tuple[int, ...] | None = None, dim: int = 0, flatten: bool = False +): """Helper method to concatenate a value (not necessarily a tensor) with a tensor.""" if not isinstance(value, torch.Tensor): value = torch.tensor([value], dtype=tensor.dtype, device=tensor.device) @@ -320,6 +321,8 @@ def concat_with_tensor(tensor: torch.Tensor, value, expand: tuple[int, ...] | No value = value.expand(*expand) if dim < 0: dim = tensor.ndim + dim + if flatten: + value = value.flatten() assert ( 0 <= dim < tensor.ndim and tensor.ndim == value.ndim @@ -590,14 +593,14 @@ def _ti_to_python( is_out_of_bounds = not (0 <= mask < _ti_data_shape[i]) elif isinstance(mask, torch.Tensor): if not mask.ndim <= 1: - gs.raise_exception(f"Expecting 1D tensor for masks.") + gs.raise_exception("Expecting 1D tensor for masks.") # Resort on post-mortem analysis for bounds check because runtime would be to costly is_out_of_bounds = None else: # np.ndarray, list, tuple, range try: mask_start, mask_end = min(mask), max(mask) except ValueError: - gs.raise_exception(f"Expecting 1D tensor for masks.") + gs.raise_exception("Expecting 1D tensor for masks.") is_out_of_bounds = not (0 <= mask_start <= mask_end < _ti_data_shape[i]) if is_out_of_bounds: gs.raise_exception("Masks are out-of-range.") diff --git a/genesis/vis/rasterizer_context.py b/genesis/vis/rasterizer_context.py index c3fa3a856..6ede33367 100644 --- a/genesis/vis/rasterizer_context.py +++ b/genesis/vis/rasterizer_context.py @@ -803,8 +803,8 @@ def update_fem(self, buffer_updates): update_data = self._scene.reorder_vertices(node, vertices) buffer_updates[self._scene.get_buffer_id(node, "pos")] = update_data - def update_sensors(self): - self.sim._sensor_manager.draw_debug(self) + def update_sensors(self, buffer_updates): + self.sim._sensor_manager.draw_debug(self, buffer_updates) def on_lights(self): for light in self.lights: @@ -907,15 +907,20 @@ def draw_debug_points(self, poss, colors=(1.0, 0.0, 0.0, 0.5)): self.add_external_node(node) return node - def update_debug_objects(self, objs, poses): - poses = tensor_to_array(poses) + def get_buffer_debug_objects(self, objs, poses): buffer_updates = {} for obj, pose in zip(objs, poses): + pose = tensor_to_array(pose) + if pose.ndim != 3: + pose = np.tile(pose[np.newaxis], (max(self.scene.n_envs, 1), 1, 1)) obj._bounds = None obj.primitives[0].poses = pose node = self.external_nodes[obj.name] - buffer_updates[self._scene.get_buffer_id(node, "model")] = poses.swapaxes(-2, -1) - self.jit.update_buffer(buffer_updates) + buffer_updates[self._scene.get_buffer_id(node, "model")] = pose.transpose((0, 2, 1)) + return buffer_updates + + def update_debug_objects(self, objs, poses): + self.jit.update_buffer(self.get_buffer_debug_objects(objs, poses)) def clear_debug_object(self, obj): self.clear_external_node(obj) @@ -949,7 +954,7 @@ def update(self, force_render: bool = False): self.update_sph(self.buffer) self.update_pbd(self.buffer) self.update_fem(self.buffer) - self.update_sensors() + self.update_sensors(self.buffer) def add_light(self, light): # light direction is light pose's -z frame diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index a15f80b45..4eba3af82 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -142,7 +142,7 @@ def setup_camera(self): pose = gu.pos_lookat_up_to_T(self._camera_init_pos, self._camera_init_lookat, self._camera_up) self._camera_node = self.context.add_node(pyrender.PerspectiveCamera(yfov=yfov), pose=pose) - def update(self, auto_refresh=None): + def update(self, auto_refresh=None, force=False): if self._followed_entity is not None: self.update_following() @@ -150,7 +150,7 @@ def update(self, auto_refresh=None): with self.lock: # Update context - self.context.update() + self.context.update(force) # Refresh viewer by default if and if this is possible if auto_refresh is None: diff --git a/genesis/vis/visualizer.py b/genesis/vis/visualizer.py index 1f2c14b48..0f88295af 100644 --- a/genesis/vis/visualizer.py +++ b/genesis/vis/visualizer.py @@ -42,7 +42,7 @@ def __init__(self, scene, show_viewer, vis_options, viewer_options, renderer_opt self._context = RasterizerContext(vis_options) try: - screen_height, _, screen_scale = gs.utils.try_get_display_size() + screen_height, _screen_width, screen_scale = gs.utils.try_get_display_size() self._has_display = True except Exception as e: if show_viewer: diff --git a/pyproject.toml b/pyproject.toml index 024e493b8..89026c6ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10,<3.14" dependencies = [ "psutil", "gstaichi==2.5.0", - "pydantic>=2.7.1", + "pydantic>=2.11.0", "numpy>=1.26.4", "trimesh", "scipy>=1.14", @@ -79,7 +79,7 @@ dev = [ # - 16.0 is causing pytest-xdist to crash in case of failure or skipped tests "pytest-rerunfailures<16.0", "syrupy", - "huggingface_hub", + "huggingface_hub[hf_xet]", "wandb", "ipython", # * Mujoco 3.3.6 made contact islands an opt-out option instead of opt-in, @@ -154,7 +154,7 @@ addopts = [ "--random-order-seed=0", # "--max-worker-restart=0", "--durations=0", - "--durations-min=40.0", + "--durations-min=100.0", "-m not (benchmarks or examples)", ] filterwarnings = [ diff --git a/tests/conftest.py b/tests/conftest.py index e6cca7734..4d721159d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,8 +44,8 @@ if not has_display and has_egl: # It is necessary to configure pyglet in headless mode if necessary before importing Genesis. # Note that environment variables are used instead of global options to ease option propagation to subprocesses. + pyglet.options["headless"] = True os.environ["PYGLET_HEADLESS"] = "1" - os.environ["GS_VIEWER_ALLOW_OFFSCREEN"] = "1" IS_INTERACTIVE_VIEWER_AVAILABLE = has_display or has_egl @@ -92,6 +92,10 @@ def pytest_cmdline_main(config: pytest.Config) -> None: if show_viewer: config.option.numprocesses = 0 + # Force headless rendering if available and the interactive viewer is disabled + if not show_viewer and has_egl: + pyglet.options["headless"] = True + # Disable low-level parallelization if distributed framework is enabled. # FIXME: It should be set to `max(int(physical_core_count / num_workers), 1)`, but 'num_workers' may be unknown. if not is_benchmarks and config.option.numprocesses != 0: diff --git a/tests/test_examples.py b/tests/test_examples.py index 9dbc0229b..ecaf27fc9 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,8 +1,6 @@ import os import sys import subprocess -import shutil -import shlex from pathlib import Path import pytest @@ -27,8 +25,6 @@ "multi_gpu.py", "fem_cube_linked_with_arm.py", # FIXME: segfault on exit "single_franka_batch_render.py", # FIXME: segfault on exit - "imu_franka.py", # FIXME: broken - "contact_force_go2.py", # FIXME: broken "cut_dragon.py", # FIXME: Only supported on Linux } diff --git a/tests/test_integration.py b/tests/test_integration.py index 1f8dc5f1a..cf25a121e 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -314,4 +314,4 @@ def test_franka_panda_grasp_fem_entity(primitive_type, show_viewer): franka.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) scene.step() box_pos_post = obj.get_state().pos.mean(dim=-2) - assert_allclose(box_pos_f, box_pos_post, atol=2e-4) + assert_allclose(box_pos_f, box_pos_post, atol=5e-4) diff --git a/tests/test_render.py b/tests/test_render.py index f69fff0f3..dad31f8bb 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -33,12 +33,11 @@ class RENDERER_TYPE(enum.IntEnum): def renderer(renderer_type): if renderer_type == RENDERER_TYPE.RASTERIZER: return gs.renderers.Rasterizer() - elif renderer_type == RENDERER_TYPE.RAYTRACER: + if renderer_type == RENDERER_TYPE.RAYTRACER: return gs.renderers.RayTracer() - else: - return gs.renderers.BatchRenderer( - use_rasterizer=renderer_type == RENDERER_TYPE.BATCHRENDER_RASTERIZER, - ) + return gs.renderers.BatchRenderer( + use_rasterizer=renderer_type == RENDERER_TYPE.BATCHRENDER_RASTERIZER, + ) @pytest.fixture(scope="function") @@ -114,15 +113,33 @@ def test_render_api(show_viewer, renderer_type, renderer): "renderer_type", [RENDERER_TYPE.RASTERIZER, RENDERER_TYPE.BATCHRENDER_RASTERIZER, RENDERER_TYPE.BATCHRENDER_RAYTRACER], ) -def test_deterministic(tmp_path, show_viewer, tol): +def test_deterministic(tmp_path, renderer_type, renderer, show_viewer, tol): scene = gs.Scene( vis_options=gs.options.VisOptions( # rendered_envs_idx=(0, 1, 2), env_separate_rigid=False, ), + renderer=renderer, show_viewer=show_viewer, show_FPS=False, ) + if renderer_type in (RENDERER_TYPE.BATCHRENDER_RASTERIZER, RENDERER_TYPE.BATCHRENDER_RAYTRACER): + scene.add_light( + pos=(0.0, 0.0, 1.5), + dir=(1.0, 1.0, -2.0), + directional=True, + castshadow=True, + cutoff=45.0, + intensity=0.5, + ) + scene.add_light( + pos=(4.0, -4.0, 4.0), + dir=(-1.0, 1.0, -1.0), + directional=False, + castshadow=True, + cutoff=45.0, + intensity=0.5, + ) plane = scene.add_entity( morph=gs.morphs.Plane(), surface=gs.surfaces.Aluminium( @@ -264,7 +281,7 @@ def test_deterministic(tmp_path, show_viewer, tol): rgb_array, *_ = cam.render( rgb=True, depth=False, segmentation=False, colorize_seg=False, normal=False, force_render=True ) - assert np.max(np.std(rgb_array.reshape((-1, 3)), axis=0)) > 10.0 + assert tensor_to_array(rgb_array).reshape((-1, 3)).astype(np.float32).std(axis=0).max() > 10.0 robots_rgb_arrays.append(rgb_array) steps_rgb_arrays.append(robots_rgb_arrays) @@ -296,6 +313,8 @@ def test_render_api_advanced(tmp_path, n_envs, show_viewer, png_snapshot, render shadow=(renderer_type != RENDERER_TYPE.RASTERIZER), ), renderer=renderer, + show_viewer=False, + show_FPS=False, ) plane = scene.add_entity( morph=gs.morphs.Plane(), @@ -522,6 +541,7 @@ def test_segmentation_map(segmentation_level, particle_mode, renderer_type, rend ), renderer=renderer, show_viewer=False, + show_FPS=False, ) robot = scene.add_entity( @@ -589,9 +609,10 @@ def test_segmentation_map(segmentation_level, particle_mode, renderer_type, rend assert_array_equal(np.sort(np.unique(seg.flat)), np.arange(0, seg_num)) -@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) +@pytest.mark.required @pytest.mark.parametrize("n_envs", [0, 2]) -def test_camera_follow_entity(n_envs, show_viewer): +@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) +def test_camera_follow_entity(n_envs, renderer, show_viewer): CAM_RES = (100, 100) scene = gs.Scene( @@ -599,7 +620,9 @@ def test_camera_follow_entity(n_envs, show_viewer): rendered_envs_idx=[1] if n_envs else None, segmentation_level="entity", ), + renderer=renderer, show_viewer=False, + show_FPS=False, ) for pos in ((1.0, 0.0, 0.0), (-1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, -1.0, 0.0)): obj = scene.add_entity( @@ -775,18 +798,19 @@ def test_point_cloud(renderer_type, renderer, show_viewer): @pytest.mark.required @pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) -def test_draw_debug(show_viewer): +def test_draw_debug(renderer, show_viewer): if "GS_DISABLE_OFFSCREEN_MARKERS" in os.environ: pytest.skip("Offscreen rendering of markers is forcibly disabled. Skipping...") scene = gs.Scene( + renderer=renderer, show_viewer=show_viewer, + show_FPS=False, ) cam = scene.add_camera( pos=(3.5, 0.5, 2.5), lookat=(0.0, 0.0, 0.5), up=(0.0, 0.0, 1.0), - fov=40, res=(640, 640), GUI=show_viewer, ) @@ -842,10 +866,124 @@ def test_draw_debug(show_viewer): assert_allclose(np.std(rgb_array.reshape((-1, 3)), axis=0), 0.0, tol=gs.EPS) +@pytest.mark.required +@pytest.mark.parametrize("n_envs", [0, 2]) +@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) +@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") +def test_sensors_draw_debug(n_envs, renderer, png_snapshot): + """Test that sensor debug drawing works correctly and renders visible debug elements.""" + scene = gs.Scene( + viewer_options=gs.options.ViewerOptions( + camera_pos=(2.0, 2.0, 2.0), + camera_lookat=(0.0, 0.0, 0.2), + # Force screen-independent low-quality resolution when running unit tests for consistency + res=(640, 480), + # Enable running in background thread if supported by the platform + run_in_thread=(sys.platform == "linux"), + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + renderer=renderer, + show_viewer=True, + ) + + scene.add_entity(gs.morphs.Plane()) + + floating_box = scene.add_entity( + gs.morphs.Box( + size=(0.1, 0.1, 0.1), + pos=(0.0, 0.0, 0.5), + fixed=True, + ) + ) + scene.add_sensor( + gs.sensors.IMU( + entity_idx=floating_box.idx, + pos_offset=(0.0, 0.0, 0.1), + draw_debug=True, + ) + ) + + ground_box = scene.add_entity( + gs.morphs.Box( + size=(0.4, 0.2, 0.1), + pos=(-0.25, 0.0, 0.05), + ) + ) + scene.add_sensor( + gs.sensors.Contact( + entity_idx=ground_box.idx, + draw_debug=True, + debug_sphere_radius=0.08, + debug_color=(1.0, 0.5, 1.0, 1.0), + ) + ) + scene.add_sensor( + gs.sensors.ContactForce( + entity_idx=ground_box.idx, + draw_debug=True, + debug_scale=0.01, + ) + ) + scene.add_sensor( + gs.sensors.Raycaster( + pattern=gs.sensors.raycaster.GridPattern( + resolution=0.2, + size=(0.4, 0.4), + direction=(0.0, 0.0, -1.0), + ), + entity_idx=floating_box.idx, + pos_offset=(0.2, 0.0, -0.1), + return_world_frame=True, + draw_debug=True, + ) + ) + scene.add_sensor( + gs.sensors.Raycaster( + pattern=gs.sensors.raycaster.SphericalPattern( + n_points=(6, 6), + fov=(60.0, (-120.0, -60.0)), + ), + entity_idx=floating_box.idx, + pos_offset=(0.0, 0.5, 0.0), + return_world_frame=False, + draw_debug=True, + debug_sphere_radius=0.01, + debug_ray_start_color=(1.0, 1.0, 0.0, 1.0), + debug_ray_hit_color=(0.5, 1.0, 1.0, 1.0), + ) + ) + + scene.build(n_envs=n_envs) + + for _ in range(5): + scene.step() + + pyrender_viewer = scene.visualizer.viewer._pyrender_viewer + assert pyrender_viewer.is_active + rgb_arr, *_ = pyrender_viewer.render_offscreen( + pyrender_viewer._camera_node, + pyrender_viewer._renderer, + rgb=True, + depth=False, + seg=False, + normal=False, + ) + + if sys.platform == "darwin": + glinfo = pyrender_viewer.context.get_info() + renderer = glinfo.get_renderer() + if renderer == "Apple Software Renderer": + pytest.xfail("Tile ground colors are altered on Apple Software Renderer.") + + assert rgb_array_to_png_bytes(rgb_arr) == png_snapshot + + @pytest.mark.required @pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) @pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") -def test_interactive_viewer_key_press(tmp_path, monkeypatch, png_snapshot, show_viewer): +def test_interactive_viewer_key_press(tmp_path, monkeypatch, renderer, png_snapshot, show_viewer): IMAGE_FILENAME = tmp_path / "screenshot.png" # Mock 'get_save_filename' to avoid poping up an interactive dialog @@ -878,7 +1016,9 @@ def on_key_press(self, symbol: int, modifiers: int): # 'EventLoop.run() must be called from the same thread that imports pyglet.app'. run_in_thread=(sys.platform == "linux"), ), + renderer=renderer, show_viewer=True, + show_FPS=False, ) cube = scene.add_entity( gs.morphs.Box( @@ -922,9 +1062,9 @@ def on_key_press(self, symbol: int, modifiers: int): @pytest.mark.parametrize( "renderer_type", - [RENDERER_TYPE.RASTERIZER], + [RENDERER_TYPE.RASTERIZER, RENDERER_TYPE.BATCHRENDER_RASTERIZER, RENDERER_TYPE.BATCHRENDER_RAYTRACER], ) -def test_render_planes(tmp_path, png_snapshot, renderer): +def test_render_planes(tmp_path, png_snapshot, renderer_type, renderer): CAM_RES = (256, 256) for test_idx, (plane_size, tile_size) in enumerate( @@ -936,7 +1076,26 @@ def test_render_planes(tmp_path, png_snapshot, renderer): ): scene = gs.Scene( renderer=renderer, + show_viewer=False, + show_FPS=False, ) + if renderer_type in (RENDERER_TYPE.BATCHRENDER_RASTERIZER, RENDERER_TYPE.BATCHRENDER_RAYTRACER): + scene.add_light( + pos=(0.0, 0.0, 1.5), + dir=(1.0, 1.0, -2.0), + directional=True, + castshadow=True, + cutoff=45.0, + intensity=0.5, + ) + scene.add_light( + pos=(4.0, -4.0, 4.0), + dir=(-1.0, 1.0, -1.0), + directional=False, + castshadow=True, + cutoff=45.0, + intensity=0.5, + ) plane = scene.add_entity( gs.morphs.Plane(plane_size=plane_size, tile_size=tile_size), ) @@ -1002,6 +1161,7 @@ def test_batch_deformable_render(tmp_path, monkeypatch, png_snapshot): visualize_sph_boundary=True, ), show_viewer=True, + show_FPS=False, ) plane = scene.add_entity( diff --git a/tests/test_sensors.py b/tests/test_sensors.py index 1816d6b3a..96892852f 100644 --- a/tests/test_sensors.py +++ b/tests/test_sensors.py @@ -4,7 +4,7 @@ import genesis as gs -from .utils import assert_allclose, assert_array_equal +from .utils import assert_allclose, assert_array_equal, rgb_array_to_png_bytes def expand_batch_dim(values: tuple[float, ...], n_envs: int) -> tuple[float, ...] | np.ndarray: @@ -84,12 +84,12 @@ def test_imu_sensor(show_viewer, tol, n_envs): # IMU should calculate "classical linear acceleration" using the local frame without accounting for gravity # acc_classical_lin_z = - theta_dot ** 2 - cos(theta) * g - assert_allclose(imu_biased.read()["lin_acc"], expand_batch_dim(BIAS, n_envs), tol=tol) - assert_allclose(imu_biased.read()["ang_vel"], expand_batch_dim(BIAS, n_envs), tol=tol) - assert_allclose(imu_delayed.read()["lin_acc"], 0.0, tol=tol) - assert_allclose(imu_delayed.read()["ang_vel"], 0.0, tol=tol) - assert_allclose(imu_noisy.read()["lin_acc"], 0.0, tol=1e-1) - assert_allclose(imu_noisy.read()["ang_vel"], 0.0, tol=1e-1) + assert_allclose(imu_biased.read().lin_acc, expand_batch_dim(BIAS, n_envs), tol=tol) + assert_allclose(imu_biased.read().ang_vel, expand_batch_dim(BIAS, n_envs), tol=tol) + assert_allclose(imu_delayed.read().lin_acc, 0.0, tol=tol) + assert_allclose(imu_delayed.read().ang_vel, 0.0, tol=tol) + assert_allclose(imu_noisy.read().lin_acc, 0.0, tol=1e-1) + assert_allclose(imu_noisy.read().ang_vel, 0.0, tol=1e-1) # shift COM to induce angular velocity com_shift = torch.tensor([[0.1, 0.1, 0.1]]) @@ -109,21 +109,21 @@ def test_imu_sensor(show_viewer, tol, n_envs): for _ in range(DELAY_STEPS): scene.step() - assert_array_equal(imu_delayed.read()["lin_acc"], true_imu_delayed_reading["lin_acc"]) - assert_array_equal(imu_delayed.read()["ang_vel"], true_imu_delayed_reading["ang_vel"]) + assert_array_equal(imu_delayed.read().lin_acc, true_imu_delayed_reading.lin_acc) + assert_array_equal(imu_delayed.read().ang_vel, true_imu_delayed_reading.ang_vel) # let box collide with ground for _ in range(20): scene.step() - assert_array_equal(imu_biased.read_ground_truth()["lin_acc"], imu_delayed.read_ground_truth()["lin_acc"]) - assert_array_equal(imu_biased.read_ground_truth()["ang_vel"], imu_delayed.read_ground_truth()["ang_vel"]) + assert_array_equal(imu_biased.read_ground_truth().lin_acc, imu_delayed.read_ground_truth().lin_acc) + assert_array_equal(imu_biased.read_ground_truth().ang_vel, imu_delayed.read_ground_truth().ang_vel) with np.testing.assert_raises(AssertionError, msg="Angular velocity should not be zero due to COM shift"): - assert_allclose(imu_biased.read_ground_truth()["ang_vel"], 0.0, tol=tol) + assert_allclose(imu_biased.read_ground_truth().ang_vel, 0.0, tol=tol) with np.testing.assert_raises(AssertionError, msg="Delayed data should not be equal to the ground truth data"): - assert_array_equal(imu_delayed.read()["lin_acc"] - imu_delayed.read_ground_truth()["lin_acc"], 0.0) + assert_array_equal(imu_delayed.read().lin_acc - imu_delayed.read_ground_truth().lin_acc, 0.0) zero_com_shift = torch.tensor([[0.0, 0.0, 0.0]]) box.set_COM_shift(zero_com_shift.expand((n_envs, 1, 3)) if n_envs > 0 else zero_com_shift) @@ -132,22 +132,22 @@ def test_imu_sensor(show_viewer, tol, n_envs): for _ in range(80): scene.step() - assert_allclose(imu_skewed.read()["lin_acc"], -GRAVITY, tol=5e-6) + assert_allclose(imu_skewed.read().lin_acc, -GRAVITY, tol=5e-6) assert_allclose( - imu_biased.read()["lin_acc"], + imu_biased.read().lin_acc, expand_batch_dim((BIAS[0], BIAS[1], BIAS[2] - GRAVITY), n_envs), tol=5e-6, ) - assert_allclose(imu_biased.read()["ang_vel"], expand_batch_dim(BIAS, n_envs), tol=1e-5) + assert_allclose(imu_biased.read().ang_vel, expand_batch_dim(BIAS, n_envs), tol=1e-5) scene.reset() - assert_allclose(imu_biased.read()["lin_acc"], 0.0, tol=gs.EPS) # biased, but cache hasn't been updated yet - assert_allclose(imu_delayed.read()["lin_acc"], 0.0, tol=gs.EPS) - assert_allclose(imu_noisy.read()["ang_vel"], 0.0, tol=gs.EPS) + assert_allclose(imu_biased.read().lin_acc, 0.0, tol=gs.EPS) # biased, but cache hasn't been updated yet + assert_allclose(imu_delayed.read().lin_acc, 0.0, tol=gs.EPS) + assert_allclose(imu_noisy.read().ang_vel, 0.0, tol=gs.EPS) scene.step() - assert_allclose(imu_biased.read()["lin_acc"], expand_batch_dim(BIAS, n_envs), tol=tol) + assert_allclose(imu_biased.read().lin_acc, expand_batch_dim(BIAS, n_envs), tol=tol) @pytest.mark.required @@ -248,3 +248,102 @@ def test_rigid_tactile_sensors_gravity_force(show_viewer, tol, n_envs): tol=NOISE * 10, err_msg="ContactForceSensor should read bias and noise and -gravity (normal) force clipped by max_force.", ) + + +@pytest.mark.required +@pytest.mark.parametrize("n_envs", [0, 2]) +def test_raycaster_hits(show_viewer, tol, n_envs): + """Test if the Raycaster sensor with GridPattern rays pointing to ground returns the correct distance.""" + EXPECTED_DISTANCE = 1.2 + NUM_RAYS_XY = 3 + BOX_HEIGHT = 0.2 + SPHERE_POS = (4.0, 0.0, 1.0) + RAYCAST_GRID_SIZE = 0.5 + + scene = gs.Scene( + profiling_options=gs.options.ProfilingOptions(show_FPS=False), + show_viewer=show_viewer, + ) + + scene.add_entity(gs.morphs.Plane()) + + box_obstacle = scene.add_entity( + gs.morphs.Box( + size=(RAYCAST_GRID_SIZE / 2.0, RAYCAST_GRID_SIZE / 2.0, BOX_HEIGHT), + # pos=(0.0, 0.0, -BOX_HEIGHT), # init below ground to not interfere with first raycast + pos=(RAYCAST_GRID_SIZE, RAYCAST_GRID_SIZE, EXPECTED_DISTANCE / 2.0 + BOX_HEIGHT / 2.0), + ), + ) + grid_sensor_box = scene.add_entity( + gs.morphs.Box( + size=(0.1, 0.1, 0.1), + pos=(0.0, 0.0, EXPECTED_DISTANCE + BOX_HEIGHT), + fixed=True, + ), + ) + grid_raycaster = scene.add_sensor( + gs.sensors.Raycaster( + pattern=gs.sensors.raycaster.GridPattern( + resolution=1.0 / (NUM_RAYS_XY - 1.0), + size=(1.0, 1.0), + direction=(0.0, 0.0, -1.0), # pointing downwards to ground + ), + entity_idx=grid_sensor_box.idx, + pos_offset=(0.0, 0.0, -BOX_HEIGHT), + return_world_frame=True, + draw_debug=True, + ) + ) + + spherical_sensor = scene.add_entity( + gs.morphs.Sphere( + radius=EXPECTED_DISTANCE, + pos=SPHERE_POS, + fixed=True, + ), + ) + spherical_raycaster = scene.add_sensor( + gs.sensors.Raycaster( + pattern=gs.sensors.raycaster.SphericalPattern( + n_points=(NUM_RAYS_XY, NUM_RAYS_XY), + ), + entity_idx=spherical_sensor.idx, + return_world_frame=False, + ) + ) + + scene.build(n_envs=n_envs) + + scene.step() + + grid_hits = grid_raycaster.read().points + grid_distances = grid_raycaster.read().distances + spherical_distances = spherical_raycaster.read().distances + + expected_shape = (NUM_RAYS_XY, NUM_RAYS_XY) if n_envs == 0 else (n_envs, NUM_RAYS_XY, NUM_RAYS_XY) + assert grid_distances.shape == spherical_distances.shape == expected_shape + + grid_distance_min = grid_distances.min() + assert grid_distances.min() < EXPECTED_DISTANCE - tol, "Raycaster grid pattern should have hit obstacle" + ground_hit_mask = grid_distances > grid_distance_min + tol + grid_hits = grid_hits[ground_hit_mask] + grid_distances = grid_distances[ground_hit_mask] + + assert_allclose( + grid_hits[..., 2], + 0.0, + tol=tol, + err_msg="Raycaster grid pattern should hit ground (z≈0)", + ) + assert_allclose( + grid_distances, + EXPECTED_DISTANCE, + tol=tol, + err_msg=f"Raycaster grid pattern should measure {EXPECTED_DISTANCE}m to ground plane", + ) + assert_allclose( + spherical_distances, + EXPECTED_DISTANCE, + tol=1e-2, # since sphere mesh is discretized, we need a larger tolerance here + err_msg=f"Raycaster spherical pattern should measure {EXPECTED_DISTANCE}m to the sphere around it", + ) diff --git a/tests/utils.py b/tests/utils.py index e8c0c5640..baf1ae4a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,7 +32,7 @@ DEFAULT_BRANCH_NAME = "main" HUGGINGFACE_ASSETS_REVISION = "16e4eae0024312b84518f4b555dd630d6b34095a" -HUGGINGFACE_SNAPSHOT_REVISION = "0db0ca5941d6b64c58d9e9711abe62e3a50738ac" +HUGGINGFACE_SNAPSHOT_REVISION = "15e836c732972cd8ddf57e43136986b34653b279" MESH_EXTENSIONS = (".mtl", *MESH_FORMATS, *GLTF_FORMATS, *USD_FORMATS) IMAGE_EXTENSIONS = (".png", ".jpg") @@ -211,7 +211,7 @@ def get_hf_dataset( continue ext = path.suffix.lower() - if not ext in (URDF_FORMAT, MJCF_FORMAT, *IMAGE_EXTENSIONS, *MESH_EXTENSIONS): + if ext not in (URDF_FORMAT, MJCF_FORMAT, *IMAGE_EXTENSIONS, *MESH_EXTENSIONS): continue has_files = True @@ -223,19 +223,19 @@ def get_hf_dataset( try: ET.parse(path) except ET.ParseError as e: - raise HTTPError(f"Impossible to parse XML file.") from e + raise HTTPError("Impossible to parse XML file.") from e elif path.suffix.lower() in IMAGE_EXTENSIONS: try: Image.open(path) except UnidentifiedImageError as e: - raise HTTPError(f"Impossible to parse Image file.") from e + raise HTTPError("Impossible to parse Image file.") from e elif path.suffix.lower() in MESH_EXTENSIONS: # TODO: Validating mesh files is more tricky. Ignoring them for now. pass if not has_files: raise HTTPError("No file downloaded.") - except (HTTPError, FileNotFoundError) as e: + except (HTTPError, FileNotFoundError): if i == num_retry - 1: raise print(f"Failed to download assets from HuggingFace dataset. Trying again in {retry_delay}s...")