From c7877f3aeae2719926cbf24d16382e65daf3756c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 9 Feb 2025 17:00:10 +0000 Subject: [PATCH 1/2] fix: loss_function_expression in distributed mode --- Project.toml | 2 +- src/Configure.jl | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 47b4d0509..5e31c9d4b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "1.7.0" +version = "1.7.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/Configure.jl b/src/Configure.jl index a16280f48..bae991075 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -125,6 +125,7 @@ function move_functions_to_workers( :elementwise_loss, :early_stop_condition, :loss_function, + :loss_function_expression, :complexity_mapping, ) @@ -157,6 +158,14 @@ function move_functions_to_workers( end ops = (options.loss_function,) example_inputs = (Node(T; val=zero(T)), dataset, options) + elseif function_set == :loss_function_expression + if options.loss_function_expression === nothing + continue + end + ops = (options.loss_function_expression,) + example_inputs = ( + create_expression(zero(T), options, dataset), dataset, options + ) elseif function_set == :complexity_mapping if !(options.complexity_mapping isa Function) continue From 4eedbfea9429ecf739bbc920cca107fc9e22beb8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 9 Feb 2025 21:31:52 +0000 Subject: [PATCH 2/2] test: of loss function expression on workers --- src/Configure.jl | 5 +- test/runtests.jl | 4 ++ ...oss_function_expression_multiprocessing.jl | 48 +++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 test/test_loss_function_expression_multiprocessing.jl diff --git a/src/Configure.jl b/src/Configure.jl index bae991075..d0a31138c 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -163,9 +163,8 @@ function move_functions_to_workers( continue end ops = (options.loss_function_expression,) - example_inputs = ( - create_expression(zero(T), options, dataset), dataset, options - ) + ex = create_expression(zero(T), options, dataset) + example_inputs = (ex, dataset, options) elseif function_set == :complexity_mapping if !(options.complexity_mapping isa Function) continue diff --git a/test/runtests.jl b/test/runtests.jl index 71ebe7f25..c67187a6d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -146,6 +146,10 @@ include("test_mlj.jl") include("test_custom_operators_multiprocessing.jl") end +@testitem "Testing whether we can move loss function expression to workers." tags = [:part2] begin + include("test_loss_function_expression_multiprocessing.jl") +end + @testitem "Test whether the precompilation script works." tags = [:part2] begin include("test_precompilation.jl") end diff --git a/test/test_loss_function_expression_multiprocessing.jl b/test/test_loss_function_expression_multiprocessing.jl new file mode 100644 index 000000000..e44d78c1a --- /dev/null +++ b/test/test_loss_function_expression_multiprocessing.jl @@ -0,0 +1,48 @@ +using SymbolicRegression +using Test + +defs = quote + using SymbolicRegression + + early_stop(loss, c) = ((loss <= 1e-10) && (c <= 4)) + function my_loss_expression(ex::Expression, dataset::Dataset, options::Options) + prediction, complete = eval_tree_array(ex, dataset.X, options) + if !complete + return Inf + end + return sum((prediction .- dataset.y) .^ 2) / dataset.n + end +end + +# This is needed as workers are initialized in `Core.Main`! +if (@__MODULE__) != Core.Main + Core.eval(Core.Main, defs) + eval(:(using Main: early_stop, my_loss_expression)) +else + eval(defs) +end + +X = randn(Float32, 5, 100) +y = @. 2 * cos(X[4, :]) + +options = SymbolicRegression.Options(; + binary_operators=[*, +], + unary_operators=[cos], + early_stop_condition=early_stop, + loss_function_expression=my_loss_expression, +) + +hof = equation_search( + X, + y; + weights=ones(Float32, 100), + options=options, + niterations=1_000_000_000, + numprocs=2, + parallelism=:multiprocessing, +) + +@test any( + early_stop(member.loss, length(get_tree(member.tree))) for + member in hof.members[hof.exists] +)