Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aiak_training_llm/data/chat_templete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""

import re
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Type, Dict, List, Optional, Sequence, Set, Tuple, Union
Expand Down Expand Up @@ -133,6 +134,9 @@ def encode_multiturn(
messages = messages[1:]

encoded_messages = self._encode(tokenizer, messages, system)
# ZXW: missing "/n"
if int(os.environ.get("FILL_TOKEN_198",0))==1:
encoded_messages[-1].append(198)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]

def encode_oneturn(
Expand Down
20 changes: 17 additions & 3 deletions aiak_training_llm/data/multimodal/flavors/packed_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@
class PackedCaptioningSample(Sample):
"""Sample type for packed captioning."""
# sample_id: str
images: List[torch.Tensor]
prompts: Optional[List[str]]
captions: List[str]
images: Union[
str, # A single image path, e.g., 'img001.jpg'
torch.Tensor, # A single image tensor
List[str], # A list of image paths, e.g., ['imgs001.jpg', 'imgs002.png']
List[List[str]], # A nested list of image paths, e.g., [['imgs001.tif'], [], ['imgs005.png']]
List[torch.Tensor], # A list of image tensors
List[List[torch.Tensor]]
]
prompts: Union[
Optional[List[str]],
List[List[str]]
]
captions: Union[
List[str],
List[List[str]]
]

33 changes: 33 additions & 0 deletions aiak_training_llm/data/multimodal/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,39 @@ def encode_sample(self, sample: Union[CaptioningSample, OCRSample, VQASample, Si
context=sample.prompts[idx]
)
l_Qwen2VLImageTaskSample.append(self.encode_vqa4packing(cur_capsample))
elif int(os.environ.get("OFFLINE_PACKING_BMR",0))==1:
def convert_to_messages(cur_prompt, cur_caption):
"""
{cur_prompt, cur_caption}---> messages
"""

if len(cur_prompt) != len(cur_caption):
raise ValueError("cur_prompt & cur_caption have different lengths")

messages = []
for prompt, caption in zip(cur_prompt, cur_caption):
messages.append({
"content": prompt,
"role": "user"
})

messages.append({
"content": caption,
"role": "assistant"
})

return messages
cur_capsample = MultiMixQASample(
__key__=f"{sample.__key__}.img{idx:03d}_jpg",
__restore_key__=sample.__restore_key__,
__subflavor__='BMR',
__subflavors__=sample.__subflavors__,
messages=convert_to_messages(sample.prompts[idx], sample.captions[idx]),
video=None,
system=None,
image=[sample.images[idx]] if sample.images[idx] else None
)
l_Qwen2VLImageTaskSample.append(self.encode_multi_mix_qa(cur_capsample))
else:
cur_capsample = CaptioningSample(
__key__=f"{sample.__key__}.img{idx:03d}_jpg",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Data path configuration
data:
# Directory of data samples
directory: "/data_1/llava_next_raw_full/split_json_files/"
# Temporary file for storing paired filenames
output_base: "base_name_v4_MR_sft_780k_8k.txt"
# Final output file (includes token length information)
output_token: "token_info_MR_sft_780k_8k.txt"

# Model path
model:
checkpoint: "/vlm/xiangan/pretrain_models/rice_vl/rice_vl_rice_300m_qwen2.5_7b_adapter_v1_fixed_tokenizer_huggingface"

sample:
# Maximum length for training data
max_len: 8192
# Whether to remove one token (The current data preprocessing step is missing one token; this flag is used to decide if alignment is needed)
# Used in conjunction with the environment variable FILL_TOKEN_198:
# false: FILL_TOKEN_198=1, Achieves faster convergence
# true: FILL_TOKEN_198=0
del_one_token: false
# Decides the parsing method
task_type: sft
max_prompt: null
max_answer: null

# Image processing parameters
image:
baidu_resolution: 1600 # baidu code's limit parameter (null)
min_pixels: 3136 # 4*28*28
max_pixels: 4014080 # 5120*28*28(4014080,8192)
# Maximum aspect ratio limit (images exceeding this value will be filtered out)
max_aspect_ratio: 200

# Parallel processing parameters
processing:
# Number of samples each process handles
chunk_size: 5000
# Merge parameter (sorting), merge N stage0 files into 1 stage1 file
stage1_merge_chunk: 20
n_workers: 64
# Minimum number of threads in the thread pool
min_workers: 10
# Maximum number of threads in the thread pool
max_workers: 32
# Timeout setting (based on data size, 1M data estimated to take 45 minutes (2700s))
time_out: 20000

# Logging and temporary file configuration
logging:
# Log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
level: "INFO"
# Log file path
file: "./logs/s1_processing_MR_sft_780k_8k.log"
# Whether to use /dev/shm as the temporary directory
use_shm: false

Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
├── ps_00000000.img000.jpg
├── ps_00000000.img001.jpg
├── ps_00000000.json

...

JSON 格式:
JSON format(pretrain):
{
"images": ["img000.jpg", "img001.jpg", ...],
"prompt": ["描述", "what about", ""],
Expand All @@ -17,9 +18,20 @@
一条 json + 对应若干 jpg = 1 条 tar 记录
"""

######-----------------------------------------######
######-----------------------------------------######
######-----------------------------------------######
"""

JSON format (multi-round + blend data for sft):
{
"images": [["img000.jpg"], ["img001.jpg"], []],
"prompt": [["描述这幅图"], ["what about this fig"], ["How are you?", "I am fine too."]],
"captions": [["stri"], ["str2"], ["I am fine and you?", "Have a nice day"]]
}
一条 json + 对应若干 jpg(可以为0) = 1 条 tar 记录
"""

# #####-----------------------------------------######
# #####-----------------------------------------######
# #####-----------------------------------------######

import argparse
import uuid
Expand Down Expand Up @@ -57,8 +69,8 @@ def sample_loader_template(media: str=None):
"def part_filter(part: str) -> bool:",
" return True",
])
### ZXW

# ## ZXW

def sample_loader_template_caption(media=None):
"""适配整条多图 captioning 的 loader"""
Expand Down Expand Up @@ -107,7 +119,25 @@ def construct_sample_caption(args, entry):
sample["json"] = json.dumps(payload, ensure_ascii=False).encode("utf-8")
return sample

### ZXW
def construct_bmr_sample(args, entry):
"multi-round & blend data package"
sample = {"__key__": entry["id"]}
for idx, img_name in enumerate(entry["images"]):
# print(img_name)
img_path = os.path.join(args.image_dir, f"{entry["id"]}.{img_name[0]}") if img_name else None
if img_name:
with open(img_path, "rb") as f:
sample[f"img{idx}.jpg"] = f.read()
payload = {
"prompts": entry["prompts"],
"captions": entry["captions"],
"images": entry["images"]
}

sample["json"] = json.dumps(payload, ensure_ascii=False).encode("utf-8")
return sample

# ## ZXW

def construct_sample(args, vision, path, entry):
""" construct webdataset sample """
Expand Down Expand Up @@ -139,7 +169,7 @@ def convert_to_wds(args):

tar = os.path.join(args.output_dir, 'pretrain-%06d.tar')
if args.mode == "caption_pack":
# 新模式
# 新模式 1
with wds.ShardWriter(tar, maxcount=args.maxcount, maxsize=args.maxsize) as sink:
for entry in tqdm(stream_samples_caption(args.json_file)):
sample=construct_sample_caption(args, entry)
Expand All @@ -150,7 +180,21 @@ def convert_to_wds(args):

write_config(EPath(args.output_dir).absolute(), args.media,
template_func=sample_loader_template_caption,
class_name="PackedCaptioningSample")
class_name="PackedCaptioningSample")
elif args.mode == "bmr_pack":
# 新模式 2
with wds.ShardWriter(tar, maxcount=args.maxcount, maxsize=args.maxsize) as sink:
for entry in tqdm(stream_samples_caption(args.json_file)):
sample=construct_bmr_sample(args, entry)
# print(sample.keys())
sink.write(sample)
# break
# sink.write(construct_sample_caption(args.image_dir, entry))

write_config(EPath(args.output_dir).absolute(), args.media,
template_func=sample_loader_template_caption,
class_name="PackedCaptioningSample")
pass
print(f"Dataset successfully converted to wds")


Expand Down Expand Up @@ -201,7 +245,7 @@ def _add_arguments(parser: argparse.ArgumentParser):
group.add_argument('--columns_messages', type=str, default="messages", help='Column name for messages')
# 新增模式选择
group.add_argument('--mode', type=str,
choices=["chat", "caption_pack"],
choices=["chat", "caption_pack", "bmr_pack"],
default="chat",
help="chat=旧格式(单图对话); caption_pack=新格式(整条多图caption)")
return parser
Expand Down
32 changes: 32 additions & 0 deletions tools/data_preprocess/offline_packing/hashbacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2108,9 +2108,14 @@ class PackingTracker:
def __init__(self, processor):
self.processor = processor
self.history = []
self.snapshots = [] # 添加状态快照(2025.09.12)

def track_packing(self, strategy_name: str, **kwargs):
"""记录一次装箱操作"""

# 保存状态快照 (2025.09.12)
self.save_current_state()

before_state = self.processor.check_hash_buckets_state()
# 支持返回详细统计(如 total_attempts),否则只返回箱子列表
result = getattr(self.processor, strategy_name)(**kwargs)
Expand Down Expand Up @@ -2147,6 +2152,33 @@ def print_summary(self):
print(f"成功率: {rate:.1%} ({op['boxes_count']}/{op['total_attempts']})")
else:
print(f"成功率: N/A")

# (2025.09.12)
def save_current_state(self):
"""保存当前状态快照"""
snapshot = {
'hash_buckets': {k: arr.copy() for k, arr in self.processor.hash_buckets.items()},
'timestamp': time.time()
}
self.snapshots.append(snapshot) # 这里已经完成了追加操作
print("save checkpoint........")
# return snapshot

# (2025.09.12)
#
def restore_state(self, index: int):
"""恢复到指定操作前的状态"""
if 0 <= index < len(self.snapshots):
snapshot = self.snapshots[index]
self.processor.hash_buckets = {
k: arr.copy() for k, arr in snapshot['hash_buckets'].items()
}
# 清理该索引之后的所有状态和历史记录
self.snapshots = self.snapshots[:index+1]
self.history = self.history[:index]
print(f"已恢复到操作 {index} 之前的状态")
return True
return False


# # 使用示例
Expand Down
Loading