Skip to content

Commit

Permalink
removed data channels from optional parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
anordertoreclaim committed Aug 10, 2019
1 parent be3a306 commit 62f81e4
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 14 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ optional arguments:
Kernel size of causal convolution
--hidden-ksize HIDDEN_KSIZE
Kernel size of hidden layers convolutions
--data-channels DATA_CHANNELS
Number of data channels
--color-levels COLOR_LEVELS
Number of levels to quantisize value of each channel
of each pixel into
Expand Down Expand Up @@ -127,8 +125,6 @@ optional arguments:
Kernel size of causal convolution
--hidden-ksize HIDDEN_KSIZE
Kernel size of hidden layers convolutions
--data-channels DATA_CHANNELS
Number of data channels
--color-levels COLOR_LEVELS
Number of levels to quantisize value of each channel
of each pixel into
Expand Down
14 changes: 8 additions & 6 deletions pixelcnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ class PixelCNN(nn.Module):
def __init__(self, cfg):
super(PixelCNN, self).__init__()

DATA_CHANNELS = 3

self.hidden_fmaps = cfg.hidden_fmaps
self.color_levels = cfg.color_levels

self.causal_conv = CausalBlock(cfg.data_channels,
self.causal_conv = CausalBlock(DATA_CHANNELS,
cfg.hidden_fmaps,
cfg.causal_ksize,
data_channels=cfg.data_channels)
data_channels=DATA_CHANNELS)

self.hidden_conv = nn.Sequential(
*[GatedBlock(cfg.hidden_fmaps, cfg.hidden_fmaps, cfg.hidden_ksize, cfg.data_channels) for _ in range(cfg.hidden_layers)]
*[GatedBlock(cfg.hidden_fmaps, cfg.hidden_fmaps, cfg.hidden_ksize, DATA_CHANNELS) for _ in range(cfg.hidden_layers)]
)

self.label_embedding = nn.Embedding(10, self.hidden_fmaps)
Expand All @@ -138,13 +140,13 @@ def __init__(self, cfg):
cfg.out_hidden_fmaps,
(1, 1),
mask_type='B',
data_channels=cfg.data_channels)
data_channels=DATA_CHANNELS)

self.out_conv = MaskedConv2d(cfg.out_hidden_fmaps,
cfg.data_channels * cfg.color_levels,
DATA_CHANNELS * cfg.color_levels,
(1, 1),
mask_type='B',
data_channels=cfg.data_channels)
data_channels=DATA_CHANNELS)

def forward(self, image, label):
count, data_channels, height, width = image.size()
Expand Down
2 changes: 0 additions & 2 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ def main():
parser.add_argument('--hidden-ksize', type=int, default=7,
help='Kernel size of hidden layers convolutions')

parser.add_argument('--data-channels', type=int, default=3,
help='Number of data channels')
parser.add_argument('--color-levels', type=int, default=2,
help='Number of levels to quantisize value of each channel of each pixel into')

Expand Down
2 changes: 0 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def main():
parser.add_argument('--hidden-ksize', type=int, default=7,
help='Kernel size of hidden layers convolutions')

parser.add_argument('--data-channels', type=int, default=3,
help='Number of data channels')
parser.add_argument('--color-levels', type=int, default=2,
help='Number of levels to quantisize value of each channel of each pixel into')

Expand Down

0 comments on commit 62f81e4

Please sign in to comment.