-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtasks.py
63 lines (53 loc) · 1.83 KB
/
tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from functools import partial
from typing import Tuple
import torch
import dynamical_systems as dslib
tasks = {
"flip_flop1": (1, 1),
"flip_flop2": (2, 2),
"flip_flop3": (3, 3),
"double_well": (1, 1),
"limit_cycle": (2, 2),
}
def flip_flop(d: int,
timesteps: int = 1000,
n: int = 5000,
p: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate (x, y) data for the flip-flop task.
"""
x = torch.zeros(n, timesteps, d)
y = torch.zeros(n, timesteps, d)
x = torch.bernoulli(torch.ones(n, timesteps, d) * p)
x = x - 2 * torch.bernoulli(x * 0.5) # flip half of the bits
cur = torch.zeros(n, d)
for t in range(timesteps):
cur = torch.where(x[:, t] != 0, x[:, t], cur)
y[:, t] = cur
return x, y
flip_flop1 = partial(flip_flop, d=1)
flip_flop2 = partial(flip_flop, d=2)
flip_flop3 = partial(flip_flop, d=3)
def fit_ds(timesteps: int,
ds: str,
n: int = 5000,
clip_val: float = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert ds in dslib.available_ds.keys()
dims = dslib.available_ds[ds]
ds_fn = getattr(dslib, ds)
x0s = torch.randn((n, dims))
if clip_val is not None:
x0s = torch.clamp(x0s, -clip_val, clip_val)
trajectories = dslib.simulate_ds(x0s, timesteps, ds_fn)
u = torch.zeros((n, timesteps, 1))
if trajectories.isnan().any():
nanexamples = trajectories.isnan().any(dim=1).any(dim=1)
print(f"Warning: NaNs in {nanexamples.sum()} trajectories, removing them")
x0s = x0s[~nanexamples]
u = u[~nanexamples]
trajectories = trajectories[~nanexamples]
return u, trajectories, x0s
double_well = partial(fit_ds, ds="double_well")
limit_cycle = partial(fit_ds, ds="limit_cycle")