Skip to content

Commit 17148cd

Browse files
authored
Merge pull request #81 from russellb/filterblock-multi-value
Allow FilterByValueBlock to handle one or many values
2 parents d6091ff + 3e21ca8 commit 17148cd

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/instructlab/sdg/filterblock.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,21 @@ class FilterByValueBlock(Block):
1313
def __init__(
1414
self, filter_column, filter_value, operation, convert_dtype=None, **batch_kwargs
1515
) -> None:
16+
"""
17+
Initializes a new instance of the FilterByValueBlock class.
18+
19+
Parameters:
20+
- filter_column (str): The name of the column in the dataset to apply the filter on.
21+
- filter_value (any or list of any): The value(s) to filter by.
22+
- operation (callable): A function that takes two arguments (column value and filter value) and returns a boolean indicating whether the row should be included in the filtered dataset.
23+
- convert_dtype (callable, optional): A function to convert the data type of the filter column before applying the filter. Defaults to None.
24+
- **batch_kwargs: Additional kwargs for batch processing.
25+
26+
Returns:
27+
None
28+
"""
1629
super().__init__(block_name=self.__class__.__name__)
17-
self.value = filter_value
30+
self.value = filter_value if isinstance(filter_value, list) else [filter_value]
1831
self.column_name = filter_column
1932
self.operation = operation
2033
self.convert_dtype = convert_dtype
@@ -38,6 +51,8 @@ def generate(self, samples) -> Dataset:
3851
)
3952

4053
return samples.filter(
41-
lambda x: self.operation(x[self.column_name], self.value),
54+
lambda x: any(
55+
self.operation(x[self.column_name], value) for value in self.value
56+
),
4257
num_proc=self.num_procs,
4358
)

tests/test_filterblock.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def setUp(self):
1818
operation=operator.eq,
1919
convert_dtype=int,
2020
)
21+
self.block_with_list = FilterByValueBlock(
22+
filter_column="age",
23+
filter_value=[30, 35],
24+
operation=operator.eq,
25+
convert_dtype=int,
26+
)
2127
self.dataset = Dataset.from_dict(
2228
{"age": ["25", "30", "35", "forty", "45"]},
2329
features=Features({"age": Value("string")}),
@@ -29,3 +35,10 @@ def test_generate_mixed_types(self, mock_logger):
2935
self.assertEqual(len(filtered_dataset), 1)
3036
self.assertEqual(filtered_dataset["age"], [30])
3137
mock_logger.error.assert_called()
38+
39+
@patch("instructlab.sdg.filterblock.logger")
40+
def test_generate_mixed_types_multi_value(self, mock_logger):
41+
filtered_dataset = self.block_with_list.generate(self.dataset)
42+
self.assertEqual(len(filtered_dataset), 2)
43+
self.assertEqual(filtered_dataset["age"], [30, 35])
44+
mock_logger.error.assert_called()

0 commit comments

Comments
 (0)