Skip to content

I am getting 'sldj' as 'nan' #179

@NikhilMank

Description

@NikhilMank

I have initialised a model and then started to train it. I am getting sldj as 'nan' after i start and hence making NLL Loss also nan. Can someone help me understand why it is happening? I have no idea how the sldj values are being calculated and I am unable to identify the problem

This is the code I have used.

class FlowModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_node = Ff.InputNode(2048,160,160) 
        self.coupling_block_1 = Ff.Node(self.input_node, Fm.GLOWCouplingBlock, {'subnet_constructor': self.subnet_cnn, 'clamp': 2.0}, name=F'coupling_{0}')
        self.permutation_block_1 = Ff.Node(self.coupling_block_1, Fm.PermuteRandom, {'seed': 0}, name=F'permute_{0}')
        self.coupling_block_2 = Ff.Node(self.permutation_block_1, Fm.GLOWCouplingBlock, {'subnet_constructor': self.subnet_cnn, 'clamp': 2.0}, name=F'coupling_{1}')
        self.permuatation_block_2 = Ff.Node(self.coupling_block_2, Fm.PermuteRandom, {'seed': 1}, name=F'permute_{1}')
        self.output_node = Ff.OutputNode(self.permuatation_block_2, name= 'output')
        self.flow_model = Ff.GraphINN([self.input_node, self.coupling_block_1, self.permutation_block_1, self.coupling_block_2, self.permuatation_block_2, self.output_node])
        
    def subnet_cnn(self, in_channels, out_channels):
        return flow_CNN(in_channels, out_channels)

    def forward(self, x):
        return self.flow_model(x)
    
    def NLLLoss(self, z, sldj):
        """Negative log-likelihood loss assuming isotropic gaussian with unit norm.
        Args:
             z (torch.Tensor): Output tensor from the flow model. The shape of the tensor is (batch_size, num_channels, height, width).
             sldj (torch.Tensor): Sum of log-determinants of the Jacobian matrix. The shape of the tensor is (batch_size,).
             
        Returns:
            nll (torch.Tensor): The mean negative log-likelihood. The shape of the tensor is ([]).
        """
        prior_ll = -0.5 * (z ** 2 + torch.log(2 * torch.tensor(float(np.pi))))
        prior_ll = prior_ll.flatten(1).sum(-1) - torch.log(torch.tensor(256.0)) * torch.prod(torch.tensor(z.size()[1:]))
        
        # Calculate the log-likelihood by adding the log Jacobian determinant
        ll = prior_ll + sldj
        
        # Calculate the mean negative log-likelihood loss
        nll = -ll.mean() / torch.prod(torch.tensor(z.size()[1:]))       # Taking average over the batch size and the dimensions of the tensor
        return nll

class flow_CNN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(512, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1)
        
        self.init_weights()

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        return x    
    
    def init_weights(self):
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
flow_model = FlowModel().to(device).half()
initialize_weights(flow_model)
optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-4)
# Training Code
losses = []
num_epochs = 50
iteration = 0
nll_loss = 0.0
for i in range(num_epochs):
    for inputs in tqdm(train_features_loader):
        if inputs is None:
            continue
        
        iteration+=1
        
        optimizer.zero_grad()
        inputs = inputs.view(-1, 2048, 160, 160).to(device)
        z, sldj = flow_model(inputs)
        
        if iteration==1:
            print(f"{z.max()=}\n{sldj.max()=}\n")
        nll_loss = flow_model.NLLLoss(z, sldj)
        if iteration==1:
            print(f"{nll_loss.item()=}\n")
            
        nll_loss.backward()
        optimizer.step()
        
        if iteration%50==0:
            losses.append(nll_loss.item())
    print(f'Epoch [{i+1}/{num_epochs}], Loss: {nll_loss.item():.4f}')
# output:
z.max()=tensor(108.6250, device='cuda:0', dtype=torch.float16, grad_fn=<MaxBackward1>)
sldj.max()=tensor(nan, device='cuda:0', grad_fn=<MaxBackward1>)

nll_loss.item()=nan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions