diff --git a/src/adam/parametric/pytorch/computations_parametric.py b/src/adam/parametric/pytorch/computations_parametric.py index daa66d6c..f9694028 100644 --- a/src/adam/parametric/pytorch/computations_parametric.py +++ b/src/adam/parametric/pytorch/computations_parametric.py @@ -23,9 +23,7 @@ def __init__( joints_name_list: list, links_name_list: list, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: diff --git a/src/adam/pytorch/computations.py b/src/adam/pytorch/computations.py index 3f39d495..6e753fcd 100644 --- a/src/adam/pytorch/computations.py +++ b/src/adam/pytorch/computations.py @@ -20,9 +20,7 @@ def __init__( urdfstring: str, joints_name_list: list = None, root_link: str = None, - gravity: np.array = torch.tensor( - [0, 0, -9.80665, 0, 0, 0], dtype=torch.float64 - ), + gravity: np.array = torch.tensor([0, 0, -9.80665, 0, 0, 0]), ) -> None: """ Args: diff --git a/src/adam/pytorch/torch_like.py b/src/adam/pytorch/torch_like.py index de2cb812..342bf0d8 100644 --- a/src/adam/pytorch/torch_like.py +++ b/src/adam/pytorch/torch_like.py @@ -17,9 +17,9 @@ class TorchLike(ArrayLike): array: torch.Tensor def __post_init__(self): - """Converts array to double precision""" - if self.array.dtype != torch.float64: - self.array = self.array.double() + """Converts array to the default type used in the library""" + if self.array.dtype != torch.get_default_dtype(): + self.array = self.array.to(torch.get_default_dtype()) def __setitem__(self, idx, value: Union["TorchLike", ntp.ArrayLike]) -> "TorchLike": """Overrides set item operator"""