diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index fb31215db..4ffe882bc 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -18,7 +18,7 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if torch.cuda.is_available() and attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) @@ -238,4 +238,4 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - return x_dec \ No newline at end of file + return x_dec diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 78eeb1003..e25002ede 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -7,7 +7,6 @@ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like - class PLMSSampler(object): def __init__(self, model, schedule="linear", **kwargs): super().__init__() @@ -17,7 +16,7 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if torch.cuda.is_available() and attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index ededbe43e..514cf681e 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -35,7 +35,7 @@ def forward(self, batch, key=None): class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda' if torch.cuda.is_available() else 'cpu'): super().__init__() self.device = device self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, @@ -52,7 +52,7 @@ def encode(self, x): class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): + def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu', vq_interface=True, max_length=77): super().__init__() from transformers import BertTokenizerFast # TODO: add to reuquirements self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") @@ -80,7 +80,7 @@ def decode(self, text): class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): + device='cuda' if torch.cuda.is_available() else 'cpu', use_tokenizer=True, embedding_dropout=0.0): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: @@ -136,7 +136,7 @@ def encode(self, x): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + def __init__(self, version="openai/clip-vit-large-patch14", device='cuda' if torch.cuda.is_available() else 'cpu', max_length=77): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -166,7 +166,7 @@ class FrozenCLIPTextEmbedder(nn.Module): """ Uses the CLIP transformer encoder for text. """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + def __init__(self, version='ViT-L/14', device='cuda' if torch.cuda.is_available() else 'cpu', max_length=77, n_repeat=1, normalize=True): super().__init__() self.model, _ = clip.load(version, jit=False, device="cpu") self.device = device @@ -231,4 +231,4 @@ def forward(self, x): if __name__ == "__main__": from ldm.util import count_params model = FrozenCLIPEmbedder() - count_params(model, verbose=True) \ No newline at end of file + count_params(model, verbose=True) diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d..285fef736 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -40,7 +40,8 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return model diff --git a/scripts/txt2img.py b/scripts/txt2img.py index bc3864043..04595b861 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -61,7 +61,8 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return model