Skip to content

Commit

Permalink
Removed in_spread and out_spread because of their incorrectness
Browse files Browse the repository at this point in the history
  • Loading branch information
anordertoreclaim committed Aug 4, 2019
1 parent 0224738 commit 8416548
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 18 deletions.
20 changes: 3 additions & 17 deletions pixelcnn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(self, x):


class MaskedConv2d(nn.Conv2d):
def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=True, **kwargs):
def __init__(self, *args, mask_type, data_channels, **kwargs):
super(MaskedConv2d, self).__init__(*args, **kwargs)

assert mask_type in ['A', 'B'], 'Invalid mask type.'
Expand All @@ -32,22 +32,8 @@ def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=T
mask[:, :, yc, :xc + 1] = 1

def cmask(out_c, in_c):
if out_spread:
a = (np.arange(out_channels) % data_channels == out_c)[:, None]
else:
split = np.ceil(out_channels / 3)
lbound = out_c * split
ubound = (out_c + 1) * split
a = ((lbound <= np.arange(out_channels)) * (np.arange(out_channels) < ubound))[:, None]

if in_spread:
b = (np.arange(in_channels) % data_channels == in_c)[None, :]
else:
split = np.ceil(in_channels / 3)
lbound = in_c * split
ubound = (in_c + 1) * split
b = ((lbound <= np.arange(in_channels)) * (np.arange(in_channels) < ubound))[None, :]

a = (np.arange(out_channels) % data_channels == out_c)[:, None]
b = (np.arange(in_channels) % data_channels == in_c)[None, :]
return a * b

for o in range(data_channels):
Expand Down
2 changes: 1 addition & 1 deletion sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():

model.load_state_dict(torch.load(cfg.model_path))

samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count)
samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, device=device)
save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)


Expand Down

0 comments on commit 8416548

Please sign in to comment.