Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor update #47

Merged
merged 11 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ disable=abstract-method,
arguments-out-of-order,
consider-using-in,
invalid-unary-operand-type,
unnecessary-lambda-assignment,


[REPORTS]
Expand Down
133 changes: 0 additions & 133 deletions pyscfad/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,133 +0,0 @@
from functools import partial
import numpy
import scipy
import scipy.linalg
from jax import numpy as np
from jax import scipy as jax_scipy
from pyscfad.ops import custom_jvp, jit

# default threshold for degenerate eigenvalues
DEG_THRESH = 1e-9

# pylint: disable = redefined-builtin
def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False,
overwrite_b=False, turbo=True, eigvals=None, type=1,
check_finite=True, subset_by_index=None, subset_by_value=None,
driver=None, deg_thresh=DEG_THRESH):
if overwrite_a is True or overwrite_b is True:
raise NotImplementedError('Overwritting a or b is not implemeneted.')
if type != 1:
raise NotImplementedError('Only the type=1 case of eigh is implemented.')
if not(eigvals is None and subset_by_index is None and subset_by_value is None):
raise NotImplementedError('Subset of eigen values is not implemented.')

a = 0.5 * (a + a.T.conj())
if b is not None:
b = 0.5 * (b + b.T.conj())

w, v = _eigh(a, b, deg_thresh=deg_thresh)

if eigvals_only:
return w
else:
return w, v

@partial(custom_jvp, nondiff_argnums=(2,))
def _eigh(a, b, deg_thresh=DEG_THRESH):
w, v = scipy.linalg.eigh(a, b=b)
w = np.asarray(w, dtype=float)
return w, v

@_eigh.defjvp
def _eigh_jvp(deg_thresh, primals, tangents):
a, b = primals
at, bt = tangents
w, v = _eigh(a, b, deg_thresh)

eji = w[None, :] - w[:, None]
idx = numpy.asarray(abs(eji) <= deg_thresh, dtype=bool)
eji = eji.at[idx].set(1e200)
eji = eji.at[numpy.diag_indices_from(eji)].set(1)
Fmat = 1 / eji - numpy.eye(a.shape[-1])
if b is None:
dw, dv = _eigh_jvp_jitted_nob(v, Fmat, at)
else:
bmask = numpy.zeros(a.shape)
bmask[idx] = 1
dw, dv = _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask)
return (w, v), (dw, dv)

@jit
def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask):
vt_at_v = np.dot(v.conj().T, np.dot(at, v))
vt_bt_v = np.dot(v.conj().T, np.dot(bt, v))
vt_bt_v_w = np.dot(vt_bt_v, np.diag(w))
da_minus_ds = vt_at_v - vt_bt_v_w
dw = np.diag(da_minus_ds).real

dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5)
return dw, dv

@jit
def _eigh_jvp_jitted_nob(v, Fmat, at):
vt_at_v = np.dot(v.conj().T, np.dot(at, v))
dw = np.diag(vt_at_v).real
dv = np.dot(v, np.multiply(Fmat, vt_at_v))
return dw, dv


def svd(a, full_matrices=True, compute_uv=True,
overwrite_a=False, check_finite=True,
lapack_driver='gesdd'):
if not full_matrices or not compute_uv:
return jax_scipy.linalg.svd(a,
full_matrices=full_matrices,
compute_uv=compute_uv)
else:
return _svd(a)

@custom_jvp
def _svd(a):
return jax_scipy.linalg.svd(a)

@_svd.defjvp
def _svd_jvp(primals, tangents):
A, = primals
dA, = tangents
if np.iscomplexobj(A):
raise NotImplementedError

m, n = A.shape
if m > n:
raise NotImplementedError('Use svd(A.conj().T) instead.')

U, s, Vt = _svd(A)
Ut = U.conj().T
V = Vt.conj().T
s_dim = s[None, :]

dS = Ut @ dA @ V
ds = np.diagonal(dS, 0, -2, -1).real

s_diffs = (s_dim + s_dim.T) * (s_dim - s_dim.T)
s_diffs_zeros = (s_diffs == 0).astype(s_diffs.dtype)
F = 1. / (s_diffs + s_diffs_zeros) - s_diffs_zeros

dP1 = dS[:,:m]
dP2 = dS[:,m:]
dSS = dP1 * s_dim
SdS = s_dim.T * dP1

dU = U @ (F * (dSS + dSS.conj().T))
dD1 = F * (SdS + SdS.conj().T)

s_zeros = (s == 0).astype(s.dtype)
s_inv = 1. / (s + s_zeros) - s_zeros
dD2 = s_inv[:,None] * dP2

dV = np.zeros_like(V)
dV = dV.at[:m,:m].set(dD1)
dV = dV.at[:m,m:].set(-dD2)
dV = dV.at[m:,:m].set(dD2.conj().T)
dV = V @ dV
return (U, s, Vt), (dU, ds, dV.conj().T)
37 changes: 19 additions & 18 deletions pyscfad/fci/fci_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pyscf.fci import cistring
from pyscfad import numpy as np
from pyscfad import ops
from pyscfad.ops import vmap, stop_grad
from pyscfad.ops import vmap, to_numpy
from pyscfad.lib.linalg_helper import davidson
from pyscfad.gto import mole
from pyscfad import ao2mo
Expand Down Expand Up @@ -82,16 +82,12 @@ def body(mo_ia, mo_ib, ida, idb):
val = np.linalg.det(sij_a) * np.linalg.det(sij_b)
return val

res = 0.
res = 0
for ia in range(na1):
mo_ia = mo_a1[:,locs_a1[ia]]
for ib in range(nb1):
mo_ib = mo_b1[:,locs_b1[ib]]
val = vmap(body, (None,None,0,0), signature='(i),(j)->()')(mo_ia, mo_ib, idxa, idxb)
#val = []
#for i in range(len(idxa)):
# val.append(body(mo_ia, mo_ib, idxa[i], idxb[i]))
#val = np.asarray(val)
res += ci1[ia,ib] * (val * ci2.ravel()).sum()
return res

Expand Down Expand Up @@ -126,16 +122,19 @@ def contract_2e(eri, fcivec, norb, nelec, opt=None):
fcinew = ops.index_add(fcinew, ops.index[:,str1], sign * t1[a,i,:,str0])
return fcinew.reshape(fcivec.shape)


def absorb_h1e(h1e, eri, norb, nelec, fac=1):
if not isinstance(nelec, (int, np.integer)):
nelec = sum(nelec)
nelec = np.sum(nelec)
assert nelec > 0

if eri.size != norb**4:
h2e = ao2mo.restore(1, eri.copy(), norb)
h2e = ao2mo.restore(1, eri, norb)
else:
h2e = eri.copy().reshape(norb,norb,norb,norb)
f1e = h1e - np.einsum('jiik->jk', h2e) * .5
f1e = f1e * (1./(nelec+1e-100))
h2e = eri.reshape([norb,]*4)

f1e = h1e - np.einsum('jiik->jk', h2e) * .5
f1e *= 1. / nelec

for k in range(norb):
h2e = ops.index_add(h2e, ops.index[k,k,:,:], f1e)
h2e = ops.index_add(h2e, ops.index[:,:,k,k], f1e)
Expand All @@ -153,7 +152,7 @@ def make_hdiag(h1e, eri, norb, nelec, opt=None):
if eri.size != norb**4:
eri = ao2mo.restore(1, eri, norb)
else:
eri = eri.reshape(norb,norb,norb,norb)
eri = eri.reshape([norb,]*4)
diagj = np.einsum('iijj->ij', eri)
diagk = np.einsum('ijji->ij', eri)
hdiag = []
Expand All @@ -173,10 +172,10 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1):
hdiag = make_hdiag(h1e, eri, norb, nelec)
try:
from pyscf.fci.direct_spin1 import pspace
addrs, h0 = pspace(stop_grad(h1e), stop_grad(eri),
norb, nelec, stop_grad(hdiag), nroots)
# pylint: disable=bare-except
except:
addrs, _ = pspace(to_numpy(h1e), to_numpy(eri),
norb, nelec, to_numpy(hdiag), nroots)
# pylint: disable=broad-exception-caught
except Exception:
addrs = numpy.argsort(hdiag)[:nroots]
ci0 = []
for addr in addrs:
Expand All @@ -187,7 +186,9 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1):
def hop(c):
hc = contract_2e(h2e, c, norb, nelec)
return hc.ravel()
# pylint: disable=unnecessary-lambda-assignment

precond = lambda x, e, *args: x/(hdiag-e+1e-4)

e, c = davidson(hop, ci0, precond, nroots=nroots)
return e+ecore, c

Loading