Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions examples/tutorial/act/act_training_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""This script demonstrates how to train ACT Policy on a real-world dataset."""

from pathlib import Path

import torch

from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors


def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]

return [i / fps for i in delta_indices]


output_directory = Path("outputs/robot_learning_tutorial/act")
output_directory.mkdir(parents=True, exist_ok=True)

# Select your device
device = torch.device("mps") # or "cuda" or "cpu"

dataset_id = "lerobot/svla_so101_pickplace"

# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)

output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

cfg = ACTConfig(input_features=input_features, output_features=output_features)
policy = ACTPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)

policy.train()
policy.to(device)

# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}

# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}

# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)

# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)

# Number of training steps and logging frequency
training_steps = 1
log_freq = 1

# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()

if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break

# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)

# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
57 changes: 57 additions & 0 deletions examples/tutorial/act/act_using_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch

from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower

device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_act"
model = ACTPolicy.from_pretrained(model_id)

dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)

# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"

# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"

MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20

# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}

robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()

for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)

obs = preprocess(obs_frame)

action = model.select_action(obs)
action = postprocess(action)

action = make_robot_action(action, dataset_metadata.features)

robot.send_action(action)

print("Episode finished! Starting new episode...")
11 changes: 11 additions & 0 deletions examples/tutorial/async-inf/policy_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lerobot.async_inference.configs import PolicyServerConfig
from lerobot.async_inference.policy_server import serve

host = ... # something like "127.0.0.1" if you're exposing to localhost
port = ... # something like 8080

config = PolicyServerConfig(
host=host,
port=port,
)
serve(config)
55 changes: 55 additions & 0 deletions examples/tutorial/async-inf/robot_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import threading

from lerobot.async_inference.configs import RobotClientConfig
from lerobot.async_inference.helpers import visualize_action_queue_size
from lerobot.async_inference.robot_client import RobotClient
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.robots.so100_follower import SO100FollowerConfig

# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
# check the config.json on the Hub for the policy you are using to see the expected camera specs
camera_cfg = {
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}

# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"

# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"

robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)

server_address = ... # something like "127.0.0.1:8080" if using localhost

# 3. Create client configuration
client_cfg = RobotClientConfig(
robot=robot_cfg,
server_address=server_address,
policy_device="mps",
policy_type="act",
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
chunk_size_threshold=0.5, # g
actions_per_chunk=50, # make sure this is less than the max actions of the policy
)

# 4. Create and start client
client = RobotClient(client_cfg)

# 5. Provide a textual description of the task
task = ...

if client.start():
# Start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
action_receiver_thread.start()

try:
# Run the control loop
client.control_loop(task)
except KeyboardInterrupt:
client.stop()
action_receiver_thread.join()
# (Optionally) plot the action queue size
visualize_action_queue_size(client.action_queue_size)
99 changes: 99 additions & 0 deletions examples/tutorial/diffusion/diffusion_training_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""This script demonstrates how to train Diffusion Policy on a real-world dataset."""

from pathlib import Path

import torch

from lerobot.configs.types import FeatureType
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import dataset_to_policy_features
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors


def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
if delta_indices is None:
return [0]

return [i / fps for i in delta_indices]


output_directory = Path("outputs/robot_learning_tutorial/diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Select your device
device = torch.device("mps") # or "cuda" or "cpu"

dataset_id = "lerobot/svla_so101_pickplace"

# This specifies the inputs the model will be expecting and the outputs it will produce
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
features = dataset_to_policy_features(dataset_metadata.features)

output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}

cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
policy = DiffusionPolicy(cfg)
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)

policy.train()
policy.to(device)

# To perform action chunking, ACT expects a given number of actions as targets
delta_timestamps = {
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
}

# add image features if they are present
delta_timestamps |= {
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
}

# Instantiate the dataset
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)

# Create the optimizer and dataloader for offline training
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
batch_size = 32
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=True,
)

# Number of training steps and logging frequency
training_steps = 1
log_freq = 1

# Run training loop
step = 0
done = False
while not done:
for batch in dataloader:
batch = preprocessor(batch)
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()

if step % log_freq == 0:
print(f"step: {step} loss: {loss.item():.3f}")
step += 1
if step >= training_steps:
done = True
break

# Save the policy checkpoint, alongside the pre/post processors
policy.save_pretrained(output_directory)
preprocessor.save_pretrained(output_directory)
postprocessor.save_pretrained(output_directory)

# Save all assets to the Hub
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
60 changes: 60 additions & 0 deletions examples/tutorial/diffusion/diffusion_using_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch

from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so100_follower.so100_follower import SO100Follower

device = torch.device("mps") # or "cuda" or "cpu"
model_id = "fracapuano/robot_learning_tutorial_diffusion"

model = DiffusionPolicy.from_pretrained(model_id)

dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
preprocess, postprocess = make_pre_post_processors(
model.config, model_id, dataset_stats=dataset_metadata.stats
)

MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20


# # find ports using lerobot-find-port
follower_port = ... # something like "/dev/tty.usbmodem58760431631"

# # the robot ids are used the load the right calibration files
follower_id = ... # something like "follower_so100"

# Robot and environment configuration
# Camera keys must match the name and resolutions of the ones used for training!
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
camera_config = {
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
}

robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
robot = SO100Follower(robot_cfg)
robot.connect()


for _ in range(MAX_EPISODES):
for _ in range(MAX_STEPS_PER_EPISODE):
obs = robot.get_observation()
obs_frame = build_inference_frame(
observation=obs, ds_features=dataset_metadata.features, device=device
)

obs = preprocess(obs_frame)

action = model.select_action(obs)
action = postprocess(action)
action = make_robot_action(action, dataset_metadata.features)
robot.send_action(action)

print("Episode finished! Starting new episode...")
Loading