@@ -50,7 +50,8 @@ mutable struct TaskLocalState
5050 math_mode = something (default_math_mode[],
5151 Base. JLOptions (). fast_math== 1 ? FAST_MATH : DEFAULT_MATH)
5252 math_precision = something (default_math_precision[], :TensorFloat32 )
53- new (dev, ctx, Base. fill (nothing , ndevices ()), math_mode, math_precision)
53+ new (dev, ctx, Union{Nothing,CuStream}[nothing for _ in 1 : ndevices ()],
54+ math_mode, math_precision)
5455 end
5556end
5657
159160@inline function context! (f:: Function , ctx:: CuContext ; skip_destroyed:: Bool = false )
160161 # @inline so that the kwarg method is inlined too and we can const-prop skip_destroyed
161162 if isvalid (ctx)
162- old_ctx = context! (ctx)
163+ old_ctx = context! (ctx):: Union{CuContext,Nothing}
163164 try
164165 f ()
165166 finally
187188
188189const __device_contexts = LazyInitialized {Vector{Union{Nothing,CuContext}}} ()
189190device_contexts () = get! (__device_contexts) do
190- [nothing for _ in 1 : ndevices ()]
191+ Union{Nothing,CuContext} [nothing for _ in 1 : ndevices ()]
191192end
192193function device_context (i:: Int )
193194 contexts = device_contexts ()
@@ -419,8 +420,8 @@ function PerDevice{T}() where {T}
419420 PerDevice {T} (ReentrantLock (), values)
420421end
421422
422- get_values (x:: PerDevice ) = get! (x. values) do
423- Base . fill ( nothing , ndevices ())
423+ get_values (x:: PerDevice{T} ) where {T} = get! (x. values) do
424+ Union{Nothing,Tuple{CuContext,T}}[ nothing for _ in 1 : ndevices ()]
424425end
425426
426427function Base. get (x:: PerDevice , dev:: CuDevice , val)
@@ -437,20 +438,20 @@ function Base.get(x::PerDevice, dev::CuDevice, val)
437438 end
438439end
439440
440- function Base. get! (constructor:: F , x:: PerDevice , dev:: CuDevice ) where {F}
441+ function Base. get! (constructor:: F , x:: PerDevice{T} , dev:: CuDevice ) where {F, T }
441442 y = get_values (x)
442443 id = deviceid (dev)+ 1
443444 ctx = device_context (id) # may be nothing
444445 @inbounds begin
445446 # test-lock-test
446- if y[id] === nothing || y[id][1 ] != = ctx
447+ if y[id] === nothing || ( y[id]:: Tuple ) [1 ] != = ctx
447448 Base. @lock x. lock begin
448- if y[id] === nothing || y[id][1 ] != = ctx
449+ if y[id] === nothing || ( y[id]:: Tuple ) [1 ] != = ctx
449450 y[id] = (context (), constructor ())
450451 end
451452 end
452453 end
453- y[id][2 ]
454+ ( y[id]:: Tuple ) [2 ]
454455 end
455456end
456457
0 commit comments