Skip to content

Commit

Permalink
Allow symbolic scalars in LinearDict (quantumlib#7003)
Browse files Browse the repository at this point in the history
* Parameterize LinearDict

* Add protocol handlers

* fix protocol handlers

* fix protocol handlers, add test

* format

* tests

* Y, Z gates, parameterize test

* mypy

* Revert changes to Scalar, and just use TParamValComplex everywhere. Add tests.

* nits
  • Loading branch information
daxfohl authored Feb 4, 2025
1 parent 3f67923 commit 5c198ce
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 37 deletions.
21 changes: 11 additions & 10 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,12 @@ def controlled(
return result

def _pauli_expansion_(self) -> value.LinearDict[str]:
if protocols.is_parameterized(self) or self._dimension != 2:
if self._dimension != 2:
return NotImplemented
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
angle = np.pi * self._exponent / 2
return value.LinearDict({'I': phase * np.cos(angle), 'X': -1j * phase * np.sin(angle)})
angle = _pi(self._exponent) * self._exponent / 2
lib = sympy if protocols.is_parameterized(self) else np
return value.LinearDict({'I': phase * lib.cos(angle), 'X': -1j * phase * lib.sin(angle)})

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
Expand Down Expand Up @@ -464,11 +465,10 @@ def _trace_distance_bound_(self) -> Optional[float]:
return abs(np.sin(self._exponent * 0.5 * np.pi))

def _pauli_expansion_(self) -> value.LinearDict[str]:
if protocols.is_parameterized(self):
return NotImplemented
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
angle = np.pi * self._exponent / 2
return value.LinearDict({'I': phase * np.cos(angle), 'Y': -1j * phase * np.sin(angle)})
angle = _pi(self._exponent) * self._exponent / 2
lib = sympy if protocols.is_parameterized(self) else np
return value.LinearDict({'I': phase * lib.cos(angle), 'Y': -1j * phase * lib.sin(angle)})

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
Expand Down Expand Up @@ -764,11 +764,12 @@ def _trace_distance_bound_(self) -> Optional[float]:
return abs(np.sin(self._exponent * 0.5 * np.pi))

def _pauli_expansion_(self) -> value.LinearDict[str]:
if protocols.is_parameterized(self) or self._dimension != 2:
if self._dimension != 2:
return NotImplemented
phase = 1j ** (2 * self._exponent * (self._global_shift + 0.5))
angle = np.pi * self._exponent / 2
return value.LinearDict({'I': phase * np.cos(angle), 'Z': -1j * phase * np.sin(angle)})
angle = _pi(self._exponent) * self._exponent / 2
lib = sympy if protocols.is_parameterized(self) else np
return value.LinearDict({'I': phase * lib.cos(angle), 'Z': -1j * phase * lib.sin(angle)})

def _phase_by_(self, phase_turns: float, qubit_index: int):
return self
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,13 @@ def test_wrong_dims():

with pytest.raises(ValueError, match='Wrong shape'):
_ = cirq.Z.on(cirq.LineQid(0, dimension=3))


@pytest.mark.parametrize('gate_type', [cirq.XPowGate, cirq.YPowGate, cirq.ZPowGate])
@pytest.mark.parametrize('exponent', [sympy.Symbol('s'), sympy.Symbol('s') * 2])
def test_parameterized_pauli_expansion(gate_type, exponent):
gate = gate_type(exponent=exponent)
pauli = cirq.pauli_expansion(gate)
gate_resolved = cirq.resolve_parameters(gate, {'s': 0.5})
pauli_resolved = cirq.resolve_parameters(pauli, {'s': 0.5})
assert cirq.approx_eq(pauli_resolved, cirq.pauli_expansion(gate_resolved))
102 changes: 80 additions & 22 deletions cirq-core/cirq/value/linear_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Linear combination represented as mapping of things to coefficients."""

from typing import (
AbstractSet,
Any,
Callable,
Dict,
Expand All @@ -28,21 +29,43 @@
Optional,
overload,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
ValuesView,
)
from typing_extensions import Self

import numpy as np
import sympy
from cirq import protocols

if TYPE_CHECKING:
import cirq

Scalar = Union[complex, np.number]
TVector = TypeVar('TVector')

TDefault = TypeVar('TDefault')


def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
class _SympyPrinter(sympy.printing.str.StrPrinter):
def __init__(self, format_spec: str):
super().__init__()
self._format_spec = format_spec

def _print(self, expr, **kwargs):
if expr.is_complex:
coefficient = complex(expr)
s = _format_coefficient(self._format_spec, coefficient)
return s[1:-1] if s.startswith('(') else s
return super()._print(expr, **kwargs)


def _format_coefficient(format_spec: str, coefficient: 'cirq.TParamValComplex') -> str:
if isinstance(coefficient, sympy.Basic):
printer = _SympyPrinter(format_spec)
return printer.doprint(coefficient)
coefficient = complex(coefficient)
real_str = f'{coefficient.real:{format_spec}}'
imag_str = f'{coefficient.imag:{format_spec}}'
Expand All @@ -59,7 +82,7 @@ def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
return f'({real_str}+{imag_str}j)'


def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
def _format_term(format_spec: str, vector: TVector, coefficient: 'cirq.TParamValComplex') -> str:
coefficient_str = _format_coefficient(format_spec, coefficient)
if not coefficient_str:
return coefficient_str
Expand All @@ -69,7 +92,7 @@ def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
return '+' + result


def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
def _format_terms(terms: Iterable[Tuple[TVector, 'cirq.TParamValComplex']], format_spec: str):
formatted_terms = [_format_term(format_spec, vector, coeff) for vector, coeff in terms]
s = ''.join(formatted_terms)
if not s:
Expand All @@ -79,7 +102,7 @@ def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
return s


class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):
class LinearDict(Generic[TVector], MutableMapping[TVector, 'cirq.TParamValComplex']):
"""Represents linear combination of things.
LinearDict implements the basic linear algebraic operations of vector
Expand All @@ -96,7 +119,7 @@ class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):

def __init__(
self,
terms: Optional[Mapping[TVector, Scalar]] = None,
terms: Optional[Mapping[TVector, 'cirq.TParamValComplex']] = None,
validator: Optional[Callable[[TVector], bool]] = None,
) -> None:
"""Initializes linear combination from a collection of terms.
Expand All @@ -112,21 +135,30 @@ def __init__(
"""
self._has_validator = validator is not None
self._is_valid = validator or (lambda x: True)
self._terms: Dict[TVector, Scalar] = {}
self._terms: Dict[TVector, 'cirq.TParamValComplex'] = {}
if terms is not None:
self.update(terms)

@classmethod
def fromkeys(cls, vectors, coefficient=0):
return LinearDict(dict.fromkeys(vectors, complex(coefficient)))
return LinearDict(
dict.fromkeys(
vectors,
coefficient if isinstance(coefficient, sympy.Basic) else complex(coefficient),
)
)

def _check_vector_valid(self, vector: TVector) -> None:
if not self._is_valid(vector):
raise ValueError(f'{vector} is not compatible with linear combination {self}')

def clean(self, *, atol: float = 1e-9) -> Self:
"""Remove terms with coefficients of absolute value atol or less."""
negligible = [v for v, c in self._terms.items() if abs(complex(c)) <= atol]
negligible = [
v
for v, c in self._terms.items()
if not isinstance(c, sympy.Basic) and abs(complex(c)) <= atol
]
for v in negligible:
del self._terms[v]
return self
Expand All @@ -139,40 +171,50 @@ def keys(self) -> KeysView[TVector]:
snapshot = self.copy().clean(atol=0)
return snapshot._terms.keys()

def values(self) -> ValuesView[Scalar]:
def values(self) -> ValuesView['cirq.TParamValComplex']:
snapshot = self.copy().clean(atol=0)
return snapshot._terms.values()

def items(self) -> ItemsView[TVector, Scalar]:
def items(self) -> ItemsView[TVector, 'cirq.TParamValComplex']:
snapshot = self.copy().clean(atol=0)
return snapshot._terms.items()

# pylint: disable=function-redefined
@overload
def update(self, other: Mapping[TVector, Scalar], **kwargs: Scalar) -> None:
def update(
self, other: Mapping[TVector, 'cirq.TParamValComplex'], **kwargs: 'cirq.TParamValComplex'
) -> None:
pass

@overload
def update(self, other: Iterable[Tuple[TVector, Scalar]], **kwargs: Scalar) -> None:
def update(
self,
other: Iterable[Tuple[TVector, 'cirq.TParamValComplex']],
**kwargs: 'cirq.TParamValComplex',
) -> None:
pass

@overload
def update(self, *args: Any, **kwargs: Scalar) -> None:
def update(self, *args: Any, **kwargs: 'cirq.TParamValComplex') -> None:
pass

def update(self, *args, **kwargs):
terms = dict()
terms.update(*args, **kwargs)
for vector, coefficient in terms.items():
if isinstance(coefficient, sympy.Basic):
coefficient = sympy.simplify(coefficient)
if coefficient.is_complex:
coefficient = complex(coefficient)
self[vector] = coefficient
self.clean(atol=0)

@overload
def get(self, vector: TVector) -> Scalar:
def get(self, vector: TVector) -> 'cirq.TParamValComplex':
pass

@overload
def get(self, vector: TVector, default: TDefault) -> Union[Scalar, TDefault]:
def get(self, vector: TVector, default: TDefault) -> Union['cirq.TParamValComplex', TDefault]:
pass

def get(self, vector, default=0):
Expand All @@ -185,10 +227,10 @@ def get(self, vector, default=0):
def __contains__(self, vector: Any) -> bool:
return vector in self._terms and self._terms[vector] != 0

def __getitem__(self, vector: TVector) -> Scalar:
def __getitem__(self, vector: TVector) -> 'cirq.TParamValComplex':
return self._terms.get(vector, 0)

def __setitem__(self, vector: TVector, coefficient: Scalar) -> None:
def __setitem__(self, vector: TVector, coefficient: 'cirq.TParamValComplex') -> None:
self._check_vector_valid(vector)
if coefficient != 0:
self._terms[vector] = coefficient
Expand Down Expand Up @@ -236,21 +278,21 @@ def __neg__(self) -> Self:
factory = type(self)
return factory({v: -c for v, c in self.items()})

def __imul__(self, a: Scalar) -> Self:
def __imul__(self, a: 'cirq.TParamValComplex') -> Self:
for vector in self:
self._terms[vector] *= a
self.clean(atol=0)
return self

def __mul__(self, a: Scalar) -> Self:
def __mul__(self, a: 'cirq.TParamValComplex') -> Self:
result = self.copy()
result *= a
return result
return result.copy()

def __rmul__(self, a: Scalar) -> Self: # type: ignore
def __rmul__(self, a: 'cirq.TParamValComplex') -> Self:
return self.__mul__(a)

def __truediv__(self, a: Scalar) -> Self:
def __truediv__(self, a: 'cirq.TParamValComplex') -> Self:
return self.__mul__(1 / a)

def __bool__(self) -> bool:
Expand Down Expand Up @@ -320,3 +362,19 @@ def _json_dict_(self) -> Dict[Any, Any]:
@classmethod
def _from_json_dict_(cls, keys, values, **kwargs):
return cls(terms=dict(zip(keys, values)))

def _is_parameterized_(self) -> bool:
return any(protocols.is_parameterized(v) for v in self._terms.values())

def _parameter_names_(self) -> AbstractSet[str]:
return set(name for v in self._terms.values() for name in protocols.parameter_names(v))

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'LinearDict':
result = self.copy()
result.update(
{
k: protocols.resolve_parameters(v, resolver, recursive)
for k, v in self._terms.items()
}
)
return result
Loading

0 comments on commit 5c198ce

Please sign in to comment.