diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5deca92df..900a7b375 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,8 +30,7 @@ jobs: - "part2" - "part3" julia-version: - - "1.6" - - "1.8" + - "1.10" - "1" os: - ubuntu-latest @@ -54,15 +53,6 @@ jobs: - os: macOS-latest julia-version: "1" test: "part3" - - os: ubuntu-latest - julia-version: "~1.11.0-0" - test: "part1" - - os: ubuntu-latest - julia-version: "~1.11.0-0" - test: "part2" - - os: ubuntu-latest - julia-version: "~1.11.0-0" - test: "part3" steps: - uses: actions/checkout@v4 diff --git a/Project.toml b/Project.toml index 6f899cb60..9f0bf2341 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,6 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -56,7 +55,6 @@ LossFunctions = "0.10, 0.11" MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11" MacroTools = "0.4, 0.5" Optim = "~1.8, ~1.9" -PackageExtensionCompat = "1" Pkg = "<0.0.1, 1" PrecompileTools = "1" Printf = "<0.0.1, 1" @@ -67,7 +65,7 @@ SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" SymbolicUtils = "0.19, ^1.0.5, 2, 3" TOML = "<0.0.1, 1" -julia = "1.6" +julia = "1.10" [extras] Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 7ec312618..9e4862348 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,6 +1,8 @@ using BenchmarkTools using SymbolicRegression, BenchmarkTools, Random using SymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics +using SymbolicRegression.MutateModule: next_generation +using SymbolicRegression.RecorderModule: RecordType using SymbolicRegression.PopulationModule: best_of_sample using SymbolicRegression.ConstantOptimizationModule: optimize_constants using SymbolicRegression.CheckConstraintsModule: check_constraints @@ -93,6 +95,57 @@ function create_utils_benchmark() ) ) + suite["next_generation_x100"] = @benchmarkable( + let + for member in members + next_generation( + dataset, + member, + temperature, + curmaxsize, + rss, + options; + tmp_recorder=recorder, + ) + end + end, + setup = ( + nfeatures = 1; + dataset = Dataset(randn(nfeatures, 32), randn(32)); + mutation_weights = MutationWeights(; + mutate_constant=1.0, + mutate_operator=1.0, + swap_operands=1.0, + rotate_tree=1.0, + add_node=1.0, + insert_node=1.0, + simplify=0.0, + randomize=0.0, + do_nothing=0.0, + form_connection=0.0, + break_connection=0.0, + ); + options = Options(; + unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights + ); + recorder = RecordType(); + temperature = 1.0; + curmaxsize = 20; + rss = RunningSearchStatistics(; options); + trees = [ + gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100 + ]; + expressions = [ + Expression(tree; operators=options.operators, variable_names=["x1"]) for + tree in trees + ]; + members = [ + PopMember(dataset, expression, options; deterministic=false) for + expression in expressions + ] + ) + ) + ntrees = 10 suite["optimize_constants_x10"] = @benchmarkable( foreach(members) do member diff --git a/docs/make.jl b/docs/make.jl index f3ad5f756..b0546821b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,7 +1,21 @@ using Documenter using SymbolicUtils using SymbolicRegression -using SymbolicRegression: Dataset, update_baseline_loss! +using SymbolicRegression: + AbstractExpression, + ExpressionInterface, + Dataset, + update_baseline_loss!, + AbstractMutationWeights, + AbstractOptions, + mutate!, + condition_mutation_weights!, + sample_mutation, + MutationResult, + AbstractRuntimeOptions, + AbstractSearchState, + @extend_operators +using DynamicExpressions DocMeta.setdocmeta!( SymbolicRegression, :DocTestSetup, :(using LossFunctions); recursive=true @@ -40,14 +54,8 @@ readme = replace( # We prepend the `` with a ```@raw html # and append the `
` with a ```: -readme = replace( - readme, - r"" => s"```@raw html\n
", -) -readme = replace( - readme, - r"
" => s"\n```", -) +readme = replace(readme, r"" => s"```@raw html\n
") +readme = replace(readme, r"
" => s"\n```") # Then, we surround ```mermaid\n...\n``` snippets # with ```@raw html\n
\n...\n
```: @@ -96,6 +104,7 @@ makedocs(; "API" => "api.md", "Losses" => "losses.md", "Types" => "types.md", + "Customization" => "customization.md", ], ) diff --git a/docs/src/api.md b/docs/src/api.md index d9ac1fa97..b7e0f7e28 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -7,7 +7,7 @@ SRRegressor MultitargetSRRegressor ``` -## equation_search +## Low-Level API ```@docs equation_search diff --git a/docs/src/customization.md b/docs/src/customization.md new file mode 100644 index 000000000..2a9d9a072 --- /dev/null +++ b/docs/src/customization.md @@ -0,0 +1,61 @@ +# Customization + +Many parts of SymbolicRegression.jl are designed to be customizable. + +The normal way to do this in Julia is to define a new type that subtypes +an abstract type from a package, and then define new methods for the type, +extending internal methods on that type. + +## Custom Options + +For example, you can define a custom options type: + +```@docs +AbstractOptions +``` + +Any function in SymbolicRegression.jl you can generally define a new method +on your custom options type, to define custom behavior. + +## Custom Mutations + +You can define custom mutation operators by defining a new method on +`mutate!`, as well as subtyping `AbstractMutationWeights`: + +```@docs +mutate! +AbstractMutationWeights +condition_mutation_weights! +sample_mutation +MutationResult +``` + +## Custom Expressions + +You can create your own expression types by defining a new type that extends `AbstractExpression`. + +```@docs +AbstractExpression +ExpressionInterface +``` + +The interface is fairly flexible, and permits you define specific functional forms, +extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what +methods you need to implement. Then, for SymbolicRegression.jl, you would +pass `expression_type` to the `Options` constructor, as well as any +`expression_options` you need (as a `NamedTuple`). + +If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in +case your expression needs additional parameters. See the method for `ParametricExpression` +as an example. + +## Other Customizations + +Other internal abstract types include the following: + +```@docs +AbstractRuntimeOptions +AbstractSearchState +``` + +These let you include custom state variables and runtime options. diff --git a/docs/src/types.md b/docs/src/types.md index 5c7277c57..92bf5632e 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -60,24 +60,6 @@ ParametricNode These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. -## Custom Expressions - -You can create your own expression types by defining a new type that extends `AbstractExpression`. - -```@docs -AbstractExpression -``` - -The interface is fairly flexible, and permits you define specific functional forms, -extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what -methods you need to implement. Then, for SymbolicRegression.jl, you would -pass `expression_type` to the `Options` constructor, as well as any -`expression_options` you need (as a `NamedTuple`). - -If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in -case your expression needs additional parameters. See the method for `ParametricExpression` -as an example. - ## Population Groups of equations are given as a population, which is diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl index b438aef4b..aa33fa613 100644 --- a/src/AdaptiveParsimony.jl +++ b/src/AdaptiveParsimony.jl @@ -1,6 +1,6 @@ module AdaptiveParsimonyModule -using ..CoreModule: Options, MAX_DEGREE +using ..CoreModule: AbstractOptions, MAX_DEGREE """ RunningSearchStatistics @@ -23,7 +23,7 @@ struct RunningSearchStatistics normalized_frequencies::Vector{Float64} # Stores `frequencies`, but normalized (updated once in a while) end -function RunningSearchStatistics(; options::Options, window_size::Int=100000) +function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000) maxsize = options.maxsize actualMaxsize = maxsize + MAX_DEGREE init_frequencies = ones(Float64, actualMaxsize) diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index 7f6093631..df748c218 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -2,12 +2,12 @@ module CheckConstraintsModule using DynamicExpressions: AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce -using ..CoreModule: Options +using ..CoreModule: AbstractOptions using ..ComplexityModule: compute_complexity, past_complexity_limit # Check if any binary operator are overly complex function flag_bin_operator_complexity( - tree::AbstractExpressionNode, op, cons, options::Options + tree::AbstractExpressionNode, op, cons, options::AbstractOptions )::Bool any(tree) do subtree if subtree.degree == 2 && subtree.op == op @@ -27,7 +27,7 @@ Check if any unary operators are overly complex. This assumes you have already checked whether the constraint is > -1. """ function flag_una_operator_complexity( - tree::AbstractExpressionNode, op, cons, options::Options + tree::AbstractExpressionNode, op, cons, options::AbstractOptions )::Bool any(tree) do subtree if subtree.degree == 1 && tree.op == op @@ -52,7 +52,7 @@ function count_max_nestedness(tree, degree, op) end """Check if there are any illegal combinations of operators""" -function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Bool +function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptions)::Bool # We search from the top first, then from child nodes at end. (nested_constraints = options.nested_constraints) === nothing && return false for (degree, op_idx, op_constraint) in nested_constraints @@ -72,7 +72,7 @@ end """Check if user-passed constraints are violated or not""" function check_constraints( ex::AbstractExpression, - options::Options, + options::AbstractOptions, maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool @@ -81,7 +81,7 @@ function check_constraints( end function check_constraints( tree::AbstractExpressionNode, - options::Options, + options::AbstractOptions, maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool @@ -103,7 +103,7 @@ function check_constraints( end check_constraints( - ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options + ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions )::Bool = check_constraints(ex, options, options.maxsize) end diff --git a/src/Complexity.jl b/src/Complexity.jl index dccb05bd3..dec8fb63e 100644 --- a/src/Complexity.jl +++ b/src/Complexity.jl @@ -2,10 +2,10 @@ module ComplexityModule using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce -using ..CoreModule: Options, ComplexityMapping +using ..CoreModule: AbstractOptions, ComplexityMapping function past_complexity_limit( - tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit + tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions, limit )::Bool return compute_complexity(tree, options) > limit end @@ -18,12 +18,12 @@ However, it could use the custom settings in options.complexity_mapping if these are defined. """ function compute_complexity( - tree::AbstractExpression, options::Options; break_sharing=Val(false) + tree::AbstractExpression, options::AbstractOptions; break_sharing=Val(false) ) return compute_complexity(get_tree(tree), options; break_sharing) end function compute_complexity( - tree::AbstractExpressionNode, options::Options; break_sharing=Val(false) + tree::AbstractExpressionNode, options::AbstractOptions; break_sharing=Val(false) )::Int if options.complexity_mapping.use raw_complexity = _compute_complexity( diff --git a/src/Configure.jl b/src/Configure.jl index 5ccc08100..eefd63619 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -29,7 +29,7 @@ end const TEST_INPUTS = collect(range(-100, 100; length=99)) -function assert_operators_well_defined(T, options::Options) +function assert_operators_well_defined(T, options::AbstractOptions) test_input = if T <: Complex (x -> convert(T, x)).(TEST_INPUTS .+ TEST_INPUTS .* im) else @@ -45,7 +45,7 @@ end # Check for errors before they happen function test_option_configuration( - parallelism, datasets::Vector{D}, options::Options, verbosity + parallelism, datasets::Vector{D}, options::AbstractOptions, verbosity ) where {T,D<:Dataset{T}} if options.deterministic && parallelism != :serial error("Determinism is only guaranteed for serial mode.") @@ -84,7 +84,7 @@ end # Check for errors before they happen function test_dataset_configuration( - dataset::Dataset{T}, options::Options, verbosity + dataset::Dataset{T}, options::AbstractOptions, verbosity ) where {T<:DATA_TYPE} n = dataset.n if n != size(dataset.X, 2) || @@ -113,7 +113,7 @@ end """ Move custom operators and loss functions to workers, if undefined """ function move_functions_to_workers( - procs, options::Options, dataset::Dataset{T}, verbosity + procs, options::AbstractOptions, dataset::Dataset{T}, verbosity ) where {T} # All the types of functions we need to move to workers: function_sets = ( @@ -168,7 +168,7 @@ function move_functions_to_workers( end end -function copy_definition_to_workers(op, procs, options::Options, verbosity) +function copy_definition_to_workers(op, procs, options::AbstractOptions, verbosity) name = nameof(op) verbosity > 0 && @info "Copying definition of $op to workers..." src_ms = methods(op).ms @@ -191,7 +191,9 @@ function test_function_on_workers(example_inputs, op, procs) end end -function activate_env_on_workers(procs, project_path::String, options::Options, verbosity) +function activate_env_on_workers( + procs, project_path::String, options::AbstractOptions, verbosity +) verbosity > 0 && @info "Activating environment on workers." @everywhere procs begin Base.MainInclude.eval( @@ -203,7 +205,9 @@ function activate_env_on_workers(procs, project_path::String, options::Options, end end -function import_module_on_workers(procs, filename::String, options::Options, verbosity) +function import_module_on_workers( + procs, filename::String, options::AbstractOptions, verbosity +) loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules] included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker @@ -251,7 +255,7 @@ function import_module_on_workers(procs, filename::String, options::Options, ver return nothing end -function test_module_on_workers(procs, options::Options, verbosity) +function test_module_on_workers(procs, options::AbstractOptions, verbosity) verbosity > 0 && @info "Testing module on workers..." futures = [] for proc in procs @@ -268,7 +272,7 @@ function test_module_on_workers(procs, options::Options, verbosity) end function test_entire_pipeline( - procs, dataset::Dataset{T}, options::Options, verbosity + procs, dataset::Dataset{T}, options::AbstractOptions, verbosity ) where {T<:DATA_TYPE} futures = [] verbosity > 0 && @info "Testing entire pipeline on workers..." @@ -310,7 +314,7 @@ function configure_workers(; procs::Union{Vector{Int},Nothing}, numprocs::Int, addprocs_function::Function, - options::Options, + options::AbstractOptions, project_path, file, exeflags::Cmd, @@ -325,10 +329,6 @@ function configure_workers(; end if we_created_procs - if VERSION < v"1.9.0" - # On newer Julia; environment is activated automatically - activate_env_on_workers(procs, project_path, options, verbosity) - end import_module_on_workers(procs, file, options, verbosity) end diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index fe66b4f5d..92b5d0c5d 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -11,13 +11,13 @@ using DynamicExpressions: get_scalar_constants, set_scalar_constants!, extract_gradient -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample using ..PopMemberModule: PopMember function optimize_constants( - dataset::Dataset{T,L}, member::P, options::Options + dataset::Dataset{T,L}, member::P, options::AbstractOptions )::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} if options.batching dispatch_optimize_constants( @@ -28,7 +28,7 @@ function optimize_constants( end end function dispatch_optimize_constants( - dataset::Dataset{T,L}, member::P, options::Options, idx + dataset::Dataset{T,L}, member::P, options::AbstractOptions, idx ) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) @@ -103,7 +103,7 @@ function _optimize_constants( return member, num_evals end -struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:Options,I} <: Function +struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:AbstractOptions,I} <: Function tree::N refs::R dataset::D diff --git a/src/Core.jl b/src/Core.jl index 8f917c906..0a16b52d0 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -13,9 +13,8 @@ include("Options.jl") using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE using .DatasetModule: Dataset -using .MutationWeightsModule: MutationWeights, sample_mutation -using .OptionsStructModule: Options, ComplexityMapping, specialized_options -using .OptionsModule: Options +using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation +using .OptionsStructModule: AbstractOptions, Options, ComplexityMapping, specialized_options using .OperatorsModule: plus, sub, diff --git a/src/Dataset.jl b/src/Dataset.jl index 99c31ee3d..c8cb9767a 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -2,7 +2,7 @@ module DatasetModule using DynamicQuantities: Quantity -using ..UtilsModule: subscriptify, get_base_type, @constfield +using ..UtilsModule: subscriptify, get_base_type using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units @@ -57,24 +57,24 @@ mutable struct Dataset{ XUS<:Union{AbstractVector{<:Quantity},Nothing}, YUS<:Union{Quantity,Nothing}, } - @constfield X::AX - @constfield y::AY - @constfield index::Int - @constfield n::Int - @constfield nfeatures::Int - @constfield weighted::Bool - @constfield weights::AW - @constfield extra::NT - @constfield avg_y::Union{T,Nothing} + const X::AX + const y::AY + const index::Int + const n::Int + const nfeatures::Int + const weighted::Bool + const weights::AW + const extra::NT + const avg_y::Union{T,Nothing} use_baseline::Bool baseline_loss::L - @constfield variable_names::Array{String,1} - @constfield display_variable_names::Array{String,1} - @constfield y_variable_name::String - @constfield X_units::XU - @constfield y_units::YU - @constfield X_sym_units::XUS - @constfield y_sym_units::YUS + const variable_names::Array{String,1} + const display_variable_names::Array{String,1} + const y_variable_name::String + const X_units::XU + const y_units::YU + const X_sym_units::XUS + const y_sym_units::YUS end """ diff --git a/src/DimensionalAnalysis.jl b/src/DimensionalAnalysis.jl index cc9440db1..d469e6848 100644 --- a/src/DimensionalAnalysis.jl +++ b/src/DimensionalAnalysis.jl @@ -3,7 +3,7 @@ module DimensionalAnalysisModule using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof -using ..CoreModule: Options, Dataset +using ..CoreModule: AbstractOptions, Dataset using ..UtilsModule: safe_call import DynamicQuantities: dimension, ustrip @@ -180,12 +180,12 @@ function violates_dimensional_constraints_dispatch( end """ - violates_dimensional_constraints(tree::AbstractExpressionNode, dataset::Dataset, options::Options) + violates_dimensional_constraints(tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions) Checks whether an expression violates dimensional constraints. """ function violates_dimensional_constraints( - tree::AbstractExpressionNode, dataset::Dataset, options::Options + tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions ) X = dataset.X return violates_dimensional_constraints( @@ -193,7 +193,7 @@ function violates_dimensional_constraints( ) end function violates_dimensional_constraints( - tree::AbstractExpression, dataset::Dataset, options::Options + tree::AbstractExpression, dataset::Dataset, options::AbstractOptions ) return violates_dimensional_constraints(get_tree(tree), dataset, options) end @@ -202,7 +202,7 @@ function violates_dimensional_constraints( X_units::AbstractVector{<:Quantity}, y_units::Union{Quantity,Nothing}, x::AbstractVector{T}, - options::Options, + options::AbstractOptions, ) where {T} allow_wildcards = !(options.dimensionless_constants_only) dimensional_output = violates_dimensional_constraints_dispatch( @@ -220,12 +220,20 @@ function violates_dimensional_constraints( return violates end function violates_dimensional_constraints( - ::AbstractExpressionNode{T}, ::Nothing, ::Quantity, ::AbstractVector{T}, ::Options + ::AbstractExpressionNode{T}, + ::Nothing, + ::Quantity, + ::AbstractVector{T}, + ::AbstractOptions, ) where {T} return error("This should never happen. Please submit a bug report.") end function violates_dimensional_constraints( - ::AbstractExpressionNode{T}, ::Nothing, ::Nothing, ::AbstractVector{T}, ::Options + ::AbstractExpressionNode{T}, + ::Nothing, + ::Nothing, + ::AbstractVector{T}, + ::AbstractOptions, ) where {T} return false end diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index a54d97bdd..12b20a06c 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -17,7 +17,7 @@ using DynamicExpressions: eval_tree_array using Random: default_rng, AbstractRNG using StatsBase: StatsBase -using ..CoreModule: Options, Dataset, DATA_TYPE +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE using ..HallOfFameModule: HallOfFame using ..LossFunctionsModule: maybe_getindex using ..InterfaceDynamicExpressionsModule: expected_array_type @@ -32,7 +32,7 @@ import ..LossFunctionsModule: eval_tree_dispatch import ..ConstantOptimizationModule: count_constants_for_optimization @unstable function create_expression( - t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false) + t::T, options::AbstractOptions, dataset::Dataset{T,L}, ::Val{embed}=Val(false) ) where {T,L,embed} return create_expression( constructorof(options.node_type)(; val=t), options, dataset, Val(embed) @@ -40,7 +40,7 @@ import ..ConstantOptimizationModule: count_constants_for_optimization end @unstable function create_expression( t::AbstractExpressionNode{T}, - options::Options, + options::AbstractOptions, dataset::Dataset{T,L}, ::Val{embed}=Val(false), ) where {T,L,embed} @@ -49,12 +49,12 @@ end ) end function create_expression( - ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false) + ex::AbstractExpression{T}, ::AbstractOptions, ::Dataset{T,L}, ::Val{embed}=Val(false) ) where {T,L,embed} return ex end @unstable function init_params( - options::Options, + options::AbstractOptions, dataset::Dataset{T,L}, prototype::Union{Nothing,AbstractExpression}, ::Val{embed}, @@ -71,7 +71,7 @@ end function extra_init_params( ::Type{E}, prototype::Union{Nothing,AbstractExpression}, - options::Options, + options::AbstractOptions, dataset::Dataset{T}, ::Val{embed}, ) where {T,embed,E<:AbstractExpression} @@ -80,7 +80,7 @@ end function extra_init_params( ::Type{E}, prototype::Union{Nothing,ParametricExpression}, - options::Options, + options::AbstractOptions, dataset::Dataset{T}, ::Val{embed}, ) where {T,embed,E<:ParametricExpression} @@ -95,8 +95,8 @@ function extra_init_params( return (; parameters=_parameters, parameter_names) end -consistency_checks(::Options, prototype::Nothing) = nothing -function consistency_checks(options::Options, prototype) +consistency_checks(::AbstractOptions, prototype::Nothing) = nothing +function consistency_checks(options::AbstractOptions, prototype) if prototype === nothing return nothing end @@ -120,12 +120,12 @@ end @unstable begin function embed_metadata( - ex::AbstractExpression, options::Options, dataset::Dataset{T,L} + ex::AbstractExpression, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) end function embed_metadata( - member::PopMember, options::Options, dataset::Dataset{T,L} + member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return PopMember( embed_metadata(member.tree, options, dataset), @@ -138,37 +138,39 @@ end ) end function embed_metadata( - pop::Population, options::Options, dataset::Dataset{T,L} + pop::Population, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return Population( map(member -> embed_metadata(member, options, dataset), pop.members) ) end function embed_metadata( - hof::HallOfFame, options::Options, dataset::Dataset{T,L} + hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return HallOfFame( map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists ) end function embed_metadata( - vec::Vector{H}, options::Options, dataset::Dataset{T,L} + vec::Vector{H}, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} return map(elem -> embed_metadata(elem, options, dataset), vec) end end """Strips all metadata except for top-level information""" -function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L} +function strip_metadata( + ex::Expression, options::AbstractOptions, dataset::Dataset{T,L} +) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end function strip_metadata( - ex::ParametricExpression, options::Options, dataset::Dataset{T,L} + ex::ParametricExpression, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) end function strip_metadata( - member::PopMember, options::Options, dataset::Dataset{T,L} + member::PopMember, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return PopMember( strip_metadata(member.tree, options, dataset), @@ -181,12 +183,12 @@ function strip_metadata( ) end function strip_metadata( - pop::Population, options::Options, dataset::Dataset{T,L} + pop::Population, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return Population(map(member -> strip_metadata(member, options, dataset), pop.members)) end function strip_metadata( - hof::HallOfFame, options::Options, dataset::Dataset{T,L} + hof::HallOfFame, options::AbstractOptions, dataset::Dataset{T,L} ) where {T,L} return HallOfFame( map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists @@ -194,7 +196,7 @@ function strip_metadata( end function eval_tree_dispatch( - tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx + tree::ParametricExpression{T}, dataset::Dataset{T}, options::AbstractOptions, idx ) where {T<:DATA_TYPE} A = expected_array_type(dataset.X) return eval_tree_array( @@ -210,7 +212,7 @@ function make_random_leaf( ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng(), - options::Union{Options,Nothing}=nothing, + options::Union{AbstractOptions,Nothing}=nothing, ) where {T<:DATA_TYPE,N<:ParametricNode} choice = rand(rng, 1:3) if choice == 1 @@ -263,7 +265,7 @@ end function mutate_constant( ex::ParametricExpression{T}, temperature, - options::Options, + options::AbstractOptions, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} if rand(rng, Bool) @@ -280,10 +282,10 @@ function mutate_constant( end end -@unstable function get_operators(ex::AbstractExpression, options::Options) +@unstable function get_operators(ex::AbstractExpression, options::AbstractOptions) return get_operators(ex, options.operators) end -@unstable function get_operators(ex::AbstractExpressionNode, options::Options) +@unstable function get_operators(ex::AbstractExpressionNode, options::AbstractOptions) return get_operators(ex, options.operators) end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 19c52f933..71032dfd5 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -3,7 +3,7 @@ module HallOfFameModule using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: split_string using ..CoreModule: - MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression + MAX_DEGREE, AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: PopMember using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING @@ -48,7 +48,7 @@ function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where end """ - HallOfFame(options::Options, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} + HallOfFame(options::AbstractOptions, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} Create empty HallOfFame. The HallOfFame stores a list of `PopMember` objects in `.members`, which is enumerated @@ -57,11 +57,11 @@ by size (i.e., `.members[1]` is the constant solution). has been instantiated or not. Arguments: -- `options`: Options containing specification about deterministic. +- `options`: AbstractOptions containing specification about deterministic. - `dataset`: Dataset containing the input data. """ function HallOfFame( - options::Options, dataset::Dataset{T,L} + options::AbstractOptions, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} actualMaxsize = options.maxsize + MAX_DEGREE base_tree = create_expression(zero(T), options, dataset) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index d5cf52300..887627b7d 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -11,14 +11,14 @@ using DynamicExpressions: Node, GraphNode using DynamicQuantities: dimension, ustrip -using ..CoreModule: Options +using ..CoreModule: AbstractOptions using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify import ..deprecate_varmap """ - eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...) + eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; kws...) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -41,7 +41,7 @@ which speed up evaluation significantly. # Arguments - `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate. - `X::AbstractArray`: The input data to evaluate the tree on. -- `options::Options`: Options used to define the operators used in the tree. +- `options::AbstractOptions`: Options used to define the operators used in the tree. # Returns - `(output, complete)::Tuple{AbstractVector, Bool}`: the result, @@ -53,7 +53,7 @@ which speed up evaluation significantly. function DE.eval_tree_array( tree::Union{AbstractExpressionNode,AbstractExpression}, X::AbstractMatrix, - options::Options; + options::AbstractOptions; kws..., ) A = expected_array_type(X) @@ -70,7 +70,7 @@ function DE.eval_tree_array( tree::ParametricExpression, X::AbstractMatrix, classes::AbstractVector{<:Integer}, - options::Options; + options::AbstractOptions; kws..., ) A = expected_array_type(X) @@ -91,7 +91,7 @@ function expected_array_type(X::AbstractArray) end """ - eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int) + eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions, direction::Int) Compute the forward derivative of an expression, using a similar structure and optimization to eval_tree_array. `direction` is the index of a particular @@ -102,7 +102,7 @@ respect to `x1`. - `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. -- `options::Options`: The options containing the operators used to create the `tree`. +- `options::AbstractOptions`: The options containing the operators used to create the `tree`. - `direction::Int`: The index of the variable to take the derivative with respect to. # Returns @@ -113,7 +113,7 @@ respect to `x1`. function DE.eval_diff_tree_array( tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, - options::Options, + options::AbstractOptions, direction::Int, ) A = expected_array_type(X) @@ -124,7 +124,7 @@ function DE.eval_diff_tree_array( end """ - eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false) + eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; variable::Bool=false) Compute the forward-mode derivative of an expression, using a similar structure and optimization to eval_tree_array. `variable` specifies whether @@ -135,7 +135,7 @@ to every constant in the expression. - `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. -- `options::Options`: The options containing the operators used to create the `tree`. +- `options::AbstractOptions`: The options containing the operators used to create the `tree`. - `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`), or with respect to every constant in the expression (`variable=false`). @@ -147,7 +147,7 @@ to every constant in the expression. function DE.eval_grad_tree_array( tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, - options::Options; + options::AbstractOptions; kws..., ) A = expected_array_type(X) @@ -158,14 +158,14 @@ function DE.eval_grad_tree_array( end """ - differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options) + differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::AbstractOptions) Evaluate an expression tree in a way that can be auto-differentiated. """ function DE.differentiable_eval_tree_array( tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, - options::Options, + options::AbstractOptions, ) A = expected_array_type(X) # TODO: Add `AbstractExpression` implementation in `Expression.jl` @@ -177,20 +177,20 @@ end const WILDCARD_UNIT_STRING = "[?]" """ - string_tree(tree::AbstractExpressionNode, options::Options; kws...) + string_tree(tree::AbstractExpressionNode, options::AbstractOptions; kws...) Convert an equation to a string. # Arguments - `tree::AbstractExpressionNode`: The equation to convert to a string. -- `options::Options`: The options holding the definition of operators. +- `options::AbstractOptions`: The options holding the definition of operators. - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ @inline function DE.string_tree( tree::Union{AbstractExpression,AbstractExpressionNode}, - options::Options; + options::AbstractOptions; raw::Bool=true, X_sym_units=nothing, y_sym_units=nothing, @@ -283,24 +283,27 @@ end end """ - print_tree(tree::AbstractExpressionNode, options::Options; kws...) + print_tree(tree::AbstractExpressionNode, options::AbstractOptions; kws...) Print an equation # Arguments - `tree::AbstractExpressionNode`: The equation to convert to a string. -- `options::Options`: The options holding the definition of operators. +- `options::AbstractOptions`: The options holding the definition of operators. - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ function DE.print_tree( - tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... + tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions; kws... ) return DE.print_tree(tree, DE.get_operators(tree, options); kws...) end function DE.print_tree( - io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... + io::IO, + tree::Union{AbstractExpression,AbstractExpressionNode}, + options::AbstractOptions; + kws..., ) return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...) end @@ -317,7 +320,7 @@ defined. """ macro extend_operators(options) operators = :($(options).operators) - type_requirements = Options + type_requirements = AbstractOptions alias_operators = gensym("alias_operators") return quote if !isa($(options), $type_requirements) @@ -340,7 +343,7 @@ function define_alias_operators(operators) end function (tree::Union{AbstractExpression,AbstractExpressionNode})( - X, options::Options; kws... + X, options::AbstractOptions; kws... ) return tree( X, @@ -351,7 +354,10 @@ function (tree::Union{AbstractExpression,AbstractExpressionNode})( ) end function DE.EvaluationHelpersModule._grad_evaluator( - tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws... + tree::Union{AbstractExpression,AbstractExpressionNode}, + X, + options::AbstractOptions; + kws..., ) return DE.EvaluationHelpersModule._grad_evaluator( tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws... diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index a84218879..b41ad3a38 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -6,7 +6,7 @@ using DynamicExpressions: using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss using ..InterfaceDynamicExpressionsModule: expected_array_type -using ..CoreModule: Options, Dataset, create_expression, DATA_TYPE, LOSS_TYPE +using ..CoreModule: AbstractOptions, Dataset, create_expression, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints @@ -44,7 +44,7 @@ end function eval_tree_dispatch( tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T}, - options::Options, + options::AbstractOptions, idx, ) where {T<:DATA_TYPE} A = expected_array_type(dataset.X) @@ -55,7 +55,7 @@ end function _eval_loss( tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, - options::Options, + options::AbstractOptions, regularization::Bool, idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} @@ -84,7 +84,11 @@ end # This evaluates function F: function evaluator( - f::F, tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options::Options, idx + f::F, + tree::AbstractExpressionNode{T}, + dataset::Dataset{T,L}, + options::AbstractOptions, + idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE,F} if hasmethod(f, typeof((tree, dataset, options, idx))) # If user defines method that accepts batching indices: @@ -105,7 +109,7 @@ end function eval_loss( tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, - options::Options; + options::AbstractOptions; regularization::Bool=true, idx=nothing, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} @@ -122,7 +126,7 @@ end function eval_loss_batched( tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, - options::Options; + options::AbstractOptions; regularization::Bool=true, idx=nothing, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} @@ -148,7 +152,7 @@ function loss_to_score( use_baseline::Bool, baseline::L, member, - options::Options, + options::AbstractOptions, complexity::Union{Int,Nothing}=nothing, )::L where {L<:LOSS_TYPE} # TODO: Come up with a more general normalization scheme. @@ -167,7 +171,10 @@ end # Score an equation function score_func( - dataset::Dataset{T,L}, member, options::Options; complexity::Union{Int,Nothing}=nothing + dataset::Dataset{T,L}, + member, + options::AbstractOptions; + complexity::Union{Int,Nothing}=nothing, )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} result_loss = eval_loss(get_tree_from_member(member), dataset, options) score = loss_to_score( @@ -185,7 +192,7 @@ end function score_func_batched( dataset::Dataset{T,L}, member, - options::Options; + options::AbstractOptions; complexity::Union{Int,Nothing}=nothing, idx=nothing, )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} @@ -202,12 +209,12 @@ function score_func_batched( end """ - update_baseline_loss!(dataset::Dataset{T,L}, options::Options) where {T<:DATA_TYPE,L<:LOSS_TYPE} + update_baseline_loss!(dataset::Dataset{T,L}, options::AbstractOptions) where {T<:DATA_TYPE,L<:LOSS_TYPE} Update the baseline loss of the dataset using the loss function specified in `options`. """ function update_baseline_loss!( - dataset::Dataset{T,L}, options::Options + dataset::Dataset{T,L}, options::AbstractOptions ) where {T<:DATA_TYPE,L<:LOSS_TYPE} example_tree = create_expression(zero(T), options, dataset) # constructorof(options.node_type)(T; val=dataset.avg_y) @@ -226,7 +233,7 @@ end function dimensional_regularization( tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, - options::Options, + options::AbstractOptions, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} if !violates_dimensional_constraints(tree, dataset, options) return zero(L) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 2bbce47f1..028493ff7 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -23,9 +23,8 @@ using DynamicQuantities: ustrip, dimension using LossFunctions: SupervisedLoss -using Compat: allequal, stack using ..InterfaceDynamicQuantitiesModule: get_dimensions_type -using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE +using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame diff --git a/src/Migration.jl b/src/Migration.jl index daab9255f..d08dab2ac 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -1,20 +1,20 @@ module MigrationModule using StatsBase: StatsBase -using ..CoreModule: Options +using ..CoreModule: AbstractOptions using ..PopulationModule: Population using ..PopMemberModule: PopMember, reset_birth! using ..UtilsModule: poisson_sample """ - migrate!(migration::Pair{Population{T,L},Population{T,L}}, options::Options; frac::AbstractFloat) + migrate!(migration::Pair{Population{T,L},Population{T,L}}, options::AbstractOptions; frac::AbstractFloat) Migrate a fraction of the population from one population to the other, creating copies to do so. The original migrant population is not modified. Pass with, e.g., `migrate!(migration_candidates => destination, options; frac=0.1)` """ function migrate!( - migration::Pair{Vector{PM},P}, options::Options; frac::AbstractFloat + migration::Pair{Vector{PM},P}, options::AbstractOptions; frac::AbstractFloat ) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} base_pop = migration.second population_size = length(base_pop.members) diff --git a/src/Mutate.jl b/src/Mutate.jl index c559d42ba..0f383177b 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -10,7 +10,8 @@ using DynamicExpressions: count_scalar_constants, simplify_tree!, combine_operators -using ..CoreModule: Options, MutationWeights, Dataset, RecordType, sample_mutation +using ..CoreModule: + AbstractOptions, AbstractMutationWeights, Dataset, RecordType, sample_mutation using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, score_func_batched using ..CheckConstraintsModule: check_constraints @@ -32,8 +33,68 @@ using ..MutationFunctionsModule: using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder +abstract type AbstractMutationResult{N<:AbstractExpression,P<:PopMember} end + +""" + MutationResult{N<:AbstractExpression,P<:PopMember} + +Represents the result of a mutation operation in the genetic programming algorithm. This struct is used to return values from `mutate!` functions. + +# Fields + +- `tree::Union{N, Nothing}`: The mutated expression tree, if applicable. Either `tree` or `member` must be set, but not both. +- `member::Union{P, Nothing}`: The mutated population member, if applicable. Either `member` or `tree` must be set, but not both. +- `num_evals::Float64`: The number of evaluations performed during the mutation, which is automatically set to `0.0`. Only used for things like `optimize`. +- `return_immediately::Bool`: If `true`, the mutation process should return immediately, bypassing further checks, used for things like `simplify` or `optimize` where you already know the loss value of the result. + +# Usage + +This struct encapsulates the result of a mutation operation. Either a new expression tree or a new population member is returned, but not both. + +Return the `member` if you want to return immediately, and have +computed the loss value as part of the mutation. +""" +struct MutationResult{N<:AbstractExpression,P<:PopMember} <: AbstractMutationResult{N,P} + tree::Union{N,Nothing} + member::Union{P,Nothing} + num_evals::Float64 + return_immediately::Bool + + # Explicit constructor with keyword arguments + function MutationResult{_N,_P}(; + tree::Union{_N,Nothing}=nothing, + member::Union{_P,Nothing}=nothing, + num_evals::Float64=0.0, + return_immediately::Bool=false, + ) where {_N<:AbstractExpression,_P<:PopMember} + @assert( + (tree === nothing) ⊻ (member === nothing), + "Mutation result must return either a tree or a pop member, not both" + ) + return new{_N,_P}(tree, member, num_evals, return_immediately) + end +end + +""" + condition_mutation_weights!(weights::AbstractMutationWeights, member::PopMember, options::AbstractOptions, curmaxsize::Int) + +Adjusts the mutation weights based on the properties of the current member and options. + +This function modifies the mutation weights to ensure that the mutations applied to the member are appropriate given its current state and the provided options. It can be overloaded to customize the behavior for different types of expressions or members. + +Note that the weights were already copied, so you don't need to worry about mutation. + +# Arguments +- `weights::AbstractMutationWeights`: The mutation weights to be adjusted. +- `member::PopMember`: The current population member being mutated. +- `options::AbstractOptions`: The options that guide the mutation process. +- `curmaxsize::Int`: The current maximum size constraint for the member's expression tree. +""" function condition_mutation_weights!( - weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int + weights::AbstractMutationWeights, + member::PopMember, + options::AbstractOptions, + curmaxsize::Int, ) tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) @@ -81,9 +142,9 @@ Use this to modify how `mutate_constant` changes for an expression type. """ function condition_mutate_constant!( ::Type{<:AbstractExpression}, - weights::MutationWeights, + weights::AbstractMutationWeights, member::PopMember, - options::Options, + options::AbstractOptions, curmaxsize::Int, ) n_constants = count_scalar_constants(member.tree) @@ -93,9 +154,9 @@ function condition_mutate_constant!( end function condition_mutate_constant!( ::Type{<:ParametricExpression}, - weights::MutationWeights, + weights::AbstractMutationWeights, member::PopMember, - options::Options, + options::AbstractOptions, curmaxsize::Int, ) # Avoid modifying the mutate_constant weight, since @@ -111,13 +172,12 @@ function next_generation( temperature, curmaxsize::Int, running_search_statistics::RunningSearchStatistics, - options::Options; + options::AbstractOptions; tmp_recorder::RecordType, )::Tuple{ P,Bool,Float64 } where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} parent_ref = member.ref - mutation_accepted = false num_evals = 0.0 #TODO - reconsider this @@ -137,8 +197,6 @@ function next_generation( mutation_choice = sample_mutation(weights) successful_mutation = false - #TODO: Currently we dont take this \/ into account - is_success_always_possible = true attempts = 0 max_attempts = 10 @@ -148,133 +206,41 @@ function next_generation( local tree while (!successful_mutation) && attempts < max_attempts tree = copy_node(member.tree) - successful_mutation = true - if mutation_choice == :mutate_constant - tree = mutate_constant(tree, temperature, options) - @recorder tmp_recorder["type"] = "constant" - is_success_always_possible = true - # Mutating a constant shouldn't invalidate an already-valid function - elseif mutation_choice == :mutate_operator - tree = mutate_operator(tree, options) - @recorder tmp_recorder["type"] = "operator" - is_success_always_possible = true - # Can always mutate to the same operator - - elseif mutation_choice == :swap_operands - tree = swap_operands(tree) - @recorder tmp_recorder["type"] = "swap_operands" - is_success_always_possible = true - - elseif mutation_choice == :add_node - if rand() < 0.5 - tree = append_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "append_op" - else - tree = prepend_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "prepend_op" - end - is_success_always_possible = false - # Can potentially have a situation without success - elseif mutation_choice == :insert_node - tree = insert_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "insert_op" - is_success_always_possible = false - elseif mutation_choice == :delete_node - tree = delete_random_op!(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "delete_op" - is_success_always_possible = true - elseif mutation_choice == :simplify - @assert options.should_simplify - simplify_tree!(tree, options.operators) - tree = combine_operators(tree, options.operators) - @recorder tmp_recorder["type"] = "partial_simplify" - mutation_accepted = true - is_success_always_possible = true - return ( - PopMember( - tree, - beforeScore, - beforeLoss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - # Simplification shouldn't hurt complexity; unless some non-symmetric constraint - # to commutative operator... - elseif mutation_choice == :randomize - # We select a random size, though the generated tree - # may have fewer nodes than we request. - tree_size_to_generate = rand(1:curmaxsize) - tree = with_contents( - tree, - gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T), - ) - @recorder tmp_recorder["type"] = "regenerate" - is_success_always_possible = true - elseif mutation_choice == :optimize - cur_member = PopMember( - tree, - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ) - cur_member, new_num_evals = optimize_constants(dataset, cur_member, options) - num_evals += new_num_evals - @recorder tmp_recorder["type"] = "optimize" - mutation_accepted = true - is_success_always_possible = true - return (cur_member, mutation_accepted, num_evals) - elseif mutation_choice == :do_nothing - @recorder begin - tmp_recorder["type"] = "identity" - tmp_recorder["result"] = "accept" - tmp_recorder["reason"] = "identity" - end - mutation_accepted = true - is_success_always_possible = true - return ( - PopMember( - tree, - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + mutation_result = _dispatch_mutations!( + tree, + member, + mutation_choice, + options.mutation_weights, + options; + recorder=tmp_recorder, + temperature, + dataset, + score=beforeScore, + loss=beforeLoss, + parent_ref, + curmaxsize, + nfeatures, + ) + mutation_result::AbstractMutationResult{N,P} + num_evals += mutation_result.num_evals::Float64 + + if mutation_result.return_immediately + @assert( + mutation_result.member isa P, + "Mutation result must return a `PopMember` if `return_immediately` is true" ) - elseif mutation_choice == :form_connection - tree = form_random_connection!(tree) - @recorder tmp_recorder["type"] = "form_connection" - is_success_always_possible = true - elseif mutation_choice == :break_connection - tree = break_random_connection!(tree) - @recorder tmp_recorder["type"] = "break_connection" - is_success_always_possible = true - elseif mutation_choice == :rotate_tree - tree = randomly_rotate_tree!(tree) - @recorder tmp_recorder["type"] = "rotate_tree" - is_success_always_possible = true + return mutation_result.member::P, true, num_evals else - error("Unknown mutation choice: $mutation_choice") + @assert( + mutation_result.tree isa N, + "Mutation result must return a tree if `return_immediately` is false" + ) + tree = mutation_result.tree::N + successful_mutation = check_constraints(tree, options, curmaxsize) + attempts += 1 end - - successful_mutation = - successful_mutation && check_constraints(tree, options, curmaxsize) - - attempts += 1 end - ############################################# - tree::AbstractExpression if !successful_mutation @recorder begin @@ -389,9 +355,300 @@ function next_generation( end end +@generated function _dispatch_mutations!( + tree::AbstractExpression, + member::PopMember, + mutation_choice::Symbol, + weights::W, + options::AbstractOptions; + kws..., +) where {W<:AbstractMutationWeights} + mutation_choices = fieldnames(W) + quote + Base.Cartesian.@nif( + $(length(mutation_choices)), + i -> mutation_choice == $(mutation_choices)[i], + i -> begin + @assert mutation_choice == $(mutation_choices)[i] + mutate!( + tree, member, Val($(mutation_choices)[i]), weights, options; kws... + ) + end, + ) + end +end + +""" + mutate!( + tree::N, + member::P, + ::Val{S}, + mutation_weights::AbstractMutationWeights, + options::AbstractOptions; + kws..., + ) where {N<:AbstractExpression,P<:PopMember,S} + +Perform a mutation on the given `tree` and `member` using the specified mutation type `S`. +Various `kws` are provided to access other data needed for some mutations. + +You may overload this function to handle new mutation types for new `AbstractMutationWeights` types. + +# Keywords + +- `temperature`: The temperature parameter for annealing-based mutations. +- `dataset::Dataset`: The dataset used for scoring. +- `score`: The score of the member before mutation. +- `loss`: The loss of the member before mutation. +- `curmaxsize`: The current maximum size constraint, which may be different from `options.maxsize`. +- `nfeatures`: The number of features in the dataset. +- `parent_ref`: Reference to the mutated member's parent (only used for logging purposes). +- `recorder::RecordType`: A recorder to log mutation details. + +# Returns + +A `MutationResult{N,P}` object containing the mutated tree or member (but not both), +the number of evaluations performed, if any, and whether to return immediately from +the mutation function, or to let the `next_generation` function handle accepting or +rejecting the mutation. For example, a `simplify` operation will not change the loss, +so it can always return immediately. +""" +function mutate!( + ::N, ::P, ::Val{S}, ::AbstractMutationWeights, ::AbstractOptions; kws... +) where {N<:AbstractExpression,P<:PopMember,S} + return error("Unknown mutation choice: $S") +end + +function mutate!( + tree::N, + member::P, + ::Val{:mutate_constant}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + temperature, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = mutate_constant(tree, temperature, options) + @recorder recorder["type"] = "mutate_constant" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:mutate_operator}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = mutate_operator(tree, options) + @recorder recorder["type"] = "mutate_operator" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:swap_operands}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = swap_operands(tree) + @recorder recorder["type"] = "swap_operands" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:add_node}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + nfeatures, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + if rand() < 0.5 + tree = append_random_op(tree, options, nfeatures) + @recorder recorder["type"] = "add_node:append" + else + tree = prepend_random_op(tree, options, nfeatures) + @recorder recorder["type"] = "add_node:prepend" + end + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:insert_node}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + nfeatures, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = insert_random_op(tree, options, nfeatures) + @recorder recorder["type"] = "insert_node" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:delete_node}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + nfeatures, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = delete_random_op!(tree, options, nfeatures) + @recorder recorder["type"] = "delete_node" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:form_connection}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = form_random_connection!(tree) + @recorder recorder["type"] = "form_connection" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:break_connection}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = break_random_connection!(tree) + @recorder recorder["type"] = "break_connection" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:rotate_tree}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + tree = randomly_rotate_tree!(tree) + @recorder recorder["type"] = "rotate_tree" + return MutationResult{N,P}(; tree=tree) +end + +# Handle mutations that require early return +function mutate!( + tree::N, + member::P, + ::Val{:simplify}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + parent_ref, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + @assert options.should_simplify + simplify_tree!(tree, options.operators) + tree = combine_operators(tree, options.operators) + @recorder recorder["type"] = "simplify" + return MutationResult{N,P}(; + member=PopMember( + tree, + member.score, + member.loss, + options; + parent=parent_ref, + deterministic=options.deterministic, + ), + return_immediately=true, + ) +end + +function mutate!( + tree::N, + ::P, + ::Val{:randomize}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + curmaxsize, + nfeatures, + kws..., +) where {T,N<:AbstractExpression{T},P<:PopMember} + tree_size_to_generate = rand(1:curmaxsize) + tree = with_contents( + tree, gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T) + ) + @recorder recorder["type"] = "randomize" + return MutationResult{N,P}(; tree=tree) +end + +function mutate!( + tree::N, + member::P, + ::Val{:optimize}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + dataset::Dataset, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + cur_member, new_num_evals = optimize_constants(dataset, member, options) + @recorder recorder["type"] = "optimize" + return MutationResult{N,P}(; + member=cur_member, num_evals=new_num_evals, return_immediately=true + ) +end + +function mutate!( + tree::N, + member::P, + ::Val{:do_nothing}, + ::AbstractMutationWeights, + options::AbstractOptions; + recorder::RecordType, + parent_ref, + kws..., +) where {N<:AbstractExpression,P<:PopMember} + @recorder begin + recorder["type"] = "identity" + recorder["result"] = "accept" + recorder["reason"] = "identity" + end + return MutationResult{N,P}(; + member=PopMember( + tree, + member.score, + member.loss, + options, + compute_complexity(tree, options); + parent=parent_ref, + deterministic=options.deterministic, + ), + return_immediately=true, + ) +end + """Generate a generation via crossover of two members.""" function crossover_generation( - member1::P, member2::P, dataset::D, curmaxsize::Int, options::Options + member1::P, member2::P, dataset::D, curmaxsize::Int, options::AbstractOptions )::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree @@ -460,4 +717,4 @@ function crossover_generation( return baby1, baby2, crossover_accepted, num_evals end -end +end # module MutateModule diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index 844160de7..496d584dd 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -14,8 +14,7 @@ using DynamicExpressions: count_nodes, has_constants, has_operators -using Compat: Returns, @inline -using ..CoreModule: Options, DATA_TYPE +using ..CoreModule: AbstractOptions, DATA_TYPE """ random_node(tree::AbstractNode; filter::F=Returns(true)) @@ -50,14 +49,16 @@ end """Randomly convert an operator into another one (binary->binary; unary->unary)""" function mutate_operator( - ex::AbstractExpression{T}, options::Options, rng::AbstractRNG=default_rng() + ex::AbstractExpression{T}, options::AbstractOptions, rng::AbstractRNG=default_rng() ) where {T<:DATA_TYPE} tree = get_contents(ex) ex = with_contents(ex, mutate_operator(tree, options, rng)) return ex end function mutate_operator( - tree::AbstractExpressionNode{T}, options::Options, rng::AbstractRNG=default_rng() + tree::AbstractExpressionNode{T}, + options::AbstractOptions, + rng::AbstractRNG=default_rng(), ) where {T} if !(has_operators(tree)) return tree @@ -73,7 +74,10 @@ end """Randomly perturb a constant""" function mutate_constant( - ex::AbstractExpression{T}, temperature, options::Options, rng::AbstractRNG=default_rng() + ex::AbstractExpression{T}, + temperature, + options::AbstractOptions, + rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} tree = get_contents(ex) ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) @@ -82,7 +86,7 @@ end function mutate_constant( tree::AbstractExpressionNode{T}, temperature, - options::Options, + options::AbstractOptions, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} # T is between 0 and 1. @@ -116,7 +120,7 @@ end """Add a random unary/binary operation to the end of a tree""" function append_random_op( ex::AbstractExpression{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(); makeNewBinOp::Union{Bool,Nothing}=nothing, @@ -127,7 +131,7 @@ function append_random_op( end function append_random_op( tree::AbstractExpressionNode{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(); makeNewBinOp::Union{Bool,Nothing}=nothing, @@ -160,7 +164,7 @@ end """Insert random node""" function insert_random_op( ex::AbstractExpression{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -170,7 +174,7 @@ function insert_random_op( end function insert_random_op( tree::AbstractExpressionNode{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -194,7 +198,7 @@ end """Add random node to the top of a tree""" function prepend_random_op( ex::AbstractExpression{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -204,7 +208,7 @@ function prepend_random_op( end function prepend_random_op( tree::AbstractExpressionNode{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -230,7 +234,7 @@ function make_random_leaf( ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng(), - ::Union{Options,Nothing}=nothing, + ::Union{AbstractOptions,Nothing}=nothing, ) where {T<:DATA_TYPE,N<:AbstractExpressionNode} if rand(rng, Bool) return constructorof(N)(T; val=randn(rng, T)) @@ -255,7 +259,7 @@ end """Select a random node, and splice it out of the tree.""" function delete_random_op!( ex::AbstractExpression{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -265,7 +269,7 @@ function delete_random_op!( end function delete_random_op!( tree::AbstractExpressionNode{T}, - options::Options, + options::AbstractOptions, nfeatures::Int, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} @@ -310,7 +314,11 @@ end """Create a random equation by appending random operators""" function gen_random_tree( - length::Int, options::Options, nfeatures::Int, ::Type{T}, rng::AbstractRNG=default_rng() + length::Int, + options::AbstractOptions, + nfeatures::Int, + ::Type{T}, + rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} # Note that this base tree is just a placeholder; it will be replaced. tree = constructorof(options.node_type)(T; val=convert(T, 1)) @@ -323,7 +331,7 @@ end function gen_random_tree_fixed_size( node_count::Int, - options::Options, + options::AbstractOptions, nfeatures::Int, ::Type{T}, rng::AbstractRNG=default_rng(), diff --git a/src/MutationWeights.jl b/src/MutationWeights.jl index f7c0f8a70..9de15af7d 100644 --- a/src/MutationWeights.jl +++ b/src/MutationWeights.jl @@ -3,11 +3,78 @@ module MutationWeightsModule using StatsBase: StatsBase """ - MutationWeights(;kws...) + AbstractMutationWeights + +An abstract type that defines the interface for mutation weight structures in the symbolic regression framework. Subtypes of `AbstractMutationWeights` specify how often different mutation operations occur during the mutation process. + +You can create custom mutation weight types by subtyping `AbstractMutationWeights` and defining your own mutation operations. Additionally, you can overload the `sample_mutation` function to handle sampling from your custom mutation types. + +# Usage + +To create a custom mutation weighting scheme with new mutation types, define a new subtype of `AbstractMutationWeights` and implement the necessary fields. Here's an example using `Base.@kwdef` to define the struct with default values: + +```julia +using SymbolicRegression: AbstractMutationWeights + +# Define custom mutation weights with default values +Base.@kwdef struct MyMutationWeights <: AbstractMutationWeights + mutate_constant::Float64 = 0.1 + mutate_operator::Float64 = 0.2 + custom_mutation::Float64 = 0.7 +end +``` + +Next, overload the `sample_mutation` function to include your custom mutation types: + +```julia +# Define the list of mutation names (symbols) +const MY_MUTATIONS = [ + :mutate_constant, + :mutate_operator, + :custom_mutation +] + +# Import the `sample_mutation` function to overload it +import SymbolicRegression: sample_mutation +using StatsBase: StatsBase + +# Overload the `sample_mutation` function +function sample_mutation(w::MyMutationWeights) + weights = [ + w.mutate_constant, + w.mutate_operator, + w.custom_mutation + ] + weights = weights ./ sum(weights) # Normalize weights to sum to 1.0 + return StatsBase.sample(MY_MUTATIONS, StatsBase.Weights(weights)) +end + +# Pass it when defining `Options`: +using SymbolicRegression: Options +options = Options(mutation_weights=MyMutationWeights()) +``` + +This allows you to customize the mutation sampling process to include your custom mutations according to their specified weights. + +To integrate your custom mutations into the mutation process, ensure that the mutation functions corresponding to your custom mutation types are defined and properly registered with the symbolic regression framework. You may need to define methods for `mutate!` that handle your custom mutation types. + +# See Also + +- [`MutationWeights`](@ref): A concrete implementation of `AbstractMutationWeights` that defines default mutation weightings. +- [`sample_mutation`](@ref): Function to sample a mutation based on current mutation weights. +- [`mutate!`](@ref SymbolicRegression.MutateModule.mutate!): Function to apply a mutation to an expression tree. +- [`AbstractOptions`](@ref SymbolicRegression.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options. +""" +abstract type AbstractMutationWeights end + +""" + MutationWeights(;kws...) <: AbstractMutationWeights This defines how often different mutations occur. These weightings will be normalized to sum to 1.0 after initialization. + # Arguments + - `mutate_constant::Float64`: How often to mutate a constant. - `mutate_operator::Float64`: How often to mutate an operator. - `swap_operands::Float64`: How often to swap the operands of a binary operator. @@ -27,8 +94,12 @@ will be normalized to sum to 1.0 after initialization. - `break_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. Otherwise, this will automatically be set to 0.0. How often to break a connection between two nodes. + +# See Also + +- [`AbstractMutationWeights`](@ref SymbolicRegression.CoreModule.MutationWeightsModule.AbstractMutationWeights): Use to define custom mutation weight types. """ -Base.@kwdef mutable struct MutationWeights +Base.@kwdef mutable struct MutationWeights <: AbstractMutationWeights mutate_constant::Float64 = 0.048 mutate_operator::Float64 = 0.47 swap_operands::Float64 = 0.1 @@ -60,7 +131,7 @@ let contents = [Expr(:., :w, QuoteNode(field)) for field in mutations] end """Sample a mutation, given the weightings.""" -function sample_mutation(w::MutationWeights) +function sample_mutation(w::AbstractMutationWeights) weights = convert(Vector, w) return StatsBase.sample(v_mutations, StatsBase.Weights(weights)) end diff --git a/src/Options.jl b/src/Options.jl index 0815bf744..701fde2ff 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -25,7 +25,7 @@ using ..OperatorsModule: safe_sqrt, safe_acosh, atanh_clip -using ..MutationWeightsModule: MutationWeights, mutations +using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: max_ops, @save_kwargs, @ignore @@ -199,7 +199,7 @@ function inverse_unaopmap(op::F) where {F} return op end -create_mutation_weights(w::MutationWeights) = w +create_mutation_weights(w::AbstractMutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) const deprecated_options_mapping = Base.ImmutableDict( @@ -282,7 +282,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators - `DWDMarginLoss(q)`. - `loss_function`: Alternatively, you may redefine the loss used as any function of `tree::AbstractExpressionNode{T}`, `dataset::Dataset{T}`, - and `options::Options`, so long as you output a non-negative + and `options::AbstractOptions`, so long as you output a non-negative scalar of type `T`. This is useful if you want to use a loss that takes into account derivatives, or correlations across the dataset. This also means you could use a custom evaluation @@ -388,7 +388,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators - `probability_negate_constant`: Probability of negating a constant in the equation when mutating it. - `mutation_weights`: Relative probabilities of the mutations. The struct - `MutationWeights` should be passed to these options. + `MutationWeights` (or any `AbstractMutationWeights`) should be passed to these options. See its documentation on `MutationWeights` for the different weights. - `crossover_probability`: Probability of performing crossover. - `annealing`: Whether to use simulated annealing. @@ -429,7 +429,7 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators """ """ - Options(;kws...) + Options(;kws...) <: AbstractOptions Construct options for `equation_search` and other functions. The current arguments have been tuned using the median values from @@ -471,7 +471,7 @@ $(OPTION_DESCRIPTIONS) annealing::Bool=false, batching::Bool=false, batch_size::Integer=50, - mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(), + mutation_weights::Union{AbstractMutationWeights,AbstractVector,NamedTuple}=MutationWeights(), crossover_probability::Real=0.066, warmup_maxsize_by::Real=0.0, use_frequency::Bool=true, @@ -737,6 +737,7 @@ $(OPTION_DESCRIPTIONS) node_type, expression_type, typeof(expression_options), + typeof(set_mutation_weights), turbo, bumper, deprecated_return_state, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 2417b75ba..ba20fe92c 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -6,7 +6,7 @@ using DynamicExpressions: AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum using LossFunctions: SupervisedLoss -import ..MutationWeightsModule: MutationWeights +import ..MutationWeightsModule: AbstractMutationWeights """ This struct defines how complexity is calculated. @@ -121,17 +121,71 @@ else @eval operator_specialization(O::Type{<:OperatorEnum}) = O end +""" + AbstractOptions + +An abstract type that stores all search hyperparameters for SymbolicRegression.jl. +The standard implementation is [`Options`](@ref). + +You may wish to create a new subtypes of `AbstractOptions` to override certain functions +or create new behavior. Ensure that this new type has all properties of [`Options`](@ref). + +For example, if we have new options that we want to add to `Options`: + +```julia +Base.@kwdef struct MyNewOptions + a::Float64 = 1.0 + b::Int = 3 +end +``` + +we can create a combined options type that forwards properties to each corresponding type: + +```julia +struct MyOptions{O<:SymbolicRegression.Options} <: SymbolicRegression.AbstractOptions + new_options::MyNewOptions + sr_options::O +end +const NEW_OPTIONS_KEYS = fieldnames(MyNewOptions) + +# Constructor with both sets of parameters: +function MyOptions(; kws...) + new_options_keys = filter(k -> k in NEW_OPTIONS_KEYS, keys(kws)) + new_options = MyNewOptions(; NamedTuple(new_options_keys .=> Tuple(kws[k] for k in new_options_keys))...) + sr_options_keys = filter(k -> !(k in NEW_OPTIONS_KEYS), keys(kws)) + sr_options = SymbolicRegression.Options(; NamedTuple(sr_options_keys .=> Tuple(kws[k] for k in sr_options_keys))...) + return MyOptions(new_options, sr_options) +end + +# Make all `Options` available while also making `new_options` accessible +function Base.getproperty(options::MyOptions, k::Symbol) + if k in NEW_OPTIONS_KEYS + return getproperty(getfield(options, :new_options), k) + else + return getproperty(getfield(options, :sr_options), k) + end +end + +Base.propertynames(options::MyOptions) = (NEW_OPTIONS_KEYS..., fieldnames(SymbolicRegression.Options)...) +``` + +which would let you access `a` and `b` from `MyOptions` objects, as well as making +all properties of `Options` available for internal methods in SymbolicRegression.jl +""" +abstract type AbstractOptions end + struct Options{ CM<:ComplexityMapping, OP<:AbstractOperatorEnum, N<:AbstractExpressionNode, E<:AbstractExpression, EO<:NamedTuple, + MW<:AbstractMutationWeights, _turbo, _bumper, _return_state, AD, -} +} <: AbstractOptions operators::OP bin_constraints::Vector{Tuple{Int,Int}} una_constraints::Vector{Int} @@ -156,7 +210,7 @@ struct Options{ annealing::Bool batching::Bool batch_size::Int - mutation_weights::MutationWeights + mutation_weights::MW crossover_probability::Float32 warmup_maxsize_by::Float32 use_frequency::Bool @@ -223,6 +277,7 @@ function Base.print(io::IO, options::Options) end Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) +specialized_options(options::AbstractOptions) = options @unstable function specialized_options(options::Options) return _specialized_options(options) end diff --git a/src/PopMember.jl b/src/PopMember.jl index 84f29f451..63cfebfb4 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -2,7 +2,7 @@ module PopMemberModule using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, create_expression +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: score_func @@ -61,7 +61,7 @@ function PopMember( t::AbstractExpression{T}, score::L, loss::L, - options::Union{Options,Nothing}=nothing, + options::Union{AbstractOptions,Nothing}=nothing, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, @@ -93,7 +93,7 @@ end PopMember( dataset::Dataset{T,L}, t::AbstractExpression{T}, - options::Options + options::AbstractOptions ) Create a population member with a birth date at the current time. @@ -103,12 +103,12 @@ Automatically compute the score for this tree. - `dataset::Dataset{T,L}`: The dataset to evaluate the tree on. - `t::AbstractExpression{T}`: The tree for the population member. -- `options::Options`: What options to use. +- `options::AbstractOptions`: What options to use. """ function PopMember( dataset::Dataset{T,L}, tree::Union{AbstractExpressionNode{T},AbstractExpression{T}}, - options::Options, + options::AbstractOptions, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, @@ -148,7 +148,7 @@ end # Can read off complexity directly from pop members function compute_complexity( - member::PopMember, options::Options; break_sharing=Val(false) + member::PopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = getfield(member, :complexity) complexity == -1 && return recompute_complexity!(member, options; break_sharing) @@ -156,7 +156,7 @@ function compute_complexity( return complexity end function recompute_complexity!( - member::PopMember, options::Options; break_sharing=Val(false) + member::PopMember, options::AbstractOptions; break_sharing=Val(false) )::Int complexity = compute_complexity(member.tree, options; break_sharing) setfield!(member, :complexity, complexity) diff --git a/src/Population.jl b/src/Population.jl index 9bc1df309..3a544730b 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -3,7 +3,7 @@ module PopulationModule using StatsBase: StatsBase using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics @@ -27,14 +27,14 @@ end """ Population(dataset::Dataset{T,L}; - population_size, nlength::Int=3, options::Options, + population_size, nlength::Int=3, options::AbstractOptions, nfeatures::Int) Create random population and score them on the dataset. """ function Population( dataset::Dataset{T,L}; - options::Options, + options::AbstractOptions, population_size=nothing, nlength::Int=3, nfeatures::Int, @@ -62,7 +62,7 @@ end """ Population(X::AbstractMatrix{T}, y::AbstractVector{T}; population_size, nlength::Int=3, - options::Options, nfeatures::Int, + options::AbstractOptions, nfeatures::Int, loss_type::Type=Nothing) Create random population and score them on the dataset. @@ -72,7 +72,7 @@ Create random population and score them on the dataset. y::AbstractVector{T}; population_size=nothing, nlength::Int=3, - options::Options, + options::AbstractOptions, nfeatures::Int, loss_type::Type{L}=Nothing, npop=nothing, @@ -99,7 +99,7 @@ function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} end # Sample random members of the population, and make a new one -function sample_pop(pop::P, options::Options)::P where {P<:Population} +function sample_pop(pop::P, options::AbstractOptions)::P where {P<:Population} return Population( StatsBase.sample(pop.members, options.tournament_selection_n; replace=false) ) @@ -109,7 +109,7 @@ end function best_of_sample( pop::Population{T,L,N}, running_search_statistics::RunningSearchStatistics, - options::Options, + options::AbstractOptions, ) where {T,L,N} sample = sample_pop(pop, options) return _best_of_sample( @@ -117,7 +117,9 @@ function best_of_sample( )::PopMember{T,L,N} end function _best_of_sample( - members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::Options + members::Vector{P}, + running_search_statistics::RunningSearchStatistics, + options::AbstractOptions, ) where {T,L,P<:PopMember{T,L}} p = options.tournament_selection_p n = length(members) # == tournament_selection_n @@ -166,7 +168,7 @@ const CACHED_WEIGHTS = PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() end -@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) +@unstable function get_tournament_selection_weights(@nospecialize(options::AbstractOptions)) n = options.tournament_selection_n p = options.tournament_selection_p # Computing the weights for the tournament becomes quite expensive, @@ -179,7 +181,7 @@ const CACHED_WEIGHTS = end function finalize_scores( - dataset::Dataset{T,L}, pop::P, options::Options + dataset::Dataset{T,L}, pop::P, options::AbstractOptions )::Tuple{P,Float64} where {T,L,P<:Population{T,L}} need_recalculate = options.batching num_evals = 0.0 @@ -200,7 +202,7 @@ function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population} return Population(pop.members[best_idx[1:topn]]) end -function record_population(pop::Population, options::Options)::RecordType +function record_population(pop::Population, options::AbstractOptions)::RecordType return RecordType( "population" => [ RecordType( diff --git a/src/Recorder.jl b/src/Recorder.jl index a25ac0e78..171a1f46b 100644 --- a/src/Recorder.jl +++ b/src/Recorder.jl @@ -2,7 +2,7 @@ module RecorderModule using ..CoreModule: RecordType -"Assumes that `options` holds the user options::Options" +"Assumes that `options` holds the user options::AbstractOptions" macro recorder(ex) quote if $(esc(:options)).use_recorder diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index 141e85888..06358a328 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -1,7 +1,7 @@ module RegularizedEvolutionModule using DynamicExpressions: string_tree -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: AbstractOptions, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..PopulationModule: Population, best_of_sample using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutateModule: next_generation, crossover_generation @@ -16,7 +16,7 @@ function reg_evol_cycle( temperature, curmaxsize::Int, running_search_statistics::RunningSearchStatistics, - options::Options, + options::AbstractOptions, record::RecordType, )::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} # Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops. diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index a540e35f6..f1ba1cfb3 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -4,13 +4,13 @@ This includes: process management, stdin reading, checking for early stops.""" module SearchUtilsModule using Printf: @printf, @sprintf -using Distributed: Distributed, @spawnat, Future, procs +using Distributed: Distributed, @spawnat, Future, procs, addprocs using StatsBase: mean using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType +using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember @@ -19,16 +19,33 @@ using ..ProgressBarsModule: WrappedProgressBar, set_multiline_postfix!, manually using ..AdaptiveParsimonyModule: RunningSearchStatistics """ - RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} + AbstractRuntimeOptions + +An abstract type representing runtime configuration parameters for the symbolic regression algorithm. + +`AbstractRuntimeOptions` is used by `equation_search` to control runtime aspects such +as parallelism and iteration limits. By subtyping `AbstractRuntimeOptions`, advanced users +can customize runtime behaviors by passing it to `equation_search`. + +# See Also + +- [`RuntimeOptions`](@ref): Default implementation used by `equation_search`. +- [`equation_search`](@ref SymbolicRegression.equation_search): Main function to perform symbolic regression. +- [`AbstractOptions`](@ref SymbolicRegression.CoreModule.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options. + +""" +abstract type AbstractRuntimeOptions end + +""" + RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} <: AbstractRuntimeOptions Parameters for a search that are passed to `equation_search` directly, rather than set within `Options`. This is to differentiate between parameters that relate to processing and the duration of the search, and parameters dealing with the search hyperparameters itself. """ -Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} +struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} <: AbstractRuntimeOptions niterations::Int64 - total_cycles::Int64 numprocs::Int64 init_procs::Union{Vector{Int},Nothing} addprocs_function::Function @@ -57,6 +74,141 @@ function Base.propertynames(roptions::RuntimeOptions) return (Base.fieldnames(typeof(roptions))..., :parallelism, :dim_out, :return_state) end +@unstable function RuntimeOptions(; + niterations::Int=10, + nout::Int=1, + options::AbstractOptions=Options(), + parallelism=:multithreading, + numprocs::Union{Int,Nothing}=nothing, + procs::Union{Vector{Int},Nothing}=nothing, + addprocs_function::Union{Function,Nothing}=nothing, + heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, + runtests::Bool=true, + return_state::Union{Bool,Nothing,Val}=nothing, + verbosity::Union{Int,Nothing}=nothing, + progress::Union{Bool,Nothing}=nothing, + v_dim_out::Val{DIM_OUT}=Val(nothing), +) where {DIM_OUT} + concurrency = if parallelism in (:multithreading, "multithreading") + :multithreading + elseif parallelism in (:multiprocessing, "multiprocessing") + :multiprocessing + elseif parallelism in (:serial, "serial") + :serial + else + error( + "Invalid parallelism mode: $parallelism. " * + "You must choose one of :multithreading, :multiprocessing, or :serial.", + ) + :serial + end + not_distributed = concurrency in (:multithreading, :serial) + not_distributed && + procs !== nothing && + error( + "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", + ) + not_distributed && + numprocs !== nothing && + error( + "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", + ) + + _return_state = if return_state isa Val + first(typeof(return_state).parameters) + else + if options.return_state === Val(nothing) + return_state === nothing ? false : return_state + else + @assert( + return_state === nothing, + "You cannot set `return_state` in both the `AbstractOptions` and in the passed arguments." + ) + first(typeof(options.return_state).parameters) + end + end + + dim_out = if DIM_OUT === nothing + nout > 1 ? 2 : 1 + else + DIM_OUT + end + _numprocs::Int = if numprocs === nothing + if procs === nothing + 4 + else + length(procs) + end + else + if procs === nothing + numprocs + else + @assert length(procs) == numprocs + numprocs + end + end + + _verbosity = if verbosity === nothing && options.verbosity === nothing + 1 + elseif verbosity === nothing && options.verbosity !== nothing + options.verbosity + elseif verbosity !== nothing && options.verbosity === nothing + verbosity + else + error( + "You cannot set `verbosity` in both the search parameters `AbstractOptions` and the call to `equation_search`.", + ) + 1 + end + _progress::Bool = if progress === nothing && options.progress === nothing + (_verbosity > 0) && nout == 1 + elseif progress === nothing && options.progress !== nothing + options.progress + elseif progress !== nothing && options.progress === nothing + progress + else + error( + "You cannot set `progress` in both the search parameters `AbstractOptions` and the call to `equation_search`.", + ) + false + end + + _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function + + exeflags = if concurrency == :multiprocessing + heap_size_hint_in_megabytes = floor( + Int, ( + if heap_size_hint_in_bytes === nothing + (Sys.free_memory() / _numprocs) + else + heap_size_hint_in_bytes + end + ) / 1024^2 + ) + _verbosity > 0 && + heap_size_hint_in_bytes === nothing && + @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter." + + `--heap-size=$(heap_size_hint_in_megabytes)M` + else + `` + end + + return RuntimeOptions{concurrency,dim_out,_return_state}( + niterations, + _numprocs, + procs, + _addprocs_function, + exeflags, + runtests, + _verbosity, + _progress, + Val(concurrency), + Val(dim_out), + Val(_return_state), + ) +end + """A simple dictionary to track worker allocations.""" const WorkerAssignments = Dict{Tuple{Int,Int},Int} @@ -126,7 +278,7 @@ macro sr_spawner(expr, kws...) end function init_dummy_pops( - npops::Int, datasets::Vector{D}, options::Options + npops::Int, datasets::Vector{D}, options::AbstractOptions ) where {T,L,D<:Dataset{T,L}} prototype = Population( first(datasets); @@ -201,14 +353,14 @@ function check_for_user_quit(reader::StdinReader)::Bool return false end -function check_for_loss_threshold(halls_of_fame, options::Options)::Bool +function check_for_loss_threshold(halls_of_fame, options::AbstractOptions)::Bool return _check_for_loss_threshold(halls_of_fame, options.early_stop_condition, options) end -function _check_for_loss_threshold(_, ::Nothing, ::Options) +function _check_for_loss_threshold(_, ::Nothing, ::AbstractOptions) return false end -function _check_for_loss_threshold(halls_of_fame, f::F, options::Options) where {F} +function _check_for_loss_threshold(halls_of_fame, f::F, options::AbstractOptions) where {F} return all(halls_of_fame) do hof any(hof.members[hof.exists]) do member f(member.loss, compute_complexity(member, options))::Bool @@ -216,12 +368,12 @@ function _check_for_loss_threshold(halls_of_fame, f::F, options::Options) where end end -function check_for_timeout(start_time::Float64, options::Options)::Bool +function check_for_timeout(start_time::Float64, options::AbstractOptions)::Bool return options.timeout_in_seconds !== nothing && time() - start_time > options.timeout_in_seconds::Float64 end -function check_max_evals(num_evals, options::Options)::Bool +function check_max_evals(num_evals, options::AbstractOptions)::Bool return options.max_evals !== nothing && options.max_evals::Int <= sum(sum, num_evals) end @@ -277,7 +429,7 @@ function update_progress_bar!( progress_bar::WrappedProgressBar, hall_of_fame::HallOfFame{T,L}, dataset::Dataset{T,L}, - options::Options, + options::AbstractOptions, equation_speed::Vector{Float32}, head_node_occupation::Float64, parallelism=:serial, @@ -306,7 +458,7 @@ end function print_search_state( hall_of_fames, datasets; - options::Options, + options::AbstractOptions, equation_speed::Vector{Float32}, total_cycles::Int, cycles_remaining::Vector{Int}, @@ -370,13 +522,34 @@ end load_saved_population(::Nothing; kws...) = nothing """ - SearchState{PopType,HallOfFameType,WorkerOutputType,ChannelType} + AbstractSearchState{T,L,N} + +An abstract type encapsulating the internal state of the search process during symbolic regression. + +`AbstractSearchState` instances hold information like populations and progress metrics, +used internally by `equation_search`. Subtyping `AbstractSearchState` allows +customization of search state management. + +Look through the source of `equation_search` to see how this is used. + +# See Also + +- [`SearchState`](@ref): Default implementation of `AbstractSearchState`. +- [`equation_search`](@ref SymbolicRegression.equation_search): Function where `AbstractSearchState` is utilized. +- [`AbstractOptions`](@ref SymbolicRegression.CoreModule.OptionsStruct.AbstractOptions): See how to extend abstract types for customizing options. + +""" +abstract type AbstractSearchState{T,L,N<:AbstractExpression{T}} end -The state of a search, including the populations, worker outputs, tasks, and +""" + SearchState{T,L,N,WorkerOutputType,ChannelType} <: AbstractSearchState{T,L,N} + +The state of the search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} +Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} <: + AbstractSearchState{T,L,N} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -396,7 +569,7 @@ Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,Cha end function save_to_file( - dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::Options + dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::AbstractOptions ) where {T,L} output_file = options.output_file if nout > 1 @@ -443,7 +616,9 @@ end For searches where the maxsize gradually increases, this function returns the current maxsize. """ -function get_cur_maxsize(; options::Options, total_cycles::Int, cycles_remaining::Int) +function get_cur_maxsize(; + options::AbstractOptions, total_cycles::Int, cycles_remaining::Int +) cycles_elapsed = total_cycles - cycles_remaining fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles in_warmup_period = fraction_elapsed <= options.warmup_maxsize_by @@ -502,7 +677,7 @@ function construct_datasets( end function update_hall_of_fame!( - hall_of_fame::HallOfFame, members::Vector{PM}, options::Options + hall_of_fame::HallOfFame, members::Vector{PM}, options::AbstractOptions ) where {PM<:PopMember} for member in members size = compute_complexity(member, options) diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 582f25b15..d15e7914c 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -3,7 +3,7 @@ module SingleIterationModule using ADTypes: AutoEnzyme using DynamicExpressions: AbstractExpression, string_tree, simplify_tree!, combine_operators using ..UtilsModule: @threads_if -using ..CoreModule: Options, Dataset, RecordType, create_expression +using ..CoreModule: AbstractOptions, Dataset, RecordType, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: generate_reference using ..PopulationModule: Population, finalize_scores @@ -23,7 +23,7 @@ function s_r_cycle( curmaxsize::Int, running_search_statistics::RunningSearchStatistics; verbosity::Int=0, - options::Options, + options::AbstractOptions, record::RecordType, )::Tuple{ P,HallOfFame{T,L,N},Float64 @@ -98,7 +98,7 @@ function s_r_cycle( end function optimize_and_simplify_population( - dataset::D, pop::P, options::Options, curmaxsize::Int, record::RecordType + dataset::D, pop::P, options::AbstractOptions, curmaxsize::Int, record::RecordType )::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}} array_num_evals = zeros(Float64, pop.n) do_optimization = rand(pop.n) .< options.optimizer_probability diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 8dfb53837..02070f324 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -30,6 +30,7 @@ export Population, compute_complexity, @parse_expression, parse_expression, + @declare_expression_operator, print_tree, string_tree, eval_tree_array, @@ -80,7 +81,6 @@ export Population, using Distributed using Printf: @printf, @sprintf -using PackageExtensionCompat: @require_extensions using Pkg: Pkg using TOML: parsefile using Random: seed!, shuffle! @@ -95,8 +95,10 @@ using DynamicExpressions: NodeSampler, AbstractExpression, AbstractExpressionNode, + ExpressionInterface, @parse_expression, parse_expression, + @declare_expression_operator, copy_node, set_node!, string_tree, @@ -151,6 +153,21 @@ using DynamicExpressions: with_type_parameters LogitDistLoss, QuantileLoss, LogCoshLoss +using Compat: @compat + +@compat public AbstractOptions, +AbstractRuntimeOptions, +RuntimeOptions, +AbstractMutationWeights, +mutate!, +condition_mutation_weights!, +sample_mutation, +MutationResult, +AbstractSearchState, +SearchState +# ^ We can add new functions here based on requests from users. +# However, I don't want to add many functions without knowing what +# users will actually want to overload. # https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15 const PACKAGE_VERSION = try @@ -210,8 +227,11 @@ using .CoreModule: LOSS_TYPE, RecordType, Dataset, + AbstractOptions, Options, + AbstractMutationWeights, MutationWeights, + sample_mutation, plus, sub, mult, @@ -253,12 +273,15 @@ using .PopMemberModule: PopMember, reset_birth! using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve +using .MutateModule: mutate!, condition_mutation_weights!, MutationResult using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population using .ProgressBarsModule: WrappedProgressBar using .RecorderModule: @recorder, find_iteration_from_record using .MigrationModule: migrate! using .SearchUtilsModule: + AbstractSearchState, SearchState, + AbstractRuntimeOptions, RuntimeOptions, WorkerAssignments, DefaultWorkerOutputType, @@ -312,7 +335,7 @@ which is useful for debugging and profiling. More iterations will improve the results. - `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally weight the loss for each `y` by this value (same shape as `y`). -- `options::Options=Options()`: The options for the search, such as +- `options::AbstractOptions=Options()`: The options for the search, such as which operators to use, evolution hyperparameters, etc. - `variable_names::Union{Vector{String}, Nothing}=nothing`: The names of each feature in `X`, which will be used during printing of equations. @@ -390,7 +413,7 @@ function equation_search( y::AbstractMatrix{T}; niterations::Int=10, weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, - options::Options=Options(), + options::AbstractOptions=Options(), variable_names::Union{AbstractVector{String},Nothing}=nothing, display_variable_names::Union{AbstractVector{String},Nothing}=variable_names, y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing, @@ -478,149 +501,23 @@ end function equation_search( datasets::Vector{D}; - niterations::Int=10, - options::Options=Options(), - parallelism=:multithreading, - numprocs::Union{Int,Nothing}=nothing, - procs::Union{Vector{Int},Nothing}=nothing, - addprocs_function::Union{Function,Nothing}=nothing, - heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, - runtests::Bool=true, + options::AbstractOptions=Options(), saved_state=nothing, - return_state::Union{Bool,Nothing,Val}=nothing, - verbosity::Union{Int,Nothing}=nothing, - progress::Union{Bool,Nothing}=nothing, - v_dim_out::Val{DIM_OUT}=Val(nothing), -) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} - concurrency = if parallelism in (:multithreading, "multithreading") - :multithreading - elseif parallelism in (:multiprocessing, "multiprocessing") - :multiprocessing - elseif parallelism in (:serial, "serial") - :serial + runtime_options::Union{AbstractRuntimeOptions,Nothing}=nothing, + runtime_options_kws..., +) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} + runtime_options = if runtime_options === nothing + RuntimeOptions(; options, nout=length(datasets), runtime_options_kws...) else - error( - "Invalid parallelism mode: $parallelism. " * - "You must choose one of :multithreading, :multiprocessing, or :serial.", - ) - :serial - end - not_distributed = concurrency in (:multithreading, :serial) - not_distributed && - procs !== nothing && - error( - "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", - ) - not_distributed && - numprocs !== nothing && - error( - "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", - ) - - _return_state = if return_state isa Val - first(typeof(return_state).parameters) - else - if options.return_state === Val(nothing) - return_state === nothing ? false : return_state - else - @assert( - return_state === nothing, - "You cannot set `return_state` in both the `Options` and in the passed arguments." - ) - first(typeof(options.return_state).parameters) - end - end - - dim_out = if DIM_OUT === nothing - length(datasets) > 1 ? 2 : 1 - else - DIM_OUT - end - _numprocs::Int = if numprocs === nothing - if procs === nothing - 4 - else - length(procs) - end - else - if procs === nothing - numprocs - else - @assert length(procs) == numprocs - numprocs - end - end - - _verbosity = if verbosity === nothing && options.verbosity === nothing - 1 - elseif verbosity === nothing && options.verbosity !== nothing - options.verbosity - elseif verbosity !== nothing && options.verbosity === nothing - verbosity - else - error( - "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.", - ) - 1 - end - _progress::Bool = if progress === nothing && options.progress === nothing - (_verbosity > 0) && length(datasets) == 1 - elseif progress === nothing && options.progress !== nothing - options.progress - elseif progress !== nothing && options.progress === nothing - progress - else - error( - "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.", - ) - false - end - - _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function - - exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing - heap_size_hint_in_megabytes = floor( - Int, ( - if heap_size_hint_in_bytes === nothing - (Sys.free_memory() / _numprocs) - else - heap_size_hint_in_bytes - end - ) / 1024^2 - ) - _verbosity > 0 && - heap_size_hint_in_bytes === nothing && - @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter." - - `--heap-size=$(heap_size_hint_in_megabytes)M` - else - `` + runtime_options end # Underscores here mean that we have mutated the variable - return _equation_search( - datasets, - RuntimeOptions(; - niterations=niterations, - total_cycles=options.populations * niterations, - numprocs=_numprocs, - init_procs=procs, - addprocs_function=_addprocs_function, - exeflags=exeflags, - runtests=runtests, - verbosity=_verbosity, - progress=_progress, - parallelism=Val(concurrency), - dim_out=Val(dim_out), - return_state=Val(_return_state), - ), - options, - saved_state, - ) + return _equation_search(datasets, runtime_options, options, saved_state) end @noinline function _equation_search( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state + datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions, saved_state ) where {D<:Dataset} _validate_options(datasets, ropt, options) state = _create_workers(datasets, ropt, options) @@ -632,7 +529,7 @@ end end function _validate_options( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options + datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions ) where {T,L,D<:Dataset{T,L}} example_dataset = first(datasets) nout = length(datasets) @@ -662,7 +559,7 @@ function _validate_options( return nothing end @stable default_mode = "disable" function _create_workers( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options + datasets::Vector{D}, ropt::AbstractRuntimeOptions, options::AbstractOptions ) where {T,L,D<:Dataset{T,L}} stdin_reader = watch_stream(stdin) @@ -723,10 +620,11 @@ end halls_of_fame = Vector{HallOfFameType}(undef, nout) - cycles_remaining = [ropt.total_cycles for j in 1:nout] + total_cycles = ropt.niterations * options.populations + cycles_remaining = [total_cycles for j in 1:nout] cur_maxsizes = [ - get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j]) - for j in 1:nout + get_cur_maxsize(; options, total_cycles, cycles_remaining=cycles_remaining[j]) for + j in 1:nout ] return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; @@ -749,7 +647,11 @@ end ) end function _initialize_search!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options, saved_state + state::AbstractSearchState{T,L,N}, + datasets, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, + saved_state, ) where {T,L,N} nout = length(datasets) @@ -823,7 +725,10 @@ function _initialize_search!( return nothing end function _warmup_search!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options + state::AbstractSearchState{T,L,N}, + datasets, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, ) where {T,L,N} nout = length(datasets) for j in 1:nout, i in 1:(options.populations) @@ -864,7 +769,10 @@ function _warmup_search!( return nothing end function _main_search_loop!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options + state::AbstractSearchState{T,L,N}, + datasets, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, ) where {T,L,N} ropt.verbosity > 0 && @info "Started!" nout = length(datasets) @@ -1017,8 +925,9 @@ function _main_search_loop!( ) end + total_cycles = ropt.niterations * options.populations state.cur_maxsizes[j] = get_cur_maxsize(; - options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j] + options, total_cycles, cycles_remaining=state.cycles_remaining[j] ) move_window!(state.all_running_search_statistics[j]) if ropt.progress @@ -1062,12 +971,13 @@ function _main_search_loop!( # Dominating pareto curve - must be better than all simpler equations head_node_occupation = estimate_work_fraction(resource_monitor) + total_cycles = ropt.niterations * options.populations print_search_state( state.halls_of_fame, datasets; options, equation_speed, - ropt.total_cycles, + total_cycles, state.cycles_remaining, head_node_occupation, parallelism=ropt.parallelism, @@ -1092,7 +1002,9 @@ function _main_search_loop!( end return nothing end -function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) +function _tear_down!( + state::AbstractSearchState, ropt::AbstractRuntimeOptions, options::AbstractOptions +) close_reader!(state.stdin_reader) # Safely close all processes or threads if ropt.parallelism == :multiprocessing @@ -1107,7 +1019,10 @@ function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) return nothing end function _format_output( - state::SearchState, datasets, ropt::RuntimeOptions, options::Options + state::AbstractSearchState, + datasets, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, ) nout = length(datasets) out_hof = if ropt.dim_out == 1 @@ -1128,7 +1043,7 @@ end @stable default_mode = "disable" function _dispatch_s_r_cycle( in_pop::Population{T,L,N}, dataset::Dataset, - options::Options; + options::AbstractOptions; pop::Int, out::Int, iteration::Int, @@ -1171,10 +1086,6 @@ end include("MLJInterface.jl") using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor -function __init__() - @require_extensions -end - # Hack to get static analysis to work from within tests: @ignore include("../test/runtests.jl") diff --git a/src/Utils.jl b/src/Utils.jl index a667b6987..473845651 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -187,11 +187,6 @@ function _save_kwargs(log_variable::Symbol, fdef::Expr) end end -# Allows using `const` fields in older versions of Julia. -macro constfield(ex) - return esc(VERSION < v"1.8.0" ? ex : Expr(:const, ex)) -end - json3_write(args...) = error("Please load the JSON3.jl package.") """ diff --git a/src/deprecates.jl b/src/deprecates.jl index 54816a408..c8e0b4d57 100644 --- a/src/deprecates.jl +++ b/src/deprecates.jl @@ -3,15 +3,6 @@ using Base: @deprecate import .HallOfFameModule: calculate_pareto_frontier import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size -@deprecate( - gen_random_tree(length::Int, options::Options, nfeatures::Int, t::Type), - gen_random_tree(length, options, nfeatures, t) -) -@deprecate( - gen_random_tree_fixed_size(node_count::Int, options::Options, nfeatures::Int, t::Type), - gen_random_tree_fixed_size(node_count, options, nfeatures, t) -) - @deprecate( calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing, varMap=nothing), calculate_pareto_frontier(hallOfFame) @@ -40,7 +31,7 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size niterations::Int=10, weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, variable_names::Union{Vector{String},Nothing}=nothing, - options::Options=Options(), + options::AbstractOptions=Options(), parallelism=:multithreading, numprocs::Union{Int,Nothing}=nothing, procs::Union{Vector{Int},Nothing}=nothing, @@ -75,7 +66,7 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size EquationSearch( datasets::Vector{D}; niterations::Int=10, - options::Options=Options(), + options::AbstractOptions=Options(), parallelism=:multithreading, numprocs::Union{Int,Nothing}=nothing, procs::Union{Vector{Int},Nothing}=nothing, diff --git a/src/precompile.jl b/src/precompile.jl index 87442695d..13aaac06f 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -30,8 +30,6 @@ macro maybe_compile_workload(mode, ex) end end -const PRECOMPILE_OPTIMIZATION = VERSION >= v"1.9.0-DEV.0" - """`mode=:precompile` will use `@precompile_*` directives; `mode=:compile` runs.""" function do_precompilation(::Val{mode}) where {mode} @maybe_setup_workload mode begin @@ -57,13 +55,12 @@ function do_precompilation(::Val{mode}) where {mode} simplify=1.0, randomize=1.0, do_nothing=1.0, - optimize=PRECOMPILE_OPTIMIZATION ? 1.0 : 0.0, + optimize=1.0, ), fraction_replaced=0.2, fraction_replaced_hof=0.2, define_helper_functions=false, - optimizer_probability=PRECOMPILE_OPTIMIZATION ? 0.05 : 0.0, - should_optimize_constants=PRECOMPILE_OPTIMIZATION, + optimizer_probability=0.05, save_to_file=false, ) state = equation_search( diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 6a4153f66..ef84cc03c 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,4 +3,4 @@ using Aqua Aqua.test_all(SymbolicRegression; ambiguities=false) -VERSION >= v"1.9" && Aqua.test_ambiguities(SymbolicRegression) +Aqua.test_ambiguities(SymbolicRegression) diff --git a/test/test_mlj.jl b/test/test_mlj.jl index 3ef8d7b93..d26773485 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -87,10 +87,8 @@ end @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) @test_throws AssertionError predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1)) + @test_throws "If specifying an equation index during" predict(mach, (data=X,)) + @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1)) end @testitem "Variable names - named outputs" tags = [:part1] begin @@ -112,7 +110,7 @@ end test_outs = predict(mach, X) @test isempty(setdiff((:c1, :c2), keys(test_outs))) @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) - VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( + @test_throws "Variable names do not match fitted" predict( mach, (b1=randn(32), a2=randn(32)) ) end @@ -146,15 +144,13 @@ end rng = MersenneTwister(0) mach = machine(model, randn(rng, 32, 3), randn(rng, 32); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) - VERSION >= v"1.8" && - @test_throws "For single-output regression, please" @quiet(fit!(mach)) + @test_throws "For single-output regression, please" @quiet(fit!(mach)) model = SRRegressor() rng = MersenneTwister(0) mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 2); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) - VERSION >= v"1.8" && - @test_throws "For multi-output regression, please" @quiet(fit!(mach)) + @test_throws "For multi-output regression, please" @quiet(fit!(mach)) model = SRRegressor(; verbosity=0) rng = MersenneTwister(0) diff --git a/test/test_operators.jl b/test/test_operators.jl index 47b83418a..1221ba6c7 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -99,10 +99,9 @@ end @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( ComplexF64, options ) - VERSION >= v"1.8" && - @test_throws "complex plane" SymbolicRegression.assert_operators_well_defined( - ComplexF64, options - ) + @test_throws "complex plane" SymbolicRegression.assert_operators_well_defined( + ComplexF64, options + ) end @testset "Operators which return the wrong type should fail" begin @@ -111,10 +110,9 @@ end @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( Float64, options ) - VERSION >= v"1.8" && - @test_throws "returned an output of type" SymbolicRegression.assert_operators_well_defined( - Float64, options - ) + @test_throws "returned an output of type" SymbolicRegression.assert_operators_well_defined( + Float64, options + ) @test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options) end diff --git a/test/test_print.jl b/test/test_print.jl index fff027504..96500fa0c 100644 --- a/test/test_print.jl +++ b/test/test_print.jl @@ -1,13 +1,19 @@ using SymbolicRegression using SymbolicRegression.UtilsModule: split_string +using DynamicExpressions: DynamicExpressions as DE include("test_params.jl") ## Test Base.print options = Options(; - default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin) + default_params..., + binary_operators=(+, *, /, -), + unary_operators=(cos, sin), + populations=1, ) +@test DE.OperatorEnumConstructionModule.LATEST_OPERATORS[].unaops == (cos, sin) + f = (x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0 tree = f(Node("x1"), Node("x2"), Node("x3")) @@ -26,6 +32,8 @@ equation_search( parallelism=:multithreading, ) +@test DE.OperatorEnumConstructionModule.LATEST_OPERATORS[].unaops == (cos, sin) + s = repr(tree) true_s = "(sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0" @test s == true_s diff --git a/test/test_units.jl b/test/test_units.jl index e35e0a185..a586f5e3c 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -311,8 +311,7 @@ end # TODO: Should return same quantity as input @test typeof(ypred.a[begin]) <: Quantity @test typeof(y.a[begin]) <: RealQuantity - VERSION >= v"1.8" && - @eval @test(typeof(ypred.b[begin]) == typeof(y.b[begin]), broken = true) + @eval @test(typeof(ypred.b[begin]) == typeof(y.b[begin]), broken = true) end end @@ -322,8 +321,7 @@ end X = randn(11, 50) y = randn(50) - VERSION >= v"1.8.0" && - @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg")) + @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg")) end @testitem "Should print units" tags = [:part3] begin