From 92e04002f2d9e5987879509c95593b6e699827f1 Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Fri, 10 Oct 2025 11:33:01 +0800 Subject: [PATCH 1/6] chore: Add policies_compilation directory --- benchmarks/policies_compilation/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 benchmarks/policies_compilation/.gitkeep diff --git a/benchmarks/policies_compilation/.gitkeep b/benchmarks/policies_compilation/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 From 46edcb8b1119726145f804a0d4a508ec01e7626c Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Sat, 18 Oct 2025 22:00:39 +0800 Subject: [PATCH 2/6] feat: Add torch.compile benchmark script Adds `benchmark_inference_compile_lerobot.py` from https://gist.github.com/AdilZouitine/3574664e4cf71605986b49e9148d29ab. --- .../benchmark_inference_compile_lerobot.py | 655 ++++++++++++++++++ 1 file changed, 655 insertions(+) create mode 100644 benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py diff --git a/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py b/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py new file mode 100644 index 0000000000..6ed7e3caf8 --- /dev/null +++ b/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python +""" +Comprehensive torch.compile benchmark for LeRobot policies. + +Compares main branch performance with torch.compile optimized version. + +Usage: + python benchmark_compile_safe.py --policy act --device cuda + python benchmark_compile_safe.py --policy diffusion --device cpu +""" + +import argparse +import json +import time +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata + +# Safe imports for all processor functions +from lerobot.policies.act.processor_act import make_act_pre_post_processors +from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors +from lerobot.policies.factory import make_policy, make_policy_config +from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors +from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors +from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors +from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors +from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors + + +class PolicyConfig: + """Type-safe policy configuration""" + + def __init__( + self, + processor_func: Callable, + config_kwargs: dict[str, Any], + delta_timestamps_func: Callable[[Any, float], dict[str, list[float]]], + ): + self.processor_func = processor_func + self.config_kwargs = config_kwargs + self.delta_timestamps_func = delta_timestamps_func + + +# Safe policy configuration mapping +POLICY_CONFIGS: dict[str, PolicyConfig] = { + "act": PolicyConfig( + processor_func=make_act_pre_post_processors, + config_kwargs={"use_vae": False, "chunk_size": 100, "n_action_steps": 10}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, + ), + "diffusion": PolicyConfig( + processor_func=make_diffusion_pre_post_processors, + config_kwargs={"n_obs_steps": 2, "horizon": 16, "n_action_steps": 8}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.n_action_steps)]}, + ), + "pi0": PolicyConfig( + processor_func=make_pi0_pre_post_processors, + config_kwargs={"n_obs_steps": 1, "chunk_size": 50, "n_action_steps": 50}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, + ), + "pi0fast": PolicyConfig( + processor_func=make_pi0fast_pre_post_processors, + config_kwargs={"n_obs_steps": 1, "chunk_size": 50, "n_action_steps": 50}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, + ), + "tdmpc": PolicyConfig( + processor_func=make_tdmpc_pre_post_processors, + config_kwargs={"horizon": 5, "n_action_steps": 1}, + delta_timestamps_func=lambda cfg, fps: {"action": [0]}, + ), + "vqbet": PolicyConfig( + processor_func=make_vqbet_pre_post_processors, + config_kwargs={"n_obs_steps": 1, "chunk_size": 8, "n_action_steps": 8}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, + ), + "smolvla": PolicyConfig( + processor_func=make_smolvla_pre_post_processors, + config_kwargs={"n_obs_steps": 1, "chunk_size": 50, "n_action_steps": 50}, + delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, + ), +} + + +class TorchCompileBenchmark: + """Safe torch.compile benchmark for LeRobot policies""" + + def __init__(self, policy_name: str, device: str, repo_id: str = "AdilZtn/grab_red_cube_test_25"): + self.policy_name = policy_name + self.device = torch.device(device) + self.repo_id = repo_id + + # Benchmark parameters + self.n_inference = 100 + self.n_training = 50 + self.batch_size = 8 + self.warmup_steps = 10 + self.tolerance = 1e-5 + + print(f"šŸ¤– Torch.compile Benchmark for {policy_name.upper()} Policy") + print(f"Device: {device}") + print(f"PyTorch version: {torch.__version__}") + print(f"Dataset: {repo_id}") + print("=" * 60) + + def setup_policy_and_data(self) -> tuple[Any, dict, Any, Any]: + """Setup policy, data, and processors""" + torch.manual_seed(42) + np.random.seed(42) + + # Load dataset metadata + ds_meta = LeRobotDatasetMetadata(self.repo_id) + + # Get policy configuration safely + if self.policy_name not in POLICY_CONFIGS: + available_policies = list(POLICY_CONFIGS.keys()) + raise ValueError(f"Policy '{self.policy_name}' not supported. Available: {available_policies}") + + policy_config = POLICY_CONFIGS[self.policy_name] + + # Create policy configuration + cfg = make_policy_config(self.policy_name, device=str(self.device), **policy_config.config_kwargs) + + # Create policy + policy = make_policy(cfg, ds_meta=ds_meta) + policy.to(self.device) + + # Create processors using the safe function reference + preprocessor, postprocessor = policy_config.processor_func(cfg, dataset_stats=ds_meta.stats) + + # Setup dataset with appropriate delta_timestamps + delta_timestamps = policy_config.delta_timestamps_func(cfg, ds_meta.fps) + dataset = LeRobotDataset(self.repo_id, episodes=[0], delta_timestamps=delta_timestamps) + dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, drop_last=True) + + # Get sample batch + sample_batch = next(iter(dataloader)) + + # Move to device + for key in sample_batch: + if isinstance(sample_batch[key], torch.Tensor): + sample_batch[key] = sample_batch[key].to(self.device) + + # Preprocess + sample_batch = preprocessor(sample_batch) + + return policy, sample_batch, preprocessor, postprocessor + + def benchmark_inference(self, policy, batch) -> tuple[float, torch.Tensor]: + """Benchmark inference performance""" + policy.eval() + + # Warmup + with torch.no_grad(): + for _ in range(self.warmup_steps): + _ = policy.select_action(batch) + if hasattr(policy, "reset"): + policy.reset() + + # Benchmark + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start_time = time.perf_counter() + + with torch.no_grad(): + for _ in range(self.n_inference): + action = policy.select_action(batch) + if hasattr(policy, "reset"): + policy.reset() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.perf_counter() + avg_time = (end_time - start_time) / self.n_inference * 1000 # ms + + return avg_time, action + + def benchmark_training(self, policy, batch) -> tuple[float, list[float], list[float]]: + """Benchmark training performance""" + policy.train() + optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4) + + # Warmup + for _ in range(self.warmup_steps): + optimizer.zero_grad() + loss, _ = policy.forward(batch) + loss.backward() + optimizer.step() + if hasattr(policy, "reset"): + policy.reset() + + # Benchmark + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start_time = time.perf_counter() + losses = [] + grad_norms = [] + + for _ in range(self.n_training): + optimizer.zero_grad() + loss, loss_dict = policy.forward(batch) + loss.backward() + + # Calculate gradient norm safely + grad_tensors = [p.grad.detach() for p in policy.parameters() if p.grad is not None] + if grad_tensors: + total_norm = torch.norm(torch.stack([torch.norm(g) for g in grad_tensors])) + else: + total_norm = torch.tensor(0.0) + + optimizer.step() + + losses.append(loss.item()) + grad_norms.append(total_norm.item()) + + if hasattr(policy, "reset"): + policy.reset() + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + end_time = time.perf_counter() + avg_time = (end_time - start_time) / self.n_training * 1000 # ms + + return avg_time, losses, grad_norms + + def test_correctness(self, policy_original, policy_compiled, batch) -> dict[str, Any]: + """Test numerical correctness between original and compiled versions""" + print("šŸ” Testing correctness...") + + # Test inference + policy_original.eval() + policy_compiled.eval() + + with torch.no_grad(): + # Reset both policies to same state + if hasattr(policy_original, "reset"): + policy_original.reset() + if hasattr(policy_compiled, "reset"): + policy_compiled.reset() + + action_original = policy_original.select_action(batch) + action_compiled = policy_compiled.select_action(batch) + + action_diff = torch.abs(action_original - action_compiled).max().item() + inference_correct = action_diff < self.tolerance + + # Test training forward pass + policy_original.train() + policy_compiled.train() + + # Use same random state + torch.manual_seed(42) + loss_original, _ = policy_original.forward(batch) + + torch.manual_seed(42) + loss_compiled, _ = policy_compiled.forward(batch) + + loss_diff = torch.abs(loss_original - loss_compiled).item() + training_correct = loss_diff < self.tolerance + + return { + "inference_correct": inference_correct, + "training_correct": training_correct, + "action_diff": action_diff, + "loss_diff": loss_diff, + } + + def generate_report(self, results: dict[str, Any]) -> str: + """Generate comprehensive markdown report including failures""" + + # Handle cases where benchmarks might have failed + inference_time_orig = results.get("time_original_inference", float("inf")) + inference_time_comp = results.get("time_compiled_inference", float("inf")) + training_time_orig = results.get("time_original_training", float("inf")) + training_time_comp = results.get("time_compiled_training", float("inf")) + + speedup_inference = results.get("speedup_inference", 0.0) + speedup_training = results.get("speedup_training", 0.0) + + compilation_successful = results.get("compilation_successful", False) + correctness_passed = results.get("correctness_passed", False) + inference_benchmarked = results.get("inference_benchmarked", False) + training_benchmarked = results.get("training_benchmarked", False) + + report = f"""# Torch.compile Benchmark Report: {self.policy_name.upper()} + +## Environment +- **Policy**: {self.policy_name} +- **Device**: {self.device} +- **PyTorch**: {torch.__version__} +- **Dataset**: {self.repo_id} +- **Batch Size**: {self.batch_size} +- **Benchmark Parameters**: {self.n_inference} inference runs, {self.n_training} training runs + +## šŸ”§ Compilation Results +- **Status**: {"āœ… SUCCESS" if compilation_successful else "āŒ FAILED"}""" + + if not compilation_successful: + compilation_error = results.get("compilation_error", "Unknown error") + report += f""" +- **Error**: `{compilation_error}` +- **Impact**: Policy will fall back to eager execution""" + + report += f""" + +## šŸŽÆ Correctness Results +- **Status**: {"āœ… PASSED" if correctness_passed else "āŒ FAILED"} +- **Inference**: {"PASSED" if results["correctness"]["inference_correct"] else "FAILED"} +- **Training**: {"PASSED" if results["correctness"]["training_correct"] else "FAILED"} + +### Detailed Differences +- **Max Action Difference**: {results["correctness"]["action_diff"]:.2e} (threshold: {self.tolerance:.2e}) +- **Loss Difference**: {results["correctness"]["loss_diff"]:.2e} (threshold: {self.tolerance:.2e})""" + + if not correctness_passed: + action_diff = results["correctness"]["action_diff"] + loss_diff = results["correctness"]["loss_diff"] + + report += f""" + +### āš ļø Correctness Analysis +- **Action diff magnitude**: {action_diff:.2e} ({"SEVERE" if action_diff > 1e-3 else "MODERATE" if action_diff > 1e-4 else "MINOR"}) +- **Loss diff magnitude**: {loss_diff:.2e} ({"SEVERE" if loss_diff > 1e-3 else "MODERATE" if loss_diff > 1e-4 else "MINOR"}) +- **Likely causes**: Graph breaks, dynamic shapes, numerical precision issues""" + + report += """ + +## ⚔ Performance Results + +### Inference Performance""" + + if inference_benchmarked: + report += f""" +- **Original**: {inference_time_orig:.2f} ms/iter +- **Compiled**: {inference_time_comp:.2f} ms/iter +- **šŸš€ Speedup**: {speedup_inference:.2f}x""" + + if speedup_inference < 1.0: + report += " (āš ļø SLOWDOWN)" + elif speedup_inference < 1.1: + report += " (āš ļø INSUFFICIENT)" + else: + report += """ +- **Status**: āŒ Benchmark failed +- **Reason**: Could not complete inference timing""" + + report += """ + +### Training Performance""" + + if training_benchmarked: + report += f""" +- **Original**: {training_time_orig:.2f} ms/iter +- **Compiled**: {training_time_comp:.2f} ms/iter +- **šŸš€ Speedup**: {speedup_training:.2f}x""" + + if speedup_training < 1.0: + report += " (āš ļø SLOWDOWN)" + elif speedup_training < 1.1: + report += " (āš ļø INSUFFICIENT)" + else: + report += """ +- **Status**: āŒ Benchmark failed +- **Reason**: Could not complete training timing""" + + # Consistency metrics if available + loss_consistency = results.get("loss_consistency", float("inf")) + grad_norm_consistency = results.get("grad_norm_consistency", float("inf")) + + if loss_consistency != float("inf") and grad_norm_consistency != float("inf"): + report += f""" + +### Consistency Metrics +- **Average Loss Difference**: {loss_consistency:.2e} +- **Average Grad Norm Difference**: {grad_norm_consistency:.2e}""" + + report += f""" + +## šŸ“‹ Success Criteria Analysis +- **āœ… Compilation**: {"PASSED" if compilation_successful else "FAILED"} +- **āœ… Correctness**: {"PASSED" if correctness_passed else "FAILED"} +- **āœ… Performance**: {"PASSED" if speedup_inference > 1.1 and speedup_training > 1.1 else "FAILED"} +- **āœ… Benchmarking**: {"PASSED" if inference_benchmarked and training_benchmarked else "FAILED"} + +## šŸŽÆ Overall Result +{"šŸŽÆ SUCCESS: Policy is torch.compile compatible!" if results["success"] else "āŒ NEEDS WORK: torch.compile not yet functional"} + +## šŸ› ļø Next Steps""" + + if not compilation_successful: + report += """ +1. **Fix compilation errors** - Enable torch._dynamo verbose mode for details +2. **Identify graph breaks** - Look for .item(), dynamic shapes, control flow +3. **Test incrementally** - Fix one issue at a time""" + elif not correctness_passed: + report += """ +1. **Debug numerical differences** - Check for precision issues +2. **Verify tensor operations** - Ensure deterministic behavior +3. **Test with smaller tolerance** - May be acceptable for some use cases""" + elif not (speedup_inference > 1.1 and speedup_training > 1.1): + report += """ +1. **Profile bottlenecks** - Use torch.profiler to identify slow ops +2. **Optimize compilation mode** - Try 'reduce-overhead' or 'max-autotune' modes +3. **Check graph breaks** - Even successful compilation can have breaks""" + else: + report += """ +1. **Document changes** - Create clear PR with benchmark results +2. **Add tests** - Ensure torch.compile compatibility is maintained +3. **Consider edge cases** - Test with different batch sizes, VAE modes, etc.""" + + report += f""" + +## šŸ” Raw Data +```json +{json.dumps(results, indent=2)} +``` +""" + return report + + def run_benchmark(self) -> dict[str, Any]: + """Run complete benchmark suite""" + + # Setup + print("šŸ“¦ Setting up policy and data...") + try: + policy, batch, preprocessor, postprocessor = self.setup_policy_and_data() + except Exception as e: + return {"success": False, "error": f"Setup failed: {str(e)}"} + + # Test compilation + print("šŸ”§ Testing torch.compile...") + compilation_error = None + policy_compiled = None + + try: + # Enable verbose compilation debugging + import torch._dynamo + + torch._dynamo.config.verbose = True + torch._dynamo.config.suppress_errors = False + + policy_compiled = torch.compile(policy, mode="default") + + # Force compilation by running once + policy_compiled.eval() + with torch.no_grad(): + _ = policy_compiled.select_action(batch) + + print("āœ… Compilation successful!") + + except Exception as e: + compilation_error = str(e) + print(f"āŒ Compilation failed: {e}") + # Continue with analysis even if compilation fails + policy_compiled = policy # Use original policy for comparison + + # Test correctness (always run this for analysis) + print("šŸ” Testing correctness...") + correctness = self.test_correctness(policy, policy_compiled, batch) + + # Print detailed correctness results + print(f" Inference correct: {correctness['inference_correct']}") + print(f" Training correct: {correctness['training_correct']}") + print(f" Max action difference: {correctness['action_diff']:.2e}") + print(f" Loss difference: {correctness['loss_diff']:.2e}") + + correctness_passed = correctness["inference_correct"] and correctness["training_correct"] + if correctness_passed: + print("āœ… Correctness tests passed!") + else: + print("āš ļø Correctness test failed - continuing with timing analysis...") + + # Always run timing benchmarks for analysis + print(f"šŸš€ Benchmarking inference ({self.n_inference} runs)...") + try: + time_orig_inf, _ = self.benchmark_inference(policy, batch) + time_comp_inf, _ = self.benchmark_inference(policy_compiled, batch) + inference_benchmarked = True + except Exception as e: + print(f"āŒ Inference benchmark failed: {e}") + time_orig_inf = time_comp_inf = float("inf") + inference_benchmarked = False + + print(f"šŸš€ Benchmarking training ({self.n_training} runs)...") + try: + time_orig_train, losses_orig, grad_norms_orig = self.benchmark_training(policy, batch) + time_comp_train, losses_comp, grad_norms_comp = self.benchmark_training(policy_compiled, batch) + training_benchmarked = True + + loss_consistency = ( + np.mean(np.abs(np.array(losses_orig) - np.array(losses_comp))) + if losses_orig and losses_comp + else float("inf") + ) + grad_norm_consistency = ( + np.mean(np.abs(np.array(grad_norms_orig) - np.array(grad_norms_comp))) + if grad_norms_orig and grad_norms_comp + else float("inf") + ) + except Exception as e: + print(f"āŒ Training benchmark failed: {e}") + time_orig_train = time_comp_train = float("inf") + loss_consistency = grad_norm_consistency = float("inf") + training_benchmarked = False + + # Calculate speedups (handle inf cases) + speedup_inference = ( + time_orig_inf / time_comp_inf if inference_benchmarked and time_comp_inf > 0 else 0.0 + ) + speedup_training = ( + time_orig_train / time_comp_train if training_benchmarked and time_comp_train > 0 else 0.0 + ) + + # Success criteria + compilation_successful = compilation_error is None + performance_good = speedup_inference > 1.1 and speedup_training > 1.1 + + success = ( + compilation_successful + and correctness_passed + and performance_good + and inference_benchmarked + and training_benchmarked + ) + + results = { + "success": success, + "policy": self.policy_name, + "device": str(self.device), + "pytorch_version": torch.__version__, + # Compilation results + "compilation_successful": compilation_successful, + "compilation_error": compilation_error, + # Correctness results + "correctness": correctness, + "correctness_passed": correctness_passed, + # Timing results + "inference_benchmarked": inference_benchmarked, + "training_benchmarked": training_benchmarked, + "time_original_inference": time_orig_inf, + "time_compiled_inference": time_comp_inf, + "speedup_inference": speedup_inference, + "time_original_training": time_orig_train, + "time_compiled_training": time_comp_train, + "speedup_training": speedup_training, + # Consistency metrics + "loss_consistency": loss_consistency, + "grad_norm_consistency": grad_norm_consistency, + } + + # Print detailed summary + print("\nšŸ“Š DETAILED RESULTS SUMMARY") + print("=" * 60) + print(f"šŸ”§ Compilation: {'āœ… SUCCESS' if compilation_successful else 'āŒ FAILED'}") + if compilation_error: + print(f" Error: {compilation_error}") + + print(f"šŸŽÆ Correctness: {'āœ… PASSED' if correctness_passed else 'āŒ FAILED'}") + print(f" Inference diff: {correctness['action_diff']:.2e} (threshold: {self.tolerance:.2e})") + print(f" Training diff: {correctness['loss_diff']:.2e} (threshold: {self.tolerance:.2e})") + + print("⚔ Performance:") + if inference_benchmarked: + print(f" Inference: {time_orig_inf:.2f}ms → {time_comp_inf:.2f}ms ({speedup_inference:.2f}x)") + else: + print(" Inference: āŒ Benchmark failed") + + if training_benchmarked: + print( + f" Training: {time_orig_train:.2f}ms → {time_comp_train:.2f}ms ({speedup_training:.2f}x)" + ) + else: + print(" Training: āŒ Benchmark failed") + + print(f"šŸ† Overall: {'āœ… SUCCESS' if success else 'āŒ NEEDS WORK'}") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Torch.compile benchmark for LeRobot policies") + parser.add_argument( + "--policy", required=True, choices=list(POLICY_CONFIGS.keys()), help="Policy to benchmark" + ) + parser.add_argument( + "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run on" + ) + parser.add_argument("--output", help="Output file for results (optional)") + parser.add_argument("--n-inference", type=int, default=100, help="Number of inference runs") + parser.add_argument("--n-training", type=int, default=50, help="Number of training runs") + parser.add_argument("--batch-size", type=int, default=8, help="Batch size to use") + + args = parser.parse_args() + + # Run benchmark + benchmark = TorchCompileBenchmark(args.policy, args.device) + + # Override default parameters if provided + if args.n_inference: + benchmark.n_inference = args.n_inference + if args.n_training: + benchmark.n_training = args.n_training + if args.batch_size: + benchmark.batch_size = args.batch_size + + results = benchmark.run_benchmark() + + # Always generate a report (even for failures) + if "error" in results and len(results) == 2: # Only basic error info + # Simple failure case - create basic report + simple_report = f"""# Benchmark Failed: {benchmark.policy_name.upper()} + +## Error +{results.get("error", "Unknown error")} + +## Environment +- **Policy**: {benchmark.policy_name} +- **Device**: {benchmark.device} +- **PyTorch**: {torch.__version__} + +## Next Steps +1. Check the error message above +2. Ensure the policy and dataset are properly configured +3. Try with a simpler configuration first +""" + if args.output: + with open(args.output, "w") as f: + f.write(simple_report) + print(f"\nšŸ’¾ Basic error report saved to {args.output}") + else: + print("\nšŸ“ BASIC ERROR REPORT:") + print(simple_report) + else: + # Full benchmark was attempted - generate comprehensive report + report = benchmark.generate_report(results) + + if args.output: + with open(args.output, "w") as f: + f.write(report) + print(f"\nšŸ’¾ Comprehensive report saved to {args.output}") + else: + print("\nšŸ“ COMPREHENSIVE REPORT:") + print(report) + + +if __name__ == "__main__": + main() From 545aed6234fc6178191f17098c3150529b9bd453 Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Sat, 18 Oct 2025 22:53:59 +0800 Subject: [PATCH 3/6] fix: Remove deprecated pi0fast from benchmark script --- .../benchmark_inference_compile_lerobot.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py b/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py index 6ed7e3caf8..ca3bd4407d 100644 --- a/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py +++ b/benchmarks/policies_compilation/benchmark_inference_compile_lerobot.py @@ -26,7 +26,6 @@ from lerobot.policies.diffusion.processor_diffusion import make_diffusion_pre_post_processors from lerobot.policies.factory import make_policy, make_policy_config from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors -from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors from lerobot.policies.smolvla.processor_smolvla import make_smolvla_pre_post_processors from lerobot.policies.tdmpc.processor_tdmpc import make_tdmpc_pre_post_processors from lerobot.policies.vqbet.processor_vqbet import make_vqbet_pre_post_processors @@ -63,11 +62,6 @@ def __init__( config_kwargs={"n_obs_steps": 1, "chunk_size": 50, "n_action_steps": 50}, delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, ), - "pi0fast": PolicyConfig( - processor_func=make_pi0fast_pre_post_processors, - config_kwargs={"n_obs_steps": 1, "chunk_size": 50, "n_action_steps": 50}, - delta_timestamps_func=lambda cfg, fps: {"action": [i / fps for i in range(cfg.chunk_size)]}, - ), "tdmpc": PolicyConfig( processor_func=make_tdmpc_pre_post_processors, config_kwargs={"horizon": 5, "n_action_steps": 1}, From 67e341e3c780bf2bfbbda6e45deb227cdfbcc6d0 Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Sat, 18 Oct 2025 23:01:50 +0800 Subject: [PATCH 4/6] chore: Remove .gitkeep file from policies_compilation directory --- benchmarks/policies_compilation/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 benchmarks/policies_compilation/.gitkeep diff --git a/benchmarks/policies_compilation/.gitkeep b/benchmarks/policies_compilation/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 From febc052a75abc1d46c436a7cdf23baddffbc5f1d Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Sat, 18 Oct 2025 23:09:47 +0800 Subject: [PATCH 5/6] feat: Add baseline benchmark report for ACT policy --- .../baseline_act_report.md | 97 +++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 benchmarks/policies_compilation/baseline_act_report.md diff --git a/benchmarks/policies_compilation/baseline_act_report.md b/benchmarks/policies_compilation/baseline_act_report.md new file mode 100644 index 0000000000..95da26a568 --- /dev/null +++ b/benchmarks/policies_compilation/baseline_act_report.md @@ -0,0 +1,97 @@ +# Torch.compile Benchmark Report: ACT + +## Environment + +- **Policy**: act +- **Device**: cuda +- **PyTorch**: 2.7.1+cu126 +- **Dataset**: AdilZtn/grab_red_cube_test_25 +- **Batch Size**: 8 +- **Benchmark Parameters**: 100 inference runs, 50 training runs + +## šŸ”§ Compilation Results + +- **Status**: āœ… SUCCESS + +## šŸŽÆ Correctness Results + +- **Status**: āŒ FAILED +- **Inference**: FAILED +- **Training**: FAILED + +### Detailed Differences + +- **Max Action Difference**: 5.57e-02 (threshold: 1.00e-05) +- **Loss Difference**: 8.93e-05 (threshold: 1.00e-05) + +### āš ļø Correctness Analysis + +- **Action diff magnitude**: 5.57e-02 (SEVERE) +- **Loss diff magnitude**: 8.93e-05 (MINOR) +- **Likely causes**: Graph breaks, dynamic shapes, numerical precision issues + +## ⚔ Performance Results + +### Inference Performance + +- **Original**: 22.14 ms/iter +- **Compiled**: 22.95 ms/iter +- **šŸš€ Speedup**: 0.96x (āš ļø SLOWDOWN) + +### Training Performance + +- **Original**: 68.74 ms/iter +- **Compiled**: 61.31 ms/iter +- **šŸš€ Speedup**: 1.12x + +### Consistency Metrics + +- **Average Loss Difference**: 6.78e-03 +- **Average Grad Norm Difference**: 1.53e+00 + +## šŸ“‹ Success Criteria Analysis + +- **āœ… Compilation**: PASSED +- **āœ… Correctness**: FAILED +- **āœ… Performance**: FAILED +- **āœ… Benchmarking**: PASSED + +## šŸŽÆ Overall Result + +āŒ NEEDS WORK: torch.compile not yet functional + +## šŸ› ļø Next Steps + +1. **Debug numerical differences** - Check for precision issues +2. **Verify tensor operations** - Ensure deterministic behavior +3. **Test with smaller tolerance** - May be acceptable for some use cases + +## šŸ” Raw Data + +```json +{ + "success": false, + "policy": "act", + "device": "cuda", + "pytorch_version": "2.7.1+cu126", + "compilation_successful": true, + "compilation_error": null, + "correctness": { + "inference_correct": false, + "training_correct": false, + "action_diff": 0.05568695068359375, + "loss_diff": 8.934736251831055e-5 + }, + "correctness_passed": false, + "inference_benchmarked": true, + "training_benchmarked": true, + "time_original_inference": 22.140723338816315, + "time_compiled_inference": 22.948127458803356, + "speedup_inference": 0.9648161218628187, + "time_original_training": 68.73527298215777, + "time_compiled_training": 61.31022967863828, + "speedup_training": 1.1211061081068912, + "loss_consistency": 0.0067824447154998775, + "grad_norm_consistency": 1.5347092628479004 +} +``` From 631a55a3745dc6017368f5d2516f2660e2a595d0 Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Mon, 20 Oct 2025 09:15:42 +0800 Subject: [PATCH 6/6] fix: Remove .item() call in act policy forward() to prevent torch.compile graph break Removed .item() calls from loss_dict in forward() to avoid breaking the torch.compile computation graph. The tensor-to-scalar conversion is now handled in the training script instead. --- .../baseline_act_report.md | 30 +++++++++---------- src/lerobot/policies/act/modeling_act.py | 4 +-- src/lerobot/scripts/lerobot_train.py | 1 + 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/benchmarks/policies_compilation/baseline_act_report.md b/benchmarks/policies_compilation/baseline_act_report.md index 95da26a568..14e55412ff 100644 --- a/benchmarks/policies_compilation/baseline_act_report.md +++ b/benchmarks/policies_compilation/baseline_act_report.md @@ -34,20 +34,20 @@ ### Inference Performance -- **Original**: 22.14 ms/iter -- **Compiled**: 22.95 ms/iter -- **šŸš€ Speedup**: 0.96x (āš ļø SLOWDOWN) +- **Original**: 21.75 ms/iter +- **Compiled**: 21.46 ms/iter +- **šŸš€ Speedup**: 1.01x (āš ļø INSUFFICIENT) ### Training Performance -- **Original**: 68.74 ms/iter -- **Compiled**: 61.31 ms/iter +- **Original**: 68.59 ms/iter +- **Compiled**: 61.15 ms/iter - **šŸš€ Speedup**: 1.12x ### Consistency Metrics -- **Average Loss Difference**: 6.78e-03 -- **Average Grad Norm Difference**: 1.53e+00 +- **Average Loss Difference**: 4.87e-03 +- **Average Grad Norm Difference**: 1.60e+00 ## šŸ“‹ Success Criteria Analysis @@ -85,13 +85,13 @@ "correctness_passed": false, "inference_benchmarked": true, "training_benchmarked": true, - "time_original_inference": 22.140723338816315, - "time_compiled_inference": 22.948127458803356, - "speedup_inference": 0.9648161218628187, - "time_original_training": 68.73527298215777, - "time_compiled_training": 61.31022967863828, - "speedup_training": 1.1211061081068912, - "loss_consistency": 0.0067824447154998775, - "grad_norm_consistency": 1.5347092628479004 + "time_original_inference": 21.745667571667582, + "time_compiled_inference": 21.46407115040347, + "speedup_inference": 1.0131194319703334, + "time_original_training": 68.5850445041433, + "time_compiled_training": 61.15469123702496, + "speedup_training": 1.1215009530228772, + "loss_consistency": 0.004866420030593872, + "grad_norm_consistency": 1.599840692281723 } ``` diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index b7cbcd0610..f23a1cf7df 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -144,7 +144,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() - loss_dict = {"l1_loss": l1_loss.item()} + loss_dict = {"l1_loss": l1_loss} if self.config.use_vae: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total @@ -153,7 +153,7 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) - loss_dict["kld_loss"] = mean_kld.item() + loss_dict["kld_loss"] = mean_kld loss = l1_loss + mean_kld * self.config.kl_weight else: loss = l1_loss diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 84eb81ad43..288ac9e577 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -89,6 +89,7 @@ def update_policy( # Let accelerator handle mixed precision with accelerator.autocast(): loss, output_dict = policy.forward(batch) + output_dict = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in output_dict.items()} # TODO(rcadene): policy.unnormalize_outputs(out_dict) # Use accelerator's backward method