@@ -15,6 +15,10 @@ using Adapt
15
15
import KernelAbstractions
16
16
import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
17
17
18
+ @static if isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.10
19
+ import KernelAbstractions: POCL
20
+ end
21
+
18
22
19
23
#
20
24
# Device functionality
@@ -40,30 +44,30 @@ Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = JlRefValue(adapt(to, r[])
40
44
# # executed on-device
41
45
42
46
# array type
47
+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
48
+ struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
49
+ data:: Vector{UInt8}
50
+ offset:: Int
51
+ dims:: Dims{N}
52
+ end
43
53
44
- struct JLDeviceArray{T, N} <: AbstractDeviceArray{T, N}
45
- data:: Vector{UInt8}
46
- offset:: Int
47
- dims:: Dims{N}
48
- end
54
+ Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
49
55
50
- Base. elsize (:: Type{<:JLDeviceArray{T}} ) where {T} = sizeof (T)
56
+ Base. size (x:: JLDeviceArray ) = x. dims
57
+ Base. sizeof (x:: JLDeviceArray ) = Base. elsize (x) * length (x)
51
58
52
- Base. size ( x:: JLDeviceArray ) = x . dims
53
- Base . sizeof (x :: JLDeviceArray ) = Base. elsize (x) * length (x)
59
+ Base. unsafe_convert ( :: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
60
+ convert (Ptr{T}, pointer (x . data)) + x . offset * Base. elsize (x)
54
61
55
- Base. unsafe_convert (:: Type{Ptr{T}} , x:: JLDeviceArray{T} ) where {T} =
56
- convert (Ptr{T}, pointer (x. data)) + x. offset* Base. elsize (x)
62
+ # conversion of untyped data to a typed Array
63
+ function typed_data (x:: JLDeviceArray{T} ) where {T}
64
+ unsafe_wrap (Array, pointer (x), x. dims)
65
+ end
57
66
58
- # conversion of untyped data to a typed Array
59
- function typed_data (x:: JLDeviceArray{T} ) where {T}
60
- unsafe_wrap (Array, pointer (x), x. dims)
67
+ @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
68
+ @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
61
69
end
62
70
63
- @inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
64
- @inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
65
-
66
-
67
71
#
68
72
# Host abstractions
69
73
#
@@ -236,7 +240,7 @@ Base.convert(::Type{T}, x::T) where T <: JLArray = x
236
240
237
241
# # broadcast
238
242
239
- using Base. Broadcast: BroadcastStyle, Broadcasted
243
+ import Base. Broadcast: BroadcastStyle, Broadcasted
240
244
241
245
struct JLArrayStyle{N} <: AbstractGPUArrayStyle{N} end
242
246
JLArrayStyle {M} (:: Val{N} ) where {N,M} = JLArrayStyle {N} ()
335
339
336
340
# # GPUArrays interfaces
337
341
338
- Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
339
- JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
342
+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
343
+ Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
344
+ JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
345
+ else
346
+ function Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N}
347
+ arr = typed_data (x)
348
+ Adapt. adapt_storage (POCL. KernelAdaptor ([pointer (arr)]), arr)
349
+ end
350
+ end
340
351
341
352
function GPUArrays. mapreducedim! (f, op, R:: AnyJLArray , A:: Union{AbstractArray,Broadcast.Broadcasted} ;
342
353
init= nothing )
@@ -377,10 +388,18 @@ KernelAbstractions.allocate(::JLBackend, ::Type{T}, dims::Tuple) where T = JLArr
377
388
return ndrange, workgroupsize, iterspace, dynamic
378
389
end
379
390
380
- KernelAbstractions. isgpu (b:: JLBackend ) = false
391
+ @static if isdefined (JLArrays. KernelAbstractions, :isgpu ) # KA v0.9
392
+ KernelAbstractions. isgpu (b:: JLBackend ) = false
393
+ end
381
394
382
- function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
383
- return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
395
+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
396
+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
397
+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
398
+ end
399
+ else
400
+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
401
+ return Kernel {typeof(KernelAbstractions.POCLBackend()), W, N, F} (KernelAbstractions. POCLBackend (), obj. f)
402
+ end
384
403
end
385
404
386
405
function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
391
410
392
411
Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
393
412
Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
394
- Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
413
+
414
+ @static if ! isdefined (JLArrays. KernelAbstractions, :POCL ) # KA v0.9
415
+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
416
+ else
417
+ Adapt. adapt_storage (:: KernelAbstractions.POCLBackend , a:: JLArrays.JLArray ) = convert (Array, a)
418
+ end
395
419
396
420
end
0 commit comments