Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e2eb9fb

Browse files
committedJan 17, 2025·
fix: update traced_type
1 parent b823526 commit e2eb9fb

File tree

1 file changed

+144
-54
lines changed

1 file changed

+144
-54
lines changed
 

‎src/Tracing.jl

+144-54
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ end
1313
BatchArray = 3
1414
end
1515

16-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg))
16+
Base.@nospecializeinfer function traced_type_inner(
17+
@nospecialize(T::Type), seen, mode::TraceMode, @nospecialize(args::Vararg)
18+
)
1719
if T === Any
1820
return T
1921
end
@@ -134,18 +136,36 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type), seen,
134136
throw(NoFieldMatchError(T, TT2))
135137
end
136138

137-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg))
139+
Base.@nospecializeinfer function traced_type_inner(
140+
@nospecialize(T::Type{Union{}}), seen, mode::TraceMode, @nospecialize(args::Vararg)
141+
)
138142
return T
139143
end
140144

141-
for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, AbstractFloat, Integer, RNumber)
142-
@eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg))
145+
for T in (
146+
DataType,
147+
Module,
148+
Nothing,
149+
Symbol,
150+
AbstractChar,
151+
AbstractString,
152+
AbstractFloat,
153+
Integer,
154+
RNumber,
155+
)
156+
@eval Base.@nospecializeinfer function traced_type_inner(
157+
@nospecialize(T::Type{<:$T}), seen, mode::TraceMode, @nospecialize(args::Vararg)
158+
)
143159
return T
144160
end
145161
end
146162

147163
Base.@nospecializeinfer function traced_type_inner(
148-
@nospecialize(T::Type{<:ReactantPrimitive}), seen, @nospecialize(mode::TraceMode), @nospecialize(track_numbers::Type), @nospecialize(args::Vararg)
164+
@nospecialize(T::Type{<:ReactantPrimitive}),
165+
seen,
166+
@nospecialize(mode::TraceMode),
167+
@nospecialize(track_numbers::Type),
168+
@nospecialize(args::Vararg)
149169
)
150170
if Mode == ArrayToConcrete && T <: track_numbers
151171
return ConcreteRNumber{T}
@@ -154,7 +174,10 @@ Base.@nospecializeinfer function traced_type_inner(
154174
end
155175

156176
Base.@nospecializeinfer function traced_type_inner(
157-
@nospecialize(C::Type{<:Complex}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg)
177+
@nospecialize(C::Type{<:Complex}),
178+
seen,
179+
@nospecialize(mode::TraceMode),
180+
@nospecialize(args::Vararg)
158181
)
159182
if !(C isa UnionAll)
160183
return Complex{traced_type_inner(C.parameters[1], seen, mode, args...)}
@@ -163,7 +186,9 @@ Base.@nospecializeinfer function traced_type_inner(
163186
end
164187
end
165188

166-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg))
189+
Base.@nospecializeinfer function traced_type_inner(
190+
@nospecialize(T::Type{<:Function}), seen, mode::TraceMode, @nospecialize(args::Vararg)
191+
)
167192
# functions are directly returned
168193
if sizeof(T) == 0
169194
return T
@@ -190,7 +215,9 @@ end
190215
@inline is_concrete_tuple(x::T2) where {T2} =
191216
(x <: Tuple) && !(x === Tuple) && !(x isa UnionAll)
192217

193-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg))
218+
Base.@nospecializeinfer function traced_type_inner(
219+
@nospecialize(T::Type{<:Tuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)
220+
)
194221
if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll
195222
throw(AssertionError("Type $T is not concrete type or concrete tuple"))
196223
elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters)
@@ -204,19 +231,27 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Tuple
204231
return Tuple{TT...}
205232
end
206233

207-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg))
234+
Base.@nospecializeinfer function traced_type_inner(
235+
@nospecialize(T::Type{<:NamedTuple}), seen, mode::TraceMode, @nospecialize(args::Vararg)
236+
)
208237
N = T.parameters[1]
209238
V = T.parameters[2]
210239
return NamedTuple{N,traced_type_inner(V, seen, mode, args...)}
211240
end
212241

213-
214242
Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict}) = nothing
215-
Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict{K}}) where K = K
243+
Base.@nospecializeinfer @inline dict_key(::Type{<:AbstractDict{K}}) where {K} = K
216244
Base.@nospecializeinfer @inline dict_value(::Type{<:AbstractDict}) = nothing
217-
Base.@nospecializeinfer @inline dict_value(::Type{<:(AbstractDict{K,V} where K)}) where V = V
245+
Base.@nospecializeinfer @inline dict_value(
246+
::Type{<:(AbstractDict{K,V} where {K})}
247+
) where {V} = V
218248

219-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:AbstractDict}), seen, mode::TraceMode, @nospecialize(args::Vararg))
249+
Base.@nospecializeinfer function traced_type_inner(
250+
@nospecialize(T::Type{<:AbstractDict}),
251+
seen,
252+
mode::TraceMode,
253+
@nospecialize(args::Vararg)
254+
)
220255
V = dict_value(T)
221256
if V === nothing
222257
return T
@@ -234,13 +269,16 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Abstr
234269
if K !== nothing
235270
return dictty{K,V2}
236271
else
237-
return (dictty{KT,V2} where KT)
272+
return (dictty{KT,V2} where {KT})
238273
end
239274
end
240275
end
241276

242277
Base.@nospecializeinfer function traced_type_inner(
243-
@nospecialize(T0::Type{<:ConcreteRNumber}), seen, mode::TraceMode, @nospecialize(args::Vararg)
278+
@nospecialize(T0::Type{<:ConcreteRNumber}),
279+
seen,
280+
mode::TraceMode,
281+
@nospecialize(args::Vararg)
244282
)
245283
T = T0.parameters[1]
246284
if mode == ConcreteToTraced
@@ -251,15 +289,22 @@ Base.@nospecializeinfer function traced_type_inner(
251289
throw("Abstract RNumber cannot be made concrete")
252290
end
253291
end
254-
255-
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = UnionAll(TV.var, base_typet(TV.body))
256-
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = TracedRArray{TV.parameters...}
257-
258-
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = UnionAll(TV.var, base_typec(TV.body))
259-
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
292+
293+
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) =
294+
UnionAll(TV.var, base_typet(TV.body))
295+
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) =
296+
TracedRArray{TV.parameters...}
297+
298+
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) =
299+
UnionAll(TV.var, base_typec(TV.body))
300+
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) =
301+
(TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
260302

261303
Base.@nospecializeinfer function traced_type_inner(
262-
@nospecialize(T::Type{<:ConcreteRArray}), seen, mode::TraceMode, @nospecialize(args::Vararg)
304+
@nospecialize(T::Type{<:ConcreteRArray}),
305+
seen,
306+
mode::TraceMode,
307+
@nospecialize(args::Vararg)
263308
)
264309
if mode == ConcreteToTraced
265310
return base_typet(T)
@@ -270,7 +315,12 @@ Base.@nospecializeinfer function traced_type_inner(
270315
end
271316
end
272317

273-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:ConcreteRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg))
318+
Base.@nospecializeinfer function traced_type_inner(
319+
@nospecialize(T::Type{<:ConcreteRNG}),
320+
seen,
321+
mode::TraceMode,
322+
@nospecialize(args::Vararg)
323+
)
274324
if mode == ConcreteToTraced
275325
return TracedRNG
276326
elseif mode == TracedToConcrete
@@ -281,16 +331,23 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Concr
281331
end
282332

283333
Base.@nospecializeinfer function traced_type_inner(
284-
::Type{<:MissingTracedValue}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch)
334+
::Type{<:MissingTracedValue},
335+
seen,
336+
mode::TraceMode,
337+
@nospecialize(track_numbers),
338+
@nospecialize(batchmode),
339+
@nospecialize(tobatch)
285340
)
286-
error("This should not happen...")
341+
return error("This should not happen...")
287342
end
288343

289-
@inline base_typec(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, base_typec(TV.body))
290-
@inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...}
291-
292344
Base.@nospecializeinfer function traced_type_inner(
293-
TR::Type{<:TracedRArray}, seen, mode::TraceMode, @nospecialize(track_numbers), @nospecialize(batchmode), @nospecialize(tobatch)
345+
TR::Type{<:TracedRArray},
346+
seen,
347+
mode::TraceMode,
348+
@nospecialize(track_numbers),
349+
@nospecialize(batchmode),
350+
@nospecialize(tobatch)
294351
)
295352
T = TR.parameters[1]
296353
N = TR.parameters[2]
@@ -317,7 +374,9 @@ Base.@nospecializeinfer function traced_type_inner(
317374
end
318375
end
319376

320-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg))
377+
Base.@nospecializeinfer function traced_type_inner(
378+
@nospecialize(T::Type{<:TracedRNG}), seen, mode::TraceMode, @nospecialize(args::Vararg)
379+
)
321380
if mode == ConcreteToTraced
322381
throw("TracedRNG cannot be traced")
323382
elseif mode == TracedToConcrete
@@ -329,7 +388,9 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:Trace
329388
end
330389
end
331390

332-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg))
391+
Base.@nospecializeinfer function traced_type_inner(
392+
@nospecialize(T::Type{<:XLAArray}), seen, mode::TraceMode, @nospecialize(args::Vararg)
393+
)
333394
throw("XLA $T array cannot be traced")
334395
end
335396

@@ -346,13 +407,20 @@ Base.@nospecializeinfer function traced_type_inner(
346407
end
347408

348409
for P in (Ptr, Core.LLVMPtr, Base.RefValue)
349-
@eval Base.@nospecializeinfer function traced_type_inner(@nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg))
410+
@eval Base.@nospecializeinfer function traced_type_inner(
411+
@nospecialize(PT::Type{<:$P}), seen, mode::TraceMode, @nospecialize(args::Vararg)
412+
)
350413
T = eltype(PT)
351414
return $P{traced_type_inner(T, seen, mode, args...)}
352415
end
353416
end
354417

355-
Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}), seen, @nospecialize(mode::TraceMode), @nospecialize(args::Vararg))
418+
Base.@nospecializeinfer function traced_type_inner(
419+
@nospecialize(VT::Type{<:Val}),
420+
seen,
421+
@nospecialize(mode::TraceMode),
422+
@nospecialize(args::Vararg)
423+
)
356424
if VT isa UnionAll
357425
return VT
358426
end
@@ -363,7 +431,7 @@ Base.@nospecializeinfer function traced_type_inner(@nospecialize(VT::Type{<:Val}
363431
throw("Val type $(Val{T}) cannot be traced")
364432
end
365433

366-
const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}()
434+
const traced_type_cache = Dict{Tuple{TraceMode,Type,Any,Any},Dict{Type,Type}}()
367435

368436
# function traced_type_generator(world::UInt, source, self, @nospecialize(T::Type), @nospecialize(mode::Type{<:Val}), @nospecialize(track_numbers::Type))
369437
# @nospecialize
@@ -456,16 +524,18 @@ const traced_type_cache = Dict{Tuple{TraceMode, Type}, Dict{Type, Type}}()
456524
# $(Expr(:meta, :generated, traced_type_generator))
457525
# end
458526

459-
Base.@assume_effects :total @inline function traced_type(T::Type, ::Val{mode}, track_numbers::Type) where mode
527+
Base.@assume_effects :total @inline function traced_type(
528+
T::Type, ::Val{mode}, track_numbers::Type, batchmode, tobatch
529+
) where {mode}
460530
cache = nothing
461-
cache_key = (mode, track_numbers)
531+
cache_key = (mode, track_numbers, batchmode, tobatch)
462532
if haskey(traced_type_cache, cache_key)
463533
cache = traced_type_cache[cache_key]
464534
else
465-
cache = Dict{Type, Type}()
535+
cache = Dict{Type,Type}()
466536
traced_type_cache[cache_key] = cache
467537
end
468-
res1 = traced_type_inner(T, cache, mode, track_numbers)
538+
return traced_type_inner(T, cache, mode, track_numbers, batchmode, tobatch)
469539
end
470540

471541
abstract type TracedTypeException <: Exception end
@@ -506,9 +576,9 @@ function make_tracer(
506576
@nospecialize(prev),
507577
@nospecialize(path),
508578
mode;
509-
@nospecialize(track_numbers::Type=Union{}),
510-
@nospecialize(batchmode=BatchNone),
511-
@nospecialize(tobatch=nothing),
579+
@nospecialize(track_numbers::Type = Union{}),
580+
@nospecialize(batchmode = BatchNone),
581+
@nospecialize(tobatch = nothing),
512582
kwargs...,
513583
)
514584
if mode != NoStopTracedTrack && haskey(seen, prev)
@@ -601,7 +671,9 @@ function make_tracer(
601671
return res
602672
end
603673

604-
function make_tracer(seen, prev::ConcreteRNumber{T}, @nospecialize(path), mode; kwargs...) where {T}
674+
function make_tracer(
675+
seen, prev::ConcreteRNumber{T}, @nospecialize(path), mode; kwargs...
676+
) where {T}
605677
if mode == ArrayToConcrete
606678
return prev
607679
end
@@ -772,7 +844,12 @@ function make_tracer(
772844
end
773845

774846
function make_tracer(
775-
seen, @nospecialize(prev::Number), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), kwargs...
847+
seen,
848+
@nospecialize(prev::Number),
849+
@nospecialize(path),
850+
mode;
851+
@nospecialize(track_numbers::Type = Union{}),
852+
kwargs...,
776853
)
777854
RT = Core.Typeof(prev)
778855
if RT <: track_numbers
@@ -815,7 +892,14 @@ function make_tracer(
815892
end
816893

817894
function make_tracer(
818-
seen, @nospecialize(prev::Array), @nospecialize(path), mode; @nospecialize(track_numbers::Type=Union{}), @nospecialize(batchmode=BatchNone), @nospecialize(tobatch=nothing), kwargs...
895+
seen,
896+
@nospecialize(prev::Array),
897+
@nospecialize(path),
898+
mode;
899+
@nospecialize(track_numbers::Type = Union{}),
900+
@nospecialize(batchmode = BatchNone),
901+
@nospecialize(tobatch = nothing),
902+
kwargs...,
819903
)
820904
RT = Core.Typeof(prev)
821905
if mode != NoStopTracedTrack && haskey(seen, prev)
@@ -854,9 +938,7 @@ function make_tracer(
854938
return newa
855939
end
856940

857-
function make_tracer(
858-
seen, @nospecialize(prev::Tuple), @nospecialize(path), mode; kwargs...
859-
)
941+
function make_tracer(seen, @nospecialize(prev::Tuple), @nospecialize(path), mode; kwargs...)
860942
return (
861943
(
862944
make_tracer(seen, v, append_path(path, i), mode; kwargs...) for
@@ -870,9 +952,9 @@ function make_tracer(
870952
@nospecialize(prev::NamedTuple),
871953
@nospecialize(path),
872954
mode;
873-
@nospecialize(track_numbers::Type=Union{}),
874-
@nospecialize(batchmode=BatchNone),
875-
@nospecialize(tobatch=nothing),
955+
@nospecialize(track_numbers::Type = Union{}),
956+
@nospecialize(batchmode = BatchNone),
957+
@nospecialize(tobatch = nothing),
876958
kwargs...,
877959
)
878960
NT = Core.Typeof(prev)
@@ -918,15 +1000,23 @@ end
9181000
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
9191001
end
9201002

921-
function to_rarray_internal(@nospecialize(::TracedRArray), @nospecialize(track_numbers::Type))
1003+
function to_rarray_internal(
1004+
@nospecialize(::TracedRArray), @nospecialize(track_numbers::Type)
1005+
)
9221006
return error("Cannot convert TracedRArray to ConcreteRArray")
9231007
end
924-
@inline to_rarray_internal(@nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type)) = x
925-
@inline function to_rarray_internal(@nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type))
1008+
@inline to_rarray_internal(
1009+
@nospecialize(x::ConcreteRArray), @nospecialize(track_numbers::Type)
1010+
) = x
1011+
@inline function to_rarray_internal(
1012+
@nospecialize(x::Array{<:ReactantPrimitive}), @nospecialize(track_numbers::Type)
1013+
)
9261014
return ConcreteRArray(x)
9271015
end
9281016

929-
@inline to_rarray_internal(@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)) = x
1017+
@inline to_rarray_internal(
1018+
@nospecialize(x::ConcreteRNumber), @nospecialize(track_numbers::Type)
1019+
) = x
9301020
@inline function to_rarray_internal(
9311021
@nospecialize(x::ReactantPrimitive), @nospecialize(track_numbers::Type)
9321022
)

0 commit comments

Comments
 (0)
Please sign in to comment.