diff --git a/src/adam/parametric/pytorch/computations_parametric.py b/src/adam/parametric/pytorch/computations_parametric.py index daa66d6c..1b1923a7 100644 --- a/src/adam/parametric/pytorch/computations_parametric.py +++ b/src/adam/parametric/pytorch/computations_parametric.py @@ -23,9 +23,7 @@ def __init__( joints_name_list: list, links_name_list: list, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: @@ -35,7 +33,7 @@ def __init__( root_link (str, optional): Deprecated. The root link is automatically chosen as the link with no parent in the URDF. Defaults to None. """ self.math = SpatialMath() - self.g = gravity + self.g = gravity.to(torch.get_default_dtype()) self.links_name_list = links_name_list self.joints_name_list = joints_name_list self.urdfstring = urdfstring diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index 9e0a34e3..4296be76 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -34,12 +34,18 @@ def __init__( joints_name_list (list): list of the actuated joints root_link (str, optional): Deprecated. The root link is automatically chosen as the link with no parent in the URDF. Defaults to None. """ + + def to_default_dtype(tensor): + """Converts a JAX tensor to the default floating-point type (float32 or float64).""" + default_dtype = jnp.array(0.0).dtype # Get the default floating-point dtype + return tensor.astype(default_dtype) + math = SpatialMath() factory = URDFModelFactory(path=urdfstring, math=math) model = Model.build(factory=factory, joints_name_list=joints_name_list) self.rbdalgos = RBDAlgorithms(model=model, math=math) self.NDoF = self.rbdalgos.NDoF - self.g = gravity + self.g = to_default_dtype(gravity) self.funcs = {} if root_link is not None: warnings.warn( diff --git a/src/adam/pytorch/computations.py b/src/adam/pytorch/computations.py index 3f39d495..ee7189bd 100644 --- a/src/adam/pytorch/computations.py +++ b/src/adam/pytorch/computations.py @@ -20,9 +20,7 @@ def __init__( urdfstring: str, joints_name_list: list = None, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: @@ -35,7 +33,7 @@ def __init__( model = Model.build(factory=factory, joints_name_list=joints_name_list) self.rbdalgos = RBDAlgorithms(model=model, math=math) self.NDoF = self.rbdalgos.NDoF - self.g = gravity + self.g = gravity.to(torch.get_default_dtype()) if root_link is not None: warnings.warn( "The root_link argument is not used. The root link is automatically chosen as the link with no parent in the URDF", diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index de2cb812..a08e5cae 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -17,9 +17,9 @@ class TorchLike(ArrayLike): array: torch.Tensor def __post_init__(self): - """Converts array to double precision""" - if self.array.dtype != torch.float64: - self.array = self.array.double() + """Converts array to the desired type""" + if self.array.dtype != torch.get_default_dtype(): + self.array = self.array.to(torch.get_default_dtype()) def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides set item operator"""