From bb3449762e9d37d7e97e5a5a6a69bd9081394303 Mon Sep 17 00:00:00 2001
From: tharittk <tharit.tangkij@gmail.com>
Date: Thu, 27 Feb 2025 21:49:13 +0700
Subject: [PATCH 1/3] dot (2 args) and cross

---
 src/stdlibs/LinearAlgebra.jl       | 31 +++++++++++++++++
 test/integration/linear_algebra.jl | 54 ++++++++++++++++++++++++++++++
 2 files changed, 85 insertions(+)

diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl
index 29a9a28744..b771b74a6c 100644
--- a/src/stdlibs/LinearAlgebra.jl
+++ b/src/stdlibs/LinearAlgebra.jl
@@ -397,4 +397,35 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr
     return C
 end
 
+function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray{T}) where {T}
+    lx = length(x)
+    if lx != length(y)
+         throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y))."))
+    end
+    
+    if T <: Complex
+        return Ops.dot_general(Ops.conj(x), y; contracting_dimensions = [[1], [1]]) 
+    else
+        return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) 
+    end
+end
+
+function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) where{T}
+    if !(length(a) == length(b) == 3)
+         throw(DimensionMismatch("cross product is only defined for vectors of length 3"))
+    end
+    a = materialize_traced_array(a)
+    b = materialize_traced_array(b)
+
+    a1, a2, a3 = a
+    b1, b2, b3 = b
+   
+    c = [a2*b3-a3*b2, a3*b1-a1*b3, a1*b2-a2*b1]
+    c = TracedUtils.promote_to(TracedRArray{T, 1}, c)
+
+    return TracedRNumber{T}((), c.mlir_data) 
+end
+
+
+
 end
diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl
index cd804d150e..8aabadcb25 100644
--- a/test/integration/linear_algebra.jl
+++ b/test/integration/linear_algebra.jl
@@ -183,3 +183,57 @@ end
         end
     end
 end
+
+@testset "dot" begin
+    a = [1, 2, 3, 4]
+    b = [-2, 5, 6, 7]
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+
+    @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
+
+    a = rand(4)
+    b = rand(4)
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+    
+    @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
+
+    a = rand(Complex{Float64}, 4)
+    b = rand(Complex{Float64}, 4)
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+    ab_ra = @jit dot(a_ra, b_ra)
+    ab = dot(a,b)
+    
+    @test ab_ra ≈ ab 
+    
+    # I found this strange. If in dot_generat I did not apply Ops.conj,
+    # calling the line below still gives test passed while the one above does not
+    # @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
+end
+
+
+@testset "cross" begin
+    a = [1, 2, 3]
+    b = [-2, 5, 7]
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+
+    @test @jit cross(a_ra, b_ra) ≈ cross(a, b)
+
+    a = rand(3)
+    b = rand(3)
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+    
+    @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
+
+    a = rand(Complex{Float64}, 3)
+    b = rand(Complex{Float64}, 3)
+    a_ra = Reactant.to_rarray(a)
+    b_ra = Reactant.to_rarray(b)
+
+    @test @jit cross(a_ra, b_ra) ≈ cross(a,b)
+end    
+

From 8ee2450e2135e4b3900bb8a262600167718164ff Mon Sep 17 00:00:00 2001
From: tharittk <tharit.tangkij@gmail.com>
Date: Thu, 27 Feb 2025 23:32:18 +0700
Subject: [PATCH 2/3] add type promotion and adding adjoint

---
 src/stdlibs/LinearAlgebra.jl       | 43 ++++++++++++++++++++++------
 test/integration/linear_algebra.jl | 45 ++++++++++++++++++++++++------
 2 files changed, 72 insertions(+), 16 deletions(-)

diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl
index b771b74a6c..44886ae95d 100644
--- a/src/stdlibs/LinearAlgebra.jl
+++ b/src/stdlibs/LinearAlgebra.jl
@@ -397,20 +397,20 @@ function LinearAlgebra._kron!(C::AnyTracedRMatrix, A::AnyTracedRMatrix, B::AnyTr
     return C
 end
 
-function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray{T}) where {T}
+function LinearAlgebra.dot(x::TracedRArray{T}, y::TracedRArray) where {T}
     lx = length(x)
     if lx != length(y)
          throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y))."))
     end
     
     if T <: Complex
-        return Ops.dot_general(Ops.conj(x), y; contracting_dimensions = [[1], [1]]) 
-    else
-        return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) 
+        x = Ops.conj(x)
     end
+    
+    return Ops.dot_general(x, y; contracting_dimensions = [[1], [1]]) 
 end
 
-function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) where{T}
+function LinearAlgebra.cross(a::AnyTracedRVector{T1}, b::AnyTracedRVector{T2}) where {T1, T2}
     if !(length(a) == length(b) == 3)
          throw(DimensionMismatch("cross product is only defined for vectors of length 3"))
     end
@@ -419,13 +419,40 @@ function LinearAlgebra.cross(a::AnyTracedRVector{T}, b::AnyTracedRVector{T}) whe
 
     a1, a2, a3 = a
     b1, b2, b3 = b
-   
     c = [a2*b3-a3*b2, a3*b1-a1*b3, a1*b2-a2*b1]
-    c = TracedUtils.promote_to(TracedRArray{T, 1}, c)
+    
+    T = promote_type(T1, T2)
 
-    return TracedRNumber{T}((), c.mlir_data) 
+    return TracedUtils.promote_to(TracedRArray{T, 1}, c)
 end
 
+function LinearAlgebra.adjoint!(B::AnyTracedRVector{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2}
+    LinearAlgebra.check_transpose_axes((size(B,1), size(B,2)), size(A))
+    T = promote_type(T1, T2)
+    if T <: Complex
+        A = Ops.conj(A)
+    end
+    AT = TracedUtils.promote_to(TracedRArray{T, 2}, A)
+    set_mlir_data!(B, get_mlir_data(Ops.reshape(AT, length(B))))
+end
 
+function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}) where {T1, T2}
+    LinearAlgebra.check_transpose_axes(size(B), (size(A, 1), size(A, 2)))
+    T = promote_type(T1, T2)
+    if T <: Complex
+        A = Ops.conj(A)
+    end
+    set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)])))
+end
+
+function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2}
+    LinearAlgebra.check_transpose_axes(size(B), size(A))
+    T = promote_type(T1, T2)
+    if T <: Complex
+        A = Ops.conj(A)
+    end
+    AT = TracedUtils.promote_to(TracedRArray{T, 2}, Ops.transpose(A, [2,1]))
+    set_mlir_data!(B, get_mlir_data(AT))
+end
 
 end
diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl
index 8aabadcb25..f49f1114d5 100644
--- a/test/integration/linear_algebra.jl
+++ b/test/integration/linear_algebra.jl
@@ -192,7 +192,7 @@ end
 
     @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
 
-    a = rand(4)
+    a = rand(Int64, 4)
     b = rand(4)
     a_ra = Reactant.to_rarray(a)
     b_ra = Reactant.to_rarray(b)
@@ -207,13 +207,8 @@ end
     ab = dot(a,b)
     
     @test ab_ra ≈ ab 
-    
-    # I found this strange. If in dot_generat I did not apply Ops.conj,
-    # calling the line below still gives test passed while the one above does not
-    # @test @jit dot(a_ra, b_ra) ≈ dot(a, b)
 end
 
-
 @testset "cross" begin
     a = [1, 2, 3]
     b = [-2, 5, 7]
@@ -222,7 +217,7 @@ end
 
     @test @jit cross(a_ra, b_ra) ≈ cross(a, b)
 
-    a = rand(3)
+    a = rand(Int64, 3)
     b = rand(3)
     a_ra = Reactant.to_rarray(a)
     b_ra = Reactant.to_rarray(b)
@@ -235,5 +230,39 @@ end
     b_ra = Reactant.to_rarray(b)
 
     @test @jit cross(a_ra, b_ra) ≈ cross(a,b)
-end    
+end  
+
+@testset "adjoint!" begin
+    v = zeros(5)
+    M = rand(1, 5)
+    v_ra = Reactant.to_rarray(v)
+    M_ra = Reactant.to_rarray(M)
+
+    @jit adjoint!(v_ra, M_ra)
+    @test v_ra ≈ adjoint!(v, M)
+
+    v = rand(7)
+    M = zeros(1, 7)
+    v_ra = Reactant.to_rarray(v)
+    M_ra = Reactant.to_rarray(M)
+
+    @jit adjoint!(M_ra, v_ra)
+    @test M_ra ≈ adjoint!(M, v)
+
+    A = [1 2; 3 4; 5 6]
+    B = fill(Float64(0), (2,3))
+    A_ra = Reactant.to_rarray(A)
+    B_ra = Reactant.to_rarray(B)
+    
+    @jit adjoint!(B_ra, A_ra)
+    @test B_ra ≈ adjoint!(B, A)
+
+    A = rand(Complex{Float64}, (2, 3))
+    B = rand(Complex{Float64}, (3, 2))
+    A_ra = Reactant.to_rarray(A)
+    B_ra = Reactant.to_rarray(B)
+    
+    @jit adjoint!(B_ra, A_ra)
+    @test B_ra ≈ adjoint!(B, A)
 
+end
\ No newline at end of file

From 24af2a7e66d401d01f95e5d3410845042fc7e4ae Mon Sep 17 00:00:00 2001
From: Avik Pal <avik.pal.2017@gmail.com>
Date: Thu, 27 Feb 2025 18:01:21 -0500
Subject: [PATCH 3/3] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>
---
 src/stdlibs/LinearAlgebra.jl | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl
index b828ffc2f6..445105c1c9 100644
--- a/src/stdlibs/LinearAlgebra.jl
+++ b/src/stdlibs/LinearAlgebra.jl
@@ -434,6 +434,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRVector{T1}, A::AnyTracedRMatrix{T2}
     end
     AT = TracedUtils.promote_to(TracedRArray{T, 2}, A)
     set_mlir_data!(B, get_mlir_data(Ops.reshape(AT, length(B))))
+    return B
 end
 
 function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}) where {T1, T2}
@@ -443,6 +444,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRVector{T2}
         A = Ops.conj(A)
     end
     set_mlir_data!(B, get_mlir_data(Ops.broadcast_in_dim(A, [2], [1, length(A)])))
+    return B
 end
 
 function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}) where {T1, T2}
@@ -453,6 +455,7 @@ function LinearAlgebra.adjoint!(B::AnyTracedRMatrix{T1}, A::AnyTracedRMatrix{T2}
     end
     AT = TracedUtils.promote_to(TracedRArray{T, 2}, Ops.transpose(A, [2,1]))
     set_mlir_data!(B, get_mlir_data(AT))
+    return B
 end
                   
 function LinearAlgebra.axpy!(α::Number, x::TracedRArray{T}, y::TracedRArray{T}) where {T}