diff --git a/wgp.py b/wgp.py index afdcd57c8..1053b4977 100644 --- a/wgp.py +++ b/wgp.py @@ -630,7 +630,8 @@ def ret(): if image_start is not None and not isinstance(image_start, list): image_start = [image_start] outpainting_dims = get_outpainting_dims(video_guide_outpainting) - if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): + fit_canvas = inputs["fit_canvas"] + if fit_canvas == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") if not model_def.get("motion_amplitude", False): motion_amplitude = 1. @@ -4636,7 +4637,7 @@ def upsample_frames(frame): return resize_lanczos(frame, h, w, method).unsqueeze(1) sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers(), in_place=True), dim=1) frames_to_upsample = None - return sample + return sample def any_audio_track(model_type): @@ -4724,23 +4725,24 @@ def edit_video( if mode == "edit_postprocessing": if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 or film_grain_intensity > 0: - send_cmd("progress", [0, get_latest_status(state,"Upsampling" if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 else "Adding Film Grain" )]) sample = get_resampled_video(video_source, 0, max_source_video_frames, fps) sample = sample.float().div_(127.5).sub_(1.).permute(-1,0,1,2) frames_count = sample.shape[1] + if len(spatial_upsampling) > 0: + send_cmd("progress", [0, get_latest_status(state,"Spatial Upsampling")]) + sample = perform_spatial_upsampling(sample, spatial_upsampling) + configs["spatial_upsampling"] = spatial_upsampling + output_fps = round(fps) if len(temporal_upsampling) > 0: + send_cmd("progress", [0, get_latest_status(state,"Temporal Upsampling")]) sample, previous_last_frame, output_fps = perform_temporal_upsampling(sample, None, temporal_upsampling, fps) configs["temporal_upsampling"] = temporal_upsampling - frames_count = sample.shape[1] - - - if len(spatial_upsampling) > 0: - sample = perform_spatial_upsampling(sample, spatial_upsampling ) - configs["spatial_upsampling"] = spatial_upsampling + frames_count = sample.shape[1] if film_grain_intensity > 0: + send_cmd("progress", [0, get_latest_status(state,"Film Grain")]) from postprocessing.film_grain import add_film_grain sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation) configs["film_grain_intensity"] = film_grain_intensity @@ -4755,8 +4757,8 @@ def edit_video( tmp_path = None any_change = False if sample != None: - video_path =get_available_filename(save_path, video_source, "_tmp") if any_mmaudio or has_already_audio else get_available_filename(save_path, video_source, "_post") - save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4")) + video_path =get_available_filename(save_path, video_source, "_tmp") if any_mmaudio or has_already_audio else get_available_filename(save_path, video_source, "_post") + save_video(tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container=server_config.get("video_container", "mp4")) if any_mmaudio or has_already_audio: tmp_path = video_path any_change = True @@ -5096,6 +5098,7 @@ def generate_video( prompt, negative_prompt, resolution, + fit_canvas, video_length, batch_size, seed, @@ -5194,6 +5197,38 @@ def remove_temp_filenames(temp_filenames_list): if temp_filename!= None and os.path.isfile(temp_filename): os.remove(temp_filename) + # --- CONTINUING VIDEO VALIDATION CHECK --- + # Before starting generation, ensure that if we are continuing a video ("V" or "L") + # the final result matches the source video resolution/FPS. + if video_source and any_letters(image_prompt_type, "VL"): + # 1. Determine Upscaling Multipliers + s_mult = 1.0 + if "lanczos1.5" in spatial_upsampling: s_mult = 1.5 + elif "lanczos2" in spatial_upsampling or "vae2" in spatial_upsampling: s_mult = 2.0 + elif "vae1" in spatial_upsampling: s_mult = 0.5 + + t_mult = 1 + if "rife2" in temporal_upsampling: t_mult = 2 + elif "rife4" in temporal_upsampling: t_mult = 4 + + # 2. Calculate Expected Target Properties (Model Resolution * Multiplier) + base_w, base_h = resolution.split("x") + base_w, base_h = int(base_w), int(base_h) + expected_w = base_w * s_mult + expected_h = base_h * s_mult + + base_model_type_check = get_base_model_type(model_type) + base_fps = get_computed_fps(force_fps, base_model_type_check, video_guide, video_source) + expected_fps = base_fps * t_mult + + # 3. Get Source Properties and compare with tolerance (16px for resolution, 0.1 for FPS) + src_fps, src_w, src_h, _ = get_video_info(video_source) + if abs(src_w - expected_w) > 16 or abs(src_h - expected_h) > 16: + raise gr.Error(f"Resolution Mismatch: Source is {src_w}x{src_h}, but result will be approx {int(expected_w)}x{int(expected_h)}. Please adjust Resolution or Spatial Upsampling to match the source.") + if abs(src_fps - expected_fps) > 0.1: + raise gr.Error(f"FPS Mismatch: Source is {src_fps:.2f} fps, but result will be {expected_fps:.2f} fps. Please adjust Default Model FPS or Temporal Upsampling to match the source.") + # ----------------------------------- + global wan_model, offloadobj, reload_needed gen = get_gen_info(state) torch.set_grad_enabled(False) @@ -5270,7 +5305,7 @@ def remove_temp_filenames(temp_filenames_list): return width, height = resolution.split("x") - width, height = int(width) // block_size * block_size, int(height) // block_size * block_size + width, height = int(width) // block_size * block_size, int(height) // block_size * block_size default_image_size = (height, width) if slg_switch == 0: @@ -5429,7 +5464,6 @@ def remove_temp_filenames(temp_filenames_list): any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1 outpainting_dims = get_outpainting_dims(video_guide_outpainting) - fit_canvas = server_config.get("fit_canvas", 0) fit_crop = fit_canvas == 2 if fit_crop and outpainting_dims is not None: fit_crop = False @@ -5628,13 +5662,13 @@ def remove_temp_filenames(temp_filenames_list): image_start_tensor = convert_image_to_tensor(image_start_tensor) pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) else: - prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size ) - prefix_video = prefix_video.permute(3, 0, 1, 2) - prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w + prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size ) + prefix_video = prefix_video.permute(3, 0, 1, 2) + prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0) new_height, new_width = prefix_video.shape[-2:] - pre_video_guide = prefix_video[:, -reuse_frames:] + pre_video_guide = prefix_video[:, -reuse_frames:] pre_video_frame = convert_tensor_to_image(prefix_video[:, -1]) source_video_overlap_frames_count = pre_video_guide.shape[1] source_video_frames_count = prefix_video.shape[1] @@ -6078,6 +6112,7 @@ def set_header_text(txt): sample = sample[:, post_decode_pre_trim:] if gen.get("extra_windows",0) > 0: sliding_window = True + if sliding_window : # guide_start_frame = guide_end_frame guide_start_frame += current_video_length @@ -6098,14 +6133,12 @@ def set_header_text(txt): else: pre_video_guide = sample[:, -reuse_frames:].clone() - - if prefix_video != None and window_no == 1 : - if prefix_video.shape[1] > 1: - # remove sliding window overlapped frames at the beginning of the generation - sample = torch.cat([ prefix_video, sample[: , source_video_overlap_frames_count:]], dim = 1) - else: - # remove source video overlapped frames at the beginning of the generation if there is only a start frame - sample = torch.cat([ prefix_video[:, :-source_video_overlap_frames_count], sample], dim = 1) + if prefix_video != None and window_no == 1: + # If continue (last) video is active, we concat high-res original frames LATER. + if not any_letters(image_prompt_type, "VL"): + # remove prefix video overlapped frames at the beginning of the generation + sample = torch.cat([prefix_video[:, :-source_video_overlap_frames_count], sample], dim = 1) + guide_start_frame -= source_video_overlap_frames_count if generated_audio is not None: generated_audio = truncate_audio( @@ -6133,26 +6166,47 @@ def set_header_text(txt): full_generated_audio = generated_audio if full_generated_audio is None else np.concatenate([full_generated_audio, generated_audio], axis=0) output_new_audio_data = full_generated_audio - - if len(temporal_upsampling) > 0 or len(spatial_upsampling) > 0 and not "vae2" in spatial_upsampling: - send_cmd("progress", [0, get_latest_status(state,"Upsampling")]) + if len(spatial_upsampling) > 0: + send_cmd("progress", [0, get_latest_status(state,"Spatial Upsampling")]) + sample = perform_spatial_upsampling(sample, spatial_upsampling) output_fps = fps if len(temporal_upsampling) > 0: + send_cmd("progress", [0, get_latest_status(state,"Temporal Upsampling")]) sample, previous_last_frame, output_fps = perform_temporal_upsampling(sample, previous_last_frame if sliding_window and window_no > 1 else None, temporal_upsampling, fps) - if len(spatial_upsampling) > 0: - sample = perform_spatial_upsampling(sample, spatial_upsampling ) + if any_letters(image_prompt_type, "VL"): + src_fps, src_w, src_h, _ = get_video_info(video_source) + sample_height, sample_width = sample.shape[-2:] + + # Resize sample to match the source's resolution exactly if they differ + if src_h != sample_height or src_w != sample_width: + send_cmd("progress", [0, get_latest_status(state,"Resizing Video")]) + # Permute to (F, C, H, W) for torch.nn.functional.interpolate + sample = sample.permute(1, 0, 2, 3) + sample = torch.nn.functional.interpolate( + sample, + size=[src_h, src_w], + mode='bilinear', + align_corners=False + ) + # Permute back to (C, F, H, W) + sample = sample.permute(1, 0, 2, 3) + if film_grain_intensity> 0: + send_cmd("progress", [0, get_latest_status(state,"Film Grain")]) from postprocessing.film_grain import add_film_grain - sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation) - if sliding_window : + sample = add_film_grain(sample, film_grain_intensity, film_grain_saturation) + + if sliding_window: + send_cmd("progress", [0, get_latest_status(state,"Sliding Window")]) if frames_already_processed == None: frames_already_processed = sample else: sample = torch.cat([frames_already_processed, sample], dim=1) frames_already_processed = sample + send_cmd("progress", [0, get_latest_status(state,"Saving")]) time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss") save_prompt = original_prompts[0] if audio_only: @@ -6230,9 +6284,20 @@ def set_header_text(txt): else: save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None), container= container) - + + # If continue (last) video is chosen + if any_letters(image_prompt_type, "VL"): + send_cmd("progress", [0, get_latest_status(state,"Merging Videos")]) + # Add "_Preview" to the file name of the new generated sample + base, ext = os.path.splitext(video_path) + video_preview_path = base + "_Preview" + ext + os.rename(video_path, video_preview_path) + # Merge the saved video with the last one via ffmpeg, which saves a lot of RAM in comparison to torch.cat + combine_videos(video_source, video_preview_path, video_path, trim_end_frames1=source_video_overlap_frames_count, trim_start_frames2=1, fps=output_fps) + os.remove(video_preview_path) end_time = time.time() + send_cmd("progress", [0, get_latest_status(state,"Add Meta Data")]) inputs.pop("send_cmd") inputs.pop("task") inputs.pop("mode") @@ -6316,6 +6381,48 @@ def set_header_text(txt): remove_temp_filenames(temp_filenames_list) +def combine_videos( + video1_path, + video2_path, + output_path, + trim_end_frames1=0, + trim_start_frames2=0, + fps=16, + vcodec='libx264', + crf=10, + preset='veryfast', + audio_bitrate='192k'): + + import ffmpeg + + # Calculate trim duration in seconds + trim_end_seconds1 = trim_end_frames1 / fps + trim_start_seconds2 = trim_start_frames2 / fps + + probe1 = ffmpeg.probe(video1_path) + probe2 = ffmpeg.probe(video2_path) + duration1 = float(probe1['streams'][0]['duration']) + + # Trim the end of the first and the beginning of the second video, then join them + input1 = ffmpeg.input(video1_path, ss=0, t=duration1 - trim_end_seconds1) + input2 = ffmpeg.input(video2_path, ss=trim_start_seconds2) + v1 = input1.video + v2 = input2.video + joined_video = ffmpeg.concat(v1, v2, v=1, a=0) + + # Check if videos have audio + has_audio1 = any(s['codec_type'] == 'audio' for s in probe1['streams']) + has_audio2 = any(s['codec_type'] == 'audio' for s in probe2['streams']) + if has_audio1 and has_audio2: + a1 = input1.audio + a2 = input2.audio + joined_audio = ffmpeg.concat(a1, a2, v=0, a=1) + output = ffmpeg.output(joined_video, joined_audio, output_path, vcodec=vcodec, crf=crf, preset=preset, audio_bitrate=audio_bitrate) + else: + output = ffmpeg.output(joined_video, output_path, vcodec=vcodec, crf=crf, preset=preset) + + ffmpeg.run(output, overwrite_output=True, quiet=True) + def prepare_generate_video(state): if state.get("validate_success",0) != 1: @@ -7882,6 +7989,7 @@ def save_inputs( prompt, negative_prompt, resolution, + fit_canvas, video_length, batch_size, seed, @@ -9165,17 +9273,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl pace = gr.Slider( 0.2, 1, value=ui_get("pace"), step=0.01, label="Pace", show_reset_button= False) with gr.Row(visible=not audio_only) as resolution_row: - fit_canvas = server_config.get("fit_canvas", 0) - if fit_canvas == 1: - label = "Outer Box Resolution (one dimension may be less to preserve video W/H ratio)" - elif fit_canvas == 2: - label = "Output Resolution (Input Images wil be Cropped if the W/H ratio is different)" - else: - label = "Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)" current_resolution_choice = ui_get("resolution") if update_form or last_resolution is None else last_resolution model_resolutions = model_def.get("resolutions", None) resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) available_groups, selected_group_resolutions, selected_group = group_resolutions(model_def,resolution_choices, current_resolution_choice) + current_fit_canvas = ui_get("fit_canvas", 0) resolution_group = gr.Dropdown( choices = available_groups, value= selected_group, @@ -9184,8 +9286,17 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl resolution = gr.Dropdown( choices = selected_group_resolutions, value= current_resolution_choice, - label= label, - scale = 5 + label= "Format", + scale = 2 + ) + fit_canvas = gr.Dropdown( + choices=[("Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)", 0), + ("Outer Box Resolution (one dimension may be less to preserve video W/H ratio)", 1), + ("Output Resolution (Input Images wil be Cropped if the W/H ratio is different)", 2)], + value= current_fit_canvas, + label="Fit Canvas", + scale = 3, + interactive = True ) with gr.Row(visible= not audio_only) as number_frames_row: batch_label = model_def.get("batch_size_label", "Number of Images to Generate") @@ -9577,7 +9688,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai force_fps_choices += [("Source Video fps", "source")] force_fps_choices += [ ("15", "15"), - ("16", "16"), + ("16", "16"), + ("20", "20"), ("23", "23"), ("24", "24"), ("25", "25"),