diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py
index 2d8b198565..acca1a8b8f 100644
--- a/data_juicer/core/data/ray_dataset.py
+++ b/data_juicer/core/data/ray_dataset.py
@@ -237,7 +237,8 @@ def process_batch_arrow(table: pyarrow.Table):
try:
if op.use_ray_actor():
- compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
+ # Use concurrency= directly for better GPU utilization
+ # (get_compute_strategy may limit parallelism)
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
@@ -247,7 +248,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
- compute=compute,
+ concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
@@ -280,7 +281,7 @@ def process_batch_arrow(table: pyarrow.Table):
)
cached_columns.add(Fields.stats)
if op.use_ray_actor():
- compute = get_compute_strategy(op.__class__, concurrency=op.num_proc)
+ # Use concurrency= directly for better GPU utilization
self.data = self.data.map_batches(
op.__class__,
fn_args=None,
@@ -290,7 +291,7 @@ def process_batch_arrow(table: pyarrow.Table):
batch_size=batch_size,
num_cpus=op.num_cpus,
num_gpus=op.num_gpus,
- compute=compute,
+ concurrency=op.num_proc,
batch_format="pyarrow",
runtime_env=op.runtime_env,
)
diff --git a/data_juicer/core/executor/concurrency_scoping.py b/data_juicer/core/executor/concurrency_scoping.py
new file mode 100644
index 0000000000..906f8cbe37
--- /dev/null
+++ b/data_juicer/core/executor/concurrency_scoping.py
@@ -0,0 +1,20 @@
+"""Utility for scoping op concurrency when running partitions concurrently."""
+
+
+def scope_op_concurrency(op, max_concurrent_partitions: int) -> int:
+ """Returns the concurrency a single partition should use for this op.
+
+ When multiple partitions run concurrently, each partition should use a
+ fraction of the total GPU/actor resources to avoid over-subscription.
+
+ Args:
+ op: An operator instance with ``use_ray_actor()`` and ``num_proc``.
+ max_concurrent_partitions: How many partitions will run in parallel.
+
+ Returns:
+ The concurrency value the partition should pass through to
+ ``map_batches``.
+ """
+ if not op.use_ray_actor() or not op.num_proc or op.num_proc <= 0:
+ return op.num_proc # CPU ops or auto-mode unchanged
+ return max(1, op.num_proc // max_concurrent_partitions)
diff --git a/data_juicer/core/executor/ray_executor_partitioned.py b/data_juicer/core/executor/ray_executor_partitioned.py
index ddcb1fc442..ff63f01e00 100644
--- a/data_juicer/core/executor/ray_executor_partitioned.py
+++ b/data_juicer/core/executor/ray_executor_partitioned.py
@@ -273,16 +273,61 @@ def _configure_partitioning(self):
logger.warning("Legacy num_partitions detected, overriding partition configuration")
self.partition_mode = mode
- self.num_partitions = num_of_partitions
self.partition_size = partition_size
self.max_size_mb = max_size_mb
+ # Resolve max_concurrent_partitions.
+ # "auto" (default) → detect from Ray cluster GPU count, fall back to 1.
+ # Explicit int → use as-is.
+ raw_max_conc = ConfigAccessor.get(partition_cfg, "max_concurrent_partitions", "auto")
+ self.max_concurrent_partitions = self._resolve_max_concurrent(raw_max_conc)
+
+ # Ensure we have at least as many partitions as concurrent slots,
+ # otherwise some GPUs would sit idle.
+ if self.max_concurrent_partitions > num_of_partitions:
+ logger.info(
+ f"num_of_partitions ({num_of_partitions}) < "
+ f"max_concurrent_partitions ({self.max_concurrent_partitions}), "
+ f"raising num_of_partitions to {self.max_concurrent_partitions}"
+ )
+ num_of_partitions = self.max_concurrent_partitions
+
+ self.num_partitions = num_of_partitions
+
if mode == "manual":
logger.info(f"Manual partition mode: using {self.num_partitions} partitions")
else: # auto mode
logger.info(f"Auto partition mode: will determine optimal partitioning based on data characteristics")
logger.info(f"Fallback partition size: {self.partition_size} samples, max {self.max_size_mb} MB")
+ if self.max_concurrent_partitions > 1:
+ logger.info(
+ f"Concurrent partition processing enabled: "
+ f"max_concurrent_partitions={self.max_concurrent_partitions}"
+ )
+
+ @staticmethod
+ def _resolve_max_concurrent(raw_value) -> int:
+ """Resolve max_concurrent_partitions from config value.
+
+ * ``"auto"`` → number of GPUs visible to Ray (falls back to 1).
+ * An explicit int is returned as-is (minimum 1).
+ """
+ if isinstance(raw_value, str) and raw_value.lower() == "auto":
+ try:
+ num_gpus = int(ray.cluster_resources().get("GPU", 0))
+ except Exception as e:
+ logger.warning(f"Could not get GPU resources from Ray cluster, defaulting to 0. Error: {e}")
+ num_gpus = 0
+ if num_gpus > 1:
+ logger.info(
+ f"Auto-detected {num_gpus} GPUs in Ray cluster, " f"setting max_concurrent_partitions={num_gpus}"
+ )
+ return num_gpus
+ # No GPUs or single GPU → sequential
+ return 1
+ return max(1, int(raw_value))
+
def _configure_auto_partitioning(self, dataset, ops):
"""Configure partitioning using the partition size optimizer for auto mode."""
try:
@@ -498,6 +543,10 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
f"{partitioning_info.total_rows} total rows"
)
+ # Branch: concurrent vs sequential partition processing
+ if self.max_concurrent_partitions > 1:
+ return self._process_partitions_concurrent(partitions, ops, partitioning_info)
+
# Process each partition separately with checkpointing
logger.info("Processing partitions with checkpointing support...")
processed_partitions = []
@@ -541,6 +590,197 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List):
# Return as RayDataset wrapper
return RayDataset(merged_dataset, cfg=self.cfg)
+ def _process_partitions_concurrent(self, partitions, ops, partitioning_info):
+ """Process partitions concurrently as Ray remote tasks.
+
+ Each partition is submitted as a Ray remote task that independently
+ loads ops from config, scopes concurrency, and processes data with
+ its own checkpoint manager. Results are collected and unioned.
+ """
+ max_conc = min(self.max_concurrent_partitions, len(partitions))
+ logger.info(f"Processing {len(partitions)} partitions concurrently " f"(max_concurrent_partitions={max_conc})")
+
+ # Serialisable values extracted from self (avoid serialising the executor)
+ cfg = self.cfg
+ ckpt_enabled = self.ckpt_manager.checkpoint_enabled
+ ckpt_strategy = self.ckpt_manager.checkpoint_strategy
+ ckpt_dir = self.ckpt_manager.ckpt_dir
+ ckpt_n_ops = getattr(self.ckpt_manager, "checkpoint_n_ops", 1)
+ ckpt_op_names = getattr(self.ckpt_manager, "checkpoint_op_names", [])
+ op_fusion_enabled = getattr(cfg, "op_fusion", False)
+
+ @ray.remote(num_cpus=0)
+ def _process_single_partition_task(
+ partition_data,
+ partition_id,
+ cfg,
+ max_concurrent_partitions,
+ ckpt_enabled,
+ ckpt_strategy,
+ ckpt_dir,
+ ckpt_n_ops,
+ ckpt_op_names,
+ op_fusion_enabled,
+ ):
+ """Ray remote task that processes one partition end-to-end."""
+ from loguru import logger as task_logger
+
+ from data_juicer.core.data.ray_dataset import RayDataset
+ from data_juicer.core.executor.concurrency_scoping import (
+ scope_op_concurrency,
+ )
+ from data_juicer.ops import load_ops
+ from data_juicer.ops.op_fusion import fuse_operators
+ from data_juicer.utils.ckpt_utils import RayCheckpointManager
+
+ task_logger.info(f"[Partition {partition_id}] Starting remote processing")
+
+ # Re-create ops from config to avoid serialisation issues
+ task_ops = load_ops(cfg.process)
+ if op_fusion_enabled:
+ task_ops = fuse_operators(task_ops)
+
+ # Scope concurrency and fix actor mode for each op.
+ # The remote task has no GPU, so use_cuda() returns False and
+ # ops default to task mode (model reloads per batch). Force
+ # actor mode for GPU ops so the model loads once per actor.
+ for op in task_ops:
+ if getattr(op, "num_gpus", 0) and op.num_gpus > 0:
+ op.ray_execution_mode = "actor"
+ op.num_proc = scope_op_concurrency(op, max_concurrent_partitions)
+
+ # Create local checkpoint manager
+ ckpt_manager = RayCheckpointManager(
+ ckpt_dir=ckpt_dir,
+ checkpoint_enabled=ckpt_enabled,
+ checkpoint_strategy=ckpt_strategy,
+ checkpoint_n_ops=ckpt_n_ops,
+ checkpoint_op_names=ckpt_op_names,
+ )
+
+ # Check for existing checkpoint
+ latest_checkpoint = ckpt_manager.find_latest_checkpoint(partition_id)
+
+ # If all ops are already checkpointed, load from checkpoint
+ if latest_checkpoint and latest_checkpoint[0] >= len(task_ops) - 1:
+ task_logger.info(f"[Partition {partition_id}] All ops checkpointed, " f"loading from checkpoint")
+ loaded = ckpt_manager.load_checkpoint(
+ latest_checkpoint[0],
+ latest_checkpoint[1],
+ partition_id,
+ cfg=cfg,
+ )
+ if loaded is not None:
+ return loaded.data.materialize()
+
+ # Determine resume point
+ start_op_idx = 0
+ partition_dataset = RayDataset(partition_data, cfg=cfg)
+
+ if latest_checkpoint:
+ loaded = ckpt_manager.load_checkpoint(
+ latest_checkpoint[0],
+ latest_checkpoint[1],
+ partition_id,
+ cfg=cfg,
+ )
+ if loaded is not None:
+ partition_dataset = loaded
+ start_op_idx = latest_checkpoint[0] + 1
+ task_logger.info(f"[Partition {partition_id}] Resuming from op " f"{start_op_idx}")
+
+ # Process ops one-by-one with checkpointing
+ remaining_ops = task_ops[start_op_idx:]
+ for rel_idx, op in enumerate(remaining_ops):
+ abs_idx = start_op_idx + rel_idx
+ task_logger.info(f"[Partition {partition_id}] Processing op {abs_idx}: " f"{op._name}")
+ partition_dataset = partition_dataset.process([op])
+
+ # Checkpoint if needed
+ if ckpt_manager.should_checkpoint(abs_idx, op._name):
+ partition_dataset.data = partition_dataset.data.materialize()
+ ckpt_manager.save_checkpoint(
+ partition_dataset.data,
+ abs_idx,
+ partition_id,
+ )
+
+ # Final materialize
+ partition_dataset.data = partition_dataset.data.materialize()
+ return partition_dataset.data
+
+ # Submit tasks (skip empty partitions)
+ futures = {}
+ for i, partition in enumerate(partitions):
+ # Skip empty partitions to avoid wasting GPU resources
+ try:
+ row_count = partition.count()
+ except Exception:
+ row_count = -1 # can't determine, submit anyway
+ if row_count == 0:
+ logger.info(f"Partition {i}: empty (0 rows), skipping")
+ continue
+
+ # Check if partition is fully checkpointed before submitting
+ latest_ckpt = self.ckpt_manager.find_latest_checkpoint(i)
+ if latest_ckpt and latest_ckpt[0] >= len(ops) - 1:
+ logger.info(f"Partition {i}: already fully checkpointed, " f"loading from checkpoint")
+ loaded = self.ckpt_manager.load_checkpoint(latest_ckpt[0], latest_ckpt[1], i, cfg=self.cfg)
+ if loaded is not None:
+ futures[i] = loaded.data.materialize()
+ continue
+
+ self._log_event(
+ event_type=EventType.PARTITION_START,
+ message=f"Starting concurrent processing of partition " f"{i + 1}/{len(partitions)}",
+ partition_id=i,
+ )
+ futures[i] = _process_single_partition_task.remote(
+ partition,
+ i,
+ cfg,
+ max_conc,
+ ckpt_enabled,
+ ckpt_strategy,
+ ckpt_dir,
+ ckpt_n_ops,
+ ckpt_op_names,
+ op_fusion_enabled,
+ )
+
+ # Collect results
+ processed_partitions = []
+ for i in sorted(futures.keys()):
+ result = futures[i]
+ if isinstance(result, ray.ObjectRef):
+ try:
+ result = ray.get(result)
+ logger.info(f"Partition {i}: completed successfully")
+ except Exception as e:
+ logger.error(f"Partition {i}: failed with error: {e}")
+ raise
+ processed_partitions.append(result)
+ self._log_event(
+ event_type=EventType.PARTITION_COMPLETE,
+ message=f"Completed concurrent processing of partition " f"{i + 1}/{len(partitions)}",
+ partition_id=i,
+ )
+
+ # Union results
+ logger.info("Merging concurrently processed partitions...")
+ if not processed_partitions:
+ logger.warning("All partitions were empty or skipped. Returning an empty dataset.")
+ return RayDataset(ray.data.from_items([]), cfg=self.cfg)
+
+ if len(processed_partitions) == 1:
+ merged_dataset = processed_partitions[0]
+ else:
+ merged_dataset = processed_partitions[0]
+ for partition in processed_partitions[1:]:
+ merged_dataset = merged_dataset.union(partition)
+
+ return RayDataset(merged_dataset, cfg=self.cfg)
+
def _process_with_convergence(self, dataset: RayDataset, ops: List, convergence_points: List[int]):
"""
Process dataset with convergence support for global operations.
@@ -954,7 +1194,14 @@ def _split_dataset_deterministic(self, dataset: RayDataset) -> tuple:
# Check for existing partitioning info (resumption case)
saved_info = self._load_partitioning_info()
- # Split the dataset
+ # Split using the dataset's natural block structure. split()
+ # distributes existing blocks round-robin, so partitions inherit
+ # multiple blocks and Ray Data's streaming executor can pipeline
+ # stages within each partition. Avoid repartition() here — it
+ # adds a costly shuffle and may reduce block count (e.g. 96 source
+ # blocks repartitioned to 32 loses parallelism). If there are
+ # fewer blocks than partitions, some partitions will be empty —
+ # that's handled downstream (empty partitions are skipped).
logger.info(f"Splitting dataset into {self.num_partitions} partitions (deterministic mode)...")
partitions = dataset.data.split(self.num_partitions)
logger.info(f"Created {len(partitions)} partitions")
@@ -974,24 +1221,16 @@ def _split_dataset_deterministic(self, dataset: RayDataset) -> tuple:
self._clear_invalid_checkpoints()
saved_info = None
- # Collect metadata for new partitions
- logger.info("Collecting partition metadata for checkpoint validation...")
- total_rows = sum(p.count() for p in partitions)
- partition_metadata = []
-
- for i, partition in enumerate(partitions):
- meta = self._collect_partition_metadata(partition, i)
- partition_metadata.append(meta)
- logger.debug(f"Partition {i}: {meta.row_count} rows, hash={meta.first_row_hash[:8]}...")
-
+ # On first run, skip expensive metadata collection (count(), take())
+ # which triggers redundant pipeline executions on lazy datasets.
+ # Save only the partition count; full metadata is not needed until
+ # resume validation.
partitioning_info = PartitioningInfo(
num_partitions=self.num_partitions,
- total_rows=total_rows,
- partitions=partition_metadata,
+ total_rows=-1, # unknown until processing completes
+ partitions=[],
deterministic=True,
)
-
- # Save partitioning info
self._save_partitioning_info(partitioning_info)
return partitions, partitioning_info
diff --git a/docs/design/parallel_partition_actor_reuse.md b/docs/design/parallel_partition_actor_reuse.md
new file mode 100644
index 0000000000..7284922467
--- /dev/null
+++ b/docs/design/parallel_partition_actor_reuse.md
@@ -0,0 +1,363 @@
+# Design Doc: Concurrent Partition Processing with GPU Scoping
+
+**Author:** Data-Juicer Team
+**Created:** 2026-03-09
+**Updated:** 2026-03-17
+**Status:** Implemented
+**Branch:** `feat/cyrusz/parallel-partition-actor-reuse`
+
+---
+
+## 1. Problem Statement
+
+### Current Behavior (Before This Change)
+
+The `PartitionedRayExecutor` processes partitions **sequentially**, creating new GPU actors for each partition:
+
+```
+Partition 1 → [Create Actors] → [Load Models] → [Process] → [Actors GC'd]
+Partition 2 → [Create Actors] → [Load Models] → [Process] → [Actors GC'd]
+Partition 3 → [Create Actors] → [Load Models] → [Process] → [Actors GC'd]
+```
+
+### Problems
+
+1. **Repeated Model Loading**: Heavy GPU models (e.g., VideoBLIP ~20GB) are loaded N times for N partitions
+2. **GPU Idle Time**: GPUs sit idle between partitions during actor teardown/creation
+3. **Poor Scalability**: Processing time scales linearly with partition count due to model loading overhead
+
+### Impact
+
+For a typical video processing pipeline with 3 GPU operators and 10 partitions:
+- Model loading time: ~60s per operator × 3 operators × 10 partitions = **30 minutes of pure overhead**
+- This overhead can exceed actual processing time for smaller datasets
+
+---
+
+## 2. Implemented Solution: Concurrent Partition Processing
+
+### Overview
+
+Instead of sequential processing with shared actor pools (originally proposed), we implemented **concurrent partition processing** where all partitions run in parallel as independent Ray remote tasks, each with its own scoped GPU actors:
+
+```
+┌──────────────────────────────────────────────────────────────────┐
+│ Concurrent Partition Processing │
+│ │
+│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
+│ │ Task P0 │ │ Task P1 │ │ Task P2 │ ... │ Task P7 │ │
+│ │ 1 GPU │ │ 1 GPU │ │ 1 GPU │ │ 1 GPU │ │
+│ │ Actor │ │ Actor │ │ Actor │ │ Actor │ │
+│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
+│ ↕ ↕ ↕ ↕ │
+│ GPU 0 GPU 1 GPU 2 GPU 7 │
+└──────────────────────────────────────────────────────────────────┘
+
+All partitions processed concurrently, each with its own scoped actor
+```
+
+### Why Concurrent Instead of Sequential + Actor Reuse
+
+The original design proposed sequential processing with detached shared actor pools. During implementation, we chose concurrent processing because:
+
+1. **Simpler architecture**: No need for detached actor lifecycle management, pool coordination, or cross-partition actor sharing
+2. **Better GPU utilization**: All GPUs are busy simultaneously instead of sequentially
+3. **Natural Ray fit**: Each partition is a self-contained Ray remote task — no complex orchestration
+4. **Same model loading cost**: Each GPU loads the model once per partition, but all load concurrently (~60s wall time vs. N × 60s sequential)
+5. **Maintained benefits**: Checkpointing, resume, and memory control per partition are all preserved
+
+### Key Design Principles
+
+1. **Concurrent partition processing**: All partitions run in parallel (up to `max_concurrent_partitions`)
+2. **Concurrency scoping**: Each partition's GPU ops get `num_proc = total_gpus // max_concurrent_partitions` actors
+3. **Forced actor mode**: GPU ops are set to `ray_execution_mode = "actor"` inside the remote task (where CUDA is not visible)
+4. **Per-partition checkpointing**: Each remote task manages its own checkpoint state
+5. **Resume support**: Skip completed partitions on restart
+
+---
+
+## 3. Detailed Design
+
+### 3.1 Architecture
+
+```
+┌─────────────────────────────────────────────────────────────────────────┐
+│ PartitionedRayExecutor │
+├─────────────────────────────────────────────────────────────────────────┤
+│ │
+│ ┌─────────────────────────────────────────────────────────────────┐ │
+│ │ _process_partitions_concurrent() │ │
+│ │ │ │
+│ │ 1. Extract serializable config values │ │
+│ │ 2. Submit Ray remote tasks (one per partition) │ │
+│ │ 3. Collect results, union partitions │ │
+│ └─────────────────────────────────────────────────────────────────┘ │
+│ │ │
+│ ┌───────────────┼───────────────┐ │
+│ ▼ ▼ ▼ │
+│ ┌─────────────────┐ ┌─────────────┐ ┌─────────────┐ │
+│ │ Remote Task P0 │ │ Remote Task │ │ Remote Task │ ... │
+│ │ │ │ P1 │ │ P2 │ │
+│ │ - load_ops() │ │ │ │ │ │
+│ │ - force actor │ │ (same) │ │ (same) │ │
+│ │ mode for GPU │ │ │ │ │ │
+│ │ - scope conc. │ │ │ │ │ │
+│ │ - process data │ │ │ │ │ │
+│ │ - checkpoint │ │ │ │ │ │
+│ └─────────────────┘ └─────────────┘ └─────────────┘ │
+│ │
+└─────────────────────────────────────────────────────────────────────────┘
+```
+
+### 3.2 Execution Flow
+
+```
+Phase 1: Dataset Splitting
+──────────────────────────
+
+Job Start
+ │
+ ▼
+┌──────────────────────────┐
+│ Repartition to N blocks │ Ensure enough blocks for N partitions
+└───────────┬──────────────┘
+ │
+ ▼
+┌──────────────────────────┐
+│ Split into N partitions │ Each partition gets ~equal rows
+└───────────┬──────────────┘
+
+
+Phase 2: Concurrent Processing
+──────────────────────────────
+ │
+ ┌───────┼───────┬───────┬───── ... ─────┐
+ ▼ ▼ ▼ ▼ ▼
+┌──────┐┌──────┐┌──────┐┌──────┐ ┌──────┐
+│ P0 ││ P1 ││ P2 ││ P3 │ │ P7 │
+│1 GPU ││1 GPU ││1 GPU ││1 GPU │ │1 GPU │
+└──┬───┘└──┬───┘└──┬───┘└──┬───┘ └──┬───┘
+ │ │ │ │ │
+ ▼ ▼ ▼ ▼ ▼
+ [Load] [Load] [Load] [Load] ... [Load] ← Models load concurrently
+ │ │ │ │ │
+ ▼ ▼ ▼ ▼ ▼
+[Process][Process][Process][Process] [Process] ← All GPUs busy
+ │ │ │ │ │
+ ▼ ▼ ▼ ▼ ▼
+ [Ckpt] [Ckpt] [Ckpt] [Ckpt] ... [Ckpt] ← Per-partition checkpoint
+
+
+Phase 3: Merge Results
+──────────────────────
+ └───────┴───────┴───────┴───── ... ─────┘
+ │
+ ▼
+ ┌──────────────────┐
+ │ Union partitions │
+ └──────────────────┘
+ │
+ ▼
+ Job End
+```
+
+### 3.3 Concurrency Scoping
+
+The critical mechanism that prevents GPU over-allocation:
+
+```python
+# Inside each remote task:
+for op in task_ops:
+ # Step 1: Force actor mode (MUST be before scope_op_concurrency)
+ if getattr(op, "num_gpus", 0) and op.num_gpus > 0:
+ op.ray_execution_mode = "actor"
+
+ # Step 2: Scope concurrency — divides num_proc by max_concurrent_partitions
+ op.num_proc = scope_op_concurrency(op, max_concurrent_partitions)
+```
+
+**Why order matters:**
+- The remote task runs on a CPU-only node (no GPU assigned to the task itself)
+- `torch.cuda.is_available()` returns `False` in the remote task
+- Without explicitly setting `ray_execution_mode = "actor"`, `use_ray_actor()` returns `False`
+- `scope_op_concurrency()` only divides `num_proc` for actor-mode ops
+- If actor mode is not set first, `num_proc` stays at the full value (e.g., 8), causing each partition to request all 8 GPUs → deadlock
+
+**Example with 8 GPUs, 8 partitions:**
+- `num_proc` original = 8 (wants 8 GPU actors)
+- `scope_op_concurrency(op, 8)` → `8 // 8 = 1` (1 GPU actor per partition)
+- 8 partitions × 1 GPU = 8 GPUs total → fits exactly
+
+### 3.4 Remote Task Design
+
+Each partition is processed by an independent `@ray.remote(num_cpus=0)` task that:
+
+1. **Re-creates ops from config** — avoids serialization issues with GPU operator state
+2. **Forces actor mode** — sets `ray_execution_mode = "actor"` for GPU ops
+3. **Scopes concurrency** — divides `num_proc` by `max_concurrent_partitions`
+4. **Manages its own checkpoints** — creates a local `RayCheckpointManager`
+5. **Handles resume** — checks for existing checkpoints before processing
+
+The task requests `num_cpus=0` because the actual compute is done by Ray Data actors/tasks spawned within.
+
+### 3.5 Dataset Splitting
+
+```python
+# Repartition to ensure enough blocks, then split
+dataset.data = dataset.data.repartition(self.num_partitions)
+partitions = dataset.data.split(self.num_partitions)
+```
+
+- `repartition(N)` ensures at least N blocks exist (lazy, adds a shuffle stage)
+- `split(N)` distributes blocks across N independent `Dataset` objects
+- Without repartition, split may produce empty partitions if there are fewer blocks than partitions
+
+---
+
+## 4. Configuration
+
+```yaml
+partition:
+ mode: 'auto' # 'auto' | 'manual'
+ num_of_partitions: 8 # Number of partitions
+ max_concurrent_partitions: 8 # Max partitions running in parallel
+
+checkpoint:
+ enabled: true
+ dir: './checkpoints'
+ strategy: 'per_op' # Checkpoint after each operator
+```
+
+The `max_concurrent_partitions` parameter controls how many partitions run simultaneously and how GPU resources are divided. Setting it equal to the number of GPUs (one partition per GPU) is typical for GPU-bound workloads.
+
+---
+
+## 5. Performance Comparison
+
+### Timeline: Sequential vs Concurrent
+
+**Before (Sequential, no actor reuse):**
+```
+Time ────────────────────────────────────────────────────────────────────▶
+
+P0: [Load 60s][Process 120s][GC]
+P1: [Load 60s][Process 120s][GC]
+P2: [Load 60s][Process 120s]
+
+Total: 3 × (60 + 120) = 540s
+GPU idle: ~67% of total time
+```
+
+**After (Concurrent, 8 partitions on 8 GPUs):**
+```
+Time ────────────────────────────────────────────────────────────────────▶
+
+P0: [Load 60s][Process 120s]
+P1: [Load 60s][Process 120s] ← All load concurrently
+P2: [Load 60s][Process 120s]
+...
+P7: [Load 60s][Process 120s]
+
+Total: 60 + 120 = 180s (wall time)
+GPU idle: ~0% during processing
+```
+
+### Observed Results
+
+**Setup:** 8× A100 80GB, 6000 video samples, VideoAestheticsFilter
+
+| Mode | Time | GPU Utilization |
+|------|------|-----------------|
+| Pure GPU (no partitioning) | ~1100s | 100% on all 8 GPUs |
+| Concurrent partitions (8) | ~1100-1300s | 100% on all 8 GPUs |
+| Sequential (old, deadlocked) | ∞ (deadlock) | 8/8 GPU allocated, 14+ pending |
+
+The concurrent approach matches pure GPU mode performance while adding partition-level checkpointing and resume capability.
+
+---
+
+## 6. Checkpointing and Resume
+
+### Checkpoint Structure
+
+Each remote task manages its own checkpoints:
+
+```
+checkpoints/
+├── partitioning_info.json # Partition metadata for validation
+├── partition_0/
+│ ├── op_0_video_aesthetics_filter/
+│ │ ├── data.parquet
+│ │ └── _SUCCESS
+│ └── ...
+├── partition_1/
+│ └── ...
+└── ...
+```
+
+### Resume Flow
+
+```
+Resume from Crash (Partition 2 was in progress)
+──────────────────────────────────────────────
+
+1. Load partitioning_info.json
+2. Validate current partitions match saved metadata
+3. Submit all partition tasks concurrently
+4. Each task independently:
+ - Checks its own checkpoint state
+ - Skips completed ops (loads from checkpoint)
+ - Resumes from last incomplete op
+5. Collect results and union
+```
+
+---
+
+## 7. Error Handling
+
+### Partition Task Failure
+
+If a remote task fails:
+- Other partitions continue processing independently
+- Failed partition's actors are cleaned up by Ray
+- On retry/resume, the failed partition restarts from its last checkpoint
+
+### GPU Resource Deadlock Prevention
+
+The concurrency scoping mechanism prevents deadlock by ensuring:
+- Total GPU requests across all concurrent partitions ≤ available GPUs
+- `num_proc` is divided by `max_concurrent_partitions` for actor-mode ops
+- Actor mode is set before scoping (critical ordering requirement)
+
+---
+
+## 8. Known Limitations and Future Work
+
+1. **No actor reuse across partitions**: Each partition loads models independently. For workloads dominated by model loading time, a shared actor pool approach (the original design) could reduce overhead.
+
+2. **Repartition cost**: `repartition()` adds a shuffle stage. For large datasets this is cheap relative to processing, but for small datasets it adds overhead.
+
+3. **Single block per partition**: After split, each partition typically has one block, which means the entire partition is processed as a single batch by the actor. This prevents streaming output — no progress is visible until the whole partition completes.
+
+4. **`max_concurrent_partitions` tuning**: Must be ≤ available GPUs for GPU-bound workloads. Auto-detection sets it to the GPU count, but mixed CPU/GPU pipelines may benefit from different values.
+
+---
+
+## 9. Design Decision Log
+
+| Decision | Choice | Rationale |
+|----------|--------|-----------|
+| Sequential vs concurrent | Concurrent | Better GPU utilization, simpler architecture |
+| Shared actors vs per-partition | Per-partition | Avoids detached actor lifecycle complexity |
+| Repartition before split | Always repartition | Avoids materializing dataset to check num_blocks |
+| Actor mode + scoping order | Actor mode first | Required for scope_op_concurrency to work correctly |
+| Remote task num_cpus | 0 | Task is just an orchestrator; actual compute uses Ray Data actors |
+
+---
+
+## References
+
+- [Ray Actors Documentation](https://docs.ray.io/en/latest/ray-core/actors.html)
+- [Ray Data User Guide](https://docs.ray.io/en/latest/data/data.html)
+- Source: `data_juicer/core/executor/ray_executor_partitioned.py`
+- Source: `data_juicer/core/executor/concurrency_scoping.py`
diff --git a/perf-test.py b/perf-test.py
new file mode 100644
index 0000000000..81a81fb93c
--- /dev/null
+++ b/perf-test.py
@@ -0,0 +1,742 @@
+#!/usr/bin/env python3
+"""
+Simple single-operator benchmark to test data loading and Ray Data parallelism.
+Enhanced for debugging Ray/DataJuicer GPU actor initialization issues.
+"""
+
+import argparse
+import importlib
+import json
+import os
+import subprocess
+import sys
+import time
+from datetime import datetime
+
+from loguru import logger
+
+# ── Paths ─────────────────────────────────────────────────────────────────────
+DJ_CODE_PATH = "/mnt/workspace/yileiz/data-juicer"
+OUTPUT_DIR = "/mnt/workspace/yileiz/outputs/partitioned_ray/simple_workdir"
+MODEL_PATH = "/mnt/workspace/miaoxiang.zfr/models/aesthetics-predictor-v2-sac-logos-ava1-l14-linearMSE"
+DEFAULT_CAPTION_JSONL = "/mnt/workspace/miaoxiang.zfr/data/Youku-AliceMind/caption_val_abs_6k.jsonl"
+DEFAULT_VIDEO_DIR = "/mnt/workspace/shurui.ksr/Project/data/modelscope/Youku-AliceMind/videos/caption"
+# ──────────────────────────────────────────────────────────────────────────────
+
+if os.path.exists(DJ_CODE_PATH):
+ sys.path.insert(0, DJ_CODE_PATH)
+
+
+def setup_logging(log_dir=None):
+ """Setup logging to file and console."""
+ if log_dir is None:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ log_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}")
+
+ os.makedirs(log_dir, exist_ok=True)
+ log_file = os.path.join(log_dir, "benchmark.log")
+
+ logger.remove()
+
+ logger.add(
+ sys.stdout,
+ level="INFO",
+ format="{time:HH:mm:ss} | {level: <8} | {message}",
+ colorize=True,
+ )
+
+ logger.add(
+ log_file,
+ level="DEBUG",
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {name}:{function}:{line} - {message}",
+ rotation="100 MB",
+ )
+
+ logger.info(f"Log file: {log_file}")
+ return log_dir, log_file
+
+
+def monitor_gpu():
+ """Print GPU utilization."""
+ try:
+ result = subprocess.run(
+ ["nvidia-smi", "--query-gpu=index,name,utilization.gpu,memory.used,memory.total", "--format=csv,noheader"],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ logger.info(f"GPU Status:\n{result.stdout}")
+ except Exception as e:
+ logger.warning(f"Failed to query GPU status: {e}")
+
+
+def log_ray_paths():
+ """Print likely Ray log locations for easier debugging."""
+ ray_tmp = "/tmp/ray"
+ if os.path.exists(ray_tmp):
+ logger.info(f"Ray temp dir exists: {ray_tmp}")
+ logger.info("Check Ray logs under: /tmp/ray/session_latest/logs/")
+ logger.info("Ray Data logs often under: /tmp/ray/session_latest/logs/ray-data/")
+ else:
+ logger.warning("Ray temp dir /tmp/ray not found yet")
+
+
+def prepare_jsonl_from_caption(jsonl_path, video_base_dir, num_samples=None, output_path=None):
+ """Prepare JSONL with absolute video paths."""
+ if output_path is None:
+ output_path = jsonl_path.replace(".jsonl", "_abs.jsonl")
+
+ if os.path.exists(output_path):
+ logger.info(f"Output already exists: {output_path}")
+ return output_path
+
+ count = 0
+ missing = 0
+ with open(jsonl_path, "r") as f_in, open(output_path, "w") as f_out:
+ for line in f_in:
+ if num_samples and count >= num_samples:
+ break
+ sample = json.loads(line)
+ videos = sample.get("videos", [])
+ abs_videos = [os.path.join(video_base_dir, os.path.basename(v)) for v in videos]
+ if all(os.path.exists(v) for v in abs_videos):
+ out_sample = {"videos": abs_videos, "text": sample.get("caption", "")}
+ f_out.write(json.dumps(out_sample, ensure_ascii=False) + "\n")
+ count += 1
+ else:
+ missing += 1
+
+ logger.info(f"Created {output_path} with {count} samples, skipped {missing} missing-video samples")
+ return output_path
+
+
+def split_jsonl(jsonl_path, num_shards=96):
+ """Split JSONL into shards."""
+ shard_dir = jsonl_path.replace(".jsonl", f"_sharded_{num_shards}")
+ marker = os.path.join(shard_dir, "_DONE")
+
+ if os.path.exists(marker):
+ logger.info(f"Sharded data exists: {shard_dir}")
+ return shard_dir
+
+ os.makedirs(shard_dir, exist_ok=True)
+
+ writers = [open(os.path.join(shard_dir, f"shard_{i:04d}.jsonl"), "w") for i in range(num_shards)]
+
+ count = 0
+ try:
+ with open(jsonl_path, "r") as f_in:
+ for line in f_in:
+ writers[count % num_shards].write(line)
+ count += 1
+ finally:
+ for w in writers:
+ w.close()
+
+ with open(marker, "w") as f:
+ f.write(f"{count} samples\n")
+
+ logger.info(f"Split {count} samples into {num_shards} shards")
+ return shard_dir
+
+
+def require_module(module_name, pip_hint=None):
+ """Fail fast if module is missing."""
+ try:
+ return importlib.import_module(module_name)
+ except Exception as e:
+ hint = f" Please install it first: {pip_hint}" if pip_hint else ""
+ raise RuntimeError(f"Missing required module [{module_name}].{hint}\nOriginal error: {e}") from e
+
+
+def precheck_environment(fail_fast=True):
+ """
+ Precheck environment in driver process to avoid hanging inside Ray actors.
+ """
+ logger.info("=" * 80)
+ logger.info("Prechecking environment before starting Ray actors")
+ logger.info("=" * 80)
+
+ # Basic env
+ logger.info(f"Python executable: {sys.executable}")
+ logger.info(f"Python version: {sys.version}")
+ logger.info(f'HF_ENDPOINT={os.environ.get("HF_ENDPOINT")}')
+
+ # Model path
+ if not os.path.exists(MODEL_PATH):
+ msg = f"Model path does not exist: {MODEL_PATH}"
+ if fail_fast:
+ raise FileNotFoundError(msg)
+ logger.warning(msg)
+ else:
+ logger.info(f"Model path exists: {MODEL_PATH}")
+
+ # Required modules
+ require_module("torch", "pip install torch")
+ require_module("transformers", "pip install transformers")
+ require_module("ray", "pip install ray")
+ require_module("pyarrow", "pip install pyarrow")
+
+ # Torch / CUDA visibility
+ import torch
+
+ logger.info(f"torch version: {torch.__version__}")
+ logger.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
+ logger.info(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
+ if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ try:
+ logger.info(f"CUDA device {i}: {torch.cuda.get_device_name(i)}")
+ except Exception:
+ pass
+
+ logger.info("Environment precheck passed.")
+
+
+def init_ray(object_store_gb=300, num_gpus=8):
+ """Initialize Ray with better defaults."""
+ # Pre-import to avoid circular import issues in Ray workers
+ logger.info("Pre-importing modules to avoid fsspec issues in Ray workers...")
+ import fsspec
+ import fsspec.spec
+ import fsspec.utils # noqa: F401
+
+ try:
+ from huggingface_hub import HfFileSystem # noqa: F401
+ except ImportError:
+ pass # OK if not available
+
+ import ray
+
+ if ray.is_initialized():
+ logger.info("Ray already initialized")
+ return
+
+ # Check if there's a running Ray cluster
+ ray_address = os.environ.get("RAY_ADDRESS")
+
+ if ray_address:
+ # Connect to specified cluster
+ logger.info(f"Connecting to Ray cluster at {ray_address}...")
+ ray.init(address=ray_address)
+ logger.info("Connected to existing Ray cluster")
+ else:
+ # Start a new local Ray instance
+ logger.info(f"Starting new Ray instance with {num_gpus} GPUs, {object_store_gb}GB object store...")
+ ray.init(
+ num_gpus=num_gpus,
+ object_store_memory=object_store_gb * 1024**3,
+ )
+ logger.info(f"Ray initialized successfully")
+
+ log_ray_paths()
+
+
+def run_simple_benchmark(
+ data_path,
+ num_shards=96,
+ num_partitions=8,
+ fail_fast=True,
+ executor_type="ray",
+):
+ """Run benchmark with DataJuicer + video_aesthetics_filter.
+
+ Args:
+ executor_type: 'ray' (standard, uses all GPUs) or 'ray_partitioned' (partitioned).
+ ray_partitioned auto-detects GPU count and runs partitions concurrently.
+ """
+ import ray # noqa: F401
+ import yaml
+
+ from data_juicer.config import init_configs
+ from data_juicer.core.executor.ray_executor import RayExecutor
+ from data_juicer.core.executor.ray_executor_partitioned import (
+ PartitionedRayExecutor,
+ )
+
+ # Environment
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
+
+ # Fail fast before actors
+ precheck_environment(fail_fast=fail_fast)
+
+ # Initialize Ray
+ init_ray(object_store_gb=300)
+
+ # Shard data
+ if os.path.isfile(data_path):
+ data_path = split_jsonl(data_path, num_shards)
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ work_dir = os.path.join(OUTPUT_DIR, f"dj_run_{timestamp}")
+ os.makedirs(work_dir, exist_ok=True)
+
+ logger.info(f"Using executor type: {executor_type}")
+
+ # Detect available GPUs from Ray cluster
+ import ray as _ray
+
+ num_gpus = int(_ray.cluster_resources().get("GPU", 0))
+ if num_gpus <= 0:
+ raise RuntimeError("No GPUs available in Ray cluster")
+ logger.info(f"Detected {num_gpus} GPUs in Ray cluster")
+
+ # Base config
+ cfg_dict = {
+ "project_name": "simple-benchmark",
+ "executor_type": executor_type,
+ "dataset_path": data_path,
+ "export_path": os.path.join(work_dir, "result.jsonl"),
+ "work_dir": work_dir,
+ "video_key": "videos",
+ "skip_op_error": False, # fail loudly
+ "use_cache": False,
+ "open_monitor": True,
+ "debug": False,
+ "auto_op_parallelism": False, # Disable auto calculation to use explicit num_proc
+ "process": [
+ {
+ "video_aesthetics_filter": {
+ "hf_scorer_model": MODEL_PATH,
+ "trust_remote_code": True,
+ "min_score": 0.4,
+ "max_score": 1.0,
+ "frame_num": 9223372036854775807, # sys.maxsize - use all frames
+ "reduce_mode": "avg",
+ "skip_op_error": False, # fail loudly during debugging
+ "batch_mode": True,
+ "num_gpus": 1,
+ "num_proc": num_gpus,
+ },
+ },
+ ],
+ }
+
+ # Add partition config only for ray_partitioned executor
+ if executor_type == "ray_partitioned":
+ cfg_dict["partition"] = {
+ "mode": "manual",
+ "num_of_partitions": num_partitions,
+ }
+ cfg_dict["checkpoint"] = {
+ "enabled": False,
+ }
+
+ config_path = os.path.join(work_dir, "config.yaml")
+ with open(config_path, "w") as f:
+ yaml.dump(cfg_dict, f, allow_unicode=True, sort_keys=False)
+
+ logger.info(f"Config saved to {config_path}")
+ logger.info(f"Work dir: {work_dir}")
+ logger.info(f"Data path: {data_path}")
+ if executor_type == "ray_partitioned":
+ logger.info(f"Num partitions: {num_partitions}")
+
+ monitor_gpu()
+
+ cfg = init_configs(args=["--config", config_path])
+
+ t0 = time.time()
+ if executor_type == "ray":
+ executor = RayExecutor(cfg)
+ else:
+ executor = PartitionedRayExecutor(cfg)
+ logger.info(f"Executor init ({executor_type}): {time.time() - t0:.2f}s")
+
+ t1 = time.time()
+ try:
+ executor.run()
+ except Exception:
+ logger.exception("DataJuicer execution failed")
+ logger.error(f"Please inspect Ray logs under /tmp/ray/session_latest/logs/")
+ raise
+
+ logger.info(f"Processing: {time.time() - t1:.2f}s")
+ monitor_gpu()
+ logger.info(f"Total: {time.time() - t0:.2f}s")
+ logger.info(f"Output dir: {work_dir}")
+
+
+def run_ray_data_test(data_path, num_shards=96):
+ """Test raw Ray Data parallelism without DataJuicer."""
+ import ray
+
+ if os.path.isfile(data_path):
+ data_path = split_jsonl(data_path, num_shards)
+
+ init_ray(object_store_gb=100)
+
+ logger.info(f"Reading data from {data_path}")
+
+ t0 = time.time()
+ ds = ray.data.read_json(data_path)
+ count = ds.count()
+ try:
+ num_blocks = ds.num_blocks()
+ except Exception:
+ num_blocks = "unknown_before_materialize"
+ logger.info(f"Loaded dataset: {count} rows, {num_blocks} blocks")
+
+ def count_videos(row):
+ return {"video_count": len(row.get("videos", [])), "text_len": len(row.get("text", ""))}
+
+ t1 = time.time()
+ ds = ds.map(count_videos)
+ result = ds.take(5)
+ logger.info(f"Map result: {result}")
+ logger.info(f"Map time: {time.time() - t1:.2f}s")
+
+ t2 = time.time()
+ total = ds.count()
+ logger.info(f"Total rows: {total}, count time: {time.time() - t2:.2f}s")
+
+ logger.info(f"Total time: {time.time() - t0:.2f}s")
+
+
+def run_direct_gpu_test(
+ data_path,
+ num_shards=96,
+ batch_size=8,
+ gpu_concurrency=8,
+ fail_fast=True,
+):
+ """
+ Direct GPU test bypassing PartitionedRayExecutor.
+ This tests if Ray Data GPU actors work correctly.
+ """
+ import pyarrow
+ import ray
+
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
+
+ # Precheck before actor creation
+ precheck_environment(fail_fast=fail_fast)
+
+ init_ray(object_store_gb=300)
+
+ logger.info("Direct GPU Test - bypassing PartitionedRayExecutor")
+ monitor_gpu()
+
+ t0 = time.time()
+ if os.path.isfile(data_path):
+ data_path = split_jsonl(data_path, num_shards)
+
+ ds = ray.data.read_json(data_path)
+ row_count = ds.count()
+ logger.info(f"Loaded {row_count} rows in {time.time() - t0:.2f}s")
+
+ def add_stats_column(table: pyarrow.Table):
+ new_column_data = [{} for _ in range(len(table))]
+ return table.append_column("__dj__stats__", [new_column_data])
+
+ ds = ds.map_batches(add_stats_column, batch_format="pyarrow")
+ logger.info("Added __dj__stats__ column")
+
+ from data_juicer.ops.filter.video_aesthetics_filter import VideoAestheticsFilter
+
+ # Create operator on driver for validation only
+ op_t0 = time.time()
+ op = VideoAestheticsFilter(
+ hf_scorer_model=MODEL_PATH,
+ trust_remote_code=True,
+ min_score=0.4,
+ max_score=1.0,
+ frame_num=9223372036854775807, # sys.maxsize - use all frames
+ reduce_mode="avg",
+ num_gpus=1,
+ )
+ logger.info(f"Operator init on driver: {time.time() - op_t0:.2f}s")
+ logger.info(f"Operator: {op._name}")
+ logger.info(f" use_cuda: {op.use_cuda()}")
+ logger.info(f" use_ray_actor: {op.use_ray_actor()}")
+ logger.info(f" num_gpus: {op.num_gpus}")
+ logger.info(f" num_proc: {op.num_proc}")
+
+ # Restrict concurrency to available GPUs
+ import torch
+
+ available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+ if available_gpus <= 0:
+ raise RuntimeError("No CUDA GPUs visible, cannot run direct GPU test")
+
+ gpu_concurrency = min(gpu_concurrency, available_gpus)
+ logger.info(f"Using gpu_concurrency={gpu_concurrency}, batch_size={batch_size}")
+
+ # Prefer new API style: concurrency=
+ t1 = time.time()
+ logger.info("Creating Ray Data GPU actor pipeline...")
+
+ try:
+ ds = ds.map_batches(
+ VideoAestheticsFilter,
+ fn_constructor_args=op._init_args,
+ fn_constructor_kwargs=op._init_kwargs,
+ batch_size=batch_size,
+ num_cpus=1,
+ num_gpus=1,
+ concurrency=gpu_concurrency,
+ batch_format="pyarrow",
+ )
+ logger.info("Using map_batches(..., concurrency=...)")
+ except TypeError:
+ # Fallback for older Ray versions
+ from ray.data import ActorPoolStrategy
+
+ logger.warning("Ray version does not support concurrency= here, fallback to ActorPoolStrategy")
+ ds = ds.map_batches(
+ VideoAestheticsFilter,
+ fn_constructor_args=op._init_args,
+ fn_constructor_kwargs=op._init_kwargs,
+ batch_size=batch_size,
+ num_cpus=1,
+ num_gpus=1,
+ compute=ActorPoolStrategy(size=gpu_concurrency),
+ batch_format="pyarrow",
+ )
+
+ logger.info("Executing pipeline...")
+ t2 = time.time()
+ try:
+ result = ds.materialize()
+ except Exception:
+ logger.exception("Direct GPU pipeline execution failed")
+ logger.error("Please inspect /tmp/ray/session_latest/logs/")
+ raise
+
+ logger.info(f"Pipeline execution: {time.time() - t2:.2f}s")
+
+ count = result.count()
+ logger.info(f"Result: {count} rows")
+
+ monitor_gpu()
+ logger.info(f"Total time: {time.time() - t0:.2f}s")
+ logger.info(f"Pipeline setup time: {time.time() - t1:.2f}s")
+
+
+def run_direct_gpu_test_dj_match(
+ data_path,
+ num_shards=96,
+ batch_size=10, # DJ CUDA default
+ gpu_concurrency=8,
+ fail_fast=True,
+):
+ """
+ Direct GPU test that matches the DJ pipeline as closely as possible.
+ Adds: convert_to_absolute_paths, count(), columns(), filter step.
+ """
+ from functools import partial
+
+ import pyarrow
+ import ray
+
+ os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
+
+ precheck_environment(fail_fast=fail_fast)
+ init_ray(object_store_gb=300)
+
+ logger.info("Direct GPU Test (DJ-matched pipeline)")
+ monitor_gpu()
+
+ t0 = time.time()
+ if os.path.isfile(data_path):
+ data_path = split_jsonl(data_path, num_shards)
+
+ ds = ray.data.read_json(data_path)
+
+ # --- Match DJ: count() before processing ---
+ t_count = time.time()
+ row_count = ds.count()
+ logger.info(f"count(): {row_count} rows in {time.time() - t_count:.2f}s")
+
+ # --- Match DJ: columns() ---
+ t_cols = time.time()
+ cols = ds.columns()
+ logger.info(f"columns(): {cols} in {time.time() - t_cols:.2f}s")
+
+ # --- Match DJ: convert_to_absolute_paths ---
+ dataset_dir = os.path.dirname(data_path)
+
+ def convert_to_absolute_paths(batch, dataset_dir, path_keys):
+ for key in path_keys:
+ if key in batch.column_names:
+ col = batch.column(key)
+ new_col = []
+ for val in col.to_pylist():
+ if isinstance(val, list):
+ new_col.append([os.path.join(dataset_dir, p) if not os.path.isabs(p) else p for p in val])
+ elif isinstance(val, str):
+ new_col.append(os.path.join(dataset_dir, val) if not os.path.isabs(val) else val)
+ else:
+ new_col.append(val)
+ idx = batch.column_names.index(key)
+ batch = batch.set_column(idx, key, [new_col])
+ return batch
+
+ path_keys = [k for k in ["videos", "images", "audios"] if k in cols]
+ if path_keys:
+ ds = ds.map_batches(
+ partial(convert_to_absolute_paths, dataset_dir=dataset_dir, path_keys=path_keys),
+ batch_format="pyarrow",
+ zero_copy_batch=True,
+ batch_size=1000,
+ )
+ logger.info(f"Added convert_to_absolute_paths for keys: {path_keys}")
+
+ # --- Match DJ: add __dj__stats__ column ---
+ def add_stats_column(table: pyarrow.Table):
+ new_column_data = [{} for _ in range(len(table))]
+ return table.append_column("__dj__stats__", [new_column_data])
+
+ ds = ds.map_batches(add_stats_column, batch_format="pyarrow", batch_size=1000)
+ logger.info("Added __dj__stats__ column")
+
+ # --- Match DJ: compute_stats via actor ---
+ from data_juicer.ops.filter.video_aesthetics_filter import VideoAestheticsFilter
+
+ op = VideoAestheticsFilter(
+ hf_scorer_model=MODEL_PATH,
+ trust_remote_code=True,
+ min_score=0.4,
+ max_score=1.0,
+ frame_num=9223372036854775807,
+ reduce_mode="avg",
+ num_gpus=1,
+ batch_mode=True,
+ )
+ logger.info(f"Op: {op._name}, batch_size={batch_size}, is_batched={op.is_batched_op()}")
+
+ import torch
+
+ available_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
+ if available_gpus <= 0:
+ raise RuntimeError("No CUDA GPUs visible")
+ gpu_concurrency = min(gpu_concurrency, available_gpus)
+ logger.info(f"gpu_concurrency={gpu_concurrency}, batch_size={batch_size}")
+
+ t1 = time.time()
+ ds = ds.map_batches(
+ VideoAestheticsFilter,
+ fn_constructor_args=op._init_args,
+ fn_constructor_kwargs=op._init_kwargs,
+ batch_size=batch_size,
+ num_gpus=1,
+ concurrency=gpu_concurrency,
+ batch_format="pyarrow",
+ )
+ logger.info("Added compute_stats map_batches (actor mode)")
+
+ # --- Match DJ: filter step ---
+ def filter_batch(batch, filter_func):
+ mask = pyarrow.array(filter_func(batch.to_pydict()))
+ return batch.filter(mask)
+
+ ds = ds.map_batches(
+ partial(filter_batch, filter_func=op.process),
+ batch_format="pyarrow",
+ zero_copy_batch=True,
+ batch_size=1000,
+ )
+ logger.info("Added filter_batch step")
+
+ # --- Execute ---
+ logger.info("Executing full DJ-matched pipeline...")
+ t2 = time.time()
+ try:
+ result = ds.materialize()
+ except Exception:
+ logger.exception("Pipeline execution failed")
+ raise
+
+ logger.info(f"Pipeline execution: {time.time() - t2:.2f}s")
+ count = result.count()
+ logger.info(f"Result: {count} rows (filtered from {row_count})")
+ monitor_gpu()
+ logger.info(f"Total time: {time.time() - t0:.2f}s")
+ logger.info(f"Pipeline time (from first map_batches): {time.time() - t1:.2f}s")
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Simple benchmark")
+ parser.add_argument(
+ "--caption-jsonl",
+ type=str,
+ default=DEFAULT_CAPTION_JSONL,
+ )
+ parser.add_argument(
+ "--video-dir",
+ type=str,
+ default=DEFAULT_VIDEO_DIR,
+ )
+ parser.add_argument("--num-samples", type=int, default=6000)
+ parser.add_argument("--num-shards", type=int, default=96)
+ parser.add_argument("--partitions", type=int, default=8)
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--gpu-concurrency", type=int, default=8)
+ parser.add_argument("--fail-fast", action="store_true", default=True)
+ parser.add_argument("--no-fail-fast", dest="fail_fast", action="store_false")
+ parser.add_argument("--mode", type=str, choices=["ray", "dj", "gpu", "gpu-dj", "both"], default="gpu")
+ parser.add_argument(
+ "--executor",
+ type=str,
+ choices=["ray", "ray_partitioned"],
+ default="ray",
+ help='Executor type: "ray" (standard, parallel GPUs) or "ray_partitioned" (partitioned)',
+ )
+ args = parser.parse_args()
+
+ log_dir, log_file = setup_logging()
+ logger.info(f"Arguments: {args}")
+
+ jsonl_path = prepare_jsonl_from_caption(args.caption_jsonl, args.video_dir, args.num_samples)
+
+ if args.mode in ["ray", "both"]:
+ logger.info("\n" + "=" * 60)
+ logger.info("Testing Ray Data parallelism")
+ logger.info("=" * 60)
+ run_ray_data_test(jsonl_path, args.num_shards)
+
+ if args.mode in ["dj", "both"]:
+ logger.info("\n" + "=" * 60)
+ logger.info(f"Testing DataJuicer with single operator (executor={args.executor})")
+ logger.info("=" * 60)
+ run_simple_benchmark(
+ jsonl_path,
+ num_shards=args.num_shards,
+ num_partitions=args.partitions,
+ fail_fast=args.fail_fast,
+ executor_type=args.executor,
+ )
+
+ if args.mode == "gpu":
+ logger.info("\n" + "=" * 60)
+ logger.info("Testing Direct GPU (bypass PartitionedRayExecutor)")
+ logger.info("=" * 60)
+ run_direct_gpu_test(
+ jsonl_path,
+ num_shards=args.num_shards,
+ batch_size=args.batch_size,
+ gpu_concurrency=args.gpu_concurrency,
+ fail_fast=args.fail_fast,
+ )
+
+ if args.mode == "gpu-dj":
+ logger.info("\n" + "=" * 60)
+ logger.info("Testing Direct GPU (DJ-matched pipeline)")
+ logger.info("=" * 60)
+ run_direct_gpu_test_dj_match(
+ jsonl_path,
+ num_shards=args.num_shards,
+ batch_size=10, # DJ CUDA default
+ gpu_concurrency=args.gpu_concurrency,
+ fail_fast=args.fail_fast,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyproject.toml b/pyproject.toml
index 26519cbbc7..c88badaab3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -234,6 +234,7 @@ filterwarnings = [
]
[tool.uv]
+constraint-dependencies = ["kaleido==0.2.1"]
override-dependencies = [
"opencv-python; sys_platform == 'never'",
"opencv-python-headless; sys_platform == 'never'",
diff --git a/tests/core/executor/test_ray_executor_partitioned.py b/tests/core/executor/test_ray_executor_partitioned.py
index 9c6fec206b..2054f0c2a7 100644
--- a/tests/core/executor/test_ray_executor_partitioned.py
+++ b/tests/core/executor/test_ray_executor_partitioned.py
@@ -679,5 +679,221 @@ def test_dag_node_status_transitions(self):
self.assertEqual(executor.pipeline_dag.nodes[node_id]["status"], "completed")
+class ConcurrencyScopingTest(DataJuicerTestCaseBase):
+ """Unit tests for scope_op_concurrency utility."""
+
+ def test_gpu_op_scoping(self):
+ """GPU op concurrency is divided by max_concurrent_partitions."""
+ from unittest.mock import MagicMock
+ from data_juicer.core.executor.concurrency_scoping import scope_op_concurrency
+
+ op = MagicMock()
+ op.use_ray_actor.return_value = True
+ op.num_proc = 4
+ self.assertEqual(scope_op_concurrency(op, 4), 1)
+ self.assertEqual(scope_op_concurrency(op, 2), 2)
+ self.assertEqual(scope_op_concurrency(op, 1), 4)
+
+ def test_gpu_op_scoping_floor_min_one(self):
+ """Scoped concurrency never goes below 1."""
+ from unittest.mock import MagicMock
+ from data_juicer.core.executor.concurrency_scoping import scope_op_concurrency
+
+ op = MagicMock()
+ op.use_ray_actor.return_value = True
+ op.num_proc = 2
+ self.assertEqual(scope_op_concurrency(op, 8), 1)
+
+ def test_cpu_op_unchanged(self):
+ """CPU ops (use_ray_actor=False) are not scoped."""
+ from unittest.mock import MagicMock
+ from data_juicer.core.executor.concurrency_scoping import scope_op_concurrency
+
+ op = MagicMock()
+ op.use_ray_actor.return_value = False
+ op.num_proc = 4
+ self.assertEqual(scope_op_concurrency(op, 4), 4)
+
+ def test_auto_mode_unchanged(self):
+ """Auto-mode (num_proc <= 0) is not scoped."""
+ from unittest.mock import MagicMock
+ from data_juicer.core.executor.concurrency_scoping import scope_op_concurrency
+
+ op = MagicMock()
+ op.use_ray_actor.return_value = True
+ op.num_proc = -1
+ self.assertEqual(scope_op_concurrency(op, 4), -1)
+
+ def test_none_num_proc_unchanged(self):
+ """None num_proc is not scoped."""
+ from unittest.mock import MagicMock
+ from data_juicer.core.executor.concurrency_scoping import scope_op_concurrency
+
+ op = MagicMock()
+ op.use_ray_actor.return_value = True
+ op.num_proc = None
+ self.assertIsNone(scope_op_concurrency(op, 4))
+
+ def test_resolve_max_concurrent_explicit_int(self):
+ """Explicit int values are passed through."""
+ from data_juicer.core.executor.ray_executor_partitioned import PartitionedRayExecutor
+ self.assertEqual(PartitionedRayExecutor._resolve_max_concurrent(4), 4)
+ self.assertEqual(PartitionedRayExecutor._resolve_max_concurrent(1), 1)
+ # Minimum clamp to 1
+ self.assertEqual(PartitionedRayExecutor._resolve_max_concurrent(0), 1)
+
+ def test_resolve_max_concurrent_auto(self):
+ """'auto' resolves to GPU count or 1."""
+ from data_juicer.core.executor.ray_executor_partitioned import PartitionedRayExecutor
+ result = PartitionedRayExecutor._resolve_max_concurrent("auto")
+ self.assertIsInstance(result, int)
+ self.assertGreaterEqual(result, 1)
+
+
+class ConcurrentPartitionConfigTest(DataJuicerTestCaseBase):
+ """Tests for max_concurrent_partitions config parsing."""
+
+ root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', '..')
+
+ def setUp(self) -> None:
+ super().setUp()
+ unique_name = f'test_concurrent_cfg_{uuid.uuid4().hex[:8]}'
+ self.tmp_dir = os.path.join(self.root_path, 'tmp', unique_name)
+ os.makedirs(self.tmp_dir, exist_ok=True)
+
+ def tearDown(self) -> None:
+ super().tearDown()
+ if os.path.exists(self.tmp_dir):
+ shutil.rmtree(self.tmp_dir)
+
+ @TEST_TAG('ray')
+ def test_default_max_concurrent_partitions_auto(self):
+ """Default max_concurrent_partitions is 'auto', resolved from GPU count."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '4'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_default_conc', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_default_conc')
+
+ executor = PartitionedRayExecutor(cfg)
+ # Auto-resolved: matches GPU count, or 1 if no GPUs
+ import ray as _ray
+ num_gpus = int(_ray.cluster_resources().get("GPU", 0))
+ if num_gpus > 1:
+ self.assertEqual(executor.max_concurrent_partitions, num_gpus)
+ else:
+ self.assertEqual(executor.max_concurrent_partitions, 1)
+
+ @TEST_TAG('ray')
+ def test_explicit_max_concurrent_partitions(self):
+ """Explicit max_concurrent_partitions is parsed correctly."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '8',
+ '--partition.max_concurrent_partitions', '4'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_explicit_conc', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_explicit_conc')
+
+ executor = PartitionedRayExecutor(cfg)
+ self.assertEqual(executor.max_concurrent_partitions, 4)
+
+ @TEST_TAG('ray')
+ def test_num_partitions_inferred_from_max_concurrent(self):
+ """num_of_partitions is raised to max_concurrent_partitions when too low."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '2',
+ '--partition.max_concurrent_partitions', '8'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_infer_partitions', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_infer_partitions')
+
+ executor = PartitionedRayExecutor(cfg)
+ # num_partitions should be raised to 8
+ self.assertEqual(executor.num_partitions, 8)
+ self.assertEqual(executor.max_concurrent_partitions, 8)
+
+ @TEST_TAG('ray')
+ def test_num_partitions_not_lowered(self):
+ """num_of_partitions is NOT lowered when already >= max_concurrent."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '16',
+ '--partition.max_concurrent_partitions', '8'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_no_lower', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_no_lower')
+
+ executor = PartitionedRayExecutor(cfg)
+ self.assertEqual(executor.num_partitions, 16)
+ self.assertEqual(executor.max_concurrent_partitions, 8)
+
+ @TEST_TAG('ray')
+ def test_concurrent_execution_end2end(self):
+ """End-to-end test: concurrent partitions produce output."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '2',
+ '--partition.max_concurrent_partitions', '2'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_conc_e2e', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_conc_e2e')
+
+ executor = PartitionedRayExecutor(cfg)
+ executor.run()
+
+ self.assertTrue(os.path.exists(cfg.export_path))
+
+ @TEST_TAG('ray')
+ def test_concurrent_with_checkpointing(self):
+ """Concurrent execution with checkpointing enabled."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '2',
+ '--partition.max_concurrent_partitions', '2',
+ '--checkpoint.enabled', 'true',
+ '--checkpoint.strategy', 'every_op'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_conc_ckpt', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_conc_ckpt')
+
+ executor = PartitionedRayExecutor(cfg)
+ executor.run()
+
+ self.assertTrue(os.path.exists(cfg.export_path))
+
+ # Verify checkpoint files were created
+ checkpoint_dir = cfg.checkpoint_dir
+ if os.path.exists(checkpoint_dir):
+ checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.endswith('.parquet')]
+ self.assertGreater(len(checkpoint_files), 0, "No checkpoint files were created")
+
+ @TEST_TAG('ray')
+ def test_backward_compat_sequential(self):
+ """max_concurrent_partitions=1 uses sequential path (same as before)."""
+ cfg = init_configs([
+ '--config', os.path.join(self.root_path, 'demos/process_on_ray/configs/demo-new-config.yaml'),
+ '--partition.mode', 'manual',
+ '--partition.num_of_partitions', '2',
+ '--partition.max_concurrent_partitions', '1'
+ ])
+ cfg.export_path = os.path.join(self.tmp_dir, 'test_seq_compat', 'res.jsonl')
+ cfg.work_dir = os.path.join(self.tmp_dir, 'test_seq_compat')
+
+ executor = PartitionedRayExecutor(cfg)
+ self.assertEqual(executor.max_concurrent_partitions, 1)
+ executor.run()
+
+ self.assertTrue(os.path.exists(cfg.export_path))
+
+
if __name__ == '__main__':
unittest.main()