Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/qwen_gr00t/conf/train/qwen_gr00t.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ data:
state_key: eepose
# Path to the training data
data_path: /workspace/datasets/IPEC-COMMUNITY/libero_goal_no_noops_1.0.0_lerobot/
image_key_order:
- observation.images.image
- observation.images.wrist_image
# Path to VLM co-training data (WDS/Energon format). Leave unset to disable co-training.
# vlm_data_path: /workspace/datasets/vlm_cotrain/
tolerance_s: 0.0001
Expand Down
23 changes: 23 additions & 0 deletions flagscale/models/vla/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
from flagscale.models.configs.types import FeatureType, PolicyFeature


def reorder_visual_input_features(
input_features: dict[str, PolicyFeature],
preferred_image_order: list[str] | None = None,
) -> dict[str, PolicyFeature]:
"""Reorder visual features while leaving other inputs untouched."""

if not preferred_image_order:
return dict(input_features)

visual_keys = [key for key, ft in input_features.items() if ft.type == FeatureType.VISUAL]
ordered_visual_keys = [key for key in preferred_image_order if key in visual_keys]
# Keep dataset-defined order for visual keys that are not listed in the recipe.
ordered_visual_keys.extend(key for key in visual_keys if key not in ordered_visual_keys)

reordered = {key: ft for key, ft in input_features.items() if ft.type != FeatureType.VISUAL}
for key in ordered_visual_keys:
reordered[key] = input_features[key]
return reordered


def get_vlm_config(vlm_config) -> dict:
"""
Extract common fields from any VLM config, handling structural differences.
Expand Down
21 changes: 21 additions & 0 deletions flagscale/train/train_qwen_gr00t.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Mainly adopted from
# https://github.com/huggingface/lerobot/blob/2b304eeb841ae6c371e3dd341bbbb9dd254b07cb/src/lerobot/scripts/lerobot_train.py
# ruff: noqa: I001

import argparse
import os
Expand Down Expand Up @@ -48,6 +49,8 @@
)
from flagscale.train.utils.random_utils import serialize_rng_state, deserialize_rng_state
from flagscale.train.utils.optim_setup import setup_optimizer_and_scheduler
from flagscale.models.vla.qwen_gr00t import QwenGr00t
from flagscale.models.vla.utils import reorder_visual_input_features
from flagscale.models.vla import TrainablePolicy
from flagscale.models.vla.pretrained_config import PreTrainedConfig
from flagscale.platforms import get_platform
Expand Down Expand Up @@ -206,6 +209,24 @@ def _format_meter_val(meter: AverageMeter) -> str:
return " ".join(display_list)


def make_policy(
config: TrainConfig,
ds_meta: LeRobotDatasetMetadata | None = None,
):
features = dataset_to_policy_features(ds_meta.features)

# Use == instead of `is` for FeatureType.ACTION comparison
# because flagscale.FeatureType and lerobot.FeatureType are different enum classes
output_features = {
key: ft
for key, ft in features.items()
if ft.type == FeatureType.ACTION
}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
# Preserve dataset feature order unless the recipe explicitly requests a visual reorder.
image_key_order = list(getattr(config.data, "image_key_order", []) or [])
if image_key_order:
input_features = reorder_visual_input_features(input_features, image_key_order)

def make_pre_post_processors(
policy,
Expand Down
52 changes: 51 additions & 1 deletion tests/unit_tests/models/vla/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from flagscale.models.vla.utils import get_vlm_config
from flagscale.models.configs.types import FeatureType, PolicyFeature
from flagscale.models.vla.utils import get_vlm_config, reorder_visual_input_features


class MockConfigDirect:
Expand Down Expand Up @@ -32,3 +33,52 @@ def test_nested_config(self):
def test_invalid_config_raises(self):
with self.assertRaises(ValueError):
get_vlm_config(MockConfigInvalid())


class TestOrderVisualInputFeatures(unittest.TestCase):
def test_prefers_explicit_image_order(self):
input_features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(7,)),
"observation.images.wrist_image": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, 224, 224)
),
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}

reordered = reorder_visual_input_features(
input_features,
["observation.images.image", "observation.images.wrist_image"],
)

visual_keys = [key for key, ft in reordered.items() if ft.type == FeatureType.VISUAL]
self.assertEqual(
visual_keys,
["observation.images.image", "observation.images.wrist_image"],
)
self.assertIn("observation.state", reordered)

def test_appends_unlisted_visual_keys(self):
input_features = {
"observation.images.left_wrist_0_rgb": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, 224, 224)
),
"observation.images.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
"observation.images.right_wrist_0_rgb": PolicyFeature(
type=FeatureType.VISUAL, shape=(3, 224, 224)
),
}

reordered = reorder_visual_input_features(
input_features,
["observation.images.image"],
)

visual_keys = [key for key, ft in reordered.items() if ft.type == FeatureType.VISUAL]
self.assertEqual(
visual_keys,
[
"observation.images.image",
"observation.images.left_wrist_0_rgb",
"observation.images.right_wrist_0_rgb",
],
)
111 changes: 111 additions & 0 deletions tests/unit_tests/train/test_train_qwen_gr00t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import unittest
from types import SimpleNamespace
from unittest.mock import patch

from flagscale.models.configs.types import FeatureType
from flagscale.train.train_qwen_gr00t import make_policy


class FakePolicy:
def __init__(self, config):
self.config = config
self.input_features = None
self.output_features = None
self.device = None

def to(self, device):
self.device = device
return self


class FakeMetadata:
def __init__(self, features):
self.features = features


def make_image_feature():
return {
"dtype": "image",
"shape": (224, 224, 3),
"names": ["height", "width", "channels"],
}


class TestMakePolicyImageOrder(unittest.TestCase):
def test_respects_explicit_image_key_order(self):
ds_meta = FakeMetadata(
{
"observation.images.wrist_image": make_image_feature(),
"observation.images.image": make_image_feature(),
"observation.state": {"dtype": "float32", "shape": (7,)},
"action": {"dtype": "float32", "shape": (7,)},
}
)
config = SimpleNamespace(
data=SimpleNamespace(
image_key_order=[
"observation.images.image",
"observation.images.wrist_image",
]
)
)

with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy):
policy = make_policy(config=config, ds_meta=ds_meta)

visual_keys = [
key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL
]
self.assertEqual(
visual_keys,
["observation.images.image", "observation.images.wrist_image"],
)
self.assertEqual(policy.device, "cuda")

def test_keeps_dataset_visual_order_without_image_key_order(self):
ds_meta = FakeMetadata(
{
"observation.images.wrist_image": make_image_feature(),
"observation.images.image": make_image_feature(),
"observation.state": {"dtype": "float32", "shape": (7,)},
"action": {"dtype": "float32", "shape": (7,)},
}
)
config = SimpleNamespace(data=SimpleNamespace())

with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy):
policy = make_policy(config=config, ds_meta=ds_meta)

visual_keys = [
key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL
]
self.assertEqual(
visual_keys,
["observation.images.wrist_image", "observation.images.image"],
)

def test_appends_visual_keys_not_listed_in_config(self):
ds_meta = FakeMetadata(
{
"observation.images.left_wrist_0_rgb": make_image_feature(),
"observation.images.image": make_image_feature(),
"observation.images.right_wrist_0_rgb": make_image_feature(),
"action": {"dtype": "float32", "shape": (7,)},
}
)
config = SimpleNamespace(data=SimpleNamespace(image_key_order=["observation.images.image"]))

with patch("flagscale.train.train_qwen_gr00t.QwenGr00t", FakePolicy):
policy = make_policy(config=config, ds_meta=ds_meta)

visual_keys = [
key for key, ft in policy.input_features.items() if ft.type == FeatureType.VISUAL
]
self.assertEqual(
visual_keys,
[
"observation.images.image",
"observation.images.left_wrist_0_rgb",
"observation.images.right_wrist_0_rgb",
],
)
Loading