diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 7f5b2c6e..c024d717 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -63,6 +63,78 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) +# Calculate determinant via PLU decomp +def det(x: Array) -> Array: + import scipy.linalg + + # L has det 1 so don't need to worry about it + p, _, u = da.linalg.lu(x) + + # TODO: numerical stability? + u_det = da.prod(da.diag(u)) + + # Now, time to calculate determinant of p + + # (from reading the source code) + # We know that dask lu decomp forces square chunks + # We also know that the P matrix will only be non-zero + # for a block i, j if and only if i = j + + # So we will calculate the determinant of each block on + # the diagonal (of blocks) + + # This isn't ideal, but hopefully still lets out of core work + # properly since each block should be able to fit in memory + + blocks_shape = p.blocks.shape + n_row_blocks = blocks_shape[0] + + p_det = 1 + for i in range(n_row_blocks): + p_det *= scipy.linalg.det(p.blocks[i, i].compute()) + return p_det * u_det + +SlogdetResult = _linalg.SlogdetResult + +# Calculate determinant via PLU decomp +def slogdet(x: Array) -> Array: + import scipy.linalg + + # L has det 1 so don't need to worry about it + p, _, u = da.linalg.lu(x) + + u_diag = da.diag(u) + neg_cnt = (u_diag < 0).sum() + + u_logabsdet = da.sum(da.log(da.abs(u_diag))) + + # Now, time to calculate determinant of p + + # (from reading the source code) + # We know that dask lu decomp forces square chunks + # We also know that the P matrix will only be non-zero + # for a block i, j if and only if i = j + + # So we will calculate the determinant of each block on + # the diagonal (of blocks) + + # This isn't ideal, but hopefully still lets out of core work + # properly since each block should be able to fit in memory + + blocks_shape = p.blocks.shape + n_row_blocks = blocks_shape[0] + + sign = 1 + for i in range(n_row_blocks): + sign *= scipy.linalg.det(p.blocks[i, i].compute()) + + if neg_cnt % 2 != 0: + sign *= -1 + return SlogdetResult(sign, u_logabsdet) + + + + __all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", "matrix_transpose", "vecdot", "EighResult", "QRResult", "SlogdetResult", "SVDResult", "qr", diff --git a/dask-xfails.txt b/dask-xfails.txt index 0d74ecbb..4c90f221 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -80,6 +80,9 @@ array_api_tests/test_linalg.py::test_cholesky array_api_tests/test_linalg.py::test_tensordot # probably same reason for failing as numpy array_api_tests/test_linalg.py::test_trace +# our version depends on dask's LU, which doesn't support ndim > 2 +array_api_tests/test_linalg.py::test_det +array_api_tests/test_linalg.py::test_slogdet # AssertionError: out.dtype=uint64, but should be uint8 [tensordot(uint8, uint8)] array_api_tests/test_linalg.py::test_linalg_tensordot @@ -97,18 +100,14 @@ array_api_tests/test_linalg.py::test_linalg_matmul # Linalg - these don't exist in dask array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cross] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.det] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigh] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.eigvalsh] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_power] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv] -array_api_tests/test_signatures.py::test_extension_func_signature[linalg.slogdet] array_api_tests/test_linalg.py::test_cross -array_api_tests/test_linalg.py::test_det array_api_tests/test_linalg.py::test_eigh array_api_tests/test_linalg.py::test_eigvalsh array_api_tests/test_linalg.py::test_pinv -array_api_tests/test_linalg.py::test_slogdet array_api_tests/test_has_names.py::test_has_names[linalg-cross] array_api_tests/test_has_names.py::test_has_names[linalg-det] array_api_tests/test_has_names.py::test_has_names[linalg-eigh]