Skip to content

Commit

Permalink
Merge pull request #105 from ggozad/FIX-generator-mps
Browse files Browse the repository at this point in the history
Allow MPS devices to use the gradio interface
  • Loading branch information
nateraw authored Nov 1, 2022
2 parents ba12367 + 703b832 commit e80012a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions stable_diffusion_videos/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ def generate_input_batches(pipeline, prompts, seeds, batch_size, height, width):
noise = torch.randn(
(1, pipeline.unet.in_channels, height // 8, width // 8),
device=pipeline.device,
generator=torch.Generator(device=pipeline.device).manual_seed(seed),
generator=torch.Generator(
device="cpu" if pipeline.device.type == "mps" else pipeline.device
).manual_seed(seed),
)
embeds_batch = (
embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
)
embeds_batch = embeds if embeds_batch is None else torch.cat([embeds_batch, embeds])
noise_batch = noise if noise_batch is None else torch.cat([noise_batch, noise])
batch_is_ready = embeds_batch.shape[0] == batch_size or i + 1 == len(prompts)
if not batch_is_ready:
Expand Down

0 comments on commit e80012a

Please sign in to comment.