diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 29a9a28744..33ee380b62 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -397,4 +397,21 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr return C end +LinearAlgebra.transpose(a::AnyTracedRArray) = error("transpose not defined for $(typeof(a)).") + +function LinearAlgebra.transpose!(B::AnyTracedRVector, A::AnyTracedRMatrix) + LinearAlgebra.check_transpose_axes((size(B,1), size(B,2)), size(A)) + set_mlir_data!(B, get_mlir_data(Ops.reshape(A, length(B)))) +end + +function LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRVector) + LinearAlgebra.check_transpose_axes(size(B), (size(A, 1), size(A, 2))) + set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)]))) +end + +function LinearAlgebra.transpose!(B::AnyTracedRMatrix, A::AnyTracedRMatrix) + LinearAlgebra.check_transpose_axes(size(B), size(A)) + set_mlir_data!(B, get_mlir_data(Ops.transpose(A, [2,1]))) +end + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cd804d150e..82303b2db9 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -183,3 +183,29 @@ end end end end + +@testset "transpose!" begin + v = zeros(5) + M = rand(1, 5) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit transpose!(v_ra, M_ra) + @test v_ra ≈ transpose!(v, M) + + v = rand(7) + M = zeros(1, 7) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit transpose!(M_ra, v_ra) + @test M_ra ≈ transpose!(M, v) + + A = rand(3, 7) + B = rand(7, 3) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + @jit transpose!(B_ra, A_ra) + @test B_ra ≈ transpose!(B, A) +end +