Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,9 @@ def process_batch_arrow(table: pyarrow.Table):
process_batch_arrow, batch_format="pyarrow", batch_size=DEFAULT_BATCH_SIZE
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
prepare_for_ray_map_batches = getattr(op, "_prepare_for_ray_map_batches", None)
use_instance_for_ray_tasks = bool(prepare_for_ray_map_batches and prepare_for_ray_map_batches())
if op.use_ray_actor() and not use_instance_for_ray_tasks:
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.__class__,
Expand All @@ -301,6 +303,8 @@ def process_batch_arrow(table: pyarrow.Table):
compute=compute,
runtime_env=op.runtime_env,
)
if use_instance_for_ray_tasks:
self.data = self.data.materialize()
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path, force_ascii=False)
# Wrap process method with tracer for sample-level collection
Expand Down
9 changes: 9 additions & 0 deletions data_juicer/ops/deduplicator/ray_basic_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def _ensure_actors(self):
RemoteDedupSet = self._RemoteDedupSet or get_remote_dedup_set()
self._dedup_sets = [RemoteDedupSet.remote() for _ in range(self.dedup_set_num)]

def prepare_for_ray_execution(self):
"""Create shared actors before this backend is serialized to Ray tasks."""
self._ensure_actors()

def is_unique(self, md5_value: str):
self._ensure_actors()
dedup_set_id = int.from_bytes(md5_value.encode(), byteorder="little") % MERSENNE_PRIME % self.dedup_set_num
Expand Down Expand Up @@ -133,6 +137,11 @@ def __init__(
else:
raise ValueError(f"Unknown backend: {backend}")

def _prepare_for_ray_map_batches(self):
if isinstance(self.backend, ActorBackend):
self.backend.prepare_for_ray_execution()
return True

def calculate_hash(self, sample, context=False):
"""Calculate hash value for the sample."""
raise NotImplementedError
Expand Down
194 changes: 191 additions & 3 deletions tests/ops/deduplicator/test_ray_document_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from unittest.mock import patch

from data_juicer.core.data import NestedDataset as Dataset

Expand All @@ -9,8 +10,39 @@

class RayDocumentDeduplicatorTest(DataJuicerTestCaseBase):

def _run_ray_cross_block_dedup(self, samples, op):
dataset = self._build_ray_cross_block_dataset(samples)
dataset.process([op])
return dataset.data.take_all()

def _build_ray_cross_block_dataset(self, samples):
import ray

from data_juicer.core.data.ray_dataset import RayDataset
from data_juicer.utils.constant import Fields

ds_list = [{Fields.stats: {}, **sample} for sample in samples]
return RayDataset(
ray.data.from_items(ds_list, override_num_blocks=len(ds_list)),
cfg={'auto_op_parallelism': False},
auto_op_parallelism=False,
)

def _run_doc_dedup(self, dataset: Dataset, target_list, op):
res_list = self.run_single_op(dataset, op, [op.text_key])
import ray

from data_juicer.core.data.ray_dataset import RayDataset

dataset = RayDataset(
ray.data.from_items(dataset.to_list()),
cfg={'auto_op_parallelism': False},
auto_op_parallelism=False,
)
dataset.process([op])
res_list = [
{op.text_key: sample[op.text_key]}
for sample in dataset.data.take_all()
]
res_list.sort(key=lambda x: x['text'])
target_list.sort(key=lambda x: x['text'])
self.assertEqual(res_list, target_list)
Expand Down Expand Up @@ -47,7 +79,14 @@ def test_english_deduplication(self):
'This paper proposed a novel method on LLM pretraining.'
}]
dataset = self.generate_dataset(ds_list)
op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False)
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=2,
auto_op_parallelism=False,
)
self._run_doc_dedup(dataset, tgt_list, op)

@TEST_TAG("ray")
Expand Down Expand Up @@ -94,9 +133,158 @@ def test_chinese_deduplication(self):
},
]
dataset = self.generate_dataset(ds_list)
op = RayDocumentDeduplicator(lowercase=False, ignore_non_character=False)
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=2,
auto_op_parallelism=False,
)
self._run_doc_dedup(dataset, tgt_list, op)

@TEST_TAG("ray")
def test_ray_actor_backend_deduplicates_across_blocks(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)

res_list = self._run_ray_cross_block_dedup(
[{'text': 'duplicate across ray blocks'} for _ in range(8)],
op,
)

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate across ray blocks')

@TEST_TAG("ray")
def test_ray_actor_execution_mode_still_shares_dedup_sets(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
ray_execution_mode='actor',
)

res_list = self._run_ray_cross_block_dedup(
[{'text': 'duplicate with actor execution mode'} for _ in range(8)],
op,
)

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate with actor execution mode')

@TEST_TAG("ray")
def test_ray_basic_deduplicator_subclasses_share_dedup_sets(self):
from data_juicer.ops.deduplicator.ray_image_deduplicator import RayImageDeduplicator
from data_juicer.ops.deduplicator.ray_video_deduplicator import RayVideoDeduplicator

cases = [
(RayImageDeduplicator, {'images': []}),
(RayVideoDeduplicator, {'videos': []}),
]
for op_cls, sample in cases:
with self.subTest(op_cls=op_cls.__name__):
op = op_cls(
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)

res_list = self._run_ray_cross_block_dedup([sample for _ in range(8)], op)

self.assertEqual(len(res_list), 1)

@TEST_TAG("ray")
def test_repeated_execution_keeps_materialized_dedup_result(self):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)
dataset = self._build_ray_cross_block_dataset([{
'text': 'duplicate across repeated executions',
} for _ in range(8)])

dataset.process([op])
self.assertEqual(dataset.data.count(), 1)
res_list = dataset.data.take_all()

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate across repeated executions')

@TEST_TAG("ray")
def test_stats_export_does_not_consume_dedup_state_before_filter(self):
def materializing_write_json(dataset, *args, **kwargs):
return dataset.count()

with patch('ray.data.Dataset.write_json', materializing_write_json):
op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
stats_export_path='mock_stats_export_path',
)
dataset = self._build_ray_cross_block_dataset([{
'text': 'duplicate with stats export',
} for _ in range(8)])

dataset.process([op])
res_list = dataset.data.take_all()

self.assertEqual(len(res_list), 1)
self.assertEqual(res_list[0]['text'], 'duplicate with stats export')

def test_prepare_for_ray_execution_reuses_existing_actor_handles(self):
from data_juicer.ops.deduplicator.ray_basic_deduplicator import ActorBackend

class RemoteDedupSet:
calls = 0

@classmethod
def remote(cls):
cls.calls += 1
return object()

backend = ActorBackend(dedup_set_num=2, RemoteDedupSet=RemoteDedupSet)

backend.prepare_for_ray_execution()
dedup_sets = backend._dedup_sets
backend.prepare_for_ray_execution()

self.assertIs(backend._dedup_sets, dedup_sets)
self.assertEqual(RemoteDedupSet.calls, 2)

def test_redis_backend_requests_ray_materialization(self):
from data_juicer.ops.deduplicator.ray_basic_deduplicator import RedisBackend

op = RayDocumentDeduplicator(
lowercase=False,
ignore_non_character=False,
dedup_set_num=1,
batch_size=1,
num_proc=4,
auto_op_parallelism=False,
)
op.backend = RedisBackend.__new__(RedisBackend)

self.assertTrue(op._prepare_for_ray_map_batches())


if __name__ == '__main__':
unittest.main()
Loading