diff --git a/aiak_training_llm/data/chat_templete.py b/aiak_training_llm/data/chat_templete.py index 82262cc..722adfe 100644 --- a/aiak_training_llm/data/chat_templete.py +++ b/aiak_training_llm/data/chat_templete.py @@ -11,6 +11,7 @@ """ import re +import os from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import TYPE_CHECKING, Type, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -133,6 +134,9 @@ def encode_multiturn( messages = messages[1:] encoded_messages = self._encode(tokenizer, messages, system) + # ZXW: missing "/n" + if int(os.environ.get("FILL_TOKEN_198",0))==1: + encoded_messages[-1].append(198) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] def encode_oneturn( diff --git a/aiak_training_llm/data/multimodal/flavors/packed_captioning.py b/aiak_training_llm/data/multimodal/flavors/packed_captioning.py index ec1f6ee..f89226a 100644 --- a/aiak_training_llm/data/multimodal/flavors/packed_captioning.py +++ b/aiak_training_llm/data/multimodal/flavors/packed_captioning.py @@ -9,6 +9,20 @@ class PackedCaptioningSample(Sample): """Sample type for packed captioning.""" # sample_id: str - images: List[torch.Tensor] - prompts: Optional[List[str]] - captions: List[str] \ No newline at end of file + images: Union[ + str, # A single image path, e.g., 'img001.jpg' + torch.Tensor, # A single image tensor + List[str], # A list of image paths, e.g., ['imgs001.jpg', 'imgs002.png'] + List[List[str]], # A nested list of image paths, e.g., [['imgs001.tif'], [], ['imgs005.png']] + List[torch.Tensor], # A list of image tensors + List[List[torch.Tensor]] + ] + prompts: Union[ + Optional[List[str]], + List[List[str]] + ] + captions: Union[ + List[str], + List[List[str]] + ] + \ No newline at end of file diff --git a/aiak_training_llm/data/multimodal/task_encoder.py b/aiak_training_llm/data/multimodal/task_encoder.py index ee3bd69..9c721e9 100644 --- a/aiak_training_llm/data/multimodal/task_encoder.py +++ b/aiak_training_llm/data/multimodal/task_encoder.py @@ -188,6 +188,39 @@ def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, Si context=sample.prompts[idx] ) l_Qwen2VLImageTaskSample.append(self.encode_vqa4packing(cur_capsample)) + elif int(os.environ.get("OFFLINE_PACKING_BMR",0))==1: + def convert_to_messages(cur_prompt, cur_caption): + """ + {cur_prompt, cur_caption}---> messages + """ + + if len(cur_prompt) != len(cur_caption): + raise ValueError("cur_prompt & cur_caption have different lengths") + + messages = [] + for prompt, caption in zip(cur_prompt, cur_caption): + messages.append({ + "content": prompt, + "role": "user" + }) + + messages.append({ + "content": caption, + "role": "assistant" + }) + + return messages + cur_capsample = MultiMixQASample( + __key__=f"{sample.__key__}.img{idx:03d}_jpg", + __restore_key__=sample.__restore_key__, + __subflavor__='BMR', + __subflavors__=sample.__subflavors__, + messages=convert_to_messages(sample.prompts[idx], sample.captions[idx]), + video=None, + system=None, + image=[sample.images[idx]] if sample.images[idx] else None + ) + l_Qwen2VLImageTaskSample.append(self.encode_multi_mix_qa(cur_capsample)) else: cur_capsample = CaptioningSample( __key__=f"{sample.__key__}.img{idx:03d}_jpg", diff --git a/tools/data_preprocess/offline_packing/configs/s1_config_MMR_sft_780k.yaml b/tools/data_preprocess/offline_packing/configs/s1_config_MMR_sft_780k.yaml new file mode 100644 index 0000000..1eba647 --- /dev/null +++ b/tools/data_preprocess/offline_packing/configs/s1_config_MMR_sft_780k.yaml @@ -0,0 +1,57 @@ +# Data path configuration +data: + # Directory of data samples + directory: "/data_1/llava_next_raw_full/split_json_files/" + # Temporary file for storing paired filenames + output_base: "base_name_v4_MR_sft_780k_8k.txt" + # Final output file (includes token length information) + output_token: "token_info_MR_sft_780k_8k.txt" + +# Model path +model: + checkpoint: "/vlm/xiangan/pretrain_models/rice_vl/rice_vl_rice_300m_qwen2.5_7b_adapter_v1_fixed_tokenizer_huggingface" + +sample: + # Maximum length for training data + max_len: 8192 + # Whether to remove one token (The current data preprocessing step is missing one token; this flag is used to decide if alignment is needed) + # Used in conjunction with the environment variable FILL_TOKEN_198: + # false: FILL_TOKEN_198=1, Achieves faster convergence + # true: FILL_TOKEN_198=0 + del_one_token: false + # Decides the parsing method + task_type: sft + max_prompt: null + max_answer: null + +# Image processing parameters +image: + baidu_resolution: 1600 # baidu code's limit parameter (null) + min_pixels: 3136 # 4*28*28 + max_pixels: 4014080 # 5120*28*28(4014080,8192) + # Maximum aspect ratio limit (images exceeding this value will be filtered out) + max_aspect_ratio: 200 + +# Parallel processing parameters +processing: + # Number of samples each process handles + chunk_size: 5000 + # Merge parameter (sorting), merge N stage0 files into 1 stage1 file + stage1_merge_chunk: 20 + n_workers: 64 + # Minimum number of threads in the thread pool + min_workers: 10 + # Maximum number of threads in the thread pool + max_workers: 32 + # Timeout setting (based on data size, 1M data estimated to take 45 minutes (2700s)) + time_out: 20000 + +# Logging and temporary file configuration +logging: + # Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + level: "INFO" + # Log file path + file: "./logs/s1_processing_MR_sft_780k_8k.log" + # Whether to use /dev/shm as the temporary directory + use_shm: false + diff --git a/tools/data_preprocess/offline_packing/convert_packedsample_to_wds.py b/tools/data_preprocess/offline_packing/convert_packedsample_to_wds.py index 64afb52..710ae14 100644 --- a/tools/data_preprocess/offline_packing/convert_packedsample_to_wds.py +++ b/tools/data_preprocess/offline_packing/convert_packedsample_to_wds.py @@ -6,9 +6,10 @@ ├── ps_00000000.img000.jpg ├── ps_00000000.img001.jpg ├── ps_00000000.json + ... -JSON 格式: +JSON format(pretrain): { "images": ["img000.jpg", "img001.jpg", ...], "prompt": ["描述", "what about", ""], @@ -17,9 +18,20 @@ 一条 json + 对应若干 jpg = 1 条 tar 记录 """ -######-----------------------------------------###### -######-----------------------------------------###### -######-----------------------------------------###### +""" + +JSON format (multi-round + blend data for sft): +{ + "images": [["img000.jpg"], ["img001.jpg"], []], + "prompt": [["描述这幅图"], ["what about this fig"], ["How are you?", "I am fine too."]], + "captions": [["stri"], ["str2"], ["I am fine and you?", "Have a nice day"]] +} +一条 json + 对应若干 jpg(可以为0) = 1 条 tar 记录 +""" + +# #####-----------------------------------------###### +# #####-----------------------------------------###### +# #####-----------------------------------------###### import argparse import uuid @@ -57,8 +69,8 @@ def sample_loader_template(media: str=None): "def part_filter(part: str) -> bool:", " return True", ]) - -### ZXW + +# ## ZXW def sample_loader_template_caption(media=None): """适配整条多图 captioning 的 loader""" @@ -107,7 +119,25 @@ def construct_sample_caption(args, entry): sample["json"] = json.dumps(payload, ensure_ascii=False).encode("utf-8") return sample -### ZXW +def construct_bmr_sample(args, entry): + "multi-round & blend data package" + sample = {"__key__": entry["id"]} + for idx, img_name in enumerate(entry["images"]): + # print(img_name) + img_path = os.path.join(args.image_dir, f"{entry["id"]}.{img_name[0]}") if img_name else None + if img_name: + with open(img_path, "rb") as f: + sample[f"img{idx}.jpg"] = f.read() + payload = { + "prompts": entry["prompts"], + "captions": entry["captions"], + "images": entry["images"] + } + + sample["json"] = json.dumps(payload, ensure_ascii=False).encode("utf-8") + return sample + +# ## ZXW def construct_sample(args, vision, path, entry): """ construct webdataset sample """ @@ -139,7 +169,7 @@ def convert_to_wds(args): tar = os.path.join(args.output_dir, 'pretrain-%06d.tar') if args.mode == "caption_pack": - # 新模式 + # 新模式 1 with wds.ShardWriter(tar, maxcount=args.maxcount, maxsize=args.maxsize) as sink: for entry in tqdm(stream_samples_caption(args.json_file)): sample=construct_sample_caption(args, entry) @@ -150,7 +180,21 @@ def convert_to_wds(args): write_config(EPath(args.output_dir).absolute(), args.media, template_func=sample_loader_template_caption, - class_name="PackedCaptioningSample") + class_name="PackedCaptioningSample") + elif args.mode == "bmr_pack": + # 新模式 2 + with wds.ShardWriter(tar, maxcount=args.maxcount, maxsize=args.maxsize) as sink: + for entry in tqdm(stream_samples_caption(args.json_file)): + sample=construct_bmr_sample(args, entry) + # print(sample.keys()) + sink.write(sample) + # break + # sink.write(construct_sample_caption(args.image_dir, entry)) + + write_config(EPath(args.output_dir).absolute(), args.media, + template_func=sample_loader_template_caption, + class_name="PackedCaptioningSample") + pass print(f"Dataset successfully converted to wds") @@ -201,7 +245,7 @@ def _add_arguments(parser: argparse.ArgumentParser): group.add_argument('--columns_messages', type=str, default="messages", help='Column name for messages') # 新增模式选择 group.add_argument('--mode', type=str, - choices=["chat", "caption_pack"], + choices=["chat", "caption_pack", "bmr_pack"], default="chat", help="chat=旧格式(单图对话); caption_pack=新格式(整条多图caption)") return parser diff --git a/tools/data_preprocess/offline_packing/hashbacket.py b/tools/data_preprocess/offline_packing/hashbacket.py index 6379098..52a60e4 100644 --- a/tools/data_preprocess/offline_packing/hashbacket.py +++ b/tools/data_preprocess/offline_packing/hashbacket.py @@ -2108,9 +2108,14 @@ class PackingTracker: def __init__(self, processor): self.processor = processor self.history = [] + self.snapshots = [] # 添加状态快照(2025.09.12) def track_packing(self, strategy_name: str, **kwargs): """记录一次装箱操作""" + + # 保存状态快照 (2025.09.12) + self.save_current_state() + before_state = self.processor.check_hash_buckets_state() # 支持返回详细统计(如 total_attempts),否则只返回箱子列表 result = getattr(self.processor, strategy_name)(**kwargs) @@ -2147,6 +2152,33 @@ def print_summary(self): print(f"成功率: {rate:.1%} ({op['boxes_count']}/{op['total_attempts']})") else: print(f"成功率: N/A") + + # (2025.09.12) + def save_current_state(self): + """保存当前状态快照""" + snapshot = { + 'hash_buckets': {k: arr.copy() for k, arr in self.processor.hash_buckets.items()}, + 'timestamp': time.time() + } + self.snapshots.append(snapshot) # 这里已经完成了追加操作 + print("save checkpoint........") + # return snapshot + + # (2025.09.12) + # + def restore_state(self, index: int): + """恢复到指定操作前的状态""" + if 0 <= index < len(self.snapshots): + snapshot = self.snapshots[index] + self.processor.hash_buckets = { + k: arr.copy() for k, arr in snapshot['hash_buckets'].items() + } + # 清理该索引之后的所有状态和历史记录 + self.snapshots = self.snapshots[:index+1] + self.history = self.history[:index] + print(f"已恢复到操作 {index} 之前的状态") + return True + return False # # 使用示例 diff --git a/tools/data_preprocess/offline_packing/s1_get_tokenlens_v4-sft.py b/tools/data_preprocess/offline_packing/s1_get_tokenlens_v4-sft.py new file mode 100644 index 0000000..9057a78 --- /dev/null +++ b/tools/data_preprocess/offline_packing/s1_get_tokenlens_v4-sft.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Usage +python s1_get_tokenlens_v4-sft.py --config ./configs/s1_config_MMR_sft_780k.yaml +""" + +import os +import json +import orjson +import threading +import logging +import psutil +import tempfile +import queue +import yaml +import argparse +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from heapq import merge +from PIL import Image +from jinja2 import Template +from transformers import AutoProcessor +from transformers import BitsAndBytesConfig +from qwen_vl_utils import fetch_image +from queue import Empty +import multiprocessing +from multiprocessing import Pool, Manager, Value + +# Declare a global, cross-process counter (defined in the main module to be inherited by child processes). +global_total_counter = None + +# ✅ Parse command-line arguments +parser = argparse.ArgumentParser(description="Token Length Processor") +parser.add_argument("--config", type=str, default="config.yaml", help="Path to config.yaml") +parser.add_argument("--log-level", type=str, default=None, + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Override log level from config") +args = parser.parse_args() + +# ✅ Load configuration file +CONFIG_PATH = Path(args.config) +if not CONFIG_PATH.exists(): + raise FileNotFoundError(f"Configuration file does not exist: {CONFIG_PATH}") +with open(CONFIG_PATH, 'r', encoding='utf-8') as f: + cfg = yaml.safe_load(f) + +# ✅ Read parameters from configuration file, override existing constants +MAX_TOKEN_LEN = cfg['sample']['max_len'] +task_type = cfg['sample']['task_type'] +DEL_ONE_TOKEN = cfg['sample']['del_one_token'] + +DEFAULT_DIRECTORY = Path(cfg['data']['directory']) +OUTPUT_FILE = Path(cfg['data']['output_base']) +TOKEN_INFO_FILE = Path(cfg['data']['output_token']) +CKPT_DIR = cfg['model']['checkpoint'] +MIN_PIXELS = cfg['image']['min_pixels'] +MAX_PIXELS = cfg['image']['max_pixels'] +image_resolution = cfg['image']['baidu_resolution'] +TIME_OUT = cfg['processing']['time_out'] +# 归并参数(仅两级:stage0 → stage1) +STAGE1_CHUNK = cfg['processing']['stage1_merge_chunk'] +chunk_size = cfg['processing']['chunk_size'] +n_workers = cfg['processing']['n_workers'] +MIN_WORKERS = cfg['processing']['min_workers'] +MAX_WORKERS = cfg['processing']['max_workers'] +use_shm = cfg['logging']['use_shm'] +log_level = cfg['logging']['level'] +log_file = cfg['logging']['file'] +if args.log_level: + log_level = args.log_level.upper() + +# ✅ Configure logging - detailed record of data flow and merge process +file_handler = logging.FileHandler( + log_file, + delay=True, + encoding='utf-8' +) +stream_handler = logging.StreamHandler() + +logging.basicConfig( + level=log_level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[file_handler, stream_handler] +) +logger = logging.getLogger(__name__) + +EXTENSIONS = (".json", ".jpg") + + +temp_dir = '/dev/shm' if use_shm else None # None 表示使用系统默认临时目录 + +def count_lines(file_path): + """ Count valid lines in file (non-empty and contain delimiter)""" + if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: + return 0 + try: + with open(file_path, 'r', encoding='utf-8') as f: + return sum(1 for line in f if line.strip() and ':' in line.strip()) + except Exception as e: + logger.error(f"❌ Error counting lines for {file_path}: {str(e)}") + return 0 + +def find_paired_files(directory): + directory = Path(directory) + files = os.listdir(directory) + json_set = {f[:-5] for f in files if f.lower().endswith('.json')} + img_set = {f[:-4] for f in files if f.lower().endswith(('.jpg', '.jpeg'))} + paired = json_set & img_set + logger.info(f"Found {len(paired)} file pairs.") + return paired + +def find_valid_files(fname_json, rel_img_path): + from s1_mr_sft_data_proc_indcoding import split_json_file + valid_names = split_json_file( + fname_json, + rel_img_path, + chunk_dim=2000, + m=8 + ) + return valid_names + +def find_valid_json(directory): + directory = Path(directory) + files = os.listdir(directory) + json_set = {f[:-5] for f in files if f.lower().endswith('.json')} + logger.info(f"Found {len(json_set)} JSON files.") + return json_set + +def write_base_names_to_file(base_names, output_file): + """ Write paired file names to output file""" + try: + content = "\n".join(sorted(base_names)) + "\n" + with open(output_file, 'w', encoding='utf-8') as f: + f.write(content) + logger.info(f"ℹ️ Wrote {len(base_names)} paired filenames to {output_file}") + except Exception as e: + logger.error(f"❌ Error writing to {output_file}: {str(e)}") + raise + + +def read_lines_in_chunks(file_path, chunk_size): + """ Read file content in chunks, each chunk contains up to chunk_size lines""" + file_path = Path(file_path) + if not file_path.exists(): + raise FileNotFoundError(f"{file_path} does not exist.") + + with open(file_path, 'r', encoding='utf-8') as f: + while True: + chunk = [line.strip() for _, line in zip(range(chunk_size), f) if line.strip()] + if not chunk: + break + logger.info(f"ℹ️ Read data chunk containing {len(chunk)} samples.") + yield chunk + + +# ✅ Precompile template for efficiency +""" +Todo: + 1) put into .yaml + 2) Add support for user-defined processing functions beyond "jinja2+processor" +""" +if task_type=="pretrain": + CAP_TEMPLATE = Template("<|vision_start|><|image_pad|><|vision_end|>{{ captions[0].content }}<|im_end|>") +elif task_type=="sft": + chat_template = """{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{{ message['content'] | replace('', '<|vision_start|><|image_pad|><|vision_end|>') }}<|im_end|>\n{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}""" + CAP_TEMPLATE = Template(chat_template) + pass + +def process_sample(json_path, img_path, processor): + """ Process a single sample and return a tuple (token_len, file name)""" + try: + if not Path(json_path).exists(): + raise FileNotFoundError(f"❌ JSON file does not exist: {json_path}") + # if not Path(img_path).exists(): + # raise FileNotFoundError(f"❌ Image file does not exist: {img_path}") + + # Read and render JSON content + with open(json_path, 'r', encoding='utf-8') as f: + json_data = json.load(f) + # with open(json_path, 'rb') as f: + # json_data = orjson.loads(f.read()) + if task_type=="pretrain": + txt_input = CAP_TEMPLATE.render(captions=json_data['captions']) + elif task_type=="sft": + # txt_input = CAP_TEMPLATE.render(json_data) + txt_input = CAP_TEMPLATE.render(json_data,tokenize=False, add_generation_prompt=False) + if img_path=="_____.jpg": + img_input = None + else: + def baidu_img_proc(image, image_resolution): + image = Image.open(image) + if max(image.width, image.height) > image_resolution: + resize_factor = image_resolution / max(image.width, image.height) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height), resample=Image.NEAREST) + + return image + + if image_resolution: + img_path = baidu_img_proc(img_path, image_resolution) + + + img_input = fetch_image({ + 'type': 'image', + 'image': img_path, + "min_pixels": MIN_PIXELS, + "max_pixels": MAX_PIXELS, + }) + # print(img_input) + # Calculate token number + base_name = Path(json_path).stem + inputs = processor( + text=[txt_input], + images=img_input, + videos=None, + padding=True, + return_tensors="pt", + ) + # print(inputs["input_ids"]) + # print(inputs["input_ids"].shape) + return (inputs["input_ids"].shape[1], base_name) + + except Exception as e: + return (None, f"❌ Failed to process sample [{Path(json_path).stem}]: {str(e)}") + + +def get_adaptive_workers(min_workers=20, max_workers=96): + """Dynamically adjust the number of threads based on system load""" + try: + cpu_usage = psutil.cpu_percent(interval=0.5) + mem_usage = psutil.virtual_memory().percent + if cpu_usage > 80 or mem_usage > 85: + adjusted = max(min_workers, max_workers // 2) + logger.info(f"High system load, adjusting thread count to {adjusted} (CPU: {cpu_usage}%, Memory: {mem_usage}%)") + return adjusted + return max_workers + except Exception as e: + logger.warning(f"System load check failed, falling back to {max_workers} threads: {str(e)}") + return max_workers + +gt_maxlen=0 +def merge_files_by_token(input_files, output_file, max_token=MAX_TOKEN_LEN): + """Merge multiple sorted files by token_len, filter out lines > max_token, return (output_path, line_count)""" + if not input_files: + logger.warning("⚠️ No files to merge") + return (None, 0) + + # Validate input files and count total lines + valid_files = [] + total_lines = 0 + for f in input_files: + line_count = count_lines(f) + if line_count > 0: + valid_files.append(f) + total_lines += line_count + logger.debug(f"ℹ️ Merging file {os.path.basename(f)} with {line_count} entries.") + else: + logger.warning(f"⚠️ Skipping empty or invalid file: {os.path.basename(f)}") + + if not valid_files: + return (None, 0) + + # Define a sorting key (sorted by the token_len integer) + def sort_key(line): + # _, token_str = line.strip().split(':', 1) + token_str = line.strip().split(':')[-1] + return int(token_str) + + try: + with open(output_file, 'w', encoding='utf-8') as out_f: + # Create iterator for all files. + iterators = [] + file_handles = [] + for fpath in valid_files: + try: + fh = open(fpath, 'r', encoding='utf-8') + file_handles.append(fh) + iterators.append(((sort_key(line), line) for line in fh)) + except Exception as e: + logger.error(f"❌ 打开文件 {os.path.basename(fpath)} 失败: {str(e)}") + + # Merge sort and write, filtering out lines with token count > max_token (other conditions can be added later) + filtered_max_len = 0 + for _, line in merge(*iterators, key=lambda x: x[0]): + token_str = line.strip().split(':')[-1] + if int(token_str) <= max_token: + out_f.write(line) + else: + logger.warning(f"⚠️ Token length: {token_str} > {max_token}: filtered out!") + filtered_max_len+=1 + gt_maxlen + + # Close all file handles + for fh in file_handles: + try: + fh.close() + except Exception as e: + logger.warning(f"⚠️ 关闭文件 {fh.name} 失败: {str(e)}") + + # Verify output file integrity + output_lines = count_lines(output_file)+filtered_max_len + if output_lines != total_lines: # Filter out lines with token count > max_token + logger.error(f"❌ Merge data loss! {total_lines} lines in, {output_lines} lines out. Deleted bad file.") + if os.path.exists(output_file): + os.remove(output_file) + return (None, 0) + else: + logger.info(f"✅ 📊 Merge successful. Input: {total_lines} lines, Output: {output_lines-filtered_max_len} lines (token ≤ {max_token}).") + + return (output_file, output_lines-filtered_max_len) + except Exception as e: + logger.error(f"❌ File merge failed: {str(e)}") + if os.path.exists(output_file): + try: + os.remove(output_file) + except Exception as e: + logger.warning(f"⚠️ Failed to delete the corrupted file {output_file}: {str(e)}") + return (None, 0) + + +def stage1_merger(input_queue, chunk_size, stage1_files, stop_event): + """ + Fixed version of stage1 merging threads + - Ensure all stage0 files are merged, including the last batch with fewer than 10 files + - Resolve thread timeout and data loss issues + """ + buffer = [] + batch_counter = 0 + logger.info(f"💡 Stage1 merge thread started. Merging every {chunk_size} stage0 files.") + + try: + # Loop condition: the queue has files, or the buffer has files, or no stop signal is received. + while (not input_queue.empty()) or buffer or (not stop_event.is_set()): + # Fetch files from the queue (with timeout to prevent permanent blocking) + if not input_queue.empty(): + try: + file_path = input_queue.get(timeout=1) # Timeout after 1 second to avoid permanent blocking. + buffer.append(file_path) + input_queue.task_done() + logger.debug(f"ℹ️ Stage1 received file {os.path.basename(file_path)}, buffer: {len(buffer)}/{chunk_size}") + + # If the buffer has enough files, execute the merge + if len(buffer) >= chunk_size: + batch_counter += 1 + merged_file = tempfile.NamedTemporaryFile( + mode='w', delete=False, + prefix=f"stage1_batch{batch_counter:03d}_", + encoding='utf-8', + dir=temp_dir + ).name + + # 执行合并 + merged_path, line_count = merge_files_by_token(buffer, merged_file) + if merged_path and line_count > 0: + stage1_files.append(merged_path) + logger.info(f"📊 Stage1 batch {batch_counter} done: {os.path.basename(merged_path)} ({line_count} lines, {len(buffer)} files merged).") + else: + logger.warning(f"⚠️ Skipping Stage1 batch {batch_counter} due to merge failure.") + + # Clear the buffer after successful merge + buffer = [] + except Empty: + continue # Continue the loop if the queue is empty. + except Exception as e: + logger.error(f"❌ Stage1 error while processing file: {str(e)}", exc_info=True) + else: + # If the queue is empty, check if we need to force merge remaining files. + if buffer and stop_event.is_set(): + # If the stop signal is received and the buffer has files, force merge. + batch_counter += 1 + merged_file = tempfile.NamedTemporaryFile( + mode='w', delete=False, + prefix=f"stage1_remaining_batch{batch_counter:03d}_", + encoding='utf-8', + dir=temp_dir + ).name + + merged_path, line_count = merge_files_by_token(buffer, merged_file) + if merged_path and line_count > 0: + stage1_files.append(merged_path) + logger.info(f"📊 Stage1 remaining files merged: {os.path.basename(merged_path)} with {line_count} entries from {len(buffer)} files.") + else: + logger.warning(f"⚠️ Skipping Stage1 remaining batch due to merge failure.") + buffer = [] + else: + # Sleep briefly to reduce CPU usage + threading.Event().wait(0.5) + + # Final check: Ensure the buffer is empty (to prevent omissions) + if buffer: + logger.error(f"❌ Stage1 thread exited with {len(buffer)} files in buffer unprocessed! Data loss may occur.") + + except Exception as e: + logger.error(f"❌ Stage1 thread exception exit: {str(e)}", exc_info=True) + finally: + logger.info(f"📊 Stage1 thread exit, {len(stage1_files)} files generated.") + +# Processing function for each process (responsible for handling a large chunk) +def process_chunk(args): + """ + Processing logic for each process: handles a large chunk of data, with parallel processing using multiple threads. + + Args: + args: A tuple containing chunk data, processor configuration, and queues for inter-process communication. + """ + # Get the global counter from the global variable, not from the arguments + global global_total_counter + + chunk_idx, chunk, ckpt_dir, min_pixels, max_pixels, stage0_queue = args + processor = None + processed_count = 0 # Record the number of valid samples processed by the current process + + try: + # Each process initializes its own processor (processors cannot be shared between processes) + # quant_config = BitsAndBytesConfig(load_in_4bit=True) + processor = AutoProcessor.from_pretrained( + ckpt_dir, + min_pixels=min_pixels, + max_pixels=max_pixels, + trust_remote_code=True, + use_fast=False + ) + # Generate the list of file paths for the current chunk + full_paths = [] + for fn in chunk: + cur_json = str(DEFAULT_DIRECTORY / f"{fn}.json") + # logger.info(f"👉 Process {multiprocessing.current_process().name} json file: {cur_json}.....{type(cur_json)}") + if f"{fn}.json".startswith("__img--output_"): + cur_img = "_____.jpg" + # cur_img = str(DEFAULT_DIRECTORY / f"{cur_img}") + else: + with open(cur_json, 'r', encoding='utf-8') as f: + data = json.load(f) + cur_img = data['images'][0] + cur_img = str(DEFAULT_DIRECTORY / f"{cur_img}") + full_paths.append(cur_json) + full_paths.append(cur_img) + # print(f"--------------cur_json:{cur_json}, cur_img:{cur_img}-------------------") + + + n_samples = len(chunk) + logger.info(f"👉 Process {multiprocessing.current_process().name} starts processing chunk {chunk_idx} with {n_samples} samples") + + # Process each sample in the chunk using a thread pool (reuse threads) + n_workers = get_adaptive_workers(min_workers=MIN_WORKERS, max_workers=MAX_WORKERS) # Reduce the number of workers per process + chunk_results = [] + with ThreadPoolExecutor( + max_workers=n_workers, + thread_name_prefix=f"proc-{multiprocessing.current_process().pid}-thread" + ) as executor: + tasks = [ + executor.submit( + process_sample, + full_paths[idx*2], + full_paths[idx*2+1], + processor + ) for idx in range(n_samples) + ] + + # Collect the results of each thread task + for future in as_completed(tasks): + try: + token_len, name = future.result() + if DEL_ONE_TOKEN: + token_len += 1 + if token_len is not None: + chunk_results.append((token_len, name)) + processed_count += 1 # Record the number of valid samples processed + else: + logger.warning(name) + except Exception as e: + logger.error(f"❌ Process {multiprocessing.current_process().name} thread task error: {str(e)}") + + # Write the results to the stage0 file and put it into the cross-process queue + if chunk_results: + chunk_results_sorted = sorted(chunk_results, key=lambda x: x[0]) + with tempfile.NamedTemporaryFile( + mode='w+', delete=False, + prefix=f"stage0_chunk{chunk_idx:03d}_", + encoding='utf-8', + dir=temp_dir + ) as f: + stage0_file = f.name + for token_len, name in chunk_results_sorted: + f.write(f"{name}:{token_len}\n") + + line_count = count_lines(stage0_file) + stage0_queue.put(stage0_file) + + proc_status = "🟢" if processed_count==n_samples else "🟡" + logger.info(f"{proc_status} 进程 {multiprocessing.current_process().name} 完成块 {chunk_idx},有效样本 {processed_count}/{n_samples}") + + # 【Key】Accumulate total data volume across processes (using Value atomic operations) + with global_total_counter.get_lock(): + global_total_counter.value += processed_count + + return stage0_file # Return the path of the generated stage0 file for subsequent cleanup + + except Exception as e: + logger.error(f"❌ Process {multiprocessing.current_process().name} failed: {str(e)}") + finally: + if processor: + del processor + return None + + +### +def main(): + global global_total_counter # Reference the global counter + processor = None # Model processor instance + stage0_files = [] # Record all stage0 files (for verification and cleanup) + stage1_files = [] # Record all stage1 files (for final merging) + + try: + + logger.info(f"💡 --------------Start the data processing flow--------------") + + # 1. Find paired files and write to a temporary file (samples where the JSON and JPG file names are the same) + # base_names = find_paired_files(DEFAULT_DIRECTORY) # DEFAULT_DIRECTORY is the location for storing raw data (JPG and JSON files) + base_names = find_valid_json(DEFAULT_DIRECTORY) + total_original = len(base_names) # Total number of original samples + logger.info(f"👉 Found {total_original} pairs of original sample files") + if total_original == 0: + logger.warning("⚠️ No original samples found, exiting the program") + return + # Write the paired file names to a file for subsequent chunk reading + write_base_names_to_file(base_names, OUTPUT_FILE) + + # 2. Initialize the cross-process queue (for passing stage0 file paths to the merging thread) + manager = Manager() # Process-sharing queue requires Manager + stage0_queue = manager.Queue() + stop_event = manager.Event() # Cross-process stop signal + + # Cross-process counter for counting the total number of processed samples (initial value: 0) + global_total_counter = Value('i', 0) # 'i' indicates the integer type, used for inter-process sharing. + + # 3. Start the stage1 merging thread (daemon thread) + stage1_thread = threading.Thread( + target=stage1_merger, + args=(stage0_queue, STAGE1_CHUNK, stage1_files, stop_event), + daemon=True + ) + stage1_thread.start() + logger.info("💡 stage1 merging thread has started") + + # 4. Process data and generate stage0 files (each chunk is processed and sorted individually) + # n_workers = 96 #get_adaptive_workers() + + # 4.1 Read all data chunks (for distribution to multiple processes) + # chunk_size = chunk_size # The size of each large chunk processed by each process (adjust based on memory) + all_chunks = list(read_lines_in_chunks(OUTPUT_FILE, chunk_size)) + total_chunks = len(all_chunks) + n_processes = min(multiprocessing.cpu_count(), total_chunks) + logger.info(f"👉 Split into {total_chunks} chunks, launching {n_processes} processes.") + + # 4.2 Prepare process pool arguments (including model configuration, queue, etc.) + process_args = [ + ( + idx + 1, # chunk index + chunk, # chunk data + CKPT_DIR, # model path + MIN_PIXELS, + MAX_PIXELS, + stage0_queue, # cross-process queue for stage0 files + ) for idx, chunk in enumerate(all_chunks) + ] + + # 4.3 Launch process pool (number of processes recommended to set to 1~2 times the number of CPU cores) + with Pool(processes=n_processes) as process_pool: + # Process all large chunks in parallel. + # stage0_files = process_pool.map(process_chunk, process_args) + result = process_pool.map_async(process_chunk, process_args) + try: + stage0_files = result.get(timeout=TIME_OUT) # Set timeout for process completion (in seconds) + except multiprocessing.TimeoutError: + logger.error("❌ Some processes timed out, force termination") + process_pool.terminate() + + # Filter out empty results. + stage0_files = [f for f in stage0_files if f is not None] + logger.info(f"✅ All processes completed, generated {len(stage0_files)} stage0 files.") + # Statistics + total_processed = global_total_counter.value # Directly get from global variable # Total processed samples + logger.info(f"👉 Total original samples: {total_original}, Valid processed samples: {total_processed}") + + # Data integrity check + if total_processed != total_original: + logger.warning(f"❌ Data integrity check failed! Original {total_original} samples, processed {total_processed} samples, difference {total_original - total_processed} samples.") + else: + logger.info("✅ Data integrity check passed, all samples were processed successfully.") + + # 5. Wait for all stage0 files to be processed (ensure all files are merged) + # Wait for all stage0 files to be processed in the queue. + logger.info("🔄 Waiting for stage0 queue to process all files...") + stage0_queue.join() # Block until all stage0 files are consumed + logger.info("💡 All stage0 files have been processed and merged.") + + # Send a stop signal to the stage1 threads to force processing of remaining files. + logger.info("💡 Sending stop signal to stage1 thread to force processing remaining files...") + stop_event.set() + + timeout_counter = 0 + while stage1_thread.is_alive() and timeout_counter < 60: + logger.debug(f"🔄 Waiting for stage1 thread to complete ({timeout_counter}/60 seconds)...") + threading.Event().wait(1) # Wait for 1 second to retry + timeout_counter += 1 + + if stage1_thread.is_alive(): + logger.warning("⚠️ Stage1 thread did not exit on timeout. Anomaly suspected (force-merge of remaining files was attempted).") + else: + logger.info("💡 Stage1 thread has exited normally.") + + # Verify that the number of stage1 files matches (1 stage1 file is merged from every 10 stage0 files; batches with fewer than 10 stage0 files also count as 1 stage1 file) + expected_stage1_count = (len(stage0_files) + STAGE1_CHUNK - 1) // STAGE1_CHUNK + if len(stage1_files) != expected_stage1_count: + logger.warning(f"⚠️ ℹ️ Stage1 file count mismatch! Expected {expected_stage1_count} files, but got {len(stage1_files)} files.") + else: + logger.info(f"✅ Stage1 file count verification passed: {len(stage1_files)} files.") + + # 6. Finally, merge all stage1 files into token_info_1.txt. + if not stage1_files: + logger.warning("⚠️ No stage1 files were generated. Please check if the intermediate processing steps encountered any errors.") + return + + # Count the total data volume of stage1 files. + stage1_total = sum(count_lines(f) for f in stage1_files) + logger.info(f"ℹ️ Starting final merge: {len(stage1_files)} stage1 files, total records: {stage1_total}.") + + # Merge into the final file. + final_path, final_lines = merge_files_by_token(stage1_files, TOKEN_INFO_FILE) + + if final_path and final_lines > 0: + logger.info(f"✅ Final result file generated: {TOKEN_INFO_FILE}, total records: {final_lines}.") + # Verify the total data volume. + if final_lines != total_processed: + logger.error(f"❌ Data volume mismatch! Processed {total_processed} records, but final file contains {final_lines} records.") + else: + logger.info("✅💡 Data volume verification passed, all records have been successfully written to the final file.") + else: + logger.error("❌ Final file merge failed.") + + # Verify the final file again after merge. + if os.path.exists(TOKEN_INFO_FILE): + final_count = count_lines(TOKEN_INFO_FILE) + logger.info(f"ℹ️ Final result file contains {final_count} records.") + if final_count != total_processed: + logger.error(f"❌ Final file data incomplete! Processed {total_processed} records, but final file contains {final_count} records.") + else: + logger.info("✅ Final file data integrity verification passed.") + + except Exception as e: + logger.error(f"❌ Critical failure in main process: {str(e)}", exc_info=True) + finally: + # Clean up resources. + if processor: + del processor + + # Ensure the stop signal is triggered. + stop_event.set() + + if stage1_thread and stage1_thread.is_alive(): + stage1_thread.join(timeout=2) + + # Wait for the final file to be written. + threading.Event().wait(2) + + # Clean up temporary files (keep the final file). + all_temp_files = stage0_files + stage1_files + for fpath in all_temp_files: + if fpath != str(TOKEN_INFO_FILE) and os.path.exists(fpath): + try: + os.remove(fpath) + logger.debug(f"✅ Temporary file cleaned: {os.path.basename(fpath)}") + except Exception as e: + logger.warning(f"❌ Failed to clean temporary file {os.path.basename(fpath)}: {str(e)}") + + logger.info("✅ Program execution completed.") + + +if __name__ == "__main__": + main() diff --git a/tools/data_preprocess/offline_packing/s1_mmr_sft_data_proc_indcoding.py b/tools/data_preprocess/offline_packing/s1_mmr_sft_data_proc_indcoding.py new file mode 100644 index 0000000..88643e1 --- /dev/null +++ b/tools/data_preprocess/offline_packing/s1_mmr_sft_data_proc_indcoding.py @@ -0,0 +1,215 @@ +import json +import os +import shutil +import logging +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import Manager, cpu_count, Process +from tqdm import tqdm + +# 1)Assign unique numeric IDs to the __img--output for QA data, separate from the numbering for VQA data. + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger() + +# ---------- tool ---------- +def extract_filename_without_ext(image_path: str) -> str: + return os.path.splitext(os.path.basename(image_path))[0] + + + +# ---------------------------- patch 1 ---------------------------- +# feat: Add a thread-safe counter for duplicate names +from collections import defaultdict +import re +import threading + +def _unique_filename(name: str, name_counter, name_lock) -> str: + base, ext = os.path.splitext(name) + with name_lock: + # Use .get() to avoid KeyError. + cnt = name_counter.get(name, 0) + name_counter[name] = cnt + 1 + if cnt == 0: + return name + return f"{base}_{cnt}{ext}" + +# ----------------------------------------------------------------- + + + +# ---------- processing single data item ---------- +def _process_single_item(args): + """ + Thread-level: Process a single data item. + Parameters are packed into a tuple for easy submission to ThreadPoolExecutor. + """ + # item, base_dir, output_dir, rel_img_path, no_img_indices = args + (item, base_dir, output_dir, rel_img_path, no_img_indices, + name_counter, name_lock) = args # patch 6 + + # ---------- Organize the original image paths. ---------- + original_image_paths = [] + if item.get("images"): + original_image_paths = item["images"] if isinstance(item["images"], list) else [item["images"]] + else: + item["images"] = [] + + if rel_img_path: + original_image_paths = [ + os.path.normpath(os.path.join(base_dir, rel_img_path, p)) + for p in original_image_paths + ] + else: + original_image_paths = [ + os.path.normpath(os.path.join(base_dir, p)) + for p in original_image_paths + ] + + # ---------- This script renames all images to a consistent format and then copies them to the output folder. ---------- + new_image_basenames = [] + for src_path in original_image_paths: + if not os.path.exists(src_path): + logger.warning(f"IMG not found: {src_path}") + continue + old_name = os.path.basename(src_path) + # new_name = _unique_filename(old_name) + new_name = _unique_filename(old_name, name_counter, name_lock) + new_image_basenames.append(new_name) + + dst_path = os.path.join(output_dir, new_name) + try: + shutil.copy2(src_path, dst_path) + except Exception as e: + logger.error(f"Image copy failed: {src_path} -> {dst_path} | {e}") + + # ---------- Update the JSON with new image basenames ---------- + item["images"] = new_image_basenames + + + #--------------patch 001---------- + # ✨ New:If no images exist, return None directly. + if original_image_paths and not new_image_basenames: + logger.info(f"Skip item with no valid images: {item.get('id', item['_orig_index'])}") + return None + #--------------patch 001 end---------- + + # ---------- Generate a filename for the JSON file ---------- + if new_image_basenames: + json_name_root = os.path.splitext(new_image_basenames[0])[0] + else: + idx_in_no_img = no_img_indices.index(item['_orig_index']) + json_name_root = f"__img--output_{idx_in_no_img:08d}" + + # json_name = _unique_filename(json_name_root + ".json") + json_name = _unique_filename(json_name_root + ".json", name_counter, name_lock) + json_path = os.path.join(output_dir, json_name) + try: + with open(json_path, 'w', encoding='utf-8') as f: + json.dump(item, f, indent=4, ensure_ascii=False) + except Exception as e: + logger.error(f"JSON write failed: {json_path} | {e}") + + return os.path.splitext(json_name)[0] + +# ---------- Process-level ---------- +def _worker_process(job_queue, result_list, base_dir, output_dir, + rel_img_path, m, no_img_indices, + name_counter, name_lock): # <-- patch4 + while True: + try: + chunk = job_queue.get_nowait() + except: + break + + logger.info(f"[PID {os.getpid()}] Processing chunk with {len(chunk)} items.") + # Construct the argument list + arg_list = [(item, base_dir, output_dir, rel_img_path, no_img_indices, name_counter, name_lock) + for item in chunk] + + valid_names = [] + with ThreadPoolExecutor(max_workers=m) as pool: + for fut in tqdm(pool.map(_process_single_item, arg_list), + total=len(arg_list), + desc=f"PID-{os.getpid()}", + leave=False): + if fut is not None: # ✨ Filter out None items. patch 002 + valid_names.append(fut) + result_list.extend(valid_names) + +# ---------- Main entry ---------- +def split_json_file(fin_name, rel_img_path=None, *, chunk_dim=1000, m=8): + # read json + try: + with open(fin_name, 'r', encoding='utf-8') as f: + data = json.load(f) + except Exception as e: + logger.error(f"JSON read failed: {e}") + return set() + + if not isinstance(data, list): + logger.error("Expected JSON root to be an array") + return set() + + # Log original indices & gather indices missing images + for i, item in enumerate(data): + item['_orig_index'] = i + no_img_indices = [i for i, item in enumerate(data) if not item.get("images")] + + # Prepare directories + base_dir = os.path.dirname(os.path.abspath(fin_name)) + output_dir = os.path.join(base_dir, "split_json_files") + # output_dir = os.path.join(base_dir, "split_json_fs") + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + os.makedirs(output_dir, exist_ok=True) + + # Split data into chunks + total = len(data) + num_chunks = (total + chunk_dim - 1) // chunk_dim + chunks = [data[i * chunk_dim:(i + 1) * chunk_dim] for i in range(num_chunks)] + + max_workers = min(num_chunks, cpu_count()) + logger.info(f"[JOB] Total: {total} | Chunks: {num_chunks} | Procs: {max_workers} | Threads/Proc: {m}") + + + with Manager() as manager: + job_queue = manager.Queue() + for c in chunks: + job_queue.put(c) + + result_list = manager.list() + name_counter = manager.dict() # <-- new patch2 + name_lock = manager.Lock() # <-- new patch2 + + processes = [ + Process(target=_worker_process, + args=(job_queue, result_list, base_dir, + output_dir, rel_img_path, m, no_img_indices, + name_counter, name_lock)) # <-- new patch3 + for _ in range(max_workers) + ] + for p in processes: + p.start() + for p in processes: + p.join() + + all_valid_names = set(result_list) + + logger.info(f"[JOB] All {total} items processed. Valid JSON files: {len(all_valid_names)}") + return all_valid_names + +# ---------- script ---------- +if __name__ == "__main__": + # f_json = "/vlm/data/llava_next_500/sampled_data.json" + f_json = "/data_1/llava_next_raw_full/megatron_format_780k.json" + rel_img = "images" + res = split_json_file( + f_json, + rel_img, + chunk_dim=2000, + m=8 + ) + print(f"Generated {len(res)} files.") \ No newline at end of file diff --git a/tools/data_preprocess/offline_packing/s2_prepare_rawsamples-mmr_sft_780k-8k-fast.py b/tools/data_preprocess/offline_packing/s2_prepare_rawsamples-mmr_sft_780k-8k-fast.py new file mode 100644 index 0000000..22cdec9 --- /dev/null +++ b/tools/data_preprocess/offline_packing/s2_prepare_rawsamples-mmr_sft_780k-8k-fast.py @@ -0,0 +1,625 @@ +# ### 所有代码放到一起,只运行这一块就可以 +# Step1: +# python -u s2_prepare_rawsamples-emova.py 2>&1 | tee s2_proc.log +# python -u s2_prepare_rawsamples-llava_vqa.py 2>&1 | tee s2_proc_llava.log +# python -u s2_prepare_rawsamples-vqa_1000k.py 2>&1 | tee ./logs/s2_proc_vqa_1000k.log +# python -u s2_prepare_rawsamples-vqa_1000k-16k.py 2>&1 | tee ./logs/s2_proc_vqa_1000k-16k.log +# python -u s2_prepare_rawsamples-vqa_5500k-16k.py 2>&1 | tee ./logs/s2_proc_vqa_5500k-16k.log +# python -u s2_prepare_rawsamples-vqa_5500k-16k-fast.py 2>&1 | tee ./logs/s2_proc_vqa_5500k-16k-fast.log +# python -u s2_prepare_rawsamples-vqa_pretrain_5M-8k-fast.py 2>&1 | tee ./logs/s2_proc_vqa_pretrain_5M-8k-fast.log + +# python -u s2_prepare_rawsamples-mr_sft_780k-8k-fast.py 2>&1 | tee ./logs/s2_prepare_rawsamples-mr_sft_780k-8k-fast.log + +import bisect +import os +import re +import json +import sys +import shutil +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +# ### 参数配置 +# target_directory = "/workspace/test/packing" # 最终数据存放的位置 + +current_file = Path(__file__).resolve() +target_directory = current_file.parent +newDir = "raw_packing_data_mr_sft_780k-8k-fast" # 转 webdataset 之前数据的存放位置 (jpg, json) +SRC_DIR_IMGS = "/data_1/llava_next_raw_full/split_json_files" # 原始 img 数据的存放位置 +SRC_DIR_JSONS = "/data_1/llava_next_raw_full/split_json_files" # 原始 json 数据的存放位置 +SRC_DST_EXTENSIONS = ("jpg", "json") +f_toklens_originalsample = os.path.join(target_directory, "token_info_MR_sft_780k_8k.txt") +PACKED_LENGTH = 8192 +dst_dir = os.path.join(target_directory,newDir) +MAX_WORKERS = 96 # 线程池大小(根据CPU核心数和IO性能调整) + + +""" +task_type 的设置: + sft:VQA 格式的 pretrain + pretrain:caption 格式的 pretrain + bmr:混合数据集多轮对话格式的 sft +""" +task_type = "bmr" + + + +f_TEST=False # test 示例输出(仅做测试用:生成少量数据) +n_packed_samples=400 # 测试用,输出几条打包后的数据 + +# PROMPTS = # Creating a list of the provided English prompts +PROMPTS = [ + "What about this picture?", + "Please provide a vivid description of the image.", + "Please Depict the image in words." + "Could you please transcribe thr image into a descriptive paragraph?" + "What is the content of this figure?", + "What do you see here?", + "Tell me about this image.", + "What's going on in this artwork?", + "What is depicted in this painting?", + "What is the subject matter here?", + "What can you make out in this picture?", + "What's the main thing shown in this image?", + "What's the gist of this artwork?", + "What's the essence of this figure?", + "What's the general idea here?", + "What does this image show?", + "What's the core element in this painting?", + "What's the overview of this scene?", + "What's the primary focus of this artwork?", + "What's the fundamental subject matter?", + "What's the general view presented?", + "What's the main impression given by this picture?", + "What's the central theme shown?", + "What's the overall presentation here?", + "What's the key element you notice?", + "What's the fundamental concept in this image?", + "What's the overall content?", + "What's the main thing you get from this?", + "What's the general subject?", + "What's the core idea conveyed?", + "What's the basic representation?", + "What's the main point of this figure?" +] + +import random + +def find_long_file_pairs(directory, length_threshold=62): + """ + 找出长文件(img,json)对中的图像文件,返回带有图像扩展名的完整文件名 + + 参数: + directory: 要检查的目录路径 + length_threshold: 文件名长度阈值,默认62 + + 返回: + 符合条件的图像文件名(带扩展名)列表 + """ + import os + from collections import defaultdict + # 存储所有文件的文件名部分及其对应的完整文件名 + file_parts = defaultdict(list) + image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp') + + try: + # 遍历目录中的所有文件,按文件名部分分组 + for filename in os.listdir(directory): + name_part, ext = os.path.splitext(filename) + ext = ext.lower() + # 只关注图片和json文件 + if ext in ('.json',) + image_extensions: + file_parts[name_part].append(filename) + + # 找出符合条件的图像文件 + long_image_files = [] + for name_part, filenames in file_parts.items(): + # 检查文件名长度和文件对完整性 + if (len(name_part) > length_threshold and + any(f.endswith('.json') for f in filenames) and + any(f.lower().endswith(image_extensions) for f in filenames)): + + # 只添加图像文件 + for filename in filenames: + if filename.lower().endswith(image_extensions): + long_image_files.append(filename) + + return long_image_files + + except FileNotFoundError: + # print(f"错误:目录 '{directory}' 不存在") + return [] + except PermissionError: + # print(f"错误:没有访问目录 '{directory}' 的权限") + return [] + except Exception as e: + # print(f"处理目录时发生错误:{str(e)}") + return [] + + +# res_long_img_names = find_long_file_pairs(SRC_DIR_JSONS) + +def filter_filenames(filenames, prefix, exclude_suffix): + """ + 筛选出以指定前缀开头且不以指定后缀结尾的文件名 + + 参数: + filenames: 文件名列表 + prefix: 文件名需要包含的前缀(如"james-tissot") + exclude_suffix: 需要排除的文件后缀(如"json") + + 返回: + 符合条件的文件名列表 + """ + # 转义前缀中的特殊字符,确保正则匹配正确 + escaped_prefix = re.escape(prefix) + # 构建正则表达式模式 + pattern = rf'^{escaped_prefix}(?!.*\.{exclude_suffix}$).*$' + + # 编译正则表达式 + regex = re.compile(pattern) + + # 筛选符合条件的文件名 + return [filename for filename in filenames if regex.match(filename)] + +def get_random_prompts(prompts, n): + if n > len(prompts): + # 允许重复 + return random.choices(prompts, k=n)0 + else: + # 不允许重复 + return random.sample(prompts, n) + +# 全局变量 - 用元组存储(不可变,效率更高) +BASE_NAMES = [] # 初始化为空元组,后续会被替换 (所有在原始数据集中的 sample 名称, 已经按照 tokens 长度排序) + +def search_for_fit(numbers: List[int], capacity: int) -> int: + """Finds the index of largest number that fits into the knapsack with the given capacity.""" + index = bisect.bisect(numbers, capacity) + return -1 if index == 0 else (index - 1) + +def greedy_knapsack(numbers: List[int], capacity: int) -> Tuple[List[List[int]], List[List[int]]]: + r"""Implement efficient greedy algorithm with binary search for the knapsack problem. + 参数 + ---- + numbers : List[int] + 物品大小列表,可以为任意顺序(这里是升序输入进来的) + capacity : int + 背包容量 + + 返回 + ---- + Tuple[List[List[int]], List[List[int]]] + 第一个列表:每个背包里的物品大小 + 第二个列表:每个背包里物品对应的原始下标 + + """ + # 保存原始索引,与输入的numbers一一对应 + indexed_numbers = [(val, idx) for idx, val in enumerate(numbers)] + # 由于输入已排序,直接使用即可(保持与原逻辑一致的处理方式) + knapsacks = [] + index_knapsacks = [] + iii = int(0) + while indexed_numbers: + current_knapsack = [] + current_indices = [] + remaining_capacity = capacity + + while True: + # 提取当前数值列表用于查找(保持升序) + current_values = [val for val, idx in indexed_numbers] + index = search_for_fit(current_values, remaining_capacity) + if index == -1: + break # 没有可放入当前背包的物品了 + + # 取出找到的物品及其原始索引 + val, idx = indexed_numbers.pop(index) + remaining_capacity -= val + current_knapsack.append(val) + current_indices.append(idx) + + if iii%1000==0: + print(f"---------第{iii}个pack----------") + print(f"{current_knapsack}--->{sum(current_knapsack)}") + print(current_indices) + print(f"\n") + iii+=1 + knapsacks.append(tuple(current_knapsack)) + index_knapsacks.append(tuple(current_indices)) + + return tuple(knapsacks), tuple(index_knapsacks) # 使用了 tuple + +def extract_content(json_file): + try: + # 打开并加载JSON文件 + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + if task_type=="sft": + try: + user_content = next(msg["content"] for msg in data["messages"] if msg["role"] == "assistant") + return user_content + except Exception as e: + pass + # 提取content内容 + # 假设captions数组至少有一个元素 + elif task_type=="pretrain": + if data.get('captions') and len(data['captions']) > 0: + return data['captions'][0].get('content', "") + else: + assert 0, "未找到有效的caption内容" + # return "未找到有效的caption内容" + + except FileNotFoundError: + return f"错误:文件 {json_file} 不存在" + except json.JSONDecodeError: + return f"错误:文件 {json_file} 不是有效的JSON格式" + except Exception as e: + return f"提取过程中发生错误:{str(e)}" + +def extract_prompt(json_file): + try: + # 打开并加载JSON文件 + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 提取助手回复 + assistant_content = next(msg["content"] for msg in data["messages"] if msg["role"] == "user") + return assistant_content + + # # 提取图片路径(可选) + # image_path = data["images"][0] if data["images"] else None + + except FileNotFoundError: + return f"错误:文件 {json_file} 不存在" + except json.JSONDecodeError: + return f"错误:文件 {json_file} 不是有效的JSON格式" + except Exception as e: + return f"提取过程中发生错误:{str(e)}" + +def extract_img_prompt_content(json_file: str) -> Tuple[List[str], List[str], List[str]]: + try: + # 打开并加载 JSON 文件 + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 1)images + imgs = data.get("images", []) + if not imgs: + images = [] + else: + images = [os.path.join(SRC_DIR_IMGS,imgs[0])] + + messages = data.get("messages", []) + + assistant_contents = [ + msg["content"] + for msg in messages + if isinstance(msg, dict) and msg.get("role") == "assistant" and "content" in msg + ] + + user_contents = [ + msg["content"] + for msg in messages + if isinstance(msg, dict) and msg.get("role") == "user" and "content" in msg + ] + + return images, user_contents, assistant_contents + + except FileNotFoundError: + return f"错误:文件 {json_file} 不存在" + except json.JSONDecodeError: + return f"错误:文件 {json_file} 不是有效的JSON格式" + except Exception as e: + return f"提取过程中发生错误:{str(e)}" + +def prepare_dirs(target_dir, new_dir): + os.chdir(target_dir) + print(f"--------change to directory {target_dir}--------") + # 创建新目录 + if not os.path.exists(new_dir): + os.makedirs(new_dir) + print(f"Directory '{new_dir}' created.") + else: + print(f"Directory '{new_dir}' already exists.") + + +def dataset_tokinfo_generator(f_name): + """ + 数据集token信息生成器,逐行读取并解析文件内容 + + 参数: + f_name (str): 包含token信息的文件路径 + + 生成: + tuple: (base_name, token_len) - 解析后的基础文件名和token长度 + """ + try: + with open(f_name, 'r', encoding='utf-8') as f: + for line in f: + # 跳过空行 + stripped_line = line.strip() + if not stripped_line: + continue + + # 按冒号分割并验证格式 + parts = stripped_line.split(':') + if len(parts) == 2: + base_name = parts[0].strip() + token_len_str = parts[1].strip() + + try: + token_len = int(token_len_str) + yield (base_name, token_len) + except ValueError: + print( + f"警告: 无法将 '{token_len_str}' 转换为整数,已跳过该行", + file=sys.stderr + ) + continue + + except FileNotFoundError: + print(f"错误: 文件 '{f_name}' 不存在", file=sys.stderr) + return + except Exception as e: + print(f"处理文件时发生错误: {str(e)}", file=sys.stderr) + return + + +class TokenInfoReader: + """ + Token信息读取器 + + 支持分批读取、全量读取和断点续读功能,适用于处理包含token信息的文本文件。 + 文件格式要求: 每行一条记录,格式为 "base_name: token_len" + """ + + def __init__(self, f_name): + """ + 初始化读取器 + + 参数: + f_name (str): 包含token信息的文件路径 + """ + self.f_name = f_name + self.generator = dataset_tokinfo_generator(f_name) + self._current_position = 0 # 记录已读取的记录数量 + + def read(self, count=None): + """ + 读取记录 + + 参数: + count (int, optional): 要读取的记录数量,默认为None(读取全部剩余记录) + + 返回: + tuple: (base_names列表, token_lens列表, 实际读取数量) + """ + base_names = [] + token_lens = [] + read_count = 0 + + # 循环读取直到达到指定数量或文件结束 + while True: + # 检查是否已达到指定读取数量 + if count is not None and read_count >= count: + break + + try: + # 从生成器获取下一条记录 + base_name, token_len = next(self.generator) + base_names.append(base_name) + token_lens.append(token_len) + read_count += 1 + self._current_position += 1 + + except StopIteration: + # 已读取到文件末尾 + break + + return base_names, token_lens, read_count + + def get_current_position(self): + """ + 获取当前读取位置 + + 返回: + int: 已读取的记录总数 + """ + return self._current_position + + +def process_knapsack(s1, idx_knapsack, dst_dir): + """ + 处理单个 packing 数据 + + 参数: + s1: 当前处理组的索引 + idx_knapsack: 背包中包含的索引列表 + dst_dir: 目标目录路径 + """ + # global BASE_NAMES + + packed_imgs, packed_caps = [], [] # 单个 packed-sample 的构成 + + # 获取基础文件名 + # packed_b_names = (BASE_NAMES[idx] for idx in idx_knapsack) + packed_b_names = (idx["name"] for idx in idx_knapsack) + + # 构建源文件信息 + if task_type == "pretrain": + packed_info = ( + (os.path.join(SRC_DIR_IMGS, f"{b_name}.{SRC_DST_EXTENSIONS[0]}"), + extract_content(os.path.join(SRC_DIR_JSONS, f"{b_name}.{SRC_DST_EXTENSIONS[1]}"))) + for b_name in packed_b_names + ) + elif task_type == "sft": + packed_info = ( + (os.path.join(SRC_DIR_IMGS, f"{b_name}.{SRC_DST_EXTENSIONS[0]}"), + extract_content(os.path.join(SRC_DIR_JSONS, f"{b_name}.{SRC_DST_EXTENSIONS[1]}")), + extract_prompt(os.path.join(SRC_DIR_JSONS, f"{b_name}.{SRC_DST_EXTENSIONS[1]}"))) + for b_name in packed_b_names + ) + elif task_type == "bmr": + packed_info = ( + extract_img_prompt_content(os.path.join(SRC_DIR_JSONS, f"{b_name}.{SRC_DST_EXTENSIONS[1]}")) + for b_name in packed_b_names + ) + + # 目标JSON文件路径 + json_dst = os.path.join(dst_dir, f"ps_{s1:08d}.{SRC_DST_EXTENSIONS[1]}") + + # 处理每张图片和对应的描述 + if task_type=="pretrain": + for s2, (img_src, cap_src) in enumerate(packed_info): + # 目标图片路径 + img_name_dst = f"ps_{s1:08d}.img{s2:03d}.{SRC_DST_EXTENSIONS[0]}" + # img_name_dst = f"img{s2:03d}.{SRC_DST_EXTENSIONS[0]}" # 看后面具体需求决定使用哪一个 + img_dst = os.path.join(dst_dir, img_name_dst) + + # 收集信息 + # packed_imgs.append(img_name_dst) + packed_imgs.append(f"img{s2:03d}.{SRC_DST_EXTENSIONS[0]}") + packed_caps.append(cap_src) + + # 复制图片 + shutil.copyfile(img_src, img_dst) + # 此处也可以调用大模型来生成 提问(对于 纯 captioning 数据) + selected_prompts = get_random_prompts(PROMPTS, len(packed_imgs)) + elif task_type=="sft": + selected_prompts = [] + for s2, (img_src, cap_src, prompt_src) in enumerate(packed_info): + # 目标图片路径 + img_name_dst = f"ps_{s1:08d}.img{s2:03d}.{SRC_DST_EXTENSIONS[0]}" + # img_name_dst = f"img{s2:03d}.{SRC_DST_EXTENSIONS[0]}" # 看后面具体需求决定使用哪一个 + img_dst = os.path.join(dst_dir, img_name_dst) + + # 收集信息 + # packed_imgs.append(img_name_dst) + packed_imgs.append(f"img{s2:03d}.{SRC_DST_EXTENSIONS[0]}") + packed_caps.append(cap_src) + + # 复制图片 + shutil.copyfile(img_src, img_dst) + + # prompts + selected_prompts.append(prompt_src) + pass + elif task_type=="bmr": + selected_prompts = [] + for s2, (img_src, prompt_src, cap_src) in enumerate(packed_info): + if not img_src: + packed_imgs.append([]) + else: + # 目标图片路径 + name, ext = os.path.splitext(img_src[0]) + img_name_dst = f"ps_{s1:08d}.img{s2:03d}{ext}" + img_dst = os.path.join(dst_dir, img_name_dst) + + # 复制图片 + shutil.copyfile(img_src[0], img_dst) + + # 收集 image 信息 + packed_imgs.append([f"img{s2:03d}{ext}"]) + # cnt_imgs += 1 + + # 收集其它信息 + packed_caps.append(cap_src) + selected_prompts.append(prompt_src) + pass + + # 生成JSON文件 + json_data = { + "images": packed_imgs, + "captions": packed_caps, + "prompts": selected_prompts + } + # print(packed_imgs) + + try: + with open(json_dst, 'w', encoding='utf-8') as f: + json.dump(json_data, f, indent=4, ensure_ascii=False) + # json.dump(json_data, f) + except Exception as e: + print(f"线程 {threading.current_thread().name} 生成JSON文件 {json_dst} 失败: {str(e)}") + return s1 + + +if __name__ == "__main__": + ## 1. 创建工作目录 + print("Step1-----------------已创建工作环境-----------------Start") + prepare_dirs(target_directory, newDir) + print("Step1-----------------已创建工作环境-----------------Stop\n\n") + + ## 2. 获取原始数据集信息(没有处理之前) + # 可以用于构建多个 pool,分块 packing(read的参数决定 packing cache size) + print("Step2-----------------读取原ds的 tokenlen 信息-----------------Start") + info_reader = TokenInfoReader(f_toklens_originalsample) + base_names, token_lens, n_count = info_reader.read() + + # global BASE_NAMES + BASE_NAMES=tuple(base_names) + print(f"已读取{n_count}条数据") + # print(BASE_NAMES) + print("Step2-----------------读取原ds的 tokenlen 信息-----------------Stop\n\n") + + # 3. packing分组 + #调用 packing-group 进行分组 + print("Step3-----------------packing 分组-----------------Start") + # knapsacks, idx_knapsacks= greedy_knapsack(token_lens, PACKED_LENGTH) + # print(idx_knapsacks[10]) + # print(knapsacks[10]) + import pickle + def load_bin_boxes(file_path: str): + """ + 加载单步装箱结果 + """ + with open(file_path, 'rb') as f: + bin_boxes = pickle.load(f) + print(f"已加载装箱结果: {file_path}") + return bin_boxes + + # bin_boxs = load_bin_boxes("./s2_ckpt/bins_boxs_8k.pkl") + bin_boxs = load_bin_boxes("./s2_ckpt/bins_boxs_mr_sft_8k.pkl") + + # total_knapsacks = len(idx_knapsacks) + total_knapsacks = len(bin_boxs) + + print(f"原始数据----{n_count}----条,packing后变为----{total_knapsacks}----条") + print("Step3-----------------packing 分组-----------------Stop\n\n") + + print("Step4----------------- 开始构建新数据集 -----------------Start") + print(f"开始处理 {total_knapsacks} 组数据,使用 {MAX_WORKERS} 个线程") + + #4. 使用线程池处理所有pack + with ThreadPoolExecutor(max_workers=MAX_WORKERS, thread_name_prefix="PackThread") as executor: + # 提交所有任务 + if f_TEST: + futures = { + executor.submit(process_knapsack, s1, idx_knapsack, dst_dir): s1 + for s1, idx_knapsack in enumerate(bin_boxs[0:n_packed_samples]) + } + else: + futures = { + executor.submit(process_knapsack, s1, idx_knapsack, dst_dir): s1 + for s1, idx_knapsack in enumerate(bin_boxs) + } + + # tqdm 自动跟踪完成数 + from tqdm import tqdm + tty = open(os.devnull, 'w') if os.name == 'nt' else open('/dev/tty', 'w') + for future in tqdm(as_completed(futures), + total=len(futures), + desc="Packing progress", + unit="pack", + file=tty + ): + try: + future.result() + except Exception as e: + s1 = futures[future] + print(f"\n处理第 {s1} 组数据时发生错误: {e}") + + print("Step4-----------------Sccessful!!!!---- 构建新数据集成功 -----------------Stop") diff --git a/tools/data_preprocess/offline_packing/s3_test_mmr_sft_780k-8k.sh b/tools/data_preprocess/offline_packing/s3_test_mmr_sft_780k-8k.sh new file mode 100644 index 0000000..4b1dc48 --- /dev/null +++ b/tools/data_preprocess/offline_packing/s3_test_mmr_sft_780k-8k.sh @@ -0,0 +1,6 @@ +# Make adjustments according to the actual data. +OUT_WDS_DIR='/vlm/data/offline_paclking_datasets/bmr_sft_780k-8k' +IN_SAMPLE_DIR='/workspace/data4packing/RiceVL/data_procs/raw_packing_data_mr_sft_780k-8k-fast' +PY_EXE="/workspace/AIAK-Training-LLM/tools/data_preprocess/convert_packedsample_to_wds.py" + +python -u ${PY_EXE} --output_dir ${OUT_WDS_DIR} --json_file ${IN_SAMPLE_DIR} --video_dir ${IN_SAMPLE_DIR} --image_dir ${IN_SAMPLE_DIR} --mode bmr_pack --maxcount 5000 2>&1 | tee ./logs/s3_proc_mr_sft_780k-8k.log \ No newline at end of file