@@ -382,6 +382,10 @@ def pyav_decode(
382382 num_clips_uniform = 10 ,
383383 target_fps = 30 ,
384384 use_offset = False ,
385+ modalities = ("visual" ,),
386+ max_spatial_scale = 0 ,
387+ min_delta = - math .inf ,
388+ max_delta = math .inf ,
385389):
386390 """
387391 Convert the video from its original fps to the target_fps. If the video
@@ -419,38 +423,69 @@ def pyav_decode(
419423 # If failed to fetch the decoding information, decode the entire video.
420424 decode_all_video = True
421425 video_start_pts , video_end_pts = 0 , math .inf
426+ start_end_delta_time = None
427+
428+ frames = None
429+ if container .streams .video :
430+ video_frames , max_pts = pyav_decode_stream (
431+ container ,
432+ video_start_pts ,
433+ video_end_pts ,
434+ container .streams .video [0 ],
435+ {"video" : 0 },
436+ )
437+ container .close ()
438+
439+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
440+ frames = torch .as_tensor (np .stack (frames ))
441+ frames_out = [frames ]
442+
422443 else :
423444 # Perform selective decoding.
424445 decode_all_video = False
425- clip_size = np .maximum (
426- 1.0 , np .ceil (sampling_rate * (num_frames - 1 ) / target_fps * fps )
427- )
428- start_idx , end_idx , fraction = get_start_end_idx (
446+ clip_sizes = [
447+ np .maximum (
448+ 1.0 ,
449+ np .ceil (
450+ sampling_rate [i ] * (num_frames [i ] - 1 ) / target_fps * fps
451+ ),
452+ )
453+ for i in range (len (sampling_rate ))
454+ ]
455+ start_end_delta_time = get_multiple_start_end_idx (
429456 frames_length ,
430- clip_size ,
457+ clip_sizes ,
431458 clip_idx ,
432459 num_clips_uniform ,
433- use_offset = use_offset ,
434- )
435- timebase = duration / frames_length
436- video_start_pts = int (start_idx * timebase )
437- video_end_pts = int (end_idx * timebase )
438-
439- frames = None
440- # If video stream was found, fetch video frames from the video.
441- if container .streams .video :
442- video_frames , max_pts = pyav_decode_stream (
443- container ,
444- video_start_pts ,
445- video_end_pts ,
446- container .streams .video [0 ],
447- {"video" : 0 },
460+ min_delta = min_delta ,
461+ max_delta = max_delta ,
448462 )
463+ frames_out = [None ] * len (num_frames )
464+ for k in range (len (num_frames )):
465+ start_idx = start_end_delta_time [k , 0 ]
466+ end_idx = start_end_delta_time [k , 1 ]
467+ timebase = duration / frames_length
468+ video_start_pts = int (start_idx * timebase )
469+ video_end_pts = int (end_idx * timebase )
470+
471+ frames = None
472+ # If video stream was found, fetch video frames from the video.
473+ if container .streams .video :
474+ video_frames , max_pts = pyav_decode_stream (
475+ container ,
476+ video_start_pts ,
477+ video_end_pts ,
478+ container .streams .video [0 ],
479+ {"video" : 0 },
480+ )
481+
482+ frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
483+ frames = torch .as_tensor (np .stack (frames ))
484+
485+ frames_out [k ] = frames
449486 container .close ()
450487
451- frames = [frame .to_rgb ().to_ndarray () for frame in video_frames ]
452- frames = torch .as_tensor (np .stack (frames ))
453- return frames , fps , decode_all_video
488+ return frames_out , fps , decode_all_video , start_end_delta_time
454489
455490
456491def decode (
@@ -510,17 +545,20 @@ def decode(
510545 ) # clips come temporally ordered from decoder
511546 try :
512547 if backend == "pyav" :
513- assert (
514- min_delta == - math .inf and max_delta == math .inf
515- ), "delta sampling not supported in pyav"
516- frames_decoded , fps , decode_all_video = pyav_decode (
548+ assert min_delta == - math .inf and max_delta == math .inf , \
549+ "delta sampling not supported in pyav"
550+ frames_decoded , fps , decode_all_video , start_end_delta_time = pyav_decode (
517551 container ,
518552 sampling_rate ,
519553 num_frames ,
520554 clip_idx ,
521555 num_clips_uniform ,
522556 target_fps ,
523557 use_offset = use_offset ,
558+ modalities = ("visual" ,),
559+ max_spatial_scale = max_spatial_scale ,
560+ min_delta = min_delta ,
561+ max_delta = max_delta ,
524562 )
525563 elif backend == "torchvision" :
526564 (
@@ -558,12 +596,12 @@ def decode(
558596 frames_decoded = [frames_decoded ]
559597 num_decoded = len (frames_decoded )
560598 clip_sizes = [
561- np .maximum (
562- 1.0 ,
563- sampling_rate [i ] * num_frames [i ] / target_fps * fps
564- )
565- for i in range (len (sampling_rate ))
566- ]
599+ np .maximum (
600+ 1.0 ,
601+ sampling_rate [i ] * num_frames [i ] / target_fps * fps
602+ )
603+ for i in range (len (sampling_rate ))
604+ ]
567605
568606 if decode_all_video : # full video was decoded (not trimmed yet)
569607 assert num_decoded == 1 and start_end_delta_time is None
0 commit comments