Skip to content

Commit

Permalink
add ability to compress time in the unet3d by specifying temporal_str…
Browse files Browse the repository at this point in the history
…ides for Unet3D
  • Loading branch information
lucidrains committed Jan 24, 2023
1 parent 4cf311f commit 3540965
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 21 deletions.
104 changes: 84 additions & 20 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import copy
import operator
import functools
from typing import List
from tqdm.auto import tqdm
from functools import partial, wraps
Expand All @@ -11,7 +13,7 @@
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat, reduce
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
Expand All @@ -31,6 +33,9 @@ def first(arr, d = None):
return d
return arr[0]

def divisible_by(numer, denom):
return (numer % denom) == 0

def maybe(fn):
@wraps(fn)
def inner(x):
Expand Down Expand Up @@ -412,6 +417,8 @@ def __init__(
causal = False,
context_dim = None,
cosine_sim_attn = False,
rel_pos_bias = False,
rel_pos_bias_mlp_depth = 2,
init_zero = False
):
super().__init__()
Expand All @@ -421,6 +428,8 @@ def __init__(
self.cosine_sim_attn = cosine_sim_attn
self.cosine_sim_scale = 16 if cosine_sim_attn else 1

self.rel_pos_bias = DynamicPositionBias(dim = dim, heads = heads, depth = rel_pos_bias_mlp_depth) if rel_pos_bias else None

self.heads = heads
inner_dim = dim_head * heads

Expand Down Expand Up @@ -482,6 +491,9 @@ def forward(

# relative positional encoding (T5 style)

if not exists(attn_bias) and exists(self.rel_pos_bias):
attn_bias = self.rel_pos_bias(n, device = device, dtype = q.dtype)

if exists(attn_bias):
null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
Expand Down Expand Up @@ -587,6 +599,49 @@ def Downsample(dim, dim_out = None):
Conv2d(dim * 4, dim_out, 1)
)

# temporal up and downsamples

class TemporalPixelShuffleUpsample(nn.Module):
def __init__(self, dim, dim_out = None, stride = 2):
super().__init__()
self.stride = stride
dim_out = default(dim_out, dim)
conv = nn.Conv1d(dim, dim_out * stride, 1)

self.net = nn.Sequential(
conv,
nn.SiLU()
)

self.pixel_shuffle = Rearrange('b (c r) n -> b c (n r)', r = stride)

self.init_conv_(conv)

def init_conv_(self, conv):
o, i, f = conv.weight.shape
conv_weight = torch.empty(o // self.stride, i, f)
nn.init.kaiming_uniform_(conv_weight)
conv_weight = repeat(conv_weight, 'o ... -> (o r) ...', r = self.stride)

conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)

def forward(self, x):
b, c, f, h, w = x.shape
x = rearrange(x, 'b c f h w -> (b h w) c f')
out = self.net(x)
out = self.pixel_shuffle(out)
return rearrange(out, '(b h w) c f -> b c f h w', h = h, w = w)

def TemporalDownsample(dim, dim_out = None, stride = 2):
dim_out = default(dim_out, dim)
return nn.Sequential(
Rearrange('b c (f p) h w -> b (c p) f h w', p = stride),
Conv2d(dim * stride, dim_out, 1)
)

# positional embedding

class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -1117,7 +1172,8 @@ def __init__(
num_time_tokens = 2,
learned_sinu_pos_emb_dim = 16,
out_dim = None,
dim_mults=(1, 2, 4, 8),
dim_mults = (1, 2, 4, 8),
temporal_strides = 1,
cond_images_channels = 0,
channels = 3,
channels_out = None,
Expand Down Expand Up @@ -1295,11 +1351,7 @@ def __init__(
temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))

temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True})))

# temporal attention relative positional encoding

self.time_rel_pos_bias = DynamicPositionBias(dim = dim * 2, heads = attn_heads, depth = time_rel_pos_bias_depth)
temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn, 'init_zero': True, 'rel_pos_bias': True})))

# resnet block klass

Expand All @@ -1314,6 +1366,11 @@ def __init__(

assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])

# temporal downsample config

temporal_strides = cast_tuple(temporal_strides, num_layers)
self.total_temporal_divisor = functools.reduce(operator.mul, temporal_strides, 1)

# downsample klass

downsample_klass = Downsample
Expand All @@ -1338,14 +1395,14 @@ def __init__(
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)

layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns]
layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, temporal_strides]
reversed_layer_params = list(map(reversed, layer_params))

# downsampling layers

skip_connect_dims = [] # keep track of skip connection dimensions

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(in_out, *layer_params)):
is_last = ind >= (num_resolutions - 1)

layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
Expand Down Expand Up @@ -1378,6 +1435,7 @@ def __init__(
transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
temporal_peg(current_dim),
temporal_attn(current_dim),
TemporalDownsample(current_dim, stride = temporal_stride) if temporal_stride > 1 else None,
post_downsample
]))

Expand All @@ -1399,7 +1457,7 @@ def __init__(

upsample_fmap_dims = []

for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, temporal_stride) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
is_last = ind == (len(in_out) - 1)
layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
Expand All @@ -1415,6 +1473,7 @@ def __init__(
transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
temporal_peg(dim_out),
temporal_attn(dim_out),
TemporalPixelShuffleUpsample(dim_out, stride = temporal_stride) if temporal_stride > 1 else None,
upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
]))

Expand Down Expand Up @@ -1541,6 +1600,8 @@ def forward(

batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype

assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames must be divisible by {self.total_temporal_divisor}'

# add self conditioning if needed

if self.self_cond:
Expand All @@ -1564,10 +1625,6 @@ def forward(
cond_images = resize_video_to(cond_images, x.shape[-1])
x = torch.cat((cond_images, x), dim = 1)

# get time relative positions

time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype)

# ignoring time in pseudo 3d resnet blocks

conv_kwargs = dict(
Expand All @@ -1580,7 +1637,7 @@ def forward(

if not ignore_time:
x = self.init_temporal_peg(x)
x = self.init_temporal_attn(x, attn_bias = time_attn_bias)
x = self.init_temporal_attn(x)

# init conv residual

Expand Down Expand Up @@ -1687,7 +1744,7 @@ def forward(

hiddens = []

for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, post_downsample in self.downs:
for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_downsample, post_downsample in self.downs:
if exists(pre_downsample):
x = pre_downsample(x)

Expand All @@ -1701,10 +1758,13 @@ def forward(

if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)
x = temporal_attn(x)

hiddens.append(x)

if exists(temporal_downsample) and not ignore_time:
x = temporal_downsample(x)

if exists(post_downsample):
x = post_downsample(x)

Expand All @@ -1715,15 +1775,18 @@ def forward(

if not ignore_time:
x = self.mid_temporal_peg(x)
x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)
x = self.mid_temporal_attn(x)

x = self.mid_block2(x, t, c, **conv_kwargs)

add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)

up_hiddens = []

for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups:
for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, temporal_upsample, upsample in self.ups:
if exists(temporal_upsample) and not ignore_time:
x = temporal_upsample(x)

x = add_skip_connection(x)
x = init_block(x, t, c, **conv_kwargs)

Expand All @@ -1735,9 +1798,10 @@ def forward(

if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)
x = temporal_attn(x)

up_hiddens.append(x.contiguous())

x = upsample(x)

# whether to combine all feature maps from upsample blocks
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.18.11'
__version__ = '1.19.0'

0 comments on commit 3540965

Please sign in to comment.