|
1 | 1 | using DiffEqBase
|
| 2 | +import ChainRulesCore |
2 | 3 |
|
3 | 4 | # Define and register smooth functions
|
4 | 5 | # These are "smooth" aka differentiable and avoid Gibbs effect
|
@@ -496,19 +497,32 @@ Base.:+(x::Number, y::Parameter) = x + y.ref
|
496 | 497 | Base.:+(y::Parameter, x::Number) = Base.:+(x, y)
|
497 | 498 | Base.:+(x::Parameter, y::Parameter) = x.ref + y.ref
|
498 | 499 |
|
| 500 | +Base.:-(y::Parameter) = -y.ref |
499 | 501 | Base.:-(x::Number, y::Parameter) = x - y.ref
|
500 | 502 | Base.:-(y::Parameter, x::Number) = y.ref - x
|
501 | 503 | Base.:-(x::Parameter, y::Parameter) = x.ref - y.ref
|
502 | 504 |
|
503 |
| -Base.:^(x::Number, y::Parameter) = Base.power_by_squaring(x, y.ref) |
504 |
| -Base.:^(y::Parameter, x::Number) = Base.power_by_squaring(y.ref, x) |
505 |
| -Base.:^(x::Parameter, y::Parameter) = Base.power_by_squaring(x.ref, y.ref) |
| 505 | +Base.:^(x::Number, y::Parameter) = Base.:^(x, y.ref) |
| 506 | +Base.:^(y::Parameter, x::Number) = Base.:^(y.ref, x) |
| 507 | +Base.:^(x::Parameter, y::Parameter) = Base.:^(x.ref, y.ref) |
506 | 508 |
|
507 | 509 | Base.isless(x::Parameter, y::Number) = Base.isless(x.ref, y)
|
508 | 510 | Base.isless(y::Number, x::Parameter) = Base.isless(y, x.ref)
|
509 | 511 |
|
510 | 512 | Base.copy(x::Parameter{T}) where {T} = Parameter{T}(copy(x.data), x.ref)
|
511 | 513 |
|
| 514 | +ifelse(c::Bool, x::Parameter, y::Parameter) = ifelse(c, x.ref, y.ref) |
| 515 | +ifelse(c::Bool, x::Parameter, y::Number) = ifelse(c, x.ref, y) |
| 516 | +ifelse(c::Bool, x::Number, y::Parameter) = ifelse(c, x, y.ref) |
| 517 | + |
| 518 | +Base.max(x::Number, y::Parameter) = max(x, y.ref) |
| 519 | +Base.max(x::Parameter, y::Number) = max(x.ref, y) |
| 520 | +Base.max(x::Parameter, y::Parameter) = max(x.ref, y.ref) |
| 521 | + |
| 522 | +Base.min(x::Number, y::Parameter) = min(x, y.ref) |
| 523 | +Base.min(x::Parameter, y::Number) = min(x.ref, y) |
| 524 | +Base.min(x::Parameter, y::Parameter) = min(x.ref, y.ref) |
| 525 | + |
512 | 526 | function Base.show(io::IO, m::MIME"text/plain", p::Parameter)
|
513 | 527 | if !isempty(p.data)
|
514 | 528 | print(io, p.data)
|
@@ -575,6 +589,9 @@ function Symbolics.derivative(::typeof(get_sampled_data), args::NTuple{2, Any},
|
575 | 589 | memory = @inbounds args[2]
|
576 | 590 | first_order_backwards_difference(t, memory)
|
577 | 591 | end
|
| 592 | +function ChainRulesCore.frule((_, ẋ, _), ::typeof(get_sampled_data), t, memory) |
| 593 | + first_order_backwards_difference(t, memory) * ẋ |
| 594 | +end |
578 | 595 |
|
579 | 596 | """
|
580 | 597 | SampledData(; name, buffer)
|
@@ -614,6 +631,9 @@ function SampledData(data::Vector{T}, dt::T, circular_buffer = true; name) where
|
614 | 631 | end
|
615 | 632 |
|
616 | 633 | Base.convert(::Type{T}, x::Parameter{T}) where {T <: Real} = x.ref
|
| 634 | +function Base.convert(::Type{<:Parameter{T}}, x::Number) where {T <: Real} |
| 635 | + Parameter{T}(T[], x, true) |
| 636 | +end |
617 | 637 |
|
618 | 638 | # Beta Code for potential AE Hack ----------------------
|
619 | 639 | function set_sampled_data!(memory::Parameter{T}, t, x, Δt::Parameter{T}) where {T}
|
@@ -649,6 +669,14 @@ function Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any},
|
649 | 669 | first_order_backwards_difference(t, x, Δt, memory)
|
650 | 670 | end
|
651 | 671 | Symbolics.derivative(::typeof(set_sampled_data!), args::NTuple{4, Any}, ::Val{3}) = 1 #set_sampled_data returns x, therefore d/dx (x) = 1
|
| 672 | +function ChainRulesCore.frule((_, _, ṫ, ẋ, _), |
| 673 | + ::typeof(set_sampled_data!), |
| 674 | + memory, |
| 675 | + t, |
| 676 | + x, |
| 677 | + Δt) |
| 678 | + first_order_backwards_difference(t, x, Δt, memory) * ṫ + ẋ |
| 679 | +end |
652 | 680 |
|
653 | 681 | function first_order_backwards_difference(t, x, Δt, memory)
|
654 | 682 | x1 = set_sampled_data!(memory, t, x, Δt)
|
|
0 commit comments