Skip to content

Conversation

@Giulero
Copy link
Collaborator

@Giulero Giulero commented Sep 9, 2025

This PR introduces the use of array-api-compat (see https://data-apis.org/array-api-compat and #98) to unify numpy/jax/pytorch implementations.
I also took the occasion to refactor the algorithms to work natively with batch operations (so no vmapping).
I had, however, to complexify a bit the casadi implementation to be compatible with the new implementation.
The tests are passing, and the pytorch and jax implementations are tested in batched and cuda cases.

Note: I could have remove the jax/numpy/pytorch interfaces now that they are unified, but I wanted to keep the api compatible w.r.t the previous implementation.

Copilot summary

This pull request introduces a major update to the codebase by replacing the use of the jax2torch library with the new array-api-compat package and implementing a unified Array API abstraction for mathematical operations. This change improves backend compatibility and simplifies code maintenance, especially for NumPy, PyTorch, and JAX. The documentation and installation instructions have also been updated to reflect these changes. Additionally, some minor improvements were made to mathematical computations and imports.

Array API abstraction and dependency updates:

  • Added the new array-api-compat dependency to all relevant places (setup.cfg, ci_env.yml, ci_env_win.yml) and removed jax2torch from dependencies and requirements. [1] [2] [3] [4] [5] [6]
  • Implemented the new src/adam/core/array_api_math.py module, introducing ArrayAPILike, ArrayAPIFactory, and ArrayAPISpatialMath classes for unified mathematical operations across supported backends.
  • Refactor rigid body algorithms to work with array-api-compat (remove item assignements and enable native batch computations)
  • Updated imports to include new Array API math classes in src/adam/core/__init__.py.

Documentation and installation instructions:

  • Updated documentation and installation instructions to remove references to jax2torch and add array-api-compat, including changes in README.md, docs/index.rst, and docs/modules/pytorch_batched.rst. [1] [2] [3] [4] [5]

Minor code improvements:

  • Simplified velocity and gravity term computations in CasADi and JAX by removing unnecessary reshaping and using direct array operations. [1] [2] [3] [4]

📚 Documentation preview 📚: https://adam-docs--134.org.readthedocs.build/en/134/

…g and remove type checking in order to be compliant with the refactor that allows only operations between ArrayLike types
…ion methods; update ArrayLikeFactory to accept variable arguments and define an abstract array method
…atic methods for building and zeroing instances.
…ArrayAPIFactory. Commenting out for now the old code.
…and ArrayAPIFactory, updating initialization and commenting out old methods for future reference.
…and ArrayAPIFactory, updating initialization and commenting out legacy methods for future reference.
@Giulero Giulero requested a review from Copilot September 9, 2025 13:29

This comment was marked as outdated.

@Giulero
Copy link
Collaborator Author

Giulero commented Sep 9, 2025

@traversaro and @GiulioRomualdi the PR is a bit long and I might have overlooked some details.

@Giulero Giulero requested a review from Copilot September 10, 2025 09:30
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces the use of array-api-compat to unify numpy/jax/pytorch implementations and refactors algorithms to work natively with batch operations. The key changes include replacing jax2torch with array-api-compat, implementing a unified Array API abstraction, and refactoring rigid body algorithms to support native batch computations without item assignments.

  • Replaced jax2torch dependency with array-api-compat for unified backend compatibility
  • Implemented ArrayAPISpatialMath and related classes for unified mathematical operations across NumPy, PyTorch, and JAX
  • Refactored RBD algorithms to work with batch operations natively (removed vmapping and item assignments)

Reviewed Changes

Copilot reviewed 37 out of 37 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
tests/test_pytorch_batch.py Enhanced test suite with proper batch validation against idyntree references
tests/test_pytorch.py Fixed concatenation operation to use PyTorch instead of NumPy
tests/test_jax_batch.py Added comprehensive JAX batch testing with gradient validation
tests/conftest.py Updated to use scipy for rotation matrix generation instead of SpatialMath
src/adam/pytorch/torch_like.py Refactored to use new ArrayAPI abstraction
src/adam/pytorch/computations.py Updated device/dtype handling and removed unnecessary reshaping
src/adam/pytorch/computation_batch.py Major refactor from JAX-based to native PyTorch batch operations
src/adam/core/array_api_math.py New unified Array API abstraction module
src/adam/core/spatial_math.py Enhanced with new batching and concatenation methods
src/adam/core/rbd_algorithms.py Complete refactor for native batch support without item assignments
pyproject.toml Updated dependencies to replace jax2torch with array-api-compat

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@pytest.fixture(scope="module")
def setup_test(tests_setup) -> KinDynComputationsBatch | RobotCfg | State:
robot_cfg, state = tests_setup
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The device selection logic is duplicated across multiple test files. Consider extracting this into a utility function or test fixture for better maintainability.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a good idea, also to change it in the future (for example to automatically used rocm if it is available).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super true! I did a quick edit in ea12f46 (#134) !

Comment on lines +233 to +237
base_velocity = torch.zeros(
batch_size + (6,),
dtype=base_transform.dtype,
device=base_transform.device,
)
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Creating zero tensors with the same dtype and device as input tensors repeatedly could be inefficient. Consider using torch.zeros_like with appropriate reshaping where possible.

Suggested change
base_velocity = torch.zeros(
batch_size + (6,),
dtype=base_transform.dtype,
device=base_transform.device,
)
base_velocity = torch.zeros_like(base_transform[..., 0, :]).reshape(batch_size + (6,))

Copilot uses AI. Check for mistakes.
s = self.sin(q)

# Build rotation matrix components
I = self.factory.eye(q.shape + (3,))
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The eye function is being called with a tuple concatenation q.shape + (3,) but the factory.eye method expects just the dimension as an integer. This will likely cause a runtime error.

Suggested change
I = self.factory.eye(q.shape + (3,))
# Create a batch of 3x3 identity matrices matching the shape of q
I = self.tile(self.factory.eye(3), q.shape + (1, 1))

Copilot uses AI. Check for mistakes.
tau_joints.append(col)

tau_joints_vec = self.math.concatenate(tau_joints, axis=-1)

Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concatenation assumes tau_base has shape (..., 6) and tau_joints_vec has shape (..., n), but tau_base should be reshaped to (..., 6, 1) to match the expected return shape (..., 6+n, 1) based on the function's docstring.

Suggested change
tau_base = self.math.reshape(tau_base, batch_size + (6, 1))

Copilot uses AI. Check for mistakes.
Comment on lines +107 to +109
def __init__(self, xp: Union[cs.SX, cs.DM, None] = None):

@staticmethod
def eye(x: int) -> "CasadiLike":
self._xp = cs.SX if xp is None else xp
Copy link

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter xp is typed as Union[cs.SX, cs.DM, None] but cs.SX and cs.DM are classes, not instances. The parameter should likely be a class or a string identifier, and the default should be the class cs.SX, not an instance.

Copilot uses AI. Check for mistakes.
@Giulero
Copy link
Collaborator Author

Giulero commented Sep 18, 2025

Thanks @traversaro!
As this is a huge code base modification, I created a release (https://github.com/ami-iit/adam/releases/tag/v0.3.5) to pin the "old" code base, and, in case something is off, go back there.

@Giulero Giulero merged commit fe457a6 into main Sep 18, 2025
15 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants