diff --git a/.pylintrc b/.pylintrc index 2de27ccb..1d386bde 100644 --- a/.pylintrc +++ b/.pylintrc @@ -162,6 +162,7 @@ disable=abstract-method, arguments-out-of-order, consider-using-in, invalid-unary-operand-type, + unnecessary-lambda-assignment, [REPORTS] diff --git a/pyscfad/_src/scipy/linalg.py b/pyscfad/_src/scipy/linalg.py index 51e970b3..e69de29b 100644 --- a/pyscfad/_src/scipy/linalg.py +++ b/pyscfad/_src/scipy/linalg.py @@ -1,133 +0,0 @@ -from functools import partial -import numpy -import scipy -import scipy.linalg -from jax import numpy as np -from jax import scipy as jax_scipy -from pyscfad.ops import custom_jvp, jit - -# default threshold for degenerate eigenvalues -DEG_THRESH = 1e-9 - -# pylint: disable = redefined-builtin -def eigh(a, b=None, lower=True, eigvals_only=False, overwrite_a=False, - overwrite_b=False, turbo=True, eigvals=None, type=1, - check_finite=True, subset_by_index=None, subset_by_value=None, - driver=None, deg_thresh=DEG_THRESH): - if overwrite_a is True or overwrite_b is True: - raise NotImplementedError('Overwritting a or b is not implemeneted.') - if type != 1: - raise NotImplementedError('Only the type=1 case of eigh is implemented.') - if not(eigvals is None and subset_by_index is None and subset_by_value is None): - raise NotImplementedError('Subset of eigen values is not implemented.') - - a = 0.5 * (a + a.T.conj()) - if b is not None: - b = 0.5 * (b + b.T.conj()) - - w, v = _eigh(a, b, deg_thresh=deg_thresh) - - if eigvals_only: - return w - else: - return w, v - -@partial(custom_jvp, nondiff_argnums=(2,)) -def _eigh(a, b, deg_thresh=DEG_THRESH): - w, v = scipy.linalg.eigh(a, b=b) - w = np.asarray(w, dtype=float) - return w, v - -@_eigh.defjvp -def _eigh_jvp(deg_thresh, primals, tangents): - a, b = primals - at, bt = tangents - w, v = _eigh(a, b, deg_thresh) - - eji = w[None, :] - w[:, None] - idx = numpy.asarray(abs(eji) <= deg_thresh, dtype=bool) - eji = eji.at[idx].set(1e200) - eji = eji.at[numpy.diag_indices_from(eji)].set(1) - Fmat = 1 / eji - numpy.eye(a.shape[-1]) - if b is None: - dw, dv = _eigh_jvp_jitted_nob(v, Fmat, at) - else: - bmask = numpy.zeros(a.shape) - bmask[idx] = 1 - dw, dv = _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask) - return (w, v), (dw, dv) - -@jit -def _eigh_jvp_jitted(w, v, Fmat, at, bt, bmask): - vt_at_v = np.dot(v.conj().T, np.dot(at, v)) - vt_bt_v = np.dot(v.conj().T, np.dot(bt, v)) - vt_bt_v_w = np.dot(vt_bt_v, np.diag(w)) - da_minus_ds = vt_at_v - vt_bt_v_w - dw = np.diag(da_minus_ds).real - - dv = np.dot(v, np.multiply(Fmat, da_minus_ds) - np.multiply(bmask, vt_bt_v) * .5) - return dw, dv - -@jit -def _eigh_jvp_jitted_nob(v, Fmat, at): - vt_at_v = np.dot(v.conj().T, np.dot(at, v)) - dw = np.diag(vt_at_v).real - dv = np.dot(v, np.multiply(Fmat, vt_at_v)) - return dw, dv - - -def svd(a, full_matrices=True, compute_uv=True, - overwrite_a=False, check_finite=True, - lapack_driver='gesdd'): - if not full_matrices or not compute_uv: - return jax_scipy.linalg.svd(a, - full_matrices=full_matrices, - compute_uv=compute_uv) - else: - return _svd(a) - -@custom_jvp -def _svd(a): - return jax_scipy.linalg.svd(a) - -@_svd.defjvp -def _svd_jvp(primals, tangents): - A, = primals - dA, = tangents - if np.iscomplexobj(A): - raise NotImplementedError - - m, n = A.shape - if m > n: - raise NotImplementedError('Use svd(A.conj().T) instead.') - - U, s, Vt = _svd(A) - Ut = U.conj().T - V = Vt.conj().T - s_dim = s[None, :] - - dS = Ut @ dA @ V - ds = np.diagonal(dS, 0, -2, -1).real - - s_diffs = (s_dim + s_dim.T) * (s_dim - s_dim.T) - s_diffs_zeros = (s_diffs == 0).astype(s_diffs.dtype) - F = 1. / (s_diffs + s_diffs_zeros) - s_diffs_zeros - - dP1 = dS[:,:m] - dP2 = dS[:,m:] - dSS = dP1 * s_dim - SdS = s_dim.T * dP1 - - dU = U @ (F * (dSS + dSS.conj().T)) - dD1 = F * (SdS + SdS.conj().T) - - s_zeros = (s == 0).astype(s.dtype) - s_inv = 1. / (s + s_zeros) - s_zeros - dD2 = s_inv[:,None] * dP2 - - dV = np.zeros_like(V) - dV = dV.at[:m,:m].set(dD1) - dV = dV.at[:m,m:].set(-dD2) - dV = dV.at[m:,:m].set(dD2.conj().T) - dV = V @ dV - return (U, s, Vt), (dU, ds, dV.conj().T) diff --git a/pyscfad/fci/fci_slow.py b/pyscfad/fci/fci_slow.py index feb9d8f9..2faf0117 100644 --- a/pyscfad/fci/fci_slow.py +++ b/pyscfad/fci/fci_slow.py @@ -2,7 +2,7 @@ from pyscf.fci import cistring from pyscfad import numpy as np from pyscfad import ops -from pyscfad.ops import vmap, stop_grad +from pyscfad.ops import vmap, to_numpy from pyscfad.lib.linalg_helper import davidson from pyscfad.gto import mole from pyscfad import ao2mo @@ -82,16 +82,12 @@ def body(mo_ia, mo_ib, ida, idb): val = np.linalg.det(sij_a) * np.linalg.det(sij_b) return val - res = 0. + res = 0 for ia in range(na1): mo_ia = mo_a1[:,locs_a1[ia]] for ib in range(nb1): mo_ib = mo_b1[:,locs_b1[ib]] val = vmap(body, (None,None,0,0), signature='(i),(j)->()')(mo_ia, mo_ib, idxa, idxb) - #val = [] - #for i in range(len(idxa)): - # val.append(body(mo_ia, mo_ib, idxa[i], idxb[i])) - #val = np.asarray(val) res += ci1[ia,ib] * (val * ci2.ravel()).sum() return res @@ -126,16 +122,19 @@ def contract_2e(eri, fcivec, norb, nelec, opt=None): fcinew = ops.index_add(fcinew, ops.index[:,str1], sign * t1[a,i,:,str0]) return fcinew.reshape(fcivec.shape) - def absorb_h1e(h1e, eri, norb, nelec, fac=1): if not isinstance(nelec, (int, np.integer)): - nelec = sum(nelec) + nelec = np.sum(nelec) + assert nelec > 0 + if eri.size != norb**4: - h2e = ao2mo.restore(1, eri.copy(), norb) + h2e = ao2mo.restore(1, eri, norb) else: - h2e = eri.copy().reshape(norb,norb,norb,norb) - f1e = h1e - np.einsum('jiik->jk', h2e) * .5 - f1e = f1e * (1./(nelec+1e-100)) + h2e = eri.reshape([norb,]*4) + + f1e = h1e - np.einsum('jiik->jk', h2e) * .5 + f1e *= 1. / nelec + for k in range(norb): h2e = ops.index_add(h2e, ops.index[k,k,:,:], f1e) h2e = ops.index_add(h2e, ops.index[:,:,k,k], f1e) @@ -153,7 +152,7 @@ def make_hdiag(h1e, eri, norb, nelec, opt=None): if eri.size != norb**4: eri = ao2mo.restore(1, eri, norb) else: - eri = eri.reshape(norb,norb,norb,norb) + eri = eri.reshape([norb,]*4) diagj = np.einsum('iijj->ij', eri) diagk = np.einsum('ijji->ij', eri) hdiag = [] @@ -173,10 +172,10 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1): hdiag = make_hdiag(h1e, eri, norb, nelec) try: from pyscf.fci.direct_spin1 import pspace - addrs, h0 = pspace(stop_grad(h1e), stop_grad(eri), - norb, nelec, stop_grad(hdiag), nroots) - # pylint: disable=bare-except - except: + addrs, _ = pspace(to_numpy(h1e), to_numpy(eri), + norb, nelec, to_numpy(hdiag), nroots) + # pylint: disable=broad-exception-caught + except Exception: addrs = numpy.argsort(hdiag)[:nroots] ci0 = [] for addr in addrs: @@ -187,7 +186,9 @@ def kernel(h1e, eri, norb, nelec, ecore=0, nroots=1): def hop(c): hc = contract_2e(h2e, c, norb, nelec) return hc.ravel() - # pylint: disable=unnecessary-lambda-assignment + precond = lambda x, e, *args: x/(hdiag-e+1e-4) + e, c = davidson(hop, ci0, precond, nroots=nroots) return e+ecore, c +