Skip to content

Commit b823526

Browse files
committed
fix: tests used older defn
1 parent a4130f4 commit b823526

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

Diff for: src/Tracing.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ Base.@nospecializeinfer function traced_type_inner(
302302
return TracedRArray{T,N}
303303
elseif mode == TracedSetPath
304304
if batchmode == BatchNone
305-
return T
305+
return TracedRArray{T,N}
306306
elseif batchmode == BatchArray
307307
if tobatch === nothing
308308
TracedRArray{T,N - 1}

Diff for: test/tracing.jl

+64-25
Original file line numberDiff line numberDiff line change
@@ -56,33 +56,65 @@ using Test
5656
(Complex{UInt128}, Complex{UInt128}, TracedRNumber{Complex{UInt128}}),
5757

5858
# RArray types
59-
(ConcreteRArray{Float64,0}, TracedRArray{Float64,0}, TracedRArray{Float64, 0}),
60-
(ConcreteRArray{Float64,1}, TracedRArray{Float64,1}, TracedRArray{Float64, 1}),
61-
(ConcreteRArray{Float64,2}, TracedRArray{Float64,2}, TracedRArray{Float64, 2}),
62-
(ConcreteRArray{Float64,3}, TracedRArray{Float64,3}, TracedRArray{Float64, 3}),
59+
(
60+
ConcreteRArray{Float64,0},
61+
TracedRArray{Float64,0},
62+
TracedRArray{Float64,0},
63+
),
64+
(
65+
ConcreteRArray{Float64,1},
66+
TracedRArray{Float64,1},
67+
TracedRArray{Float64,1},
68+
),
69+
(
70+
ConcreteRArray{Float64,2},
71+
TracedRArray{Float64,2},
72+
TracedRArray{Float64,2},
73+
),
74+
(
75+
ConcreteRArray{Float64,3},
76+
TracedRArray{Float64,3},
77+
TracedRArray{Float64,3},
78+
),
6379

6480
# Array types
65-
(Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64}, 1}),
66-
(Array{ConcreteRArray{Float64,2},1}, Array{TracedRArray{Float64,2},1}, Array{TracedRArray{Float64,2}, 1}),
81+
(Array{Float64,1}, Array{Float64,1}, Array{TracedRNumber{Float64},1}),
82+
(
83+
Array{ConcreteRArray{Float64,2},1},
84+
Array{TracedRArray{Float64,2},1},
85+
Array{TracedRArray{Float64,2},1},
86+
),
6787

6888
# Union types
69-
(Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing, TracedRNumber{Int}}),
89+
(Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}),
7090
(
7191
Union{Nothing,ConcreteRArray{Float64,1}},
7292
Union{Nothing,TracedRArray{Float64,1}},
73-
Union{Nothing, TracedRArray{Float64, 1}}
93+
Union{Nothing,TracedRArray{Float64,1}},
7494
),
7595

7696
# Ptr types
7797
(Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}),
78-
(Ptr{ConcreteRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}, Ptr{TracedRArray{Float64,1}}),
79-
(Core.LLVMPtr{Float64}, Core.LLVMPtr{Float64}, Core.LLVMPtr{TracedRNumber{Float64}}),
98+
(
99+
Ptr{ConcreteRArray{Float64,1}},
100+
Ptr{TracedRArray{Float64,1}},
101+
Ptr{TracedRArray{Float64,1}},
102+
),
103+
(
104+
Core.LLVMPtr{Float64},
105+
Core.LLVMPtr{Float64},
106+
Core.LLVMPtr{TracedRNumber{Float64}},
107+
),
80108
(
81109
Core.LLVMPtr{ConcreteRArray{Float64,1}},
82110
Core.LLVMPtr{TracedRArray{Float64,1}},
83-
Core.LLVMPtr{TracedRArray{Float64,1}}
111+
Core.LLVMPtr{TracedRArray{Float64,1}},
112+
),
113+
(
114+
Base.RefValue{Float64},
115+
Base.RefValue{Float64},
116+
Base.RefValue{TracedRNumber{Float64}},
84117
),
85-
(Base.RefValue{Float64}, Base.RefValue{Float64}, Base.RefValue{TracedRNumber{Float64}}),
86118
(
87119
Base.RefValue{ConcreteRArray{Float64,1}},
88120
Base.RefValue{TracedRArray{Float64,1}},
@@ -93,23 +125,30 @@ using Test
93125
(Val{0}, Val{0}, Val{0}),
94126
(Val{0.5}, Val{0.5}, Val{0.5}),
95127
(Val{:x}, Val{:x}, Val{:x}),
96-
97-
98-
(Dict{Int, ConcreteRArray{Float64,0}}, Dict{Int, TracedRArray{Float64,0}}, Dict{Int, TracedRArray{Float64, 0}}),
128+
(
129+
Dict{Int,ConcreteRArray{Float64,0}},
130+
Dict{Int,TracedRArray{Float64,0}},
131+
Dict{Int,TracedRArray{Float64,0}},
132+
),
99133
(Dict{Int}, Dict{Int}, Dict{Int}),
100134
(Dict, Dict, Dict),
101-
((Dict{A, ConcreteRArray{Float64,0}} where A), (Dict{A, TracedRArray{Float64,0}} where A), (Dict{A, TracedRArray{Float64,0}} where A)),
102-
103-
(Base.Pairs{Symbol, Union{}}, Base.Pairs{Symbol, Union{}}, Base.Pairs{Symbol, Union{}})
135+
(
136+
(Dict{A,ConcreteRArray{Float64,0}} where {A}),
137+
(Dict{A,TracedRArray{Float64,0}} where {A}),
138+
(Dict{A,TracedRArray{Float64,0}} where {A}),
139+
),
140+
(
141+
Base.Pairs{Symbol,Union{}},
142+
Base.Pairs{Symbol,Union{}},
143+
Base.Pairs{Symbol,Union{}},
144+
),
104145
]
105146
tracedty = traced_type(
106-
origty, Val(ConcreteToTraced), Union{}
147+
origty, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing
107148
)
108149
@test tracedty == targetty
109150

110-
tracedty2 = traced_type(
111-
origty, Val(ConcreteToTraced), ReactantPrimitive
112-
)
151+
tracedty2 = traced_type(origty, Val(ConcreteToTraced), ReactantPrimitive)
113152
@test tracedty2 == targetty
114153
end
115154

@@ -120,21 +159,21 @@ using Test
120159
TracedRArray{Float64,3},
121160
]
122161
@test_throws Union{ErrorException,String} traced_type(
123-
type, Val(ConcreteToTraced), Union{}
162+
type, Val(ConcreteToTraced), Union{}, Reactant.BatchNone, nothing
124163
)
125164
end
126165
end
127166
@testset "traced_type exceptions" begin
128167
@test_throws TracedTypeError Reactant.traced_type(
129-
Real, Val(Reactant.ArrayToConcrete), Union{}
168+
Real, Val(Reactant.ArrayToConcrete), Union{}, Reactant.BatchNone, nothing
130169
)
131170

132171
struct Node
133172
x::Vector{Float64}
134173
y::Union{Nothing,Node}
135174
end
136175
@test_throws NoFieldMatchError traced_type(
137-
Node, Val(ArrayToConcrete), Union{}
176+
Node, Val(ArrayToConcrete), Union{}, Reactant.BatchNone, nothing
138177
)
139178
end
140179
end

0 commit comments

Comments
 (0)