From 0288030ea2c0e24cb337d8cd1183f5dac0931221 Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Wed, 24 Sep 2025 17:51:31 -0400 Subject: [PATCH 1/4] add sensor.draw_debug --- examples/sensors/contact_force_go2.py | 98 +++++ examples/sensors/force.py | 84 ---- examples/sensors/{imu.py => imu_franka.py} | 11 +- genesis/recorders/__init__.py | 5 +- genesis/recorders/base_recorder.py | 7 +- genesis/recorders/plotters.py | 452 ++++++++++++++------- genesis/sensors/base_sensor.py | 25 +- genesis/sensors/contact_force.py | 78 +++- genesis/sensors/imu.py | 44 +- genesis/sensors/sensor_manager.py | 7 + genesis/vis/rasterizer_context.py | 7 +- tests/test_recorders.py | 3 +- 12 files changed, 543 insertions(+), 278 deletions(-) create mode 100644 examples/sensors/contact_force_go2.py delete mode 100644 examples/sensors/force.py rename examples/sensors/{imu.py => imu_franka.py} (90%) diff --git a/examples/sensors/contact_force_go2.py b/examples/sensors/contact_force_go2.py new file mode 100644 index 000000000..0b7d49ec8 --- /dev/null +++ b/examples/sensors/contact_force_go2.py @@ -0,0 +1,98 @@ +import argparse + +from tqdm import tqdm + +import genesis as gs +from genesis.recorders.plotters import IS_MATPLOTLIB_AVAILABLE, IS_PYQTGRAPH_AVAILABLE + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-dt", "--timestep", type=float, default=1e-2, help="Simulation time step") + parser.add_argument("-v", "--vis", action="store_true", default=True, help="Show visualization GUI") + parser.add_argument("-nv", "--no-vis", action="store_false", dest="vis", help="Disable visualization GUI") + parser.add_argument("-c", "--cpu", action="store_true", help="Use CPU instead of GPU") + parser.add_argument("-t", "--seconds", type=float, default=2, help="Number of seconds to simulate") + parser.add_argument("-f", "--force", action="store_true", default=True, help="Use ContactForceSensor (xyz float)") + parser.add_argument("-nf", "--no-force", action="store_false", dest="force", help="Use ContactSensor (boolean)") + + args = parser.parse_args() + + ########################## init ########################## + gs.init(backend=gs.cpu if args.cpu else gs.gpu, logging_level=None) + + ########################## scene setup ########################## + scene = gs.Scene( + sim_options=gs.options.SimOptions(dt=args.timestep), + rigid_options=gs.options.RigidOptions( + use_gjk_collision=True, + constraint_timeconst=max(0.01, 2 * args.timestep), + ), + vis_options=gs.options.VisOptions(show_world_frame=True), + profiling_options=gs.options.ProfilingOptions(show_FPS=False), + show_viewer=args.vis, + ) + + scene.add_entity(gs.morphs.Plane()) + + foot_link_names = ["FR_foot", "FL_foot", "RR_foot", "RL_foot"] + go2 = scene.add_entity( + gs.morphs.URDF( + file="urdf/go2/urdf/go2.urdf", + pos=(0.0, 0.0, 0.2), + links_to_keep=foot_link_names, + ) + ) + + for link_name in foot_link_names: + if args.force: + force_sensor = scene.add_sensor( + gs.sensors.ContactForce( + entity_idx=go2.idx, + link_idx_local=go2.get_link(link_name).idx_local, + draw_debug=True, + ) + ) + if IS_PYQTGRAPH_AVAILABLE: + force_sensor.start_recording( + gs.recorders.PyQtLinePlot( + title="Force Sensor Data", + labels=["force_x", "force_y", "force_z"], + ) + ) + elif IS_MATPLOTLIB_AVAILABLE: + print("pyqtgraph not found, falling back to matplotlib.") + force_sensor.start_recording( + gs.recorders.MPLLinePlot( + title="Force Sensor Data", + labels=["force_x", "force_y", "force_z"], + ) + ) + else: + print("matplotlib or pyqtgraph not found, skipping real-time plotting.") + else: + contact_sensor = scene.add_sensor( + gs.sensors.Contact( + entity_idx=go2.idx, + link_idx_local=go2.get_link(link_name).idx_local, + draw_debug=True, + ) + ) + + scene.build() + + try: + steps = int(args.seconds / args.timestep) + for _ in tqdm(range(steps)): + scene.step() + + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") + + scene.stop_recording() + + +if __name__ == "__main__": + main() diff --git a/examples/sensors/force.py b/examples/sensors/force.py deleted file mode 100644 index 527c4f1da..000000000 --- a/examples/sensors/force.py +++ /dev/null @@ -1,84 +0,0 @@ -import argparse - -from tqdm import tqdm - -import genesis as gs -from genesis.recorders.plotters import IS_MATPLOTLIB_AVAILABLE, IS_PYQTGRAPH_AVAILABLE - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-dt", "--timestep", type=float, default=1e-2, help="Simulation time step") - parser.add_argument("-v", "--vis", action="store_true", help="Show visualization GUI", default=True) - parser.add_argument("-nv", "--no-vis", action="store_false", dest="vis", help="Disable visualization GUI") - parser.add_argument("-c", "--cpu", action="store_true", help="Use CPU instead of GPU") - parser.add_argument("-t", "--seconds", type=float, default=2, help="Number of seconds to simulate") - - args = parser.parse_args() - - ########################## init ########################## - gs.init(backend=gs.cpu if args.cpu else gs.gpu, logging_level=None) - - ########################## scene setup ########################## - scene = gs.Scene( - sim_options=gs.options.SimOptions(dt=args.timestep), - rigid_options=gs.options.RigidOptions( - use_gjk_collision=True, - constraint_timeconst=max(0.01, 2 * args.timestep), - ), - vis_options=gs.options.VisOptions(show_world_frame=True), - profiling_options=gs.options.ProfilingOptions(show_FPS=False), - show_viewer=args.vis, - ) - - scene.add_entity(gs.morphs.Plane()) - - box = scene.add_entity( - morph=gs.morphs.Box( - size=(0.05, 0.05, 0.05), - pos=(0.0, 0.0, 0.2), - ), - material=gs.materials.Rigid(rho=1.0), - ) - # load the hand .urdf - hand = scene.add_entity( - morph=gs.morphs.URDF( - file="urdf/shadow_hand/shadow_hand.urdf", - pos=(0.0, -0.3, 0.1), - euler=(-90.0, 0.0, 0.0), - fixed=True, # Fix the base so the whole hand doesn't flop on the ground - ), - material=gs.materials.Rigid(), - ) - palm = hand.get_link("palm") - - force_sensor = scene.add_sensor(gs.sensors.ContactForce(entity_idx=hand.idx, link_idx_local=palm.idx_local)) - - labels = ["force_x", "force_y", "force_z"] - if IS_PYQTGRAPH_AVAILABLE: - force_sensor.start_recording(gs.recorders.PyQtPlot(title="Force Sensor Measured Data", labels=labels)) - elif IS_MATPLOTLIB_AVAILABLE: - print("pyqtgraph not found, falling back to matplotlib.") - force_sensor.start_recording(gs.recorders.MPLPlot(title="Force Sensor Measured Data", labels=labels)) - else: - print("matplotlib or pyqtgraph not found, skipping real-time plotting.") - - force_sensor.start_recording(gs.recorders.NPZFile(filename="force_data.npz", save_on_reset=True)) - - scene.build() - - try: - steps = int(args.seconds / args.timestep) - for _ in tqdm(range(steps)): - scene.step() - - except KeyboardInterrupt: - gs.logger.info("Simulation interrupted, exiting.") - finally: - gs.logger.info("Simulation finished.") - - scene.stop_recording() - - -if __name__ == "__main__": - main() diff --git a/examples/sensors/imu.py b/examples/sensors/imu_franka.py similarity index 90% rename from examples/sensors/imu.py rename to examples/sensors/imu_franka.py index 90f308add..a3c255cae 100644 --- a/examples/sensors/imu.py +++ b/examples/sensors/imu_franka.py @@ -44,6 +44,7 @@ def main(): gs.sensors.IMU( entity_idx=franka.idx, link_idx_local=end_effector.idx_local, + pos_offset=(0.0, 0.0, 0.2), # noise parameters acc_axes_skew=(0.0, 0.01, 0.02), gyro_axes_skew=(0.03, 0.04, 0.05), @@ -54,22 +55,24 @@ def main(): delay=0.01, jitter=0.01, interpolate=True, + # visualize + draw_debug=True, ) ) labels = {"lin_acc": ("acc_x", "acc_y", "acc_z"), "ang_vel": ("gyro_x", "gyro_y", "gyro_z")} if args.vis: if IS_PYQTGRAPH_AVAILABLE: - imu.start_recording(gs.recorders.PyQtPlot(title="IMU Measured Data", labels=labels)) + imu.start_recording(gs.recorders.PyQtLinePlot(title="IMU Measured Data", labels=labels)) scene.start_recording( imu.read_ground_truth, - gs.recorders.PyQtPlot(title="IMU Ground Truth Data", labels=labels), + gs.recorders.PyQtLinePlot(title="IMU Ground Truth Data", labels=labels), ) elif IS_MATPLOTLIB_AVAILABLE: gs.logger.info("pyqtgraph not found, falling back to matplotlib.") - imu.start_recording(gs.recorders.MPLPlot(title="IMU Measured Data", labels=labels)) + imu.start_recording(gs.recorders.MPLLinePlot(title="IMU Measured Data", labels=labels)) scene.start_recording( imu.read_ground_truth, - gs.recorders.MPLPlot(title="IMU Ground Truth Data", labels=labels), + gs.recorders.MPLLinePlot(title="IMU Ground Truth Data", labels=labels), ) else: print("matplotlib or pyqtgraph not found, skipping real-time plotting.") diff --git a/genesis/recorders/__init__.py b/genesis/recorders/__init__.py index 1fccec378..c58d1298e 100644 --- a/genesis/recorders/__init__.py +++ b/genesis/recorders/__init__.py @@ -2,6 +2,7 @@ from .file_writers import CSVFileWriterOptions as CSVFile from .file_writers import NPZFileWriterOptions as NPZFile from .file_writers import VideoFileWriterOptions as VideoFile -from .plotters import MPLPlotterOptions as MPLPlot -from .plotters import PyQtPlotterOptions as PyQtPlot +from .plotters import MPLImagePlotterOptions as MPLImagePlot +from .plotters import MPLLinePlotterOptions as MPLLinePlot +from .plotters import PyQtLinePlotterOptions as PyQtLinePlot from .recorder_manager import RecorderManager, register_recording diff --git a/genesis/recorders/base_recorder.py b/genesis/recorders/base_recorder.py index bdedd22c0..cb31724b3 100644 --- a/genesis/recorders/base_recorder.py +++ b/genesis/recorders/base_recorder.py @@ -1,12 +1,13 @@ import queue import threading import time -from typing import Callable, Generic, TypeVar +from typing import TYPE_CHECKING, Callable, Generic, TypeVar import genesis as gs from genesis.options import Options -from .recorder_manager import RecorderManager +if TYPE_CHECKING: + from .recorder_manager import RecorderManager T = TypeVar("T") @@ -50,7 +51,7 @@ class Recorder(Generic[T]): done through the RecorderManager. """ - def __init__(self, manager: RecorderManager, options: RecorderOptions, data_func: Callable[[], T]): + def __init__(self, manager: "RecorderManager", options: RecorderOptions, data_func: Callable[[], T]): self._options = options self._manager = manager self._data_func = data_func diff --git a/genesis/recorders/plotters.py b/genesis/recorders/plotters.py index 85ce4a2d8..6523a4f1c 100644 --- a/genesis/recorders/plotters.py +++ b/genesis/recorders/plotters.py @@ -4,8 +4,9 @@ import time from collections import defaultdict from collections.abc import Sequence +from dataclasses import dataclass from functools import partial -from typing import Any +from typing import Any, Generic, TypeVar import numpy as np import torch @@ -34,10 +35,10 @@ pass -COLORS = itertools.cycle(("r", "g", "b", "c", "m", "y", "w")) +COLORS = itertools.cycle(("r", "g", "b", "c", "m", "y")) -def _data_to_array(data: Any) -> np.ndarray: +def _data_to_array(data: Sequence) -> np.ndarray: if isinstance(data, torch.Tensor): data = tensor_to_array(data) return np.atleast_1d(data) @@ -45,23 +46,14 @@ def _data_to_array(data: Any) -> np.ndarray: class BasePlotterOptions(RecorderOptions): """ - Base class for live line plot visualization of scalar data. - - The recorded data_func should return scalar data (single scalar, a tuple of scalars, or a dict with string keys and - scalar or tuple of scalars as values). + Base class for plot visualization. Parameters ---------- title: str The title of the plot. - labels: tuple[str] | dict[str, tuple[str]] | None - The labels for the plot. The length of the labels should match the length of the data. - If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. - The keys will be used as subplot titles and the values will be used as labels within each subplot. window_size: tuple[int, int] The size of the window in pixels. - history_length: int - The maximum number of previous data to store. save_to_filename: str | None If provided, the animation will be saved to a file with the given filename. show_window: bool | None @@ -69,21 +61,61 @@ class BasePlotterOptions(RecorderOptions): """ title: str = "" - labels: tuple[str, ...] | dict[str, tuple[str, ...]] | None = None window_size: tuple[int, int] = (800, 600) - history_length: int = 100 save_to_filename: str | None = None show_window: bool | None = None class BasePlotter(Recorder): - """Base class for real-time plotters with shared functionality.""" def build(self): super().build() - self.show_window = self._options.show_window if self._options.show_window is not None else has_display() + def process(self, data, cur_time): + pass # allow super() calls + + def cleanup(self): + pass # allow super() calls + + +@dataclass +class LinePlotterMixinOptions: + """ + Mixin class for live line plot visualization of scalar data. + + The recorded data_func should return scalar data (single scalar, a tuple of scalars, or a dict with string keys and + scalar or tuple of scalars as values). + + Parameters + ---------- + labels: tuple[str] | dict[str, tuple[str]] | None + The labels for the plot. The length of the labels should match the length of the data. + If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. + The keys will be used as subplot titles and the values will be used as labels within each subplot. + x_label: str, optional + Label for the horizontal axis. + y_label: str, optional + Label for the vertical axis. + history_length: int + The maximum number of previous data to store. + """ + + labels: tuple[str, ...] | dict[str, tuple[str, ...]] | None = None + x_label: str = "" + y_label: str = "" + history_length: int = 100 + + +LinePlotterOptionsMixinT = TypeVar("LinePlotterOptionsMixinT", bound=LinePlotterMixinOptions) + + +class LinePlotterMixin(Generic[LinePlotterOptionsMixinT]): + """Base class for real-time plotters with shared functionality.""" + + def build(self): + super().build() + self.x_data: list[float] = [] self.y_data: defaultdict[str, defaultdict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) @@ -93,13 +125,13 @@ def build(self): if self._options.labels is not None: self._setup_plot_structure(self._options.labels) + else: + self._setup_plot_structure(self._data_func()) self.video_writer = None if self._options.save_to_filename: def _get_video_frame_buffer(plotter): - from matplotlib.backends.backend_agg import FigureCanvasAgg - # Make sure that all the data in the pipe has been processed before rendering anything if not plotter._frames_buffer: if plotter._data_queue is not None and not plotter._data_queue.empty(): @@ -131,7 +163,7 @@ def cleanup(self): self._frames_buffer.clear() self.video_writer = None - def _setup_plot_structure(self, labels_or_data: dict[str, Any] | Any): + def _setup_plot_structure(self, labels_or_data: dict[str, Sequence] | Sequence): """Set up the plot structure based on labels or first data sample.""" if isinstance(labels_or_data, dict): self.is_dict_data = True @@ -158,8 +190,6 @@ def _setup_plot_structure(self, labels_or_data: dict[str, Any] | Any): def process(self, data, cur_time): """Process new data point and update plot.""" - if self.subplot_structure is None: - self._setup_plot_structure(data) if self.is_dict_data: processed_data = {} @@ -224,7 +254,69 @@ def get_image_array(self): raise NotImplementedError -class PyQtPlotterOptions(BasePlotterOptions): +class BasePyQtPlotter(BasePlotter): + """ + Base class for PyQt based plotters. + """ + + def build(self): + if not IS_PYQTGRAPH_AVAILABLE: + gs.raise_exception( + f"{type(self).__name__} pyqtgraph is not installed. Please install it with `pip install pyqtgraph`." + ) + + super().build() + + self.app: pg.QtWidgets.QApplication | None = None + self.widget: pg.GraphicsLayoutWidget | None = None + self.plot_widgets: list[pg.PlotWidget] = [] + if not pg.QtWidgets.QApplication.instance(): + self.app = pg.QtWidgets.QApplication([]) + else: + self.app = pg.QtWidgets.QApplication.instance() + + self.widget = pg.GraphicsLayoutWidget(show=self.show_window, title=self._options.title) + if self.show_window: + gs.logger.info(f"[{type(self).__name__}] created PyQtGraph window") + self.widget.resize(*self._options.window_size) + + def cleanup(self): + super().cleanup() + + if self.widget: + try: + self.widget.close() + gs.logger.debug(f"[{type(self).__name__}] closed PyQtGraph window") + except Exception as e: + gs.logger.warning(f"[{type(self).__name__}] Error closing window: {e}") + finally: + self.widget = None + self.plot_widgets.clear() + + @property + def run_in_thread(self) -> bool: + return True + + def get_image_array(self): + """ + Capture the plot image as a video frame. + + Returns + ------- + image_array : np.ndarray + The RGB image as a numpy array. + """ + pixmap = self.widget.grab() + qimage = pixmap.toImage() + + qimage = qimage.convertToFormat(pg.QtGui.QImage.Format_RGB888) + ptr = qimage.bits() + ptr.setsize(qimage.byteCount()) + + return np.array(ptr).reshape((qimage.height(), qimage.width(), 3)) + + +class PyQtLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): """ Live line plot visualization of data using PyQtGraph. @@ -232,54 +324,38 @@ class PyQtPlotterOptions(BasePlotterOptions): ---------- title: str The title of the plot. - labels: tuple[str] | dict[str, tuple[str]] | None - The labels for the plot. The length of the labels should match the length of the data. - If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. - The keys will be used as subplot titles and the values will be used as labels within each subplot. window_size: tuple[int, int] The size of the window in pixels. - history_length: int - The maximum number of previous data to store. save_to_filename: str | None If provided, the animation will be saved to a file with the given filename. show_window: bool | None Whether to show the window. If not provided, it will be set to True if a display is connected, False otherwise. + labels: tuple[str] | dict[str, tuple[str]] | None + The labels for the plot. The length of the labels should match the length of the data. + If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. + The keys will be used as subplot titles and the values will be used as labels within each subplot. + x_label: str, optional + Label for the horizontal axis. + y_label: str, optional + Label for the vertical axis. + history_length: int + The maximum number of previous data to store. """ pass -@register_recording(PyQtPlotterOptions) -class PyQtPlotter(BasePlotter): +@register_recording(PyQtLinePlotterOptions) +class PyQtLinePlotter(LinePlotterMixin[PyQtLinePlotterOptions], BasePyQtPlotter): """ Real-time plot using PyQt for live sensor data visualization. - - Inherits common plotting functionality from BasePlotter. """ def build(self): - if not IS_PYQTGRAPH_AVAILABLE: - gs.raise_exception( - "[PyQtPlotter] pyqtgraph is not installed. Please install it with `pip install pyqtgraph`." - ) - super().build() - self.app: pg.QtWidgets.QApplication | None = None - self.widget: pg.GraphicsLayoutWidget | None = None - self.plot_widgets: list[pg.PlotWidget] = [] self.curves: dict[str, list[pg.PlotCurveItem]] = {} - if not pg.QtWidgets.QApplication.instance(): - self.app = pg.QtWidgets.QApplication([]) - else: - self.app = pg.QtWidgets.QApplication.instance() - - self.widget = pg.GraphicsLayoutWidget(show=self.show_window, title=self._options.title) - self.widget.resize(*self._options.window_size) - - gs.logger.info("[PyQtPlotter] created PyQtGraph window") - # create plots for each subplot for subplot_idx, (subplot_key, channel_labels) in enumerate(self.subplot_structure.items()): # add new row if not the first plot @@ -287,8 +363,8 @@ def build(self): self.widget.nextRow() plot_widget = self.widget.addPlot(title=subplot_key if self.is_dict_data else self._options.title) - plot_widget.setLabel("left", "Value") - plot_widget.setLabel("bottom", "Time") + plot_widget.setLabel("bottom", self._options.x_label) + plot_widget.setLabel("left", self._options.y_label) plot_widget.showGrid(x=True, y=True, alpha=0.3) plot_widget.addLegend() @@ -317,20 +393,59 @@ def _update_plot(self): def cleanup(self): super().cleanup() - if self.widget: + self.curves.clear() + + +class BaseMPLPlotter(BasePlotter): + """ + Base class for matplotlib based plotters. + """ + + def build(self): + if not IS_MATPLOTLIB_AVAILABLE: + gs.raise_exception( + f"{type(self).__name__} matplotlib is not installed. Please install it with `pip install matplotlib>=3.7.0`." + ) + + super().build() + + import matplotlib.pyplot as plt + + self.fig: plt.Figure | None = None + self._lock = threading.Lock() + + # matplotlib figsize uses inches + dpi = mpl.rcParams.get("figure.dpi", 100) + self.figsize = (self._options.window_size[0] / dpi, self._options.window_size[1] / dpi) + + def _show_fig(self): + if self.show_window: + self.fig.show() + gs.logger.info(f"[{type(self).__name__}] created matplotlib window") + + def cleanup(self): + """Clean up matplotlib resources.""" + super().cleanup() + + # Logger may not be available anymore + logger_exists = hasattr(gs, "logger") + + if self.fig is not None: try: - self.widget.close() - gs.logger.debug("[PyQtPlotter] closed PyQtGraph window") + import matplotlib.pyplot as plt + + plt.close(self.fig) + if logger_exists: + gs.logger.debug(f"[{type(self).__name__}] Closed matplotlib window") except Exception as e: - gs.logger.warning(f"[PyQtPlotter] Error closing window: {e}") + if logger_exists: + gs.logger.warning(f"[{type(self).__name__}] Error closing window: {e}") finally: - self.widget = None - self.plot_widgets.clear() - self.curves.clear() + self.fig = None @property def run_in_thread(self) -> bool: - return True + return not self.show_window or gs.platform != "macOS" def get_image_array(self): """ @@ -341,93 +456,100 @@ def get_image_array(self): image_array : np.ndarray The RGB image as a numpy array. """ - pixmap = self.widget.grab() - qimage = pixmap.toImage() + from matplotlib.backends.backend_agg import FigureCanvasAgg - qimage = qimage.convertToFormat(pg.QtGui.QImage.Format_RGB888) - ptr = qimage.bits() - ptr.setsize(qimage.byteCount()) + self._lock.acquire() + if isinstance(self.fig.canvas, FigureCanvasAgg): + # Read internal buffer + width, height = self.fig.canvas.get_width_height(physical=True) + rgba_array_flat = np.frombuffer(self.fig.canvas.buffer_rgba(), dtype=np.uint8) + rgb_array = rgba_array_flat.reshape((height, width, 4))[..., :3] - return np.array(ptr).reshape((qimage.height(), qimage.width(), 3)) + # Rescale image if necessary + if (width, height) != tuple(self._options.window_size): + img = Image.fromarray(rgb_array) + img = img.resize(self._options.window_size, resample=Image.BILINEAR) + rgb_array = np.asarray(img) + else: + rgb_array = rgb_array.copy() + else: + # Slower but more generic fallback only if necessary + buffer = io.BytesIO() + self.fig.canvas.print_figure(buffer, format="png", dpi="figure") + buffer.seek(0) + img = Image.open(buffer) + rgb_array = np.asarray(img.convert("RGB")) + self._lock.release() + + return rgb_array -class MPLPlotterOptions(BasePlotterOptions): +class MPLLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): """ - Live line plot visualization of data using MatPlotLib. + Live line plot visualization of data using matplotlib. Parameters ---------- title: str The title of the plot. - labels: tuple[str] | dict[str, tuple[str]] | None - The labels for the plot. The length of the labels should match the length of the data. - If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. - The keys will be used as subplot titles and the values will be used as labels within each subplot. window_size: tuple[int, int] The size of the window in pixels. - history_length: int - The maximum number of previous data to store. save_to_filename: str | None If provided, the animation will be saved to a file with the given filename. show_window: bool | None Whether to show the window. If not provided, it will be set to True if a display is connected, False otherwise. + labels: tuple[str] | dict[str, tuple[str]] | None + The labels for the plot. The length of the labels should match the length of the data. + If a dict is provided, the data should also be a dict of tuples of strings that match the length of the data. + The keys will be used as subplot titles and the values will be used as labels within each subplot. + x_label: str, optional + Label for the horizontal axis. + y_label: str, optional + Label for the vertical axis. + history_length: int + The maximum number of previous data to store. """ pass -@register_recording(MPLPlotterOptions) -class MPLPlotter(BasePlotter): +@register_recording(MPLLinePlotterOptions) +class MPLLinePlotter(LinePlotterMixin[MPLLinePlotterOptions], BaseMPLPlotter): """ - Real-time plot using MatPlotLib for live sensor data visualization. + Real-time plot using matplotlib for live sensor data visualization. Inherits common plotting functionality from BasePlotter. """ def build(self): - if not IS_MATPLOTLIB_AVAILABLE: - gs.raise_exception( - "[MPLPlotter] matplotlib is not installed. Please install it with `pip install matplotlib>=3.7.0`." - ) super().build() import matplotlib.pyplot as plt - self.fig: plt.Figure | None = None self.axes: list[plt.Axes] = [] self.lines: dict[str, list[plt.Line2D]] = {} self.backgrounds: list[Any] = [] - self._lock = threading.Lock() - - gs.logger.info("[MPLPlotter] created Matplotlib window") - - # create figure and subplots + # Create figure and subplots n_subplots = len(self.subplot_structure) - dpi = mpl.rcParams.get("figure.dpi", 100) - # matplotlib figsize uses inches - figsize = (self._options.window_size[0] / dpi, self._options.window_size[1] / dpi) - if n_subplots == 1: - self.fig, ax = plt.subplots(figsize=figsize) + self.fig, ax = plt.subplots(figsize=self.figsize) self.axes = [ax] else: - self.fig, axes = plt.subplots(n_subplots, 1, figsize=figsize, sharex=True, constrained_layout=True) + self.fig, axes = plt.subplots(n_subplots, 1, figsize=self.figsize, sharex=True, constrained_layout=True) self.axes = axes if isinstance(axes, (list, tuple, np.ndarray)) else [axes] self.fig.suptitle(self._options.title) - # create lines for each subplot + # Create lines for each subplot for subplot_idx, (subplot_key, channel_labels) in enumerate(self.subplot_structure.items()): ax = self.axes[subplot_idx] - ax.set_xlabel("Time") - ax.set_ylabel("Value") + ax.set_xlabel(self._options.x_label) + ax.set_ylabel(self._options.y_label) ax.grid(True, alpha=0.3) - # set subplot title if we have multiple subplots if self.is_dict_data and n_subplots > 1: ax.set_title(subplot_key) - # create lines for this subplot subplot_lines = [] for color, channel_label in zip(COLORS, channel_labels): @@ -436,22 +558,19 @@ def build(self): self.lines[subplot_key] = subplot_lines - ax.set_xlim(0, 10) - ax.set_ylim(-1, 1) - # Legend must be outside, otherwise it will not play well with blitting self.fig.legend(ncol=sum(map(len, self.lines.values())), loc="outside lower center") - - if self.show_window: - self.fig.show() self.fig.canvas.draw() for ax in self.axes: self.backgrounds.append(self.fig.canvas.copy_from_bbox(ax.bbox)) + self._show_fig() + def _update_plot(self): - # Update each subplot self._lock.acquire() + + # Update each subplot for subplot_idx, (subplot_key, subplot_lines) in enumerate(self.lines.items()): ax = self.axes[subplot_idx] @@ -504,67 +623,84 @@ def _update_plot(self): self._lock.release() def cleanup(self): - """Clean up Matplotlib resources.""" super().cleanup() - # Logger may not be available anymore - logger_exists = hasattr(gs, "logger") + self.lines.clear() + self.backgrounds.clear() - if self.fig is not None: - try: - import matplotlib.pyplot as plt - plt.close(self.fig) - if logger_exists: - gs.logger.debug("[MPLPlotter] Closed Matplotlib window") - except Exception as e: - if logger_exists: - gs.logger.warning(f"[MPLPlotter] Error closing window: {e}") - finally: - self.lines.clear() - self.backgrounds.clear() - self.fig = None +class MPLImagePlotterOptions(BasePlotterOptions): + """ + Live visualization of image data using matplotlib. - @property - def run_in_thread(self) -> bool: - return gs.platform != "macOS" + Parameters + ---------- + title: str + The title of the plot. + window_size: tuple[int, int] + The size of the window in pixels. + save_to_filename: str | None + If provided, the animation will be saved to a file with the given filename. + show_window: bool | None + Whether to show the window. If not provided, it will be set to True if a display is connected, False otherwise. + """ - def get_image_array(self): - """ - Capture the plot image as a video frame. + pass - Returns - ------- - image_array : np.ndarray - The RGB image as a numpy array. - """ - from matplotlib.backends.backend_agg import FigureCanvasAgg - self._lock.acquire() - if isinstance(self.fig.canvas, FigureCanvasAgg): - # Must force rendering manually - # FIXME: Check if necessary - # FigureCanvasAgg.draw(self.fig.canvas) +@register_recording(MPLImagePlotterOptions) +class MPLImagePlotter(BaseMPLPlotter): + """ + Live image viewer using matplotlib. + """ - # Read internal buffer - width, height = self.fig.canvas.get_width_height(physical=True) - rgba_array_flat = np.frombuffer(self.fig.canvas.buffer_rgba(), dtype=np.uint8) - rgb_array = rgba_array_flat.reshape((height, width, 4))[..., :3] + def build(self): + super().build() - # Rescale image if necessary - if (width, height) != tuple(self._options.window_size): - img = Image.fromarray(rgb_array) - img = img.resize(self._options.window_size, resample=Image.BILINEAR) - rgb_array = np.asarray(img) - else: - rgb_array = rgb_array.copy() + import matplotlib.pyplot as plt + + self.image_plot = None + self.background = None + + self.fig, self.ax = plt.subplots(figsize=self.figsize) + self.fig.tight_layout(pad=0) + self.ax.set_axis_off() + self.fig.subplots_adjust(left=0, right=1, top=1, bottom=0) + self.image_plot = self.ax.imshow(np.zeros((1, 1)), cmap="plasma", origin="upper", aspect="auto") + self._show_fig() + + def process(self, data, cur_time): + """Process new image data and update display.""" + if isinstance(data, torch.Tensor): + img_data = tensor_to_array(data) else: - # Slower but more generic fallback only if necessary - buffer = io.BytesIO() - self.fig.canvas.print_figure(buffer, format="png", dpi="figure") - buffer.seek(0) - img = Image.open(buffer) - rgb_array = np.asarray(img.convert("RGB")) - self._lock.release() + img_data = np.asarray(data) - return rgb_array + if img_data.ndim == 3 and img_data.shape[0] == 1: + img_data = img_data[0] # remove batch dimension + # TODO: color images? + # elif img_data.ndim == 3 and img_data.shape[-1] in [1, 3, 4]: + elif img_data.ndim == 3 and img_data.shape[-1] == 1: + img_data = img_data.squeeze(-1) + + vmin, vmax = np.min(img_data), np.max(img_data) + + current_vmin, current_vmax = self.image_plot.get_clim() + if vmin != current_vmin or vmax != current_vmax: + self.image_plot.set_clim(vmin, vmax) + self.fig.canvas.draw() + self.background = self.fig.canvas.copy_from_bbox(self.ax.bbox) + + self.fig.canvas.restore_region(self.background) + self.image_plot.set_data(img_data) + self.ax.draw_artist(self.image_plot) + self.fig.canvas.blit(self.ax.bbox) + + self.fig.canvas.flush_events() + + def cleanup(self): + super().cleanup() + + self.ax = None + self.image_plot = None + self.background = None diff --git a/genesis/sensors/base_sensor.py b/genesis/sensors/base_sensor.py index d6c632303..b02f35a6e 100644 --- a/genesis/sensors/base_sensor.py +++ b/genesis/sensors/base_sensor.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from functools import partial -from typing import TYPE_CHECKING, Sequence, Generic, TypeVar +from typing import TYPE_CHECKING, Generic, Sequence, TypeVar import gstaichi as ti import numpy as np @@ -17,6 +17,7 @@ if TYPE_CHECKING: from genesis.recorders.base_recorder import Recorder, RecorderOptions from genesis.utils.ring_buffer import TensorRingBuffer + from genesis.vis.rasterizer_context import RasterizerContext from .sensor_manager import SensorManager @@ -43,19 +44,23 @@ def _to_tuple(*values: NumericType | torch.Tensor, length_per_value: int = 3) -> class SensorOptions(Options): """ Base class for all sensor options. + Each sensor should have their own options class that inherits from this class. The options class should be registered with the SensorManager using the @register_sensor decorator. Parameters ---------- delay : float - The read delay time in seconds. Data read will be outdated by this amount. + The read delay time in seconds. Data read will be outdated by this amount. Defaults to 0.0 (no delay). update_ground_truth_only : bool - If True, the sensor will only update the ground truth data, and not the measured data. + If True, the sensor will only update the ground truth data, and not the measured data. Defaults to False. + draw_debug : bool + If True and visualizer is active, the sensor will draw debug shapes in the scene. Defaults to False. """ delay: float = 0.0 update_ground_truth_only: bool = False + draw_debug: bool = False def validate(self, scene): """ @@ -203,6 +208,12 @@ def _get_cache_dtype(cls) -> torch.dtype: """ raise NotImplementedError(f"{cls.__name__} has not implemented `get_cache_dtype()`.") + def _draw_debug(self, context: "RasterizerContext"): + """ + Draw debug shapes for the sensor in the scene. + """ + raise NotImplementedError(f"{type(self).__name__} has not implemented `draw_debug()`.") + # =============================== public shared methods =============================== @gs.assert_built @@ -386,10 +397,10 @@ def build(self): batch_size = self._manager._sim._B - link_start = self._shared_metadata.solver.entities[self._options.entity_idx].link_start - self._shared_metadata.links_idx = concat_with_tensor( - self._shared_metadata.links_idx, self._options.link_idx_local + link_start - ) + entity = self._shared_metadata.solver.entities[self._options.entity_idx] + self.link_idx = self._options.link_idx_local + entity.link_start + self.link = entity.links[self._options.link_idx_local] + self._shared_metadata.links_idx = concat_with_tensor(self._shared_metadata.links_idx, self.link_idx) self._shared_metadata.offsets_pos = concat_with_tensor( self._shared_metadata.offsets_pos, self._options.pos_offset, diff --git a/genesis/sensors/contact_force.py b/genesis/sensors/contact_force.py index 2a53699df..aae9e4c84 100644 --- a/genesis/sensors/contact_force.py +++ b/genesis/sensors/contact_force.py @@ -7,8 +7,8 @@ import genesis as gs from genesis.engine.solvers import RigidSolver -from genesis.utils.geom import ti_inv_transform_by_quat -from genesis.utils.misc import concat_with_tensor, make_tensor_field +from genesis.utils.geom import ti_inv_transform_by_quat, transform_by_quat +from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array from .base_sensor import ( MaybeTuple3FType, @@ -27,6 +27,7 @@ if TYPE_CHECKING: from genesis.utils.ring_buffer import TensorRingBuffer + from genesis.vis.rasterizer_context import RasterizerContext @ti.kernel @@ -79,8 +80,17 @@ class ContactSensorOptions(RigidSensorOptionsMixin, SensorOptions): The delay in seconds before the sensor data is read. update_ground_truth_only : bool If True, the sensor will only update the ground truth data, and not the measured data. + draw_debug : bool, optional + If True and the rasterizer visualization is active, a sphere will be drawn at the sensor's position. + debug_sphere_radius : float, optional + The radius of the debug sphere. Defaults to 0.05. + debug_color : float, optional + The rgba color of the debug sphere. Defaults to (1.0, 0.0, 1.0, 0.5). """ + debug_sphere_radius: float = 0.05 + debug_color: tuple[float, float, float, float] = (1.0, 0.0, 1.0, 0.5) + @dataclass class ContactSensorMetadata(SharedSensorMetadata): @@ -104,14 +114,17 @@ def build(self): if self._shared_metadata.solver is None: self._shared_metadata.solver = self._manager._sim.rigid_solver - link_start = self._shared_metadata.solver.entities[self._options.entity_idx].link_start + entity = self._shared_metadata.solver.entities[self._options.entity_idx] + self.link_idx = self._options.link_idx_local + entity.link_start + self.link = entity.links[self._options.link_idx_local] + self._shared_metadata.expanded_links_idx = concat_with_tensor( - self._shared_metadata.expanded_links_idx, - link_start + self._options.link_idx_local, - expand=(1,), - dim=0, + self._shared_metadata.expanded_links_idx, self.link_idx, expand=(1,), dim=0 ) + if self._options.draw_debug: + self.debug_object = None + def _get_return_format(self) -> tuple[int, ...]: return (1,) @@ -143,6 +156,25 @@ def _update_shared_cache( buffered_data.append(shared_ground_truth_cache) cls._apply_delay_to_shared_cache(shared_metadata, shared_cache, buffered_data) + def _draw_debug(self, context: "RasterizerContext"): + """ + Draw debug sphere when the sensor detects contact. + + Only draws for first environment. + """ + envs_idx = 0 if self._manager._sim.n_envs > 0 else None + + pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) + is_contact = self.read(envs_idx=envs_idx).squeeze(0).item() + + if self.debug_object is not None: + context.clear_debug_object(self.debug_object) + + if is_contact: + self.debug_object = context.draw_debug_sphere( + pos=pos, radius=self._options.debug_sphere_radius, color=self._options.debug_color + ) + # ========================================================================================================== @@ -180,11 +212,20 @@ class ContactForceSensorOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin Otherwise, the sensor data at the closest time step will be used. Default is False. update_ground_truth_only : bool, optional If True, the sensor will only update the ground truth data, and not the measured data. + draw_debug : bool, optional + If True and the rasterizer visualization is active, an arrow for the contact force will be drawn. + debug_color : float, optional + The rgba color of the debug arrow. Defaults to (1.0, 0.0, 1.0, 0.5). + debug_scale : float, optional + The scale factor for the debug force arrow. Defaults to 0.01. """ min_force: MaybeTuple3FType = 0.0 max_force: MaybeTuple3FType = np.inf + debug_color: tuple[float, float, float, float] = (1.0, 0.0, 1.0, 0.5) + debug_scale: float = 0.01 + def validate(self, scene): super().validate(scene) if not ( @@ -244,6 +285,9 @@ def build(self): _to_tuple(self._options.max_force, length_per_value=3), ) + if self._options.draw_debug: + self.debug_object = None + def _get_return_format(self) -> tuple[int, ...]: return (3,) @@ -299,7 +343,25 @@ def _update_shared_cache( cls._add_noise_drift_bias(shared_metadata, shared_cache) shared_cache_per_sensor = shared_cache.reshape(shared_cache.shape[0], -1, 3) # B, n_sensors * 3 # clip for max force - shared_cache_per_sensor.clamp_(max=shared_metadata.max_force) + shared_cache_per_sensor.clamp_(min=-shared_metadata.max_force, max=shared_metadata.max_force) # set to 0 for undetectable force shared_cache_per_sensor[torch.abs(shared_cache_per_sensor) < shared_metadata.min_force] = 0.0 cls._quantize_to_resolution(shared_metadata.resolution, shared_cache) + + def _draw_debug(self, context: "RasterizerContext"): + """ + Draw debug arrow representing the contact force. + + Only draws for first environment. + """ + envs_idx = 0 if self._manager._sim.n_envs > 0 else None + + pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) + quat = self.link.get_quat(envs_idx=envs_idx).squeeze(0) + + force = self.read(envs_idx=envs_idx) + vec = tensor_to_array(transform_by_quat(force.squeeze(0) * self._options.debug_scale, quat)) + + if self.debug_object is not None: + context.clear_debug_object(self.debug_object) + self.debug_object = context.draw_debug_arrow(pos=pos, vec=vec, color=self._options.debug_color) diff --git a/genesis/sensors/imu.py b/genesis/sensors/imu.py index 3aad834dd..86c812e05 100644 --- a/genesis/sensors/imu.py +++ b/genesis/sensors/imu.py @@ -1,23 +1,19 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import gstaichi as ti import numpy as np import torch import genesis as gs -from genesis.utils.geom import ( - inv_transform_by_trans_quat, - transform_quat_by_quat, -) -from genesis.utils.misc import concat_with_tensor, make_tensor_field +from genesis.utils.geom import inv_transform_by_trans_quat, transform_by_quat, transform_quat_by_quat +from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array from .base_sensor import ( MaybeTuple3FType, NoisySensorMetadataMixin, NoisySensorMixin, NoisySensorOptionsMixin, - NumericType, RigidSensorMetadataMixin, RigidSensorMixin, RigidSensorOptionsMixin, @@ -30,6 +26,7 @@ if TYPE_CHECKING: from genesis.utils.ring_buffer import TensorRingBuffer + from genesis.vis.rasterizer_context import RasterizerContext Matrix3x3Type = tuple[tuple[float, float, float], tuple[float, float, float], tuple[float, float, float]] MaybeMatrix3x3Type = Matrix3x3Type | MaybeTuple3FType @@ -128,6 +125,12 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions Otherwise, the sensor data at the closest time step will be used. Default is False. update_ground_truth_only : bool, optional If True, the sensor will only update the ground truth data, and not the measured data. + draw_debug : bool, optional + If True and the rasterizer visualization is active, an arrow for linear acceleration will be drawn. + debug_acc_color : float, optional + The rgba color of the debug acceleration arrow. Defaults to (1.0, 0.0, 0.0, 0.5). + debug_acc_scale: float, optional + The scale factor for the debug acceleration arrow. Defaults to 0.01. """ acc_resolution: MaybeTuple3FType = 0.0 @@ -141,6 +144,9 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions acc_random_walk: MaybeTuple3FType = 0.0 gyro_random_walk: MaybeTuple3FType = 0.0 + debug_acc_color: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 0.5) + debug_acc_scale: float = 0.01 + def validate(self, scene): super().validate(scene) self._validate_axes_skew(self.acc_axes_skew) @@ -245,6 +251,11 @@ def build(self): dim=1, ) + if self._options.draw_debug: + self.debug_object = None + self.quat_offset = self._shared_metadata.offsets_quat[0, self._idx] + self.pos_offset = self._shared_metadata.offsets_pos[0, self._idx] + def _get_return_format(self) -> dict[str, tuple[int, ...]]: return { "lin_acc": (3,), @@ -310,3 +321,22 @@ def _update_shared_cache( # apply additive noise and bias to the shared cache cls._add_noise_drift_bias(shared_metadata, shared_cache) cls._quantize_to_resolution(shared_metadata.resolution, shared_cache) + + def _draw_debug(self, context: "RasterizerContext"): + """ + Draw debug arrow for the IMU acceleration. + + Only draws for first environment. + """ + envs_idx = 0 if self._manager._sim.n_envs > 0 else None + + quat = self.link.get_quat(envs_idx=envs_idx).squeeze(0) + pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) + transform_by_quat(self.pos_offset, quat) + + data = self.read(envs_idx=envs_idx) + vec = data["lin_acc"].squeeze(0) * self._options.debug_acc_scale + vec = tensor_to_array(transform_by_quat(vec, transform_quat_by_quat(self.quat_offset, quat))) + + if self.debug_object is not None: + context.clear_debug_object(self.debug_object) + self.debug_object = context.draw_debug_arrow(pos=pos, vec=vec, color=self._options.debug_acc_color) diff --git a/genesis/sensors/sensor_manager.py b/genesis/sensors/sensor_manager.py index 33265ba9c..627545599 100644 --- a/genesis/sensors/sensor_manager.py +++ b/genesis/sensors/sensor_manager.py @@ -5,6 +5,8 @@ from genesis.utils.ring_buffer import TensorRingBuffer if TYPE_CHECKING: + from genesis.vis.rasterizer_context import RasterizerContext + from .base_sensor import Sensor, SensorOptions, SharedSensorMetadata @@ -84,6 +86,11 @@ def step(self): self._buffered_data[dtype][:, cache_slice], ) + def draw_debug(self, context: "RasterizerContext"): + for sensor in self.sensors: + if sensor._options.draw_debug: + sensor._draw_debug(context) + def reset(self, envs_idx=None): envs_idx = self._sim._scene._sanitize_envs_idx(envs_idx) for dtype in self._buffered_data.keys(): diff --git a/genesis/vis/rasterizer_context.py b/genesis/vis/rasterizer_context.py index d72910518..98563ecea 100644 --- a/genesis/vis/rasterizer_context.py +++ b/genesis/vis/rasterizer_context.py @@ -2,13 +2,10 @@ import torch import trimesh -import gstaichi as ti - import genesis as gs import genesis.utils.geom as gu import genesis.utils.mesh as mu import genesis.utils.particle as pu - from genesis.ext import pyrender from genesis.ext.pyrender.jit_render import JITRenderer from genesis.utils.misc import tensor_to_array, ti_to_numpy @@ -806,6 +803,9 @@ def update_fem(self, buffer_updates): update_data = self._scene.reorder_vertices(node, vertices) buffer_updates[self._scene.get_buffer_id(node, "pos")] = update_data + def update_sensors(self): + self.sim._sensor_manager.draw_debug(self) + def on_lights(self): for light in self.lights: self.add_light(light) @@ -949,6 +949,7 @@ def update(self): self.update_sph(self.buffer) self.update_pbd(self.buffer) self.update_fem(self.buffer) + self.update_sensors() def add_light(self, light): # light direction is light pose's -z frame diff --git a/tests/test_recorders.py b/tests/test_recorders.py index 9d3dffd05..46534eade 100644 --- a/tests/test_recorders.py +++ b/tests/test_recorders.py @@ -4,7 +4,6 @@ import pytest import genesis as gs -from genesis.utils.misc import tensor_to_array from genesis.utils.image_exporter import as_grayscale_image from .utils import assert_allclose, rgb_array_to_png_bytes @@ -66,7 +65,7 @@ def dummy_data_func(): plotter = scene.start_recording( data_func=dummy_data_func, - rec_options=gs.recorders.MPLPlot( + rec_options=gs.recorders.MPLLinePlot( labels={"a": ("x", "y", "z"), "b": ("u", "v")}, title="Test MPLPlotter", history_length=HISTORY_LENGTH, From 91bcc0a0e2e51fe177610f2282dd6ec361ee1fa7 Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Wed, 24 Sep 2025 19:02:20 -0400 Subject: [PATCH 2/4] add imu gyro draw_debug, update doc --- examples/sensors/imu_franka.py | 2 +- genesis/recorders/plotters.py | 9 ++------- genesis/sensors/imu.py | 30 +++++++++++++++++++++--------- 3 files changed, 24 insertions(+), 17 deletions(-) diff --git a/examples/sensors/imu_franka.py b/examples/sensors/imu_franka.py index a3c255cae..85df8c6f4 100644 --- a/examples/sensors/imu_franka.py +++ b/examples/sensors/imu_franka.py @@ -44,7 +44,7 @@ def main(): gs.sensors.IMU( entity_idx=franka.idx, link_idx_local=end_effector.idx_local, - pos_offset=(0.0, 0.0, 0.2), + pos_offset=(0.0, 0.0, 0.15), # noise parameters acc_axes_skew=(0.0, 0.01, 0.02), gyro_axes_skew=(0.03, 0.04, 0.05), diff --git a/genesis/recorders/plotters.py b/genesis/recorders/plotters.py index 6523a4f1c..d8c559547 100644 --- a/genesis/recorders/plotters.py +++ b/genesis/recorders/plotters.py @@ -652,6 +652,8 @@ class MPLImagePlotterOptions(BasePlotterOptions): class MPLImagePlotter(BaseMPLPlotter): """ Live image viewer using matplotlib. + + The image data should be an array-like object with shape (H, W), (H, W, 1), (H, W, 3), or (H, W, 4). """ def build(self): @@ -676,13 +678,6 @@ def process(self, data, cur_time): else: img_data = np.asarray(data) - if img_data.ndim == 3 and img_data.shape[0] == 1: - img_data = img_data[0] # remove batch dimension - # TODO: color images? - # elif img_data.ndim == 3 and img_data.shape[-1] in [1, 3, 4]: - elif img_data.ndim == 3 and img_data.shape[-1] == 1: - img_data = img_data.squeeze(-1) - vmin, vmax = np.min(img_data), np.max(img_data) current_vmin, current_vmax = self.image_plot.get_clim() diff --git a/genesis/sensors/imu.py b/genesis/sensors/imu.py index 86c812e05..153cf1f6a 100644 --- a/genesis/sensors/imu.py +++ b/genesis/sensors/imu.py @@ -128,9 +128,13 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions draw_debug : bool, optional If True and the rasterizer visualization is active, an arrow for linear acceleration will be drawn. debug_acc_color : float, optional - The rgba color of the debug acceleration arrow. Defaults to (1.0, 0.0, 0.0, 0.5). + The rgba color of the debug acceleration arrow. Defaults to (0.0, 1.0, 1.0, 0.5). debug_acc_scale: float, optional The scale factor for the debug acceleration arrow. Defaults to 0.01. + debug_gyro_color : float, optional + The rgba color of the debug gyroscope arrow. Defaults to (1.0, 1.0, 0.0, 0.5). + debug_gyro_scale: float, optional + The scale factor for the debug gyroscope arrow. Defaults to 0.01. """ acc_resolution: MaybeTuple3FType = 0.0 @@ -144,8 +148,10 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions acc_random_walk: MaybeTuple3FType = 0.0 gyro_random_walk: MaybeTuple3FType = 0.0 - debug_acc_color: tuple[float, float, float, float] = (1.0, 0.0, 0.0, 0.5) + debug_acc_color: tuple[float, float, float, float] = (0.0, 1.0, 1.0, 0.5) debug_acc_scale: float = 0.01 + debug_gyro_color: tuple[float, float, float, float] = (1.0, 1.0, 0.0, 0.5) + debug_gyro_scale: float = 0.01 def validate(self, scene): super().validate(scene) @@ -252,7 +258,7 @@ def build(self): ) if self._options.draw_debug: - self.debug_object = None + self.debug_objects = [None, None] self.quat_offset = self._shared_metadata.offsets_quat[0, self._idx] self.pos_offset = self._shared_metadata.offsets_pos[0, self._idx] @@ -334,9 +340,15 @@ def _draw_debug(self, context: "RasterizerContext"): pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) + transform_by_quat(self.pos_offset, quat) data = self.read(envs_idx=envs_idx) - vec = data["lin_acc"].squeeze(0) * self._options.debug_acc_scale - vec = tensor_to_array(transform_by_quat(vec, transform_quat_by_quat(self.quat_offset, quat))) - - if self.debug_object is not None: - context.clear_debug_object(self.debug_object) - self.debug_object = context.draw_debug_arrow(pos=pos, vec=vec, color=self._options.debug_acc_color) + acc_vec = data["lin_acc"].squeeze(0) * self._options.debug_acc_scale + gyro_vec = data["ang_vel"].squeeze(0) * self._options.debug_gyro_scale + # transform from local frame to world frame + offset_quat = transform_quat_by_quat(self.quat_offset, quat) + acc_vec = tensor_to_array(transform_by_quat(acc_vec, offset_quat)) + gyro_vec = tensor_to_array(transform_by_quat(gyro_vec, offset_quat)) + + for debug_object in self.debug_objects: + if debug_object is not None: + context.clear_debug_object(debug_object) + self.debug_objects[0] = context.draw_debug_arrow(pos=pos, vec=acc_vec, color=self._options.debug_acc_color) + self.debug_objects[1] = context.draw_debug_arrow(pos=pos, vec=gyro_vec, color=self._options.debug_gyro_color) From 8fdbdd65e15860b6c67f31829765daa681eaa45b Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Tue, 30 Sep 2025 01:13:57 -0400 Subject: [PATCH 3/4] address review comments 1 --- examples/sensors/contact_force_go2.py | 56 +++-- genesis/recorders/plotters.py | 312 ++++++++++++++------------ genesis/sensors/base_sensor.py | 12 +- genesis/sensors/contact_force.py | 72 +++--- genesis/sensors/imu.py | 20 +- tests/test_recorders.py | 10 +- 6 files changed, 260 insertions(+), 222 deletions(-) diff --git a/examples/sensors/contact_force_go2.py b/examples/sensors/contact_force_go2.py index 0b7d49ec8..9a9178822 100644 --- a/examples/sensors/contact_force_go2.py +++ b/examples/sensors/contact_force_go2.py @@ -46,46 +46,42 @@ def main(): for link_name in foot_link_names: if args.force: - force_sensor = scene.add_sensor( - gs.sensors.ContactForce( - entity_idx=go2.idx, - link_idx_local=go2.get_link(link_name).idx_local, - draw_debug=True, - ) + sensor_options = gs.sensors.ContactForce( + entity_idx=go2.idx, + link_idx_local=go2.get_link(link_name).idx_local, + draw_debug=True, + ) + plot_kwargs = dict( + title=f"{link_name} Force Sensor Data", + labels=["force_x", "force_y", "force_z"], ) - if IS_PYQTGRAPH_AVAILABLE: - force_sensor.start_recording( - gs.recorders.PyQtLinePlot( - title="Force Sensor Data", - labels=["force_x", "force_y", "force_z"], - ) - ) - elif IS_MATPLOTLIB_AVAILABLE: - print("pyqtgraph not found, falling back to matplotlib.") - force_sensor.start_recording( - gs.recorders.MPLLinePlot( - title="Force Sensor Data", - labels=["force_x", "force_y", "force_z"], - ) - ) - else: - print("matplotlib or pyqtgraph not found, skipping real-time plotting.") else: - contact_sensor = scene.add_sensor( - gs.sensors.Contact( - entity_idx=go2.idx, - link_idx_local=go2.get_link(link_name).idx_local, - draw_debug=True, - ) + sensor_options = gs.sensors.Contact( + entity_idx=go2.idx, + link_idx_local=go2.get_link(link_name).idx_local, + draw_debug=True, + ) + plot_kwargs = dict( + title=f"{link_name} Contact Sensor Data", + labels=["in_contact"], ) + sensor = scene.add_sensor(sensor_options) + + if IS_PYQTGRAPH_AVAILABLE: + sensor.start_recording(gs.recorders.PyQtLinePlot(**plot_kwargs)) + elif IS_MATPLOTLIB_AVAILABLE: + print("pyqtgraph not found, falling back to matplotlib.") + sensor.start_recording(gs.recorders.MPLLinePlot(**plot_kwargs)) + else: + print("matplotlib or pyqtgraph not found, skipping real-time plotting.") + scene.build() try: steps = int(args.seconds / args.timestep) for _ in tqdm(range(steps)): scene.step() - except KeyboardInterrupt: gs.logger.info("Simulation interrupted, exiting.") finally: diff --git a/genesis/recorders/plotters.py b/genesis/recorders/plotters.py index d8c559547..01a2a7081 100644 --- a/genesis/recorders/plotters.py +++ b/genesis/recorders/plotters.py @@ -6,7 +6,7 @@ from collections.abc import Sequence from dataclasses import dataclass from functools import partial -from typing import Any, Generic, TypeVar +from typing import Any, Callable, T import numpy as np import torch @@ -16,7 +16,7 @@ from genesis.utils import has_display, tensor_to_array from .base_recorder import Recorder, RecorderOptions -from .recorder_manager import register_recording +from .recorder_manager import RecorderManager, register_recording IS_PYQTGRAPH_AVAILABLE = False try: @@ -68,15 +68,64 @@ class BasePlotterOptions(RecorderOptions): class BasePlotter(Recorder): + def __init__(self, manager: "RecorderManager", options: RecorderOptions, data_func: Callable[[], T]): + super().__init__(manager, options, data_func) + self._frames_buffer: list[np.ndarray] = [] + def build(self): super().build() self.show_window = self._options.show_window if self._options.show_window is not None else has_display() + self.video_writer = None + if self._options.save_to_filename: + + def _get_video_frame_buffer(plotter): + # Make sure that all the data in the pipe has been processed before rendering anything + if not plotter._frames_buffer: + if plotter._data_queue is not None and not plotter._data_queue.empty(): + while not plotter._frames_buffer: + time.sleep(0.1) + + return plotter._frames_buffer.pop(0) + + self.video_writer = self._manager.add_recorder( + data_func=partial(_get_video_frame_buffer, self), + rec_options=gs.recorders.VideoFile( + filename=self._options.save_to_filename, + hz=self._options.hz, + ), + ) + def process(self, data, cur_time): - pass # allow super() calls + # Update plot + self._update_plot() + + # Render frame if necessary + if self._options.save_to_filename: + self._frames_buffer.append(self.get_image_array()) def cleanup(self): - pass # allow super() calls + if self.video_writer is not None: + self.video_writer.stop() + self._frames_buffer.clear() + self.video_writer = None + + def _update_plot(self): + """ + Update plot. + """ + raise NotImplementedError(f"[{type(self).__name__}] _update_plot() is not implemented.") + + def get_image_array(self): + """ + Capture the plot image as a video frame. + + Returns + ------- + image_array : np.ndarray + The RGB image as a numpy array. + """ + raise NotImplementedError(f"[{type(self).__name__}] get_image_array() is not implemented.") @dataclass @@ -107,94 +156,70 @@ class LinePlotterMixinOptions: history_length: int = 100 -LinePlotterOptionsMixinT = TypeVar("LinePlotterOptionsMixinT", bound=LinePlotterMixinOptions) - - -class LinePlotterMixin(Generic[LinePlotterOptionsMixinT]): - """Base class for real-time plotters with shared functionality.""" +class LinePlotHelper: + """ + Helper class that manages line plot data. - def build(self): - super().build() + Use composition pattern. + """ - self.x_data: list[float] = [] - self.y_data: defaultdict[str, defaultdict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) + def __init__(self, options: LinePlotterMixinOptions, data: dict[str, Sequence] | Sequence): + self._x_data: list[float] = [] + self._y_data: defaultdict[str, defaultdict[str, list[float]]] = defaultdict(lambda: defaultdict(list)) + self._history_length = options.history_length # Note that these attributes will be set during first data processing or initialization - self.is_dict_data: bool | None = None - self.subplot_structure: dict[str, tuple[str, ...]] | None = None - - if self._options.labels is not None: - self._setup_plot_structure(self._options.labels) + self._is_dict_data: bool | None = None + self._subplot_structure: dict[str, tuple[str, ...]] = {} + + if isinstance(data, dict): + self._is_dict_data = True + + if options.labels is not None: + assert isinstance( + options.labels, dict + ), f"[{type(self).__name__}] Labels must be a dict when data is a dict" + assert set(options.labels.keys()) == set( + data.keys() + ), f"[{type(self).__name__}] Label keys must match data keys" + + for key in data.keys(): + data_values = _data_to_array(data[key]) + label_values = options.labels[key] + assert len(label_values) == len( + data_values + ), f"[{type(self).__name__}] Label count must match data count for key '{key}'" + self._subplot_structure[key] = tuple(label_values) + else: + self._subplot_structure = {} + for key, values in data.items(): + values = _data_to_array(values) + self._subplot_structure[key] = tuple(f"{key}_{i}" for i in range(len(values))) else: - self._setup_plot_structure(self._data_func()) - - self.video_writer = None - if self._options.save_to_filename: - - def _get_video_frame_buffer(plotter): - # Make sure that all the data in the pipe has been processed before rendering anything - if not plotter._frames_buffer: - if plotter._data_queue is not None and not plotter._data_queue.empty(): - while not plotter._frames_buffer: - time.sleep(0.1) - - return plotter._frames_buffer.pop(0) - - self.video_writer = self._manager.add_recorder( - data_func=partial(_get_video_frame_buffer, self), - rec_options=gs.recorders.VideoFile( - filename=self._options.save_to_filename, - hz=self._options.hz, - ), - ) - self._frames_buffer: list[np.ndarray] = [] - - def reset(self, envs_idx=None): - super().reset(envs_idx) - - # no envs specific resetting supported - self.x_data.clear() - self.y_data.clear() + self._is_dict_data = False + data = _data_to_array(data) - def cleanup(self): - """Clean up resources.""" - if self.video_writer is not None: - self.video_writer.stop() - self._frames_buffer.clear() - self.video_writer = None + if options.labels is not None: + if not isinstance(options.labels, Sequence): + options.labels = (options.labels,) + assert len(options.labels) == len(data), f"[{type(self).__name__}] Label count must match data count" + plot_labels = tuple(options.labels) + else: + plot_labels = tuple(f"data_{i}" for i in range(len(data))) - def _setup_plot_structure(self, labels_or_data: dict[str, Sequence] | Sequence): - """Set up the plot structure based on labels or first data sample.""" - if isinstance(labels_or_data, dict): - self.is_dict_data = True - next_dict_value = next(iter(labels_or_data.values())) + self._subplot_structure = {"main": plot_labels} - if isinstance(next_dict_value, (torch.Tensor, np.ndarray)): - # data was provided - self.subplot_structure = {} - for key, values in labels_or_data.items(): - values = _data_to_array(values) - self.subplot_structure[key] = tuple(f"{key}_{i}" for i in range(len(values))) - elif isinstance(next_dict_value, Sequence) and isinstance(next_dict_value[0], str): - # labels were provided - self.subplot_structure = {k: tuple(v) for k, v in labels_or_data.items()} - else: - gs.raise_exception(f"[{type(self).__name__}] Unsupported input argument type: {type(labels_or_data)}") - else: - self.is_dict_data = False - if not isinstance(labels_or_data, Sequence): - labels_or_data = (labels_or_data,) - if isinstance(labels_or_data[0], (int, float, np.number)): - labels_or_data = [f"data_{i}" for i in range(len(labels_or_data))] - self.subplot_structure = {"main": tuple(labels_or_data)} + def clear_data(self): + self._x_data.clear() + self._y_data.clear() def process(self, data, cur_time): """Process new data point and update plot.""" - if self.is_dict_data: + if self._is_dict_data: processed_data = {} for key, values in data.items(): - if key not in self.subplot_structure: + if key not in self._subplot_structure: continue # skip keys not included in subplot structure values = _data_to_array(values) processed_data[key] = values @@ -203,11 +228,11 @@ def process(self, data, cur_time): processed_data = {"main": data} # Update time data - self.x_data.append(cur_time) + self._x_data.append(cur_time) # Update y data for each subplot for subplot_key, subplot_data in processed_data.items(): - channel_labels = self.subplot_structure[subplot_key] + channel_labels = self._subplot_structure[subplot_key] if len(subplot_data) != len(channel_labels): gs.logger.warning( f"[{type(self).__name__}] Data length ({len(subplot_data)}) doesn't match " @@ -217,41 +242,33 @@ def process(self, data, cur_time): for i, channel_label in enumerate(channel_labels): if i < len(subplot_data): - self.y_data[subplot_key][channel_label].append(float(subplot_data[i])) + self._y_data[subplot_key][channel_label].append(float(subplot_data[i])) # Maintain rolling history window - if len(self.x_data) > self._options.history_length: - self.x_data.pop(0) - for subplot_key in self.y_data: - for channel_label in self.y_data[subplot_key]: + if len(self._x_data) > self._history_length: + self._x_data.pop(0) + for subplot_key in self._y_data: + for channel_label in self._y_data[subplot_key]: try: - self.y_data[subplot_key][channel_label].pop(0) + self._y_data[subplot_key][channel_label].pop(0) except IndexError: break # empty, nothing to do. - # Update plot - self._update_plot() - - # Render frame if necessary - if self._options.save_to_filename: - self._frames_buffer.append(self.get_image_array()) + @property + def x_data(self): + return self._x_data - def _update_plot(self): - """ - Update plot. - """ - raise NotImplementedError + @property + def y_data(self): + return self._y_data - def get_image_array(self): - """ - Capture the plot image as a video frame. + @property + def is_dict_data(self): + return self._is_dict_data - Returns - ------- - image_array : np.ndarray - The RGB image as a numpy array. - """ - raise NotImplementedError + @property + def subplot_structure(self): + return self._subplot_structure class BasePyQtPlotter(BasePlotter): @@ -304,22 +321,23 @@ def get_image_array(self): Returns ------- image_array : np.ndarray - The RGB image as a numpy array. + The image as a numpy array in (b,g,r,a) format. """ pixmap = self.widget.grab() qimage = pixmap.toImage() - qimage = qimage.convertToFormat(pg.QtGui.QImage.Format_RGB888) - ptr = qimage.bits() - ptr.setsize(qimage.byteCount()) - - return np.array(ptr).reshape((qimage.height(), qimage.width(), 3)) + # pyqtgraph provides imageToArray but it always outputs (b,g,r,a) format + # https://pyqtgraph.readthedocs.io/en/latest/api_reference/functions.html#pyqtgraph.functions.imageToArray + return pg.imageToArray(qimage, copy=True, transpose=True) class PyQtLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): """ Live line plot visualization of data using PyQtGraph. + The recorded data_func should return scalar data (single scalar, a tuple of scalars, or a dict with string keys and + scalar or tuple of scalars as values). + Parameters ---------- title: str @@ -346,23 +364,21 @@ class PyQtLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): @register_recording(PyQtLinePlotterOptions) -class PyQtLinePlotter(LinePlotterMixin[PyQtLinePlotterOptions], BasePyQtPlotter): - """ - Real-time plot using PyQt for live sensor data visualization. - """ +class PyQtLinePlotter(BasePyQtPlotter): def build(self): super().build() + self.line_plot = LinePlotHelper(options=self._options, data=self._data_func()) self.curves: dict[str, list[pg.PlotCurveItem]] = {} # create plots for each subplot - for subplot_idx, (subplot_key, channel_labels) in enumerate(self.subplot_structure.items()): + for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()): # add new row if not the first plot if subplot_idx > 0: self.widget.nextRow() - plot_widget = self.widget.addPlot(title=subplot_key if self.is_dict_data else self._options.title) + plot_widget = self.widget.addPlot(title=subplot_key if self.line_plot.is_dict_data else self._options.title) plot_widget.setLabel("bottom", self._options.x_label) plot_widget.setLabel("left", self._options.y_label) plot_widget.showGrid(x=True, y=True, alpha=0.3) @@ -379,20 +395,23 @@ def build(self): self.curves[subplot_key] = subplot_curves + def process(self, data, cur_time): + self.line_plot.process(data, cur_time) + super().process(data, cur_time) + def _update_plot(self): # update all curves for subplot_key, curves in self.curves.items(): - channel_labels = self.subplot_structure[subplot_key] + channel_labels = self.line_plot.subplot_structure[subplot_key] for curve, channel_label in zip(curves, channel_labels): - y_data = self.y_data[subplot_key][channel_label] - curve.setData(x=self.x_data, y=y_data) + curve.setData(x=self.line_plot.x_data, y=self.line_plot.y_data[subplot_key][channel_label]) if self.app: self.app.processEvents() def cleanup(self): super().cleanup() - + self.line_plot.clear_data() self.curves.clear() @@ -443,10 +462,6 @@ def cleanup(self): finally: self.fig = None - @property - def run_in_thread(self) -> bool: - return not self.show_window or gs.platform != "macOS" - def get_image_array(self): """ Capture the plot image as a video frame. @@ -483,11 +498,19 @@ def get_image_array(self): return rgb_array + @property + def run_in_thread(self) -> bool: + # matplotlib throws NSInternalInconsistencyException when trying to use threading for visualization on macOS + return not self.show_window or gs.platform != "macOS" + class MPLLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): """ Live line plot visualization of data using matplotlib. + The recorded data_func should return scalar data (single scalar, a tuple of scalars, or a dict with string keys and + scalar or tuple of scalars as values). + Parameters ---------- title: str @@ -514,16 +537,13 @@ class MPLLinePlotterOptions(BasePlotterOptions, LinePlotterMixinOptions): @register_recording(MPLLinePlotterOptions) -class MPLLinePlotter(LinePlotterMixin[MPLLinePlotterOptions], BaseMPLPlotter): - """ - Real-time plot using matplotlib for live sensor data visualization. - - Inherits common plotting functionality from BasePlotter. - """ +class MPLLinePlotter(BaseMPLPlotter): def build(self): super().build() + self.line_plot = LinePlotHelper(options=self._options, data=self._data_func()) + import matplotlib.pyplot as plt self.axes: list[plt.Axes] = [] @@ -531,7 +551,7 @@ def build(self): self.backgrounds: list[Any] = [] # Create figure and subplots - n_subplots = len(self.subplot_structure) + n_subplots = len(self.line_plot.subplot_structure) if n_subplots == 1: self.fig, ax = plt.subplots(figsize=self.figsize) self.axes = [ax] @@ -541,13 +561,13 @@ def build(self): self.fig.suptitle(self._options.title) # Create lines for each subplot - for subplot_idx, (subplot_key, channel_labels) in enumerate(self.subplot_structure.items()): + for subplot_idx, (subplot_key, channel_labels) in enumerate(self.line_plot.subplot_structure.items()): ax = self.axes[subplot_idx] ax.set_xlabel(self._options.x_label) ax.set_ylabel(self._options.y_label) ax.grid(True, alpha=0.3) - if self.is_dict_data and n_subplots > 1: + if self.line_plot.is_dict_data and n_subplots > 1: ax.set_title(subplot_key) subplot_lines = [] @@ -567,6 +587,10 @@ def build(self): self._show_fig() + def process(self, data, cur_time): + self.line_plot.process(data, cur_time) + super().process(data, cur_time) + def _update_plot(self): self._lock.acquire() @@ -576,8 +600,8 @@ def _update_plot(self): # Check if axis limits need updating for this subplot limits_changed = False - if self.x_data: - x_min, x_max = min(self.x_data), max(self.x_data) + if self.line_plot.x_data: + x_min, x_max = min(self.line_plot.x_data), max(self.line_plot.x_data) x_range = x_max - x_min if x_range == 0: x_range = 1 @@ -588,8 +612,8 @@ def _update_plot(self): # Update y limits based on all data in this subplot all_y_values = [] - for channel_label in self.y_data[subplot_key]: - all_y_values.extend(self.y_data[subplot_key][channel_label]) + for channel_label in self.line_plot.y_data[subplot_key]: + all_y_values.extend(self.line_plot.y_data[subplot_key][channel_label]) if all_y_values: y_min, y_max = min(all_y_values), max(all_y_values) @@ -610,10 +634,10 @@ def _update_plot(self): self.fig.canvas.restore_region(self.backgrounds[subplot_idx]) # Update lines - channel_labels = self.subplot_structure[subplot_key] + channel_labels = self.line_plot.subplot_structure[subplot_key] for line, channel_label in zip(subplot_lines, channel_labels): - y_data = self.y_data[subplot_key][channel_label] - line.set_data(self.x_data, y_data) + y_data = self.line_plot.y_data[subplot_key][channel_label] + line.set_data(self.line_plot.x_data, y_data) ax.draw_artist(line) # Blit the updated subplot @@ -624,7 +648,7 @@ def _update_plot(self): def cleanup(self): super().cleanup() - + self.line_plot.clear_data() self.lines.clear() self.backgrounds.clear() @@ -633,6 +657,8 @@ class MPLImagePlotterOptions(BasePlotterOptions): """ Live visualization of image data using matplotlib. + The image data should be an array-like object with shape (H, W), (H, W, 1), (H, W, 3), or (H, W, 4). + Parameters ---------- title: str diff --git a/genesis/sensors/base_sensor.py b/genesis/sensors/base_sensor.py index b02f35a6e..3c17944de 100644 --- a/genesis/sensors/base_sensor.py +++ b/genesis/sensors/base_sensor.py @@ -15,6 +15,7 @@ from genesis.utils.misc import concat_with_tensor, make_tensor_field if TYPE_CHECKING: + from genesis.engine.entities.rigid_entity.rigid_link import RigidLink from genesis.recorders.base_recorder import Recorder, RecorderOptions from genesis.utils.ring_buffer import TensorRingBuffer from genesis.vis.rasterizer_context import RasterizerContext @@ -389,6 +390,10 @@ class RigidSensorMixin(Generic[RigidSensorMetadataMixinT]): Base sensor class for sensors that are attached to a RigidEntity. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._link: "RigidLink" | None = None + def build(self): super().build() @@ -398,9 +403,10 @@ def build(self): batch_size = self._manager._sim._B entity = self._shared_metadata.solver.entities[self._options.entity_idx] - self.link_idx = self._options.link_idx_local + entity.link_start - self.link = entity.links[self._options.link_idx_local] - self._shared_metadata.links_idx = concat_with_tensor(self._shared_metadata.links_idx, self.link_idx) + self._link = entity.links[self._options.link_idx_local] + self._shared_metadata.links_idx = concat_with_tensor( + self._shared_metadata.links_idx, self._options.link_idx_local + entity.link_start + ) self._shared_metadata.offsets_pos = concat_with_tensor( self._shared_metadata.offsets_pos, self._options.pos_offset, diff --git a/genesis/sensors/contact_force.py b/genesis/sensors/contact_force.py index aae9e4c84..255819dbb 100644 --- a/genesis/sensors/contact_force.py +++ b/genesis/sensors/contact_force.py @@ -7,7 +7,7 @@ import genesis as gs from genesis.engine.solvers import RigidSolver -from genesis.utils.geom import ti_inv_transform_by_quat, transform_by_quat +from genesis.utils.geom import ti_inv_transform_by_quat, trans_to_T, transform_by_quat from genesis.utils.misc import concat_with_tensor, make_tensor_field, tensor_to_array from .base_sensor import ( @@ -26,9 +26,13 @@ from .sensor_manager import register_sensor if TYPE_CHECKING: + from genesis.engine.entities.rigid_entity.rigid_link import RigidLink + from genesis.ext.pyrender.mesh import Mesh from genesis.utils.ring_buffer import TensorRingBuffer from genesis.vis.rasterizer_context import RasterizerContext + from .sensor_manager import SensorManager + @ti.kernel def _kernel_get_contacts_forces( @@ -81,7 +85,7 @@ class ContactSensorOptions(RigidSensorOptionsMixin, SensorOptions): update_ground_truth_only : bool If True, the sensor will only update the ground truth data, and not the measured data. draw_debug : bool, optional - If True and the rasterizer visualization is active, a sphere will be drawn at the sensor's position. + If True and the interactive viewer is active, a sphere will be drawn at the sensor's position. debug_sphere_radius : float, optional The radius of the debug sphere. Defaults to 0.05. debug_color : float, optional @@ -109,22 +113,24 @@ class ContactSensor(Sensor): Sensor that returns bool based on whether associated RigidLink is in contact. """ + def __init__(self, sensor_options: ContactSensorOptions, sensor_idx: int, sensor_manager: "SensorManager"): + super().__init__(sensor_options, sensor_idx, sensor_manager) + self._link: "RigidLink" | None = None + self.debug_object: "Mesh" | None = None + def build(self): super().build() if self._shared_metadata.solver is None: self._shared_metadata.solver = self._manager._sim.rigid_solver entity = self._shared_metadata.solver.entities[self._options.entity_idx] - self.link_idx = self._options.link_idx_local + entity.link_start - self.link = entity.links[self._options.link_idx_local] + link_idx = self._options.link_idx_local + entity.link_start + self._link = entity.links[self._options.link_idx_local] self._shared_metadata.expanded_links_idx = concat_with_tensor( - self._shared_metadata.expanded_links_idx, self.link_idx, expand=(1,), dim=0 + self._shared_metadata.expanded_links_idx, link_idx, expand=(1,), dim=0 ) - if self._options.draw_debug: - self.debug_object = None - def _get_return_format(self) -> tuple[int, ...]: return (1,) @@ -160,20 +166,23 @@ def _draw_debug(self, context: "RasterizerContext"): """ Draw debug sphere when the sensor detects contact. - Only draws for first environment. + Only draws for first rendered environment. """ - envs_idx = 0 if self._manager._sim.n_envs > 0 else None - - pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) - is_contact = self.read(envs_idx=envs_idx).squeeze(0).item() + env_idx = context.rendered_envs_idx[0] - if self.debug_object is not None: - context.clear_debug_object(self.debug_object) + pos = self._link.get_pos(envs_idx=env_idx)[0] + is_contact = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None).item() if is_contact: - self.debug_object = context.draw_debug_sphere( - pos=pos, radius=self._options.debug_sphere_radius, color=self._options.debug_color - ) + if self.debug_object is None: + self.debug_object = context.draw_debug_sphere( + pos=pos, radius=self._options.debug_sphere_radius, color=self._options.debug_color + ) + else: + context.update_debug_objects([self.debug_object], trans_to_T(pos).unsqueeze(0)) + elif self.debug_object is not None: + context.clear_debug_object(self.debug_object) + self.debug_object = None # ========================================================================================================== @@ -190,9 +199,9 @@ class ContactForceSensorOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin link_idx_local : int, optional The local index of the RigidLink of the RigidEntity to which this sensor is attached. min_force : float | tuple[float, float, float], optional - The minimum detectable force per each axis. Values below this will be treated as 0. Default is 0. + The minimum detectable absolute force per each axis. Values below this will be treated as 0. Default is 0. max_force : float | tuple[float, float, float], optional - The maximum output force per each axis. Values above this will be clipped. Default is infinity. + The maximum output absolute force per each axis. Values above this will be clipped. Default is infinity. resolution : float | tuple[float, float, float], optional The measurement resolution of each axis of force (smallest increment of change in the sensor reading). Default is 0.0, which means no quantization is applied. @@ -213,7 +222,7 @@ class ContactForceSensorOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin update_ground_truth_only : bool, optional If True, the sensor will only update the ground truth data, and not the measured data. draw_debug : bool, optional - If True and the rasterizer visualization is active, an arrow for the contact force will be drawn. + If True and the interactive viewer is active, an arrow for the contact force will be drawn. debug_color : float, optional The rgba color of the debug arrow. Defaults to (1.0, 0.0, 1.0, 0.5). debug_scale : float, optional @@ -267,6 +276,10 @@ class ContactForceSensor( Sensor that returns the total contact force being applied to the associated RigidLink in its local frame. """ + def __init__(self, sensor_options: ContactForceSensorOptions, sensor_idx: int, sensor_manager: "SensorManager"): + super().__init__(sensor_options, sensor_idx, sensor_manager) + self.debug_object: "Mesh" | None = None + def build(self): if not (isinstance(self._options.resolution, tuple) and len(self._options.resolution) == 3): self._options.resolution = tuple([self._options.resolution] * 3) @@ -285,9 +298,6 @@ def build(self): _to_tuple(self._options.max_force, length_per_value=3), ) - if self._options.draw_debug: - self.debug_object = None - def _get_return_format(self) -> tuple[int, ...]: return (3,) @@ -352,16 +362,16 @@ def _draw_debug(self, context: "RasterizerContext"): """ Draw debug arrow representing the contact force. - Only draws for first environment. + Only draws for first rendered environment. """ - envs_idx = 0 if self._manager._sim.n_envs > 0 else None + env_idx = context.rendered_envs_idx[0] - pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) - quat = self.link.get_quat(envs_idx=envs_idx).squeeze(0) + pos = self._link.get_pos(envs_idx=env_idx) + quat = self._link.get_quat(envs_idx=env_idx) - force = self.read(envs_idx=envs_idx) - vec = tensor_to_array(transform_by_quat(force.squeeze(0) * self._options.debug_scale, quat)) + force = self.read(envs_idx=env_idx if self._manager._sim.n_envs > 0 else None) + vec = tensor_to_array(transform_by_quat(force * self._options.debug_scale, quat)) if self.debug_object is not None: context.clear_debug_object(self.debug_object) - self.debug_object = context.draw_debug_arrow(pos=pos, vec=vec, color=self._options.debug_color) + self.debug_object = context.draw_debug_arrow(pos=pos[0], vec=vec[0], color=self._options.debug_color) diff --git a/genesis/sensors/imu.py b/genesis/sensors/imu.py index 153cf1f6a..9f17cbcc6 100644 --- a/genesis/sensors/imu.py +++ b/genesis/sensors/imu.py @@ -126,7 +126,7 @@ class IMUOptions(RigidSensorOptionsMixin, NoisySensorOptionsMixin, SensorOptions update_ground_truth_only : bool, optional If True, the sensor will only update the ground truth data, and not the measured data. draw_debug : bool, optional - If True and the rasterizer visualization is active, an arrow for linear acceleration will be drawn. + If True and the interactive viewer is active, an arrow for linear acceleration will be drawn. debug_acc_color : float, optional The rgba color of the debug acceleration arrow. Defaults to (0.0, 1.0, 1.0, 0.5). debug_acc_scale: float, optional @@ -332,16 +332,16 @@ def _draw_debug(self, context: "RasterizerContext"): """ Draw debug arrow for the IMU acceleration. - Only draws for first environment. + Only draws for first rendered environment. """ - envs_idx = 0 if self._manager._sim.n_envs > 0 else None + env_idx = context.rendered_envs_idx[0] - quat = self.link.get_quat(envs_idx=envs_idx).squeeze(0) - pos = self.link.get_pos(envs_idx=envs_idx).squeeze(0) + transform_by_quat(self.pos_offset, quat) + quat = self._link.get_quat(envs_idx=env_idx) + pos = self._link.get_pos(envs_idx=env_idx) + transform_by_quat(self.pos_offset, quat) - data = self.read(envs_idx=envs_idx) - acc_vec = data["lin_acc"].squeeze(0) * self._options.debug_acc_scale - gyro_vec = data["ang_vel"].squeeze(0) * self._options.debug_gyro_scale + data = self.read(envs_idx=env_idx) + acc_vec = data["lin_acc"] * self._options.debug_acc_scale + gyro_vec = data["ang_vel"] * self._options.debug_gyro_scale # transform from local frame to world frame offset_quat = transform_quat_by_quat(self.quat_offset, quat) acc_vec = tensor_to_array(transform_by_quat(acc_vec, offset_quat)) @@ -350,5 +350,5 @@ def _draw_debug(self, context: "RasterizerContext"): for debug_object in self.debug_objects: if debug_object is not None: context.clear_debug_object(debug_object) - self.debug_objects[0] = context.draw_debug_arrow(pos=pos, vec=acc_vec, color=self._options.debug_acc_color) - self.debug_objects[1] = context.draw_debug_arrow(pos=pos, vec=gyro_vec, color=self._options.debug_gyro_color) + self.debug_objects[0] = context.draw_debug_arrow(pos=pos[0], vec=acc_vec, color=self._options.debug_acc_color) + self.debug_objects[1] = context.draw_debug_arrow(pos=pos[0], vec=gyro_vec, color=self._options.debug_gyro_color) diff --git a/tests/test_recorders.py b/tests/test_recorders.py index 46534eade..82df309c6 100644 --- a/tests/test_recorders.py +++ b/tests/test_recorders.py @@ -84,14 +84,14 @@ def dummy_data_func(): if plotter.run_in_thread: plotter.sync() - assert call_count == STEPS // 2 - assert len(plotter.x_data) == HISTORY_LENGTH - assert np.isclose(plotter.x_data[-1], STEPS * DT, atol=gs.EPS) + assert call_count == STEPS // 2 + 1 # one additional call during plot setup + assert len(plotter.line_plot.x_data) == HISTORY_LENGTH + assert np.isclose(plotter.line_plot.x_data[-1], STEPS * DT, atol=gs.EPS) assert rgb_array_to_png_bytes(plotter.get_image_array()) == png_snapshot assert len(buffers) == 5 - assert_allclose([cur_time for data, cur_time in buffers], np.arange(STEPS + 1)[::2][1:] * DT, tol=gs.EPS) - for rgb_diff in np.diff([data for data, cur_time in buffers], axis=0): + assert_allclose([cur_time for _, cur_time in buffers], np.arange(STEPS + 1)[::2][1:] * DT, tol=gs.EPS) + for rgb_diff in np.diff([data for data, _ in buffers], axis=0): assert rgb_diff.max() > 10.0 # Intentionally do not stop the recording to test the destructor From c42c67616e132717350aa572aef11c524257deb5 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Tue, 30 Sep 2025 10:20:57 +0200 Subject: [PATCH 4/4] Update snapshot. --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index f16ab2d36..9849f0c6a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -32,7 +32,7 @@ DEFAULT_BRANCH_NAME = "main" HUGGINGFACE_ASSETS_REVISION = "4d96c3512df4421d4dd3d626055d0d1ebdfdd7cc" -HUGGINGFACE_SNAPSHOT_REVISION = "74f5b178fb96dfa17a05d98585af8e212db9b4e6" +HUGGINGFACE_SNAPSHOT_REVISION = "a6fd3b99364b927dd5367488e58cd251f254fa94" MESH_EXTENSIONS = (".mtl", *MESH_FORMATS, *GLTF_FORMATS, *USD_FORMATS) IMAGE_EXTENSIONS = (".png", ".jpg")