diff --git a/pytorch/model/net.py b/pytorch/model/net.py index 1173d70..83435b7 100644 --- a/pytorch/model/net.py +++ b/pytorch/model/net.py @@ -256,14 +256,14 @@ def __init__(self, in_channels, act=F.elu, norm=None, opt=None): if self.opt.pretrain_network is False: if self.opt.mask_type == 'rect': self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act, - g_fc_channels=16 * 16 * opt.d_cnum * 4, - l_fc_channels=8 * 8 * opt.d_cnum * 4, + g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4, + l_fc_channels=opt.mask_shapes[0]//16*opt.mask_shapes[1]//16*opt.d_cnum*4, spectral_norm=self.opt.spectral_norm).cuda() else: self.netD = GlobalLocalDiscriminator(3, cnum=opt.d_cnum, act=act, spectral_norm=self.opt.spectral_norm, - g_fc_channels=16 * 16 * opt.d_cnum * 4, - l_fc_channels=16 * 16 * opt.d_cnum * 4).cuda() + g_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4, + l_fc_channels=opt.img_shapes[0]//16*opt.img_shapes[1]//16*opt.d_cnum*4).cuda() init_weights(self.netD) self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=opt.lr, betas=(0.5, 0.9))