Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple silicon mps support #709

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,135 @@
# Apple Silicon Mac Users

NOTE: I have submitted a merge request to move the changes in this repo to the [lstein fork of stable-diffusion](https://github.com/lstein/stable-diffusion/) because he has so many wonderful features in his fork! Another fork that I know that has up-to-date Mac Support and some really cool features is the [Birch-san stable-diffusion fork](https://github.com/Birch-san/stable-diffusion). If my pull request to lstein is accepted, I no longer plan on updating this fork with the latest changes!

Several people have gotten Stable Diffusion to work on Apple Silicon Macs using Anaconda. I've gathered up most of their instructions and put them in this fork (and readme). I haven't tested anything besides Anaconda, and I've read about issues with things like miniforge, so if you have an issue that isn't dealt with in this fork then head on over to the [Apple Silicon](https://github.com/CompVis/stable-diffusion/issues/25) issue on GitHub (that page is so long that GitHub hides most of it by default, so you need to find the hidden part and expand it to view the whole thing). This fork would not have been possible without the work done by the people on that issue.

You have to have macOS 12.3 Monterey or later. Anything earlier than that won't work.

BTW, I haven't tested any of this on Intel Macs.

How to:

```
git clone https://github.com/magnusviri/stable-diffusion.git
cd stable-diffusion
git checkout apple-silicon-mps-support

mkdir -p models/ldm/stable-diffusion-v1/
ln -s /path/to/ckpt/sd-v1-1.ckpt models/ldm/stable-diffusion-v1/model.ckpt

conda env create -f environment-mac.yaml
conda activate ldm
```

These instructions are identical to the main repo except I added environment-mac.yaml because Mac doesn't have cudatoolkit.

After you follow all the instructions and run txt2img.py you might get several errors. Here's the errors I've seen and found solutions for.

### Doesn't work anymore?

We are using PyTorch nightly, which includes support for MPS. I don't know exactly how Anaconda does updates, but I woke up one morning and Stable Diffusion crashed and I couldn't think of anything I did that would've changed anything the night before, when it worked. A day and a half later I finally got it working again. I don't know what changed overnight. PyTorch-nightly changes overnight but I'm pretty sure I didn't manually update it. Either way, things are probably going to be bumpy on Apple Silicon until PyTorch releases a firm version that we can lock to.

To manually update to the latest version of PyTorch nightly (which could fix issues), run this command.

conda install pytorch torchvision torchaudio -c pytorch-nightly

### "No module named cv2" (or some other module)

Did you remember to `conda activate ldm`? If your terminal prompt begins with "(ldm)" then you activated it. If it begins with "(base)" or something else you haven't.

If you have activated the ldm virtual environment, the problem could be that I have something installed that you don't and you'll just need to manually install it.

conda activate ldm
pip install *name*

You might also need to install Rust (I mention this again below).

### "The operator [name] is not current implemented for the MPS device." (sic)

Example error.

```
...
NotImplementedError: The operator 'aten::index.Tensor' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on [https://github.com/pytorch/pytorch/issues/77764](https://github.com/pytorch/pytorch/issues/77764). As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```

Just do what it says:

export PYTORCH_ENABLE_MPS_FALLBACK=1

### "Could not build wheels for tokenizers"

I have not seen this error because I had Rust installed on my computer before I started playing with Stable Diffusion. The fix is to install Rust.

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

### How come `--seed` doesn't work?

> Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds.

[PyTorch docs](https://pytorch.org/docs/stable/notes/randomness.html)

There is an [open issue](https://github.com/pytorch/pytorch/issues/78035) (as of August 2022) in pytorch regarding gradient inconsistency. I am guessing that's what is causing this.

### libiomp5.dylib error?

OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.

There are several things you can do. First, you could use something besides Anaconda like miniforge. I read a lot of things online telling people to use something else, but I am stuck with Anaconda for other reasons.

Or you can try this.

export KMP_DUPLICATE_LIB_OK=True

Or this (which takes forever on my computer and didn't work anyway).

conda install nomkl

This error happens with Anaconda on Macs, and [nomkl](https://stackoverflow.com/questions/66224879/what-is-the-nomkl-python-package-used-for) is supposed to fix the issue (it isn't a module but a fix of some sort). [There's more suggestions](https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial), like uninstalling tensorflow and reinstalling. I haven't tried them.

### Not enough memory.

This seems to be a common problem and is probably the underlying problem for a lot of symptoms (listed below). The fix is to lower your image size or to add `model.half()` right after the model is loaded. I should probably test it out. I've read that the reason this fixes problems is because it converts the model from 32-bit to 16-bit and that leaves more RAM for other things. I have no idea how that would affect the quality of the images though.

See [this issue](https://github.com/CompVis/stable-diffusion/issues/71).

### "Error: product of dimension sizes > 2**31'"

This error happens with img2img, which I haven't played with too much yet. But I know it's because your image is too big or the resolution isn't a multiple of 32x32. Because the stable-diffusion model was trained on images that were 512 x 512, it's always best to use that output size (which is the default). However, if you're using that size and you get the above error, try 256 x 256 or 512 x 256 or something as the source image.

BTW, 2**31-1 = [2,147,483,647](https://en.wikipedia.org/wiki/2,147,483,647#In_computing), which is also 32-bit signed [LONG_MAX](https://en.wikipedia.org/wiki/C_data_types) in C.

### I just got Rickrolled! Do I have a virus?

You don't have a virus. It's part of the project. Here's [Rick](https://github.com/magnusviri/stable-diffusion/blob/main/assets/rick.jpeg) and here's [the code](https://github.com/magnusviri/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/scripts/txt2img.py#L79) that swaps him in. It's a NSFW filter, which IMO, doesn't work very good (and we call this "computer vision", sheesh).

Actually, this could be happening because there's not enough RAM. You could try the `model.half()` suggestion or specify smaller output images.

### My images come out black

I haven't solved this issue. I just throw away my black images. There's a [similar issue](https://github.com/CompVis/stable-diffusion/issues/69) on CUDA GPU's where the images come out green. Maybe it's the same issue? Someone in that issue says to use "--precision full", but this fork actually disables that flag. I don't know why, someone else provided that code and I don't know what it does. Maybe the `model.half()` suggestion above would fix this issue too. I should probably test it.

### "view size is not compatible with input tensor's size and stride"

```
File "/opt/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/functional.py", line 2511, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
```

Update to the latest version of magnusviri/stable-diffusion. We were patching pytorch but we found a file in stable-diffusion that we could change instead. This is a 32-bit vs 16-bit problem.

### Still slow?

I changed the defaults of n_samples and n_iter to 1 so that it uses less RAM and makes less images so it will be faster the first time you use it. I don't actually know what n_samples does internally, but I know it consumes a lot more RAM. The n_iter flag just loops around the image creation code, so it shouldn't consume more RAM (it should be faster if you're going to do multiple images because the libraries and model will already be loaded--use a prompt file to get this speed boost).

These flags are the default sample and iter settings in this fork/branch:

python scripts/txt2img.py --prompt "ocean" --n_samples=1 --n_iter=1

Happy fuzzy internet image copying!

# Stable Diffusion
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*

Expand Down
32 changes: 32 additions & 0 deletions environment-mac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: ldm
channels:
- pytorch-nightly
- apple
- conda-forge
- defaults
dependencies:
- python=3.10.4
- pip=22.1.2
- pytorch
- torchvision
- numpy=1.23.1
- pip:
- albumentations==0.4.6
- diffusers
- opencv-python==4.6.0.66
- pudb==2019.2
- invisible-watermark
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.4.2
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2
- torchmetrics==0.6.0
- kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
10 changes: 8 additions & 2 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ def __init__(self, model, schedule="linear", **kwargs):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
if(torch.cuda.is_available()):
self.device_available = "cuda"
elif(torch.backends.mps.is_available()):
self.device_available = "mps"
else:
self.device_available = "cpu"

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
10 changes: 8 additions & 2 deletions ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ def __init__(self, model, schedule="linear", **kwargs):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
if(torch.cuda.is_available()):
self.device_available = "cuda"
elif(torch.backends.mps.is_available()):
self.device_available = "mps"
else:
self.device_available = "cpu"

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
1 change: 1 addition & 0 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

def _forward(self, x, context=None):
x = x.contiguous()
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
Expand Down
23 changes: 16 additions & 7 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test


def get_default_device_type():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"


class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -35,7 +44,7 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=get_default_device_type()):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
Expand All @@ -52,7 +61,7 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device=get_default_device_type(), vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Expand Down Expand Up @@ -80,7 +89,7 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device=get_default_device_type(),use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
Expand Down Expand Up @@ -136,7 +145,7 @@ def encode(self, x):

class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device=get_default_device_type(), max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -166,9 +175,9 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(self, version='ViT-L/14', device=get_default_device_type(), max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.model, _ = clip.load(version, jit=False, device=device)
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
Expand Down Expand Up @@ -202,7 +211,7 @@ def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device=get_default_device_type(),
antialias=False,
):
super().__init__()
Expand Down
8 changes: 7 additions & 1 deletion notebook_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,13 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.

c = c.to(torch.device("cuda"))
if(torch.cuda.is_available()):
device = torch.device("cuda")
elif(torch.backends.mps.is_available()):
device = torch.device("mps")
else:
device = torch.device("cpu")
c = c.to(device)
example["LR_image"] = c
example["image"] = c_up

Expand Down
21 changes: 21 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
numpy==1.23.1
--pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

albumentations==0.4.6
diffusers
opencv-python==4.6.0.66
pudb==2019.2
invisible-watermark
imageio==2.9.0
imageio-ffmpeg==0.4.2
pytorch-lightning==1.4.2
omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
torch-fidelity==0.3.0
transformers==4.19.2
torchmetrics==0.6.0
kornia==0.6
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/openai/CLIP.git@main#egg=clip
16 changes: 13 additions & 3 deletions scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

def get_device():
if(torch.cuda.is_available()):
return torch.device("cuda")
elif(torch.backends.mps.is_available()):
return torch.device("mps")
else:
return torch.device("cpu")


def chunk(it, size):
it = iter(it)
Expand All @@ -40,7 +48,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
model.to(get_device())
model.eval()
return model

Expand Down Expand Up @@ -199,7 +207,7 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = get_device()
model = model.to(device)

if opt.plms:
Expand Down Expand Up @@ -241,8 +249,10 @@ def main():
print(f"target t_enc is {t_enc} steps")

precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type == 'mps':
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
Expand Down
7 changes: 6 additions & 1 deletion scripts/inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def make_batch(image, mask, device):
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if(torch.cuda.is_available()):
device = torch.device("cuda")
elif(torch.backends.mps.is_available()):
device = torch.device("mps")
else:
device = torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

Expand Down
Loading