Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from ...utils.import_utils import requires
from ...video_processing_utils import BaseVideoProcessor
from ...video_utils import group_videos_by_shape, reorder_videos
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos


if is_vision_available():
Expand Down Expand Up @@ -66,6 +66,7 @@ class InstructBlipVideoVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = InstructBlipVideoVideoProcessorInitKwargs
model_input_names = ["pixel_values"]

Expand All @@ -75,6 +76,7 @@ def __init__(self, **kwargs: Unpack[InstructBlipVideoVideoProcessorInitKwargs]):
def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
Expand All @@ -86,10 +88,18 @@ def _preprocess(
do_pad: bool,
rescale_factor: float,
do_normalize: bool,
do_sample_frames: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
fps: Optional[int] = None,
num_frames: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if do_sample_frames:
videos = [
self.sample_frames(video, metadata, num_frames, fps) for video, metadata in zip(videos, video_metadata)
]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
Expand Down
62 changes: 1 addition & 61 deletions src/transformers/models/internvl/processing_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...video_utils import VideoInput, VideoMetadata, load_video, make_batched_videos
from ...video_utils import VideoInput, make_batched_videos


class InternVLImagesKwargs(ImagesKwargs, total=False):
Expand Down Expand Up @@ -290,32 +290,6 @@ def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):

return MultiModalData(**vision_data)

def sample_indices_fn(
self, metadata: VideoMetadata, num_frames: Optional[int] = None, initial_shift: Union[bool, float, int] = True
):
"""
The function to generate indices of frames to sample from a video.

Args:
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If None, all frames are sampled.
initial_shift (`bool`, `float` or `int`, defaults to `0`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.

Returns:
`np.ndarray`: Array of frame indices to sample.
"""
num_frames = num_frames if num_frames is not None else metadata.total_num_frames

if initial_shift is True:
initial_shift = metadata.total_num_frames / num_frames / 2
indices = np.arange(initial_shift, metadata.total_num_frames, metadata.total_num_frames / num_frames).astype(
int
)
return indices

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand All @@ -336,39 +310,5 @@ def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(tokenizer_input_names) + list(image_processor_input_names)

# TODO: raushan, has to be public method under `VideoProcessorBase` when API is added
def _load_video_for_model(
self,
video: Union[str, "VideoInput"],
num_frames: Optional[int],
backend: str = "pyav",
initial_shift: bool = True,
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.

Args:
video (`str` or `VideoInput`):
The video to convert to the numpy array format. Can be a link to video or local path.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not passed, the whole video is loaded.
backend (`str`, *optional*, defaults to `"pyav"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "pyav".
initial_shift (`bool`, *optional*, defaults to `True`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.

Returns:
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""

def sample_indices_fn_func(metadata, **fn_kwargs):
return self.sample_indices_fn(metadata, num_frames=num_frames, initial_shift=initial_shift, **fn_kwargs)

video, metadata = load_video(video, backend=backend, sample_indices_fn=sample_indices_fn_func)
return video, metadata


__all__ = ["InternVLProcessor"]
143 changes: 139 additions & 4 deletions src/transformers/models/internvl/video_processing_internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,43 @@
# limitations under the License.
"""Fast Video processor class for InternVL."""

from typing import List, Optional, Union

from ...image_processing_utils import BatchFeature
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
SizeDict,
)
from ...processing_utils import Unpack, VideosKwargs
from ...utils import (
TensorType,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
is_vision_available,
)
from ...utils.import_utils import requires
from ...video_processing_utils import (
BaseVideoProcessor,
)
from ...video_processing_utils import BaseVideoProcessor
from ...video_utils import VideoMetadata, group_videos_by_shape, reorder_videos


if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F


if is_torch_available():
import torch

if is_vision_available():
from ...image_utils import PILImageResampling


class InternVLVideoProcessorInitKwargs(VideosKwargs): ...
class InternVLVideoProcessorInitKwargs(VideosKwargs):
initial_shift: Union[bool, float, int]


@requires(backends=("torchvision",))
Expand All @@ -45,11 +63,128 @@ class InternVLVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
initial_shift = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = InternVLVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

def __init__(self, **kwargs: Unpack[InternVLVideoProcessorInitKwargs]):
super().__init__(**kwargs)

def sample_frames(
self,
video: "torch.Tensor",
metadata: Optional[Union[VideoMetadata, dict]] = None,
num_frames: Optional[int] = None,
fps: Optional[int] = None,
initial_shift: Optional[Union[bool, float, int]] = None,
):
"""
Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
and `fps` are mutually exclusive.

Args:
video (`torch.Tensor`):
Video that need to be sampled.
metadata (`VideoMetadata`, *optional*):
Metadata of the video containing information about total duration, fps and total number of frames.
num_frames (`int`, *optional*):
Maximum number of frames to sample. Defaults to `self.num_frames`.
fps (`int`, *optional*):
Target frames to sample per second. Defaults to `self.fps`.
initial_shift (`bool`, `float` or `int`, defaults to `self.initial_shift`):
The initial shift to apply when sampling frames. If `True`, the shift is set so that frames are sampled from the middle of the video.

Returns:
torch.Tensor:
Sampled video frames.
"""
num_frames = num_frames if num_frames is not None else self.num_frames
initial_shift = initial_shift if initial_shift is not None else self.initial_shift
total_num_frames = video.shape[0]

# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
if metadata is None:
raise ValueError(
"Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
"Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
)
num_frames = int(total_num_frames / metadata["fps"] * fps)

if initial_shift is True:
initial_shift = total_num_frames / num_frames / 2

if num_frames > total_num_frames:
raise ValueError(
f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
)

indices = torch.arange(initial_shift, total_num_frames, total_num_frames / num_frames).int()
video = video[indices].contiguous()
return video

def _preprocess(
self,
videos: List["torch.Tensor"],
video_metadata: Union[List[VideoMetadata], List[dict]],
do_convert_rgb: bool,
do_resize: bool,
size: SizeDict,
size_divisor: Optional[int],
interpolation: Optional["F.InterpolationMode"],
do_center_crop: bool,
crop_size: SizeDict,
do_rescale: bool,
do_pad: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
do_sample_frames: Optional[bool] = None,
fps: Optional[int] = None,
num_frames: Optional[int] = None,
initial_shift: Optional[Union[bool, float, int]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if do_sample_frames:
# Sample video frames
videos = [
self.sample_frames(video, metadata, fps=fps, num_frames=num_frames, initial_shift=initial_shift)
for video, metadata in zip(videos, video_metadata)
]

# Group videos by size for batched resizing
grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
resized_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
if do_convert_rgb:
stacked_videos = self.convert_to_rgb(stacked_videos)
if do_resize:
stacked_videos = self.resize(
stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
)
resized_videos_grouped[shape] = stacked_videos
resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)

# Group videos by size for further processing
# Needed in case do_resize is False, or resize returns videos with different sizes
grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
processed_videos_grouped = {}
for shape, stacked_videos in grouped_videos.items():
if do_center_crop:
stacked_videos = self.center_crop(stacked_videos, crop_size)
# Fused rescale and normalize
stacked_videos = self.rescale_and_normalize(
stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_videos_grouped[shape] = stacked_videos

processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos

return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)


__all__ = ["InternVLVideoProcessor"]
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LlavaNextVideoVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = LlavaNextVideoFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class LlavaOnevisionVideoProcessor(BaseVideoProcessor):
do_rescale = True
do_normalize = True
do_convert_rgb = True
do_sample_frames = False # Set to False for BC, recommended to set `True` in new models
valid_kwargs = LlavaOnevisionFastVideoProcessorInitKwargs
model_input_names = ["pixel_values_videos"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __call__(
seconds_per_chunk = output_kwargs["videos_kwargs"].pop("seconds_per_chunk")
position_id_per_seconds = output_kwargs["videos_kwargs"].pop("position_id_per_seconds")
use_audio_in_video = output_kwargs["videos_kwargs"].pop("use_audio_in_video")
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
fps = output_kwargs["videos_kwargs"].get("fps", 2.0)

if audio is not None:
output_kwargs["audio_kwargs"]["padding"] = "max_length" # Support "max_length" padding only here
Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}


Expand Down Expand Up @@ -1013,9 +1012,7 @@ def __call__(
image_grid_thw = image_inputs["image_grid_thw"]

if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)

fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
"padding": False,
"return_mm_token_type_ids": False,
},
"videos_kwargs": {"fps": 2.0},
}


Expand Down Expand Up @@ -151,9 +150,7 @@ def __call__(
image_grid_thw = image_inputs["image_grid_thw"]

if videos is not None:
# pop fps in advance for passing kwargs validation
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)

fps = output_kwargs["videos_kwargs"].get("fps", 2.0)
videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
video_grid_thw = videos_inputs["video_grid_thw"]

Expand Down
Loading