diff --git a/mujoco_playground/_src/mjx_env.py b/mujoco_playground/_src/mjx_env.py index b17413211..9a25a438a 100644 --- a/mujoco_playground/_src/mjx_env.py +++ b/mujoco_playground/_src/mjx_env.py @@ -343,7 +343,10 @@ def get_image(state, modify_scn_fn=None) -> np.ndarray: mujoco.mj_forward(mj_model, d) renderer.update_scene(d, camera=camera, scene_option=scene_option) if modify_scn_fn is not None: - modify_scn_fn(renderer.scene) + if modify_scene_fn.__code__.co_argcount == 1: + modify_scn_fn(renderer.scene) + elif modify_scene_fn.__code__.co_argcount == 2: + modify_scn_fn(renderer.scene, state) return renderer.render() if isinstance(trajectory, list):