diff --git a/guided_diffusion/script_util.py b/guided_diffusion/script_util.py index 21a03e917..ab0e958db 100644 --- a/guided_diffusion/script_util.py +++ b/guided_diffusion/script_util.py @@ -97,6 +97,7 @@ def create_model_and_diffusion( use_fp16, use_new_attention_order, use_neighborhood_attention, + use_torch_sdp_attention, ): model = create_model( image_size, @@ -116,6 +117,7 @@ def create_model_and_diffusion( use_fp16=use_fp16, use_new_attention_order=use_new_attention_order, use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ) diffusion = create_gaussian_diffusion( steps=diffusion_steps, @@ -148,6 +150,7 @@ def create_model( use_fp16=False, use_new_attention_order=False, use_neighborhood_attention=False, + use_torch_sdp_attention=False, ): if channel_mult == "": if image_size == 512: @@ -186,6 +189,7 @@ def create_model( resblock_updown=resblock_updown, use_new_attention_order=use_new_attention_order, use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ) diff --git a/guided_diffusion/unet.py b/guided_diffusion/unet.py index 25da85285..0e7a86ef5 100644 --- a/guided_diffusion/unet.py +++ b/guided_diffusion/unet.py @@ -6,6 +6,7 @@ import torch as th import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from .fp16_util import convert_module_to_f16, convert_module_to_f32 from .nn import ( @@ -274,6 +275,7 @@ def __init__( use_new_attention_order=False, spatial=None, use_neighborhood_attention=False, + use_torch_sdp_attention=False, ): super().__init__() self.channels = channels @@ -287,12 +289,14 @@ def __init__( self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_neighborhood_attention and use_torch_sdp_attention: + raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True') if use_new_attention_order: # split qkv before split heads - self.attention = QKVAttention(self.num_heads, spatial, use_neighborhood_attention) + self.attention = QKVAttention(self.num_heads, spatial, use_neighborhood_attention, use_torch_sdp_attention) else: # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads, spatial, use_neighborhood_attention) + self.attention = QKVAttentionLegacy(self.num_heads, spatial, use_neighborhood_attention, use_torch_sdp_attention) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) @@ -333,11 +337,14 @@ class QKVAttentionLegacy(nn.Module): A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ - def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False): + def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False, use_torch_sdp_attention=False): super().__init__() self.n_heads = n_heads self.spatial = tuple(spatial) self.use_neighborhood_attention = use_neighborhood_attention + self.use_torch_sdp_attention = use_torch_sdp_attention + if use_neighborhood_attention and use_torch_sdp_attention: + raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True') self.neighborhood_mask_cache = {} def get_neighborhood_mask(self, spatial, device): @@ -359,6 +366,11 @@ def forward(self, qkv, spatial=None): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) + if self.use_torch_sdp_attention: + q, k, v = rearrange(qkv, "n (h p c) t -> p n h t c", p=3, c=ch).contiguous().unbind() + a = th.nn.functional.scaled_dot_product_attention(q, k, v) + a = rearrange(a, 'n h t c -> n (h c) t') + return a q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( @@ -380,11 +392,14 @@ class QKVAttention(nn.Module): A module which performs QKV attention and splits in a different order. """ - def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False): + def __init__(self, n_heads, spatial=None, use_neighborhood_attention=False, use_torch_sdp_attention=False): super().__init__() self.n_heads = n_heads self.spatial = tuple(spatial) self.use_neighborhood_attention = use_neighborhood_attention + self.use_torch_sdp_attention = use_torch_sdp_attention + if use_neighborhood_attention and use_torch_sdp_attention: + raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True') self.neighborhood_mask_cache = {} def get_neighborhood_mask(self, spatial, device): @@ -406,6 +421,11 @@ def forward(self, qkv, spatial=None): bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) + if self.use_torch_sdp_attention: + q, k, v = rearrange(qkv, "n (p h c) t -> p n h t c", p=3, c=ch).contiguous().unbind() + a = th.nn.functional.scaled_dot_product_attention(q, k, v) + a = rearrange(a, 'n h t c -> n (h c) t') + return a q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( @@ -477,6 +497,7 @@ def __init__( resblock_updown=False, use_new_attention_order=False, use_neighborhood_attention=False, + use_torch_sdp_attention=False, ): super().__init__() @@ -499,6 +520,9 @@ def __init__( self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.use_neighborhood_attention = use_neighborhood_attention + self.use_torch_sdp_attention = use_torch_sdp_attention + if use_neighborhood_attention and use_torch_sdp_attention: + raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True') time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -541,6 +565,7 @@ def __init__( use_new_attention_order=use_new_attention_order, spatial=(image_size // ds, image_size // ds), use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -588,6 +613,7 @@ def __init__( use_new_attention_order=use_new_attention_order, spatial=(image_size // ds, image_size // ds), use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ), ResBlock( ch, @@ -626,6 +652,7 @@ def __init__( use_new_attention_order=use_new_attention_order, spatial=(image_size // ds, image_size // ds), use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ) ) if level and i == num_res_blocks: @@ -748,6 +775,7 @@ def __init__( use_new_attention_order=False, pool="adaptive", use_neighborhood_attention=False, + use_torch_sdp_attention=False, ): super().__init__() @@ -768,6 +796,9 @@ def __init__( self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.use_neighborhood_attention = use_neighborhood_attention + self.use_torch_sdp_attention = use_torch_sdp_attention + if use_neighborhood_attention and use_torch_sdp_attention: + raise ValueError('Cannot satisfy both use_neighborhood_attention:True and use_torch_sdp_attention:True') time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( @@ -807,6 +838,7 @@ def __init__( use_new_attention_order=use_new_attention_order, spatial=(image_size // ds, image_size // ds), use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) @@ -846,6 +878,7 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, use_neighborhood_attention=use_neighborhood_attention, + use_torch_sdp_attention=use_torch_sdp_attention, ), AttentionBlock( ch, diff --git a/setup.py b/setup.py index dae93bdcc..63858d01e 100644 --- a/setup.py +++ b/setup.py @@ -4,5 +4,5 @@ name="guided-diffusion", packages=["guided_diffusion"], py_modules=["guided_diffusion"], - install_requires=["blobfile>=1.0.5", "torch", "tqdm"], + install_requires=["blobfile>=1.0.5", "torch", "tqdm", "einops"], )