Skip to content

Commit

Permalink
Merge pull request #95 from nateraw/fix-weight-calc
Browse files Browse the repository at this point in the history
Fix audio interpolation weight calculation
  • Loading branch information
nateraw authored Oct 23, 2022
2 parents d2ac212 + 7885ffc commit bad063a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 41 deletions.
6 changes: 5 additions & 1 deletion stable_diffusion_videos/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def generate_images(
prompt,
batch_size=1,
num_batches=1,
seeds=None,
num_inference_steps=50,
guidance_scale=7.5,
output_dir="./images",
Expand All @@ -120,6 +121,7 @@ def generate_images(
prompt (str): The prompt to use for the image generation.
batch_size (int, *optional*, defaults to 1): The batch size to use for image generation.
num_batches (int, *optional*, defaults to 1): The number of batches to generate.
seeds (list[int], *optional*): The seeds to use for the image generation.
num_inference_steps (int, *optional*, defaults to 50): The number of inference steps to take.
guidance_scale (float, *optional*, defaults to 7.5): The guidance scale to use for image generation.
output_dir (str, *optional*, defaults to "./images"): The output directory to save the images to.
Expand All @@ -145,7 +147,9 @@ def generate_images(
prompt_config_path = save_path / "prompt_config.json"

num_images = batch_size * num_batches
seeds = [random.choice(list(range(0, 9999999))) for _ in range(num_images)]
seeds = seeds or [random.choice(list(range(0, 9999999))) for _ in range(num_images)]
if len(seeds) != num_images:
raise ValueError("Number of seeds must be equal to batch_size * num_batches.")

if upsample:
if getattr(pipeline, 'upsampler', None) is None:
Expand Down
69 changes: 29 additions & 40 deletions stable_diffusion_videos/stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,53 +25,35 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

def get_spec_norm(wav, sr, n_mels=512, hop_length=704):
"""Obtain maximum value for each time-frame in Mel Spectrogram,
and normalize between 0 and 1

Borrowed from lucid sonic dreams repo. In there, they programatically determine hop length
but I really didn't understand what was going on so I removed it and hard coded the output.
"""
def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=1.0, smooth=0.0):
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)

# Generate Mel Spectrogram
spec_raw = librosa.feature.melspectrogram(y=wav, sr=sr, n_mels=n_mels, hop_length=hop_length)
# librosa.stft hardcoded defaults...
# n_fft defaults to 2048
# hop length is win_length // 4
# win_length defaults to n_fft
D = librosa.stft(y, n_fft=2048, hop_length=2048 // 4, win_length=2048)

# Obtain maximum value per time-frame
spec_max = np.amax(spec_raw, axis=0)
# Extract percussive elements
D_harmonic, D_percussive = librosa.decompose.hpss(D, margin=margin)
y_percussive = librosa.istft(D_percussive, length=len(y))

# Normalize all values between 0 and 1
# Get normalized melspectrogram
spec_raw = librosa.feature.melspectrogram(y=y_percussive, sr=sr)
spec_max = np.amax(spec_raw, axis=0)
spec_norm = (spec_max - np.min(spec_max)) / np.ptp(spec_max)

return spec_norm

# Resize cumsum of spec norm to our desired number of interpolation frames
x_norm = np.linspace(0, spec_norm.shape[-1], spec_norm.shape[-1])
y_norm = np.cumsum(spec_norm)
y_norm /= y_norm[-1]
x_resize = np.linspace(0, y_norm.shape[-1], int(duration*fps))

def get_timesteps_arr(audio_filepath, offset, duration, fps=30, margin=(1.0, 5.0)):
"""Get the array that will be used to determine how much to interpolate between images.
Normally, this is just a linspace between 0 and 1 for the number of frames to generate. In this case,
we want to use the amplitude of the audio to determine how much to interpolate between images.
So, here we:
1. Load the audio file
2. Split the audio into harmonic and percussive components
3. Get the normalized amplitude of the percussive component, resized to the number of frames
4. Get the cumulative sum of the amplitude array
5. Normalize the cumulative sum between 0 and 1
6. Return the array
I honestly have no clue what I'm doing here. Suggestions welcome.
"""
y, sr = librosa.load(audio_filepath, offset=offset, duration=duration)
wav_harmonic, wav_percussive = librosa.effects.hpss(y, margin=margin)
T = np.interp(x_resize, x_norm, y_norm)

# Apparently n_mels is supposed to be input shape but I don't think it matters here?
frame_duration = int(sr / fps)
wav_norm = get_spec_norm(wav_percussive, sr, n_mels=512, hop_length=frame_duration)
amplitude_arr = np.resize(wav_norm, int(duration * fps))
T = np.cumsum(amplitude_arr)
T /= T[-1]
T[0] = 0.0
return T
# Apply smoothing
return T * (1 - smooth) + np.linspace(0.0, 1.0, T.shape[0]) * smooth


def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
Expand Down Expand Up @@ -604,6 +586,8 @@ def walk(
resume: Optional[bool] = False,
audio_filepath: str = None,
audio_start_sec: Optional[Union[int, float]] = None,
margin: Optional[float] = 1.0,
smooth: Optional[float] = 0.0,
):
"""Generate a video from a sequence of prompts and seeds. Optionally, add audio to the
video to interpolate to the intensity of the audio.
Expand Down Expand Up @@ -650,6 +634,10 @@ def walk(
Optional path to an audio file to influence the interpolation rate.
audio_start_sec (Optional[Union[int, float]], *optional*, defaults to 0):
Global start time of the provided audio_filepath.
margin (Optional[float], *optional*, defaults to 1.0):
Margin from librosa hpss to use for audio interpolation.
smooth (Optional[float], *optional*, defaults to 0.0):
Smoothness of the audio interpolation. 1.0 means linear interpolation.
This function will create sub directories for each prompt and seed pair.
Expand Down Expand Up @@ -789,7 +777,8 @@ def walk(
offset=audio_offset,
duration=audio_duration,
fps=fps,
margin=(1.0, 5.0),
margin=margin,
smooth=smooth,
)
if audio_filepath
else None,
Expand Down

0 comments on commit bad063a

Please sign in to comment.