diff --git a/src/instructlab/sdg/filterblock.py b/src/instructlab/sdg/filterblock.py index e6d4eb24..794a4071 100644 --- a/src/instructlab/sdg/filterblock.py +++ b/src/instructlab/sdg/filterblock.py @@ -20,13 +20,20 @@ def __init__( self.convert_dtype = convert_dtype self.num_procs = batch_kwargs.get("num_procs", 1) + def _convert_dtype(self, sample): + try: + sample[self.column_name] = self.convert_dtype(sample[self.column_name]) + except ValueError as e: + logger.error( + "Error converting dtype: %s, filling with None to be filtered later", e + ) + sample[self.column_name] = None + return sample + def generate(self, samples) -> Dataset: if self.convert_dtype: samples = samples.map( - lambda x: { - **x, - self.column_name: self.convert_dtype(x[self.column_name]), - }, + self._convert_dtype, num_proc=self.num_procs, ) diff --git a/tests/test_filterblock.py b/tests/test_filterblock.py new file mode 100644 index 00000000..9ee47c65 --- /dev/null +++ b/tests/test_filterblock.py @@ -0,0 +1,31 @@ +# Standard +from unittest.mock import patch +import operator +import unittest + +# Third Party +from datasets import Dataset, Features, Value + +# First Party +from instructlab.sdg.filterblock import FilterByValueBlock + + +class TestFilterByValueBlock(unittest.TestCase): + def setUp(self): + self.block = FilterByValueBlock( + filter_column="age", + filter_value=30, + operation=operator.eq, + convert_dtype=int, + ) + self.dataset = Dataset.from_dict( + {"age": ["25", "30", "35", "forty", "45"]}, + features=Features({"age": Value("string")}), + ) + + @patch("instructlab.sdg.filterblock.logger") + def test_generate_mixed_types(self, mock_logger): + filtered_dataset = self.block.generate(self.dataset) + self.assertEqual(len(filtered_dataset), 1) + self.assertEqual(filtered_dataset["age"], [30]) + mock_logger.error.assert_called()