Skip to content

Commit 6e684c3

Browse files
committed
fix: tracing
1 parent f776a4c commit 6e684c3

2 files changed

Lines changed: 43 additions & 25 deletions

File tree

src/Tracing.jl

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -285,36 +285,25 @@ Base.@nospecializeinfer function traced_type_inner(
285285
mode::TraceMode,
286286
@nospecialize(args::Vararg)
287287
)
288-
T = T0.parameters[1]
289288
if mode == ConcreteToTraced
290-
return TracedRNumber{T}
289+
return TracedRNumber{T0.parameters[1]}
291290
elseif mode == TracedToConcrete
292-
return ConcreteRNumber{T}
291+
return T0
293292
else
294293
throw("Abstract RNumber cannot be made concrete")
295294
end
296295
end
297296

298-
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) =
299-
UnionAll(TV.var, base_typet(TV.body))
300-
Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) =
301-
TracedRArray{TV.parameters...}
302-
303-
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) =
304-
UnionAll(TV.var, base_typec(TV.body))
305-
Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) =
306-
(TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
307-
308297
Base.@nospecializeinfer function traced_type_inner(
309-
@nospecialize(T::Type{<:ConcreteRArray}),
298+
@nospecialize(CA::Type{<:ConcreteRArray}),
310299
seen,
311300
mode::TraceMode,
312301
@nospecialize(args::Vararg)
313302
)
314303
if mode == ConcreteToTraced
315-
return base_typet(T)
304+
return TracedRArray{CA.parameters[1],CA.parameters[2]}
316305
elseif mode == TracedToConcrete
317-
return T
306+
return CA
318307
else
319308
throw("Abstract RArray cannot be made concrete")
320309
end
@@ -346,6 +335,38 @@ Base.@nospecializeinfer function traced_type_inner(
346335
return error("This should not happen...")
347336
end
348337

338+
Base.@nospecializeinfer function traced_type_inner(
339+
TR::Type{<:TracedRNumber},
340+
seen,
341+
mode::TraceMode,
342+
@nospecialize(track_numbers),
343+
@nospecialize(batchmode),
344+
@nospecialize(tobatch)
345+
)
346+
T = TR.parameters[1]
347+
if mode == ConcreteToTraced
348+
throw("TracedRArray $(TracedRArray{T,N}) cannot be traced")
349+
elseif mode == TracedToConcrete
350+
return ConcreteRNumber{T}
351+
elseif mode == TracedTrack || mode == NoStopTracedTrack
352+
return TracedRNumber{T}
353+
elseif mode == TracedSetPath
354+
if batchmode == BatchNone
355+
return TracedRNumber{T}
356+
elseif batchmode == BatchScalar
357+
if tobatch === nothing
358+
return TracedRNumber{T}
359+
else
360+
return TracedRArray{T,length(tobatch)}
361+
end
362+
else
363+
error("Cannot BatchArray on a scalar")
364+
end
365+
else
366+
throw("$(TracedRNumber{T}) cannot be made concrete in mode $mode")
367+
end
368+
end
369+
349370
Base.@nospecializeinfer function traced_type_inner(
350371
TR::Type{<:TracedRArray},
351372
seen,
@@ -359,7 +380,7 @@ Base.@nospecializeinfer function traced_type_inner(
359380
if mode == ConcreteToTraced
360381
throw("TracedRArray $(TracedRArray{T,N}) cannot be traced")
361382
elseif mode == TracedToConcrete
362-
return base_typec(TracedRArray{T,N})
383+
return ConcreteRArray{T,N}
363384
elseif mode == TracedTrack || mode == NoStopTracedTrack
364385
return TracedRArray{T,N}
365386
elseif mode == TracedSetPath

test/basic.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,18 @@ end
3434
end
3535

3636
sinexp(x) = sin(exp(x))
37-
sinexpbc(x) = sinexp.(x)
3837

3938
@testset "Broadcast combined" begin
4039
x = rand(2, 10)
4140

42-
r_res = sinexpbc(x)
41+
r_res = sinexp.(x)
4342

4443
a = Reactant.ConcreteRArray(x)
4544

46-
c_res = @allowscalar sinexpbc(a)
45+
c_res = @allowscalar sinexp.(a)
4746
@test c_res r_res
4847

49-
@test @jit(sinexpbc(a)) r_res
48+
@test @jit(sinexp.(a)) r_res
5049
end
5150

5251
sumexp(x) = sum(exp, x)
@@ -82,13 +81,11 @@ end
8281
@test f_res r_res
8382
end
8483

85-
bcast_cos(x) = cos.(x)
86-
8784
@testset "Basic cos" begin
8885
x = rand(3, 2)
8986
c = Reactant.ConcreteRArray(x)
9087

91-
@test @jit(bcast_cos(c)) cos.(x)
88+
@test @jit(cos.(c)) cos.(x)
9289
end
9390

9491
f_var(args...) = sum(args)
@@ -376,7 +373,7 @@ end
376373
b = Reactant.to_rarray(_b)
377374
c = Reactant.to_rarray(_c)
378375

379-
# vcat test
376+
# vcat test
380377
y = @jit vcat(a, b)
381378
@test y == vcat(a, _b)
382379
@test y isa ConcreteRArray{typeof_a,1}

0 commit comments

Comments
 (0)