From 020052fb4692e67bd95b433c84ebbbbc32063c79 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Wed, 26 Mar 2025 20:34:36 +0200 Subject: [PATCH 1/2] conceptual processors implementation, not runnable --- fast_llm/config.py | 89 +++++++++++++++++-- fast_llm/data/preparator/gpt_memmap/config.py | 4 + .../data/preparator/gpt_memmap/prepare.py | 3 + .../hf_processors/configs/agregator.py | 21 +++++ .../preparator/hf_processors/configs/base.py | 19 ++++ .../configs/doc_length_filter.py | 22 +++++ .../implementations/agregator.py | 11 +++ .../implementations/doc_length_filter.py | 7 ++ 8 files changed, 169 insertions(+), 7 deletions(-) create mode 100644 fast_llm/data/preparator/hf_processors/configs/agregator.py create mode 100644 fast_llm/data/preparator/hf_processors/configs/base.py create mode 100644 fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py create mode 100644 fast_llm/data/preparator/hf_processors/implementations/agregator.py create mode 100644 fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py diff --git a/fast_llm/config.py b/fast_llm/config.py index f1c889658..c2c5434e9 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -9,7 +9,16 @@ import yaml -from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value +from fast_llm.utils import ( + Assert, + Tag, + Registry, + get_type_name, + header, + log, + pop_nested_dict_value, + set_nested_dict_value, +) logger = logging.getLogger(__name__) @@ -634,17 +643,17 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None: value = str(value) return value - def to_copy[ - T - ](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T: + def to_copy[T]( + self: T, + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + ) -> T: return self.from_dict(self, *updates, strict=strict) def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]: return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True) - def to_logs[ - T - ]( + def to_logs[T]( self, verbose: int | None = FieldVerboseLevel.core, log_fn: typing.Callable[[str], T] = logger.info, @@ -916,3 +925,69 @@ def __init__(self, config: ConfigType, *args, **kwargs): @property def config(self) -> ConfigType: return self._config + + +@config_class() +class TypeableConfig(Config): + """ + Base Config class that instantiates a subclass type + based on the 'type' field in config files or params. + The root class must define its own _registry, and + subclasses must set a unique _type. Final classes + to be instantiated should have _abstract as False. + """ + + _abstract: typing.ClassVar[bool] = True + _registry: typing.ClassVar[Registry[str, type["TypeableConfig"]] | None] = None + + type_: typing.ClassVar[str | None] = None + type: str | None = Field( + default=None, + desc="Config specifieble type of the class.", + hint=FieldHint.core, + ) + + def _validate(self) -> None: + if self.type is None: + self.type = self.type_ + # Should be handled in `from_dict`, but can fail if instantiating directly. + Assert.eq(self.type, self.__class__.type_) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + type_ = default.get("type") + if type_ is None: + actual_cls = cls + else: + if type_ not in cls._registry: + raise ValueError( + f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" + ) + actual_cls = cls._registry[type_] + Assert.custom(issubclass, actual_cls, cls) + if actual_cls == cls: + return super()._from_dict(default, strict=strict, flat=flat) + else: + return actual_cls._from_dict(default, strict=strict, flat=flat) + + def __init_subclass__(cls) -> None: + registry = getattr(cls, "_registry") + if registry is None: + raise ValueError(f"Sublass {cls.__name__} or one of its parents needs to set __registry") + if cls._abstract and cls.type_ is not None: + # Abstract classes should not have a `type_` + raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") + if cls.type_ is not None: + if cls.type_ in registry: + raise ValueError( + f"Registry {cls._registry.name} already contains type {cls.type_}." + f" Make sure all classes either have a unique or `None` type." + ) + registry[cls.type_] = cls + super().__init_subclass__() diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c37..9b3a790f7 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -8,6 +8,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert +from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig + if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator MEMMAP_DTYPES = { @@ -165,6 +167,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) + processors: AgregatorConfig = Field(default=AgregatorConfig) + def _validate(self) -> None: assert self.tokenizer.path is not None if self.dataset.data_type is not None: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df1..3f970e91d 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -221,6 +221,9 @@ def run(self) -> None: else: tokenize_fn = self._tokenize_batch + # Process dataset before tokenizing + dataset = self._config.processors.apply(dataset) + # Tokenize the dataset in parallel tokenized_dataset = dataset.map( tokenize_fn, diff --git a/fast_llm/data/preparator/hf_processors/configs/agregator.py b/fast_llm/data/preparator/hf_processors/configs/agregator.py new file mode 100644 index 000000000..f3ce4e49f --- /dev/null +++ b/fast_llm/data/preparator/hf_processors/configs/agregator.py @@ -0,0 +1,21 @@ +import datasets + +from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig +from fast_llm.config import Field, Config, config_class + +from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig + +def default_processors(): + """Default processors to apply""" + return [DocLengthFilterConfig()] + + +@config_class +class AgregatorConfig(Config, Applicable): + steps: list[ShardProcessorConfig] = Field(default_factory=default_processors) + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.hf_processors.implementations.agregator import apply + return apply(self, dataset) + + diff --git a/fast_llm/data/preparator/hf_processors/configs/base.py b/fast_llm/data/preparator/hf_processors/configs/base.py new file mode 100644 index 000000000..a77a76a00 --- /dev/null +++ b/fast_llm/data/preparator/hf_processors/configs/base.py @@ -0,0 +1,19 @@ +import abc +import typing +import datasets + +from fast_llm.config import TypeableConfig, config_class +from fast_llm.utils import Registry + + +class Applicable: + @abc.abstractmethod + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + raise NotImplementedError + + +@config_class() +class ShardProcessorConfig(TypeableConfig, Applicable): + _registry: typing.ClassVar[Registry[str, type["ShardProcessorConfig"]] | None] = Registry( + "ShardProcessorConfig", {} + ) diff --git a/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py b/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py new file mode 100644 index 000000000..a344ec5ba --- /dev/null +++ b/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py @@ -0,0 +1,22 @@ +import abc +import typing +import datasets + +from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig +from fast_llm.config import Field, config_class + + +@config_class +class DocLengthFilterConfig(ShardProcessorConfig): + _abstract: typing.ClassVar[bool] = False + type_: typing.ClassVar[str | None] = "length_filter" + + field: str = Field(default='text') + min_length_chars: int = Field(default=0) + max_length_chars: int = Field(default=1_000_000) + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.hf_processors.implementations.doc_length_filter import apply + return apply(self, dataset) + + diff --git a/fast_llm/data/preparator/hf_processors/implementations/agregator.py b/fast_llm/data/preparator/hf_processors/implementations/agregator.py new file mode 100644 index 000000000..9cb3984cc --- /dev/null +++ b/fast_llm/data/preparator/hf_processors/implementations/agregator.py @@ -0,0 +1,11 @@ +import datasets + +from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig + +def apply(config: AgregatorConfig, dataset: datasets.Dataset) -> datasets.Dataset: + # do something before applyting each processor + for step in config.steps: + dataset = step.apply(dataset) + # compute metrics + # save meterics, from all ranks? + return dataset \ No newline at end of file diff --git a/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py b/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py new file mode 100644 index 000000000..7e0a6a8ae --- /dev/null +++ b/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py @@ -0,0 +1,7 @@ +import datasets + +from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig + +def apply(config: DocLengthFilterConfig, dataset: datasets.Dataset) -> datasets.Dataset: + # do dataset.filter eliminating too long or too short docs + return dataset \ No newline at end of file From eac822c2c3747fca0d5d425e8439bde03e80da11 Mon Sep 17 00:00:00 2001 From: Denis Kocetkov Date: Tue, 1 Apr 2025 17:34:05 +0300 Subject: [PATCH 2/2] processors implementaion based on feedback, unittest work for all except for clamav, integration not tested --- fast_llm/config.py | 66 ------- fast_llm/data/preparator/gpt_memmap/config.py | 48 ++--- .../gpt_memmap/distributed_config.py | 38 ++++ .../gpt_memmap/hf_processors/configs.py | 172 ++++++++++++++++++ .../hf_processors/processor_metrics_logger.py | 86 +++++++++ .../gpt_memmap/hf_processors/processors.py | 163 +++++++++++++++++ .../data/preparator/gpt_memmap/prepare.py | 19 +- .../hf_processors/configs/agregator.py | 21 --- .../preparator/hf_processors/configs/base.py | 19 -- .../configs/doc_length_filter.py | 22 --- .../implementations/agregator.py | 11 -- .../implementations/doc_length_filter.py | 7 - tests/data/test_prepare_hf_processors.py | 86 +++++++++ 13 files changed, 575 insertions(+), 183 deletions(-) create mode 100644 fast_llm/data/preparator/gpt_memmap/distributed_config.py create mode 100644 fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py create mode 100644 fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py create mode 100644 fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py delete mode 100644 fast_llm/data/preparator/hf_processors/configs/agregator.py delete mode 100644 fast_llm/data/preparator/hf_processors/configs/base.py delete mode 100644 fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py delete mode 100644 fast_llm/data/preparator/hf_processors/implementations/agregator.py delete mode 100644 fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py create mode 100644 tests/data/test_prepare_hf_processors.py diff --git a/fast_llm/config.py b/fast_llm/config.py index c2c5434e9..ba7ce47e4 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -925,69 +925,3 @@ def __init__(self, config: ConfigType, *args, **kwargs): @property def config(self) -> ConfigType: return self._config - - -@config_class() -class TypeableConfig(Config): - """ - Base Config class that instantiates a subclass type - based on the 'type' field in config files or params. - The root class must define its own _registry, and - subclasses must set a unique _type. Final classes - to be instantiated should have _abstract as False. - """ - - _abstract: typing.ClassVar[bool] = True - _registry: typing.ClassVar[Registry[str, type["TypeableConfig"]] | None] = None - - type_: typing.ClassVar[str | None] = None - type: str | None = Field( - default=None, - desc="Config specifieble type of the class.", - hint=FieldHint.core, - ) - - def _validate(self) -> None: - if self.type is None: - self.type = self.type_ - # Should be handled in `from_dict`, but can fail if instantiating directly. - Assert.eq(self.type, self.__class__.type_) - super()._validate() - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - type_ = default.get("type") - if type_ is None: - actual_cls = cls - else: - if type_ not in cls._registry: - raise ValueError( - f"Unknown {cls._registry.name} type {type_}." f" Available types: {list(cls._registry.keys())}" - ) - actual_cls = cls._registry[type_] - Assert.custom(issubclass, actual_cls, cls) - if actual_cls == cls: - return super()._from_dict(default, strict=strict, flat=flat) - else: - return actual_cls._from_dict(default, strict=strict, flat=flat) - - def __init_subclass__(cls) -> None: - registry = getattr(cls, "_registry") - if registry is None: - raise ValueError(f"Sublass {cls.__name__} or one of its parents needs to set __registry") - if cls._abstract and cls.type_ is not None: - # Abstract classes should not have a `type_` - raise ValueError(f"Abstract class {cls.__name__} has type = {cls.type_}, expected None.") - if cls.type_ is not None: - if cls.type_ in registry: - raise ValueError( - f"Registry {cls._registry.name} already contains type {cls.type_}." - f" Make sure all classes either have a unique or `None` type." - ) - registry[cls.type_] = cls - super().__init_subclass__() diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 9b3a790f7..2f1a24b34 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -1,4 +1,3 @@ -import os import pathlib import typing @@ -8,7 +7,8 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert -from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import HFProcessorConfig, ProcessorsConfig if typing.TYPE_CHECKING: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator @@ -79,39 +79,6 @@ class GPTHuggingfaceDatasetConfig(Config): ) -@config_class -class DatasetPreparatorDistributedConfig(Config): - # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig - - default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1)) - default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0)) - world_size: int = Field( - default=None, - desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", - hint=FieldHint.expert, - valid=check_field(Assert.gt, 0), - ) - rank: int = Field( - default=None, - desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.", - hint=FieldHint.expert, - valid=check_field(Assert.geq, 0), - ) - backend: str = Field( - default="gloo", - desc="Distributed backend to use.", - hint=FieldHint.optional, - ) - - def _validate(self) -> None: - if self.world_size is None: - self.world_size = self.default_world_size - if self.rank is None: - self.rank = self.default_rank - super()._validate() - Assert.in_range(self.rank, 0, self.world_size) - - @config_class() class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): preparator_name: typing.ClassVar[str] = "gpt_memmap" @@ -167,7 +134,8 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): hint=FieldHint.optional, ) - processors: AgregatorConfig = Field(default=AgregatorConfig) + # TODO: Add desc and hint. + processors: ProcessorsConfig = Field(default=ProcessorsConfig) def _validate(self) -> None: assert self.tokenizer.path is not None @@ -175,6 +143,14 @@ def _validate(self) -> None: Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV) super()._validate() + # Propagete datasaet field name and workers count if not set in processors' configs. + for processor_config_field_name in self.processors.get_processor_types_map().keys(): + config: HFProcessorConfig = getattr(self.processors, processor_config_field_name) + if config.field is None: + config.field = self.dataset.field + if config.num_proc is None: + config.num_proc = self.tokenize_workers + @classmethod def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]: from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator diff --git a/fast_llm/data/preparator/gpt_memmap/distributed_config.py b/fast_llm/data/preparator/gpt_memmap/distributed_config.py new file mode 100644 index 000000000..7c653a4ba --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/distributed_config.py @@ -0,0 +1,38 @@ +import os +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + + +@config_class +class DatasetPreparatorDistributedConfig(Config): + # TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig + + default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1)) + default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0)) + world_size: int = Field( + default=None, + desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.gt, 0), + ) + rank: int = Field( + default=None, + desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.", + hint=FieldHint.expert, + valid=check_field(Assert.geq, 0), + ) + backend: str = Field( + default="gloo", + desc="Distributed backend to use.", + hint=FieldHint.optional, + ) + + def _validate(self) -> None: + if self.world_size is None: + self.world_size = self.default_world_size + if self.rank is None: + self.rank = self.default_rank + super()._validate() + Assert.in_range(self.rank, 0, self.world_size) \ No newline at end of file diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py new file mode 100644 index 000000000..e87438179 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py @@ -0,0 +1,172 @@ +import abc +import datasets +import typing + +from fast_llm.config import Config, Configurable, Field, FieldUpdate, config_class +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig + + +# TODO: Add desc and hint to all fields. + + +@config_class +class HFProcessorConfig(Config): + use_processor: bool = Field(default=True) + human_readable_name: str = Field(default="") + batch_size: int | None = Field(default=None) + num_proc: int | None = Field(default=None) + field: str | None = Field(default=None) + + +class HFProcessor[ConfigType: HFProcessorConfig](Configurable[ConfigType], abc.ABC): + config_class: typing.ClassVar[type[HFProcessorConfig]] = HFProcessorConfig + + def __init__(self, config: ConfigType, distributed_config: DatasetPreparatorDistributedConfig, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + self._distributed_config = distributed_config + + @abc.abstractmethod + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + raise NotImplementedError + + +@config_class +class DocLengthFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Document Length Filter") + min_length_chars: int = Field(default=0) + max_length_chars: int = Field(default=1_000_000) + + +class DocLengthFilterProcessor[ConfigType: DocLengthFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[DocLengthFilterProcessorConfig]] = DocLengthFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_doc_length_filter_processor + + return apply_doc_length_filter_processor(self._config, dataset) + + +@config_class +class NGramRepetitionFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="N-Gram Repetition Filter") + n: int = Field(default=5) + max_repetitions: int = Field(default=32) + + +class NGramRepetitionFilterProcessor[ConfigType: NGramRepetitionFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[NGramRepetitionFilterProcessorConfig]] = NGramRepetitionFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import ( + apply_ngram_repetition_filter_processor, + ) + + return apply_ngram_repetition_filter_processor(self._config, dataset) + + +@config_class +class FrequencyBasedFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Frequency-Based Filter") + max_single_word_ratio: float = Field(default=0.3) + max_top_two_word_ratio: float = Field(default=0.5) + + +class FrequencyBasedFilterProcessor[ConfigType: FrequencyBasedFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[FrequencyBasedFilterProcessorConfig]] = FrequencyBasedFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_frequency_based_filter_processor + + return apply_frequency_based_filter_processor(self._config, dataset) + + +@config_class +class BinaryContentFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Binary Content Filter") + max_bin_ratio: float = Field(default=0.5) + + +class BinaryContentFilterProcessor[ConfigType: BinaryContentFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[BinaryContentFilterProcessorConfig]] = BinaryContentFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_binary_content_filter_processor + + return apply_binary_content_filter_processor(self._config, dataset) + + +@config_class +class NumericalContentFilterProcessorConfig(HFProcessorConfig): + human_readable_name: str | None = FieldUpdate(default="Numerical Content Filter") + max_numeric_token_ratio: float = Field(default=0.5) + + +class NumericalContentFilterProcessor[ConfigType: NumericalContentFilterProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[NumericalContentFilterProcessorConfig]] = NumericalContentFilterProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import ( + apply_numerical_content_filter_processor, + ) + + return apply_numerical_content_filter_processor(self._config, dataset) + + +@config_class +class PiiRedactionProcessorConfig(HFProcessorConfig): + use_processor: bool = FieldUpdate(default=False) + human_readable_name: str | None = FieldUpdate(default="PII Redaction Processor") + # TODO: make enum + redaction_method: str = Field(default="remove") # Options: 'remove', 'mask' + + +class PiiRedactionProcessor[ConfigType: PiiRedactionProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[PiiRedactionProcessorConfig]] = PiiRedactionProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_pii_redaction_processor + + return apply_pii_redaction_processor(self._config, self._distributed_config, dataset) + + +@config_class +class MalwareRemovalProcessorConfig(HFProcessorConfig): + use_processor: bool = FieldUpdate(default=False) + human_readable_name: str | None = FieldUpdate(default="Malware Removal Processor") + + +class MalwareRemovalProcessor[ConfigType: MalwareRemovalProcessorConfig](HFProcessor[ConfigType]): + config_class: typing.ClassVar[type[MalwareRemovalProcessorConfig]] = MalwareRemovalProcessorConfig + + def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: + from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_malware_removal_processor + + return apply_malware_removal_processor(self._config, dataset) + + +@config_class +class ProcessorsConfig(Config): + doc_length: DocLengthFilterProcessorConfig = Field(default=DocLengthFilterProcessorConfig) + n_gramms: NGramRepetitionFilterProcessorConfig = Field(default=NGramRepetitionFilterProcessorConfig) + frequency: FrequencyBasedFilterProcessorConfig = Field(default=FrequencyBasedFilterProcessorConfig) + binary: BinaryContentFilterProcessorConfig = Field(default=BinaryContentFilterProcessorConfig) + numerical: NumericalContentFilterProcessorConfig = Field(default=NumericalContentFilterProcessorConfig) + pii: PiiRedactionProcessorConfig = Field(default=PiiRedactionProcessorConfig) + malware: MalwareRemovalProcessorConfig = Field(default=MalwareRemovalProcessorConfig) + + # TODO: add validation so all steps are actual field names + order: list[str] = Field( + default_factory=lambda: ["doc_length", "n_gramms", "frequency", "binary", "numerical", "pii", "malware"] + ) + + def get_processor_types_map(self): + return { + "doc_length": DocLengthFilterProcessor, + "n_gramms": NGramRepetitionFilterProcessor, + "frequency": FrequencyBasedFilterProcessor, + "binary": BinaryContentFilterProcessor, + "numerical": NumericalContentFilterProcessor, + "pii": PiiRedactionProcessor, + "malware": MalwareRemovalProcessor, + } diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py new file mode 100644 index 000000000..79e8cc202 --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/processor_metrics_logger.py @@ -0,0 +1,86 @@ +import datasets +import pathlib +import time +import typing + +import torch +import torch.distributed + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig + + +class ProcessorMetricsLogger: + def __init__( + self, distributed_config: DatasetPreparatorDistributedConfig, field: str, num_proc: int, batch_size: int + ): + self.start_time = None + self.distributed_config = distributed_config + self.field = field + self.num_proc = num_proc + self.batch_size = batch_size + self.local_times = [] + self.local_doc_lengths = [] + self.local_chars = [] + + def start(self): + self.start_time = time.time() + + def stop(self, dataset: datasets.Dataset, step_name: str): + # TODO: seems generated nonsense, rewrite manually + elapsed_time = time.time() - self.start_time + num_rows = len(dataset) + + def compute_doc_lengths(batch): + return {"doc_lengths": [len(doc) for doc in batch[self.field]]} + + doc_lengths = dataset.map( + compute_doc_lengths, batched=True, batch_size=self.batch_size, num_proc=self.num_proc + ) + doc_lengths = sum(doc_lengths["doc_lengths"], []) + num_chars = sum(doc_lengths) + + self.local_times.append(elapsed_time) + self.local_doc_lengths.extend(doc_lengths) + self.local_chars.append(num_chars) + + local_stats = torch.tensor( + [num_rows, num_chars, min(doc_lengths, default=0), max(doc_lengths, default=0)], dtype=torch.long + ) + all_stats = [ + torch.zeros_like(local_stats) for _ in range(torch.distributed.get_world_size(self.process_group)) + ] + + if torch.distributed.is_initialized(): + torch.distributed.all_gather(all_stats, local_stats, group=self.process_group) + + if self.rank == 0: + all_times = torch.tensor(self.local_times) + all_chars = torch.tensor(self.local_chars) + min_time, max_time, avg_time = all_times.min().item(), all_times.max().item(), all_times.mean().item() + min_chars, max_chars, total_chars = all_chars.min().item(), all_chars.max().item(), all_chars.sum().item() + min_doc_length = min(stat[2].item() for stat in all_stats) + max_doc_length = max(stat[3].item() for stat in all_stats) + total_rows = sum(stat[0].item() for stat in all_stats) + + return { + "step_name": step_name, + "elapsed_time": {"min": min_time, "max": max_time, "avg": avg_time}, + "document_length": {"min": min_doc_length, "max": max_doc_length, "total": total_chars}, + "total_rows": total_rows, + } + return None + + @classmethod + def format(cls, metrics: dict[str, typing.Any]): + return ( + f"Processor {metrics['step_name']}' applied, max shard processing time {metrics['elapsed_time']['max']}," + f" number of rows remained in the dataset {metrics['total_rows']}," + f" number of characters remained in the dataset {metrics['document_length']['total']}" + ) + + @classmethod + def save_as_yaml(cls, file_name: pathlib.Path, metrics: list[dict[str, typing.Any]]): + import yaml + + with file_name.with_suffix(".yaml").open("wt") as f: + yaml.safe_dump(metrics, f) diff --git a/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py b/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py new file mode 100644 index 000000000..82cc1ca1c --- /dev/null +++ b/fast_llm/data/preparator/gpt_memmap/hf_processors/processors.py @@ -0,0 +1,163 @@ +import collections +import datasets +import logging +import re + + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import ( + DocLengthFilterProcessorConfig, + NGramRepetitionFilterProcessorConfig, + FrequencyBasedFilterProcessorConfig, + BinaryContentFilterProcessorConfig, + NumericalContentFilterProcessorConfig, + PiiRedactionProcessorConfig, + MalwareRemovalProcessorConfig, +) + + +logger = logging.getLogger(__name__) + +WORD_PATTERN = r"\b\w+(?:'\w+)?\b" +NUMBER_PATTERN = r"\b\d+\b" + + +def apply_doc_length_filter_processor( + config: DocLengthFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + return dataset.filter( + lambda batch: [ + config.min_length_chars <= len(text) <= config.max_length_chars for text in batch[config.field] + ], + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_ngram_repetition_filter_processor( + config: NGramRepetitionFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def has_repeated_ngrams(batch): + results = [] + word_pattern = re.compile(WORD_PATTERN) + for text in batch[config.field]: + words = word_pattern.findall(text) + ngrams = [tuple(words[i : i + config.n]) for i in range(len(words) - config.n + 1)] + ngram_counts = collections.Counter(ngrams) + results.append(max(ngram_counts.values(), default=0) <= config.max_repetitions) + return results + + return dataset.filter( + has_repeated_ngrams, + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_frequency_based_filter_processor( + config: FrequencyBasedFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def exceeds_word_frequency_threshold(batch): + results = [] + word_pattern = re.compile(WORD_PATTERN) + for text in batch[config.field]: + words = word_pattern.findall(text) + total_words = len(words) + word_counts = collections.Counter(words) + most_common = word_counts.most_common(2) + + if most_common and (most_common[0][1] / total_words) > config.max_single_word_ratio: + results.append(False) + elif ( + len(most_common) > 1 + and ((most_common[0][1] + most_common[1][1]) / total_words) > config.max_top_two_word_ratio + ): + results.append(False) + else: + results.append(True) + return results + + return dataset.filter( + exceeds_word_frequency_threshold, + num_proc=config.num_proc, + batched=True, + batch_size=config.batch_size, + ) + + +def apply_binary_content_filter_processor( + config: BinaryContentFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def is_binary(batch): + return [ + not sum(1 for char in text if char.isprintable()) / len(text) < config.max_bin_ratio + for text in batch[config.field] + ] + + return dataset.filter(is_binary, num_proc=config.num_proc, batched=True, batch_size=config.batch_size) + + +def apply_numerical_content_filter_processor( + config: NumericalContentFilterProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + def exceeds_numeric_threshold(batch): + results = [] + number_pattern = re.compile(NUMBER_PATTERN) + for text in batch[config.field]: + tokens = number_pattern.findall(text) + results.append((len(tokens) / max(1, len(text.split()))) <= config.max_numeric_token_ratio) + return results + + return dataset.filter( + exceeds_numeric_threshold, num_proc=config.num_proc, batched=True, batch_size=config.batch_size + ) + + +def apply_pii_redaction_processor( + config: PiiRedactionProcessorConfig, + distributed_condig: DatasetPreparatorDistributedConfig, + dataset: datasets.Dataset, +) -> datasets.Dataset: + # TODO: check if multiprocessing is possible + # TODO: manage explicit model download and loading as now it + # internally install a python package which is not transferable to workrs + + from presidio_analyzer import AnalyzerEngine + from presidio_anonymizer import AnonymizerEngine + + analyzer = AnalyzerEngine() + anonymizer = AnonymizerEngine() + + def redact_pii(batch): + results = [] + for text in batch[config.field]: + entities = analyzer.analyze( + text=text, entities=["PERSON", "EMAIL_ADDRESS", "PHONE_NUMBER", "CREDIT_CARD"], language="en" + ) + if config.redaction_method == "remove": + for result in reversed(entities): + text = text[: result.start] + "" + text[result.end :] + elif config.redaction_method == "mask": + text = anonymizer.anonymize(text, entities).text + else: + raise ValueError(f"Unkown redaction method: {config.redaction_method}") + results.append(text) + return {config.field: results} + + return dataset.map(redact_pii, num_proc=None, batched=True, batch_size=config.batch_size) + + +def apply_malware_removal_processor( + config: MalwareRemovalProcessorConfig, dataset: datasets.Dataset +) -> datasets.Dataset: + # TODO: this is not working, scan_bytes does not exist. + # Rewrite either with downloading virus definitions file, + # loading dataset and running a file or use clamav directly + import clamav + + def is_malicious(batch): + return [not clamav.scan_bytes(text.encode()) for text in batch[config.field]] + + return dataset.filter(is_malicious, num_proc=config.num_proc, batched=True, batch_size=config.batch_size) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 3f970e91d..a643a4120 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -28,6 +28,8 @@ from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import HFProcessor +from fast_llm.data.preparator.gpt_memmap.hf_processors.processor_metrics_logger import ProcessorMetricsLogger logger = logging.getLogger(__name__) @@ -222,7 +224,22 @@ def run(self) -> None: tokenize_fn = self._tokenize_batch # Process dataset before tokenizing - dataset = self._config.processors.apply(dataset) + metrics = [] + pml = ProcessorMetricsLogger() + for processor_config_field_name in self._config.processors.order: + processor: HFProcessor = self._config.processors.get_processor_types_map()[processor_config_field_name]( + config=getattr(self._config.processors, processor_config_field_name), + distributed_config=self._config.distributed, + ) + pml.start() + dataset = processor.apply(dataset) + processor_metrics = pml.stop(dataset, processor.config.human_readable_name) + metrics.append(processor_metrics) + if self._config.distributed.rank == 0: + logger.info(ProcessorMetricsLogger.format(processor_metrics)) + + if self._config.distributed.rank == 0: + ProcessorMetricsLogger.save_as_yaml(pathlib.Path(self._config.output_path) / "processors_log", metrics) # Tokenize the dataset in parallel tokenized_dataset = dataset.map( diff --git a/fast_llm/data/preparator/hf_processors/configs/agregator.py b/fast_llm/data/preparator/hf_processors/configs/agregator.py deleted file mode 100644 index f3ce4e49f..000000000 --- a/fast_llm/data/preparator/hf_processors/configs/agregator.py +++ /dev/null @@ -1,21 +0,0 @@ -import datasets - -from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig -from fast_llm.config import Field, Config, config_class - -from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig - -def default_processors(): - """Default processors to apply""" - return [DocLengthFilterConfig()] - - -@config_class -class AgregatorConfig(Config, Applicable): - steps: list[ShardProcessorConfig] = Field(default_factory=default_processors) - - def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: - from fast_llm.data.preparator.hf_processors.implementations.agregator import apply - return apply(self, dataset) - - diff --git a/fast_llm/data/preparator/hf_processors/configs/base.py b/fast_llm/data/preparator/hf_processors/configs/base.py deleted file mode 100644 index a77a76a00..000000000 --- a/fast_llm/data/preparator/hf_processors/configs/base.py +++ /dev/null @@ -1,19 +0,0 @@ -import abc -import typing -import datasets - -from fast_llm.config import TypeableConfig, config_class -from fast_llm.utils import Registry - - -class Applicable: - @abc.abstractmethod - def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: - raise NotImplementedError - - -@config_class() -class ShardProcessorConfig(TypeableConfig, Applicable): - _registry: typing.ClassVar[Registry[str, type["ShardProcessorConfig"]] | None] = Registry( - "ShardProcessorConfig", {} - ) diff --git a/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py b/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py deleted file mode 100644 index a344ec5ba..000000000 --- a/fast_llm/data/preparator/hf_processors/configs/doc_length_filter.py +++ /dev/null @@ -1,22 +0,0 @@ -import abc -import typing -import datasets - -from fast_llm.data.preparator.hf_processors.configs.base import Applicable, ShardProcessorConfig -from fast_llm.config import Field, config_class - - -@config_class -class DocLengthFilterConfig(ShardProcessorConfig): - _abstract: typing.ClassVar[bool] = False - type_: typing.ClassVar[str | None] = "length_filter" - - field: str = Field(default='text') - min_length_chars: int = Field(default=0) - max_length_chars: int = Field(default=1_000_000) - - def apply(self, dataset: datasets.Dataset) -> datasets.Dataset: - from fast_llm.data.preparator.hf_processors.implementations.doc_length_filter import apply - return apply(self, dataset) - - diff --git a/fast_llm/data/preparator/hf_processors/implementations/agregator.py b/fast_llm/data/preparator/hf_processors/implementations/agregator.py deleted file mode 100644 index 9cb3984cc..000000000 --- a/fast_llm/data/preparator/hf_processors/implementations/agregator.py +++ /dev/null @@ -1,11 +0,0 @@ -import datasets - -from fast_llm.data.preparator.hf_processors.configs.agregator import AgregatorConfig - -def apply(config: AgregatorConfig, dataset: datasets.Dataset) -> datasets.Dataset: - # do something before applyting each processor - for step in config.steps: - dataset = step.apply(dataset) - # compute metrics - # save meterics, from all ranks? - return dataset \ No newline at end of file diff --git a/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py b/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py deleted file mode 100644 index 7e0a6a8ae..000000000 --- a/fast_llm/data/preparator/hf_processors/implementations/doc_length_filter.py +++ /dev/null @@ -1,7 +0,0 @@ -import datasets - -from fast_llm.data.preparator.hf_processors.configs.doc_length_filter import DocLengthFilterConfig - -def apply(config: DocLengthFilterConfig, dataset: datasets.Dataset) -> datasets.Dataset: - # do dataset.filter eliminating too long or too short docs - return dataset \ No newline at end of file diff --git a/tests/data/test_prepare_hf_processors.py b/tests/data/test_prepare_hf_processors.py new file mode 100644 index 000000000..bfaf34245 --- /dev/null +++ b/tests/data/test_prepare_hf_processors.py @@ -0,0 +1,86 @@ +import datasets + +from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig +from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import ( + DocLengthFilterProcessorConfig, + NGramRepetitionFilterProcessorConfig, + FrequencyBasedFilterProcessorConfig, + BinaryContentFilterProcessorConfig, + NumericalContentFilterProcessorConfig, + PiiRedactionProcessorConfig, + MalwareRemovalProcessorConfig, + DocLengthFilterProcessor, + NGramRepetitionFilterProcessor, + FrequencyBasedFilterProcessor, + BinaryContentFilterProcessor, + NumericalContentFilterProcessor, + PiiRedactionProcessor, + MalwareRemovalProcessor, +) + + +def create_test_dataset(data): + return datasets.Dataset.from_dict({"text": data}) + + +def test_doc_length_filter_processor(): + dataset = create_test_dataset(["short", "this is a medium length sentence", "this is a very long text" * 100]) + config = DocLengthFilterProcessorConfig(min_length_chars=10, max_length_chars=50, field="text") + processor = DocLengthFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only one entry should match the criteria + + +def test_ngram_repetition_filter_processor(): + dataset = create_test_dataset( + ["word word word", "word word word word", "unique words here", "repeat repeat repeat repeat repeat"] + ) + config = NGramRepetitionFilterProcessorConfig(n=2, max_repetitions=2, field="text") + processor = NGramRepetitionFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 2 # Only "word word word" and "unique words here" should remain + + +def test_frequency_based_filter_processor(): + dataset = create_test_dataset(["hello hello hello world", "this is fine just because", "spam spam spam spam spam"]) + config = FrequencyBasedFilterProcessorConfig(max_single_word_ratio=0.4, max_top_two_word_ratio=0.6, field="text") + processor = FrequencyBasedFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only "this is fine" should remain + + +def test_binary_content_filter_processor(): + dataset = create_test_dataset(["hello world", b"\x00\x00\x01\x02bin".decode("utf8"), "normal text"]) + config = BinaryContentFilterProcessorConfig(max_bin_ratio=0.5, field="text") + processor = BinaryContentFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 2 # Binary data should be removed + + +def test_numerical_content_filter_processor(): + dataset = create_test_dataset( + ["123 456 789", "some words and 123", "almost all numbers 123 456 789 101112 131415"] + ) + config = NumericalContentFilterProcessorConfig(max_numeric_token_ratio=0.5, field="text") + processor = NumericalContentFilterProcessor(config, DatasetPreparatorDistributedConfig()) + filtered_dataset = processor.apply(dataset) + assert len(filtered_dataset) == 1 # Only "some words and 123" should remain + + +# TODO: Make optional conditioned on library installed +def test_pii_redaction_processor(): + dataset = create_test_dataset(["My name is John Doe", "Contact me at john@example.com", "This is safe text"]) + config = PiiRedactionProcessorConfig(redaction_method="remove", field="text") + processor = PiiRedactionProcessor(config, DatasetPreparatorDistributedConfig()) + processed_dataset = processor.apply(dataset) + assert "John Doe" not in processed_dataset["text"] + assert "john@example.com" not in processed_dataset["text"] + + +# TODO: Make optional conditioned on library installed +# def test_malware_removal_processor(): +# dataset = create_test_dataset(["malicious_code();", "safe text", "virus_payload();"]) +# config = MalwareRemovalProcessorConfig(field="text") +# processor = MalwareRemovalProcessor(config, DatasetPreparatorDistributedConfig()) +# filtered_dataset = processor.apply(dataset) +# assert len(filtered_dataset) == 1 # Only "safe text" should remain