Skip to content

Commit

Permalink
Merge pull request #743 from SciML/scimlsensitivity
Browse files Browse the repository at this point in the history
update to SciMLSensitivity
  • Loading branch information
ChrisRackauckas authored Jun 25, 2022
2 parents 4dcb907 + 4f3d433 commit e32422d
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 13 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name = "DiffEqFlux"
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
authors = ["Chris Rackauckas <[email protected]>"]
version = "1.50.0"
version = "1.51.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -32,6 +31,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -43,7 +43,6 @@ Cassette = "0.3.7"
ConsoleProgressMonitor = "0.1"
DataInterpolations = "3.3"
DiffEqBase = "6.41"
DiffEqSensitivity = "6.65"
DiffResults = "1.0"
Distributions = "0.23, 0.24, 0.25"
DistributionsAD = "0.6"
Expand All @@ -62,6 +61,7 @@ RecursiveArrayTools = "2"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1"
SciMLSensitivity = "7"
StaticArrays = "0.11, 0.12, 1"
TerminalLoggers = "0.1"
Zygote = "0.5, 0.6"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -18,6 +17,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"

Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/GPUs.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ condition is a GPU array. Thus, for example, we can define a neural ODE by hand
that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):

```julia
using DifferentialEquations, Flux, DiffEqFlux, DiffEqSensitivity
using DifferentialEquations, Flux, DiffEqFlux, SciMLSensitivity

using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -70,7 +70,7 @@ same code works on CPUs and GPUs, dependent on `using CUDA`.

```julia
using Flux, DiffEqFlux, Optimization, OptimizationFlux, Zygote,
OrdinaryDiffEq, Plots, CUDA, DiffEqSensitivity, Random, ComponentArrays
OrdinaryDiffEq, Plots, CUDA, SciMLSensitivity, Random, ComponentArrays
CUDA.allowscalar(false) # Makes sure no slow operations are occuring

#rng for Lux.setup
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/collocation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pretraining the neural network against a smoothed collocation of the
data. First the example and then an explanation.

```@example collocation_cp
using Lux, DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity, Optimization, OptimizationFlux, Plots
using Lux, DiffEqFlux, OrdinaryDiffEq, SciMLSensitivity, Optimization, OptimizationFlux, Plots
using Random
rng = Random.default_rng()
Expand Down
4 changes: 2 additions & 2 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module DiffEqFlux

using Adapt, Base.Iterators, ConsoleProgressMonitor, DataInterpolations,
DiffEqBase, DiffEqSensitivity, DiffResults, Distributions, DistributionsAD,
DiffEqBase, SciMLSensitivity, DiffResults, Distributions, DistributionsAD,
ForwardDiff, Optimization, OptimizationPolyalgorithms, LinearAlgebra,
Logging, LoggingExtras, Printf, ProgressLogging, Random, RecursiveArrayTools,
Reexport, SciMLBase, StaticArrays, TerminalLoggers, Zygote, ZygoteRules

@reexport using DiffEqSensitivity
@reexport using SciMLSensitivity
@reexport using Zygote

# deprecate
Expand Down
4 changes: 2 additions & 2 deletions src/fast_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,12 @@ paramlength(f::StaticDense{out,in,bias}) where {out,in,bias} = out*(in + bias)
initial_params(f::StaticDense) = f.initial_params()

# Override FastDense to exclude the branch from the check
function Cassette.overdub(ctx::DiffEqSensitivity.HasBranchingCtx, f::FastDense, x, p)
function Cassette.overdub(ctx::SciMLSensitivity.HasBranchingCtx, f::FastDense, x, p)
y = reshape(p[1:(f.out*f.in)],f.out,f.in)*x
Cassette.@overdub ctx f.σ.(y)
end

function Cassette.overdub(ctx::DiffEqSensitivity.HasBranchingCtx, f::StaticDense{out,in,bias}, x, p) where {out,in,bias}
function Cassette.overdub(ctx::SciMLSensitivity.HasBranchingCtx, f::StaticDense{out,in,bias}, x, p) where {out,in,bias}
y = reshape(p[1:(out*in)],out,in)*x
Cassette.@overdub ctx f.σ.(y)
end
4 changes: 2 additions & 2 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ derivatives of the loss backwards in time.
```julia
NeuralODE(model,tspan,alg=nothing,args...;kwargs...)
NeuralODE(model::FastChain,tspan,alg=nothing,args...;
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true)),
sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)),
kwargs...)
```
Expand Down Expand Up @@ -490,7 +490,7 @@ the constraint equations.
```julia
NeuralODEMM(model,constraints_model,tspan,mass_matrix,alg=nothing,args...;kwargs...)
NeuralODEMM(model::FastChain,tspan,mass_matrix,alg=nothing,args...;
sensealg=InterpolatingAdjoint(autojacvec=DiffEqSensitivity.ReverseDiffVJP(true)),
sensealg=InterpolatingAdjoint(autojacvec=SciMLSensitivity.ReverseDiffVJP(true)),
kwargs...)
```
Expand Down

0 comments on commit e32422d

Please sign in to comment.