Skip to content

Commit b2f01e0

Browse files
committed
Switch to test_transform
1 parent cdff90d commit b2f01e0

File tree

2 files changed

+25
-72
lines changed

2 files changed

+25
-72
lines changed

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ pocl_jll = "7.0"
2828
ParallelTestRunner = "1.0.1"
2929

3030
[sources]
31-
ParallelTestRunner = {url="https://github.com/JuliaTesting/ParallelTestRunner.jl", rev="vc/custom_record"}
31+
ParallelTestRunner = {url="https://github.com/JuliaTesting/ParallelTestRunner.jl", rev="vc/test_transform"}

test/runtests.jl

Lines changed: 24 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -8,96 +8,45 @@ import Test
88
## --platform selector
99
do_platform, platform_filter = ParallelTestRunner.extract_flag!(ARGS, "--platform", nothing)
1010

11-
custom_record_init = quote
12-
import ParallelTestRunner: Test
13-
struct OpenCLTestRecord <: ParallelTestRunner.AbstractTestRecord
14-
# TODO: Would it be better to wrap "ParallelTestRunner.TestRecord "
15-
value::Any # AbstractTestSet or TestSetException
16-
output::String # captured stdout/stderr
17-
18-
# stats
19-
time::Float64
20-
bytes::UInt64
21-
gctime::Float64
22-
rss::UInt64
23-
end
24-
function ParallelTestRunner.memory_usage(rec::OpenCLTestRecord)
25-
return rec.rss
26-
end
27-
function ParallelTestRunner.test_IOContext(::Type{OpenCLTestRecord}, stdout::IO, stderr::IO, lock::ReentrantLock, name_align::Int64)
28-
return ParallelTestRunner.test_IOContext(ParallelTestRunner.TestRecord, stdout, stderr, lock, name_align)
29-
end
30-
31-
const targets = []
32-
using OpenCL, IOCapture
33-
34-
function ParallelTestRunner.execute(::Type{OpenCLTestRecord}, mod, f, name, color, (; platform_filter))
11+
test_transform = function(test, expr)
12+
# targets is a global variable that is defined in init_code
13+
return quote
3514
if isempty(targets)
3615
for platform in cl.platforms(),
3716
device in cl.devices(platform)
38-
if platform_filter !== nothing
17+
if $(platform_filter) !== nothing
3918
# filter on the name or vendor
4019
names = lowercase.([platform.name, platform.vendor])
41-
if !any(contains(platform_filter), names)
20+
if !any(contains($(platform_filter)), names)
4221
continue
4322
end
4423
end
4524
push!(targets, (; platform, device))
4625
end
4726
if isempty(targets)
48-
if platform_filter === nothing
27+
if $(platform_filter) === nothing
4928
throw(ArgumentError("No OpenCL platforms found"))
5029
else
51-
throw(ArgumentError("No OpenCL platforms found matching $platform_filter"))
30+
throw(ArgumentError("No OpenCL platforms found matching $($(platform_filter))"))
5231
end
5332
end
5433
end
5534

5635
# some tests require native execution capabilities
57-
requires_il = name in ["atomics", "execution", "intrinsics", "kernelabstractions"] ||
58-
startswith(name, "gpuarrays/")
59-
60-
data = @eval mod begin
61-
GC.gc(true)
62-
Random.seed!(1)
63-
OpenCL.allowscalar(false)
64-
65-
mktemp() do path, io
66-
stats = redirect_stdio(stdout=io, stderr=io) do
67-
@timed try
68-
@testset $(Expr(:$, :name)) begin
69-
@testset "\$(device.name)" for (; platform, device) in $(Expr(:$, :targets))
70-
cl.platform!(platform)
71-
cl.device!(device)
72-
73-
if !$(Expr(:$, :requires_il)) || "cl_khr_il_program" in device.extensions
74-
$(Expr(:$, :f))
75-
end
76-
end
77-
end
78-
catch err
79-
isa(err, Test.TestSetException) || rethrow()
80-
81-
# return the error to package it into a TestRecord
82-
err
83-
end
84-
end
85-
close(io)
86-
output = read(path, String)
87-
(; testset=stats.value, output, stats.time, stats.bytes, stats.gctime)
36+
requires_il = $(test) in ["atomics", "execution", "intrinsics", "kernelabstractions"] ||
37+
startswith($(test), "gpuarrays/")
8838

39+
@testset "\$(device.name)" for (; platform, device) in targets
40+
cl.platform!(platform)
41+
cl.device!(device)
42+
43+
if !requires_il || "cl_khr_il_program" in device.extensions
44+
$(expr)
8945
end
9046
end
91-
92-
# process results
93-
rss = Sys.maxrss()
94-
record = OpenCLTestRecord(data..., rss)
95-
96-
GC.gc(true)
97-
return record
9847
end
99-
end # quote
100-
eval(custom_record_init)
48+
end
49+
10150

10251
# register custom tests that do not correspond to files in the test directory
10352
custom_tests = Dict{String, Expr}()
@@ -116,7 +65,8 @@ const GPUArraysTestSuite = let
11665
end
11766

11867
for name in keys(GPUArraysTestSuite.tests)
119-
custom_tests["GPUArraysTestSuite/$name"] = :(GPUArraysTestSuite.tests[$name](CLArray))
68+
test = "GPUArraysTestSuite/$name"
69+
custom_tests[test] = test_transform(test, :(GPUArraysTestSuite.tests[$name](CLArray)))
12070
end
12171

12272
function test_filter(test)
@@ -131,6 +81,9 @@ end
13181
const init_code = quote
13282
using OpenCL, pocl_jll
13383

84+
OpenCL.allowscalar(false)
85+
const targets = []
86+
13487
# GPUArrays has a testsuite that isn't part of the main package.
13588
# Include it directly.
13689
const GPUArraysTestSuite = let
@@ -186,5 +139,5 @@ const init_code = quote
186139
end
187140
end
188141

189-
runtests(OpenCL, ARGS; custom_tests, test_filter, init_code, custom_record_init,
190-
RecordType=OpenCLTestRecord, custom_args=(;platform_filter))
142+
143+
runtests(OpenCL, ARGS; custom_tests, test_filter, init_code, test_transform)

0 commit comments

Comments
 (0)