Skip to content

Commit

Permalink
test: of loss function expression on workers
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Feb 9, 2025
1 parent c7877f3 commit 4eedbfe
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
5 changes: 2 additions & 3 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,8 @@ function move_functions_to_workers(
continue
end
ops = (options.loss_function_expression,)
example_inputs = (
create_expression(zero(T), options, dataset), dataset, options
)
ex = create_expression(zero(T), options, dataset)
example_inputs = (ex, dataset, options)
elseif function_set == :complexity_mapping
if !(options.complexity_mapping isa Function)
continue
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ include("test_mlj.jl")
include("test_custom_operators_multiprocessing.jl")
end

@testitem "Testing whether we can move loss function expression to workers." tags = [:part2] begin
include("test_loss_function_expression_multiprocessing.jl")
end

@testitem "Test whether the precompilation script works." tags = [:part2] begin
include("test_precompilation.jl")
end
Expand Down
48 changes: 48 additions & 0 deletions test/test_loss_function_expression_multiprocessing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using SymbolicRegression
using Test

defs = quote
using SymbolicRegression

early_stop(loss, c) = ((loss <= 1e-10) && (c <= 4))
function my_loss_expression(ex::Expression, dataset::Dataset, options::Options)
prediction, complete = eval_tree_array(ex, dataset.X, options)
if !complete
return Inf
end
return sum((prediction .- dataset.y) .^ 2) / dataset.n
end
end

# This is needed as workers are initialized in `Core.Main`!
if (@__MODULE__) != Core.Main
Core.eval(Core.Main, defs)
eval(:(using Main: early_stop, my_loss_expression))
else
eval(defs)
end

X = randn(Float32, 5, 100)
y = @. 2 * cos(X[4, :])

options = SymbolicRegression.Options(;
binary_operators=[*, +],
unary_operators=[cos],
early_stop_condition=early_stop,
loss_function_expression=my_loss_expression,
)

hof = equation_search(
X,
y;
weights=ones(Float32, 100),
options=options,
niterations=1_000_000_000,
numprocs=2,
parallelism=:multiprocessing,
)

@test any(
early_stop(member.loss, length(get_tree(member.tree))) for
member in hof.members[hof.exists]
)

0 comments on commit 4eedbfe

Please sign in to comment.