Skip to content

Commit 01369e9

Browse files
committed
fix: JET flagged issue
1 parent fff15ed commit 01369e9

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

src/LossFunctions.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,12 @@ function eval_loss(
144144
)::L where {T<:DATA_TYPE,L<:LOSS_TYPE}
145145
loss_val = if !isnothing(options.loss_function)
146146
f = options.loss_function::Function
147-
evaluator(f, get_tree(tree)::AbstractExpressionNode, dataset, options, idx)
147+
inner_tree = tree isa AbstractExpression ? get_tree(tree) : tree
148+
evaluator(f, inner_tree, dataset, options, idx)
148149
elseif !isnothing(options.loss_function_expression)
149150
f = options.loss_function_expression::Function
150-
evaluator(f, tree::AbstractExpression, dataset, options, idx)
151+
@assert tree isa AbstractExpression
152+
evaluator(f, tree, dataset, options, idx)
151153
else
152154
_eval_loss(tree, dataset, options, regularization, idx)
153155
end

0 commit comments

Comments
 (0)