diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 0208e112..3b61a2e9 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -542,9 +542,7 @@ def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tens zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device) omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas) - near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze( - -1 - ) + near_pi = angles.isclose(angles.new_tensor(torch.pi)).squeeze(-1) axis_angles = torch.empty_like(omegas) axis_angles[~near_pi] = (