Skip to content

Commit

Permalink
Merge pull request #396 from MilesCranmer/fix-dual-constraint
Browse files Browse the repository at this point in the history
fix: higher order safe operators
  • Loading branch information
MilesCranmer authored Dec 26, 2024
2 parents 1c64682 + 12449ca commit 5dbd23a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "1.5.0"
version = "1.5.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ const Dual = ForwardDiff.Dual
#binary: mod
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.

const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}
const FloatOrDual = Union{AbstractFloat,Dual}
# Note that a complex dual is Complex{<:Dual}, so we are safe to use this signature.

# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
# Define allowed operators. Any julia operator can also be used.
Expand Down
38 changes: 38 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,41 @@ end
X = stack(([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]); dims=1)
@test expr(X) [1.0, 2.0] .- sin.([3.0, 4.0] .- [5.0, 6.0]) .+ 2.5
end

@testitem "Test higher-order derivatives of safe_log with DynamicDiff" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: D, safe_log, ValidVector
using DynamicExpressions: OperatorEnum
using ForwardDiff: DimensionMismatch

operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(safe_log,))
variable_names = ["x"]
x = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)

# Test first and second derivatives of log(x)
structure = TemplateStructure{(:f,)}(
((; f), (x,)) ->
ValidVector([(f(x).x[1], D(f, 1)(x).x[1], D(D(f, 1), 1)(x).x[1])], true),
)
expr = TemplateExpression((; f=log(x)); structure, operators, variable_names)

# Test at x = 2.0 where log(x) is well-defined
X = [2.0]'
result = only(expr(X))
@test result !== nothing
@test result[1] == log(2.0) # function value
@test result[2] == 1 / 2.0 # first derivative
@test result[3] == -1 / 4.0 # second derivative

# We handle invalid ranges gracefully:
X_invalid = [-1.0]'
result = only(expr(X_invalid))
@test result !== nothing
@test isnan(result[1])
@test result[2] == 0.0
@test result[3] == 0.0

# Eventually we want to support complex numbers:
X_complex = [-1.0 - 1.0im]'
@test_throws DimensionMismatch expr(X_complex)
end

0 comments on commit 5dbd23a

Please sign in to comment.