diff --git a/examples/libero_sac/run_pi05_libero_sac_separated.sh b/examples/libero_sac/run_pi05_libero_sac_separated.sh index 2ae28b6..0ebd81f 100644 --- a/examples/libero_sac/run_pi05_libero_sac_separated.sh +++ b/examples/libero_sac/run_pi05_libero_sac_separated.sh @@ -94,7 +94,7 @@ $PYTHON -m verl_vla.trainer.main_sac \ env.actor.model.action_dim=7 \ env.train.device=$ENV_DEVICE \ env.train.max_episode_steps=$MAX_EPISODE_STEPS \ - +env.train.async_reset=$ASYNC_RESET \ + env.train.async_reset=$ASYNC_RESET \ $MAX_INTERACTIONS_CONFIG \ +env.train.pipeline_stage_num=$NUM_STAGE \ env.train.video_cfg.save_video=True \ diff --git a/src/verl_vla/env_loop/env_loop.py b/src/verl_vla/env_loop/env_loop.py index 37b291c..28d3fb0 100644 --- a/src/verl_vla/env_loop/env_loop.py +++ b/src/verl_vla/env_loop/env_loop.py @@ -54,7 +54,9 @@ def __init__(self, env_wg: RayWorkerGroup, rollout_wg: RayWorkerGroup, config: D self.envs_per_stage = self.total_envs // self.stage_num self.default_max_interactions = config.env.train.max_episode_steps // config.env.actor.model.num_action_chunks - self.configured_max_interactions = config.env.train.get("max_interactions", self.default_max_interactions) + self.configured_max_interactions = config.env.train.get("max_interactions") + if self.configured_max_interactions is None: + self.configured_max_interactions = self.default_max_interactions self.max_interactions = self.configured_max_interactions self.warmup_max_interactions = False diff --git a/src/verl_vla/teleop/config.py b/src/verl_vla/teleop/config.py index 2dc2f45..e71f894 100644 --- a/src/verl_vla/teleop/config.py +++ b/src/verl_vla/teleop/config.py @@ -47,6 +47,24 @@ class XRControllerTeleopConfig: max_events: int = 256 +@dataclass(frozen=True) +class GamepadTeleopConfig: + pos_sensitivity: float = 0.5 + rot_sensitivity: float = 0.5 + intervention_button: str = "RT" + gripper_button: str = "X" + button_threshold: float = 0.5 + max_events: int = 256 + left_stick_x_axis: str = "axis_0" + left_stick_y_axis: str = "axis_1" + right_stick_y_axis: str = "axis_3" + right_stick_x_axis: str = "axis_2" + dpad_up_button: str = "DUp" + dpad_down_button: str = "DDown" + dpad_left_button: str = "DLeft" + dpad_right_button: str = "DRight" + + @dataclass(frozen=True) class TeleopConfig: enable: bool = False @@ -55,6 +73,7 @@ class TeleopConfig: server: TeleopServerConfig = field(default_factory=TeleopServerConfig) keyboard: KeyboardTeleopConfig = field(default_factory=KeyboardTeleopConfig) xr_controller: XRControllerTeleopConfig = field(default_factory=XRControllerTeleopConfig) + gamepad: GamepadTeleopConfig = field(default_factory=GamepadTeleopConfig) def load_teleop_config(cfg: DictConfig | Any, device: str | None = None) -> TeleopConfig: @@ -96,6 +115,13 @@ def load_teleop_config(cfg: DictConfig | Any, device: str | None = None) -> Tele xr_controller_cfg = XRControllerTeleopConfig( **{key: xr_controller_raw[key] for key in XRControllerTeleopConfig.__annotations__ if key in xr_controller_raw} ) + gamepad_raw = raw.get("gamepad", {}) + if isinstance(gamepad_raw, DictConfig): + gamepad_raw = OmegaConf.to_container(gamepad_raw, resolve=True) + gamepad_raw = dict(gamepad_raw or {}) + gamepad_cfg = GamepadTeleopConfig( + **{key: gamepad_raw[key] for key in GamepadTeleopConfig.__annotations__ if key in gamepad_raw} + ) devices = raw.get("devices") if devices is None: devices = [raw.get("device", TeleopConfig.device)] @@ -114,4 +140,5 @@ def load_teleop_config(cfg: DictConfig | Any, device: str | None = None) -> Tele server=server_cfg, keyboard=keyboard_cfg, xr_controller=xr_controller_cfg, + gamepad=gamepad_cfg, ) diff --git a/src/verl_vla/teleop/devices/__init__.py b/src/verl_vla/teleop/devices/__init__.py index 123e48f..2d23ad8 100644 --- a/src/verl_vla/teleop/devices/__init__.py +++ b/src/verl_vla/teleop/devices/__init__.py @@ -13,12 +13,15 @@ # limitations under the License. from verl_vla.teleop.devices.device_base import DeviceBase, DeviceEvent +from verl_vla.teleop.devices.gamepad import GamepadDevice, GamepadDeviceCfg from verl_vla.teleop.devices.keyboard import KeyboardDevice, KeyboardDeviceCfg from verl_vla.teleop.devices.xr_controller import XRControllerDevice, XRControllerDeviceCfg __all__ = [ "DeviceBase", "DeviceEvent", + "GamepadDevice", + "GamepadDeviceCfg", "KeyboardDevice", "KeyboardDeviceCfg", "XRControllerDevice", diff --git a/src/verl_vla/teleop/devices/gamepad.py b/src/verl_vla/teleop/devices/gamepad.py new file mode 100644 index 0000000..30767b6 --- /dev/null +++ b/src/verl_vla/teleop/devices/gamepad.py @@ -0,0 +1,122 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any + +from typing_extensions import override + +from verl_vla.teleop.devices.device_base import DeviceBase, DeviceEvent + + +@dataclass(frozen=True) +class GamepadDeviceCfg: + max_events: int = 256 + + +class GamepadDevice(DeviceBase): + name = "gamepad" + + def __init__(self, cfg: GamepadDeviceCfg | None = None): + self.cfg = cfg or GamepadDeviceCfg() + super().__init__(max_events=self.cfg.max_events) + self._latest_state: dict[str, Any] = {} + self._button_states: dict[str, bool] = {} + self._axis_values: dict[str, float] = {} + self._connected = False + self._device_id = "" + + @override + def reset(self) -> None: + with self._lock: + self._latest_state.clear() + self._button_states.clear() + self._axis_values.clear() + self._events.clear() + self._connected = False + self._device_id = "" + + @override + def handle_event(self, event: DeviceEvent) -> None: + with self._lock: + event_type = event.event_type.lower() + if event_type == "gamepad_update": + self._latest_state = dict(event.raw) + buttons_raw = event.raw.get("buttons") + if not isinstance(buttons_raw, dict): + buttons_raw = {} + axes_raw = event.raw.get("axes") + if not isinstance(axes_raw, dict): + axes_raw = {} + + for key, value in buttons_raw.items(): + if isinstance(value, dict): + self._button_states[key] = bool(value.get("pressed", False)) + else: + self._button_states[key] = bool(value) + + for key, value in axes_raw.items(): + self._axis_values[key] = float(value) if isinstance(value, int | float) else 0.0 + + self._connected = True + self._device_id = str(event.raw.get("id", "")) + + elif event_type == "gamepad_disconnect": + self._connected = False + self._device_id = "" + self._button_states.clear() + self._axis_values.clear() + + self._record_event(event) + + @override + def snapshot(self) -> dict[str, Any]: + with self._lock: + pressed_buttons = sorted([k for k, v in self._button_states.items() if v]) + active_axes = {k: round(v, 3) for k, v in self._axis_values.items() if abs(v) > 0.01} + is_active = bool(pressed_buttons) or bool(active_axes) + return { + "device": self.name, + "connected": self._connected, + "device_id": self._device_id, + "pressed_buttons": pressed_buttons, + "active_axes": active_axes, + "timestamp": self._latest_state.get("timestamp"), + "active": is_active, + "key_bindings": self.key_bindings(), + } + + def key_bindings(self) -> dict[str, str]: + return { + "Left Stick Y": "+x / -x", + "Left Stick X": "+y / -y", + "Right Stick Y": "+z / -z", + "Right Stick X": "+yaw / -yaw", + "D-Pad Left/Right": "+roll / -roll", + "D-Pad Up/Down": "+pitch / -pitch", + "RT": "intervention (hold)", + "X": "toggle gripper", + } + + def is_active(self) -> bool: + with self._lock: + return any(self._button_states.values()) or any(abs(v) > 0.01 for v in self._axis_values.values()) + + def get_button(self, button_name: str) -> bool: + with self._lock: + return self._button_states.get(button_name, False) + + def get_axis(self, axis_name: str) -> float: + with self._lock: + return self._axis_values.get(axis_name, 0.0) diff --git a/src/verl_vla/teleop/obs_server/html/gamepad_device.js b/src/verl_vla/teleop/obs_server/html/gamepad_device.js new file mode 100644 index 0000000..93c2456 --- /dev/null +++ b/src/verl_vla/teleop/obs_server/html/gamepad_device.js @@ -0,0 +1,190 @@ +/* +Copyright 2026 Bytedance Ltd. and/or its affiliates + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +class TeleopGamepadDevice { + constructor(socketProvider) { + this.socketProvider = socketProvider; + this.pollInterval = null; + this.connected = false; + this.lastButtons = {}; + this.lastAxes = {}; + this.deadzone = 0.15; + this.threshold = 0.5; + this.handleConnect = this.handleConnect.bind(this); + this.handleDisconnect = this.handleDisconnect.bind(this); + this.pollState = this.pollState.bind(this); + + this.buttonNames = { + 0: "A", + 1: "B", + 2: "X", + 3: "Y", + 4: "LB", + 5: "RB", + 6: "LT", + 7: "RT", + 8: "View", + 9: "Menu", + 10: "LS", + 11: "RS", + 12: "DUp", + 13: "DDown", + 14: "DLeft", + 15: "DRight" + }; + } + + attach() { + window.addEventListener("gamepadconnected", this.handleConnect); + window.addEventListener("gamepaddisconnected", this.handleDisconnect); + this.startPolling(); + } + + detach() { + window.removeEventListener("gamepadconnected", this.handleConnect); + window.removeEventListener("gamepaddisconnected", this.handleDisconnect); + this.stopPolling(); + } + + handleConnect(event) { + this.connected = true; + console.log("Gamepad connected:", event.gamepad.id); + } + + handleDisconnect(event) { + this.connected = false; + this.lastButtons = {}; + this.lastAxes = {}; + console.log("Gamepad disconnected:", event.gamepad.id); + } + + startPolling() { + if (this.pollInterval) { + clearInterval(this.pollInterval); + } + this.pollInterval = setInterval(this.pollState, 16); + } + + stopPolling() { + if (this.pollInterval) { + clearInterval(this.pollInterval); + this.pollInterval = null; + } + } + + applyDeadzone(value) { + if (Math.abs(value) < this.deadzone) { + return 0; + } + return value; + } + + pollState() { + const gamepads = navigator.getGamepads ? navigator.getGamepads() : []; + let hasActiveGamepad = false; + + for (const gamepad of gamepads) { + if (!gamepad || !gamepad.connected) { + continue; + } + + hasActiveGamepad = true; + const buttons = {}; + const axes = {}; + let changed = false; + + for (let i = 0; i < gamepad.buttons.length; i++) { + const button = gamepad.buttons[i]; + // 使用友好名称,如 "A", "B", "RT", "LT" 等 + const key = this.buttonNames[i] || `button_${i}`; + buttons[key] = { + pressed: button.pressed, + touched: button.touched, + value: button.value + }; + if (this.lastButtons[key] === undefined || + this.lastButtons[key].pressed !== button.pressed || + Math.abs(this.lastButtons[key].value - button.value) > 0.01) { + changed = true; + } + } + + for (let i = 0; i < gamepad.axes.length; i++) { + const rawValue = gamepad.axes[i]; + const value = this.applyDeadzone(rawValue); + const key = `axis_${i}`; + axes[key] = value; + if (this.lastAxes[key] === undefined || + Math.abs(this.lastAxes[key] - value) > 0.01) { + changed = true; + } + } + + if (changed || Object.keys(this.lastButtons).length === 0) { + this.sendState(gamepad, buttons, axes); + this.lastButtons = buttons; + this.lastAxes = axes; + } + break; + } + + if (!hasActiveGamepad && Object.keys(this.lastButtons).length > 0) { + this.sendDisconnect(); + } + } + + sendState(gamepad, buttons, axes) { + const socket = this.socketProvider(); + if (!socket || socket.readyState !== WebSocket.OPEN) { + return; + } + + const payload = { + event_type: "gamepad_update", + timestamp: Date.now() / 1000, + id: gamepad.id, + index: gamepad.index, + mapping: gamepad.mapping, + buttons, + axes + }; + + socket.send(JSON.stringify({ + type: "gamepad_update", + device: "gamepad", + payload + })); + } + + sendDisconnect() { + const socket = this.socketProvider(); + if (!socket || socket.readyState !== WebSocket.OPEN) { + return; + } + + socket.send(JSON.stringify({ + type: "gamepad_update", + device: "gamepad", + payload: { + event_type: "gamepad_disconnect", + timestamp: Date.now() / 1000 + } + })); + + this.lastButtons = {}; + this.lastAxes = {}; + } +} \ No newline at end of file diff --git a/src/verl_vla/teleop/obs_server/html/index.html b/src/verl_vla/teleop/obs_server/html/index.html index 78fae08..f740b72 100644 --- a/src/verl_vla/teleop/obs_server/html/index.html +++ b/src/verl_vla/teleop/obs_server/html/index.html @@ -63,6 +63,7 @@ +