Skip to content

Commit 7e42d0e

Browse files
authored
Merge pull request #146 from SciML/ap/fix_tests
Retrigger Tests
2 parents 61970f7 + 7bc6dba commit 7e42d0e

File tree

12 files changed

+77
-49
lines changed

12 files changed

+77
-49
lines changed

.github/workflows/CI.yml

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ jobs:
2323
matrix:
2424
version:
2525
- '1'
26-
- '~1.10.0-0'
2726
steps:
2827
- uses: actions/checkout@v4
2928
- uses: julia-actions/setup-julia@v1

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <[email protected]>"]
4-
version = "2.0.2"
4+
version = "2.0.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
15+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1416
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1517
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1618
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -32,9 +34,11 @@ ChainRulesCore = "1"
3234
ConcreteStructs = "0.2"
3335
ConstructionBase = "1"
3436
DiffEqBase = "6.119"
37+
FastClosures = "0.3"
3538
LinearAlgebra = "1"
3639
LinearSolve = "2.21.2"
3740
Lux = "0.5.11"
41+
PrecompileTools = "1"
3842
Random = "1"
3943
SciMLBase = "2"
4044
SciMLSensitivity = "7.43"

docs/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
55
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
6+
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
67
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
78
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
89
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
@@ -20,6 +21,7 @@ DeepEquilibriumNetworks = "2"
2021
Documenter = "1"
2122
DocumenterCitations = "1"
2223
LinearSolve = "2"
24+
LoggingExtras = "1"
2325
Lux = "0.5"
2426
LuxCUDA = "0.3"
2527
MLDataUtils = "0.5"

docs/ref.bib

-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ @article{baideep2019
1919
year = {2019},
2020
note = {arXiv: 1909.01377},
2121
keywords = {Statistics - Machine Learning, Computer Science - Machine Learning},
22-
annote = {Comment: NeurIPS 2019 Spotlight Oral},
2322
file = {Bai et al. - 2019 - Deep Equilibrium Models.pdf:files/245/Bai et al. - 2019 - Deep Equilibrium Models.pdf:application/pdf},
2423
}
2524

@@ -35,7 +34,6 @@ @article{baimultiscale2020
3534
year = {2020},
3635
note = {arXiv: 2006.08656},
3736
keywords = {Statistics - Machine Learning, Computer Science - Machine Learning, Computer Science - Computer Vision and Pattern Recognition},
38-
annote = {Comment: NeurIPS 2020 Oral},
3937
file = {Bai et al. - 2020 - Multiscale Deep Equilibrium Models.pdf:files/248/Bai et al. - 2020 - Multiscale Deep Equilibrium Models.pdf:application/pdf},
4038
}
4139

docs/src/tutorials/basic_mnist_deq.md

+19-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
44

55
```@example basic_mnist_deq
66
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
7+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
88
using MLDatasets: MNIST
99
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1010
@@ -20,6 +20,18 @@ const cdev = cpu_device()
2020
const gdev = gpu_device()
2121
```
2222

23+
SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
24+
it with the following logger
25+
26+
```@example basic_mnist_deq
27+
function remove_syms_warning(log_args)
28+
return log_args.message !=
29+
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
30+
end
31+
32+
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
33+
```
34+
2335
We can now construct our dataloader.
2436

2537
```@example basic_mnist_deq
@@ -175,15 +187,19 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
175187
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
176188

177189
```@example basic_mnist_deq
178-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
190+
with_logger(filtered_logger) do
191+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
192+
end
179193
nothing # hide
180194
```
181195

182196
We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
183197
which tend to be quite fast for continuous Neural Network problems.
184198

185199
```@example basic_mnist_deq
186-
train_model(VCAB3(), :deq)
200+
with_logger(filtered_logger) do
201+
train_model(VCAB3(), :deq)
202+
end
187203
nothing # hide
188204
```
189205

docs/src/tutorials/reduced_dim_deq.md

+14-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size.
66

77
```@example reduced_dim_mnist
88
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
9-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
9+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
1010
using MLDatasets: MNIST
1111
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1212
@@ -16,6 +16,13 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
1616
const cdev = cpu_device()
1717
const gdev = gpu_device()
1818
19+
function remove_syms_warning(log_args)
20+
return log_args.message !=
21+
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
22+
end
23+
24+
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
25+
1926
function onehot(labels_raw)
2027
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
2128
end
@@ -168,11 +175,15 @@ Now we can train our model. We can't use `:regdeq` here currently, but we will s
168175
in the future.
169176

170177
```@example reduced_dim_mnist
171-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
178+
with_logger(filtered_logger) do
179+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
180+
end
172181
nothing # hide
173182
```
174183

175184
```@example reduced_dim_mnist
176-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
185+
with_logger(filtered_logger) do
186+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
187+
end
177188
nothing # hide
178189
```

src/DeepEquilibriumNetworks.jl

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
module DeepEquilibriumNetworks
22

3-
using ADTypes,
4-
DiffEqBase, LinearAlgebra, Lux, Random, SciMLBase, Statistics, SteadyStateDiffEq
3+
import PrecompileTools: @recompile_invalidations
54

6-
import ChainRulesCore as CRC
7-
import ConcreteStructs: @concrete
8-
import ConstructionBase: constructorof
9-
import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer
10-
import TruncatedStacktraces: @truncate_stacktrace
5+
@recompile_invalidations begin
6+
using ADTypes, DiffEqBase, FastClosures, LinearAlgebra, Lux, Random, SciMLBase,
7+
Statistics, SteadyStateDiffEq
118

12-
import SciMLBase: AbstractNonlinearAlgorithm,
13-
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
9+
import ChainRulesCore as CRC
10+
import ConcreteStructs: @concrete
11+
import ConstructionBase: constructorof
12+
import Lux: AbstractExplicitLayer, AbstractExplicitContainerLayer
13+
import SciMLBase: AbstractNonlinearAlgorithm,
14+
AbstractODEAlgorithm, _unwrap_val, NonlinearSolution
15+
import TruncatedStacktraces: @truncate_stacktrace
16+
end
1417

1518
# Useful Constants
1619
const DEQs = DeepEquilibriumNetworks

src/layers.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,12 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
9797

9898
model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st.model)
9999

100-
dudt(u, p, t) = model((u, p.x), p.ps) .- u
100+
dudt = @closure (u, p, t) -> begin
101+
# The type-assert is needed because of an upstream Lux issue with type stability of
102+
# conv with Dual numbers
103+
y = model((u, p.x), p.ps)::typeof(u)
104+
return y .- u
105+
end
101106

102107
prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x))
103108
alg = __normalize_alg(deq)
@@ -144,7 +149,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
144149
145150
## Example
146151
147-
```jldoctest
152+
```julia
148153
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
149154
150155
julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
@@ -225,7 +230,7 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
225230
226231
## Example
227232
228-
```jldoctest
233+
```julia
229234
julia> using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve
230235
231236
julia> main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),

test/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
2121
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
22+
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
2223
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2324

2425
[compat]

test/layers.jl

+10-14
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ end
2222
x_sizes = [(2, 14), (3, 3, 1, 3)]
2323

2424
model_type = (:deq, :skipdeq, :skipregdeq)
25-
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
25+
solvers = (VCAB3(), Tsit5(),
26+
NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)),
27+
SimpleLimitedMemoryBroyden())
2628
jacobian_regularizations = Any[nothing, AutoZygote()]
2729
!ongpu && push!(jacobian_regularizations, AutoFiniteDiff())
2830

@@ -31,8 +33,6 @@ end
3133

3234
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models,
3335
init_models, x_sizes)
34-
@info solver, mtype, jacobian_regularization, base_model, init_model, x_size
35-
3636
model = if mtype === :deq
3737
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
3838
elseif mtype === :skipdeq
@@ -48,9 +48,8 @@ end
4848
x = randn(rng, Float32, x_size...) |> dev
4949
z, st = model(x, ps, st)
5050

51-
opt_broken = solver isa NewtonRaphson ||
52-
solver isa SimpleLimitedMemoryBroyden
53-
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
51+
opt_broken = solver isa SimpleLimitedMemoryBroyden
52+
@jet model(x, ps, st) opt_broken=opt_broken
5453

5554
@test all(isfinite, z)
5655
@test size(z) == size(x)
@@ -107,20 +106,18 @@ end
107106
scales = [((4,), (3,), (2,), (1,))]
108107

109108
model_type = (:deq, :skipdeq, :skipregdeq, :node)
110-
solvers = (VCAB3(), Tsit5(), NewtonRaphson(), SimpleLimitedMemoryBroyden())
109+
solvers = (VCAB3(), Tsit5(),
110+
NewtonRaphson(; autodiff=AutoForwardDiff(; chunksize=12)),
111+
SimpleLimitedMemoryBroyden())
111112
jacobian_regularizations = (nothing,)
112113

113114
for mtype in model_type, jacobian_regularization in jacobian_regularizations
114115
@testset "Solver: $(__nameof(solver))" for solver in solvers
115116
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
116117
mapping_layers, init_layers, x_sizes, scales)
117-
@info solver, mtype, jacobian_regularization, main_layer, mapping_layer,
118-
init_layer, x_size, scale
119-
120118
model = if mtype === :deq
121119
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
122-
solver,
123-
scale; jacobian_regularization)
120+
solver, scale; jacobian_regularization)
124121
elseif mtype === :skipdeq
125122
MultiScaleSkipDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
126123
init_layer, solver, scale; jacobian_regularization)
@@ -140,8 +137,7 @@ end
140137
z, st = model(x, ps, st)
141138
z_ = DEQs.__flatten_vcat(z)
142139

143-
opt_broken = solver isa NewtonRaphson ||
144-
solver isa SimpleLimitedMemoryBroyden
140+
opt_broken = solver isa SimpleLimitedMemoryBroyden
145141
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
146142

147143
@test all(isfinite, z_)

test/qa.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ import ChainRulesCore as CRC
33

44
@testset "Aqua" begin
55
Aqua.test_all(DeepEquilibriumNetworks; ambiguities=false)
6-
Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false,
7-
exclude=[CRC.rrule, CRC.frule])
6+
Aqua.test_ambiguities(DeepEquilibriumNetworks; recursive=false)
87
end

test/runtests.jl

+5-11
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
1-
using SafeTestsets, Test
1+
using SafeTestsets, Test, TestSetExtensions
22

3-
@testset "Deep Equilibrium Networks" begin
4-
@safetestset "Quality Assurance" begin
5-
include("qa.jl")
6-
end
7-
@safetestset "Utilities" begin
8-
include("utils.jl")
9-
end
10-
@safetestset "Layers" begin
11-
include("layers.jl")
12-
end
3+
@testset ExtendedTestSet "Deep Equilibrium Networks" begin
4+
@safetestset "Quality Assurance" include("qa.jl")
5+
@safetestset "Utilities" include("utils.jl")
6+
@safetestset "Layers" include("layers.jl")
137
end

0 commit comments

Comments
 (0)