diff --git a/README.md b/README.md index 9da7670..2417294 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ I_op = nkf.infidelity.InfidelityOperator(phi, U=U, U_dagger=U, is_unitary=True, # Create the driver optimizer = nk.optimizer.Sgd(learning_rate=0.01) -te = nkf.driver.infidelity_optimizer.InfidelityOptimizer(phi, U, psi, optimizer, U_dagger=U, is_unitary=True, cv_coeff=-0.5) +te = nkf.driver.infidelity_optimizer.InfidelityOptimizer(phi, optimizer, U=U, U_dagger=U, variational_state=psi, is_unitary=True, cv_coeff=-0.5) # Run the driver te.run(n_iter=100) diff --git a/netket_fidelity/driver/infidelity_optimizer.py b/netket_fidelity/driver/infidelity_optimizer.py index 8983890..4c22af3 100644 --- a/netket_fidelity/driver/infidelity_optimizer.py +++ b/netket_fidelity/driver/infidelity_optimizer.py @@ -1,48 +1,140 @@ +from typing import Optional from netket.stats import Stats from netket.driver.abstract_variational_driver import AbstractVariationalDriver -from .infidelity_optimizer_common import info +from netket.optimizer import ( + identity_preconditioner, + PreconditionerT, +) + from netket_fidelity.infidelity import InfidelityOperator +from .infidelity_optimizer_common import info + class InfidelityOptimizer(AbstractVariationalDriver): def __init__( self, target_state, - U, - vstate, optimizer, + *, + variational_state, + U=None, U_dagger=None, - sr=None, + preconditioner: PreconditionerT = identity_preconditioner, is_unitary=False, - cv_coeff=None, + cv_coeff=-0.5, ): - super().__init__(vstate, optimizer, minimized_quantity_name="Infidelity") + r""" + Constructs a driver training the state to match the target state. + + The target state is either `math`:\ket{\psi}` or `math`:\hat{U}\ket{\psi}` + depending on the provided inputs. + + Operator I_op computing the infidelity I among two variational states |ψ⟩ and |Φ⟩ as: + + .. math:: + + I = 1 - |⟨ψ|Φ⟩|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩ = 1 - ⟨ψ|I_op|ψ⟩ / ⟨ψ|ψ⟩ + + where: + + .. math:: + + I_op = |Φ⟩⟨Φ| / ⟨Φ|Φ⟩ + + The state |Φ⟩ can be an autonomous state |Φ⟩ =|ϕ⟩ or an operator U applied to it, namely + |Φ⟩ = U|ϕ⟩. I_op is defined by the state |ϕ⟩ (called target) and, possibly, by the operator U. + If U is not passed, it is assumed |Φ⟩ =|ϕ⟩. + + The Monte Carlo estimator of I is: + + ..math:: + + I = \mathbb{E}_{χ}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|Φ⟩ ⟨η|ψ⟩ / ⟨σ|ψ⟩ ⟨η|Φ⟩ ] + + where χ(σ, η) = |Ψ(σ)|^2 |Φ(η)|^2 / ⟨ψ|ψ⟩ ⟨Φ|Φ⟩. In practice, since I is a real quantity, Re{I_loc(σ,η)} + is used. This estimator can be utilized both when |Φ⟩ =|ϕ⟩ and when |Φ⟩ = U|ϕ⟩, with U a (unitary or + non-unitary) operator. In the second case, we have to sample from U|ϕ⟩ and this is implemented in + the function :ref:`jax.:ref:`InfidelityUPsi`. This works only with the operators provdided in the package. + We remark that sampling from U|ϕ⟩ requires to compute connected elements of U and so is more expensive + than sampling from an autonomous state. The choice of this estimator is specified by passing + `sample_Upsi=True`, while the flag argument `is_unitary` indicates whether U is unitary or not. + + If U is unitary, the following alternative estimator can be used: + + ..math:: + + I = \mathbb{E}_{χ'}[ I_loc(σ,η) ] = \mathbb{E}_{χ}[ ⟨σ|U|ϕ⟩ ⟨η|ψ⟩ / ⟨σ|U^{\dagger}|ψ⟩ ⟨η|ϕ⟩ ]. + + where χ'(σ, η) = |Ψ(σ)|^2 |ϕ(η)|^2 / ⟨ψ|ψ⟩ ⟨ϕ|ϕ⟩. This estimator is more efficient since it does not + require to sample from U|ϕ⟩, but only from |ϕ⟩. This choice of the estimator is the default and it works only + with `is_unitary==True` (besides `sample_Upsi=False`). When |Φ⟩ = |ϕ⟩ the two estimators coincides. + + To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists + in modifying the estimator into: + + ..math:: + + I_loc^{CV} = Re{I_loc(σ,η)} - c (|1 - I_loc(σ,η)^2| - 1) + + where c ∈ \mathbb{R}. The constant c is chosen to minimize the variance of I_loc^{CV} as: + + ..math:: + + c* = Cov_{χ}[ |1-I_loc|^2, Re{1-I_loc}] / Var_{χ}[ |1-I_loc|^2 ], + + where Cov[..., ...] indicates the covariance and Var[...] the variance. In the relevant limit + |Ψ⟩ →|Φ⟩, we have c*→-1/2. The value -1/2 is adopted as default value for c in the infidelity + estimator. To not apply CV, set c=0. + + Args: + target_state: target variational state |ϕ⟩. + optimizer: the optimizer to use to use (from optax) + variational_state: the variational state to train + U: operator U. + U_dagger: dagger operator U^{\dagger}. + cv_coeff: Control Variates coefficient c. + is_unitary: flag specifiying the unitarity of U. If True with `sample_Upsi=False`, the second estimator is used. + dtype: The dtype of the output of expectation value and gradient. + sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False`, an error occurs. + preconditioner: Determines which preconditioner to use for the loss gradient. + This must be a tuple of `(object, solver)` as documented in the section + `preconditioners` in the documentation. The standard preconditioner + included with NetKet is Stochastic Reconfiguration. By default, no + preconditioner is used and the bare gradient is passed to the optimizer. + """ + super().__init__( + variational_state, optimizer, minimized_quantity_name="Infidelity" + ) + + self._cv = cv_coeff + + self.preconditioner = preconditioner - self.sr = sr self._I_op = InfidelityOperator( - target_state, U=U, U_dagger=U, is_unitary=True, cv_coeff=-1 / 2 + target_state, U=U, U_dagger=U, is_unitary=True, cv_coeff=cv_coeff ) def _forward_and_backward(self): self.state.reset() self._I_op.target.reset() - I_stats, I_grad = self.state.expect_and_grad(self._I_op) - - # TODO - self._loss_stats = I_stats - self._loss_grad = I_grad + self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._I_op) - if self.sr is not None: - self._S = self.state.quantum_geometric_tensor(self.sr) - self._dp = self._S(self._loss_grad) - else: - self._dp = self._loss_grad + # if it's the identity it does + self._dp = self.preconditioner(self.state, self._loss_grad, self.step_count) return self._dp + @property + def cv(self) -> Optional[float]: + """ + Return the coefficient for the Control Variates + """ + return self._cv + @property def infidelity(self) -> Stats: """ @@ -51,6 +143,36 @@ def infidelity(self) -> Stats: """ return self._loss_stats + @property + def preconditioner(self): + """ + The preconditioner used to modify the gradient. + + This is a function with the following signature + + .. code-block:: python + + precondtioner(vstate: VariationalState, + grad: PyTree, + step: Optional[Scalar] = None) + + Where the first argument is a variational state, the second argument + is the PyTree of the gradient to precondition and the last optional + argument is the step, used to change some parameters along the + optimisation. + + Often, this is taken to be :func:`nk.optimizer.SR`. If it is set to + `None`, then the identity is used. + """ + return self._preconditioner + + @preconditioner.setter + def preconditioner(self, val: Optional[PreconditionerT]): + if val is None: + val = identity_preconditioner + + self._preconditioner = val + def __repr__(self): return ( "InfidelityOptimiser(" @@ -69,6 +191,3 @@ def info(self, depth=0): ] ] return "\n{}".format(" " * 3 * (depth + 1)).join([str(self)] + lines) - - def info(self): - pass diff --git a/netket_fidelity/infidelity/logic.py b/netket_fidelity/infidelity/logic.py index efdc26c..9105012 100644 --- a/netket_fidelity/infidelity/logic.py +++ b/netket_fidelity/infidelity/logic.py @@ -1,9 +1,16 @@ from typing import Optional from netket.operator import AbstractOperator, Adjoint -from netket.vqs import VariationalState, ExactState +from netket.vqs import VariationalState from netket.utils.types import DType +import netket + +if hasattr(netket.vqs, "FullSumState"): + from netket.vqs import FullSumState +else: + from netket.vqs import ExactState as FullSumState + from .overlap import InfidelityOperatorStandard, InfidelityUPsi from .overlap_U import InfidelityOperatorUPsi @@ -120,7 +127,7 @@ def InfidelityOperator( "use operators coming from `netket_fidelity`." ) - if isinstance(target, ExactState): + if isinstance(target, FullSumState): return InfidelityOperatorUPsi( U, target, diff --git a/netket_fidelity/infidelity/overlap/exact.py b/netket_fidelity/infidelity/overlap/exact.py index 8784042..1d85a22 100644 --- a/netket_fidelity/infidelity/overlap/exact.py +++ b/netket_fidelity/infidelity/overlap/exact.py @@ -5,21 +5,29 @@ from netket import jax as nkjax from netket.utils.dispatch import TrueT -from netket.vqs import ExactState, expect, expect_and_grad +from netket.vqs import expect, expect_and_grad from netket.utils import mpi from netket.stats import Stats +# support future netket +import netket + +if hasattr(netket.vqs, "FullSumState"): + from netket.vqs import FullSumState +else: + from netket.vqs import ExactState as FullSumState + from .operator import InfidelityOperatorStandard @expect.dispatch -def infidelity(vstate: ExactState, op: InfidelityOperatorStandard): +def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard): if op.hilbert != vstate.hilbert: raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, ExactState): + if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") - return infidelity_sampling_ExactState( + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, @@ -30,8 +38,8 @@ def infidelity(vstate: ExactState, op: InfidelityOperatorStandard): @expect_and_grad.dispatch -def infidelity( - vstate: ExactState, +def infidelity( # noqa: F811 + vstate: FullSumState, op: InfidelityOperatorStandard, use_covariance: TrueT, *, @@ -39,10 +47,10 @@ def infidelity( ): if op.hilbert != vstate.hilbert: raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, ExactState): + if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") - return infidelity_sampling_ExactState( + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, @@ -53,7 +61,7 @@ def infidelity( @partial(jax.jit, static_argnames=("afun", "return_grad")) -def infidelity_sampling_ExactState( +def infidelity_sampling_FullSumState( afun, params, model_state, diff --git a/netket_fidelity/infidelity/overlap/expect.py b/netket_fidelity/infidelity/overlap/expect.py index 089e82e..bbde732 100644 --- a/netket_fidelity/infidelity/overlap/expect.py +++ b/netket_fidelity/infidelity/overlap/expect.py @@ -33,7 +33,7 @@ def infidelity(vstate: MCState, op: InfidelityOperatorStandard): @expect_and_grad.dispatch -def infidelity( +def infidelity( # noqa: F811 vstate: MCState, op: InfidelityOperatorStandard, use_covariance: TrueT, diff --git a/netket_fidelity/infidelity/overlap/operator.py b/netket_fidelity/infidelity/overlap/operator.py index b21ea5e..37e716f 100644 --- a/netket_fidelity/infidelity/overlap/operator.py +++ b/netket_fidelity/infidelity/overlap/operator.py @@ -5,7 +5,16 @@ from netket.operator import AbstractOperator from netket.utils.types import DType from netket.utils.numbers import is_scalar -from netket.vqs import VariationalState, ExactState, MCState +from netket.vqs import VariationalState, MCState + +# support future netket +import netket + +if hasattr(netket.vqs, "FullSumState"): + from netket.vqs import FullSumState +else: + from netket.vqs import ExactState as FullSumState + from netket_fidelity.utils.sampling_Ustate import _logpsi_U @@ -29,7 +38,7 @@ def __init__( if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): raise TypeError("`cv_coeff` should be a real scalar number or None.") - if isinstance(target, ExactState): + if isinstance(target, FullSumState): cv_coeff = None self._target = target diff --git a/netket_fidelity/infidelity/overlap_U/exact.py b/netket_fidelity/infidelity/overlap_U/exact.py index fb1e0db..1e1e683 100644 --- a/netket_fidelity/infidelity/overlap_U/exact.py +++ b/netket_fidelity/infidelity/overlap_U/exact.py @@ -3,10 +3,18 @@ from netket import jax as nkjax from netket.utils.dispatch import TrueT -from netket.vqs import ExactState, expect, expect_and_grad +from netket.vqs import expect, expect_and_grad from netket.utils import mpi from netket.stats import Stats +# support future netket +import netket + +if hasattr(netket.vqs, "FullSumState"): + from netket.vqs import FullSumState +else: + from netket.vqs import ExactState as FullSumState + from .operator import InfidelityOperatorUPsi @@ -15,13 +23,13 @@ def sparsify(U): @expect.dispatch -def infidelity(vstate: ExactState, op: InfidelityOperatorUPsi): +def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi): if op.hilbert != vstate.hilbert: raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, ExactState): + if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") - return infidelity_sampling_ExactState( + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, @@ -33,8 +41,8 @@ def infidelity(vstate: ExactState, op: InfidelityOperatorUPsi): @expect_and_grad.dispatch -def infidelity( - vstate: ExactState, +def infidelity( # noqa: F811 + vstate: FullSumState, op: InfidelityOperatorUPsi, use_covariance: TrueT, *, @@ -42,10 +50,10 @@ def infidelity( ): if op.hilbert != vstate.hilbert: raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, ExactState): + if not isinstance(op.target, FullSumState): raise TypeError("Can only compute infidelity of exact states.") - return infidelity_sampling_ExactState( + return infidelity_sampling_FullSumState( vstate._apply_fun, vstate.parameters, vstate.model_state, @@ -56,7 +64,7 @@ def infidelity( ) -def infidelity_sampling_ExactState( +def infidelity_sampling_FullSumState( afun, params, model_state, diff --git a/netket_fidelity/infidelity/overlap_U/expect.py b/netket_fidelity/infidelity/overlap_U/expect.py index b1632d2..4a7c508 100644 --- a/netket_fidelity/infidelity/overlap_U/expect.py +++ b/netket_fidelity/infidelity/overlap_U/expect.py @@ -42,7 +42,7 @@ def infidelity(vstate: MCState, op: InfidelityOperatorUPsi): @expect_and_grad.dispatch -def infidelity( +def infidelity( # noqa: F811 vstate: MCState, op: InfidelityOperatorUPsi, use_covariance: TrueT, diff --git a/netket_fidelity/infidelity/overlap_U/operator.py b/netket_fidelity/infidelity/overlap_U/operator.py index 7023018..751d483 100644 --- a/netket_fidelity/infidelity/overlap_U/operator.py +++ b/netket_fidelity/infidelity/overlap_U/operator.py @@ -4,7 +4,15 @@ from netket.operator import AbstractOperator from netket.utils.types import DType from netket.utils.numbers import is_scalar -from netket.vqs import VariationalState, ExactState +from netket.vqs import VariationalState + +# support future netket +import netket + +if hasattr(netket.vqs, "FullSumState"): + from netket.vqs import FullSumState +else: + from netket.vqs import ExactState as FullSumState class InfidelityOperatorUPsi(AbstractOperator): @@ -23,7 +31,7 @@ def __init__( if not isinstance(state, VariationalState): raise TypeError("The first argument should be a variational state.") - if not is_unitary and not isinstance(state, ExactState): + if not is_unitary and not isinstance(state, FullSumState): raise ValueError( "Only works with unitary gates. If the gate is non unitary" " then you must sample from it. Use a different operator." @@ -35,7 +43,7 @@ def __init__( if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): raise TypeError("`cv_coeff` should be a real scalar number or None.") - if isinstance(state, ExactState): + if isinstance(state, FullSumState): cv_coeff = None self._target = state diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index b78c8d2..0570351 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -151,7 +151,7 @@ def get_conns_and_mels_Ry(sigma, idx, angle): @nk.vqs.get_local_kernel_arguments.dispatch -def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: Ry): +def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: Ry): # noqa: F811 sigma = vstate.samples conns, mels = get_conns_and_mels_Ry( sigma.reshape(-1, vstate.hilbert.size), op.idx, op.angle @@ -228,7 +228,7 @@ def get_conns_and_mels_Hadamard(sigma, idx): @nk.vqs.get_local_kernel_arguments.dispatch -def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: Hadamard): +def get_local_kernel_arguments(vstate: nk.vqs.MCState, op: Hadamard): # noqa: F811 sigma = vstate.samples conns, mels = get_conns_and_mels_Hadamard( sigma.reshape(-1, vstate.hilbert.size), diff --git a/test/_finite_diff.py b/test/_finite_diff.py index 96226b7..1f75e78 100644 --- a/test/_finite_diff.py +++ b/test/_finite_diff.py @@ -1,10 +1,10 @@ import numpy as np -import jax import jax.numpy as jnp import netket as nk + def central_diff_grad(func, x, eps, *args, dtype=None): if dtype is None: dtype = x.dtype diff --git a/test/_infidelity_exact.py b/test/_infidelity_exact.py index 8d69820..643e358 100644 --- a/test/_infidelity_exact.py +++ b/test/_infidelity_exact.py @@ -1,5 +1,4 @@ import jax.numpy as jnp -import netket as nk def _infidelity_exact(params_new, vstate, U): @@ -10,12 +9,12 @@ def _infidelity_exact(params_new, vstate, U): state_new = vstate.to_array() vstate.parameters = params_old - if U is None: - return 1 - jnp.absolute(state_new.conj().T @ state_old)**2 / ( + if U is None: + return 1 - jnp.absolute(state_new.conj().T @ state_old) ** 2 / ( (state_new.conj().T @ state_new) * (state_old.conj().T @ state_old) ) - - else: - return 1 - jnp.absolute(state_new.conj().T @ U.to_sparse() @ state_old)**2 / ( + + else: + return 1 - jnp.absolute(state_new.conj().T @ U.to_sparse() @ state_old) ** 2 / ( (state_new.conj().T @ state_new) * (state_old.conj().T @ state_old) ) diff --git a/test/test_infidelity_operator.py b/test/test_infidelity_operator.py index b82f78b..d67e077 100644 --- a/test/test_infidelity_operator.py +++ b/test/test_infidelity_operator.py @@ -101,7 +101,7 @@ def _infidelity_exact_fun(params, vstate, U): @pytest.mark.parametrize("sample_Upsi", [False, True]) @pytest.mark.parametrize("is_identity", [False, True]) -def test_ExactState(sample_Upsi, is_identity): +def test_FullSumState(sample_Upsi, is_identity): vs_t, vs, vs_exact_t, vs_exact, _U, _U_dag = _setup() if is_identity is False: diff --git a/test/test_inverserot.py b/test/test_inverserot.py index 912e941..e722105 100644 --- a/test/test_inverserot.py +++ b/test/test_inverserot.py @@ -1,9 +1,6 @@ -import pytest -from pytest import approx import netket as nk import jax.numpy as jnp import numpy as np -import scipy import netket_fidelity as nkf