Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Giving error on shape changes #1

Open
CypherpunkSamurai opened this issue Dec 10, 2022 · 0 comments
Open

Giving error on shape changes #1

CypherpunkSamurai opened this issue Dec 10, 2022 · 0 comments

Comments

@CypherpunkSamurai
Copy link

Cannot Generate Tensor

Hi I tried to use the code from your repo to generate a tensor. Any Idea why it might not be working?

"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class KSampler(object):
    def __init__(self, model, schedule='lms', device='cuda', **kwargs):
        super().__init__()
        self.model = K.external.CompVisDenoiser(model)
        self.schedule = schedule
        self.device = device

        def forward(self, x, sigma, uncond, cond, cond_scale):
            x_in = torch.cat([x] * 2)
            sigma_in = torch.cat([sigma] * 2)
            cond_in = torch.cat([uncond, cond])
            uncond, cond = self.inner_model(
                x_in, sigma_in, cond=cond_in
            ).chunk(2)
            return uncond + (cond - uncond) * cond_scale

    # most of these arguments are ignored and are only present for compatibility with
    # other samples
    @torch.no_grad()
    def sample(
        self,
        S,
        batch_size,
        shape,
        conditioning=None,
        callback=None,
        normals_sequence=None,
        img_callback=None,
        quantize_x0=False,
        eta=0.0,
        mask=None,
        x0=None,
        temperature=1.0,
        noise_dropout=0.0,
        score_corrector=None,
        corrector_kwargs=None,
        verbose=True,
        x_T=None,
        log_every_t=100,
        unconditional_guidance_scale=1.0,
        unconditional_conditioning=None,
        # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
        **kwargs,
    ):

        sigmas = self.model.get_sigmas(S)
        if x_T:
            x = x_T
        else:
            x = (
                torch.randn([batch_size, *shape], device=self.device)
                * sigmas[0]
            )   # for GPU draw
        model_wrap_cfg = CFGDenoiser(self.model)
        extra_args = {
            'cond': conditioning,
            'uncond': unconditional_conditioning,
            'cond_scale': unconditional_guidance_scale,
        }
        return (
            K.sampling.__dict__[f'sample_{self.schedule}'](
                model_wrap_cfg, x, sigmas, extra_args=extra_args
            ),
            None,
        )

# Run
# model is custom ckpt loaded with config
steps = 28
# shape = [4, 328 // 8, 328 // 8]

a = KSampler(model, "euler_ancestral")

# shape = [4, 384 // 8, 384 // 8] 
shape = [4, 600 // 8, 600 // 8] # error

c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning(1 * ["lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"])
sample, _ = a.sample(28, 1, shape, conditioning=c, unconditional_conditioning=uc)

LOG

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_23/3561072134.py in <module>
      6 c = model.get_learned_conditioning([prompt])
      7 uc = model.get_learned_conditioning(1 * ["lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"])
----> 8 sample, _ = a.sample(28, 1, shape, conditioning=c, unconditional_conditioning=uc)

/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/tmp/ipykernel_23/2775818800.py in sample(self, S, batch_size, shape, conditioning, callback, normals_sequence, img_callback, quantize_x0, eta, mask, x0, temperature, noise_dropout, score_corrector, corrector_kwargs, verbose, x_T, log_every_t, unconditional_guidance_scale, unconditional_conditioning, **kwargs)
     86         return (
     87             K.sampling.__dict__[f'sample_{self.schedule}'](
---> 88                 model_wrap_cfg, x, sigmas, extra_args=extra_args
     89             ),
     90             None,

/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/lib/python3.7/site-packages/k_diffusion/sampling.py in sample_euler_ancestral(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
    136 
    137 
--> 138 @torch.no_grad()
    139 def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
    140     """Ancestral sampling with Euler method steps."""

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/tmp/ipykernel_23/2775818800.py in forward(self, x, sigma, uncond, cond, cond_scale)
     18         sigma_two = torch.cat([sigma] * 2)
     19         cond_full = torch.cat([uncond, cond])
---> 20         uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
     21         x_0 = uncond + (cond - uncond) * cond_scale
     22         if self.thresholder is not None:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/lib/python3.7/site-packages/k_diffusion/external.py in forward(self, input, sigma, **kwargs)
    110     def forward(self, input, sigma, **kwargs):
    111         c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
--> 112         eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
    113         return input + eps * c_out
    114 

/opt/conda/lib/python3.7/site-packages/k_diffusion/external.py in get_eps(self, *args, **kwargs)
    136 
    137     def get_eps(self, *args, **kwargs):
--> 138         return self.inner_model.apply_model(*args, **kwargs)
    139 
    140 

/kaggle/working/code/ldm/models/diffusion/ddpm.py in apply_model(self, x_noisy, t, cond, return_ids)
    985 
    986         else:
--> 987             x_recon = self.model(x_noisy, t, **cond)
    988 
    989         if isinstance(x_recon, tuple) and not return_ids:

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/kaggle/working/code/ldm/models/diffusion/ddpm.py in forward(self, x, t, c_concat, c_crossattn)
   1408         elif self.conditioning_key == 'crossattn':
   1409             cc = torch.cat(c_crossattn, 1)
-> 1410             out = self.diffusion_model(x, t, context=cc)
   1411         elif self.conditioning_key == 'hybrid':
   1412             xc = torch.cat([x] + c_concat, dim=1)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/kaggle/working/code/ldm/modules/diffusionmodules/openaimodel.py in forward(self, x, timesteps, context, y, **kwargs)
    734         h = self.middle_block(h, emb, context)
    735         for module in self.output_blocks:
--> 736             h = th.cat([h, hs.pop()], dim=1)
    737             h = module(h, emb, context)
    738         h = h.type(x.dtype)
Sizes of tensors must match except in dimension 1. Expected size 20 but got size 19 for tensor number 1 in the list.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant