From bb3449762e9d37d7e97e5a5a6a69bd9081394303 Mon Sep 17 00:00:00 2001 From: tharittk <tharit.tangkij@gmail.com> Date: Thu, 27 Feb 2025 21:49:13 +0700 Subject: [PATCH 1/3] dot (2 args) and cross --- src/stdlibs/LinearAlgebra.jl | 31 +++++++++++++++++ test/integration/linear_algebra.jl | 54 ++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 29a9a28744..b771b74a6c 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -397,4 +397,35 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr return C end +function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray{T}) where {T} + lx = length(x) + if lx != length(y) + throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y)).")) + end + + if T <: Complex + return Ops.dot_general(Ops.conj(x), y; contracting_dimensions = [[1], [1]]) + else + return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) + end +end + +function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) where{T} + if !(length(a) == length(b) == 3) + throw(DimensionMismatch("cross product is only defined for vectors of length 3")) + end + a = materialize_traced_array(a) + b = materialize_traced_array(b) + + a1, a2, a3 = a + b1, b2, b3 = b + + c = [a2*b3-a3*b2, a3*b1-a1*b3, a1*b2-a2*b1] + c = TracedUtils.promote_to(TracedRArray{T, 1}, c) + + return TracedRNumber{T}((), c.mlir_data) +end + + + end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index cd804d150e..8aabadcb25 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -183,3 +183,57 @@ end end end end + +@testset "dot" begin + a = [1, 2, 3, 4] + b = [-2, 5, 6, 7] + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(4) + b = rand(4) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(Complex{Float64}, 4) + b = rand(Complex{Float64}, 4) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + ab_ra = @jit dot(a_ra, b_ra) + ab = dot(a,b) + + @test ab_ra ≈ ab + + # I found this strange. If in dot_generat I did not apply Ops.conj, + # calling the line below still gives test passed while the one above does not + # @test @jit dot(a_ra, b_ra) ≈ dot(a, b) +end + + +@testset "cross" begin + a = [1, 2, 3] + b = [-2, 5, 7] + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit cross(a_ra, b_ra) ≈ cross(a, b) + + a = rand(3) + b = rand(3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit dot(a_ra, b_ra) ≈ dot(a, b) + + a = rand(Complex{Float64}, 3) + b = rand(Complex{Float64}, 3) + a_ra = Reactant.to_rarray(a) + b_ra = Reactant.to_rarray(b) + + @test @jit cross(a_ra, b_ra) ≈ cross(a,b) +end + From 8ee2450e2135e4b3900bb8a262600167718164ff Mon Sep 17 00:00:00 2001 From: tharittk <tharit.tangkij@gmail.com> Date: Thu, 27 Feb 2025 23:32:18 +0700 Subject: [PATCH 2/3] add type promotion and adding adjoint --- src/stdlibs/LinearAlgebra.jl | 43 ++++++++++++++++++++++------ test/integration/linear_algebra.jl | 45 ++++++++++++++++++++++++------ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b771b74a6c..44886ae95d 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -397,20 +397,20 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr return C end -function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray{T}) where {T} +function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray) where {T} lx = length(x) if lx != length(y) throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y)).")) end if T <: Complex - return Ops.dot_general(Ops.conj(x), y; contracting_dimensions = [[1], [1]]) - else - return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) + x = Ops.conj(x) end + + return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) end -function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) where{T} +function LinearAlgebra.cross(a::AnyTracedRVector{T1}, b::AnyTracedRVector{T2}) where {T1, T2} if !(length(a) == length(b) == 3) throw(DimensionMismatch("cross product is only defined for vectors of length 3")) end @@ -419,13 +419,40 @@ function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) whe a1, a2, a3 = a b1, b2, b3 = b - c = [a2*b3-a3*b2, a3*b1-a1*b3, a1*b2-a2*b1] - c = TracedUtils.promote_to(TracedRArray{T, 1}, c) + + T = promote_type(T1, T2) - return TracedRNumber{T}((), c.mlir_data) + return TracedUtils.promote_to(TracedRArray{T, 1}, c) end +function LinearAlgebra.adjoint!(B::AnyTracedRVector{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes((size(B,1), size(B,2)), size(A)) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + AT = TracedUtils.promote_to(TracedRArray{T, 2}, A) + set_mlir_data!(B, get_mlir_data(Ops.reshape(AT, length(B)))) +end +function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes(size(B), (size(A, 1), size(A, 2))) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)]))) +end + +function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2} + LinearAlgebra.check_transpose_axes(size(B), size(A)) + T = promote_type(T1, T2) + if T <: Complex + A = Ops.conj(A) + end + AT = TracedUtils.promote_to(TracedRArray{T, 2}, Ops.transpose(A, [2,1])) + set_mlir_data!(B, get_mlir_data(AT)) +end end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 8aabadcb25..f49f1114d5 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -192,7 +192,7 @@ end @test @jit dot(a_ra, b_ra) ≈ dot(a, b) - a = rand(4) + a = rand(Int64, 4) b = rand(4) a_ra = Reactant.to_rarray(a) b_ra = Reactant.to_rarray(b) @@ -207,13 +207,8 @@ end ab = dot(a,b) @test ab_ra ≈ ab - - # I found this strange. If in dot_generat I did not apply Ops.conj, - # calling the line below still gives test passed while the one above does not - # @test @jit dot(a_ra, b_ra) ≈ dot(a, b) end - @testset "cross" begin a = [1, 2, 3] b = [-2, 5, 7] @@ -222,7 +217,7 @@ end @test @jit cross(a_ra, b_ra) ≈ cross(a, b) - a = rand(3) + a = rand(Int64, 3) b = rand(3) a_ra = Reactant.to_rarray(a) b_ra = Reactant.to_rarray(b) @@ -235,5 +230,39 @@ end b_ra = Reactant.to_rarray(b) @test @jit cross(a_ra, b_ra) ≈ cross(a,b) -end +end + +@testset "adjoint!" begin + v = zeros(5) + M = rand(1, 5) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit adjoint!(v_ra, M_ra) + @test v_ra ≈ adjoint!(v, M) + + v = rand(7) + M = zeros(1, 7) + v_ra = Reactant.to_rarray(v) + M_ra = Reactant.to_rarray(M) + + @jit adjoint!(M_ra, v_ra) + @test M_ra ≈ adjoint!(M, v) + + A = [1 2; 3 4; 5 6] + B = fill(Float64(0), (2,3)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + + @jit adjoint!(B_ra, A_ra) + @test B_ra ≈ adjoint!(B, A) + + A = rand(Complex{Float64}, (2, 3)) + B = rand(Complex{Float64}, (3, 2)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + + @jit adjoint!(B_ra, A_ra) + @test B_ra ≈ adjoint!(B, A) +end \ No newline at end of file From 24af2a7e66d401d01f95e5d3410845042fc7e4ae Mon Sep 17 00:00:00 2001 From: Avik Pal <avik.pal.2017@gmail.com> Date: Thu, 27 Feb 2025 18:01:21 -0500 Subject: [PATCH 3/3] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/stdlibs/LinearAlgebra.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b828ffc2f6..445105c1c9 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -434,6 +434,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRVector{T1}, A::AnyTracedRMatrix{T2} end AT = TracedUtils.promote_to(TracedRArray{T, 2}, A) set_mlir_data!(B, get_mlir_data(Ops.reshape(AT, length(B)))) + return B end function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}) where {T1, T2} @@ -443,6 +444,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2} A = Ops.conj(A) end set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)]))) + return B end function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2} @@ -453,6 +455,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2} end AT = TracedUtils.promote_to(TracedRArray{T, 2}, Ops.transpose(A, [2,1])) set_mlir_data!(B, get_mlir_data(AT)) + return B end function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T}