Skip to content

Commit 326aca0

Browse files
authored
Add API Examples (#2289)
* (unscrewing things up) (#2288) * fix: expose a function explicitly building a frame for inference * fix: first make dataset frame, then make ready for inference * fix: reducing reliance on lerobot record for policy's ouptuts too * fix: encapsulating squeezing out + device handling from predict action * fix: remove duplicated call to build_inference_frame and add a function to only perform data type handling (whole conversion is: keys matching + data type conversion) * refactor(envs): add custom-observation-size (#2167) * fix: add MockMotorBus to MockRobot * rl: first drafts * add: all components of HIL SERL * fix: actor block works * fix: less friction, less friction * add: hil-serl complete example * fix: dataset names * fix: restructuring example folder * fix: act works but found bug in how ACT works * fix: same path for both pre and postprocessors * fix: paths * add: example usage for act * add: using ACT example * fix: training examples * fix: using examples * fix: camera index * fix: rename workflows into tutorial so that the path of the files is lerobot/examples/tutorial/... * fix: upload everything in one repo * fix: model name * fix: simplify model path * add: VLAs example --------- Signed-off-by: Francesco Capuano <[email protected]> * fix: minor fix using named attributes * fix: change model to act * fix: named attributes for inference frame building * fix: minor fixes to smolvla * fix: small changes to pi0 * remove: old file that should have never been committed (ups sorry sorry) --------- Signed-off-by: Francesco Capuano <[email protected]>
1 parent be46bde commit 326aca0

File tree

10 files changed

+920
-0
lines changed

10 files changed

+920
-0
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""This script demonstrates how to train ACT Policy on a real-world dataset."""
2+
3+
from pathlib import Path
4+
5+
import torch
6+
7+
from lerobot.configs.types import FeatureType
8+
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
9+
from lerobot.datasets.utils import dataset_to_policy_features
10+
from lerobot.policies.act.configuration_act import ACTConfig
11+
from lerobot.policies.act.modeling_act import ACTPolicy
12+
from lerobot.policies.factory import make_pre_post_processors
13+
14+
15+
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
16+
if delta_indices is None:
17+
return [0]
18+
19+
return [i / fps for i in delta_indices]
20+
21+
22+
output_directory = Path("outputs/robot_learning_tutorial/act")
23+
output_directory.mkdir(parents=True, exist_ok=True)
24+
25+
# Select your device
26+
device = torch.device("mps") # or "cuda" or "cpu"
27+
28+
dataset_id = "lerobot/svla_so101_pickplace"
29+
30+
# This specifies the inputs the model will be expecting and the outputs it will produce
31+
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
32+
features = dataset_to_policy_features(dataset_metadata.features)
33+
34+
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
35+
input_features = {key: ft for key, ft in features.items() if key not in output_features}
36+
37+
cfg = ACTConfig(input_features=input_features, output_features=output_features)
38+
policy = ACTPolicy(cfg)
39+
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
40+
41+
policy.train()
42+
policy.to(device)
43+
44+
# To perform action chunking, ACT expects a given number of actions as targets
45+
delta_timestamps = {
46+
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
47+
}
48+
49+
# add image features if they are present
50+
delta_timestamps |= {
51+
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
52+
}
53+
54+
# Instantiate the dataset
55+
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
56+
57+
# Create the optimizer and dataloader for offline training
58+
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
59+
batch_size = 32
60+
dataloader = torch.utils.data.DataLoader(
61+
dataset,
62+
batch_size=batch_size,
63+
shuffle=True,
64+
pin_memory=device.type != "cpu",
65+
drop_last=True,
66+
)
67+
68+
# Number of training steps and logging frequency
69+
training_steps = 1
70+
log_freq = 1
71+
72+
# Run training loop
73+
step = 0
74+
done = False
75+
while not done:
76+
for batch in dataloader:
77+
batch = preprocessor(batch)
78+
loss, _ = policy.forward(batch)
79+
loss.backward()
80+
optimizer.step()
81+
optimizer.zero_grad()
82+
83+
if step % log_freq == 0:
84+
print(f"step: {step} loss: {loss.item():.3f}")
85+
step += 1
86+
if step >= training_steps:
87+
done = True
88+
break
89+
90+
# Save the policy checkpoint, alongside the pre/post processors
91+
policy.save_pretrained(output_directory)
92+
preprocessor.save_pretrained(output_directory)
93+
postprocessor.save_pretrained(output_directory)
94+
95+
# Save all assets to the Hub
96+
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
97+
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
98+
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import torch
2+
3+
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
4+
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
5+
from lerobot.policies.act.modeling_act import ACTPolicy
6+
from lerobot.policies.factory import make_pre_post_processors
7+
from lerobot.policies.utils import build_inference_frame, make_robot_action
8+
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
9+
from lerobot.robots.so100_follower.so100_follower import SO100Follower
10+
11+
device = torch.device("mps") # or "cuda" or "cpu"
12+
model_id = "fracapuano/robot_learning_tutorial_act"
13+
model = ACTPolicy.from_pretrained(model_id)
14+
15+
dataset_id = "lerobot/svla_so101_pickplace"
16+
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
17+
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
18+
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
19+
20+
# # find ports using lerobot-find-port
21+
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
22+
23+
# # the robot ids are used the load the right calibration files
24+
follower_id = ... # something like "follower_so100"
25+
26+
MAX_EPISODES = 5
27+
MAX_STEPS_PER_EPISODE = 20
28+
29+
# Robot and environment configuration
30+
# Camera keys must match the name and resolutions of the ones used for training!
31+
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
32+
camera_config = {
33+
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
34+
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
35+
}
36+
37+
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
38+
robot = SO100Follower(robot_cfg)
39+
robot.connect()
40+
41+
for _ in range(MAX_EPISODES):
42+
for _ in range(MAX_STEPS_PER_EPISODE):
43+
obs = robot.get_observation()
44+
obs_frame = build_inference_frame(
45+
observation=obs, ds_features=dataset_metadata.features, device=device
46+
)
47+
48+
obs = preprocess(obs_frame)
49+
50+
action = model.select_action(obs)
51+
action = postprocess(action)
52+
53+
action = make_robot_action(action, dataset_metadata.features)
54+
55+
robot.send_action(action)
56+
57+
print("Episode finished! Starting new episode...")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lerobot.async_inference.configs import PolicyServerConfig
2+
from lerobot.async_inference.policy_server import serve
3+
4+
host = ... # something like "127.0.0.1" if you're exposing to localhost
5+
port = ... # something like 8080
6+
7+
config = PolicyServerConfig(
8+
host=host,
9+
port=port,
10+
)
11+
serve(config)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import threading
2+
3+
from lerobot.async_inference.configs import RobotClientConfig
4+
from lerobot.async_inference.helpers import visualize_action_queue_size
5+
from lerobot.async_inference.robot_client import RobotClient
6+
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
7+
from lerobot.robots.so100_follower import SO100FollowerConfig
8+
9+
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
10+
# check the config.json on the Hub for the policy you are using to see the expected camera specs
11+
camera_cfg = {
12+
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
13+
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
14+
}
15+
16+
# # find ports using lerobot-find-port
17+
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
18+
19+
# # the robot ids are used the load the right calibration files
20+
follower_id = ... # something like "follower_so100"
21+
22+
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
23+
24+
server_address = ... # something like "127.0.0.1:8080" if using localhost
25+
26+
# 3. Create client configuration
27+
client_cfg = RobotClientConfig(
28+
robot=robot_cfg,
29+
server_address=server_address,
30+
policy_device="mps",
31+
policy_type="act",
32+
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
33+
chunk_size_threshold=0.5, # g
34+
actions_per_chunk=50, # make sure this is less than the max actions of the policy
35+
)
36+
37+
# 4. Create and start client
38+
client = RobotClient(client_cfg)
39+
40+
# 5. Provide a textual description of the task
41+
task = ...
42+
43+
if client.start():
44+
# Start action receiver thread
45+
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
46+
action_receiver_thread.start()
47+
48+
try:
49+
# Run the control loop
50+
client.control_loop(task)
51+
except KeyboardInterrupt:
52+
client.stop()
53+
action_receiver_thread.join()
54+
# (Optionally) plot the action queue size
55+
visualize_action_queue_size(client.action_queue_size)
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""
2+
3+
from pathlib import Path
4+
5+
import torch
6+
7+
from lerobot.configs.types import FeatureType
8+
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
9+
from lerobot.datasets.utils import dataset_to_policy_features
10+
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
11+
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
12+
from lerobot.policies.factory import make_pre_post_processors
13+
14+
15+
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
16+
if delta_indices is None:
17+
return [0]
18+
19+
return [i / fps for i in delta_indices]
20+
21+
22+
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
23+
output_directory.mkdir(parents=True, exist_ok=True)
24+
25+
# Select your device
26+
device = torch.device("mps") # or "cuda" or "cpu"
27+
28+
dataset_id = "lerobot/svla_so101_pickplace"
29+
30+
# This specifies the inputs the model will be expecting and the outputs it will produce
31+
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
32+
features = dataset_to_policy_features(dataset_metadata.features)
33+
34+
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
35+
input_features = {key: ft for key, ft in features.items() if key not in output_features}
36+
37+
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
38+
policy = DiffusionPolicy(cfg)
39+
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
40+
41+
policy.train()
42+
policy.to(device)
43+
44+
# To perform action chunking, ACT expects a given number of actions as targets
45+
delta_timestamps = {
46+
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
47+
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
48+
}
49+
50+
# add image features if they are present
51+
delta_timestamps |= {
52+
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
53+
}
54+
55+
# Instantiate the dataset
56+
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
57+
58+
# Create the optimizer and dataloader for offline training
59+
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
60+
batch_size = 32
61+
dataloader = torch.utils.data.DataLoader(
62+
dataset,
63+
batch_size=batch_size,
64+
shuffle=True,
65+
pin_memory=device.type != "cpu",
66+
drop_last=True,
67+
)
68+
69+
# Number of training steps and logging frequency
70+
training_steps = 1
71+
log_freq = 1
72+
73+
# Run training loop
74+
step = 0
75+
done = False
76+
while not done:
77+
for batch in dataloader:
78+
batch = preprocessor(batch)
79+
loss, _ = policy.forward(batch)
80+
loss.backward()
81+
optimizer.step()
82+
optimizer.zero_grad()
83+
84+
if step % log_freq == 0:
85+
print(f"step: {step} loss: {loss.item():.3f}")
86+
step += 1
87+
if step >= training_steps:
88+
done = True
89+
break
90+
91+
# Save the policy checkpoint, alongside the pre/post processors
92+
policy.save_pretrained(output_directory)
93+
preprocessor.save_pretrained(output_directory)
94+
postprocessor.save_pretrained(output_directory)
95+
96+
# Save all assets to the Hub
97+
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
98+
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
99+
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
3+
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
4+
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
5+
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
6+
from lerobot.policies.factory import make_pre_post_processors
7+
from lerobot.policies.utils import build_inference_frame, make_robot_action
8+
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
9+
from lerobot.robots.so100_follower.so100_follower import SO100Follower
10+
11+
device = torch.device("mps") # or "cuda" or "cpu"
12+
model_id = "fracapuano/robot_learning_tutorial_diffusion"
13+
14+
model = DiffusionPolicy.from_pretrained(model_id)
15+
16+
dataset_id = "lerobot/svla_so101_pickplace"
17+
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
18+
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
19+
preprocess, postprocess = make_pre_post_processors(
20+
model.config, model_id, dataset_stats=dataset_metadata.stats
21+
)
22+
23+
MAX_EPISODES = 5
24+
MAX_STEPS_PER_EPISODE = 20
25+
26+
27+
# # find ports using lerobot-find-port
28+
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
29+
30+
# # the robot ids are used the load the right calibration files
31+
follower_id = ... # something like "follower_so100"
32+
33+
# Robot and environment configuration
34+
# Camera keys must match the name and resolutions of the ones used for training!
35+
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
36+
camera_config = {
37+
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
38+
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
39+
}
40+
41+
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
42+
robot = SO100Follower(robot_cfg)
43+
robot.connect()
44+
45+
46+
for _ in range(MAX_EPISODES):
47+
for _ in range(MAX_STEPS_PER_EPISODE):
48+
obs = robot.get_observation()
49+
obs_frame = build_inference_frame(
50+
observation=obs, ds_features=dataset_metadata.features, device=device
51+
)
52+
53+
obs = preprocess(obs_frame)
54+
55+
action = model.select_action(obs)
56+
action = postprocess(action)
57+
action = make_robot_action(action, dataset_metadata.features)
58+
robot.send_action(action)
59+
60+
print("Episode finished! Starting new episode...")

0 commit comments

Comments
 (0)