diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 538a90bbb9..7503fdd91a 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -275,7 +275,9 @@ def process_batch_arrow(table: pyarrow.Table): process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE ) cached_columns.add(Fields.stats) - if op.use_ray_actor(): + prepare_for_ray_map_batches = getattr(op, "_prepare_for_ray_map_batches", None) + use_instance_for_ray_tasks = bool(prepare_for_ray_map_batches and prepare_for_ray_map_batches()) + if op.use_ray_actor() and not use_instance_for_ray_tasks: compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) self.data = self.data.map_batches( op.__class__, @@ -301,6 +303,8 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + if use_instance_for_ray_tasks: + self.data = self.data.materialize() if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) # Wrap process method with tracer for sample-level collection diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index b7b50172a4..1ff09a8886 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -74,6 +74,10 @@ def _ensure_actors(self): RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set() self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)] + def prepare_for_ray_execution(self): + """Create shared actors before this backend is serialized to Ray tasks.""" + self._ensure_actors() + def is_unique(self, md5_value: str): self._ensure_actors() dedup_set_id = int.from_bytes(md5_value.encode(), byteorder="little") % MERSENNE_PRIME % self.dedup_set_num @@ -133,6 +137,11 @@ def __init__( else: raise ValueError(f"Unknown backend: {backend}") + def _prepare_for_ray_map_batches(self): + if isinstance(self.backend, ActorBackend): + self.backend.prepare_for_ray_execution() + return True + def calculate_hash(self, sample, context=False): """Calculate hash value for the sample.""" raise NotImplementedError diff --git a/tests/ops/deduplicator/test_ray_document_deduplicator.py b/tests/ops/deduplicator/test_ray_document_deduplicator.py index e8bf23183a..12d15c5d78 100644 --- a/tests/ops/deduplicator/test_ray_document_deduplicator.py +++ b/tests/ops/deduplicator/test_ray_document_deduplicator.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch from data_juicer.core.data import NestedDataset as Dataset @@ -9,8 +10,39 @@ class RayDocumentDeduplicatorTest(DataJuicerTestCaseBase): + def _run_ray_cross_block_dedup(self, samples, op): + dataset = self._build_ray_cross_block_dataset(samples) + dataset.process([op]) + return dataset.data.take_all() + + def _build_ray_cross_block_dataset(self, samples): + import ray + + from data_juicer.core.data.ray_dataset import RayDataset + from data_juicer.utils.constant import Fields + + ds_list = [{Fields.stats: {}, **sample} for sample in samples] + return RayDataset( + ray.data.from_items(ds_list, override_num_blocks=len(ds_list)), + cfg={'auto_op_parallelism': False}, + auto_op_parallelism=False, + ) + def _run_doc_dedup(self, dataset: Dataset, target_list, op): - res_list = self.run_single_op(dataset, op, [op.text_key]) + import ray + + from data_juicer.core.data.ray_dataset import RayDataset + + dataset = RayDataset( + ray.data.from_items(dataset.to_list()), + cfg={'auto_op_parallelism': False}, + auto_op_parallelism=False, + ) + dataset.process([op]) + res_list = [ + {op.text_key: sample[op.text_key]} + for sample in dataset.data.take_all() + ] res_list.sort(key=lambda x: x['text']) target_list.sort(key=lambda x: x['text']) self.assertEqual(res_list, target_list) @@ -47,7 +79,14 @@ def test_english_deduplication(self): 'This paper proposed a novel method on LLM pretraining.' }] dataset = self.generate_dataset(ds_list) - op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False) + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=2, + auto_op_parallelism=False, + ) self._run_doc_dedup(dataset, tgt_list, op) @TEST_TAG("ray") @@ -94,9 +133,158 @@ def test_chinese_deduplication(self): }, ] dataset = self.generate_dataset(ds_list) - op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False) + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=2, + auto_op_parallelism=False, + ) self._run_doc_dedup(dataset, tgt_list, op) + @TEST_TAG("ray") + def test_ray_actor_backend_deduplicates_across_blocks(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + + res_list = self._run_ray_cross_block_dedup( + [{'text': 'duplicate across ray blocks'} for _ in range(8)], + op, + ) + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate across ray blocks') + + @TEST_TAG("ray") + def test_ray_actor_execution_mode_still_shares_dedup_sets(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ray_execution_mode='actor', + ) + + res_list = self._run_ray_cross_block_dedup( + [{'text': 'duplicate with actor execution mode'} for _ in range(8)], + op, + ) + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate with actor execution mode') + + @TEST_TAG("ray") + def test_ray_basic_deduplicator_subclasses_share_dedup_sets(self): + from data_juicer.ops.deduplicator.ray_image_deduplicator import RayImageDeduplicator + from data_juicer.ops.deduplicator.ray_video_deduplicator import RayVideoDeduplicator + + cases = [ + (RayImageDeduplicator, {'images': []}), + (RayVideoDeduplicator, {'videos': []}), + ] + for op_cls, sample in cases: + with self.subTest(op_cls=op_cls.__name__): + op = op_cls( + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + + res_list = self._run_ray_cross_block_dedup([sample for _ in range(8)], op) + + self.assertEqual(len(res_list), 1) + + @TEST_TAG("ray") + def test_repeated_execution_keeps_materialized_dedup_result(self): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + dataset = self._build_ray_cross_block_dataset([{ + 'text': 'duplicate across repeated executions', + } for _ in range(8)]) + + dataset.process([op]) + self.assertEqual(dataset.data.count(), 1) + res_list = dataset.data.take_all() + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate across repeated executions') + + @TEST_TAG("ray") + def test_stats_export_does_not_consume_dedup_state_before_filter(self): + def materializing_write_json(dataset, *args, **kwargs): + return dataset.count() + + with patch('ray.data.Dataset.write_json', materializing_write_json): + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + stats_export_path='mock_stats_export_path', + ) + dataset = self._build_ray_cross_block_dataset([{ + 'text': 'duplicate with stats export', + } for _ in range(8)]) + + dataset.process([op]) + res_list = dataset.data.take_all() + + self.assertEqual(len(res_list), 1) + self.assertEqual(res_list[0]['text'], 'duplicate with stats export') + + def test_prepare_for_ray_execution_reuses_existing_actor_handles(self): + from data_juicer.ops.deduplicator.ray_basic_deduplicator import ActorBackend + + class RemoteDedupSet: + calls = 0 + + @classmethod + def remote(cls): + cls.calls += 1 + return object() + + backend = ActorBackend(dedup_set_num=2, RemoteDedupSet=RemoteDedupSet) + + backend.prepare_for_ray_execution() + dedup_sets = backend._dedup_sets + backend.prepare_for_ray_execution() + + self.assertIs(backend._dedup_sets, dedup_sets) + self.assertEqual(RemoteDedupSet.calls, 2) + + def test_redis_backend_requests_ray_materialization(self): + from data_juicer.ops.deduplicator.ray_basic_deduplicator import RedisBackend + + op = RayDocumentDeduplicator( + lowercase=False, + ignore_non_character=False, + dedup_set_num=1, + batch_size=1, + num_proc=4, + auto_op_parallelism=False, + ) + op.backend = RedisBackend.__new__(RedisBackend) + + self.assertTrue(op._prepare_for_ray_map_batches()) + if __name__ == '__main__': unittest.main()