diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 2d8b198565..ef353090ce 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -8,6 +8,7 @@ import ray from jsonargparse import Namespace from loguru import logger +from ray.data import ActorPoolStrategy from ray.data._internal.util import get_compute_strategy from data_juicer.core.data import DJDataset @@ -237,9 +238,16 @@ def process_batch_arrow(table: pyarrow.Table): try: if op.use_ray_actor(): - compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.__class__, + # Repartition right before GPU actor stage to ensure enough data blocks + # Pipeline: Read(streaming) → Repartition → GPU actors + # Note: override_num_blocks cannot be passed to map_batches for actors, + # so we repartition beforehand instead + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + self.data = self.data.repartition(override_num_blocks) + + compute = ActorPoolStrategy(size=op.num_proc) + map_batches_kwargs = dict( fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, @@ -251,10 +259,10 @@ def process_batch_arrow(table: pyarrow.Table): batch_format="pyarrow", runtime_env=op.runtime_env, ) + self.data = self.data.map_batches(op.__class__, **map_batches_kwargs) else: compute = get_compute_strategy(op.process, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.process, + map_batches_kwargs = dict( batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, @@ -262,6 +270,10 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + map_batches_kwargs['override_num_blocks'] = override_num_blocks + self.data = self.data.map_batches(op.process, **map_batches_kwargs) finally: # Restore original process method if tracer and should_trace_op(tracer, op._name) and original_process: @@ -280,9 +292,16 @@ def process_batch_arrow(table: pyarrow.Table): ) cached_columns.add(Fields.stats) if op.use_ray_actor(): - compute = get_compute_strategy(op.__class__, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.__class__, + # Repartition AFTER CPU preprocessing but BEFORE GPU actor stage + # Pipeline: Read(streaming) → CPU preprocessing(streaming) → Repartition → GPU actors + # This allows Read → CPU to stream freely without pipeline barriers, + # then repartition splits collapsed blocks into enough pieces for GPU utilization + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + self.data = self.data.repartition(override_num_blocks) + + compute = ActorPoolStrategy(size=op.num_proc) + map_batches_kwargs = dict( fn_args=None, fn_kwargs=None, fn_constructor_args=op._init_args, @@ -294,10 +313,10 @@ def process_batch_arrow(table: pyarrow.Table): batch_format="pyarrow", runtime_env=op.runtime_env, ) + self.data = self.data.map_batches(op.__class__, **map_batches_kwargs) else: compute = get_compute_strategy(op.compute_stats, concurrency=op.num_proc) - self.data = self.data.map_batches( - op.compute_stats, + map_batches_kwargs = dict( batch_size=batch_size, batch_format="pyarrow", num_cpus=op.num_cpus, @@ -305,6 +324,10 @@ def process_batch_arrow(table: pyarrow.Table): compute=compute, runtime_env=op.runtime_env, ) + override_num_blocks = getattr(op, 'override_num_blocks', None) + if override_num_blocks is not None: + map_batches_kwargs['override_num_blocks'] = override_num_blocks + self.data = self.data.map_batches(op.compute_stats, **map_batches_kwargs) if op.stats_export_path is not None: self.data.write_json(op.stats_export_path, force_ascii=False) # Wrap process method with tracer for sample-level collection diff --git a/data_juicer/core/elasticjuicer/__init__.py b/data_juicer/core/elasticjuicer/__init__.py new file mode 100644 index 0000000000..715b79fbd9 --- /dev/null +++ b/data_juicer/core/elasticjuicer/__init__.py @@ -0,0 +1,33 @@ +""" +ElasticJuicer: Adaptive Resource Scheduling for Data-Juicer + +A system that provides dynamic resource management and OOM prevention for +multimodal data processing pipelines. +""" + +__version__ = "0.1.0" + +# Core ElasticJuicer classes +from .elastic_juicer import ElasticJuicer +from .scheduler.scheduler_config import SchedulerConfig +from .scheduler.tower import Tower +from .scheduler.captain import Captain, CaptainPool +from .scheduler.micro_scheduler import MicroScheduler + + +# Lazy import for tuner (requires ray dependency) +def get_pbt_tuner(): + from .tuner.pbt_tuner import PBTTuner + return PBTTuner + + +__all__ = [ + "profiler", + "ElasticJuicer", + "SchedulerConfig", + "Tower", + "Captain", + "CaptainPool", + "MicroScheduler", + "get_pbt_tuner", +] diff --git a/data_juicer/core/elasticjuicer/elastic_juicer.py b/data_juicer/core/elasticjuicer/elastic_juicer.py new file mode 100644 index 0000000000..b2a85f3636 --- /dev/null +++ b/data_juicer/core/elasticjuicer/elastic_juicer.py @@ -0,0 +1,447 @@ +""" +ElasticJuicer: Top-level orchestrator for the complete ElasticJuicer flow. + +The complete flow consists of two phases: + OFFLINE: Ray Tune PBT to find optimal scheduling parameters + -> Outputs base_config.yaml + ONLINE: Adaptive Tower (Macro) + Captain (Micro) scheduling + -> Tower runs rebalance loop, Captains run PID+Pred micro-scheduling + -> Metrics feedback from Captains to Tower + +Usage: + # Complete flow + ej = ElasticJuicer() + ej.run(operators=[op1, op2, op3]) + + # Or step by step + ej = ElasticJuicer() + config = ej.run_offline_tuning(stage_names=["filter", "mapper"]) + ej.run_online(config=config, operators=[op1, op2, op3]) + + # Or skip offline, use existing config + ej = ElasticJuicer() + config = SchedulerConfig.from_yaml("base_config.yaml") + ej.run_online(config=config, operators=[op1, op2]) + + # Graceful shutdown + ej.stop() +""" + +import logging +from typing import Any, Callable, Dict, List, Optional + +from .scheduler.scheduler_config import SchedulerConfig +from .scheduler.tower import Tower, ClusterState +from .scheduler.captain import Captain, CaptainConfig, CaptainPool + +logger = logging.getLogger(__name__) + + +def _get_default_cluster_state() -> ClusterState: + """Create a default ClusterState based on current system resources. + + Returns: + ClusterState with detected or default resource values. + """ + try: + import psutil + + cpu_count = psutil.cpu_count(logical=True) or 4 + memory_info = psutil.virtual_memory() + total_memory_mb = memory_info.total / (1024 * 1024) + available_memory_mb = memory_info.available / (1024 * 1024) + + return ClusterState( + total_cpu_cores=cpu_count, + total_memory_mb=total_memory_mb, + total_gpu_count=0, # GPU detection requires additional libraries + available_cpu_cores=float(cpu_count), + available_memory_mb=available_memory_mb, + available_gpus=0.0, + ) + except ImportError: + # Fallback to sensible defaults if psutil not available + return ClusterState( + total_cpu_cores=4, + total_memory_mb=8192.0, + total_gpu_count=0, + available_cpu_cores=4.0, + available_memory_mb=6144.0, + available_gpus=0.0, + ) + + +class ElasticJuicer: + """ + Top-level orchestrator for the complete ElasticJuicer flow. + + ElasticJuicer combines OFFLINE (PBT hyperparameter tuning) and ONLINE + (adaptive bi-level scheduling) phases to provide automatic, + resource-efficient data processing. + + The bi-level scheduling architecture: + - Tower (macro-scheduler): Global resource allocation and rebalancing + - Captains (micro-schedulers): Per-operator batch size control with PID + + Attributes: + config: Current SchedulerConfig instance. + tower: Tower macro-scheduler instance (created during run_online). + captain_pool: CaptainPool managing per-operator Captains. + is_running: Whether the online phase is currently active. + + Example: + >>> # Complete flow with automatic tuning + >>> with ElasticJuicer() as ej: + ... ej.run(operators=[filter_op, mapper_op, dedup_op]) + + >>> # Skip tuning, use existing config + >>> ej = ElasticJuicer(config_path="base_config.yaml") + >>> ej.run_online(operators=[filter_op, mapper_op]) + >>> ej.stop() + """ + + def __init__( + self, + config: Optional[SchedulerConfig] = None, + config_path: Optional[str] = None, + cluster_state: Optional[ClusterState] = None, + ): + """Initialize ElasticJuicer. + + Args: + config: Pre-existing SchedulerConfig. If None, default config used. + config_path: Path to load config from YAML. Takes precedence over config. + cluster_state: Optional ClusterState for Tower. If None, auto-detected. + """ + # Load config from YAML if path provided, otherwise use provided config or default + if config_path is not None: + self._config = SchedulerConfig.from_yaml(config_path) + logger.info(f"Loaded config from {config_path}") + elif config is not None: + self._config = config + else: + self._config = SchedulerConfig() + + # Cluster state for Tower initialization + self._cluster_state = cluster_state + + # Runtime components (created during run_online) + self._tower: Optional[Tower] = None + self._captain_pool: Optional[CaptainPool] = None + self._captain_ids: Dict[str, str] = {} # stage_name -> captain_id mapping + + # State tracking + self._is_running: bool = False + + @property + def config(self) -> SchedulerConfig: + """Get the current SchedulerConfig.""" + return self._config + + @property + def tower(self) -> Optional[Tower]: + """Get the Tower macro-scheduler instance (None if not started).""" + return self._tower + + @property + def captain_pool(self) -> Optional[CaptainPool]: + """Get the CaptainPool managing all Captains (None if not started).""" + return self._captain_pool + + @property + def is_running(self) -> bool: + """Check if the online phase is currently running.""" + return self._is_running + + def run_offline_tuning( + self, + stage_names: Optional[List[str]] = None, + simulation_fn: Optional[Callable[[SchedulerConfig], Dict[str, float]]] = None, + num_samples: int = 8, + max_iterations: int = 50, + export_path: str = "base_config.yaml", + ) -> SchedulerConfig: + """OFFLINE Phase: Run Ray Tune PBT to find optimal params. + + Uses Population Based Training to optimize: + - PID controller parameters (kp, ki, kd) + - Memory safety buffers + - Predictor settings + - Per-stage resource allocation weights + + Args: + stage_names: Operator stage names to tune allocation weights for. + If None, an empty list is used (no per-stage weights). + simulation_fn: Custom simulation function. If None, uses default + simulation that tests batch processing with random memory. + num_samples: PBT population size (number of parallel trials). + max_iterations: Maximum training iterations per trial. + export_path: Path to save the tuned config YAML. + + Returns: + Optimized SchedulerConfig. + + Raises: + ImportError: If Ray Tune is not installed. + RuntimeError: If tuning fails. + """ + # Import locally to avoid errors when Ray is not installed + from .tuner.pbt_tuner import PBTTuner, PBTTunerConfig + + tuner_config = PBTTunerConfig( + num_samples=num_samples, + max_iterations=max_iterations, + stage_names=stage_names or [], + ) + + tuner = PBTTuner(config=tuner_config, simulation_fn=simulation_fn) + + logger.info( + f"Starting offline PBT tuning with {num_samples} samples, " + f"{max_iterations} iterations" + ) + + self._config = tuner.tune() + tuner.export_config(self._config, export_path) + + logger.info(f"Offline tuning complete. Config saved to {export_path}") + return self._config + + def run_online( + self, + operators: List[Any], + config: Optional[SchedulerConfig] = None, + ) -> None: + """ONLINE Phase: Start Adaptive Tower + Captains. + + Sets up the bi-level scheduling hierarchy: + 1. Creates Tower (macro-scheduler) with rebalance loop + 2. Creates Captains (micro-schedulers) for each operator + 3. Registers Captains with Tower + 4. Starts Tower rebalance loop + + Args: + operators: List of operator objects. Each should have a 'name' + attribute (or str representation will be used). + config: Optional config override. If None, uses self._config. + + Raises: + ValueError: If operators list is empty. + RuntimeError: If already running. + """ + # Validate inputs + if not operators: + raise ValueError("operators list cannot be empty") + + if self._is_running: + raise RuntimeError( + "ElasticJuicer is already running. Call stop() first." + ) + + # Update config if provided + if config is not None: + self._config = config + + # Get or create cluster state + cluster_state = self._cluster_state or _get_default_cluster_state() + + # Create Tower macro-scheduler + self._tower = Tower( + cluster_state=cluster_state, + target_queue_depth=100, + sla_latency_ms=5000.0, + update_interval_sec=self._config.rebalance_interval_sec, + config=self._config, + ) + + # Create CaptainPool + self._captain_pool = CaptainPool() + self._captain_ids.clear() + + # Register each operator as a stage with its own Captain + for op in operators: + op_name = getattr(op, 'name', None) + if op_name is None: + op_name = str(op) + + # Register stage in Tower (returns captain_id) + captain_id = self._tower.register_stage( + stage_name=op_name, + initial_parallelism=1, + ) + self._captain_ids[op_name] = captain_id + + # Create Captain config + captain_config = CaptainConfig( + stage_name=op_name, + initial_batch_size=self._config.initial_batch_size, + enable_micro_scheduler=self._config.enable_auto_adjust, + enable_prediction=self._config.enable_prediction, + ) + + # Create Captain via pool (adds to pool automatically) + captain = self._captain_pool.add_captain(captain_config) + + # Register Captain with Tower for metrics collection + quota broadcast + self._tower.register_captain(captain_id, captain) + + # Start Tower rebalance loop + self._tower.start() + self._is_running = True + + logger.info( + f"Online phase started with {len(operators)} operators: " + f"{[getattr(op, 'name', str(op)) for op in operators]}" + ) + + def run( + self, + operators: List[Any], + skip_offline: bool = False, + config_path: Optional[str] = None, + **offline_kwargs, + ) -> None: + """Complete flow: OFFLINE -> ONLINE. + + Runs both phases in sequence: + 1. OFFLINE: PBT tuning (unless skip_offline=True) + 2. ONLINE: Start adaptive scheduling + + Args: + operators: List of operator objects. + skip_offline: If True, skip PBT tuning and use existing config. + config_path: Path to load/save config. If skip_offline is False, + this is used as export_path. If skip_offline is True, this is + used to load an existing config. + **offline_kwargs: Additional kwargs passed to run_offline_tuning(). + Supported kwargs: simulation_fn, num_samples, max_iterations. + + Raises: + ValueError: If operators list is empty. + ImportError: If Ray not installed and skip_offline is False. + """ + if not operators: + raise ValueError("operators list cannot be empty") + + if not skip_offline: + # Extract stage names from operators + stage_names = [ + getattr(op, 'name', str(op)) for op in operators + ] + export_path = config_path or "base_config.yaml" + + self.run_offline_tuning( + stage_names=stage_names, + export_path=export_path, + **offline_kwargs, + ) + elif config_path: + # Load existing config + self._config = SchedulerConfig.from_yaml(config_path) + logger.info(f"Loaded existing config from {config_path}") + + # Start online phase + self.run_online(operators=operators) + + def stop(self) -> None: + """Graceful shutdown of all components. + + Stops the Tower rebalance loop and cleans up resources. + Safe to call even if not running. + """ + if self._tower is not None: + self._tower.stop() + + self._is_running = False + self._tower = None + self._captain_pool = None + self._captain_ids.clear() + + logger.info("ElasticJuicer stopped") + + def get_status(self) -> Dict[str, Any]: + """Get current system status. + + Returns: + Dictionary containing: + - is_running: Whether online phase is active + - config: String representation of current config + - tower_stats: Global stats from Tower (if running) + - captain_stats: Stats from all Captains (if running) + """ + status: Dict[str, Any] = { + "is_running": self._is_running, + "config": str(self._config), + } + + if self._tower is not None: + status["tower_stats"] = self._tower.get_global_stats() + + if self._captain_pool is not None: + status["captain_stats"] = self._captain_pool.get_all_stats() + + return status + + def get_captain(self, stage_name: str) -> Optional[Captain]: + """Get the Captain for a specific stage. + + Args: + stage_name: Name of the operator stage. + + Returns: + Captain instance for the stage, or None if not found. + """ + if self._captain_pool is None: + return None + return self._captain_pool.get_captain(stage_name) + + def update_config(self, **kwargs) -> None: + """Update configuration parameters dynamically. + + Only updates the config object. For runtime changes to take effect, + the Tower and Captains may need to be restarted. + + Args: + **kwargs: Configuration parameters to update. See SchedulerConfig + for available parameters. + + Example: + >>> ej.update_config(target_memory_utilization=0.9, pid_kp=0.8) + """ + from dataclasses import fields + + valid_fields = {f.name for f in fields(SchedulerConfig)} + + for key, value in kwargs.items(): + if key not in valid_fields: + logger.warning(f"Unknown config parameter ignored: {key}") + continue + setattr(self._config, key, value) + + logger.info(f"Config updated with: {kwargs}") + + def save_config(self, path: str) -> None: + """Save current configuration to a YAML file. + + Args: + path: Output file path for the YAML config. + """ + self._config.to_yaml(path) + logger.info(f"Config saved to {path}") + + # Context manager support + def __enter__(self) -> 'ElasticJuicer': + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + """Context manager exit - ensures graceful shutdown.""" + self.stop() + return False + + def __repr__(self) -> str: + """String representation of ElasticJuicer instance.""" + return ( + f"ElasticJuicer(is_running={self._is_running}, " + f"stages={list(self._captain_ids.keys()) if self._captain_ids else []})" + ) diff --git a/data_juicer/core/elasticjuicer/predictor/__init__.py b/data_juicer/core/elasticjuicer/predictor/__init__.py new file mode 100644 index 0000000000..2faac3c81b --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/__init__.py @@ -0,0 +1,18 @@ +""" +Memory Prediction Module + +Provides: +- Online learning models for memory prediction +- Feature extraction from data samples +- Prediction with confidence intervals +- Safety margin calculations +""" + +from .memory_predictor import MemoryPredictor, PredictionResult +from .feature_extractor import FeatureExtractor + +__all__ = [ + "MemoryPredictor", + "PredictionResult", + "FeatureExtractor", +] diff --git a/data_juicer/core/elasticjuicer/predictor/feature_extractor.py b/data_juicer/core/elasticjuicer/predictor/feature_extractor.py new file mode 100644 index 0000000000..10438ad87b --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/feature_extractor.py @@ -0,0 +1,294 @@ +""" +Feature Extractor for Memory Prediction + +Extracts relevant features from data samples to predict memory usage: +- Text: length, num_tokens, special_chars +- Image: width, height, channels, format +- Video: resolution, frame_count, fps, duration +- Audio: sample_rate, duration, channels + +Based on Report Section 3.3 - Prediction Model +""" + +from typing import Dict, Any, List, Optional +from dataclasses import dataclass +import re + + +@dataclass +class SampleFeatures: + """Features extracted from a single sample""" + # Common features + batch_size: int = 1 + modality: str = "text" # text, image, video, audio, multimodal + + # Text features + text_length: Optional[int] = None + num_tokens: Optional[int] = None + + # Image features + image_width: Optional[int] = None + image_height: Optional[int] = None + image_channels: Optional[int] = None + num_images: Optional[int] = 0 + + # Video features + video_width: Optional[int] = None + video_height: Optional[int] = None + frame_count: Optional[int] = None + fps: Optional[float] = None + num_videos: Optional[int] = 0 + + # Audio features + audio_sample_rate: Optional[int] = None + audio_duration: Optional[float] = None + num_audios: Optional[int] = 0 + + # Derived features + total_pixels: Optional[int] = None # For images/videos + estimated_size_mb: Optional[float] = None # Rough size estimate + + def to_feature_vector(self) -> List[float]: + """Convert to numerical feature vector for ML models""" + features = [ + float(self.batch_size), + # Text + float(self.text_length or 0), + float(self.num_tokens or 0), + # Image + float(self.image_width or 0), + float(self.image_height or 0), + float(self.image_channels or 0), + float(self.num_images or 0), + # Video + float(self.video_width or 0), + float(self.video_height or 0), + float(self.frame_count or 0), + float(self.fps or 0), + float(self.num_videos or 0), + # Audio + float(self.audio_sample_rate or 0), + float(self.audio_duration or 0), + float(self.num_audios or 0), + # Derived + float(self.total_pixels or 0), + float(self.estimated_size_mb or 0), + ] + return features + + @staticmethod + def feature_names() -> List[str]: + """Get names of features in the vector""" + return [ + 'batch_size', + 'text_length', 'num_tokens', + 'image_width', 'image_height', 'image_channels', 'num_images', + 'video_width', 'video_height', 'frame_count', 'fps', 'num_videos', + 'audio_sample_rate', 'audio_duration', 'num_audios', + 'total_pixels', 'estimated_size_mb' + ] + + +class FeatureExtractor: + """ + Extracts memory-relevant features from Data-Juicer samples. + + Handles different modalities and data formats. + """ + + def __init__(self): + pass + + def extract_from_sample(self, sample: Dict[str, Any]) -> SampleFeatures: + """ + Extract features from a single sample. + + Args: + sample: Data-Juicer sample dictionary + + Returns: + SampleFeatures object + """ + features = SampleFeatures(batch_size=1) + + # Determine modality + has_text = bool('text' in sample and sample['text']) + has_images = bool('images' in sample and sample['images']) + has_videos = bool('videos' in sample and sample['videos']) + has_audios = bool('audios' in sample and sample['audios']) + + modality_count = sum([has_text, has_images, has_videos, has_audios]) + if modality_count > 1: + features.modality = "multimodal" + elif has_text: + features.modality = "text" + elif has_images: + features.modality = "image" + elif has_videos: + features.modality = "video" + elif has_audios: + features.modality = "audio" + + # Extract text features + if has_text: + text = sample['text'] + features.text_length = len(text) + # Simple tokenization (space-based) + features.num_tokens = len(text.split()) + + # Extract image features + if has_images: + images = sample['images'] + features.num_images = len(images) if isinstance(images, list) else 1 + # Try to get image metadata if available + if 'image_metadata' in sample: + meta = sample['image_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] # Use first image + features.image_width = meta.get('width') + features.image_height = meta.get('height') + features.image_channels = meta.get('channels', 3) + + if features.image_width and features.image_height: + features.total_pixels = features.image_width * features.image_height * features.num_images + + # Extract video features + if has_videos: + videos = sample['videos'] + features.num_videos = len(videos) if isinstance(videos, list) else 1 + # Try to get video metadata + if 'video_metadata' in sample: + meta = sample['video_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] # Use first video + features.video_width = meta.get('width') + features.video_height = meta.get('height') + features.frame_count = meta.get('frame_count') + features.fps = meta.get('fps') + + if features.video_width and features.video_height and features.frame_count: + features.total_pixels = (features.video_width * features.video_height * + features.frame_count * features.num_videos) + + # Extract audio features + if has_audios: + audios = sample['audios'] + features.num_audios = len(audios) if isinstance(audios, list) else 1 + if 'audio_metadata' in sample: + meta = sample['audio_metadata'] + if isinstance(meta, list) and meta: + meta = meta[0] + features.audio_sample_rate = meta.get('sample_rate') + features.audio_duration = meta.get('duration') + + # Estimate rough size in MB + features.estimated_size_mb = self._estimate_size(features) + + return features + + def extract_from_batch(self, batch: Dict[str, Any]) -> SampleFeatures: + """ + Extract features from a batched sample. + + Args: + batch: Batched data dictionary where values are lists + + Returns: + SampleFeatures object (aggregated) + """ + # Determine batch size + batch_size = 0 + for value in batch.values(): + if isinstance(value, list): + batch_size = len(value) + break + + if batch_size == 0: + # Not a batched format, treat as single + return self.extract_from_sample(batch) + + # Extract features from first sample and scale + first_sample = { + key: values[0] if isinstance(values, list) and values else values + for key, values in batch.items() + } + + features = self.extract_from_sample(first_sample) + features.batch_size = batch_size + + # Scale certain features + if features.estimated_size_mb: + features.estimated_size_mb *= batch_size + + return features + + def _estimate_size(self, features: SampleFeatures) -> float: + """ + Rough estimate of sample size in MB. + + This is a heuristic based on typical data sizes. + """ + size_mb = 0.0 + + # Text: ~1 byte per character + if features.text_length: + size_mb += features.text_length / (1024 * 1024) + + # Images: width * height * channels * bytes_per_pixel (typically 1-4) + if features.total_pixels and features.modality in ['image', 'multimodal']: + # Assume 3 bytes per pixel for RGB + size_mb += (features.total_pixels * 3) / (1024 * 1024) + + # Videos: similar but multiplied by frames + if features.total_pixels and features.modality == 'video': + # Videos in memory are often decoded to raw frames + size_mb += (features.total_pixels * 3) / (1024 * 1024) + + # Audio: sample_rate * duration * channels * bytes_per_sample + if features.audio_sample_rate and features.audio_duration: + # Assume 2 bytes per sample (16-bit), mono or stereo + channels = 2 + bytes_per_sample = 2 + size_mb += (features.audio_sample_rate * features.audio_duration * + channels * bytes_per_sample) / (1024 * 1024) + + return size_mb + + def analyze_batch_variance(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Analyze variance in a batch to detect skew. + + High variance indicates need for dynamic batching. + """ + if not any(isinstance(v, list) for v in batch.values()): + return {'variance': 0, 'requires_dynamic_batching': False} + + # Extract features for each sample in batch + batch_size = len(batch[next(iter(batch))]) + sizes = [] + + for i in range(batch_size): + sample = {k: (v[i] if isinstance(v, list) else v) for k, v in batch.items()} + features = self.extract_from_sample(sample) + if features.estimated_size_mb: + sizes.append(features.estimated_size_mb) + + if not sizes: + return {'variance': 0, 'requires_dynamic_batching': False} + + import numpy as np + variance = float(np.var(sizes)) + mean_size = float(np.mean(sizes)) + coef_variation = variance / mean_size if mean_size > 0 else 0 + + # High coefficient of variation suggests dynamic batching + requires_dynamic = coef_variation > 0.5 + + return { + 'variance': variance, + 'mean_size_mb': mean_size, + 'min_size_mb': float(np.min(sizes)), + 'max_size_mb': float(np.max(sizes)), + 'coef_variation': coef_variation, + 'requires_dynamic_batching': requires_dynamic, + } diff --git a/data_juicer/core/elasticjuicer/predictor/memory_predictor.py b/data_juicer/core/elasticjuicer/predictor/memory_predictor.py new file mode 100644 index 0000000000..81d8e585e1 --- /dev/null +++ b/data_juicer/core/elasticjuicer/predictor/memory_predictor.py @@ -0,0 +1,301 @@ +""" +Memory Predictor with Online Learning + +Predicts memory usage for operators based on sample features. +Uses online learning to adapt to changing data distributions. + +Based on: +- Autothrottle: Online learning for resource prediction +- Report Section 3.3: Prediction Model +""" + +from typing import Optional, List, Tuple +from dataclasses import dataclass +import numpy as np +from collections import deque + +from .feature_extractor import SampleFeatures, FeatureExtractor + + +@dataclass +class PredictionResult: + """Result of memory prediction""" + predicted_memory_mb: float + confidence_lower: float # Lower bound of confidence interval + confidence_upper: float # Upper bound of confidence interval + prediction_error_history: Optional[float] = None # Recent prediction error + + def get_safe_prediction(self, safety_margin: float = 0.9) -> float: + """ + Get conservative prediction with safety margin. + + Uses upper confidence bound to be safe. + """ + return self.confidence_upper / safety_margin + + +class MemoryPredictor: + """ + Online learning model for memory prediction. + + Features: + - Incremental learning from new observations + - Confidence intervals for predictions + - Automatic model retraining + - Handles different operator types + """ + + def __init__( + self, + op_name: str, + window_size: int = 100, + confidence_level: float = 0.95, + min_samples_for_prediction: int = 5, + ): + """ + Initialize memory predictor. + + Args: + op_name: Operator name + window_size: Number of recent samples to keep for online learning + confidence_level: Confidence level for prediction intervals (default 95%) + min_samples_for_prediction: Minimum samples needed before making predictions + """ + self.op_name = op_name + self.window_size = window_size + self.confidence_level = confidence_level + self.min_samples_for_prediction = min_samples_for_prediction + + # Online learning data + self.feature_history = deque(maxlen=window_size) + self.memory_history = deque(maxlen=window_size) + self.error_history = deque(maxlen=window_size) + + # Model parameters (online linear regression) + self.weights: Optional[np.ndarray] = None + self.intercept: float = 0.0 + + # Feature extractor + self.feature_extractor = FeatureExtractor() + + # Statistics + self.total_predictions = 0 + self.total_updates = 0 + + def observe(self, features: SampleFeatures, actual_memory_mb: float): + """ + Observe a new data point and update the model. + + This is the core of online learning - model adapts as new data arrives. + + Args: + features: Sample features + actual_memory_mb: Actual memory used + """ + feature_vec = np.array(features.to_feature_vector()) + + # Store observation + self.feature_history.append(feature_vec) + self.memory_history.append(actual_memory_mb) + self.total_updates += 1 + + # Calculate prediction error if we had a model + if self.weights is not None: + predicted = self._predict_from_vector(feature_vec) + error = abs(predicted - actual_memory_mb) + self.error_history.append(error) + + # Retrain model if we have enough samples + if len(self.feature_history) >= self.min_samples_for_prediction: + self._retrain_model() + + def predict(self, features: SampleFeatures) -> Optional[PredictionResult]: + """ + Predict memory usage for given features. + + Args: + features: Sample features + + Returns: + PredictionResult with prediction and confidence bounds, or None if not enough data + """ + if len(self.feature_history) < self.min_samples_for_prediction: + return None + + if self.weights is None: + return None + + feature_vec = np.array(features.to_feature_vector()) + predicted = self._predict_from_vector(feature_vec) + + # Calculate confidence interval based on recent errors + if self.error_history: + # Use standard deviation of recent errors + std_error = np.std(list(self.error_history)) + # For 95% confidence, use ~2 standard deviations + z_score = 1.96 if self.confidence_level == 0.95 else 2.58 + margin = z_score * std_error + + confidence_lower = max(0, predicted - margin) + confidence_upper = predicted + margin + avg_error = np.mean(list(self.error_history)) + else: + # No error history yet, use conservative estimate + confidence_lower = predicted * 0.8 + confidence_upper = predicted * 1.5 + avg_error = None + + self.total_predictions += 1 + + return PredictionResult( + predicted_memory_mb=predicted, + confidence_lower=confidence_lower, + confidence_upper=confidence_upper, + prediction_error_history=avg_error, + ) + + def predict_batch_memory( + self, + sample_features: SampleFeatures, + target_batch_size: int, + ) -> Optional[PredictionResult]: + """ + Predict memory for a specific batch size. + + Scales the prediction based on batch size. + """ + # Scale features to target batch size + scaled_features = SampleFeatures(**vars(sample_features)) + scale_factor = target_batch_size / sample_features.batch_size + + scaled_features.batch_size = target_batch_size + if scaled_features.estimated_size_mb: + scaled_features.estimated_size_mb *= scale_factor + + return self.predict(scaled_features) + + def recommend_batch_size( + self, + sample_features: SampleFeatures, + available_memory_mb: float, + safety_margin: float = 0.85, + ) -> int: + """ + Recommend safe batch size given available memory. + + Uses binary search to find maximum safe batch size. + + Args: + sample_features: Features of a single sample + available_memory_mb: Available memory in MB + safety_margin: Use this fraction of available memory (default 85%) + + Returns: + Recommended batch size + """ + target_memory = available_memory_mb * safety_margin + + # Binary search for optimal batch size + low, high = 1, 1000 + best_batch_size = 1 + + for _ in range(20): # Max 20 iterations + mid = (low + high) // 2 + prediction = self.predict_batch_memory(sample_features, mid) + + if prediction is None: + # Not enough data, return conservative estimate + return 1 + + predicted_mem = prediction.get_safe_prediction(safety_margin) + + if predicted_mem <= target_memory: + best_batch_size = mid + low = mid + 1 + else: + high = mid - 1 + + return max(1, best_batch_size) + + def _predict_from_vector(self, feature_vec: np.ndarray) -> float: + """Make prediction from feature vector""" + if self.weights is None: + return 0.0 + + prediction = np.dot(feature_vec, self.weights) + self.intercept + return max(0, prediction) # Memory can't be negative + + def _retrain_model(self): + """ + Retrain the model using recent observations. + + Uses online linear regression for efficiency. + """ + if len(self.feature_history) < self.min_samples_for_prediction: + return + + # Convert to arrays + X = np.array(list(self.feature_history)) + y = np.array(list(self.memory_history)) + + try: + # Add regularization to prevent overfitting + lambda_reg = 0.01 + n_features = X.shape[1] + + # Ridge regression: (X^T X + λI)^-1 X^T y + XtX = X.T @ X + Xty = X.T @ y + + # Add regularization + XtX_reg = XtX + lambda_reg * np.eye(n_features) + + # Solve for weights + self.weights = np.linalg.solve(XtX_reg, Xty) + + # Calculate intercept (for better fit) + self.intercept = np.mean(y - X @ self.weights) + + except np.linalg.LinAlgError: + # Singular matrix, fall back to simple mean + self.weights = np.zeros(X.shape[1]) + self.intercept = np.mean(y) + + def get_model_stats(self) -> dict: + """Get statistics about the model""" + stats = { + 'op_name': self.op_name, + 'total_updates': self.total_updates, + 'total_predictions': self.total_predictions, + 'samples_in_window': len(self.feature_history), + 'model_trained': self.weights is not None, + } + + if self.error_history: + stats['avg_prediction_error_mb'] = float(np.mean(list(self.error_history))) + stats['std_prediction_error_mb'] = float(np.std(list(self.error_history))) + + if self.memory_history: + stats['avg_memory_mb'] = float(np.mean(list(self.memory_history))) + stats['peak_memory_mb'] = float(np.max(list(self.memory_history))) + + return stats + + def export_model(self) -> dict: + """Export model parameters for serialization""" + return { + 'op_name': self.op_name, + 'weights': self.weights.tolist() if self.weights is not None else None, + 'intercept': self.intercept, + 'window_size': self.window_size, + 'total_updates': self.total_updates, + 'stats': self.get_model_stats(), + } + + def import_model(self, model_data: dict): + """Import model parameters""" + self.op_name = model_data['op_name'] + if model_data['weights'] is not None: + self.weights = np.array(model_data['weights']) + self.intercept = model_data['intercept'] + self.total_updates = model_data.get('total_updates', 0) diff --git a/data_juicer/core/elasticjuicer/profiler/__init__.py b/data_juicer/core/elasticjuicer/profiler/__init__.py new file mode 100644 index 0000000000..52d7f8aedd --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/__init__.py @@ -0,0 +1,20 @@ +""" +Resource Profiling Module + +Provides: +- Lightweight resource monitoring for operators +- Operator Cost Signature (OCS) annotations +- Resource-throughput curve fitting +""" + +from .resource_monitor import ResourceMonitor, MonitoredOp +from .ocs_annotator import OCSAnnotator, OpCostSignature +from .profiling_store import ProfilingStore + +__all__ = [ + "ResourceMonitor", + "MonitoredOp", + "OCSAnnotator", + "OpCostSignature", + "ProfilingStore", +] diff --git a/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py b/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py new file mode 100644 index 0000000000..25621db4c3 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/ocs_annotator.py @@ -0,0 +1,287 @@ +""" +Operator Cost Signature (OCS) Annotator + +Provides semantic annotations for operators based on Alpa's operator modeling: +- Memory Locality: Device preference (CPU-Strong, GPU-Strong, Balanced) +- Transfer Cost: Data movement overhead (Low, Medium, High) +- Failure Cost: Recovery cost from OOM (Low, Medium, High) +- State-free: Whether operator can be safely retried + +Inspired by: +- Alpa's operator cost modeling +- ExoFlow's failure cost analysis +""" + +from enum import Enum +from dataclasses import dataclass, field +from typing import Dict, Optional, List +import json + + +class MemoryLocality(Enum): + """Device preference for operator execution""" + CPU_STRONG = "cpu_strong" # Strongly prefers CPU (e.g., regex, text filters) + GPU_STRONG = "gpu_strong" # Strongly prefers GPU (e.g., VLM, video decoding) + BALANCED = "balanced" # Can run efficiently on either + MIXED = "mixed" # Benefits from CPU-GPU cooperation + + +class TransferCost(Enum): + """Data movement overhead""" + LOW = "low" # < 1MB per sample (text, metadata) + MEDIUM = "medium" # 1-100MB per sample (images) + HIGH = "high" # > 100MB per sample (videos, large models) + + +class FailureCost(Enum): + """Recovery cost from failure""" + LOW = "low" # Fast retry, no state loss + MEDIUM = "medium" # Moderate retry cost + HIGH = "high" # Expensive recomputation (e.g., long video processing) + + +@dataclass +class OpCostSignature: + """ + Cost signature for an operator. + + This is the core of OCS profiling - semantic annotations that guide scheduling. + """ + op_name: str + op_type: str # filter, mapper, deduplicator, etc. + + # Core OCS attributes (based on Alpa) + memory_locality: MemoryLocality = MemoryLocality.BALANCED + transfer_cost: TransferCost = TransferCost.MEDIUM + failure_cost: FailureCost = FailureCost.MEDIUM + + # State properties (based on ExoFlow) + state_free: bool = True # Can be safely retried without side effects + deterministic: bool = True # Same input always produces same output + + # Resource preferences + preferred_batch_size: Optional[int] = None + min_memory_mb: Optional[float] = None + max_memory_mb: Optional[float] = None + + # Modality tags + handles_text: bool = False + handles_image: bool = False + handles_video: bool = False + handles_audio: bool = False + + # Additional metadata + notes: str = "" + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return { + 'op_name': self.op_name, + 'op_type': self.op_type, + 'memory_locality': self.memory_locality.value, + 'transfer_cost': self.transfer_cost.value, + 'failure_cost': self.failure_cost.value, + 'state_free': self.state_free, + 'deterministic': self.deterministic, + 'preferred_batch_size': self.preferred_batch_size, + 'min_memory_mb': self.min_memory_mb, + 'max_memory_mb': self.max_memory_mb, + 'handles_text': self.handles_text, + 'handles_image': self.handles_image, + 'handles_video': self.handles_video, + 'handles_audio': self.handles_audio, + 'notes': self.notes, + } + + @classmethod + def from_dict(cls, data: Dict) -> 'OpCostSignature': + """Create from dictionary""" + return cls( + op_name=data['op_name'], + op_type=data['op_type'], + memory_locality=MemoryLocality(data.get('memory_locality', 'balanced')), + transfer_cost=TransferCost(data.get('transfer_cost', 'medium')), + failure_cost=FailureCost(data.get('failure_cost', 'medium')), + state_free=data.get('state_free', True), + deterministic=data.get('deterministic', True), + preferred_batch_size=data.get('preferred_batch_size'), + min_memory_mb=data.get('min_memory_mb'), + max_memory_mb=data.get('max_memory_mb'), + handles_text=data.get('handles_text', False), + handles_image=data.get('handles_image', False), + handles_video=data.get('handles_video', False), + handles_audio=data.get('handles_audio', False), + notes=data.get('notes', ''), + ) + + +class OCSAnnotator: + """ + Annotates operators with cost signatures. + + Provides pre-defined annotations for common Data-Juicer operators + and supports custom annotations. + """ + + def __init__(self): + self.signatures: Dict[str, OpCostSignature] = {} + self._load_default_signatures() + + def _load_default_signatures(self): + """Load default OCS signatures for common Data-Juicer operators""" + + # Text Filters - CPU Strong, Low Transfer, Low Failure + text_filter_ops = [ + 'TextLengthFilter', + 'AlphanumericFilter', + 'CharacterRepetitionFilter', + 'WordRepetitionFilter', + 'SpecialCharactersFilter', + ] + for op in text_filter_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='filter', + memory_locality=MemoryLocality.CPU_STRONG, + transfer_cost=TransferCost.LOW, + failure_cost=FailureCost.LOW, + state_free=True, + deterministic=True, + handles_text=True, + notes="Lightweight text filter, CPU-bound" + ) + + # Image Operations - GPU Preferred, Medium Transfer + image_ops = [ + 'ImageFaceRatioFilter', + 'ImageAestheticFilter', + 'ImageNSFWFilter', + ] + for op in image_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='filter', + memory_locality=MemoryLocality.GPU_STRONG, + transfer_cost=TransferCost.MEDIUM, + failure_cost=FailureCost.MEDIUM, + state_free=True, + deterministic=True, + handles_image=True, + notes="Image model inference, GPU-accelerated" + ) + + # Video Operations - GPU Strong, High Transfer, High Failure + video_ops = [ + 'VideoDecoder', + 'VideoCaptioning', + 'VideoActionRecognition', + ] + for op in video_ops: + self.signatures[op] = OpCostSignature( + op_name=op, + op_type='mapper', + memory_locality=MemoryLocality.GPU_STRONG, + transfer_cost=TransferCost.HIGH, + failure_cost=FailureCost.HIGH, + state_free=True, + deterministic=True, + handles_video=True, + notes="Heavy video processing, high memory requirement" + ) + + # Deduplicators - Mixed locality, variable cost + self.signatures['DocumentDeduplicator'] = OpCostSignature( + op_name='DocumentDeduplicator', + op_type='deduplicator', + memory_locality=MemoryLocality.CPU_STRONG, + transfer_cost=TransferCost.LOW, + failure_cost=FailureCost.HIGH, + state_free=False, # Maintains hash index + deterministic=True, + handles_text=True, + notes="Hash-based dedup, stateful index" + ) + + self.signatures['ImageDeduplicator'] = OpCostSignature( + op_name='ImageDeduplicator', + op_type='deduplicator', + memory_locality=MemoryLocality.MIXED, + transfer_cost=TransferCost.MEDIUM, + failure_cost=FailureCost.HIGH, + state_free=False, + deterministic=True, + handles_image=True, + notes="Image hash computation, benefits from GPU" + ) + + def annotate(self, op_name: str, signature: OpCostSignature): + """Add or update OCS signature for an operator""" + self.signatures[op_name] = signature + + def get_signature(self, op_name: str) -> Optional[OpCostSignature]: + """Get OCS signature for an operator""" + return self.signatures.get(op_name) + + def get_all_signatures(self) -> Dict[str, OpCostSignature]: + """Get all registered signatures""" + return dict(self.signatures) + + def export_to_file(self, filepath: str): + """Export signatures to JSON file""" + data = { + name: sig.to_dict() + for name, sig in self.signatures.items() + } + with open(filepath, 'w') as f: + json.dump(data, f, indent=2) + + def import_from_file(self, filepath: str): + """Import signatures from JSON file""" + with open(filepath, 'r') as f: + data = json.load(f) + + for name, sig_dict in data.items(): + self.signatures[name] = OpCostSignature.from_dict(sig_dict) + + def infer_signature(self, op_name: str, op_type: str, **hints) -> OpCostSignature: + """ + Infer OCS signature from operator name and hints. + + This provides a best-effort annotation for unknown operators. + """ + # Default values + locality = MemoryLocality.BALANCED + transfer = TransferCost.MEDIUM + failure = FailureCost.MEDIUM + + # Infer from name patterns + op_lower = op_name.lower() + + if 'video' in op_lower: + locality = MemoryLocality.GPU_STRONG + transfer = TransferCost.HIGH + failure = FailureCost.HIGH + elif 'image' in op_lower: + locality = MemoryLocality.GPU_STRONG + transfer = TransferCost.MEDIUM + elif 'text' in op_lower or 'word' in op_lower or 'character' in op_lower: + locality = MemoryLocality.CPU_STRONG + transfer = TransferCost.LOW + failure = FailureCost.LOW + + # Apply hints if provided + if 'accelerator' in hints and hints['accelerator'] == 'cuda': + locality = MemoryLocality.GPU_STRONG + + return OpCostSignature( + op_name=op_name, + op_type=op_type, + memory_locality=locality, + transfer_cost=transfer, + failure_cost=failure, + handles_text='text' in op_lower, + handles_image='image' in op_lower, + handles_video='video' in op_lower, + handles_audio='audio' in op_lower, + notes="Auto-inferred signature" + ) diff --git a/data_juicer/core/elasticjuicer/profiler/profiling_store.py b/data_juicer/core/elasticjuicer/profiler/profiling_store.py new file mode 100644 index 0000000000..7c4a598834 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/profiling_store.py @@ -0,0 +1,302 @@ +""" +Profiling Store + +Persistent storage and query interface for: +- Resource-throughput curves +- OCS signatures +- Historical performance data + +Supports online learning and model updating. +""" + +import json +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass, asdict +import numpy as np +from scipy.optimize import curve_fit + +from .resource_monitor import OpExecutionStats, ResourceSnapshot +from .ocs_annotator import OpCostSignature + + +@dataclass +class ResourceThroughputCurve: + """ + Resource-throughput relationship for an operator. + + Models T(r, b) where: + - T = throughput (samples/sec) + - r = resource allocation (memory, GPU) + - b = batch size + """ + op_name: str + # Curve parameters (fitted from data) + coefficients: Dict[str, float] + # Model type: 'linear', 'polynomial', 'power' + model_type: str = 'linear' + # Goodness of fit + r_squared: float = 0.0 + # Sample count used for fitting + n_samples: int = 0 + + def predict_throughput(self, batch_size: int, memory_mb: float) -> float: + """Predict throughput given batch size and memory""" + if self.model_type == 'linear': + # T = a * batch_size + b * memory + c + a = self.coefficients.get('batch_coef', 0) + b = self.coefficients.get('memory_coef', 0) + c = self.coefficients.get('intercept', 0) + return max(0, a * batch_size + b * memory_mb + c) + + elif self.model_type == 'power': + # T = a * batch_size^b + a = self.coefficients.get('scale', 1) + b = self.coefficients.get('power', 1) + return a * (batch_size ** b) + + return 0.0 + + def to_dict(self) -> Dict: + """Convert to dictionary""" + return asdict(self) + + @classmethod + def from_dict(cls, data: Dict) -> 'ResourceThroughputCurve': + """Create from dictionary""" + return cls(**data) + + +class ProfilingStore: + """ + Persistent store for operator profiling data. + + Provides: + - Storage and retrieval of execution stats + - Resource-throughput curve fitting + - Online model updates + - Query interface for schedulers + """ + + def __init__(self, storage_dir: str = "./elastic_juicer_profiles"): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + + # In-memory caches + self.execution_stats: Dict[str, OpExecutionStats] = {} + self.ocs_signatures: Dict[str, OpCostSignature] = {} + self.throughput_curves: Dict[str, ResourceThroughputCurve] = {} + + # Load existing data + self._load_all() + + def _load_all(self): + """Load all stored profiles""" + # Load execution stats + stats_file = self.storage_dir / "execution_stats.pkl" + if stats_file.exists(): + with open(stats_file, 'rb') as f: + self.execution_stats = pickle.load(f) + + # Load OCS signatures + ocs_file = self.storage_dir / "ocs_signatures.json" + if ocs_file.exists(): + with open(ocs_file, 'r') as f: + data = json.load(f) + self.ocs_signatures = { + name: OpCostSignature.from_dict(sig) + for name, sig in data.items() + } + + # Load throughput curves + curves_file = self.storage_dir / "throughput_curves.json" + if curves_file.exists(): + with open(curves_file, 'r') as f: + data = json.load(f) + self.throughput_curves = { + name: ResourceThroughputCurve.from_dict(curve) + for name, curve in data.items() + } + + def save_all(self): + """Persist all profiles to disk""" + # Save execution stats + stats_file = self.storage_dir / "execution_stats.pkl" + with open(stats_file, 'wb') as f: + pickle.dump(self.execution_stats, f) + + # Save OCS signatures + ocs_file = self.storage_dir / "ocs_signatures.json" + with open(ocs_file, 'w') as f: + data = { + name: sig.to_dict() + for name, sig in self.ocs_signatures.items() + } + json.dump(data, f, indent=2) + + # Save throughput curves + curves_file = self.storage_dir / "throughput_curves.json" + with open(curves_file, 'w') as f: + data = { + name: curve.to_dict() + for name, curve in self.throughput_curves.items() + } + json.dump(data, f, indent=2) + + def update_execution_stats(self, op_name: str, stats: OpExecutionStats): + """Update execution statistics for an operator""" + self.execution_stats[op_name] = stats + self._fit_throughput_curve(op_name, stats) + + def update_ocs_signature(self, op_name: str, signature: OpCostSignature): + """Update OCS signature for an operator""" + self.ocs_signatures[op_name] = signature + + def get_execution_stats(self, op_name: str) -> Optional[OpExecutionStats]: + """Get execution statistics for an operator""" + return self.execution_stats.get(op_name) + + def get_ocs_signature(self, op_name: str) -> Optional[OpCostSignature]: + """Get OCS signature for an operator""" + return self.ocs_signatures.get(op_name) + + def get_throughput_curve(self, op_name: str) -> Optional[ResourceThroughputCurve]: + """Get resource-throughput curve for an operator""" + return self.throughput_curves.get(op_name) + + def _fit_throughput_curve(self, op_name: str, stats: OpExecutionStats): + """ + Fit resource-throughput curve from execution statistics. + + Uses online learning approach (inspired by Autothrottle). + """ + if len(stats.snapshots) < 5: + # Not enough data points + return + + # Extract features and target + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + throughputs = np.array([s.throughput for s in stats.snapshots]) + + # Filter out invalid data + valid_idx = throughputs > 0 + if valid_idx.sum() < 5: + return + + batch_sizes = batch_sizes[valid_idx] + memories = memories[valid_idx] + throughputs = throughputs[valid_idx] + + try: + # Try linear model first: T = a*batch + b*mem + c + X = np.column_stack([batch_sizes, memories, np.ones_like(batch_sizes)]) + coeffs, residuals, _, _ = np.linalg.lstsq(X, throughputs, rcond=None) + + # Calculate R² + ss_res = residuals[0] if len(residuals) > 0 else 0 + ss_tot = np.sum((throughputs - np.mean(throughputs)) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + curve = ResourceThroughputCurve( + op_name=op_name, + coefficients={ + 'batch_coef': float(coeffs[0]), + 'memory_coef': float(coeffs[1]), + 'intercept': float(coeffs[2]), + }, + model_type='linear', + r_squared=float(r_squared), + n_samples=len(batch_sizes), + ) + + self.throughput_curves[op_name] = curve + + except Exception as e: + # Fitting failed, use simple mean + pass + + def predict_memory_for_batch(self, op_name: str, batch_size: int) -> Optional[float]: + """ + Predict memory usage for a given batch size. + + Based on historical data with online learning. + """ + stats = self.execution_stats.get(op_name) + if not stats or len(stats.snapshots) < 3: + return None + + # Simple linear regression: memory = a * batch_size + b + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + + try: + # Fit linear model + coeffs = np.polyfit(batch_sizes, memories, deg=1) + predicted = coeffs[0] * batch_size + coeffs[1] + return float(predicted) + except Exception: + # Fall back to average + return float(np.mean(memories)) + + def get_safe_batch_size(self, op_name: str, available_memory_mb: float, + safety_margin: float = 0.9) -> int: + """ + Recommend safe batch size given available memory. + + Args: + op_name: Operator name + available_memory_mb: Available memory in MB + safety_margin: Use only this fraction of available memory (default 90%) + + Returns: + Recommended batch size + """ + stats = self.execution_stats.get(op_name) + if not stats or len(stats.snapshots) < 3: + return 1 # Conservative default + + # Find batch sizes and their memory usage + batch_sizes = np.array([s.batch_size for s in stats.snapshots]) + memories = np.array([s.memory_mb for s in stats.snapshots]) + + # Calculate memory per sample + mem_per_sample = memories / batch_sizes + avg_mem_per_sample = np.median(mem_per_sample) # Use median for robustness + + # Calculate safe batch size + target_memory = available_memory_mb * safety_margin + safe_batch = int(target_memory / avg_mem_per_sample) + + return max(1, safe_batch) + + def export_report(self, output_file: str): + """Export profiling report as markdown""" + lines = ["# ElasticJuicer Profiling Report\n"] + + lines.append("## Operator Execution Statistics\n") + for op_name, stats in sorted(self.execution_stats.items()): + lines.append(f"### {op_name}\n") + lines.append(f"- Total Samples: {stats.total_samples}") + lines.append(f"- Total Batches: {stats.total_batches}") + lines.append(f"- Avg Latency: {stats.avg_latency_ms:.2f} ms") + lines.append(f"- P95 Latency: {stats.p95_latency_ms:.2f} ms") + lines.append(f"- Avg Throughput: {stats.avg_throughput:.2f} samples/s") + lines.append(f"- Peak Memory: {stats.peak_memory_mb:.2f} MB") + if stats.peak_gpu_memory_mb: + lines.append(f"- Peak GPU Memory: {stats.peak_gpu_memory_mb:.2f} MB") + lines.append("") + + lines.append("\n## OCS Signatures\n") + for op_name, sig in sorted(self.ocs_signatures.items()): + lines.append(f"### {op_name}") + lines.append(f"- Type: {sig.op_type}") + lines.append(f"- Memory Locality: {sig.memory_locality.value}") + lines.append(f"- Transfer Cost: {sig.transfer_cost.value}") + lines.append(f"- Failure Cost: {sig.failure_cost.value}") + lines.append(f"- State Free: {sig.state_free}") + lines.append("") + + with open(output_file, 'w') as f: + f.writelines(line + '\n' for line in lines) diff --git a/data_juicer/core/elasticjuicer/profiler/resource_monitor.py b/data_juicer/core/elasticjuicer/profiler/resource_monitor.py new file mode 100644 index 0000000000..a93fc97cc1 --- /dev/null +++ b/data_juicer/core/elasticjuicer/profiler/resource_monitor.py @@ -0,0 +1,263 @@ +""" +Resource Monitor + +Lightweight monitoring for Data-Juicer operators to collect: +- Batch size +- Resource usage (CPU, GPU memory, RAM) +- Processing latency +- Throughput + +Based on Pollux-style agent monitoring. +""" + +import time +import psutil +import threading +from dataclasses import dataclass, field +from typing import Optional, Dict, List, Any +from collections import defaultdict +import numpy as np + +try: + import GPUtil + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + + +@dataclass +class ResourceSnapshot: + """Single measurement of resource usage""" + timestamp: float + batch_size: int + # CPU metrics + cpu_percent: float + memory_mb: float + # GPU metrics (if available) + gpu_memory_mb: Optional[float] = None + gpu_utilization: Optional[float] = None + # Performance metrics + latency_ms: float = 0.0 + throughput: float = 0.0 # samples/sec + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary""" + return { + 'timestamp': self.timestamp, + 'batch_size': self.batch_size, + 'cpu_percent': self.cpu_percent, + 'memory_mb': self.memory_mb, + 'gpu_memory_mb': self.gpu_memory_mb, + 'gpu_utilization': self.gpu_utilization, + 'latency_ms': self.latency_ms, + 'throughput': self.throughput, + } + + +@dataclass +class OpExecutionStats: + """Aggregated statistics for an operator""" + op_name: str + total_samples: int = 0 + total_batches: int = 0 + avg_latency_ms: float = 0.0 + p95_latency_ms: float = 0.0 + p99_latency_ms: float = 0.0 + avg_throughput: float = 0.0 + avg_memory_mb: float = 0.0 + peak_memory_mb: float = 0.0 + avg_gpu_memory_mb: Optional[float] = None + peak_gpu_memory_mb: Optional[float] = None + snapshots: List[ResourceSnapshot] = field(default_factory=list) + + def update(self, snapshot: ResourceSnapshot): + """Update statistics with new snapshot""" + self.snapshots.append(snapshot) + self.total_samples += snapshot.batch_size + self.total_batches += 1 + + # Update averages + latencies = [s.latency_ms for s in self.snapshots] + self.avg_latency_ms = np.mean(latencies) + self.p95_latency_ms = np.percentile(latencies, 95) + self.p99_latency_ms = np.percentile(latencies, 99) + + throughputs = [s.throughput for s in self.snapshots if s.throughput > 0] + if throughputs: + self.avg_throughput = np.mean(throughputs) + + memories = [s.memory_mb for s in self.snapshots] + self.avg_memory_mb = np.mean(memories) + self.peak_memory_mb = max(memories) + + if snapshot.gpu_memory_mb is not None: + gpu_mems = [s.gpu_memory_mb for s in self.snapshots if s.gpu_memory_mb is not None] + if gpu_mems: + self.avg_gpu_memory_mb = np.mean(gpu_mems) + self.peak_gpu_memory_mb = max(gpu_mems) + + +class ResourceMonitor: + """ + Lightweight resource monitor for operators. + + Inspired by PolluxAgent - measures resource-throughput curves in real-time. + """ + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.stats_by_op: Dict[str, OpExecutionStats] = defaultdict(OpExecutionStats) + self._lock = threading.Lock() + self.process = psutil.Process() + + def measure_execution(self, op_name: str, batch_size: int): + """ + Context manager to measure operator execution. + + Usage: + with monitor.measure_execution("my_filter", batch_size=100): + # Process batch + result = op.process(batch) + """ + return ExecutionContext(self, op_name, batch_size) + + def record_snapshot(self, op_name: str, snapshot: ResourceSnapshot): + """Record a resource snapshot for an operator""" + if not self.enabled: + return + + with self._lock: + if op_name not in self.stats_by_op: + self.stats_by_op[op_name] = OpExecutionStats(op_name=op_name) + self.stats_by_op[op_name].update(snapshot) + + def get_stats(self, op_name: str) -> Optional[OpExecutionStats]: + """Get statistics for a specific operator""" + return self.stats_by_op.get(op_name) + + def get_all_stats(self) -> Dict[str, OpExecutionStats]: + """Get statistics for all operators""" + return dict(self.stats_by_op) + + def clear(self): + """Clear all collected statistics""" + with self._lock: + self.stats_by_op.clear() + + def _get_current_resources(self) -> Dict[str, Any]: + """Get current resource usage""" + cpu_percent = self.process.cpu_percent() + memory_mb = self.process.memory_info().rss / (1024 * 1024) + + gpu_memory_mb = None + gpu_utilization = None + + if GPU_AVAILABLE: + try: + gpus = GPUtil.getGPUs() + if gpus: + # Use first GPU for now + gpu = gpus[0] + gpu_memory_mb = gpu.memoryUsed + gpu_utilization = gpu.load * 100 + except Exception: + pass + + return { + 'cpu_percent': cpu_percent, + 'memory_mb': memory_mb, + 'gpu_memory_mb': gpu_memory_mb, + 'gpu_utilization': gpu_utilization, + } + + +class ExecutionContext: + """Context manager for measuring operator execution""" + + def __init__(self, monitor: ResourceMonitor, op_name: str, batch_size: int): + self.monitor = monitor + self.op_name = op_name + self.batch_size = batch_size + self.start_time = None + + def __enter__(self): + if self.monitor.enabled: + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.monitor.enabled or self.start_time is None: + return + + # Calculate latency + end_time = time.time() + latency_s = end_time - self.start_time + latency_ms = latency_s * 1000 + + # Calculate throughput + throughput = self.batch_size / latency_s if latency_s > 0 else 0 + + # Get resource usage + resources = self.monitor._get_current_resources() + + # Create snapshot + snapshot = ResourceSnapshot( + timestamp=end_time, + batch_size=self.batch_size, + cpu_percent=resources['cpu_percent'], + memory_mb=resources['memory_mb'], + gpu_memory_mb=resources['gpu_memory_mb'], + gpu_utilization=resources['gpu_utilization'], + latency_ms=latency_ms, + throughput=throughput, + ) + + # Record snapshot + self.monitor.record_snapshot(self.op_name, snapshot) + + +class MonitoredOp: + """ + Wrapper to inject monitoring into Data-Juicer operators. + + Usage: + original_op = SomeFilter(**config) + monitored_op = MonitoredOp(original_op, monitor) + """ + + def __init__(self, operator, monitor: ResourceMonitor): + self.operator = operator + self.monitor = monitor + self.op_name = operator.__class__.__name__ + + def __getattr__(self, name): + """Delegate attribute access to wrapped operator""" + return getattr(self.operator, name) + + def process(self, *args, **kwargs): + """Wrap process method with monitoring""" + # Estimate batch size + batch_size = self._estimate_batch_size(args, kwargs) + + with self.monitor.measure_execution(self.op_name, batch_size): + return self.operator.process(*args, **kwargs) + + def compute_stats(self, *args, **kwargs): + """Wrap compute_stats method with monitoring (for filters)""" + batch_size = self._estimate_batch_size(args, kwargs) + + with self.monitor.measure_execution(f"{self.op_name}_stats", batch_size): + return self.operator.compute_stats(*args, **kwargs) + + def _estimate_batch_size(self, args, kwargs) -> int: + """Estimate batch size from arguments""" + # For single sample: return 1 + # For batched: try to extract from first argument (usually a dict/dataset) + if args: + sample = args[0] + if isinstance(sample, dict): + # Check if it's batched data + for value in sample.values(): + if isinstance(value, list): + return len(value) + return 1 diff --git a/data_juicer/core/elasticjuicer/scheduler/__init__.py b/data_juicer/core/elasticjuicer/scheduler/__init__.py new file mode 100644 index 0000000000..013666352b --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/__init__.py @@ -0,0 +1,22 @@ +""" +Scheduler Module + +Provides: +- Micro-Scheduler: JABAS-style PID control for batch size +- Macro-Scheduler: Tower/Captain bi-level architecture +""" + +from .micro_scheduler import MicroScheduler, PIDController, BatchSizeController +from .scheduler_config import SchedulerConfig +from .tower import Tower +from .captain import Captain, CaptainPool + +__all__ = [ + "MicroScheduler", + "PIDController", + "BatchSizeController", + "SchedulerConfig", + "Tower", + "Captain", + "CaptainPool", +] diff --git a/data_juicer/core/elasticjuicer/scheduler/captain.py b/data_juicer/core/elasticjuicer/scheduler/captain.py new file mode 100644 index 0000000000..ce485b05d6 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/captain.py @@ -0,0 +1,496 @@ +""" +Captain: Local Per-Operator Scheduler for ElasticJuicer + +Based on Autothrottle's bi-level architecture, Captain is the local scheduler +that manages a single operator stage under Tower's global constraints. + +Key responsibilities: +1. Execute Micro-Scheduler (JABAS-style batch size control) within quota +2. Report metrics to Tower +3. Enforce resource quotas from Tower +4. Handle local OOM events and recovery +5. Coordinate with adjacent Captains in pipeline + +References: +- Autothrottle (NSDI 2024): Bi-level control for SLO-targeted microservices +- Report Section 5.1: Tower/Captain architecture +""" + +import time +from dataclasses import dataclass +from typing import Optional, Callable, List, Dict, TYPE_CHECKING +from collections import deque +import psutil + +if TYPE_CHECKING: + from .tower import StageMetrics + +from ..scheduler.micro_scheduler import MicroScheduler, BatchSizeController +from ..scheduler.tower import ResourceQuota, StageMetrics, TopologyMode +from ..profiler.resource_monitor import ResourceMonitor, ResourceSnapshot +from ..predictor.memory_predictor import MemoryPredictor +from ..predictor.feature_extractor import FeatureExtractor + + +@dataclass +class CaptainConfig: + """Configuration for Captain scheduler""" + stage_name: str + initial_batch_size: int = 32 + report_interval_sec: float = 1.0 # How often to report to Tower + quota_check_interval_sec: float = 0.5 # How often to check quota + enable_micro_scheduler: bool = True # Use JABAS-style control + enable_prediction: bool = True # Use memory prediction + emergency_backoff_ratio: float = 0.5 # OOM backoff ratio + + +class Captain: + """ + Local Per-Operator Scheduler (Captain from Autothrottle architecture) + + Captain manages a single operator stage, executing micro-scheduling decisions + (batch size adjustment) within the global constraints set by Tower. + + Key design: + - Tower sets "what to achieve" (target parallelism, resource quota, SLO) + - Captain decides "how to achieve it" (batch size, local optimization) + - Bi-level decoupling enables scalability and autonomy + """ + + def __init__( + self, + config: CaptainConfig, + tower_callback: Optional[Callable[[StageMetrics], None]] = None + ): + """ + Initialize Captain local scheduler + + Args: + config: Captain configuration + tower_callback: Callback function to report metrics to Tower + """ + self.config = config + self.tower_callback = tower_callback + + # Micro-scheduler for batch size control + if config.enable_micro_scheduler: + self.micro_scheduler = MicroScheduler( + initial_batch_size=config.initial_batch_size, + max_batch_size=1024, + min_batch_size=1 + ) + else: + self.micro_scheduler = None + + # Resource monitoring + self.monitor = ResourceMonitor() + + # Memory prediction + if config.enable_prediction: + self.predictor = MemoryPredictor(op_name=config.stage_name) + self.feature_extractor = FeatureExtractor() + else: + self.predictor = None + self.feature_extractor = None + + # Current resource quota from Tower + self.quota: Optional[ResourceQuota] = None + + # Current stage metrics + self.metrics = StageMetrics(stage_name=config.stage_name) + + # Queue simulation + self.queue: deque = deque() + + # Timing + self.last_report_time = time.time() + self.last_quota_check_time = time.time() + + # Processing statistics + self.samples_processed = 0 + self.total_latency_ms = 0.0 + self.latency_history = deque(maxlen=100) + self.throughput_history = deque(maxlen=100) + + # OOM tracking + self.oom_events = 0 + self.last_oom_time = 0.0 + self._total_oom_count: int = 0 # Cumulative OOM count for metrics reporting + + # Backpressure state + self._backpressure_active: bool = False + self._backpressure_slowdown: float = 0.5 # Default slowdown ratio + + # Metrics tracking fields for Tower consumption + self._recent_throughput: float = 0.0 # samples/sec from recent batches + self._recent_latency_ms: float = 0.0 # average latency from recent batches + self._current_cpu_util: float = 0.0 # latest CPU utilization + self._current_memory_util: float = 0.0 # latest memory utilization + self._current_gpu_util: float = 0.0 # latest GPU utilization + + def set_quota(self, quota: ResourceQuota): + """ + Receive resource quota from Tower + + Args: + quota: Resource allocation from Tower. May include a 'backpressure' + attribute to signal upstream throttling. + """ + self.quota = quota + + # Check for backpressure signal from Tower + if hasattr(quota, 'backpressure'): + self._backpressure_active = quota.backpressure + + # Update micro-scheduler constraints if quota changed + if self.micro_scheduler and quota.memory_quota_mb > 0: + # Adjust max batch size based on memory quota + # Rough estimate: 100MB per sample for typical multimodal data + estimated_max_batch = max(1, int(quota.memory_quota_mb / 100)) + self.micro_scheduler.controller.max_batch_size = min( + self.micro_scheduler.controller.max_batch_size, + estimated_max_batch + ) + + def enqueue_samples(self, samples: List): + """ + Add samples to processing queue + + Args: + samples: List of samples to process + """ + for sample in samples: + self.queue.append(sample) + + # Update queue depth metric + self.metrics.queue_depth = len(self.queue) + + def process_batch( + self, + operator_func: Callable, + sample_batch: Optional[List] = None + ) -> Optional[List]: + """ + Process a batch using the operator, with Captain's orchestration + + This is the core execution loop that: + 1. Gets batch size from micro-scheduler + 2. Dequeues samples + 3. Monitors execution + 4. Updates predictor and scheduler + 5. Checks quota compliance + + Args: + operator_func: The actual operator function to execute + sample_batch: Optional pre-formed batch (if None, dequeue from queue) + + Returns: + Processed results or None if queue empty + """ + start_time = time.time() + + # Get current batch size recommendation + if self.micro_scheduler: + current_batch_size = self.micro_scheduler.controller.current_batch_size + else: + current_batch_size = self.config.initial_batch_size + + # Apply backpressure throttling if active + if self._backpressure_active: + # Throttle: reduce effective batch size and add delay + current_batch_size = max(1, int(current_batch_size * self._backpressure_slowdown)) + time.sleep(0.1) # Small delay to reduce pressure on downstream + + # Dequeue samples if not provided + if sample_batch is None: + if len(self.queue) == 0: + return None + + actual_batch_size = min(current_batch_size, len(self.queue)) + sample_batch = [self.queue.popleft() for _ in range(actual_batch_size)] + else: + actual_batch_size = len(sample_batch) + + # Extract features for prediction (if enabled) + predicted_memory_mb = None + if self.predictor and self.feature_extractor and len(sample_batch) > 0: + features = self.feature_extractor.extract_from_sample(sample_batch[0]) + features.batch_size = actual_batch_size # Set batch size + prediction = self.predictor.predict(features) + if prediction: + predicted_memory_mb = prediction.predicted_memory_mb + + # Monitor execution + with self.monitor.measure_execution( + self.config.stage_name, + actual_batch_size + ): + try: + # Execute operator + results = operator_func(sample_batch) + + # Record success + self.samples_processed += actual_batch_size + + except MemoryError as e: + # OOM event - get approximate snapshot + snapshot_approx = ResourceSnapshot( + timestamp=time.time(), + batch_size=actual_batch_size, + cpu_percent=psutil.cpu_percent(), + memory_mb=psutil.virtual_memory().used / (1024 * 1024), + latency_ms=0 + ) + self._handle_oom(actual_batch_size, snapshot_approx) + raise + + # Get recorded stats + op_stats = self.monitor.get_stats(self.config.stage_name) + if op_stats and op_stats.snapshots: + snapshot = op_stats.snapshots[-1] # Get latest snapshot + # Update predictor with actual memory usage + if self.predictor and self.feature_extractor and len(sample_batch) > 0: + features = self.feature_extractor.extract_from_sample(sample_batch[0]) + self.predictor.observe(features, snapshot.memory_mb) + + # Update micro-scheduler + if self.micro_scheduler: + self.micro_scheduler.update( + actual_memory_used=snapshot.memory_mb, + sample_features=None # Already updated predictor above + ) + + # Update metrics + latency_ms = snapshot.latency_ms + self.total_latency_ms += latency_ms + self.latency_history.append(latency_ms) + + throughput = snapshot.throughput + self.throughput_history.append(throughput) + + self.metrics.avg_latency_ms = ( + sum(self.latency_history) / len(self.latency_history) + if self.latency_history else 0 + ) + self.metrics.throughput = ( + sum(self.throughput_history) / len(self.throughput_history) + if self.throughput_history else 0 + ) + self.metrics.cpu_utilization = snapshot.cpu_percent + self.metrics.memory_utilization = ( + (snapshot.memory_mb / self.quota.memory_quota_mb * 100) + if self.quota and self.quota.memory_quota_mb > 0 + else 0 + ) + self.metrics.gpu_utilization = snapshot.gpu_utilization or 0 + self.metrics.queue_depth = len(self.queue) + self.metrics.oom_count = self.oom_events + self.metrics.current_parallelism = 1 # Single-actor for now + + # Update internal metrics tracking fields for collect_metrics() + elapsed_time = time.time() - start_time + processed_count = actual_batch_size + self._recent_throughput = processed_count / elapsed_time if elapsed_time > 0 else 0 + self._recent_latency_ms = elapsed_time * 1000 / processed_count if processed_count > 0 else 0 + self._current_cpu_util = snapshot.cpu_percent + self._current_memory_util = ( + (snapshot.memory_mb / self.quota.memory_quota_mb * 100) + if self.quota and self.quota.memory_quota_mb > 0 + else 0 + ) + self._current_gpu_util = snapshot.gpu_utilization or 0 + + # Check if should report to Tower + current_time = time.time() + if current_time - self.last_report_time >= self.config.report_interval_sec: + self._report_to_tower() + self.last_report_time = current_time + + # Check quota compliance + if current_time - self.last_quota_check_time >= self.config.quota_check_interval_sec: + self._check_quota_compliance() + self.last_quota_check_time = current_time + + return results + + def _handle_oom(self, batch_size: int, snapshot: Optional[ResourceSnapshot]): + """ + Handle OOM event with emergency backoff + + Args: + batch_size: Batch size that caused OOM + snapshot: Resource snapshot at OOM time + """ + self.oom_events += 1 + self._total_oom_count += 1 # Increment cumulative OOM count for metrics + self.last_oom_time = time.time() + + # Emergency backoff + if self.micro_scheduler: + new_batch_size = max(1, batch_size // 2) + self.micro_scheduler.controller.current_batch_size = new_batch_size + self.micro_scheduler.controller.max_batch_size = batch_size + + # Update metrics + self.metrics.oom_count = self.oom_events + + def _report_to_tower(self): + """Report current metrics to Tower""" + if self.tower_callback: + # Update timestamp + self.metrics.last_update = time.time() + + # Send to Tower + self.tower_callback(self.metrics) + + def _check_quota_compliance(self): + """ + Check if current resource usage is within Tower's quota + + If exceeding quota, apply throttling + """ + if not self.quota: + return + + # Check memory quota + current_memory_mb = psutil.virtual_memory().used / (1024 * 1024) + if current_memory_mb > self.quota.memory_quota_mb: + # Exceeding memory quota, reduce batch size + if self.micro_scheduler: + reduction_ratio = self.quota.memory_quota_mb / current_memory_mb + new_batch_size = max( + 1, + int(self.micro_scheduler.controller.current_batch_size * reduction_ratio) + ) + self.micro_scheduler.controller.current_batch_size = new_batch_size + + def _get_available_memory_mb(self) -> float: + """Get available system memory in MB""" + return psutil.virtual_memory().available / (1024 * 1024) + + def get_stats(self) -> dict: + """Get Captain statistics""" + return { + 'stage_name': self.config.stage_name, + 'samples_processed': self.samples_processed, + 'queue_depth': len(self.queue), + 'current_batch_size': ( + self.micro_scheduler.controller.current_batch_size + if self.micro_scheduler else self.config.initial_batch_size + ), + 'avg_latency_ms': self.metrics.avg_latency_ms, + 'avg_throughput': self.metrics.throughput, + 'oom_events': self.oom_events, + 'quota': { + 'target_parallelism': self.quota.target_parallelism if self.quota else 1, + 'memory_quota_mb': self.quota.memory_quota_mb if self.quota else 0, + 'cpu_quota': self.quota.cpu_quota if self.quota else 0, + } if self.quota else None + } + + def collect_metrics(self) -> 'StageMetrics': + """Collect current metrics snapshot for Tower consumption. + + This is the standardized interface for Tower to pull metrics from Captain. + Returns a StageMetrics object containing the current state of this stage. + + Returns: + StageMetrics: A snapshot containing: + - stage_name: Name of this operator stage + - queue_depth: Number of pending samples in queue + - current_parallelism: Current number of actors (1 for single-actor) + - throughput: Recent throughput in samples/sec + - avg_latency_ms: Recent average processing latency + - cpu_utilization: Current CPU utilization percentage + - memory_utilization: Current memory utilization percentage + - gpu_utilization: Current GPU utilization percentage (if applicable) + - oom_count: Cumulative OOM event count + """ + # Use local import to avoid circular dependencies + from .tower import StageMetrics + + return StageMetrics( + stage_name=self.config.stage_name, + queue_depth=len(self.queue), + current_parallelism=self.metrics.current_parallelism, + throughput=self._recent_throughput if self._recent_throughput > 0 else self.metrics.throughput, + avg_latency_ms=self._recent_latency_ms if self._recent_latency_ms > 0 else self.metrics.avg_latency_ms, + cpu_utilization=self._current_cpu_util if self._current_cpu_util > 0 else self.metrics.cpu_utilization, + memory_utilization=self._current_memory_util if self._current_memory_util > 0 else self.metrics.memory_utilization, + gpu_utilization=self._current_gpu_util if self._current_gpu_util > 0 else self.metrics.gpu_utilization, + oom_count=self._total_oom_count, + last_update=time.time() + ) + + +class CaptainPool: + """ + Manages multiple Captains in a pipeline + + Coordinates execution across multiple stages, ensuring data flows + correctly and all Captains report to Tower. + """ + + def __init__(self, tower_callback: Optional[Callable[[StageMetrics], None]] = None): + """ + Initialize Captain pool + + Args: + tower_callback: Shared callback to Tower for all Captains + """ + self.tower_callback = tower_callback + self.captains: dict[str, Captain] = {} + + def add_captain(self, config: CaptainConfig) -> Captain: + """ + Add a new Captain to the pool + + Args: + config: Configuration for the Captain + + Returns: + The created Captain instance + """ + captain = Captain(config, self.tower_callback) + self.captains[config.stage_name] = captain + return captain + + def get_captain(self, stage_name: str) -> Optional[Captain]: + """Get Captain by stage name""" + return self.captains.get(stage_name) + + def set_quotas(self, quotas: dict[str, ResourceQuota]): + """ + Distribute quotas from Tower to all Captains + + Args: + quotas: Dict mapping captain_id to ResourceQuota + """ + for captain_id, quota in quotas.items(): + # Extract stage name from captain_id + stage_name = quota.captain_id.replace('captain_', '').rsplit('_', 1)[0] + + if stage_name in self.captains: + self.captains[stage_name].set_quota(quota) + + def get_all_stats(self) -> dict[str, dict]: + """Get statistics from all Captains""" + return { + name: captain.get_stats() + for name, captain in self.captains.items() + } + + def collect_all_metrics(self) -> Dict[str, 'StageMetrics']: + """Collect metrics from all managed Captains. + + This is the standardized interface for Tower to pull metrics from all + Captains in the pool at once. + + Returns: + Dict[str, StageMetrics]: A dictionary mapping captain stage names + to their current StageMetrics snapshots. + """ + return { + stage_name: captain.collect_metrics() + for stage_name, captain in self.captains.items() + } diff --git a/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py b/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py new file mode 100644 index 0000000000..b62c1a889e --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/micro_scheduler.py @@ -0,0 +1,543 @@ +""" +Micro-Scheduler with JABAS-style PID Control + +Implements dynamic batch size adjustment based on memory feedback. +Prevents OOM by continuously monitoring memory and adjusting batch sizes. + +Based on: +- JABAS (EuroSys 2025): Adaptive batching for heterogeneous GPUs +- Report Section 4.1: PID Control for Batch Size + +Key Features: +- PID controller for smooth batch size adjustment +- Memory pressure monitoring +- Safety thresholds and fallback strategies +- Integration with MemoryPredictor +""" + +import time +import psutil +from typing import Optional, Dict, Any, Callable +from dataclasses import dataclass +from collections import deque +import numpy as np + +try: + import GPUtil + GPU_AVAILABLE = True +except ImportError: + GPU_AVAILABLE = False + + +@dataclass +class MemoryState: + """Current memory state""" + timestamp: float + # CPU memory + total_memory_mb: float + used_memory_mb: float + available_memory_mb: float + memory_percent: float + # GPU memory (if available) + gpu_total_mb: Optional[float] = None + gpu_used_mb: Optional[float] = None + gpu_available_mb: Optional[float] = None + gpu_percent: Optional[float] = None + + def get_available_memory(self, use_gpu: bool = False) -> float: + """Get available memory in MB""" + if use_gpu and self.gpu_available_mb is not None: + return self.gpu_available_mb + return self.available_memory_mb + + +class PIDController: + """ + PID (Proportional-Integral-Derivative) Controller. + + Classic control theory algorithm used in JABAS for smooth adjustments. + + Formula: + output(t) = Kp * error(t) + Ki * Σerror + Kd * Δerror + + Where: + - error = setpoint - current_value + - Kp, Ki, Kd are tuning parameters + """ + + def __init__( + self, + kp: float = 1.0, + ki: float = 0.1, + kd: float = 0.05, + setpoint: float = 1000.0, + output_limits: tuple = (1, 1000), + ): + """ + Initialize PID controller. + + Args: + kp: Proportional gain + ki: Integral gain + kd: Derivative gain + setpoint: Target value (e.g., target available memory in MB) + output_limits: (min, max) bounds for output + """ + self.kp = kp + self.ki = ki + self.kd = kd + self.setpoint = setpoint + self.output_limits = output_limits + + # State + self.last_error = 0.0 + self.integral = 0.0 + self.last_time = None + + def update(self, current_value: float, dt: Optional[float] = None) -> float: + """ + Update PID controller with new measurement. + + Args: + current_value: Current measured value + dt: Time delta since last update (optional) + + Returns: + Control output + """ + # Calculate error + error = self.setpoint - current_value + + # Calculate time delta + current_time = time.time() + if self.last_time is None or dt is not None: + dt = dt or 0.1 # Default dt + else: + dt = current_time - self.last_time + self.last_time = current_time + + # Proportional term + p_term = self.kp * error + + # Integral term (with anti-windup) + self.integral += error * dt + # Clamp integral to prevent windup + max_integral = self.output_limits[1] / (self.ki + 1e-6) + self.integral = np.clip(self.integral, -max_integral, max_integral) + i_term = self.ki * self.integral + + # Derivative term + derivative = (error - self.last_error) / (dt + 1e-6) + d_term = self.kd * derivative + + # Calculate output + output = p_term + i_term + d_term + + # Apply output limits + output = np.clip(output, self.output_limits[0], self.output_limits[1]) + + # Update state + self.last_error = error + + return output + + def reset(self): + """Reset controller state""" + self.last_error = 0.0 + self.integral = 0.0 + self.last_time = None + + def set_setpoint(self, setpoint: float): + """Update setpoint""" + self.setpoint = setpoint + + +class BatchSizeController: + """ + Controls batch size using PID feedback based on memory pressure. + + This is the core of the micro-scheduler - it continuously monitors + memory and adjusts batch size to maximize throughput while preventing OOM. + + Strategy (from Report Section 4.1): + B_next = B_curr × (M_target / M_curr) + + Enhanced with PID for smoothness. + """ + + def __init__( + self, + initial_batch_size: int = 1, + min_batch_size: int = 1, + max_batch_size: int = 1000, + target_memory_utilization: float = 0.85, + safety_buffer_mb: float = 1000.0, + use_gpu: bool = False, + enable_prediction: bool = True, + memory_predictor = None, + ): + """ + Initialize batch size controller. + + Args: + initial_batch_size: Starting batch size + min_batch_size: Minimum allowed batch size + max_batch_size: Maximum allowed batch size + target_memory_utilization: Target memory usage (0.0-1.0) + safety_buffer_mb: Safety buffer to keep free (MB) + use_gpu: Monitor GPU memory instead of CPU + enable_prediction: Use MemoryPredictor for proactive adjustment + memory_predictor: MemoryPredictor instance + """ + self.current_batch_size = initial_batch_size + self.min_batch_size = min_batch_size + self.max_batch_size = max_batch_size + self.target_utilization = target_memory_utilization + self.safety_buffer_mb = safety_buffer_mb + self.use_gpu = use_gpu + self.enable_prediction = enable_prediction + self.memory_predictor = memory_predictor + + # PID controller for smooth adjustments + # Setpoint will be dynamically updated based on available memory + self.pid = PIDController( + kp=0.5, # Moderate proportional gain + ki=0.05, # Small integral gain + kd=0.1, # Small derivative gain + setpoint=safety_buffer_mb, + output_limits=(min_batch_size, max_batch_size), + ) + + # History + self.batch_size_history = deque(maxlen=100) + self.memory_history = deque(maxlen=100) + self.oom_events = [] + + # Statistics + self.total_adjustments = 0 + self.increase_count = 0 + self.decrease_count = 0 + + def get_memory_state(self) -> MemoryState: + """Get current memory state""" + # CPU memory + mem = psutil.virtual_memory() + total_mb = mem.total / (1024 * 1024) + used_mb = mem.used / (1024 * 1024) + available_mb = mem.available / (1024 * 1024) + percent = mem.percent + + # GPU memory + gpu_total = None + gpu_used = None + gpu_available = None + gpu_percent = None + + if self.use_gpu and GPU_AVAILABLE: + try: + gpus = GPUtil.getGPUs() + if gpus: + gpu = gpus[0] # Use first GPU + gpu_total = gpu.memoryTotal + gpu_used = gpu.memoryUsed + gpu_available = gpu.memoryFree + gpu_percent = (gpu_used / gpu_total * 100) if gpu_total > 0 else 0 + except Exception: + pass + + return MemoryState( + timestamp=time.time(), + total_memory_mb=total_mb, + used_memory_mb=used_mb, + available_memory_mb=available_mb, + memory_percent=percent, + gpu_total_mb=gpu_total, + gpu_used_mb=gpu_used, + gpu_available_mb=gpu_available, + gpu_percent=gpu_percent, + ) + + def calculate_next_batch_size( + self, + memory_state: MemoryState, + predicted_memory_per_sample: Optional[float] = None, + ) -> int: + """ + Calculate next batch size using PID control and predictions. + + Args: + memory_state: Current memory state + predicted_memory_per_sample: Predicted memory per sample (optional) + + Returns: + Recommended batch size + """ + available_mb = memory_state.get_available_memory(self.use_gpu) + + # Method 1: Direct calculation based on available memory + # More memory available -> larger batch size + if available_mb > self.safety_buffer_mb: + usable_memory = available_mb - self.safety_buffer_mb + # Scale batch size proportionally to usable memory + # Normalize to a reasonable range + memory_based_batch = int((usable_memory / 1000.0) * self.max_batch_size) + memory_based_batch = np.clip(memory_based_batch, self.min_batch_size, self.max_batch_size) + else: + # Below safety buffer, use minimum + memory_based_batch = self.min_batch_size + + # Method 2: Ratio-based adjustment (from JABAS paper) + # B_next = B_curr × (M_target / M_curr) + total_mb = memory_state.total_memory_mb if not self.use_gpu else (memory_state.gpu_total_mb or 1000) + target_used = total_mb * self.target_utilization + current_used = total_mb - available_mb + + if current_used > 0: + ratio = target_used / current_used + # Clamp ratio to prevent extreme changes + ratio = np.clip(ratio, 0.5, 2.0) + ratio_batch_size = int(self.current_batch_size * ratio) + else: + ratio_batch_size = self.current_batch_size + + # Method 3: Prediction-based adjustment (if predictor available) + prediction_batch_size = None + if self.enable_prediction and predicted_memory_per_sample: + # Calculate how many samples can fit + usable_memory = available_mb - self.safety_buffer_mb + if usable_memory > 0 and predicted_memory_per_sample > 0: + prediction_batch_size = int(usable_memory / predicted_memory_per_sample) + + # Combine methods (weighted average) + candidates = [] + weights = [] + + candidates.append(memory_based_batch) + weights.append(0.4) # 40% weight on memory-based + + candidates.append(ratio_batch_size) + weights.append(0.3) # 30% weight on ratio + + if prediction_batch_size is not None: + candidates.append(prediction_batch_size) + weights.append(0.3) # 30% weight on prediction + + # Weighted average + next_batch_size = int(np.average(candidates, weights=weights)) + + # Apply bounds + next_batch_size = np.clip(next_batch_size, self.min_batch_size, self.max_batch_size) + + # Smooth changes (avoid drastic jumps) + max_change = max(1, int(self.current_batch_size * 0.5)) # Max 50% change per step + if abs(next_batch_size - self.current_batch_size) > max_change: + if next_batch_size > self.current_batch_size: + next_batch_size = self.current_batch_size + max_change + else: + next_batch_size = self.current_batch_size - max_change + + return int(next_batch_size) + + def update_batch_size( + self, + actual_memory_used: Optional[float] = None, + predicted_memory_per_sample: Optional[float] = None, + ) -> int: + """ + Update batch size based on current memory state. + + Args: + actual_memory_used: Actual memory used by last batch (for feedback) + predicted_memory_per_sample: Predicted memory per sample + + Returns: + New batch size + """ + # Get current memory state + memory_state = self.get_memory_state() + + # Calculate next batch size + next_batch_size = self.calculate_next_batch_size( + memory_state, + predicted_memory_per_sample, + ) + + # Update statistics + if next_batch_size > self.current_batch_size: + self.increase_count += 1 + elif next_batch_size < self.current_batch_size: + self.decrease_count += 1 + + if next_batch_size != self.current_batch_size: + self.total_adjustments += 1 + + # Update current batch size + old_batch_size = self.current_batch_size + self.current_batch_size = next_batch_size + + # Record history + self.batch_size_history.append({ + 'timestamp': time.time(), + 'old_batch': old_batch_size, + 'new_batch': next_batch_size, + 'available_mb': memory_state.get_available_memory(self.use_gpu), + 'memory_percent': memory_state.memory_percent, + }) + + self.memory_history.append(memory_state) + + return next_batch_size + + def report_oom(self, batch_size: int, memory_mb: float): + """Report an OOM event to adjust strategy""" + self.oom_events.append({ + 'timestamp': time.time(), + 'batch_size': batch_size, + 'memory_mb': memory_mb, + }) + + # Emergency reduction + self.current_batch_size = max(1, batch_size // 2) + self.max_batch_size = batch_size # Don't go higher than OOM point + + # Reset PID to avoid windup + self.pid.reset() + + def get_stats(self) -> Dict[str, Any]: + """Get controller statistics""" + return { + 'current_batch_size': self.current_batch_size, + 'min_batch_size': self.min_batch_size, + 'max_batch_size': self.max_batch_size, + 'total_adjustments': self.total_adjustments, + 'increase_count': self.increase_count, + 'decrease_count': self.decrease_count, + 'oom_events': len(self.oom_events), + 'avg_batch_size': np.mean([h['new_batch'] for h in self.batch_size_history]) if self.batch_size_history else 0, + } + + +class MicroScheduler: + """ + Micro-Scheduler with JABAS-style adaptive batching. + + Orchestrates: + - Memory monitoring + - Batch size control via PID + - Memory prediction integration + - OOM prevention + + Usage: + scheduler = MicroScheduler(memory_predictor=predictor) + + for batch in data_loader: + # Get recommended batch size + batch_size = scheduler.get_batch_size() + + # Process batch + result = process(batch[:batch_size]) + + # Update scheduler with feedback + scheduler.update(actual_memory_used=memory_mb) + """ + + def __init__( + self, + memory_predictor=None, + initial_batch_size: int = 32, + min_batch_size: int = 1, + max_batch_size: int = 1000, + target_memory_utilization: float = 0.85, + safety_buffer_mb: float = 1000.0, + use_gpu: bool = False, + enable_auto_adjust: bool = True, + ): + """ + Initialize micro-scheduler. + + Args: + memory_predictor: MemoryPredictor instance + initial_batch_size: Starting batch size + min_batch_size: Minimum batch size + max_batch_size: Maximum batch size + target_memory_utilization: Target memory usage (0.0-1.0) + safety_buffer_mb: Safety buffer in MB + use_gpu: Monitor GPU memory + enable_auto_adjust: Enable automatic batch size adjustment + """ + self.memory_predictor = memory_predictor + self.enable_auto_adjust = enable_auto_adjust + + # Batch size controller + self.controller = BatchSizeController( + initial_batch_size=initial_batch_size, + min_batch_size=min_batch_size, + max_batch_size=max_batch_size, + target_memory_utilization=target_memory_utilization, + safety_buffer_mb=safety_buffer_mb, + use_gpu=use_gpu, + enable_prediction=memory_predictor is not None, + memory_predictor=memory_predictor, + ) + + # State + self.iteration = 0 + self.last_prediction = None + + def get_batch_size(self, sample_features=None) -> int: + """ + Get recommended batch size for next iteration. + + Args: + sample_features: Optional sample features for prediction + + Returns: + Recommended batch size + """ + if not self.enable_auto_adjust: + return self.controller.current_batch_size + + # Get memory prediction if available + predicted_per_sample = None + if self.memory_predictor and sample_features: + prediction = self.memory_predictor.predict(sample_features) + if prediction: + self.last_prediction = prediction + # Estimate per-sample memory + predicted_per_sample = prediction.predicted_memory_mb / sample_features.batch_size + + # Update batch size + new_batch_size = self.controller.update_batch_size( + predicted_memory_per_sample=predicted_per_sample, + ) + + self.iteration += 1 + return new_batch_size + + def update(self, actual_memory_used: float, sample_features=None): + """ + Update scheduler with feedback from actual execution. + + Args: + actual_memory_used: Actual memory used in MB + sample_features: Sample features (for predictor update) + """ + # Update memory predictor if available + if self.memory_predictor and sample_features: + self.memory_predictor.observe(sample_features, actual_memory_used) + + def report_oom(self, batch_size: int, memory_mb: float): + """Report OOM event""" + self.controller.report_oom(batch_size, memory_mb) + + def get_stats(self) -> Dict[str, Any]: + """Get scheduler statistics""" + stats = self.controller.get_stats() + stats['iteration'] = self.iteration + if self.last_prediction: + stats['last_prediction'] = { + 'predicted_mb': self.last_prediction.predicted_memory_mb, + 'confidence_lower': self.last_prediction.confidence_lower, + 'confidence_upper': self.last_prediction.confidence_upper, + } + return stats diff --git a/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py b/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py new file mode 100644 index 0000000000..9efd0c4de9 --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/scheduler_config.py @@ -0,0 +1,147 @@ +""" +Scheduler Configuration + +Centralized configuration for micro and macro schedulers. +""" + +from dataclasses import dataclass, field, asdict +from typing import Dict, Optional + +try: + import yaml + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False + + +@dataclass +class SchedulerConfig: + """Configuration for ElasticJuicer schedulers""" + + # Batch size control + initial_batch_size: int = 32 + min_batch_size: int = 1 + max_batch_size: int = 1000 + + # Memory management + target_memory_utilization: float = 0.85 # 85% utilization target + safety_buffer_mb: float = 1000.0 # 1GB safety buffer + use_gpu_memory: bool = False + + # PID tuning + pid_kp: float = 0.5 # Proportional gain + pid_ki: float = 0.05 # Integral gain + pid_kd: float = 0.1 # Derivative gain + + # Auto-adjustment + enable_auto_adjust: bool = True + enable_prediction: bool = True + + # Predictor settings + predictor_window_size: int = 100 + predictor_min_samples: int = 5 + predictor_confidence_level: float = 0.95 + + # Safety settings + max_batch_change_ratio: float = 0.5 # Max 50% change per adjustment + oom_backoff_ratio: float = 0.5 # Reduce to 50% on OOM + + # Tower macro-scheduler settings (PBT output) + rebalance_interval_sec: float = 5.0 # Tower macro-scheduler rebalance loop interval in seconds + tower_allocation_weights: Optional[Dict[str, float]] = field(default=None) # Per-stage resource allocation weights from PBT tuning + backpressure_threshold: float = 0.9 # Memory utilization threshold above which backpressure is applied + backpressure_slowdown_ratio: float = 0.5 # Factor to reduce throughput when backpressure is active + + @classmethod + def conservative(cls) -> 'SchedulerConfig': + """Conservative configuration (prioritizes safety)""" + return cls( + target_memory_utilization=0.70, + safety_buffer_mb=2000.0, + max_batch_change_ratio=0.25, + rebalance_interval_sec=10.0, + backpressure_threshold=0.8, + ) + + @classmethod + def aggressive(cls) -> 'SchedulerConfig': + """Aggressive configuration (prioritizes throughput)""" + return cls( + target_memory_utilization=0.95, + safety_buffer_mb=500.0, + max_batch_change_ratio=0.75, + rebalance_interval_sec=2.0, + backpressure_threshold=0.95, + ) + + @classmethod + def gpu(cls) -> 'SchedulerConfig': + """GPU-optimized configuration""" + return cls( + use_gpu_memory=True, + target_memory_utilization=0.90, + safety_buffer_mb=1024.0, # 1GB buffer for GPU + ) + + @classmethod + def from_yaml(cls, path: str) -> 'SchedulerConfig': + """Load config from a YAML file (the output of PBT tuning). + + Args: + path: Path to the YAML configuration file. + + Returns: + SchedulerConfig instance with values from YAML, using defaults for missing fields. + + Raises: + ImportError: If PyYAML is not installed. + FileNotFoundError: If the YAML file does not exist. + """ + if not YAML_AVAILABLE: + raise ImportError( + "PyYAML is required for YAML support. " + "Install it with: pip install pyyaml" + ) + + with open(path, 'r') as f: + data = yaml.safe_load(f) or {} + + # Filter to only include valid fields for SchedulerConfig + valid_fields = {f.name for f in cls.__dataclass_fields__.values()} + filtered_data = {k: v for k, v in data.items() if k in valid_fields} + + return cls(**filtered_data) + + def to_yaml(self, path: str) -> None: + """Export config to YAML file. + + Args: + path: Path to write the YAML configuration file. + + Raises: + ImportError: If PyYAML is not installed. + """ + if not YAML_AVAILABLE: + raise ImportError( + "PyYAML is required for YAML support. " + "Install it with: pip install pyyaml" + ) + + data = asdict(self) + + with open(path, 'w') as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + + def get_stage_weight(self, stage_name: str) -> float: + """Return the allocation weight for a given stage. + + Args: + stage_name: Name of the stage to get weight for. + + Returns: + The allocation weight for the stage. Returns 1.0 if tower_allocation_weights + is None or the stage is not found (equal weight). + """ + if self.tower_allocation_weights is None: + return 1.0 + return self.tower_allocation_weights.get(stage_name, 1.0) diff --git a/data_juicer/core/elasticjuicer/scheduler/tower.py b/data_juicer/core/elasticjuicer/scheduler/tower.py new file mode 100644 index 0000000000..1e03cd3ddf --- /dev/null +++ b/data_juicer/core/elasticjuicer/scheduler/tower.py @@ -0,0 +1,869 @@ +""" +Tower: Global Macro-Scheduler for ElasticJuicer + +Based on Autothrottle's bi-level architecture, Tower is the global resource allocator +that sets performance targets and resource quotas for local Captains. + +Key responsibilities: +1. Monitor global queue depth and cluster resource utilization +2. Set target parallelism for each operator stage +3. Allocate resource budgets to Captains +4. Make topology decisions (co-location vs distributed) +5. Handle global SLA guarantees + +References: +- Autothrottle (NSDI 2024): Bi-level control for SLO-targeted microservices +- Report Section 5.1: Tower/Captain architecture +""" + +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from enum import Enum +import numpy as np +from collections import deque + +if TYPE_CHECKING: + from .scheduler_config import SchedulerConfig + + +class TopologyMode(Enum): + """Topology execution mode based on transfer cost and resource availability""" + CO_LOCATION = "co_location" # Operators on same node (high transfer cost) + DISTRIBUTED = "distributed" # Operators on different nodes (high parallelism) + ADAPTIVE = "adaptive" # Let Tower decide based on current state + + +@dataclass +class StageMetrics: + """Performance metrics for an operator stage""" + stage_name: str + queue_depth: int = 0 # Number of pending samples + current_parallelism: int = 1 # Current number of actors + throughput: float = 0.0 # Samples/sec + avg_latency_ms: float = 0.0 # Average processing latency + cpu_utilization: float = 0.0 # % CPU used + memory_utilization: float = 0.0 # % Memory used + gpu_utilization: float = 0.0 # % GPU used (if applicable) + oom_count: int = 0 # Number of OOM events + last_update: float = field(default_factory=time.time) + + +@dataclass +class ResourceQuota: + """Resource allocation quota for a Captain""" + captain_id: str + target_parallelism: int # Target number of actors + cpu_quota: float # CPU cores allocated + memory_quota_mb: float # Memory budget in MB + gpu_quota: float = 0.0 # GPU cores allocated (0-1) + target_throughput: float = 0.0 # Target samples/sec (SLO) + topology_mode: TopologyMode = TopologyMode.ADAPTIVE + backpressure: bool = False # Whether upstream backpressure is active + + +@dataclass +class ClusterState: + """Global cluster resource state""" + total_cpu_cores: int + total_memory_mb: float + total_gpu_count: int + available_cpu_cores: float + available_memory_mb: float + available_gpus: float + timestamp: float = field(default_factory=time.time) + + +class Tower: + """ + Global Macro-Scheduler (Tower from Autothrottle architecture) + + Tower doesn't directly control individual actors' behavior. Instead, it: + 1. Monitors global system state (queue depths, resource utilization) + 2. Sets performance targets and resource quotas for Captains + 3. Makes high-level topology decisions + 4. Ensures cluster-wide SLA guarantees + + The bi-level design (Tower + Captain) solves the single-point bottleneck + problem of centralized schedulers, enabling high-frequency local decisions + under global constraints. + """ + + def __init__( + self, + cluster_state: ClusterState, + target_queue_depth: int = 100, + sla_latency_ms: float = 5000.0, + update_interval_sec: float = 5.0, + history_window: int = 20, + config: Optional['SchedulerConfig'] = None + ): + """ + Initialize Tower global scheduler + + Args: + cluster_state: Initial cluster resource state + target_queue_depth: Target queue depth to maintain + sla_latency_ms: SLA latency target (max allowed latency) + update_interval_sec: How often to recompute resource allocation + history_window: Window size for tracking metrics history + config: Optional SchedulerConfig for rebalance settings + """ + self.cluster = cluster_state + self.target_queue_depth = target_queue_depth + self.sla_latency_ms = sla_latency_ms + self.update_interval = update_interval_sec + self.config = config + + # Track all stages and their metrics + self.stages: Dict[str, StageMetrics] = {} + + # Track resource quotas allocated to each Captain + self.quotas: Dict[str, ResourceQuota] = {} + + # Metrics history for trend analysis + self.metrics_history: Dict[str, deque] = {} + self.history_window = history_window + + # Last allocation time + self.last_allocation_time = time.time() + + # SLA violation tracking + self.sla_violations = 0 + self.total_requests = 0 + + # Captain registry for direct metric collection and quota broadcast + self._captains: Dict[str, Any] = {} + + # Rebalance loop control + self._running: bool = False + self._rebalance_thread: Optional[threading.Thread] = None + + # Rebalance interval from config or fallback to update_interval + self.rebalance_interval: float = ( + config.rebalance_interval_sec if config else update_interval_sec + ) + + # Backpressure threshold from config + self._backpressure_threshold: float = ( + config.backpressure_threshold if config else 0.9 + ) + + # Per-stage backpressure state tracking + self._backpressure_states: Dict[str, bool] = {} + + # Track registration order for upstream detection + self._stage_order: List[str] = [] + + def register_stage(self, stage_name: str, initial_parallelism: int = 1) -> str: + """ + Register a new operator stage with Tower + + Args: + stage_name: Name of the operator stage + initial_parallelism: Initial number of actors + + Returns: + captain_id: Unique ID for the Captain managing this stage + """ + captain_id = f"captain_{stage_name}_{int(time.time())}" + + # Initialize stage metrics + self.stages[stage_name] = StageMetrics( + stage_name=stage_name, + current_parallelism=initial_parallelism + ) + + # Initialize metrics history + self.metrics_history[stage_name] = deque(maxlen=self.history_window) + + # Track stage registration order for upstream detection + if stage_name not in self._stage_order: + self._stage_order.append(stage_name) + + # Initialize backpressure state + self._backpressure_states[stage_name] = False + + # Allocate initial quota + initial_quota = self._compute_initial_quota(stage_name, initial_parallelism) + self.quotas[captain_id] = initial_quota + + return captain_id + + def update_stage_metrics(self, stage_name: str, metrics: StageMetrics): + """ + Update metrics for a stage (called by Captain) + + Args: + stage_name: Name of the stage + metrics: Latest metrics from Captain + """ + if stage_name not in self.stages: + raise ValueError(f"Stage {stage_name} not registered") + + # Update current metrics + self.stages[stage_name] = metrics + + # Add to history + self.metrics_history[stage_name].append({ + 'timestamp': metrics.last_update, + 'queue_depth': metrics.queue_depth, + 'throughput': metrics.throughput, + 'latency_ms': metrics.avg_latency_ms, + 'cpu_util': metrics.cpu_utilization, + 'memory_util': metrics.memory_utilization + }) + + # Track SLA violations + self.total_requests += 1 + if metrics.avg_latency_ms > self.sla_latency_ms: + self.sla_violations += 1 + + def allocate_resources(self) -> Dict[str, ResourceQuota]: + """ + Compute and allocate resource quotas to all Captains + + This is the core global decision-making function. It: + 1. Analyzes global bottlenecks (queue depths, latencies) + 2. Computes target parallelism for each stage + 3. Allocates CPU/GPU/memory budgets + 4. Returns updated quotas for Captains to enforce + + Returns: + Updated resource quotas for all Captains + """ + current_time = time.time() + + # Rate limit allocation updates (avoid thrashing) + if current_time - self.last_allocation_time < self.update_interval: + return self.quotas + + self.last_allocation_time = current_time + + # Identify bottleneck stages + bottlenecks = self._identify_bottlenecks() + + # Compute resource allocation strategy + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name not in self.stages: + continue + + metrics = self.stages[stage_name] + + # Decide target parallelism based on queue depth and throughput + target_parallelism = self._compute_target_parallelism( + metrics, + is_bottleneck=(stage_name in bottlenecks) + ) + + # Allocate resources proportionally + resource_allocation = self._allocate_stage_resources( + stage_name, + target_parallelism + ) + + # Update quota + quota.target_parallelism = target_parallelism + quota.cpu_quota = resource_allocation['cpu'] + quota.memory_quota_mb = resource_allocation['memory_mb'] + quota.gpu_quota = resource_allocation['gpu'] + quota.target_throughput = self._compute_target_throughput(metrics) + quota.topology_mode = self._decide_topology(stage_name, metrics) + + return self.quotas + + def _identify_bottlenecks(self) -> List[str]: + """ + Identify bottleneck stages based on queue depth and latency + + A stage is a bottleneck if: + 1. Queue depth > target_queue_depth + 2. Latency approaching SLA limit + 3. Throughput declining over time + + Returns: + List of bottleneck stage names + """ + bottlenecks = [] + + for stage_name, metrics in self.stages.items(): + # Check queue depth + queue_pressure = metrics.queue_depth > self.target_queue_depth + + # Check latency + latency_pressure = metrics.avg_latency_ms > (self.sla_latency_ms * 0.8) + + # Check throughput trend + throughput_declining = False + if stage_name in self.metrics_history and len(self.metrics_history[stage_name]) >= 3: + recent = list(self.metrics_history[stage_name])[-3:] + throughputs = [m['throughput'] for m in recent] + if len(throughputs) >= 2: + throughput_declining = throughputs[-1] < throughputs[0] * 0.9 + + if queue_pressure or latency_pressure or throughput_declining: + bottlenecks.append(stage_name) + + return bottlenecks + + def _compute_target_parallelism( + self, + metrics: StageMetrics, + is_bottleneck: bool + ) -> int: + """ + Compute target parallelism for a stage + + Strategy: + - If bottleneck: Increase parallelism to drain queue + - If underutilized: Decrease parallelism to free resources + - Consider resource availability + + Args: + metrics: Current stage metrics + is_bottleneck: Whether this stage is a bottleneck + + Returns: + Target parallelism (number of actors) + """ + current = metrics.current_parallelism + + if is_bottleneck: + # Estimate needed parallelism to drain queue + if metrics.throughput > 0: + # Time to process queue at current throughput + queue_drain_time = metrics.queue_depth / metrics.throughput + + # If drain time > SLA, scale up + if queue_drain_time > (self.sla_latency_ms / 1000.0): + scale_factor = min(2.0, queue_drain_time / (self.sla_latency_ms / 1000.0)) + target = int(current * scale_factor) + else: + target = current + 1 # Conservative increase + else: + target = current + 1 # No throughput data, try increasing + else: + # Check if we can scale down (free resources) + if metrics.queue_depth < self.target_queue_depth * 0.5 and current > 1: + target = max(1, current - 1) + else: + target = current # Keep current level + + # Clamp to available resources + max_possible = self._estimate_max_parallelism() + target = min(target, max_possible) + + return max(1, target) # At least 1 actor + + def _allocate_stage_resources( + self, + stage_name: str, + target_parallelism: int + ) -> Dict[str, float]: + """ + Allocate CPU/GPU/memory to a stage based on target parallelism + + Args: + stage_name: Name of the stage + target_parallelism: Target number of actors + + Returns: + Resource allocation dict with 'cpu', 'memory_mb', 'gpu' + """ + # Simple proportional allocation (can be enhanced with OCS annotations) + total_stages = len(self.stages) + + if total_stages == 0: + cpu_share = self.cluster.available_cpu_cores + memory_share = self.cluster.available_memory_mb + gpu_share = self.cluster.available_gpus + else: + # Equal share for now (TODO: weight by OCS cost) + cpu_share = self.cluster.available_cpu_cores / total_stages + memory_share = self.cluster.available_memory_mb / total_stages + gpu_share = self.cluster.available_gpus / total_stages + + return { + 'cpu': cpu_share * target_parallelism, + 'memory_mb': memory_share * target_parallelism, + 'gpu': gpu_share * target_parallelism + } + + def _compute_target_throughput(self, metrics: StageMetrics) -> float: + """ + Compute target throughput to meet SLA + + Args: + metrics: Current stage metrics + + Returns: + Target throughput in samples/sec + """ + # To meet SLA, we need throughput >= queue_depth / (SLA_time - current_latency) + sla_time_sec = self.sla_latency_ms / 1000.0 + current_latency_sec = metrics.avg_latency_ms / 1000.0 + + remaining_time = max(0.1, sla_time_sec - current_latency_sec) + + if metrics.queue_depth > 0: + target = metrics.queue_depth / remaining_time + else: + target = metrics.throughput # Maintain current + + return max(1.0, target) + + def _decide_topology( + self, + stage_name: str, + metrics: StageMetrics + ) -> TopologyMode: + """ + Decide topology mode for operator placement + + Based on Report Section 5.4: + - CO_LOCATION: High transfer cost, sufficient local resources + - DISTRIBUTED: Different resource bottlenecks, ample bandwidth + + Args: + stage_name: Name of the stage + metrics: Current metrics + + Returns: + Topology mode decision + """ + # Check resource pressure + high_cpu = metrics.cpu_utilization > 80 + high_memory = metrics.memory_utilization > 80 + high_gpu = metrics.gpu_utilization > 80 + + # If single resource bottleneck, distribute to specialize + bottleneck_count = sum([high_cpu, high_memory, high_gpu]) + + if bottleneck_count >= 2: + # Multiple bottlenecks on same node -> distribute + return TopologyMode.DISTRIBUTED + elif bottleneck_count == 0: + # No pressure -> co-locate for efficiency + return TopologyMode.CO_LOCATION + else: + # Single bottleneck -> adaptive + return TopologyMode.ADAPTIVE + + def _estimate_max_parallelism(self) -> int: + """ + Estimate maximum parallelism given available resources + + Returns: + Maximum number of actors cluster can support + """ + # Conservative estimate: assume each actor needs 1 CPU + 1GB memory + cpu_limit = int(self.cluster.available_cpu_cores) + memory_limit = int(self.cluster.available_memory_mb / 1024) # 1GB per actor + + return max(1, min(cpu_limit, memory_limit)) + + def _get_stage_from_captain(self, captain_id: str) -> str: + """Extract stage name from captain ID""" + # captain_video_decoder_1234567890 -> video_decoder + parts = captain_id.split('_') + if len(parts) >= 3: + return '_'.join(parts[1:-1]) + return captain_id + + def _compute_initial_quota( + self, + stage_name: str, + parallelism: int + ) -> ResourceQuota: + """Compute initial resource quota for a new stage""" + captain_id = f"captain_{stage_name}_{int(time.time())}" + + # Equal share allocation initially + total_stages = max(1, len(self.stages)) + + return ResourceQuota( + captain_id=captain_id, + target_parallelism=parallelism, + cpu_quota=self.cluster.available_cpu_cores / total_stages, + memory_quota_mb=self.cluster.available_memory_mb / total_stages, + gpu_quota=self.cluster.available_gpus / total_stages, + target_throughput=10.0, # Default + topology_mode=TopologyMode.ADAPTIVE + ) + + def get_sla_compliance_rate(self) -> float: + """ + Calculate SLA compliance rate + + Returns: + Percentage of requests meeting SLA (0-100) + """ + if self.total_requests == 0: + return 100.0 + + return ((self.total_requests - self.sla_violations) / self.total_requests) * 100.0 + + def get_global_stats(self) -> Dict: + """Get global system statistics""" + return { + 'total_stages': len(self.stages), + 'total_parallelism': sum(q.target_parallelism for q in self.quotas.values()), + 'sla_compliance_rate': self.get_sla_compliance_rate(), + 'total_requests': self.total_requests, + 'sla_violations': self.sla_violations, + 'cluster_cpu_util': ( + (self.cluster.total_cpu_cores - self.cluster.available_cpu_cores) / + self.cluster.total_cpu_cores * 100 + ) if self.cluster.total_cpu_cores > 0 else 0, + 'cluster_memory_util': ( + (self.cluster.total_memory_mb - self.cluster.available_memory_mb) / + self.cluster.total_memory_mb * 100 + ) if self.cluster.total_memory_mb > 0 else 0 + } + + # ========== Captain Registry Methods ========== + + def register_captain(self, captain_id: str, captain: Any) -> None: + """Register a Captain instance for direct metric collection and quota broadcast. + + Args: + captain_id: Unique identifier for the captain + captain: Captain instance to register + """ + self._captains[captain_id] = captain + + def unregister_captain(self, captain_id: str) -> None: + """Remove a Captain from the registry. + + Args: + captain_id: Unique identifier of the captain to remove + """ + self._captains.pop(captain_id, None) + + # ========== Rebalance Loop Methods ========== + + def collect_all_metrics(self) -> Dict[str, StageMetrics]: + """Step 1: Collect metrics from all registered Captains. + + For each registered captain, call captain.collect_metrics() if available, + or use captain.metrics directly. Also update internal stage_metrics dict. + + Returns: + Dict mapping stage_name to StageMetrics + """ + collected_metrics: Dict[str, StageMetrics] = {} + + for captain_id, captain in self._captains.items(): + try: + # Try collect_metrics() method first, fall back to metrics attribute + if hasattr(captain, 'collect_metrics'): + metrics = captain.collect_metrics() + elif hasattr(captain, 'metrics'): + metrics = captain.metrics + else: + continue + + if metrics and hasattr(metrics, 'stage_name'): + stage_name = metrics.stage_name + collected_metrics[stage_name] = metrics + + # Update internal stage metrics + if stage_name in self.stages: + self.stages[stage_name] = metrics + except Exception as e: + logger = logging.getLogger(__name__) + logger.warning(f"Failed to collect metrics from captain {captain_id}: {e}") + + # Also include stages that have metrics but no registered captain + for stage_name, metrics in self.stages.items(): + if stage_name not in collected_metrics: + collected_metrics[stage_name] = metrics + + return collected_metrics + + def identify_bottleneck(self, metrics: Dict[str, StageMetrics]) -> Optional[str]: + """Step 2: Identify the bottleneck stage (highest queue depth / lowest throughput). + + Find the single worst bottleneck based on queue depth and throughput. + Uses existing _identify_bottlenecks() logic but returns only the worst one. + + Args: + metrics: Dict of stage metrics + + Returns: + stage_name of worst bottleneck, or None if no bottleneck + """ + if not metrics: + return None + + # Get all bottleneck candidates + bottlenecks = self._identify_bottlenecks() + + if not bottlenecks: + return None + + # Find the worst bottleneck based on score + # Score = (queue_depth / target) + (1 - throughput_ratio) + worst_bottleneck = None + worst_score = -1.0 + + for stage_name in bottlenecks: + if stage_name not in metrics: + continue + + stage_metrics = metrics[stage_name] + + # Calculate bottleneck severity score + queue_ratio = ( + stage_metrics.queue_depth / self.target_queue_depth + if self.target_queue_depth > 0 else 0 + ) + + # Consider throughput relative to what's needed + # Higher queue with lower throughput = worse bottleneck + throughput_factor = 1.0 + if stage_metrics.throughput > 0 and stage_metrics.queue_depth > 0: + # Time to drain queue at current throughput + drain_time = stage_metrics.queue_depth / stage_metrics.throughput + sla_time = self.sla_latency_ms / 1000.0 + throughput_factor = drain_time / sla_time if sla_time > 0 else 1.0 + + score = queue_ratio + throughput_factor + + if score > worst_score: + worst_score = score + worst_bottleneck = stage_name + + return worst_bottleneck + + def reallocate_resources( + self, + bottleneck: Optional[str], + metrics: Dict[str, StageMetrics] + ) -> Dict[str, ResourceQuota]: + """Step 3: Reallocate resources, increasing quota for bottleneck. + + If bottleneck exists, shift resources toward it using tower_allocation_weights + from config (via get_stage_weight). + + Args: + bottleneck: Name of bottleneck stage, or None + metrics: Current stage metrics + + Returns: + Dict mapping captain_id to new ResourceQuota + """ + # Compute weights for each stage + stage_weights: Dict[str, float] = {} + for stage_name in self.stages: + # Base weight from config + if self.config: + base_weight = self.config.get_stage_weight(stage_name) + else: + base_weight = 1.0 + + # Boost weight for bottleneck stage + if stage_name == bottleneck: + base_weight *= 1.5 # 50% boost for bottleneck + + stage_weights[stage_name] = base_weight + + # Normalize weights + total_weight = sum(stage_weights.values()) + if total_weight > 0: + for stage_name in stage_weights: + stage_weights[stage_name] /= total_weight + + # Allocate resources based on weights + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name not in self.stages: + continue + + stage_metrics = metrics.get(stage_name, self.stages[stage_name]) + weight = stage_weights.get(stage_name, 1.0 / max(1, len(self.stages))) + + # Compute target parallelism + target_parallelism = self._compute_target_parallelism( + stage_metrics, + is_bottleneck=(stage_name == bottleneck) + ) + + # Allocate resources proportionally to weight + resource_allocation = { + 'cpu': self.cluster.available_cpu_cores * weight * target_parallelism, + 'memory_mb': self.cluster.available_memory_mb * weight * target_parallelism, + 'gpu': self.cluster.available_gpus * weight * target_parallelism + } + + # Update quota + quota.target_parallelism = target_parallelism + quota.cpu_quota = resource_allocation['cpu'] + quota.memory_quota_mb = resource_allocation['memory_mb'] + quota.gpu_quota = resource_allocation['gpu'] + quota.target_throughput = self._compute_target_throughput(stage_metrics) + quota.topology_mode = self._decide_topology(stage_name, stage_metrics) + + return self.quotas + + def apply_backpressure( + self, + bottleneck: Optional[str], + metrics: Dict[str, StageMetrics] + ) -> None: + """Step 3b: Apply backpressure to upstream stages if needed. + + If bottleneck stage memory_utilization > backpressure_threshold: + Find upstream stages (stages registered before bottleneck) + Set backpressure=True flag on their quotas + + Args: + bottleneck: Name of bottleneck stage, or None + metrics: Current stage metrics + """ + # Reset all backpressure states first + for stage_name in self._backpressure_states: + self._backpressure_states[stage_name] = False + + if not bottleneck or bottleneck not in metrics: + # No bottleneck, clear all backpressure + for quota in self.quotas.values(): + quota.backpressure = False + return + + bottleneck_metrics = metrics[bottleneck] + + # Check if memory utilization exceeds threshold + # memory_utilization is in percentage (0-100), threshold is ratio (0-1) + memory_util_ratio = bottleneck_metrics.memory_utilization / 100.0 + + if memory_util_ratio <= self._backpressure_threshold: + # Below threshold, no backpressure needed + for quota in self.quotas.values(): + quota.backpressure = False + return + + # Find bottleneck position in stage order + try: + bottleneck_idx = self._stage_order.index(bottleneck) + except ValueError: + return + + # Apply backpressure to all upstream stages (before bottleneck) + upstream_stages = set(self._stage_order[:bottleneck_idx]) + + for captain_id, quota in self.quotas.items(): + stage_name = self._get_stage_from_captain(captain_id) + if stage_name in upstream_stages: + quota.backpressure = True + self._backpressure_states[stage_name] = True + else: + quota.backpressure = False + + def broadcast_quotas(self, quotas: Dict[str, ResourceQuota]) -> None: + """Step 4: Send updated quotas to all Captains. + + For each captain in registry, call captain.set_quota(quota). + + Args: + quotas: Dict mapping captain_id to ResourceQuota + """ + logger = logging.getLogger(__name__) + + for captain_id, captain in self._captains.items(): + # Find matching quota + quota = quotas.get(captain_id) + + if quota is None: + # Try to find quota by stage name + for qid, q in quotas.items(): + if self._get_stage_from_captain(qid) == self._get_stage_from_captain(captain_id): + quota = q + break + + if quota is None: + continue + + try: + if hasattr(captain, 'set_quota'): + captain.set_quota(quota) + except Exception as e: + logger.warning(f"Failed to broadcast quota to captain {captain_id}: {e}") + + # ========== Rebalance Loop Lifecycle ========== + + def start(self) -> None: + """Start the rebalance loop in a background thread.""" + if self._running: + return + + self._running = True + self._rebalance_thread = threading.Thread( + target=self._rebalance_loop, + daemon=True, + name="Tower-Rebalance-Loop" + ) + self._rebalance_thread.start() + + def stop(self) -> None: + """Stop the rebalance loop.""" + self._running = False + if self._rebalance_thread and self._rebalance_thread.is_alive(): + self._rebalance_thread.join(timeout=self.rebalance_interval * 2) + self._rebalance_thread = None + + def _rebalance_loop(self) -> None: + """Main rebalance loop - runs periodically. + + Implements the adaptive tower macro-scheduler: + for each rebalance_interval: + 1. Collect metrics from all Captains + 2. Identify bottleneck stage (highest queue / lowest throughput) + 3. Reallocate resources (increase quota for bottleneck, apply backpressure) + 4. Broadcast new quotas to Captains + """ + logger = logging.getLogger(__name__) + logger.info(f"Tower rebalance loop started (interval={self.rebalance_interval}s)") + + while self._running: + try: + # Step 1: Collect metrics from all Captains + metrics = self.collect_all_metrics() + + if metrics: + # Step 2: Identify bottleneck stage + bottleneck = self.identify_bottleneck(metrics) + + # Step 3: Reallocate resources + new_quotas = self.reallocate_resources(bottleneck, metrics) + + # Step 3b: Apply backpressure if needed + self.apply_backpressure(bottleneck, metrics) + + # Step 4: Broadcast new quotas to Captains + self.broadcast_quotas(new_quotas) + + if bottleneck: + logger.debug(f"Rebalance: bottleneck={bottleneck}, applying adjustments") + + except Exception as e: + logger.error(f"Rebalance loop error: {e}") + + # Wait for next interval + time.sleep(self.rebalance_interval) + + logger.info("Tower rebalance loop stopped") + + # ========== Context Manager Support ========== + + def __enter__(self) -> 'Tower': + """Context manager entry - start the rebalance loop.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit - stop the rebalance loop.""" + self.stop() diff --git a/data_juicer/core/elasticjuicer/tuner/__init__.py b/data_juicer/core/elasticjuicer/tuner/__init__.py new file mode 100644 index 0000000000..7f9ee5c56d --- /dev/null +++ b/data_juicer/core/elasticjuicer/tuner/__init__.py @@ -0,0 +1,10 @@ +""" +Tuner submodule for ElasticJuicer hyperparameter optimization. + +Provides OFFLINE Ray Tune Population Based Training (PBT) for tuning +scheduling parameters. +""" + +from .pbt_tuner import PBTTuner + +__all__ = ["PBTTuner"] diff --git a/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py b/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py new file mode 100644 index 0000000000..8e304934c5 --- /dev/null +++ b/data_juicer/core/elasticjuicer/tuner/pbt_tuner.py @@ -0,0 +1,407 @@ +""" +OFFLINE Phase: Ray Tune Population Based Training (PBT) +for hyperparameter optimization of ElasticJuicer scheduling parameters. + +Tunes: + - PID controller params (kp, ki, kd) + - Safety buffers (safety_buffer_mb, target_memory_utilization) + - Predictor params (predictor_window_size, predictor_confidence_level) + - Tower allocation weights (per-stage resource proportions) +Output: base_config.yaml (a SchedulerConfig serialized to YAML) + +Usage: + from data_juicer.core.elasticjuicer.tuner import PBTTuner + + config = PBTTunerConfig( + stage_names=["filter", "mapper", "deduplicator"], + num_samples=8, + max_iterations=50, + ) + tuner = PBTTuner(config) + best_config = tuner.tune() + tuner.export_config(best_config, "base_config.yaml") +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Callable, Any +import random +import numpy as np + +# Graceful handling of optional Ray dependency +try: + import ray + from ray import tune + from ray.tune.schedulers import PopulationBasedTraining + RAY_AVAILABLE = True +except ImportError: + RAY_AVAILABLE = False + ray = None + tune = None + PopulationBasedTraining = None + +from ..scheduler.scheduler_config import SchedulerConfig +from ..scheduler.micro_scheduler import MicroScheduler, BatchSizeController + + +@dataclass +class PBTTunerConfig: + """Configuration for PBT-based hyperparameter tuning. + + Attributes: + num_samples: Number of PBT population members (parallel trials). + max_iterations: Maximum training iterations per trial. + perturbation_interval: How often PBT perturbs hyperparameters. + metric: Metric to optimize (e.g., "throughput", "score"). + mode: Optimization mode - "max" to maximize, "min" to minimize. + stage_names: Operator stage names to tune allocation weights for. + resources_per_trial: Resources allocated per trial (cpu, gpu). + grace_period: Minimum iterations before stopping poor trials. + """ + num_samples: int = 8 + max_iterations: int = 50 + perturbation_interval: int = 5 + metric: str = "throughput" + mode: str = "max" + stage_names: List[str] = field(default_factory=list) + resources_per_trial: Dict[str, float] = field( + default_factory=lambda: {"cpu": 2, "gpu": 0} + ) + grace_period: int = 5 + + +class PBTTuner: + """ + Population Based Training (PBT) tuner for ElasticJuicer scheduling parameters. + + This class implements OFFLINE hyperparameter optimization using Ray Tune's PBT + scheduler. It tunes PID controller parameters, memory safety buffers, predictor + settings, and per-stage resource allocation weights. + + The tuning process simulates batch processing with the given configuration and + measures throughput and OOM rates to find optimal parameters. + + Attributes: + config: PBTTunerConfig instance with tuning settings. + simulation_fn: Callable that simulates execution and returns metrics. + + Example: + >>> tuner_config = PBTTunerConfig( + ... stage_names=["filter", "mapper"], + ... num_samples=4, + ... max_iterations=20, + ... ) + >>> tuner = PBTTuner(tuner_config) + >>> best_config = tuner.tune() + >>> tuner.export_config(best_config, "base_config.yaml") + """ + + def __init__( + self, + config: PBTTunerConfig, + simulation_fn: Optional[Callable[[SchedulerConfig], Dict[str, float]]] = None, + ): + """ + Initialize the PBT tuner. + + Args: + config: PBTTunerConfig with tuning parameters. + simulation_fn: Optional callable that takes a SchedulerConfig and returns + a dict with "throughput" and "oom_rate" keys. If None, uses default + simulation that creates a MicroScheduler and simulates batch processing. + + Raises: + ImportError: If Ray is not installed and tune() is called. + """ + self.config = config + self.simulation_fn = simulation_fn or self._default_simulation + + def _get_search_space(self) -> Dict[str, Any]: + """ + Get the Ray Tune search space for hyperparameters. + + Returns: + Dictionary mapping parameter names to Ray Tune search distributions. + + Raises: + ImportError: If Ray Tune is not available. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + search_space = { + # PID controller parameters + "pid_kp": tune.uniform(0.1, 2.0), + "pid_ki": tune.uniform(0.01, 0.2), + "pid_kd": tune.uniform(0.01, 0.5), + + # Safety and memory parameters + "safety_buffer_mb": tune.uniform(256, 4096), + "target_memory_utilization": tune.uniform(0.6, 0.95), + + # Predictor parameters + "predictor_window_size": tune.choice([50, 100, 200, 500]), + "predictor_confidence_level": tune.uniform(0.9, 0.99), + } + + # Add per-stage allocation weights + for stage_name in self.config.stage_names: + search_space[f"weight_{stage_name}"] = tune.uniform(0.1, 5.0) + + return search_space + + def _default_simulation(self, scheduler_config: SchedulerConfig) -> Dict[str, float]: + """ + Default simulation function that tests a SchedulerConfig. + + Creates a MicroScheduler with the given PID parameters and simulates + N iterations of batch processing with random memory fluctuations. + Measures simulated throughput and OOM events. + + Args: + scheduler_config: Configuration to evaluate. + + Returns: + Dictionary with "throughput" (samples/sec) and "oom_rate" (0.0-1.0). + """ + # Create MicroScheduler with config parameters + micro_scheduler = MicroScheduler( + memory_predictor=None, + initial_batch_size=scheduler_config.initial_batch_size, + min_batch_size=scheduler_config.min_batch_size, + max_batch_size=scheduler_config.max_batch_size, + target_memory_utilization=scheduler_config.target_memory_utilization, + safety_buffer_mb=scheduler_config.safety_buffer_mb, + use_gpu=scheduler_config.use_gpu_memory, + enable_auto_adjust=scheduler_config.enable_auto_adjust, + ) + + # Override PID parameters in the controller + micro_scheduler.controller.pid.kp = scheduler_config.pid_kp + micro_scheduler.controller.pid.ki = scheduler_config.pid_ki + micro_scheduler.controller.pid.kd = scheduler_config.pid_kd + + # Simulation parameters + num_iterations = 100 + total_samples_processed = 0 + oom_events = 0 + + # Simulated available memory (starts high, fluctuates) + base_memory_mb = 8000.0 # 8GB base + + for i in range(num_iterations): + # Get current batch size from scheduler + batch_size = micro_scheduler.controller.current_batch_size + + # Simulate memory usage per sample (varies randomly) + memory_per_sample = np.random.uniform(5.0, 20.0) # 5-20 MB per sample + + # Add random memory fluctuation (simulates other processes) + memory_fluctuation = np.random.uniform(-500, 500) + + # Calculate simulated memory state + simulated_used_memory = batch_size * memory_per_sample + memory_fluctuation + simulated_available = base_memory_mb - simulated_used_memory + + # Check for simulated OOM + if simulated_available < scheduler_config.safety_buffer_mb * 0.5: + oom_events += 1 + # Report OOM to scheduler + micro_scheduler.controller.report_oom(batch_size, simulated_used_memory) + # Penalize throughput for OOM + total_samples_processed += batch_size // 4 + else: + # Successful batch + total_samples_processed += batch_size + + # Update scheduler (simulates feedback loop) + micro_scheduler.controller.update_batch_size( + predicted_memory_per_sample=memory_per_sample + ) + + # Calculate metrics + throughput = total_samples_processed / num_iterations # samples per iteration + oom_rate = oom_events / num_iterations + + return { + "throughput": throughput, + "oom_rate": oom_rate, + } + + def _trial_config_to_scheduler_config(self, trial_config: Dict) -> SchedulerConfig: + """ + Convert Ray Tune trial config dict to SchedulerConfig. + + Args: + trial_config: Dictionary of hyperparameters from Ray Tune. + + Returns: + SchedulerConfig instance with the trial's hyperparameters. + """ + # Extract tower allocation weights from weight_{stage} keys + tower_weights = {} + for key, value in trial_config.items(): + if key.startswith("weight_"): + stage_name = key[7:] # Remove "weight_" prefix + tower_weights[stage_name] = value + + return SchedulerConfig( + # PID parameters + pid_kp=trial_config.get("pid_kp", 0.5), + pid_ki=trial_config.get("pid_ki", 0.05), + pid_kd=trial_config.get("pid_kd", 0.1), + + # Safety parameters + safety_buffer_mb=trial_config.get("safety_buffer_mb", 1000.0), + target_memory_utilization=trial_config.get("target_memory_utilization", 0.85), + + # Predictor parameters + predictor_window_size=int(trial_config.get("predictor_window_size", 100)), + predictor_confidence_level=trial_config.get("predictor_confidence_level", 0.95), + + # Tower allocation weights + tower_allocation_weights=tower_weights if tower_weights else None, + ) + + def _trainable(self, trial_config: Dict) -> None: + """ + Ray Tune trainable function. + + Converts trial config to SchedulerConfig, runs simulation, + and reports metrics to Ray Tune. + + Args: + trial_config: Dictionary of hyperparameters from Ray Tune. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + # Convert trial config to SchedulerConfig + scheduler_config = self._trial_config_to_scheduler_config(trial_config) + + # Run simulation + results = self.simulation_fn(scheduler_config) + + throughput = results.get("throughput", 0.0) + oom_rate = results.get("oom_rate", 1.0) + + # Calculate composite score (higher is better) + # Penalize OOM events heavily + score = throughput * (1.0 - oom_rate) + + # Report metrics to Ray Tune + tune.report( + throughput=throughput, + oom_rate=oom_rate, + score=score, + ) + + def tune(self) -> SchedulerConfig: + """ + Run PBT hyperparameter tuning. + + Sets up Ray Tune with PBT scheduler, runs the tuning process, + and returns the best configuration found. + + Returns: + SchedulerConfig with the best hyperparameters found. + + Raises: + ImportError: If Ray Tune is not installed. + RuntimeError: If tuning fails or no results are found. + """ + if not RAY_AVAILABLE: + raise ImportError( + "Ray Tune is required for PBT tuning. " + "Install it with: pip install 'ray[tune]'" + ) + + # Initialize Ray if not already initialized + if not ray.is_initialized(): + ray.init(ignore_reinit_error=True) + + # Get search space + search_space = self._get_search_space() + + # Define perturbation bounds for PBT + hyperparam_mutations = { + "pid_kp": tune.uniform(0.1, 2.0), + "pid_ki": tune.uniform(0.01, 0.2), + "pid_kd": tune.uniform(0.01, 0.5), + "safety_buffer_mb": tune.uniform(256, 4096), + "target_memory_utilization": tune.uniform(0.6, 0.95), + "predictor_window_size": [50, 100, 200, 500], + "predictor_confidence_level": tune.uniform(0.9, 0.99), + } + + # Add stage weight mutations + for stage_name in self.config.stage_names: + hyperparam_mutations[f"weight_{stage_name}"] = tune.uniform(0.1, 5.0) + + # Create PBT scheduler + pbt_scheduler = PopulationBasedTraining( + time_attr="training_iteration", + perturbation_interval=self.config.perturbation_interval, + hyperparam_mutations=hyperparam_mutations, + quantile_fraction=0.25, # Top 25% survive + resample_probability=0.25, # 25% chance to resample instead of perturb + ) + + # Run tuning + analysis = tune.run( + self._trainable, + config=search_space, + metric=self.config.metric, + mode=self.config.mode, + num_samples=self.config.num_samples, + scheduler=pbt_scheduler, + resources_per_trial=self.config.resources_per_trial, + stop={"training_iteration": self.config.max_iterations}, + verbose=1, + raise_on_failed_trial=False, + ) + + # Get best trial + best_trial = analysis.get_best_trial( + metric=self.config.metric, + mode=self.config.mode, + ) + + if best_trial is None: + raise RuntimeError( + "PBT tuning failed: no successful trials found. " + "Check simulation function and resource availability." + ) + + # Convert best config to SchedulerConfig + best_config = self._trial_config_to_scheduler_config(best_trial.config) + + return best_config + + def export_config(self, config: SchedulerConfig, path: str = "base_config.yaml") -> None: + """ + Export a SchedulerConfig to a YAML file. + + Args: + config: SchedulerConfig to export. + path: Output file path (default: "base_config.yaml"). + """ + config.to_yaml(path) + + @staticmethod + def load_config(path: str) -> SchedulerConfig: + """ + Load a SchedulerConfig from a YAML file. + + Args: + path: Path to the YAML configuration file. + + Returns: + SchedulerConfig instance loaded from the file. + """ + return SchedulerConfig.from_yaml(path) diff --git a/data_juicer/core/executor/event_logging_mixin.py b/data_juicer/core/executor/event_logging_mixin.py index c994b455ad..14c0c34e54 100644 --- a/data_juicer/core/executor/event_logging_mixin.py +++ b/data_juicer/core/executor/event_logging_mixin.py @@ -646,6 +646,16 @@ def log_op_complete( "operation_class": operation_name, } + # Ensure input_rows and output_rows are integers (they might be strings from some sources) + try: + input_rows = int(input_rows) if input_rows is not None else None + except (ValueError, TypeError): + input_rows = None + try: + output_rows = int(output_rows) if output_rows is not None else None + except (ValueError, TypeError): + output_rows = None + # Only include row counts and derived metrics if they're meaningful (non-zero or explicitly set) if input_rows is not None and input_rows > 0: metadata["input_rows"] = input_rows diff --git a/data_juicer/core/executor/ray_executor_partitioned.py b/data_juicer/core/executor/ray_executor_partitioned.py index ddcb1fc442..342b012a77 100644 --- a/data_juicer/core/executor/ray_executor_partitioned.py +++ b/data_juicer/core/executor/ray_executor_partitioned.py @@ -13,6 +13,7 @@ import os import shutil import time +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional @@ -179,9 +180,13 @@ def __init__(self, cfg: Optional[Namespace] = None): checkpoint_cfg = getattr(self.cfg, "checkpoint", None) checkpoint_dir = getattr(self.cfg, "checkpoint_dir", os.path.join(self.work_dir, "checkpoints")) + # Debug: log checkpoint_cfg type and value + logger.info(f"DEBUG: checkpoint_cfg type = {type(checkpoint_cfg)}, value = {checkpoint_cfg}") + if checkpoint_cfg: # Use ConfigAccessor to handle both dict and object configurations checkpoint_enabled = ConfigAccessor.get(checkpoint_cfg, "enabled", True) + logger.info(f"DEBUG: checkpoint_enabled from ConfigAccessor = {checkpoint_enabled}") strategy_str = ConfigAccessor.get(checkpoint_cfg, "strategy", "every_op") checkpoint_n_ops = ConfigAccessor.get(checkpoint_cfg, "n_ops", 1) checkpoint_op_names = ConfigAccessor.get(checkpoint_cfg, "op_names", []) @@ -449,11 +454,21 @@ def _run_impl(self, load_data_np: Optional[PositiveInt] = None, skip_return=Fals # Detect convergence points for global operations convergence_points = self._detect_convergence_points(self.cfg) + # Debug logging for checkpoint status + logger.info(f"DEBUG: checkpoint_enabled = {self.ckpt_manager.checkpoint_enabled}") + logger.info(f"DEBUG: convergence_points = {convergence_points}") + + # Choose processing strategy based on checkpointing and convergence points + # Fast path: when checkpointing is disabled and no convergence points, + # process without manual partitioning to let Ray Data handle parallelism if convergence_points: logger.info(f"Found convergence points at operations: {convergence_points}") final_dataset = self._process_with_convergence(dataset, ops, convergence_points) + elif not self.ckpt_manager.checkpoint_enabled: + logger.info("Checkpointing disabled, using fast path without manual partitioning") + final_dataset = self._process_without_partitioning(dataset, ops) else: - logger.info("No convergence points found, processing with simple partitioning") + logger.info("Checkpointing enabled, processing with partitioning for checkpoint support") final_dataset = self._process_with_simple_partitioning(dataset, ops) # Export final dataset @@ -482,6 +497,42 @@ def cleanup_temp_files(self): else: logger.info("No temporary files found to clean up") + def _process_without_partitioning(self, dataset: RayDataset, ops: List) -> RayDataset: + """ + Process dataset without manual partitioning. + + This is the fast path when checkpointing is disabled. + Ray Data handles parallelism automatically through map_batches with concurrency. + """ + logger.info("Processing without manual partitioning (fast path)...") + + start_time = time.time() + + # Pre-execute DAG monitoring (log operation start events) + if self.pipeline_dag: + self._pre_execute_operations_with_dag_monitoring(ops, partition_id=0) + + # Execute operations (lazy evaluation - Ray Data handles parallelism) + processed_dataset = dataset.process(ops) + + # Force materialization only at the end (required for export anyway) + logger.info("Materializing final dataset...") + processed_dataset.data = processed_dataset.data.materialize() + + duration = time.time() - start_time + logger.info(f"Processing completed in {duration:.2f}s") + + # Post-execute DAG monitoring + if self.pipeline_dag: + try: + output_rows = processed_dataset.data.count() + metrics = {"duration": duration, "input_rows": "unknown", "output_rows": output_rows} + except Exception: + metrics = {"duration": duration} + self._post_execute_operations_with_dag_monitoring(ops, partition_id=0, metrics=metrics) + + return processed_dataset + def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): """ Process dataset with real partitioning using Ray Data's split and union. @@ -500,9 +551,10 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): # Process each partition separately with checkpointing logger.info("Processing partitions with checkpointing support...") - processed_partitions = [] + processed_partitions = [None] * len(partitions) - for i, partition in enumerate(partitions): + def process_single_partition(i, partition): + """Helper function to process a single partition.""" logger.info(f"Processing partition {i+1}/{len(partitions)}") # Log partition start event @@ -518,9 +570,6 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): # Apply operations with checkpointing support and DAG monitoring processed_partition = self._process_with_checkpointing(partition_dataset, i, ops) - # Store the processed partition's data - processed_partitions.append(processed_partition.data) - # Log partition completion event self._log_event( event_type=EventType.PARTITION_COMPLETE, @@ -528,6 +577,18 @@ def _process_with_simple_partitioning(self, dataset: RayDataset, ops: List): partition_id=i, ) + return i, processed_partition.data + + # Process partitions in parallel using ThreadPoolExecutor + with ThreadPoolExecutor(max_workers=len(partitions)) as executor: + futures = { + executor.submit(process_single_partition, i, partition): i + for i, partition in enumerate(partitions) + } + for future in as_completed(futures): + i, result = future.result() + processed_partitions[i] = result + # Merge all processed partitions back into a single dataset logger.info("Merging processed partitions...") if len(processed_partitions) == 1: @@ -1001,4 +1062,4 @@ def _clear_invalid_checkpoints(self) -> None: if os.path.exists(self.ckpt_manager.ckpt_dir): logger.warning(f"Clearing invalid checkpoints in {self.ckpt_manager.ckpt_dir}") shutil.rmtree(self.ckpt_manager.ckpt_dir) - os.makedirs(self.ckpt_manager.ckpt_dir, exist_ok=True) + os.makedirs(self.ckpt_manager.ckpt_dir, exist_ok=True) \ No newline at end of file diff --git a/data_juicer/core/ray_exporter.py b/data_juicer/core/ray_exporter.py index ea4b700ae9..a6b98e62f0 100644 --- a/data_juicer/core/ray_exporter.py +++ b/data_juicer/core/ray_exporter.py @@ -131,12 +131,33 @@ def _export_impl(self, dataset, export_path, columns=None): :param columns: the columns to export. :return: """ + # Debug: Log dataset info before export + logger.info(f"Starting export to: {export_path}") + + # Materialize the dataset first to get accurate count + # Ray dataset needs to be materialized before calling count() or num_blocks() + try: + dataset = dataset.materialize() + logger.info("Dataset materialized successfully") + except Exception as e: + logger.warning(f"Dataset materialize failed (may already be materialized): {e}") + + # Get row count for validation + try: + row_count = dataset.count() + logger.info(f"Dataset row count before export: {row_count}") + if row_count == 0: + logger.warning("Dataset is empty (0 rows)! Export will produce no data files.") + except Exception as e: + logger.warning(f"Could not get dataset row count: {e}") + row_count = None + # Handle empty dataset case - Ray returns None for columns() on empty datasets # Check if dataset is empty by calling columns() regardless of columns parameter cols = dataset.columns() if cols is None: # Empty dataset with unknown schema - create an empty file - logger.warning(f"Dataset is empty, creating empty export file at {export_path}") + logger.warning(f"Dataset is empty (no columns), creating empty export file at {export_path}") os.makedirs(os.path.dirname(export_path) or ".", exist_ok=True) with open(export_path, "w"): pass # Create empty file @@ -182,7 +203,28 @@ def _export_impl(self, dataset, export_path, columns=None): if not export_path.startswith("s3://"): os.makedirs(export_path, exist_ok=True) - return export_method(dataset, export_path, **export_kwargs) + result = export_method(dataset, export_path, **export_kwargs) + + # Post-export verification: check if files were actually written + if not export_path.startswith("s3://"): + if os.path.isdir(export_path): + files = os.listdir(export_path) + if files: + logger.info(f"Export verification: {len(files)} file(s) written to {export_path}") + # Log first few files for debugging + for f in files[:5]: + file_path = os.path.join(export_path, f) + file_size = os.path.getsize(file_path) + logger.info(f" - {f} ({file_size} bytes)") + if len(files) > 5: + logger.info(f" ... and {len(files) - 5} more files") + else: + logger.warning(f"Export verification FAILED: No files written to {export_path}!") + logger.warning("This may indicate the dataset was empty or an export error occurred.") + else: + logger.warning(f"Export path {export_path} is not a directory after export!") + + return result def export(self, dataset, columns=None): """ diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 4ae7d07bc9..6b4de22c39 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -402,6 +402,10 @@ def __init__(self, *args, **kwargs): self.ray_execution_mode = kwargs.get("ray_execution_mode", None) assert self.ray_execution_mode in [None, "actor", "task"] + # Override the number of output blocks for Ray Data map_batches + # (helps prevent Ray block starvation when Ray fuses/coalesces blocks) + self.override_num_blocks = kwargs.get("override_num_blocks", None) + # Local import to avoid logger being serialized in multiprocessing from loguru import logger diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 4a7901d98b..c924ce825f 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -1,3 +1,4 @@ +import os from typing import Optional import numpy as np @@ -123,7 +124,10 @@ def __init__( trust_remote_code=trust_remote_code, ) # the original score predicted by laion-ai's scorer is within [0, 10] - self.need_normalized_by_ten = "shunk031/aesthetics-predictor" in hf_scorer_model + self.need_normalized_by_ten = ( + "shunk031/aesthetics-predictor" in hf_scorer_model + or "aesthetics-predictor" in os.path.basename(hf_scorer_model) + ) self.frame_sampling_method = frame_sampling_method self.frame_num = frame_num