diff --git a/utils.py b/utils.py index 54ab8357..422680f3 100755 --- a/utils.py +++ b/utils.py @@ -46,7 +46,7 @@ def log(self, losses=None, images=None): sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch)) batches_done = self.batches_epoch*(self.epoch - 1) + self.batch - batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch + batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done))) # Draw images @@ -61,7 +61,7 @@ def log(self, losses=None, images=None): # Plot losses for loss_name, loss in self.losses.items(): if loss_name not in self.loss_windows: - self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), + self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name}) else: self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append') @@ -74,7 +74,7 @@ def log(self, losses=None, images=None): else: self.batch += 1 - + class ReplayBuffer(): def __init__(self, max_size=50): @@ -111,8 +111,7 @@ def step(self, epoch): def weights_init_normal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: - torch.nn.init.normal(m.weight.data, 0.0, 0.02) + torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: - torch.nn.init.normal(m.weight.data, 1.0, 0.02) + torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant(m.bias.data, 0.0) -