|
404 | 404 | dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB) |
405 | 405 | @test collect(dC) ≈ triu(A) \ B |
406 | 406 | end |
| 407 | + @testset "trsm strided" begin |
| 408 | + A = rand(T, 2m, 2m) |
| 409 | + dA = view(ROCArray(A),1:m,1:m) |
| 410 | + B = rand(T, 2m, 2m) |
| 411 | + dB = view(ROCArray(B),1:m,1:m) |
| 412 | + dC = rocBLAS.trsm('L', 'U', 'N', 'N', one(T), dA, dB) |
| 413 | + @test collect(dC) ≈ triu(view(A,1:m,1:m)) \ view(B,1:m,1:m) |
| 414 | + end |
| 415 | + |
407 | 416 | @testset "trsm_batched" begin |
408 | 417 | batch_count = 3 |
409 | 418 | A = [rand(T, m, m) for ix in 1:batch_count] |
|
441 | 450 | @test collect(dC) ≈ triu(A) * B |
442 | 451 | end |
443 | 452 |
|
| 453 | + @testset "trmm strided" begin |
| 454 | + A = rand(T, 2m, 2m) |
| 455 | + dA = view(ROCArray(A),1:m,1:m) |
| 456 | + B = rand(T, 2m, 2m) |
| 457 | + dB = view(ROCArray(B),1:m,1:m) |
| 458 | + dC = rocBLAS.trmm('L', 'U', 'N', 'N', one(T), dA, dB) |
| 459 | + @test collect(dC) ≈ triu(view(A,1:m,1:m)) * view(B,1:m,1:m) |
| 460 | + end |
| 461 | + |
444 | 462 | @testset "triangular-triangular mul" for (TRa, ta, TRb, tb) in ( |
445 | 463 | (UpperTriangular, identity, LowerTriangular, identity), |
446 | 464 | (LowerTriangular, identity, UpperTriangular, identity), |
|
578 | 596 | h_C = triu(h_C) |
579 | 597 | @test C ≈ h_C |
580 | 598 | end |
| 599 | + @testset "syrk strided T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 600 | + # generate parameters |
| 601 | + α = rand(T) |
| 602 | + β = rand(T) |
| 603 | + A = rand(T, 2m, 2m) |
| 604 | + Abad = rand(T, 2m + 1, 2m + 1) |
| 605 | + C = rand(T, 2m, 2m) |
| 606 | + # move to device |
| 607 | + d_A, d_Abad = ROCArray(A), ROCArray(Abad) |
| 608 | + C = C + transpose(C) |
| 609 | + d_C = ROCArray(C) |
| 610 | + A_view = view(A,1:m,1:m) |
| 611 | + C = α*(A_view*transpose(A_view)) + β*view(C,1:m,1:m) |
| 612 | + rocBLAS.syrk!('U','N',α,view(d_A,1:m,1:m),β,view(d_C,1:m,1:m)) |
| 613 | + # move back to host and compare |
| 614 | + C = triu(C) |
| 615 | + h_C = Array(view(d_C,1:m,1:m)) |
| 616 | + h_C = triu(h_C) |
| 617 | + @test C ≈ h_C |
| 618 | + end |
581 | 619 | @testset "syr2k T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
582 | 620 | # generate parameters |
583 | 621 | α = rand(T) |
|
0 commit comments