From ba7f7109c38c3805800283cdb9d79cd7c4a3294f Mon Sep 17 00:00:00 2001 From: yiwang Date: Wed, 8 Jan 2020 22:39:39 +0800 Subject: [PATCH] Update net.py --- pytorch/model/net.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch/model/net.py b/pytorch/model/net.py index 1173d70..83435b7 100644 --- a/pytorch/model/net.py +++ b/pytorch/model/net.py @@ -256,14 +256,14 @@ def __init__(self, in_channels, act=F.elu, norm=None, opt=None): if self.opt.pretrain_network is False: if self.opt.mask_type == 'rect': self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act, - g_fc_channels=16 * 16 * opt.d_cnum * 4, - l_fc_channels=8 * 8 * opt.d_cnum * 4, + g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4, + l_fc_channels=opt.mask_shapes[0]//16*opt.mask_shapes[1]//16*opt.d_cnum*4, spectral_norm=self.opt.spectral_norm).cuda() else: self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act, spectral_norm=self.opt.spectral_norm, - g_fc_channels=16 * 16 * opt.d_cnum * 4, - l_fc_channels=16 * 16 * opt.d_cnum * 4).cuda() + g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4, + l_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4).cuda() init_weights(self.netD) self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr, betas=(0.5, 0.9))