@@ -56,33 +56,65 @@ using Test
56
56
(Complex{UInt128}, Complex{UInt128}, TracedRNumber{Complex{UInt128}}),
57
57
58
58
# RArray types
59
- (ConcreteRArray{Float64,0 }, TracedRArray{Float64,0 }, TracedRArray{Float64, 0 }),
60
- (ConcreteRArray{Float64,1 }, TracedRArray{Float64,1 }, TracedRArray{Float64, 1 }),
61
- (ConcreteRArray{Float64,2 }, TracedRArray{Float64,2 }, TracedRArray{Float64, 2 }),
62
- (ConcreteRArray{Float64,3 }, TracedRArray{Float64,3 }, TracedRArray{Float64, 3 }),
59
+ (
60
+ ConcreteRArray{Float64,0 },
61
+ TracedRArray{Float64,0 },
62
+ TracedRArray{Float64,0 },
63
+ ),
64
+ (
65
+ ConcreteRArray{Float64,1 },
66
+ TracedRArray{Float64,1 },
67
+ TracedRArray{Float64,1 },
68
+ ),
69
+ (
70
+ ConcreteRArray{Float64,2 },
71
+ TracedRArray{Float64,2 },
72
+ TracedRArray{Float64,2 },
73
+ ),
74
+ (
75
+ ConcreteRArray{Float64,3 },
76
+ TracedRArray{Float64,3 },
77
+ TracedRArray{Float64,3 },
78
+ ),
63
79
64
80
# Array types
65
- (Array{Float64,1 }, Array{Float64,1 }, Array{TracedRNumber{Float64}, 1 }),
66
- (Array{ConcreteRArray{Float64,2 },1 }, Array{TracedRArray{Float64,2 },1 }, Array{TracedRArray{Float64,2 }, 1 }),
81
+ (Array{Float64,1 }, Array{Float64,1 }, Array{TracedRNumber{Float64},1 }),
82
+ (
83
+ Array{ConcreteRArray{Float64,2 },1 },
84
+ Array{TracedRArray{Float64,2 },1 },
85
+ Array{TracedRArray{Float64,2 },1 },
86
+ ),
67
87
68
88
# Union types
69
- (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing, TracedRNumber{Int}}),
89
+ (Union{Nothing,Int}, Union{Nothing,Int}, Union{Nothing,TracedRNumber{Int}}),
70
90
(
71
91
Union{Nothing,ConcreteRArray{Float64,1 }},
72
92
Union{Nothing,TracedRArray{Float64,1 }},
73
- Union{Nothing, TracedRArray{Float64, 1 }}
93
+ Union{Nothing,TracedRArray{Float64,1 }},
74
94
),
75
95
76
96
# Ptr types
77
97
(Ptr{Float64}, Ptr{Float64}, Ptr{TracedRNumber{Float64}}),
78
- (Ptr{ConcreteRArray{Float64,1 }}, Ptr{TracedRArray{Float64,1 }}, Ptr{TracedRArray{Float64,1 }}),
79
- (Core. LLVMPtr{Float64}, Core. LLVMPtr{Float64}, Core. LLVMPtr{TracedRNumber{Float64}}),
98
+ (
99
+ Ptr{ConcreteRArray{Float64,1 }},
100
+ Ptr{TracedRArray{Float64,1 }},
101
+ Ptr{TracedRArray{Float64,1 }},
102
+ ),
103
+ (
104
+ Core. LLVMPtr{Float64},
105
+ Core. LLVMPtr{Float64},
106
+ Core. LLVMPtr{TracedRNumber{Float64}},
107
+ ),
80
108
(
81
109
Core. LLVMPtr{ConcreteRArray{Float64,1 }},
82
110
Core. LLVMPtr{TracedRArray{Float64,1 }},
83
- Core. LLVMPtr{TracedRArray{Float64,1 }}
111
+ Core. LLVMPtr{TracedRArray{Float64,1 }},
112
+ ),
113
+ (
114
+ Base. RefValue{Float64},
115
+ Base. RefValue{Float64},
116
+ Base. RefValue{TracedRNumber{Float64}},
84
117
),
85
- (Base. RefValue{Float64}, Base. RefValue{Float64}, Base. RefValue{TracedRNumber{Float64}}),
86
118
(
87
119
Base. RefValue{ConcreteRArray{Float64,1 }},
88
120
Base. RefValue{TracedRArray{Float64,1 }},
@@ -93,23 +125,30 @@ using Test
93
125
(Val{0 }, Val{0 }, Val{0 }),
94
126
(Val{0.5 }, Val{0.5 }, Val{0.5 }),
95
127
(Val{:x }, Val{:x }, Val{:x }),
96
-
97
-
98
- (Dict{Int, ConcreteRArray{Float64,0 }}, Dict{Int, TracedRArray{Float64,0 }}, Dict{Int, TracedRArray{Float64, 0 }}),
128
+ (
129
+ Dict{Int,ConcreteRArray{Float64,0 }},
130
+ Dict{Int,TracedRArray{Float64,0 }},
131
+ Dict{Int,TracedRArray{Float64,0 }},
132
+ ),
99
133
(Dict{Int}, Dict{Int}, Dict{Int}),
100
134
(Dict, Dict, Dict),
101
- ((Dict{A, ConcreteRArray{Float64,0 }} where A), (Dict{A, TracedRArray{Float64,0 }} where A), (Dict{A, TracedRArray{Float64,0 }} where A)),
102
-
103
- (Base. Pairs{Symbol, Union{}}, Base. Pairs{Symbol, Union{}}, Base. Pairs{Symbol, Union{}})
135
+ (
136
+ (Dict{A,ConcreteRArray{Float64,0 }} where {A}),
137
+ (Dict{A,TracedRArray{Float64,0 }} where {A}),
138
+ (Dict{A,TracedRArray{Float64,0 }} where {A}),
139
+ ),
140
+ (
141
+ Base. Pairs{Symbol,Union{}},
142
+ Base. Pairs{Symbol,Union{}},
143
+ Base. Pairs{Symbol,Union{}},
144
+ ),
104
145
]
105
146
tracedty = traced_type (
106
- origty, Val (ConcreteToTraced), Union{}
147
+ origty, Val (ConcreteToTraced), Union{}, Reactant . BatchNone, nothing
107
148
)
108
149
@test tracedty == targetty
109
150
110
- tracedty2 = traced_type (
111
- origty, Val (ConcreteToTraced), ReactantPrimitive
112
- )
151
+ tracedty2 = traced_type (origty, Val (ConcreteToTraced), ReactantPrimitive)
113
152
@test tracedty2 == targetty
114
153
end
115
154
@@ -120,21 +159,21 @@ using Test
120
159
TracedRArray{Float64,3 },
121
160
]
122
161
@test_throws Union{ErrorException,String} traced_type (
123
- type, Val (ConcreteToTraced), Union{}
162
+ type, Val (ConcreteToTraced), Union{}, Reactant . BatchNone, nothing
124
163
)
125
164
end
126
165
end
127
166
@testset " traced_type exceptions" begin
128
167
@test_throws TracedTypeError Reactant. traced_type (
129
- Real, Val (Reactant. ArrayToConcrete), Union{}
168
+ Real, Val (Reactant. ArrayToConcrete), Union{}, Reactant . BatchNone, nothing
130
169
)
131
170
132
171
struct Node
133
172
x:: Vector{Float64}
134
173
y:: Union{Nothing,Node}
135
174
end
136
175
@test_throws NoFieldMatchError traced_type (
137
- Node, Val (ArrayToConcrete), Union{}
176
+ Node, Val (ArrayToConcrete), Union{}, Reactant . BatchNone, nothing
138
177
)
139
178
end
140
179
end
0 commit comments