@@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
463463 return out [..., 1 :]
464464
465465
466- def axis_angle_to_matrix (axis_angle : torch .Tensor ) -> torch .Tensor :
466+ def axis_angle_to_matrix (axis_angle : torch .Tensor , fast : bool = False ) -> torch .Tensor :
467467 """
468468 Convert rotations given as axis/angle to rotation matrices.
469469
@@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
472472 as a tensor of shape (..., 3), where the magnitude is
473473 the angle turned anticlockwise in radians around the
474474 vector's direction.
475+ fast: Whether to use the new faster implementation (based on the
476+ Rodrigues formula) instead of the original implementation (which
477+ first converted to a quaternion and then back to a rotation matrix).
475478
476479 Returns:
477480 Rotation matrices as tensor of shape (..., 3, 3).
478481 """
479- return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
482+ if not fast :
483+ return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
480484
485+ shape = axis_angle .shape
486+ device , dtype = axis_angle .device , axis_angle .dtype
481487
482- def matrix_to_axis_angle (matrix : torch .Tensor ) -> torch .Tensor :
488+ angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True ).unsqueeze (- 1 )
489+
490+ rx , ry , rz = axis_angle [..., 0 ], axis_angle [..., 1 ], axis_angle [..., 2 ]
491+ zeros = torch .zeros (shape [:- 1 ], dtype = dtype , device = device )
492+ cross_product_matrix = torch .stack (
493+ [zeros , - rz , ry , rz , zeros , - rx , - ry , rx , zeros ], dim = - 1
494+ ).view (shape + (3 ,))
495+ cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
496+
497+ identity = torch .eye (3 , dtype = dtype , device = device )
498+ angles_sqrd = angles * angles
499+ angles_sqrd = torch .where (angles_sqrd == 0 , 1 , angles_sqrd )
500+ return (
501+ identity .expand (cross_product_matrix .shape )
502+ + torch .sinc (angles / torch .pi ) * cross_product_matrix
503+ + ((1 - torch .cos (angles )) / angles_sqrd ) * cross_product_matrix_sqrd
504+ )
505+
506+
507+ def matrix_to_axis_angle (matrix : torch .Tensor , fast : bool = False ) -> torch .Tensor :
483508 """
484509 Convert rotations given as rotation matrices to axis/angle.
485510
486511 Args:
487512 matrix: Rotation matrices as tensor of shape (..., 3, 3).
513+ fast: Whether to use the new faster implementation (based on the
514+ Rodrigues formula) instead of the original implementation (which
515+ first converted to a quaternion and then back to a rotation matrix).
488516
489517 Returns:
490518 Rotations given as a vector in axis angle form, as a tensor
491519 of shape (..., 3), where the magnitude is the angle
492520 turned anticlockwise in radians around the vector's
493521 direction.
522+
494523 """
495- return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
524+ if not fast :
525+ return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
526+
527+ if matrix .size (- 1 ) != 3 or matrix .size (- 2 ) != 3 :
528+ raise ValueError (f"Invalid rotation matrix shape { matrix .shape } ." )
529+
530+ omegas = torch .stack (
531+ [
532+ matrix [..., 2 , 1 ] - matrix [..., 1 , 2 ],
533+ matrix [..., 0 , 2 ] - matrix [..., 2 , 0 ],
534+ matrix [..., 1 , 0 ] - matrix [..., 0 , 1 ],
535+ ],
536+ dim = - 1 ,
537+ )
538+ norms = torch .norm (omegas , p = 2 , dim = - 1 , keepdim = True )
539+ traces = torch .diagonal (matrix , dim1 = - 2 , dim2 = - 1 ).sum (- 1 ).unsqueeze (- 1 )
540+ angles = torch .atan2 (norms , traces - 1 )
541+
542+ zeros = torch .zeros (3 , dtype = matrix .dtype , device = matrix .device )
543+ omegas = torch .where (torch .isclose (angles , torch .zeros_like (angles )), zeros , omegas )
544+
545+ near_pi = torch .isclose (((traces - 1 ) / 2 ).abs (), torch .ones_like (traces )).squeeze (
546+ - 1
547+ )
548+
549+ axis_angles = torch .empty_like (omegas )
550+ axis_angles [~ near_pi ] = (
551+ 0.5 * omegas [~ near_pi ] / torch .sinc (angles [~ near_pi ] / torch .pi )
552+ )
553+
554+ # this derives from: nnT = (R + 1) / 2
555+ n = 0.5 * (
556+ matrix [near_pi ][..., 0 , :]
557+ + torch .eye (1 , 3 , dtype = matrix .dtype , device = matrix .device )
558+ )
559+ axis_angles [near_pi ] = angles [near_pi ] * n / torch .norm (n )
560+
561+ return axis_angles
496562
497563
498564def axis_angle_to_quaternion (axis_angle : torch .Tensor ) -> torch .Tensor :
@@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
509575 quaternions with real part first, as tensor of shape (..., 4).
510576 """
511577 angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True )
512- half_angles = angles * 0.5
513- eps = 1e-6
514- small_angles = angles .abs () < eps
515- sin_half_angles_over_angles = torch .empty_like (angles )
516- sin_half_angles_over_angles [~ small_angles ] = (
517- torch .sin (half_angles [~ small_angles ]) / angles [~ small_angles ]
518- )
519- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
520- # so sin(x/2)/x is about 1/2 - (x*x)/48
521- sin_half_angles_over_angles [small_angles ] = (
522- 0.5 - (angles [small_angles ] * angles [small_angles ]) / 48
578+ sin_half_angles_over_angles = 0.5 * torch .sinc (angles * 0.5 / torch .pi )
579+ return torch .cat (
580+ [torch .cos (angles * 0.5 ), axis_angle * sin_half_angles_over_angles ], dim = - 1
523581 )
524- quaternions = torch .cat (
525- [torch .cos (half_angles ), axis_angle * sin_half_angles_over_angles ], dim = - 1
526- )
527- return quaternions
528582
529583
530584def quaternion_to_axis_angle (quaternions : torch .Tensor ) -> torch .Tensor :
@@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
543597 """
544598 norms = torch .norm (quaternions [..., 1 :], p = 2 , dim = - 1 , keepdim = True )
545599 half_angles = torch .atan2 (norms , quaternions [..., :1 ])
546- angles = 2 * half_angles
547- eps = 1e-6
548- small_angles = angles .abs () < eps
549- sin_half_angles_over_angles = torch .empty_like (angles )
550- sin_half_angles_over_angles [~ small_angles ] = (
551- torch .sin (half_angles [~ small_angles ]) / angles [~ small_angles ]
552- )
553- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
554- # so sin(x/2)/x is about 1/2 - (x*x)/48
555- sin_half_angles_over_angles [small_angles ] = (
556- 0.5 - (angles [small_angles ] * angles [small_angles ]) / 48
557- )
600+ sin_half_angles_over_angles = 0.5 * torch .sinc (half_angles / torch .pi )
601+ # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
602+ # can't be zero
558603 return quaternions [..., 1 :] / sin_half_angles_over_angles
559604
560605
0 commit comments