|
22 | 22 | x_sizes = [(2, 14), (3, 3, 1, 3)]
|
23 | 23 |
|
24 | 24 | model_type = (:deq, :skipdeq, :skipregdeq)
|
25 |
| - solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden()) |
| 25 | + solvers = (VCAB3(), Tsit5(), |
| 26 | + NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), |
| 27 | + SimpleLimitedMemoryBroyden()) |
26 | 28 | jacobian_regularizations = Any[nothing, AutoZygote()]
|
27 | 29 | !ongpu && push!(jacobian_regularizations, AutoFiniteDiff())
|
28 | 30 |
|
|
31 | 33 |
|
32 | 34 | @testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models,
|
33 | 35 | init_models, x_sizes)
|
34 |
| - @info solver, mtype, jacobian_regularization, base_model, init_model, x_size |
35 |
| - |
36 | 36 | model = if mtype === :deq
|
37 | 37 | DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
|
38 | 38 | elseif mtype === :skipdeq
|
|
48 | 48 | x = randn(rng, Float32, x_size...) |> dev
|
49 | 49 | z, st = model(x, ps, st)
|
50 | 50 |
|
51 |
| - opt_broken = solver isa NewtonRaphson || |
52 |
| - solver isa SimpleLimitedMemoryBroyden |
53 |
| - @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch |
| 51 | + opt_broken = solver isa SimpleLimitedMemoryBroyden |
| 52 | + @jet model(x, ps, st) opt_broken=opt_broken |
54 | 53 |
|
55 | 54 | @test all(isfinite, z)
|
56 | 55 | @test size(z) == size(x)
|
@@ -107,20 +106,18 @@ end
|
107 | 106 | scales = [((4,), (3,), (2,), (1,))]
|
108 | 107 |
|
109 | 108 | model_type = (:deq, :skipdeq, :skipregdeq, :node)
|
110 |
| - solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden()) |
| 109 | + solvers = (VCAB3(), Tsit5(), |
| 110 | + NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)), |
| 111 | + SimpleLimitedMemoryBroyden()) |
111 | 112 | jacobian_regularizations = (nothing,)
|
112 | 113 |
|
113 | 114 | for mtype in model_type, jacobian_regularization in jacobian_regularizations
|
114 | 115 | @testset "Solver: $(__nameof(solver))" for solver in solvers
|
115 | 116 | @testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
|
116 | 117 | mapping_layers, init_layers, x_sizes, scales)
|
117 |
| - @info solver, mtype, jacobian_regularization, main_layer, mapping_layer, |
118 |
| - init_layer, x_size, scale |
119 |
| - |
120 | 118 | model = if mtype === :deq
|
121 | 119 | MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
|
122 |
| - solver, |
123 |
| - scale; jacobian_regularization) |
| 120 | + solver, scale; jacobian_regularization) |
124 | 121 | elseif mtype === :skipdeq
|
125 | 122 | MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
|
126 | 123 | init_layer, solver, scale; jacobian_regularization)
|
|
140 | 137 | z, st = model(x, ps, st)
|
141 | 138 | z_ = DEQs.__flatten_vcat(z)
|
142 | 139 |
|
143 |
| - opt_broken = solver isa NewtonRaphson || |
144 |
| - solver isa SimpleLimitedMemoryBroyden |
| 140 | + opt_broken = solver isa SimpleLimitedMemoryBroyden |
145 | 141 | @jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
|
146 | 142 |
|
147 | 143 | @test all(isfinite, z_)
|
|
0 commit comments