-
Notifications
You must be signed in to change notification settings - Fork 26
Refactor algorithms to use array-api-compat and native batch operations #134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…asadiLike objects
…g and remove type checking in order to be compliant with the refactor that allows only operations between ArrayLike types
…nd use directly ArrayLike types
…ion methods; update ArrayLikeFactory to accept variable arguments and define an abstract array method
…check if this is clean
…atic methods for building and zeroing instances.
…ng and eliminate type checks
…reamline array handling
…nd streamline array handling
…ArrayAPIFactory. Commenting out for now the old code.
…and ArrayAPIFactory, updating initialization and commenting out old methods for future reference.
…and ArrayAPIFactory, updating initialization and commenting out legacy methods for future reference.
…StdJoint, StdLink, and Pose classes
… pose calculations
|
@traversaro and @GiulioRomualdi the PR is a bit long and I might have overlooked some details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces the use of array-api-compat to unify numpy/jax/pytorch implementations and refactors algorithms to work natively with batch operations. The key changes include replacing jax2torch with array-api-compat, implementing a unified Array API abstraction, and refactoring rigid body algorithms to support native batch computations without item assignments.
- Replaced
jax2torchdependency witharray-api-compatfor unified backend compatibility - Implemented
ArrayAPISpatialMathand related classes for unified mathematical operations across NumPy, PyTorch, and JAX - Refactored RBD algorithms to work with batch operations natively (removed vmapping and item assignments)
Reviewed Changes
Copilot reviewed 37 out of 37 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_pytorch_batch.py |
Enhanced test suite with proper batch validation against idyntree references |
tests/test_pytorch.py |
Fixed concatenation operation to use PyTorch instead of NumPy |
tests/test_jax_batch.py |
Added comprehensive JAX batch testing with gradient validation |
tests/conftest.py |
Updated to use scipy for rotation matrix generation instead of SpatialMath |
src/adam/pytorch/torch_like.py |
Refactored to use new ArrayAPI abstraction |
src/adam/pytorch/computations.py |
Updated device/dtype handling and removed unnecessary reshaping |
src/adam/pytorch/computation_batch.py |
Major refactor from JAX-based to native PyTorch batch operations |
src/adam/core/array_api_math.py |
New unified Array API abstraction module |
src/adam/core/spatial_math.py |
Enhanced with new batching and concatenation methods |
src/adam/core/rbd_algorithms.py |
Complete refactor for native batch support without item assignments |
pyproject.toml |
Updated dependencies to replace jax2torch with array-api-compat |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
tests/test_pytorch_batch.py
Outdated
| @pytest.fixture(scope="module") | ||
| def setup_test(tests_setup) -> KinDynComputationsBatch | RobotCfg | State: | ||
| robot_cfg, state = tests_setup | ||
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
Copilot
AI
Sep 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The device selection logic is duplicated across multiple test files. Consider extracting this into a utility function or test fixture for better maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may be a good idea, also to change it in the future (for example to automatically used rocm if it is available).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super true! I did a quick edit in ea12f46 (#134) !
| base_velocity = torch.zeros( | ||
| batch_size + (6,), | ||
| dtype=base_transform.dtype, | ||
| device=base_transform.device, | ||
| ) |
Copilot
AI
Sep 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Creating zero tensors with the same dtype and device as input tensors repeatedly could be inefficient. Consider using torch.zeros_like with appropriate reshaping where possible.
| base_velocity = torch.zeros( | |
| batch_size + (6,), | |
| dtype=base_transform.dtype, | |
| device=base_transform.device, | |
| ) | |
| base_velocity = torch.zeros_like(base_transform[..., 0, :]).reshape(batch_size + (6,)) |
| s = self.sin(q) | ||
|
|
||
| # Build rotation matrix components | ||
| I = self.factory.eye(q.shape + (3,)) |
Copilot
AI
Sep 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The eye function is being called with a tuple concatenation q.shape + (3,) but the factory.eye method expects just the dimension as an integer. This will likely cause a runtime error.
| I = self.factory.eye(q.shape + (3,)) | |
| # Create a batch of 3x3 identity matrices matching the shape of q | |
| I = self.tile(self.factory.eye(3), q.shape + (1, 1)) |
| tau_joints.append(col) | ||
|
|
||
| tau_joints_vec = self.math.concatenate(tau_joints, axis=-1) | ||
|
|
Copilot
AI
Sep 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The concatenation assumes tau_base has shape (..., 6) and tau_joints_vec has shape (..., n), but tau_base should be reshaped to (..., 6, 1) to match the expected return shape (..., 6+n, 1) based on the function's docstring.
| tau_base = self.math.reshape(tau_base, batch_size + (6, 1)) |
| def __init__(self, xp: Union[cs.SX, cs.DM, None] = None): | ||
|
|
||
| @staticmethod | ||
| def eye(x: int) -> "CasadiLike": | ||
| self._xp = cs.SX if xp is None else xp |
Copilot
AI
Sep 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter xp is typed as Union[cs.SX, cs.DM, None] but cs.SX and cs.DM are classes, not instances. The parameter should likely be a class or a string identifier, and the default should be the class cs.SX, not an instance.
…ncy in Pose, Inertia, Limits, Inertial, and Model classes
… for CUDA handling
|
Thanks @traversaro! |
This PR introduces the use of
array-api-compat(see https://data-apis.org/array-api-compat and #98) to unify numpy/jax/pytorch implementations.I also took the occasion to refactor the algorithms to work natively with batch operations (so no vmapping).
I had, however, to complexify a bit the casadi implementation to be compatible with the new implementation.
The tests are passing, and the pytorch and jax implementations are tested in batched and cuda cases.
Note: I could have remove the jax/numpy/pytorch interfaces now that they are unified, but I wanted to keep the api compatible w.r.t the previous implementation.
Copilot summary
This pull request introduces a major update to the codebase by replacing the use of the
jax2torchlibrary with the newarray-api-compatpackage and implementing a unified Array API abstraction for mathematical operations. This change improves backend compatibility and simplifies code maintenance, especially for NumPy, PyTorch, and JAX. The documentation and installation instructions have also been updated to reflect these changes. Additionally, some minor improvements were made to mathematical computations and imports.Array API abstraction and dependency updates:
array-api-compatdependency to all relevant places (setup.cfg,ci_env.yml,ci_env_win.yml) and removedjax2torchfrom dependencies and requirements. [1] [2] [3] [4] [5] [6]src/adam/core/array_api_math.pymodule, introducingArrayAPILike,ArrayAPIFactory, andArrayAPISpatialMathclasses for unified mathematical operations across supported backends.src/adam/core/__init__.py.Documentation and installation instructions:
jax2torchand addarray-api-compat, including changes inREADME.md,docs/index.rst, anddocs/modules/pytorch_batched.rst. [1] [2] [3] [4] [5]Minor code improvements:
📚 Documentation preview 📚: https://adam-docs--134.org.readthedocs.build/en/134/