diff --git a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml index 85ba6bd351..6cd24209ce 100644 --- a/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml +++ b/examples/qwen_gr00t/conf/train/qwen_gr00t.yaml @@ -120,6 +120,9 @@ data: state_key: eepose # Path to the training data data_path: /workspace/datasets/IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot/ + image_key_order: + - observation.images.image + - observation.images.wrist_image # Path to VLM co-training data (WDS/Energon format). Leave unset to disable co-training. # vlm_data_path: /workspace/datasets/vlm_cotrain/ tolerance_s: 0.0001 diff --git a/flagscale/models/vla/utils.py b/flagscale/models/vla/utils.py index b0ee5fdd12..3c2dba6329 100644 --- a/flagscale/models/vla/utils.py +++ b/flagscale/models/vla/utils.py @@ -1,3 +1,26 @@ +from flagscale.models.configs.types import FeatureType, PolicyFeature + + +def reorder_visual_input_features( + input_features: dict[str, PolicyFeature], + preferred_image_order: list[str] | None = None, +) -> dict[str, PolicyFeature]: + """Reorder visual features while leaving other inputs untouched.""" + + if not preferred_image_order: + return dict(input_features) + + visual_keys = [key for key, ft in input_features.items() if ft.type == FeatureType.VISUAL] + ordered_visual_keys = [key for key in preferred_image_order if key in visual_keys] + # Keep dataset-defined order for visual keys that are not listed in the recipe. + ordered_visual_keys.extend(key for key in visual_keys if key not in ordered_visual_keys) + + reordered = {key: ft for key, ft in input_features.items() if ft.type != FeatureType.VISUAL} + for key in ordered_visual_keys: + reordered[key] = input_features[key] + return reordered + + def get_vlm_config(vlm_config) -> dict: """ Extract common fields from any VLM config, handling structural differences. diff --git a/flagscale/train/train_qwen_gr00t.py b/flagscale/train/train_qwen_gr00t.py index aaf5ab3e94..df34f4728c 100644 --- a/flagscale/train/train_qwen_gr00t.py +++ b/flagscale/train/train_qwen_gr00t.py @@ -1,5 +1,6 @@ # Mainly adopted from # https://github.com/huggingface/lerobot/blob/2b304eeb841ae6c371e3dd341bbbb9dd254b07cb/src/lerobot/scripts/lerobot_train.py +# ruff: noqa: I001 import argparse import os @@ -48,6 +49,8 @@ ) from flagscale.train.utils.random_utils import serialize_rng_state, deserialize_rng_state from flagscale.train.utils.optim_setup import setup_optimizer_and_scheduler +from flagscale.models.vla.qwen_gr00t import QwenGr00t +from flagscale.models.vla.utils import reorder_visual_input_features from flagscale.models.vla import TrainablePolicy from flagscale.models.vla.pretrained_config import PreTrainedConfig from flagscale.platforms import get_platform @@ -206,6 +209,24 @@ def _format_meter_val(meter: AverageMeter) -> str: return " ".join(display_list) +def make_policy( + config: TrainConfig, + ds_meta: LeRobotDatasetMetadata | None = None, +): + features = dataset_to_policy_features(ds_meta.features) + + # Use == instead of `is` for FeatureType.ACTION comparison + # because flagscale.FeatureType and lerobot.FeatureType are different enum classes + output_features = { + key: ft + for key, ft in features.items() + if ft.type == FeatureType.ACTION + } + input_features = {key: ft for key, ft in features.items() if key not in output_features} + # Preserve dataset feature order unless the recipe explicitly requests a visual reorder. + image_key_order = list(getattr(config.data, "image_key_order", []) or []) + if image_key_order: + input_features = reorder_visual_input_features(input_features, image_key_order) def make_pre_post_processors( policy, diff --git a/tests/unit_tests/models/vla/test_utils.py b/tests/unit_tests/models/vla/test_utils.py index 0ff384d481..7fa133716e 100644 --- a/tests/unit_tests/models/vla/test_utils.py +++ b/tests/unit_tests/models/vla/test_utils.py @@ -1,6 +1,7 @@ import unittest -from flagscale.models.vla.utils import get_vlm_config +from flagscale.models.configs.types import FeatureType, PolicyFeature +from flagscale.models.vla.utils import get_vlm_config, reorder_visual_input_features class MockConfigDirect: @@ -32,3 +33,52 @@ def test_nested_config(self): def test_invalid_config_raises(self): with self.assertRaises(ValueError): get_vlm_config(MockConfigInvalid()) + + +class TestOrderVisualInputFeatures(unittest.TestCase): + def test_prefers_explicit_image_order(self): + input_features = { + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(7,)), + "observation.images.wrist_image": PolicyFeature( + type=FeatureType.VISUAL, shape=(3, 224, 224) + ), + "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + + reordered = reorder_visual_input_features( + input_features, + ["observation.images.image", "observation.images.wrist_image"], + ) + + visual_keys = [key for key, ft in reordered.items() if ft.type == FeatureType.VISUAL] + self.assertEqual( + visual_keys, + ["observation.images.image", "observation.images.wrist_image"], + ) + self.assertIn("observation.state", reordered) + + def test_appends_unlisted_visual_keys(self): + input_features = { + "observation.images.left_wrist_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, shape=(3, 224, 224) + ), + "observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + "observation.images.right_wrist_0_rgb": PolicyFeature( + type=FeatureType.VISUAL, shape=(3, 224, 224) + ), + } + + reordered = reorder_visual_input_features( + input_features, + ["observation.images.image"], + ) + + visual_keys = [key for key, ft in reordered.items() if ft.type == FeatureType.VISUAL] + self.assertEqual( + visual_keys, + [ + "observation.images.image", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + ], + ) diff --git a/tests/unit_tests/train/test_train_qwen_gr00t.py b/tests/unit_tests/train/test_train_qwen_gr00t.py new file mode 100644 index 0000000000..f955986270 --- /dev/null +++ b/tests/unit_tests/train/test_train_qwen_gr00t.py @@ -0,0 +1,111 @@ +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +from flagscale.models.configs.types import FeatureType +from flagscale.train.train_qwen_gr00t import make_policy + + +class FakePolicy: + def __init__(self, config): + self.config = config + self.input_features = None + self.output_features = None + self.device = None + + def to(self, device): + self.device = device + return self + + +class FakeMetadata: + def __init__(self, features): + self.features = features + + +def make_image_feature(): + return { + "dtype": "image", + "shape": (224, 224, 3), + "names": ["height", "width", "channels"], + } + + +class TestMakePolicyImageOrder(unittest.TestCase): + def test_respects_explicit_image_key_order(self): + ds_meta = FakeMetadata( + { + "observation.images.wrist_image": make_image_feature(), + "observation.images.image": make_image_feature(), + "observation.state": {"dtype": "float32", "shape": (7,)}, + "action": {"dtype": "float32", "shape": (7,)}, + } + ) + config = SimpleNamespace( + data=SimpleNamespace( + image_key_order=[ + "observation.images.image", + "observation.images.wrist_image", + ] + ) + ) + + with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy): + policy = make_policy(config=config, ds_meta=ds_meta) + + visual_keys = [ + key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL + ] + self.assertEqual( + visual_keys, + ["observation.images.image", "observation.images.wrist_image"], + ) + self.assertEqual(policy.device, "cuda") + + def test_keeps_dataset_visual_order_without_image_key_order(self): + ds_meta = FakeMetadata( + { + "observation.images.wrist_image": make_image_feature(), + "observation.images.image": make_image_feature(), + "observation.state": {"dtype": "float32", "shape": (7,)}, + "action": {"dtype": "float32", "shape": (7,)}, + } + ) + config = SimpleNamespace(data=SimpleNamespace()) + + with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy): + policy = make_policy(config=config, ds_meta=ds_meta) + + visual_keys = [ + key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL + ] + self.assertEqual( + visual_keys, + ["observation.images.wrist_image", "observation.images.image"], + ) + + def test_appends_visual_keys_not_listed_in_config(self): + ds_meta = FakeMetadata( + { + "observation.images.left_wrist_0_rgb": make_image_feature(), + "observation.images.image": make_image_feature(), + "observation.images.right_wrist_0_rgb": make_image_feature(), + "action": {"dtype": "float32", "shape": (7,)}, + } + ) + config = SimpleNamespace(data=SimpleNamespace(image_key_order=["observation.images.image"])) + + with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy): + policy = make_policy(config=config, ds_meta=ds_meta) + + visual_keys = [ + key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL + ] + self.assertEqual( + visual_keys, + [ + "observation.images.image", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + ], + )