@@ -68,7 +68,11 @@ def add_segmentation_and_position(x, data_columns, padding_token=0):
6868
6969def reformat_prompt (example , column , image_placeholder , model_name ):
7070 """reformat prompt for multimodal SFT"""
71- example [column ] = multimodal_utils .reformat_prompt (example [column ], image_placeholder , model_name )
71+ if isinstance (example ["images" ], list ):
72+ num_images = len (example ["images" ])
73+ else :
74+ num_images = 1
75+ example [column ] = multimodal_utils .reformat_prompt (example [column ], image_placeholder , model_name , num_images )
7276 return example
7377
7478
@@ -80,11 +84,19 @@ def reformat_response(example, column, model_name):
8084
8185def pre_process_image_sft (example , image_column , model_name ):
8286 """pre-process image for multimodal SFT"""
83- image = multimodal_utils .convert_to_RGB (example [image_column ])
84- # TODO(aireenmei, hengtaoguo): add support for different image sizes
85- image = multimodal_utils .resize_image (image , model_name )
86- image = np .array (image )
87- example [image_column ] = multimodal_utils .pre_process_image (image , model_name )
87+
88+ def _process_image_fn (image ):
89+ image = multimodal_utils .convert_to_RGB (image )
90+ # TODO(aireenmei, hengtaoguo): add support for different image sizes
91+ image = multimodal_utils .resize_image (image , model_name )
92+ image = np .array (image )
93+ image = multimodal_utils .pre_process_image (image , model_name )
94+ return image
95+
96+ if isinstance (example [image_column ], list ):
97+ example [image_column ] = [_process_image_fn (img ) for img in example [image_column ]]
98+ else :
99+ example [image_column ] = _process_image_fn (example [image_column ])
88100 return example
89101
90102
@@ -93,7 +105,10 @@ def prepare_text_for_image_fusion(example, column_name, model_name):
93105 example [column_name ] = multimodal_utils .prepare_text_for_image_fusion (
94106 example [column_name ], model_name , processor_output = example ["images" ]
95107 )
96- example ["images" ] = example ["images" ].pixel_values
108+ if isinstance (example ["images" ], list ):
109+ example ["images" ] = [image .pixel_values for image in example ["images" ]]
110+ else :
111+ example ["images" ] = example ["images" ].pixel_values
97112 return example
98113
99114
@@ -400,58 +415,58 @@ def map(self, element):
400415
401416@dataclasses .dataclass
402417class PadOrTrimToMaxLength (grain .MapTransform ):
403- """Pads/Trims each input to the specified length
404- and returns true_length of input
405- """
406-
407- def __init__ (self , max_length ):
408- self .max_length = max_length
409-
410- def map (self , element : dict [str , np .ndarray ]):
411- """map to each element"""
412-
413- def _pad (x , max_length ):
414- pad_amount = max (max_length - x .shape [0 ], 0 )
415- pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
416- return np .pad (x , pad_amount )[:max_length ]
417-
418- data_columns = list (element .keys ())
419- for data_column in data_columns :
420- element [f"{ data_column } _segmentation" ] = (element [data_column ] != 0 ).astype (np .int32 )
421- element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
422- element [f"{ data_column } _true_length" ] = np .array ([element [data_column ].shape [0 ]], dtype = np .int32 )
423- for key , _ in element .items ():
424- if "true_length" not in key :
425- element [key ] = _pad (element [key ], self .max_length )
426- # for data_column in data_columns:
427- # data[f"{data_column}_true_length"] = _max_true_length(data[data_column], 0)
428- return element
429-
418+ """Pads or trims each input to the specified length.
419+ And optionally add true length for the input."""
430420
431- @dataclasses .dataclass
432- class PadToMaxLength (grain .MapTransform ):
433- """Pads each input to the specified length"""
434-
435- def __init__ (self , max_length , pad_id ):
421+ def __init__ (self , max_length , pad_id = 0 , model_name = None , add_true_length = False , max_num_images_per_example = - 1 ):
436422 self .max_length = max_length
437423 self .pad_id = pad_id
424+ self .model_name = model_name
425+ self .add_true_length = add_true_length
426+ self .max_num_images_per_example = max_num_images_per_example
427+
428+ def _pad_text (self , x , max_length , pad_id ):
429+ pad_amount = max (max_length - x .shape [0 ], 0 )
430+ pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
431+ return np .pad (x , pad_amount , constant_values = pad_id )[: self .max_length ]
432+
433+ def _pad_image (self , images ):
434+ image_offsets = multimodal_utils .get_image_offsets (self .model_name , None )
435+ max_num_images = (self .max_length // image_offsets ) - 1 # -1 to reserve space for at least one text token
436+ if self .max_num_images_per_example > 0 :
437+ max_num_images = min (self .max_num_images_per_example , max_num_images )
438+ image_shape = multimodal_utils .get_dummy_image_shape_for_init (self .model_name )[2 :]
439+ assert (
440+ images .shape [0 ] <= max_num_images
441+ ), f"Number of images { images .shape [0 ]} exceeds the maximum allowed { max_num_images } "
442+ if images .shape [0 ] < max_num_images :
443+ pad_size = max_num_images - images .shape [0 ]
444+ pad_shape = (pad_size ,) + image_shape
445+ pad_images = np .zeros (pad_shape , dtype = images .dtype )
446+ if images is not None and images .size > 0 :
447+ images = np .concatenate ([images , pad_images ], axis = 0 )
448+ else :
449+ images = pad_images
450+ return images
438451
439452 def map (self , element : dict [str , np .ndarray ]):
440453 """map to each element"""
441-
442- def _pad (x , max_length , pad_id ):
443- pad_amount = max (max_length - x .shape [0 ], 0 )
444- pad_amount = [(0 , pad_amount )] + [(0 , 0 )] * (len (x .shape ) - 1 )
445- return np .pad (x , pad_amount , constant_values = pad_id )
446-
447454 data_columns = list (element .keys ())
448455 for data_column in data_columns :
449456 if data_column != "images" :
450457 element [f"{ data_column } _segmentation" ] = (element [data_column ] != self .pad_id ).astype (np .int32 )
451458 element [f"{ data_column } _position" ] = np .arange (element [data_column ].shape [0 ], dtype = np .int32 )
459+ if self .add_true_length :
460+ element [f"{ data_column } _true_length" ] = np .array ([element [data_column ].shape [0 ]], dtype = np .int32 )
452461 for key , _ in element .items ():
453- if key != "images" :
454- element [key ] = _pad (element [key ], self .max_length , self .pad_id )
462+ if key == "images" :
463+ if isinstance (element ["images" ], list ):
464+ assert self .model_name is not None , "model_name must be provided when padding images"
465+ element ["images" ] = self ._pad_image (np .asarray (element ["images" ]))
466+ else :
467+ element ["images" ] = np .asarray (element ["images" ])[None , ...]
468+ elif "true_length" not in key :
469+ element [key ] = self ._pad_text (element [key ], self .max_length , self .pad_id )
455470 return element
456471
457472
0 commit comments