Skip to content

Commit 7a3c0cb

Browse files
alex-benefacebook-github-bot
authored andcommitted
Increase performance for conversions including axis angles (#1948)
Summary: This is an extension of #1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle). The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc. ### Notes - As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general. - bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this. Pull Request resolved: #1948 Reviewed By: MichaelRamamonjisoa Differential Revision: D69193009 Pulled By: bottler fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
1 parent 215590b commit 7a3c0cb

File tree

2 files changed

+83
-33
lines changed

2 files changed

+83
-33
lines changed

pytorch3d/transforms/rotation_conversions.py

+76-31
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
463463
return out[..., 1:]
464464

465465

466-
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
466+
def axis_angle_to_matrix(axis_angle: torch.Tensor, fast: bool = False) -> torch.Tensor:
467467
"""
468468
Convert rotations given as axis/angle to rotation matrices.
469469
@@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
472472
as a tensor of shape (..., 3), where the magnitude is
473473
the angle turned anticlockwise in radians around the
474474
vector's direction.
475+
fast: Whether to use the new faster implementation (based on the
476+
Rodrigues formula) instead of the original implementation (which
477+
first converted to a quaternion and then back to a rotation matrix).
475478
476479
Returns:
477480
Rotation matrices as tensor of shape (..., 3, 3).
478481
"""
479-
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
482+
if not fast:
483+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
480484

485+
shape = axis_angle.shape
486+
device, dtype = axis_angle.device, axis_angle.dtype
481487

482-
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
488+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True).unsqueeze(-1)
489+
490+
rx, ry, rz = axis_angle[..., 0], axis_angle[..., 1], axis_angle[..., 2]
491+
zeros = torch.zeros(shape[:-1], dtype=dtype, device=device)
492+
cross_product_matrix = torch.stack(
493+
[zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1
494+
).view(shape + (3,))
495+
cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
496+
497+
identity = torch.eye(3, dtype=dtype, device=device)
498+
angles_sqrd = angles * angles
499+
angles_sqrd = torch.where(angles_sqrd == 0, 1, angles_sqrd)
500+
return (
501+
identity.expand(cross_product_matrix.shape)
502+
+ torch.sinc(angles / torch.pi) * cross_product_matrix
503+
+ ((1 - torch.cos(angles)) / angles_sqrd) * cross_product_matrix_sqrd
504+
)
505+
506+
507+
def matrix_to_axis_angle(matrix: torch.Tensor, fast: bool = False) -> torch.Tensor:
483508
"""
484509
Convert rotations given as rotation matrices to axis/angle.
485510
486511
Args:
487512
matrix: Rotation matrices as tensor of shape (..., 3, 3).
513+
fast: Whether to use the new faster implementation (based on the
514+
Rodrigues formula) instead of the original implementation (which
515+
first converted to a quaternion and then back to a rotation matrix).
488516
489517
Returns:
490518
Rotations given as a vector in axis angle form, as a tensor
491519
of shape (..., 3), where the magnitude is the angle
492520
turned anticlockwise in radians around the vector's
493521
direction.
522+
494523
"""
495-
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
524+
if not fast:
525+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
526+
527+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
528+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
529+
530+
omegas = torch.stack(
531+
[
532+
matrix[..., 2, 1] - matrix[..., 1, 2],
533+
matrix[..., 0, 2] - matrix[..., 2, 0],
534+
matrix[..., 1, 0] - matrix[..., 0, 1],
535+
],
536+
dim=-1,
537+
)
538+
norms = torch.norm(omegas, p=2, dim=-1, keepdim=True)
539+
traces = torch.diagonal(matrix, dim1=-2, dim2=-1).sum(-1).unsqueeze(-1)
540+
angles = torch.atan2(norms, traces - 1)
541+
542+
zeros = torch.zeros(3, dtype=matrix.dtype, device=matrix.device)
543+
omegas = torch.where(torch.isclose(angles, torch.zeros_like(angles)), zeros, omegas)
544+
545+
near_pi = torch.isclose(((traces - 1) / 2).abs(), torch.ones_like(traces)).squeeze(
546+
-1
547+
)
548+
549+
axis_angles = torch.empty_like(omegas)
550+
axis_angles[~near_pi] = (
551+
0.5 * omegas[~near_pi] / torch.sinc(angles[~near_pi] / torch.pi)
552+
)
553+
554+
# this derives from: nnT = (R + 1) / 2
555+
n = 0.5 * (
556+
matrix[near_pi][..., 0, :]
557+
+ torch.eye(1, 3, dtype=matrix.dtype, device=matrix.device)
558+
)
559+
axis_angles[near_pi] = angles[near_pi] * n / torch.norm(n)
560+
561+
return axis_angles
496562

497563

498564
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
@@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
509575
quaternions with real part first, as tensor of shape (..., 4).
510576
"""
511577
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
512-
half_angles = angles * 0.5
513-
eps = 1e-6
514-
small_angles = angles.abs() < eps
515-
sin_half_angles_over_angles = torch.empty_like(angles)
516-
sin_half_angles_over_angles[~small_angles] = (
517-
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
518-
)
519-
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
520-
# so sin(x/2)/x is about 1/2 - (x*x)/48
521-
sin_half_angles_over_angles[small_angles] = (
522-
0.5 - (angles[small_angles] * angles[small_angles]) / 48
578+
sin_half_angles_over_angles = 0.5 * torch.sinc(angles * 0.5 / torch.pi)
579+
return torch.cat(
580+
[torch.cos(angles * 0.5), axis_angle * sin_half_angles_over_angles], dim=-1
523581
)
524-
quaternions = torch.cat(
525-
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
526-
)
527-
return quaternions
528582

529583

530584
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
@@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
543597
"""
544598
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
545599
half_angles = torch.atan2(norms, quaternions[..., :1])
546-
angles = 2 * half_angles
547-
eps = 1e-6
548-
small_angles = angles.abs() < eps
549-
sin_half_angles_over_angles = torch.empty_like(angles)
550-
sin_half_angles_over_angles[~small_angles] = (
551-
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
552-
)
553-
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
554-
# so sin(x/2)/x is about 1/2 - (x*x)/48
555-
sin_half_angles_over_angles[small_angles] = (
556-
0.5 - (angles[small_angles] * angles[small_angles]) / 48
557-
)
600+
sin_half_angles_over_angles = 0.5 * torch.sinc(half_angles / torch.pi)
601+
# angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
602+
# can't be zero
558603
return quaternions[..., 1:] / sin_half_angles_over_angles
559604

560605

tests/test_rotation_conversions.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def test_from_axis_angle(self):
204204
n_repetitions = 20
205205
data = torch.rand(n_repetitions, 3)
206206
matrices = axis_angle_to_matrix(data)
207+
self.assertClose(data, matrix_to_axis_angle(matrices), atol=2e-6)
208+
self.assertClose(data, matrix_to_axis_angle(matrices, fast=True), atol=2e-6)
209+
matrices = axis_angle_to_matrix(data, fast=True)
207210
mdata = matrix_to_axis_angle(matrices)
208211
self.assertClose(data, mdata, atol=2e-6)
209212

@@ -221,8 +224,10 @@ def test_to_axis_angle(self):
221224
"""mtx -> axis_angle -> mtx"""
222225
data = random_rotations(13, dtype=torch.float64)
223226
euler_angles = matrix_to_axis_angle(data)
224-
mdata = axis_angle_to_matrix(euler_angles)
225-
self.assertClose(data, mdata)
227+
euler_angles_fast = matrix_to_axis_angle(data)
228+
self.assertClose(data, axis_angle_to_matrix(euler_angles))
229+
self.assertClose(data, axis_angle_to_matrix(euler_angles_fast))
230+
self.assertClose(data, axis_angle_to_matrix(euler_angles, fast=True))
226231

227232
def test_quaternion_application(self):
228233
"""Applying a quaternion is the same as applying the matrix."""

0 commit comments

Comments
 (0)