@@ -463,7 +463,7 @@ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Ten
463
463
return out [..., 1 :]
464
464
465
465
466
- def axis_angle_to_matrix (axis_angle : torch .Tensor ) -> torch .Tensor :
466
+ def axis_angle_to_matrix (axis_angle : torch .Tensor , fast : bool = False ) -> torch .Tensor :
467
467
"""
468
468
Convert rotations given as axis/angle to rotation matrices.
469
469
@@ -472,27 +472,93 @@ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
472
472
as a tensor of shape (..., 3), where the magnitude is
473
473
the angle turned anticlockwise in radians around the
474
474
vector's direction.
475
+ fast: Whether to use the new faster implementation (based on the
476
+ Rodrigues formula) instead of the original implementation (which
477
+ first converted to a quaternion and then back to a rotation matrix).
475
478
476
479
Returns:
477
480
Rotation matrices as tensor of shape (..., 3, 3).
478
481
"""
479
- return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
482
+ if not fast :
483
+ return quaternion_to_matrix (axis_angle_to_quaternion (axis_angle ))
480
484
485
+ shape = axis_angle .shape
486
+ device , dtype = axis_angle .device , axis_angle .dtype
481
487
482
- def matrix_to_axis_angle (matrix : torch .Tensor ) -> torch .Tensor :
488
+ angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True ).unsqueeze (- 1 )
489
+
490
+ rx , ry , rz = axis_angle [..., 0 ], axis_angle [..., 1 ], axis_angle [..., 2 ]
491
+ zeros = torch .zeros (shape [:- 1 ], dtype = dtype , device = device )
492
+ cross_product_matrix = torch .stack (
493
+ [zeros , - rz , ry , rz , zeros , - rx , - ry , rx , zeros ], dim = - 1
494
+ ).view (shape + (3 ,))
495
+ cross_product_matrix_sqrd = cross_product_matrix @ cross_product_matrix
496
+
497
+ identity = torch .eye (3 , dtype = dtype , device = device )
498
+ angles_sqrd = angles * angles
499
+ angles_sqrd = torch .where (angles_sqrd == 0 , 1 , angles_sqrd )
500
+ return (
501
+ identity .expand (cross_product_matrix .shape )
502
+ + torch .sinc (angles / torch .pi ) * cross_product_matrix
503
+ + ((1 - torch .cos (angles )) / angles_sqrd ) * cross_product_matrix_sqrd
504
+ )
505
+
506
+
507
+ def matrix_to_axis_angle (matrix : torch .Tensor , fast : bool = False ) -> torch .Tensor :
483
508
"""
484
509
Convert rotations given as rotation matrices to axis/angle.
485
510
486
511
Args:
487
512
matrix: Rotation matrices as tensor of shape (..., 3, 3).
513
+ fast: Whether to use the new faster implementation (based on the
514
+ Rodrigues formula) instead of the original implementation (which
515
+ first converted to a quaternion and then back to a rotation matrix).
488
516
489
517
Returns:
490
518
Rotations given as a vector in axis angle form, as a tensor
491
519
of shape (..., 3), where the magnitude is the angle
492
520
turned anticlockwise in radians around the vector's
493
521
direction.
522
+
494
523
"""
495
- return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
524
+ if not fast :
525
+ return quaternion_to_axis_angle (matrix_to_quaternion (matrix ))
526
+
527
+ if matrix .size (- 1 ) != 3 or matrix .size (- 2 ) != 3 :
528
+ raise ValueError (f"Invalid rotation matrix shape { matrix .shape } ." )
529
+
530
+ omegas = torch .stack (
531
+ [
532
+ matrix [..., 2 , 1 ] - matrix [..., 1 , 2 ],
533
+ matrix [..., 0 , 2 ] - matrix [..., 2 , 0 ],
534
+ matrix [..., 1 , 0 ] - matrix [..., 0 , 1 ],
535
+ ],
536
+ dim = - 1 ,
537
+ )
538
+ norms = torch .norm (omegas , p = 2 , dim = - 1 , keepdim = True )
539
+ traces = torch .diagonal (matrix , dim1 = - 2 , dim2 = - 1 ).sum (- 1 ).unsqueeze (- 1 )
540
+ angles = torch .atan2 (norms , traces - 1 )
541
+
542
+ zeros = torch .zeros (3 , dtype = matrix .dtype , device = matrix .device )
543
+ omegas = torch .where (torch .isclose (angles , torch .zeros_like (angles )), zeros , omegas )
544
+
545
+ near_pi = torch .isclose (((traces - 1 ) / 2 ).abs (), torch .ones_like (traces )).squeeze (
546
+ - 1
547
+ )
548
+
549
+ axis_angles = torch .empty_like (omegas )
550
+ axis_angles [~ near_pi ] = (
551
+ 0.5 * omegas [~ near_pi ] / torch .sinc (angles [~ near_pi ] / torch .pi )
552
+ )
553
+
554
+ # this derives from: nnT = (R + 1) / 2
555
+ n = 0.5 * (
556
+ matrix [near_pi ][..., 0 , :]
557
+ + torch .eye (1 , 3 , dtype = matrix .dtype , device = matrix .device )
558
+ )
559
+ axis_angles [near_pi ] = angles [near_pi ] * n / torch .norm (n )
560
+
561
+ return axis_angles
496
562
497
563
498
564
def axis_angle_to_quaternion (axis_angle : torch .Tensor ) -> torch .Tensor :
@@ -509,22 +575,10 @@ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
509
575
quaternions with real part first, as tensor of shape (..., 4).
510
576
"""
511
577
angles = torch .norm (axis_angle , p = 2 , dim = - 1 , keepdim = True )
512
- half_angles = angles * 0.5
513
- eps = 1e-6
514
- small_angles = angles .abs () < eps
515
- sin_half_angles_over_angles = torch .empty_like (angles )
516
- sin_half_angles_over_angles [~ small_angles ] = (
517
- torch .sin (half_angles [~ small_angles ]) / angles [~ small_angles ]
518
- )
519
- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
520
- # so sin(x/2)/x is about 1/2 - (x*x)/48
521
- sin_half_angles_over_angles [small_angles ] = (
522
- 0.5 - (angles [small_angles ] * angles [small_angles ]) / 48
578
+ sin_half_angles_over_angles = 0.5 * torch .sinc (angles * 0.5 / torch .pi )
579
+ return torch .cat (
580
+ [torch .cos (angles * 0.5 ), axis_angle * sin_half_angles_over_angles ], dim = - 1
523
581
)
524
- quaternions = torch .cat (
525
- [torch .cos (half_angles ), axis_angle * sin_half_angles_over_angles ], dim = - 1
526
- )
527
- return quaternions
528
582
529
583
530
584
def quaternion_to_axis_angle (quaternions : torch .Tensor ) -> torch .Tensor :
@@ -543,18 +597,9 @@ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
543
597
"""
544
598
norms = torch .norm (quaternions [..., 1 :], p = 2 , dim = - 1 , keepdim = True )
545
599
half_angles = torch .atan2 (norms , quaternions [..., :1 ])
546
- angles = 2 * half_angles
547
- eps = 1e-6
548
- small_angles = angles .abs () < eps
549
- sin_half_angles_over_angles = torch .empty_like (angles )
550
- sin_half_angles_over_angles [~ small_angles ] = (
551
- torch .sin (half_angles [~ small_angles ]) / angles [~ small_angles ]
552
- )
553
- # for x small, sin(x/2) is about x/2 - (x/2)^3/6
554
- # so sin(x/2)/x is about 1/2 - (x*x)/48
555
- sin_half_angles_over_angles [small_angles ] = (
556
- 0.5 - (angles [small_angles ] * angles [small_angles ]) / 48
557
- )
600
+ sin_half_angles_over_angles = 0.5 * torch .sinc (half_angles / torch .pi )
601
+ # angles/2 are between [-pi/2, pi/2], thus sin_half_angles_over_angles
602
+ # can't be zero
558
603
return quaternions [..., 1 :] / sin_half_angles_over_angles
559
604
560
605
0 commit comments