diff --git a/cirq-google/cirq_google/serialization/circuit_serializer.py b/cirq-google/cirq_google/serialization/circuit_serializer.py index 491932b7d26..d0c43c9e727 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer.py @@ -56,15 +56,23 @@ class CircuitSerializer(serializer.Serializer): serialization of duplicate operations as entries in the constant table. This flag will soon become the default and disappear as soon as deserialization of this field is deployed. + op_serializer: Optional custom serializer for serializing unknown gates. + op_deserializer: Optional custom deserializer for deserializing unknown gates. """ def __init__( - self, USE_CONSTANTS_TABLE_FOR_MOMENTS=False, USE_CONSTANTS_TABLE_FOR_OPERATIONS=False + self, + USE_CONSTANTS_TABLE_FOR_MOMENTS=False, + USE_CONSTANTS_TABLE_FOR_OPERATIONS=False, + op_serializer: Optional[op_serializer.OpSerializer] = None, + op_deserializer: Optional[op_deserializer.OpDeserializer] = None, ): """Construct the circuit serializer object.""" super().__init__(gate_set_name=_SERIALIZER_NAME) self.use_constants_table_for_moments = USE_CONSTANTS_TABLE_FOR_MOMENTS self.use_constants_table_for_operations = USE_CONSTANTS_TABLE_FOR_OPERATIONS + self.op_serializer = op_serializer + self.op_deserializer = op_deserializer def serialize( self, @@ -144,6 +152,37 @@ def _serialize_circuit( moment_proto.operation_indices.append(op_index) else: op_pb = v2.program_pb2.Operation() + if self.op_serializer and self.op_serializer.can_serialize_operation(op): + self.op_serializer.to_proto( + op, + op_pb, + arg_function_language=arg_function_language, + constants=constants, + raw_constants=raw_constants, + ) + else: + self._serialize_gate_op( + op, + op_pb, + arg_function_language=arg_function_language, + constants=constants, + raw_constants=raw_constants, + ) + constants.append(v2.program_pb2.Constant(operation_value=op_pb)) + op_index = len(constants) - 1 + raw_constants[op] = op_index + moment_proto.operation_indices.append(op_index) + else: + op_pb = moment_proto.operations.add() + if self.op_serializer and self.op_serializer.can_serialize_operation(op): + self.op_serializer.to_proto( + op, + op_pb, + arg_function_language=arg_function_language, + constants=constants, + raw_constants=raw_constants, + ) + else: self._serialize_gate_op( op, op_pb, @@ -151,19 +190,6 @@ def _serialize_circuit( constants=constants, raw_constants=raw_constants, ) - constants.append(v2.program_pb2.Constant(operation_value=op_pb)) - op_index = len(constants) - 1 - raw_constants[op] = op_index - moment_proto.operation_indices.append(op_index) - else: - op_pb = moment_proto.operations.add() - self._serialize_gate_op( - op, - op_pb, - arg_function_language=arg_function_language, - constants=constants, - raw_constants=raw_constants, - ) if self.use_constants_table_for_moments: # Add this moment to the constants table @@ -469,14 +495,23 @@ def deserialize(self, proto: v2.program_pb2.Program) -> cirq.Circuit: elif which_const == 'qubit': deserialized_constants.append(v2.qubit_from_proto_id(constant.qubit.id)) elif which_const == 'operation_value': - deserialized_constants.append( - self._deserialize_gate_op( + if self.op_deserializer and self.op_deserializer.can_deserialize_proto( + constant.operation_value + ): + op_pb = self.op_deserializer.from_proto( constant.operation_value, arg_function_language=arg_func_language, constants=proto.constants, deserialized_constants=deserialized_constants, ) - ) + else: + op_pb = self._deserialize_gate_op( + constant.operation_value, + arg_function_language=arg_func_language, + constants=proto.constants, + deserialized_constants=deserialized_constants, + ) + deserialized_constants.append(op_pb) elif which_const == 'moment_value': deserialized_constants.append( self._deserialize_moment( @@ -541,12 +576,20 @@ def _deserialize_moment( ) -> cirq.Moment: moment_ops = [] for op in moment_proto.operations: - gate_op = self._deserialize_gate_op( - op, - arg_function_language=arg_function_language, - constants=constants, - deserialized_constants=deserialized_constants, - ) + if self.op_deserializer and self.op_deserializer.can_deserialize_proto(op): + gate_op = self.op_deserializer.from_proto( + op, + arg_function_language=arg_function_language, + constants=constants, + deserialized_constants=deserialized_constants, + ) + else: + gate_op = self._deserialize_gate_op( + op, + arg_function_language=arg_function_language, + constants=constants, + deserialized_constants=deserialized_constants, + ) if op.tag_indices: tags = [ deserialized_constants[tag_index] diff --git a/cirq-google/cirq_google/serialization/circuit_serializer_test.py b/cirq-google/cirq_google/serialization/circuit_serializer_test.py index 3ae1e097a9f..d2c06e78b07 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer_test.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import pytest import numpy as np @@ -25,6 +25,8 @@ import cirq_google as cg from cirq_google.api import v2 from cirq_google.serialization.circuit_serializer import _SERIALIZER_NAME +from cirq_google.serialization.op_deserializer import OpDeserializer +from cirq_google.serialization.op_serializer import OpSerializer class FakeDevice(cirq.Device): @@ -856,6 +858,7 @@ def test_circuit_with_tag(tag): assert nc[0].operations[0].tags == (tag,) +@pytest.mark.filterwarnings('ignore:Unrecognized Tag .*DingDongTag') def test_unknown_tag_is_ignored(): class DingDongTag: pass @@ -866,6 +869,7 @@ class DingDongTag: assert cirq.Circuit(cirq.X(cirq.q(0))) == nc +@pytest.mark.filterwarnings('ignore:Unknown tag msg=phase_match') def test_unrecognized_tag_is_ignored(): op_tag = v2.program_pb2.Operation() op_tag.xpowgate.exponent.float_value = 1.0 @@ -917,3 +921,90 @@ def test_circuit_with_units(): ) msg = cg.CIRCUIT_SERIALIZER.serialize(c) assert c == cg.CIRCUIT_SERIALIZER.deserialize(msg) + + +class BingBongGate(cirq.Gate): + + def __init__(self, param: float): + self.param = param + + def _num_qubits_(self) -> int: + return 1 + + +class BingBongSerializer(OpSerializer): + """Describes how to serialize CircuitOperations.""" + + def can_serialize_operation(self, op): + return isinstance(op.gate, BingBongGate) + + def to_proto( + self, + op: cirq.CircuitOperation, + msg: Optional[v2.program_pb2.CircuitOperation] = None, + *, + arg_function_language: Optional[str] = '', + constants: List[v2.program_pb2.Constant], + raw_constants: Dict[Any, int], + ) -> v2.program_pb2.CircuitOperation: + assert isinstance(op.gate, BingBongGate) + if msg is None: + msg = v2.program_pb2.Operation() # pragma: nocover + msg.internalgate.name = 'bingbong' + msg.internalgate.module = 'test' + msg.internalgate.num_qubits = 1 + msg.internalgate.gate_args['param'].arg_value.float_value = op.gate.param + + for qubit in op.qubits: + if qubit not in raw_constants: + constants.append( + v2.program_pb2.Constant( + qubit=v2.program_pb2.Qubit(id=v2.qubit_to_proto_id(qubit)) + ) + ) + raw_constants[qubit] = len(constants) - 1 + msg.qubit_constant_index.append(raw_constants[qubit]) + return msg + + +class BingBongDeserializer(OpDeserializer): + """Describes how to serialize CircuitOperations.""" + + def can_deserialize_proto(self, proto): + return ( + isinstance(proto, v2.program_pb2.Operation) + and proto.WhichOneof("gate_value") == "internalgate" + and proto.internalgate.name == 'bingbong' + and proto.internalgate.module == 'test' + ) + + def from_proto( + self, + proto: v2.program_pb2.Operation, + *, + arg_function_language: str = '', + constants: List[v2.program_pb2.Constant], + deserialized_constants: List[Any], + ) -> cirq.Operation: + return BingBongGate(param=proto.internalgate.gate_args["param"].arg_value.float_value).on( + deserialized_constants[proto.qubit_constant_index[0]] + ) + + +@pytest.mark.parametrize('use_constants_table', [True, False]) +def test_custom_serializer(use_constants_table: bool): + c = cirq.Circuit(BingBongGate(param=2.5)(cirq.q(0, 0))) + serializer = cg.CircuitSerializer( + USE_CONSTANTS_TABLE_FOR_MOMENTS=use_constants_table, + USE_CONSTANTS_TABLE_FOR_OPERATIONS=use_constants_table, + op_serializer=BingBongSerializer(), + op_deserializer=BingBongDeserializer(), + ) + msg = serializer.serialize(c) + deserialized_circuit = serializer.deserialize(msg) + moment = deserialized_circuit[0] + assert len(moment) == 1 + op = moment[cirq.q(0, 0)] + assert isinstance(op.gate, BingBongGate) + assert op.gate.param == 2.5 + assert op.qubits == (cirq.q(0, 0),) diff --git a/cirq-google/cirq_google/serialization/op_deserializer.py b/cirq-google/cirq_google/serialization/op_deserializer.py index 44dec4f09d7..ba46cf6db90 100644 --- a/cirq-google/cirq_google/serialization/op_deserializer.py +++ b/cirq-google/cirq_google/serialization/op_deserializer.py @@ -26,18 +26,12 @@ class OpDeserializer(abc.ABC): """Generic supertype for operation deserializers. Each operation deserializer describes how to deserialize operation protos - with a particular `serialized_id` to a specific type of Cirq operation. + to a specific type of Cirq operation. """ - @property @abc.abstractmethod - def serialized_id(self) -> str: - """Returns the string identifier for the accepted serialized objects. - - This ID denotes the serialization format this deserializer consumes. For - example, one of the common deserializers converts objects with the id - 'xy' into PhasedXPowGates. - """ + def can_deserialize_proto(self, proto) -> bool: + """Whether the given operation can be serialized by this serializer.""" @abc.abstractmethod def from_proto( @@ -66,9 +60,8 @@ def from_proto( class CircuitOpDeserializer(OpDeserializer): """Describes how to serialize CircuitOperations.""" - @property - def serialized_id(self): - return 'circuit' + def can_deserialize_proto(self, proto): + return isinstance(proto, v2.program_pb2.CircuitOperation) # pragma: nocover def from_proto( self, diff --git a/cirq-google/cirq_google/serialization/op_serializer.py b/cirq-google/cirq_google/serialization/op_serializer.py index 4cd8fea0604..0ee5d213575 100644 --- a/cirq-google/cirq_google/serialization/op_serializer.py +++ b/cirq-google/cirq_google/serialization/op_serializer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar +from typing import Any, Dict, List, Optional, Union import numbers import abc @@ -23,9 +23,6 @@ from cirq_google.api import v2 from cirq_google.serialization.arg_func_langs import arg_to_proto -# Type for variables that are subclasses of ops.Gate. -Gate = TypeVar('Gate', bound=cirq.Gate) - class OpSerializer(abc.ABC): """Generic supertype for operation serializers. @@ -35,25 +32,6 @@ class OpSerializer(abc.ABC): may serialize to the same format. """ - @property - @abc.abstractmethod - def internal_type(self) -> Type: - """Returns the type that the operation contains. - - For GateOperations, this is the gate type. - For CircuitOperations, this is FrozenCircuit. - """ - - @property - @abc.abstractmethod - def serialized_id(self) -> str: - """Returns the string identifier for the resulting serialized object. - - This ID denotes the serialization format this serializer produces. For - example, one of the common serializers assigns the id 'xy' to XPowGates, - as they serialize into a format also used by YPowGates. - """ - @abc.abstractmethod def to_proto( self, @@ -63,7 +41,7 @@ def to_proto( arg_function_language: Optional[str] = '', constants: List[v2.program_pb2.Constant], raw_constants: Dict[Any, int], - ) -> Optional[v2.program_pb2.CircuitOperation]: + ) -> Optional[Union[v2.program_pb2.CircuitOperation, v2.program_pb2.Operation]]: """Converts op to proto using this serializer. If self.can_serialize_operation(op) == false, this should return None. @@ -83,33 +61,16 @@ def to_proto( the returned object. """ - @property @abc.abstractmethod - def can_serialize_predicate(self) -> Callable[[cirq.Operation], bool]: - """The method used to determine if this can serialize an operation. - - Depending on the serializer, additional checks may be required. - """ - def can_serialize_operation(self, op: cirq.Operation) -> bool: """Whether the given operation can be serialized by this serializer.""" - return self.can_serialize_predicate(op) class CircuitOpSerializer(OpSerializer): """Describes how to serialize CircuitOperations.""" - @property - def internal_type(self): - return cirq.FrozenCircuit - - @property - def serialized_id(self): - return 'circuit' - - @property - def can_serialize_predicate(self): - return lambda op: isinstance(op.untagged, cirq.CircuitOperation) + def can_serialize_operation(self, op: cirq.Operation): + return isinstance(op.untagged, cirq.CircuitOperation) def to_proto( self, diff --git a/cirq-google/cirq_google/serialization/op_serializer_test.py b/cirq-google/cirq_google/serialization/op_serializer_test.py index 0faf4fffe42..788ba4ceb7c 100644 --- a/cirq-google/cirq_google/serialization/op_serializer_test.py +++ b/cirq-google/cirq_google/serialization/op_serializer_test.py @@ -49,12 +49,6 @@ def default_circuit(): ) -def test_circuit_op_serializer_properties(): - serializer = cg.CircuitOpSerializer() - assert serializer.internal_type == cirq.FrozenCircuit - assert serializer.serialized_id == 'circuit' - - def test_can_serialize_circuit_op(): serializer = cg.CircuitOpSerializer() assert serializer.can_serialize_operation(cirq.CircuitOperation(default_circuit()))