diff --git a/README.md b/README.md index 2837b3e..803f29a 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,6 @@ optional arguments: Kernel size of causal convolution --hidden-ksize HIDDEN_KSIZE Kernel size of hidden layers convolutions - --data-channels DATA_CHANNELS - Number of data channels --color-levels COLOR_LEVELS Number of levels to quantisize value of each channel of each pixel into @@ -127,8 +125,6 @@ optional arguments: Kernel size of causal convolution --hidden-ksize HIDDEN_KSIZE Kernel size of hidden layers convolutions - --data-channels DATA_CHANNELS - Number of data channels --color-levels COLOR_LEVELS Number of levels to quantisize value of each channel of each pixel into diff --git a/pixelcnn/model.py b/pixelcnn/model.py index 1ff8eda..c41e42a 100644 --- a/pixelcnn/model.py +++ b/pixelcnn/model.py @@ -120,16 +120,18 @@ class PixelCNN(nn.Module): def __init__(self, cfg): super(PixelCNN, self).__init__() + DATA_CHANNELS = 3 + self.hidden_fmaps = cfg.hidden_fmaps self.color_levels = cfg.color_levels - self.causal_conv = CausalBlock(cfg.data_channels, + self.causal_conv = CausalBlock(DATA_CHANNELS, cfg.hidden_fmaps, cfg.causal_ksize, - data_channels=cfg.data_channels) + data_channels=DATA_CHANNELS) self.hidden_conv = nn.Sequential( - *[GatedBlock(cfg.hidden_fmaps, cfg.hidden_fmaps, cfg.hidden_ksize, cfg.data_channels) for _ in range(cfg.hidden_layers)] + *[GatedBlock(cfg.hidden_fmaps, cfg.hidden_fmaps, cfg.hidden_ksize, DATA_CHANNELS) for _ in range(cfg.hidden_layers)] ) self.label_embedding = nn.Embedding(10, self.hidden_fmaps) @@ -138,13 +140,13 @@ def __init__(self, cfg): cfg.out_hidden_fmaps, (1, 1), mask_type='B', - data_channels=cfg.data_channels) + data_channels=DATA_CHANNELS) self.out_conv = MaskedConv2d(cfg.out_hidden_fmaps, - cfg.data_channels * cfg.color_levels, + DATA_CHANNELS * cfg.color_levels, (1, 1), mask_type='B', - data_channels=cfg.data_channels) + data_channels=DATA_CHANNELS) def forward(self, image, label): count, data_channels, height, width = image.size() diff --git a/sample.py b/sample.py index d3425dd..08a45d6 100644 --- a/sample.py +++ b/sample.py @@ -16,8 +16,6 @@ def main(): parser.add_argument('--hidden-ksize', type=int, default=7, help='Kernel size of hidden layers convolutions') - parser.add_argument('--data-channels', type=int, default=3, - help='Number of data channels') parser.add_argument('--color-levels', type=int, default=2, help='Number of levels to quantisize value of each channel of each pixel into') diff --git a/train.py b/train.py index 763856b..7205647 100644 --- a/train.py +++ b/train.py @@ -87,8 +87,6 @@ def main(): parser.add_argument('--hidden-ksize', type=int, default=7, help='Kernel size of hidden layers convolutions') - parser.add_argument('--data-channels', type=int, default=3, - help='Number of data channels') parser.add_argument('--color-levels', type=int, default=2, help='Number of levels to quantisize value of each channel of each pixel into')