Skip to content

Commit 87de2f9

Browse files
test for strided syrk, trmm, trsm
Test for strided syrk, trmm,trsm
2 parents dd2332b + 3cdd8d1 commit 87de2f9

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

test/rocarray/blas.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,15 @@ end
404404
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
405405
@test collect(dC) triu(A) \ B
406406
end
407+
@testset "trsm strided" begin
408+
A = rand(T, 2m, 2m)
409+
dA = view(ROCArray(A),1:m,1:m)
410+
B = rand(T, 2m, 2m)
411+
dB = view(ROCArray(B),1:m,1:m)
412+
dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB)
413+
@test collect(dC) triu(view(A,1:m,1:m)) \ view(B,1:m,1:m)
414+
end
415+
407416
@testset "trsm_batched" begin
408417
batch_count = 3
409418
A = [rand(T, m, m) for ix in 1:batch_count]
@@ -441,6 +450,15 @@ end
441450
@test collect(dC) triu(A) * B
442451
end
443452

453+
@testset "trmm strided" begin
454+
A = rand(T, 2m, 2m)
455+
dA = view(ROCArray(A),1:m,1:m)
456+
B = rand(T, 2m, 2m)
457+
dB = view(ROCArray(B),1:m,1:m)
458+
dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB)
459+
@test collect(dC) triu(view(A,1:m,1:m)) * view(B,1:m,1:m)
460+
end
461+
444462
@testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in (
445463
(UpperTriangular, identity, LowerTriangular, identity),
446464
(LowerTriangular, identity, UpperTriangular, identity),
@@ -578,6 +596,26 @@ end
578596
h_C = triu(h_C)
579597
@test C h_C
580598
end
599+
@testset "syrk strided T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
600+
# generate parameters
601+
α = rand(T)
602+
β = rand(T)
603+
A = rand(T, 2m, 2m)
604+
Abad = rand(T, 2m + 1, 2m + 1)
605+
C = rand(T, 2m, 2m)
606+
# move to device
607+
d_A, d_Abad = ROCArray(A), ROCArray(Abad)
608+
C = C + transpose(C)
609+
d_C = ROCArray(C)
610+
A_view = view(A,1:m,1:m)
611+
C = α*(A_view*transpose(A_view)) + β*view(C,1:m,1:m)
612+
rocBLAS.syrk!('U','N',α,view(d_A,1:m,1:m),β,view(d_C,1:m,1:m))
613+
# move back to host and compare
614+
C = triu(C)
615+
h_C = Array(view(d_C,1:m,1:m))
616+
h_C = triu(h_C)
617+
@test C h_C
618+
end
581619
@testset "syr2k T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
582620
# generate parameters
583621
α = rand(T)

0 commit comments

Comments
 (0)