Skip to content

Commit eae9751

Browse files
peterbell10pytorchmergebot
authored andcommitted
Fix linalg_eigvals invalid use of composite dispatch key (pytorch#121142)
`linalg_eigvals_out` calls into a dispatch stub, so only supports CPU and CUDA strided tensors but incorrectly claimed to be a composite op. `linalg_eigvals` also shouldn't defer to the out variant inside a `CompositeImplicitAutograd` op as not all types support out variants. Instead, I add a new helper `_linalg_eigvals` which does the same thing in a non-composite operator. Pull Request resolved: pytorch#121142 Approved by: https://github.com/lezcano
1 parent 393b4ab commit eae9751

7 files changed

Lines changed: 41 additions & 9 deletions

File tree

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include <ATen/ops/_linalg_eigh.h>
2929
#include <ATen/ops/_linalg_eigh_meta.h>
3030
#include <ATen/ops/_linalg_eigh_native.h>
31+
#include <ATen/ops/_linalg_eigvals.h>
32+
#include <ATen/ops/_linalg_eigvals_native.h>
3133
#include <ATen/ops/_linalg_solve_ex.h>
3234
#include <ATen/ops/_linalg_solve_ex_meta.h>
3335
#include <ATen/ops/_linalg_solve_ex_native.h>
@@ -3100,12 +3102,13 @@ Tensor linalg_eigvals(const Tensor& input) {
31003102
if (_may_require_fw_or_bw_grad(input)) {
31013103
return std::get<0>(at::linalg_eig(input));
31023104
}
3105+
return at::_linalg_eigvals(input);
3106+
}
31033107

3108+
Tensor _linalg_eigvals(const Tensor& input) {
31043109
ScalarType complex_dtype = toComplexType(input.scalar_type());
31053110
Tensor values = at::empty({0}, input.options().dtype(complex_dtype));
3106-
3107-
at::linalg_eigvals_outf(input, values);
3108-
3111+
linalg_eigvals_out(input, values);
31093112
return values;
31103113
}
31113114

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13812,11 +13812,18 @@
1381213812
dispatch:
1381313813
CPU, CUDA: linalg_eig_out
1381413814

13815+
- func: _linalg_eigvals(Tensor self) -> Tensor
13816+
python_module: linalg
13817+
dispatch:
13818+
CPU, CUDA: _linalg_eigvals
13819+
1381513820
- func: linalg_eigvals(Tensor self) -> Tensor
1381613821
python_module: linalg
1381713822

1381813823
- func: linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
1381913824
python_module: linalg
13825+
dispatch:
13826+
CPU, CUDA: linalg_eigvals_out
1382013827

1382113828
# This function is exposes the `compute_v` flag, which is then used to implement `linalg.eigh` and
1382213829
# `linalg.eigvalsh` as composite functions that call this one

test/expect/HasDecompTest.test_has_decomposition.expect

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ aten::_linalg_det
387387
aten::_linalg_det.result
388388
aten::_linalg_eigh
389389
aten::_linalg_eigh.eigenvalues
390+
aten::_linalg_eigvals
390391
aten::_linalg_slogdet
391392
aten::_linalg_slogdet.sign
392393
aten::_linalg_solve_ex
@@ -844,6 +845,7 @@ aten::linalg_cholesky_ex
844845
aten::linalg_cholesky_ex.L
845846
aten::linalg_eig
846847
aten::linalg_eig.out
848+
aten::linalg_eigvals.out
847849
aten::linalg_householder_product
848850
aten::linalg_householder_product.out
849851
aten::linalg_inv_ex

test/functorch/test_aotdispatch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4469,7 +4469,6 @@ def forward(self, x):
44694469
xfail('combinations', ''), # aten.masked_select.default
44704470
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
44714471
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
4472-
xfail('linalg.eigvals', ''), # aten.linalg_eig.default - couldn't find symbolic meta function/decomposition
44734472
xfail('linalg.lstsq', ''), # aten.linalg_lstsq.default - couldn't find symbolic meta function/decomposition
44744473
xfail('linalg.lstsq', 'grad_oriented'), # aten.linalg_lstsq.default - couldn't find symbolic meta funct...
44754474
xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/deco...

test/test_meta.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -651,8 +651,6 @@ def run_meta_crossref(
651651
torch.kthvalue : {f64, i32, i64, u8, i16, f16, bf16, i8, f32},
652652
torch.nn.functional.ctc_loss : {f64, f32},
653653
torch.nn.functional.gaussian_nll_loss : {f16, f64, bf16, f32},
654-
torch.linalg.eig : {f64, f32, c128, c64},
655-
torch.linalg.eigvals : {f64, f32, c128, c64},
656654
torch.linalg.lstsq : {f64, f32, c128, c64},
657655
}
658656

@@ -800,7 +798,6 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
800798
meta_dispatch_expected_failures = {
801799
aten.allclose.default: {f16, bf16, f32, f64, c64, c128}, # NotImplementedError: 'aten::_local_scalar_dense'
802800
aten.geqrf.default : {c64, c128, f64, f32},
803-
aten.linalg_eig.default : {c64, c128, f64, f32},
804801
aten.linalg_lstsq.default : {c64, c128, f64, f32},
805802
aten.masked_select.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
806803
aten.masked_select.out : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},

test/test_proxy_tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,8 +1875,6 @@ def f(t):
18751875
}
18761876

18771877
symbolic_tensor_failures = {
1878-
xfail('linalg.eig'),
1879-
xfail('linalg.eigvals'),
18801878
xfail('combinations', ''),
18811879
xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
18821880
xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...

torch/_meta_registrations.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,32 @@ def meta__linalg_eigh(
788788
return vals, vecs
789789

790790

791+
@register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
792+
@out_wrapper()
793+
def meta__linalg_eigvals(input: Tensor) -> Tensor:
794+
squareCheckInputs(input, "linalg.eigvals")
795+
complex_dtype = (
796+
input.dtype
797+
if utils.is_complex_dtype(input.dtype)
798+
else utils.corresponding_complex_dtype(input.dtype)
799+
)
800+
return input.new_empty(input.shape[:-1], dtype=complex_dtype)
801+
802+
803+
@register_meta([aten.linalg_eig])
804+
@out_wrapper("eigenvalues", "eigenvectors")
805+
def meta_linalg_eig(input: Tensor):
806+
squareCheckInputs(input, "linalg.eig")
807+
complex_dtype = (
808+
input.dtype
809+
if utils.is_complex_dtype(input.dtype)
810+
else utils.corresponding_complex_dtype(input.dtype)
811+
)
812+
values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
813+
vectors = input.new_empty(input.shape, dtype=complex_dtype)
814+
return values, vectors
815+
816+
791817
def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
792818
return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
793819

0 commit comments

Comments
 (0)