Skip to content

Commit

Permalink
fixed issue with batching, should work properly now
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexYFM committed Sep 18, 2024
1 parent f145b2e commit 069c5b2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 86 deletions.
Binary file removed verse/stars/model_weights.pth
Binary file not shown.
72 changes: 0 additions & 72 deletions verse/stars/nn_results.csv

This file was deleted.

26 changes: 12 additions & 14 deletions verse/stars/star_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def he_init(m):
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

num_epochs = 50 # sample number of epoch -- can play with this/set this as a hyperparameter
num_samples = 100 # number of samples per time step
lamb = 1
batch_size = 1 # number of times computed at once, lower should be better but slower
num_epochs: int = 50 # sample number of epoch -- can play with this/set this as a hyperparameter
num_samples: int = 100 # number of samples per time step
lamb: float = 1
batch_size: int = 5 # number of times computed at once, lower should be better but slower, min of 1

T = 14
ts = 0.2
Expand Down Expand Up @@ -190,15 +190,13 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
post_points = torch.tensor(post_points).to(device)

for i in range(len(times)//batch_size+1):
batch_times: torch.Tensor
if i==len(times)/batch_size:
batch_times = times[i*batch_size:]
else:
batch_times = times[i*batch_size:(i+1)*batch_size]
start = i*batch_size
end = (i+1)*batch_size

### very naive way to do this, probably would want more or less equal batch sizes if not dividing equally

mu = model(times.unsqueeze(1)) # get times in right form
loss = (torch.log(1+torch.sum(mu))+lamb*torch.sum(containment(post_points[:, :, 1:], times, bases, centers))/num_samples)
mu = model(times[start:end].unsqueeze(1)) # get times in right form
loss = torch.sum(mu)+lamb*torch.sum(containment(post_points[:, start:end, 1:], times[start:end], bases[start:end], centers[start:end]))/num_samples
loss.backward()
optimizer.step()

Expand All @@ -213,12 +211,12 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
t = torch.tensor([times[i]], dtype=torch.float32).to(device)
mu = model(t)
cont = lambda p, i: torch.linalg.vector_norm(torch.relu(C@torch.linalg.inv(bases[i])@(p-centers[i])-mu*g))
loss = torch.log(1+mu) + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
loss = mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
# loss = (1-lamb)*mu + lamb*torch.sum(torch.stack([cont(point, i) for point in post_points[:, i, 1:]]))/len(post_points[:,i,1:])
print(f'loss: {loss.item():.4f}, mu: {mu.item():.4f}, time: {t.item():.1f}')
losses += loss.item()
mu = model(times.unsqueeze(1)) # get times in right form
other_loss = (torch.log(1+torch.sum(mu))+lamb*torch.sum(containment(post_points[:, :, 1:], times, bases, centers)))/(num_samples)
other_loss = torch.sum(mu)+lamb*torch.sum(containment(post_points[:, :, 1:], times, bases, centers))/(num_samples)
print(f'Losses: {losses:.4f}, ..., other loss {other_loss:.5f}')

# test the new model
Expand Down Expand Up @@ -261,7 +259,7 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten

for i in range(len(times)):
# mu, center = model(test[i])[0].detach().numpy(), model(test[i])[1:].detach().numpy()
stars.append(StarSet(centers[i], bases[i], C.numpy(), model(test[i]).detach().numpy()*g.numpy()))
stars.append(StarSet(centers[i], bases[i], C.numpy(), torch.relu(model(test[i])).detach().numpy()*g.numpy()))
points = torch.tensor(post_points[:, i, 1:])
contain = torch.sum(torch.stack([cont(point, i) == 0 for point in points]))
percent_contained.append(contain/(num_samples*10)*100)
Expand Down

0 comments on commit 069c5b2

Please sign in to comment.