12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import logging
15
- from collections import defaultdict
16
- from typing import Callable , Dict , Sequence , Tuple , TYPE_CHECKING
15
+ from collections import Counter , defaultdict
16
+ from typing import Callable , Dict , Mapping , Sequence , Tuple , TYPE_CHECKING
17
17
18
18
import attrs
19
19
import networkx as nx
@@ -112,6 +112,12 @@ def __str__(self):
112
112
return f'{ self .gateset_name } counts'
113
113
114
114
115
+ def _mapping_to_counter (mapping : Mapping [float , int ]) -> Counter [float ]:
116
+ if isinstance (mapping , Counter ):
117
+ return mapping
118
+ return Counter (mapping )
119
+
120
+
115
121
@frozen (kw_only = True )
116
122
class GateCounts :
117
123
"""A data class of counts of the typical target gates in a compilation.
@@ -125,8 +131,17 @@ class GateCounts:
125
131
cswap : int = 0
126
132
and_bloq : int = 0
127
133
clifford : int = 0
128
- rotation : int = 0
129
134
measurement : int = 0
135
+ rotation_epsilons : Counter [float ] = field (factory = Counter , converter = _mapping_to_counter )
136
+
137
+ @property
138
+ def rotation (self ):
139
+ from qualtran .cirq_interop .t_complexity_protocol import TComplexity
140
+
141
+ return sum (
142
+ n_rotations * int (TComplexity .rotation_cost (eps ))
143
+ for eps , n_rotations in self .rotation_epsilons .items ()
144
+ )
130
145
131
146
def __add__ (self , other ):
132
147
if not isinstance (other , GateCounts ):
@@ -138,8 +153,8 @@ def __add__(self, other):
138
153
cswap = self .cswap + other .cswap ,
139
154
and_bloq = self .and_bloq + other .and_bloq ,
140
155
clifford = self .clifford + other .clifford ,
141
- rotation = self .rotation + other .rotation ,
142
156
measurement = self .measurement + other .measurement ,
157
+ rotation_epsilons = self .rotation_epsilons + other .rotation_epsilons ,
143
158
)
144
159
145
160
def __mul__ (self , other ):
@@ -149,8 +164,8 @@ def __mul__(self, other):
149
164
cswap = other * self .cswap ,
150
165
and_bloq = other * self .and_bloq ,
151
166
clifford = other * self .clifford ,
152
- rotation = other * self .rotation ,
153
167
measurement = other * self .measurement ,
168
+ rotation_epsilons = Counter ({k : other * v for k , v in self .rotation_epsilons .items ()}),
154
169
)
155
170
156
171
def __rmul__ (self , other ):
@@ -167,7 +182,13 @@ def __str__(self):
167
182
168
183
def asdict (self ):
169
184
d = attrs .asdict (self )
170
- return {k : v for k , v in d .items () if v > 0 }
185
+
186
+ def _keep (key , value ) -> bool :
187
+ if key == 'rotation_epsilons' :
188
+ return value
189
+ return value > 0
190
+
191
+ return {k : v for k , v in d .items () if _keep (k , v )}
171
192
172
193
def total_t_count (
173
194
self ,
@@ -232,6 +253,7 @@ class QECGatesCost(CostKey[GateCounts]):
232
253
233
254
def compute (self , bloq : 'Bloq' , get_callee_cost : Callable [['Bloq' ], GateCounts ]) -> GateCounts :
234
255
from qualtran .bloqs .basic_gates import TGate , Toffoli , TwoBitCSwap
256
+ from qualtran .bloqs .basic_gates .rotation import _HasEps
235
257
from qualtran .bloqs .mcmt .and_bloq import And
236
258
237
259
# T gates
@@ -257,7 +279,8 @@ def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], GateCounts])
257
279
return GateCounts (clifford = 1 )
258
280
259
281
if bloq_is_rotation (bloq ):
260
- return GateCounts (rotation = 1 )
282
+ assert isinstance (bloq , _HasEps )
283
+ return GateCounts (rotation_epsilons = {bloq .eps : 1 })
261
284
262
285
# Recursive case
263
286
totals = GateCounts ()
0 commit comments