diff --git a/cosmos_framework/data/vfm/action/action_normalization.py b/cosmos_framework/data/vfm/action/action_normalization.py index c58bb90c..8504cf0c 100644 --- a/cosmos_framework/data/vfm/action/action_normalization.py +++ b/cosmos_framework/data/vfm/action/action_normalization.py @@ -12,7 +12,7 @@ from cosmos_framework.utils import log -def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, np.ndarray]: +def load_action_stats(stats_path: str) -> dict[str, np.ndarray]: """Load pre-computed action normalization stats from a JSON file.""" path = Path(stats_path) if not path.exists(): @@ -20,12 +20,6 @@ def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, n log.info(f"Loading action normalization stats from {stats_path}") with path.open("r") as f: raw = json.load(f) - if stats_key in raw: - raw = raw[stats_key] - if not isinstance(raw, dict): - raise TypeError(f"Action normalization stats block {stats_key!r} in {stats_path} must be a dict.") - elif stats_key != "global": - raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.") stat_keys = {"mean", "std", "min", "max", "q01", "q99"} return {key: np.array(value, dtype=np.float32) for key, value in raw.items() if key in stat_keys} @@ -39,11 +33,28 @@ def normalize_action( if method == "quantile": q01, q99 = stats["q01"], stats["q99"] denom = (q99 - q01).clamp(min=1e-8) - return (2.0 * (action - q01) / denom - 1.0).clamp(-1.0, 1.0) + return 2.0 * (action - q01) / denom - 1.0 if method == "meanstd": return (action - stats["mean"]) / stats["std"].clamp(min=1e-8) if method == "minmax": lo, hi = stats["min"], stats["max"] denom = (hi - lo).clamp(min=1e-8) - return (2.0 * (action - lo) / denom - 1.0).clamp(-1.0, 1.0) + return 2.0 * (action - lo) / denom - 1.0 + raise ValueError(f"Unknown normalization method: {method!r}") + + +def denormalize_action( + action: torch.Tensor, + method: str, + stats: dict[str, torch.Tensor], +) -> torch.Tensor: + """Denormalize action tensor.""" + if method == "quantile": + q01, q99 = stats["q01"], stats["q99"] + return 0.5 * (action + 1.0) * (q99 - q01) + q01 + if method == "meanstd": + return action * stats["std"] + stats["mean"] + if method == "minmax": + lo, hi = stats["min"], stats["max"] + return 0.5 * (action + 1.0) * (hi - lo) + lo raise ValueError(f"Unknown normalization method: {method!r}") diff --git a/cosmos_framework/data/vfm/action/action_normalization_test.py b/cosmos_framework/data/vfm/action/action_normalization_test.py new file mode 100644 index 00000000..bbfff16d --- /dev/null +++ b/cosmos_framework/data/vfm/action/action_normalization_test.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +import json + +import numpy as np +import pytest +import torch + +from cosmos_framework.data.vfm.action.action_normalization import ( + denormalize_action, + load_action_stats, + normalize_action, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_RAW_STATS = { + "mean": [0.0, 1.0, -1.0], + "std": [1.0, 2.0, 0.5], + "min": [-2.0, -1.0, -3.0], + "max": [2.0, 3.0, 1.0], + "q01": [-1.0, 0.0, -2.0], + "q99": [1.0, 2.0, 0.0], +} + + +def _tensor_stats(raw=_RAW_STATS) -> dict[str, torch.Tensor]: + return {k: torch.tensor(v, dtype=torch.float32) for k, v in raw.items()} + + +def _action() -> torch.Tensor: + return torch.tensor([[0.0, 1.0, -1.0], [1.0, 2.0, 0.0]], dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# load_action_stats +# --------------------------------------------------------------------------- + + +def test_load_action_stats_flat(tmp_path): + p = tmp_path / "stats.json" + p.write_text(json.dumps(_RAW_STATS)) + result = load_action_stats(str(p)) + assert set(result) == set(_RAW_STATS) + for key, value in result.items(): + assert isinstance(value, np.ndarray) + assert value.dtype == np.float32 + np.testing.assert_array_equal(value, np.array(_RAW_STATS[key], dtype=np.float32)) + + +def test_load_action_stats_filters_unknown_keys(tmp_path): + raw = {**_RAW_STATS, "extra_field": [1.0, 2.0]} + p = tmp_path / "stats.json" + p.write_text(json.dumps(raw)) + result = load_action_stats(str(p)) + assert "extra_field" not in result + + +def test_load_action_stats_missing_file(): + with pytest.raises(FileNotFoundError): + load_action_stats("/nonexistent/path/stats.json") + + +# --------------------------------------------------------------------------- +# normalize_action / denormalize_action — round-trip identity +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("method", ["quantile", "meanstd", "minmax"]) +def test_round_trip(method): + action = _action() + stats = _tensor_stats() + normalized = normalize_action(action, method, stats) + recovered = denormalize_action(normalized, method, stats) + torch.testing.assert_close(recovered, action, atol=1e-5, rtol=1e-5) + + +# --------------------------------------------------------------------------- +# normalize_action — endpoint correctness +# --------------------------------------------------------------------------- + + +def test_normalize_quantile_endpoints(): + stats = _tensor_stats() + q01, q99 = stats["q01"], stats["q99"] + assert torch.allclose(normalize_action(q01.unsqueeze(0), "quantile", stats), torch.full((1, 3), -1.0)) + assert torch.allclose(normalize_action(q99.unsqueeze(0), "quantile", stats), torch.full((1, 3), 1.0)) + + +def test_normalize_minmax_endpoints(): + stats = _tensor_stats() + lo, hi = stats["min"], stats["max"] + assert torch.allclose(normalize_action(lo.unsqueeze(0), "minmax", stats), torch.full((1, 3), -1.0)) + assert torch.allclose(normalize_action(hi.unsqueeze(0), "minmax", stats), torch.full((1, 3), 1.0)) + + +def test_normalize_meanstd_zero_mean(): + stats = _tensor_stats() + result = normalize_action(stats["mean"].unsqueeze(0), "meanstd", stats) + assert torch.allclose(result, torch.zeros(1, 3)) + + +# --------------------------------------------------------------------------- +# denormalize_action — endpoint correctness +# --------------------------------------------------------------------------- + + +def test_denormalize_quantile_endpoints(): + stats = _tensor_stats() + q01, q99 = stats["q01"], stats["q99"] + assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "quantile", stats), q01.unsqueeze(0)) + assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "quantile", stats), q99.unsqueeze(0)) + + +def test_denormalize_minmax_endpoints(): + stats = _tensor_stats() + lo, hi = stats["min"], stats["max"] + assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "minmax", stats), lo.unsqueeze(0)) + assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "minmax", stats), hi.unsqueeze(0)) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +def test_normalize_zero_range_no_nan(): + stats = {k: torch.zeros(3) for k in ("q01", "q99", "mean", "std", "min", "max")} + action = torch.ones(1, 3) + for method in ("quantile", "meanstd", "minmax"): + result = normalize_action(action, method, stats) + assert torch.isfinite(result).all(), f"{method} produced non-finite output with zero range" + + +def test_normalize_unknown_method_raises(): + with pytest.raises(ValueError, match="Unknown normalization method"): + normalize_action(_action(), "unknown_method", _tensor_stats()) + + +def test_denormalize_unknown_method_raises(): + with pytest.raises(ValueError, match="Unknown normalization method"): + denormalize_action(_action(), "unknown_method", _tensor_stats()) diff --git a/cosmos_framework/data/vfm/action/agibot_fk.py b/cosmos_framework/data/vfm/action/agibot_fk.py new file mode 100644 index 00000000..10382844 --- /dev/null +++ b/cosmos_framework/data/vfm/action/agibot_fk.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Lightweight AgiBot World forward kinematics for datasets and viewers.""" + +from __future__ import annotations + +import xml.etree.ElementTree as ET +from functools import lru_cache + +import numpy as np + +from cosmos_framework.data.vfm.action.agibot_spec import ( + AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT, + AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT, + AGIBOT_WORLD_ARM_STATE_SLICE, + AGIBOT_WORLD_EXT_ARM_STATE_SLICE, + AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX, + AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX, + AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE, + AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE, + AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE, + AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE, + AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX, + AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX, + AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG, + AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD, + AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME, + AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME, + AGIBOT_WORLD_HEAD_YAW_JOINT_NAME, + AGIBOT_WORLD_LEFT_EE_LINK_NAME, + AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS, + AGIBOT_WORLD_RIGHT_EE_LINK_NAME, + AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS, + AGIBOT_WORLD_STATE_HEAD_PITCH_IDX, + AGIBOT_WORLD_STATE_HEAD_YAW_IDX, + AGIBOT_WORLD_STATE_WAIST_LIFT_IDX, + AGIBOT_WORLD_STATE_WAIST_PITCH_IDX, + AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME, + AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME, + get_agibot_world_embodiment_spec, + get_agibot_world_kind_spec, + get_agibot_world_urdf_path, +) +from cosmos_framework.data.vfm.action.pose_utils import convert_rotation + +_GRIPPER_VALUE_EPS = 1e-4 +_QUATERNION_NORM_EPS = 1e-8 +_GRIPPER_ACTUATOR_OVERSHOOT_DEG = 5.0 +# Main-branch wrist rotations composed with one extra local-Z 180 degree rotation. +AGIBOT_WORLD_LEFT_GRIPPER_TO_OPENCV: np.ndarray = np.asarray( + [ + [0.0, 1.0, 0.0], + [-1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, +) +AGIBOT_WORLD_RIGHT_GRIPPER_TO_OPENCV: np.ndarray = np.asarray( + [ + [0.0, -1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 1.0], + ], + dtype=np.float32, +) +AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST: dict[str, np.ndarray] = { + "right_wrist": AGIBOT_WORLD_RIGHT_GRIPPER_TO_OPENCV, + "left_wrist": AGIBOT_WORLD_LEFT_GRIPPER_TO_OPENCV, +} + + +def _scale_to_unit_interval(values: np.ndarray, scale: float) -> np.ndarray: + """Scale non-negative gripper actuator values to ``[0,1]``.""" + + return np.clip(values / scale, 0.0, 1.0).astype(np.float32, copy=False) + + +def _scale_negative_to_unit_interval(values: np.ndarray, scale: float) -> np.ndarray: + """Scale URDF-style negative gripper angles to ``[0,1]`` open fractions.""" + + return np.clip(-values / scale, 0.0, 1.0).astype(np.float32, copy=False) + + +def _normalize_quaternions_xyzw(quaternions: np.ndarray) -> np.ndarray: + """Normalize ``xyzw`` quaternions, treating all-zero rows as identity.""" + + normalized = np.asarray(quaternions, dtype=np.float32).copy() # [T,4] + norms = np.linalg.norm(normalized, axis=-1, keepdims=True) # [T,1] + valid = norms[:, 0] >= _QUATERNION_NORM_EPS # [T] + normalized[valid] = normalized[valid] / norms[valid] # [T_valid,4] + normalized[~valid] = np.asarray([0.0, 0.0, 0.0, 1.0], dtype=np.float32) # [T_invalid,4] + return normalized + + +def _quat_xyzw_to_rotation_matrix(quaternions: np.ndarray) -> np.ndarray: + """Convert ``xyzw`` quaternions to rotation matrices.""" + + normalized = _normalize_quaternions_xyzw(quaternions) # [T,4] + rotations = convert_rotation( + normalized, + input_format="quat_xyzw", + output_format="matrix", + normalize_matrix=True, + ) + return np.asarray(rotations, dtype=np.float32) + + +def build_robot_base_transforms(positions: np.ndarray, quaternions: np.ndarray) -> np.ndarray: + """Build robot-base poses from position and ``xyzw`` quaternion arrays.""" + + positions = np.asarray(positions, dtype=np.float32) # [T,3] + quaternions = np.asarray(quaternions, dtype=np.float32) # [T,4] + if positions.ndim != 2 or positions.shape[1] != 3: + raise ValueError(f"robot base positions must have shape [T,3], got {positions.shape}.") + if quaternions.ndim != 2 or quaternions.shape[1] != 4: + raise ValueError(f"robot base quaternions must have shape [T,4], got {quaternions.shape}.") + if positions.shape[0] != quaternions.shape[0]: + raise ValueError( + f"robot base positions/quaternions must share T, got {positions.shape[0]} and {quaternions.shape[0]}." + ) + + transforms = np.tile(np.eye(4, dtype=np.float32), (positions.shape[0], 1, 1)) # [T,4,4] + transforms[:, :3, :3] = _quat_xyzw_to_rotation_matrix(quaternions) # [T,3,3] + transforms[:, :3, 3] = positions # [T,3] + return transforms + + +def _invert_rigid_transform(transform: np.ndarray) -> np.ndarray: + """Invert one homogeneous rigid transform.""" + + inverse = np.eye(4, dtype=np.float32) # [4,4] + rotation_t = transform[:3, :3].T.astype(np.float32, copy=False) # [3,3] + inverse[:3, :3] = rotation_t + inverse[:3, 3] = -(rotation_t @ transform[:3, 3]) # [3] + return inverse + + +def apply_robot_base_motion_to_poses( + poses_by_name: dict[str, np.ndarray], + positions: np.ndarray, + quaternions: np.ndarray, +) -> dict[str, np.ndarray]: + """Apply mobile-base motion to FK poses, normalized to the first frame.""" + + base_poses = build_robot_base_transforms(positions, quaternions) # [T,4,4] + initial_base_inv = _invert_rigid_transform(base_poses[0]) # [4,4] + base_motion = np.einsum("ij,tjk->tik", initial_base_inv, base_poses).astype(np.float32, copy=False) # [T,4,4] + return { + name: np.einsum("tij,tjk->tik", base_motion, poses).astype(np.float32, copy=False) # [T,4,4] + for name, poses in poses_by_name.items() + } + + +def _apply_ext_base_motion_to_poses( + poses_by_name: dict[str, np.ndarray], + states: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Apply ext mobile-base motion to FK poses, normalized to the first frame.""" + + if embodiment_type != "agibot_world_gripper_ext": + return poses_by_name + if states.shape[1] < AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE.stop: + raise ValueError( + f"agibot_world_gripper_ext state must include robot pose through index " + f"{AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE.stop - 1}, got shape {states.shape}." + ) + + positions = states[:, AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE].astype(np.float32, copy=False) # [T,3] + quaternions = states[:, AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE].astype(np.float32, copy=False) # [T,4] + return apply_robot_base_motion_to_poses(poses_by_name, positions, quaternions) + + +def apply_agibot_gripper_to_opencv( + poses_by_name: dict[str, np.ndarray], + to_opencv_by_wrist: dict[str, np.ndarray], +) -> dict[str, np.ndarray]: + """Post-rotate AgiBot gripper wrist poses into OpenCV convention.""" + + aligned = {name: poses.astype(np.float32, copy=True) for name, poses in poses_by_name.items()} # {name:[...,4,4]} + for wrist_name, wrist_to_opencv in to_opencv_by_wrist.items(): + poses = aligned.get(wrist_name) + if poses is None: + continue + aligned[wrist_name][..., :3, :3] = poses[..., :3, :3] @ wrist_to_opencv.astype(poses.dtype) # [...,3,3] + return aligned + + +def _get_agibot_world_mujoco_kinematics_xml() -> str: + """Build a MuJoCo-loadable kinematics-only XML string from the committed URDF.""" + + root = ET.parse(get_agibot_world_urdf_path()).getroot() + mujoco_element = root.find("mujoco") + if mujoco_element is None: + mujoco_element = ET.Element("mujoco") + root.insert(0, mujoco_element) + compiler_element = mujoco_element.find("compiler") + if compiler_element is None: + compiler_element = ET.SubElement(mujoco_element, "compiler") + compiler_element.attrib["fusestatic"] = "false" + + for link_element in root.findall("link"): + for child_element in list(link_element): + if child_element.tag in {"visual", "collision"}: + link_element.remove(child_element) + + return ET.tostring(root, encoding="unicode") + + +class _MujocoFk: + """MuJoCo-backed FK engine for the committed AgiBot G1 omnipicker URDF.""" + + def __init__(self) -> None: + import mujoco + + self._mujoco = mujoco + self.model = mujoco.MjModel.from_xml_string(_get_agibot_world_mujoco_kinematics_xml()) + self.data = mujoco.MjData(self.model) + self._joint_qpos_addresses: dict[str, int] = {} + for joint_id in range(self.model.njnt): + joint_name = mujoco.mj_id2name(self.model, mujoco.mjtObj.mjOBJ_JOINT, joint_id) + if joint_name is not None: + self._joint_qpos_addresses[joint_name] = int(self.model.jnt_qposadr[joint_id]) + + def link_poses(self, joint_values: dict[str, float]) -> dict[str, np.ndarray]: + """Return world transforms for every named body in the MuJoCo model.""" + + self.data.qpos[:] = 0.0 + for joint_name, joint_value in joint_values.items(): + qpos_address = self._joint_qpos_addresses.get(joint_name) + if qpos_address is not None: + self.data.qpos[qpos_address] = float(joint_value) + self._mujoco.mj_forward(self.model, self.data) + + poses: dict[str, np.ndarray] = {} + for body_id in range(1, self.model.nbody): + body_name = self._mujoco.mj_id2name(self.model, self._mujoco.mjtObj.mjOBJ_BODY, body_id) + if body_name is None: + continue + transform = np.eye(4, dtype=np.float32) + transform[:3, :3] = self.data.xmat[body_id].reshape(3, 3).astype(np.float32, copy=False) + transform[:3, 3] = self.data.xpos[body_id].astype(np.float32, copy=False) + poses[body_name] = transform + return poses + + +@lru_cache(maxsize=1) +def _get_fk_engine() -> _MujocoFk: + """Return a cached MuJoCo FK engine for the committed AgiBot URDF.""" + + return _MujocoFk() + + +def _extract_joint_values_from_state(state: np.ndarray, embodiment_type: str) -> dict[str, float]: + """Map one observation.state vector to the URDF joint names used for FK.""" + + if embodiment_type == "agibot_world_gripper_ext": + # Ext layout: 94-dim state with joints at different offsets. + arm_state = state[AGIBOT_WORLD_EXT_ARM_STATE_SLICE] + head_yaw = float(state[AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX]) + head_pitch = float(state[AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX]) + waist_lift = float(state[AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX]) + waist_pitch = float(state[AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX]) + else: + arm_state = state[AGIBOT_WORLD_ARM_STATE_SLICE] + head_yaw = float(state[AGIBOT_WORLD_STATE_HEAD_YAW_IDX]) + head_pitch = float(state[AGIBOT_WORLD_STATE_HEAD_PITCH_IDX]) + waist_pitch = float(state[AGIBOT_WORLD_STATE_WAIST_PITCH_IDX]) + waist_lift = float(state[AGIBOT_WORLD_STATE_WAIST_LIFT_IDX]) + + joint_values = { + AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME: float(waist_lift), + AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME: float(waist_pitch), + AGIBOT_WORLD_HEAD_YAW_JOINT_NAME: float(head_yaw), + AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME: float(head_pitch), + } + joint_values.update({name: float(arm_state[idx]) for idx, name in enumerate(AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT)}) + joint_values.update({name: float(arm_state[7 + idx]) for idx, name in enumerate(AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT)}) + _set_gripper_joint_values_from_state(joint_values, state, embodiment_type) + return joint_values + + +def _set_gripper_joint_values_from_state( + joint_values: dict[str, float], + state: np.ndarray, + embodiment_type: str, +) -> None: + """Map observed scalar gripper state into all omnipicker finger joints.""" + + embodiment_spec = get_agibot_world_embodiment_spec(embodiment_type) + if embodiment_spec.kind != "gripper": + return + + if embodiment_type == "agibot_world_gripper_ext": + left_raw = float(state[AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE][0]) + right_raw = float(state[AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE][0]) + else: + kind_spec = get_agibot_world_kind_spec(embodiment_type) + state_hand_slice = kind_spec.state_hand_slice + left_raw = float(state[state_hand_slice.start]) + right_raw = float(state[state_hand_slice.start + 1]) + + left_open = float(convert_gripper_state_to_open_fraction(np.asarray([left_raw], dtype=np.float32))[0]) # [1] + right_open = float(convert_gripper_state_to_open_fraction(np.asarray([right_raw], dtype=np.float32))[0]) # [1] + for opening, joint_mimics in ( + (left_open, AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS), + (right_open, AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS), + ): + primary_angle = -float(np.clip(opening, 0.0, 1.0)) * AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD + for joint_name, multiplier, offset in joint_mimics: + joint_values[joint_name] = multiplier * primary_angle + offset + + +def compute_fk_transforms( + state: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Compute native-frame calibrated head-camera and gripper-base transforms for one state.""" + + fk_engine = _get_fk_engine() + link_poses = fk_engine.link_poses(_extract_joint_values_from_state(state, embodiment_type)) + + return { + "head_camera": link_poses[AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME].astype(np.float32, copy=False), + "right_wrist": link_poses[AGIBOT_WORLD_RIGHT_EE_LINK_NAME].astype(np.float32, copy=False), + "left_wrist": link_poses[AGIBOT_WORLD_LEFT_EE_LINK_NAME].astype(np.float32, copy=False), + } + + +def compute_fk_transforms_batch( + states: np.ndarray, + embodiment_type: str, +) -> dict[str, np.ndarray]: + """Compute absolute transforms for a batch of AgiBot observation states.""" + + num_steps = int(states.shape[0]) + head_camera = np.empty((num_steps, 4, 4), dtype=np.float32) + right_wrist = np.empty((num_steps, 4, 4), dtype=np.float32) + left_wrist = np.empty((num_steps, 4, 4), dtype=np.float32) + + for step in range(num_steps): + transforms = compute_fk_transforms(states[step], embodiment_type) + head_camera[step] = transforms["head_camera"] + right_wrist[step] = transforms["right_wrist"] + left_wrist[step] = transforms["left_wrist"] + + transforms_by_name = { + "head_camera": head_camera, + "right_wrist": right_wrist, + "left_wrist": left_wrist, + } + return _apply_ext_base_motion_to_poses(transforms_by_name, states, embodiment_type) + + +def convert_gripper_state_to_open_fraction(values: np.ndarray) -> np.ndarray: + """Convert observed AgiBot gripper state to viewer/dataset open fractions. + + The shared viewer/action convention is ``0=closed`` and ``1=open``. + Observed AgiBot gripper state uses actuator-close angle units: ``0`` is + open and ``120`` is closed. Some episodes contain small closed-state + overshoot above ``120``; those values are accepted and clipped to fully + closed. Small open-state sensor jitter such as ``0.217`` must therefore + remain nearly fully open, not be interpreted as a normalized close fraction. + """ + + values = np.asarray(values, dtype=np.float32) + if values.size == 0: + return values + if not np.isfinite(values).all(): + raise ValueError("AgiBot gripper values contain NaN or Inf values.") + + min_value = float(np.min(values)) + max_value = float(np.max(values)) + if ( + min_value < -_GRIPPER_VALUE_EPS + and min_value >= -AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD - _GRIPPER_VALUE_EPS + and max_value <= _GRIPPER_VALUE_EPS + ): + return _scale_negative_to_unit_interval(values, AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD) + if ( + min_value < -_GRIPPER_VALUE_EPS + and min_value >= -np.degrees(AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD) - _GRIPPER_VALUE_EPS + and max_value <= _GRIPPER_VALUE_EPS + ): + return _scale_negative_to_unit_interval(values, np.degrees(AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD)) + max_actuator_value = AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG + _GRIPPER_ACTUATOR_OVERSHOOT_DEG + if min_value >= -_GRIPPER_VALUE_EPS and max_value <= max_actuator_value + _GRIPPER_VALUE_EPS: + close_fraction = _scale_to_unit_interval(values, AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG) # [*] + return (1.0 - close_fraction).astype(np.float32, copy=False) # [*] + + raise ValueError( + f"Unsupported AgiBot gripper value range; min={min_value:.4f}, max={max_value:.4f}. " + f"Expected URDF angle [-pi/4,0] or actuator-close degrees [0,{max_actuator_value:.1f}] " + f"(values above {AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG:.1f} are clipped closed)." + ) + + diff --git a/cosmos_framework/data/vfm/action/agibot_spec.py b/cosmos_framework/data/vfm/action/agibot_spec.py new file mode 100644 index 00000000..0abfe22a --- /dev/null +++ b/cosmos_framework/data/vfm/action/agibot_spec.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Shared AgiBot metadata used by datasets and visualizers.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +AgibotWorldKind = Literal["gripper"] + +AGIBOT_WORLD_URDF_FILENAME = "G1_omnipicker_calibrated.urdf" +AGIBOT_WORLD_ARM_STATE_SLICE = slice(0, 14) +AGIBOT_WORLD_STATE_HEAD_YAW_IDX = 16 +AGIBOT_WORLD_STATE_HEAD_PITCH_IDX = 17 +AGIBOT_WORLD_STATE_WAIST_PITCH_IDX = 18 +AGIBOT_WORLD_STATE_WAIST_LIFT_IDX = 19 +AGIBOT_WORLD_HEAD_PITCH_JOINT_NAME = "idx04_head_pitch_joint" + +# -- Ext layout constants (94-dim state) ------------------------------------- +# The ext split stores joints at different offsets from the standard layout. +AGIBOT_WORLD_EXT_ARM_STATE_SLICE = slice(54, 68) +AGIBOT_WORLD_EXT_STATE_HEAD_YAW_IDX = 82 +AGIBOT_WORLD_EXT_STATE_HEAD_PITCH_IDX = 83 +AGIBOT_WORLD_EXT_STATE_WAIST_PITCH_IDX = 84 +AGIBOT_WORLD_EXT_STATE_WAIST_LIFT_IDX = 85 +AGIBOT_WORLD_EXT_STATE_ROBOT_POSITION_SLICE = slice(86, 89) +AGIBOT_WORLD_EXT_STATE_ROBOT_ORIENTATION_SLICE = slice(89, 93) +AGIBOT_WORLD_EXT_STATE_LEFT_HAND_SLICE = slice(0, 1) +AGIBOT_WORLD_EXT_STATE_RIGHT_HAND_SLICE = slice(1, 2) +AGIBOT_WORLD_HEAD_CAMERA_LINK_NAME = "head_camera_link" +AGIBOT_WORLD_LEFT_EE_LINK_NAME = "gripper_l_base_link" +AGIBOT_WORLD_RIGHT_EE_LINK_NAME = "gripper_r_base_link" +AGIBOT_WORLD_ARM_JOINT_NAMES_LEFT = tuple(f"idx{4 + i:02d}_left_arm_joint{i}" for i in range(1, 8)) +AGIBOT_WORLD_ARM_JOINT_NAMES_RIGHT = tuple(f"idx{11 + i:02d}_right_arm_joint{i}" for i in range(1, 8)) +AGIBOT_WORLD_WAIST_LIFT_JOINT_NAME = "idx01_waist_lift_joint" +AGIBOT_WORLD_WAIST_PITCH_JOINT_NAME = "idx02_waist_pitch_joint" +AGIBOT_WORLD_HEAD_YAW_JOINT_NAME = "idx03_head_yaw_joint" +AGIBOT_WORLD_GRIPPER_OPEN_ANGLE_RAD = math.pi / 4.0 +AGIBOT_WORLD_GRIPPER_OPEN_ACTUATOR_DEG = 120.0 +AGIBOT_WORLD_LEFT_GRIPPER_JOINT_MIMICS = ( + ("idx31_gripper_l_inner_joint1", 1.0, 0.0), + ("idx32_gripper_l_inner_joint3", 0.1, 0.0), + ("idx33_gripper_l_inner_joint4", 0.25, 0.0), + ("idx39_gripper_l_inner_joint0", -0.7, 0.0), + ("idx41_gripper_l_outer_joint1", -1.0, 0.0), + ("idx42_gripper_l_outer_joint3", 0.1, 0.0), + ("idx43_gripper_l_outer_joint4", -0.25, 0.0), + ("idx49_gripper_l_outer_joint0", 0.7, 0.0), +) +AGIBOT_WORLD_RIGHT_GRIPPER_JOINT_MIMICS = ( + ("idx71_gripper_r_inner_joint1", 1.0, 0.0), + ("idx72_gripper_r_inner_joint3", 0.1, 0.0), + ("idx73_gripper_r_inner_joint4", 0.25, 0.0), + ("idx79_gripper_r_inner_joint0", -0.7, 0.0), + ("idx81_gripper_r_outer_joint1", -1.0, 0.0), + ("idx82_gripper_r_outer_joint3", 0.1, 0.0), + ("idx83_gripper_r_outer_joint4", -0.25, 0.0), + ("idx89_gripper_r_outer_joint0", 0.7, 0.0), +) + + +@dataclass(frozen=True) +class AgibotWorldKindSpec: + """Layout metadata shared across all embodiments of one hand kind.""" + + kind: AgibotWorldKind + state_hand_slice: slice + + +@dataclass(frozen=True) +class AgibotWorldEmbodimentSpec: + """Per-embodiment metadata shared by training and visualization code.""" + + embodiment_type: str + kind: AgibotWorldKind + + +AGIBOT_WORLD_KIND_SPECS: dict[AgibotWorldKind, AgibotWorldKindSpec] = { + "gripper": AgibotWorldKindSpec( + kind="gripper", + state_hand_slice=slice(14, 16), + ), +} + +AGIBOT_WORLD_EMBODIMENT_SPECS: dict[str, AgibotWorldEmbodimentSpec] = { + "agibot_world_gripper": AgibotWorldEmbodimentSpec( + embodiment_type="agibot_world_gripper", + kind="gripper", + ), + "agibot_world_gripper_ext": AgibotWorldEmbodimentSpec( + embodiment_type="agibot_world_gripper_ext", + kind="gripper", + ), +} + + +def get_agibot_world_embodiment_spec(embodiment_type: str) -> AgibotWorldEmbodimentSpec: + """Return the registered spec for one AgiBot embodiment.""" + + try: + return AGIBOT_WORLD_EMBODIMENT_SPECS[embodiment_type] + except KeyError as exc: + raise ValueError( + f"Unknown AgiBot World embodiment_type={embodiment_type!r}. " + f"Expected one of {sorted(AGIBOT_WORLD_EMBODIMENT_SPECS)}." + ) from exc + + +def get_agibot_world_kind_spec(embodiment_type: str | AgibotWorldKind) -> AgibotWorldKindSpec: + """Resolve an embodiment type or kind to its shared layout metadata.""" + + kind = embodiment_type if embodiment_type in AGIBOT_WORLD_KIND_SPECS else get_agibot_world_kind(embodiment_type) + return AGIBOT_WORLD_KIND_SPECS[kind] + + +def get_agibot_world_kind(embodiment_type: str) -> AgibotWorldKind: + """Return the hand kind used by one AgiBot embodiment.""" + + return get_agibot_world_embodiment_spec(embodiment_type).kind + + +def get_agibot_world_urdf_path() -> Path: + """Return the committed AgiBot G1 omnipicker URDF path.""" + + return Path(__file__).resolve().parent / "urdf_visualizer" / AGIBOT_WORLD_URDF_FILENAME diff --git a/cosmos_framework/data/vfm/action/datasets/__init__.py b/cosmos_framework/data/vfm/action/datasets/__init__.py index 6828c760..0b01e6b3 100644 --- a/cosmos_framework/data/vfm/action/datasets/__init__.py +++ b/cosmos_framework/data/vfm/action/datasets/__init__.py @@ -1,8 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: OpenMDW-1.1 -"""Minimal Action dataset wrappers.""" +"""Action dataset wrappers for Cosmos Action. +All concrete datasets inherit from :class:`ActionBaseDataset` and expose a +``load_action_stats()`` classmethod for retrieving pre-computed normalization +statistics without instantiating the dataset. +""" + +from cosmos_framework.data.vfm.action.datasets.agibotworld_beta_lerobot_dataset import AgiBotWorldBetaLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.datasets.bridge_orig_lerobot_dataset import BridgeOrigLeRobotDataset from cosmos_framework.data.vfm.action.datasets.droid_lerobot_dataset import DROIDLeRobotDataset +from cosmos_framework.data.vfm.action.datasets.robomind_franka_dataset import RoboMINDFrankaDataset -__all__ = ["DROIDLeRobotDataset"] +__all__ = [ + "ActionBaseDataset", + "AgiBotWorldBetaLeRobotDataset", + "BridgeOrigLeRobotDataset", + "DROIDLeRobotDataset", + "RoboMINDFrankaDataset", +] diff --git a/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py new file mode 100644 index 00000000..f95feea8 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/agibotworld_beta_lerobot_dataset.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""AgiBotWorld-Beta LeRobot dataset.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +import torch.nn.functional as F +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.agibot_fk import ( + AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST, + apply_agibot_gripper_to_opencv, + apply_robot_base_motion_to_poses, + compute_fk_transforms_batch, + convert_gripper_state_to_open_fraction, +) +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import pose_abs_to_rel + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["concat_view", "ego_view"] + +_HEAD_KEY = "observation.images.head" +_HAND_LEFT_KEY = "observation.images.hand_left" +_HAND_RIGHT_KEY = "observation.images.hand_right" +_CONCAT_KEY = "observation.images.video_concat_view" + +_EFFECTOR_KEY = "observation.states.effector.position" +_JOINT_KEY = "observation.states.joint.position" +_HEAD_STATE_KEY = "observation.states.head.position" +_WAIST_KEY = "observation.states.waist.position" +_ROBOT_POSITION_KEY = "observation.states.robot.position" +_ROBOT_ORIENTATION_KEY = "observation.states.robot.orientation" + +_NORMALIZER_PATH = Path(__file__).parent / "stats/agibotworld_beta_lerobot_stats.json" + + +def _split_task_for_caption(task: str) -> tuple[str, str]: + ai_caption, separator, debug_caption = task.partition("|") + if not separator: + return task.strip(), "" + return ai_caption.strip(), debug_caption.strip() + + +def _assemble_agibot_world_state( + effector_pos: np.ndarray, + joint_pos: np.ndarray, + head_pos: np.ndarray, + waist_pos: np.ndarray, +) -> np.ndarray: + """Assemble standard 20D gripper state from Beta decomposed fields.""" + + body_head = np.stack( + [head_pos[:, 0], head_pos[:, 1], waist_pos[:, 0], waist_pos[:, 1]], + axis=-1, + ) + return np.concatenate([joint_pos, effector_pos, body_head], axis=-1).astype(np.float32, copy=False) + + +def _compute_idle_frames_agibot(action: torch.Tensor) -> int: + """Small local idle-frame helper for the 29D AgiBot FK layout. + + The shared `compute_idle_frames` expects one rotation group after each + position block; AgiBot's action spec has three such groups plus grippers. + For cookbook inference, idle frames are metadata only, so this conservative + implementation marks the initial low-motion streak length. + """ + + if action.numel() == 0: + return 0 + abs_action = action.detach().abs() + motion = torch.cat( + [ + abs_action[:, 0:3], + abs_action[:, 9:12], + abs_action[:, 18:21], + abs_action[:, 18:19].diff(dim=0, prepend=abs_action[0:1, 18:19]), + abs_action[:, 28:29].diff(dim=0, prepend=abs_action[0:1, 28:29]), + ], + dim=-1, + ).amax(dim=-1) + below = motion < 1e-3 + count = 0 + for value in below.tolist(): + if not value: + break + count += 1 + return count + + +class AgiBotWorldBetaLeRobotDataset(ActionBaseDataset): + """AgiBotWorld-Beta dataset with FK-pose 29D actions. + + Action layout matches the AgiBot World gripper normalizer: + + [head_pos+rot6d(9), right_pos+rot6d(9), right_gripper(1), + left_pos+rot6d(9), left_gripper(1)] + + The local cookbook asset provides head, left wrist, and right wrist videos. + By default this wrapper uses `concat_view`: head view on top, left/right + wrist views resized and concatenated on the bottom. + """ + + + def __init__( + self, + root: str, + fps: float = 10.0, + chunk_length: int = 16, + mode: str = "joint", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 3e-4, + viewpoint: Viewpoint = "concat_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if viewpoint not in ("concat_view", "ego_view"): + raise NotImplementedError("Supported viewpoints are concat_view and ego_view.") + super().__init__( + root=root, + domain_name="agibotworld", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + self._rows_by_episode: dict[int, list[dict[str, Any]]] = {} + for row in self._rows: + self._rows_by_episode.setdefault(int(row["episode_index"]), []).append(row) + self._timestamps_by_episode = { + episode_id: np.asarray([float(row["timestamp"]) for row in rows], dtype=np.float64) + for episode_id, rows in self._rows_by_episode.items() + } + + @property + def action_dim(self) -> int: + return 29 + + def _action_spec(self) -> ActionSpec: + return build_action_spec( + Pos(prefix="head"), + Rot("rot6d", prefix="head"), + Pos(prefix="right"), + Rot("rot6d", prefix="right"), + Gripper(prefix="right"), + Pos(prefix="left"), + Rot("rot6d", prefix="left"), + Gripper(prefix="left"), + ) + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def _compute_idle_frames(self, action: torch.Tensor) -> int: + return _compute_idle_frames_agibot(action) + + def __len__(self) -> int: + return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride) + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + row_idx = int(idx) * self._sample_stride + start_row = self._rows[row_idx] + observation_rows = self._select_observation_rows(start_row) + episode = self._episodes[int(observation_rows[0]["episode_index"])] + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption, debug_caption = _split_task_for_caption(task) + + video = self._load_video(episode, observation_rows) + action, extras = self._build_fk_action(observation_rows) + if self._viewpoint == "concat_view": + extras["additional_view_description"] = ( + "The top row shows the head-mounted camera view looking down at the workspace. " + "The bottom row contains two horizontally concatenated wrist-mounted camera views: " + "the left hand camera on the left and the right hand camera on the right." + ) + if debug_caption: + extras["debug_caption"] = debug_caption + + return self._build_result( + mode=mode, + video=video, + action=action, + ai_caption=ai_caption, + action_spec_names=self.action_names, + **extras, + ) + + def _select_observation_rows(self, start_row: dict[str, Any]) -> list[dict[str, Any]]: + """Select T+1 rows at this wrapper's target FPS within one episode.""" + + episode_id = int(start_row["episode_index"]) + rows = self._rows_by_episode[episode_id] + timestamps = self._timestamps_by_episode[episode_id] + start_frame = int(start_row["frame_index"]) + start_ts = float(start_row["timestamp"]) + target_ts = start_ts + np.arange(self._chunk_length + 1, dtype=np.float64) / self._fps + indices = np.searchsorted(timestamps, target_ts, side="left") + indices = np.minimum(indices, len(rows) - 1) + prev = np.maximum(indices - 1, 0) + choose_prev = np.abs(timestamps[prev] - target_ts) <= np.abs(timestamps[indices] - target_ts) + indices = np.where(choose_prev, prev, indices) + if int(indices[-1]) <= start_frame: + raise IndexError(f"Could not select {self._chunk_length + 1} frames from episode {episode_id} at fps={self._fps}.") + return [rows[int(i)] for i in indices] + + def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]]) -> torch.Tensor: + if self._viewpoint == "ego_view": + return self._load_video_key(episode, observation_rows, _HEAD_KEY) + + # Prefer a pre-rendered concat view if present. The local asset includes + # metadata for this key but not the public mp4, so the fallback composes + # it from the three camera streams. + concat_path = self._video_path(episode, _CONCAT_KEY) + if concat_path.exists(): + return self._load_video_key(episode, observation_rows, _CONCAT_KEY) + top = self._load_video_key(episode, observation_rows, _HEAD_KEY) + left = self._load_video_key(episode, observation_rows, _HAND_LEFT_KEY) + right = self._load_video_key(episode, observation_rows, _HAND_RIGHT_KEY) + return self._compose_multi_view(top, left, right) + + def _load_video_key(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]], key: str) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + return decode_video_frames( + self._video_path(episode, key), + [float(episode.get(f"videos/{key}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + + def _compose_multi_view(self, top: torch.Tensor, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + # Inputs are [T,C,H,W] float tensors in [0,1]. + _, _, h_top, w_top = top.shape + half_h, half_w = h_top // 2, w_top // 2 + left = F.interpolate(left, size=(half_h, half_w), mode="bilinear", align_corners=False) + right = F.interpolate(right, size=(half_h, half_w), mode="bilinear", align_corners=False) + bottom = torch.cat([left, right], dim=-1) + return torch.cat([top, bottom], dim=-2) + + def _build_fk_action(self, rows: list[dict[str, Any]]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + effector_pos = np.asarray([row[_EFFECTOR_KEY] for row in rows], dtype=np.float32) + joint_pos = np.asarray([row[_JOINT_KEY] for row in rows], dtype=np.float32) + head_pos = np.asarray([row[_HEAD_STATE_KEY] for row in rows], dtype=np.float32) + waist_pos = np.asarray([row[_WAIST_KEY] for row in rows], dtype=np.float32) + robot_pos = np.asarray([row[_ROBOT_POSITION_KEY] for row in rows], dtype=np.float32) + robot_quat = np.asarray([row[_ROBOT_ORIENTATION_KEY] for row in rows], dtype=np.float32) + states_np = _assemble_agibot_world_state(effector_pos, joint_pos, head_pos, waist_pos) + + native_fk = compute_fk_transforms_batch(states_np, "agibot_world_gripper") + native_fk = apply_robot_base_motion_to_poses(native_fk, robot_pos, robot_quat) + fk = apply_agibot_gripper_to_opencv(native_fk, AGIBOT_WORLD_GRIPPER_TO_OPENCV_BY_WRIST) + + head_rel = pose_abs_to_rel(fk["head_camera"], rotation_format="rot6d", pose_convention=self._pose_convention) + right_rel = pose_abs_to_rel(fk["right_wrist"], rotation_format="rot6d", pose_convention=self._pose_convention) + left_rel = pose_abs_to_rel(fk["left_wrist"], rotation_format="rot6d", pose_convention=self._pose_convention) + right_gripper = convert_gripper_state_to_open_fraction(effector_pos[1:, 1:2]) + left_gripper = convert_gripper_state_to_open_fraction(effector_pos[1:, 0:1]) + action_np = np.concatenate([head_rel, right_rel, right_gripper, left_rel, left_gripper], axis=-1).astype( + np.float32 + ) + extras = { + "initial_pose": torch.from_numpy(fk["head_camera"][0].copy()).float(), + "initial_pose_right": torch.from_numpy(fk["right_wrist"][0].copy()).float(), + "initial_pose_left": torch.from_numpy(fk["left_wrist"][0].copy()).float(), + } + return torch.from_numpy(action_np).float(), extras diff --git a/cosmos_framework/data/vfm/action/datasets/base_dataset.py b/cosmos_framework/data/vfm/action/datasets/base_dataset.py new file mode 100644 index 00000000..564d48e5 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/base_dataset.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Abstract base class for Action LeRobot datasets.""" + +from __future__ import annotations + +import json +import random +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +import numpy as np +import pyarrow.parquet as pq +import torch +from torch.utils.data import Dataset + +from cosmos_framework.data.vfm.action.action_normalization import load_action_stats, normalize_action +from cosmos_framework.data.vfm.action.action_spec import ActionSpec +from cosmos_framework.data.vfm.action.domain_utils import get_domain_id +from cosmos_framework.data.vfm.action.pose_utils import compute_idle_frames + +_MODE_CHOICES = ("forward_dynamics", "inverse_dynamics", "policy") + + +class ActionBaseDataset(ABC, Dataset): + """Abstract base for Action LeRobot datasets. + + Subclasses must implement the abstract methods listed below. + """ + + def __init__( + self, + root: str, + domain_name: str, + fps: float, + chunk_length: int, + mode: str, + pose_convention: str, + tolerance_s: float, + viewpoint: str, + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + super().__init__() + if pose_convention != "backward_framewise": + raise NotImplementedError(f"{type(self).__name__} only supports backward_framewise pose deltas.") + + self._fps = float(fps) + self._dt = 1.0 / self._fps + self._chunk_length = int(chunk_length) + self._sample_stride = int(sample_stride) + if self._sample_stride < 1: + raise ValueError(f"sample_stride must be >= 1, got {self._sample_stride}") + self._mode = mode + self._pose_convention = pose_convention + self._tolerance_s = float(tolerance_s) + self._viewpoint = viewpoint + self._domain_id = get_domain_id(domain_name) + self._action_normalization = action_normalization + self._norm_stats: dict[str, torch.Tensor] | None = None + + self._root = Path(root) + self._info = json.loads((self._root / "meta" / "info.json").read_text()) + self._episodes = { + int(row["episode_index"]): row + for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet")) + for row in pq.read_table(path).to_pylist() + } + self._tasks = { + int(row["task_index"]): str(row["task"]) + for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() + } + self._rows = sorted( + ( + row + for path in sorted((self._root / "data").glob("chunk-*/file-*.parquet")) + for row in pq.read_table(path).to_pylist() + ), + key=lambda row: int(row["index"]), + ) + + @property + def fps(self) -> float: + return self._fps + + @property + def chunk_length(self) -> int: + return self._chunk_length + + @property + def mode(self) -> str: + return self._mode + + @mode.setter + def mode(self, value: str) -> None: + self._mode = value + + @property + def domain_id(self) -> int: + return self._domain_id + + @property + def action_normalization(self) -> str: + return self._action_normalization + + @property + @abstractmethod + def action_dim(self) -> int: ... + + @abstractmethod + def _action_spec(self) -> ActionSpec: ... + + @property + def action_names(self) -> list[str]: + return self._action_spec().names + + @classmethod + @abstractmethod + def _stats_path(cls) -> Path: + """Return the path to the stats JSON file for this dataset.""" + ... + + @classmethod + def load_action_stats(cls) -> dict[str, torch.Tensor]: + """Return action normalization stats for this dataset as torch tensors.""" + return { + key: torch.from_numpy(value).float() + for key, value in load_action_stats(str(cls._stats_path())).items() + } + + @abstractmethod + def __getitem__(self, idx: int) -> dict[str, Any]: ... + + def _compute_idle_frames(self, action: torch.Tensor) -> int: + return compute_idle_frames( + action, + self._action_spec(), + eps_t=5e-3 / self._fps, + eps_r=np.deg2rad(1.5) / self._fps, + eps_g=1e-2, + joint_threshold=5e-3 / self._fps, + min_streak=3, + ) + + def _choose_mode(self) -> str: + if self._mode == "joint": + return random.choice(_MODE_CHOICES) + return self._mode + + def _video_path(self, episode: dict[str, Any], video_key: str) -> Path: + chunk_idx = int( + episode.get( + f"videos/{video_key}/chunk_index", + episode.get(f"videos/{video_key}/episode_chunk", episode.get("data/chunk_index", 0)), + ) + ) + file_idx = int( + episode.get( + f"videos/{video_key}/file_index", + episode.get(f"videos/{video_key}/episode_file", episode.get("data/file_index", 0)), + ) + ) + rel = self._info["video_path"].format( + video_key=video_key, + chunk_index=chunk_idx, + file_index=file_idx, + episode_chunk=chunk_idx, + episode_file=file_idx, + ) + return self._root / rel + + def _load_norm_stats(self) -> dict[str, torch.Tensor]: + if self._norm_stats is None: + self._norm_stats = self.load_action_stats() + return self._norm_stats + + def _build_result( + self, + *, + mode: str, + video: torch.Tensor, + action: torch.Tensor, + ai_caption: str, + **extras: Any, + ) -> dict[str, Any]: + idle_frames = self._compute_idle_frames(action) + normalized_action = normalize_action(action, self.action_normalization, self._load_norm_stats()) + formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) + return { + "ai_caption": ai_caption, + "video": formatted_video, + "action": normalized_action, + "conditioning_fps": torch.tensor(self._fps, dtype=torch.long), + "mode": mode, + "domain_id": torch.tensor(self._domain_id, dtype=torch.long), + "viewpoint": self._viewpoint, + "idle_frames": torch.tensor(idle_frames, dtype=torch.long), + **extras, + } + + def __len__(self) -> int: + return max(0, (len(self._rows) - self._chunk_length + self._sample_stride - 1) // self._sample_stride) diff --git a/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py new file mode 100644 index 00000000..5992ce67 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/bridge_orig_lerobot_dataset.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""Bridge Orig LeRobot dataset.""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import ( + build_abs_pose_from_components, + pose_abs_to_rel, +) + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["ego_view"] + +_IMAGE_FEATURE = "observation.images.image_0" +_STATE_FEATURE = "observation.state" +_ACTION_FEATURE = "action" + +# Raw Bridge state -> kinematics frame. The WidowX controller records +# R_state = R_fk @ DEFAULT_ROTATION.T, so R_fk = R_state @ DEFAULT_ROTATION. +_DEFAULT_ROTATION = np.array( + [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], + dtype=np.float32, +) + +# Kinematics frame -> OpenCV frame used by Cosmos action. +_BRIDGE_TO_OPENCV = np.array( + [[0.0, 0.0, 1.0], [-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + dtype=np.float32, +) + +# Re-reference from ee_gripper_link to gripper_link in the kinematics frame. +_TCP_TO_FLANGE = np.array( + [ + [1.0, 0.0, 0.0, -0.093575], + [0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float32, +) + +_NORMALIZER_PATH = Path(__file__).parent / "stats/bridge_orig_lerobot_stats.json" + + +class BridgeOrigLeRobotDataset(ActionBaseDataset): + """Bridge Orig dataset with 10D cartesian actions: + + [pos_delta(3), rot6d_delta(6), gripper(1)] + + Uses a single ``image_0`` ego-view video, backward-framewise rot6d actions, + and quantile normalization. + """ + + + def __init__( + self, + root: str, + fps: float = 5.0, + chunk_length: int = 16, + mode: str = "joint", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 1e-4, + viewpoint: Viewpoint = "ego_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if viewpoint != "ego_view": + raise NotImplementedError("This minimal Bridge dataset only supports ego_view.") + super().__init__( + root=root, + domain_name="bridge_orig_lerobot", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + + @property + def action_dim(self) -> int: + return 10 + + def _action_spec(self) -> ActionSpec: + return build_action_spec(Pos(), Rot("rot6d"), Gripper()) + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + first_row = self._rows[idx] + episode = self._episodes[int(first_row["episode_index"])] + + row_idx = idx * self._sample_stride + observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] + action_rows = observation_rows[: self._chunk_length] + + video = self._load_video(episode, observation_rows) + raw_action, initial_pose = self._build_raw_action(observation_rows, action_rows) + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) + + return self._build_result( + mode=mode, + video=video, + action=raw_action, + ai_caption=ai_caption, + initial_pose=initial_pose, + ) + + def _load_video(self, episode: dict[str, Any], observation_rows: list[dict[str, Any]]) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + return decode_video_frames( + self._video_path(episode, _IMAGE_FEATURE), + [float(episode.get(f"videos/{_IMAGE_FEATURE}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + + def _build_raw_action( + self, + observation_rows: list[dict[str, Any]], + action_rows: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor]: + state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) + poses_abs = build_abs_pose_from_components(state[:, 0:3], state[:, 3:6], "euler_xyz") + + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _DEFAULT_ROTATION.astype(poses_abs.dtype) + poses_abs = poses_abs @ _TCP_TO_FLANGE.astype(poses_abs.dtype) + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _BRIDGE_TO_OPENCV.astype(poses_abs.dtype) + + initial_pose = torch.from_numpy(poses_abs[0].copy()).float() + poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) + gripper = np.asarray([row[_ACTION_FEATURE][6] for row in action_rows], dtype=np.float32).reshape(-1, 1) + action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) + return torch.from_numpy(action).float(), initial_pose diff --git a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py index 204df695..631f1e97 100644 --- a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py +++ b/cosmos_framework/data/vfm/action/datasets/droid_lerobot_dataset.py @@ -16,14 +16,11 @@ import torch.nn.functional as F import torchvision.transforms as T from lerobot.datasets.video_utils import decode_video_frames -from torch.utils.data import Dataset -from cosmos_framework.data.vfm.action.action_normalization import load_action_stats, normalize_action -from cosmos_framework.data.vfm.action.action_spec import Gripper, Joint, Pos, Rot, build_action_spec -from cosmos_framework.data.vfm.action.domain_utils import get_domain_id +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Joint, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset from cosmos_framework.data.vfm.action.pose_utils import ( build_abs_pose_from_components, - compute_idle_frames, pose_abs_to_rel, ) @@ -55,11 +52,10 @@ dtype=np.float32, ) -_NORMALIZER_PATH = Path(__file__).parent / "droid_lerobot_normalization.json" -_MODE_CHOICES = ("forward_dynamics", "inverse_dynamics", "policy") +_NORMALIZER_PATH = Path(__file__).parent / "stats/droid_lerobot_stats.json" -class DROIDLeRobotDataset(Dataset): +class DROIDLeRobotDataset(ActionBaseDataset): """DROID Action dataset. Two action layouts: @@ -75,7 +71,7 @@ class DROIDLeRobotDataset(Dataset): def __init__( self, - root: str = "/path/to/cosmos3_action_datasets/droid_plus_lerobot_640x360_20260412", + root: str, fps: float = 15.0, chunk_length: int = 16, mode: str = "joint", @@ -89,23 +85,28 @@ def __init__( use_filter_dict: bool = False, filter_dict_path: str | None = None, ) -> None: - super().__init__() - if pose_convention != "backward_framewise": - raise NotImplementedError("This minimal DROID dataset only supports backward_framewise pose deltas.") if viewpoint != "concat_view": raise NotImplementedError("This minimal DROID dataset only supports concat_view.") if action_space not in _ACTION_SPACES: raise NotImplementedError(f"action_space must be one of {_ACTION_SPACES}, got {action_space!r}.") if use_state and action_space != "joint_pos": raise NotImplementedError("use_state is only supported with action_space='joint_pos'.") + if use_filter_dict and not filter_dict_path: + raise ValueError("use_filter_dict=True requires filter_dict_path") + + # joint_pos uses raw joint values — disable normalization at the base level. + super().__init__( + root=root, + domain_name="droid_lerobot", + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=None if action_space == "joint_pos" else action_normalization, + ) - self._fps = float(fps) - self._dt = 1.0 / self._fps - self._chunk_length = int(chunk_length) - self._mode = mode - self._pose_convention = pose_convention - self._tolerance_s = float(tolerance_s) - self._viewpoint = viewpoint self._action_space = action_space self._use_state = bool(use_state) # Per-sample image augmentation (random crop+rescale + color jitter), applied @@ -117,25 +118,7 @@ def __init__( # keep-ranges JSON is supplied via filter_dict_path (an internal data artifact). self._use_filter_dict = bool(use_filter_dict) self._filter_dict_path = filter_dict_path - if self._use_filter_dict and not self._filter_dict_path: - raise ValueError("use_filter_dict=True requires filter_dict_path") - # joint_pos trains on raw 8D joint values (the internal canonical run - # leaves action_normalization=None); ee_pose keeps quantile normalization. - self._action_normalization = None if action_space == "joint_pos" else action_normalization - self._domain_id = get_domain_id("droid_lerobot") - self._norm_stats: dict[str, torch.Tensor] | None = None - - self._root = Path(root) - self._info = json.loads((self._root / "meta" / "info.json").read_text()) - self._episodes = { - int(row["episode_index"]): row - for path in sorted((self._root / "meta" / "episodes").glob("chunk-*/file-*.parquet")) - for row in pq.read_table(path).to_pylist() - } - self._tasks = { - int(row["task_index"]): str(row["task"]) - for row in pq.read_table(self._root / "meta" / "tasks.parquet").to_pylist() - } + # Compact, lazy frame index. Materializing every frame as a Python dict # (``sorted(... pq.read_table(path).to_pylist() ...)``) does not scale: # the full DROID success shard is ~18M frames, which is tens of GB of @@ -215,43 +198,18 @@ def __init__( self._seg_win_start = np.asarray(seg_win_start, dtype=np.int64) self._seg_cum = np.cumsum(seg_len).astype(np.int64) if seg_len else np.zeros(0, dtype=np.int64) - @property - def fps(self) -> float: - return self._fps - - @property - def chunk_length(self) -> int: - return self._chunk_length - - @property - def mode(self) -> str: - return self._mode - - @mode.setter - def mode(self, value: str) -> None: - self._mode = value - - @property - def domain_id(self) -> int: - return self._domain_id - @property def action_dim(self) -> int: return 8 if self._action_space == "joint_pos" else 10 - def _action_spec(self): + def _action_spec(self) -> ActionSpec: if self._action_space == "joint_pos": return build_action_spec(Joint(n=7, label="joint"), Gripper()) return build_action_spec(Pos(), Rot("rot6d"), Gripper()) - @property - def action_names(self) -> list[str]: - return self._action_spec().names - - def _choose_mode(self) -> str: - if self._mode == "joint": - return random.choice(_MODE_CHOICES) - return self._mode + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH def _window_rows(self, start: int, stop: int, episode_index: int) -> list[dict[str, Any]]: """Reconstruct the per-frame dicts the sample builder consumes for the @@ -371,28 +329,6 @@ def _load_concat_video( bottom = torch.cat([left, right], dim=-1) return torch.cat([wrist, bottom], dim=-2) - def _video_path(self, episode: dict[str, Any], video_key: str) -> Path: - chunk_idx = int( - episode.get( - f"videos/{video_key}/chunk_index", - episode.get(f"videos/{video_key}/episode_chunk", episode.get("data/chunk_index", 0)), - ) - ) - file_idx = int( - episode.get( - f"videos/{video_key}/file_index", - episode.get(f"videos/{video_key}/episode_file", episode.get("data/file_index", 0)), - ) - ) - rel = self._info["video_path"].format( - video_key=video_key, - chunk_index=chunk_idx, - file_index=file_idx, - episode_chunk=chunk_idx, - episode_file=file_idx, - ) - return self._root / rel - def _build_raw_action( self, observation_rows: list[dict[str, Any]], @@ -404,56 +340,13 @@ def _build_raw_action( initial_pose = torch.from_numpy(poses_abs[0].copy()).float() poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) - gripper = np.asarray([row["action.gripper_position"] for row in action_rows], dtype=np.float32).reshape(-1, 1) + gripper = np.asarray( + [row[_ACTION_GRIPPER_FEATURE] for row in action_rows], dtype=np.float32 + ).reshape(-1, 1) gripper = 1.0 - gripper action = np.concatenate([poses_rel[-self._chunk_length :], gripper[-self._chunk_length :]], axis=-1) return torch.from_numpy(action).float(), initial_pose - def _build_result( - self, - *, - mode: str, - video: torch.Tensor, - action: torch.Tensor, - ai_caption: str, - **extras: Any, - ) -> dict[str, Any]: - spec = self._action_spec() - idle_frames = compute_idle_frames( - action, - spec, - eps_t=5e-3 / self._fps, - eps_r=np.deg2rad(1.5) / self._fps, - eps_g=1e-2, - joint_threshold=5e-3 / self._fps, - min_streak=3, - ) - if self._action_normalization is None: - out_action = action - else: - out_action = normalize_action(action, self._action_normalization, self._load_norm_stats()) - formatted_video = (video * 255.0).clamp(0.0, 255.0).to(torch.uint8).permute(1, 0, 2, 3) - return { - "ai_caption": ai_caption, - "video": formatted_video, - "action": out_action, - "conditioning_fps": torch.tensor(self._fps, dtype=torch.long), - "mode": mode, - "domain_id": torch.tensor(self._domain_id, dtype=torch.long), - "viewpoint": self._viewpoint, - "idle_frames": torch.tensor(idle_frames, dtype=torch.long), - **extras, - } - - def _load_norm_stats(self) -> dict[str, torch.Tensor]: - if self._norm_stats is not None: - return self._norm_stats - self._norm_stats = { - key: torch.from_numpy(value).float() - for key, value in load_action_stats(str(_NORMALIZER_PATH)).items() - } - return self._norm_stats - def __len__(self) -> int: if self._use_filter_dict: return int(self._seg_cum[-1]) if self._seg_cum.size else 0 diff --git a/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py b/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py new file mode 100644 index 00000000..136cd6c0 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/robomind_franka_dataset.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: OpenMDW-1.1 + +"""RoboMIND Franka LeRobot dataset.""" + +from __future__ import annotations + +import random +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import torch +import torch.nn.functional as F +from lerobot.datasets.video_utils import decode_video_frames + +from cosmos_framework.data.vfm.action.action_spec import ActionSpec, Gripper, Pos, Rot, build_action_spec +from cosmos_framework.data.vfm.action.datasets.base_dataset import ActionBaseDataset +from cosmos_framework.data.vfm.action.pose_utils import ( + build_abs_pose_from_components, + pose_abs_to_rel, +) + +PoseConvention = Literal["backward_framewise"] +Viewpoint = Literal["concat_view"] + +_IMAGE_FEATURES = { + "front": "observation.images.camera_front", + "left": "observation.images.camera_left", + "right": "observation.images.camera_right", +} +_STATE_FEATURE = "observation.states.end_effector" +_ACTION_FEATURE = "actions.joint_position" + +# 90-degree clockwise rotation about the Z axis in the local frame. This matches +# the production RoboMIND Franka wrapper conversion to OpenCV coordinates. +_ROBOMIND_FRANKA_TO_OPENCV: np.ndarray = np.array( + [[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + dtype=np.float32, +) + +_NORMALIZER_PATH = Path(__file__).parent / "stats/robomind_franka_stats.json" + + +def _dual_arm_action_spec(): + return build_action_spec( + Pos(prefix="left"), + Rot("rot6d", prefix="left"), + Gripper(prefix="left"), + Pos(prefix="right"), + Rot("rot6d", prefix="right"), + Gripper(prefix="right"), + ) + + +class RoboMINDFrankaDataset(ActionBaseDataset): + """RoboMIND Franka dual-arm dataset with 20D cartesian actions:: + + [left_pos_delta(3), left_rot6d_delta(6), left_gripper(1), + right_pos_delta(3), right_rot6d_delta(6), right_gripper(1)] + + Single-arm shards, split/filter logic, image augmentation, fast + initialization, and alternate viewpoints are omitted. + """ + + + def __init__( + self, + root: str, + fps: float = 10.0, + chunk_length: int = 16, + mode: str = "joint", + embodiment_type: str = "robomind-franka-dual", + pose_convention: PoseConvention = "backward_framewise", + tolerance_s: float = 1e-4, + viewpoint: Viewpoint = "concat_view", + action_normalization: str | None = "quantile", + sample_stride: int = 1, + ) -> None: + if embodiment_type != "robomind-franka-dual": + raise NotImplementedError("This minimal RoboMIND dataset only supports robomind-franka-dual.") + if viewpoint != "concat_view": + raise NotImplementedError("This minimal RoboMIND dataset only supports concat_view.") + self._embodiment_type = embodiment_type + super().__init__( + root=root, + domain_name=embodiment_type, + fps=fps, + chunk_length=chunk_length, + mode=mode, + pose_convention=pose_convention, + tolerance_s=tolerance_s, + viewpoint=viewpoint, + action_normalization=action_normalization, + sample_stride=sample_stride, + ) + + @property + def action_dim(self) -> int: + return 20 + + def _action_spec(self) -> ActionSpec: + return _dual_arm_action_spec() + + @classmethod + def _stats_path(cls) -> Path: + return _NORMALIZER_PATH + + def __getitem__(self, idx: int) -> dict[str, Any]: + mode = self._choose_mode() + idx = int(idx) + first_row = self._rows[idx] + episode = self._episodes[int(first_row["episode_index"])] + + row_idx = idx * self._sample_stride + observation_rows = self._rows[row_idx : row_idx + self._chunk_length + 1] + action_rows = observation_rows[: self._chunk_length] + + video = self._load_concat_video(episode, observation_rows) + raw_action, initial_pose_left, initial_pose_right = self._build_raw_action(observation_rows, action_rows) + task = self._tasks[int(observation_rows[0]["task_index"])] + ai_caption = random.choice([part.strip() for part in task.split(" | ") if part.strip()] or [task]) + + return self._build_result( + mode=mode, + video=video, + action=raw_action, + ai_caption=ai_caption, + initial_pose=initial_pose_left, + initial_pose_right=initial_pose_right, + additional_view_description=( + "The top row shows a third-person perspective looking towards the dual-arm Franka robot from the front. " + "The bottom-left view looks at the scene from the left side, and the bottom-right view looks at the scene from the right side." + ), + ) + + def _load_concat_video( + self, + episode: dict[str, Any], + observation_rows: list[dict[str, Any]], + ) -> torch.Tensor: + timestamps = [float(row["timestamp"]) for row in observation_rows] + frames_by_view = { + name: decode_video_frames( + self._video_path(episode, video_key), + [float(episode.get(f"videos/{video_key}/from_timestamp", 0.0)) + ts for ts in timestamps], + self._tolerance_s, + ) + for name, video_key in _IMAGE_FEATURES.items() + } + + front = frames_by_view["front"] + left = frames_by_view["left"] + right = frames_by_view["right"] + _, _, h_front, w_front = front.shape + half_h, half_w = h_front // 2, w_front // 2 + left = F.interpolate(left, size=(half_h, half_w), mode="bilinear", align_corners=False) + right = F.interpolate(right, size=(half_h, half_w), mode="bilinear", align_corners=False) + bottom = torch.cat([left, right], dim=-1) + return torch.cat([front, bottom], dim=-2) + + def _build_relative_poses( + self, + positions: np.ndarray, + euler_xyz: np.ndarray, + ) -> tuple[np.ndarray, torch.Tensor]: + poses_abs = build_abs_pose_from_components(positions, euler_xyz, "euler_xyz") + poses_abs[:, :3, :3] = poses_abs[:, :3, :3] @ _ROBOMIND_FRANKA_TO_OPENCV + initial_pose = torch.from_numpy(poses_abs[0].copy()).float() + poses_rel = pose_abs_to_rel(poses_abs, rotation_format="rot6d", pose_convention=self._pose_convention) + return poses_rel, initial_pose + + def _build_raw_action( + self, + observation_rows: list[dict[str, Any]], + action_rows: list[dict[str, Any]], + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + state = np.asarray([row[_STATE_FEATURE] for row in observation_rows], dtype=np.float32) + gripper = np.asarray([row[_ACTION_FEATURE] for row in action_rows], dtype=np.float32) + + poses_rel_left, initial_pose_left = self._build_relative_poses(state[:, 0:3], state[:, 3:6]) + poses_rel_right, initial_pose_right = self._build_relative_poses(state[:, 6:9], state[:, 9:12]) + action = np.concatenate( + [ + poses_rel_left[-self._chunk_length :], + 1.0 - gripper[-self._chunk_length :, [7]], + poses_rel_right[-self._chunk_length :], + 1.0 - gripper[-self._chunk_length :, [15]], + ], + axis=-1, + ) + return torch.from_numpy(action).float(), initial_pose_left, initial_pose_right diff --git a/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json new file mode 100644 index 00000000..970ac30d --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/agibotworld_beta_lerobot_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.000167, -0.007272, -0.014935, 0.999999, -0.000306, -0.000594, -0.000260, 0.999227, -0.025516, -0.012912, -0.017163, -0.017614, 0.994613, -0.064506, -0.053231, -0.066267, 0.994383, -0.051163, 0.000000, -0.011640, -0.015508, -0.013880, 0.996511, -0.050126, -0.040305, -0.047330, 0.996618, -0.038303, 0.000000], + "q99": [ 0.000164, 0.004822, 0.013706, 1.000000, 0.000240, 0.000703, 0.000278, 1.000000, 0.030090, 0.013182, 0.016960, 0.016101, 1.000000, 0.066268, 0.053905, 0.064357, 1.000000, 0.052547, 1.000000, 0.010890, 0.015347, 0.012968, 1.000000, 0.047482, 0.042217, 0.050173, 1.000000, 0.041428, 1.000000] +} diff --git a/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json new file mode 100644 index 00000000..66d1d799 --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/bridge_orig_lerobot_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.038884, -0.028667, -0.037840, 0.976292, -0.163098, -0.081545, -0.160193, 0.976322, -0.078872, 0.000000], + "q99": [ 0.039722, 0.029068, 0.026702, 1.000000, 0.160195, 0.081655, 0.163227, 1.000000, 0.095189, 1.000000] +} diff --git a/cosmos_framework/data/vfm/action/datasets/droid_lerobot_normalization.json b/cosmos_framework/data/vfm/action/datasets/stats/droid_lerobot_stats.json similarity index 100% rename from cosmos_framework/data/vfm/action/datasets/droid_lerobot_normalization.json rename to cosmos_framework/data/vfm/action/datasets/stats/droid_lerobot_stats.json diff --git a/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json b/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json new file mode 100644 index 00000000..66e3c3ce --- /dev/null +++ b/cosmos_framework/data/vfm/action/datasets/stats/robomind_franka_stats.json @@ -0,0 +1,4 @@ +{ + "q01": [-0.051367, -0.031964, -0.046482, 0.988101, -0.053179, -0.128603, -0.075432, 0.994427, -0.059973, 0.000000, -0.035108, -0.021212, -0.029788, 0.986086, -0.098043, -0.111441, -0.093441, 0.991492, -0.058030, 0.000000], + "q99": [ 0.043729, 0.021737, 0.036738, 1.000000, 0.075612, 0.102791, 0.053223, 1.000000, 0.077057, 1.000000, 0.047581, 0.021270, 0.025712, 1.000000, 0.095525, 0.126049, 0.098778, 1.000000, 0.041914, 0.995443] +} diff --git a/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf b/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf new file mode 100644 index 00000000..bd83679e --- /dev/null +++ b/cosmos_framework/data/vfm/action/urdf_visualizer/G1_omnipicker_calibrated.urdf @@ -0,0 +1,1350 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +