Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 63 additions & 28 deletions qualtran/_infra/composite_bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
Mapping,
Optional,
overload,
Protocol,
Sequence,
Set,
Tuple,
TYPE_CHECKING,
TypeGuard,
TypeVar,
Union,
)
Expand All @@ -55,13 +57,30 @@
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import SymbolicInt

# NDArrays must be bound to np.generic
_SoquetType = TypeVar('_SoquetType', bound=np.generic)

SoquetT = Union[Soquet, NDArray[_SoquetType]]
"""A `Soquet` or array of soquets."""
class SoquetT(Protocol):
"""Either a Soquet or an array thereof.

SoquetInT = Union[Soquet, NDArray[_SoquetType], Sequence[Soquet]]
To narrow objects of this type, use `BloqBuilder.is_single(soq)` and/or
`BloqBuilder.is_ndarray(soqs)`.

Example:
>>> soq_or_soqs: SoquetT
... if BloqBuilder.is_ndarray(soq_or_soqs):
... first_soq = soq_or_soqs.reshape(-1).item(0)
... else:
... # Note: `.item()` raises if not a single item.
... first_soq = soq_or_soqs.item()

"""

@property
def shape(self) -> Tuple[int, ...]: ...

def item(self, *args) -> Soquet: ...


SoquetInT = Union[SoquetT, Sequence[SoquetT]]
"""A soquet or array-like of soquets.

This type alias is used for input argument to parts of the library that are more
Expand Down Expand Up @@ -693,9 +712,10 @@ def _flatten_soquet_collection(vals: Iterable[SoquetT]) -> List[Soquet]:
"""
soqvals = []
for soq_or_arr in vals:
if isinstance(soq_or_arr, Soquet):
soqvals.append(soq_or_arr)
if BloqBuilder.is_single(soq_or_arr):
soqvals.append(soq_or_arr.item())
else:
assert BloqBuilder.is_ndarray(soq_or_arr)
soqvals.extend(soq_or_arr.reshape(-1))
return soqvals

Expand Down Expand Up @@ -802,13 +822,10 @@ def _process_soquets(
unchecked_names.remove(reg.name) # so we can check for surplus arguments.

for li in reg.all_idxs():
idxed_soq = in_soq[li]
assert isinstance(idxed_soq, Soquet), idxed_soq
idxed_soq = in_soq[li].item()
func(idxed_soq, reg, li)
if not check_dtypes_consistent(idxed_soq.reg.dtype, reg.dtype):
extra_str = (
f"{idxed_soq.reg.name}: {idxed_soq.reg.dtype} vs {reg.name}: {reg.dtype}"
)
if not check_dtypes_consistent(idxed_soq.dtype, reg.dtype):
extra_str = f"{idxed_soq.reg.name}: {idxed_soq.dtype} vs {reg.name}: {reg.dtype}"
raise BloqError(
f"{debug_str} register dtypes are not consistent {extra_str}."
) from None
Expand Down Expand Up @@ -838,9 +855,9 @@ def _map_soqs(
# First: flatten out any numpy arrays
flat_soq_map: Dict[Soquet, Soquet] = {}
for old_soqs, new_soqs in soq_map:
if isinstance(old_soqs, Soquet):
assert isinstance(new_soqs, Soquet), new_soqs
flat_soq_map[old_soqs] = new_soqs
if BloqBuilder.is_single(old_soqs):
assert BloqBuilder.is_single(new_soqs), new_soqs
flat_soq_map[old_soqs] = new_soqs.item()
continue

assert isinstance(old_soqs, np.ndarray), old_soqs
Expand All @@ -858,9 +875,9 @@ def _map_soq(soq: Soquet) -> Soquet:
vmap = np.vectorize(_map_soq, otypes=[object])

def _map_soqs(soqs: SoquetT) -> SoquetT:
if isinstance(soqs, Soquet):
return _map_soq(soqs)
return vmap(soqs)
if BloqBuilder.is_ndarray(soqs):
return vmap(soqs)
return _map_soq(soqs.item())

return {name: _map_soqs(soqs) for name, soqs in soqs.items()}

Expand Down Expand Up @@ -1061,6 +1078,24 @@ def from_signature(

return bb, initial_soqs

@staticmethod
def is_single(x: 'SoquetT') -> TypeGuard['Soquet']:
"""Returns True if `x` is a single soquet (not an ndarray of them).

This doesn't use stringent runtime type checking; it uses the SoquetT protocol
for "duck typing".
"""
return x.shape == ()

@staticmethod
def is_ndarray(x: 'SoquetT') -> TypeGuard['NDArray']:
"""Returns True if `x` is an ndarray of soquets (not a single one).

This doesn't use stringent runtime type checking; it uses the SoquetT protocol
for "duck typing".
"""
return x.shape != ()

@staticmethod
def map_soqs(
soqs: Dict[str, SoquetT], soq_map: Iterable[Tuple[SoquetT, SoquetT]]
Expand Down Expand Up @@ -1265,8 +1300,7 @@ def add_from(self, bloq: Bloq, **in_soqs: SoquetInT) -> Tuple[SoquetT, ...]:
cbloq = bloq.decompose_bloq()

for k, v in in_soqs.items():
if not isinstance(v, Soquet):
in_soqs[k] = np.asarray(v)
in_soqs[k] = np.asarray(v)

# Initial mapping of LeftDangle according to user-provided in_soqs.
soq_map: List[Tuple[SoquetT, SoquetT]] = [
Expand Down Expand Up @@ -1306,12 +1340,13 @@ def finalize(self, **final_soqs: SoquetT) -> CompositeBloq:

def _infer_reg(name: str, soq: SoquetT) -> Register:
"""Go from Soquet -> register, but use a specific name for the register."""
if isinstance(soq, Soquet):
return Register(name=name, dtype=soq.reg.dtype, side=Side.RIGHT)
if BloqBuilder.is_single(soq):
return Register(name=name, dtype=soq.dtype, side=Side.RIGHT)
assert BloqBuilder.is_ndarray(soq)

# Get info from 0th soquet in an ndarray.
return Register(
name=name, dtype=soq.reshape(-1)[0].reg.dtype, shape=soq.shape, side=Side.RIGHT
name=name, dtype=soq.reshape(-1).item(0).dtype, shape=soq.shape, side=Side.RIGHT
)

right_reg_names = [reg.name for reg in self._regs if reg.side & Side.RIGHT]
Expand Down Expand Up @@ -1358,10 +1393,10 @@ def allocate(
def free(self, soq: Soquet, dirty: bool = False) -> None:
from qualtran.bloqs.bookkeeping import Free

if not isinstance(soq, Soquet):
if not BloqBuilder.is_single(soq):
raise ValueError("`free` expects a single Soquet to free.")

qdtype = soq.reg.dtype
qdtype = soq.dtype
if not isinstance(qdtype, QDType):
raise ValueError("`free` can only free quantum registers.")

Expand All @@ -1371,10 +1406,10 @@ def split(self, soq: Soquet) -> NDArray[Soquet]: # type: ignore[type-var]
"""Add a Split bloq to split up a register."""
from qualtran.bloqs.bookkeeping import Split

if not isinstance(soq, Soquet):
if not BloqBuilder.is_single(soq):
raise ValueError("`split` expects a single Soquet to split.")

qdtype = soq.reg.dtype
qdtype = soq.dtype
if not isinstance(qdtype, QDType):
raise ValueError("`split` can only split quantum registers.")

Expand Down
37 changes: 36 additions & 1 deletion qualtran/_infra/composite_bloq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, List, Tuple
from typing import cast, Dict, List, Tuple

import attrs
import networkx as nx
import numpy as np
import pytest
import sympy
from numpy.typing import NDArray
from typing_extensions import assert_type

import qualtran.testing as qlt_testing
from qualtran import (
Expand Down Expand Up @@ -643,6 +644,40 @@ def test_get_soquet():
_ = _get_soquet(binst=binst, reg_name='in', right=True, binst_graph=binst_graph)


def test_can_tell_individual_from_ndsoquet():
s1 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(0,))
s2 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(1,))
s3 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(2,))
s4 = Soquet(cast(BloqInstance, None), Register('test', QBit(), shape=(4,)), idx=(3,))

# A ndarray of soquet objects should be SoquetT and we can tell by checking its shape.
ndsoq: SoquetT = np.array([s1, s2, s3, s4])
assert_type(ndsoq, SoquetT)
assert ndsoq.shape
assert ndsoq.shape == (4,)
assert ndsoq.item(2) == s3
with pytest.raises(ValueError, match=r'scalar'):
_ = ndsoq.item()

# A single soquet is still a valid SoquetT, and it has a false-y shape.
single_soq: SoquetT = s1
assert_type(single_soq, SoquetT)
assert not single_soq.shape
assert single_soq.shape == ()
single_soq_unwarp = single_soq.item()
assert single_soq_unwarp == s1

# A single soquet wrapped in a 0-dim ndarray is ok if you call `item()`.
single_soq2: SoquetT = np.asarray(s1)
assert_type(single_soq2, SoquetT)
assert not single_soq2.shape
assert single_soq2.shape == ()
single_soq2_unwrap = single_soq2.item()
assert hash(single_soq2_unwrap) == hash(s1)
assert single_soq2_unwrap == s1
assert isinstance(single_soq2_unwrap, Soquet)


@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('composite_bloq')
16 changes: 15 additions & 1 deletion qualtran/_infra/quantum_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from attrs import field, frozen

if TYPE_CHECKING:
from qualtran import Bloq, Register
from qualtran import Bloq, BloqBuilder, QCDType, Register


@frozen
Expand Down Expand Up @@ -103,6 +103,20 @@ def _check_idx(self, attribute, value):
for i, shape in zip(value, self.reg.shape):
if i >= shape:
raise ValueError(f"Bad index {i} for {self.reg}.")
return value

@property
def dtype(self) -> 'QCDType':
return self.reg.dtype

@property
def shape(self) -> Tuple[int, ...]:
return ()

def item(self, *args) -> 'Soquet':
if args:
raise ValueError("Tried to index into a single soquet.")
return self

def pretty(self) -> str:
label = self.reg.name
Expand Down
3 changes: 3 additions & 0 deletions qualtran/_infra/quantum_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_soquet():
assert soq.idx == ()
assert soq.pretty() == 'x'

assert soq.item() == soq
assert soq.dtype == QAny(10)


def test_soquet_idxed():
binst = BloqInstance(TestTwoBitOp(), i=0)
Expand Down
3 changes: 1 addition & 2 deletions qualtran/bloqs/basic_gates/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ def signature(self) -> 'Signature':
def build_composite_bloq(self, bb: 'BloqBuilder', q: 'SoquetT') -> Dict[str, 'SoquetT']:
from qualtran.bloqs.mcmt import And

q1, q2 = q # type: ignore
(q1, q2), anc = bb.add(And(), ctrl=[q1, q2])
(q1, q2), anc = bb.add(And(), ctrl=q)
anc = bb.add(ZPowGate(self.exponent, eps=self.eps), q=anc)
(q1, q2) = bb.add(And().adjoint(), ctrl=[q1, q2], target=anc)
return {'q': np.array([q1, q2])}
Expand Down
3 changes: 1 addition & 2 deletions qualtran/bloqs/block_encoding/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ def build_composite_bloq(
if is_symbolic(self.system_bitsize) or is_symbolic(self.row_oracle.num_nonzero):
raise DecomposeTypeError(f"Cannot decompose symbolic {self=}")

assert not isinstance(ancilla, np.ndarray)
ancilla_bits = bb.split(ancilla)
ancilla_bits = bb.split(ancilla.item())
q, l = ancilla_bits[0], bb.join(ancilla_bits[1:])

l = bb.add(self.diffusion, target=l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def _reshape_reg(
"""
# np.prod(()) returns a float (1.0), so take int
size = int(np.prod(out_shape))
if isinstance(in_reg, np.ndarray):
if BloqBuilder.is_ndarray(in_reg):
# split an array of bitsize qubits into flat list of qubits
split_qubits = bb.split(bb.join(np.concatenate([bb.split(x) for x in in_reg.ravel()])))
else:
split_qubits = bb.split(in_reg)
split_qubits = bb.split(in_reg.item())
merged_qubits = np.array(
[bb.join(split_qubits[i * bitsize : (i + 1) * bitsize]) for i in range(size)]
)
Expand Down
5 changes: 2 additions & 3 deletions qualtran/bloqs/chemistry/trotter/grid_ham/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
QAny,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran._infra.data_types import BQUInt
Expand Down Expand Up @@ -93,7 +92,7 @@ def wire_symbol(
def build_composite_bloq(
self, bb: BloqBuilder, *, system_i: SoquetT, system_j: SoquetT
) -> Dict[str, SoquetT]:
if isinstance(system_i, Soquet) or isinstance(system_j, Soquet):
if not (BloqBuilder.is_ndarray(system_i) and BloqBuilder.is_ndarray(system_j)):
raise ValueError("system_i and system_j must be numpy arrays of Soquet")
# compute r_i - r_j
# r_i + (-r_j), in practice we need to flip the sign bit, but this is just 3 cliffords.
Expand Down Expand Up @@ -227,7 +226,7 @@ def wire_symbol(
return super().wire_symbol(reg, idx)

def build_composite_bloq(self, bb: BloqBuilder, *, system: SoquetT) -> Dict[str, SoquetT]:
if isinstance(system, Soquet):
if not BloqBuilder.is_ndarray(system):
raise ValueError("system must be a numpy array of Soquet")
bitsize = (self.num_grid - 1).bit_length() + 1
ij_pairs = np.triu_indices(self.num_elec, k=1)
Expand Down
4 changes: 2 additions & 2 deletions qualtran/bloqs/data_loading/qroam_clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,8 @@ def build_composite_bloq(self, bb: 'BloqBuilder', **soqs: 'SoquetT') -> Dict[str
# Construct and return dictionary of final soquets.
soqs |= {reg.name: soq for reg, soq in zip(self.control_registers, ctrl)}
soqs |= {reg.name: soq for reg, soq in zip(self.selection_registers, selection)}
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[union-attr]
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[union-attr]
soqs |= {reg.name: soq.flat[1:] for reg, soq in zip(self.junk_registers, qrom_targets)} # type: ignore[attr-defined]
soqs |= {reg.name: soq.flat[0] for reg, soq in zip(self.target_registers, qrom_targets)} # type: ignore[attr-defined]
return soqs

def on_classical_vals(
Expand Down
9 changes: 3 additions & 6 deletions qualtran/bloqs/for_testing/with_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# limitations under the License.

from functools import cached_property
from typing import Dict, TYPE_CHECKING
from typing import Dict

from attrs import frozen

from qualtran import Bloq, BloqBuilder, Signature, Soquet
from qualtran import Bloq, BloqBuilder, Signature, SoquetT
from qualtran.bloqs.for_testing.atom import TestAtom

if TYPE_CHECKING:
from qualtran import SoquetT


@frozen
class TestSerialCombo(Bloq):
Expand All @@ -47,7 +44,7 @@ def signature(self) -> Signature:
return Signature.build(reg=3)

def build_composite_bloq(self, bb: 'BloqBuilder', reg: 'SoquetT') -> Dict[str, 'SoquetT']:
assert isinstance(reg, Soquet)
assert BloqBuilder.is_single(reg)
reg = bb.split(reg)
for i in range(len(reg)):
reg[i] = bb.add(TestAtom(), q=reg[i])
Expand Down
Loading