Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Feb 10, 2024
1 parent 31a7b51 commit cd737a9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
4 changes: 3 additions & 1 deletion quantile_forest/_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(
}
super().__init__(**init_dict)

self.param_validation = hasattr(self, "_parameter_constraints")

def fit(self, X, y, sample_weight=None, sparse_pickle=False):
"""Build a forest from the training set (X, y).
Expand Down Expand Up @@ -151,7 +153,7 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
self : object
Fitted estimator.
"""
if param_validation:
if self.param_validation:
self._validate_params()
else:
if isinstance(self.max_samples_leaf, (Integral, np.integer)):
Expand Down
21 changes: 10 additions & 11 deletions quantile_forest/tests/test_quantile_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def check_regression_toy(name, weighted_quantile):
)
assert_allclose(y_pred, y_true)

assert regr._more_tags()


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
@pytest.mark.parametrize("weighted_quantile", [True, False])
Expand Down Expand Up @@ -682,17 +684,14 @@ def check_max_samples_leaf(name):
for max_1, max_2 in zip(max_leaf_sizes[::], max_leaf_sizes[1::]):
assert max_1 <= max_2

# Check error if `max_samples_leaf` <= 0.
est = ForestRegressor(n_estimators=1, max_samples_leaf=0)
assert_raises(ValueError, est.fit, X, y)

# Check error if `max_samples_leaf` is float larger than 1.
est = ForestRegressor(n_estimators=1, max_samples_leaf=1.5)
assert_raises(ValueError, est.fit, X, y)

# Check error if `max_samples_leaf` is not int, float, or None.
est = ForestRegressor(n_estimators=1, max_samples_leaf="None")
assert_raises(ValueError, est.fit, X, y)
# Check error if `max_samples_leaf` <= 0, float larger than 1, or string.
for max_samples_leaf in [0, 1.5, "None"]:
for param_validation in [True, False]:
est = ForestRegressor(n_estimators=1, max_samples_leaf=max_samples_leaf)
est.param_validation = param_validation
assert_raises(ValueError, est.fit, X, y)
est.max_samples_leaf = max_samples_leaf
assert_raises(ValueError, est._get_y_train_leaves, X, 1)


@pytest.mark.parametrize("name", FOREST_REGRESSORS)
Expand Down

0 comments on commit cd737a9

Please sign in to comment.