diff --git a/strands_robots_sim/policies/groot/__init__.py b/strands_robots_sim/policies/groot/__init__.py index 89b62ce..6a34d3d 100644 --- a/strands_robots_sim/policies/groot/__init__.py +++ b/strands_robots_sim/policies/groot/__init__.py @@ -1,6 +1,10 @@ #!/usr/bin/env python3 """GR00T Policy — natural language robot control via GR00T inference servers. +Adapts observation formatting and transport to the active protocol +(``sim_wrapper`` or ``direct``). Data configs are unchanged; the +protocol only controls HOW data is shaped and sent. + SPDX-License-Identifier: Apache-2.0 """ @@ -12,7 +16,7 @@ from .. import Policy from .client import GR00TClient -from .data_config import load_data_config +from .data_config import get_protocol, load_data_config logger = logging.getLogger(__name__) @@ -21,24 +25,32 @@ class Gr00tPolicy(Policy): """GR00T policy: connects to a GR00T inference server via ZMQ.""" def __init__(self, data_config: Union[str, dict], host: str = "localhost", port: int = 5555, **kwargs): - """Initialize GR00T policy. + protocol_override = kwargs.pop("protocol", kwargs.pop("groot_version", None)) + # Map legacy version aliases to protocol names + _aliases = {"n1d6": "sim_wrapper", "n1.6": "sim_wrapper", "n1d5": "direct", "n1.5": "direct"} + if protocol_override in _aliases: + protocol_override = _aliases[protocol_override] - Args: - data_config: Config name (e.g. "libero") or dict with video/state/action/language keys - host: Inference service host - port: Inference service port - """ - self.config = load_data_config(data_config) + self.config = load_data_config(data_config, protocol=protocol_override) self.data_config_name = data_config if isinstance(data_config, str) else "custom" - self.client = GR00TClient(host=host, port=port) + + self.protocol_name = self.config.get("protocol", "auto") + self.protocol = get_protocol(self.protocol_name) + + self.client = GR00TClient(host=host, port=port, protocol=self.protocol_name) self.camera_keys = self.config["video"] self.state_keys = self.config["state"] self.action_keys = self.config["action"] self.language_keys = self.config["language"] - self.robot_state_keys = [] + self.robot_state_keys: List[str] = [] + + logger.info(f"🧠 GR00T Policy: {self.data_config_name} @ {host}:{port} (protocol: {self.protocol_name})") - logger.info(f"🧠 GR00T Policy: {self.data_config_name} @ {host}:{port}") + # Backward-compat alias + @property + def groot_version(self) -> str: + return self.protocol_name @property def provider_name(self) -> str: @@ -47,280 +59,291 @@ def provider_name(self) -> str: def set_robot_state_keys(self, robot_state_keys: List[str]) -> None: self.robot_state_keys = robot_state_keys + # ------------------------------------------------------------------ + # Main entry point + # ------------------------------------------------------------------ + async def get_actions(self, observation_dict: Dict[str, Any], instruction: str, **kwargs) -> List[Dict[str, Any]]: - """Get actions from GR00T policy server. + obs = self._build_observation(observation_dict, instruction) + try: + action_chunk = self.client.get_action(obs) + except Exception as e: + logger.error(f"GR00T inference failed: {e}") + action_chunk = self._create_fallback_actions() + return self._to_robot_actions(action_chunk) - Args: - observation_dict: Robot observations (cameras + state) - instruction: Natural language instruction + # ------------------------------------------------------------------ + # Observation building — driven by protocol descriptor + # ------------------------------------------------------------------ - Returns: - List of action dicts for robot execution - """ - obs = {} + def _build_observation(self, observation_dict: Dict[str, Any], instruction: str) -> dict: + """Build observation dict. Protocol controls shape/dtype/wrapping.""" + obs: dict = {} + video_ndim = self.protocol.get("video_ndim", 4) + state_dtype = self.protocol.get("state_dtype", np.float64) - # Camera observations with resizing + # Video for vkey in self.camera_keys: cam = self._find_camera(vkey, observation_dict) - if cam and cam in observation_dict: - image = observation_dict[cam] - if "so100" in self.data_config_name.lower(): - obs[vkey] = self._resize_image(image, target_size=(720, 1280)) - else: - obs[vkey] = self._resize_image(image, target_size=(256, 256)) - else: - if "so100" in self.data_config_name.lower(): - obs[vkey] = np.zeros((720, 1280, 3), dtype=np.uint8) - else: - obs[vkey] = np.zeros((256, 256, 3), dtype=np.uint8) - - # State observations - robot_state_parts = [] - for k in self.robot_state_keys: - value = observation_dict.get(k, 0.0) - if isinstance(value, (list, np.ndarray)): - robot_state_parts.extend(np.atleast_1d(value).flatten()) - else: - robot_state_parts.append(float(value)) - robot_state = np.array(robot_state_parts, dtype=np.float64) - + img = observation_dict[cam] if cam and cam in observation_dict else None + img = ( + self._resize_image(img, self._image_size()) + if img is not None + else np.zeros((*self._image_size(), 3), dtype=np.uint8) + ) + obs[vkey] = self._add_video_dims(img, video_ndim) + + # State if "libero" in self.data_config_name.lower(): - self._map_libero_state(obs, observation_dict) + self._map_libero_state(obs, observation_dict, state_dtype, video_ndim) else: - self._map_state(obs, robot_state) + self._map_state(obs, observation_dict, state_dtype) - # Language instruction + # Language if self.language_keys: - obs[self.language_keys[0]] = instruction + lang_type = self.protocol.get("language_type", "list") + obs[self.language_keys[0]] = (instruction,) if lang_type == "tuple" else [instruction] - # Batch dimension - for k in obs: - if isinstance(obs[k], np.ndarray) and k.startswith("video."): - obs[k] = np.expand_dims(obs[k], axis=0) - elif isinstance(obs[k], str): - obs[k] = [obs[k]] + return obs - try: - action_chunk = self.client.get_action(obs) - except Exception as e: - logger.error(f"GR00T inference failed: {e}") - action_chunk = self._create_fallback_actions() + @staticmethod + def _add_video_dims(image: np.ndarray, ndim: int) -> np.ndarray: + image = image.astype(np.uint8) + assert image.ndim == 3, f"Expected (H, W, C) image, got ndim={image.ndim} shape={image.shape}" + if ndim == 5: + return image.reshape(1, 1, *image.shape) # (B, T, H, W, C) + return np.expand_dims(image, 0) # (B, H, W, C) - return self._to_robot_actions(action_chunk) + def _image_size(self) -> tuple: + return (720, 1280) if "so100" in self.data_config_name.lower() else (256, 256) - def _find_camera(self, video_key: str, obs: dict) -> str: - """Map GR00T video key to available camera key.""" - if video_key in obs: - return video_key + # ------------------------------------------------------------------ + # State mapping + # ------------------------------------------------------------------ - name = video_key.replace("video.", "") - if name in obs: - return name + def _map_libero_state(self, obs: dict, env_obs: dict, dtype, video_ndim: int): + """Decompose Libero eef_pos / eef_quat into state.x, state.y … state.gripper. - # Libero-specific aliases - libero_aliases = { - "image": ["front_camera", "agentview_image", "front", "webcam", "main"], - "wrist_image": ["wrist_camera", "robot0_eye_in_hand_image", "wrist", "hand", "end_effector"], - } - if name in libero_aliases: - for candidate in libero_aliases[name]: - if candidate in obs: - return candidate + If eef_pos/eef_quat are missing, zero-valued state entries are still + added so the server always receives a complete observation. + """ + eef_pos = env_obs.get("robot0_eef_pos") + eef_quat = env_obs.get("robot0_eef_quat") + gripper = env_obs.get("robot0_gripper_qpos", np.array([0.0, 0.0])) + + if eef_pos is None or eef_quat is None: + logger.warning("robot0_eef_pos/eef_quat missing from observation — using zeros") + eef_pos = eef_pos if eef_pos is not None else np.zeros(3) + eef_quat = eef_quat if eef_quat is not None else np.array([0, 0, 0, 1.0]) + rpy = self._quat2axisangle(eef_quat) + + scalars = {"x": eef_pos[0], "y": eef_pos[1], "z": eef_pos[2], "roll": rpy[0], "pitch": rpy[1], "yaw": rpy[2]} + + for name, val in scalars.items(): + key = f"state.{name}" + if video_ndim == 5: + obs[key] = np.array([[[val]]], dtype=dtype) + else: + obs[key] = np.array([[val]], dtype=dtype) + + gripper_arr = np.asarray(gripper, dtype=dtype) + if video_ndim == 5: + obs["state.gripper"] = gripper_arr.reshape(1, 1, -1) + else: + obs["state.gripper"] = np.expand_dims(gripper_arr, 0) + + def _map_state(self, obs: dict, env_obs: dict, dtype): + parts = [] + for k in self.robot_state_keys: + v = env_obs.get(k, 0.0) + parts.extend(np.atleast_1d(v).flatten() if isinstance(v, (list, np.ndarray)) else [float(v)]) + state = np.array(parts, dtype=dtype) + + name = self.data_config_name.lower() + if "so100" in name and len(state) >= 6: + obs["state.single_arm"] = state[:5] + obs["state.gripper"] = state[5:6] + elif "fourier_gr1" in name and len(state) >= 14: + obs["state.left_arm"] = state[:7] + obs["state.right_arm"] = state[7:14] + elif "unitree_g1" in name and len(state) >= 14: + obs["state.left_arm"] = state[:7] + obs["state.right_arm"] = state[7:14] + elif "bimanual_panda" in name and len(state) >= 12: + obs["state.right_arm_eef_pos"] = state[:3] + obs["state.right_arm_eef_quat"] = state[3:7] + obs["state.left_arm_eef_pos"] = state[7:10] + obs["state.left_arm_eef_quat"] = state[10:14] + elif self.state_keys and len(state) > 0: + obs[self.state_keys[0]] = state + # ------------------------------------------------------------------ + # Camera helpers + # ------------------------------------------------------------------ + + def _find_camera(self, video_key: str, obs: dict) -> str: + name = video_key.replace("video.", "") + for candidate in (video_key, name): + if candidate in obs: + return candidate aliases = { + "image": ["front_camera", "agentview_image", "front", "webcam", "main"], + "wrist_image": ["wrist_camera", "robot0_eye_in_hand_image", "wrist", "hand", "end_effector"], "webcam": ["webcam", "front", "wrist", "main"], "front": ["front", "webcam", "top", "ego_view", "main"], "wrist": ["wrist", "hand", "end_effector", "gripper"], "ego_view": ["front", "ego_view", "webcam", "main"], - "top": ["top", "overhead", "front"], - "side": ["side", "lateral", "left", "right"], "rs_view": ["rs_view", "front", "ego_view", "webcam"], } - for candidate in aliases.get(name, [name]): - if candidate in obs: - return candidate - - # Fallback: first camera-like key + for c in aliases.get(name, [name]): + if c in obs: + return c cams = [ k for k in obs - if any(n in k.lower() for n in ["camera", "image", "webcam", "front", "wrist", "video", "rgb", "depth"]) - and not k.startswith("state.") - and not k.startswith("robot0_joint") - and not k.startswith("robot0_eef") - and not k.startswith("robot0_gripper") + if any(n in k.lower() for n in ("camera", "image", "webcam", "front", "wrist", "video", "rgb")) + and not k.startswith(("state.", "robot0_joint", "robot0_eef", "robot0_gripper")) ] return cams[0] if cams else None - def _resize_image(self, image: np.ndarray, target_size: tuple = (256, 256)) -> np.ndarray: - """Resize image to match GR00T server expectations.""" + def _resize_image(self, image: np.ndarray, target: tuple = (256, 256)) -> np.ndarray: + """Resize image to target (H, W). Always returns a 3-D (H, W, C) array.""" try: - if len(image.shape) == 4: + if image.ndim == 4: image = image[0] - elif len(image.shape) == 2: + elif image.ndim == 2: image = image[..., np.newaxis] - h, w = image.shape[:2] - th, tw = target_size - if h == th and w == tw: + th, tw = target + if (h, w) == (th, tw): return image - try: - import cv2 # nosec B404 + import cv2 return cv2.resize(image, (tw, th), interpolation=cv2.INTER_LINEAR) except ImportError: pass - try: from scipy.ndimage import zoom - factors = (th / h, tw / w, 1) if len(image.shape) == 3 else (th / h, tw / w) - return zoom(image, factors, order=1).astype(image.dtype) + return zoom(image, (th / h, tw / w, 1) if image.ndim == 3 else (th / h, tw / w), order=1).astype( + image.dtype + ) except ImportError: pass - - # Numpy nearest-neighbor fallback - h_idx = np.linspace(0, h - 1, th).astype(int) - w_idx = np.linspace(0, w - 1, tw).astype(int) - if len(image.shape) == 3: - return image[np.ix_(h_idx, w_idx, range(image.shape[2]))] - return image[np.ix_(h_idx, w_idx)] + hi, wi = np.linspace(0, h - 1, th).astype(int), np.linspace(0, w - 1, tw).astype(int) + return image[np.ix_(hi, wi, range(image.shape[2]))] if image.ndim == 3 else image[np.ix_(hi, wi)] except Exception: + # Ensure we always return 3-D so _add_video_dims doesn't fail + if image.ndim == 2: + return image[..., np.newaxis] + if image.ndim == 4: + return image[0] return image - def _map_libero_state(self, obs: dict, observation_dict: dict): - """Map Libero end-effector pose to GR00T state format.""" - if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict: - xyz = observation_dict["robot0_eef_pos"] - quat = observation_dict["robot0_eef_quat"] - gripper = observation_dict.get("robot0_gripper_qpos", np.array([0.0, 0.0])) - rpy = self._quat2axisangle(quat) - obs["state.x"] = np.array([[xyz[0]]]) - obs["state.y"] = np.array([[xyz[1]]]) - obs["state.z"] = np.array([[xyz[2]]]) - obs["state.roll"] = np.array([[rpy[0]]]) - obs["state.pitch"] = np.array([[rpy[1]]]) - obs["state.yaw"] = np.array([[rpy[2]]]) - obs["state.gripper"] = np.expand_dims(gripper, axis=0) - else: - for key in ("x", "y", "z", "roll", "pitch", "yaw"): - obs[f"state.{key}"] = np.array([[0.0]], dtype=np.float64) - obs["state.gripper"] = np.array([[0.0]], dtype=np.float64) - - def _map_state(self, obs: dict, state: np.ndarray): - """Map robot state array to GR00T state keys.""" - name = self.data_config_name.lower() - if "so100" in name and len(state) >= 6: - obs["state.single_arm"] = state[:5].astype(np.float64) - obs["state.gripper"] = state[5:6].astype(np.float64) - elif "fourier_gr1" in name and len(state) >= 14: - obs["state.left_arm"] = state[:7].astype(np.float64) - obs["state.right_arm"] = state[7:14].astype(np.float64) - elif "unitree_g1" in name and len(state) >= 14: - obs["state.left_arm"] = state[:7].astype(np.float64) - obs["state.right_arm"] = state[7:14].astype(np.float64) - elif "bimanual_panda" in name and len(state) >= 12: - obs["state.right_arm_eef_pos"] = state[:3].astype(np.float64) - obs["state.right_arm_eef_quat"] = state[3:7].astype(np.float64) - obs["state.left_arm_eef_pos"] = state[7:10].astype(np.float64) - obs["state.left_arm_eef_quat"] = state[10:14].astype(np.float64) - elif self.state_keys and len(state) > 0: - obs[self.state_keys[0]] = state.astype(np.float64) + # ------------------------------------------------------------------ + # Action conversion + # ------------------------------------------------------------------ def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]: - """Convert GR00T action chunk to list of robot action dicts.""" - act_key = None - for k in self.action_keys: - base = k.replace("action.", "") if k.startswith("action.") else k - full = f"action.{base}" - if full in chunk: - act_key = full - break - if not act_key: - act_keys = [k for k in chunk if k.startswith("action.")] - act_key = act_keys[0] if act_keys else None + # Strip batch dim if protocol says response has one + if self.protocol.get("response_batch_dim"): + chunk = { + k: v[0] if isinstance(v, np.ndarray) and v.ndim == 3 and v.shape[0] == 1 else v + for k, v in chunk.items() + } + + act_key = self._find_action_key(chunk) if not act_key: return [] - horizon = chunk[act_key].shape[0] - actions = [] + actions: list = [] if "libero" in self.data_config_name.lower(): for i in range(horizon): - action_array = self._to_libero_action(chunk, idx=i) - actions.append({"action": action_array.tolist()}) + actions.append({"action": self._to_libero_action(chunk, i).tolist()}) else: for i in range(horizon): parts = [] for k in self.action_keys: - mod = k.split(".")[-1] - if f"action.{mod}" in chunk: - parts.append(np.atleast_1d(chunk[f"action.{mod}"][i])) - if not parts: - for k, v in chunk.items(): - if k.startswith("action."): - parts.append(np.atleast_1d(v[i])) - - concat = np.concatenate(parts) if parts else np.zeros(len(self.robot_state_keys) or 6) - actions.append( - {k: float(concat[j]) if j < len(concat) else 0.0 for j, k in enumerate(self.robot_state_keys)} - ) - + mod = k.split(".")[-1] if "." in k else k + for c in (k, f"action.{mod}", mod): + if c in chunk: + parts.append(np.atleast_1d(chunk[c][i]).flatten()) + break + cat = np.concatenate(parts) if parts else np.zeros(len(self.robot_state_keys) or 6) + actions.append({k: float(cat[j]) if j < len(cat) else 0.0 for j, k in enumerate(self.robot_state_keys)}) return actions - @staticmethod - def _quat2axisangle(quat: np.ndarray) -> np.ndarray: - """Convert quaternion (x,y,z,w) to axis-angle (roll,pitch,yaw).""" - quat = np.array(quat) - quat[3] = np.clip(quat[3], -1.0, 1.0) - den = np.sqrt(1.0 - quat[3] * quat[3]) - if math.isclose(den, 0.0): - return np.zeros(3) - return (quat[:3] * 2.0 * math.acos(quat[3])) / den - - def _to_libero_action(self, action_chunk: dict, idx: int = 0) -> np.ndarray: - """Convert GR00T action chunk to Libero 7-dim: [dx,dy,dz,droll,dpitch,dyaw,gripper].""" - components = [] + def _find_action_key(self, chunk: dict) -> str: + for k in self.action_keys: + base = k.split(".")[-1] if "." in k else k + for c in (k, f"action.{base}", base): + if c in chunk: + return c + for k in chunk: + if k.startswith("action."): + return k + return None + + def _to_libero_action(self, chunk: dict, idx: int = 0) -> np.ndarray: + parts = [] for key in ("x", "y", "z", "roll", "pitch", "yaw", "gripper"): - full_key = f"action.{key}" - if full_key in action_chunk: - components.append(np.atleast_1d(action_chunk[full_key][idx])[0]) + for c in (f"action.{key}", key): + if c in chunk: + parts.append(float(np.asarray(chunk[c][idx]).flatten()[0])) + break else: - components.append(0.0) - action = np.array(components, dtype=np.float32) - action = self._normalize_gripper(action) - assert len(action) == 7 # nosec B101 + parts.append(0.0) + action = np.array(parts, dtype=np.float32) + action[-1] = np.sign(1 - 2 * action[-1]) # gripper [0,1] → {+1,−1} return action - @staticmethod - def _normalize_gripper(action: np.ndarray, binarize: bool = True) -> np.ndarray: - """Normalize gripper action from [0,1] to [+1,-1].""" - action[..., -1] = 1 - 2 * action[..., -1] - if binarize: - action[..., -1] = np.sign(action[..., -1]) - return action + # Typical dimensionality for known action key fragments. + _ACTION_DIM_PATTERNS = { + "joint_pos": 7, + "joint_vel": 7, + "eef_pos": 3, + "eef_quat": 4, + "eef_rot": 3, + "gripper_qpos": 1, + "gripper_close": 1, + "gripper": 1, + "left_arm": 7, + "right_arm": 7, + "left_hand": 1, + "right_hand": 1, + "single_arm": 5, + } + + @classmethod + def _infer_action_dim(cls, key: str) -> int: + """Infer action dimensionality from a key name like 'action.robot0_joint_pos'.""" + name = key.split(".")[-1] if "." in key else key + # Try longest suffix match first (e.g. "gripper_qpos" before "gripper") + for pattern in sorted(cls._ACTION_DIM_PATTERNS, key=len, reverse=True): + if name.endswith(pattern): + return cls._ACTION_DIM_PATTERNS[pattern] + return 1 def _create_fallback_actions(self) -> dict: - """Create zero-action fallback when inference fails.""" + h = 16 if self.protocol_name == "sim_wrapper" else 8 chunk = {} - horizon = 8 for key in self.action_keys: - mod = key.split(".")[-1] - if "joint_pos" in mod.lower(): - dim = 7 - elif "eef_pos" in mod.lower(): - dim = 3 - elif "eef_quat" in mod.lower(): - dim = 4 - elif "gripper" in mod.lower(): - dim = 1 - else: - dim = len(self.robot_state_keys) // 5 if self.robot_state_keys else 7 - chunk[f"action.{mod}"] = np.zeros((horizon, dim), dtype=np.float64) - if not chunk: - chunk["action.robot0_joint_pos"] = np.zeros((horizon, 7), dtype=np.float64) + dim = self._infer_action_dim(key) + chunk[key] = np.zeros((h, dim), dtype=np.float32) return chunk + # ------------------------------------------------------------------ + # Math + # ------------------------------------------------------------------ + + @staticmethod + def _quat2axisangle(quat: np.ndarray) -> np.ndarray: + q = np.array(quat) + q[3] = np.clip(q[3], -1.0, 1.0) + den = np.sqrt(1.0 - q[3] * q[3]) + return np.zeros(3) if math.isclose(den, 0.0) else (q[:3] * 2.0 * math.acos(q[3])) / den + __all__ = ["Gr00tPolicy"] diff --git a/strands_robots_sim/policies/groot/client.py b/strands_robots_sim/policies/groot/client.py index e3f43a0..867dc04 100644 --- a/strands_robots_sim/policies/groot/client.py +++ b/strands_robots_sim/policies/groot/client.py @@ -1,6 +1,18 @@ #!/usr/bin/env python3 """GR00T inference client — thin ZMQ wrapper for policy server communication. +Supports two transport protocols controlled by the ``protocol`` parameter: + +* ``"sim_wrapper"`` — Wraps observations in ``{"observation": obs}`` and + handles tuple ``(action, info)`` responses. Used by Isaac-GR00T N1.6+ + servers running ``Gr00tSimPolicyWrapper``. + +* ``"direct"`` — Sends observations as a flat data dict. Used by + Isaac-GR00T N1.5 servers. + +* ``"auto"`` (default) — Tries ``sim_wrapper`` first, falls back to + ``direct`` on error. + SPDX-License-Identifier: Apache-2.0 """ @@ -30,20 +42,60 @@ def _decode(obj): class GR00TClient: """Minimal ZMQ client for GR00T inference servers.""" - def __init__(self, host="localhost", port=5555): + def __init__(self, host="localhost", port=5555, protocol="auto"): self.ctx = zmq.Context() self.sock = self.ctx.socket(zmq.REQ) self.sock.connect(f"tcp://{host}:{port}") + self.protocol = protocol def get_action(self, observations): - """Send observations, receive action chunk.""" - request = {"endpoint": "get_action", "data": observations} + """Send observations and receive an action chunk. + + The request format adapts to the configured protocol. + """ + if self.protocol == "sim_wrapper": + return self._request({"observation": observations}) + elif self.protocol == "direct": + return self._request_flat(observations) + else: + # Auto-detect: try sim_wrapper first; if the response lacks + # action keys, retry with the direct protocol. + try: + result = self._request({"observation": observations}) + if self._has_action_keys(result): + return result + except Exception: + pass + return self._request_flat(observations) + + def _request(self, data): + """Send wrapped request and parse response.""" + raw = self._send_recv(data) + # sim_wrapper returns (action_dict, info_dict) tuple + if isinstance(raw, (list, tuple)) and len(raw) == 2: + return raw[0] + return raw + + def _request_flat(self, observations): + """Send flat request (legacy protocol).""" + return self._send_recv(observations) + + def _send_recv(self, data): + """Low-level send/receive with error handling.""" + request = {"endpoint": "get_action", "data": data} self.sock.send(msgpack.packb(request, default=_encode)) response = msgpack.unpackb(self.sock.recv(), object_hook=_decode) if isinstance(response, dict) and "error" in response: raise RuntimeError(f"GR00T server error: {response['error']}") return response + @staticmethod + def _has_action_keys(result): + """Check whether the response looks like a valid action chunk.""" + if not isinstance(result, dict) or not result: + return False + return any(k.startswith("action.") or k in ("action",) for k in result) + def ping(self): """Check server connectivity.""" try: diff --git a/strands_robots_sim/policies/groot/data_config.py b/strands_robots_sim/policies/groot/data_config.py index 82cf9b4..abde82f 100644 --- a/strands_robots_sim/policies/groot/data_config.py +++ b/strands_robots_sim/policies/groot/data_config.py @@ -1,10 +1,49 @@ #!/usr/bin/env python3 """GR00T data configurations — robot embodiment key mappings. +Each config maps modality names (video, state, action, language) to the +keys expected by the GR00T inference server. + +Configs may optionally include a ``protocol`` field that selects the +transport format. Two protocols exist: + +* ``"sim_wrapper"`` — For ``Gr00tSimPolicyWrapper`` servers (Isaac-GR00T + N1.6+). 5-D video, float32 state, wrapped request envelope. +* ``"direct"`` — For bare-policy servers (Isaac-GR00T N1.5 and earlier). + 4-D video, float64 state, flat request. + +When ``protocol`` is absent the client auto-detects at runtime. + SPDX-License-Identifier: Apache-2.0 """ -# Each config: (video_keys, state_keys, action_keys, language_keys) +import numpy as np + +# --------------------------------------------------------------------------- +# Protocol descriptors — HOW observations are formatted and sent. +# --------------------------------------------------------------------------- + +PROTOCOLS = { + "sim_wrapper": { + "video_ndim": 5, # (B, T, H, W, C) + "state_dtype": np.float32, + "request_wrap": "observation", + "response_batch_dim": True, + "language_type": "tuple", + }, + "direct": { + "video_ndim": 4, # (B, H, W, C) + "state_dtype": np.float64, + "request_wrap": None, + "response_batch_dim": False, + "language_type": "list", + }, +} + +# --------------------------------------------------------------------------- +# Embodiment data configs — WHAT keys to use (unchanged from original). +# --------------------------------------------------------------------------- + DATA_CONFIGS = { "fourier_gr1_arms_only": { "video": ["video.ego_view"], @@ -49,6 +88,7 @@ "action.robot0_gripper_qpos", ], "language": ["annotation.human.action.task_description"], + "protocol": "sim_wrapper", }, "libero_spatial": { "video": ["video.image", "video.wrist_image"], @@ -61,6 +101,7 @@ "action.robot0_gripper_qpos", ], "language": ["annotation.human.action.task_description"], + "protocol": "sim_wrapper", }, "libero_goal": { "video": ["video.image", "video.wrist_image"], @@ -73,6 +114,7 @@ "action.robot0_gripper_qpos", ], "language": ["annotation.human.action.task_description"], + "protocol": "sim_wrapper", }, "libero_meanstd": { "video": ["video.image", "video.wrist_image"], @@ -89,15 +131,58 @@ } -def load_data_config(name): - """Load a data config by name. Returns dict with video/state/action/language keys.""" +def get_protocol(name): + """Return a copy of the named protocol descriptor, or empty dict.""" + return dict(PROTOCOLS.get(name, {})) + + +def load_data_config(name, protocol=None): + """Load a data config by name. + + Args: + name: Config name (e.g. "libero") or dict. + Supports "name:protocol" syntax (e.g. "libero:sim_wrapper"). + protocol: Override the protocol field in the returned config. + + Returns: + Dict with video / state / action / language keys and optional + ``protocol`` field. + """ if isinstance(name, dict): return name + + # Support "name:protocol" shorthand + if isinstance(name, str) and ":" in name: + name, protocol = name.split(":", 1) + + config = _lookup(name) + if config is None: + raise ValueError(f"Unknown data_config '{name}'. Available: {list(DATA_CONFIGS.keys())}") + + if protocol: + config["protocol"] = protocol + + return config + + +def _lookup(name): + """Exact or fuzzy config lookup. Returns a shallow copy. + + Resolution order: + 1. Exact match against DATA_CONFIGS keys. + 2. Fuzzy: any name containing "libero" falls back to the base + ``libero`` config (or ``libero_meanstd`` if "goal"/"meanstd" + appears in the name). This lets callers pass suite names like + ``"libero_object"`` without registering every variant. + 3. None if nothing matches — caller should raise ValueError. + + Note: the ``protocol`` field in the returned config can still be + overridden by ``load_data_config(..., protocol=...)`` after lookup. + """ if name in DATA_CONFIGS: - return DATA_CONFIGS[name] - # Fuzzy match: any name containing "libero" falls back to the base libero config + return dict(DATA_CONFIGS[name]) if isinstance(name, str) and "libero" in name.lower(): if "goal" in name.lower() or "meanstd" in name.lower(): - return DATA_CONFIGS["libero_meanstd"] - return DATA_CONFIGS["libero"] - raise ValueError(f"Unknown data_config '{name}'. Available: {list(DATA_CONFIGS.keys())}") + return dict(DATA_CONFIGS["libero_meanstd"]) + return dict(DATA_CONFIGS["libero"]) + return None diff --git a/tests/test_groot_n1d6.py b/tests/test_groot_n1d6.py new file mode 100644 index 0000000..304f0b0 --- /dev/null +++ b/tests/test_groot_n1d6.py @@ -0,0 +1,262 @@ +"""Tests for GR00T multi-protocol support (sim_wrapper + direct). + +Verifies that observation formatting adapts to the active protocol +while data configs remain unchanged. All tests are mock-only — +GR00TClient is patched so no real ZMQ socket is created. +""" + +from unittest.mock import patch + +import numpy as np +import pytest + +from strands_robots_sim.policies.groot.data_config import PROTOCOLS, load_data_config + +pytestmark = pytest.mark.mock + +_CLIENT_PATH = "strands_robots_sim.policies.groot.GR00TClient" + + +# -- Fixtures ---------------------------------------------------------------- + + +@pytest.fixture +def libero_obs(): + return { + "robot0_joint_pos": np.zeros(7), + "robot0_joint_vel": np.zeros(7), + "robot0_eef_pos": np.array([0.1, 0.2, 0.3]), + "robot0_eef_quat": np.array([0.0, 0.0, 0.0, 1.0]), + "robot0_gripper_qpos": np.array([0.02, -0.02]), + "agentview_image": np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8), + "robot0_eye_in_hand_image": np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8), + } + + +@pytest.fixture +def state_keys(): + return ["robot0_joint_pos", "robot0_joint_vel", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"] + + +def _make_policy(protocol, state_keys): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + p = Gr00tPolicy(data_config="libero", host="localhost", port=9999, protocol=protocol) + p.set_robot_state_keys(state_keys) + return p + + +# -- DataConfig tests -------------------------------------------------------- + + +class TestDataConfig: + """Configs are unchanged from original; protocol is additive.""" + + def test_original_configs_unchanged(self): + """Every original config still has its original keys.""" + for name in ("libero", "libero_spatial", "libero_goal", "libero_meanstd"): + cfg = load_data_config(name) + assert "video.image" in cfg["video"] + assert cfg["action"][0].startswith("action.robot0_") + assert cfg["state"] == ["state"] + + def test_libero_has_sim_wrapper_protocol(self): + assert load_data_config("libero").get("protocol") == "sim_wrapper" + + def test_meanstd_has_no_protocol(self): + assert "protocol" not in load_data_config("libero_meanstd") + + def test_protocol_override(self): + cfg = load_data_config("libero", protocol="direct") + assert cfg["protocol"] == "direct" + + def test_colon_syntax(self): + cfg = load_data_config("libero:direct") + assert cfg["protocol"] == "direct" + + def test_dict_passthrough(self): + d = {"video": ["v"], "state": ["s"], "action": ["a"], "language": ["l"]} + assert load_data_config(d) == d + + def test_unknown_raises(self): + with pytest.raises(ValueError): + load_data_config("nonexistent") + + def test_fuzzy_match(self): + assert load_data_config("libero_custom") is not None + + def test_protocols_complete(self): + for name, p in PROTOCOLS.items(): + for field in ("video_ndim", "state_dtype", "request_wrap", "response_batch_dim", "language_type"): + assert field in p, f"{name} missing {field}" + + +# -- sim_wrapper observation -------------------------------------------------- + + +class TestSimWrapperObservation: + + def test_video_5d(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "test") + assert obs["video.image"].ndim == 5 + assert obs["video.image"].shape == (1, 1, 256, 256, 3) + assert obs["video.image"].dtype == np.uint8 + + def test_state_float32(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "test") + assert obs["state.x"].dtype == np.float32 + + def test_state_3d(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "test") + assert obs["state.x"].ndim == 3 + + def test_state_values(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "test") + assert float(obs["state.x"].flat[0]) == pytest.approx(0.1) + assert float(obs["state.y"].flat[0]) == pytest.approx(0.2) + assert float(obs["state.z"].flat[0]) == pytest.approx(0.3) + + def test_gripper_multi_dim(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "test") + assert obs["state.gripper"].shape[-1] == 2 # robot0_gripper_qpos is 2-dim + + def test_language_tuple(self, libero_obs, state_keys): + obs = _make_policy("sim_wrapper", state_keys)._build_observation(libero_obs, "hi") + assert obs["annotation.human.action.task_description"] == ("hi",) + + +# -- direct observation ------------------------------------------------------- + + +class TestDirectObservation: + + def test_video_4d(self, libero_obs, state_keys): + obs = _make_policy("direct", state_keys)._build_observation(libero_obs, "test") + assert obs["video.image"].ndim == 4 + + def test_state_float64(self, libero_obs, state_keys): + obs = _make_policy("direct", state_keys)._build_observation(libero_obs, "test") + assert obs["state.x"].dtype == np.float64 + + def test_language_list(self, libero_obs, state_keys): + obs = _make_policy("direct", state_keys)._build_observation(libero_obs, "hi") + assert obs["annotation.human.action.task_description"] == ["hi"] + + +# -- Action conversion -------------------------------------------------------- + + +class TestActionConversion: + + def test_libero_7dim(self, state_keys): + p = _make_policy("sim_wrapper", state_keys) + chunk = { + f"action.{k}": np.random.randn(16, 1).astype(np.float32) + for k in ("x", "y", "z", "roll", "pitch", "yaw", "gripper") + } + actions = p._to_robot_actions(chunk) + assert len(actions) == 16 + assert all(len(a["action"]) == 7 for a in actions) + + def test_batch_dim_stripped(self, state_keys): + p = _make_policy("sim_wrapper", state_keys) + chunk = {f"action.{k}": np.zeros((1, 8, 1)) for k in ("x", "y", "z", "roll", "pitch", "yaw", "gripper")} + actions = p._to_robot_actions(chunk) + assert len(actions) == 8 + + def test_fallback(self, state_keys): + chunk = _make_policy("sim_wrapper", state_keys)._create_fallback_actions() + assert all(isinstance(v, np.ndarray) for v in chunk.values()) + + def test_fallback_per_key_dims(self, state_keys): + """Fallback actions should have correct per-key dimensionality.""" + chunk = _make_policy("sim_wrapper", state_keys)._create_fallback_actions() + assert chunk["action.robot0_joint_pos"].shape == (16, 7) + assert chunk["action.robot0_eef_pos"].shape == (16, 3) + assert chunk["action.robot0_eef_quat"].shape == (16, 4) + assert chunk["action.robot0_gripper_qpos"].shape == (16, 1) + + def test_fallback_direct_horizon(self, state_keys): + """Direct protocol fallback should use horizon=8.""" + chunk = _make_policy("direct", state_keys)._create_fallback_actions() + assert chunk["action.robot0_joint_pos"].shape[0] == 8 + + +# -- Defensive edge cases ---------------------------------------------------- + + +class TestDefensiveEdgeCases: + + def test_missing_eef_pos_uses_zeros(self, state_keys): + """State should still be populated when eef_pos is missing.""" + obs_no_eef = { + "robot0_joint_pos": np.zeros(7), + "robot0_joint_vel": np.zeros(7), + "robot0_gripper_qpos": np.array([0.02, -0.02]), + "agentview_image": np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8), + "robot0_eye_in_hand_image": np.random.randint(0, 255, (256, 256, 3), dtype=np.uint8), + } + p = _make_policy("sim_wrapper", state_keys) + built = p._build_observation(obs_no_eef, "test") + assert "state.x" in built + assert float(built["state.x"].flat[0]) == 0.0 + + def test_video_ndim_assertion(self, state_keys): + """_add_video_dims should reject non-3D input.""" + from strands_robots_sim.policies.groot import Gr00tPolicy + + with pytest.raises(AssertionError, match="Expected.*H, W, C"): + Gr00tPolicy._add_video_dims(np.zeros((256, 256), dtype=np.uint8), ndim=5) + + def test_video_ndim_4d_rejected(self, state_keys): + """Already-batched 4D image should be rejected.""" + from strands_robots_sim.policies.groot import Gr00tPolicy + + with pytest.raises(AssertionError, match="Expected.*H, W, C"): + Gr00tPolicy._add_video_dims(np.zeros((1, 256, 256, 3), dtype=np.uint8), ndim=5) + + +# -- Protocol selection ------------------------------------------------------- + + +class TestProtocolSelection: + + def test_colon_sim_wrapper(self): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + assert ( + Gr00tPolicy(data_config="libero:sim_wrapper", host="localhost", port=9999).protocol_name + == "sim_wrapper" + ) + + def test_colon_direct(self): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + assert Gr00tPolicy(data_config="libero:direct", host="localhost", port=9999).protocol_name == "direct" + + def test_legacy_n1d6(self): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + assert ( + Gr00tPolicy(data_config="libero", host="localhost", port=9999, groot_version="n1d6").protocol_name + == "sim_wrapper" + ) + + def test_legacy_n1d5(self): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + assert ( + Gr00tPolicy(data_config="libero", host="localhost", port=9999, groot_version="n1d5").protocol_name + == "direct" + ) + + def test_default_from_config(self): + from strands_robots_sim.policies.groot import Gr00tPolicy + + with patch(_CLIENT_PATH): + assert Gr00tPolicy(data_config="libero", host="localhost", port=9999).protocol_name == "sim_wrapper"