Skip to content

Commit af4904e

Browse files
committed
maxwell ft
1 parent ca43543 commit af4904e

File tree

4 files changed

+113
-66
lines changed

4 files changed

+113
-66
lines changed

playground/data/ft/maxwell/test_split_1.json

+1
Large diffs are not rendered by default.

playground/data/ft/maxwell/train_split_1.json

+1
Large diffs are not rendered by default.

q_align/train/train_mem.py

+71-66
Original file line numberDiff line numberDiff line change
@@ -574,74 +574,79 @@ def next_rand(self):
574574

575575
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
576576
while True:
577-
sources = self.list_data_dict[i]
578-
if isinstance(i, int):
579-
sources = [sources]
580-
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
581-
if 'image' in sources[0]:
582-
image_file = self.list_data_dict[i]['image']
583-
584-
image_folder = self.data_args.image_folder
585-
processor = self.data_args.image_processor
586-
from pathlib import Path
587-
#if not Path(os.path.join(image_folder, image_file)).exists():
588-
# i = self.next_rand()
589-
# continue
590-
if isinstance(image_file, list):
591-
# Multiple Images as Input
592-
try:
593-
image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
594-
except Exception as ex:
595-
print(ex)
596-
i = self.next_rand()
597-
continue
598-
if self.data_args.image_aspect_ratio == 'pad':
599-
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
600-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
577+
try:
578+
sources = self.list_data_dict[i]
579+
if isinstance(i, int):
580+
sources = [sources]
581+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
582+
if 'image' in sources[0]:
583+
image_file = self.list_data_dict[i]['image']
584+
585+
image_folder = self.data_args.image_folder
586+
processor = self.data_args.image_processor
587+
from pathlib import Path
588+
#if not Path(os.path.join(image_folder, image_file)).exists():
589+
# i = self.next_rand()
590+
# continue
591+
if isinstance(image_file, list):
592+
# Multiple Images as Input
593+
try:
594+
image = [Image.open(os.path.join(image_folder, imfile)).convert('RGB') for imfile in image_file]
595+
except Exception as ex:
596+
print(ex)
597+
i = self.next_rand()
598+
continue
599+
if self.data_args.image_aspect_ratio == 'pad':
600+
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
601+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
602+
else:
603+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
604+
elif os.path.join(image_folder, image_file).endswith("mp4"):
605+
# Video as Input
606+
image = load_video(os.path.join(image_folder, image_file))
607+
if self.data_args.image_aspect_ratio == 'pad':
608+
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
609+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
610+
else:
611+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
601612
else:
602-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
603-
elif os.path.join(image_folder, image_file).endswith("mp4"):
604-
# Video as Input
605-
image = load_video(os.path.join(image_folder, image_file))
606-
if self.data_args.image_aspect_ratio == 'pad':
607-
image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
608-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
609-
else:
610-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
613+
try:
614+
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
615+
except Exception as ex:
616+
print(ex)
617+
i = self.next_rand()
618+
continue
619+
if self.data_args.image_aspect_ratio == 'pad':
620+
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
621+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
622+
else:
623+
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
624+
sources = preprocess_multimodal(
625+
copy.deepcopy([e["conversations"] for e in sources]),
626+
self.data_args)
611627
else:
612-
try:
613-
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
614-
except Exception as ex:
615-
print(ex)
616-
i = self.next_rand()
617-
continue
618-
if self.data_args.image_aspect_ratio == 'pad':
619-
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
620-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
621-
else:
622-
image = processor.preprocess(image, return_tensors='pt')['pixel_values']
623-
sources = preprocess_multimodal(
624-
copy.deepcopy([e["conversations"] for e in sources]),
625-
self.data_args)
626-
else:
627-
628-
sources = copy.deepcopy([e["conversations"] for e in sources])
629-
data_dict = preprocess(
630-
sources,
631-
self.tokenizer,
632-
has_image=('image' in self.list_data_dict[i]))
633-
if isinstance(i, int):
634-
data_dict = dict(input_ids=data_dict["input_ids"][0],
635-
labels=data_dict["labels"][0])
636-
637-
# image exist in the data
638-
if 'image' in self.list_data_dict[i]:
639-
data_dict['image'] = image
640-
elif self.data_args.is_multimodal:
641-
# image does not exist in the data, but the model is multimodal
642-
crop_size = self.data_args.image_processor.crop_size
643-
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
644-
return data_dict
628+
629+
sources = copy.deepcopy([e["conversations"] for e in sources])
630+
data_dict = preprocess(
631+
sources,
632+
self.tokenizer,
633+
has_image=('image' in self.list_data_dict[i]))
634+
if isinstance(i, int):
635+
data_dict = dict(input_ids=data_dict["input_ids"][0],
636+
labels=data_dict["labels"][0])
637+
638+
# image exist in the data
639+
if 'image' in self.list_data_dict[i]:
640+
data_dict['image'] = image
641+
elif self.data_args.is_multimodal:
642+
# image does not exist in the data, but the model is multimodal
643+
crop_size = self.data_args.image_processor.crop_size
644+
data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
645+
return data_dict
646+
except Exception as ex:
647+
print(ex)
648+
i = self.next_rand()
649+
continue
645650

646651

647652
@dataclass

scripts/maxwell-officialsplit-lora.sh

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
LOAD='q-future/one-align'
3+
4+
for i in $(seq 1 1)
5+
do
6+
echo "Split $i"
7+
DATA_FILE=playground/data/ft/maxwell/train_split_$i.json
8+
deepspeed --master_port 25801 q_align/train/train_mem.py \
9+
--deepspeed ./scripts/zero3.json \
10+
--lora_enable True --visual_abstractor_lr 2e-5\
11+
--model_name_or_path $LOAD \
12+
--version v1 \
13+
--data_path $DATA_FILE \
14+
--image_folder ../datasets/MaxWell \
15+
--image_aspect_ratio pad \
16+
--group_by_modality_length True \
17+
--bf16 True \
18+
--output_dir ./q-align-maxwell-lora-$i \
19+
--num_train_epochs 5 \
20+
--per_device_train_batch_size 4 \
21+
--per_device_eval_batch_size 4 \
22+
--gradient_accumulation_steps 8 \
23+
--evaluation_strategy "no" \
24+
--save_strategy "steps" \
25+
--save_steps 800 \
26+
--save_total_limit 3 \
27+
--learning_rate 2e-4 \
28+
--weight_decay 0. \
29+
--warmup_ratio 0.03 \
30+
--lr_scheduler_type "cosine" \
31+
--logging_steps 1 \
32+
--tf32 True \
33+
--model_max_length 2048 \
34+
--gradient_checkpointing True \
35+
--tune_visual_abstractor True \
36+
--freeze_vision_model False \
37+
--dataloader_num_workers 4 \
38+
--lazy_preprocess True \
39+
--report_to wandb
40+
done

0 commit comments

Comments
 (0)