33
44import logging
55import math
6- import numpy as np
76import random
7+
8+ import numpy as np
89import torch
910import torchvision .io as io
1011
@@ -84,7 +85,7 @@ def get_multiple_start_end_idx(
8485 num_clips_uniform ,
8586 min_delta = 0 ,
8687 max_delta = math .inf ,
87- use_offset = False ,
88+ use_offset = False
8889):
8990 """
9091 Sample a clip of size clip_size from a video of size video_size and
@@ -114,7 +115,7 @@ def sample_clips(
114115 min_delta = 0 ,
115116 max_delta = math .inf ,
116117 num_retries = 100 ,
117- use_offset = False ,
118+ use_offset = False
118119 ):
119120 se_inds = np .empty ((0 , 2 ))
120121 dt = np .empty ((0 ))
@@ -125,15 +126,13 @@ def sample_clips(
125126 if clip_idx == - 1 :
126127 # Random temporal sampling.
127128 start_idx = random .uniform (0 , max_start )
128- else : # Uniformly sample the clip with the given index.
129+ else : # Uniformly sample the clip with the given index.
129130 if use_offset :
130131 if num_clips_uniform == 1 :
131132 # Take the center clip if num_clips is 1.
132133 start_idx = math .floor (max_start / 2 )
133134 else :
134- start_idx = clip_idx * math .floor (
135- max_start / (num_clips_uniform - 1 )
136- )
135+ start_idx = clip_idx * math .floor (max_start / (num_clips_uniform - 1 ))
137136 else :
138137 start_idx = max_start * clip_idx / num_clips_uniform
139138
@@ -304,7 +303,10 @@ def torchvision_decode(
304303 decode_all_video = False # try selective decoding
305304
306305 clip_sizes = [
307- np .maximum (1.0 , sampling_rate [i ] * num_frames [i ] / target_fps * fps )
306+ np .maximum (
307+ 1.0 ,
308+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
309+ )
308310 for i in range (len (sampling_rate ))
309311 ]
310312 start_end_delta_time = get_multiple_start_end_idx (
@@ -381,6 +383,10 @@ def pyav_decode(
381383 num_clips_uniform = 10 ,
382384 target_fps = 30 ,
383385 use_offset = False ,
386+ modalities = ("visual" ,),
387+ max_spatial_scale = 0 ,
388+ min_delta = - math .inf ,
389+ max_delta = math .inf ,
384390):
385391 """
386392 Convert the video from its original fps to the target_fps. If the video
@@ -418,38 +424,69 @@ def pyav_decode(
418424 # If failed to fetch the decoding information, decode the entire video.
419425 decode_all_video = True
420426 video_start_pts , video_end_pts = 0 , math .inf
427+ start_end_delta_time = None
428+
429+ frames = None
430+ if container .streams .video :
431+ video_frames , max_pts = pyav_decode_stream (
432+ container ,
433+ video_start_pts ,
434+ video_end_pts ,
435+ container .streams .video [0 ],
436+ {"video" : 0 },
437+ )
438+ container .close ()
439+
440+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
441+ frames = torch .as_tensor (np .stack (frames ))
442+ frames_out = [frames ]
443+
421444 else :
422445 # Perform selective decoding.
423446 decode_all_video = False
424- clip_size = np .maximum (
425- 1.0 , np .ceil (sampling_rate * (num_frames - 1 ) / target_fps * fps )
426- )
427- start_idx , end_idx , fraction = get_start_end_idx (
447+ clip_sizes = [
448+ np .maximum (
449+ 1.0 ,
450+ np .ceil (
451+ sampling_rate [i ] * (num_frames [i ] - 1 ) / target_fps * fps
452+ ),
453+ )
454+ for i in range (len (sampling_rate ))
455+ ]
456+ start_end_delta_time = get_multiple_start_end_idx (
428457 frames_length ,
429- clip_size ,
458+ clip_sizes ,
430459 clip_idx ,
431460 num_clips_uniform ,
432- use_offset = use_offset ,
433- )
434- timebase = duration / frames_length
435- video_start_pts = int (start_idx * timebase )
436- video_end_pts = int (end_idx * timebase )
437-
438- frames = None
439- # If video stream was found, fetch video frames from the video.
440- if container .streams .video :
441- video_frames , max_pts = pyav_decode_stream (
442- container ,
443- video_start_pts ,
444- video_end_pts ,
445- container .streams .video [0 ],
446- {"video" : 0 },
461+ min_delta = min_delta ,
462+ max_delta = max_delta ,
447463 )
464+ frames_out = [None ] * len (num_frames )
465+ for k in range (len (num_frames )):
466+ start_idx = start_end_delta_time [k , 0 ]
467+ end_idx = start_end_delta_time [k , 1 ]
468+ timebase = duration / frames_length
469+ video_start_pts = int (start_idx * timebase )
470+ video_end_pts = int (end_idx * timebase )
471+
472+ frames = None
473+ # If video stream was found, fetch video frames from the video.
474+ if container .streams .video :
475+ video_frames , max_pts = pyav_decode_stream (
476+ container ,
477+ video_start_pts ,
478+ video_end_pts ,
479+ container .streams .video [0 ],
480+ {"video" : 0 },
481+ )
482+
483+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
484+ frames = torch .as_tensor (np .stack (frames ))
485+
486+ frames_out [k ] = frames
448487 container .close ()
449488
450- frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
451- frames = torch .as_tensor (np .stack (frames ))
452- return frames , fps , decode_all_video
489+ return frames_out , fps , decode_all_video , start_end_delta_time
453490
454491
455492def decode (
@@ -509,17 +546,20 @@ def decode(
509546 ) # clips come temporally ordered from decoder
510547 try :
511548 if backend == "pyav" :
512- assert (
513- min_delta == - math .inf and max_delta == math .inf
514- ), "delta sampling not supported in pyav"
515- frames_decoded , fps , decode_all_video = pyav_decode (
549+ assert min_delta == - math .inf and max_delta == math .inf , \
550+ "delta sampling not supported in pyav"
551+ frames_decoded , fps , decode_all_video , start_end_delta_time = pyav_decode (
516552 container ,
517553 sampling_rate ,
518554 num_frames ,
519555 clip_idx ,
520556 num_clips_uniform ,
521557 target_fps ,
522558 use_offset = use_offset ,
559+ modalities = ("visual" ,),
560+ max_spatial_scale = max_spatial_scale ,
561+ min_delta = min_delta ,
562+ max_delta = max_delta ,
523563 )
524564 elif backend == "torchvision" :
525565 (
@@ -557,7 +597,10 @@ def decode(
557597 frames_decoded = [frames_decoded ]
558598 num_decoded = len (frames_decoded )
559599 clip_sizes = [
560- np .maximum (1.0 , sampling_rate [i ] * num_frames [i ] / target_fps * fps )
600+ np .maximum (
601+ 1.0 ,
602+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
603+ )
561604 for i in range (len (sampling_rate ))
562605 ]
563606
@@ -621,4 +664,4 @@ def decode(
621664 for i in range (num_decode )
622665 )
623666
624- return frames_out , start_end_delta_time , time_diff_aug
667+ return frames_out , start_end_delta_time , time_diff_aug
0 commit comments