Skip to content

Commit 81b2c38

Browse files
authored
Merge pull request #204 from SciML/myb/chainrules
Add derivative frules
2 parents 79e91be + de800db commit 81b2c38

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ authors = ["Chris Rackauckas and Julia Computing"]
44
version = "2.1.0"
55

66
[deps]
7+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
89
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
1112
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1213

1314
[compat]
15+
ChainRulesCore = "1"
1416
DiffEqBase = "6"
1517
IfElse = "0.1"
1618
ModelingToolkit = "8.50"

src/Blocks/Blocks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ The module `Blocks` contains common input-output components, referred to as bloc
33
"""
44
module Blocks
55
using ModelingToolkit, Symbolics
6-
using IfElse: ifelse
6+
import IfElse: ifelse
77
import ..@symcheck
88
using ModelingToolkit: getdefault
99

src/Blocks/sources.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DiffEqBase
2+
import ChainRulesCore
23

34
# Define and register smooth functions
45
# These are "smooth" aka differentiable and avoid Gibbs effect
@@ -496,19 +497,32 @@ Base.:+(x::Number, y::Parameter) = x + y.ref
496497
Base.:+(y::Parameter, x::Number) = Base.:+(x, y)
497498
Base.:+(x::Parameter, y::Parameter) = x.ref + y.ref
498499

500+
Base.:-(y::Parameter) = -y.ref
499501
Base.:-(x::Number, y::Parameter) = x - y.ref
500502
Base.:-(y::Parameter, x::Number) = y.ref - x
501503
Base.:-(x::Parameter, y::Parameter) = x.ref - y.ref
502504

503-
Base.:^(x::Number, y::Parameter) = Base.power_by_squaring(x, y.ref)
504-
Base.:^(y::Parameter, x::Number) = Base.power_by_squaring(y.ref, x)
505-
Base.:^(x::Parameter, y::Parameter) = Base.power_by_squaring(x.ref, y.ref)
505+
Base.:^(x::Number, y::Parameter) = Base.:^(x, y.ref)
506+
Base.:^(y::Parameter, x::Number) = Base.:^(y.ref, x)
507+
Base.:^(x::Parameter, y::Parameter) = Base.:^(x.ref, y.ref)
506508

507509
Base.isless(x::Parameter, y::Number) = Base.isless(x.ref, y)
508510
Base.isless(y::Number, x::Parameter) = Base.isless(y, x.ref)
509511

510512
Base.copy(x::Parameter{T}) where {T} = Parameter{T}(copy(x.data), x.ref)
511513

514+
ifelse(c::Bool, x::Parameter, y::Parameter) = ifelse(c, x.ref, y.ref)
515+
ifelse(c::Bool, x::Parameter, y::Number) = ifelse(c, x.ref, y)
516+
ifelse(c::Bool, x::Number, y::Parameter) = ifelse(c, x, y.ref)
517+
518+
Base.max(x::Number, y::Parameter) = max(x, y.ref)
519+
Base.max(x::Parameter, y::Number) = max(x.ref, y)
520+
Base.max(x::Parameter, y::Parameter) = max(x.ref, y.ref)
521+
522+
Base.min(x::Number, y::Parameter) = min(x, y.ref)
523+
Base.min(x::Parameter, y::Number) = min(x.ref, y)
524+
Base.min(x::Parameter, y::Parameter) = min(x.ref, y.ref)
525+
512526
function Base.show(io::IO, m::MIME"text/plain", p::Parameter)
513527
if !isempty(p.data)
514528
print(io, p.data)
@@ -575,6 +589,9 @@ function Symbolics.derivative(::typeof(get_sampled_data), args::NTuple{2, Any},
575589
memory = @inbounds args[2]
576590
first_order_backwards_difference(t, memory)
577591
end
592+
function ChainRulesCore.frule((_, ẋ, _), ::typeof(get_sampled_data), t, memory)
593+
first_order_backwards_difference(t, memory) *
594+
end
578595

579596
"""
580597
SampledData(; name, buffer)
@@ -614,6 +631,9 @@ function SampledData(data::Vector{T}, dt::T, circular_buffer = true; name) where
614631
end
615632

616633
Base.convert(::Type{T}, x::Parameter{T}) where {T <: Real} = x.ref
634+
function Base.convert(::Type{<:Parameter{T}}, x::Number) where {T <: Real}
635+
Parameter{T}(T[], x, true)
636+
end
617637

618638
# Beta Code for potential AE Hack ----------------------
619639
function set_sampled_data!(memory::Parameter{T}, t, x, Δt::Parameter{T}) where {T}
@@ -649,6 +669,14 @@ function Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any},
649669
first_order_backwards_difference(t, x, Δt, memory)
650670
end
651671
Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any}, ::Val{3}) = 1 #set_sampled_data returns x, therefore d/dx (x) = 1
672+
function ChainRulesCore.frule((_, _, ṫ, ẋ, _),
673+
::typeof(set_sampled_data!),
674+
memory,
675+
t,
676+
x,
677+
Δt)
678+
first_order_backwards_difference(t, x, Δt, memory) *+
679+
end
652680

653681
function first_order_backwards_difference(t, x, Δt, memory)
654682
x1 = set_sampled_data!(memory, t, x, Δt)

src/Hydraulic/IsothermalCompressible/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ChainRulesCore
2+
13
regPow(x, a, delta = 0.01) = x * (x * x + delta * delta)^((a - 1) / 2);
24
regRoot(x, delta = 0.01) = regPow(x, 0.5, delta)
35

@@ -116,6 +118,9 @@ end
116118
@register_symbolic friction_factor(dm, area, d_h, viscosity, shape_factor)
117119
Symbolics.derivative(::typeof(friction_factor), args, ::Val{1}) = 0
118120
Symbolics.derivative(::typeof(friction_factor), args, ::Val{4}) = 0
121+
function ChainRulesCore.frule(_, ::typeof(friction_factor), args...)
122+
(friction_factor(args...), ChainRulesCore.ZeroTangent)
123+
end
119124

120125
function transition(x1, x2, y1, y2, x)
121126
u = (x - x1) / (x2 - x1)

src/Mechanical/Translational/Translational.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ include("utils.jl")
1818
export Mass, Spring, Damper, Fixed
1919
include("components.jl")
2020

21-
export Force, Position
21+
export Force, Position, Velocity, Acceleration
2222
include("sources.jl")
2323

2424
export ForceSensor, PositionSensor

0 commit comments

Comments
 (0)