|
| 1 | +# Copyright 2024 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from collections import Counter |
| 15 | +from typing import Union |
| 16 | + |
| 17 | +import attrs |
| 18 | +import numpy as np |
| 19 | +from attrs import frozen |
| 20 | + |
| 21 | +from qualtran import ( |
| 22 | + AddControlledT, |
| 23 | + Bloq, |
| 24 | + bloq_example, |
| 25 | + BloqBuilder, |
| 26 | + BloqDocSpec, |
| 27 | + CtrlSpec, |
| 28 | + QBit, |
| 29 | + QInt, |
| 30 | + QUInt, |
| 31 | + Register, |
| 32 | + Signature, |
| 33 | + Soquet, |
| 34 | + SoquetT, |
| 35 | +) |
| 36 | +from qualtran.bloqs.arithmetic import LinearDepthHalfLessThan |
| 37 | +from qualtran.bloqs.basic_gates import CNOT, XGate |
| 38 | +from qualtran.bloqs.mcmt import MultiControlX |
| 39 | +from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator |
| 40 | +from qualtran.simulation.classical_sim import ClassicalValT |
| 41 | +from qualtran.symbolics import HasLength, is_symbolic, SymbolicInt |
| 42 | + |
| 43 | + |
| 44 | +@frozen |
| 45 | +class HasDuplicates(Bloq): |
| 46 | + r"""Given a sorted list of `l` numbers, check if it contains any duplicates. |
| 47 | +
|
| 48 | + Produces a single qubit which is `1` if there are duplicates, and `0` if all are disjoint. |
| 49 | + It compares every adjacent pair, and therefore uses `l - 1` comparisons. |
| 50 | + It then uses a single MCX on `l - 1` bits gate to compute the flag. |
| 51 | +
|
| 52 | + Args: |
| 53 | + l: number of elements in the list |
| 54 | + dtype: type of each element to store `[n]`. |
| 55 | +
|
| 56 | + Registers: |
| 57 | + xs: a list of `l` registers of `dtype`. |
| 58 | + flag: single qubit. Value is flipped if the input list has duplicates, otherwise stays same. |
| 59 | +
|
| 60 | + References: |
| 61 | + [Quartic quantum speedups for planted inference](https://arxiv.org/abs/2406.19378v1) |
| 62 | + Lemma 4.12. Eq. 122. |
| 63 | + """ |
| 64 | + |
| 65 | + l: SymbolicInt |
| 66 | + dtype: Union[QUInt, QInt] |
| 67 | + is_controlled: bool = False |
| 68 | + |
| 69 | + @property |
| 70 | + def signature(self) -> 'Signature': |
| 71 | + registers = [Register('xs', self.dtype, shape=(self.l,)), Register('flag', QBit())] |
| 72 | + if self.is_controlled: |
| 73 | + registers.append(Register('ctrl', QBit())) |
| 74 | + return Signature(registers) |
| 75 | + |
| 76 | + @property |
| 77 | + def _le_bloq(self) -> LinearDepthHalfLessThan: |
| 78 | + return LinearDepthHalfLessThan(self.dtype) |
| 79 | + |
| 80 | + def build_composite_bloq( |
| 81 | + self, bb: 'BloqBuilder', xs: 'SoquetT', flag: 'Soquet', **extra_soqs: 'SoquetT' |
| 82 | + ) -> dict[str, 'SoquetT']: |
| 83 | + assert not is_symbolic(self.l) |
| 84 | + assert isinstance(xs, np.ndarray) |
| 85 | + |
| 86 | + cs = [] |
| 87 | + oks = [] |
| 88 | + if self.is_controlled: |
| 89 | + oks = [extra_soqs.pop('ctrl')] |
| 90 | + assert not extra_soqs |
| 91 | + |
| 92 | + for i in range(1, self.l): |
| 93 | + xs[i - 1], xs[i], c, ok = bb.add(self._le_bloq, a=xs[i - 1], b=xs[i]) |
| 94 | + cs.append(c) |
| 95 | + oks.append(ok) |
| 96 | + |
| 97 | + oks, flag = bb.add(MultiControlX((1,) * len(oks)), controls=np.array(oks), target=flag) |
| 98 | + if not self.is_controlled: |
| 99 | + flag = bb.add(XGate(), q=flag) |
| 100 | + else: |
| 101 | + oks[0], flag = bb.add(CNOT(), ctrl=oks[0], target=flag) |
| 102 | + |
| 103 | + oks = list(oks) |
| 104 | + for i in reversed(range(1, self.l)): |
| 105 | + xs[i - 1], xs[i] = bb.add( |
| 106 | + self._le_bloq.adjoint(), a=xs[i - 1], b=xs[i], c=cs.pop(), target=oks.pop() |
| 107 | + ) |
| 108 | + |
| 109 | + if self.is_controlled: |
| 110 | + extra_soqs = {'ctrl': oks.pop()} |
| 111 | + assert not oks |
| 112 | + |
| 113 | + return {'xs': xs, 'flag': flag} | extra_soqs |
| 114 | + |
| 115 | + def build_call_graph(self, ssa: 'SympySymbolAllocator') -> BloqCountDictT: |
| 116 | + counts = Counter[Bloq]() |
| 117 | + |
| 118 | + counts[self._le_bloq] += self.l - 1 |
| 119 | + counts[self._le_bloq.adjoint()] += self.l - 1 |
| 120 | + |
| 121 | + n_ctrls = self.l - (1 if not self.is_controlled else 0) |
| 122 | + counts[MultiControlX(HasLength(n_ctrls))] += 1 |
| 123 | + |
| 124 | + counts[XGate() if not self.is_controlled else CNOT()] += 1 |
| 125 | + |
| 126 | + return counts |
| 127 | + |
| 128 | + def on_classical_vals(self, **vals: 'ClassicalValT') -> dict[str, 'ClassicalValT']: |
| 129 | + xs = np.asarray(vals['xs']) |
| 130 | + assert np.all(xs == np.sort(xs)) |
| 131 | + if np.any(xs[:-1] == xs[1:]): |
| 132 | + vals['flag'] ^= 1 |
| 133 | + return vals |
| 134 | + |
| 135 | + def adjoint(self) -> 'HasDuplicates': |
| 136 | + return self |
| 137 | + |
| 138 | + def get_ctrl_system(self, ctrl_spec: 'CtrlSpec') -> tuple['Bloq', 'AddControlledT']: |
| 139 | + from qualtran.bloqs.mcmt.specialized_ctrl import get_ctrl_system_1bit_cv_from_bloqs |
| 140 | + |
| 141 | + return get_ctrl_system_1bit_cv_from_bloqs( |
| 142 | + self, |
| 143 | + ctrl_spec, |
| 144 | + current_ctrl_bit=1 if self.is_controlled else None, |
| 145 | + bloq_with_ctrl=attrs.evolve(self, is_controlled=True), |
| 146 | + ctrl_reg_name='ctrl', |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +@bloq_example |
| 151 | +def _has_duplicates() -> HasDuplicates: |
| 152 | + has_duplicates = HasDuplicates(4, QUInt(3)) |
| 153 | + return has_duplicates |
| 154 | + |
| 155 | + |
| 156 | +@bloq_example |
| 157 | +def _has_duplicates_symb() -> HasDuplicates: |
| 158 | + import sympy |
| 159 | + |
| 160 | + n = sympy.Symbol("n") |
| 161 | + has_duplicates_symb = HasDuplicates(4, QUInt(n)) |
| 162 | + return has_duplicates_symb |
| 163 | + |
| 164 | + |
| 165 | +@bloq_example |
| 166 | +def _has_duplicates_symb_len() -> HasDuplicates: |
| 167 | + import sympy |
| 168 | + |
| 169 | + l, n = sympy.symbols("l n") |
| 170 | + has_duplicates_symb_len = HasDuplicates(l, QUInt(n)) |
| 171 | + return has_duplicates_symb_len |
| 172 | + |
| 173 | + |
| 174 | +_HAS_DUPLICATES_DOC = BloqDocSpec( |
| 175 | + bloq_cls=HasDuplicates, |
| 176 | + examples=[_has_duplicates_symb, _has_duplicates, _has_duplicates_symb_len], |
| 177 | +) |
0 commit comments