diff --git a/pixelcnn/conv_layers.py b/pixelcnn/conv_layers.py index 1937352..de70681 100644 --- a/pixelcnn/conv_layers.py +++ b/pixelcnn/conv_layers.py @@ -19,7 +19,7 @@ def forward(self, x): class MaskedConv2d(nn.Conv2d): - def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=True, **kwargs): + def __init__(self, *args, mask_type, data_channels, **kwargs): super(MaskedConv2d, self).__init__(*args, **kwargs) assert mask_type in ['A', 'B'], 'Invalid mask type.' @@ -32,22 +32,8 @@ def __init__(self, *args, mask_type, data_channels, in_spread=True, out_spread=T mask[:, :, yc, :xc + 1] = 1 def cmask(out_c, in_c): - if out_spread: - a = (np.arange(out_channels) % data_channels == out_c)[:, None] - else: - split = np.ceil(out_channels / 3) - lbound = out_c * split - ubound = (out_c + 1) * split - a = ((lbound <= np.arange(out_channels)) * (np.arange(out_channels) < ubound))[:, None] - - if in_spread: - b = (np.arange(in_channels) % data_channels == in_c)[None, :] - else: - split = np.ceil(in_channels / 3) - lbound = in_c * split - ubound = (in_c + 1) * split - b = ((lbound <= np.arange(in_channels)) * (np.arange(in_channels) < ubound))[None, :] - + a = (np.arange(out_channels) % data_channels == out_c)[:, None] + b = (np.arange(in_channels) % data_channels == in_c)[None, :] return a * b for o in range(data_channels): diff --git a/sample.py b/sample.py index ae6dd72..e53fca0 100644 --- a/sample.py +++ b/sample.py @@ -49,7 +49,7 @@ def main(): model.load_state_dict(torch.load(cfg.model_path)) - samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count) + samples = model.sample((cfg.data_channels, cfg.height, cfg.width), cfg.count, device=device) save_samples(samples, OUTPUT_DIRNAME, OUTPUT_FILENAME)