From 22b2e25c61f3d4384d4533ae3da2cd180570e9b0 Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Mon, 4 Sep 2023 21:37:54 +0200 Subject: [PATCH 01/17] Adapted Hamadard gate to use local_states. Added black pre commit file. --- .pre-commit-config.yaml | 12 +++ netket_fidelity/operator/singlequbit_gates.py | 77 +++++++++++-------- 2 files changed, 58 insertions(+), 31 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a605170 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.282 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black \ No newline at end of file diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 93b8aa2..e09adeb 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -14,14 +14,14 @@ @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """) + # if not isinstance(hi, Spin): + # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. + # + # This limitation could be lifted by 'fixing' the method + # `get_conn_and_mels` to work with arbitrary hilbert spaces, which + # should be relatively straightforward to do, but we have not done so + # yet. + # """) super().__init__(hi) self._idx = idx self._angle = angle @@ -54,13 +54,16 @@ def __eq__(self, o): return False def tree_flatten(self): - children = (self.angle, ) - aux_data = (self.hilbert, self.idx, ) + children = (self.angle,) + aux_data = ( + self.hilbert, + self.idx, + ) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): - angle, = children + (angle,) = children return cls(*aux_data, angle) @property @@ -108,13 +111,15 @@ def get_conns_and_mels_Rx(sigma, idx, angle): class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): if not isinstance(hi, Spin): - raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. + raise TypeError( + """The Hilbert space used by Rx must be a `Spin` space. - This limitation could be lifted by 'fixing' the method + This limitation could be lifted by 'fixing' the method `get_conn_and_mels` to work with arbitrary hilbert spaces, which should be relatively straightforward to do, but we have not done so yet. - """) + """ + ) super().__init__(hi) self._idx = idx @@ -152,13 +157,16 @@ def __eq__(self, o): return False def tree_flatten(self): - children = (self.angle, ) - aux_data = (self.hilbert, self.idx, ) + children = (self.angle,) + aux_data = ( + self.hilbert, + self.idx, + ) return (children, aux_data) @classmethod def tree_unflatten(cls, aux_data, children): - angle, = children + (angle,) = children return cls(*aux_data, angle) @jax.jit @@ -204,15 +212,15 @@ def get_conns_and_mels_Ry(sigma, idx, angle): @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - if not isinstance(hi, Spin): - raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """) - + # if not isinstance(hi, Spin): + # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. + # + # This limitation could be lifted by 'fixing' the method + # `get_conn_and_mels` to work with arbitrary hilbert spaces, which + # should be relatively straightforward to do, but we have not done so + # yet. + # """) + self._local_states = hi.local_states super().__init__(hi) self._idx = idx @@ -258,7 +266,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Hadamard(xr, self.idx) + xp, mels = get_conns_and_mels_Hadamard(xr, self.idx, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -275,15 +283,22 @@ def get_conn_flattened(self, x, sections): return xp, mels -@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) -def get_conns_and_mels_Hadamard(sigma, idx): +@partial(jax.vmap, in_axes=(0, None, [None, None]), out_axes=(0, 0)) +def get_conns_and_mels_Hadamard(sigma, idx, local_states): assert sigma.ndim == 1 + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + cons = jnp.tile(sigma, (2, 1)) - cons = cons.at[1, idx].set(-cons.at[1, idx].get()) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=float) mels = mels.at[1].set(1 / jnp.sqrt(2)) - mels = mels.at[0].set(((-1) ** ((cons.at[0, idx].get() + 1) / 2)) / jnp.sqrt(2)) + state_value = cons.at[0, idx].get() + mels_value = jnp.where(state_value == local_states[0], 1, -1) / jnp.sqrt(2) + mels = mels.at[0].set(mels_value) return cons, mels From 33a2503f1fc30552e174ab9eba75d39553acbe32 Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:05:05 +0200 Subject: [PATCH 02/17] Updated rx and ry to work with local states. --- netket_fidelity/operator/singlequbit_gates.py | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index e09adeb..2162f16 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -14,15 +14,9 @@ @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - # if not isinstance(hi, Spin): - # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - # - # This limitation could be lifted by 'fixing' the method - # `get_conn_and_mels` to work with arbitrary hilbert spaces, which - # should be relatively straightforward to do, but we have not done so - # yet. - # """) + super().__init__(hi) + self._local_states = hi.local_states self._idx = idx self._angle = angle @@ -95,33 +89,29 @@ def to_local_operator(self): @partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Rx(sigma, idx, angle): +def get_conns_and_mels_Rx(sigma, idx, angle, local_states): assert sigma.ndim == 1 - conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + + cons = jnp.tile(sigma, (2, 1)) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) + mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) mels = mels.at[1].set(-1j * jnp.sin(angle / 2)) - return conns, mels + return cons, mels @register_pytree_node_class class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) - super().__init__(hi) + self._local_states = hi.local_states self._idx = idx self._angle = angle @@ -172,7 +162,7 @@ def tree_unflatten(cls, aux_data, children): @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._lo) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -193,33 +183,32 @@ def to_local_operator(self): return ctheta + 1j * stheta * spin.sigmay(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Ry(sigma, idx, angle): +@partial(jax.vmap, in_axes=(0, None, None, [None, None]), out_axes=(0, 0)) +def get_conns_and_mels_Ry(sigma, idx, angle, local_states): assert sigma.ndim == 1 - conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + + cons = jnp.tile(sigma, (2, 1)) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) + phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0],1,-1) mels = mels.at[1].set( - (-1) ** ((conns.at[0, idx].get() + 1) / 2) * jnp.sin(angle / 2) + phase_factor * jnp.sin(angle / 2) ) - return conns, mels + return cons, mels @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - # if not isinstance(hi, Spin): - # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - # - # This limitation could be lifted by 'fixing' the method - # `get_conn_and_mels` to work with arbitrary hilbert spaces, which - # should be relatively straightforward to do, but we have not done so - # yet. - # """) + self._local_states = hi.local_states super().__init__(hi) self._idx = idx From 9f35b777f19dc7479fdb1b42fbe6ee54f13906e0 Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:12:48 +0200 Subject: [PATCH 03/17] Ran black code formatter --- netket_fidelity/operator/singlequbit_gates.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 2162f16..5e4b629 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -1,13 +1,8 @@ from functools import partial - import numpy as np - import jax import jax.numpy as jnp - from jax.tree_util import register_pytree_node_class - -from netket.hilbert import Spin from netket.operator import DiscreteJaxOperator, spin @@ -197,10 +192,8 @@ def get_conns_and_mels_Ry(sigma, idx, angle, local_states): mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) - phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0],1,-1) - mels = mels.at[1].set( - phase_factor * jnp.sin(angle / 2) - ) + phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0], 1, -1) + mels = mels.at[1].set(phase_factor * jnp.sin(angle / 2)) return cons, mels From ab7768f0df8bd366948cf252ff33211830e3fc85 Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:40:09 +0200 Subject: [PATCH 04/17] Fixed typo's in local_state argument --- netket_fidelity/operator/singlequbit_gates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 5e4b629..10f1075 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -62,7 +62,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -83,7 +83,7 @@ def to_local_operator(self): return ctheta - 1j * stheta * spin.sigmax(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) +@partial(jax.vmap, in_axes=(0, None, None, [None, None]), out_axes=(0, 0)) def get_conns_and_mels_Rx(sigma, idx, angle, local_states): assert sigma.ndim == 1 @@ -157,7 +157,7 @@ def tree_unflatten(cls, aux_data, children): @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._lo) + xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels From 149c64ab0a2f7b8dfb132da64a828ef78fd6d248 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 5 Sep 2023 19:57:45 +0200 Subject: [PATCH 05/17] cleanup ruff --- .github/workflows/formatting_check.yaml | 2 +- examples/onespin/onespin_rotation.py | 1 - examples/twospins_GHZ/twospins_GHZ.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/formatting_check.yaml b/.github/workflows/formatting_check.yaml index c173873..37a2261 100644 --- a/.github/workflows/formatting_check.yaml +++ b/.github/workflows/formatting_check.yaml @@ -46,4 +46,4 @@ jobs: with: version: 0.0.287 args: --config pyproject.toml - src: netket test Examples + src: netket test examples diff --git a/examples/onespin/onespin_rotation.py b/examples/onespin/onespin_rotation.py index a0f9baa..3f42e94 100644 --- a/examples/onespin/onespin_rotation.py +++ b/examples/onespin/onespin_rotation.py @@ -1,7 +1,6 @@ import netket as nk import netket_fidelity as nkf import jax.numpy as jnp -import scipy import matplotlib.pyplot as plt import flax diff --git a/examples/twospins_GHZ/twospins_GHZ.py b/examples/twospins_GHZ/twospins_GHZ.py index 5c952ab..2655ad6 100644 --- a/examples/twospins_GHZ/twospins_GHZ.py +++ b/examples/twospins_GHZ/twospins_GHZ.py @@ -1,7 +1,6 @@ import netket as nk import netket_fidelity as nkf import jax.numpy as jnp -import scipy import matplotlib.pyplot as plt import flax From cd61fd51d980cad14de3b432078141b1e4d07fcc Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Thu, 14 Sep 2023 19:15:28 +0200 Subject: [PATCH 06/17] Implemented small changes according to feedback Filippo. Will include operator test after returning from break. --- .pre-commit-config.yaml | 4 +-- netket_fidelity/operator/singlequbit_gates.py | 28 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a605170..10b87a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.282 + rev: v0.0.287 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.7.0 hooks: - id: black \ No newline at end of file diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index b185905..49349d2 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -10,7 +10,7 @@ class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): super().__init__(hi) - self._local_states = hi.local_states + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -89,23 +89,23 @@ def get_conns_and_mels_Rx(sigma, idx, angle, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) mels = mels.at[1].set(-1j * jnp.sin(angle / 2)) - return cons, mels + return conns, mels @register_pytree_node_class class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): super().__init__(hi) - self._local_states = hi.local_states + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -184,24 +184,24 @@ def get_conns_and_mels_Ry(sigma, idx, angle, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) - phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0], 1, -1) + phase_factor = jnp.where(conns.at[0, idx].get() == local_states[0], 1, -1) mels = mels.at[1].set(phase_factor * jnp.sin(angle / 2)) - return cons, mels + return conns, mels @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - self._local_states = hi.local_states super().__init__(hi) + self._local_states = jnp.asarray(hi.local_states) self._idx = idx @property @@ -270,15 +270,15 @@ def get_conns_and_mels_Hadamard(sigma, idx, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=float) mels = mels.at[1].set(1 / jnp.sqrt(2)) - state_value = cons.at[0, idx].get() + state_value = conns.at[0, idx].get() mels_value = jnp.where(state_value == local_states[0], 1, -1) / jnp.sqrt(2) mels = mels.at[0].set(mels_value) - return cons, mels + return conns, mels From 4ee4f843218ab3dbd60343d6c21810c43cc77c2c Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Mon, 4 Sep 2023 21:37:54 +0200 Subject: [PATCH 07/17] Adapted Hamadard gate to use local_states. Added black pre commit file. --- .pre-commit-config.yaml | 12 ++++ netket_fidelity/operator/singlequbit_gates.py | 55 ++++++++++--------- 2 files changed, 41 insertions(+), 26 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a605170 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.282 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] + + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black \ No newline at end of file diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 12b5f22..e09adeb 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -14,16 +14,14 @@ @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) + # if not isinstance(hi, Spin): + # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. + # + # This limitation could be lifted by 'fixing' the method + # `get_conn_and_mels` to work with arbitrary hilbert spaces, which + # should be relatively straightforward to do, but we have not done so + # yet. + # """) super().__init__(hi) self._idx = idx self._angle = angle @@ -214,17 +212,15 @@ def get_conns_and_mels_Ry(sigma, idx, angle): @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) - + # if not isinstance(hi, Spin): + # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. + # + # This limitation could be lifted by 'fixing' the method + # `get_conn_and_mels` to work with arbitrary hilbert spaces, which + # should be relatively straightforward to do, but we have not done so + # yet. + # """) + self._local_states = hi.local_states super().__init__(hi) self._idx = idx @@ -270,7 +266,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Hadamard(xr, self.idx) + xp, mels = get_conns_and_mels_Hadamard(xr, self.idx, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -287,15 +283,22 @@ def get_conn_flattened(self, x, sections): return xp, mels -@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) -def get_conns_and_mels_Hadamard(sigma, idx): +@partial(jax.vmap, in_axes=(0, None, [None, None]), out_axes=(0, 0)) +def get_conns_and_mels_Hadamard(sigma, idx, local_states): assert sigma.ndim == 1 + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + cons = jnp.tile(sigma, (2, 1)) - cons = cons.at[1, idx].set(-cons.at[1, idx].get()) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=float) mels = mels.at[1].set(1 / jnp.sqrt(2)) - mels = mels.at[0].set(((-1) ** ((cons.at[0, idx].get() + 1) / 2)) / jnp.sqrt(2)) + state_value = cons.at[0, idx].get() + mels_value = jnp.where(state_value == local_states[0], 1, -1) / jnp.sqrt(2) + mels = mels.at[0].set(mels_value) return cons, mels From 0d55caa7b954fbee6de4c6619db66dbf8eefda78 Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:05:05 +0200 Subject: [PATCH 08/17] Updated rx and ry to work with local states. --- netket_fidelity/operator/singlequbit_gates.py | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index e09adeb..2162f16 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -14,15 +14,9 @@ @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - # if not isinstance(hi, Spin): - # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - # - # This limitation could be lifted by 'fixing' the method - # `get_conn_and_mels` to work with arbitrary hilbert spaces, which - # should be relatively straightforward to do, but we have not done so - # yet. - # """) + super().__init__(hi) + self._local_states = hi.local_states self._idx = idx self._angle = angle @@ -95,33 +89,29 @@ def to_local_operator(self): @partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Rx(sigma, idx, angle): +def get_conns_and_mels_Rx(sigma, idx, angle, local_states): assert sigma.ndim == 1 - conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + + cons = jnp.tile(sigma, (2, 1)) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) + mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) mels = mels.at[1].set(-1j * jnp.sin(angle / 2)) - return conns, mels + return cons, mels @register_pytree_node_class class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) - super().__init__(hi) + self._local_states = hi.local_states self._idx = idx self._angle = angle @@ -172,7 +162,7 @@ def tree_unflatten(cls, aux_data, children): @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._lo) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -193,33 +183,32 @@ def to_local_operator(self): return ctheta + 1j * stheta * spin.sigmay(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Ry(sigma, idx, angle): +@partial(jax.vmap, in_axes=(0, None, None, [None, None]), out_axes=(0, 0)) +def get_conns_and_mels_Ry(sigma, idx, angle, local_states): assert sigma.ndim == 1 - conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + + cons = jnp.tile(sigma, (2, 1)) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + cons = cons.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) + phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0],1,-1) mels = mels.at[1].set( - (-1) ** ((conns.at[0, idx].get() + 1) / 2) * jnp.sin(angle / 2) + phase_factor * jnp.sin(angle / 2) ) - return conns, mels + return cons, mels @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - # if not isinstance(hi, Spin): - # raise TypeError("""The Hilbert space used by Rx must be a `Spin` space. - # - # This limitation could be lifted by 'fixing' the method - # `get_conn_and_mels` to work with arbitrary hilbert spaces, which - # should be relatively straightforward to do, but we have not done so - # yet. - # """) + self._local_states = hi.local_states super().__init__(hi) self._idx = idx From 5e3d8835585c7807678e2a6b95842b2b2d94df95 Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:12:48 +0200 Subject: [PATCH 09/17] Ran black code formatter --- netket_fidelity/operator/singlequbit_gates.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 2162f16..5e4b629 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -1,13 +1,8 @@ from functools import partial - import numpy as np - import jax import jax.numpy as jnp - from jax.tree_util import register_pytree_node_class - -from netket.hilbert import Spin from netket.operator import DiscreteJaxOperator, spin @@ -197,10 +192,8 @@ def get_conns_and_mels_Ry(sigma, idx, angle, local_states): mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) - phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0],1,-1) - mels = mels.at[1].set( - phase_factor * jnp.sin(angle / 2) - ) + phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0], 1, -1) + mels = mels.at[1].set(phase_factor * jnp.sin(angle / 2)) return cons, mels From 5cfdc64d7ce0ceaea66fe1271aa6cc74740e834f Mon Sep 17 00:00:00 2001 From: mboky Date: Tue, 5 Sep 2023 15:40:09 +0200 Subject: [PATCH 10/17] Fixed typo's in local_state argument --- netket_fidelity/operator/singlequbit_gates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 5e4b629..10f1075 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -62,7 +62,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -83,7 +83,7 @@ def to_local_operator(self): return ctheta - 1j * stheta * spin.sigmax(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) +@partial(jax.vmap, in_axes=(0, None, None, [None, None]), out_axes=(0, 0)) def get_conns_and_mels_Rx(sigma, idx, angle, local_states): assert sigma.ndim == 1 @@ -157,7 +157,7 @@ def tree_unflatten(cls, aux_data, children): @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._lo) + xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels From 0da79bbf9e46b585798ca7ffbdd48c6cee70b0d5 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 5 Sep 2023 19:57:45 +0200 Subject: [PATCH 11/17] cleanup ruff --- .github/workflows/formatting_check.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/formatting_check.yaml b/.github/workflows/formatting_check.yaml index c173873..37a2261 100644 --- a/.github/workflows/formatting_check.yaml +++ b/.github/workflows/formatting_check.yaml @@ -46,4 +46,4 @@ jobs: with: version: 0.0.287 args: --config pyproject.toml - src: netket test Examples + src: netket test examples From 06af8be25f41efaa56cf4e922c5e5ea4cb8cab7e Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Thu, 14 Sep 2023 19:15:28 +0200 Subject: [PATCH 12/17] Implemented small changes according to feedback Filippo. Will include operator test after returning from break. --- .pre-commit-config.yaml | 4 +-- netket_fidelity/operator/singlequbit_gates.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a605170..10b87a1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.282 + rev: v0.0.287 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.7.0 hooks: - id: black \ No newline at end of file diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 10f1075..49349d2 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -9,9 +9,8 @@ @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - super().__init__(hi) - self._local_states = hi.local_states + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -90,23 +89,23 @@ def get_conns_and_mels_Rx(sigma, idx, angle, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) mels = mels.at[1].set(-1j * jnp.sin(angle / 2)) - return cons, mels + return conns, mels @register_pytree_node_class class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): super().__init__(hi) - self._local_states = hi.local_states + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -185,25 +184,24 @@ def get_conns_and_mels_Ry(sigma, idx, angle, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) - phase_factor = jnp.where(cons.at[0, idx].get() == local_states[0], 1, -1) + phase_factor = jnp.where(conns.at[0, idx].get() == local_states[0], 1, -1) mels = mels.at[1].set(phase_factor * jnp.sin(angle / 2)) - return cons, mels + return conns, mels @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - - self._local_states = hi.local_states super().__init__(hi) + self._local_states = jnp.asarray(hi.local_states) self._idx = idx @property @@ -272,15 +270,15 @@ def get_conns_and_mels_Hadamard(sigma, idx, local_states): state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) - cons = jnp.tile(sigma, (2, 1)) + conns = jnp.tile(sigma, (2, 1)) current_state = sigma[idx] flipped_state = jnp.where(current_state == state_0, state_1, state_0) - cons = cons.at[1, idx].set(flipped_state) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=float) mels = mels.at[1].set(1 / jnp.sqrt(2)) - state_value = cons.at[0, idx].get() + state_value = conns.at[0, idx].get() mels_value = jnp.where(state_value == local_states[0], 1, -1) / jnp.sqrt(2) mels = mels.at[0].set(mels_value) - return cons, mels + return conns, mels From 6f7e8648581a53c97ec53fddeca39e0b338e709a Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Mon, 30 Oct 2023 15:14:09 +0100 Subject: [PATCH 13/17] Implemented get_conns_and_mels_... with local_states from hilbert space to support both spin and qubit spaces. Implemented additional test. --- netket_fidelity/operator/singlequbit_gates.py | 89 ++++++++----------- test/test_operator.py | 42 ++++++++- 2 files changed, 79 insertions(+), 52 deletions(-) diff --git a/netket_fidelity/operator/singlequbit_gates.py b/netket_fidelity/operator/singlequbit_gates.py index 05c142c..84ce800 100644 --- a/netket_fidelity/operator/singlequbit_gates.py +++ b/netket_fidelity/operator/singlequbit_gates.py @@ -7,24 +7,14 @@ from jax.tree_util import register_pytree_node_class -from netket.hilbert import Spin from netket.operator import DiscreteJaxOperator, spin @register_pytree_node_class class Rx(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Rx must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) super().__init__(hi) + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -75,7 +65,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Rx(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -96,12 +86,18 @@ def to_local_operator(self): return ctheta - 1j * stheta * spin.sigmax(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Rx(sigma, idx, angle): +@partial(jax.vmap, in_axes=(0, None, None, None), out_axes=(0, 0)) +def get_conns_and_mels_Rx(sigma, idx, angle, local_states): assert sigma.ndim == 1 + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + conns = conns.at[1, idx].set(flipped_state) + mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) mels = mels.at[1].set(-1j * jnp.sin(angle / 2)) @@ -112,18 +108,8 @@ def get_conns_and_mels_Rx(sigma, idx, angle): @register_pytree_node_class class Ry(DiscreteJaxOperator): def __init__(self, hi, idx, angle): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Ry must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) - super().__init__(hi) + self._local_states = jnp.asarray(hi.local_states) self._idx = idx self._angle = angle @@ -174,7 +160,7 @@ def tree_unflatten(cls, aux_data, children): @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle) + xp, mels = get_conns_and_mels_Ry(xr, self.idx, self.angle, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -195,18 +181,22 @@ def to_local_operator(self): return ctheta + 1j * stheta * spin.sigmay(self.hilbert, self.idx) -@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) -def get_conns_and_mels_Ry(sigma, idx, angle): +@partial(jax.vmap, in_axes=(0, None, None, None), out_axes=(0, 0)) +def get_conns_and_mels_Ry(sigma, idx, angle, local_states): assert sigma.ndim == 1 + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + conns = jnp.tile(sigma, (2, 1)) - conns = conns.at[1, idx].set(-conns.at[1, idx].get()) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=complex) mels = mels.at[0].set(jnp.cos(angle / 2)) - mels = mels.at[1].set( - (-1) ** ((conns.at[0, idx].get() + 1) / 2) * jnp.sin(angle / 2) - ) + phase_factor = jnp.where(conns.at[0, idx].get() == local_states[0], 1, -1) + mels = mels.at[1].set(phase_factor * jnp.sin(angle / 2)) return conns, mels @@ -214,18 +204,8 @@ def get_conns_and_mels_Ry(sigma, idx, angle): @register_pytree_node_class class Hadamard(DiscreteJaxOperator): def __init__(self, hi, idx): - if not isinstance(hi, Spin): - raise TypeError( - """The Hilbert space used by Hadamard must be a `Spin` space. - - This limitation could be lifted by 'fixing' the method - `get_conn_and_mels` to work with arbitrary hilbert spaces, which - should be relatively straightforward to do, but we have not done so - yet. - """ - ) - super().__init__(hi) + self._local_states = jnp.asarray(hi.local_states) self._idx = idx @property @@ -270,7 +250,7 @@ def max_conn_size(self) -> int: @jax.jit def get_conn_padded(self, x): xr = x.reshape(-1, x.shape[-1]) - xp, mels = get_conns_and_mels_Hadamard(xr, self.idx) + xp, mels = get_conns_and_mels_Hadamard(xr, self.idx, self._local_states) xp = xp.reshape(x.shape[:-1] + xp.shape[-2:]) mels = mels.reshape(x.shape[:-1] + mels.shape[-1:]) return xp, mels @@ -287,15 +267,22 @@ def get_conn_flattened(self, x, sections): return xp, mels -@partial(jax.vmap, in_axes=(0, None), out_axes=(0, 0)) -def get_conns_and_mels_Hadamard(sigma, idx): +@partial(jax.vmap, in_axes=(0, None, None), out_axes=(0, 0)) +def get_conns_and_mels_Hadamard(sigma, idx, local_states): assert sigma.ndim == 1 - cons = jnp.tile(sigma, (2, 1)) - cons = cons.at[1, idx].set(-cons.at[1, idx].get()) + state_0 = jnp.asarray(local_states[0], dtype=sigma.dtype) + state_1 = jnp.asarray(local_states[1], dtype=sigma.dtype) + + conns = jnp.tile(sigma, (2, 1)) + current_state = sigma[idx] + flipped_state = jnp.where(current_state == state_0, state_1, state_0) + conns = conns.at[1, idx].set(flipped_state) mels = jnp.zeros(2, dtype=float) mels = mels.at[1].set(1 / jnp.sqrt(2)) - mels = mels.at[0].set(((-1) ** ((cons.at[0, idx].get() + 1) / 2)) / jnp.sqrt(2)) + state_value = conns.at[0, idx].get() + mels_value = jnp.where(state_value == local_states[0], 1, -1) / jnp.sqrt(2) + mels = mels.at[0].set(mels_value) - return cons, mels + return conns, mels diff --git a/test/test_operator.py b/test/test_operator.py index 8940fcc..27eb06e 100644 --- a/test/test_operator.py +++ b/test/test_operator.py @@ -1,7 +1,8 @@ import pytest import netket as nk import numpy as np - +from netket_fidelity.operator import singlequbit_gates as sg +from jax import numpy as jnp import netket_fidelity as nkf @@ -26,3 +27,42 @@ def test_operator_dense_and_conversion(operator): np.testing.assert_allclose(op_dense, op_local_dense) assert operator.hilbert == operator.to_local_operator().hilbert + + +def test_get_conns(): + hi_spin = nk.hilbert.Spin(s=0.5, N=3) + hi_qubit = nk.hilbert.Qubit(N=3) + + local_state_spin = hi_spin.local_states + local_state_qubit = hi_qubit.local_states + + sigma_4_qubit = hi_qubit.numbers_to_states(2) + sigma_7_qubit = hi_qubit.numbers_to_states(7) + sigma_4_spin = hi_spin.numbers_to_states(2) + sigma_7_spin = hi_spin.numbers_to_states(7) + + sigma_qubit = jnp.array([sigma_4_qubit, sigma_7_qubit]) + sigma_spin = jnp.array([sigma_4_spin, sigma_7_spin]) + + conns_rx_qubit, _ = sg.get_conns_and_mels_Rx(sigma_qubit, 0, 0, local_state_qubit) + conns_ry_qubit, _ = sg.get_conns_and_mels_Ry(sigma_qubit, 0, 0, local_state_qubit) + conns_h_qubit, _ = sg.get_conns_and_mels_Hadamard(sigma_qubit, 0, local_state_qubit) + + conns_rx_spin, _ = sg.get_conns_and_mels_Rx(sigma_spin, 0, 0, local_state_spin) + conns_ry_spin, _ = sg.get_conns_and_mels_Ry(sigma_spin, 0, 0, local_state_spin) + conns_h_spin, _ = sg.get_conns_and_mels_Hadamard(sigma_spin, 0, local_state_spin) + + values_check_qubit = jnp.array( + [[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], [[1.0, 1.0, 1.0], [0.0, 1.0, 1.0]]] + ) + + values_check_spin = jnp.array( + [[[-1.0, 1.0, -1.0], [1.0, 1.0, -1.0]], [[1.0, 1.0, 1.0], [-1.0, 1.0, 1.0]]] + ) + + assert (conns_rx_qubit == values_check_qubit).all() + assert (conns_ry_qubit == values_check_qubit).all() + assert (conns_h_qubit == values_check_qubit).all() + assert (conns_rx_spin == values_check_spin).all() + assert (conns_ry_spin == values_check_spin).all() + assert (conns_h_spin == values_check_spin).all() From 749f2b5b178a8309eb783fba9160748f0a184967 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 31 Oct 2023 18:44:33 +0100 Subject: [PATCH 14/17] bump black ruff versions --- .github/workflows/formatting_check.yaml | 4 ++-- .pre-commit-config.yaml | 6 +++--- pyproject.toml | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/formatting_check.yaml b/.github/workflows/formatting_check.yaml index 37a2261..43ca0cd 100644 --- a/.github/workflows/formatting_check.yaml +++ b/.github/workflows/formatting_check.yaml @@ -25,7 +25,7 @@ jobs: - name: Pip install python dependencies run: | python -m pip install --upgrade pip - pip install -v black==23.7.0 + pip install -v black==23.10.1 - name: Black Code Formatter run: black --check --diff --color . @@ -44,6 +44,6 @@ jobs: - name: Set up Python 3.10 uses: chartboost/ruff-action@v1 with: - version: 0.0.287 + version: 0.1.3 args: --config pyproject.toml src: netket test examples diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 10b87a1..48c15d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.0.287 + rev: v0.1.3 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.10.1 hooks: - - id: black \ No newline at end of file + - id: black diff --git a/pyproject.toml b/pyproject.toml index e0b9b66..f486148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,8 @@ dev = [ "pytest-json-report>=1.3", "coverage>=5", "pre-commit>=2.7", - "black==23.7.0", - "ruff==0.0.287", + "black==23.10.1", + "ruff==0.1.3", "wheel", "build", "qutip", From 99f2f35cb3e982bdc1d18213191ba7d6802ab939 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 31 Oct 2023 18:44:39 +0100 Subject: [PATCH 15/17] format with black --- .../dynamics_with_measurements/RBM_Jastrow_measurement.py | 1 - examples/state_learning/state_learning.py | 8 +++----- netket_fidelity/driver/infidelity_optimizer.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/examples/dynamics_with_measurements/RBM_Jastrow_measurement.py b/examples/dynamics_with_measurements/RBM_Jastrow_measurement.py index 8881083..d80f2cf 100644 --- a/examples/dynamics_with_measurements/RBM_Jastrow_measurement.py +++ b/examples/dynamics_with_measurements/RBM_Jastrow_measurement.py @@ -15,7 +15,6 @@ @partial(jax.vmap, in_axes=(0, None, None, None), out_axes=(0)) def spwf(sigma, orbital_up, orbital_down, N): - return 0.5 * (1 + sigma) * orbital_up + 0.5 * (1 - sigma) * orbital_down diff --git a/examples/state_learning/state_learning.py b/examples/state_learning/state_learning.py index 32ae714..598a3f5 100644 --- a/examples/state_learning/state_learning.py +++ b/examples/state_learning/state_learning.py @@ -39,13 +39,11 @@ # or just use the Infidelity optimisation driver optimizer = nk.optimizer.Adam() -driver = nkf.driver.InfidelityOptimizer( - vs_target, optimizer, variational_state=vs -) +driver = nkf.driver.InfidelityOptimizer(vs_target, optimizer, variational_state=vs) log = nk.logging.RuntimeLog() driver.run(300, out=log) plt.ion() -plt.semilogy(log.data['Infidelity'].iters, log.data['Infidelity']) -plt.show() \ No newline at end of file +plt.semilogy(log.data["Infidelity"].iters, log.data["Infidelity"]) +plt.show() diff --git a/netket_fidelity/driver/infidelity_optimizer.py b/netket_fidelity/driver/infidelity_optimizer.py index e034213..8a56e66 100644 --- a/netket_fidelity/driver/infidelity_optimizer.py +++ b/netket_fidelity/driver/infidelity_optimizer.py @@ -39,7 +39,7 @@ def __init__( U_dagger=None, preconditioner: PreconditionerT = identity_preconditioner, is_unitary=False, - sample_Upsi=False, + sample_Upsi=False, cv_coeff=-0.5, ): r""" @@ -135,7 +135,7 @@ def __init__( U_dagger=U_dagger, is_unitary=is_unitary, cv_coeff=cv_coeff, - sample_Upsi=sample_Upsi, + sample_Upsi=sample_Upsi, ) def _forward_and_backward(self): From f32dc1375cbf0213d9fef22fc4a340954fdf2902 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Tue, 31 Oct 2023 18:47:38 +0100 Subject: [PATCH 16/17] fixes --- .github/workflows/formatting_check.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/formatting_check.yaml b/.github/workflows/formatting_check.yaml index 43ca0cd..f51069a 100644 --- a/.github/workflows/formatting_check.yaml +++ b/.github/workflows/formatting_check.yaml @@ -46,4 +46,4 @@ jobs: with: version: 0.1.3 args: --config pyproject.toml - src: netket test examples + src: netket_fidelity test examples From 8d0044c91bff83fd22434250d1193ae9f059a858 Mon Sep 17 00:00:00 2001 From: Mohammed Boky Date: Tue, 5 Dec 2023 13:36:00 +0100 Subject: [PATCH 17/17] Implemented feedback on conns test: -better naming convention -np.testing.assert_allclose instead of assert Addition -matrix element check. --- test/test_operator.py | 82 ++++++++++++++++++++++++++++++++----------- 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/test/test_operator.py b/test/test_operator.py index 316c316..9ce3dbc 100644 --- a/test/test_operator.py +++ b/test/test_operator.py @@ -3,7 +3,6 @@ import numpy as np from netket_fidelity.operator import singlequbit_gates as sg from jax import numpy as jnp - import netket_fidelity as nkf @@ -29,40 +28,83 @@ def test_operator_dense_and_conversion(operator): assert operator.hilbert == operator.to_local_operator().hilbert -def test_get_conns(): +def test_get_conns_and_mels(): hi_spin = nk.hilbert.Spin(s=0.5, N=3) hi_qubit = nk.hilbert.Qubit(N=3) local_state_spin = hi_spin.local_states local_state_qubit = hi_qubit.local_states - sigma_4_qubit = hi_qubit.numbers_to_states(2) + sigma_2_qubit = hi_qubit.numbers_to_states(2) sigma_7_qubit = hi_qubit.numbers_to_states(7) - sigma_4_spin = hi_spin.numbers_to_states(2) + sigma_2_spin = hi_spin.numbers_to_states(2) sigma_7_spin = hi_spin.numbers_to_states(7) - sigma_qubit = jnp.array([sigma_4_qubit, sigma_7_qubit]) - sigma_spin = jnp.array([sigma_4_spin, sigma_7_spin]) + sigma_qubit = jnp.array([sigma_2_qubit, sigma_7_qubit]) + sigma_spin = jnp.array([sigma_2_spin, sigma_7_spin]) - conns_rx_qubit, _ = sg.get_conns_and_mels_Rx(sigma_qubit, 0, 0, local_state_qubit) - conns_ry_qubit, _ = sg.get_conns_and_mels_Ry(sigma_qubit, 0, 0, local_state_qubit) - conns_h_qubit, _ = sg.get_conns_and_mels_Hadamard(sigma_qubit, 0, local_state_qubit) + conns_rx_qubit, mels_rx_qubit = sg.get_conns_and_mels_Rx( + sigma_qubit, 0, np.pi / 2, local_state_qubit + ) + conns_ry_qubit, mels_ry_qubit = sg.get_conns_and_mels_Ry( + sigma_qubit, 0, np.pi / 2, local_state_qubit + ) + conns_h_qubit, mels_h_qubit = sg.get_conns_and_mels_Hadamard( + sigma_qubit, 0, local_state_qubit + ) - conns_rx_spin, _ = sg.get_conns_and_mels_Rx(sigma_spin, 0, 0, local_state_spin) - conns_ry_spin, _ = sg.get_conns_and_mels_Ry(sigma_spin, 0, 0, local_state_spin) - conns_h_spin, _ = sg.get_conns_and_mels_Hadamard(sigma_spin, 0, local_state_spin) + conns_rx_spin, mels_rx_spin = sg.get_conns_and_mels_Rx( + sigma_spin, 0, np.pi / 2, local_state_spin + ) + conns_ry_spin, mels_ry_spin = sg.get_conns_and_mels_Ry( + sigma_spin, 0, np.pi / 2, local_state_spin + ) + conns_h_spin, mels_h_spin = sg.get_conns_and_mels_Hadamard( + sigma_spin, 0, local_state_spin + ) - values_check_qubit = jnp.array( + conns_check_qubit = jnp.array( [[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], [[1.0, 1.0, 1.0], [0.0, 1.0, 1.0]]] ) - values_check_spin = jnp.array( + conns_check_spin = jnp.array( [[[-1.0, 1.0, -1.0], [1.0, 1.0, -1.0]], [[1.0, 1.0, 1.0], [-1.0, 1.0, 1.0]]] ) - assert (conns_rx_qubit == values_check_qubit).all() - assert (conns_ry_qubit == values_check_qubit).all() - assert (conns_h_qubit == values_check_qubit).all() - assert (conns_rx_spin == values_check_spin).all() - assert (conns_ry_spin == values_check_spin).all() - assert (conns_h_spin == values_check_spin).all() + mels_check_qubit_rx = jnp.array( + [[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]] + ) + mels_check_qubit_ry = jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ] + ) + mels_check_qubit_h = jnp.array( + [[0.70710678, 0.70710678], [-0.70710678, 0.70710678]] + ) + + mels_check_spin_rx = jnp.array( + [[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]] + ) + mels_check_spin_ry = jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ] + ) + mels_check_spin_h = jnp.array([[0.70710678, 0.70710678], [-0.70710678, 0.70710678]]) + + np.testing.assert_allclose(conns_rx_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_ry_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_h_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_rx_spin, conns_check_spin) + np.testing.assert_allclose(conns_ry_spin, conns_check_spin) + np.testing.assert_allclose(conns_h_spin, conns_check_spin) + + np.testing.assert_allclose(mels_rx_qubit, mels_check_qubit_rx) + np.testing.assert_allclose(mels_ry_qubit, mels_check_qubit_ry) + np.testing.assert_allclose(mels_h_qubit, mels_check_qubit_h) + np.testing.assert_allclose(mels_rx_spin, mels_check_spin_rx) + np.testing.assert_allclose(mels_ry_spin, mels_check_spin_ry) + np.testing.assert_allclose(mels_h_spin, mels_check_spin_h)