Skip to content

Commit

Permalink
last changes before fork
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexYFM committed Sep 18, 2024
1 parent 069c5b2 commit 8ee3d94
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions verse/stars/star_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def he_init(m):

num_epochs: int = 50 # sample number of epoch -- can play with this/set this as a hyperparameter
num_samples: int = 100 # number of samples per time step
lamb: float = 1
lamb: float = 5
batch_size: int = 5 # number of times computed at once, lower should be better but slower, min of 1

T = 14
Expand Down Expand Up @@ -196,7 +196,7 @@ def containment(points: torch.Tensor, times: torch.Tensor, bases: List[torch.Ten
### very naive way to do this, probably would want more or less equal batch sizes if not dividing equally

mu = model(times[start:end].unsqueeze(1)) # get times in right form
loss = torch.sum(mu)+lamb*torch.sum(containment(post_points[:, start:end, 1:], times[start:end], bases[start:end], centers[start:end]))/num_samples
loss = torch.log1p(torch.sum(mu))+lamb*torch.sum(containment(post_points[:, start:end, 1:], times[start:end], bases[start:end], centers[start:end]))/num_samples
loss.backward()
optimizer.step()

Expand Down

0 comments on commit 8ee3d94

Please sign in to comment.