Skip to content

Commit

Permalink
Allow ability to plug in custom (de)serializers for cirq_google protos (
Browse files Browse the repository at this point in the history
quantumlib#7059)

* Allow ability to plug in custom (de)serializers for cirq_google protos

- This will allow users to plug in custom serializers and deserializers,
which can parse gates before falling back to the default.
- This enables internal libraries to parse and deserialize non-public
gates, tags, and operations.

* Fix coverage and get rid of unneeded junk.

* Address comments.

* Flip warnings.
  • Loading branch information
dstrain115 authored Feb 16, 2025
1 parent 2e71950 commit ca6ceb3
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 85 deletions.
89 changes: 66 additions & 23 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,23 @@ class CircuitSerializer(serializer.Serializer):
serialization of duplicate operations as entries in the constant table.
This flag will soon become the default and disappear as soon as
deserialization of this field is deployed.
op_serializer: Optional custom serializer for serializing unknown gates.
op_deserializer: Optional custom deserializer for deserializing unknown gates.
"""

def __init__(
self, USE_CONSTANTS_TABLE_FOR_MOMENTS=False, USE_CONSTANTS_TABLE_FOR_OPERATIONS=False
self,
USE_CONSTANTS_TABLE_FOR_MOMENTS=False,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=False,
op_serializer: Optional[op_serializer.OpSerializer] = None,
op_deserializer: Optional[op_deserializer.OpDeserializer] = None,
):
"""Construct the circuit serializer object."""
super().__init__(gate_set_name=_SERIALIZER_NAME)
self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS
self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS
self.op_serializer = op_serializer
self.op_deserializer = op_deserializer

def serialize(
self,
Expand Down Expand Up @@ -144,26 +152,44 @@ def _serialize_circuit(
moment_proto.operation_indices.append(op_index)
else:
op_pb = v2.program_pb2.Operation()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
if self.op_serializer and self.op_serializer.can_serialize_operation(op):
self.op_serializer.to_proto(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
else:
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)
constants.append(v2.program_pb2.Constant(operation_value=op_pb))
op_index = len(constants) - 1
raw_constants[op] = op_index
moment_proto.operation_indices.append(op_index)
else:
op_pb = moment_proto.operations.add()
self._serialize_gate_op(
op,
op_pb,
arg_function_language=arg_function_language,
constants=constants,
raw_constants=raw_constants,
)

if self.use_constants_table_for_moments:
# Add this moment to the constants table
Expand Down Expand Up @@ -469,14 +495,23 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit:
elif which_const == 'qubit':
deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id))
elif which_const == 'operation_value':
deserialized_constants.append(
self._deserialize_gate_op(
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(
constant.operation_value
):
op_pb = self.op_deserializer.from_proto(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
)
else:
op_pb = self._deserialize_gate_op(
constant.operation_value,
arg_function_language=arg_func_language,
constants=proto.constants,
deserialized_constants=deserialized_constants,
)
deserialized_constants.append(op_pb)
elif which_const == 'moment_value':
deserialized_constants.append(
self._deserialize_moment(
Expand Down Expand Up @@ -541,12 +576,20 @@ def _deserialize_moment(
) -> cirq.Moment:
moment_ops = []
for op in moment_proto.operations:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if self.op_deserializer and self.op_deserializer.can_deserialize_proto(op):
gate_op = self.op_deserializer.from_proto(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
else:
gate_op = self._deserialize_gate_op(
op,
arg_function_language=arg_function_language,
constants=constants,
deserialized_constants=deserialized_constants,
)
if op.tag_indices:
tags = [
deserialized_constants[tag_index]
Expand Down
93 changes: 92 additions & 1 deletion cirq-google/cirq_google/serialization/circuit_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional
import pytest

import numpy as np
Expand All @@ -25,6 +25,8 @@
import cirq_google as cg
from cirq_google.api import v2
from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME
from cirq_google.serialization.op_deserializer import OpDeserializer
from cirq_google.serialization.op_serializer import OpSerializer


class FakeDevice(cirq.Device):
Expand Down Expand Up @@ -856,6 +858,7 @@ def test_circuit_with_tag(tag):
assert nc[0].operations[0].tags == (tag,)


@pytest.mark.filterwarnings('ignore:Unrecognized Tag .*DingDongTag')
def test_unknown_tag_is_ignored():
class DingDongTag:
pass
Expand All @@ -866,6 +869,7 @@ class DingDongTag:
assert cirq.Circuit(cirq.X(cirq.q(0))) == nc


@pytest.mark.filterwarnings('ignore:Unknown tag msg=phase_match')
def test_unrecognized_tag_is_ignored():
op_tag = v2.program_pb2.Operation()
op_tag.xpowgate.exponent.float_value = 1.0
Expand Down Expand Up @@ -917,3 +921,90 @@ def test_circuit_with_units():
)
msg = cg.CIRCUIT_SERIALIZER.serialize(c)
assert c == cg.CIRCUIT_SERIALIZER.deserialize(msg)


class BingBongGate(cirq.Gate):

def __init__(self, param: float):
self.param = param

def _num_qubits_(self) -> int:
return 1


class BingBongSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

def can_serialize_operation(self, op):
return isinstance(op.gate, BingBongGate)

def to_proto(
self,
op: cirq.CircuitOperation,
msg: Optional[v2.program_pb2.CircuitOperation] = None,
*,
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> v2.program_pb2.CircuitOperation:
assert isinstance(op.gate, BingBongGate)
if msg is None:
msg = v2.program_pb2.Operation() # pragma: nocover
msg.internalgate.name = 'bingbong'
msg.internalgate.module = 'test'
msg.internalgate.num_qubits = 1
msg.internalgate.gate_args['param'].arg_value.float_value = op.gate.param

for qubit in op.qubits:
if qubit not in raw_constants:
constants.append(
v2.program_pb2.Constant(
qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit))
)
)
raw_constants[qubit] = len(constants) - 1
msg.qubit_constant_index.append(raw_constants[qubit])
return msg


class BingBongDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

def can_deserialize_proto(self, proto):
return (
isinstance(proto, v2.program_pb2.Operation)
and proto.WhichOneof("gate_value") == "internalgate"
and proto.internalgate.name == 'bingbong'
and proto.internalgate.module == 'test'
)

def from_proto(
self,
proto: v2.program_pb2.Operation,
*,
arg_function_language: str = '',
constants: List[v2.program_pb2.Constant],
deserialized_constants: List[Any],
) -> cirq.Operation:
return BingBongGate(param=proto.internalgate.gate_args["param"].arg_value.float_value).on(
deserialized_constants[proto.qubit_constant_index[0]]
)


@pytest.mark.parametrize('use_constants_table', [True, False])
def test_custom_serializer(use_constants_table: bool):
c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0)))
serializer = cg.CircuitSerializer(
USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table,
USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table,
op_serializer=BingBongSerializer(),
op_deserializer=BingBongDeserializer(),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
moment = deserialized_circuit[0]
assert len(moment) == 1
op = moment[cirq.q(0, 0)]
assert isinstance(op.gate, BingBongGate)
assert op.gate.param == 2.5
assert op.qubits == (cirq.q(0, 0),)
17 changes: 5 additions & 12 deletions cirq-google/cirq_google/serialization/op_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,12 @@ class OpDeserializer(abc.ABC):
"""Generic supertype for operation deserializers.
Each operation deserializer describes how to deserialize operation protos
with a particular `serialized_id` to a specific type of Cirq operation.
to a specific type of Cirq operation.
"""

@property
@abc.abstractmethod
def serialized_id(self) -> str:
"""Returns the string identifier for the accepted serialized objects.
This ID denotes the serialization format this deserializer consumes. For
example, one of the common deserializers converts objects with the id
'xy' into PhasedXPowGates.
"""
def can_deserialize_proto(self, proto) -> bool:
"""Whether the given operation can be serialized by this serializer."""

@abc.abstractmethod
def from_proto(
Expand Down Expand Up @@ -66,9 +60,8 @@ def from_proto(
class CircuitOpDeserializer(OpDeserializer):
"""Describes how to serialize CircuitOperations."""

@property
def serialized_id(self):
return 'circuit'
def can_deserialize_proto(self, proto):
return isinstance(proto, v2.program_pb2.CircuitOperation) # pragma: nocover

def from_proto(
self,
Expand Down
47 changes: 4 additions & 43 deletions cirq-google/cirq_google/serialization/op_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from typing import Any, Dict, List, Optional, Union
import numbers

import abc
Expand All @@ -23,9 +23,6 @@
from cirq_google.api import v2
from cirq_google.serialization.arg_func_langs import arg_to_proto

# Type for variables that are subclasses of ops.Gate.
Gate = TypeVar('Gate', bound=cirq.Gate)


class OpSerializer(abc.ABC):
"""Generic supertype for operation serializers.
Expand All @@ -35,25 +32,6 @@ class OpSerializer(abc.ABC):
may serialize to the same format.
"""

@property
@abc.abstractmethod
def internal_type(self) -> Type:
"""Returns the type that the operation contains.
For GateOperations, this is the gate type.
For CircuitOperations, this is FrozenCircuit.
"""

@property
@abc.abstractmethod
def serialized_id(self) -> str:
"""Returns the string identifier for the resulting serialized object.
This ID denotes the serialization format this serializer produces. For
example, one of the common serializers assigns the id 'xy' to XPowGates,
as they serialize into a format also used by YPowGates.
"""

@abc.abstractmethod
def to_proto(
self,
Expand All @@ -63,7 +41,7 @@ def to_proto(
arg_function_language: Optional[str] = '',
constants: List[v2.program_pb2.Constant],
raw_constants: Dict[Any, int],
) -> Optional[v2.program_pb2.CircuitOperation]:
) -> Optional[Union[v2.program_pb2.CircuitOperation, v2.program_pb2.Operation]]:
"""Converts op to proto using this serializer.
If self.can_serialize_operation(op) == false, this should return None.
Expand All @@ -83,33 +61,16 @@ def to_proto(
the returned object.
"""

@property
@abc.abstractmethod
def can_serialize_predicate(self) -> Callable[[cirq.Operation], bool]:
"""The method used to determine if this can serialize an operation.
Depending on the serializer, additional checks may be required.
"""

def can_serialize_operation(self, op: cirq.Operation) -> bool:
"""Whether the given operation can be serialized by this serializer."""
return self.can_serialize_predicate(op)


class CircuitOpSerializer(OpSerializer):
"""Describes how to serialize CircuitOperations."""

@property
def internal_type(self):
return cirq.FrozenCircuit

@property
def serialized_id(self):
return 'circuit'

@property
def can_serialize_predicate(self):
return lambda op: isinstance(op.untagged, cirq.CircuitOperation)
def can_serialize_operation(self, op: cirq.Operation):
return isinstance(op.untagged, cirq.CircuitOperation)

def to_proto(
self,
Expand Down
6 changes: 0 additions & 6 deletions cirq-google/cirq_google/serialization/op_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def default_circuit():
)


def test_circuit_op_serializer_properties():
serializer = cg.CircuitOpSerializer()
assert serializer.internal_type == cirq.FrozenCircuit
assert serializer.serialized_id == 'circuit'


def test_can_serialize_circuit_op():
serializer = cg.CircuitOpSerializer()
assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))
Expand Down

0 comments on commit ca6ceb3

Please sign in to comment.