diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 7ff8bcbebd..445105c1c9 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -397,6 +397,67 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr return C end +function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray) where {T} + lx = length(x) + if lx != length(y) + throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y)).")) + end + + if T <: Complex + x = Ops.conj(x) + end + + return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) +end + +function LinearAlgebra.cross(a::AnyTracedRVector{T1}, b::AnyTracedRVector{T2}) where {T1, T2} + if !(length(a) == length(b) == 3) + throw(DimensionMismatch("cross product is only defined for vectors of length 3")) + end + a = materialize_traced_array(a) + b = materialize_traced_array(b) + + a1, a2, a3 = a + b1, b2, b3 = b + c = [a2*b3-a3*b2, a3*b1-a1*b3, a1*b2-a2*b1] + + T = promote_type(T1, T2) + + return TracedUtils.promote_to(TracedRArray{T, 1}, c) +end + +function LinearAlgebra.adjoint!(B::AnyTracedRVector{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes((size(B,1), size(B,2)), size(A)) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + AT = TracedUtils.promote_to(TracedRArray{T, 2}, A) + set_mlir_data!(B, get_mlir_data(Ops.reshape(AT, length(B)))) + return B +end + +function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes(size(B), (size(A, 1), size(A, 2))) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)]))) + return B +end + +function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes(size(B), size(A)) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + AT = TracedUtils.promote_to(TracedRArray{T, 2}, Ops.transpose(A, [2,1])) + set_mlir_data!(B, get_mlir_data(AT)) + return B +end + function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T} if length(x) != length(y) throw(DimensionMismatch(lazy"x has length $(length(x)), but y has length $(length(y))")) diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 114b352524..516c317a66 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -184,6 +184,88 @@ end end end +@testset "dot" begin + a = [1, 2, 3, 4] + b = [-2, 5, 6, 7] + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(Int64, 4) + b = rand(4) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(Complex{Float64}, 4) + b = rand(Complex{Float64}, 4) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + ab_ra = @jit dot(a_ra, b_ra) + ab = dot(a,b) + + @test ab_ra ≈ ab +end + +@testset "cross" begin + a = [1, 2, 3] + b = [-2, 5, 7] + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit cross(a_ra, b_ra) ≈ cross(a, b) + + a = rand(Int64, 3) + b = rand(3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(Complex{Float64}, 3) + b = rand(Complex{Float64}, 3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit cross(a_ra, b_ra) ≈ cross(a,b) +end + +@testset "adjoint!" begin + v = zeros(5) + M = rand(1, 5) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit adjoint!(v_ra, M_ra) + @test v_ra ≈ adjoint!(v, M) + + v = rand(7) + M = zeros(1, 7) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit adjoint!(M_ra, v_ra) + @test M_ra ≈ adjoint!(M, v) + + A = [1 2; 3 4; 5 6] + B = fill(Float64(0), (2,3)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + + @jit adjoint!(B_ra, A_ra) + @test B_ra ≈ adjoint!(B, A) + + A = rand(Complex{Float64}, (2, 3)) + B = rand(Complex{Float64}, (3, 2)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + + @jit adjoint!(B_ra, A_ra) + @test B_ra ≈ adjoint!(B, A) +end + @testset "axpy!" begin α = 3 x = rand(Int64, 4) @@ -220,7 +302,6 @@ end @jit axpy!(α, x_ra, y_ra) @test y_ra ≈ axpy!(α, x, y) - end @testset "axpby!" begin @@ -262,8 +343,4 @@ end @jit axpby!(α, x_ra, β, y_ra) @test y_ra ≈ axpby!(α, x, β, y) - end - - -