Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/adam/casadi/casadi_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,26 @@ def __truediv__(self, other: Union["CasadiLike", npt.ArrayLike]) -> "CasadiLike"

def __setitem__(self, idx, value: Union["CasadiLike", npt.ArrayLike]):
"""Overrides set item operator"""
self.array[idx] = value.array if type(self) is type(value) else value
if idx is Ellipsis:
self.array = value.array if isinstance(value, CasadiLike) else value
elif isinstance(idx, tuple) and Ellipsis in idx:
idx = tuple(slice(None) if i is Ellipsis else i for i in idx)
self.array[idx] = value.array if isinstance(value, CasadiLike) else value
else:
self.array[idx] = value.array if isinstance(value, CasadiLike) else value

def __getitem__(self, idx) -> "CasadiLike":
"""Overrides get item operator"""
return CasadiLike(self.array[idx])
if idx is Ellipsis:
# Handle the case where only Ellipsis is passed
return CasadiLike(self.array)
elif isinstance(idx, tuple) and Ellipsis in idx:
# Handle the case where Ellipsis is part of a tuple
idx = tuple(slice(None) if k is Ellipsis else k for k in idx)
return CasadiLike(self.array[idx])
else:
# For other cases, delegate to the CasADi object's __getitem__
return CasadiLike(self.array[idx])

@property
def T(self) -> "CasadiLike":
Expand Down Expand Up @@ -129,6 +144,68 @@ def array(*x) -> "CasadiLike":
"""
return CasadiLike(cs.SX(*x))

@staticmethod
def zeros_like(x) -> CasadiLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: zero matrix of dimension x
"""

kind = (
cs.DM
if (isinstance(x, CasadiLike) and isinstance(x.array, cs.DM))
or isinstance(x, cs.DM)
else cs.SX
)

return (
CasadiLike(kind.zeros(x.array.shape))
if isinstance(x, CasadiLike)
else (
CasadiLike(kind.zeros(x.shape))
if isinstance(x, (cs.SX, cs.DM))
else (
TypeError(f"Unsupported type for zeros_like: {type(x)}")
if isinstance(x, CasadiLike)
else CasadiLike(kind.zeros(x.shape))
)
)
)

@staticmethod
def ones_like(x) -> CasadiLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: Identity matrix of dimension x
"""

kind = (
cs.DM
if (isinstance(x, CasadiLike) and isinstance(x.array, cs.DM))
or isinstance(x, cs.DM)
else cs.SX
)

return (
CasadiLike(kind.ones(x.array.shape))
if isinstance(x, CasadiLike)
else (
CasadiLike(kind.ones(x.shape))
if isinstance(x, (cs.SX, cs.DM))
else (
TypeError(f"Unsupported type for ones_like: {type(x)}")
if isinstance(x, CasadiLike)
else CasadiLike(kind.ones(x.shape))
)
)
)


class SpatialMath(SpatialMath):

Expand Down
30 changes: 28 additions & 2 deletions src/adam/core/spatial_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def T(self):
class ArrayLikeFactory(abc.ABC):
"""Abstract class for a generic Array wrapper. Every method should be implemented for every data type."""

@staticmethod
@abc.abstractmethod
def zeros(self, x: npt.ArrayLike) -> npt.ArrayLike:
def zeros(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix dimension
Expand All @@ -79,8 +80,9 @@ def zeros(self, x: npt.ArrayLike) -> npt.ArrayLike:
"""
pass

@staticmethod
@abc.abstractmethod
def eye(self, x: npt.ArrayLike) -> npt.ArrayLike:
def eye(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix dimension
Expand All @@ -90,6 +92,30 @@ def eye(self, x: npt.ArrayLike) -> npt.ArrayLike:
"""
pass

@staticmethod
@abc.abstractmethod
def zeros_like(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: zero matrix of dimension x
"""
pass

@staticmethod
@abc.abstractmethod
def ones_like(x: npt.ArrayLike) -> npt.ArrayLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: ones matrix of dimension x
"""
pass


class SpatialMath:
"""Class implementing the main geometric functions used for computing rigid-body algorithm
Expand Down
30 changes: 30 additions & 0 deletions src/adam/jax/jax_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,36 @@ def array(x) -> "JaxLike":
"""
return JaxLike(jnp.array(x))

@staticmethod
def zeros_like(x) -> JaxLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
JaxLike(jnp.zeros_like(x.array))
if isinstance(x, JaxLike)
else JaxLike(jnp.zeros_like(x))
)

@staticmethod
def ones_like(x) -> JaxLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: Ones matrix of dimension x
"""
return (
JaxLike(jnp.ones_like(x.array))
if isinstance(x, JaxLike)
else JaxLike(jnp.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions src/adam/numpy/numpy_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def array(x) -> "NumpyLike":
"""
return NumpyLike(np.array(x))

@staticmethod
def zeros_like(x) -> NumpyLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
NumpyLike(np.zeros_like(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.zeros_like(x))
)

@staticmethod
def ones_like(x) -> NumpyLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: Ones matrix of dimension x
"""
return (
NumpyLike(np.ones_like(x.array))
if isinstance(x, NumpyLike)
else NumpyLike(np.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
30 changes: 30 additions & 0 deletions src/adam/pytorch/torch_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,36 @@ def array(x: ntp.ArrayLike) -> "TorchLike":
"""
return TorchLike(torch.tensor(x))

@staticmethod
def zeros_like(x) -> TorchLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: zero matrix of dimension x
"""
return (
TorchLike(torch.zeros_like(x.array))
if isinstance(x, TorchLike)
else TorchLike(torch.zeros_like(x))
)

@staticmethod
def ones_like(x) -> TorchLike:
"""
Args:
x (npt.ArrayLike): matrix

Returns:
npt.ArrayLike: Identity matrix of dimension x
"""
return (
TorchLike(torch.ones_like(x.array))
if isinstance(x, TorchLike)
else TorchLike(torch.ones_like(x))
)


class SpatialMath(SpatialMath):
def __init__(self):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from conftest import RobotCfg, State

from adam.casadi import KinDynComputations
from adam.casadi.casadi_like import CasadiLike, CasadiLikeFactory


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -173,3 +174,15 @@ def test_gravity_term(setup_test):
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)
adam_gravity = cs.DM(adam_kin_dyn.gravity_term_fun()(state.H, state.joints_pos))
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)


def test_casadi_like():
B = cs.DM([[1.0, 2.0], [3.0, 4.0]])
B_like = CasadiLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = CasadiLikeFactory.ones_like(B)
assert ones[...].array - cs.DM.ones(2, 2) == pytest.approx(0.0, abs=1e-5)

zeros = CasadiLikeFactory.zeros_like(B)
assert zeros[...].array - cs.DM.zeros(2, 2) == pytest.approx(0.0, abs=1e-5)
15 changes: 15 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import pytest
from conftest import RobotCfg, State
from jax import config
import jax.numpy as jnp


from adam.jax import KinDynComputations
from adam.jax.jax_like import JaxLike, JaxLikeFactory

config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -119,3 +122,15 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)


def test_jax_like():
B = jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = JaxLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = JaxLikeFactory.ones_like(B_like)
assert ones.array - jnp.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = JaxLikeFactory.zeros_like(B_like)
assert zeros.array - jnp.zeros_like(B) == pytest.approx(0.0, abs=1e-5)
13 changes: 13 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from conftest import RobotCfg, State

from adam.numpy import KinDynComputations
from adam.numpy.numpy_like import NumpyLike, NumpyLikeFactory


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -116,3 +117,15 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity == pytest.approx(0.0, abs=1e-4)


def test_numpy_like():
B = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = NumpyLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = NumpyLikeFactory.ones_like(B_like)
assert ones.array - np.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = NumpyLikeFactory.zeros_like(B_like)
assert zeros.array - np.zeros_like(B) == pytest.approx(0.0, abs=1e-5)
13 changes: 13 additions & 0 deletions tests/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from conftest import RobotCfg, State

from adam.pytorch import KinDynComputations
from adam.pytorch.torch_like import TorchLike, TorchLikeFactory

torch.set_default_dtype(torch.float64)

Expand Down Expand Up @@ -128,3 +129,15 @@ def test_gravity_term(setup_test):
idyn_gravity = robot_cfg.idyn_function_values.gravity_term
adam_gravity = adam_kin_dyn.gravity_term(state.H, state.joints_pos)
assert idyn_gravity - adam_gravity.numpy() == pytest.approx(0.0, abs=1e-4)


def test_torch_like():
B = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
B_like = TorchLike(B)
assert B_like[...].array - B == pytest.approx(0.0, abs=1e-5)

ones = TorchLikeFactory.ones_like(B_like)
assert ones.array - np.ones_like(B) == pytest.approx(0.0, abs=1e-5)

zeros = TorchLikeFactory.zeros_like(B_like)
assert zeros.array - np.zeros_like(B) == pytest.approx(0.0, abs=1e-5)
Loading