diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml
index 5deca92df..900a7b375 100644
--- a/.github/workflows/CI.yml
+++ b/.github/workflows/CI.yml
@@ -30,8 +30,7 @@ jobs:
- "part2"
- "part3"
julia-version:
- - "1.6"
- - "1.8"
+ - "1.10"
- "1"
os:
- ubuntu-latest
@@ -54,15 +53,6 @@ jobs:
- os: macOS-latest
julia-version: "1"
test: "part3"
- - os: ubuntu-latest
- julia-version: "~1.11.0-0"
- test: "part1"
- - os: ubuntu-latest
- julia-version: "~1.11.0-0"
- test: "part2"
- - os: ubuntu-latest
- julia-version: "~1.11.0-0"
- test: "part3"
steps:
- uses: actions/checkout@v4
diff --git a/Project.toml b/Project.toml
index 6f899cb60..9f0bf2341 100644
--- a/Project.toml
+++ b/Project.toml
@@ -18,7 +18,6 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
-PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -56,7 +55,6 @@ LossFunctions = "0.10, 0.11"
MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11"
MacroTools = "0.4, 0.5"
Optim = "~1.8, ~1.9"
-PackageExtensionCompat = "1"
Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Printf = "<0.0.1, 1"
@@ -67,7 +65,7 @@ SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
TOML = "<0.0.1, 1"
-julia = "1.6"
+julia = "1.10"
[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl
index 7ec312618..9e4862348 100644
--- a/benchmark/benchmarks.jl
+++ b/benchmark/benchmarks.jl
@@ -1,6 +1,8 @@
using BenchmarkTools
using SymbolicRegression, BenchmarkTools, Random
using SymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics
+using SymbolicRegression.MutateModule: next_generation
+using SymbolicRegression.RecorderModule: RecordType
using SymbolicRegression.PopulationModule: best_of_sample
using SymbolicRegression.ConstantOptimizationModule: optimize_constants
using SymbolicRegression.CheckConstraintsModule: check_constraints
@@ -93,6 +95,57 @@ function create_utils_benchmark()
)
)
+ suite["next_generation_x100"] = @benchmarkable(
+ let
+ for member in members
+ next_generation(
+ dataset,
+ member,
+ temperature,
+ curmaxsize,
+ rss,
+ options;
+ tmp_recorder=recorder,
+ )
+ end
+ end,
+ setup = (
+ nfeatures = 1;
+ dataset = Dataset(randn(nfeatures, 32), randn(32));
+ mutation_weights = MutationWeights(;
+ mutate_constant=1.0,
+ mutate_operator=1.0,
+ swap_operands=1.0,
+ rotate_tree=1.0,
+ add_node=1.0,
+ insert_node=1.0,
+ simplify=0.0,
+ randomize=0.0,
+ do_nothing=0.0,
+ form_connection=0.0,
+ break_connection=0.0,
+ );
+ options = Options(;
+ unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights
+ );
+ recorder = RecordType();
+ temperature = 1.0;
+ curmaxsize = 20;
+ rss = RunningSearchStatistics(; options);
+ trees = [
+ gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100
+ ];
+ expressions = [
+ Expression(tree; operators=options.operators, variable_names=["x1"]) for
+ tree in trees
+ ];
+ members = [
+ PopMember(dataset, expression, options; deterministic=false) for
+ expression in expressions
+ ]
+ )
+ )
+
ntrees = 10
suite["optimize_constants_x10"] = @benchmarkable(
foreach(members) do member
diff --git a/docs/make.jl b/docs/make.jl
index f3ad5f756..b0546821b 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -1,7 +1,21 @@
using Documenter
using SymbolicUtils
using SymbolicRegression
-using SymbolicRegression: Dataset, update_baseline_loss!
+using SymbolicRegression:
+ AbstractExpression,
+ ExpressionInterface,
+ Dataset,
+ update_baseline_loss!,
+ AbstractMutationWeights,
+ AbstractOptions,
+ mutate!,
+ condition_mutation_weights!,
+ sample_mutation,
+ MutationResult,
+ AbstractRuntimeOptions,
+ AbstractSearchState,
+ @extend_operators
+using DynamicExpressions
DocMeta.setdocmeta!(
SymbolicRegression, :DocTestSetup, :(using LossFunctions); recursive=true
@@ -40,14 +54,8 @@ readme = replace(
# We prepend the `
` with a ```@raw html
# and append the `
` with a ```:
-readme = replace(
- readme,
- r"" => s"```@raw html\n",
-)
-readme = replace(
- readme,
- r"
" => s"
\n```",
-)
+readme = replace(readme, r"" => s"```@raw html\n")
+readme = replace(readme, r"
" => s"
\n```")
# Then, we surround ```mermaid\n...\n``` snippets
# with ```@raw html\n\n...\n
```:
@@ -96,6 +104,7 @@ makedocs(;
"API" => "api.md",
"Losses" => "losses.md",
"Types" => "types.md",
+ "Customization" => "customization.md",
],
)
diff --git a/docs/src/api.md b/docs/src/api.md
index d9ac1fa97..b7e0f7e28 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -7,7 +7,7 @@ SRRegressor
MultitargetSRRegressor
```
-## equation_search
+## Low-Level API
```@docs
equation_search
diff --git a/docs/src/customization.md b/docs/src/customization.md
new file mode 100644
index 000000000..2a9d9a072
--- /dev/null
+++ b/docs/src/customization.md
@@ -0,0 +1,61 @@
+# Customization
+
+Many parts of SymbolicRegression.jl are designed to be customizable.
+
+The normal way to do this in Julia is to define a new type that subtypes
+an abstract type from a package, and then define new methods for the type,
+extending internal methods on that type.
+
+## Custom Options
+
+For example, you can define a custom options type:
+
+```@docs
+AbstractOptions
+```
+
+Any function in SymbolicRegression.jl you can generally define a new method
+on your custom options type, to define custom behavior.
+
+## Custom Mutations
+
+You can define custom mutation operators by defining a new method on
+`mutate!`, as well as subtyping `AbstractMutationWeights`:
+
+```@docs
+mutate!
+AbstractMutationWeights
+condition_mutation_weights!
+sample_mutation
+MutationResult
+```
+
+## Custom Expressions
+
+You can create your own expression types by defining a new type that extends `AbstractExpression`.
+
+```@docs
+AbstractExpression
+ExpressionInterface
+```
+
+The interface is fairly flexible, and permits you define specific functional forms,
+extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what
+methods you need to implement. Then, for SymbolicRegression.jl, you would
+pass `expression_type` to the `Options` constructor, as well as any
+`expression_options` you need (as a `NamedTuple`).
+
+If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in
+case your expression needs additional parameters. See the method for `ParametricExpression`
+as an example.
+
+## Other Customizations
+
+Other internal abstract types include the following:
+
+```@docs
+AbstractRuntimeOptions
+AbstractSearchState
+```
+
+These let you include custom state variables and runtime options.
diff --git a/docs/src/types.md b/docs/src/types.md
index 5c7277c57..92bf5632e 100644
--- a/docs/src/types.md
+++ b/docs/src/types.md
@@ -60,24 +60,6 @@ ParametricNode
These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`.
-## Custom Expressions
-
-You can create your own expression types by defining a new type that extends `AbstractExpression`.
-
-```@docs
-AbstractExpression
-```
-
-The interface is fairly flexible, and permits you define specific functional forms,
-extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what
-methods you need to implement. Then, for SymbolicRegression.jl, you would
-pass `expression_type` to the `Options` constructor, as well as any
-`expression_options` you need (as a `NamedTuple`).
-
-If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in
-case your expression needs additional parameters. See the method for `ParametricExpression`
-as an example.
-
## Population
Groups of equations are given as a population, which is
diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl
index b438aef4b..aa33fa613 100644
--- a/src/AdaptiveParsimony.jl
+++ b/src/AdaptiveParsimony.jl
@@ -1,6 +1,6 @@
module AdaptiveParsimonyModule
-using ..CoreModule: Options, MAX_DEGREE
+using ..CoreModule: AbstractOptions, MAX_DEGREE
"""
RunningSearchStatistics
@@ -23,7 +23,7 @@ struct RunningSearchStatistics
normalized_frequencies::Vector{Float64} # Stores `frequencies`, but normalized (updated once in a while)
end
-function RunningSearchStatistics(; options::Options, window_size::Int=100000)
+function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000)
maxsize = options.maxsize
actualMaxsize = maxsize + MAX_DEGREE
init_frequencies = ones(Float64, actualMaxsize)
diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl
index 7f6093631..df748c218 100644
--- a/src/CheckConstraints.jl
+++ b/src/CheckConstraints.jl
@@ -2,12 +2,12 @@ module CheckConstraintsModule
using DynamicExpressions:
AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce
-using ..CoreModule: Options
+using ..CoreModule: AbstractOptions
using ..ComplexityModule: compute_complexity, past_complexity_limit
# Check if any binary operator are overly complex
function flag_bin_operator_complexity(
- tree::AbstractExpressionNode, op, cons, options::Options
+ tree::AbstractExpressionNode, op, cons, options::AbstractOptions
)::Bool
any(tree) do subtree
if subtree.degree == 2 && subtree.op == op
@@ -27,7 +27,7 @@ Check if any unary operators are overly complex.
This assumes you have already checked whether the constraint is > -1.
"""
function flag_una_operator_complexity(
- tree::AbstractExpressionNode, op, cons, options::Options
+ tree::AbstractExpressionNode, op, cons, options::AbstractOptions
)::Bool
any(tree) do subtree
if subtree.degree == 1 && tree.op == op
@@ -52,7 +52,7 @@ function count_max_nestedness(tree, degree, op)
end
"""Check if there are any illegal combinations of operators"""
-function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Bool
+function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptions)::Bool
# We search from the top first, then from child nodes at end.
(nested_constraints = options.nested_constraints) === nothing && return false
for (degree, op_idx, op_constraint) in nested_constraints
@@ -72,7 +72,7 @@ end
"""Check if user-passed constraints are violated or not"""
function check_constraints(
ex::AbstractExpression,
- options::Options,
+ options::AbstractOptions,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
@@ -81,7 +81,7 @@ function check_constraints(
end
function check_constraints(
tree::AbstractExpressionNode,
- options::Options,
+ options::AbstractOptions,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
@@ -103,7 +103,7 @@ function check_constraints(
end
check_constraints(
- ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options
+ ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions
)::Bool = check_constraints(ex, options, options.maxsize)
end
diff --git a/src/Complexity.jl b/src/Complexity.jl
index dccb05bd3..dec8fb63e 100644
--- a/src/Complexity.jl
+++ b/src/Complexity.jl
@@ -2,10 +2,10 @@ module ComplexityModule
using DynamicExpressions:
AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce
-using ..CoreModule: Options, ComplexityMapping
+using ..CoreModule: AbstractOptions, ComplexityMapping
function past_complexity_limit(
- tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit
+ tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions, limit
)::Bool
return compute_complexity(tree, options) > limit
end
@@ -18,12 +18,12 @@ However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(
- tree::AbstractExpression, options::Options; break_sharing=Val(false)
+ tree::AbstractExpression, options::AbstractOptions; break_sharing=Val(false)
)
return compute_complexity(get_tree(tree), options; break_sharing)
end
function compute_complexity(
- tree::AbstractExpressionNode, options::Options; break_sharing=Val(false)
+ tree::AbstractExpressionNode, options::AbstractOptions; break_sharing=Val(false)
)::Int
if options.complexity_mapping.use
raw_complexity = _compute_complexity(
diff --git a/src/Configure.jl b/src/Configure.jl
index 5ccc08100..eefd63619 100644
--- a/src/Configure.jl
+++ b/src/Configure.jl
@@ -29,7 +29,7 @@ end
const TEST_INPUTS = collect(range(-100, 100; length=99))
-function assert_operators_well_defined(T, options::Options)
+function assert_operators_well_defined(T, options::AbstractOptions)
test_input = if T <: Complex
(x -> convert(T, x)).(TEST_INPUTS .+ TEST_INPUTS .* im)
else
@@ -45,7 +45,7 @@ end
# Check for errors before they happen
function test_option_configuration(
- parallelism, datasets::Vector{D}, options::Options, verbosity
+ parallelism, datasets::Vector{D}, options::AbstractOptions, verbosity
) where {T,D<:Dataset{T}}
if options.deterministic && parallelism != :serial
error("Determinism is only guaranteed for serial mode.")
@@ -84,7 +84,7 @@ end
# Check for errors before they happen
function test_dataset_configuration(
- dataset::Dataset{T}, options::Options, verbosity
+ dataset::Dataset{T}, options::AbstractOptions, verbosity
) where {T<:DATA_TYPE}
n = dataset.n
if n != size(dataset.X, 2) ||
@@ -113,7 +113,7 @@ end
""" Move custom operators and loss functions to workers, if undefined """
function move_functions_to_workers(
- procs, options::Options, dataset::Dataset{T}, verbosity
+ procs, options::AbstractOptions, dataset::Dataset{T}, verbosity
) where {T}
# All the types of functions we need to move to workers:
function_sets = (
@@ -168,7 +168,7 @@ function move_functions_to_workers(
end
end
-function copy_definition_to_workers(op, procs, options::Options, verbosity)
+function copy_definition_to_workers(op, procs, options::AbstractOptions, verbosity)
name = nameof(op)
verbosity > 0 && @info "Copying definition of $op to workers..."
src_ms = methods(op).ms
@@ -191,7 +191,9 @@ function test_function_on_workers(example_inputs, op, procs)
end
end
-function activate_env_on_workers(procs, project_path::String, options::Options, verbosity)
+function activate_env_on_workers(
+ procs, project_path::String, options::AbstractOptions, verbosity
+)
verbosity > 0 && @info "Activating environment on workers."
@everywhere procs begin
Base.MainInclude.eval(
@@ -203,7 +205,9 @@ function activate_env_on_workers(procs, project_path::String, options::Options,
end
end
-function import_module_on_workers(procs, filename::String, options::Options, verbosity)
+function import_module_on_workers(
+ procs, filename::String, options::AbstractOptions, verbosity
+)
loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]
included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker
@@ -251,7 +255,7 @@ function import_module_on_workers(procs, filename::String, options::Options, ver
return nothing
end
-function test_module_on_workers(procs, options::Options, verbosity)
+function test_module_on_workers(procs, options::AbstractOptions, verbosity)
verbosity > 0 && @info "Testing module on workers..."
futures = []
for proc in procs
@@ -268,7 +272,7 @@ function test_module_on_workers(procs, options::Options, verbosity)
end
function test_entire_pipeline(
- procs, dataset::Dataset{T}, options::Options, verbosity
+ procs, dataset::Dataset{T}, options::AbstractOptions, verbosity
) where {T<:DATA_TYPE}
futures = []
verbosity > 0 && @info "Testing entire pipeline on workers..."
@@ -310,7 +314,7 @@ function configure_workers(;
procs::Union{Vector{Int},Nothing},
numprocs::Int,
addprocs_function::Function,
- options::Options,
+ options::AbstractOptions,
project_path,
file,
exeflags::Cmd,
@@ -325,10 +329,6 @@ function configure_workers(;
end
if we_created_procs
- if VERSION < v"1.9.0"
- # On newer Julia; environment is activated automatically
- activate_env_on_workers(procs, project_path, options, verbosity)
- end
import_module_on_workers(procs, file, options, verbosity)
end
diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl
index fe66b4f5d..92b5d0c5d 100644
--- a/src/ConstantOptimization.jl
+++ b/src/ConstantOptimization.jl
@@ -11,13 +11,13 @@ using DynamicExpressions:
get_scalar_constants,
set_scalar_constants!,
extract_gradient
-using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options
+using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options
using ..UtilsModule: get_birth_order
using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample
using ..PopMemberModule: PopMember
function optimize_constants(
- dataset::Dataset{T,L}, member::P, options::Options
+ dataset::Dataset{T,L}, member::P, options::AbstractOptions
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
if options.batching
dispatch_optimize_constants(
@@ -28,7 +28,7 @@ function optimize_constants(
end
end
function dispatch_optimize_constants(
- dataset::Dataset{T,L}, member::P, options::Options, idx
+ dataset::Dataset{T,L}, member::P, options::AbstractOptions, idx
) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}}
nconst = count_constants_for_optimization(member.tree)
nconst == 0 && return (member, 0.0)
@@ -103,7 +103,7 @@ function _optimize_constants(
return member, num_evals
end
-struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:Options,I} <: Function
+struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:AbstractOptions,I} <: Function
tree::N
refs::R
dataset::D
diff --git a/src/Core.jl b/src/Core.jl
index 8f917c906..0a16b52d0 100644
--- a/src/Core.jl
+++ b/src/Core.jl
@@ -13,9 +13,8 @@ include("Options.jl")
using .ProgramConstantsModule:
MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE
using .DatasetModule: Dataset
-using .MutationWeightsModule: MutationWeights, sample_mutation
-using .OptionsStructModule: Options, ComplexityMapping, specialized_options
-using .OptionsModule: Options
+using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation
+using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options
using .OperatorsModule:
plus,
sub,
diff --git a/src/Dataset.jl b/src/Dataset.jl
index 99c31ee3d..c8cb9767a 100644
--- a/src/Dataset.jl
+++ b/src/Dataset.jl
@@ -2,7 +2,7 @@ module DatasetModule
using DynamicQuantities: Quantity
-using ..UtilsModule: subscriptify, get_base_type, @constfield
+using ..UtilsModule: subscriptify, get_base_type
using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units
@@ -57,24 +57,24 @@ mutable struct Dataset{
XUS<:Union{AbstractVector{<:Quantity},Nothing},
YUS<:Union{Quantity,Nothing},
}
- @constfield X::AX
- @constfield y::AY
- @constfield index::Int
- @constfield n::Int
- @constfield nfeatures::Int
- @constfield weighted::Bool
- @constfield weights::AW
- @constfield extra::NT
- @constfield avg_y::Union{T,Nothing}
+ const X::AX
+ const y::AY
+ const index::Int
+ const n::Int
+ const nfeatures::Int
+ const weighted::Bool
+ const weights::AW
+ const extra::NT
+ const avg_y::Union{T,Nothing}
use_baseline::Bool
baseline_loss::L
- @constfield variable_names::Array{String,1}
- @constfield display_variable_names::Array{String,1}
- @constfield y_variable_name::String
- @constfield X_units::XU
- @constfield y_units::YU
- @constfield X_sym_units::XUS
- @constfield y_sym_units::YUS
+ const variable_names::Array{String,1}
+ const display_variable_names::Array{String,1}
+ const y_variable_name::String
+ const X_units::XU
+ const y_units::YU
+ const X_sym_units::XUS
+ const y_sym_units::YUS
end
"""
diff --git a/src/DimensionalAnalysis.jl b/src/DimensionalAnalysis.jl
index cc9440db1..d469e6848 100644
--- a/src/DimensionalAnalysis.jl
+++ b/src/DimensionalAnalysis.jl
@@ -3,7 +3,7 @@ module DimensionalAnalysisModule
using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree
using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof
-using ..CoreModule: Options, Dataset
+using ..CoreModule: AbstractOptions, Dataset
using ..UtilsModule: safe_call
import DynamicQuantities: dimension, ustrip
@@ -180,12 +180,12 @@ function violates_dimensional_constraints_dispatch(
end
"""
- violates_dimensional_constraints(tree::AbstractExpressionNode, dataset::Dataset, options::Options)
+ violates_dimensional_constraints(tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions)
Checks whether an expression violates dimensional constraints.
"""
function violates_dimensional_constraints(
- tree::AbstractExpressionNode, dataset::Dataset, options::Options
+ tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions
)
X = dataset.X
return violates_dimensional_constraints(
@@ -193,7 +193,7 @@ function violates_dimensional_constraints(
)
end
function violates_dimensional_constraints(
- tree::AbstractExpression, dataset::Dataset, options::Options
+ tree::AbstractExpression, dataset::Dataset, options::AbstractOptions
)
return violates_dimensional_constraints(get_tree(tree), dataset, options)
end
@@ -202,7 +202,7 @@ function violates_dimensional_constraints(
X_units::AbstractVector{<:Quantity},
y_units::Union{Quantity,Nothing},
x::AbstractVector{T},
- options::Options,
+ options::AbstractOptions,
) where {T}
allow_wildcards = !(options.dimensionless_constants_only)
dimensional_output = violates_dimensional_constraints_dispatch(
@@ -220,12 +220,20 @@ function violates_dimensional_constraints(
return violates
end
function violates_dimensional_constraints(
- ::AbstractExpressionNode{T}, ::Nothing, ::Quantity, ::AbstractVector{T}, ::Options
+ ::AbstractExpressionNode{T},
+ ::Nothing,
+ ::Quantity,
+ ::AbstractVector{T},
+ ::AbstractOptions,
) where {T}
return error("This should never happen. Please submit a bug report.")
end
function violates_dimensional_constraints(
- ::AbstractExpressionNode{T}, ::Nothing, ::Nothing, ::AbstractVector{T}, ::Options
+ ::AbstractExpressionNode{T},
+ ::Nothing,
+ ::Nothing,
+ ::AbstractVector{T},
+ ::AbstractOptions,
) where {T}
return false
end
diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl
index a54d97bdd..12b20a06c 100644
--- a/src/ExpressionBuilder.jl
+++ b/src/ExpressionBuilder.jl
@@ -17,7 +17,7 @@ using DynamicExpressions:
eval_tree_array
using Random: default_rng, AbstractRNG
using StatsBase: StatsBase
-using ..CoreModule: Options, Dataset, DATA_TYPE
+using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE
using ..HallOfFameModule: HallOfFame
using ..LossFunctionsModule: maybe_getindex
using ..InterfaceDynamicExpressionsModule: expected_array_type
@@ -32,7 +32,7 @@ import ..LossFunctionsModule: eval_tree_dispatch
import ..ConstantOptimizationModule: count_constants_for_optimization
@unstable function create_expression(
- t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false)
+ t::T, options::AbstractOptions, dataset::Dataset{T,L}, ::Val{embed}=Val(false)
) where {T,L,embed}
return create_expression(
constructorof(options.node_type)(; val=t), options, dataset, Val(embed)
@@ -40,7 +40,7 @@ import ..ConstantOptimizationModule: count_constants_for_optimization
end
@unstable function create_expression(
t::AbstractExpressionNode{T},
- options::Options,
+ options::AbstractOptions,
dataset::Dataset{T,L},
::Val{embed}=Val(false),
) where {T,L,embed}
@@ -49,12 +49,12 @@ end
)
end
function create_expression(
- ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false)
+ ex::AbstractExpression{T}, ::AbstractOptions, ::Dataset{T,L}, ::Val{embed}=Val(false)
) where {T,L,embed}
return ex
end
@unstable function init_params(
- options::Options,
+ options::AbstractOptions,
dataset::Dataset{T,L},
prototype::Union{Nothing,AbstractExpression},
::Val{embed},
@@ -71,7 +71,7 @@ end
function extra_init_params(
::Type{E},
prototype::Union{Nothing,AbstractExpression},
- options::Options,
+ options::AbstractOptions,
dataset::Dataset{T},
::Val{embed},
) where {T,embed,E<:AbstractExpression}
@@ -80,7 +80,7 @@ end
function extra_init_params(
::Type{E},
prototype::Union{Nothing,ParametricExpression},
- options::Options,
+ options::AbstractOptions,
dataset::Dataset{T},
::Val{embed},
) where {T,embed,E<:ParametricExpression}
@@ -95,8 +95,8 @@ function extra_init_params(
return (; parameters=_parameters, parameter_names)
end
-consistency_checks(::Options, prototype::Nothing) = nothing
-function consistency_checks(options::Options, prototype)
+consistency_checks(::AbstractOptions, prototype::Nothing) = nothing
+function consistency_checks(options::AbstractOptions, prototype)
if prototype === nothing
return nothing
end
@@ -120,12 +120,12 @@ end
@unstable begin
function embed_metadata(
- ex::AbstractExpression, options::Options, dataset::Dataset{T,L}
+ ex::AbstractExpression, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return with_metadata(ex; init_params(options, dataset, ex, Val(true))...)
end
function embed_metadata(
- member::PopMember, options::Options, dataset::Dataset{T,L}
+ member::PopMember, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return PopMember(
embed_metadata(member.tree, options, dataset),
@@ -138,37 +138,39 @@ end
)
end
function embed_metadata(
- pop::Population, options::Options, dataset::Dataset{T,L}
+ pop::Population, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return Population(
map(member -> embed_metadata(member, options, dataset), pop.members)
)
end
function embed_metadata(
- hof::HallOfFame, options::Options, dataset::Dataset{T,L}
+ hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return HallOfFame(
map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists
)
end
function embed_metadata(
- vec::Vector{H}, options::Options, dataset::Dataset{T,L}
+ vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L,H<:Union{HallOfFame,Population,PopMember}}
return map(elem -> embed_metadata(elem, options, dataset), vec)
end
end
"""Strips all metadata except for top-level information"""
-function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L}
+function strip_metadata(
+ ex::Expression, options::AbstractOptions, dataset::Dataset{T,L}
+) where {T,L}
return with_metadata(ex; init_params(options, dataset, ex, Val(false))...)
end
function strip_metadata(
- ex::ParametricExpression, options::Options, dataset::Dataset{T,L}
+ ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return with_metadata(ex; init_params(options, dataset, ex, Val(false))...)
end
function strip_metadata(
- member::PopMember, options::Options, dataset::Dataset{T,L}
+ member::PopMember, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return PopMember(
strip_metadata(member.tree, options, dataset),
@@ -181,12 +183,12 @@ function strip_metadata(
)
end
function strip_metadata(
- pop::Population, options::Options, dataset::Dataset{T,L}
+ pop::Population, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return Population(map(member -> strip_metadata(member, options, dataset), pop.members))
end
function strip_metadata(
- hof::HallOfFame, options::Options, dataset::Dataset{T,L}
+ hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L}
) where {T,L}
return HallOfFame(
map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists
@@ -194,7 +196,7 @@ function strip_metadata(
end
function eval_tree_dispatch(
- tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx
+ tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx
) where {T<:DATA_TYPE}
A = expected_array_type(dataset.X)
return eval_tree_array(
@@ -210,7 +212,7 @@ function make_random_leaf(
::Type{T},
::Type{N},
rng::AbstractRNG=default_rng(),
- options::Union{Options,Nothing}=nothing,
+ options::Union{AbstractOptions,Nothing}=nothing,
) where {T<:DATA_TYPE,N<:ParametricNode}
choice = rand(rng, 1:3)
if choice == 1
@@ -263,7 +265,7 @@ end
function mutate_constant(
ex::ParametricExpression{T},
temperature,
- options::Options,
+ options::AbstractOptions,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
if rand(rng, Bool)
@@ -280,10 +282,10 @@ function mutate_constant(
end
end
-@unstable function get_operators(ex::AbstractExpression, options::Options)
+@unstable function get_operators(ex::AbstractExpression, options::AbstractOptions)
return get_operators(ex, options.operators)
end
-@unstable function get_operators(ex::AbstractExpressionNode, options::Options)
+@unstable function get_operators(ex::AbstractExpressionNode, options::AbstractOptions)
return get_operators(ex, options.operators)
end
diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl
index 19c52f933..71032dfd5 100644
--- a/src/HallOfFame.jl
+++ b/src/HallOfFame.jl
@@ -3,7 +3,7 @@ module HallOfFameModule
using DynamicExpressions: AbstractExpression, string_tree
using ..UtilsModule: split_string
using ..CoreModule:
- MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression
+ MAX_DEGREE, AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression
using ..ComplexityModule: compute_complexity
using ..PopMemberModule: PopMember
using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING
@@ -48,7 +48,7 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where
end
"""
- HallOfFame(options::Options, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE}
+ HallOfFame(options::AbstractOptions, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE}
Create empty HallOfFame. The HallOfFame stores a list
of `PopMember` objects in `.members`, which is enumerated
@@ -57,11 +57,11 @@ by size (i.e., `.members[1]` is the constant solution).
has been instantiated or not.
Arguments:
-- `options`: Options containing specification about deterministic.
+- `options`: AbstractOptions containing specification about deterministic.
- `dataset`: Dataset containing the input data.
"""
function HallOfFame(
- options::Options, dataset::Dataset{T,L}
+ options::AbstractOptions, dataset::Dataset{T,L}
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
actualMaxsize = options.maxsize + MAX_DEGREE
base_tree = create_expression(zero(T), options, dataset)
diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl
index d5cf52300..887627b7d 100644
--- a/src/InterfaceDynamicExpressions.jl
+++ b/src/InterfaceDynamicExpressions.jl
@@ -11,14 +11,14 @@ using DynamicExpressions:
Node,
GraphNode
using DynamicQuantities: dimension, ustrip
-using ..CoreModule: Options
+using ..CoreModule: AbstractOptions
using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap
using ..UtilsModule: subscriptify
import ..deprecate_varmap
"""
- eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...)
+ eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; kws...)
Evaluate a binary tree (equation) over a given input data matrix. The
operators contain all of the operators used. This function fuses doublets
@@ -41,7 +41,7 @@ which speed up evaluation significantly.
# Arguments
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate.
- `X::AbstractArray`: The input data to evaluate the tree on.
-- `options::Options`: Options used to define the operators used in the tree.
+- `options::AbstractOptions`: Options used to define the operators used in the tree.
# Returns
- `(output, complete)::Tuple{AbstractVector, Bool}`: the result,
@@ -53,7 +53,7 @@ which speed up evaluation significantly.
function DE.eval_tree_array(
tree::Union{AbstractExpressionNode,AbstractExpression},
X::AbstractMatrix,
- options::Options;
+ options::AbstractOptions;
kws...,
)
A = expected_array_type(X)
@@ -70,7 +70,7 @@ function DE.eval_tree_array(
tree::ParametricExpression,
X::AbstractMatrix,
classes::AbstractVector{<:Integer},
- options::Options;
+ options::AbstractOptions;
kws...,
)
A = expected_array_type(X)
@@ -91,7 +91,7 @@ function expected_array_type(X::AbstractArray)
end
"""
- eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int)
+ eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions, direction::Int)
Compute the forward derivative of an expression, using a similar
structure and optimization to eval_tree_array. `direction` is the index of a particular
@@ -102,7 +102,7 @@ respect to `x1`.
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
- `X::AbstractArray`: The data matrix, with each column being a data point.
-- `options::Options`: The options containing the operators used to create the `tree`.
+- `options::AbstractOptions`: The options containing the operators used to create the `tree`.
- `direction::Int`: The index of the variable to take the derivative with respect to.
# Returns
@@ -113,7 +113,7 @@ respect to `x1`.
function DE.eval_diff_tree_array(
tree::Union{AbstractExpression,AbstractExpressionNode},
X::AbstractArray,
- options::Options,
+ options::AbstractOptions,
direction::Int,
)
A = expected_array_type(X)
@@ -124,7 +124,7 @@ function DE.eval_diff_tree_array(
end
"""
- eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false)
+ eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; variable::Bool=false)
Compute the forward-mode derivative of an expression, using a similar
structure and optimization to eval_tree_array. `variable` specifies whether
@@ -135,7 +135,7 @@ to every constant in the expression.
- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate.
- `X::AbstractArray`: The data matrix, with each column being a data point.
-- `options::Options`: The options containing the operators used to create the `tree`.
+- `options::AbstractOptions`: The options containing the operators used to create the `tree`.
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
or with respect to every constant in the expression (`variable=false`).
@@ -147,7 +147,7 @@ to every constant in the expression.
function DE.eval_grad_tree_array(
tree::Union{AbstractExpression,AbstractExpressionNode},
X::AbstractArray,
- options::Options;
+ options::AbstractOptions;
kws...,
)
A = expected_array_type(X)
@@ -158,14 +158,14 @@ function DE.eval_grad_tree_array(
end
"""
- differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options)
+ differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::AbstractOptions)
Evaluate an expression tree in a way that can be auto-differentiated.
"""
function DE.differentiable_eval_tree_array(
tree::Union{AbstractExpression,AbstractExpressionNode},
X::AbstractArray,
- options::Options,
+ options::AbstractOptions,
)
A = expected_array_type(X)
# TODO: Add `AbstractExpression` implementation in `Expression.jl`
@@ -177,20 +177,20 @@ end
const WILDCARD_UNIT_STRING = "[?]"
"""
- string_tree(tree::AbstractExpressionNode, options::Options; kws...)
+ string_tree(tree::AbstractExpressionNode, options::AbstractOptions; kws...)
Convert an equation to a string.
# Arguments
- `tree::AbstractExpressionNode`: The equation to convert to a string.
-- `options::Options`: The options holding the definition of operators.
+- `options::AbstractOptions`: The options holding the definition of operators.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
to print for each feature.
"""
@inline function DE.string_tree(
tree::Union{AbstractExpression,AbstractExpressionNode},
- options::Options;
+ options::AbstractOptions;
raw::Bool=true,
X_sym_units=nothing,
y_sym_units=nothing,
@@ -283,24 +283,27 @@ end
end
"""
- print_tree(tree::AbstractExpressionNode, options::Options; kws...)
+ print_tree(tree::AbstractExpressionNode, options::AbstractOptions; kws...)
Print an equation
# Arguments
- `tree::AbstractExpressionNode`: The equation to convert to a string.
-- `options::Options`: The options holding the definition of operators.
+- `options::AbstractOptions`: The options holding the definition of operators.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
to print for each feature.
"""
function DE.print_tree(
- tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
+ tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions; kws...
)
return DE.print_tree(tree, DE.get_operators(tree, options); kws...)
end
function DE.print_tree(
- io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws...
+ io::IO,
+ tree::Union{AbstractExpression,AbstractExpressionNode},
+ options::AbstractOptions;
+ kws...,
)
return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...)
end
@@ -317,7 +320,7 @@ defined.
"""
macro extend_operators(options)
operators = :($(options).operators)
- type_requirements = Options
+ type_requirements = AbstractOptions
alias_operators = gensym("alias_operators")
return quote
if !isa($(options), $type_requirements)
@@ -340,7 +343,7 @@ function define_alias_operators(operators)
end
function (tree::Union{AbstractExpression,AbstractExpressionNode})(
- X, options::Options; kws...
+ X, options::AbstractOptions; kws...
)
return tree(
X,
@@ -351,7 +354,10 @@ function (tree::Union{AbstractExpression,AbstractExpressionNode})(
)
end
function DE.EvaluationHelpersModule._grad_evaluator(
- tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws...
+ tree::Union{AbstractExpression,AbstractExpressionNode},
+ X,
+ options::AbstractOptions;
+ kws...,
)
return DE.EvaluationHelpersModule._grad_evaluator(
tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws...
diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl
index a84218879..b41ad3a38 100644
--- a/src/LossFunctions.jl
+++ b/src/LossFunctions.jl
@@ -6,7 +6,7 @@ using DynamicExpressions:
using LossFunctions: LossFunctions
using LossFunctions: SupervisedLoss
using ..InterfaceDynamicExpressionsModule: expected_array_type
-using ..CoreModule: Options, Dataset, create_expression, DATA_TYPE, LOSS_TYPE
+using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
using ..DimensionalAnalysisModule: violates_dimensional_constraints
@@ -44,7 +44,7 @@ end
function eval_tree_dispatch(
tree::Union{AbstractExpression{T},AbstractExpressionNode{T}},
dataset::Dataset{T},
- options::Options,
+ options::AbstractOptions,
idx,
) where {T<:DATA_TYPE}
A = expected_array_type(dataset.X)
@@ -55,7 +55,7 @@ end
function _eval_loss(
tree::Union{AbstractExpression{T},AbstractExpressionNode{T}},
dataset::Dataset{T,L},
- options::Options,
+ options::AbstractOptions,
regularization::Bool,
idx,
)::L where {T<:DATA_TYPE,L<:LOSS_TYPE}
@@ -84,7 +84,11 @@ end
# This evaluates function F:
function evaluator(
- f::F, tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options::Options, idx
+ f::F,
+ tree::AbstractExpressionNode{T},
+ dataset::Dataset{T,L},
+ options::AbstractOptions,
+ idx,
)::L where {T<:DATA_TYPE,L<:LOSS_TYPE,F}
if hasmethod(f, typeof((tree, dataset, options, idx)))
# If user defines method that accepts batching indices:
@@ -105,7 +109,7 @@ end
function eval_loss(
tree::Union{AbstractExpression{T},AbstractExpressionNode{T}},
dataset::Dataset{T,L},
- options::Options;
+ options::AbstractOptions;
regularization::Bool=true,
idx=nothing,
)::L where {T<:DATA_TYPE,L<:LOSS_TYPE}
@@ -122,7 +126,7 @@ end
function eval_loss_batched(
tree::Union{AbstractExpression{T},AbstractExpressionNode{T}},
dataset::Dataset{T,L},
- options::Options;
+ options::AbstractOptions;
regularization::Bool=true,
idx=nothing,
)::L where {T<:DATA_TYPE,L<:LOSS_TYPE}
@@ -148,7 +152,7 @@ function loss_to_score(
use_baseline::Bool,
baseline::L,
member,
- options::Options,
+ options::AbstractOptions,
complexity::Union{Int,Nothing}=nothing,
)::L where {L<:LOSS_TYPE}
# TODO: Come up with a more general normalization scheme.
@@ -167,7 +171,10 @@ end
# Score an equation
function score_func(
- dataset::Dataset{T,L}, member, options::Options; complexity::Union{Int,Nothing}=nothing
+ dataset::Dataset{T,L},
+ member,
+ options::AbstractOptions;
+ complexity::Union{Int,Nothing}=nothing,
)::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
result_loss = eval_loss(get_tree_from_member(member), dataset, options)
score = loss_to_score(
@@ -185,7 +192,7 @@ end
function score_func_batched(
dataset::Dataset{T,L},
member,
- options::Options;
+ options::AbstractOptions;
complexity::Union{Int,Nothing}=nothing,
idx=nothing,
)::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE}
@@ -202,12 +209,12 @@ function score_func_batched(
end
"""
- update_baseline_loss!(dataset::Dataset{T,L}, options::Options) where {T<:DATA_TYPE,L<:LOSS_TYPE}
+ update_baseline_loss!(dataset::Dataset{T,L}, options::AbstractOptions) where {T<:DATA_TYPE,L<:LOSS_TYPE}
Update the baseline loss of the dataset using the loss function specified in `options`.
"""
function update_baseline_loss!(
- dataset::Dataset{T,L}, options::Options
+ dataset::Dataset{T,L}, options::AbstractOptions
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
example_tree = create_expression(zero(T), options, dataset)
# constructorof(options.node_type)(T; val=dataset.avg_y)
@@ -226,7 +233,7 @@ end
function dimensional_regularization(
tree::Union{AbstractExpression{T},AbstractExpressionNode{T}},
dataset::Dataset{T,L},
- options::Options,
+ options::AbstractOptions,
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
if !violates_dimensional_constraints(tree, dataset, options)
return zero(L)
diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl
index 2bbce47f1..028493ff7 100644
--- a/src/MLJInterface.jl
+++ b/src/MLJInterface.jl
@@ -23,9 +23,8 @@ using DynamicQuantities:
ustrip,
dimension
using LossFunctions: SupervisedLoss
-using Compat: allequal, stack
using ..InterfaceDynamicQuantitiesModule: get_dimensions_type
-using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE
+using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE
using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS
using ..ComplexityModule: compute_complexity
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
diff --git a/src/Migration.jl b/src/Migration.jl
index daab9255f..d08dab2ac 100644
--- a/src/Migration.jl
+++ b/src/Migration.jl
@@ -1,20 +1,20 @@
module MigrationModule
using StatsBase: StatsBase
-using ..CoreModule: Options
+using ..CoreModule: AbstractOptions
using ..PopulationModule: Population
using ..PopMemberModule: PopMember, reset_birth!
using ..UtilsModule: poisson_sample
"""
- migrate!(migration::Pair{Population{T,L},Population{T,L}}, options::Options; frac::AbstractFloat)
+ migrate!(migration::Pair{Population{T,L},Population{T,L}}, options::AbstractOptions; frac::AbstractFloat)
Migrate a fraction of the population from one population to the other, creating copies
to do so. The original migrant population is not modified. Pass with, e.g.,
`migrate!(migration_candidates => destination, options; frac=0.1)`
"""
function migrate!(
- migration::Pair{Vector{PM},P}, options::Options; frac::AbstractFloat
+ migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat
) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}}
base_pop = migration.second
population_size = length(base_pop.members)
diff --git a/src/Mutate.jl b/src/Mutate.jl
index c559d42ba..0f383177b 100644
--- a/src/Mutate.jl
+++ b/src/Mutate.jl
@@ -10,7 +10,8 @@ using DynamicExpressions:
count_scalar_constants,
simplify_tree!,
combine_operators
-using ..CoreModule: Options, MutationWeights, Dataset, RecordType, sample_mutation
+using ..CoreModule:
+ AbstractOptions, AbstractMutationWeights, Dataset, RecordType, sample_mutation
using ..ComplexityModule: compute_complexity
using ..LossFunctionsModule: score_func, score_func_batched
using ..CheckConstraintsModule: check_constraints
@@ -32,8 +33,68 @@ using ..MutationFunctionsModule:
using ..ConstantOptimizationModule: optimize_constants
using ..RecorderModule: @recorder
+abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end
+
+"""
+ MutationResult{N<:AbstractExpression,P<:PopMember}
+
+Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions.
+
+# Fields
+
+- `tree::Union{N, Nothing}`: The mutated expression tree, if applicable. Either `tree` or `member` must be set, but not both.
+- `member::Union{P, Nothing}`: The mutated population member, if applicable. Either `member` or `tree` must be set, but not both.
+- `num_evals::Float64`: The number of evaluations performed during the mutation, which is automatically set to `0.0`. Only used for things like `optimize`.
+- `return_immediately::Bool`: If `true`, the mutation process should return immediately, bypassing further checks, used for things like `simplify` or `optimize` where you already know the loss value of the result.
+
+# Usage
+
+This struct encapsulates the result of a mutation operation. Either a new expression tree or a new population member is returned, but not both.
+
+Return the `member` if you want to return immediately, and have
+computed the loss value as part of the mutation.
+"""
+struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P}
+ tree::Union{N,Nothing}
+ member::Union{P,Nothing}
+ num_evals::Float64
+ return_immediately::Bool
+
+ # Explicit constructor with keyword arguments
+ function MutationResult{_N,_P}(;
+ tree::Union{_N,Nothing}=nothing,
+ member::Union{_P,Nothing}=nothing,
+ num_evals::Float64=0.0,
+ return_immediately::Bool=false,
+ ) where {_N<:AbstractExpression,_P<:PopMember}
+ @assert(
+ (tree === nothing) ⊻ (member === nothing),
+ "Mutation result must return either a tree or a pop member, not both"
+ )
+ return new{_N,_P}(tree, member, num_evals, return_immediately)
+ end
+end
+
+"""
+ condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int)
+
+Adjusts the mutation weights based on the properties of the current member and options.
+
+This function modifies the mutation weights to ensure that the mutations applied to the member are appropriate given its current state and the provided options. It can be overloaded to customize the behavior for different types of expressions or members.
+
+Note that the weights were already copied, so you don't need to worry about mutation.
+
+# Arguments
+- `weights::AbstractMutationWeights`: The mutation weights to be adjusted.
+- `member::PopMember`: The current population member being mutated.
+- `options::AbstractOptions`: The options that guide the mutation process.
+- `curmaxsize::Int`: The current maximum size constraint for the member's expression tree.
+"""
function condition_mutation_weights!(
- weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int
+ weights::AbstractMutationWeights,
+ member::PopMember,
+ options::AbstractOptions,
+ curmaxsize::Int,
)
tree = get_tree(member.tree)
if !preserve_sharing(typeof(member.tree))
@@ -81,9 +142,9 @@ Use this to modify how `mutate_constant` changes for an expression type.
"""
function condition_mutate_constant!(
::Type{<:AbstractExpression},
- weights::MutationWeights,
+ weights::AbstractMutationWeights,
member::PopMember,
- options::Options,
+ options::AbstractOptions,
curmaxsize::Int,
)
n_constants = count_scalar_constants(member.tree)
@@ -93,9 +154,9 @@ function condition_mutate_constant!(
end
function condition_mutate_constant!(
::Type{<:ParametricExpression},
- weights::MutationWeights,
+ weights::AbstractMutationWeights,
member::PopMember,
- options::Options,
+ options::AbstractOptions,
curmaxsize::Int,
)
# Avoid modifying the mutate_constant weight, since
@@ -111,13 +172,12 @@ function next_generation(
temperature,
curmaxsize::Int,
running_search_statistics::RunningSearchStatistics,
- options::Options;
+ options::AbstractOptions;
tmp_recorder::RecordType,
)::Tuple{
P,Bool,Float64
} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}}
parent_ref = member.ref
- mutation_accepted = false
num_evals = 0.0
#TODO - reconsider this
@@ -137,8 +197,6 @@ function next_generation(
mutation_choice = sample_mutation(weights)
successful_mutation = false
- #TODO: Currently we dont take this \/ into account
- is_success_always_possible = true
attempts = 0
max_attempts = 10
@@ -148,133 +206,41 @@ function next_generation(
local tree
while (!successful_mutation) && attempts < max_attempts
tree = copy_node(member.tree)
- successful_mutation = true
- if mutation_choice == :mutate_constant
- tree = mutate_constant(tree, temperature, options)
- @recorder tmp_recorder["type"] = "constant"
- is_success_always_possible = true
- # Mutating a constant shouldn't invalidate an already-valid function
- elseif mutation_choice == :mutate_operator
- tree = mutate_operator(tree, options)
- @recorder tmp_recorder["type"] = "operator"
- is_success_always_possible = true
- # Can always mutate to the same operator
-
- elseif mutation_choice == :swap_operands
- tree = swap_operands(tree)
- @recorder tmp_recorder["type"] = "swap_operands"
- is_success_always_possible = true
-
- elseif mutation_choice == :add_node
- if rand() < 0.5
- tree = append_random_op(tree, options, nfeatures)
- @recorder tmp_recorder["type"] = "append_op"
- else
- tree = prepend_random_op(tree, options, nfeatures)
- @recorder tmp_recorder["type"] = "prepend_op"
- end
- is_success_always_possible = false
- # Can potentially have a situation without success
- elseif mutation_choice == :insert_node
- tree = insert_random_op(tree, options, nfeatures)
- @recorder tmp_recorder["type"] = "insert_op"
- is_success_always_possible = false
- elseif mutation_choice == :delete_node
- tree = delete_random_op!(tree, options, nfeatures)
- @recorder tmp_recorder["type"] = "delete_op"
- is_success_always_possible = true
- elseif mutation_choice == :simplify
- @assert options.should_simplify
- simplify_tree!(tree, options.operators)
- tree = combine_operators(tree, options.operators)
- @recorder tmp_recorder["type"] = "partial_simplify"
- mutation_accepted = true
- is_success_always_possible = true
- return (
- PopMember(
- tree,
- beforeScore,
- beforeLoss,
- options;
- parent=parent_ref,
- deterministic=options.deterministic,
- ),
- mutation_accepted,
- num_evals,
- )
- # Simplification shouldn't hurt complexity; unless some non-symmetric constraint
- # to commutative operator...
- elseif mutation_choice == :randomize
- # We select a random size, though the generated tree
- # may have fewer nodes than we request.
- tree_size_to_generate = rand(1:curmaxsize)
- tree = with_contents(
- tree,
- gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T),
- )
- @recorder tmp_recorder["type"] = "regenerate"
- is_success_always_possible = true
- elseif mutation_choice == :optimize
- cur_member = PopMember(
- tree,
- beforeScore,
- beforeLoss,
- options,
- compute_complexity(member, options);
- parent=parent_ref,
- deterministic=options.deterministic,
- )
- cur_member, new_num_evals = optimize_constants(dataset, cur_member, options)
- num_evals += new_num_evals
- @recorder tmp_recorder["type"] = "optimize"
- mutation_accepted = true
- is_success_always_possible = true
- return (cur_member, mutation_accepted, num_evals)
- elseif mutation_choice == :do_nothing
- @recorder begin
- tmp_recorder["type"] = "identity"
- tmp_recorder["result"] = "accept"
- tmp_recorder["reason"] = "identity"
- end
- mutation_accepted = true
- is_success_always_possible = true
- return (
- PopMember(
- tree,
- beforeScore,
- beforeLoss,
- options,
- compute_complexity(member, options);
- parent=parent_ref,
- deterministic=options.deterministic,
- ),
- mutation_accepted,
- num_evals,
+ mutation_result = _dispatch_mutations!(
+ tree,
+ member,
+ mutation_choice,
+ options.mutation_weights,
+ options;
+ recorder=tmp_recorder,
+ temperature,
+ dataset,
+ score=beforeScore,
+ loss=beforeLoss,
+ parent_ref,
+ curmaxsize,
+ nfeatures,
+ )
+ mutation_result::AbstractMutationResult{N,P}
+ num_evals += mutation_result.num_evals::Float64
+
+ if mutation_result.return_immediately
+ @assert(
+ mutation_result.member isa P,
+ "Mutation result must return a `PopMember` if `return_immediately` is true"
)
- elseif mutation_choice == :form_connection
- tree = form_random_connection!(tree)
- @recorder tmp_recorder["type"] = "form_connection"
- is_success_always_possible = true
- elseif mutation_choice == :break_connection
- tree = break_random_connection!(tree)
- @recorder tmp_recorder["type"] = "break_connection"
- is_success_always_possible = true
- elseif mutation_choice == :rotate_tree
- tree = randomly_rotate_tree!(tree)
- @recorder tmp_recorder["type"] = "rotate_tree"
- is_success_always_possible = true
+ return mutation_result.member::P, true, num_evals
else
- error("Unknown mutation choice: $mutation_choice")
+ @assert(
+ mutation_result.tree isa N,
+ "Mutation result must return a tree if `return_immediately` is false"
+ )
+ tree = mutation_result.tree::N
+ successful_mutation = check_constraints(tree, options, curmaxsize)
+ attempts += 1
end
-
- successful_mutation =
- successful_mutation && check_constraints(tree, options, curmaxsize)
-
- attempts += 1
end
- #############################################
- tree::AbstractExpression
if !successful_mutation
@recorder begin
@@ -389,9 +355,300 @@ function next_generation(
end
end
+@generated function _dispatch_mutations!(
+ tree::AbstractExpression,
+ member::PopMember,
+ mutation_choice::Symbol,
+ weights::W,
+ options::AbstractOptions;
+ kws...,
+) where {W<:AbstractMutationWeights}
+ mutation_choices = fieldnames(W)
+ quote
+ Base.Cartesian.@nif(
+ $(length(mutation_choices)),
+ i -> mutation_choice == $(mutation_choices)[i],
+ i -> begin
+ @assert mutation_choice == $(mutation_choices)[i]
+ mutate!(
+ tree, member, Val($(mutation_choices)[i]), weights, options; kws...
+ )
+ end,
+ )
+ end
+end
+
+"""
+ mutate!(
+ tree::N,
+ member::P,
+ ::Val{S},
+ mutation_weights::AbstractMutationWeights,
+ options::AbstractOptions;
+ kws...,
+ ) where {N<:AbstractExpression,P<:PopMember,S}
+
+Perform a mutation on the given `tree` and `member` using the specified mutation type `S`.
+Various `kws` are provided to access other data needed for some mutations.
+
+You may overload this function to handle new mutation types for new `AbstractMutationWeights` types.
+
+# Keywords
+
+- `temperature`: The temperature parameter for annealing-based mutations.
+- `dataset::Dataset`: The dataset used for scoring.
+- `score`: The score of the member before mutation.
+- `loss`: The loss of the member before mutation.
+- `curmaxsize`: The current maximum size constraint, which may be different from `options.maxsize`.
+- `nfeatures`: The number of features in the dataset.
+- `parent_ref`: Reference to the mutated member's parent (only used for logging purposes).
+- `recorder::RecordType`: A recorder to log mutation details.
+
+# Returns
+
+A `MutationResult{N,P}` object containing the mutated tree or member (but not both),
+the number of evaluations performed, if any, and whether to return immediately from
+the mutation function, or to let the `next_generation` function handle accepting or
+rejecting the mutation. For example, a `simplify` operation will not change the loss,
+so it can always return immediately.
+"""
+function mutate!(
+ ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws...
+) where {N<:AbstractExpression,P<:PopMember,S}
+ return error("Unknown mutation choice: $S")
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:mutate_constant},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ temperature,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = mutate_constant(tree, temperature, options)
+ @recorder recorder["type"] = "mutate_constant"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:mutate_operator},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = mutate_operator(tree, options)
+ @recorder recorder["type"] = "mutate_operator"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:swap_operands},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = swap_operands(tree)
+ @recorder recorder["type"] = "swap_operands"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:add_node},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ nfeatures,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ if rand() < 0.5
+ tree = append_random_op(tree, options, nfeatures)
+ @recorder recorder["type"] = "add_node:append"
+ else
+ tree = prepend_random_op(tree, options, nfeatures)
+ @recorder recorder["type"] = "add_node:prepend"
+ end
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:insert_node},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ nfeatures,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = insert_random_op(tree, options, nfeatures)
+ @recorder recorder["type"] = "insert_node"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:delete_node},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ nfeatures,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = delete_random_op!(tree, options, nfeatures)
+ @recorder recorder["type"] = "delete_node"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:form_connection},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = form_random_connection!(tree)
+ @recorder recorder["type"] = "form_connection"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:break_connection},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = break_random_connection!(tree)
+ @recorder recorder["type"] = "break_connection"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:rotate_tree},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ tree = randomly_rotate_tree!(tree)
+ @recorder recorder["type"] = "rotate_tree"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+# Handle mutations that require early return
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:simplify},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ parent_ref,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ @assert options.should_simplify
+ simplify_tree!(tree, options.operators)
+ tree = combine_operators(tree, options.operators)
+ @recorder recorder["type"] = "simplify"
+ return MutationResult{N,P}(;
+ member=PopMember(
+ tree,
+ member.score,
+ member.loss,
+ options;
+ parent=parent_ref,
+ deterministic=options.deterministic,
+ ),
+ return_immediately=true,
+ )
+end
+
+function mutate!(
+ tree::N,
+ ::P,
+ ::Val{:randomize},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ curmaxsize,
+ nfeatures,
+ kws...,
+) where {T,N<:AbstractExpression{T},P<:PopMember}
+ tree_size_to_generate = rand(1:curmaxsize)
+ tree = with_contents(
+ tree, gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T)
+ )
+ @recorder recorder["type"] = "randomize"
+ return MutationResult{N,P}(; tree=tree)
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:optimize},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ dataset::Dataset,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ cur_member, new_num_evals = optimize_constants(dataset, member, options)
+ @recorder recorder["type"] = "optimize"
+ return MutationResult{N,P}(;
+ member=cur_member, num_evals=new_num_evals, return_immediately=true
+ )
+end
+
+function mutate!(
+ tree::N,
+ member::P,
+ ::Val{:do_nothing},
+ ::AbstractMutationWeights,
+ options::AbstractOptions;
+ recorder::RecordType,
+ parent_ref,
+ kws...,
+) where {N<:AbstractExpression,P<:PopMember}
+ @recorder begin
+ recorder["type"] = "identity"
+ recorder["result"] = "accept"
+ recorder["reason"] = "identity"
+ end
+ return MutationResult{N,P}(;
+ member=PopMember(
+ tree,
+ member.score,
+ member.loss,
+ options,
+ compute_complexity(tree, options);
+ parent=parent_ref,
+ deterministic=options.deterministic,
+ ),
+ return_immediately=true,
+ )
+end
+
"""Generate a generation via crossover of two members."""
function crossover_generation(
- member1::P, member2::P, dataset::D, curmaxsize::Int, options::Options
+ member1::P, member2::P, dataset::D, curmaxsize::Int, options::AbstractOptions
)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}}
tree1 = member1.tree
tree2 = member2.tree
@@ -460,4 +717,4 @@ function crossover_generation(
return baby1, baby2, crossover_accepted, num_evals
end
-end
+end # module MutateModule
diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl
index 844160de7..496d584dd 100644
--- a/src/MutationFunctions.jl
+++ b/src/MutationFunctions.jl
@@ -14,8 +14,7 @@ using DynamicExpressions:
count_nodes,
has_constants,
has_operators
-using Compat: Returns, @inline
-using ..CoreModule: Options, DATA_TYPE
+using ..CoreModule: AbstractOptions, DATA_TYPE
"""
random_node(tree::AbstractNode; filter::F=Returns(true))
@@ -50,14 +49,16 @@ end
"""Randomly convert an operator into another one (binary->binary; unary->unary)"""
function mutate_operator(
- ex::AbstractExpression{T}, options::Options, rng::AbstractRNG=default_rng()
+ ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng()
) where {T<:DATA_TYPE}
tree = get_contents(ex)
ex = with_contents(ex, mutate_operator(tree, options, rng))
return ex
end
function mutate_operator(
- tree::AbstractExpressionNode{T}, options::Options, rng::AbstractRNG=default_rng()
+ tree::AbstractExpressionNode{T},
+ options::AbstractOptions,
+ rng::AbstractRNG=default_rng(),
) where {T}
if !(has_operators(tree))
return tree
@@ -73,7 +74,10 @@ end
"""Randomly perturb a constant"""
function mutate_constant(
- ex::AbstractExpression{T}, temperature, options::Options, rng::AbstractRNG=default_rng()
+ ex::AbstractExpression{T},
+ temperature,
+ options::AbstractOptions,
+ rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
tree = get_contents(ex)
ex = with_contents(ex, mutate_constant(tree, temperature, options, rng))
@@ -82,7 +86,7 @@ end
function mutate_constant(
tree::AbstractExpressionNode{T},
temperature,
- options::Options,
+ options::AbstractOptions,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
# T is between 0 and 1.
@@ -116,7 +120,7 @@ end
"""Add a random unary/binary operation to the end of a tree"""
function append_random_op(
ex::AbstractExpression{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng();
makeNewBinOp::Union{Bool,Nothing}=nothing,
@@ -127,7 +131,7 @@ function append_random_op(
end
function append_random_op(
tree::AbstractExpressionNode{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng();
makeNewBinOp::Union{Bool,Nothing}=nothing,
@@ -160,7 +164,7 @@ end
"""Insert random node"""
function insert_random_op(
ex::AbstractExpression{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -170,7 +174,7 @@ function insert_random_op(
end
function insert_random_op(
tree::AbstractExpressionNode{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -194,7 +198,7 @@ end
"""Add random node to the top of a tree"""
function prepend_random_op(
ex::AbstractExpression{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -204,7 +208,7 @@ function prepend_random_op(
end
function prepend_random_op(
tree::AbstractExpressionNode{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -230,7 +234,7 @@ function make_random_leaf(
::Type{T},
::Type{N},
rng::AbstractRNG=default_rng(),
- ::Union{Options,Nothing}=nothing,
+ ::Union{AbstractOptions,Nothing}=nothing,
) where {T<:DATA_TYPE,N<:AbstractExpressionNode}
if rand(rng, Bool)
return constructorof(N)(T; val=randn(rng, T))
@@ -255,7 +259,7 @@ end
"""Select a random node, and splice it out of the tree."""
function delete_random_op!(
ex::AbstractExpression{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -265,7 +269,7 @@ function delete_random_op!(
end
function delete_random_op!(
tree::AbstractExpressionNode{T},
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
@@ -310,7 +314,11 @@ end
"""Create a random equation by appending random operators"""
function gen_random_tree(
- length::Int, options::Options, nfeatures::Int, ::Type{T}, rng::AbstractRNG=default_rng()
+ length::Int,
+ options::AbstractOptions,
+ nfeatures::Int,
+ ::Type{T},
+ rng::AbstractRNG=default_rng(),
) where {T<:DATA_TYPE}
# Note that this base tree is just a placeholder; it will be replaced.
tree = constructorof(options.node_type)(T; val=convert(T, 1))
@@ -323,7 +331,7 @@ end
function gen_random_tree_fixed_size(
node_count::Int,
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
::Type{T},
rng::AbstractRNG=default_rng(),
diff --git a/src/MutationWeights.jl b/src/MutationWeights.jl
index f7c0f8a70..9de15af7d 100644
--- a/src/MutationWeights.jl
+++ b/src/MutationWeights.jl
@@ -3,11 +3,78 @@ module MutationWeightsModule
using StatsBase: StatsBase
"""
- MutationWeights(;kws...)
+ AbstractMutationWeights
+
+An abstract type that defines the interface for mutation weight structures in the symbolic regression framework. Subtypes of `AbstractMutationWeights` specify how often different mutation operations occur during the mutation process.
+
+You can create custom mutation weight types by subtyping `AbstractMutationWeights` and defining your own mutation operations. Additionally, you can overload the `sample_mutation` function to handle sampling from your custom mutation types.
+
+# Usage
+
+To create a custom mutation weighting scheme with new mutation types, define a new subtype of `AbstractMutationWeights` and implement the necessary fields. Here's an example using `Base.@kwdef` to define the struct with default values:
+
+```julia
+using SymbolicRegression: AbstractMutationWeights
+
+# Define custom mutation weights with default values
+Base.@kwdef struct MyMutationWeights <: AbstractMutationWeights
+ mutate_constant::Float64 = 0.1
+ mutate_operator::Float64 = 0.2
+ custom_mutation::Float64 = 0.7
+end
+```
+
+Next, overload the `sample_mutation` function to include your custom mutation types:
+
+```julia
+# Define the list of mutation names (symbols)
+const MY_MUTATIONS = [
+ :mutate_constant,
+ :mutate_operator,
+ :custom_mutation
+]
+
+# Import the `sample_mutation` function to overload it
+import SymbolicRegression: sample_mutation
+using StatsBase: StatsBase
+
+# Overload the `sample_mutation` function
+function sample_mutation(w::MyMutationWeights)
+ weights = [
+ w.mutate_constant,
+ w.mutate_operator,
+ w.custom_mutation
+ ]
+ weights = weights ./ sum(weights) # Normalize weights to sum to 1.0
+ return StatsBase.sample(MY_MUTATIONS, StatsBase.Weights(weights))
+end
+
+# Pass it when defining `Options`:
+using SymbolicRegression: Options
+options = Options(mutation_weights=MyMutationWeights())
+```
+
+This allows you to customize the mutation sampling process to include your custom mutations according to their specified weights.
+
+To integrate your custom mutations into the mutation process, ensure that the mutation functions corresponding to your custom mutation types are defined and properly registered with the symbolic regression framework. You may need to define methods for `mutate!` that handle your custom mutation types.
+
+# See Also
+
+- [`MutationWeights`](@ref): A concrete implementation of `AbstractMutationWeights` that defines default mutation weightings.
+- [`sample_mutation`](@ref): Function to sample a mutation based on current mutation weights.
+- [`mutate!`](@ref SymbolicRegression.MutateModule.mutate!): Function to apply a mutation to an expression tree.
+- [`AbstractOptions`](@ref SymbolicRegression.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options.
+"""
+abstract type AbstractMutationWeights end
+
+"""
+ MutationWeights(;kws...) <: AbstractMutationWeights
This defines how often different mutations occur. These weightings
will be normalized to sum to 1.0 after initialization.
+
# Arguments
+
- `mutate_constant::Float64`: How often to mutate a constant.
- `mutate_operator::Float64`: How often to mutate an operator.
- `swap_operands::Float64`: How often to swap the operands of a binary operator.
@@ -27,8 +94,12 @@ will be normalized to sum to 1.0 after initialization.
- `break_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**.
Otherwise, this will automatically be set to 0.0. How often to break a
connection between two nodes.
+
+# See Also
+
+- [`AbstractMutationWeights`](@ref SymbolicRegression.CoreModule.MutationWeightsModule.AbstractMutationWeights): Use to define custom mutation weight types.
"""
-Base.@kwdef mutable struct MutationWeights
+Base.@kwdef mutable struct MutationWeights <: AbstractMutationWeights
mutate_constant::Float64 = 0.048
mutate_operator::Float64 = 0.47
swap_operands::Float64 = 0.1
@@ -60,7 +131,7 @@ let contents = [Expr(:., :w, QuoteNode(field)) for field in mutations]
end
"""Sample a mutation, given the weightings."""
-function sample_mutation(w::MutationWeights)
+function sample_mutation(w::AbstractMutationWeights)
weights = convert(Vector, w)
return StatsBase.sample(v_mutations, StatsBase.Weights(weights))
end
diff --git a/src/Options.jl b/src/Options.jl
index 0815bf744..701fde2ff 100644
--- a/src/Options.jl
+++ b/src/Options.jl
@@ -25,7 +25,7 @@ using ..OperatorsModule:
safe_sqrt,
safe_acosh,
atanh_clip
-using ..MutationWeightsModule: MutationWeights, mutations
+using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations
import ..OptionsStructModule: Options
using ..OptionsStructModule: ComplexityMapping, operator_specialization
using ..UtilsModule: max_ops, @save_kwargs, @ignore
@@ -199,7 +199,7 @@ function inverse_unaopmap(op::F) where {F}
return op
end
-create_mutation_weights(w::MutationWeights) = w
+create_mutation_weights(w::AbstractMutationWeights) = w
create_mutation_weights(w::NamedTuple) = MutationWeights(; w...)
const deprecated_options_mapping = Base.ImmutableDict(
@@ -282,7 +282,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
- `DWDMarginLoss(q)`.
- `loss_function`: Alternatively, you may redefine the loss used
as any function of `tree::AbstractExpressionNode{T}`, `dataset::Dataset{T}`,
- and `options::Options`, so long as you output a non-negative
+ and `options::AbstractOptions`, so long as you output a non-negative
scalar of type `T`. This is useful if you want to use a loss
that takes into account derivatives, or correlations across
the dataset. This also means you could use a custom evaluation
@@ -388,7 +388,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
- `probability_negate_constant`: Probability of negating a constant in the equation
when mutating it.
- `mutation_weights`: Relative probabilities of the mutations. The struct
- `MutationWeights` should be passed to these options.
+ `MutationWeights` (or any `AbstractMutationWeights`) should be passed to these options.
See its documentation on `MutationWeights` for the different weights.
- `crossover_probability`: Probability of performing crossover.
- `annealing`: Whether to use simulated annealing.
@@ -429,7 +429,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators
"""
"""
- Options(;kws...)
+ Options(;kws...) <: AbstractOptions
Construct options for `equation_search` and other functions.
The current arguments have been tuned using the median values from
@@ -471,7 +471,7 @@ $(OPTION_DESCRIPTIONS)
annealing::Bool=false,
batching::Bool=false,
batch_size::Integer=50,
- mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(),
+ mutation_weights::Union{AbstractMutationWeights,AbstractVector,NamedTuple}=MutationWeights(),
crossover_probability::Real=0.066,
warmup_maxsize_by::Real=0.0,
use_frequency::Bool=true,
@@ -737,6 +737,7 @@ $(OPTION_DESCRIPTIONS)
node_type,
expression_type,
typeof(expression_options),
+ typeof(set_mutation_weights),
turbo,
bumper,
deprecated_return_state,
diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl
index 2417b75ba..ba20fe92c 100644
--- a/src/OptionsStruct.jl
+++ b/src/OptionsStruct.jl
@@ -6,7 +6,7 @@ using DynamicExpressions:
AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum
using LossFunctions: SupervisedLoss
-import ..MutationWeightsModule: MutationWeights
+import ..MutationWeightsModule: AbstractMutationWeights
"""
This struct defines how complexity is calculated.
@@ -121,17 +121,71 @@ else
@eval operator_specialization(O::Type{<:OperatorEnum}) = O
end
+"""
+ AbstractOptions
+
+An abstract type that stores all search hyperparameters for SymbolicRegression.jl.
+The standard implementation is [`Options`](@ref).
+
+You may wish to create a new subtypes of `AbstractOptions` to override certain functions
+or create new behavior. Ensure that this new type has all properties of [`Options`](@ref).
+
+For example, if we have new options that we want to add to `Options`:
+
+```julia
+Base.@kwdef struct MyNewOptions
+ a::Float64 = 1.0
+ b::Int = 3
+end
+```
+
+we can create a combined options type that forwards properties to each corresponding type:
+
+```julia
+struct MyOptions{O<:SymbolicRegression.Options} <: SymbolicRegression.AbstractOptions
+ new_options::MyNewOptions
+ sr_options::O
+end
+const NEW_OPTIONS_KEYS = fieldnames(MyNewOptions)
+
+# Constructor with both sets of parameters:
+function MyOptions(; kws...)
+ new_options_keys = filter(k -> k in NEW_OPTIONS_KEYS, keys(kws))
+ new_options = MyNewOptions(; NamedTuple(new_options_keys .=> Tuple(kws[k] for k in new_options_keys))...)
+ sr_options_keys = filter(k -> !(k in NEW_OPTIONS_KEYS), keys(kws))
+ sr_options = SymbolicRegression.Options(; NamedTuple(sr_options_keys .=> Tuple(kws[k] for k in sr_options_keys))...)
+ return MyOptions(new_options, sr_options)
+end
+
+# Make all `Options` available while also making `new_options` accessible
+function Base.getproperty(options::MyOptions, k::Symbol)
+ if k in NEW_OPTIONS_KEYS
+ return getproperty(getfield(options, :new_options), k)
+ else
+ return getproperty(getfield(options, :sr_options), k)
+ end
+end
+
+Base.propertynames(options::MyOptions) = (NEW_OPTIONS_KEYS..., fieldnames(SymbolicRegression.Options)...)
+```
+
+which would let you access `a` and `b` from `MyOptions` objects, as well as making
+all properties of `Options` available for internal methods in SymbolicRegression.jl
+"""
+abstract type AbstractOptions end
+
struct Options{
CM<:ComplexityMapping,
OP<:AbstractOperatorEnum,
N<:AbstractExpressionNode,
E<:AbstractExpression,
EO<:NamedTuple,
+ MW<:AbstractMutationWeights,
_turbo,
_bumper,
_return_state,
AD,
-}
+} <: AbstractOptions
operators::OP
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
@@ -156,7 +210,7 @@ struct Options{
annealing::Bool
batching::Bool
batch_size::Int
- mutation_weights::MutationWeights
+ mutation_weights::MW
crossover_probability::Float32
warmup_maxsize_by::Float32
use_frequency::Bool
@@ -223,6 +277,7 @@ function Base.print(io::IO, options::Options)
end
Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options)
+specialized_options(options::AbstractOptions) = options
@unstable function specialized_options(options::Options)
return _specialized_options(options)
end
diff --git a/src/PopMember.jl b/src/PopMember.jl
index 84f29f451..63cfebfb4 100644
--- a/src/PopMember.jl
+++ b/src/PopMember.jl
@@ -2,7 +2,7 @@ module PopMemberModule
using DispatchDoctor: @unstable
using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree
-using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, create_expression
+using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression
import ..ComplexityModule: compute_complexity
using ..UtilsModule: get_birth_order
using ..LossFunctionsModule: score_func
@@ -61,7 +61,7 @@ function PopMember(
t::AbstractExpression{T},
score::L,
loss::L,
- options::Union{Options,Nothing}=nothing,
+ options::Union{AbstractOptions,Nothing}=nothing,
complexity::Union{Int,Nothing}=nothing;
ref::Int=-1,
parent::Int=-1,
@@ -93,7 +93,7 @@ end
PopMember(
dataset::Dataset{T,L},
t::AbstractExpression{T},
- options::Options
+ options::AbstractOptions
)
Create a population member with a birth date at the current time.
@@ -103,12 +103,12 @@ Automatically compute the score for this tree.
- `dataset::Dataset{T,L}`: The dataset to evaluate the tree on.
- `t::AbstractExpression{T}`: The tree for the population member.
-- `options::Options`: What options to use.
+- `options::AbstractOptions`: What options to use.
"""
function PopMember(
dataset::Dataset{T,L},
tree::Union{AbstractExpressionNode{T},AbstractExpression{T}},
- options::Options,
+ options::AbstractOptions,
complexity::Union{Int,Nothing}=nothing;
ref::Int=-1,
parent::Int=-1,
@@ -148,7 +148,7 @@ end
# Can read off complexity directly from pop members
function compute_complexity(
- member::PopMember, options::Options; break_sharing=Val(false)
+ member::PopMember, options::AbstractOptions; break_sharing=Val(false)
)::Int
complexity = getfield(member, :complexity)
complexity == -1 && return recompute_complexity!(member, options; break_sharing)
@@ -156,7 +156,7 @@ function compute_complexity(
return complexity
end
function recompute_complexity!(
- member::PopMember, options::Options; break_sharing=Val(false)
+ member::PopMember, options::AbstractOptions; break_sharing=Val(false)
)::Int
complexity = compute_complexity(member.tree, options; break_sharing)
setfield!(member, :complexity, complexity)
diff --git a/src/Population.jl b/src/Population.jl
index 9bc1df309..3a544730b 100644
--- a/src/Population.jl
+++ b/src/Population.jl
@@ -3,7 +3,7 @@ module PopulationModule
using StatsBase: StatsBase
using DispatchDoctor: @unstable
using DynamicExpressions: AbstractExpression, string_tree
-using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
+using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..ComplexityModule: compute_complexity
using ..LossFunctionsModule: score_func, update_baseline_loss!
using ..AdaptiveParsimonyModule: RunningSearchStatistics
@@ -27,14 +27,14 @@ end
"""
Population(dataset::Dataset{T,L};
- population_size, nlength::Int=3, options::Options,
+ population_size, nlength::Int=3, options::AbstractOptions,
nfeatures::Int)
Create random population and score them on the dataset.
"""
function Population(
dataset::Dataset{T,L};
- options::Options,
+ options::AbstractOptions,
population_size=nothing,
nlength::Int=3,
nfeatures::Int,
@@ -62,7 +62,7 @@ end
"""
Population(X::AbstractMatrix{T}, y::AbstractVector{T};
population_size, nlength::Int=3,
- options::Options, nfeatures::Int,
+ options::AbstractOptions, nfeatures::Int,
loss_type::Type=Nothing)
Create random population and score them on the dataset.
@@ -72,7 +72,7 @@ Create random population and score them on the dataset.
y::AbstractVector{T};
population_size=nothing,
nlength::Int=3,
- options::Options,
+ options::AbstractOptions,
nfeatures::Int,
loss_type::Type{L}=Nothing,
npop=nothing,
@@ -99,7 +99,7 @@ function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}}
end
# Sample random members of the population, and make a new one
-function sample_pop(pop::P, options::Options)::P where {P<:Population}
+function sample_pop(pop::P, options::AbstractOptions)::P where {P<:Population}
return Population(
StatsBase.sample(pop.members, options.tournament_selection_n; replace=false)
)
@@ -109,7 +109,7 @@ end
function best_of_sample(
pop::Population{T,L,N},
running_search_statistics::RunningSearchStatistics,
- options::Options,
+ options::AbstractOptions,
) where {T,L,N}
sample = sample_pop(pop, options)
return _best_of_sample(
@@ -117,7 +117,9 @@ function best_of_sample(
)::PopMember{T,L,N}
end
function _best_of_sample(
- members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::Options
+ members::Vector{P},
+ running_search_statistics::RunningSearchStatistics,
+ options::AbstractOptions,
) where {T,L,P<:PopMember{T,L}}
p = options.tournament_selection_p
n = length(members) # == tournament_selection_n
@@ -166,7 +168,7 @@ const CACHED_WEIGHTS =
PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}()
end
-@unstable function get_tournament_selection_weights(@nospecialize(options::Options))
+@unstable function get_tournament_selection_weights(@nospecialize(options::AbstractOptions))
n = options.tournament_selection_n
p = options.tournament_selection_p
# Computing the weights for the tournament becomes quite expensive,
@@ -179,7 +181,7 @@ const CACHED_WEIGHTS =
end
function finalize_scores(
- dataset::Dataset{T,L}, pop::P, options::Options
+ dataset::Dataset{T,L}, pop::P, options::AbstractOptions
)::Tuple{P,Float64} where {T,L,P<:Population{T,L}}
need_recalculate = options.batching
num_evals = 0.0
@@ -200,7 +202,7 @@ function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population}
return Population(pop.members[best_idx[1:topn]])
end
-function record_population(pop::Population, options::Options)::RecordType
+function record_population(pop::Population, options::AbstractOptions)::RecordType
return RecordType(
"population" => [
RecordType(
diff --git a/src/Recorder.jl b/src/Recorder.jl
index a25ac0e78..171a1f46b 100644
--- a/src/Recorder.jl
+++ b/src/Recorder.jl
@@ -2,7 +2,7 @@ module RecorderModule
using ..CoreModule: RecordType
-"Assumes that `options` holds the user options::Options"
+"Assumes that `options` holds the user options::AbstractOptions"
macro recorder(ex)
quote
if $(esc(:options)).use_recorder
diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl
index 141e85888..06358a328 100644
--- a/src/RegularizedEvolution.jl
+++ b/src/RegularizedEvolution.jl
@@ -1,7 +1,7 @@
module RegularizedEvolutionModule
using DynamicExpressions: string_tree
-using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
+using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE
using ..PopulationModule: Population, best_of_sample
using ..AdaptiveParsimonyModule: RunningSearchStatistics
using ..MutateModule: next_generation, crossover_generation
@@ -16,7 +16,7 @@ function reg_evol_cycle(
temperature,
curmaxsize::Int,
running_search_statistics::RunningSearchStatistics,
- options::Options,
+ options::AbstractOptions,
record::RecordType,
)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}}
# Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops.
diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl
index a540e35f6..f1ba1cfb3 100644
--- a/src/SearchUtils.jl
+++ b/src/SearchUtils.jl
@@ -4,13 +4,13 @@ This includes: process management, stdin reading, checking for early stops."""
module SearchUtilsModule
using Printf: @printf, @sprintf
-using Distributed: Distributed, @spawnat, Future, procs
+using Distributed: Distributed, @spawnat, Future, procs, addprocs
using StatsBase: mean
using DispatchDoctor: @unstable
using DynamicExpressions: AbstractExpression, string_tree
using ..UtilsModule: subscriptify
-using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType
+using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType
using ..ComplexityModule: compute_complexity
using ..PopulationModule: Population
using ..PopMemberModule: PopMember
@@ -19,16 +19,33 @@ using ..ProgressBarsModule: WrappedProgressBar, set_multiline_postfix!, manually
using ..AdaptiveParsimonyModule: RunningSearchStatistics
"""
- RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE}
+ AbstractRuntimeOptions
+
+An abstract type representing runtime configuration parameters for the symbolic regression algorithm.
+
+`AbstractRuntimeOptions` is used by `equation_search` to control runtime aspects such
+as parallelism and iteration limits. By subtyping `AbstractRuntimeOptions`, advanced users
+can customize runtime behaviors by passing it to `equation_search`.
+
+# See Also
+
+- [`RuntimeOptions`](@ref): Default implementation used by `equation_search`.
+- [`equation_search`](@ref SymbolicRegression.equation_search): Main function to perform symbolic regression.
+- [`AbstractOptions`](@ref SymbolicRegression.CoreModule.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options.
+
+"""
+abstract type AbstractRuntimeOptions end
+
+"""
+ RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} <: AbstractRuntimeOptions
Parameters for a search that are passed to `equation_search` directly,
rather than set within `Options`. This is to differentiate between
parameters that relate to processing and the duration of the search,
and parameters dealing with the search hyperparameters itself.
"""
-Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE}
+struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} <: AbstractRuntimeOptions
niterations::Int64
- total_cycles::Int64
numprocs::Int64
init_procs::Union{Vector{Int},Nothing}
addprocs_function::Function
@@ -57,6 +74,141 @@ function Base.propertynames(roptions::RuntimeOptions)
return (Base.fieldnames(typeof(roptions))..., :parallelism, :dim_out, :return_state)
end
+@unstable function RuntimeOptions(;
+ niterations::Int=10,
+ nout::Int=1,
+ options::AbstractOptions=Options(),
+ parallelism=:multithreading,
+ numprocs::Union{Int,Nothing}=nothing,
+ procs::Union{Vector{Int},Nothing}=nothing,
+ addprocs_function::Union{Function,Nothing}=nothing,
+ heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
+ runtests::Bool=true,
+ return_state::Union{Bool,Nothing,Val}=nothing,
+ verbosity::Union{Int,Nothing}=nothing,
+ progress::Union{Bool,Nothing}=nothing,
+ v_dim_out::Val{DIM_OUT}=Val(nothing),
+) where {DIM_OUT}
+ concurrency = if parallelism in (:multithreading, "multithreading")
+ :multithreading
+ elseif parallelism in (:multiprocessing, "multiprocessing")
+ :multiprocessing
+ elseif parallelism in (:serial, "serial")
+ :serial
+ else
+ error(
+ "Invalid parallelism mode: $parallelism. " *
+ "You must choose one of :multithreading, :multiprocessing, or :serial.",
+ )
+ :serial
+ end
+ not_distributed = concurrency in (:multithreading, :serial)
+ not_distributed &&
+ procs !== nothing &&
+ error(
+ "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
+ )
+ not_distributed &&
+ numprocs !== nothing &&
+ error(
+ "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
+ )
+
+ _return_state = if return_state isa Val
+ first(typeof(return_state).parameters)
+ else
+ if options.return_state === Val(nothing)
+ return_state === nothing ? false : return_state
+ else
+ @assert(
+ return_state === nothing,
+ "You cannot set `return_state` in both the `AbstractOptions` and in the passed arguments."
+ )
+ first(typeof(options.return_state).parameters)
+ end
+ end
+
+ dim_out = if DIM_OUT === nothing
+ nout > 1 ? 2 : 1
+ else
+ DIM_OUT
+ end
+ _numprocs::Int = if numprocs === nothing
+ if procs === nothing
+ 4
+ else
+ length(procs)
+ end
+ else
+ if procs === nothing
+ numprocs
+ else
+ @assert length(procs) == numprocs
+ numprocs
+ end
+ end
+
+ _verbosity = if verbosity === nothing && options.verbosity === nothing
+ 1
+ elseif verbosity === nothing && options.verbosity !== nothing
+ options.verbosity
+ elseif verbosity !== nothing && options.verbosity === nothing
+ verbosity
+ else
+ error(
+ "You cannot set `verbosity` in both the search parameters `AbstractOptions` and the call to `equation_search`.",
+ )
+ 1
+ end
+ _progress::Bool = if progress === nothing && options.progress === nothing
+ (_verbosity > 0) && nout == 1
+ elseif progress === nothing && options.progress !== nothing
+ options.progress
+ elseif progress !== nothing && options.progress === nothing
+ progress
+ else
+ error(
+ "You cannot set `progress` in both the search parameters `AbstractOptions` and the call to `equation_search`.",
+ )
+ false
+ end
+
+ _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function
+
+ exeflags = if concurrency == :multiprocessing
+ heap_size_hint_in_megabytes = floor(
+ Int, (
+ if heap_size_hint_in_bytes === nothing
+ (Sys.free_memory() / _numprocs)
+ else
+ heap_size_hint_in_bytes
+ end
+ ) / 1024^2
+ )
+ _verbosity > 0 &&
+ heap_size_hint_in_bytes === nothing &&
+ @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter."
+
+ `--heap-size=$(heap_size_hint_in_megabytes)M`
+ else
+ ``
+ end
+
+ return RuntimeOptions{concurrency,dim_out,_return_state}(
+ niterations,
+ _numprocs,
+ procs,
+ _addprocs_function,
+ exeflags,
+ runtests,
+ _verbosity,
+ _progress,
+ Val(concurrency),
+ Val(dim_out),
+ Val(_return_state),
+ )
+end
+
"""A simple dictionary to track worker allocations."""
const WorkerAssignments = Dict{Tuple{Int,Int},Int}
@@ -126,7 +278,7 @@ macro sr_spawner(expr, kws...)
end
function init_dummy_pops(
- npops::Int, datasets::Vector{D}, options::Options
+ npops::Int, datasets::Vector{D}, options::AbstractOptions
) where {T,L,D<:Dataset{T,L}}
prototype = Population(
first(datasets);
@@ -201,14 +353,14 @@ function check_for_user_quit(reader::StdinReader)::Bool
return false
end
-function check_for_loss_threshold(halls_of_fame, options::Options)::Bool
+function check_for_loss_threshold(halls_of_fame, options::AbstractOptions)::Bool
return _check_for_loss_threshold(halls_of_fame, options.early_stop_condition, options)
end
-function _check_for_loss_threshold(_, ::Nothing, ::Options)
+function _check_for_loss_threshold(_, ::Nothing, ::AbstractOptions)
return false
end
-function _check_for_loss_threshold(halls_of_fame, f::F, options::Options) where {F}
+function _check_for_loss_threshold(halls_of_fame, f::F, options::AbstractOptions) where {F}
return all(halls_of_fame) do hof
any(hof.members[hof.exists]) do member
f(member.loss, compute_complexity(member, options))::Bool
@@ -216,12 +368,12 @@ function _check_for_loss_threshold(halls_of_fame, f::F, options::Options) where
end
end
-function check_for_timeout(start_time::Float64, options::Options)::Bool
+function check_for_timeout(start_time::Float64, options::AbstractOptions)::Bool
return options.timeout_in_seconds !== nothing &&
time() - start_time > options.timeout_in_seconds::Float64
end
-function check_max_evals(num_evals, options::Options)::Bool
+function check_max_evals(num_evals, options::AbstractOptions)::Bool
return options.max_evals !== nothing && options.max_evals::Int <= sum(sum, num_evals)
end
@@ -277,7 +429,7 @@ function update_progress_bar!(
progress_bar::WrappedProgressBar,
hall_of_fame::HallOfFame{T,L},
dataset::Dataset{T,L},
- options::Options,
+ options::AbstractOptions,
equation_speed::Vector{Float32},
head_node_occupation::Float64,
parallelism=:serial,
@@ -306,7 +458,7 @@ end
function print_search_state(
hall_of_fames,
datasets;
- options::Options,
+ options::AbstractOptions,
equation_speed::Vector{Float32},
total_cycles::Int,
cycles_remaining::Vector{Int},
@@ -370,13 +522,34 @@ end
load_saved_population(::Nothing; kws...) = nothing
"""
- SearchState{PopType,HallOfFameType,WorkerOutputType,ChannelType}
+ AbstractSearchState{T,L,N}
+
+An abstract type encapsulating the internal state of the search process during symbolic regression.
+
+`AbstractSearchState` instances hold information like populations and progress metrics,
+used internally by `equation_search`. Subtyping `AbstractSearchState` allows
+customization of search state management.
+
+Look through the source of `equation_search` to see how this is used.
+
+# See Also
+
+- [`SearchState`](@ref): Default implementation of `AbstractSearchState`.
+- [`equation_search`](@ref SymbolicRegression.equation_search): Function where `AbstractSearchState` is utilized.
+- [`AbstractOptions`](@ref SymbolicRegression.CoreModule.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options.
+
+"""
+abstract type AbstractSearchState{T,L,N<:AbstractExpression{T}} end
-The state of a search, including the populations, worker outputs, tasks, and
+"""
+ SearchState{T,L,N,WorkerOutputType,ChannelType} <: AbstractSearchState{T,L,N}
+
+The state of the search, including the populations, worker outputs, tasks, and
channels. This is used to manage the search and keep track of runtime variables
in a single struct.
"""
-Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType}
+Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <:
+ AbstractSearchState{T,L,N}
procs::Vector{Int}
we_created_procs::Bool
worker_output::Vector{Vector{WorkerOutputType}}
@@ -396,7 +569,7 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha
end
function save_to_file(
- dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::Options
+ dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::AbstractOptions
) where {T,L}
output_file = options.output_file
if nout > 1
@@ -443,7 +616,9 @@ end
For searches where the maxsize gradually increases, this function returns the
current maxsize.
"""
-function get_cur_maxsize(; options::Options, total_cycles::Int, cycles_remaining::Int)
+function get_cur_maxsize(;
+ options::AbstractOptions, total_cycles::Int, cycles_remaining::Int
+)
cycles_elapsed = total_cycles - cycles_remaining
fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles
in_warmup_period = fraction_elapsed <= options.warmup_maxsize_by
@@ -502,7 +677,7 @@ function construct_datasets(
end
function update_hall_of_fame!(
- hall_of_fame::HallOfFame, members::Vector{PM}, options::Options
+ hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions
) where {PM<:PopMember}
for member in members
size = compute_complexity(member, options)
diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl
index 582f25b15..d15e7914c 100644
--- a/src/SingleIteration.jl
+++ b/src/SingleIteration.jl
@@ -3,7 +3,7 @@ module SingleIterationModule
using ADTypes: AutoEnzyme
using DynamicExpressions: AbstractExpression, string_tree, simplify_tree!, combine_operators
using ..UtilsModule: @threads_if
-using ..CoreModule: Options, Dataset, RecordType, create_expression
+using ..CoreModule: AbstractOptions, Dataset, RecordType, create_expression
using ..ComplexityModule: compute_complexity
using ..PopMemberModule: generate_reference
using ..PopulationModule: Population, finalize_scores
@@ -23,7 +23,7 @@ function s_r_cycle(
curmaxsize::Int,
running_search_statistics::RunningSearchStatistics;
verbosity::Int=0,
- options::Options,
+ options::AbstractOptions,
record::RecordType,
)::Tuple{
P,HallOfFame{T,L,N},Float64
@@ -98,7 +98,7 @@ function s_r_cycle(
end
function optimize_and_simplify_population(
- dataset::D, pop::P, options::Options, curmaxsize::Int, record::RecordType
+ dataset::D, pop::P, options::AbstractOptions, curmaxsize::Int, record::RecordType
)::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}}
array_num_evals = zeros(Float64, pop.n)
do_optimization = rand(pop.n) .< options.optimizer_probability
diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl
index 8dfb53837..02070f324 100644
--- a/src/SymbolicRegression.jl
+++ b/src/SymbolicRegression.jl
@@ -30,6 +30,7 @@ export Population,
compute_complexity,
@parse_expression,
parse_expression,
+ @declare_expression_operator,
print_tree,
string_tree,
eval_tree_array,
@@ -80,7 +81,6 @@ export Population,
using Distributed
using Printf: @printf, @sprintf
-using PackageExtensionCompat: @require_extensions
using Pkg: Pkg
using TOML: parsefile
using Random: seed!, shuffle!
@@ -95,8 +95,10 @@ using DynamicExpressions:
NodeSampler,
AbstractExpression,
AbstractExpressionNode,
+ ExpressionInterface,
@parse_expression,
parse_expression,
+ @declare_expression_operator,
copy_node,
set_node!,
string_tree,
@@ -151,6 +153,21 @@ using DynamicExpressions: with_type_parameters
LogitDistLoss,
QuantileLoss,
LogCoshLoss
+using Compat: @compat
+
+@compat public AbstractOptions,
+AbstractRuntimeOptions,
+RuntimeOptions,
+AbstractMutationWeights,
+mutate!,
+condition_mutation_weights!,
+sample_mutation,
+MutationResult,
+AbstractSearchState,
+SearchState
+# ^ We can add new functions here based on requests from users.
+# However, I don't want to add many functions without knowing what
+# users will actually want to overload.
# https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15
const PACKAGE_VERSION = try
@@ -210,8 +227,11 @@ using .CoreModule:
LOSS_TYPE,
RecordType,
Dataset,
+ AbstractOptions,
Options,
+ AbstractMutationWeights,
MutationWeights,
+ sample_mutation,
plus,
sub,
mult,
@@ -253,12 +273,15 @@ using .PopMemberModule: PopMember, reset_birth!
using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample
using .HallOfFameModule:
HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve
+using .MutateModule: mutate!, condition_mutation_weights!, MutationResult
using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population
using .ProgressBarsModule: WrappedProgressBar
using .RecorderModule: @recorder, find_iteration_from_record
using .MigrationModule: migrate!
using .SearchUtilsModule:
+ AbstractSearchState,
SearchState,
+ AbstractRuntimeOptions,
RuntimeOptions,
WorkerAssignments,
DefaultWorkerOutputType,
@@ -312,7 +335,7 @@ which is useful for debugging and profiling.
More iterations will improve the results.
- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally
weight the loss for each `y` by this value (same shape as `y`).
-- `options::Options=Options()`: The options for the search, such as
+- `options::AbstractOptions=Options()`: The options for the search, such as
which operators to use, evolution hyperparameters, etc.
- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names
of each feature in `X`, which will be used during printing of equations.
@@ -390,7 +413,7 @@ function equation_search(
y::AbstractMatrix{T};
niterations::Int=10,
weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
- options::Options=Options(),
+ options::AbstractOptions=Options(),
variable_names::Union{AbstractVector{String},Nothing}=nothing,
display_variable_names::Union{AbstractVector{String},Nothing}=variable_names,
y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing,
@@ -478,149 +501,23 @@ end
function equation_search(
datasets::Vector{D};
- niterations::Int=10,
- options::Options=Options(),
- parallelism=:multithreading,
- numprocs::Union{Int,Nothing}=nothing,
- procs::Union{Vector{Int},Nothing}=nothing,
- addprocs_function::Union{Function,Nothing}=nothing,
- heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
- runtests::Bool=true,
+ options::AbstractOptions=Options(),
saved_state=nothing,
- return_state::Union{Bool,Nothing,Val}=nothing,
- verbosity::Union{Int,Nothing}=nothing,
- progress::Union{Bool,Nothing}=nothing,
- v_dim_out::Val{DIM_OUT}=Val(nothing),
-) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
- concurrency = if parallelism in (:multithreading, "multithreading")
- :multithreading
- elseif parallelism in (:multiprocessing, "multiprocessing")
- :multiprocessing
- elseif parallelism in (:serial, "serial")
- :serial
+ runtime_options::Union{AbstractRuntimeOptions,Nothing}=nothing,
+ runtime_options_kws...,
+) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
+ runtime_options = if runtime_options === nothing
+ RuntimeOptions(; options, nout=length(datasets), runtime_options_kws...)
else
- error(
- "Invalid parallelism mode: $parallelism. " *
- "You must choose one of :multithreading, :multiprocessing, or :serial.",
- )
- :serial
- end
- not_distributed = concurrency in (:multithreading, :serial)
- not_distributed &&
- procs !== nothing &&
- error(
- "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
- )
- not_distributed &&
- numprocs !== nothing &&
- error(
- "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.",
- )
-
- _return_state = if return_state isa Val
- first(typeof(return_state).parameters)
- else
- if options.return_state === Val(nothing)
- return_state === nothing ? false : return_state
- else
- @assert(
- return_state === nothing,
- "You cannot set `return_state` in both the `Options` and in the passed arguments."
- )
- first(typeof(options.return_state).parameters)
- end
- end
-
- dim_out = if DIM_OUT === nothing
- length(datasets) > 1 ? 2 : 1
- else
- DIM_OUT
- end
- _numprocs::Int = if numprocs === nothing
- if procs === nothing
- 4
- else
- length(procs)
- end
- else
- if procs === nothing
- numprocs
- else
- @assert length(procs) == numprocs
- numprocs
- end
- end
-
- _verbosity = if verbosity === nothing && options.verbosity === nothing
- 1
- elseif verbosity === nothing && options.verbosity !== nothing
- options.verbosity
- elseif verbosity !== nothing && options.verbosity === nothing
- verbosity
- else
- error(
- "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.",
- )
- 1
- end
- _progress::Bool = if progress === nothing && options.progress === nothing
- (_verbosity > 0) && length(datasets) == 1
- elseif progress === nothing && options.progress !== nothing
- options.progress
- elseif progress !== nothing && options.progress === nothing
- progress
- else
- error(
- "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.",
- )
- false
- end
-
- _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function
-
- exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing
- heap_size_hint_in_megabytes = floor(
- Int, (
- if heap_size_hint_in_bytes === nothing
- (Sys.free_memory() / _numprocs)
- else
- heap_size_hint_in_bytes
- end
- ) / 1024^2
- )
- _verbosity > 0 &&
- heap_size_hint_in_bytes === nothing &&
- @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter."
-
- `--heap-size=$(heap_size_hint_in_megabytes)M`
- else
- ``
+ runtime_options
end
# Underscores here mean that we have mutated the variable
- return _equation_search(
- datasets,
- RuntimeOptions(;
- niterations=niterations,
- total_cycles=options.populations * niterations,
- numprocs=_numprocs,
- init_procs=procs,
- addprocs_function=_addprocs_function,
- exeflags=exeflags,
- runtests=runtests,
- verbosity=_verbosity,
- progress=_progress,
- parallelism=Val(concurrency),
- dim_out=Val(dim_out),
- return_state=Val(_return_state),
- ),
- options,
- saved_state,
- )
+ return _equation_search(datasets, runtime_options, options, saved_state)
end
@noinline function _equation_search(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state
+ datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions, saved_state
) where {D<:Dataset}
_validate_options(datasets, ropt, options)
state = _create_workers(datasets, ropt, options)
@@ -632,7 +529,7 @@ end
end
function _validate_options(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options
+ datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions
) where {T,L,D<:Dataset{T,L}}
example_dataset = first(datasets)
nout = length(datasets)
@@ -662,7 +559,7 @@ function _validate_options(
return nothing
end
@stable default_mode = "disable" function _create_workers(
- datasets::Vector{D}, ropt::RuntimeOptions, options::Options
+ datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions
) where {T,L,D<:Dataset{T,L}}
stdin_reader = watch_stream(stdin)
@@ -723,10 +620,11 @@ end
halls_of_fame = Vector{HallOfFameType}(undef, nout)
- cycles_remaining = [ropt.total_cycles for j in 1:nout]
+ total_cycles = ropt.niterations * options.populations
+ cycles_remaining = [total_cycles for j in 1:nout]
cur_maxsizes = [
- get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j])
- for j in 1:nout
+ get_cur_maxsize(; options, total_cycles, cycles_remaining=cycles_remaining[j]) for
+ j in 1:nout
]
return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(;
@@ -749,7 +647,11 @@ end
)
end
function _initialize_search!(
- state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options, saved_state
+ state::AbstractSearchState{T,L,N},
+ datasets,
+ ropt::AbstractRuntimeOptions,
+ options::AbstractOptions,
+ saved_state,
) where {T,L,N}
nout = length(datasets)
@@ -823,7 +725,10 @@ function _initialize_search!(
return nothing
end
function _warmup_search!(
- state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options
+ state::AbstractSearchState{T,L,N},
+ datasets,
+ ropt::AbstractRuntimeOptions,
+ options::AbstractOptions,
) where {T,L,N}
nout = length(datasets)
for j in 1:nout, i in 1:(options.populations)
@@ -864,7 +769,10 @@ function _warmup_search!(
return nothing
end
function _main_search_loop!(
- state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options
+ state::AbstractSearchState{T,L,N},
+ datasets,
+ ropt::AbstractRuntimeOptions,
+ options::AbstractOptions,
) where {T,L,N}
ropt.verbosity > 0 && @info "Started!"
nout = length(datasets)
@@ -1017,8 +925,9 @@ function _main_search_loop!(
)
end
+ total_cycles = ropt.niterations * options.populations
state.cur_maxsizes[j] = get_cur_maxsize(;
- options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j]
+ options, total_cycles, cycles_remaining=state.cycles_remaining[j]
)
move_window!(state.all_running_search_statistics[j])
if ropt.progress
@@ -1062,12 +971,13 @@ function _main_search_loop!(
# Dominating pareto curve - must be better than all simpler equations
head_node_occupation = estimate_work_fraction(resource_monitor)
+ total_cycles = ropt.niterations * options.populations
print_search_state(
state.halls_of_fame,
datasets;
options,
equation_speed,
- ropt.total_cycles,
+ total_cycles,
state.cycles_remaining,
head_node_occupation,
parallelism=ropt.parallelism,
@@ -1092,7 +1002,9 @@ function _main_search_loop!(
end
return nothing
end
-function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options)
+function _tear_down!(
+ state::AbstractSearchState, ropt::AbstractRuntimeOptions, options::AbstractOptions
+)
close_reader!(state.stdin_reader)
# Safely close all processes or threads
if ropt.parallelism == :multiprocessing
@@ -1107,7 +1019,10 @@ function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options)
return nothing
end
function _format_output(
- state::SearchState, datasets, ropt::RuntimeOptions, options::Options
+ state::AbstractSearchState,
+ datasets,
+ ropt::AbstractRuntimeOptions,
+ options::AbstractOptions,
)
nout = length(datasets)
out_hof = if ropt.dim_out == 1
@@ -1128,7 +1043,7 @@ end
@stable default_mode = "disable" function _dispatch_s_r_cycle(
in_pop::Population{T,L,N},
dataset::Dataset,
- options::Options;
+ options::AbstractOptions;
pop::Int,
out::Int,
iteration::Int,
@@ -1171,10 +1086,6 @@ end
include("MLJInterface.jl")
using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor
-function __init__()
- @require_extensions
-end
-
# Hack to get static analysis to work from within tests:
@ignore include("../test/runtests.jl")
diff --git a/src/Utils.jl b/src/Utils.jl
index a667b6987..473845651 100644
--- a/src/Utils.jl
+++ b/src/Utils.jl
@@ -187,11 +187,6 @@ function _save_kwargs(log_variable::Symbol, fdef::Expr)
end
end
-# Allows using `const` fields in older versions of Julia.
-macro constfield(ex)
- return esc(VERSION < v"1.8.0" ? ex : Expr(:const, ex))
-end
-
json3_write(args...) = error("Please load the JSON3.jl package.")
"""
diff --git a/src/deprecates.jl b/src/deprecates.jl
index 54816a408..c8e0b4d57 100644
--- a/src/deprecates.jl
+++ b/src/deprecates.jl
@@ -3,15 +3,6 @@ using Base: @deprecate
import .HallOfFameModule: calculate_pareto_frontier
import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size
-@deprecate(
- gen_random_tree(length::Int, options::Options, nfeatures::Int, t::Type),
- gen_random_tree(length, options, nfeatures, t)
-)
-@deprecate(
- gen_random_tree_fixed_size(node_count::Int, options::Options, nfeatures::Int, t::Type),
- gen_random_tree_fixed_size(node_count, options, nfeatures, t)
-)
-
@deprecate(
calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing, varMap=nothing),
calculate_pareto_frontier(hallOfFame)
@@ -40,7 +31,7 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size
niterations::Int=10,
weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
variable_names::Union{Vector{String},Nothing}=nothing,
- options::Options=Options(),
+ options::AbstractOptions=Options(),
parallelism=:multithreading,
numprocs::Union{Int,Nothing}=nothing,
procs::Union{Vector{Int},Nothing}=nothing,
@@ -75,7 +66,7 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size
EquationSearch(
datasets::Vector{D};
niterations::Int=10,
- options::Options=Options(),
+ options::AbstractOptions=Options(),
parallelism=:multithreading,
numprocs::Union{Int,Nothing}=nothing,
procs::Union{Vector{Int},Nothing}=nothing,
diff --git a/src/precompile.jl b/src/precompile.jl
index 87442695d..13aaac06f 100644
--- a/src/precompile.jl
+++ b/src/precompile.jl
@@ -30,8 +30,6 @@ macro maybe_compile_workload(mode, ex)
end
end
-const PRECOMPILE_OPTIMIZATION = VERSION >= v"1.9.0-DEV.0"
-
"""`mode=:precompile` will use `@precompile_*` directives; `mode=:compile` runs."""
function do_precompilation(::Val{mode}) where {mode}
@maybe_setup_workload mode begin
@@ -57,13 +55,12 @@ function do_precompilation(::Val{mode}) where {mode}
simplify=1.0,
randomize=1.0,
do_nothing=1.0,
- optimize=PRECOMPILE_OPTIMIZATION ? 1.0 : 0.0,
+ optimize=1.0,
),
fraction_replaced=0.2,
fraction_replaced_hof=0.2,
define_helper_functions=false,
- optimizer_probability=PRECOMPILE_OPTIMIZATION ? 0.05 : 0.0,
- should_optimize_constants=PRECOMPILE_OPTIMIZATION,
+ optimizer_probability=0.05,
save_to_file=false,
)
state = equation_search(
diff --git a/test/test_aqua.jl b/test/test_aqua.jl
index 6a4153f66..ef84cc03c 100644
--- a/test/test_aqua.jl
+++ b/test/test_aqua.jl
@@ -3,4 +3,4 @@ using Aqua
Aqua.test_all(SymbolicRegression; ambiguities=false)
-VERSION >= v"1.9" && Aqua.test_ambiguities(SymbolicRegression)
+Aqua.test_ambiguities(SymbolicRegression)
diff --git a/test/test_mlj.jl b/test/test_mlj.jl
index 3ef8d7b93..d26773485 100644
--- a/test/test_mlj.jl
+++ b/test/test_mlj.jl
@@ -87,10 +87,8 @@ end
@test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3])
@test_throws AssertionError predict(mach, (data=X,))
- VERSION >= v"1.8" &&
- @test_throws "If specifying an equation index during" predict(mach, (data=X,))
- VERSION >= v"1.8" &&
- @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1))
+ @test_throws "If specifying an equation index during" predict(mach, (data=X,))
+ @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1))
end
@testitem "Variable names - named outputs" tags = [:part1] begin
@@ -112,7 +110,7 @@ end
test_outs = predict(mach, X)
@test isempty(setdiff((:c1, :c2), keys(test_outs)))
@test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32)))
- VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict(
+ @test_throws "Variable names do not match fitted" predict(
mach, (b1=randn(32), a2=randn(32))
)
end
@@ -146,15 +144,13 @@ end
rng = MersenneTwister(0)
mach = machine(model, randn(rng, 32, 3), randn(rng, 32); scitype_check_level=0)
@test_throws AssertionError @quiet(fit!(mach))
- VERSION >= v"1.8" &&
- @test_throws "For single-output regression, please" @quiet(fit!(mach))
+ @test_throws "For single-output regression, please" @quiet(fit!(mach))
model = SRRegressor()
rng = MersenneTwister(0)
mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 2); scitype_check_level=0)
@test_throws AssertionError @quiet(fit!(mach))
- VERSION >= v"1.8" &&
- @test_throws "For multi-output regression, please" @quiet(fit!(mach))
+ @test_throws "For multi-output regression, please" @quiet(fit!(mach))
model = SRRegressor(; verbosity=0)
rng = MersenneTwister(0)
diff --git a/test/test_operators.jl b/test/test_operators.jl
index 47b83418a..1221ba6c7 100644
--- a/test/test_operators.jl
+++ b/test/test_operators.jl
@@ -99,10 +99,9 @@ end
@test_throws ErrorException SymbolicRegression.assert_operators_well_defined(
ComplexF64, options
)
- VERSION >= v"1.8" &&
- @test_throws "complex plane" SymbolicRegression.assert_operators_well_defined(
- ComplexF64, options
- )
+ @test_throws "complex plane" SymbolicRegression.assert_operators_well_defined(
+ ComplexF64, options
+ )
end
@testset "Operators which return the wrong type should fail" begin
@@ -111,10 +110,9 @@ end
@test_throws ErrorException SymbolicRegression.assert_operators_well_defined(
Float64, options
)
- VERSION >= v"1.8" &&
- @test_throws "returned an output of type" SymbolicRegression.assert_operators_well_defined(
- Float64, options
- )
+ @test_throws "returned an output of type" SymbolicRegression.assert_operators_well_defined(
+ Float64, options
+ )
@test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options)
end
diff --git a/test/test_print.jl b/test/test_print.jl
index fff027504..96500fa0c 100644
--- a/test/test_print.jl
+++ b/test/test_print.jl
@@ -1,13 +1,19 @@
using SymbolicRegression
using SymbolicRegression.UtilsModule: split_string
+using DynamicExpressions: DynamicExpressions as DE
include("test_params.jl")
## Test Base.print
options = Options(;
- default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin)
+ default_params...,
+ binary_operators=(+, *, /, -),
+ unary_operators=(cos, sin),
+ populations=1,
)
+@test DE.OperatorEnumConstructionModule.LATEST_OPERATORS[].unaops == (cos, sin)
+
f = (x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0
tree = f(Node("x1"), Node("x2"), Node("x3"))
@@ -26,6 +32,8 @@ equation_search(
parallelism=:multithreading,
)
+@test DE.OperatorEnumConstructionModule.LATEST_OPERATORS[].unaops == (cos, sin)
+
s = repr(tree)
true_s = "(sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0"
@test s == true_s
diff --git a/test/test_units.jl b/test/test_units.jl
index e35e0a185..a586f5e3c 100644
--- a/test/test_units.jl
+++ b/test/test_units.jl
@@ -311,8 +311,7 @@ end
# TODO: Should return same quantity as input
@test typeof(ypred.a[begin]) <: Quantity
@test typeof(y.a[begin]) <: RealQuantity
- VERSION >= v"1.8" &&
- @eval @test(typeof(ypred.b[begin]) == typeof(y.b[begin]), broken = true)
+ @eval @test(typeof(ypred.b[begin]) == typeof(y.b[begin]), broken = true)
end
end
@@ -322,8 +321,7 @@ end
X = randn(11, 50)
y = randn(50)
- VERSION >= v"1.8.0" &&
- @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg"))
+ @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg"))
end
@testitem "Should print units" tags = [:part3] begin