-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Add API Examples #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add API Examples #2289
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
2749d9f
(unscrewing things up) (#2288)
fracapuano 272bdaa
fix: minor fix using named attributes
fracapuano ec4218d
fix: change model to act
fracapuano 80e87ed
fix: named attributes for inference frame building
fracapuano 9124772
fix: minor fixes to smolvla
fracapuano d2fc577
fix: small changes to pi0
fracapuano bd108af
remove: old file that should have never been committed (ups sorry sorry)
fracapuano File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| """This script demonstrates how to train ACT Policy on a real-world dataset.""" | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| from lerobot.configs.types import FeatureType | ||
| from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata | ||
| from lerobot.datasets.utils import dataset_to_policy_features | ||
| from lerobot.policies.act.configuration_act import ACTConfig | ||
| from lerobot.policies.act.modeling_act import ACTPolicy | ||
| from lerobot.policies.factory import make_pre_post_processors | ||
|
|
||
|
|
||
| def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]: | ||
| if delta_indices is None: | ||
| return [0] | ||
|
|
||
| return [i / fps for i in delta_indices] | ||
|
|
||
|
|
||
| output_directory = Path("outputs/robot_learning_tutorial/act") | ||
| output_directory.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # Select your device | ||
| device = torch.device("mps") # or "cuda" or "cpu" | ||
|
|
||
| dataset_id = "lerobot/svla_so101_pickplace" | ||
|
|
||
| # This specifies the inputs the model will be expecting and the outputs it will produce | ||
| dataset_metadata = LeRobotDatasetMetadata(dataset_id) | ||
| features = dataset_to_policy_features(dataset_metadata.features) | ||
|
|
||
| output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} | ||
| input_features = {key: ft for key, ft in features.items() if key not in output_features} | ||
|
|
||
| cfg = ACTConfig(input_features=input_features, output_features=output_features) | ||
| policy = ACTPolicy(cfg) | ||
| preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) | ||
|
|
||
| policy.train() | ||
| policy.to(device) | ||
|
|
||
| # To perform action chunking, ACT expects a given number of actions as targets | ||
| delta_timestamps = { | ||
| "action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps), | ||
| } | ||
|
|
||
| # add image features if they are present | ||
| delta_timestamps |= { | ||
| k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features | ||
| } | ||
|
|
||
| # Instantiate the dataset | ||
| dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps) | ||
|
|
||
| # Create the optimizer and dataloader for offline training | ||
| optimizer = cfg.get_optimizer_preset().build(policy.parameters()) | ||
| batch_size = 32 | ||
| dataloader = torch.utils.data.DataLoader( | ||
| dataset, | ||
| batch_size=batch_size, | ||
| shuffle=True, | ||
| pin_memory=device.type != "cpu", | ||
| drop_last=True, | ||
| ) | ||
|
|
||
| # Number of training steps and logging frequency | ||
| training_steps = 1 | ||
| log_freq = 1 | ||
|
|
||
| # Run training loop | ||
| step = 0 | ||
| done = False | ||
| while not done: | ||
| for batch in dataloader: | ||
| batch = preprocessor(batch) | ||
| loss, _ = policy.forward(batch) | ||
| loss.backward() | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
|
|
||
| if step % log_freq == 0: | ||
| print(f"step: {step} loss: {loss.item():.3f}") | ||
| step += 1 | ||
| if step >= training_steps: | ||
| done = True | ||
| break | ||
|
|
||
| # Save the policy checkpoint, alongside the pre/post processors | ||
| policy.save_pretrained(output_directory) | ||
| preprocessor.save_pretrained(output_directory) | ||
| postprocessor.save_pretrained(output_directory) | ||
|
|
||
| # Save all assets to the Hub | ||
| policy.push_to_hub("fracapuano/robot_learning_tutorial_act") | ||
| preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act") | ||
| postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import torch | ||
|
|
||
| from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig | ||
| from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata | ||
| from lerobot.policies.act.modeling_act import ACTPolicy | ||
| from lerobot.policies.factory import make_pre_post_processors | ||
| from lerobot.policies.utils import build_inference_frame, make_robot_action | ||
| from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig | ||
| from lerobot.robots.so100_follower.so100_follower import SO100Follower | ||
|
|
||
| device = torch.device("mps") # or "cuda" or "cpu" | ||
| model_id = "fracapuano/robot_learning_tutorial_act" | ||
| model = ACTPolicy.from_pretrained(model_id) | ||
|
|
||
| dataset_id = "lerobot/svla_so101_pickplace" | ||
| # This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets | ||
| dataset_metadata = LeRobotDatasetMetadata(dataset_id) | ||
| preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats) | ||
|
|
||
| # # find ports using lerobot-find-port | ||
| follower_port = ... # something like "/dev/tty.usbmodem58760431631" | ||
|
|
||
| # # the robot ids are used the load the right calibration files | ||
| follower_id = ... # something like "follower_so100" | ||
|
|
||
| MAX_EPISODES = 5 | ||
| MAX_STEPS_PER_EPISODE = 20 | ||
|
|
||
| # Robot and environment configuration | ||
| # Camera keys must match the name and resolutions of the ones used for training! | ||
| # You can check the camera keys expected by a model in the info.json card on the model card on the Hub | ||
| camera_config = { | ||
| "side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), | ||
| "up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30), | ||
| } | ||
|
|
||
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config) | ||
| robot = SO100Follower(robot_cfg) | ||
| robot.connect() | ||
|
|
||
| for _ in range(MAX_EPISODES): | ||
| for _ in range(MAX_STEPS_PER_EPISODE): | ||
| obs = robot.get_observation() | ||
| obs_frame = build_inference_frame( | ||
| observation=obs, ds_features=dataset_metadata.features, device=device | ||
| ) | ||
|
|
||
| obs = preprocess(obs_frame) | ||
|
|
||
| action = model.select_action(obs) | ||
| action = postprocess(action) | ||
|
|
||
| action = make_robot_action(action, dataset_metadata.features) | ||
|
|
||
| robot.send_action(action) | ||
|
|
||
| print("Episode finished! Starting new episode...") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from lerobot.async_inference.configs import PolicyServerConfig | ||
| from lerobot.async_inference.policy_server import serve | ||
|
|
||
| host = ... # something like "127.0.0.1" if you're exposing to localhost | ||
| port = ... # something like 8080 | ||
|
|
||
| config = PolicyServerConfig( | ||
| host=host, | ||
| port=port, | ||
| ) | ||
| serve(config) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import threading | ||
|
|
||
| from lerobot.async_inference.configs import RobotClientConfig | ||
| from lerobot.async_inference.helpers import visualize_action_queue_size | ||
| from lerobot.async_inference.robot_client import RobotClient | ||
| from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig | ||
| from lerobot.robots.so100_follower import SO100FollowerConfig | ||
|
|
||
| # these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras | ||
| # check the config.json on the Hub for the policy you are using to see the expected camera specs | ||
| camera_cfg = { | ||
| "up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), | ||
| "side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30), | ||
| } | ||
|
|
||
| # # find ports using lerobot-find-port | ||
| follower_port = ... # something like "/dev/tty.usbmodem58760431631" | ||
|
|
||
| # # the robot ids are used the load the right calibration files | ||
| follower_id = ... # something like "follower_so100" | ||
|
|
||
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg) | ||
|
|
||
| server_address = ... # something like "127.0.0.1:8080" if using localhost | ||
|
|
||
| # 3. Create client configuration | ||
| client_cfg = RobotClientConfig( | ||
| robot=robot_cfg, | ||
| server_address=server_address, | ||
| policy_device="mps", | ||
| policy_type="act", | ||
| pretrained_name_or_path="fracapuano/robot_learning_tutorial_act", | ||
| chunk_size_threshold=0.5, # g | ||
| actions_per_chunk=50, # make sure this is less than the max actions of the policy | ||
| ) | ||
|
|
||
| # 4. Create and start client | ||
| client = RobotClient(client_cfg) | ||
|
|
||
| # 5. Provide a textual description of the task | ||
| task = ... | ||
|
|
||
| if client.start(): | ||
| # Start action receiver thread | ||
| action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True) | ||
| action_receiver_thread.start() | ||
|
|
||
| try: | ||
| # Run the control loop | ||
| client.control_loop(task) | ||
| except KeyboardInterrupt: | ||
| client.stop() | ||
| action_receiver_thread.join() | ||
| # (Optionally) plot the action queue size | ||
| visualize_action_queue_size(client.action_queue_size) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| """This script demonstrates how to train Diffusion Policy on a real-world dataset.""" | ||
|
|
||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| from lerobot.configs.types import FeatureType | ||
| from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata | ||
| from lerobot.datasets.utils import dataset_to_policy_features | ||
| from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig | ||
| from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy | ||
| from lerobot.policies.factory import make_pre_post_processors | ||
|
|
||
|
|
||
| def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]: | ||
| if delta_indices is None: | ||
| return [0] | ||
|
|
||
| return [i / fps for i in delta_indices] | ||
|
|
||
|
|
||
| output_directory = Path("outputs/robot_learning_tutorial/diffusion") | ||
| output_directory.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| # Select your device | ||
| device = torch.device("mps") # or "cuda" or "cpu" | ||
|
|
||
| dataset_id = "lerobot/svla_so101_pickplace" | ||
|
|
||
| # This specifies the inputs the model will be expecting and the outputs it will produce | ||
| dataset_metadata = LeRobotDatasetMetadata(dataset_id) | ||
| features = dataset_to_policy_features(dataset_metadata.features) | ||
|
|
||
| output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} | ||
| input_features = {key: ft for key, ft in features.items() if key not in output_features} | ||
|
|
||
| cfg = DiffusionConfig(input_features=input_features, output_features=output_features) | ||
| policy = DiffusionPolicy(cfg) | ||
| preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats) | ||
|
|
||
| policy.train() | ||
| policy.to(device) | ||
|
|
||
| # To perform action chunking, ACT expects a given number of actions as targets | ||
| delta_timestamps = { | ||
| "observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps), | ||
| "action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps), | ||
| } | ||
|
|
||
| # add image features if they are present | ||
| delta_timestamps |= { | ||
| k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features | ||
| } | ||
|
|
||
| # Instantiate the dataset | ||
| dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps) | ||
|
|
||
| # Create the optimizer and dataloader for offline training | ||
| optimizer = cfg.get_optimizer_preset().build(policy.parameters()) | ||
| batch_size = 32 | ||
| dataloader = torch.utils.data.DataLoader( | ||
| dataset, | ||
| batch_size=batch_size, | ||
| shuffle=True, | ||
| pin_memory=device.type != "cpu", | ||
| drop_last=True, | ||
| ) | ||
|
|
||
| # Number of training steps and logging frequency | ||
| training_steps = 1 | ||
| log_freq = 1 | ||
|
|
||
| # Run training loop | ||
| step = 0 | ||
| done = False | ||
| while not done: | ||
| for batch in dataloader: | ||
| batch = preprocessor(batch) | ||
| loss, _ = policy.forward(batch) | ||
| loss.backward() | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
|
|
||
| if step % log_freq == 0: | ||
| print(f"step: {step} loss: {loss.item():.3f}") | ||
| step += 1 | ||
| if step >= training_steps: | ||
| done = True | ||
| break | ||
|
|
||
| # Save the policy checkpoint, alongside the pre/post processors | ||
| policy.save_pretrained(output_directory) | ||
| preprocessor.save_pretrained(output_directory) | ||
| postprocessor.save_pretrained(output_directory) | ||
|
|
||
| # Save all assets to the Hub | ||
| policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion") | ||
| preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion") | ||
| postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import torch | ||
|
|
||
| from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig | ||
| from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata | ||
| from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy | ||
| from lerobot.policies.factory import make_pre_post_processors | ||
| from lerobot.policies.utils import build_inference_frame, make_robot_action | ||
| from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig | ||
| from lerobot.robots.so100_follower.so100_follower import SO100Follower | ||
|
|
||
| device = torch.device("mps") # or "cuda" or "cpu" | ||
| model_id = "fracapuano/robot_learning_tutorial_diffusion" | ||
|
|
||
| model = DiffusionPolicy.from_pretrained(model_id) | ||
|
|
||
| dataset_id = "lerobot/svla_so101_pickplace" | ||
| # This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets | ||
| dataset_metadata = LeRobotDatasetMetadata(dataset_id) | ||
| preprocess, postprocess = make_pre_post_processors( | ||
| model.config, model_id, dataset_stats=dataset_metadata.stats | ||
| ) | ||
|
|
||
| MAX_EPISODES = 5 | ||
| MAX_STEPS_PER_EPISODE = 20 | ||
|
|
||
|
|
||
| # # find ports using lerobot-find-port | ||
| follower_port = ... # something like "/dev/tty.usbmodem58760431631" | ||
|
|
||
| # # the robot ids are used the load the right calibration files | ||
| follower_id = ... # something like "follower_so100" | ||
|
|
||
| # Robot and environment configuration | ||
| # Camera keys must match the name and resolutions of the ones used for training! | ||
| # You can check the camera keys expected by a model in the info.json card on the model card on the Hub | ||
| camera_config = { | ||
| "side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30), | ||
| "up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30), | ||
| } | ||
|
|
||
| robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config) | ||
| robot = SO100Follower(robot_cfg) | ||
| robot.connect() | ||
|
|
||
|
|
||
| for _ in range(MAX_EPISODES): | ||
| for _ in range(MAX_STEPS_PER_EPISODE): | ||
| obs = robot.get_observation() | ||
| obs_frame = build_inference_frame( | ||
| observation=obs, ds_features=dataset_metadata.features, device=device | ||
| ) | ||
|
|
||
| obs = preprocess(obs_frame) | ||
|
|
||
| action = model.select_action(obs) | ||
| action = postprocess(action) | ||
| action = make_robot_action(action, dataset_metadata.features) | ||
| robot.send_action(action) | ||
|
|
||
| print("Episode finished! Starting new episode...") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.