Skip to content

Commit eecd77a

Browse files
authored
Merge pull request #826
Coverage for Strided Arrays for syrk, trsm, trmm and potrf
2 parents 15824e6 + d66023e commit eecd77a

File tree

4 files changed

+60
-7
lines changed

4 files changed

+60
-7
lines changed

src/blas/wrappers.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ for (fname, elty) in ((:rocblas_dsyrk,:Float64),
814814
@eval begin
815815
function syrk!(
816816
uplo::Char, trans::Char, alpha::($elty),
817-
A::ROCVecOrMat{$elty}, beta::($elty), C::ROCMatrix{$elty},
817+
A::StridedROCVecOrMat{$elty}, beta::($elty), C::StridedROCMatrix{$elty},
818818
)
819819
mC, n = size(C)
820820
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -979,7 +979,7 @@ for (mmname, smname, elty) in
979979
@eval begin
980980
function trmm!(
981981
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
982-
A::ROCMatrix{$elty}, B::ROCMatrix{$elty}, C::ROCMatrix{$elty},
982+
A::StridedROCMatrix{$elty}, B::StridedROCMatrix{$elty}, C::StridedROCMatrix{$elty},
983983
)
984984
m, n = size(B)
985985
mA, nA = size(A)
@@ -997,13 +997,13 @@ for (mmname, smname, elty) in
997997
end
998998
function trmm(
999999
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1000-
A::ROCMatrix{$elty}, B::ROCMatrix{$elty},
1000+
A::StridedROCMatrix{$elty}, B::StridedROCMatrix{$elty},
10011001
)
10021002
trmm!(side, uplo, transa, diag, alpha, A, B, similar(B))
10031003
end
10041004
function trsm!(
10051005
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1006-
A::ROCMatrix{$elty}, B::ROCMatrix{$elty},
1006+
A::StridedROCMatrix{$elty}, B::StridedROCMatrix{$elty},
10071007
)
10081008
m, n = size(B)
10091009
mA, nA = size(A)
@@ -1018,7 +1018,7 @@ for (mmname, smname, elty) in
10181018
end
10191019
function trsm(
10201020
side::Char, uplo::Char, transa::Char, diag::Char, alpha::($elty),
1021-
A::ROCMatrix{$elty}, B::ROCMatrix{$elty},
1021+
A::StridedROCMatrix{$elty}, B::StridedROCMatrix{$elty},
10221022
)
10231023
trsm!(side, uplo, transa, diag, alpha, A, copy(B))
10241024
end

src/solver/highlevel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ for (fname, elty) in (
55
(:rocsolver_zpotrf, :ComplexF64),
66
)
77
@eval begin
8-
function potrf!(uplo::Char, A::ROCMatrix{$elty})
8+
function potrf!(uplo::Char, A::StridedROCMatrix{$elty})
99
chkuplo(uplo)
1010
n = checksquare(A)
1111
lda = max(1, stride(A, 2))
@@ -630,7 +630,7 @@ end
630630

631631
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
632632
@eval begin
633-
LinearAlgebra.LAPACK.potrf!(uplo::Char, A::ROCMatrix{$elty}) = rocSOLVER.potrf!(uplo, A)
633+
LinearAlgebra.LAPACK.potrf!(uplo::Char, A::StridedROCMatrix{$elty}) = rocSOLVER.potrf!(uplo, A)
634634
LinearAlgebra.LAPACK.potrs!(uplo::Char, A::ROCMatrix{$elty}, B::ROCVecOrMat{$elty}) = rocSOLVER.potrs!(uplo, A, B)
635635
LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::ROCMatrix{$elty}) = rocSOLVER.sytrf!(uplo, A)
636636
LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::ROCMatrix{$elty}, ipiv::ROCVector{Cint}) = rocSOLVER.sytrf!(uplo, A, ipiv)

test/rocarray/blas.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,15 @@ end
404404
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
405405
@test collect(dC) triu(A) \ B
406406
end
407+
@testset "trsm strided" begin
408+
A = rand(T, 2m, 2m)
409+
dA = view(ROCArray(A),1:m,1:m)
410+
B = rand(T, 2m, 2m)
411+
dB = view(ROCArray(B),1:m,1:m)
412+
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
413+
@test collect(dC) triu(view(A,1:m,1:m)) \ view(B,1:m,1:m)
414+
end
415+
407416
@testset "trsm_batched" begin
408417
batch_count = 3
409418
A = [rand(T, m, m) for ix in 1:batch_count]
@@ -441,6 +450,15 @@ end
441450
@test collect(dC) triu(A) * B
442451
end
443452

453+
@testset "trmm strided" begin
454+
A = rand(T, 2m, 2m)
455+
dA = view(ROCArray(A),1:m,1:m)
456+
B = rand(T, 2m, 2m)
457+
dB = view(ROCArray(B),1:m,1:m)
458+
dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB)
459+
@test collect(dC) triu(view(A,1:m,1:m)) * view(B,1:m,1:m)
460+
end
461+
444462
@testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in (
445463
(UpperTriangular, identity, LowerTriangular, identity),
446464
(LowerTriangular, identity, UpperTriangular, identity),
@@ -578,6 +596,26 @@ end
578596
h_C = triu(h_C)
579597
@test C h_C
580598
end
599+
@testset "syrk strided T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
600+
# generate parameters
601+
α = rand(T)
602+
β = rand(T)
603+
A = rand(T, 2m, 2m)
604+
Abad = rand(T, 2m + 1, 2m + 1)
605+
C = rand(T, 2m, 2m)
606+
# move to device
607+
d_A, d_Abad = ROCArray(A), ROCArray(Abad)
608+
C = C + transpose(C)
609+
d_C = ROCArray(C)
610+
A_view = view(A,1:m,1:m)
611+
C = α*(A_view*transpose(A_view)) + β*view(C,1:m,1:m)
612+
rocBLAS.syrk!('U','N',α,view(d_A,1:m,1:m),β,view(d_C,1:m,1:m))
613+
# move back to host and compare
614+
C = triu(C)
615+
h_C = Array(view(d_C,1:m,1:m))
616+
h_C = triu(h_C)
617+
@test C h_C
618+
end
581619
@testset "syr2k T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
582620
# generate parameters
583621
α = rand(T)

test/rocarray/solver.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,21 @@ end
183183
LAPACK.potrs!('U',A,B)
184184
@test B collect(d_B)
185185
end
186+
@testset "elty = $elty strided" for elty in [Float32, Float64, ComplexF32, ComplexF64]
187+
A = rand(elty,n*2,n*2)
188+
A = A*A' + I
189+
B = rand(elty,n,p)
190+
d_A = view(ROCArray(A),1:n,1:n)
191+
d_B = ROCArray(B)
192+
193+
LAPACK.potrf!('L',d_A)
194+
LAPACK.potrs!('U',copy(d_A),d_B)
195+
LAPACK.potrf!('L',view(A,1:n,1:n))
196+
LAPACK.potrs!('U',copy(view(A,1:n,1:n)),B)
197+
@test B collect(d_B)
198+
end
199+
200+
186201
end
187202

188203
@testset "sytrf!" begin

0 commit comments

Comments
 (0)