13
13
BatchArray = 3
14
14
end
15
15
16
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
16
+ Base. @nospecializeinfer function traced_type_inner (
17
+ @nospecialize (T:: Type ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
18
+ )
17
19
if T === Any
18
20
return T
19
21
end
@@ -134,18 +136,36 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen,
134
136
throw (NoFieldMatchError (T, TT2))
135
137
end
136
138
137
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{Union{}} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
139
+ Base. @nospecializeinfer function traced_type_inner (
140
+ @nospecialize (T:: Type{Union{}} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
141
+ )
138
142
return T
139
143
end
140
144
141
- for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, AbstractFloat, Integer, RNumber)
142
- @eval Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:$T} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
145
+ for T in (
146
+ DataType,
147
+ Module,
148
+ Nothing,
149
+ Symbol,
150
+ AbstractChar,
151
+ AbstractString,
152
+ AbstractFloat,
153
+ Integer,
154
+ RNumber,
155
+ )
156
+ @eval Base. @nospecializeinfer function traced_type_inner (
157
+ @nospecialize (T:: Type{<:$T} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
158
+ )
143
159
return T
144
160
end
145
161
end
146
162
147
163
Base. @nospecializeinfer function traced_type_inner (
148
- @nospecialize (T:: Type{<:ReactantPrimitive} ), seen, @nospecialize (mode:: TraceMode ), @nospecialize (track_numbers:: Type ), @nospecialize (args:: Vararg )
164
+ @nospecialize (T:: Type{<:ReactantPrimitive} ),
165
+ seen,
166
+ @nospecialize (mode:: TraceMode ),
167
+ @nospecialize (track_numbers:: Type ),
168
+ @nospecialize (args:: Vararg )
149
169
)
150
170
if Mode == ArrayToConcrete && T <: track_numbers
151
171
return ConcreteRNumber{T}
@@ -154,7 +174,10 @@ Base.@nospecializeinfer function traced_type_inner(
154
174
end
155
175
156
176
Base. @nospecializeinfer function traced_type_inner (
157
- @nospecialize (C:: Type{<:Complex} ), seen, @nospecialize (mode:: TraceMode ), @nospecialize (args:: Vararg )
177
+ @nospecialize (C:: Type{<:Complex} ),
178
+ seen,
179
+ @nospecialize (mode:: TraceMode ),
180
+ @nospecialize (args:: Vararg )
158
181
)
159
182
if ! (C isa UnionAll)
160
183
return Complex{traced_type_inner (C. parameters[1 ], seen, mode, args... )}
@@ -163,7 +186,9 @@ Base.@nospecializeinfer function traced_type_inner(
163
186
end
164
187
end
165
188
166
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:Function} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
189
+ Base. @nospecializeinfer function traced_type_inner (
190
+ @nospecialize (T:: Type{<:Function} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
191
+ )
167
192
# functions are directly returned
168
193
if sizeof (T) == 0
169
194
return T
190
215
@inline is_concrete_tuple (x:: T2 ) where {T2} =
191
216
(x <: Tuple ) && ! (x === Tuple) && ! (x isa UnionAll)
192
217
193
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:Tuple} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
218
+ Base. @nospecializeinfer function traced_type_inner (
219
+ @nospecialize (T:: Type{<:Tuple} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
220
+ )
194
221
if ! Base. isconcretetype (T) || ! is_concrete_tuple (T) || T isa UnionAll
195
222
throw (AssertionError (" Type $T is not concrete type or concrete tuple" ))
196
223
elseif is_concrete_tuple (T) && any (T2 isa Core. TypeofVararg for T2 in T. parameters)
@@ -204,19 +231,27 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple
204
231
return Tuple{TT... }
205
232
end
206
233
207
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:NamedTuple} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
234
+ Base. @nospecializeinfer function traced_type_inner (
235
+ @nospecialize (T:: Type{<:NamedTuple} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
236
+ )
208
237
N = T. parameters[1 ]
209
238
V = T. parameters[2 ]
210
239
return NamedTuple{N,traced_type_inner (V, seen, mode, args... )}
211
240
end
212
241
213
-
214
242
Base. @nospecializeinfer @inline dict_key (:: Type{<:AbstractDict} ) = nothing
215
- Base. @nospecializeinfer @inline dict_key (:: Type{<:AbstractDict{K}} ) where K = K
243
+ Base. @nospecializeinfer @inline dict_key (:: Type{<:AbstractDict{K}} ) where {K} = K
216
244
Base. @nospecializeinfer @inline dict_value (:: Type{<:AbstractDict} ) = nothing
217
- Base. @nospecializeinfer @inline dict_value (:: Type{<:(AbstractDict{K,V} where K)} ) where V = V
245
+ Base. @nospecializeinfer @inline dict_value (
246
+ :: Type{<:(AbstractDict{K,V} where {K})}
247
+ ) where {V} = V
218
248
219
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:AbstractDict} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
249
+ Base. @nospecializeinfer function traced_type_inner (
250
+ @nospecialize (T:: Type{<:AbstractDict} ),
251
+ seen,
252
+ mode:: TraceMode ,
253
+ @nospecialize (args:: Vararg )
254
+ )
220
255
V = dict_value (T)
221
256
if V === nothing
222
257
return T
@@ -234,13 +269,16 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Abstr
234
269
if K != = nothing
235
270
return dictty{K,V2}
236
271
else
237
- return (dictty{KT,V2} where KT )
272
+ return (dictty{KT,V2} where {KT} )
238
273
end
239
274
end
240
275
end
241
276
242
277
Base. @nospecializeinfer function traced_type_inner (
243
- @nospecialize (T0:: Type{<:ConcreteRNumber} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
278
+ @nospecialize (T0:: Type{<:ConcreteRNumber} ),
279
+ seen,
280
+ mode:: TraceMode ,
281
+ @nospecialize (args:: Vararg )
244
282
)
245
283
T = T0. parameters[1 ]
246
284
if mode == ConcreteToTraced
@@ -251,15 +289,22 @@ Base.@nospecializeinfer function traced_type_inner(
251
289
throw (" Abstract RNumber cannot be made concrete" )
252
290
end
253
291
end
254
-
255
- Base. @nospecializeinfer @inline base_typet (@nospecialize (TV:: UnionAll )) = UnionAll (TV. var, base_typet (TV. body))
256
- Base. @nospecializeinfer @inline base_typet (@nospecialize (TV:: DataType )) = TracedRArray{TV. parameters... }
257
-
258
- Base. @nospecializeinfer @inline base_typec (@nospecialize (TV:: UnionAll )) = UnionAll (TV. var, base_typec (TV. body))
259
- Base. @nospecializeinfer @inline base_typec (@nospecialize (TV:: DataType )) = (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV. parameters... }
292
+
293
+ Base. @nospecializeinfer @inline base_typet (@nospecialize (TV:: UnionAll )) =
294
+ UnionAll (TV. var, base_typet (TV. body))
295
+ Base. @nospecializeinfer @inline base_typet (@nospecialize (TV:: DataType )) =
296
+ TracedRArray{TV. parameters... }
297
+
298
+ Base. @nospecializeinfer @inline base_typec (@nospecialize (TV:: UnionAll )) =
299
+ UnionAll (TV. var, base_typec (TV. body))
300
+ Base. @nospecializeinfer @inline base_typec (@nospecialize (TV:: DataType )) =
301
+ (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV. parameters... }
260
302
261
303
Base. @nospecializeinfer function traced_type_inner (
262
- @nospecialize (T:: Type{<:ConcreteRArray} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
304
+ @nospecialize (T:: Type{<:ConcreteRArray} ),
305
+ seen,
306
+ mode:: TraceMode ,
307
+ @nospecialize (args:: Vararg )
263
308
)
264
309
if mode == ConcreteToTraced
265
310
return base_typet (T)
@@ -270,7 +315,12 @@ Base.@nospecializeinfer function traced_type_inner(
270
315
end
271
316
end
272
317
273
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:ConcreteRNG} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
318
+ Base. @nospecializeinfer function traced_type_inner (
319
+ @nospecialize (T:: Type{<:ConcreteRNG} ),
320
+ seen,
321
+ mode:: TraceMode ,
322
+ @nospecialize (args:: Vararg )
323
+ )
274
324
if mode == ConcreteToTraced
275
325
return TracedRNG
276
326
elseif mode == TracedToConcrete
@@ -281,16 +331,23 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Concr
281
331
end
282
332
283
333
Base. @nospecializeinfer function traced_type_inner (
284
- :: Type{<:MissingTracedValue} , seen, mode:: TraceMode , @nospecialize (track_numbers), @nospecialize (batchmode), @nospecialize (tobatch)
334
+ :: Type{<:MissingTracedValue} ,
335
+ seen,
336
+ mode:: TraceMode ,
337
+ @nospecialize (track_numbers),
338
+ @nospecialize (batchmode),
339
+ @nospecialize (tobatch)
285
340
)
286
- error (" This should not happen..." )
341
+ return error (" This should not happen..." )
287
342
end
288
343
289
- @inline base_typec (TV:: TT ) where {TT<: UnionAll } = UnionAll (TV. var, base_typec (TV. body))
290
- @inline base_typec (TV:: TT ) where {TT<: DataType } = ConcreteRArray{TV. parameters... }
291
-
292
344
Base. @nospecializeinfer function traced_type_inner (
293
- TR:: Type{<:TracedRArray} , seen, mode:: TraceMode , @nospecialize (track_numbers), @nospecialize (batchmode), @nospecialize (tobatch)
345
+ TR:: Type{<:TracedRArray} ,
346
+ seen,
347
+ mode:: TraceMode ,
348
+ @nospecialize (track_numbers),
349
+ @nospecialize (batchmode),
350
+ @nospecialize (tobatch)
294
351
)
295
352
T = TR. parameters[1 ]
296
353
N = TR. parameters[2 ]
@@ -317,7 +374,9 @@ Base.@nospecializeinfer function traced_type_inner(
317
374
end
318
375
end
319
376
320
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:TracedRNG} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
377
+ Base. @nospecializeinfer function traced_type_inner (
378
+ @nospecialize (T:: Type{<:TracedRNG} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
379
+ )
321
380
if mode == ConcreteToTraced
322
381
throw (" TracedRNG cannot be traced" )
323
382
elseif mode == TracedToConcrete
@@ -329,7 +388,9 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Trace
329
388
end
330
389
end
331
390
332
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (T:: Type{<:XLAArray} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
391
+ Base. @nospecializeinfer function traced_type_inner (
392
+ @nospecialize (T:: Type{<:XLAArray} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
393
+ )
333
394
throw (" XLA $T array cannot be traced" )
334
395
end
335
396
@@ -346,13 +407,20 @@ Base.@nospecializeinfer function traced_type_inner(
346
407
end
347
408
348
409
for P in (Ptr, Core. LLVMPtr, Base. RefValue)
349
- @eval Base. @nospecializeinfer function traced_type_inner (@nospecialize (PT:: Type{<:$P} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg ))
410
+ @eval Base. @nospecializeinfer function traced_type_inner (
411
+ @nospecialize (PT:: Type{<:$P} ), seen, mode:: TraceMode , @nospecialize (args:: Vararg )
412
+ )
350
413
T = eltype (PT)
351
414
return $ P{traced_type_inner (T, seen, mode, args... )}
352
415
end
353
416
end
354
417
355
- Base. @nospecializeinfer function traced_type_inner (@nospecialize (VT:: Type{<:Val} ), seen, @nospecialize (mode:: TraceMode ), @nospecialize (args:: Vararg ))
418
+ Base. @nospecializeinfer function traced_type_inner (
419
+ @nospecialize (VT:: Type{<:Val} ),
420
+ seen,
421
+ @nospecialize (mode:: TraceMode ),
422
+ @nospecialize (args:: Vararg )
423
+ )
356
424
if VT isa UnionAll
357
425
return VT
358
426
end
@@ -363,7 +431,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}
363
431
throw (" Val type $(Val{T}) cannot be traced" )
364
432
end
365
433
366
- const traced_type_cache = Dict {Tuple{TraceMode, Type}, Dict{Type, Type}} ()
434
+ const traced_type_cache = Dict {Tuple{TraceMode,Type,Any,Any}, Dict{Type,Type}} ()
367
435
368
436
# function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type))
369
437
# @nospecialize
@@ -456,16 +524,18 @@ const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}()
456
524
# $(Expr(:meta, :generated, traced_type_generator))
457
525
# end
458
526
459
- Base. @assume_effects :total @inline function traced_type (T:: Type , :: Val{mode} , track_numbers:: Type ) where mode
527
+ Base. @assume_effects :total @inline function traced_type (
528
+ T:: Type , :: Val{mode} , track_numbers:: Type , batchmode, tobatch
529
+ ) where {mode}
460
530
cache = nothing
461
- cache_key = (mode, track_numbers)
531
+ cache_key = (mode, track_numbers, batchmode, tobatch )
462
532
if haskey (traced_type_cache, cache_key)
463
533
cache = traced_type_cache[cache_key]
464
534
else
465
- cache = Dict {Type, Type} ()
535
+ cache = Dict {Type,Type} ()
466
536
traced_type_cache[cache_key] = cache
467
537
end
468
- res1 = traced_type_inner (T, cache, mode, track_numbers)
538
+ return traced_type_inner (T, cache, mode, track_numbers, batchmode, tobatch )
469
539
end
470
540
471
541
abstract type TracedTypeException <: Exception end
@@ -506,9 +576,9 @@ function make_tracer(
506
576
@nospecialize (prev),
507
577
@nospecialize (path),
508
578
mode;
509
- @nospecialize (track_numbers:: Type = Union{}),
510
- @nospecialize (batchmode= BatchNone),
511
- @nospecialize (tobatch= nothing ),
579
+ @nospecialize (track_numbers:: Type = Union{}),
580
+ @nospecialize (batchmode = BatchNone),
581
+ @nospecialize (tobatch = nothing ),
512
582
kwargs... ,
513
583
)
514
584
if mode != NoStopTracedTrack && haskey (seen, prev)
@@ -601,7 +671,9 @@ function make_tracer(
601
671
return res
602
672
end
603
673
604
- function make_tracer (seen, prev:: ConcreteRNumber{T} , @nospecialize (path), mode; kwargs... ) where {T}
674
+ function make_tracer (
675
+ seen, prev:: ConcreteRNumber{T} , @nospecialize (path), mode; kwargs...
676
+ ) where {T}
605
677
if mode == ArrayToConcrete
606
678
return prev
607
679
end
@@ -772,7 +844,12 @@ function make_tracer(
772
844
end
773
845
774
846
function make_tracer (
775
- seen, @nospecialize (prev:: Number ), @nospecialize (path), mode; @nospecialize (track_numbers:: Type = Union{}), kwargs...
847
+ seen,
848
+ @nospecialize (prev:: Number ),
849
+ @nospecialize (path),
850
+ mode;
851
+ @nospecialize (track_numbers:: Type = Union{}),
852
+ kwargs... ,
776
853
)
777
854
RT = Core. Typeof (prev)
778
855
if RT <: track_numbers
@@ -815,7 +892,14 @@ function make_tracer(
815
892
end
816
893
817
894
function make_tracer (
818
- seen, @nospecialize (prev:: Array ), @nospecialize (path), mode; @nospecialize (track_numbers:: Type = Union{}), @nospecialize (batchmode= BatchNone), @nospecialize (tobatch= nothing ), kwargs...
895
+ seen,
896
+ @nospecialize (prev:: Array ),
897
+ @nospecialize (path),
898
+ mode;
899
+ @nospecialize (track_numbers:: Type = Union{}),
900
+ @nospecialize (batchmode = BatchNone),
901
+ @nospecialize (tobatch = nothing ),
902
+ kwargs... ,
819
903
)
820
904
RT = Core. Typeof (prev)
821
905
if mode != NoStopTracedTrack && haskey (seen, prev)
@@ -854,9 +938,7 @@ function make_tracer(
854
938
return newa
855
939
end
856
940
857
- function make_tracer (
858
- seen, @nospecialize (prev:: Tuple ), @nospecialize (path), mode; kwargs...
859
- )
941
+ function make_tracer (seen, @nospecialize (prev:: Tuple ), @nospecialize (path), mode; kwargs... )
860
942
return (
861
943
(
862
944
make_tracer (seen, v, append_path (path, i), mode; kwargs... ) for
@@ -870,9 +952,9 @@ function make_tracer(
870
952
@nospecialize (prev:: NamedTuple ),
871
953
@nospecialize (path),
872
954
mode;
873
- @nospecialize (track_numbers:: Type = Union{}),
874
- @nospecialize (batchmode= BatchNone),
875
- @nospecialize (tobatch= nothing ),
955
+ @nospecialize (track_numbers:: Type = Union{}),
956
+ @nospecialize (batchmode = BatchNone),
957
+ @nospecialize (tobatch = nothing ),
876
958
kwargs... ,
877
959
)
878
960
NT = Core. Typeof (prev)
@@ -918,15 +1000,23 @@ end
918
1000
return make_tracer (OrderedIdDict (), x, (), Reactant. ArrayToConcrete; track_numbers)
919
1001
end
920
1002
921
- function to_rarray_internal (@nospecialize (:: TracedRArray ), @nospecialize (track_numbers:: Type ))
1003
+ function to_rarray_internal (
1004
+ @nospecialize (:: TracedRArray ), @nospecialize (track_numbers:: Type )
1005
+ )
922
1006
return error (" Cannot convert TracedRArray to ConcreteRArray" )
923
1007
end
924
- @inline to_rarray_internal (@nospecialize (x:: ConcreteRArray ), @nospecialize (track_numbers:: Type )) = x
925
- @inline function to_rarray_internal (@nospecialize (x:: Array{<:ReactantPrimitive} ), @nospecialize (track_numbers:: Type ))
1008
+ @inline to_rarray_internal (
1009
+ @nospecialize (x:: ConcreteRArray ), @nospecialize (track_numbers:: Type )
1010
+ ) = x
1011
+ @inline function to_rarray_internal (
1012
+ @nospecialize (x:: Array{<:ReactantPrimitive} ), @nospecialize (track_numbers:: Type )
1013
+ )
926
1014
return ConcreteRArray (x)
927
1015
end
928
1016
929
- @inline to_rarray_internal (@nospecialize (x:: ConcreteRNumber ), @nospecialize (track_numbers:: Type )) = x
1017
+ @inline to_rarray_internal (
1018
+ @nospecialize (x:: ConcreteRNumber ), @nospecialize (track_numbers:: Type )
1019
+ ) = x
930
1020
@inline function to_rarray_internal (
931
1021
@nospecialize (x:: ReactantPrimitive ), @nospecialize (track_numbers:: Type )
932
1022
)
0 commit comments