Skip to content

Commit 86ea424

Browse files
committed
Fix sample_weight docstrings
1 parent 167402a commit 86ea424

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

quantile_forest/_quantile_forest.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ def fit(self, X, y, sample_weight=None, sparse_pickle=False):
127127
sample_weight : array-like of shape (n_samples,), default=None
128128
Sample weights. If None, then samples are equally weighted. Splits
129129
that would create child nodes with net zero or negative weight are
130-
ignored while searching for a split in each node. In the case of
131-
classification, splits are also ignored if they would result in any
132-
single class carrying a negative weight in either child node.
130+
ignored while searching for a split in each node.
133131
134132
sparse_pickle : bool, default=False
135133
Pickle the underlying data structure using a SciPy sparse matrix.
@@ -230,12 +228,12 @@ def _get_y_train_leaves_slice(
230228
X_leaves_bootstrap : array-like of shape (n_samples,)
231229
Leaf node indices of the bootstrap training samples.
232230
233-
sample_weight : array-like of shape (n_samples,), default=None
231+
sample_weight : array-like of shape (n_samples, n_outputs), \
232+
default=None
234233
Sample weights. If None, then samples are equally weighted. Splits
235234
that would create child nodes with net zero or negative weight are
236-
ignored while searching for a split in each node. In the case of
237-
classification, splits are also ignored if they would result in any
238-
single class carrying a negative weight in either child node.
235+
ignored while searching for a split in each node. For each output,
236+
the ordering of the weights correspond to the sorted samples.
239237
240238
leaf_subsample : bool
241239
Subsample leaf nodes. If True, leaves are randomly sampled to size
@@ -261,6 +259,9 @@ def _get_y_train_leaves_slice(
261259
"""
262260
n_outputs = bootstrap_indices.shape[1]
263261

262+
if sample_weight is not None:
263+
sample_weight = np.squeeze(sample_weight)
264+
264265
shape = (max_node_count, n_outputs, max_samples_leaf)
265266
y_train_leaves_slice = np.zeros(shape, dtype=np.int64)
266267

@@ -319,10 +320,12 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None):
319320
The indices that would sort the target values in ascending order.
320321
Used to associate ``est.apply`` outputs with sorted target values.
321322
322-
sample_weight : array-like of shape (n_samples,), default=None
323+
sample_weight : array-like of shape (n_samples, n_outputs), \
324+
default=None
323325
Sample weights. If None, then samples are equally weighted. Splits
324326
that would create child nodes with net zero or negative weight are
325-
ignored while searching for a split in each node.
327+
ignored while searching for a split in each node. For each output,
328+
the ordering of the weights correspond to the sorted samples.
326329
327330
Returns
328331
-------
@@ -394,9 +397,6 @@ def _get_y_train_leaves(self, X, y, sorter=None, sample_weight=None):
394397
if sample_count > max_samples_leaf:
395398
max_samples_leaf = sample_count
396399

397-
if sample_weight is not None:
398-
sample_weight = np.squeeze(sample_weight)
399-
400400
y_train_leaves = [
401401
self._get_y_train_leaves_slice(
402402
bootstrap_indices[:, i],

0 commit comments

Comments
 (0)