@@ -574,74 +574,79 @@ def next_rand(self):
574
574
575
575
def __getitem__ (self , i ) -> Dict [str , torch .Tensor ]:
576
576
while True :
577
- sources = self .list_data_dict [i ]
578
- if isinstance (i , int ):
579
- sources = [sources ]
580
- assert len (sources ) == 1 , "Don't know why it is wrapped to a list" # FIXME
581
- if 'image' in sources [0 ]:
582
- image_file = self .list_data_dict [i ]['image' ]
583
-
584
- image_folder = self .data_args .image_folder
585
- processor = self .data_args .image_processor
586
- from pathlib import Path
587
- #if not Path(os.path.join(image_folder, image_file)).exists():
588
- # i = self.next_rand()
589
- # continue
590
- if isinstance (image_file , list ):
591
- # Multiple Images as Input
592
- try :
593
- image = [Image .open (os .path .join (image_folder , imfile )).convert ('RGB' ) for imfile in image_file ]
594
- except Exception as ex :
595
- print (ex )
596
- i = self .next_rand ()
597
- continue
598
- if self .data_args .image_aspect_ratio == 'pad' :
599
- image = [expand2square (img , tuple (int (x * 255 ) for x in processor .image_mean )) for img in image ]
600
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
577
+ try :
578
+ sources = self .list_data_dict [i ]
579
+ if isinstance (i , int ):
580
+ sources = [sources ]
581
+ assert len (sources ) == 1 , "Don't know why it is wrapped to a list" # FIXME
582
+ if 'image' in sources [0 ]:
583
+ image_file = self .list_data_dict [i ]['image' ]
584
+
585
+ image_folder = self .data_args .image_folder
586
+ processor = self .data_args .image_processor
587
+ from pathlib import Path
588
+ #if not Path(os.path.join(image_folder, image_file)).exists():
589
+ # i = self.next_rand()
590
+ # continue
591
+ if isinstance (image_file , list ):
592
+ # Multiple Images as Input
593
+ try :
594
+ image = [Image .open (os .path .join (image_folder , imfile )).convert ('RGB' ) for imfile in image_file ]
595
+ except Exception as ex :
596
+ print (ex )
597
+ i = self .next_rand ()
598
+ continue
599
+ if self .data_args .image_aspect_ratio == 'pad' :
600
+ image = [expand2square (img , tuple (int (x * 255 ) for x in processor .image_mean )) for img in image ]
601
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
602
+ else :
603
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
604
+ elif os .path .join (image_folder , image_file ).endswith ("mp4" ):
605
+ # Video as Input
606
+ image = load_video (os .path .join (image_folder , image_file ))
607
+ if self .data_args .image_aspect_ratio == 'pad' :
608
+ image = [expand2square (img , tuple (int (x * 255 ) for x in processor .image_mean )) for img in image ]
609
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
610
+ else :
611
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
601
612
else :
602
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
603
- elif os .path .join (image_folder , image_file ).endswith ("mp4" ):
604
- # Video as Input
605
- image = load_video (os .path .join (image_folder , image_file ))
606
- if self .data_args .image_aspect_ratio == 'pad' :
607
- image = [expand2square (img , tuple (int (x * 255 ) for x in processor .image_mean )) for img in image ]
608
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
609
- else :
610
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
613
+ try :
614
+ image = Image .open (os .path .join (image_folder , image_file )).convert ('RGB' )
615
+ except Exception as ex :
616
+ print (ex )
617
+ i = self .next_rand ()
618
+ continue
619
+ if self .data_args .image_aspect_ratio == 'pad' :
620
+ image = expand2square (image , tuple (int (x * 255 ) for x in processor .image_mean ))
621
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
622
+ else :
623
+ image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
624
+ sources = preprocess_multimodal (
625
+ copy .deepcopy ([e ["conversations" ] for e in sources ]),
626
+ self .data_args )
611
627
else :
612
- try :
613
- image = Image .open (os .path .join (image_folder , image_file )).convert ('RGB' )
614
- except Exception as ex :
615
- print (ex )
616
- i = self .next_rand ()
617
- continue
618
- if self .data_args .image_aspect_ratio == 'pad' :
619
- image = expand2square (image , tuple (int (x * 255 ) for x in processor .image_mean ))
620
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
621
- else :
622
- image = processor .preprocess (image , return_tensors = 'pt' )['pixel_values' ]
623
- sources = preprocess_multimodal (
624
- copy .deepcopy ([e ["conversations" ] for e in sources ]),
625
- self .data_args )
626
- else :
627
-
628
- sources = copy .deepcopy ([e ["conversations" ] for e in sources ])
629
- data_dict = preprocess (
630
- sources ,
631
- self .tokenizer ,
632
- has_image = ('image' in self .list_data_dict [i ]))
633
- if isinstance (i , int ):
634
- data_dict = dict (input_ids = data_dict ["input_ids" ][0 ],
635
- labels = data_dict ["labels" ][0 ])
636
-
637
- # image exist in the data
638
- if 'image' in self .list_data_dict [i ]:
639
- data_dict ['image' ] = image
640
- elif self .data_args .is_multimodal :
641
- # image does not exist in the data, but the model is multimodal
642
- crop_size = self .data_args .image_processor .crop_size
643
- data_dict ['image' ] = torch .zeros (3 , crop_size ['height' ], crop_size ['width' ])
644
- return data_dict
628
+
629
+ sources = copy .deepcopy ([e ["conversations" ] for e in sources ])
630
+ data_dict = preprocess (
631
+ sources ,
632
+ self .tokenizer ,
633
+ has_image = ('image' in self .list_data_dict [i ]))
634
+ if isinstance (i , int ):
635
+ data_dict = dict (input_ids = data_dict ["input_ids" ][0 ],
636
+ labels = data_dict ["labels" ][0 ])
637
+
638
+ # image exist in the data
639
+ if 'image' in self .list_data_dict [i ]:
640
+ data_dict ['image' ] = image
641
+ elif self .data_args .is_multimodal :
642
+ # image does not exist in the data, but the model is multimodal
643
+ crop_size = self .data_args .image_processor .crop_size
644
+ data_dict ['image' ] = torch .zeros (3 , crop_size ['height' ], crop_size ['width' ])
645
+ return data_dict
646
+ except Exception as ex :
647
+ print (ex )
648
+ i = self .next_rand ()
649
+ continue
645
650
646
651
647
652
@dataclass
0 commit comments