-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathvmc.py
executable file
·334 lines (270 loc) · 9.03 KB
/
vmc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
#!/usr/bin/env python3
import time
from functools import partial
import jax
import netket as nk
import numpy as np
import optax
from netket.jax import dtype_real
from netket.optimizer import identity_preconditioner
from netket.optimizer.qgt import QGTOnTheFly
from args import args
from ham import HeisenbergTriangular, Triangular
from models import MPS, MPSRNN1D, MPSRNN2D, TensorRNN2D, TensorRNNCmpr2D
from models.symmetry import symmetrize_spins
from readers import (
convert_variables,
try_load_enlarge,
try_load_hierarchical,
try_load_itensors,
)
from sampler import MPSDirectSampler
from utils import init_out_dir, tree_size_real_nonzero, try_load_variables
def get_ham(*, _args=None):
if not _args:
_args = args
if _args.boundary == "open":
pbc = False
elif _args.boundary == "peri":
pbc = True
else:
raise ValueError(f"Unknown boundary: {_args.boundary}")
if _args.ham.endswith("tri"):
assert _args.ham_dim == 2
if pbc and _args.sign == "mars":
assert _args.L % 2 == 0
graph = Triangular(_args.L, pbc)
else:
graph = nk.graph.Hypercube(length=_args.L, n_dim=_args.ham_dim, pbc=pbc)
hilbert = nk.hilbert.Spin(s=1 / 2, N=graph.n_nodes)
if _args.J == "afm":
J = 1
elif _args.J == "fm":
J = -1
else:
raise ValueError(f"Unknown J: {_args.J}")
if _args.ham == "ising":
assert _args.sign == "none"
H = nk.operator.IsingJax(hilbert=hilbert, graph=graph, J=J, h=_args.h)
elif _args.ham.startswith("heis"):
assert not _args.h
if _args.ham.endswith("tri"):
H = HeisenbergTriangular(
hilbert=hilbert, graph=graph, J=J, sign_rule=_args.sign
)
else:
H = nk.operator.Heisenberg(
hilbert=hilbert,
graph=graph,
J=J,
sign_rule=(_args.sign == "mars"),
)
else:
raise ValueError(f"Unknown ham: {_args.ham}")
return hilbert, H
def get_net(hilbert, *, _args=None):
if not _args:
_args = args
net_args = dict( # noqa: C408
hilbert=hilbert,
bond_dim=_args.bond_dim,
zero_mag=_args.zero_mag,
refl_sym=_args.refl_sym,
affine=_args.affine,
no_phase=_args.no_phase,
no_w_phase=_args.no_w_phase,
cond_psi=_args.cond_psi,
reorder_type=_args.reorder_type,
reorder_dim=_args.reorder_dim,
dtype=_args.dtype,
)
if _args.net == "mps":
assert _args.net_dim == 1
Net = MPS
elif _args.net == "mps_rnn":
if _args.net_dim == 1:
Net = MPSRNN1D
elif _args.net_dim == 2:
Net = MPSRNN2D
else:
raise ValueError(f"Unknown net_dim: {_args.net_dim}")
elif _args.net == "tensor_rnn":
assert _args.net_dim == 2
Net = TensorRNN2D
elif _args.net == "tensor_rnn_cmpr":
assert _args.net_dim == 2
Net = TensorRNNCmpr2D
else:
raise ValueError(f"Unknown net: {_args.net}")
model = Net(**net_args)
return model
def get_sampler(hilbert, *, _args=None):
if not _args:
_args = args
return MPSDirectSampler(
hilbert,
dtype=dtype_real(_args.dtype),
symmetrize_fun=symmetrize_spins if _args.refl_sym else None,
)
def get_vstate(sampler, model, variables, *, _args=None, n_samples=None):
if not _args:
_args = args
if not n_samples:
n_samples = _args.batch_size
return nk.vqs.MCState(
sampler,
model,
n_samples=n_samples,
chunk_size=_args.chunk_size,
variables=variables,
seed=_args.seed,
)
def get_optimizer(*, _args=None):
if not _args:
_args = args
if _args.optimizer.startswith("rk"):
assert not _args.split_complex
assert not _args.grad_clip
assert not _args.train_only
from netket import experimental as nkx
if _args.optimizer == "rk12":
Integrator = nkx.dynamics.RK12
elif _args.optimizer == "rk23":
Integrator = nkx.dynamics.RK23
else:
raise ValueError(f"Unknown optimizer: {_args.optimizer}")
integrator = Integrator(dt=_args.lr, adaptive=True, rtol=1e-3, atol=1e-3)
return integrator, None, None
# Clip gradients after preconditioner
chain = []
if _args.grad_clip:
chain.append(optax.clip_by_global_norm(_args.grad_clip))
if _args.optimizer == "adam":
chain.append(optax.scale_by_adam())
lr = optax.linear_schedule(
init_value=1e-6, end_value=args.lr, transition_steps=args.max_step // 10
)
chain.append(optax.scale_by_learning_rate(lr))
optimizer = optax.chain(*chain)
if _args.train_only:
names = _args.train_only.split(",")
transforms = {True: optimizer, False: optax.set_to_zero()}
def map_nested_fn(fn):
def map_fn(d):
return {
k: map_fn(v) if isinstance(v, dict) else fn(k, v)
for k, v in d.items()
}
return map_fn
@map_nested_fn
def label_fn(k, v):
return k in names
optimizer = optax.multi_transform(transforms, label_fn)
if _args.split_complex:
optimizer = optax.contrib.split_real_and_imaginary(optimizer)
if _args.optimizer == "sr":
solver = partial(jax.scipy.sparse.linalg.cg, tol=1e-7, atol=1e-7, maxiter=10)
preconditioner = nk.optimizer.SR(
qgt=QGTOnTheFly(), solver=solver, diag_shift=_args.diag_shift
)
else:
assert not _args.diag_shift
preconditioner = identity_preconditioner
return optimizer, preconditioner
def get_vmc(H, vstate, optimizer, preconditioner, *, _args=None):
if not _args:
_args = args
if _args.optimizer.startswith("rk"):
assert preconditioner is None
from netket import experimental as nkx
solver = partial(jax.scipy.sparse.linalg.cg, tol=1e-7, atol=1e-7, maxiter=10)
vmc = nkx.TDVP(
H,
variational_state=vstate,
integrator=optimizer,
propagation_type="imag",
qgt=QGTOnTheFly(diag_shift=_args.diag_shift),
linear_solver=solver,
error_norm="qgt",
)
else:
vmc = nk.VMC(
H,
variational_state=vstate,
optimizer=optimizer,
preconditioner=preconditioner,
)
logger = nk.logging.JsonLog(
_args.log_filename,
"w",
save_params_every=_args.max_step // 100,
write_every=_args.max_step // 100,
)
return vmc, logger
def try_load_variables_init(model, *, _args=None):
if not _args:
_args = args
config = [
("init.mpack", try_load_variables),
("init_hi.mpack", try_load_hierarchical),
("init_el.mpack", try_load_enlarge),
("init.hdf5", try_load_itensors),
]
for basename, func in config:
filename = _args.full_out_dir + basename
if func == try_load_variables:
variables = func(filename)
else:
variables = func(filename, model, _args)
if variables is not None:
print(f"Found {filename}")
variables = convert_variables(variables, _args)
return variables
print(f"Variables not found in {_args.full_out_dir}")
return None
def try_load_variables_out(*, _args=None):
if not _args:
_args = args
config = [
("out_ema.mpack", try_load_variables),
("out.mpack", try_load_variables),
]
for basename, func in config:
filename = _args.full_out_dir + basename
variables = func(filename)
if variables is not None:
print(f"Found {filename}")
variables = convert_variables(variables, _args)
return variables
print(f"Variables not found in {_args.full_out_dir}")
return None
def main():
init_out_dir()
print(args.log_filename)
hilbert, H = get_ham()
model = get_net(hilbert)
variables = try_load_variables_init(model)
sampler = get_sampler(hilbert)
vstate = get_vstate(sampler, model, variables)
print("n_params", tree_size_real_nonzero(vstate.parameters))
optimizer, preconditioner = get_optimizer()
vmc, logger = get_vmc(H, vstate, optimizer, preconditioner)
print("start_time", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
start_time = time.time()
if args.optimizer.startswith("rk"):
t_max = args.lr * args.max_step
vmc.run(
T=t_max,
out=logger,
tstops=np.linspace(0, t_max, args.max_step + 1),
show_progress=args.show_progress,
)
else:
vmc.run(n_iter=args.max_step, out=logger, show_progress=args.show_progress)
used_time = time.time() - start_time
print("used_time", used_time)
vstate.n_samples = args.estim_size
energy = vstate.expect(H)
print("energy", energy)
if __name__ == "__main__":
main()