Skip to content

Commit

Permalink
Merge pull request #411 from MilesCranmer/reduce-test-time
Browse files Browse the repository at this point in the history
Reduce test time
  • Loading branch information
MilesCranmer authored Feb 8, 2025
2 parents 859de9f + 630530a commit 6890e1d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 35 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down
20 changes: 10 additions & 10 deletions test/test_abstract_numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ using Random
include("test_params.jl")

get_base_type(::Type{<:Complex{BT}}) where {BT} = BT
early_stop(loss::L, c) where {L} = ((loss <= L(1e-2)) && (c <= 15))
example_loss(prediction, target) = abs2(prediction - target)

options = SymbolicRegression.Options(;
binary_operators=[+, *, -, /],
unary_operators=[cos],
populations=20,
early_stop_condition=early_stop,
elementwise_loss=example_loss,
)

for T in (ComplexF16, ComplexF32, ComplexF64)
L = get_base_type(T)
@testset "Test search with $T type" begin
X = randn(MersenneTwister(0), T, 1, 100)
y = @. (2 - 0.5im) * cos((1 + 1im) * X[1, :]) |> T

early_stop(loss::L, c) where {L} = ((loss <= L(1e-2)) && (c <= 15))

options = SymbolicRegression.Options(;
binary_operators=[+, *, -, /],
unary_operators=[cos],
populations=20,
early_stop_condition=early_stop,
elementwise_loss=(prediction, target) -> abs2(prediction - target),
)

dataset = Dataset(X, y, L)
hof = if T == ComplexF16
equation_search([dataset]; options=options, niterations=1_000_000_000)
Expand Down
28 changes: 16 additions & 12 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ end
using LoopVectorization: LoopVectorization as _
include("test_params.jl")

binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond]
unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu]
options = Options(; binary_operators, unary_operators)
all_binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond]
all_unary_operators = [
square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu
]

function test_part(tree, Xpart, options)
y, completed = eval_tree_array(tree, Xpart, options)
Expand All @@ -216,21 +217,24 @@ end
eval_warnings = @capture_err begin
y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true)
end
test_info(@test(y[1] y_turbo[1] && eval_warnings == "")) do
test_info(@test(y y_turbo && eval_warnings == "")) do
@info T tree X[:, seed] y y_turbo eval_warnings
end
end

for T in (Float32, Float64),
index_bin in 1:length(binary_operators),
index_una in 1:length(unary_operators)
index_bin in 1:length(all_binary_operators),
index_una in 1:length(all_unary_operators)

x1, x2 = Node(T; feature=1), Node(T; feature=2)
tree = Node(index_bin, x1, Node(index_una, x2))
X = rand(MersenneTwister(0), T, 2, 20)
for seed in 1:20
Xpart = X[:, [seed]]
test_part(tree, Xpart, options)
let
x1, x2 = Node(T; feature=1), Node(T; feature=2)
tree = Node(index_bin, x1, Node(index_una, x2))
options = Options(;
binary_operators=all_binary_operators[[index_bin]],
unary_operators=all_unary_operators[[index_una]],
)
X = rand(MersenneTwister(0), T, 2, 20)
test_part(tree, X, options)
end
end
end
Expand Down
33 changes: 20 additions & 13 deletions test/test_tree_construction.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
using SymbolicRegression
using Random
using Compat: Fix
using SymbolicRegression: eval_loss, score_func, Dataset
using ForwardDiff
include("test_params.jl")

x1 = 2.0

function make_options_maker(binop, unaop; kw...)
@nospecialize binop unaop kw
return Options(;
default_params...,
binary_operators=(+, *, ^, /, binop),
unary_operators=(unaop, abs),
populations=4,
verbosity=(unaop == gamma) ? 0 : Int(1e9),
kw...,
)
end

# Initialize functions in Base....
for unaop in [cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma, safe_acosh]
for binop in [sub]
function make_options(; kw...)
return Options(;
default_params...,
binary_operators=(+, *, ^, /, binop),
unary_operators=(unaop, abs),
populations=4,
verbosity=(unaop == gamma) ? 0 : Int(1e9),
kw...,
)
end
for unaop in
[cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma, safe_acosh],
binop in [sub]

let
make_options = Fix{1}(Fix{2}(make_options_maker, unaop), binop)
options = make_options()
@extend_operators options

Expand All @@ -36,7 +43,7 @@ for unaop in [cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma,

true_result = f_true(x1)

result = eval(Meta.parse(string_tree(const_tree, make_options())))
result = eval(Meta.parse(string_tree(const_tree, options)))

# Test Basics
@test n == 9
Expand Down

0 comments on commit 6890e1d

Please sign in to comment.