Skip to content

Commit da57d59

Browse files
authored
Merge branch 'main' into feature-aiomysql
2 parents 0e5100b + 75d6a4d commit da57d59

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tests/mlmodel_sklearn/test_discriminant_analysis_models.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
)
1919

2020
from newrelic.api.background_task import background_task
21+
from newrelic.common.package_version_utils import get_package_version_tuple
22+
23+
SKLEARN_VERSION = get_package_version_tuple("sklearn")
24+
SKLEARN_VERSION_GT_1_6_0 = SKLEARN_VERSION >= (1, 6, 0)
2125

2226

2327
@pytest.mark.parametrize(
@@ -38,8 +42,14 @@ def test_model_methods_wrapped_in_function_trace(discriminant_analysis_model_nam
3842
"QuadraticDiscriminantAnalysis": [
3943
("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.fit", 1),
4044
("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict", 1),
41-
("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_proba", 2),
42-
("Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_log_proba", 1),
45+
(
46+
"Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_proba",
47+
1 if SKLEARN_VERSION_GT_1_6_0 else 2,
48+
),
49+
(
50+
"Function/MLModel/Sklearn/Named/QuadraticDiscriminantAnalysis.predict_log_proba",
51+
2 if SKLEARN_VERSION_GT_1_6_0 else 1,
52+
),
4353
],
4454
}
4555

0 commit comments

Comments
 (0)