Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions data_juicer/core/data/ray_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def process_batch_arrow(table: pyarrow.Table):

try:
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
# Use concurrency= directly for better GPU utilization
# (get_compute_strategy may limit parallelism)
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
Expand All @@ -247,7 +248,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
Expand Down Expand Up @@ -280,7 +281,7 @@ def process_batch_arrow(table: pyarrow.Table):
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
# Use concurrency= directly for better GPU utilization
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
Expand All @@ -290,7 +291,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
compute=compute,
concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
Expand Down
20 changes: 20 additions & 0 deletions data_juicer/core/executor/concurrency_scoping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utility for scoping op concurrency when running partitions concurrently."""


def scope_op_concurrency(op, max_concurrent_partitions: int) -> int:
"""Returns the concurrency a single partition should use for this op.

When multiple partitions run concurrently, each partition should use a
fraction of the total GPU/actor resources to avoid over-subscription.

Args:
op: An operator instance with ``use_ray_actor()`` and ``num_proc``.
max_concurrent_partitions: How many partitions will run in parallel.

Returns:
The concurrency value the partition should pass through to
``map_batches``.
"""
if not op.use_ray_actor() or not op.num_proc or op.num_proc <= 0:
return op.num_proc # CPU ops or auto-mode unchanged
return max(1, op.num_proc // max_concurrent_partitions)
245 changes: 243 additions & 2 deletions data_juicer/core/executor/ray_executor_partitioned.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,60 @@ def _configure_partitioning(self):
logger.warning("Legacy num_partitions detected, overriding partition configuration")

self.partition_mode = mode
self.num_partitions = num_of_partitions
self.partition_size = partition_size
self.max_size_mb = max_size_mb

# Resolve max_concurrent_partitions.
# "auto" (default) → detect from Ray cluster GPU count, fall back to 1.
# Explicit int → use as-is.
raw_max_conc = ConfigAccessor.get(partition_cfg, "max_concurrent_partitions", "auto")
self.max_concurrent_partitions = self._resolve_max_concurrent(raw_max_conc)

# Ensure we have at least as many partitions as concurrent slots,
# otherwise some GPUs would sit idle.
if self.max_concurrent_partitions > num_of_partitions:
logger.info(
f"num_of_partitions ({num_of_partitions}) < "
f"max_concurrent_partitions ({self.max_concurrent_partitions}), "
f"raising num_of_partitions to {self.max_concurrent_partitions}"
)
num_of_partitions = self.max_concurrent_partitions

self.num_partitions = num_of_partitions

if mode == "manual":
logger.info(f"Manual partition mode: using {self.num_partitions} partitions")
else: # auto mode
logger.info(f"Auto partition mode: will determine optimal partitioning based on data characteristics")
logger.info(f"Fallback partition size: {self.partition_size} samples, max {self.max_size_mb} MB")

if self.max_concurrent_partitions > 1:
logger.info(
f"Concurrent partition processing enabled: "
f"max_concurrent_partitions={self.max_concurrent_partitions}"
)

@staticmethod
def _resolve_max_concurrent(raw_value) -> int:
"""Resolve max_concurrent_partitions from config value.

* ``"auto"`` → number of GPUs visible to Ray (falls back to 1).
* An explicit int is returned as-is (minimum 1).
"""
if isinstance(raw_value, str) and raw_value.lower() == "auto":
try:
num_gpus = int(ray.cluster_resources().get("GPU", 0))
except Exception:
num_gpus = 0
Comment thread
cyruszhang marked this conversation as resolved.
Outdated
if num_gpus > 1:
logger.info(
f"Auto-detected {num_gpus} GPUs in Ray cluster, " f"setting max_concurrent_partitions={num_gpus}"
)
return num_gpus
# No GPUs or single GPU → sequential
return 1
return max(1, int(raw_value))

def _configure_auto_partitioning(self, dataset, ops):
"""Configure partitioning using the partition size optimizer for auto mode."""
try:
Expand Down Expand Up @@ -498,6 +542,10 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
f"{partitioning_info.total_rows} total rows"
)

# Branch: concurrent vs sequential partition processing
if self.max_concurrent_partitions > 1:
return self._process_partitions_concurrent(partitions, ops, partitioning_info)

# Process each partition separately with checkpointing
logger.info("Processing partitions with checkpointing support...")
processed_partitions = []
Expand Down Expand Up @@ -541,6 +589,193 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
# Return as RayDataset wrapper
return RayDataset(merged_dataset, cfg=self.cfg)

def _process_partitions_concurrent(self, partitions, ops, partitioning_info):
"""Process partitions concurrently as Ray remote tasks.

Each partition is submitted as a Ray remote task that independently
loads ops from config, scopes concurrency, and processes data with
its own checkpoint manager. Results are collected and unioned.
"""
max_conc = min(self.max_concurrent_partitions, len(partitions))
logger.info(f"Processing {len(partitions)} partitions concurrently " f"(max_concurrent_partitions={max_conc})")

# Serialisable values extracted from self (avoid serialising the executor)
cfg = self.cfg
ckpt_enabled = self.ckpt_manager.checkpoint_enabled
ckpt_strategy = self.ckpt_manager.checkpoint_strategy
ckpt_dir = self.ckpt_manager.ckpt_dir
ckpt_n_ops = getattr(self.ckpt_manager, "checkpoint_n_ops", 1)
ckpt_op_names = getattr(self.ckpt_manager, "checkpoint_op_names", [])
op_fusion_enabled = getattr(cfg, "op_fusion", False)

@ray.remote(num_cpus=0)
def _process_single_partition_task(
partition_data,
partition_id,
cfg,
max_concurrent_partitions,
ckpt_enabled,
ckpt_strategy,
ckpt_dir,
ckpt_n_ops,
ckpt_op_names,
op_fusion_enabled,
):
"""Ray remote task that processes one partition end-to-end."""
from loguru import logger as task_logger

from data_juicer.core.data.ray_dataset import RayDataset
from data_juicer.core.executor.concurrency_scoping import (
scope_op_concurrency,
)
from data_juicer.ops import load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.utils.ckpt_utils import RayCheckpointManager

task_logger.info(f"[Partition {partition_id}] Starting remote processing")

# Re-create ops from config to avoid serialisation issues
task_ops = load_ops(cfg.process)
if op_fusion_enabled:
task_ops = fuse_operators(task_ops)

# Scope concurrency and fix actor mode for each op.
# The remote task has no GPU, so use_cuda() returns False and
# ops default to task mode (model reloads per batch). Force
# actor mode for GPU ops so the model loads once per actor.
for op in task_ops:
if getattr(op, "num_gpus", 0) and op.num_gpus > 0:
op.ray_execution_mode = "actor"
op.num_proc = scope_op_concurrency(op, max_concurrent_partitions)

# Create local checkpoint manager
ckpt_manager = RayCheckpointManager(
ckpt_dir=ckpt_dir,
checkpoint_enabled=ckpt_enabled,
checkpoint_strategy=ckpt_strategy,
checkpoint_n_ops=ckpt_n_ops,
checkpoint_op_names=ckpt_op_names,
)

# Check for existing checkpoint
latest_checkpoint = ckpt_manager.find_latest_checkpoint(partition_id)

# If all ops are already checkpointed, load from checkpoint
if latest_checkpoint and latest_checkpoint[0] >= len(task_ops) - 1:
task_logger.info(f"[Partition {partition_id}] All ops checkpointed, " f"loading from checkpoint")
loaded = ckpt_manager.load_checkpoint(
latest_checkpoint[0],
latest_checkpoint[1],
partition_id,
cfg=cfg,
)
if loaded is not None:
return loaded.data.materialize()

# Determine resume point
start_op_idx = 0
partition_dataset = RayDataset(partition_data, cfg=cfg)

if latest_checkpoint:
loaded = ckpt_manager.load_checkpoint(
latest_checkpoint[0],
latest_checkpoint[1],
partition_id,
cfg=cfg,
)
if loaded is not None:
partition_dataset = loaded
start_op_idx = latest_checkpoint[0] + 1
task_logger.info(f"[Partition {partition_id}] Resuming from op " f"{start_op_idx}")

# Process ops one-by-one with checkpointing
remaining_ops = task_ops[start_op_idx:]
for rel_idx, op in enumerate(remaining_ops):
abs_idx = start_op_idx + rel_idx
task_logger.info(f"[Partition {partition_id}] Processing op {abs_idx}: " f"{op._name}")
partition_dataset = partition_dataset.process([op])

# Checkpoint if needed
if ckpt_manager.should_checkpoint(abs_idx, op._name):
partition_dataset.data = partition_dataset.data.materialize()
ckpt_manager.save_checkpoint(
partition_dataset.data,
abs_idx,
partition_id,
)

# Final materialize
partition_dataset.data = partition_dataset.data.materialize()
return partition_dataset.data

# Submit tasks (skip empty partitions)
futures = {}
for i, partition in enumerate(partitions):
# Skip empty partitions to avoid wasting GPU resources
try:
row_count = partition.count()
except Exception:
row_count = -1 # can't determine, submit anyway
if row_count == 0:
logger.info(f"Partition {i}: empty (0 rows), skipping")
continue

# Check if partition is fully checkpointed before submitting
latest_ckpt = self.ckpt_manager.find_latest_checkpoint(i)
if latest_ckpt and latest_ckpt[0] >= len(ops) - 1:
logger.info(f"Partition {i}: already fully checkpointed, " f"loading from checkpoint")
loaded = self.ckpt_manager.load_checkpoint(latest_ckpt[0], latest_ckpt[1], i, cfg=self.cfg)
if loaded is not None:
futures[i] = loaded.data.materialize()
continue

self._log_event(
event_type=EventType.PARTITION_START,
message=f"Starting concurrent processing of partition " f"{i + 1}/{len(partitions)}",
partition_id=i,
)
futures[i] = _process_single_partition_task.remote(
partition,
i,
cfg,
max_conc,
ckpt_enabled,
ckpt_strategy,
ckpt_dir,
ckpt_n_ops,
ckpt_op_names,
op_fusion_enabled,
)

# Collect results
processed_partitions = []
for i in range(len(partitions)):
result = futures[i]
if isinstance(result, ray.ObjectRef):
try:
result = ray.get(result)
logger.info(f"Partition {i}: completed successfully")
except Exception as e:
logger.error(f"Partition {i}: failed with error: {e}")
raise
processed_partitions.append(result)
self._log_event(
event_type=EventType.PARTITION_COMPLETE,
message=f"Completed concurrent processing of partition " f"{i + 1}/{len(partitions)}",
partition_id=i,
)
Comment thread
cyruszhang marked this conversation as resolved.
Outdated

# Union results
logger.info("Merging concurrently processed partitions...")
if len(processed_partitions) == 1:
merged_dataset = processed_partitions[0]
else:
merged_dataset = processed_partitions[0]
for partition in processed_partitions[1:]:
merged_dataset = merged_dataset.union(partition)
Comment thread
cyruszhang marked this conversation as resolved.

return RayDataset(merged_dataset, cfg=self.cfg)

def _process_with_convergence(self, dataset: RayDataset, ops: List, convergence_points: List[int]):
"""
Process dataset with convergence support for global operations.
Expand Down Expand Up @@ -954,7 +1189,13 @@ def _split_dataset_deterministic(self, dataset: RayDataset) -> tuple:
# Check for existing partitioning info (resumption case)
saved_info = self._load_partitioning_info()

# Split the dataset
# Ensure enough blocks so split() doesn't produce empty partitions.
# split() distributes by blocks — if there are fewer non-empty
# blocks than partitions, some partitions get 0 rows.
# Always repartition to num_partitions to avoid materializing the
# dataset just to check num_blocks() (which kills lazy evaluation).
dataset.data = dataset.data.repartition(self.num_partitions)

logger.info(f"Splitting dataset into {self.num_partitions} partitions (deterministic mode)...")
partitions = dataset.data.split(self.num_partitions)
logger.info(f"Created {len(partitions)} partitions")
Expand Down
Loading
Loading