-
Notifications
You must be signed in to change notification settings - Fork 116
Open
Description
I have initialised a model and then started to train it. I am getting sldj as 'nan' after i start and hence making NLL Loss also nan. Can someone help me understand why it is happening? I have no idea how the sldj values are being calculated and I am unable to identify the problem
This is the code I have used.
class FlowModel(nn.Module):
def __init__(self):
super().__init__()
self.input_node = Ff.InputNode(2048,160,160)
self.coupling_block_1 = Ff.Node(self.input_node, Fm.GLOWCouplingBlock, {'subnet_constructor': self.subnet_cnn, 'clamp': 2.0}, name=F'coupling_{0}')
self.permutation_block_1 = Ff.Node(self.coupling_block_1, Fm.PermuteRandom, {'seed': 0}, name=F'permute_{0}')
self.coupling_block_2 = Ff.Node(self.permutation_block_1, Fm.GLOWCouplingBlock, {'subnet_constructor': self.subnet_cnn, 'clamp': 2.0}, name=F'coupling_{1}')
self.permuatation_block_2 = Ff.Node(self.coupling_block_2, Fm.PermuteRandom, {'seed': 1}, name=F'permute_{1}')
self.output_node = Ff.OutputNode(self.permuatation_block_2, name= 'output')
self.flow_model = Ff.GraphINN([self.input_node, self.coupling_block_1, self.permutation_block_1, self.coupling_block_2, self.permuatation_block_2, self.output_node])
def subnet_cnn(self, in_channels, out_channels):
return flow_CNN(in_channels, out_channels)
def forward(self, x):
return self.flow_model(x)
def NLLLoss(self, z, sldj):
"""Negative log-likelihood loss assuming isotropic gaussian with unit norm.
Args:
z (torch.Tensor): Output tensor from the flow model. The shape of the tensor is (batch_size, num_channels, height, width).
sldj (torch.Tensor): Sum of log-determinants of the Jacobian matrix. The shape of the tensor is (batch_size,).
Returns:
nll (torch.Tensor): The mean negative log-likelihood. The shape of the tensor is ([]).
"""
prior_ll = -0.5 * (z ** 2 + torch.log(2 * torch.tensor(float(np.pi))))
prior_ll = prior_ll.flatten(1).sum(-1) - torch.log(torch.tensor(256.0)) * torch.prod(torch.tensor(z.size()[1:]))
# Calculate the log-likelihood by adding the log Jacobian determinant
ll = prior_ll + sldj
# Calculate the mean negative log-likelihood loss
nll = -ll.mean() / torch.prod(torch.tensor(z.size()[1:])) # Taking average over the batch size and the dimensions of the tensor
return nll
class flow_CNN(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(512, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1)
self.init_weights()
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return x
def init_weights(self):
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def initialize_weights(model):
for m in model.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
flow_model = FlowModel().to(device).half()
initialize_weights(flow_model)
optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-4)
# Training Code
losses = []
num_epochs = 50
iteration = 0
nll_loss = 0.0
for i in range(num_epochs):
for inputs in tqdm(train_features_loader):
if inputs is None:
continue
iteration+=1
optimizer.zero_grad()
inputs = inputs.view(-1, 2048, 160, 160).to(device)
z, sldj = flow_model(inputs)
if iteration==1:
print(f"{z.max()=}\n{sldj.max()=}\n")
nll_loss = flow_model.NLLLoss(z, sldj)
if iteration==1:
print(f"{nll_loss.item()=}\n")
nll_loss.backward()
optimizer.step()
if iteration%50==0:
losses.append(nll_loss.item())
print(f'Epoch [{i+1}/{num_epochs}], Loss: {nll_loss.item():.4f}')
# output:
z.max()=tensor(108.6250, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)
sldj.max()=tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)
nll_loss.item()=nan
Metadata
Metadata
Assignees
Labels
No labels