diff --git a/README_this_pr.md b/README_this_pr.md new file mode 100644 index 0000000000..64bbd541c2 --- /dev/null +++ b/README_this_pr.md @@ -0,0 +1,30 @@ +# Data-Juicer-HumanVbench-ops + +This is the operator contribution page for the paper: **HumanVBench: Probing Human-Centric Video Understanding in MLLMs with Automatically Synthesized Benchmarks (CVPR'26)**. + +## Related Operator Documentation Locations + +* **Example Recipe:** `demos/video_humanvbench_simple/analyzer.yaml` +* **Operator Definition:** `data_juicer/config/config_all.yaml` + +## Quick Start + +As HumanVBench operators involve modifications to external repositories, these adjusted repositories are currently stored in: +`thirdparty/humanvbench_models` + +To use these operators, you can choose: + +1. **Manual Mode:** Follow the instructions in `thirdparty/humanvbench_models/README.md` to manually complete the `git clone` and `.diff` patch merging, then run: + +```shell +dj-process --config demos/video_humanvbench_simple/analyzer.yaml + +``` + +2. **Automatic Mode (Recommended):** Start running directly: + +```shell +dj-process --config demos/video_humanvbench_simple/analyzer.yaml + +``` +The relevant operators already cover the logic for automatic `git clone` and `merge diff`, making manual intervention non-essential. diff --git a/README_this_pr_CH.md b/README_this_pr_CH.md new file mode 100644 index 0000000000..8c987e066d --- /dev/null +++ b/README_this_pr_CH.md @@ -0,0 +1,29 @@ +# Data-Juicer-HumanVbench-ops + +这是论文:**HumanVBench: Probing Human-Centric Video Understanding in MLLMs with Automatically Synthesized Benchmarks (CVPR'26)** 的算子贡献页。 + +## 相关算子介绍文件位置 + +* **范例 Recipe:** `demos/video_humanvbench_simple/analyzer.yaml` +* **算子定义:** `data_juicer/config/config_all.yaml` + +## 快速开始 + +由于 HumanVBench 算子涉及外部仓库的修改,这些经过调整的仓库目前存储在: +`thirdparty/humanvbench_models` + +为了使用这些算子,你可以选择: + +1. **手动模式:** 按照 `thirdparty/humanvbench_models/README.md` 下的指引手动完成 `git clone` 和 `.diff` 补丁合并,然后运行: +```shell +dj-process --config demos/video_humanvbench_simple/analyzer.yaml + +``` + + +2. **自动模式(推荐):** 直接开始运行: +```shell +dj-process --config demos/video_humanvbench_simple/analyzer.yaml + +``` +我们在相关算子已经涵盖了自动 `git clone` 和 `merge diff` 的逻辑,手动干预是非必须的。 diff --git a/data_juicer/config/config_all.yaml b/data_juicer/config/config_all.yaml index 30d6b5f71b..38faca9d92 100644 --- a/data_juicer/config/config_all.yaml +++ b/data_juicer/config/config_all.yaml @@ -199,19 +199,6 @@ process: model_params: {} # Parameters for initializing the API model. sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95} - expand_macro_mapper: # expand macro definitions in Latex text. - - latex_figure_context_extractor_mapper: # Extract figures and their citing context from LaTeX source. - citation_commands: ['\ref', '\cref', '\Cref', '\autoref'] # LaTeX reference commands to search for citing paragraphs. - paragraph_separator: '\n\n' # Pattern for splitting LaTeX text into paragraphs. - caption_key: 'caption' # Output field name for the figure caption. - label_key: 'label' # Output field name for the LaTeX label. - context_key: 'citing_paragraphs' # Output field name for citing paragraphs. - parent_caption_key: 'parent_caption' # Output field name for the parent figure's caption (subfigures only). - parent_label_key: 'parent_label' # Output field name for the parent figure's label (for grouping subfigures). - - latex_merge_tex_mapper: # Extract and concatenate all .tex files from a compressed LaTeX project archive. - compressed_file_key: 'compressed_file' # Field storing the archive path. - separator: '\n\n' # Separator between concatenated .tex files. - max_file_size: 52428800 # 50 MB; skip .tex entries larger than this (zip bomb protection). - max_total_size: 104857600 # 100 MB; cumulative limit for all extracted .tex content. - extract_entity_attribute_mapper: # Extract attributes for given entities from the text. api_model: 'gpt-4o' # API model name. query_entities: ["孙悟空", "猪八戒"] # Entity list to be queried. @@ -698,7 +685,61 @@ process: save_visualization_dir: None # The path for saving visualization results. - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. + +# When use HumanVBench mapper, keep_stats_in_res_ds should be set true + + - video_human_tracks_extraction_mapper: # Get the body and face trajectory bounding box of people in one shot of the video. To ensure correctness, it should be applied after video_split_by_scene_mapper + face_track_bbox_path: /your_path/bounding_box_track # The storage location of the bounding box tracks of the characters in the video + mem_required: '10GB' + + # video_human_tracks_face_demographic_mapper should be operated after video_human_tracks_extraction_mapper. + - video_human_tracks_face_demographic_mapper: # Get the facial demographics of each person based on the results of video_human_tracks_extraction_mapper + original_data_save_path: your_path/bounding_box_track # The location where the specific results of each frame's detection are stored + detect_interval: 5 + + # video_audio_detect_age_gender_mapper should be operated after video_tagging_from_audio_mapper. + - video_audio_detect_age_gender_mapper: # If the audio is speech, classify the gender and age of the speech + hf_audio_mapper: 'audeering/wav2vec2-large-robust-24-ft-age-gender' # Huggingface model name for speech age and gender classification + mem_required: '7GB' + + # video_captioning_from_human_tracks_mapper should be operated after video_human_tracks_extraction_mapper. + - video_captioning_from_human_tracks_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on the single person in the video for captioning + video_describe_model_path: DAMO-NLP-SG/VideoLLaMA3-7B # model path to VideoLLaMA3-7B + trust_remote_code: true + temp_video_path: ./temp_video_path # Used to store temporary videos that will be removed finally. + mem_required: '25GB' + + # video_captioning_face_attribute_emotion_mapper should be operated after video_human_tracks_extraction_mapper. + - video_captioning_face_attribute_emotion_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on judging the gender, age, and race of a single person in the video + face_track_query: Please only describe the appearance and facial emotions of the person in the video in detail. Don't mention the background. Less than 80 words. + trust_remote_code: true + cropping_face_video_temp_path: ./temp_video_path # Used to store temporary videos + video_describe_model_path: DAMO-NLP-SG/VideoLLaMA3-7B # Huggingface model DAMO-NLP-SG/VideoLLaMA3-7B + mem_required: '25GB' + + # video_active_speaker_detect_mapper must be operated after video_tagging_from_audio_mapper and video_human_tracks_extraction_mapper. + - video_active_speaker_detect_mapper: # Based on the results of video_human_tracks_extraction_mapper, determine whether each person is an active speaker + temp_save_path: ./temp_path # Used to store temporary videos + active_threshold: 15 # Higher values are stricter, reducing false positives from noise but potentially increasing missed detections + mem_required: '10GB' + + - video_audio_ASR_mapper: # Automatic speech recognition from video speech + model_dir_ASR: 'FunAudioLLM/SenseVoiceSmall' # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + + - video_audio_speech_emotion_mapper: # Speech emotion recognition from video speech + model_dir_emo: 'FunAudioLLM/SenseVoiceSmall' # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + + + + # Filter ops + - video_face_ratio_filter: # Filter to retain human-centric videos + threshold: 0.65 # The lower limit of the ratio of frames with faces to the total number of video frames + detect_interval: 4 + any_or_all: any + - alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range. tokenization: false # whether to count the ratio of alphanumeric to the total number of tokens. min_ratio: 0.0 # the min ratio of filter range diff --git a/data_juicer/core/data/load_strategy.py b/data_juicer/core/data/load_strategy.py index 63a293b5d3..cdd6417ed2 100644 --- a/data_juicer/core/data/load_strategy.py +++ b/data_juicer/core/data/load_strategy.py @@ -631,7 +631,7 @@ def load_data(self, **kwargs): # Use ray.data functions directly with PyArrow filesystem support # Ray's read functions support filesystem parameter via PyArrow - if data_format in {"json", "jsonl", "json.gz", "jsonl.gz", "json.zst", "jsonl.zst"}: + if data_format in {"json", "jsonl"}: # For JSON, we need to use read_json_stream with filesystem from data_juicer.core.data.ray_dataset import read_json_stream diff --git a/data_juicer/core/data/ray_dataset.py b/data_juicer/core/data/ray_dataset.py index 2d8b198565..e5bd739e6b 100644 --- a/data_juicer/core/data/ray_dataset.py +++ b/data_juicer/core/data/ray_dataset.py @@ -355,7 +355,7 @@ def count(self) -> int: @classmethod def read(cls, data_format: str, paths: Union[str, List[str]]) -> RayDataset: - if data_format in {"json", "jsonl", "json.gz", "jsonl.gz", "json.zst", "jsonl.zst"}: + if data_format in {"json", "jsonl"}: return RayDataset.read_json(paths) elif data_format == "webdataset": return RayDataset.read_webdataset(paths) @@ -453,7 +453,7 @@ def read_json_stream( include_paths: bool = False, ignore_missing_paths: bool = False, shuffle: Union[Literal["files"], None] = None, - file_extensions: Optional[List[str]] = ["json", "jsonl", "json.gz", "jsonl.gz", "json.zst", "jsonl.zst"], + file_extensions: Optional[List[str]] = ["json", "jsonl"], concurrency: Optional[int] = None, override_num_blocks: Optional[int] = None, **arrow_json_args, diff --git a/data_juicer/format/json_formatter.py b/data_juicer/format/json_formatter.py index b97a606f1b..25525c9044 100644 --- a/data_juicer/format/json_formatter.py +++ b/data_juicer/format/json_formatter.py @@ -6,10 +6,10 @@ class JsonFormatter(LocalFormatter): """ The class is used to load and format json-type files. - Default suffixes is `['.json', '.jsonl', '.json.gz', '.jsonl.gz', '.json.zst', '.jsonl.zst']` + Default suffixes is `['.json', '.jsonl', '.jsonl.zst']` """ - SUFFIXES = [".json", ".jsonl", ".json.gz", ".jsonl.gz", ".json.zst", ".jsonl.zst"] + SUFFIXES = [".json", ".jsonl", ".jsonl.zst"] def __init__(self, dataset_path, suffixes=None, **kwargs): """ diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py index 2825ed1c01..5996c888fa 100644 --- a/data_juicer/ops/filter/__init__.py +++ b/data_juicer/ops/filter/__init__.py @@ -56,6 +56,7 @@ from .video_watermark_filter import VideoWatermarkFilter from .word_repetition_filter import WordRepetitionFilter from .words_num_filter import WordsNumFilter +from .video_face_ratio_filter import VideoFaceRatioFilter __all__ = [ "AlphanumericFilter", @@ -114,6 +115,7 @@ "WordRepetitionFilter", "WordsNumFilter", "GeneralFieldFilter", + "VideoFaceRatioFilter" ] NON_STATS_FILTERS = [ diff --git a/data_juicer/ops/filter/image_face_count_filter.py b/data_juicer/ops/filter/image_face_count_filter.py index a520c121b2..31dabbaf69 100644 --- a/data_juicer/ops/filter/image_face_count_filter.py +++ b/data_juicer/ops/filter/image_face_count_filter.py @@ -67,8 +67,10 @@ def __init__( self.min_face_count = min_face_count self.max_face_count = max_face_count - self.extra_kwargs = self._default_kwargs.copy() - self.extra_kwargs.update((k, v) for k, v in kwargs.items() if k in self.extra_kwargs) + self.extra_kwargs = self._default_kwargs + for key in kwargs: + if key in self.extra_kwargs: + self.extra_kwargs[key] = kwargs[key] if any_or_all not in ["any", "all"]: raise ValueError(f"Keep strategy [{any_or_all}] is not supported. " f'Can only be one of ["any", "all"].') @@ -96,10 +98,13 @@ def compute_stats_single(self, sample, context=False): # count the number of detected faces in each image face_counts = {} - for key, image in images.items(): - dets = detect_faces(image, model, **self.extra_kwargs) - face_counts[key] = len(dets) - logger.debug(f"face counts: {face_counts}") + try: + for key, image in images.items(): + dets = detect_faces(image, model, **self.extra_kwargs) + face_counts[key] = len(dets) + logger.debug(f"face counts: {face_counts}") + except Exception as e: + logger.exception(e) sample[Fields.stats][StatsKeys.face_counts] = [face_counts[key] for key in loaded_image_keys] return sample diff --git a/data_juicer/ops/filter/image_face_ratio_filter.py b/data_juicer/ops/filter/image_face_ratio_filter.py index 2ba21135e8..9a9a935ec6 100644 --- a/data_juicer/ops/filter/image_face_ratio_filter.py +++ b/data_juicer/ops/filter/image_face_ratio_filter.py @@ -67,8 +67,10 @@ def __init__( self.min_ratio = min_ratio self.max_ratio = max_ratio - self.extra_kwargs = self._default_kwargs.copy() - self.extra_kwargs.update((k, v) for k, v in kwargs.items() if k in self.extra_kwargs) + self.extra_kwargs = self._default_kwargs + for key in kwargs: + if key in self.extra_kwargs: + self.extra_kwargs[key] = kwargs[key] if any_or_all not in ["any", "all"]: raise ValueError(f"Keep strategy [{any_or_all}] is not supported. " f'Can only be one of ["any", "all"].') diff --git a/data_juicer/ops/filter/video_face_ratio_filter.py b/data_juicer/ops/filter/video_face_ratio_filter.py new file mode 100644 index 0000000000..62465fbe84 --- /dev/null +++ b/data_juicer/ops/filter/video_face_ratio_filter.py @@ -0,0 +1,143 @@ +import av +import numpy as np +from data_juicer.utils.constant import Fields, StatsKeys +from data_juicer.utils.mm_utils import (load_data_with_context, load_video, + pil_to_opencv, pil_to_opencv, process_each_frame) +from ..base_op import OPERATORS, Filter +from ..op_fusion import LOADED_VIDEOS +from ..op_fusion import INTER_SAMPLED_FRAMES + +import psutil +import gc,os + + +import cv2,dlib +from PIL import ImageFilter + +OP_NAME = 'video_face_ratio_filter' +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) + +class VideoFaceRatioFilter(Filter): + """ + Keep data samples whose videos' durations are within a specified range. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + + def __init__(self, + threshold: float = 0.8, + detect_interval: int = 1, + any_or_all: str = 'all', + *args, + **kwargs): + """ + Initialization method. + + :param any_or_all: keep this sample with 'any' or 'all' strategy of + all videos. 'any': keep this sample if any videos meet the + condition. 'all': keep this sample only if all videos meet the + condition. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.threshold = threshold + + if any_or_all not in ['any', 'all']: + raise ValueError(f'Keep strategy [{any_or_all}] is not supported. ' + f'Can only be one of ["any", "all"].') + self.any = (any_or_all == 'any') + + # Initialize face detector + self.detector = dlib.get_frontal_face_detector() + + + self.detect_interval = detect_interval + + + def compute_stats_single(self, sample, rank=None, context=False): + # check if it's computed already + if StatsKeys.video_face_exist in sample[Fields.stats]: + return sample + + # load videos + loaded_video_keys = sample[self.video_key] + video_faces_ratio = {} + + # face_detect_S3FD = get_model(self.detector_key, rank=rank) + + process = psutil.Process(os.getpid()) + # memory_before = process.memory_info().rss / 1024 ** 2 # MB + + + for video_key in loaded_video_keys: + try: + with av.open(video_key) as container: + # getting video stream + video_stream = next(s for s in container.streams if s.type == 'video') + # iterate over the video frame and detect faces + frame_counter = 0 + total_frames = 0 + frames_with_face = 0 + detect_num = 0 + for packet in container.demux(video_stream): + try: + for frame in packet.decode(): + total_frames += 1 + frame_counter += 1 + + if frame_counter % self.detect_interval == 0: + detect_num = detect_num + 1 + img = frame.to_image() + image = pil_to_opencv(img) + # imageNumpy = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # faces = face_detect_S3FD.detect_faces(imageNumpy, conf_th=0.9, scales=[0.25]) + faces = self.detector(image) + if len(faces) > 0: + frames_with_face += 1 + except Exception as e: + print(f"Frame decoding error in video {video_key}: {e}") + frames_with_face = 0 + detect_num = 0 + + # calculate the proportion of the number of face frames + if detect_num > 0: + face_ratio = frames_with_face / detect_num + else: + face_ratio = 0.0 + video_faces_ratio[video_key] = face_ratio + except av.AVError as e: + print(f"Error opening video {video_key}: {e}") + video_faces_ratio[video_key] = 0.0 + finally: + container.close() + + video_faces_ratio[video_key] = face_ratio + + # get video faces ratio + sample[Fields.stats][StatsKeys.video_face_exist] = [ + video_faces_ratio[video_key] for video_key in sample[self.video_key] + ] + + memory_after = process.memory_info().rss / 1024 ** 2 # MB + print(f"Memory Usage: {memory_after:.2f} MB") + + gc.collect() + + return sample + + def process_single(self, sample): + video_faces_ratio = sample[Fields.stats][StatsKeys.video_face_exist] + keep_bools = np.array([ + duration >= self.threshold + for duration in video_faces_ratio + ]) + if len(keep_bools) <= 0: + return True + + # different strategies + if self.any: + return keep_bools.any() + else: + return keep_bools.all() diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index c88430e7e9..5aa002bd0e 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -115,8 +115,10 @@ def __init__( self.divisible = divisible self.relative = relative - self.extra_kwargs = self._default_kwargs.copy() - self.extra_kwargs.update((k, v) for k, v in kwargs.items() if k in self.extra_kwargs) + self.extra_kwargs = self._default_kwargs + for key in kwargs: + if key in self.extra_kwargs: + self.extra_kwargs[key] = kwargs[key] if any_or_all not in ["any", "all"]: raise ValueError(f"Keep strategy [{any_or_all}] is not supported. " f'Can only be one of ["any", "all"].') diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 9be5b8accd..009f66574b 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -47,8 +47,6 @@ from .imgdiff_difference_caption_generator_mapper import ( Difference_Caption_Generator_Mapper, ) -from .latex_figure_context_extractor_mapper import LatexFigureContextExtractorMapper -from .latex_merge_tex_mapper import LatexMergeTexMapper from .mllm_mapper import MllmMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper @@ -114,6 +112,16 @@ from .video_undistort_mapper import VideoUndistortMapper from .video_whole_body_pose_estimation_mapper import VideoWholeBodyPoseEstimationMapper from .whitespace_normalization_mapper import WhitespaceNormalizationMapper +from .video_human_tracks_extraction_mapper import VideoHumanTracksExtractionMapper +from .video_active_speaker_detect_mapper import VideoActiveSpeakerDetectMapper +from .video_audio_detect_age_gender_mapper import VideoAudioDetectAgeGenderMapper +from .video_audio_ASR_mapper import VideoAudioASRMapper +from .video_audio_speech_emotion_mapper import VideoAudioSpeechEmotionMapper +from .video_captioning_face_attribute_emotion_mapper import VideoCaptioningFaceAttributeEmotionMapper +from .video_captioning_from_human_tracks_mapper import VideoCaptioningFromHumanTracksMapper +from .video_captioning_face_attribute_emotion_mapper import VideoCaptioningFaceAttributeEmotionMapper +from .video_human_tracks_face_demographic_mapper import VideoHumantrackFaceDemographicMapper + __all__ = [ "AudioAddGaussianNoiseMapper", @@ -161,8 +169,6 @@ "ImageSegmentMapper", "ImageTaggingMapper", "ImageTaggingVLMMapper", - "LatexFigureContextExtractorMapper", - "LatexMergeTexMapper", "MllmMapper", "NlpaugEnMapper", "NlpcdaZhMapper", @@ -221,4 +227,13 @@ "VideoUndistortMapper", "VideoWholeBodyPoseEstimationMapper", "WhitespaceNormalizationMapper", + "VideoHumanTracksExtractionMapper", + "VideoActiveSpeakerDetectMapper", + 'VideoAudioDetectAgeGenderMapper', + 'VideoAudioASRMapper', + 'VideoCaptioningFaceAttributeEmotionMapper', + 'VideoCaptioningFromHumanTracksMapper', + 'VideoCaptioningFaceAttributeEmotionMapper', + 'VideoHumantrackFaceDemographicMapper', + 'VideoAudioSpeechEmotionMapper' ] diff --git a/data_juicer/ops/mapper/image_face_blur_mapper.py b/data_juicer/ops/mapper/image_face_blur_mapper.py index 54b00c269a..29dbe3526e 100644 --- a/data_juicer/ops/mapper/image_face_blur_mapper.py +++ b/data_juicer/ops/mapper/image_face_blur_mapper.py @@ -83,8 +83,10 @@ def __init__( self.blur_type = blur_type self.radius = radius - self.extra_kwargs = self._default_kwargs.copy() - self.extra_kwargs.update((k, v) for k, v in kwargs.items() if k in self.extra_kwargs) + self.extra_kwargs = self._default_kwargs + for key in kwargs: + if key in self.extra_kwargs: + self.extra_kwargs[key] = kwargs[key] self.model_key = prepare_model(model_type="opencv_classifier", model_path=cv_classifier) self.save_dir = save_dir diff --git a/data_juicer/ops/mapper/latex_figure_context_extractor_mapper.py b/data_juicer/ops/mapper/latex_figure_context_extractor_mapper.py deleted file mode 100644 index b82a10dc3a..0000000000 --- a/data_juicer/ops/mapper/latex_figure_context_extractor_mapper.py +++ /dev/null @@ -1,564 +0,0 @@ -from dataclasses import dataclass, field -from typing import Dict, List, Optional - -import regex as re -from loguru import logger - -from ..base_op import OPERATORS, Mapper - -OP_NAME = "latex_figure_context_extractor_mapper" - - -@dataclass -class SubFigure: - """A subfigure within a figure environment.""" - - caption: str = "" - label: str = "" - image_paths: List[str] = field(default_factory=list) - - -@dataclass -class Figure: - """A top-level figure/figure* environment.""" - - caption: str = "" - label: str = "" - image_paths: List[str] = field(default_factory=list) - sub_figures: List[SubFigure] = field(default_factory=list) - - -@OPERATORS.register_module(OP_NAME) -class LatexFigureContextExtractorMapper(Mapper): - """Extracts figures and their citing context from LaTeX source. - - This operator parses figure environments from a paper's LaTeX - source, extracts each figure's caption, label, and image path(s), - and finds the prose paragraphs that cite each figure. It fans out - one paper row into N figure rows (one per figure or subfigure). - **Samples that contain no figures with images are dropped from - the output.** - - Supported figure environments: figure, figure*, wrapfigure, - subfigure (environment), \\subfigure (command), - \\subfloat (command, subfig package). - Supported caption commands: \\caption, \\caption*, - \\subcaption, \\captionof{figure}. - - Figures without \\includegraphics are skipped. Subfigures - inherit citing paragraphs from their parent figure's label. - - Output fields (in addition to all input fields): - - - ```` (default ``images``, inherited from base - class): list of image paths from ``\\includegraphics``. - - ```` (default ``caption``): figure caption text. - - ```` (default ``label``): LaTeX label string. - - ```` (default ``citing_paragraphs``): list of - paragraphs that cite this figure. - - ```` (default ``parent_caption``): parent - figure caption (subfigures only; empty for standalone figures). - - ```` (default ``parent_label``): parent - figure label (subfigures only; empty for standalone figures). - - Note: this operator expects the full LaTeX source as a single - string. It does **not** resolve ``\\input`` or ``\\include`` - directives. If your documents span multiple ``.tex`` files, - concatenate them into a single text field before applying this - mapper. - """ - - _batched_op = True - - # Recursive nested braces pattern via the ``regex`` module. - # Matches balanced ``{...}`` content at arbitrary nesting depth. - # ``(?P...)`` defines a named group for a single balanced brace - # pair; ``(?&B)`` recurses into it, so - # ``\caption{A \textbf{B \emph{C \footnote{D \cite{E}}}}}`` - # is matched correctly regardless of depth. - _NESTED_BRACES = r"(?:[^{}]|(?P\{(?:[^{}]|(?&B))*\}))*" - - # LaTeX environments stripped from prose before searching for - # citing paragraphs. Only float / display / verbatim environments - # are listed — structural ones (document, section, itemize, …) - # are kept so that prose inside them is still searchable. - _STRIPPED_ENVS = ( - "figure", - r"figure\*", - "wrapfigure", - "table", - r"table\*", - "tabular", - r"tabular\*", - "equation", - r"equation\*", - "align", - r"align\*", - "alignat", - r"alignat\*", - "gather", - r"gather\*", - "multline", - r"multline\*", - "flalign", - r"flalign\*", - "algorithm", - r"algorithm\*", - "lstlisting", - "verbatim", - "minted", - ) - - def __init__( - self, - citation_commands: Optional[List[str]] = None, - paragraph_separator: str = "\n\n", - caption_key: str = "caption", - label_key: str = "label", - context_key: str = "citing_paragraphs", - parent_caption_key: str = "parent_caption", - parent_label_key: str = "parent_label", - *args, - **kwargs, - ): - """ - Initialization method. - - :param citation_commands: LaTeX reference commands to search - for when finding citing paragraphs. Defaults to - ['\\ref', '\\cref', '\\Cref', '\\autoref']. - Comma-separated label lists (e.g. ``\\cref{fig:a,fig:b}``) - are handled automatically. - :param paragraph_separator: Pattern for splitting LaTeX text - into paragraphs. Defaults to '\\n\\n'. - :param caption_key: Output field name for the figure caption. - :param label_key: Output field name for the LaTeX label. - :param context_key: Output field name for citing paragraphs. - :param parent_caption_key: Output field name for the parent - figure's caption. For subfigures this carries the parent - figure environment's caption; for standalone figures it - is an empty string. - :param parent_label_key: Output field name for the parent - figure's label. Useful for grouping subfigures that - belong to the same figure environment. Empty string for - standalone figures. - :param args: extra args - :param kwargs: extra args. Notably ``text_key`` (default - ``'text'``) controls which input field contains the LaTeX - source, and ``image_key`` (default ``'images'``) controls - the output field name for extracted image paths. Both - are inherited from the base ``OP`` class. - """ - super().__init__(*args, **kwargs) - if citation_commands is None: - citation_commands = [ - r"\ref", - r"\cref", - r"\Cref", - r"\autoref", - ] - self.citation_commands = citation_commands - self.paragraph_separator = paragraph_separator - - # Pre-build the citation command alternation once (it never - # changes after init). Individual label patterns are cached - # lazily in _citation_pattern_cache. - cmd_names = [cmd.lstrip("\\") for cmd in citation_commands] - self._cite_cmd_alt = "|".join(re.escape(c) for c in cmd_names) - self._citation_pattern_cache: Dict[str, re.Pattern] = {} - self.caption_key = caption_key - self.label_key = label_key - self.context_key = context_key - self.parent_caption_key = parent_caption_key - self.parent_label_key = parent_label_key - - self._compile_patterns() - - def _compile_patterns(self): - """Compile all regex patterns used by the parser. - - Called once from ``__init__``. Separated for readability — - the patterns are non-trivial and benefit from being grouped - together away from the parameter-handling logic. - """ - nb = self._NESTED_BRACES - - # -- Figure environments ------------------------------------------ - # figure, figure*, wrapfigure (with optional {pos}{width} args). - # Named group + backreference so \begin{X} only matches \end{X}. - self._figure_env_pattern = re.compile( - r"\\begin\{(?Pfigure\*?|wrapfigure)\}" - r"(?:\[[^\]]*\])*" # skip optional [...] args - r"(?:\{[^}]*\})*" # skip mandatory {...} args - r".*?" - r"\\end\{(?P=fig_env)\}", - re.DOTALL, - ) - - # -- Subfigure environments --------------------------------------- - # \begin{subfigure}[pos]{width}...\end{subfigure} - self._subfigure_env_pattern = re.compile( - r"\\begin\{subfigure\}" - r"(?:\[[^\]]*\])*" # optional [pos] arg - r"(?:\{[^}]*\})*" # mandatory {width} arg - r".*?" - r"\\end\{subfigure\}", - re.DOTALL, - ) - - # -- Subfigure / subfloat commands -------------------------------- - # \subfigure[caption]{content} or \subfloat[caption]{content} - self._subfigure_cmd_pattern = re.compile(r"\\(?:subfigure|subfloat)\[([^\]]*)\]" r"\s*" r"\{(" + nb + r")\}") - # \subfigure{content} or \subfloat{content} (no optional caption) - self._subfig_cmd_nocaption_pattern = re.compile( - r"\\(?:subfigure|subfloat)" r"(?!\[)" r"\s*" r"\{(" + nb + r")\}" - ) - - # -- Caption commands --------------------------------------------- - self._caption_pattern = re.compile(r"\\caption\*?(?:\[[^\]]*\])?\{(" + nb + r")\}") - self._subcaption_pattern = re.compile(r"\\subcaption(?:\[[^\]]*\])?\{(" + nb + r")\}") - self._captionof_pattern = re.compile(r"\\captionof\{figure\}(?:\[[^\]]*\])?\{(" + nb + r")\}") - # \captionof{table}{...} — used to detect table minipages inside - # figure environments so they can be excluded from figure output. - self._captionof_table_pattern = re.compile(r"\\captionof\{table\}(?:\[[^\]]*\])?\{(" + nb + r")\}") - - # -- Minipage environments ---------------------------------------- - # Matches \begin{minipage}[pos]{width}...\end{minipage}. - self._minipage_pattern = re.compile( - r"\\begin\{minipage\}" r"(?:\[[^\]]*\])*" r"(?:\{[^}]*\})*" r"(.*?)" r"\\end\{minipage\}", - re.DOTALL, - ) - - # -- Label and includegraphics ------------------------------------ - self._label_pattern = re.compile(r"\\label\{([^}]+)\}") - self._includegraphics_pattern = re.compile(r"\\includegraphics(?:\[[^\]]*\])?\{([^}]+)\}") - - # -- Environment stripping ---------------------------------------- - # Removes float/display/verbatim environments so that - # citing-paragraph search only sees prose text. - env_alt = "|".join(self._STRIPPED_ENVS) - self._env_strip_pattern = re.compile( - r"\\begin\{(" + env_alt + r")\}" r".*?" r"\\end\{\1\}", - re.DOTALL, - ) - - def _extract_caption(self, text): - """Extract caption text from a LaTeX fragment. - - Tries \\caption (including \\caption*), \\subcaption, and - \\captionof{figure}. Returns the first match's content, - or ''. - """ - for pattern in (self._caption_pattern, self._subcaption_pattern, self._captionof_pattern): - m = pattern.search(text) - if m: - return m.group(1).strip() - return "" - - def _extract_label(self, text): - """Extract the first \\label{...} value from a LaTeX - fragment.""" - m = self._label_pattern.search(text) - return m.group(1).strip() if m else "" - - def _extract_image_paths(self, text): - """Extract all \\includegraphics paths from a LaTeX - fragment.""" - return [m.group(1).strip() for m in self._includegraphics_pattern.finditer(text)] - - def _build_subfigure(self, caption, content): - """Build a SubFigure from an explicit caption and content text. - - :param caption: the caption string (already extracted). - :param content: LaTeX fragment to extract label and image - paths from. - :return: a SubFigure instance. - """ - return SubFigure( - caption=caption, - label=self._extract_label(content), - image_paths=self._extract_image_paths(content), - ) - - def _is_table_minipage(self, text): - """Check if a minipage contains a \\captionof{table} command, - indicating it holds a table rather than a figure. - - :param text: the minipage body text. - :return: True if the minipage is a table. - """ - return bool(self._captionof_table_pattern.search(text)) - - def _parse_figure_env(self, fig_text): - """Parse a figure/figure*/wrapfigure environment block. - - Handles \\begin{subfigure} environments, - \\subfigure[caption]{content} commands (older subfigure - package), and \\subfloat[caption]{content} commands (subfig - package). Commands without the optional [caption] argument - are also supported. - - Also handles the minipage pattern where a single figure - environment contains multiple \\begin{minipage} blocks, each - with its own \\caption, \\label, and \\includegraphics. - Minipages that contain \\captionof{table} are skipped. - - :param fig_text: the full text of a figure environment. - :return: a Figure object, a list of Figure objects, or None - if it has no images. - """ - # Check for \begin{subfigure}...\end{subfigure} environments - subfig_env_matches = list(self._subfigure_env_pattern.finditer(fig_text)) - # Check for \subfigure[caption]{} / \subfloat[caption]{} commands - subfig_cmd_matches = list(self._subfigure_cmd_pattern.finditer(fig_text)) - # Check for \subfigure{} / \subfloat{} commands (no caption) - subfig_nocap_matches = list(self._subfig_cmd_nocaption_pattern.finditer(fig_text)) - - has_subfigures = bool(subfig_env_matches or subfig_cmd_matches or subfig_nocap_matches) - - if has_subfigures: - # Normalise every subfigure variant into (caption, content) - # so we can parse them in a single pass. - caption_content_pairs = [] - for m in subfig_env_matches: - # \begin{subfigure}...\end{subfigure}: caption is - # inside the environment body, content is the whole match - caption_content_pairs.append((self._extract_caption(m.group(0)), m.group(0))) - for m in subfig_cmd_matches: - # \subfigure[caption]{content} / \subfloat[caption]{content} - caption_content_pairs.append((m.group(1).strip(), m.group(2))) - for m in subfig_nocap_matches: - # \subfigure{content} / \subfloat{content} (no caption) - caption_content_pairs.append(("", m.group(1))) - - sub_figures = [] - for caption, content in caption_content_pairs: - sf = self._build_subfigure(caption, content) - if sf.image_paths: - sub_figures.append(sf) - - # Extract parent caption/label from text outside - # all subfigure/subfloat blocks - text_outside = fig_text - all_matches = sorted( - subfig_env_matches + subfig_cmd_matches + subfig_nocap_matches, - key=lambda m: m.start(), - reverse=True, - ) - for m in all_matches: - text_outside = text_outside[: m.start()] + text_outside[m.end() :] - - if not sub_figures: - return None - - return Figure( - caption=self._extract_caption(text_outside), - label=self._extract_label(text_outside), - image_paths=[], - sub_figures=sub_figures, - ) - else: - # No subfigures — check for the minipage pattern. - # When a figure environment contains multiple minipages - # with separate \caption/\label pairs, each minipage is - # treated as an independent figure. Minipages holding - # \captionof{table} are skipped (they are tables). - minipage_matches = list(self._minipage_pattern.finditer(fig_text)) - if len(minipage_matches) >= 2: - figures = [] - for mp in minipage_matches: - mp_body = mp.group(1) - # Skip table minipages - if self._is_table_minipage(mp_body): - continue - image_paths = self._extract_image_paths(mp_body) - if not image_paths: - continue - figures.append( - Figure( - caption=self._extract_caption(mp_body), - label=self._extract_label(mp_body), - image_paths=image_paths, - sub_figures=[], - ) - ) - return figures if figures else None - - # Single figure (no subfigures, no multi-minipage) - image_paths = self._extract_image_paths(fig_text) - if not image_paths: - return None - - return Figure( - caption=self._extract_caption(fig_text), - label=self._extract_label(fig_text), - image_paths=image_paths, - sub_figures=[], - ) - - def _parse_figures(self, latex_source): - """Parse all figure environments from a LaTeX source. - - :param latex_source: full LaTeX document source. - :return: a list of Figure objects. - """ - figures = [] - for m in self._figure_env_pattern.finditer(latex_source): - result = self._parse_figure_env(m.group(0)) - if result is None: - continue - if isinstance(result, list): - figures.extend(result) - else: - figures.append(result) - return figures - - def _prepare_paragraphs(self, latex_source): - """Strip float/display environments and split into clean - paragraphs. - - :param latex_source: full LaTeX document source. - :return: a list of non-empty paragraph strings. - """ - stripped = self._env_strip_pattern.sub("", latex_source) - return [p.strip() for p in stripped.split(self.paragraph_separator) if p.strip()] - - def _get_citation_pattern(self, label): - """Return a compiled regex that matches any citation command - referencing *label*. Results are cached per label. - - Handles comma-separated label lists such as - ``\\cref{fig:a,fig:b}``. The label must appear as a - complete entry (bounded by ``{``, ``,``, or ``}``) so that - e.g. label ``fig:a`` does not false-match ``fig:ab``. - - :param label: the LaTeX label string to search for. - :return: a compiled regex pattern (cached). - """ - pat = self._citation_pattern_cache.get(label) - if pat is not None: - return pat - # Match \cmd{...label...} where label appears as a - # complete entry in a comma-separated list. - # (?:[^},]*,\s*)* — zero or more preceding entries - # LABEL — the target label - # \s*(?:,[^}]*)? — optional trailing entries - pat = re.compile( - r"\\(?:" + self._cite_cmd_alt + r")\{" r"(?:[^},]*,\s*)*" + re.escape(label) + r"\s*(?:,[^}]*)?\}" - ) - self._citation_pattern_cache[label] = pat - return pat - - def _find_citing_paragraphs(self, label, paragraphs): - """Find paragraphs that cite a given label. - - :param label: the LaTeX label to search for. - :param paragraphs: list of paragraph strings to search in. - :return: a list of paragraph strings that cite the label, - or [] if label is empty. - """ - if not label: - return [] - cite_pattern = self._get_citation_pattern(label) - return [p for p in paragraphs if cite_pattern.search(p)] - - def _append_output_row(self, output_samples, samples, idx, input_keys, *, fig, citing_paragraphs, parent=None): - """Append one output row to the output_samples dict. - - :param output_samples: accumulator dict of lists. - :param samples: the original input batch. - :param idx: index of the current sample in the batch. - :param input_keys: keys to copy from the original sample. - :param fig: a Figure or SubFigure whose caption, label, - and image_paths are emitted. - :param citing_paragraphs: list of citing paragraph strings. - :param parent: optional parent Figure for subfigure rows. - When provided, its caption and label are emitted as - parent_caption / parent_label; otherwise empty strings. - """ - # Keys that are explicitly set below must be skipped during - # the input-copy loop to avoid double-appending (e.g. when - # the input batch already contains an ``images`` column). - output_only_keys = { - self.image_key, - self.caption_key, - self.label_key, - self.context_key, - self.parent_caption_key, - self.parent_label_key, - } - for k in input_keys: - if k not in output_only_keys: - output_samples[k].append(samples[k][idx]) - output_samples[self.caption_key].append(fig.caption) - output_samples[self.image_key].append(fig.image_paths) - output_samples[self.label_key].append(fig.label) - output_samples[self.context_key].append(citing_paragraphs) - output_samples[self.parent_caption_key].append(parent.caption if parent else "") - output_samples[self.parent_label_key].append(parent.label if parent else "") - - def process_batched(self, samples): - input_keys = samples.keys() - num_samples = len(samples[next(iter(input_keys))]) - output_keys = input_keys | { - self.caption_key, - self.label_key, - self.context_key, - self.image_key, - self.parent_caption_key, - self.parent_label_key, - } - output_samples: Dict[str, list] = {key: [] for key in output_keys} - - for i in range(num_samples): - latex_source = samples[self.text_key][i] - - # Parse figures - figures = self._parse_figures(latex_source) - - if not figures: - logger.warning( - f"No figures with images found in sample {i} " - f"(batch size {num_samples}). " - f"Sample will be dropped from output." - ) - continue - - # Prepare cleaned paragraphs once per paper - paragraphs = self._prepare_paragraphs(latex_source) - - # Fan out - for fig in figures: - if fig.sub_figures: - # Parent-level citing paragraphs - parent_citing = self._find_citing_paragraphs(fig.label, paragraphs) - - for sf in fig.sub_figures: - # Subfigure-specific citing paragraphs - sf_citing = self._find_citing_paragraphs(sf.label, paragraphs) - # Merge parent + subfigure, deduplicated, - # preserving order - merged = list(dict.fromkeys(parent_citing + sf_citing)) - - self._append_output_row( - output_samples, - samples, - i, - input_keys, - fig=sf, - citing_paragraphs=merged, - parent=fig, - ) - else: - # Single figure (leaf) — no parent - citing = self._find_citing_paragraphs(fig.label, paragraphs) - self._append_output_row( - output_samples, - samples, - i, - input_keys, - fig=fig, - citing_paragraphs=citing, - ) - - return output_samples diff --git a/data_juicer/ops/mapper/latex_merge_tex_mapper.py b/data_juicer/ops/mapper/latex_merge_tex_mapper.py deleted file mode 100644 index da4696a5d6..0000000000 --- a/data_juicer/ops/mapper/latex_merge_tex_mapper.py +++ /dev/null @@ -1,185 +0,0 @@ -import tarfile -import zipfile - -from loguru import logger - -from ..base_op import OPERATORS, Mapper - -OP_NAME = "latex_merge_tex_mapper" - - -@OPERATORS.register_module(OP_NAME) -class LatexMergeTexMapper(Mapper): - """Extracts and concatenates all ``.tex`` files from a compressed - LaTeX project archive into a single text field. - - Supported archive formats: ``.tar``, ``.tar.gz`` / ``.tgz``, - and ``.zip``. Plain ``.gz`` (single-file gzip) is **not** - supported because gzip archives carry no filename metadata, - making it impossible to verify that the content is actually a - ``.tex`` file. - - All ``.tex`` files found inside the archive are read in-memory and - joined with a configurable separator. No ordering or - deduplication is applied. - - This operator is typically placed before LaTeX-processing operators - such as ``remove_comments_mapper``, ``expand_macro_mapper``, or - ``latex_figure_context_extractor_mapper``.""" - - def __init__( - self, - compressed_file_key: str = "compressed_file", - separator: str = "\n\n", - max_file_size: int = 50 * 1024 * 1024, - max_total_size: int = 100 * 1024 * 1024, - *args, - **kwargs, - ): - """ - Initialization method. - - :param compressed_file_key: Field name that stores the archive - file path. - :param separator: String used to join the contents of multiple - ``.tex`` files. - :param max_file_size: Maximum allowed uncompressed size in bytes - for a single ``.tex`` entry inside the archive. Entries - exceeding this limit are skipped with a warning. Set to - ``None`` or ``0`` to disable the check. - :param max_total_size: Maximum allowed cumulative size in bytes - for all extracted ``.tex`` content combined. Once this - limit is reached, remaining files in the archive are - skipped with a warning. Set to ``None`` or ``0`` to - disable the check. - :param args: extra args - :param kwargs: extra args - """ - super().__init__(*args, **kwargs) - self.compressed_file_key = compressed_file_key - self.separator = separator - self.max_file_size = max_file_size or 0 - self.max_total_size = max_total_size or 0 - - def _extract_tex_contents(self, archive_path: str): - """Return a list of decoded ``.tex`` file contents from - *archive_path*. Dispatches by file extension to the - appropriate reader.""" - path_lower = archive_path.lower() - - try: - if path_lower.endswith(".zip"): - return self._read_zip(archive_path, self.max_file_size, self.max_total_size) - elif path_lower.endswith((".tar.gz", ".tgz", ".tar")): - return self._read_tar(archive_path, self.max_file_size, self.max_total_size) - else: - logger.warning( - f"Unsupported archive format: {archive_path}. " f"Supported formats: .tar, .tar.gz, .tgz, .zip" - ) - return [] - except Exception: - logger.exception(f"Failed to read archive {archive_path}") - return [] - - @staticmethod - def _read_tar(archive_path: str, max_file_size: int = 0, max_total_size: int = 0): - contents = [] - total_bytes = 0 - with tarfile.open(archive_path, "r:*") as tf: - for member in tf: - if not member.isfile(): - continue - if not member.name.endswith(".tex"): - continue - if max_file_size and member.size > max_file_size: - logger.warning( - f"Skipping {member.name} in {archive_path}: " - f"declared size {member.size} bytes exceeds " - f"limit of {max_file_size} bytes" - ) - continue - # Use declared header size to bail before reading. - if max_total_size and (total_bytes + member.size) > max_total_size: - logger.warning( - f"Cumulative extracted size would exceed limit " - f"of {max_total_size} bytes in {archive_path}. " - f"Skipping remaining files." - ) - break - raw = tf.extractfile(member) - if raw is None: - continue - raw_bytes = raw.read() - if max_file_size and len(raw_bytes) > max_file_size: - logger.warning( - f"Skipping {member.name} in {archive_path}: " - f"actual size {len(raw_bytes)} bytes exceeds " - f"limit of {max_file_size} bytes" - ) - continue - total_bytes += len(raw_bytes) - if max_total_size and total_bytes > max_total_size: - logger.warning( - f"Cumulative extracted size {total_bytes} bytes " - f"exceeds limit of {max_total_size} bytes in " - f"{archive_path}. Skipping remaining files." - ) - break - contents.append(raw_bytes.decode("utf-8", errors="replace")) - return contents - - @staticmethod - def _read_zip(archive_path: str, max_file_size: int = 0, max_total_size: int = 0): - contents = [] - total_bytes = 0 - with zipfile.ZipFile(archive_path) as zf: - for name in zf.namelist(): - if not name.endswith(".tex"): - continue - info = zf.getinfo(name) - if max_file_size and info.file_size > max_file_size: - logger.warning( - f"Skipping {name} in {archive_path}: " - f"declared size {info.file_size} bytes exceeds " - f"limit of {max_file_size} bytes" - ) - continue - # Use declared header size to bail before reading. - if max_total_size and (total_bytes + info.file_size) > max_total_size: - logger.warning( - f"Cumulative extracted size would exceed limit " - f"of {max_total_size} bytes in {archive_path}. " - f"Skipping remaining files." - ) - break - raw_bytes = zf.read(name) - if max_file_size and len(raw_bytes) > max_file_size: - logger.warning( - f"Skipping {name} in {archive_path}: " - f"actual size {len(raw_bytes)} bytes exceeds " - f"limit of {max_file_size} bytes" - ) - continue - total_bytes += len(raw_bytes) - if max_total_size and total_bytes > max_total_size: - logger.warning( - f"Cumulative extracted size {total_bytes} bytes " - f"exceeds limit of {max_total_size} bytes in " - f"{archive_path}. Skipping remaining files." - ) - break - contents.append(raw_bytes.decode("utf-8", errors="replace")) - return contents - - def process_single(self, sample): - if self.compressed_file_key not in sample: - raise ValueError( - f"Compressed file key '{self.compressed_file_key}' " - f"not found in sample. " - f"Available keys: {list(sample.keys())}" - ) - - path = sample[self.compressed_file_key] - tex_contents = self._extract_tex_contents(path) - sample[self.text_key] = self.separator.join(tex_contents) - return sample diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 6e3436737c..6e5a45f4f7 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -1,3 +1,5 @@ +from copy import deepcopy + from loguru import logger from pydantic import PositiveInt @@ -144,30 +146,25 @@ def process_batched(self, samples): else: return {key: [] for key in samples} - res_samples = {key: [] for key in samples} - # process each sample in the batch - for idx, text_to_aug in enumerate(samples[self.text_key]): - # get augmented texts for this sample - if self.sequential: - aug_texts = self.aug.augment(text_to_aug, n=self.aug_num) - else: - aug_texts = [] - for aug_method in self.aug: - aug_texts += aug_method.augment(text_to_aug, n=self.aug_num) - - if not isinstance(aug_texts, list): - aug_texts = [aug_texts] - - # collect texts for this sample - if self.keep_original_sample: - sample_texts = [text_to_aug] + aug_texts - else: - sample_texts = aug_texts - res_samples[self.text_key] += sample_texts - - # replicate other fields to match - for key in samples: - if key != self.text_key: - res_samples[key] += [samples[key][idx]] * len(sample_texts) + texts_to_aug = samples[self.text_key][0] # batch_size = 1 + res_samples = deepcopy(samples) + # get augmented texts + if self.sequential: + aug_texts = self.aug.augment(texts_to_aug, n=self.aug_num) + else: + # apply each aug method to generate several augmented texts + aug_texts = [] + for aug_method in self.aug: + aug_texts += aug_method.augment(texts_to_aug, n=self.aug_num) + + # add augmented samples to the batch with other replicate fields + if self.keep_original_sample: + res_samples[self.text_key] += aug_texts + else: + res_samples[self.text_key] = aug_texts + # add other replicate fields + for key in res_samples: + if key != self.text_key: + res_samples[key] = res_samples[key] * len(res_samples[self.text_key]) return res_samples diff --git a/data_juicer/ops/mapper/video_active_speaker_detect_mapper.py b/data_juicer/ops/mapper/video_active_speaker_detect_mapper.py new file mode 100644 index 0000000000..cb6af6c583 --- /dev/null +++ b/data_juicer/ops/mapper/video_active_speaker_detect_mapper.py @@ -0,0 +1,247 @@ +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2,evaluate_network, \ + crop_video_with_facetrack, longest_continuous_actives + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.model_utils import get_model, prepare_model +import gc,os +from loguru import logger + +OP_NAME = 'video_active_speaker_detect_mapper' + +import torch +import sys +sys.path.append('./thirdparty/humanvbench_models/Light-ASD') +from data_juicer.utils.constant import Fields, MetaKeys +import tempfile +import shutil, pickle +from shutil import rmtree +import os, subprocess +import tqdm, glob +# from model.faceDetector.s3fd import S3FD + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoActiveSpeakerDetectMapper(Mapper): + _accelerator = 'cuda' + _batched_op = True + + """ + Detect active speakers in a video by analyzing visual face tracks and + audio signals, including consistency checks for gender and age. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + + _default_kwargs = {'upsample_num_times': 0} + + def __init__(self, + temp_save_path: str = './temp_path', + Light_ASD_model_path: str = './thirdparty/humanvbench_models/Light-ASD/weight/finetuning_TalkSet.model', + active_threshold: int = 15, + active_speaker_flag: str = MetaKeys.active_speaker_flag, + *args, + **kwargs): + """ + Initialization method. + + :param blur_type: + """ + kwargs.setdefault('mem_required', '10GB') + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + self.active_threshold = active_threshold + + self.temp_save_path = temp_save_path + + # Initialize ASD model + self.ASD_model_key = prepare_model(model_type='Light_ASD', + pretrained_model_name_or_path=Light_ASD_model_path) + + self.active_speaker_flag = active_speaker_flag + + def active_speaker_detection_revise(self, active_score,is_child_descrip,speech_audio,face_gender): + speech_child = speech_audio['child'][0] + speech_male = speech_audio['male'][0] + speech_female = speech_audio['female'][0] + if speech_male > speech_female: + speech_gender = 'Man' + speech_gender_confidence = speech_male + else: + speech_gender = 'Woman' + speech_gender_confidence = speech_female + + if 'No' in is_child_descrip or 'no' in is_child_descrip: + is_child_apperance = False + else: + is_child_apperance = True + + if speech_child < 0.1: + is_child_voice = False + elif speech_audio['Age'][0]<=12: + is_child_voice = True + else: + is_child_voice = 'Not Sure' + + # Consistency detection: only perform false positive detection on positive samples + if active_score>self.active_threshold: + speak_active = True + # age consistency test: + if not is_child_voice == 'Not Sure': + if is_child_apperance == is_child_voice: + # gender consistency test + if speech_gender_confidence > 0.85 and float(face_gender[1]) > 0.85: + if not speech_gender == face_gender[0]: + speak_active = False + else: + speak_active = False + return speak_active + else: + return False + + + def process_single(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if Fields.meta not in sample: + sample[Fields.meta] = {} + + meta = sample.get(Fields.meta, {}) + + # Core dependencies: Both human tracks and audio tags are required for ASD + has_tracks = MetaKeys.human_track_data_path in meta + has_audio_tags = MetaKeys.video_audio_tags in meta + + if not (has_tracks and has_audio_tags): + missing = [] + if not has_tracks: missing.append(MetaKeys.human_track_data_path) + if not has_audio_tags: missing.append(MetaKeys.video_audio_tags) + + logger.warning( + f"[{OP_NAME}] Skip sample: Missing mandatory keys {missing}. " + f"video_active_speaker_detect_mapper must be operated after video_tagging_from_audio_mapper. " + f"video_active_speaker_detect_mapper must be operated after video_human_tracks_extraction_mapper. " + f"Please ensure prior Mappers are executed correctly." + ) + return sample + + # Optional dependencies for 'revise' function + has_audio_attr = MetaKeys.audio_speech_attribute in meta + has_face_attr = MetaKeys.video_facetrack_attribute_demographic in meta + has_child_attr = MetaKeys.video_track_is_child in meta + + revise_available = (has_audio_attr and has_face_attr and has_child_attr) + if not revise_available: + logger.info( + f"[{OP_NAME}] Some metadata missing. Running in 'Basic Mode' without consistency detection. " + f"To enable full consistency detection, ensure these OPs are executed: video_audio_detect_age_gender_mapper, video_humantrack_face_demographic_mapper and video_captioning_from_human_tracks_mapper." + ) + + loaded_video_keys = sample[self.video_key] + + if revise_available: + audio_speech_attribute = sample[Fields.meta][MetaKeys.audio_speech_attribute] + face_demographic = sample[Fields.meta][MetaKeys.video_facetrack_attribute_demographic][0] + child_flag = sample[Fields.meta][MetaKeys.video_track_is_child][0] + + Total_result = [] + + temp_dir = tempfile.mkdtemp(dir=self.temp_save_path) + pyaviPath = os.path.join(temp_dir, 'pyavi') + pyframesPath = os.path.join(temp_dir, 'pyframes') + pyworkPath = os.path.join(temp_dir, 'pywork') + pycropPath = os.path.join(temp_dir, 'pycrop') + if os.path.exists(temp_dir): + rmtree(temp_dir) + + audio_tag = sample[Fields.meta][MetaKeys.video_audio_tags] + asd_detection_model = get_model(self.ASD_model_key, rank=rank) + + for id_out,video_key in enumerate(loaded_video_keys): + os.makedirs(pyaviPath, exist_ok = False) # The path for the input video, input audio, output video + os.makedirs(pyframesPath, exist_ok = False) # Save all the video frames + os.makedirs(pyworkPath, exist_ok = False) # Save the results in this process by the pckl method + os.makedirs(pycropPath, exist_ok = False) # Save the detected face clips (audio+video) in this process + + # Extract audio + audio_is_empty = False + audioFilePath = os.path.join(pyaviPath, 'audio.wav') + command = ("ffmpeg -y -i '%s' -qscale:a 0 -ac 1 -vn -threads %d -ar 16000 %s -loglevel panic" % \ + (video_key, 10, audioFilePath)) + if audio_tag[id_out] == "EMPTY": + audio_is_empty = True + else: + subprocess.call(command, shell=True, stdout=None) + + + video_array = get_video_array_cv2(video_key) + + def load_pkl(file_path): + with open(file_path, 'rb') as file: + return pickle.load(file) + # get allTracks + allTracks = [load_pkl(item['bbox_path']) for item in sample[Fields.meta][MetaKeys.human_track_data_path][id_out]] + + # Face clips cropping + for ii, track in tqdm.tqdm(enumerate(allTracks), total = len(allTracks)): + result = crop_video_with_facetrack(video_array, track, os.path.join(pycropPath, '%05d' % ii), audioFilePath, audio_is_empty) + if not result: + raise ValueError("something wrong with crop_video_with_facetrack.") + + # Active Speaker Detection + if audio_tag[id_out] == 'Speech': + files = glob.glob("%s/*.avi"%pycropPath) + files.sort() + try: + scores = evaluate_network(files, asd_detection_model, pycropPath) + except: + scores = [[-10000]]* len(allTracks) + + else: + scores = [[-10000]]* len(allTracks) + + for id in range(len(scores)): + allTracks[id]['active_scores'] = scores[id] + + update_track = allTracks + # for validation + # visualization(vidTracks, scores, video_array, pyaviPath) + + shutil.rmtree(temp_dir) + + speak_flag_for_tracks_in_a_video = [] + for track_idx,track_i in enumerate(update_track): + active_count = longest_continuous_actives(track_i['active_scores']) + + if revise_available: + try: + audio_attri = audio_speech_attribute[id_out][0] + is_child_descrip = child_flag[id_out][track_idx][0] + face_gender = face_demographic[id_out][track_idx]['gender'] + flag = self.active_speaker_detection_revise(active_count, is_child_descrip, audio_attri, face_gender) + except: + if active_count>self.active_threshold: + flag = True + else: + flag = False + else: + if active_count>self.active_threshold: + flag = True + else: + flag = False + speak_flag_for_tracks_in_a_video.append(flag) + + + Total_result.append(speak_flag_for_tracks_in_a_video) + torch.cuda.empty_cache() + + sample[Fields.meta][self.active_speaker_flag] = Total_result + + gc.collect() + torch.cuda.empty_cache() + + return sample diff --git a/data_juicer/ops/mapper/video_audio_ASR_mapper.py b/data_juicer/ops/mapper/video_audio_ASR_mapper.py new file mode 100644 index 0000000000..c1fe84b56f --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_ASR_mapper.py @@ -0,0 +1,113 @@ +import librosa +from data_juicer.utils.mm_utils import extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Mapper +import gc +from data_juicer.utils.constant import Fields, MetaKeys + +OP_NAME = 'video_audio_ASR_mapper' + +import torch,funasr +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class VideoAudioASRMapper(Mapper): + """Mapper to generate video tags from audio streams extracted by video + using the Audio Spectrogram Transformer. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + model_dir_ASR = 'FunAudioLLM/SenseVoiceSmall', + speech_ASR: str = MetaKeys.speech_ASR, + *args, + **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '20GB') + super().__init__(*args, **kwargs) + self._batched_op = True + self._model_sampling_rate = 16000 + self.model_dir_ASR = model_dir_ASR + + self.model_key = prepare_model( + model_type='SenseVoiceSmall', + pretrained_model_name_or_path=model_dir_ASR, + ) + + self.speech_ASR = speech_ASR + + def process_single(self, sample, rank=None): + # check if it's generated already + if MetaKeys.speech_emotion in sample[Fields.meta]: + return sample + + if Fields.meta not in sample: + sample[Fields.meta] = {} + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_audio_ASR_mapper must be operated after video_tagging_from_audio_mapper.") + + + # load video paths + loaded_video_keys = sample[self.video_key] + audio_tags = sample[Fields.meta][MetaKeys.video_audio_tags] + + ASR_model, kwargs1= get_model(self.model_key, rank=rank) + + # model, feature_extractor = get_model(self.model_key, rank=rank) + video_audio_tags = [] + + for id,video_path in enumerate(loaded_video_keys): + if audio_tags[id] == 'Speech': + # only extract audio data and sr for index 0 for now + ys, srs, valid_indexes = extract_audio_from_video( + video_path, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + video_audio_tags.append(self._no_audio_label) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + inputs = torch.tensor(y).to(next(ASR_model.parameters()).device) + with torch.no_grad(): + output_ASR_emo = ASR_model.inference( + data_in=inputs, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + **kwargs1, + ) + + # Example of output_ASR_emo[0][0]['text']: + # "<|en|><|NEUTRAL|><|Speech|> Hello, world." + # The split logic extracts language (en) and the clean text (Hello, world.) + video_audio_tags.append({'language':output_ASR_emo[0][0]['text'].split('<|',1)[-1].split('|>')[0], 'asr': output_ASR_emo[0][0]['text'].split('|>',4)[-1]}) + else: + video_audio_tags.append({'language': '', 'asr': ''}) + + sample[Fields.meta][self.speech_ASR] = video_audio_tags + # gc.collect() + # torch.cuda.empty_cache() + return sample diff --git a/data_juicer/ops/mapper/video_audio_detect_age_gender_mapper.py b/data_juicer/ops/mapper/video_audio_detect_age_gender_mapper.py new file mode 100644 index 0000000000..2f084a9ba6 --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_detect_age_gender_mapper.py @@ -0,0 +1,103 @@ +import librosa +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.mm_utils import extract_audio_from_video +from thirdparty.humanvbench_models.audio_code.wav2vec_age_gender import process_func,AgeGenderModel +from ..base_op import OPERATORS, Mapper +from data_juicer.utils.model_utils import get_model, prepare_model + +NAME = 'video_audio_detect_age_gender_mapper' +CHECK_PKGS = [ + 'transformers', 'transformers_stream_generator', 'einops', 'accelerate', + 'tiktoken' +] + +from data_juicer.utils.model_utils import get_model, prepare_model + + + +@OPERATORS.register_module(NAME) +class VideoAudioDetectAgeGenderMapper(Mapper): + """ + Detect age and gender (male, female, child) from video audio signals using a pretrained wav2vec2 model. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + hf_audio_mapper: str = None, + tag_field_name: str = MetaKeys.audio_speech_attribute, + *args, **kwargs): + """ + Initialization method. + + :param keep_original_sample: whether to keep the original sample. If + it's set to False, there will be only captioned sample in the + final datasets and the original sample will be removed. It's True + in default. + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '7GB') + super().__init__(*args, **kwargs) + self._model_sampling_rate = 16000 + + self._hf_summarizer = hf_audio_mapper if hf_audio_mapper else 'audeering/wav2vec2-large-robust-24-ft-age-gender' # noqa: E501 + self.model_key = prepare_model( + model_type='wav2vec2_age_gender', + pretrained_model_name_or_path=self._hf_summarizer, + ) + self.tag_field_name = tag_field_name + + def process_single(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + if Fields.meta not in sample: + sample[Fields.meta] = {} + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_audio_detect_age_gender_mapper must be operated after video_tagging_from_audio_mapper.") + + # get paths of all video(s) + loaded_video_keys = sample[self.video_key] + audio_tag = sample[Fields.meta][MetaKeys.video_audio_tags] + + Total_result = [] + # get models + model, processor = get_model(self.model_key, rank, self.use_cuda()) + + for i,video in enumerate(loaded_video_keys): + audio_tag_this = audio_tag[i] + if not audio_tag_this == 'Speech': + Total_result.append([]) + else: + ys, srs, valid_indexes = extract_audio_from_video( + video, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + Total_result.append([]) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + Age_female_male_child = process_func(y, sr, processor, model, device=model.device)[0] + Age_female_male_child_dict = {} + Age_female_male_child_dict['Age'] = [int(Age_female_male_child[0]*100)] + Age_female_male_child_dict['female'] = [Age_female_male_child[1]] + Age_female_male_child_dict['male'] = [Age_female_male_child[2]] + Age_female_male_child_dict['child'] = [Age_female_male_child[3]] + Total_result.append([Age_female_male_child_dict]) + + sample[Fields.meta][self.tag_field_name] = Total_result + return sample diff --git a/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py b/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py new file mode 100644 index 0000000000..c0a0e3a189 --- /dev/null +++ b/data_juicer/ops/mapper/video_audio_speech_emotion_mapper.py @@ -0,0 +1,109 @@ +import librosa +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.mm_utils import extract_audio_from_video +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, Mapper +import gc + +OP_NAME = 'video_audio_speech_emotion_mapper' + +import torch,funasr +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +class VideoAudioSpeechEmotionMapper(Mapper): + """Mapper to generate video tags from audio streams extracted by video + using the Audio Spectrogram Transformer. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + _accelerator = 'cuda' + _batched_op = True + + def __init__(self, + model_dir_emo='FunAudioLLM/SenseVoiceSmall', + speech_Emo: str = MetaKeys.speech_emotion, + *args, + **kwargs): + """ + Initialization method. + + :param args: extra args + :param kwargs: extra args + """ + kwargs.setdefault('mem_required', '20GB') + super().__init__(*args, **kwargs) + self._batched_op = True + self._model_sampling_rate = 16000 + self.model_dir_emo = model_dir_emo + + self.model_key = prepare_model( + model_type='SenseVoiceSmall', + pretrained_model_name_or_path=self.model_dir_emo, + ) + + self.speech_Emo = speech_Emo + + def process_single(self, sample, rank=None): + # check if it's generated already + if MetaKeys.speech_emotion in sample[Fields.meta]: + return sample + + if Fields.meta not in sample: + sample[Fields.meta] = {} + + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if not MetaKeys.video_audio_tags in sample[Fields.meta]: + raise ValueError("video_active_speaker_mapper must be operated after video_tagging_from_audio_mapper.") + + + # load video paths + loaded_video_keys = sample[self.video_key] + audio_tags = sample[Fields.meta][MetaKeys.video_audio_tags] + + Emo_model, kwargs1= get_model(self.model_key, rank=rank) + + video_audio_tags = [] + for id,video_path in enumerate(loaded_video_keys): + if audio_tags[id] == 'Speech': + # only extract audio data and sr for index 0 for now + ys, srs, valid_indexes = extract_audio_from_video( + video_path, stream_indexes=[0]) + if len(valid_indexes) == 0: + # there is no valid audio streams. Skip! + video_audio_tags.append(self._no_audio_label) + continue + + # inference + y = ys[0] + sr = srs[0] + # check if it meets the sampling rate condition of the model + if sr != self._model_sampling_rate: + y = librosa.resample(y, + orig_sr=sr, + target_sr=self._model_sampling_rate) + sr = self._model_sampling_rate + + inputs = torch.tensor(y).to(next(Emo_model.parameters()).device) + with torch.no_grad(): + output_emo = Emo_model.inference( + data_in=inputs, + language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech" + use_itn=False, + **kwargs1, + ) + + video_audio_tags.append(output_emo[0][0]['text'].split('<|',2)[-1].split('|>')[0]) + else: + video_audio_tags.append('') + + sample[Fields.meta][self.speech_Emo] = video_audio_tags + gc.collect() + torch.cuda.empty_cache() + return sample diff --git a/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py b/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py new file mode 100644 index 0000000000..39b07f3569 --- /dev/null +++ b/data_juicer/ops/mapper/video_captioning_face_attribute_emotion_mapper.py @@ -0,0 +1,165 @@ +import numpy as np +from data_juicer.utils.constant import Fields, MetaKeys +from data_juicer.utils.model_utils import get_model, prepare_model +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2 +import gc + +OP_NAME = 'video_captioning_face_attribute_emotion_mapper' + +import torch, os, tempfile, shutil +from shutil import rmtree +import pickle, copy, cv2 +import transformers # noqa: F401 + +# avoid hanging when calling clip in multiprocessing +# torch.set_num_threads(1) +import sys + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoCaptioningFaceAttributeEmotionMapper(Mapper): + _accelerator = 'cuda' + _batched_op = True + + def __init__( + self, + face_track_query: str = "Please describe the person's facial expression, tell me the person's emotion through the video, like Happiness, Excitement, Love, Gratitude, Relief, Pride, Anger, Sadness, Fear, Guilt, Shame, Disgust, Surprise, Confusion, Curiosity, Boredom ...", + trust_remote_code: bool = False, + cropping_face_video_temp_path = './temp_video_path', + video_describe_model_path: str = 'DAMO-NLP-SG/VideoLLaMA3-7B', + video_facetrack_attribute_emotion: str = MetaKeys.video_facetrack_attribute_emotion, + *args, + **kwargs + ): + """ + Initialization method. + + :param hf_video_blip: video-blip model name on huggingface + to generate caption + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + kwargs.setdefault('mem_required', '40GB') + super().__init__(*args, **kwargs) + + self._batched_op = True + self._accelerator = 'cuda' + self.context_param = 0.8 + + # self.pre_query_prompt = "The provided image arranges keyframes from a video in a grid view, keyframes are separated with white bands. " + self.query = face_track_query + self.cropping_face_video_temp_path = cropping_face_video_temp_path + + self.video_describe_model_path = video_describe_model_path if video_describe_model_path else 'DAMO-NLP-SG/VideoLLaMA3-7B' + self.model_key = prepare_model( + model_type='huggingface', + pretrained_model_name_or_path=video_describe_model_path, + trust_remote_code=trust_remote_code, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" + ) + + self.video_facetrack_attribute_emotion = video_facetrack_attribute_emotion + + + + def process_single(self, samples, rank=None): + + if not MetaKeys.human_track_data_path in samples[Fields.meta]: + raise ValueError("video_captioning_face_attribute_emotion_mapper must be operated after video_human_tracks_extraction_mapper.") + + if Fields.meta not in samples: + samples[Fields.meta] = {} + + Total_information = [] + video_samples = samples[Fields.meta][MetaKeys.human_track_data_path] + loaded_video_keys = samples[self.video_key] + + cropping_face_video_temp_path = tempfile.mkdtemp(dir=self.cropping_face_video_temp_path) + if os.path.exists(cropping_face_video_temp_path): + rmtree(cropping_face_video_temp_path) + + os.makedirs(cropping_face_video_temp_path, exist_ok = False) + model, processor = get_model(self.model_key, rank, self.use_cuda()) + for vedio_id,ASD_attribute_all_tracks_for_one_video in enumerate(video_samples): + if len(ASD_attribute_all_tracks_for_one_video) == 0: + Total_information.append([]) + continue + + description_for_each_track = [] + video_array = get_video_array_cv2(loaded_video_keys[vedio_id]) + for track_id,tracks_now in enumerate(ASD_attribute_all_tracks_for_one_video): + cs = self.context_param + + with open(tracks_now['bbox_path'], 'rb') as f: + bbox_data = pickle.load(f) + xys_bbox = bbox_data['xys_bbox'] + track_frame = bbox_data['frame'] + + face_video_out_path = os.path.join(cropping_face_video_temp_path, loaded_video_keys[vedio_id].split('/')[-1][:-4] + '__' + str(track_id) + '.mp4') + + + num_total_frames = len(track_frame) + + target_write_fps = 25 if num_total_frames > 25 else max(1, num_total_frames-1) + + vOut = cv2.VideoWriter( + face_video_out_path, + cv2.VideoWriter_fourcc(*'XVID'), + target_write_fps, + (224,224) + ) + + start_frame_id_in = 0 + start_frame_id_out = track_frame[start_frame_id_in] # tag + while start_frame_id_in + 1 25 else max(1, num_total_frames-1) + + vOut = cv2.VideoWriter( + human_video_out_path, + cv2.VideoWriter_fourcc(*'mp4v'), + target_write_fps, + (wide_max + 2, height_max + 2) + ) + + while start_frame_id_in 0: + if all(element < 0 for element in human_bbox['x1']): + return False + human_bbox['x1'] = detect_and_mark_anomalies(human_bbox['x1'], window_size=30, std_multiplier=10) + human_bbox['x1'] = update_negative_ones(human_bbox['x1']) + if (np.array(human_bbox['y1'])<0).sum() > 0: + human_bbox['y1'] = detect_and_mark_anomalies(human_bbox['y1'], window_size=30, std_multiplier=10) + human_bbox['y1'] = update_negative_ones(human_bbox['y1']) + if (np.array(human_bbox['x2'])<0).sum() > 0: + human_bbox['x2'] = detect_and_mark_anomalies(human_bbox['x2'], window_size=30, std_multiplier=10) + human_bbox['x2'] = update_negative_ones(human_bbox['x2']) + if (np.array(human_bbox['y2'])<0).sum() > 0: + human_bbox['y2'] = detect_and_mark_anomalies(human_bbox['y2'], window_size=30, std_multiplier=10) + human_bbox['y2'] = update_negative_ones(human_bbox['y2']) + human_bbox['x1'] = signal.medfilt(human_bbox['x1'], kernel_size=5).tolist() + human_bbox['y1'] = signal.medfilt(human_bbox['y1'], kernel_size=5).tolist() + human_bbox['x2'] = signal.medfilt(human_bbox['x2'], kernel_size=5).tolist() + human_bbox['y2'] = signal.medfilt(human_bbox['y2'], kernel_size=5).tolist() + + return {'track':track, 'proc_track':dets, 'human_bbox':human_bbox} + + def process_single(self, sample, rank=None): + # there is no video in this sample + if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] + return sample + + if Fields.meta not in sample: + sample[Fields.meta] = {} + + loaded_video_keys = sample[self.video_key] + + Total_result = [] + min_people_in_video = [] + + face_detect_S3FD = get_model(self.face_detect_S3FD_model_key, rank, self.use_cuda()) + human_detection_model = get_model(self.human_detection_model_key, rank, self.use_cuda()) + + for id_out,video_key in enumerate(loaded_video_keys): + # Scene detection for the video frames + scene = scene_detect(video_key) + + video_array = get_video_array_cv2(video_key) + + # Face detection for the video frames + faces = inference_video(video_array, face_detect_S3FD) + + # Face tracking + allTracks, vidTracks = [], [] + minTrack = 10 + for shot in scene: + if shot[1].frame_num - shot[0].frame_num >= minTrack: # Discard the shot frames less than minTrack frames + allTracks.extend(track_shot(faces[shot[0].frame_num:shot[1].frame_num])) # 'frames' to present this tracks' timestep, 'bbox' presents the location of the faces + + # Get face and human tracks + for ii, track in tqdm.tqdm(enumerate(allTracks), total = len(allTracks)): + result = self.get_face_and_human_tracks(video_array, track, human_detection_model) + if result: + vidTracks.append(result) + # merge + people_num_atleast, update_track = post_merge(vidTracks,video_array) + + for i in range(len(update_track)): + save_bbox_name = os.path.join(self.face_track_bbox_path, video_key.split("/")[-1][:-4] +'_'+str(i)+'.pkl') + xy_bbox = update_track[i]['track']['bbox'] + xys_bbox = update_track[i]['proc_track'] + xy_human_bbox = update_track[i]['human_bbox'] + frames = update_track[i]['track']['frame'] + bbox_dict = {'frame':frames, 'xy_bbox':xy_bbox, 'xys_bbox':xys_bbox, 'xy_human_bbox':xy_human_bbox} + f_save = open(save_bbox_name, 'wb') + pickle.dump(bbox_dict, f_save) + f_save.close() + del update_track[i]['human_bbox'] + del update_track[i]['proc_track'] + del update_track[i]['track'] + update_track[i]['bbox_path'] = save_bbox_name + + + Total_result.append(update_track) + min_people_in_video.append(people_num_atleast) + torch.cuda.empty_cache() + + sample[Fields.meta][self.tag_field_name_human_track_path] = Total_result + sample[Fields.meta][self.tag_field_name_people_num] = min_people_in_video + + gc.collect() + torch.cuda.empty_cache() + + return sample diff --git a/data_juicer/ops/mapper/video_human_tracks_face_demographic_mapper.py b/data_juicer/ops/mapper/video_human_tracks_face_demographic_mapper.py new file mode 100644 index 0000000000..d327d20356 --- /dev/null +++ b/data_juicer/ops/mapper/video_human_tracks_face_demographic_mapper.py @@ -0,0 +1,207 @@ +import numpy as np +from data_juicer.utils.constant import Fields, MetaKeys +from deepface import DeepFace +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS +from data_juicer.utils.ASD_mapper_utils import get_video_array_cv2 +import gc + +OP_NAME = 'video_human_tracks_face_demographic_mapper' + +import torch, os +import pickle + +# avoid hanging when calling clip in multiprocessing +torch.set_num_threads(1) + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoHumantrackFaceDemographicMapper(Mapper): + """ + Mapper to generate samples whose captions are generated based on + a video-to-text model and sampled video frame. + + Source: This operator is a part of HumanVBench (CVPR 2026). + """ + + def __init__( + self, + original_data_save_path = '', + detect_interval: int = 5, + tag_field_name: str = MetaKeys.video_facetrack_attribute_demographic, + *args, + **kwargs + ): + """ + Initialization method. + + :param hf_video_blip: video-blip model name on huggingface + to generate caption + """ + super().__init__(*args, **kwargs) + + self.interval = detect_interval + self.original_data_save_path = original_data_save_path + self.tag_field_name = tag_field_name + + def process_single(self, samples, rank=None, context=False): + if not MetaKeys.human_track_data_path in samples[Fields.meta]: + raise ValueError("video_human_tracks_face_demographic_mapper must be operated after video_human_tracks_extraction_mapper.") + + if Fields.meta not in samples: + samples[Fields.meta] = {} + + Total_information = [] + video_samples = samples[Fields.meta][MetaKeys.human_track_data_path] + loaded_video_keys = samples[self.video_key] + + for vedio_id,ASD_attribute_all_tracks_for_one_video in enumerate(video_samples): + if len(ASD_attribute_all_tracks_for_one_video) == 0: + Total_information.append([]) + continue + description_for_each_track = [] + video_array = get_video_array_cv2(loaded_video_keys[vedio_id]) + for track_id,tracks_now in enumerate(ASD_attribute_all_tracks_for_one_video): + face_attribute_dict_with_framestamp = {} + + bbox_path = tracks_now['bbox_path'] + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xys_bbox = bbox_data['xys_bbox'] + track_frame = bbox_data['frame'] + + + total_len = len(track_frame) + if total_len > 75: + interval = int(total_len/15) + else: + interval = self.interval + + + start_frame_id_in = 0 + start_frame_id_out = track_frame[start_frame_id_in] # tag + cs = 0.5 + while start_frame_id_in + interval iouThres and iou > max_iou: + best_match = face + max_iou = iou + else: + break + + if best_match is not None: + track.append(best_match) + frameFaces.remove(best_match) + + if track == []: + break + elif len(track) > minTrack: + frameNum = np.array([ f['frame'] for f in track ]) + bboxes = np.array([np.array(f['bbox']) for f in track]) + frameI = np.arange(frameNum[0],frameNum[-1]+1) + bboxesI = [] + for ij in range(0,4): + interpfn = interp1d(frameNum, bboxes[:,ij]) + bboxesI.append(interpfn(frameI)) + bboxesI = np.stack(bboxesI, axis=1) + if max(np.mean(bboxesI[:,2]-bboxesI[:,0]), np.mean(bboxesI[:,3]-bboxesI[:,1])) > 1: + tracks.append({'frame':frameI,'bbox':bboxesI}) + return tracks + + +def find_human_bounding_box(face_bbox, human_bboxes): + head_x1, head_y1, head_x2, head_y2 = face_bbox + head_center_x = (head_x1 + head_x2)/2 + + candidate_bboxes = [] + + for human_bbox in human_bboxes: + human_x1, human_y1, human_x2, human_y2 = human_bbox + + if (human_x1 <= head_x1 and head_x2 <= human_x2) and (human_y1 <= head_y1 and head_y2 <= human_y2): + candidate_bboxes.append(human_bbox) + + if not candidate_bboxes: + return () + + # Select the human body bounding box with the smallest distance between (x1 + x2) / 2 and (x1 + x2) / 2 of face_bbox + closest_bbox = min(candidate_bboxes, key=lambda bbox: (((bbox[0] + bbox[2]) / 2) - head_center_x)**2 + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) + + return closest_bbox + +def update_negative_ones(values): + n = len(values) + i = 0 + + while i < n: + if values[i] == -1: + # Find the nearest number on the left + left_index = i - 1 + while left_index >= 0 and values[left_index] == -1: + left_index -= 1 + + # Find the nearest number on the right + right_index = i + 1 + while right_index < n and values[right_index] == -1: + right_index += 1 + + # Update the value of -1 + if left_index >= 0 and right_index < n: + left_value = values[left_index] + right_value = values[right_index] + values[i] = (left_value + right_value) / 2 + elif left_index >= 0: + values[i] = values[left_index] + elif right_index < n: + values[i] = values[right_index] + else: + raise ValueError("Unable to find valid values ​​on both the left and right to update -1 at index {i}") + i += 1 + + return values + + +def detect_and_mark_anomalies(data, window_size=7, std_multiplier=2): + data = np.array(data) + result = data.copy() + + for i in range(len(data)): + if data[i] > 0: + start = max(0, i - window_size) + end = min(len(data), i + window_size + 1) + neighbors = data[start:end] + + neighbors = np.delete(neighbors, np.where(neighbors == data[i])) + + positive_neighbors = neighbors[neighbors > 0] + + if len(positive_neighbors) < 2: + continue + + mean = np.mean(positive_neighbors) + std = np.std(positive_neighbors) + + if abs(data[i] - mean) > std * std_multiplier: + result[i] = -1 + + return result + + +def crop_video_with_facetrack(video_array, track, cropFile, audioFilePath,is_empty=False): + if is_empty: + return True + + dets = track['xys_bbox'] + # CPU: crop the face clips + vOut = cv2.VideoWriter(cropFile + 't.avi', cv2.VideoWriter_fourcc(*'XVID'), 25, (224,224))# Write video + + for fidx, frame in enumerate(track['frame']): + cs = 0.4 + bs = dets['s'][fidx] # Detection box size + bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount + image = video_array[frame] + frame = numpy.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110)) + my = dets['y'][fidx] + bsi # BBox center Y + mx = dets['x'][fidx] + bsi # BBox center X + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + vOut.write(cv2.resize(face, (224, 224))) + audioTmp = cropFile + '.wav' + audioStart = (track['frame'][0]) / 25 + audioEnd = (track['frame'][-1]+1) / 25 + vOut.release() + command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 -threads %d -ss %.3f -to %.3f %s -loglevel panic" % \ + (audioFilePath, 10, audioStart, audioEnd, audioTmp)) + output = subprocess.call(command, shell=True, stdout=None) # Crop audio file + _, audio = wavfile.read(audioTmp) + command = ("ffmpeg -y -i %st.avi -i %s -threads %d -c:v copy -c:a copy %s.avi -loglevel panic" % \ + (cropFile, audioTmp, 10, cropFile)) # Combine audio and video file + output = subprocess.call(command, shell=True, stdout=None) + os.remove(cropFile + 't.avi') + return True + + +def evaluate_network(files, s, pycropPath): + # GPU: active speaker detection by pretrained model + allScores = [] + # durationSet = {1,2,4,6} # To make the result more reliable + durationSet = {1,1,1,2,2,2,3,3,4,5,6} # Use this line can get more reliable result + for file in tqdm.tqdm(files, total = len(files)): + fileName = os.path.splitext(file.split('/')[-1])[0] # Load audio and video + _, audio = wavfile.read(os.path.join(pycropPath, fileName + '.wav')) + if len(audio) == 0: + scores = numpy.array([-5]) + allScores.append(allScore) + continue + + audioFeature = python_speech_features.mfcc(audio, 16000, numcep = 13, winlen = 0.025, winstep = 0.010) + + video = cv2.VideoCapture(os.path.join(pycropPath, fileName + '.avi')) + videoFeature = [] + while video.isOpened(): + ret, frames = video.read() + if ret == True: + face = cv2.cvtColor(frames, cv2.COLOR_BGR2GRAY) + face = cv2.resize(face, (224,224)) + face = face[int(112-(112/2)):int(112+(112/2)), int(112-(112/2)):int(112+(112/2))] + videoFeature.append(face) + else: + break + video.release() + videoFeature = np.array(videoFeature) + length = min((audioFeature.shape[0] - audioFeature.shape[0] % 4) / 100, videoFeature.shape[0]) + audioFeature = audioFeature[:int(round(length * 100)),:] + videoFeature = videoFeature[:int(round(length * 25)),:,:] + allScore = [] # Evaluation use model + for duration in durationSet: + batchSize = int(math.ceil(length / duration)) + scores = [] + with torch.no_grad(): + for i in range(batchSize): + inputA = torch.FloatTensor(audioFeature[i * duration * 100:(i+1) * duration * 100,:]).unsqueeze(0).to(next(s.parameters()).device) + inputV = torch.FloatTensor(videoFeature[i * duration * 25: (i+1) * duration * 25,:,:]).unsqueeze(0).to(next(s.parameters()).device) + embedA = s.model.forward_audio_frontend(inputA) + embedV = s.model.forward_visual_frontend(inputV) + out = s.model.forward_audio_visual_backend(embedA, embedV) + score = s.lossAV.forward(out, labels = None) + scores.extend(score) + del inputA + del inputV + del embedA + del embedV + allScore.append(scores) + allScore = numpy.round((numpy.mean(numpy.array(allScore), axis = 0)), 1).astype(float) + allScores.append(allScore) + return allScores + + +def visualization(tracks, scores, video_array, pyaviPath): + # CPU: visulize the result for video format + + faces = [[] for i in range(video_array.shape[0])] + for tidx, track in enumerate(tracks): + score = scores[tidx] + for fidx, frame in enumerate(track['track']['frame'].tolist()): + s = score[max(fidx - 2, 0): min(fidx + 3, len(score) - 1)] # average smoothing + s = numpy.mean(s) + faces[frame].append({'track':tidx, 'score':float(s),'s':track['proc_track']['s'][fidx], 'x':track['proc_track']['x'][fidx], 'y':track['proc_track']['y'][fidx]}) + firstImage = video_array[0] + fw = firstImage.shape[1] + fh = firstImage.shape[0] + vOut = cv2.VideoWriter(os.path.join(pyaviPath, 'video_only.avi'), cv2.VideoWriter_fourcc(*'XVID'), 25, (fw,fh)) + colorDict = {0: 0, 1: 255} + for fidx in tqdm.tqdm(range(video_array.shape[0])): + image = video_array[fidx] + for face in faces[fidx]: + clr = colorDict[int((face['score'] >= 0))] + txt = round(face['score'], 1) + cv2.rectangle(image, (int(face['x']-face['s']), int(face['y']-face['s'])), (int(face['x']+face['s']), int(face['y']+face['s'])),(0,clr,255-clr),10) + cv2.putText(image,'%s'%(txt), (int(face['x']-face['s']), int(face['y']-face['s'])), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,clr,255-clr),5) + vOut.write(image) + vOut.release() + command = ("ffmpeg -y -i %s -i %s -threads %d -c:v copy -c:a copy %s -loglevel panic" % \ + (os.path.join(pyaviPath, 'video_only.avi'), os.path.join(pyaviPath, 'audio.wav'), \ + 10, os.path.join(pyaviPath,'video_out.avi'))) + output = subprocess.call(command, shell=True, stdout=None) + +def calculate_good_matches(matches, ratio=0.75): + good_matches = [] + for m, n in matches: + if m.distance < ratio * n.distance: + good_matches.append(m) + return len(good_matches) + +def find_max_intersection_and_remaining_dicts(dicts): + if not dicts: + return [], [] + + track_frames = [d['track']['frame'] for d in dicts] + + all_elements = set() + for frame in track_frames: + all_elements.update(frame) + + max_combination_indices = [] + max_intersection = set() + + for elem in all_elements: + current_combination_indices = [] + current_intersection = set([elem]) + + for i, frame in enumerate(track_frames): + if elem in frame: + current_combination_indices.append(i) + current_intersection.intersection_update(frame) + + if len(current_combination_indices) > len(max_combination_indices): + max_combination_indices = current_combination_indices + max_intersection = current_intersection + + max_combination = [dicts[i] for i in max_combination_indices] + remaining_dicts = [d for i, d in enumerate(dicts) if i not in max_combination_indices] + + return max_combination, remaining_dicts + +def get_faces_array(frame,s,x,y): + cs = 0.4 + bs = s # Detection box size + bsi = int(bs * (1 + 2 * cs)) # Pad videos by this amount + image = frame + frame = np.pad(image, ((bsi,bsi), (bsi,bsi), (0, 0)), 'constant', constant_values=(110, 110)) + my = y + bsi # BBox center Y + mx = x + bsi # BBox center X + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + return face + + +def order_track_distance(track1,track2,video_array): + # Get the last face frame of track1 and the first face frame of track2 + track1_end_frame = video_array[track1['track']['frame'][-1]] + track1_s = track1['proc_track']['s'][-1] + track1_x = track1['proc_track']['x'][-1] + track1_y = track1['proc_track']['y'][-1] + track1_end_face_array = get_faces_array(track1_end_frame,track1_s,track1_x,track1_y) + + track2_start_frame = video_array[track2['track']['frame'][0]] + track2_s = track2['proc_track']['s'][0] + track2_x = track2['proc_track']['x'][0] + track2_y = track2['proc_track']['y'][0] + track2_strat_face_array = get_faces_array(track2_start_frame,track2_s,track2_x,track2_y) + + # Calculate the area overlap ratio + track1_bbox = track1['track']['bbox'][-1] + track2_bbox = track2['track']['bbox'][0] + iou = bb_intersection_over_union(track1_bbox, track2_bbox) + if iou <= 0.2: + distance_iou = 10000 + else: + distance_iou = math.exp(-5*iou) + + normalized_distance = 0 + + # face_id distance (with facenet) + result = DeepFace.verify(track1_end_face_array, track2_strat_face_array, model_name='Facenet', detector_backend = 'skip') + facenet_distance = result['distance'] + if facenet_distance > 0.85: + facenet_distance = facenet_distance + 10000 + + distance = 2*distance_iou + normalized_distance + facenet_distance + + return distance + +def update_remain(remaining_dicts, pop_item): + updated_dicts = [item for item in remaining_dicts if item['track']['bbox'].shape != pop_item['track']['bbox'].shape or (item['track']['bbox'] != pop_item['track']['bbox']).any()] + return updated_dicts + +def order_merge_tracks(track1,track2): + new_track = {} + new_track['proc_track'] = {} + new_track['proc_track']['x'] = track1['proc_track']['x'] + track2['proc_track']['x'] + new_track['proc_track']['y'] = track1['proc_track']['y'] + track2['proc_track']['y'] + new_track['proc_track']['s'] = track1['proc_track']['s'] + track2['proc_track']['s'] + new_track['human_bbox'] = {} + new_track['human_bbox']['x1'] = track1['human_bbox']['x1'] + track2['human_bbox']['x1'] + new_track['human_bbox']['y1'] = track1['human_bbox']['y1'] + track2['human_bbox']['y1'] + new_track['human_bbox']['x2'] = track1['human_bbox']['x2'] + track2['human_bbox']['x2'] + new_track['human_bbox']['y2'] = track1['human_bbox']['y2'] + track2['human_bbox']['y2'] + + new_track['track'] = {} + for key in list(track1['track'].keys()): + object1 = track1['track'][key] + object2 = track2['track'][key] + if isinstance(object1, np.ndarray): + new_track['track'][key] = np.concatenate((object1, object2)) + elif isinstance(object1, list): + new_track['track'][key] = object1 + object2 + else: + raise('new data type') + + return new_track + +def post_merge(vidTracks,video_array): + # Find the maximum overlapping tracks as the initial anchor + anchor_combination, remaining_dicts = find_max_intersection_and_remaining_dicts(vidTracks) + end_frame = video_array.shape[0] + continue_flag = np.ones((len(anchor_combination),2)) + max_iteration = 10 + iteration_count = 0 + while iteration_count0: + for track_ind in range(len(anchor_combination)): + track = anchor_combination[track_ind] + # Try to extend forward + if continue_flag[track_ind][0]: + if track['track']['frame'][0] == 0: + continue_flag[track_ind][0] = 0 + else: + # Find the candidate that is connected to it and is in the front row + possible_prior_tracks = [] + for checktrack in remaining_dicts: + if checktrack['track']['frame'][-1]+1 == track['track']['frame'][0] or checktrack['track']['frame'][-1]+2 == track['track']['frame'][0]: + possible_prior_tracks.append(checktrack) + # If it is not zero, then check the calculated distance + if len(possible_prior_tracks)>0: + distance_score_list = [] + for possible_prior_track in possible_prior_tracks: + distance_score_list.append(order_track_distance(possible_prior_track, track, video_array)) + distance_score_array = np.array(distance_score_list) + if min(distance_score_array) < 10000: + min_index = np.argmin(distance_score_array) + new_anchor = order_merge_tracks(possible_prior_tracks[min_index], track) + # update_anchor() + anchor_combination[track_ind] = new_anchor + track = new_anchor + remaining_dicts = update_remain(remaining_dicts, possible_prior_tracks[min_index]) + else: + continue_flag[track_ind][0] = 0 + else: + continue_flag[track_ind][0] = 0 + # Try to extend backwards + if continue_flag[track_ind][1]: + if track['track']['frame'][-1] == end_frame: + continue_flag[track_ind][0] = 0 + else: + # Find the candidate that is connected to it and in front of it + possible_after_tracks = [] + for checktrack in remaining_dicts: + if checktrack['track']['frame'][0]-1 == track['track']['frame'][-1] or checktrack['track']['frame'][0]-2 == track['track']['frame'][-1]: + possible_after_tracks.append(checktrack) + # If it is not zero, then check the calculated distance + if len(possible_after_tracks)>0: + distance_score_list = [] + for possible_after_track in possible_after_tracks: + distance_score_list.append(order_track_distance(track, possible_after_track, video_array)) + distance_score_array = np.array(distance_score_list) + if min(distance_score_array) < 10000: + min_index = np.argmin(distance_score_array) + new_anchor = order_merge_tracks(track, possible_after_tracks[min_index]) + # update_anchor() + anchor_combination[track_ind] = new_anchor + remaining_dicts = update_remain(remaining_dicts, possible_after_tracks[min_index]) + else: + continue_flag[track_ind][1] = 0 + else: + continue_flag[track_ind][1] = 0 + + final_tracks = anchor_combination + remaining_dicts + if len(final_tracks) > 5: + sorted_tracks = sorted(final_tracks, key=lambda x: len(x['track']['frame']), reverse=True) + top_tracks = sorted_tracks[:5] + else: + top_tracks = final_tracks + # return len(anchor_combination), top_5_tracks + returntracks = [] + for item in top_tracks: + if len(item['track']['frame'])>15: + returntracks.append(item) + return len(anchor_combination), returntracks + + +def longest_continuous_actives(arr): + max_length = 0 + current_length = 0 + + for num in arr: + if num > 0: + current_length += 1 + if current_length > max_length: + max_length = current_length + else: + current_length = 0 + + return max_length + +import pickle +import moviepy as mp + +def annotate_video_with_bounding_boxes_with_audio(video_path, q_human_video_track_bbox, output_path): + bbox_path = q_human_video_track_bbox['bbox_path'] + frame_indices = q_human_video_track_bbox['track']['frame'] + video_array = get_video_array_cv2(video_path) + + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xy_bbox = bbox_data['xy_bbox'] + + # Get video dimensions and frame rate + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) # Get original video frame rate + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + temp_video_path = output_path.split('.')[0] + 'temp.mp4' + out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) # Use original FPS + + # Annotate video frames with bounding boxes + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = xy_bbox[idx] + # Draw bounding box + thickness = max(int((x2 - x1) / 40), 2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thickness) + + # Write frame to temporary video + out.write(frame) + + out.release() + cap.release() # Release the video capture object + + # Load original video and audio + original_video = mp.VideoFileClip(video_path) + annotated_video = mp.VideoFileClip(temp_video_path) + + # Combine annotated video with original audio, ensuring alignment + final_video = annotated_video.set_audio(original_video.audio) + + # Write the final output video with audio + final_video.write_videofile(output_path, codec='libx264', audio_codec='aac', fps=fps) + + # Clean up temporary video file + annotated_video.close() + original_video.close() + + # Optionally, remove the temporary video file + import os + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + + return output_path + +def annotate_video_with_bounding_boxes_withText_with_audio(video_path, q_human_video_track_bbox, output_path, numbers): + bbox_path = q_human_video_track_bbox['bbox_path'] + frame_indices = q_human_video_track_bbox['track']['frame'] + video_array = get_video_array_cv2(video_path) + + with open(bbox_path, 'rb') as f: + bbox_data = pickle.load(f) + xy_bbox = bbox_data['xy_bbox'] + + # Get video dimensions and frame rate + cap = cv2.VideoCapture(video_path) + fps = cap.get(cv2.CAP_PROP_FPS) # Get original video frame rate + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + temp_video_path = output_path.split('.')[0] + 'temp.mp4' + out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) # Use original FPS + + # Annotate video frames with bounding boxes + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = xy_bbox[idx] + # Draw bounding box + thickness = max(int((x2 - x1) / 40), 2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thickness) + # Put the number in the top-left corner of the bounding box + cv2.putText(frame, numbers, (int(x1) + 10, int(y1) + 35), cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 0, 255), 3) + + # Write frame to temporary video + out.write(frame) + + out.release() + cap.release() # Release the video capture object + + # Load original video and audio + original_video = mp.VideoFileClip(video_path) + annotated_video = mp.VideoFileClip(temp_video_path) + + # Combine annotated video with original audio, ensuring alignment + final_video = annotated_video.set_audio(original_video.audio) + + # Write the final output video with audio + final_video.write_videofile(output_path, codec='libx264', audio_codec='aac', fps=fps) + + # Clean up temporary video file + annotated_video.close() + original_video.close() + + # Optionally, remove the temporary video file + import os + if os.path.exists(temp_video_path): + os.remove(temp_video_path) + + return output_path + + +def annotate_video_with_bounding_boxes(video_array, frame_indices, bounding_boxes, output_path): + """ + Annotates specified frames in the video with bounding boxes and saves the result to a new video file. + + :param video_array: Input video as a numpy array with shape (num_frames, height, width, channels). + :param frame_indices: List of frame indices to annotate. + :param bounding_boxes: Array of bounding box coordinates with shape (num_frames_to_annotate, 4), where each bounding box is (x, y, w, h). + :param output_path: Path to save the output video. + """ + # Get video dimensions + num_frames, height, width, channels = video_array.shape + assert channels == 3, "Input video must have 3 channels (BGR)." + + # Initialize video writer + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for mp4 + out = cv2.VideoWriter(output_path, fourcc, 30.0, (width, height)) + + # option 1: keep all video + for i in range(num_frames): + frame = video_array[i] + if i in frame_indices: + idx = frame_indices.index(i) + x1, y1, x2, y2 = bounding_boxes[idx] + # Draw bounding box + thinkness = max(int((x2-x1)/40),2) + cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), thinkness) + + # Write frame to output video + out.write(frame) + + # option 2:crap + # for in_id, out_id in enumerate(frame_indices): + # frame = video_array[out_id] + # x1, y1, x2, y2 = bounding_boxes[in_id] + # # Draw bounding box + # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 5) + # # Write frame to output video + # out.write(frame) + + out.release() + return output_path + + +def crop_from_array(frame_before_crop, coords): + x1, y1, x2, y2 = coords + cropped_frame = frame_before_crop[y1:y2, x1:x2] + return cropped_frame \ No newline at end of file diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 73fd3c93e3..308569ce55 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -39,6 +39,19 @@ class BatchMetaKeys(object): class MetaKeys(object): # === text related tags === + + # # humanvbench related tags (CVPR'26) + active_speaker_flag = 'active_speaker_flag' + audio_speech_attribute = 'audio_speech_attribute' + speech_ASR = 'speech_ASR' + speech_emotion = 'speech_emotion' + video_facetrack_attribute_demographic = 'video_facetrack_attribute_demographic' + video_facetrack_attribute_emotion = 'video_facetrack_attribute_emotion' + track_video_caption = 'track_video_caption' + video_track_is_child = 'video_track_is_child' + human_track_data_path = 'human_track_data_path' + number_people_in_video = 'number_people_in_video' + # # sentiment dialog_sentiment_intensity = "dialog_sentiment_intensity" dialog_sentiment_intensity_analysis = "dialog_sentiment_intensity_analysis" @@ -314,6 +327,9 @@ class StatsKeysConstant(object): # general-field-filter general_field_filter_condition = "general_field_filter_condition" + # video-face-ratio + video_face_exist = 'video_face_exist' + class StatsKeys(object, metaclass=StatsKeysMeta): _constants_class = StatsKeysConstant diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index 625eb0f631..689fa0ac61 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -11,7 +11,6 @@ import aiohttp import pandas as pd -from datasets.utils.extract import GzipExtractor from datasets.utils.extract import ZstdExtractor as Extractor from data_juicer.utils.common_utils import dict_to_hash @@ -113,12 +112,6 @@ def find_files_with_suffix( # just like '.jsonl.zst' file_suffixes = [suffix.lower() for suffix in file.suffixes] suffix = "".join(file_suffixes[-2:]) - elif GzipExtractor.is_extractable(file): - # support gzip-format file - # and use the last 2 sub-suffixes as the final suffix - # just like '.jsonl.gz' - file_suffixes = [suffix.lower() for suffix in file.suffixes] - suffix = "".join(file_suffixes[-2:]) if not suffixes or (suffix in suffixes): if suffix not in file_dict: diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 40336da076..684f4dc11a 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -1643,6 +1643,252 @@ def _download_model(local_dir): return estimator +def prepare_SenseVoiceSmall_model(pretrained_model_name_or_path, **model_params): + """ + Prepare and load light sharegpt4video. + + :param model_name: input model name. + """ + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models")) + diff_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models/SenseVoice_changes.diff")) + repo_url = "https://github.com/FunAudioLLM/SenseVoice.git" + + final_dir = os.path.join(base_dir, 'SenseVoice') + + if not os.path.exists(final_dir): + print(f"Starting direct clone from {repo_url}...") + try: + os.makedirs(base_dir, exist_ok=True) + subprocess.run(["git", "clone", "--depth", "1", repo_url, final_dir], check=True) + + if os.path.exists(diff_file): + print(f"Applying patch: {diff_file}") + subprocess.run(["git", "-C", final_dir, "apply", diff_file], check=True) + else: + print(f"Warning: Patch file not found at {diff_file}") + + except (subprocess.CalledProcessError, Exception) as e: + print(f"Operation failed: {e}") + if os.path.exists(final_dir): + shutil.rmtree(final_dir) + return + else: + print(f"Directory {final_dir} already exists.") + + from thirdparty.humanvbench_models.SenseVoice.model import SenseVoiceSmall + + logger.info('Loading ASR_model model...') + ASR_Emo_model, kwargs1 = SenseVoiceSmall.from_pretrained(model=pretrained_model_name_or_path) + + ASR_Emo_model.eval() + return ASR_Emo_model, kwargs1 + +def prepare_light_asd_model( + pretrained_model_name_or_path='weight/finetuning_TalkSet.model', **model_params): + """ + Prepare and load light asd model. + + :param model_name: input model name. + """ + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models")) + pretrained_model_name_or_path = os.path.join(base_dir, 'Light-ASD/weight/finetuning_TalkSet.model') + + logger.info('Loading light_asd model...') + from ASD import ASD + model = ASD() + model.loadParameters(pretrained_model_name_or_path) + model.eval() + return model + +import subprocess +import shutil + +def prepare_YOLOv8_human_model( + pretrained_model_name_or_path='./thirdparty/humanvbench_models/YOLOv8_human/weights/best.pt', **model_params): + """ + Prepare and load light YOLOv8_human. + + :param model_name: input model name. + """ + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models")) + diff_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models/YOLOv8_human_changes.diff")) + repo_url = "https://github.com/jahongir7174/YOLOv8-human.git" + + old_name = "YOLOv8-human" + new_name = "YOLOv8_human" + + target_dir = os.path.join(base_dir, old_name) + final_dir = os.path.join(base_dir, new_name) + + if not os.path.exists(final_dir): + print(f"Starting direct clone from {repo_url}...") + try: + os.makedirs(base_dir, exist_ok=True) + + subprocess.run(["git", "clone", "--depth", "1", repo_url, target_dir], check=True) + + print(f"Renaming {old_name} to {new_name}...") + os.rename(target_dir, final_dir) + + if os.path.exists(diff_file): + print(f"Applying patch: {diff_file}") + subprocess.run(["git", "-C", final_dir, "apply", diff_file], check=True) + print("Setup completed successfully.") + else: + print(f"Warning: Patch file not found at {diff_file}") + + except (subprocess.CalledProcessError, Exception) as e: + print(f"Operation failed: {e}") + for d in [target_dir, final_dir]: + if os.path.exists(d): + shutil.rmtree(d) + print("Cleanup finished.") + else: + print(f"Directory {final_dir} already exists. Skipping setup.") + + + logger.info('Loading YOLOv8_human model...') + pretrained_model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models/YOLOv8_human/weights/best.pt")) + human_detection_model = torch.load(pretrained_model_path, weights_only=False)['model'].float() + human_detection_model.half() + human_detection_model.eval() + return human_detection_model + +# import sys +# sys.path.append("../thirdparty/humanvbench_models/Light-ASD") +def prepare_face_detect_S3FD_model(model_path=None, **model_params): + """ + Prepare and load light asd model. + + :param model_name: input model name. + """ + base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models")) + diff_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models/Light-ASD_changes.diff")) + repo_url = "https://github.com/Junhua-Liao/Light-ASD.git" + + final_dir = os.path.join(base_dir, 'Light-ASD') + model_save_dir = os.path.join(final_dir, "model/faceDetector/s3fd") + + if not os.path.exists(final_dir): + print(f"Starting direct clone from {repo_url}...") + try: + os.makedirs(base_dir, exist_ok=True) + subprocess.run(["git", "clone", "--depth", "1", repo_url, final_dir], check=True) + + if os.path.exists(diff_file): + print(f"Applying patch: {diff_file}") + subprocess.run(["git", "-C", final_dir, "apply", diff_file], check=True) + else: + print(f"Warning: Patch file not found at {diff_file}") + + except (subprocess.CalledProcessError, Exception) as e: + print(f"Operation failed: {e}") + if os.path.exists(final_dir): + shutil.rmtree(final_dir) + return + else: + print(f"Directory {final_dir} already exists.") + + def download_sfd_model(target_dir): + model_url = "https://huggingface.co/lithiumice/syncnet/resolve/main/sfd_face.pth" + mirror_model_url = "https://hf-mirror.com/lithiumice/syncnet/resolve/main/sfd_face.pth" + + target_path = os.path.join(target_dir, "sfd_face.pth") + + if os.path.exists(target_path): + print(f"Model file {target_path} already exists. Skipping download.") + return + + print(f"Downloading sfd_face.pth to {target_dir}...") + os.makedirs(target_dir, exist_ok=True) + + try: + subprocess.run(["wget", "-c", model_url, "-O", target_path], check=True) + print("Model download completed successfully.") + except subprocess.CalledProcessError as e: + try: + subprocess.run(["wget", "-c", mirror_model_url, "-O", target_path], check=True) + print("Model download completed successfully.") + except subprocess.CalledProcessError as e: + print(f"Failed to download model: {e}") + if os.path.exists(target_path): + os.remove(target_path) + + download_sfd_model(model_save_dir) + print("Setup and model preparation completed.") + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../thirdparty/humanvbench_models/Light-ASD"))) + + logger.info('Loading face_detect_S3FD_model model...') + from model.faceDetector.s3fd import S3FD + model = S3FD() + return model + + +import torch +import torch.nn as nn +from transformers import Wav2Vec2Processor +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, +) +def prepare_wav2vec2_age_gender_model(pretrained_model_name_or_path = 'audeering/wav2vec2-large-robust-24-ft-age-gender', **model_params): + + class ModelHead(nn.Module): + r"""Classification head.""" + + def __init__(self, config, num_labels): + + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, num_labels) + + def forward(self, features, **kwargs): + + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + + class AgeGenderModel(Wav2Vec2PreTrainedModel): + r"""Speech emotion classifier.""" + + def __init__(self, config): + + super().__init__(config) + + self.config = config + self.wav2vec2 = Wav2Vec2Model(config) + self.age = ModelHead(config, 1) + self.gender = ModelHead(config, 3) + self.init_weights() + + def forward( + self, + input_values, + ): + + outputs = self.wav2vec2(input_values) + hidden_states = outputs[0] + hidden_states = torch.mean(hidden_states, dim=1) + logits_age = self.age(hidden_states) + logits_gender = torch.softmax(self.gender(hidden_states), dim=1) + + return hidden_states, logits_age, logits_gender + + processor = Wav2Vec2Processor.from_pretrained(pretrained_model_name_or_path) + model = AgeGenderModel.from_pretrained(pretrained_model_name_or_path) + return model, processor + + + MODEL_FUNCTION_MAPPING = { "api": prepare_api_model, "deepcalib": prepare_deepcalib_model, @@ -1671,6 +1917,11 @@ def _download_model(local_dir): "embedding": prepare_embedding_model, "sam_3d_body": prepare_sam_3d_body_model, "mmlab": prepare_mmlab_model, + 'Light_ASD': prepare_light_asd_model, + 'SenseVoiceSmall': prepare_SenseVoiceSmall_model, + 'YOLOv8_human': prepare_YOLOv8_human_model, + 'face_detect_S3FD': prepare_face_detect_S3FD_model, + 'wav2vec2_age_gender': prepare_wav2vec2_age_gender_model } _MODELS_WITHOUT_FILE_LOCK = {"fasttext", "fastsam", "kenlm", "nltk", "recognizeAnything", "sentencepiece", "spacy"} diff --git a/demos/data/demo-dataset-videos2.jsonl b/demos/data/demo-dataset-videos2.jsonl new file mode 100644 index 0000000000..df0ea65b8a --- /dev/null +++ b/demos/data/demo-dataset-videos2.jsonl @@ -0,0 +1,6 @@ +{"videos":["../../tests/ops/data/video10.mp4"], "text": ""} +{"videos":["../../tests/ops/data/video11.mp4"], "text": ""} +{"videos":["../../tests/ops/data/video12.mp4"], "text": ""} +{"videos":["../../tests/ops/data/video13.mp4"], "text": ""} +{"videos":["../../tests/ops/data/video14.mp4"], "text": ""} +{"videos":["../../tests/ops/data/video15.mp4"], "text": ""} diff --git a/demos/video_humanvbench_simple/analyzer.yaml b/demos/video_humanvbench_simple/analyzer.yaml new file mode 100644 index 0000000000..bf265a35f6 --- /dev/null +++ b/demos/video_humanvbench_simple/analyzer.yaml @@ -0,0 +1,72 @@ +# Process config example for dataset + +# global parameters +project_name: 'demo-analyzer' +dataset_path: './demos/data/demo-dataset-videos2.jsonl' # path to your dataset directory or file +np: 4 # number of subprocess to process your dataset + +export_path: './outputs/demo-analyzer/demo-analyzer-result.jsonl' + # number of subprocess to process your dataset + # Note: currently, we support specify only ONE key for each op, for cases requiring multiple keys, users can specify the op multiple times. We will only use the first key of `text_keys` when you set multiple keys. +open_tracer: false # whether to open the tracer to trace the changes during process. It might take more time when opening tracer + +# for multimodal data processing +video_key: 'videos' # key name of field to store the list of sample video paths. +video_special_token: '<__dj__video>' # the special token that represents a video in the text. In default, it's "<__dj__video>". You can specify your own special token according to your input dataset. +eoc_special_token: '<|__dj__eoc|>' # the special token that represents the end of a chunk in the text. In default, it's "<|__dj__eoc|>". You can specify your own special token according to your input dataset. + +keep_stats_in_res_ds: true + +# process schedule +# a list of several process operators with their arguments +process: + - video_face_ratio_filter: # Filter to retain human-centric videos + threshold: 0.2 # The lower limit of the ratio of frames with faces to the total number of video frames + detect_interval: 4 + any_or_all: any + + - video_human_tracks_extraction_mapper: # Get the body and face trajectory bounding box of people in one shot of the video. To ensure correctness, it should be applied after video_split_by_scene_mapper + original_data_save_path: your_path/bounding_box_track # The location where the specific results of each frame's detection are stored + detect_interval: 5 + + - video_human_tracks_face_demographic_mapper: # Get the facial demographics of each person based on the results of video_human_tracks_extraction_mapper + original_data_save_path: your_path/bounding_box_track # The location where the specific results of each frame's detection are stored + detect_interval: 5 + + + - video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video. + hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model. + tag_field_name: 'video_audio_tags' # the field name to store the tags. It's "video_audio_tags" in default. + memory: '500MB' + + - video_audio_detect_age_gender_mapper: # If the audio is speech, classify the gender and age of the speech + hf_audio_mapper: 'audeering/wav2vec2-large-robust-24-ft-age-gender' # Huggingface model name for speech age and gender classification + mem_required: '7GB' + + - video_captioning_from_human_tracks_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on the single person in the video for captioning + video_describe_model_path: DAMO-NLP-SG/VideoLLaMA3-7B # model path to VideoLLaMA3-7B + trust_remote_code: true + temp_video_path: your_path # Used to store temporary videos that will be removed finally. + mem_required: '25GB' + + - video_captioning_face_attribute_emotion_mapper: # Based on the results of video_human_tracks_extraction_mapper, focus on judging the gender, age, and race of a single person in the video + face_track_query: Please only describe the appearance and facial emotions of the person in the video in detail. Don't mention the background. Less than 80 words. + trust_remote_code: true + cropping_face_video_temp_path: your_path # Used to store temporary videos + video_describe_model_path: DAMO-NLP-SG/VideoLLaMA3-7B # Huggingface model DAMO-NLP-SG/VideoLLaMA3-7B + mem_required: '25GB' + + - video_active_speaker_detect_mapper: # Based on the results of video_human_tracks_extraction_mapper, determine whether each person is an active speaker + temp_save_path: your_path # Used to store temporary videos + active_threshold: 15 # Higher values are stricter, reducing false positives from noise but potentially increasing missed detections + mem_required: '10GB' + + + - video_audio_ASR_mapper: # Automatic speech recognition from video speech + model_dir_ASR: 'FunAudioLLM/SenseVoiceSmall' # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + + - video_audio_speech_emotion_mapper: # Speech emotion recognition from video speech + model_dir_emo: 'FunAudioLLM/SenseVoiceSmall' # Huggingface model FunAudioLLM/SenseVoiceSmall + mem_required: '20GB' + diff --git a/docs/Cache.md b/docs/Cache.md deleted file mode 100644 index a6d758a336..0000000000 --- a/docs/Cache.md +++ /dev/null @@ -1,276 +0,0 @@ -# Cache Management - -This document describes DataJuicer's cache management system, including HuggingFace dataset caching, cache directory configuration, cache compression, and temporary storage. - -## Overview - -DataJuicer provides a caching mechanism based on HuggingFace Datasets to avoid redundant computation. When enabled, each operator generates cache files with a unique **fingerprint** based on: -- The fingerprint of the input data -- The operator name and parameters -- The hash of the processing function - -That is: same input + same operator configuration = same fingerprint = cache hit. Therefore, re-running the same pipeline on the same data will skip already-computed steps. For more details, please refer to our paper: [Data-Juicer: A One-Stop Data Processing System for Large Language Models](https://arxiv.org/abs/2309.02033). - -The cache system also provides: -- **Configurable cache directories** via environment variables or config options -- **Cache compression** to reduce disk usage for large-scale datasets -- **Temporary storage** for intermediate files in non-cache mode -- **Fine-grained cache control** via context managers and decorators - -## Configuration - -### Basic Cache Settings - -```yaml -use_cache: true # Enable/disable HuggingFace dataset caching -ds_cache_dir: null # Custom cache directory (overrides HF_DATASETS_CACHE) -cache_compress: null # Compression method: 'gzip', 'zstd', 'lz4', or null -temp_dir: null # Temp directory for intermediate files when cache is disabled -``` - -### Command Line - -```bash -# Enable caching (default) -dj-process --config config.yaml --use_cache true - -# Disable caching -dj-process --config config.yaml --use_cache false - -# Enable cache compression -dj-process --config config.yaml --cache_compress zstd - -# Custom cache directory -dj-process --config config.yaml --ds_cache_dir /fast-storage/dj-cache -``` - -## Cache Directory Structure - -DataJuicer organizes cache files in a hierarchical directory structure controlled by environment variables: - -``` -~/.cache/ # CACHE_HOME (default) -└── data_juicer/ # DATA_JUICER_CACHE_HOME - ├── assets/ # DATA_JUICER_ASSETS_CACHE - │ └── (extracted frames, stopwords, flagged words, etc.) - └── models/ # DATA_JUICER_MODELS_CACHE - └── (downloaded model files) -``` - -### Environment Variables - -| Variable | Default | Description | -|----------|---------|-------------| -| `CACHE_HOME` | `~/.cache` | Root cache directory | -| `DATA_JUICER_CACHE_HOME` | `$CACHE_HOME/data_juicer` | DataJuicer cache root | -| `DATA_JUICER_ASSETS_CACHE` | `$DATA_JUICER_CACHE_HOME/assets` | Assets cache (frames, word lists, etc.) | -| `DATA_JUICER_MODELS_CACHE` | `$DATA_JUICER_CACHE_HOME/models` | Downloaded models cache | -| `DATA_JUICER_EXTERNAL_MODELS_HOME` | `None` | External models directory | - -Override defaults by setting environment variables: - -```bash -export DATA_JUICER_CACHE_HOME=/data/dj-cache -export DATA_JUICER_MODELS_CACHE=/models/dj-models -dj-process --config config.yaml -``` - -## Cache Compression - -For large-scale datasets (tens of GB or more), cache files can consume significant disk space. Cache compression reduces storage requirements by compressing intermediate cache files after each operator completes. - -### Supported Algorithms - -| Algorithm | Library | Speed | Compression Ratio | Recommended For | -|-----------|---------|-------|-------------------|-----------------| -| `zstd` | zstandard | Fast | High | General use (default) | -| `lz4` | lz4 | Fastest | Moderate | Speed-critical workloads | -| `gzip` | gzip | Slow | High | Compatibility needs | - -### Configuration - -```yaml -use_cache: true -cache_compress: zstd # Enable zstd compression -``` - -```bash -dj-process --config config.yaml --cache_compress zstd -``` - -### Multi-Process Compression - -Cache compression supports parallel processing. The number of compression worker processes is controlled by the `np` parameter: - -```yaml -np: 4 # Number of parallel workers (also used for compression) -cache_compress: zstd -``` - -## Cache Control API - -### DatasetCacheControl - -A context manager to temporarily enable or disable HuggingFace dataset caching within a specific scope: - -```python -from data_juicer.utils.cache_utils import DatasetCacheControl - -# Temporarily disable caching -with DatasetCacheControl(on=False): - # Operations here will not use cache - result = dataset.map(my_function) - -# Temporarily enable caching -with DatasetCacheControl(on=True): - # Operations here will use cache - result = dataset.map(my_function) -``` - -### dataset_cache_control Decorator - -A decorator for functions that need to control cache state: - -```python -from data_juicer.utils.cache_utils import dataset_cache_control - -@dataset_cache_control(on=False) -def process_without_cache(dataset): - return dataset.map(my_function) -``` - -### CompressionOff - -A context manager to temporarily disable cache compression: - -```python -from data_juicer.utils.compress import CompressionOff - -with CompressionOff(): - # Cache compression is disabled in this scope - result = dataset.map(my_function) -``` - -### CompressManager - -Low-level API for manual compression/decompression: - -```python -from data_juicer.utils.compress import CompressManager - -manager = CompressManager(compressor_format="zstd") - -# Compress a file -manager.compress("input.arrow", "input.arrow.zstd") - -# Decompress a file -manager.decompress("input.arrow.zstd", "input.arrow") -``` - -### CacheCompressManager - -High-level API for managing HuggingFace dataset cache compression: - -```python -from data_juicer.utils.compress import CacheCompressManager - -manager = CacheCompressManager(compressor_format="zstd") - -# Compress previous dataset's cache files -manager.compress(prev_ds=previous_dataset, this_ds=current_dataset, num_proc=4) - -# Decompress cache files for a dataset -manager.decompress(ds=dataset, num_proc=4) - -# Clean up all compressed cache files -manager.cleanup_cache_files(ds=dataset) -``` - -## Cache vs Checkpoint - -Cache and checkpoint are mutually exclusive — enabling checkpoint automatically disables cache: - -| Feature | Cache | Checkpoint | -|---------|-------|------------| -| **Purpose** | Accelerate repeated runs with same configuration | Fault recovery and resumption | -| **Granularity** | Per-operator result | Full dataset snapshot | -| **Storage Location** | HuggingFace cache directory | Work directory | -| **Recovery Method** | Automatic (hash-based) | Manual (config-based) | -| **Compression** | Supported (`cache_compress`) | Not applicable | -| **Scenario** | Iterative development, parameter tuning | Long-running production tasks | - -```yaml -# Cache mode (default) -use_cache: true -use_checkpoint: false - -# Checkpoint mode (cache auto-disabled) -use_cache: true # Will be overridden to false -use_checkpoint: true -``` - -## Disabling Cache and Temporary Directory - -When `use_cache: false` or checkpoint mode is enabled (`use_checkpoint: true`), HuggingFace dataset caching is fully disabled. In this mode, DataJuicer writes intermediate files produced during operator processing to a temporary directory, and cleans them up automatically when processing completes. The `temp_dir` parameter controls where these intermediate files are stored. - -### Behavior - -- **Defaults to `null`**: When not set, the operating system determines the temporary directory location (typically `/tmp`), equivalent to Python's `tempfile.gettempdir()`. -- **Takes effect automatically when cache is disabled**: Once caching is disabled, `temp_dir` is applied as the global temporary directory for the entire process via Python's `tempfile.tempdir`, affecting all temporary files created through `tempfile` in the process. -- **Cache compression is automatically disabled**: When caching is disabled, `cache_compress` is automatically ignored and reset to `null`. - -### Configuration - -```yaml -use_cache: false -temp_dir: /data/dj-temp # Custom temp directory; null means system default -``` - -```bash -dj-process --config config.yaml --use_cache false --temp_dir /data/dj-temp -``` - -### Safety Notes - -> **Set `temp_dir` with caution — an unsafe path can cause unexpected program behavior.** - -- **Do not point to critical system directories** (e.g., `/`, `/usr`, `/etc`). Automatic cleanup of temporary files may accidentally delete important files. -- **Do not point to directories containing important data**. Temporary file writes and cleanup operations may conflict with existing files. -- **Ensure sufficient disk space**: When cache is disabled, intermediate files are written and deleted dynamically during processing. Peak disk usage is approximately equal to the output size of a single operator. -- **The directory is created automatically if it does not exist**: DataJuicer calls `os.makedirs` to create the specified path if it is missing. -- **`temp_dir` affects the entire process's `tempfile` behavior**: Because it sets the global `tempfile.tempdir` variable, this setting influences all components in the process that rely on `tempfile`, including third-party libraries. - -## Performance Considerations - -### When to Enable Cache - -- **Enable**: For iterative development where you frequently re-run pipelines with minor changes -- **Enable**: When operators are computationally expensive and you want to skip already-computed steps -- **Disable**: For one-shot processing to avoid disk overhead - -### When to Enable Compression - -- **Enable**: When dataset size exceeds tens of GB and disk space is limited -- **Enable** `zstd`: For the best balance of speed and compression ratio -- **Enable** `lz4`: When compression speed is critical -- **Disable**: When disk space is abundant and you want maximum processing speed - -## Troubleshooting - -**Cache files consuming too much disk space:** -```bash -# Check cache directory size -du -sh ~/.cache/data_juicer/ - -# Enable compression -dj-process --config config.yaml --cache_compress zstd -``` - -**Stale cache causing unexpected results:** -```bash -# Clear HuggingFace dataset cache -rm -rf ~/.cache/huggingface/datasets/ - -# Or specify a fresh cache directory -dj-process --config config.yaml --ds_cache_dir /tmp/fresh-cache -``` diff --git a/docs/Cache_ZH.md b/docs/Cache_ZH.md deleted file mode 100644 index ebb114513a..0000000000 --- a/docs/Cache_ZH.md +++ /dev/null @@ -1,275 +0,0 @@ -# 缓存管理 - -本文档描述 DataJuicer 的缓存管理系统,包括 HuggingFace 数据集缓存、缓存目录配置、缓存压缩和临时存储。 - -## 概述 - -DataJuicer 提供了基于 HuggingFace Datasets 的缓存机制来避免重复计算。启用后,每个算子根据以下内容生成具有唯一 **fingerprint(指纹)** 的缓存文件: -- 输入数据的指纹 -- 算子名称和参数 -- 处理函数的哈希值 - -即:相同输入 + 相同算子配置 = 相同指纹 = 缓存命中。因此在相同数据上重新运行相同的管道时会跳过已计算的步骤。更多细节请参考我们的论文:[Data-Juicer: A One-Stop Data Processing System for Large Language Models](https://arxiv.org/abs/2309.02033) 。 - -缓存系统还提供: -- **可配置的缓存目录**,通过环境变量或配置选项设置 -- **缓存压缩**,减少大规模数据集的磁盘占用 -- **临时存储**,用于非缓存模式下的中间文件 -- **细粒度缓存控制**,通过上下文管理器和装饰器实现 - -## 配置 - -### 基本缓存设置 - -```yaml -use_cache: true # 启用/禁用 HuggingFace 数据集缓存 -ds_cache_dir: null # 自定义缓存目录(覆盖 HF_DATASETS_CACHE) -cache_compress: null # 压缩方法:'gzip'、'zstd'、'lz4' 或 null -temp_dir: null # 缓存禁用时的中间文件临时目录 -``` - -### 命令行 - -```bash -# 启用缓存(默认) -dj-process --config config.yaml --use_cache true - -# 禁用缓存 -dj-process --config config.yaml --use_cache false - -# 启用缓存压缩 -dj-process --config config.yaml --cache_compress zstd - -# 自定义缓存目录 -dj-process --config config.yaml --ds_cache_dir /fast-storage/dj-cache -``` - -## 缓存目录结构 - -DataJuicer 通过环境变量控制的层级目录结构来组织缓存文件: - -``` -~/.cache/ # CACHE_HOME(默认) -└── data_juicer/ # DATA_JUICER_CACHE_HOME - ├── assets/ # DATA_JUICER_ASSETS_CACHE - │ └── (提取的帧、停用词、标记词等) - └── models/ # DATA_JUICER_MODELS_CACHE - └── (下载的模型文件) -``` - -### 环境变量 - -| 变量 | 默认值 | 描述 | -|------|--------|------| -| `CACHE_HOME` | `~/.cache` | 根缓存目录 | -| `DATA_JUICER_CACHE_HOME` | `$CACHE_HOME/data_juicer` | DataJuicer 缓存根目录 | -| `DATA_JUICER_ASSETS_CACHE` | `$DATA_JUICER_CACHE_HOME/assets` | 资产缓存(帧、词表等) | -| `DATA_JUICER_MODELS_CACHE` | `$DATA_JUICER_CACHE_HOME/models` | 下载的模型缓存 | -| `DATA_JUICER_EXTERNAL_MODELS_HOME` | `None` | 外部模型目录 | - -通过设置环境变量覆盖默认值: - -```bash -export DATA_JUICER_CACHE_HOME=/data/dj-cache -export DATA_JUICER_MODELS_CACHE=/models/dj-models -dj-process --config config.yaml -``` - -## 缓存压缩 - -对于大规模数据集(数十 GB 或更大),缓存文件可能占用大量磁盘空间。缓存压缩通过在每个算子完成后压缩中间缓存文件来减少存储需求。 - -### 支持的算法 - -| 算法 | 依赖库 | 速度 | 压缩率 | 推荐场景 | -|------|--------|------|--------|----------| -| `zstd` | zstandard | 快 | 高 | 通用场景(默认) | -| `lz4` | lz4 | 最快 | 中等 | 速度敏感的工作负载 | -| `gzip` | gzip | 慢 | 高 | 需要兼容性的场景 | - -### 配置 - -```yaml -use_cache: true -cache_compress: zstd # 启用 zstd 压缩 -``` - -```bash -dj-process --config config.yaml --cache_compress zstd -``` - -### 多进程压缩 - -缓存压缩支持并行处理。压缩工作进程数由 `np` 参数控制: - -```yaml -np: 4 # 并行工作进程数(也用于压缩) -cache_compress: zstd -``` - -## 缓存控制 API - -### DatasetCacheControl - -上下文管理器,用于在特定范围内临时启用或禁用 HuggingFace 数据集缓存: - -```python -from data_juicer.utils.cache_utils import DatasetCacheControl - -# 临时禁用缓存 -with DatasetCacheControl(on=False): - # 此处的操作不会使用缓存 - result = dataset.map(my_function) - -# 临时启用缓存 -with DatasetCacheControl(on=True): - # 此处的操作会使用缓存 - result = dataset.map(my_function) -``` - -### dataset_cache_control 装饰器 - -用于需要控制缓存状态的函数的装饰器: - -```python -from data_juicer.utils.cache_utils import dataset_cache_control - -@dataset_cache_control(on=False) -def process_without_cache(dataset): - return dataset.map(my_function) -``` - -### CompressionOff - -上下文管理器,用于临时禁用缓存压缩: - -```python -from data_juicer.utils.compress import CompressionOff - -with CompressionOff(): - # 此范围内缓存压缩被禁用 - result = dataset.map(my_function) -``` - -### CompressManager - -底层 API,用于手动压缩/解压: - -```python -from data_juicer.utils.compress import CompressManager - -manager = CompressManager(compressor_format="zstd") - -# 压缩文件 -manager.compress("input.arrow", "input.arrow.zstd") - -# 解压文件 -manager.decompress("input.arrow.zstd", "input.arrow") -``` - -### CacheCompressManager - -高层 API,用于管理 HuggingFace 数据集缓存压缩: - -```python -from data_juicer.utils.compress import CacheCompressManager - -manager = CacheCompressManager(compressor_format="zstd") - -# 压缩前一个数据集的缓存文件 -manager.compress(prev_ds=previous_dataset, this_ds=current_dataset, num_proc=4) - -# 解压数据集的缓存文件 -manager.decompress(ds=dataset, num_proc=4) - -# 清理所有压缩的缓存文件 -manager.cleanup_cache_files(ds=dataset) -``` - -## 缓存与检查点 - -缓存和检查点是互斥的——启用检查点会自动禁用缓存: - -| 特性 | 缓存 | 检查点 | -|------|------|--------| -| **用途** | 加速相同配置的重复运行 | 故障恢复和断点续跑 | -| **粒度** | 每个算子的结果 | 完整数据集快照 | -| **存储位置** | HuggingFace 缓存目录 | 工作目录 | -| **恢复方式** | 自动(基于哈希) | 手动(基于配置) | -| **压缩** | 支持(`cache_compress`) | 不适用 | -| **场景** | 迭代开发、参数调优 | 长时间运行的生产任务 | - -```yaml -# 缓存模式(默认) -use_cache: true -use_checkpoint: false - -# 检查点模式(缓存自动禁用) -use_cache: true # 将被覆盖为 false -use_checkpoint: true -``` - -## 缓存禁用与临时目录 - -当 `use_cache: false` 或启用检查点(`use_checkpoint: true`)时,HuggingFace 数据集缓存会被完全禁用。此时,DataJuicer 会将算子处理过程中产生的中间文件写入一个临时目录,并在处理完成后自动清理。`temp_dir` 参数用于指定这些中间文件的存放位置。 - -### 行为说明 - -- **默认值为 `null`**:此时由操作系统决定临时目录位置(通常为 `/tmp`),等价于 Python 的 `tempfile.gettempdir()` 返回值。 -- **禁用缓存时自动生效**:只要缓存被禁用,`temp_dir` 就会被设置为 Python `tempfile` 模块的全局临时目录(`tempfile.tempdir`),影响整个进程中所有通过 `tempfile` 创建的临时文件。 -- **缓存压缩自动禁用**:禁用缓存后,`cache_compress` 配置会被自动忽略并置为 `null`。 - -### 配置示例 - -```yaml -use_cache: false -temp_dir: /data/dj-temp # 指定临时目录,null 则由系统决定 -``` - -```bash -dj-process --config config.yaml --use_cache false --temp_dir /data/dj-temp -``` - -### 安全须知 - -> **请谨慎设置 `temp_dir`,错误的路径可能导致不可预期的程序行为。** - -- **不要指向系统关键目录**(如 `/`、`/usr`、`/etc`),因为临时文件的自动清理可能误删重要文件。 -- **不要指向已有重要数据的目录**,临时文件写入和清理操作可能与现有文件产生冲突。 -- **确保目录有足够的磁盘空间**:禁用缓存时,中间文件会在处理过程中动态写入和删除,峰值占用约等于单个算子输出的数据量。 -- **目录不存在时会自动创建**:若指定路径不存在,DataJuicer 会自动调用 `os.makedirs` 创建该目录。 - -## 性能考虑 - -### 何时启用缓存 - -- **启用**:迭代开发中频繁重新运行管道且仅有少量更改时 -- **启用**:算子计算成本高,希望跳过已计算步骤时 -- **禁用**:一次性处理,避免磁盘开销 - -### 何时启用压缩 - -- **启用**:数据集大小超过数十 GB 且磁盘空间有限时 -- **启用** `zstd`:速度和压缩率的最佳平衡 -- **启用** `lz4`:压缩速度至关重要时 -- **禁用**:磁盘空间充足且希望最大化处理速度时 - -## 故障排除 - -**缓存文件占用过多磁盘空间:** -```bash -# 检查缓存目录大小 -du -sh ~/.cache/data_juicer/ - -# 启用压缩 -dj-process --config config.yaml --cache_compress zstd -``` - -**过期缓存导致意外结果:** -```bash -# 清除 HuggingFace 数据集缓存 -rm -rf ~/.cache/huggingface/datasets/ - -# 或指定一个新的缓存目录 -dj-process --config config.yaml --ds_cache_dir /tmp/fresh-cache -``` \ No newline at end of file diff --git a/docs/Export.md b/docs/Export.md deleted file mode 100644 index 144966f680..0000000000 --- a/docs/Export.md +++ /dev/null @@ -1,282 +0,0 @@ -# Dataset Export - -This document describes how DataJuicer exports processed datasets, including supported formats, sharding, parallel export, S3 export, and stats/hash management. - -## Overview - -After processing, DataJuicer exports the result dataset to disk using the `Exporter` (default mode) or `RayExporter` (Ray mode). The export system supports: - -- **Multiple output formats** — JSONL, JSON, Parquet, and more in Ray mode -- **Shard export** — split large datasets into multiple files by size -- **Parallel export** — speed up single-file export with multiprocessing -- **S3 export** — write results directly to Amazon S3 or S3-compatible storage -- **Stats and hash management** — control which intermediate fields are kept in the output - -## Configuration - -### Basic Settings - -```yaml -export_path: ./outputs/result.jsonl # Output file path (required) -export_type: jsonl # Format type (auto-detected from path if omitted) -export_shard_size: 0 # Shard size in bytes (0 = single file) -export_in_parallel: false # Parallel export for single-file mode -keep_stats_in_res_ds: false # Keep computed stats in output -keep_hashes_in_res_ds: false # Keep computed hashes in output -export_extra_args: {} # Additional format-specific arguments -export_aws_credentials: null # For S3 export, see S3 section for details -``` - -### Command Line - -```bash -# Basic export -dj-process --config config.yaml --export_path ./outputs/result.jsonl - -# Export as Parquet -dj-process --config config.yaml --export_path ./outputs/result.parquet - -# Export with sharding (256MB per shard) -dj-process --config config.yaml --export_shard_size 268435456 - -# Keep stats in output -dj-process --config config.yaml --keep_stats_in_res_ds true -``` - -## Supported Formats - -### Default Mode (Exporter) - -| Format | Suffix | Description | -|--------|--------|-------------| -| JSONL | `.jsonl` | JSON Lines — one JSON object per line (default) | -| JSON | `.json` | Standard JSON array | -| Parquet | `.parquet` | Columnar format, efficient for large datasets | - -### Ray Mode (RayExporter) - -| Format | Suffix | Description | -|--------|--------|-------------| -| JSONL | `.jsonl` | JSON Lines | -| JSON | `.json` | Standard JSON | -| Parquet | `.parquet` | Columnar format | -| CSV | `.csv` | Comma-separated values | -| TFRecords | `.tfrecords` | TensorFlow record format | -| WebDataset | `webdataset` | WebDataset tar-based format | -| Lance | `.lance` | Lance columnar format | - -## Shard Export - -For large datasets, split the output into multiple shard files based on size: - -```yaml -export_path: ./outputs/result.jsonl -export_shard_size: 268435456 # 256 MB per shard -``` - -This produces files like: -``` -outputs/ -├── result-00-of-04.jsonl -├── result-01-of-04.jsonl -├── result-02-of-04.jsonl -└── result-03-of-04.jsonl -``` - -**How shard size is calculated:** -1. The total dataset size in bytes is estimated -2. Number of shards = `ceil(dataset_bytes / export_shard_size)` -3. The dataset is split into contiguous shards -4. Each shard is exported in parallel using multiprocessing - -**Recommended shard sizes:** - -| Dataset Size | Recommended Shard Size | Notes | -|-------------|----------------------|-------| -| < 1 GB | 0 (single file) | No need to shard | -| 1-10 GB | 256 MB - 512 MB | Good balance | -| 10-100 GB | 512 MB - 1 GB | Fewer files | -| > 100 GB | 1 GB - 10 GB | Avoid too many shards | - -Shard sizes below 1 MiB or above 1 TiB will trigger warnings. - -## Parallel Export - -For single-file export (`export_shard_size: 0`), enable parallel writing to speed up the process: - -```yaml -export_path: ./outputs/result.jsonl -export_shard_size: 0 -export_in_parallel: true -np: 4 # Number of parallel processes -``` - -**Important**: Parallel export can sometimes be **slower** than sequential export due to IO blocking, especially for very large datasets. If you observe this, set `export_in_parallel: false`. - -When `export_shard_size > 0`, shards are always exported in parallel regardless of this setting. - -## S3 Export - -Export results directly to Amazon S3 or S3-compatible storage. - -### Default Mode - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_aws_credentials: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." - aws_region: "us-east-1" - endpoint_url: "https://s3.example.com" # Optional: for S3-compatible storage -``` - -The default exporter uses HuggingFace's `storage_options` with `fsspec`/`s3fs` for S3 access. - -### Ray Mode - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_extra_args: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." - aws_region: "us-east-1" -``` - -The Ray exporter uses PyArrow's S3 filesystem for S3 access. - -### S3 with Sharding - -When using S3 with shard export, shard files are written directly to S3: - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_shard_size: 268435456 -export_aws_credentials: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." -``` - -This produces S3 objects like: -``` -s3://my-bucket/outputs/result-00-of-04.jsonl -s3://my-bucket/outputs/result-01-of-04.jsonl -... -``` - -### Credential Resolution - -AWS credentials are resolved in priority order: -1. `export_aws_credentials` config (default mode) or `export_extra_args` (Ray mode) -2. Environment variables (`AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`) -3. Default credential chain (IAM role, `~/.aws/credentials`) - -## Stats and Hash Management - -During processing, DataJuicer computes intermediate fields: -- **Stats** (`__dj__stats__`, `__dj__meta__`): computed by Filter operators -- **Hashes** (`__dj__hash__`, `__dj__minhash__`, `__dj__simhash__`, etc.): computed by Deduplicator operators - -By default, these fields are **removed** from the exported dataset. To keep them: - -```yaml -keep_stats_in_res_ds: true # Keep stats and meta fields -keep_hashes_in_res_ds: true # Keep hash fields -``` - -### Stats Export - -Regardless of `keep_stats_in_res_ds`, DataJuicer always exports a separate stats file alongside the main dataset: - -``` -outputs/ -├── result.jsonl # Main dataset (stats removed by default) -└── result_stats.jsonl # Stats-only file (always exported) -``` - -The stats file contains only the `__dj__stats__` and `__dj__meta__` columns. - -## WebDataset Export (Ray Mode) - -In Ray mode, you can export to WebDataset format with custom field mapping: - -```yaml -export_path: ./outputs/webdataset -export_type: webdataset -export_extra_args: - field_mapping: - txt: "text" - png: "images" - json: "metadata" -``` - -## API Reference - -### Exporter (Default Mode) - -```python -from data_juicer.core.exporter import Exporter - -exporter = Exporter( - export_path="./outputs/result.jsonl", - export_type="jsonl", - export_shard_size=0, - export_in_parallel=True, - num_proc=4, - keep_stats_in_res_ds=False, - keep_hashes_in_res_ds=False, -) - -exporter.export(dataset) -``` - -### RayExporter (Ray Mode) - -```python -from data_juicer.core.ray_exporter import RayExporter - -exporter = RayExporter( - export_path="./outputs/result.jsonl", - export_type="jsonl", - export_shard_size=268435456, - keep_stats_in_res_ds=False, - keep_hashes_in_res_ds=False, -) - -exporter.export(ray_dataset) -``` - -## Troubleshooting - -**Export format not supported:** -```bash -# Check supported formats -# Default mode: jsonl, json, parquet -# Ray mode: jsonl, json, parquet, csv, tfrecords, webdataset, lance -``` - -**Parallel export is slower than expected:** -```yaml -# Disable parallel export -export_in_parallel: false -``` - -**S3 export fails with permission error:** -```bash -# Verify credentials -aws s3 ls s3://your-bucket/ - -# Check that export_aws_credentials is configured -``` - -**Too many shard files generated:** -```yaml -# Increase shard size -export_shard_size: 1073741824 # 1 GB -``` - -**Stats missing from exported dataset:** -```yaml -# Keep stats in the result dataset -keep_stats_in_res_ds: true -# Or check the separate stats file: result_stats.jsonl -``` diff --git a/docs/Export_ZH.md b/docs/Export_ZH.md deleted file mode 100644 index 8f04dd30fb..0000000000 --- a/docs/Export_ZH.md +++ /dev/null @@ -1,282 +0,0 @@ -# 数据集导出 - -本文档描述 DataJuicer 如何导出处理后的数据集,包括支持的格式、分片、并行导出、S3 导出以及统计信息/哈希管理。 - -## 概述 - -处理完成后,DataJuicer 使用 `Exporter`(默认模式)或 `RayExporter`(Ray 模式)将结果数据集导出到磁盘。导出系统支持: - -- **多种输出格式** — JSONL、JSON、Parquet,Ray 模式下支持更多格式 -- **分片导出** — 按大小将大型数据集拆分为多个文件 -- **并行导出** — 使用多进程加速单文件导出 -- **S3 导出** — 将结果直接写入 Amazon S3 或 S3 兼容存储 -- **统计信息和哈希管理** — 控制输出中保留哪些中间字段 - -## 配置 - -### 基本设置 - -```yaml -export_path: ./outputs/result.jsonl # 输出文件路径(必需) -export_type: jsonl # 格式类型(省略时从路径自动检测) -export_shard_size: 0 # 分片大小(字节),0 = 单文件 -export_in_parallel: false # 单文件模式下的并行导出 -keep_stats_in_res_ds: false # 在输出中保留计算的统计信息 -keep_hashes_in_res_ds: false # 在输出中保留计算的哈希值 -export_extra_args: {} # 额外的格式特定参数 -export_aws_credentials: null # S3 导出专用,详见 S3 导出章节 -``` - -### 命令行 - -```bash -# 基本导出 -dj-process --config config.yaml --export_path ./outputs/result.jsonl - -# 导出为 Parquet -dj-process --config config.yaml --export_path ./outputs/result.parquet - -# 分片导出(每片 256MB) -dj-process --config config.yaml --export_shard_size 268435456 - -# 在输出中保留统计信息 -dj-process --config config.yaml --keep_stats_in_res_ds true -``` - -## 支持的格式 - -### 默认模式(Exporter) - -| 格式 | 后缀 | 描述 | -|------|------|------| -| JSONL | `.jsonl` | JSON Lines — 每行一个 JSON 对象(默认) | -| JSON | `.json` | 标准 JSON 数组 | -| Parquet | `.parquet` | 列式格式,适合大型数据集 | - -### Ray 模式(RayExporter) - -| 格式 | 后缀 | 描述 | -|------|------|------| -| JSONL | `.jsonl` | JSON Lines | -| JSON | `.json` | 标准 JSON | -| Parquet | `.parquet` | 列式格式 | -| CSV | `.csv` | 逗号分隔值 | -| TFRecords | `.tfrecords` | TensorFlow 记录格式 | -| WebDataset | `webdataset` | WebDataset tar 格式 | -| Lance | `.lance` | Lance 列式格式 | - -## 分片导出 - -对于大型数据集,按大小将输出拆分为多个分片文件: - -```yaml -export_path: ./outputs/result.jsonl -export_shard_size: 268435456 # 每片 256 MB -``` - -生成的文件如下: -``` -outputs/ -├── result-00-of-04.jsonl -├── result-01-of-04.jsonl -├── result-02-of-04.jsonl -└── result-03-of-04.jsonl -``` - -**分片大小计算方式:** -1. 估算数据集的总字节大小 -2. 分片数 = `ceil(dataset_bytes / export_shard_size)` -3. 数据集被拆分为连续的分片 -4. 每个分片使用多进程并行导出 - -**推荐的分片大小:** - -| 数据集大小 | 推荐分片大小 | 说明 | -|-----------|-------------|------| -| < 1 GB | 0(单文件) | 无需分片 | -| 1-10 GB | 256 MB - 512 MB | 良好平衡 | -| 10-100 GB | 512 MB - 1 GB | 更少文件 | -| > 100 GB | 1 GB - 10 GB | 避免过多分片 | - -分片大小低于 1 MiB 或高于 1 TiB 将触发警告。 - -## 并行导出 - -对于单文件导出(`export_shard_size: 0`),启用并行写入以加速导出过程: - -```yaml -export_path: ./outputs/result.jsonl -export_shard_size: 0 -export_in_parallel: true -np: 4 # 并行进程数 -``` - -**重要提示**:并行导出有时可能比顺序导出**更慢**,因为 IO 阻塞,特别是对于非常大的数据集。如果观察到这种情况,请设置 `export_in_parallel: false`。 - -当 `export_shard_size > 0` 时,无论此设置如何,分片始终并行导出。 - -## S3 导出 - -将结果直接导出到 Amazon S3 或 S3 兼容存储。 - -### 默认模式 - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_aws_credentials: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." - aws_region: "us-east-1" - endpoint_url: "https://s3.example.com" # 可选:用于 S3 兼容存储 -``` - -默认导出器使用 HuggingFace 的 `storage_options` 配合 `fsspec`/`s3fs` 进行 S3 访问。 - -### Ray 模式 - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_extra_args: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." - aws_region: "us-east-1" -``` - -Ray 导出器使用 PyArrow 的 S3 文件系统进行 S3 访问。 - -### S3 分片导出 - -使用 S3 进行分片导出时,分片文件直接写入 S3: - -```yaml -export_path: "s3://my-bucket/outputs/result.jsonl" -export_shard_size: 268435456 -export_aws_credentials: - aws_access_key_id: "AKIA..." - aws_secret_access_key: "secret..." -``` - -生成的 S3 对象如下: -``` -s3://my-bucket/outputs/result-00-of-04.jsonl -s3://my-bucket/outputs/result-01-of-04.jsonl -... -``` - -### 凭证解析 - -AWS 凭证按以下优先级解析: -1. `export_aws_credentials` 配置(默认模式)或 `export_extra_args`(Ray 模式) -2. 环境变量(`AWS_ACCESS_KEY_ID`、`AWS_SECRET_ACCESS_KEY`) -3. 默认凭证链(IAM 角色、`~/.aws/credentials`) - -## 统计信息和哈希管理 - -在处理过程中,DataJuicer 会计算中间字段: -- **统计信息**(`__dj__stats__`、`__dj__meta__`):由 Filter 算子计算 -- **哈希值**(`__dj__hash__`、`__dj__minhash__`、`__dj__simhash__` 等):由 Deduplicator 算子计算 - -默认情况下,这些字段会从导出的数据集中**移除**。要保留它们: - -```yaml -keep_stats_in_res_ds: true # 保留统计信息和元数据字段 -keep_hashes_in_res_ds: true # 保留哈希字段 -``` - -### 统计信息导出 - -无论 `keep_stats_in_res_ds` 如何设置,DataJuicer 始终会在主数据集旁边导出一个单独的统计信息文件: - -``` -outputs/ -├── result.jsonl # 主数据集(默认移除统计信息) -└── result_stats.jsonl # 仅统计信息文件(始终导出) -``` - -统计信息文件仅包含 `__dj__stats__` 和 `__dj__meta__` 列。 - -## WebDataset 导出(Ray 模式) - -在 Ray 模式下,可以使用自定义字段映射导出为 WebDataset 格式: - -```yaml -export_path: ./outputs/webdataset -export_type: webdataset -export_extra_args: - field_mapping: - txt: "text" - png: "images" - json: "metadata" -``` - -## API 参考 - -### Exporter(默认模式) - -```python -from data_juicer.core.exporter import Exporter - -exporter = Exporter( - export_path="./outputs/result.jsonl", - export_type="jsonl", - export_shard_size=0, - export_in_parallel=True, - num_proc=4, - keep_stats_in_res_ds=False, - keep_hashes_in_res_ds=False, -) - -exporter.export(dataset) -``` - -### RayExporter(Ray 模式) - -```python -from data_juicer.core.ray_exporter import RayExporter - -exporter = RayExporter( - export_path="./outputs/result.jsonl", - export_type="jsonl", - export_shard_size=268435456, - keep_stats_in_res_ds=False, - keep_hashes_in_res_ds=False, -) - -exporter.export(ray_dataset) -``` - -## 故障排除 - -**导出格式不支持:** -```bash -# 检查支持的格式 -# 默认模式:jsonl, json, parquet -# Ray 模式:jsonl, json, parquet, csv, tfrecords, webdataset, lance -``` - -**并行导出比预期慢:** -```yaml -# 禁用并行导出 -export_in_parallel: false -``` - -**S3 导出权限错误:** -```bash -# 验证凭证 -aws s3 ls s3://your-bucket/ - -# 检查 export_aws_credentials 是否已配置 -``` - -**生成的分片文件过多:** -```yaml -# 增大分片大小 -export_shard_size: 1073741824 # 1 GB -``` - -**导出的数据集中缺少统计信息:** -```yaml -# 在结果数据集中保留统计信息 -keep_stats_in_res_ds: true -# 或检查单独的统计信息文件:result_stats.jsonl -``` diff --git a/docs/Operators.md b/docs/Operators.md index 2418e8e0af..7a7918ec3b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -46,7 +46,7 @@ Data-Juicer 中的算子分为以下 8 种类型。 | [filter](#filter) | 56 | Filters out low-quality samples. 过滤低质量样本。 | | [formatter](#formatter) | 8 | Discovers, loads, and canonicalizes source data. 发现、加载、规范化原始数据。 | | [grouper](#grouper) | 3 | Group samples to batched samples. 将样本分组,每一组组成一个批量样本。 | -| [mapper](#mapper) | 105 | Edits and transforms samples. 对数据样本进行编辑和转换。 | +| [mapper](#mapper) | 103 | Edits and transforms samples. 对数据样本进行编辑和转换。 | | [pipeline](#pipeline) | 3 | Applies dataset-level processing; both input and output are datasets. 执行数据集级别的操作,输入和输出均为完整数据集。 | | [selector](#selector) | 5 | Selects top samples based on ranking. 基于排序选取高质量样本。 | @@ -224,8 +224,6 @@ All the specific operators are listed below, each featured with several capabili | image_tagging_vlm_mapper | 🔮Multimodal 🚀GPU 🔗API 🌊vLLM 🟡Beta | Mapper to generates image tags. 映射器生成图像标签。 | [info](operators/mapper/image_tagging_vlm_mapper.md) | - | | imgdiff_difference_area_generator_mapper | 🚀GPU 🟡Beta | Generates and filters bounding boxes for image pairs based on similarity, segmentation, and text matching. 根据相似性、分割和文本匹配生成和过滤图像对的边界框。 | [info](operators/mapper/imgdiff_difference_area_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) | | imgdiff_difference_caption_generator_mapper | 🚀GPU 🟡Beta | Generates difference captions for bounding box regions in two images. 为两个图像中的边界框区域生成差异字幕。 | [info](operators/mapper/imgdiff_difference_caption_generator_mapper.md) | [ImgDiff](https://arxiv.org/abs/2408.04594) | -| latex_figure_context_extractor_mapper | 🔮Multimodal 💻CPU 🟡Beta | A subfigure within a figure environment. 图环境中的子图。 | [info](operators/mapper/latex_figure_context_extractor_mapper.md) | - | -| latex_merge_tex_mapper | 🔤Text 💻CPU 🟡Beta | Extracts and concatenates all ``.tex`` files from a compressed LaTeX project archive into a single text field. 将压缩的LaTeX项目存档中的所有 “.tex'' 文件提取并连接到单个文本字段中。 | [info](operators/mapper/latex_merge_tex_mapper.md) | - | | mllm_mapper | 🔮Multimodal 🚀GPU 🧩HF 🟢Stable | Mapper to use MLLMs for visual question answering tasks. Mapper使用MLLMs进行视觉问答任务。 | [info](operators/mapper/mllm_mapper.md) | - | | nlpaug_en_mapper | 🔤Text 💻CPU 🟢Stable | Augments English text samples using various methods from the nlpaug library. 使用nlpaug库中的各种方法增强英语文本样本。 | [info](operators/mapper/nlpaug_en_mapper.md) | - | | nlpcda_zh_mapper | 🔤Text 💻CPU 🟢Stable | Augments Chinese text samples using the nlpcda library. 使用nlpcda库扩充中文文本样本。 | [info](operators/mapper/nlpcda_zh_mapper.md) | - | diff --git a/docs/Tracing.md b/docs/Tracing.md deleted file mode 100644 index 9c366b3ea2..0000000000 --- a/docs/Tracing.md +++ /dev/null @@ -1,250 +0,0 @@ -# Data Tracing - -This document describes DataJuicer's tracing system for tracking sample-level changes during data processing. - -## Overview - -The Tracer records how each operator modifies, filters, or deduplicates individual samples in the processing pipeline. This is useful for: - -- **Debugging** — Understanding why specific samples were modified or removed -- **Quality Assurance** — Verifying operators are working as expected -- **Auditing** — Maintaining records of data transformations - -## Configuration - -### Basic Settings - -```yaml -open_tracer: false # Enable/disable tracing -op_list_to_trace: [] # List of operators to trace (empty = all operators) -trace_num: 10 # Maximum number of samples to collect per operator -trace_keys: [] # Additional fields to include in trace output -``` - -### Command Line - -```bash -# Enable tracing for all operators -dj-process --config config.yaml --open_tracer true - -# Trace only specific operators -dj-process --config config.yaml --open_tracer true \ - --op_list_to_trace clean_email_mapper,words_num_filter - -# Collect more samples per operator -dj-process --config config.yaml --open_tracer true --trace_num 50 - -# Include additional fields in trace output -dj-process --config config.yaml --open_tracer true \ - --trace_keys sample_id,source_file -``` - -## Output Structure - -Trace results are stored in the `trace/` subdirectory of the work directory: - -``` -{work_dir}/ -└── trace/ - ├── sample_trace-clean_email_mapper.jsonl - ├── sample_trace-words_num_filter.jsonl - ├── duplicate-document_deduplicator.jsonl - └── ... -``` - -Each trace file is in JSONL format (one JSON object per line), with content varying by operator type. - -## Traced Operator Types - -### Mapper Tracing - -For Mapper operators, the Tracer records samples where text content changes. Each record contains: - -| Field | Description | -|-------|-------------| -| `original_text` | Text before Mapper processing | -| `processed_text` | Text after Mapper processing | -| *trace_keys fields* | Values corresponding to configured `trace_keys` | - -Example output (`sample_trace-clean_email_mapper.jsonl`): -```json -{"original_text":"Contact us at user@example.com for details.","processed_text":"Contact us at for details."} -{"original_text": "Email: admin@test.org", "processed_text": "Email: "} -``` - -Only samples with actual text changes are collected; unchanged samples are skipped. - -### Filter Tracing - -For Filter operators, the Tracer records samples that are **filtered out** (removed). Each record contains the complete sample data. - -Example output (`sample_trace-words_num_filter.jsonl`): -```json -{"text": "Too short.", "__dj__stats__": {"words_num": 2}} -{"text": "Also brief.", "__dj__stats__": {"words_num": 2}} -``` - -Only samples that fail the filter are collected; samples passing the filter are skipped. - -### Deduplicator Tracing - -For Deduplicator operators, the Tracer records pairs of near-duplicate samples. Each record contains: - -| Field | Description | -|-------|-------------| -| `dup1` | First sample in the duplicate pair | -| `dup2` | Second sample in the duplicate pair | - -Example output (`duplicate-document_deduplicator.jsonl`): -```json -{"dup1": "This is a duplicate text.", "dup2": "This is a duplicate text."} -``` - -## Sample Collection Behavior - -The Tracer uses an efficient **sample-level collection** approach: - -1. Each operator collects at most `trace_num` samples during processing -2. Collection stops early once enough samples are gathered -3. In default mode, collection is **thread-safe** using multiprocess locks -4. In Ray mode, each Worker has its own Tracer instance (no locking needed) - -This design minimizes performance overhead — the Tracer does not compare the entire dataset, but captures changes in real-time during processing. - -## trace_keys - -The `trace_keys` option allows including additional fields from original samples in the trace output. This is useful for identifying which samples were affected: - -```yaml -open_tracer: true -trace_keys: - - sample_id - - source_file -``` - -With this configuration, Mapper trace entries will include: -```json -{ - "sample_id": "doc_00042", - "source_file": "corpus_part1.jsonl", - "original_text": "Original content...", - "processed_text": "Processed content..." -} -``` - -## API Reference - -### Tracer (Default Mode) - -```python -from data_juicer.core.tracer import Tracer - -tracer = Tracer( - work_dir="./outputs", - op_list_to_trace=["clean_email_mapper", "words_num_filter"], - show_num=10, - trace_keys=["sample_id"] -) - -# Check if an operator should be traced -tracer.should_trace_op("clean_email_mapper") # True - -# Check if enough samples have been collected -tracer.is_collection_complete("clean_email_mapper") # False - -# Collect Mapper sample -tracer.collect_mapper_sample( - op_name="clean_email_mapper", - original_sample={"text": "Email: a@b.com"}, - processed_sample={"text": "Email: "}, - text_key="text" -) - -# Collect Filter sample -tracer.collect_filter_sample( - op_name="words_num_filter", - sample={"text": "too short"}, - should_keep=False -) -``` - -### RayTracer (Distributed Mode) - -```python -from data_juicer.core.tracer.ray_tracer import RayTracer - -# RayTracer is a Ray Actor — created via Ray -tracer = RayTracer.remote( - work_dir="./outputs", - op_list_to_trace=None, # Trace all operators - show_num=10, - trace_keys=["sample_id"] -) - -# Remote method calls -ray.get(tracer.collect_mapper_sample.remote( - op_name="clean_email_mapper", - original_sample={"text": "Email: a@b.com"}, - processed_sample={"text": "Email: "}, - text_key="text" -)) - -# Finalize and export all trace results -ray.get(tracer.finalize_traces.remote()) -``` - -### Helper Functions - -The `data_juicer.core.tracer` module provides mode-agnostic helper functions: - -```python -from data_juicer.core.tracer import ( - should_trace_op, - check_tracer_collect_complete, - collect_for_mapper, - collect_for_filter, -) - -# These functions automatically handle default mode and Ray mode -should_trace_op(tracer_instance, "clean_email_mapper") -check_tracer_collect_complete(tracer_instance, "clean_email_mapper") -collect_for_mapper(tracer_instance, "op_name", original, processed, "text") -collect_for_filter(tracer_instance, "op_name", sample, should_keep=False) -``` - -## Performance Considerations - -### Overhead - -- When `trace_num` is small (default: 10), the additional overhead of tracing is minimal -- Once an operator has collected `trace_num` samples, no further collection occurs -- The main cost is comparing original and processed text in Mappers - -### Recommendations - -| Scenario | Recommended Settings | -|----------|----------------------| -| Development/Debugging | `open_tracer: true`, `trace_num: 10-50` | -| Production Runs | `open_tracer: false` | -| Auditing Specific Operators | `open_tracer: true`, `op_list_to_trace: [specific operators]` | -| Large-scale Tracing | `open_tracer: true`, `trace_num: 100`, specify `op_list_to_trace` | - -## Troubleshooting - -**No trace files generated:** -```bash -# Verify tracer is enabled -grep "open_tracer" config.yaml - -# Check if trace directory exists -ls -la ./outputs/{work_dir}/trace/ -``` - -**Trace files are empty:** -- For Mapper: The operator may not have modified any samples -- For Filter: The operator may not have filtered out any samples -- Check logs for warnings like "Datasets before and after op [X] are all the same" - -**Too few samples in trace files:** -- Increase `trace_num` to collect more samples -- There may be fewer than `trace_num` changed/filtered samples in the dataset diff --git a/docs/Tracing_ZH.md b/docs/Tracing_ZH.md deleted file mode 100644 index ac44fd4ec2..0000000000 --- a/docs/Tracing_ZH.md +++ /dev/null @@ -1,250 +0,0 @@ -# 数据追踪 - -本文档描述 DataJuicer 的追踪系统,用于跟踪数据处理过程中样本级别的变化。 - -## 概述 - -Tracer 记录每个算子在处理管道中如何修改、过滤或去重各个样本。这对以下场景非常有用: - -- **调试** — 理解特定样本为何被修改或删除 -- **质量保证** — 验证算子是否按预期工作 -- **审计** — 维护数据转换的记录 - -## 配置 - -### 基本设置 - -```yaml -open_tracer: false # 启用/禁用追踪 -op_list_to_trace: [] # 要追踪的算子列表(空 = 所有算子) -trace_num: 10 # 每个算子最多收集的样本数 -trace_keys: [] # 追踪输出中包含的额外字段 -``` - -### 命令行 - -```bash -# 启用所有算子的追踪 -dj-process --config config.yaml --open_tracer true - -# 仅追踪特定算子 -dj-process --config config.yaml --open_tracer true \ - --op_list_to_trace clean_email_mapper,words_num_filter - -# 每个算子收集更多样本 -dj-process --config config.yaml --open_tracer true --trace_num 50 - -# 在追踪输出中包含额外字段 -dj-process --config config.yaml --open_tracer true \ - --trace_keys sample_id,source_file -``` - -## 输出结构 - -追踪结果存储在工作目录的 `trace/` 子目录中: - -``` -{work_dir}/ -└── trace/ - ├── sample_trace-clean_email_mapper.jsonl - ├── sample_trace-words_num_filter.jsonl - ├── duplicate-document_deduplicator.jsonl - └── ... -``` - -每个追踪文件为 JSONL 格式(每行一个 JSON 对象),内容因算子类型而异。 - -## 追踪的算子类型 - -### Mapper 追踪 - -对于 Mapper 算子,Tracer 记录文本内容发生变化的样本。每条记录包含: - -| 字段 | 描述 | -|------|------| -| `original_text` | Mapper 处理前的文本 | -| `processed_text` | Mapper 处理后的文本 | -| *trace_keys 字段* | 配置的 `trace_keys` 对应的值 | - -输出示例(`sample_trace-clean_email_mapper.jsonl`): -```json -{"original_text":"联系我们 user@example.com 获取详情。","processed_text":"联系我们 获取详情。"} -{"original_text": "邮箱:admin@test.org", "processed_text": "邮箱:"} -``` - -仅收集文本实际发生变化的样本,未变化的样本会被跳过。 - -### Filter 追踪 - -对于 Filter 算子,Tracer 记录被**过滤掉**(删除)的样本。每条记录包含完整的样本数据。 - -输出示例(`sample_trace-words_num_filter.jsonl`): -```json -{"text": "Too short.", "__dj__stats__": {"words_num": 2}} -{"text": "Also brief.", "__dj__stats__": {"words_num": 2}} -``` - -仅收集被过滤掉的样本,通过过滤器的样本会被跳过。 - -### Deduplicator 追踪 - -对于 Deduplicator 算子,Tracer 记录近似重复的样本对。每条记录包含: - -| 字段 | 描述 | -|------|------| -| `dup1` | 重复对中的第一个样本 | -| `dup2` | 重复对中的第二个样本 | - -输出示例(`duplicate-document_deduplicator.jsonl`): -```json -{"dup1": "这是一段重复的文本。", "dup2": "这是一段重复的文本。"} -``` - -## 样本收集行为 - -Tracer 使用高效的**样本级收集**方式: - -1. 每个算子在处理过程中最多收集 `trace_num` 个样本 -2. 收集到足够样本后提前停止 -3. 在默认模式下,收集是**线程安全**的,使用多进程锁 -4. 在 Ray 模式下,每个 Worker 有自己的 Tracer 实例(无需加锁) - -这种设计最大限度地减少了性能开销——Tracer 不会比较整个数据集,而是在处理过程中实时捕获变化。 - -## trace_keys - -`trace_keys` 选项允许在追踪输出中包含原始样本的额外字段。这对于识别哪些样本受到影响非常有用: - -```yaml -open_tracer: true -trace_keys: - - sample_id - - source_file -``` - -使用此配置,Mapper 追踪条目将包含: -```json -{ - "sample_id": "doc_00042", - "source_file": "corpus_part1.jsonl", - "original_text": "原始内容...", - "processed_text": "处理后的内容..." -} -``` - -## API 参考 - -### Tracer(默认模式) - -```python -from data_juicer.core.tracer import Tracer - -tracer = Tracer( - work_dir="./outputs", - op_list_to_trace=["clean_email_mapper", "words_num_filter"], - show_num=10, - trace_keys=["sample_id"] -) - -# 检查某个算子是否需要追踪 -tracer.should_trace_op("clean_email_mapper") # True - -# 检查是否已收集足够的样本 -tracer.is_collection_complete("clean_email_mapper") # False - -# 收集 Mapper 样本 -tracer.collect_mapper_sample( - op_name="clean_email_mapper", - original_sample={"text": "邮箱:a@b.com"}, - processed_sample={"text": "邮箱:"}, - text_key="text" -) - -# 收集 Filter 样本 -tracer.collect_filter_sample( - op_name="words_num_filter", - sample={"text": "太短"}, - should_keep=False -) -``` - -### RayTracer(分布式模式) - -```python -from data_juicer.core.tracer.ray_tracer import RayTracer - -# RayTracer 是一个 Ray Actor — 通过 Ray 创建 -tracer = RayTracer.remote( - work_dir="./outputs", - op_list_to_trace=None, # 追踪所有算子 - show_num=10, - trace_keys=["sample_id"] -) - -# 远程方法调用 -ray.get(tracer.collect_mapper_sample.remote( - op_name="clean_email_mapper", - original_sample={"text": "邮箱:a@b.com"}, - processed_sample={"text": "邮箱:"}, - text_key="text" -)) - -# 最终化并导出所有追踪结果 -ray.get(tracer.finalize_traces.remote()) -``` - -### 辅助函数 - -`data_juicer.core.tracer` 模块提供了模式无关的辅助函数: - -```python -from data_juicer.core.tracer import ( - should_trace_op, - check_tracer_collect_complete, - collect_for_mapper, - collect_for_filter, -) - -# 这些函数自动处理默认模式和 Ray 模式 -should_trace_op(tracer_instance, "clean_email_mapper") -check_tracer_collect_complete(tracer_instance, "clean_email_mapper") -collect_for_mapper(tracer_instance, "op_name", original, processed, "text") -collect_for_filter(tracer_instance, "op_name", sample, should_keep=False) -``` - -## 性能考虑 - -### 开销 - -- 当 `trace_num` 较小时(默认:10),追踪的额外开销极小 -- 一旦某个算子收集了 `trace_num` 个样本,就不再进行进一步收集 -- 主要成本是 Mapper 中原始文本与处理后文本的比较 - -### 建议 - -| 场景 | 推荐设置 | -|------|----------| -| 开发/调试 | `open_tracer: true`,`trace_num: 10-50` | -| 生产运行 | `open_tracer: false` | -| 审计特定算子 | `open_tracer: true`,`op_list_to_trace: [特定算子]` | -| 大规模追踪 | `open_tracer: true`,`trace_num: 100`,指定 `op_list_to_trace` | - -## 故障排除 - -**没有生成追踪文件:** -```bash -# 验证追踪器是否启用 -grep "open_tracer" config.yaml - -# 检查追踪目录是否存在 -ls -la ./outputs/{work_dir}/trace/ -``` - -**追踪文件为空:** -- 对于 Mapper:算子可能没有修改任何样本 -- 对于 Filter:算子可能没有过滤掉任何样本 -- 检查日志中是否有类似 "Datasets before and after op [X] are all the same" 的警告 - -**追踪文件中样本太少:** -- 增加 `trace_num` 以收集更多样本 -- 数据集中变化/过滤的样本可能少于 `trace_num` diff --git a/docs/operators/mapper/latex_figure_context_extractor_mapper.md b/docs/operators/mapper/latex_figure_context_extractor_mapper.md deleted file mode 100644 index a9a062b554..0000000000 --- a/docs/operators/mapper/latex_figure_context_extractor_mapper.md +++ /dev/null @@ -1,90 +0,0 @@ -# latex_figure_context_extractor_mapper - -Extracts figures and their citing context from LaTeX source. - -This operator parses figure environments from a paper's LaTeX source, extracts each figure's caption, label, and image path(s), and finds the prose paragraphs that cite each figure. It fans out one paper row into N figure rows (one per figure or subfigure). **Samples that contain no figures with images are dropped from the output.** Supported figure environments: `figure`, `figure*`, `wrapfigure`, `subfigure` (environment), `\subfigure` (command), `\subfloat` (command, subfig package). Supported caption commands: `\caption`, `\caption*`, `\subcaption`, `\captionof{figure}`. Figures without `\includegraphics` are skipped. Subfigures inherit citing paragraphs from their parent figure's label. When building citing paragraphs, float/display environments (figures, tables, tabulars, equations, algorithms, etc.) are stripped so only prose text is searched. - -> **Note:** This operator expects the full LaTeX source as a single string. It does **not** resolve `\input` or `\include` directives. If your documents span multiple `.tex` files, concatenate them into a single text field before applying this mapper. - -从LaTeX源码中提取图片及其引用上下文。 - -该算子解析论文LaTeX源码中的figure环境,提取每个图片的标题、标签和图片路径,并找到引用该图片的段落文本。它将一行论文数据展开为N行图片数据(每个图片或子图一行)。**不包含带图片的figure环境的样本将被丢弃。** 支持的图片环境:`figure`、`figure*`、`wrapfigure`、`subfigure`(环境)、`\subfigure`(命令)、`\subfloat`(命令,subfig宏包)。支持的标题命令:`\caption`、`\caption*`、`\subcaption`、`\captionof{figure}`。没有`\includegraphics`的图片会被跳过。子图会继承父图标签的引用段落。构建引用段落时,浮动/展示环境(图片、表格、公式、算法等)会被去除,仅在正文文本中搜索。 - -> **注意:** 该算子要求完整的LaTeX源码作为单个字符串输入。它**不会**解析`\input`或`\include`指令。如果您的文档分散在多个`.tex`文件中,请在使用此算子之前将它们合并到一个文本字段中。 - -Type 算子类型: **mapper** - -Tags 标签: cpu, text - -## 🔧 Parameter Configuration 参数配置 -| name 参数名 | type 类型 | default 默认值 | desc 说明 | -|--------|------|--------|------| -| `citation_commands` | `list` | `['\ref', '\cref', '\Cref', '\autoref']` | LaTeX reference commands to search for when finding citing paragraphs. | -| `paragraph_separator` | `str` | `'\n\n'` | Pattern for splitting LaTeX text into paragraphs. | -| `caption_key` | `str` | `'caption'` | Output field name for the figure caption. | -| `label_key` | `str` | `'label'` | Output field name for the LaTeX label. | -| `context_key` | `str` | `'citing_paragraphs'` | Output field name for citing paragraphs. | -| `parent_caption_key` | `str` | `'parent_caption'` | Output field name for the parent figure's caption. For subfigures this carries the parent figure environment's caption; empty for standalone figures. | -| `parent_label_key` | `str` | `'parent_label'` | Output field name for the parent figure's label. Useful for grouping subfigures that belong to the same figure environment; empty for standalone figures. | -| `args` | | `''` | extra args | -| `kwargs` | | `''` | extra args | - -## 📤 Output Fields 输出字段 - -In addition to all input fields, each output row contains: - -除所有输入字段外,每行输出还包含: - -| field 字段 | type 类型 | desc 说明 | -|-------|------|------| -| `images` (or custom `image_key`) | `list[str]` | Image paths from `\includegraphics`. 从`\includegraphics`提取的图片路径。 | -| `caption` (or custom `caption_key`) | `str` | Figure caption text. 图片标题文本。 | -| `label` (or custom `label_key`) | `str` | LaTeX label string. LaTeX标签字符串。 | -| `citing_paragraphs` (or custom `context_key`) | `list[str]` | Paragraphs that cite this figure. 引用该图片的段落。 | -| `parent_caption` (or custom `parent_caption_key`) | `str` | Parent figure caption (subfigures only; empty for standalone). 父图标题(仅子图;独立图为空)。 | -| `parent_label` (or custom `parent_label_key`) | `str` | Parent figure label (subfigures only; empty for standalone). 父图标签(仅子图;独立图为空)。 | - -## 📊 Effect demonstration 效果演示 -### test_single_figure -```python -LatexFigureContextExtractorMapper() -``` - -#### 📥 input data 输入数据 - -A LaTeX document with a single figure: - -```latex -\begin{document} -Some intro text. - -As shown in \ref{fig:arch}, the architecture is novel. - -\begin{figure} -\centering -\includegraphics[width=0.8\linewidth]{img/arch.pdf} -\caption{Overall architecture} -\label{fig:arch} -\end{figure} -\end{document} -``` - -#### 📤 output data 输出数据 - -One row is produced: -- `caption`: `"Overall architecture"` -- `label`: `"fig:arch"` -- `images`: `["img/arch.pdf"]` -- `citing_paragraphs`: `["As shown in \\ref{fig:arch}, the architecture is novel."]` -- `parent_caption`: `""` (standalone figure, no parent) -- `parent_label`: `""` (standalone figure, no parent) - -#### ✨ explanation 解释 -The operator extracts the figure environment, parses its caption, label, and image path, then searches the document paragraphs (with float/display/tabular environments stripped) for any paragraph containing `\ref{fig:arch}`. The matching paragraph is returned as the citing context. - -算子提取figure环境,解析其标题、标签和图片路径,然后在文档段落中(去除浮动/展示/表格环境后)搜索包含`\ref{fig:arch}`的段落。匹配的段落作为引用上下文返回。 - -## 🔗 related links 相关链接 -- [source code 源代码](../../../data_juicer/ops/mapper/latex_figure_context_extractor_mapper.py) -- [unit test 单元测试](../../../tests/ops/mapper/test_latex_figure_context_extractor_mapper.py) -- [Return operator list 返回算子列表](../../Operators.md) diff --git a/docs/operators/mapper/latex_merge_tex_mapper.md b/docs/operators/mapper/latex_merge_tex_mapper.md deleted file mode 100644 index 65ed1642f6..0000000000 --- a/docs/operators/mapper/latex_merge_tex_mapper.md +++ /dev/null @@ -1,31 +0,0 @@ -# latex_merge_tex_mapper - -Extracts and concatenates all `.tex` files from a compressed LaTeX project archive into a single text field. - -Supported archive formats: `.tar`, `.tar.gz` / `.tgz`, and `.zip`. Plain `.gz` (single-file gzip) is not supported because gzip archives carry no filename metadata, making it impossible to verify that the content is actually a `.tex` file. All `.tex` files found inside the archive are read in-memory and joined with a configurable separator. No ordering or deduplication is applied. - -This operator is typically placed before LaTeX-processing operators such as `remove_comments_mapper`, `expand_macro_mapper`, or `latex_figure_context_extractor_mapper`. - -从压缩的 LaTeX 项目归档文件中提取并拼接所有 `.tex` 文件到一个文本字段中。 - -支持的归档格式:`.tar`、`.tar.gz` / `.tgz` 以及 `.zip`。不支持单独的 `.gz`(单文件 gzip),因为 gzip 格式不包含文件名元数据,无法验证内容是否为 `.tex` 文件。归档中所有 `.tex` 文件会被读入内存,并使用可配置的分隔符拼接。不会进行排序或去重。 - -该算子通常放置在 LaTeX 处理算子(如 `remove_comments_mapper`、`expand_macro_mapper` 或 `latex_figure_context_extractor_mapper`)之前。 - -Type 算子类型: **mapper** - -Tags 标签: cpu, text - -## 🔧 Parameter Configuration 参数配置 - -| name 参数名 | type 类型 | default 默认值 | desc 说明 | -|--------|------|--------|------| -| `compressed_file_key` | `str` | `'compressed_file'` | Field name that stores the archive file path. 存储归档文件路径的字段名。 | -| `separator` | `str` | `'\n\n'` | String used to join the contents of multiple `.tex` files. 用于拼接多个 `.tex` 文件内容的分隔符。 | -| `max_file_size` | `int` | `52428800` (50 MB) | Maximum allowed uncompressed size in bytes for a single `.tex` entry inside the archive. Entries exceeding this limit are skipped with a warning. Set to `0` to disable the check. 单个 `.tex` 条目允许的最大解压大小(字节)。超过此限制的条目将被跳过并输出警告。设为 `0` 可禁用检查。 | - -## 🔗 Related links 相关链接 - -- [source code 源代码](../../../data_juicer/ops/mapper/latex_merge_tex_mapper.py) -- [unit test 单元测试](../../../tests/ops/mapper/test_latex_merge_tex_mapper.py) -- [Return operator list 返回算子列表](../../Operators.md) diff --git a/tests/format/test_json_formatter.py b/tests/format/test_json_formatter.py index 1628529188..826c4f76d0 100644 --- a/tests/format/test_json_formatter.py +++ b/tests/format/test_json_formatter.py @@ -1,68 +1,32 @@ import os import unittest -import gzip -import tempfile -import shutil from data_juicer.format.json_formatter import JsonFormatter from data_juicer.format.load import load_formatter from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -try: - import zstandard as zstd # type: ignore - - HAS_ZSTD = True -except Exception: - zstd = None - HAS_ZSTD = False - class JsonFormatterTest(DataJuicerTestCaseBase): def setUp(self): super().setUp() - self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data", "structured") - self._file = os.path.join(self._path, "demo-dataset.jsonl") + self._path = os.path.join(os.path.dirname(os.path.realpath(__file__)), + 'data', 'structured') + self._file = os.path.join(self._path, 'demo-dataset.jsonl') print(self._file) - # create compressed variants for testing - # create a temp directory to hold generated compressed files - self._temp_dir = tempfile.mkdtemp() - with open(self._file, "rb") as f: - raw = f.read() - - # .jsonl.gz - self._jsonl_gz = os.path.join(self._temp_dir, "demo-dataset.jsonl.gz") - with gzip.open(self._jsonl_gz, "wb") as f: - f.write(raw) - - # .json.gz (same content, different suffix) - self._json_gz = os.path.join(self._temp_dir, "demo-dataset.json.gz") - with gzip.open(self._json_gz, "wb") as f: - f.write(raw) - - # .json.zst and .jsonl.zst if zstandard available - if HAS_ZSTD: - self._jsonl_zst = os.path.join(self._temp_dir, "demo-dataset.jsonl.zst") - self._json_zst = os.path.join(self._temp_dir, "demo-dataset.json.zst") - cctx = zstd.ZstdCompressor() - compressed = cctx.compress(raw) - with open(self._jsonl_zst, "wb") as f: - f.write(compressed) - with open(self._json_zst, "wb") as f: - f.write(compressed) def test_json_file(self): formatter = JsonFormatter(self._file) ds = formatter.load_dataset() self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) def test_json_path(self): formatter = JsonFormatter(self._path) ds = formatter.load_dataset() self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) def test_load_formatter_with_file(self): """Test load_formatter with a direct file path""" @@ -70,48 +34,16 @@ def test_load_formatter_with_file(self): self.assertIsInstance(formatter, JsonFormatter) ds = formatter.load_dataset() self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) def test_load_formatter_with_specified_suffix(self): """Test load_formatter with specified suffixes""" - formatter = load_formatter(self._path, suffixes=[".jsonl"]) + formatter = load_formatter(self._path, suffixes=['.jsonl']) self.assertIsInstance(formatter, JsonFormatter) ds = formatter.load_dataset() self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) - - def tearDown(self): - # cleanup temp dir and files - if hasattr(self, "_temp_dir") and os.path.exists(self._temp_dir): - shutil.rmtree(self._temp_dir) - super().tearDown() - - def test_jsonl_gz_file(self): - formatter = JsonFormatter(self._jsonl_gz) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) - - def test_json_gz_file(self): - formatter = JsonFormatter(self._json_gz) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) - - @unittest.skipUnless(HAS_ZSTD, "zstandard not installed") - def test_json_zst_file(self): - formatter = JsonFormatter(self._json_zst) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) - - @unittest.skipUnless(HAS_ZSTD, "zstandard not installed") - def test_jsonl_zst_file(self): - formatter = JsonFormatter(self._jsonl_zst) - ds = formatter.load_dataset() - self.assertEqual(len(ds), 6) - self.assertEqual(list(ds.features.keys()), ["text", "meta"]) + self.assertEqual(list(ds.features.keys()), ['text', 'meta']) -if __name__ == "__main__": - unittest.main() +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/ops/data/video13.mp4 b/tests/ops/data/video13.mp4 new file mode 100755 index 0000000000..348986e61b Binary files /dev/null and b/tests/ops/data/video13.mp4 differ diff --git a/tests/ops/data/video14.mp4 b/tests/ops/data/video14.mp4 new file mode 100755 index 0000000000..9135af68d1 Binary files /dev/null and b/tests/ops/data/video14.mp4 differ diff --git a/tests/ops/data/video15.mp4 b/tests/ops/data/video15.mp4 new file mode 100755 index 0000000000..5125ab8892 Binary files /dev/null and b/tests/ops/data/video15.mp4 differ diff --git a/tests/ops/mapper/test_latex_figure_context_extractor_mapper.py b/tests/ops/mapper/test_latex_figure_context_extractor_mapper.py deleted file mode 100644 index 1f752b67bd..0000000000 --- a/tests/ops/mapper/test_latex_figure_context_extractor_mapper.py +++ /dev/null @@ -1,556 +0,0 @@ -import unittest - -from data_juicer.core.data import NestedDataset as Dataset -from data_juicer.ops.mapper.latex_figure_context_extractor_mapper import ( - LatexFigureContextExtractorMapper, -) -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - - -class LatexFigureContextExtractorMapperTest(DataJuicerTestCaseBase): - - def setUp(self): - super().setUp() - self.op = LatexFigureContextExtractorMapper() - - def _run_mapper(self, samples): - """Helper: run the batched mapper on a list of dicts.""" - dataset = Dataset.from_list(samples) - dataset = dataset.map( - self.op.process, batch_size=len(samples), - ) - return dataset.to_list() - - # ------------------------------------------------------------------ - # 1. Single figure with caption, label, and \includegraphics - # ------------------------------------------------------------------ - def test_single_figure(self): - latex = ( - '\\begin{document}\n' - 'Some intro text.\n\n' - 'As shown in \\ref{fig:arch}, the architecture is novel.\n\n' - '\\begin{figure}\n' - '\\centering\n' - '\\includegraphics[width=0.8\\linewidth]{img/arch.pdf}\n' - '\\caption{Overall architecture}\n' - '\\label{fig:arch}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Overall architecture') - self.assertEqual(results[0]['label'], 'fig:arch') - self.assertEqual(results[0]['images'], ['img/arch.pdf']) - self.assertEqual(len(results[0]['citing_paragraphs']), 1) - self.assertIn('\\ref{fig:arch}', - results[0]['citing_paragraphs'][0]) - # Standalone figure: parent fields are empty - self.assertEqual(results[0]['parent_caption'], '') - self.assertEqual(results[0]['parent_label'], '') - - # ------------------------------------------------------------------ - # 2. Figure with \begin{subfigure} environments (modern subcaption) - # ------------------------------------------------------------------ - def test_subfigure_environments(self): - latex = ( - '\\begin{document}\n' - 'See \\cref{fig:main} for details.\n\n' - 'Also \\ref{fig:sub_b} is interesting.\n\n' - '\\begin{figure}\n' - '\\centering\n' - '\\begin{subfigure}\n' - '\\includegraphics{img/a.png}\n' - '\\caption{Sub A}\n' - '\\label{fig:sub_a}\n' - '\\end{subfigure}\n' - '\\begin{subfigure}\n' - '\\includegraphics{img/b.png}\n' - '\\caption{Sub B}\n' - '\\label{fig:sub_b}\n' - '\\end{subfigure}\n' - '\\caption{Main caption}\n' - '\\label{fig:main}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - # Should produce 2 rows (one per subfigure) - self.assertEqual(len(results), 2) - - # Sub A - self.assertEqual(results[0]['caption'], 'Sub A') - self.assertEqual(results[0]['label'], 'fig:sub_a') - self.assertEqual(results[0]['images'], ['img/a.png']) - # Sub A inherits parent-level \cref{fig:main} citation - self.assertTrue( - any('\\cref{fig:main}' in p - for p in results[0]['citing_paragraphs']) - ) - # Sub A carries parent info - self.assertEqual(results[0]['parent_caption'], 'Main caption') - self.assertEqual(results[0]['parent_label'], 'fig:main') - - # Sub B - self.assertEqual(results[1]['caption'], 'Sub B') - self.assertEqual(results[1]['label'], 'fig:sub_b') - self.assertEqual(results[1]['images'], ['img/b.png']) - # Sub B has both parent citation and its own \ref{fig:sub_b} - contexts_b = results[1]['citing_paragraphs'] - self.assertTrue( - any('\\cref{fig:main}' in p for p in contexts_b) - ) - self.assertTrue( - any('\\ref{fig:sub_b}' in p for p in contexts_b) - ) - # Sub B also carries same parent info - self.assertEqual(results[1]['parent_caption'], 'Main caption') - self.assertEqual(results[1]['parent_label'], 'fig:main') - - # ------------------------------------------------------------------ - # 3. Figure with \subfigure[]{} commands (older subfig package) - # ------------------------------------------------------------------ - def test_subfigure_commands(self): - latex = ( - '\\begin{document}\n' - 'Refer to \\ref{fig:old}.\n\n' - '\\begin{figure}\n' - '\\centering\n' - '\\subfigure[Caption X]{\n' - ' \\includegraphics{img/x.pdf}\n' - ' \\label{fig:x}\n' - '}\n' - '\\subfigure[Caption Y]{\n' - ' \\includegraphics{img/y.pdf}\n' - ' \\label{fig:y}\n' - '}\n' - '\\caption{Old style}\n' - '\\label{fig:old}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 2) - self.assertEqual(results[0]['caption'], 'Caption X') - self.assertEqual(results[0]['images'], ['img/x.pdf']) - self.assertEqual(results[0]['parent_caption'], 'Old style') - self.assertEqual(results[0]['parent_label'], 'fig:old') - self.assertEqual(results[1]['caption'], 'Caption Y') - self.assertEqual(results[1]['images'], ['img/y.pdf']) - self.assertEqual(results[1]['parent_caption'], 'Old style') - self.assertEqual(results[1]['parent_label'], 'fig:old') - - # ------------------------------------------------------------------ - # 4. Figure without \includegraphics is skipped - # ------------------------------------------------------------------ - def test_figure_without_images_skipped(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\caption{No image here}\n' - '\\label{fig:empty}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 0) - - # ------------------------------------------------------------------ - # 5. No figures at all — sample is dropped - # ------------------------------------------------------------------ - def test_no_figures_drops_sample(self): - latex = ( - '\\begin{document}\n' - 'Just some text, no figures.\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 0) - - # ------------------------------------------------------------------ - # 6. Citation via comma-separated \cref{fig:a,fig:b} - # ------------------------------------------------------------------ - def test_comma_separated_cref(self): - latex = ( - '\\begin{document}\n' - 'See \\cref{fig:a,fig:b} for comparison.\n\n' - '\\begin{figure}\n' - '\\includegraphics{img/a.png}\n' - '\\caption{Figure A}\n' - '\\label{fig:a}\n' - '\\end{figure}\n' - '\\begin{figure}\n' - '\\includegraphics{img/b.png}\n' - '\\caption{Figure B}\n' - '\\label{fig:b}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 2) - # Both figures should find the paragraph with \cref{fig:a,fig:b} - for r in results: - self.assertTrue( - any('\\cref{fig:a,fig:b}' in p - for p in r['citing_paragraphs']) - ) - - # ------------------------------------------------------------------ - # 7. Multiple figures in one document - # ------------------------------------------------------------------ - def test_multiple_figures(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/one.png}\n' - '\\caption{First}\n' - '\\label{fig:one}\n' - '\\end{figure}\n' - '\\begin{figure}\n' - '\\includegraphics{img/two.png}\n' - '\\caption{Second}\n' - '\\label{fig:two}\n' - '\\end{figure}\n' - '\\begin{figure*}\n' - '\\includegraphics{img/three.png}\n' - '\\caption{Third wide}\n' - '\\label{fig:three}\n' - '\\end{figure*}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 3) - captions = [r['caption'] for r in results] - self.assertEqual(captions, ['First', 'Second', 'Third wide']) - - # ------------------------------------------------------------------ - # 8. figure* environment is recognized - # ------------------------------------------------------------------ - def test_figure_star(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure*}\n' - '\\includegraphics{img/wide.png}\n' - '\\caption{Wide figure}\n' - '\\label{fig:wide}\n' - '\\end{figure*}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Wide figure') - - # ------------------------------------------------------------------ - # 9. wrapfigure environment - # ------------------------------------------------------------------ - def test_wrapfigure(self): - latex = ( - '\\begin{document}\n' - '\\begin{wrapfigure}{r}{0.5\\textwidth}\n' - '\\includegraphics{img/wrap.png}\n' - '\\caption{Wrapped}\n' - '\\label{fig:wrap}\n' - '\\end{wrapfigure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Wrapped') - self.assertEqual(results[0]['label'], 'fig:wrap') - - # ------------------------------------------------------------------ - # 10. Nested caption braces (e.g. \textbf{...} inside caption) - # ------------------------------------------------------------------ - def test_nested_caption_braces(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/nest.png}\n' - '\\caption{A \\textbf{bold \\emph{italic}} caption}\n' - '\\label{fig:nest}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertIn('\\textbf{bold \\emph{italic}}', - results[0]['caption']) - - # ------------------------------------------------------------------ - # 10b. Deeply nested caption braces (5+ levels) - # ------------------------------------------------------------------ - def test_deeply_nested_caption_braces(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/deep.png}\n' - '\\caption{A \\textbf{B \\emph{C \\footnote{D \\cite{E}}}}}\n' - '\\label{fig:deep}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertIn( - '\\textbf{B \\emph{C \\footnote{D \\cite{E}}}}', - results[0]['caption'], - ) - - # ------------------------------------------------------------------ - # 11. \captionof{figure}{...} is recognized - # ------------------------------------------------------------------ - def test_captionof(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/cof.png}\n' - '\\captionof{figure}{Caption via captionof}\n' - '\\label{fig:cof}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], - 'Caption via captionof') - - # ------------------------------------------------------------------ - # 12. \autoref citation command - # ------------------------------------------------------------------ - def test_autoref(self): - latex = ( - '\\begin{document}\n' - 'See \\autoref{fig:auto} for details.\n\n' - '\\begin{figure}\n' - '\\includegraphics{img/auto.png}\n' - '\\caption{Auto}\n' - '\\label{fig:auto}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(len(results[0]['citing_paragraphs']), 1) - - # ------------------------------------------------------------------ - # 13. Label fig:a must not false-match fig:ab - # ------------------------------------------------------------------ - def test_label_boundary(self): - latex = ( - '\\begin{document}\n' - 'See \\ref{fig:ab} for details.\n\n' - '\\begin{figure}\n' - '\\includegraphics{img/a.png}\n' - '\\caption{A}\n' - '\\label{fig:a}\n' - '\\end{figure}\n' - '\\begin{figure}\n' - '\\includegraphics{img/ab.png}\n' - '\\caption{AB}\n' - '\\label{fig:ab}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 2) - # fig:a should NOT match the paragraph citing \ref{fig:ab} - fig_a = [r for r in results if r['label'] == 'fig:a'][0] - self.assertEqual(fig_a['citing_paragraphs'], []) - # fig:ab should match - fig_ab = [r for r in results if r['label'] == 'fig:ab'][0] - self.assertEqual(len(fig_ab['citing_paragraphs']), 1) - - # ------------------------------------------------------------------ - # 14. Custom output keys - # ------------------------------------------------------------------ - def test_custom_keys(self): - op = LatexFigureContextExtractorMapper( - caption_key='fig_caption', - label_key='fig_label', - context_key='fig_context', - ) - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/custom.png}\n' - '\\caption{Custom}\n' - '\\label{fig:custom}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - dataset = Dataset.from_list([{'text': latex}]) - dataset = dataset.map(op.process, batch_size=1) - results = dataset.to_list() - self.assertEqual(len(results), 1) - self.assertIn('fig_caption', results[0]) - self.assertIn('fig_label', results[0]) - self.assertIn('fig_context', results[0]) - self.assertEqual(results[0]['fig_caption'], 'Custom') - - # ------------------------------------------------------------------ - # 15. Multiple samples in one batch (fan-out + drop) - # ------------------------------------------------------------------ - def test_batch_mixed(self): - latex_with_fig = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/ok.png}\n' - '\\caption{OK}\n' - '\\label{fig:ok}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - latex_no_fig = ( - '\\begin{document}\n' - 'No figures here.\n' - '\\end{document}\n' - ) - results = self._run_mapper([ - {'text': latex_with_fig}, - {'text': latex_no_fig}, - ]) - # Only the first sample produces output - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'OK') - - - # ------------------------------------------------------------------ - # 16. \caption[short]{long} — optional short caption - # ------------------------------------------------------------------ - def test_caption_with_short_form(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/short.png}\n' - '\\caption[Short form]{Long detailed caption}\n' - '\\label{fig:short}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Long detailed caption') - - # ------------------------------------------------------------------ - # 17. \subcaption[short]{long} — optional short caption - # ------------------------------------------------------------------ - def test_subcaption_with_short_form(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\centering\n' - '\\begin{subfigure}\n' - '\\includegraphics{img/sc.png}\n' - '\\subcaption[Short sub]{Long subcaption text}\n' - '\\label{fig:sc}\n' - '\\end{subfigure}\n' - '\\caption{Parent}\n' - '\\label{fig:parent_sc}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Long subcaption text') - self.assertEqual(results[0]['parent_caption'], 'Parent') - - # ------------------------------------------------------------------ - # 18. \captionof{figure}[short]{long} — optional short caption - # ------------------------------------------------------------------ - def test_captionof_with_short_form(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/cof2.png}\n' - '\\captionof{figure}[Short cof]{Long captionof text}\n' - '\\label{fig:cof2}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], 'Long captionof text') - - - # ------------------------------------------------------------------ - # 19. \subfloat[caption]{content} — subfig package - # ------------------------------------------------------------------ - def test_subfloat_with_caption(self): - latex = ( - '\\begin{document}\n' - 'See \\ref{fig:sf_parent} for comparison.\n\n' - '\\begin{figure}\n' - '\\centering\n' - '\\subfloat[Float A]{\n' - ' \\includegraphics{img/fa.png}\n' - ' \\label{fig:fa}\n' - '}\n' - '\\subfloat[Float B]{\n' - ' \\includegraphics{img/fb.png}\n' - ' \\label{fig:fb}\n' - '}\n' - '\\caption{Subfloat parent}\n' - '\\label{fig:sf_parent}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 2) - self.assertEqual(results[0]['caption'], 'Float A') - self.assertEqual(results[0]['images'], ['img/fa.png']) - self.assertEqual(results[0]['parent_caption'], 'Subfloat parent') - self.assertEqual(results[0]['parent_label'], 'fig:sf_parent') - self.assertEqual(results[1]['caption'], 'Float B') - self.assertEqual(results[1]['images'], ['img/fb.png']) - self.assertEqual(results[1]['parent_caption'], 'Subfloat parent') - # Parent citation inherited - self.assertTrue( - any('\\ref{fig:sf_parent}' in p - for p in results[0]['citing_paragraphs']) - ) - - # ------------------------------------------------------------------ - # 20. \subfloat{content} — without optional caption - # ------------------------------------------------------------------ - def test_subfloat_without_caption(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\centering\n' - '\\subfloat{\n' - ' \\includegraphics{img/nc.png}\n' - ' \\label{fig:nc}\n' - '}\n' - '\\caption{No-caption subfloats}\n' - '\\label{fig:nc_parent}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], '') - self.assertEqual(results[0]['images'], ['img/nc.png']) - self.assertEqual(results[0]['parent_caption'], - 'No-caption subfloats') - - # ------------------------------------------------------------------ - # 21. \caption*{...} — unnumbered caption - # ------------------------------------------------------------------ - def test_caption_star(self): - latex = ( - '\\begin{document}\n' - '\\begin{figure}\n' - '\\includegraphics{img/star.png}\n' - '\\caption*{Unnumbered caption text}\n' - '\\label{fig:star}\n' - '\\end{figure}\n' - '\\end{document}\n' - ) - results = self._run_mapper([{'text': latex}]) - self.assertEqual(len(results), 1) - self.assertEqual(results[0]['caption'], - 'Unnumbered caption text') - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/ops/mapper/test_latex_merge_tex_mapper.py b/tests/ops/mapper/test_latex_merge_tex_mapper.py deleted file mode 100644 index 151539051f..0000000000 --- a/tests/ops/mapper/test_latex_merge_tex_mapper.py +++ /dev/null @@ -1,225 +0,0 @@ -import io -import os -import shutil -import tarfile -import tempfile -import unittest -import zipfile - -from data_juicer.core.data import NestedDataset as Dataset -from data_juicer.ops.mapper.latex_merge_tex_mapper import LatexMergeTexMapper -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - -TEX_A = r"""\documentclass{article} -\begin{document} -Hello from main. -\end{document} -""" - -TEX_B = r"""\section{Intro} -Some intro text. -""" - -TEX_C = r"""\section{Method} -Some method text. -""" - - -class LatexMergeTexMapperTest(DataJuicerTestCaseBase): - - def setUp(self): - self._tmpdir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self._tmpdir, ignore_errors=True) - - def _make_tar_gz(self, files, name="proj.tar.gz"): - """Create a tar.gz with str values (UTF-8 encoded) or raw bytes.""" - path = os.path.join(self._tmpdir, name) - with tarfile.open(path, "w:gz") as tf: - for fname, content in files.items(): - data = content.encode("utf-8") \ - if isinstance(content, str) else content - info = tarfile.TarInfo(name=fname) - info.size = len(data) - tf.addfile(info, io.BytesIO(data)) - return path - - def _make_tar(self, files, name="proj.tar"): - """Create a plain (uncompressed) tar archive.""" - path = os.path.join(self._tmpdir, name) - with tarfile.open(path, "w") as tf: - for fname, content in files.items(): - data = content.encode("utf-8") \ - if isinstance(content, str) else content - info = tarfile.TarInfo(name=fname) - info.size = len(data) - tf.addfile(info, io.BytesIO(data)) - return path - - def _make_zip(self, files, name="proj.zip"): - """Create a zip with str values (UTF-8 encoded) or raw bytes.""" - path = os.path.join(self._tmpdir, name) - with zipfile.ZipFile(path, "w") as zf: - for fname, content in files.items(): - zf.writestr(fname, content) - return path - - def _run(self, samples, op): - dataset = Dataset.from_list(samples) - dataset = dataset.map(op.process) - return list(dataset) - - def _assert_joined_in_either_order(self, result, part_a, part_b, sep): - """Assert *result* equals part_a + sep + part_b (in either order).""" - option1 = part_a + sep + part_b - option2 = part_b + sep + part_a - self.assertTrue( - result == option1 or result == option2, - f"Result does not match either ordering.\n" - f"Got:\n{result!r}\n" - f"Expected one of:\n{option1!r}\n--- or ---\n{option2!r}" - ) - - def _sample(self, archive_path): - """Build a sample dict with separate compressed_file and text keys.""" - return {"compressed_file": archive_path, "text": ""} - - def test_tar_gz_multiple_tex(self): - archive = self._make_tar_gz({ - "main.tex": TEX_A, - "intro.tex": TEX_B, - }) - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - self._assert_joined_in_either_order( - results[0]["text"], TEX_A, TEX_B, "\n\n") - - def test_zip_multiple_tex(self): - archive = self._make_zip({ - "main.tex": TEX_A, - "method.tex": TEX_C, - }) - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - self._assert_joined_in_either_order( - results[0]["text"], TEX_A, TEX_C, "\n\n") - - def test_plain_tar_multiple_tex(self): - archive = self._make_tar({ - "main.tex": TEX_A, - "intro.tex": TEX_B, - }) - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - self._assert_joined_in_either_order( - results[0]["text"], TEX_A, TEX_B, "\n\n") - - def test_tgz_multiple_tex(self): - archive = self._make_tar_gz({ - "main.tex": TEX_A, - "method.tex": TEX_C, - }, name="proj.tgz") - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - self._assert_joined_in_either_order( - results[0]["text"], TEX_A, TEX_C, "\n\n") - - def test_unsupported_extension(self): - path = os.path.join(self._tmpdir, "paper.gz") - with open(path, "wb") as f: - f.write(b"dummy") - results = self._run([self._sample(path)], LatexMergeTexMapper()) - self.assertEqual(results[0]["text"], "") - - def test_no_tex_in_archive(self): - archive = self._make_tar_gz({ - "readme.md": "# Hello", - "fig.png": "fake-png-bytes", - }) - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - self.assertEqual(results[0]["text"], "") - - def test_custom_separator(self): - archive = self._make_tar_gz({ - "a.tex": TEX_A, - "b.tex": TEX_B, - }) - sep = "\n%%% FILE BOUNDARY %%%\n" - op = LatexMergeTexMapper(separator=sep) - results = self._run([self._sample(archive)], op) - self._assert_joined_in_either_order( - results[0]["text"], TEX_A, TEX_B, sep) - - def test_custom_compressed_file_key(self): - archive = self._make_tar_gz({ - "main.tex": TEX_A, - }) - samples = [{"text": "", "archive_path": archive}] - op = LatexMergeTexMapper(compressed_file_key="archive_path") - results = self._run(samples, op) - self.assertIn(TEX_A.strip(), results[0]["text"]) - - def test_multiple_samples(self): - archive1 = self._make_tar_gz( - {"a.tex": TEX_A}, name="p1.tar.gz") - archive2 = self._make_zip( - {"b.tex": TEX_B}, name="p2.zip") - samples = [self._sample(archive1), self._sample(archive2)] - results = self._run(samples, LatexMergeTexMapper()) - self.assertIn(TEX_A.strip(), results[0]["text"]) - self.assertIn(TEX_B.strip(), results[1]["text"]) - - def test_invalid_path(self): - results = self._run( - [self._sample("/nonexistent/path/foo.tar.gz")], - LatexMergeTexMapper()) - self.assertEqual(results[0]["text"], "") - - def test_latin1_encoding_tar(self): - latin1_tex = b"\\section{R\xe9sum\xe9}\nCaf\xe9 na\xefve \xfcber.\n" - archive = self._make_tar_gz( - {"paper.tex": latin1_tex}, name="latin1.tar.gz") - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - text = results[0]["text"] - self.assertIn("\\section{R", text) - self.assertIn("\ufffd", text, - "Non-UTF-8 bytes should be replaced with U+FFFD") - self.assertNotIn("\xe9", text, - "Raw Latin-1 chars must not survive as-is") - - def test_latin1_encoding_zip(self): - latin1_tex = b"\\begin{document}\nStra\xdfe na\xefve.\n\\end{document}\n" - archive = self._make_zip( - {"doc.tex": latin1_tex}, name="latin1.zip") - results = self._run([self._sample(archive)], LatexMergeTexMapper()) - text = results[0]["text"] - self.assertIn("\\begin{document}", text) - self.assertIn("\ufffd", text, - "Non-UTF-8 bytes should be replaced with U+FFFD") - self.assertNotIn("\xdf", text, - "Raw Latin-1 chars must not survive as-is") - - def test_max_file_size_tar(self): - small = "ok" - big = "x" * 200 - archive = self._make_tar_gz({ - "small.tex": small, - "big.tex": big, - }, name="size_limit.tar.gz") - op = LatexMergeTexMapper(max_file_size=100) - results = self._run([self._sample(archive)], op) - self.assertIn(small, results[0]["text"]) - self.assertNotIn(big, results[0]["text"]) - - def test_max_file_size_zip(self): - small = "ok" - big = "x" * 200 - archive = self._make_zip({ - "small.tex": small, - "big.tex": big, - }, name="size_limit.zip") - op = LatexMergeTexMapper(max_file_size=100) - results = self._run([self._sample(archive)], op) - self.assertIn(small, results[0]["text"]) - self.assertNotIn(big, results[0]["text"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/ops/mapper/test_nlpaug_en_mapper_batch_bug.py b/tests/ops/mapper/test_nlpaug_en_mapper_batch_bug.py deleted file mode 100644 index 7c650f14ae..0000000000 --- a/tests/ops/mapper/test_nlpaug_en_mapper_batch_bug.py +++ /dev/null @@ -1,104 +0,0 @@ -# Reproduction test for batch processing bug in NlpaugEnMapper -# Bug: samples[self.text_key][0] only processes the first sample in a batch. -# In production (batch_size=1000), 999 out of 1000 samples are never augmented. - -import unittest - -from data_juicer.ops.mapper.nlpaug_en_mapper import NlpaugEnMapper - - -class NlpaugEnMapperBatchBugTest(unittest.TestCase): - """Demonstrate that process_batched only augments the first sample.""" - - def test_batch_only_augments_first_sample(self): - """With batch_size > 1, only the first sample should be augmented - (before the fix). After the fix, all samples should be augmented.""" - op = NlpaugEnMapper( - sequential=False, - aug_num=1, - keep_original_sample=True, - delete_random_word=True, - ) - - # Simulate a batch of 3 samples (dict-of-lists format) - samples = { - 'text': [ - 'The quick brown fox jumps over the lazy dog', - 'Machine learning is transforming the world today', - 'Natural language processing enables computers to understand text', - ], - 'meta': ['meta1', 'meta2', 'meta3'], - } - - result = op.process_batched(samples) - - # With 3 input samples, 1 aug method, aug_num=1, keep_original=True: - # Each sample should produce 1 original + 1 augmented = 2 texts per sample - # Total expected: 3 originals + 3 augmented = 6 - num_texts = len(result['text']) - num_metas = len(result['meta']) - - # Assert that ALL 3 original texts are present in the output - for original_text in samples['text']: - self.assertIn(original_text, result['text'], - f"Original text missing from output: {original_text}") - - # Assert correct total count: 3 originals + 3 augmented = 6 - self.assertEqual(num_texts, 6, - f"Expected 6 texts (3 original + 3 augmented), got {num_texts}") - self.assertEqual(num_metas, num_texts, - f"Meta count ({num_metas}) should match text count ({num_texts})") - - def test_batch_without_keep_original(self): - """Without keeping originals, all samples should still be augmented.""" - op = NlpaugEnMapper( - sequential=False, - aug_num=1, - keep_original_sample=False, - delete_random_word=True, - ) - - samples = { - 'text': [ - 'The quick brown fox jumps over the lazy dog', - 'Machine learning is transforming the world today', - 'Natural language processing enables computers to understand text', - ], - 'meta': ['meta1', 'meta2', 'meta3'], - } - - result = op.process_batched(samples) - num_texts = len(result['text']) - - # Should have exactly 3 augmented texts (one per input sample) - self.assertEqual(num_texts, 3, - f"Expected 3 augmented texts, got {num_texts}") - - def test_batch_sequential_mode(self): - """Sequential mode should also process all samples in the batch.""" - op = NlpaugEnMapper( - sequential=True, - aug_num=2, - keep_original_sample=True, - delete_random_word=True, - swap_random_char=True, - ) - - samples = { - 'text': [ - 'The quick brown fox jumps over the lazy dog', - 'Machine learning is transforming the world today', - ], - 'meta': ['meta1', 'meta2'], - } - - result = op.process_batched(samples) - num_texts = len(result['text']) - - # 2 originals + 2 samples * 2 aug_num = 6 - self.assertEqual(num_texts, 6, - f"Expected 6 texts (2 original + 4 augmented), got {num_texts}") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/test_file_utils.py b/tests/utils/test_file_utils.py index dd6844037d..721870ba2c 100644 --- a/tests/utils/test_file_utils.py +++ b/tests/utils/test_file_utils.py @@ -1,56 +1,55 @@ import os import unittest import regex as re -import gzip from data_juicer.utils.file_utils import ( - find_files_with_suffix, - is_absolute_path, - add_suffix_to_filename, - create_directory_if_not_exists, - transfer_filename, - copy_data, + find_files_with_suffix, is_absolute_path, + add_suffix_to_filename, create_directory_if_not_exists, transfer_filename, + copy_data ) from data_juicer.utils.mm_utils import Fields from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase - class FileUtilsTest(DataJuicerTestCaseBase): def setUp(self) -> None: super().setUp() - self.temp_output_path = "tmp/test_file_utils/" + self.temp_output_path = 'tmp/test_file_utils/' os.makedirs(self.temp_output_path) def tearDown(self): if os.path.exists(self.temp_output_path): - os.system(f"rm -rf {self.temp_output_path}") + os.system(f'rm -rf {self.temp_output_path}') super().tearDown() def test_find_files_with_suffix(self): # prepare test files - fn_list = ["test1.txt", "test2.txt", "test3.md"] + fn_list = ['test1.txt', 'test2.txt', 'test3.md'] for fn in fn_list: - with open(os.path.join(self.temp_output_path, fn), "w") as f: + with open(os.path.join(self.temp_output_path, fn), 'w') as f: f.write(fn) - self.assertEqual( - find_files_with_suffix(os.path.join(self.temp_output_path, "test1.txt")), - {".txt": [os.path.join(self.temp_output_path, "test1.txt")]}, - ) + self.assertEqual(find_files_with_suffix(os.path.join(self.temp_output_path, 'test1.txt')), + {'.txt': [os.path.join(self.temp_output_path, 'test1.txt')]}) result = find_files_with_suffix(self.temp_output_path) expected = { - ".txt": sorted([os.path.join(self.temp_output_path, "test1.txt"), os.path.join(self.temp_output_path, "test2.txt")]), - ".md": [os.path.join(self.temp_output_path, "test3.md")], + '.txt': sorted([ + os.path.join(self.temp_output_path, 'test1.txt'), + os.path.join(self.temp_output_path, 'test2.txt') + ]), + '.md': [os.path.join(self.temp_output_path, 'test3.md')] } for suffix in result: result[suffix] = sorted(result[suffix]) self.assertEqual(result, expected) - result_txt = find_files_with_suffix(self.temp_output_path, "txt") + result_txt = find_files_with_suffix(self.temp_output_path, 'txt') expected_txt = { - ".txt": sorted([os.path.join(self.temp_output_path, "test1.txt"), os.path.join(self.temp_output_path, "test2.txt")]) + '.txt': sorted([ + os.path.join(self.temp_output_path, 'test1.txt'), + os.path.join(self.temp_output_path, 'test2.txt') + ]) } for suffix in result_txt: result_txt[suffix] = sorted(result_txt[suffix]) @@ -61,10 +60,10 @@ def test_is_absolute_path(self): self.assertTrue(is_absolute_path(os.path.abspath(self.temp_output_path))) def test_add_suffix_to_filename(self): - self.assertEqual(add_suffix_to_filename("test.txt", "_suffix"), "test_suffix.txt") - self.assertEqual(add_suffix_to_filename("test.txt", ""), "test.txt") - self.assertEqual(add_suffix_to_filename("test", "_suffix"), "test_suffix") - self.assertEqual(add_suffix_to_filename(".git", "_suffix"), ".git_suffix") + self.assertEqual(add_suffix_to_filename('test.txt', '_suffix'), 'test_suffix.txt') + self.assertEqual(add_suffix_to_filename('test.txt', ''), 'test.txt') + self.assertEqual(add_suffix_to_filename('test', '_suffix'), 'test_suffix') + self.assertEqual(add_suffix_to_filename('.git', '_suffix'), '.git_suffix') def test_create_directory_if_not_exists(self): self.assertTrue(os.path.exists(self.temp_output_path)) @@ -77,82 +76,55 @@ def test_create_directory_if_not_exists(self): def test_transfer_filename(self): # test existing file - with open(os.path.join(self.temp_output_path, "abc.jpg"), "w") as f: - f.write("test") + with open(os.path.join(self.temp_output_path, 'abc.jpg'), 'w') as f: + f.write('test') self.assertTrue( re.match( - os.path.join(self.temp_output_path, Fields.multimodal_data_output_dir, "op1", "abc__dj_hash_#(.*?)#.jpg"), - transfer_filename(os.path.join(self.temp_output_path, "abc.jpg"), "op1"), - ) - ) + os.path.join(self.temp_output_path, Fields.multimodal_data_output_dir, 'op1', 'abc__dj_hash_#(.*?)#.jpg'), + transfer_filename(os.path.join(self.temp_output_path, 'abc.jpg'), 'op1'))) # test non-existing file self.assertTrue( re.match( - os.path.join(self.temp_output_path, "non-existing.jpg"), - transfer_filename(os.path.join(self.temp_output_path, "non-existing.jpg"), "op1"), - ) - ) + os.path.join(self.temp_output_path, 'non-existing.jpg'), + transfer_filename(os.path.join(self.temp_output_path, 'non-existing.jpg'), 'op1'))) # test save_dir self.temp_output_path = os.path.abspath(self.temp_output_path) self.assertTrue( re.match( - os.path.join(self.temp_output_path, "tmp_save_dir", "abc__dj_hash_#(.*?)#.jpg"), - transfer_filename( - os.path.join(self.temp_output_path, "abc.jpg"), - "op1", - save_dir=os.path.join(self.temp_output_path, "tmp_save_dir"), - ), - ) - ) + os.path.join(self.temp_output_path, 'tmp_save_dir', 'abc__dj_hash_#(.*?)#.jpg'), + transfer_filename(os.path.join(self.temp_output_path, 'abc.jpg'), 'op1', + save_dir=os.path.join(self.temp_output_path, 'tmp_save_dir')))) # test env dir try: - ori_env_dir = os.environ.get("DJ_PRODUCED_DATA_DIR", None) - test_env_dir = os.path.join(self.temp_output_path, "tmp_env_dir") - os.environ["DJ_PRODUCED_DATA_DIR"] = test_env_dir + ori_env_dir = os.environ.get('DJ_PRODUCED_DATA_DIR', None) + test_env_dir = os.path.join(self.temp_output_path, 'tmp_env_dir') + os.environ['DJ_PRODUCED_DATA_DIR'] = test_env_dir - transfer_filename(os.path.join(self.temp_output_path, "abc.jpg"), "op1") + transfer_filename(os.path.join(self.temp_output_path, 'abc.jpg'), 'op1') self.assertTrue( re.match( - os.path.join(test_env_dir, "op1", "abc__dj_hash_#(.*?)#.jpg"), - transfer_filename(os.path.join(self.temp_output_path, "abc.jpg"), "op1"), - ) - ) + os.path.join(test_env_dir, 'op1', 'abc__dj_hash_#(.*?)#.jpg'), + transfer_filename(os.path.join(self.temp_output_path, 'abc.jpg'), 'op1'))) finally: if ori_env_dir: - os.environ["DJ_PRODUCED_DATA_DIR"] = ori_env_dir - elif "DJ_PRODUCED_DATA_DIR" in os.environ: - del os.environ["DJ_PRODUCED_DATA_DIR"] + os.environ['DJ_PRODUCED_DATA_DIR'] = ori_env_dir + elif 'DJ_PRODUCED_DATA_DIR' in os.environ: + del os.environ['DJ_PRODUCED_DATA_DIR'] def test_copy_data(self): - tgt_fn = "test.txt" - ori_dir = os.path.join(self.temp_output_path, "test1") - tgt_dir = os.path.join(self.temp_output_path, "test2") + tgt_fn = 'test.txt' + ori_dir = os.path.join(self.temp_output_path, 'test1') + tgt_dir = os.path.join(self.temp_output_path, 'test2') self.assertFalse(copy_data(ori_dir, tgt_dir, tgt_fn)) os.makedirs(ori_dir, exist_ok=True) - with open(os.path.join(ori_dir, tgt_fn), "w") as f: - f.write("test") + with open(os.path.join(ori_dir, tgt_fn), 'w') as f: + f.write('test') self.assertTrue(copy_data(ori_dir, tgt_dir, tgt_fn)) self.assertTrue(os.path.exists(os.path.join(tgt_dir, tgt_fn))) - def test_find_files_with_suffix_gzip(self): - # create a gzip compressed jsonl file and ensure it is detected as '.jsonl.gz' - content = '{"text": "gzip test"}\n' - gz_path = os.path.join(self.temp_output_path, "demo-dataset.jsonl.gz") - with gzip.open(gz_path, "wb") as f: - f.write(content.encode("utf-8")) - - result = find_files_with_suffix(self.temp_output_path) - - # normalize lists for comparison - for suffix in result: - result[suffix] = sorted(result[suffix]) - - self.assertIn(".jsonl.gz", result) - self.assertEqual(result[".jsonl.gz"], [gz_path]) - -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() diff --git a/thirdparty/humanvbench_models/.gitmodules b/thirdparty/humanvbench_models/.gitmodules new file mode 100644 index 0000000000..b7a42be124 --- /dev/null +++ b/thirdparty/humanvbench_models/.gitmodules @@ -0,0 +1,12 @@ +[submodule "YOLOv8_human"] + path = YOLOv8_human + url = https://github.com/jahongir7174/YOLOv8-human.git + # commit_id = 8f8a65e +[submodule "Light-ASD"] + path = Light-ASD + url = https://github.com/Junhua-Liao/Light-ASD.git + # commit_id = e4f33e1 +[submodule "SenseVoice"] + path = SenseVoice + url = https://github.com/FunAudioLLM/SenseVoice.git + # commit_id = 771252c diff --git a/thirdparty/humanvbench_models/Light-ASD_changes.diff b/thirdparty/humanvbench_models/Light-ASD_changes.diff new file mode 100644 index 0000000000..e549de9450 --- /dev/null +++ b/thirdparty/humanvbench_models/Light-ASD_changes.diff @@ -0,0 +1,27 @@ +diff --git a/model/faceDetector/s3fd/__init__.py b/model/faceDetector/s3fd/__init__.py +index 943292a..a029f3d 100644 +--- a/model/faceDetector/s3fd/__init__.py ++++ b/model/faceDetector/s3fd/__init__.py +@@ -6,7 +6,7 @@ from torchvision import transforms + from .nets import S3FDNet + from .box_utils import nms_ + +-PATH_WEIGHT = 'model/faceDetector/s3fd/sfd_face.pth' ++PATH_WEIGHT = './thirdparty/humanvbench_models/Light-ASD/model/faceDetector/s3fd/sfd_face.pth' + if os.path.isfile(PATH_WEIGHT) == False: + Link = "1KafnHz7ccT-3IyddBsL5yi2xGtxAKypt" + cmd = "gdown --id %s -O %s"%(Link, PATH_WEIGHT) + +diff --git a/model/faceDetector/s3fd/box_utils.py b/model/faceDetector/s3fd/box_utils.py +index 0779bcd..1bf4be2 100644 +--- a/model/faceDetector/s3fd/box_utils.py ++++ b/model/faceDetector/s3fd/box_utils.py +@@ -35,7 +35,7 @@ def nms_(dets, thresh): + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + +- return np.array(keep).astype(np.int) ++ return np.array(keep).astype(int) + + + def decode(loc, priors, variances): diff --git a/thirdparty/humanvbench_models/README.md b/thirdparty/humanvbench_models/README.md new file mode 100644 index 0000000000..c5fea18580 --- /dev/null +++ b/thirdparty/humanvbench_models/README.md @@ -0,0 +1,38 @@ +# HumanVBench Models Setup + +This directory manages the external models and adapters required for HumanVBench. Please follow the instructions below to ensure all components are correctly initialized and patched. + +## 1. Environment Initialization +These models are managed as Git submodules. To clone the contents, run the following command from the project root: + +```bash +git submodule update --init --recursive + +``` + +## 2. Applying Custom Patches + +We maintain custom modifications for these models via `.diff` files. You **must** apply these patches after initializing the submodules to ensure the pipeline functions correctly: + +```bash +# Navigate to this directory +cd thirdparty/humanvbench_models + +# Apply patch to YOLOv8-human +cd YOLOv8_human && git apply ../YOLOv8_human_changes.diff && cd .. + +# Apply patch to Light-ASD +cd Light-ASD && git apply ../Light-ASD_changes.diff && cd .. + +# Apply patch to SenseVoice +cd SenseVoice && git apply ../SenseVoice_changes.diff && cd .. + +``` + +## 3. External Weights Download + +The following weight file must be downloaded manually and placed in the specific directory: + +| Model | File | Target Path | Download Source | +| --- | --- | --- | --- | +| **Light-ASD** | `sfd_face.pth` | `./Light-ASD/model/faceDetector/s3fd/` | [HuggingFace - SyncNet](https://huggingface.co/lithiumice/syncnet/tree/main) | diff --git a/thirdparty/humanvbench_models/SenseVoice_changes.diff b/thirdparty/humanvbench_models/SenseVoice_changes.diff new file mode 100644 index 0000000000..f7e14be7ec --- /dev/null +++ b/thirdparty/humanvbench_models/SenseVoice_changes.diff @@ -0,0 +1,15 @@ +diff --git a/model.py b/model.py +index a89defd..11b1285 100644 +--- a/model.py ++++ b/model.py +@@ -13,7 +13,9 @@ from funasr.train_utils.device_funcs import force_gatherable + from funasr.losses.label_smoothing_loss import LabelSmoothingLoss + from funasr.metrics.compute_acc import compute_accuracy, th_accuracy + from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank +-from utils.ctc_alignment import ctc_forced_align ++import sys,os ++sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) ++from SenseVoice.utils.ctc_alignment import ctc_forced_align + + class SinusoidalPositionEncoder(torch.nn.Module): + """ """ diff --git a/thirdparty/humanvbench_models/YOLOv8_human_changes.diff b/thirdparty/humanvbench_models/YOLOv8_human_changes.diff new file mode 100644 index 0000000000..d7b9229a81 --- /dev/null +++ b/thirdparty/humanvbench_models/YOLOv8_human_changes.diff @@ -0,0 +1,133 @@ +diff --git a/dj.py b/dj.py +new file mode 100644 +index 0000000..ce25877 +--- /dev/null ++++ b/dj.py +@@ -0,0 +1,111 @@ ++import sys,os ++import warnings ++from argparse import ArgumentParser ++ ++import numpy ++import torch ++sys.path.append(os.path.dirname(os.path.abspath(__file__))) ++from nets import nn ++from util import non_max_suppression ++ ++warnings.filterwarnings("ignore") ++ ++ ++@torch.no_grad() ++def demo(img_array, model): ++ import cv2 ++ ++ frame = img_array ++ image = frame.copy() ++ shape = image.shape[:2] ++ ++ r = 640 / max(shape[0], shape[1]) ++ if r != 1: ++ resample = cv2.INTER_LINEAR if r > 1 else cv2.INTER_AREA ++ image = cv2.resize(image, dsize=(int(shape[1] * r), int(shape[0] * r)), interpolation=resample) ++ height, width = image.shape[:2] ++ ++ # Scale ratio (new / old) ++ r = min(1.0, 640 / height, 640 / width) ++ ++ # Compute padding ++ pad = int(round(width * r)), int(round(height * r)) ++ w = numpy.mod((640 - pad[0]), 32) / 2 ++ h = numpy.mod((640 - pad[1]), 32) / 2 ++ ++ if (width, height) != pad: # resize ++ image = cv2.resize(image, pad, interpolation=cv2.INTER_LINEAR) ++ top, bottom = int(round(h - 0.1)), int(round(h + 0.1)) ++ left, right = int(round(w - 0.1)), int(round(w + 0.1)) ++ image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border ++ ++ # Convert HWC to CHW, BGR to RGB ++ x = image.transpose((2, 0, 1))[::-1] ++ x = numpy.ascontiguousarray(x) ++ x = torch.from_numpy(x) ++ x = x.unsqueeze(dim=0) ++ x = x.to(next(model.parameters()).device) ++ x = x.half() ++ x = x / 255 ++ # Inference ++ outputs = model(x) ++ # NMS ++ outputs = non_max_suppression(outputs, 0.25, 0.7) ++ final_output_box_list = [] ++ for output in outputs: ++ output[:, [0, 2]] -= w # x padding ++ output[:, [1, 3]] -= h # y padding ++ output[:, :4] /= min(height / shape[0], width / shape[1]) ++ ++ output[:, 0].clamp_(0, shape[1]) # x1 ++ output[:, 1].clamp_(0, shape[0]) # y1 ++ output[:, 2].clamp_(0, shape[1]) # x2 ++ output[:, 3].clamp_(0, shape[0]) # y2 ++ ++ for box in output: ++ box = box.cpu().numpy() ++ x1, y1, x2, y2, score, index = box ++ final_output_box_list.append((x1, y1, x2, y2)) ++ # cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2) ++ del x ++ return final_output_box_list ++ ++ ++ ++def profile(args, params): ++ model = nn.yolo_v8_n(len(params['names'])) ++ shape = (1, 3, args.input_size, args.input_size) ++ ++ model.eval() ++ model(torch.zeros(shape)) ++ params = sum(p.numel() for p in model.parameters()) ++ if args.local_rank == 0: ++ print(f'Number of parameters: {int(params)}') ++ ++ ++def human_detect(img_array): ++ parser = ArgumentParser() ++ parser.add_argument('--input-size', default=640, type=int) ++ parser.add_argument('--local_rank', default=0, type=int) ++ ++ args = parser.parse_args() ++ ++ args.local_rank = int(os.getenv('LOCAL_RANK', 0)) ++ args.world_size = int(os.getenv('WORLD_SIZE', 1)) ++ args.distributed = int(os.getenv('WORLD_SIZE', 1)) > 1 ++ ++ if args.distributed: ++ torch.cuda.set_device(device=args.local_rank) ++ torch.distributed.init_process_group(backend='nccl', init_method='env://') ++ ++ if args.local_rank == 0: ++ if not os.path.exists('weights'): ++ os.makedirs('weights') ++ ++ profile(args, img_array) ++ ++ demo(args,img_array) ++ ++ ++if __name__ == "__main__": ++ main() +diff --git a/nets/nn.py b/nets/nn.py +index 66aec47..0dd5ee4 100644 +--- a/nets/nn.py ++++ b/nets/nn.py +@@ -1,8 +1,8 @@ + import math +- ++import sys,os + import torch +- +-from utils.util import make_anchors ++sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'utils'))) ++from util import make_anchors + + + def fuse_conv(conv, norm): diff --git a/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py b/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py new file mode 100644 index 0000000000..3d4fee61a3 --- /dev/null +++ b/thirdparty/humanvbench_models/audio_code/wav2vec_age_gender.py @@ -0,0 +1,112 @@ +import numpy as np +import torch +import torch.nn as nn +from transformers import Wav2Vec2Processor +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, +) + + +class ModelHead(nn.Module): + r"""Classification head.""" + + def __init__(self, config, num_labels): + + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, num_labels) + + def forward(self, features, **kwargs): + + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + +class AgeGenderModel(Wav2Vec2PreTrainedModel): + r"""Speech emotion classifier.""" + + def __init__(self, config): + + super().__init__(config) + + self.config = config + self.wav2vec2 = Wav2Vec2Model(config) + self.age = ModelHead(config, 1) + self.gender = ModelHead(config, 3) + self.init_weights() + + def forward( + self, + input_values, + ): + + outputs = self.wav2vec2(input_values) + hidden_states = outputs[0] + hidden_states = torch.mean(hidden_states, dim=1) + logits_age = self.age(hidden_states) + logits_gender = torch.softmax(self.gender(hidden_states), dim=1) + + return hidden_states, logits_age, logits_gender + + + +# load model from hub +# device = 'cpu' +# model_name = '/mnt1/daoyuan_mm/wav2vec2-large-robust-24-ft-age-gender' +# processor = Wav2Vec2Processor.from_pretrained(model_name) +# model = AgeGenderModel.from_pretrained(model_name) + +# dummy signal +# sampling_rate = 16000 +# signal = np.zeros((1, sampling_rate), dtype=np.float32) + + +def process_func( + x: np.ndarray, + sampling_rate: int, + processor, + model, + device, + embeddings: bool = False, +) -> np.ndarray: + r"""Predict age and gender or extract embeddings from raw audio signal.""" + + # run through processor to normalize signal + # always returns a batch, so we just get the first entry + # then we put it on the device + y = processor(x, sampling_rate=sampling_rate) + y = y['input_values'][0] + y = y.reshape(1, -1) + y = torch.from_numpy(y).to(device) + + # run through model + with torch.no_grad(): + y = model(y) + if embeddings: + y = y[0] + else: + y = torch.hstack([y[1], y[2]]) + + # convert to numpy + y = y.detach().cpu().numpy() + + return y + + +# print(process_func(signal, sampling_rate)) +# # Age female male child +# # [[ 0.33793038 0.2715511 0.2275236 0.5009253 ]] + +# print(process_func(signal, sampling_rate, embeddings=True)) +# Pooled hidden states of last transformer layer +# [[ 0.024444 0.0508722 0.04930823 ... 0.07247854 -0.0697901 +# -0.0170537 ]]