Skip to content

Commit

Permalink
Remove [float,int] unions from type declarations (quantumlib#7042)
Browse files Browse the repository at this point in the history
* Clean up redundant complex type unions

* format

* fix numpy ufunc, one mypy err

* remove float/int unions

* Fix merge

* Remove test

* lint
  • Loading branch information
daxfohl authored Feb 7, 2025
1 parent 7f46121 commit e2de439
Show file tree
Hide file tree
Showing 18 changed files with 59 additions and 96 deletions.
2 changes: 1 addition & 1 deletion cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __eq__(self, other) -> bool:
and all(m0 == m1 for m0, m1 in zip(self.moments, other.moments))
)

def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
def _approx_eq_(self, other: Any, atol: float) -> bool:
"""See `cirq.protocols.SupportsApproximateEquality`."""
if not isinstance(other, AbstractCircuit):
return NotImplemented
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/circuits/moment.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def __eq__(self, other) -> bool:

return self is other or self._sorted_operations_() == other._sorted_operations_()

def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
def _approx_eq_(self, other: Any, atol: float) -> bool:
"""See `cirq.protocols.SupportsApproximateEquality`."""
if not isinstance(other, type(self)):
return NotImplemented
Expand Down
58 changes: 13 additions & 45 deletions cirq-core/cirq/circuits/text_diagram_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)

import numpy as np
Expand All @@ -45,23 +44,11 @@

_HorizontalLine = NamedTuple(
'_HorizontalLine',
[
('y', Union[int, float]),
('x1', Union[int, float]),
('x2', Union[int, float]),
('emphasize', bool),
('doubled', bool),
],
[('y', float), ('x1', float), ('x2', float), ('emphasize', bool), ('doubled', bool)],
)
_VerticalLine = NamedTuple(
'_VerticalLine',
[
('x', Union[int, float]),
('y1', Union[int, float]),
('y2', Union[int, float]),
('emphasize', bool),
('doubled', bool),
],
[('x', float), ('y1', float), ('y2', float), ('emphasize', bool), ('doubled', bool)],
)
_DiagramText = NamedTuple('_DiagramText', [('text', str), ('transposed_text', str)])

Expand Down Expand Up @@ -99,10 +86,10 @@ def __init__(
self.vertical_lines: List[_VerticalLine] = (
[] if vertical_lines is None else list(vertical_lines)
)
self.horizontal_padding: Dict[int, Union[int, float]] = (
self.horizontal_padding: Dict[int, float] = (
dict() if horizontal_padding is None else dict(horizontal_padding)
)
self.vertical_padding: Dict[int, Union[int, float]] = (
self.vertical_padding: Dict[int, float] = (
dict() if vertical_padding is None else dict(vertical_padding)
)

Expand Down Expand Up @@ -171,24 +158,14 @@ def grid_line(
raise ValueError("Line is neither horizontal nor vertical")

def vertical_line(
self,
x: Union[int, float],
y1: Union[int, float],
y2: Union[int, float],
emphasize: bool = False,
doubled: bool = False,
self, x: float, y1: float, y2: float, emphasize: bool = False, doubled: bool = False
) -> None:
"""Adds a line from (x, y1) to (x, y2)."""
y1, y2 = sorted([y1, y2])
self.vertical_lines.append(_VerticalLine(x, y1, y2, emphasize, doubled))

def horizontal_line(
self,
y: Union[int, float],
x1: Union[int, float],
x2: Union[int, float],
emphasize: bool = False,
doubled: bool = False,
self, y: float, x1: float, x2: float, emphasize: bool = False, doubled: bool = False
) -> None:
"""Adds a line from (x1, y) to (x2, y)."""
x1, x2 = sorted([x1, x2])
Expand Down Expand Up @@ -228,26 +205,21 @@ def height(self) -> int:
max_y = max(max_y, v.y1, v.y2)
return 1 + int(max_y)

def force_horizontal_padding_after(self, index: int, padding: Union[int, float]) -> None:
def force_horizontal_padding_after(self, index: int, padding: float) -> None:
"""Change the padding after the given column."""
self.horizontal_padding[index] = padding

def force_vertical_padding_after(self, index: int, padding: Union[int, float]) -> None:
def force_vertical_padding_after(self, index: int, padding: float) -> None:
"""Change the padding after the given row."""
self.vertical_padding[index] = padding

def _transform_coordinates(
self,
func: Callable[
[Union[int, float], Union[int, float]], Tuple[Union[int, float], Union[int, float]]
],
) -> None:
def _transform_coordinates(self, func: Callable[[float, float], Tuple[float, float]]) -> None:
"""Helper method to transformer either row or column coordinates."""

def func_x(x: Union[int, float]) -> Union[int, float]:
def func_x(x: float) -> float:
return func(x, 0)[0]

def func_y(y: Union[int, float]) -> Union[int, float]:
def func_y(y: float) -> float:
return func(0, y)[1]

self.entries = {
Expand All @@ -271,19 +243,15 @@ def func_y(y: Union[int, float]) -> Union[int, float]:
def insert_empty_columns(self, x: int, amount: int = 1) -> None:
"""Insert a number of columns after the given column."""

def transform_columns(
column: Union[int, float], row: Union[int, float]
) -> Tuple[Union[int, float], Union[int, float]]:
def transform_columns(column: float, row: float) -> Tuple[float, float]:
return column + (amount if column >= x else 0), row

self._transform_coordinates(transform_columns)

def insert_empty_rows(self, y: int, amount: int = 1) -> None:
"""Insert a number of rows after the given row."""

def transform_rows(
column: Union[int, float], row: Union[int, float]
) -> Tuple[Union[int, float], Union[int, float]]:
def transform_rows(column: float, row: float) -> Tuple[float, float]:
return column, row + (amount if row >= y else 0)

self._transform_coordinates(transform_rows)
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/linalg/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ class EntangledStateError(ValueError):


def partial_trace_of_state_vector_as_mixture(
state_vector: np.ndarray, keep_indices: List[int], *, atol: Union[int, float] = 1e-8
state_vector: np.ndarray, keep_indices: List[int], *, atol: float = 1e-8
) -> Tuple[Tuple[float, np.ndarray], ...]:
"""Returns a mixture representing a state vector with only some qubits kept.
Expand Down Expand Up @@ -481,7 +481,7 @@ def sub_state_vector(
keep_indices: List[int],
*,
default: np.ndarray = RaiseValueErrorIfNotProvided,
atol: Union[int, float] = 1e-6,
atol: float = 1e-6,
) -> np.ndarray:
r"""Attempts to factor a state vector into two parts and return one of them.
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/clifford_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def _to_phased_xz_gate(self) -> phased_x_z_gate.PhasedXZGate:
z = -0.5 if x_to_flip else 0.5
return phased_x_z_gate.PhasedXZGate(x_exponent=x, z_exponent=z, axis_phase_exponent=a)

def __pow__(self, exponent: Union[float, int]) -> 'SingleQubitCliffordGate':
def __pow__(self, exponent: float) -> 'SingleQubitCliffordGate':
if int(exponent) == exponent:
# The single qubit Clifford gates are a group of size 24
ret_gate = super().__pow__(int(exponent) % 24)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/dense_pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool)
def __pos__(self):
return self

def __pow__(self, power: Union[int, float]) -> Union[NotImplementedType, Self]:
def __pow__(self, power: float) -> Union[NotImplementedType, Self]:
concrete_class = type(self)
if isinstance(power, int):
i_group = [1, +1j, -1, -1j]
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/gate_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
return protocols.qasm(self.gate, args=args, qubits=self.qubits, default=None)

def _equal_up_to_global_phase_(
self, other: Any, atol: Union[int, float] = 1e-8
self, other: Any, atol: float = 1e-8
) -> Union[NotImplementedType, bool]:
if not isinstance(other, type(self)):
return NotImplemented
Expand Down
20 changes: 10 additions & 10 deletions cirq-core/cirq/ops/pauli_string_phasor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(
pauli_string: ps.PauliString,
qubits: Optional[Sequence['cirq.Qid']] = None,
*,
exponent_neg: Union[int, float, sympy.Expr] = 1,
exponent_pos: Union[int, float, sympy.Expr] = 0,
exponent_neg: 'cirq.TParamVal' = 1,
exponent_pos: 'cirq.TParamVal' = 0,
) -> None:
"""Initializes the operation.
Expand Down Expand Up @@ -112,12 +112,12 @@ def gate(self) -> 'cirq.PauliStringPhasorGate':
return cast(PauliStringPhasorGate, self._gate)

@property
def exponent_neg(self) -> Union[int, float, sympy.Expr]:
def exponent_neg(self) -> 'cirq.TParamVal':
"""The negative exponent."""
return self.gate.exponent_neg

@property
def exponent_pos(self) -> Union[int, float, sympy.Expr]:
def exponent_pos(self) -> 'cirq.TParamVal':
"""The positive exponent."""
return self.gate.exponent_pos

Expand All @@ -127,7 +127,7 @@ def pauli_string(self) -> 'cirq.PauliString':
return self._pauli_string

@property
def exponent_relative(self) -> Union[int, float, sympy.Expr]:
def exponent_relative(self) -> 'cirq.TParamVal':
"""The relative exponent between negative and positive exponents."""
return self.gate.exponent_relative

Expand Down Expand Up @@ -278,8 +278,8 @@ def __init__(
self,
dense_pauli_string: dps.DensePauliString,
*,
exponent_neg: Union[int, float, sympy.Expr] = 1,
exponent_pos: Union[int, float, sympy.Expr] = 0,
exponent_neg: 'cirq.TParamVal' = 1,
exponent_pos: 'cirq.TParamVal' = 0,
) -> None:
"""Initializes the PauliStringPhasorGate.
Expand Down Expand Up @@ -309,17 +309,17 @@ def __init__(
self._exponent_pos = value.canonicalize_half_turns(exponent_pos)

@property
def exponent_relative(self) -> Union[int, float, sympy.Expr]:
def exponent_relative(self) -> 'cirq.TParamVal':
"""The relative exponent between negative and positive exponents."""
return value.canonicalize_half_turns(self.exponent_neg - self.exponent_pos)

@property
def exponent_neg(self) -> Union[int, float, sympy.Expr]:
def exponent_neg(self) -> 'cirq.TParamVal':
"""The negative exponent."""
return self._exponent_neg

@property
def exponent_pos(self) -> Union[int, float, sympy.Expr]:
def exponent_pos(self) -> 'cirq.TParamVal':
"""The positive exponent."""
return self._exponent_pos

Expand Down
5 changes: 2 additions & 3 deletions cirq-core/cirq/ops/pauli_sum_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, Tuple, Union, TYPE_CHECKING
from typing import Any, Iterator, Tuple, TYPE_CHECKING

import numpy as np
import sympy

from cirq import linalg, protocols, value, _compat
from cirq.ops import linear_combinations, pauli_string_phasor
Expand Down Expand Up @@ -45,7 +44,7 @@ class returns an operation which is equivalent to
def __init__(
self,
pauli_sum_like: 'cirq.PauliSumLike',
exponent: Union[int, float, sympy.Expr] = 1,
exponent: 'cirq.TParamVal' = 1,
atol: float = 1e-8,
):
pauli_sum = linear_combinations.PauliSum.wrap(pauli_sum_like)
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/phased_x_z_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def _decompose_(self, qubits: Sequence['cirq.Qid']) -> Iterator['cirq.OP_TREE']:
yield ops.X(q) ** self._x_exponent
yield ops.Z(q) ** (self._axis_phase_exponent + self._z_exponent)

def __pow__(self, exponent: Union[float, int]) -> 'PhasedXZGate':
def __pow__(self, exponent: float) -> 'PhasedXZGate':
if exponent == 1:
return self
if exponent == -1:
Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/ops/raw_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def _qid_shape_(self) -> Tuple[int, ...]:
raise NotImplementedError

def _equal_up_to_global_phase_(
self, other: Any, atol: Union[int, float] = 1e-8
self, other: Any, atol: float = 1e-8
) -> Union[NotImplementedType, bool]:
"""Default fallback for gates that do not implement this protocol."""
try:
Expand Down Expand Up @@ -997,7 +997,7 @@ def _qasm_(self, args: 'protocols.QasmArgs') -> Optional[str]:
return protocols.qasm(self.sub_operation, args=args, default=None)

def _equal_up_to_global_phase_(
self, other: Any, atol: Union[int, float] = 1e-8
self, other: Any, atol: float = 1e-8
) -> Union[NotImplementedType, bool]:
return protocols.equal_up_to_global_phase(self.sub_operation, other, atol=atol)

Expand Down
12 changes: 5 additions & 7 deletions cirq-core/cirq/ops/wait_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union

import sympy
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING

from cirq import value, protocols
from cirq.ops import raw_types
Expand Down Expand Up @@ -136,10 +134,10 @@ def _value_equality_values_(self) -> Any:
def wait(
*target: 'cirq.Qid',
duration: 'cirq.DURATION_LIKE' = None,
picos: Union[int, float, sympy.Expr] = 0,
nanos: Union[int, float, sympy.Expr] = 0,
micros: Union[int, float, sympy.Expr] = 0,
millis: Union[int, float, sympy.Expr] = 0,
picos: 'cirq.TParamVal' = 0,
nanos: 'cirq.TParamVal' = 0,
micros: 'cirq.TParamVal' = 0,
millis: 'cirq.TParamVal' = 0,
) -> raw_types.Operation:
"""Creates a WaitGate applied to all the given qubits.
Expand Down
10 changes: 5 additions & 5 deletions cirq-core/cirq/protocols/approximate_equality_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Union, Iterable
from typing import Any, Iterable
from fractions import Fraction
from decimal import Decimal

Expand All @@ -29,7 +29,7 @@ class SupportsApproximateEquality(Protocol):
"""Object which can be compared approximately."""

@doc_private
def _approx_eq_(self, other: Any, *, atol: Union[int, float]) -> bool:
def _approx_eq_(self, other: Any, *, atol: float) -> bool:
"""Approximate comparator.
Types implementing this protocol define their own logic for approximate
Expand All @@ -47,7 +47,7 @@ def _approx_eq_(self, other: Any, *, atol: Union[int, float]) -> bool:
"""


def approx_eq(val: Any, other: Any, *, atol: Union[int, float] = 1e-8) -> bool:
def approx_eq(val: Any, other: Any, *, atol: float = 1e-8) -> bool:
"""Approximately compares two objects.
If `val` implements SupportsApproxEquality protocol then it is invoked and
Expand Down Expand Up @@ -120,7 +120,7 @@ def approx_eq(val: Any, other: Any, *, atol: Union[int, float] = 1e-8) -> bool:
return val == other


def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: Union[int, float]) -> bool:
def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: float) -> bool:
"""Iterates over arguments and calls approx_eq recursively.
Types of `val` and `other` does not necessarily needs to match each other.
Expand Down Expand Up @@ -161,7 +161,7 @@ def _approx_eq_iterables(val: Iterable, other: Iterable, *, atol: Union[int, flo
return True


def _isclose(a: Any, b: Any, *, atol: Union[int, float]) -> bool:
def _isclose(a: Any, b: Any, *, atol: float) -> bool:
"""Convenience wrapper around np.isclose."""

# support casting some standard numeric types
Expand Down
Loading

0 comments on commit e2de439

Please sign in to comment.