Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ray
from jsonargparse import Namespace
from loguru import logger
from ray.data import ActorPoolStrategy
from ray.data._internal.util import get_compute_strategy

from data_juicer.core.data import DJDataset
Expand Down Expand Up @@ -237,9 +238,16 @@ def process_batch_arrow(table: pyarrow.Table):

try:
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.__class__,
# Repartition right before GPU actor stage to ensure enough data blocks
# Pipeline: Read(streaming) → Repartition → GPU actors
# Note: override_num_blocks cannot be passed to map_batches for actors,
# so we repartition beforehand instead
override_num_blocks = getattr(op, 'override_num_blocks', None)
if override_num_blocks is not None:
self.data = self.data.repartition(override_num_blocks)

compute = ActorPoolStrategy(size=op.num_proc)
map_batches_kwargs = dict(
fn_args=None,
fn_kwargs=None,
fn_constructor_args=op._init_args,
Expand All @@ -251,17 +259,21 @@ def process_batch_arrow(table: pyarrow.Table):
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
self.data = self.data.map_batches(op.__class__, **map_batches_kwargs)
else:
compute = get_compute_strategy(op.process, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.process,
map_batches_kwargs = dict(
batch_size=batch_size,
batch_format="pyarrow",
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
runtime_env=op.runtime_env,
)
override_num_blocks = getattr(op, 'override_num_blocks', None)
if override_num_blocks is not None:
map_batches_kwargs['override_num_blocks'] = override_num_blocks
self.data = self.data.map_batches(op.process, **map_batches_kwargs)
finally:
# Restore original process method
if tracer and should_trace_op(tracer, op._name) and original_process:
Expand All @@ -280,9 +292,16 @@ def process_batch_arrow(table: pyarrow.Table):
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.__class__,
# Repartition AFTER CPU preprocessing but BEFORE GPU actor stage
# Pipeline: Read(streaming) → CPU preprocessing(streaming) → Repartition → GPU actors
# This allows Read → CPU to stream freely without pipeline barriers,
# then repartition splits collapsed blocks into enough pieces for GPU utilization
override_num_blocks = getattr(op, 'override_num_blocks', None)
if override_num_blocks is not None:
self.data = self.data.repartition(override_num_blocks)

compute = ActorPoolStrategy(size=op.num_proc)
map_batches_kwargs = dict(
fn_args=None,
fn_kwargs=None,
fn_constructor_args=op._init_args,
Expand All @@ -294,17 +313,21 @@ def process_batch_arrow(table: pyarrow.Table):
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
self.data = self.data.map_batches(op.__class__, **map_batches_kwargs)
else:
compute = get_compute_strategy(op.compute_stats, concurrency=op.num_proc)
self.data = self.data.map_batches(
op.compute_stats,
map_batches_kwargs = dict(
batch_size=batch_size,
batch_format="pyarrow",
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
runtime_env=op.runtime_env,
)
override_num_blocks = getattr(op, 'override_num_blocks', None)
if override_num_blocks is not None:
map_batches_kwargs['override_num_blocks'] = override_num_blocks
self.data = self.data.map_batches(op.compute_stats, **map_batches_kwargs)
if op.stats_export_path is not None:
self.data.write_json(op.stats_export_path, force_ascii=False)
# Wrap process method with tracer for sample-level collection
Expand Down
33 changes: 33 additions & 0 deletions data_juicer/core/elasticjuicer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
ElasticJuicer: Adaptive Resource Scheduling for Data-Juicer

A system that provides dynamic resource management and OOM prevention for
multimodal data processing pipelines.
"""

__version__ = "0.1.0"

# Core ElasticJuicer classes
from .elastic_juicer import ElasticJuicer
from .scheduler.scheduler_config import SchedulerConfig
from .scheduler.tower import Tower
from .scheduler.captain import Captain, CaptainPool
from .scheduler.micro_scheduler import MicroScheduler


# Lazy import for tuner (requires ray dependency)
def get_pbt_tuner():
from .tuner.pbt_tuner import PBTTuner
return PBTTuner


__all__ = [
"profiler",
"ElasticJuicer",
"SchedulerConfig",
"Tower",
"Captain",
"CaptainPool",
"MicroScheduler",
"get_pbt_tuner",
]
Loading