Skip to content

Commit 9811a6a

Browse files
authored
Fix bit pack of mwpf and fusion blossom decoders under multiple logical observable (#873)
This PR fixed two bugs in MWPF decoder ## 1. Supporting decomposed detector error model While MWPF expects a decoding hypergraph, the input detector error model from sinter is by default decomposed. The decomposed DEM may contain the same detector or logical observable multiple times, which is not considered by the previous implementation. The previous implementation assumes that each detector and logical observable only appears once, thus, I used ```python frames: List[int] = [] ... frames.append(t.val) ``` However, this no longer works if the same frame appears in multiple decomposed parts. In this case, the DEM actually means that "the hyperedge contributes to the logical observable iff count(frame) % 2 == 1". This is fixed by ```python frames: set[int] = set() ... frames ^= { t.val } ``` ## 2. Supporting multiple logical observables Although a previous [PR #864](#864) has fixed the panic issue when multiple logical observables are encountered, the returned value is actually problematic and causes significantly higher logical error rate. The previous implementation converts a `int` typed bitmask to a bitpacked value using `np.packbits(prediction, bitorder="little")`. However, this doesn't work for more than one logical observables. For example, if I define an observable using `OBSERVABLE_INCLUDE(2) ...`, supposedly the bitpacked value should be `[4]` because $1<<2 = 4$. However, `np.packbits(4, bitorder="little") = [1]`, which is incorrect. The correct procedure is first generate the binary representation with `self.num_obs` bits using `np.binary_repr(prediction, width=self.num_obs)`, in this case, `'100'`, and then revert the order of the bits to `['0', '0', '1']`, and then run the packbits which gives us the correct value `[4]`. The full code is below: ```python predictions[shot] = np.packbits( np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8), bitorder="little", ) ```
1 parent 6afad14 commit 9811a6a

File tree

2 files changed

+39
-54
lines changed

2 files changed

+39
-54
lines changed

glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def decode_shots_bit_packed(
3131
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
3232
self.solver.solve(syndrome)
3333
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
34-
predictions[shot] = np.packbits(prediction, bitorder='little')
34+
predictions[shot] = np.packbits(
35+
np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
36+
bitorder="little",
37+
)
3538
self.solver.clear()
3639
return predictions
3740

glue/sample/src/sinter/_decoding/_decoding_mwpf.py

+35-53
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def decode_shots_bit_packed(
3838
bit_packed_detection_event_data: "np.ndarray",
3939
) -> "np.ndarray":
4040
num_shots = bit_packed_detection_event_data.shape[0]
41-
predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
41+
predictions = np.zeros(
42+
shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
43+
)
4244
import mwpf
4345

4446
for shot in range(num_shots):
@@ -58,29 +60,42 @@ def decode_shots_bit_packed(
5860
np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
5961
)
6062
self.solver.clear()
61-
predictions[shot] = np.packbits(prediction, bitorder="little")
63+
predictions[shot] = np.packbits(
64+
np.array(
65+
list(np.binary_repr(prediction, width=self.num_obs))[::-1],
66+
dtype=np.uint8,
67+
),
68+
bitorder="little",
69+
)
6270
return predictions
6371

6472

6573
class MwpfDecoder(Decoder):
6674
"""Use MWPF to predict observables from detection events."""
6775

68-
def compile_decoder_for_dem(
76+
def __init__(
6977
self,
70-
*,
71-
dem: "stim.DetectorErrorModel",
7278
decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
7379
# in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
7480
# but just provide different plugins for optimizing the primal and/or dual solutions.
7581
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
7682
# grows the clusters until the first valid solution appears; some more optimized solvers uses
7783
# one or more plugins to further optimize the solution, which requires longer decoding time.
78-
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
84+
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
85+
):
86+
self.decoder_cls = decoder_cls
87+
self.cluster_node_limit = cluster_node_limit
88+
super().__init__()
89+
90+
def compile_decoder_for_dem(
91+
self,
92+
*,
93+
dem: "stim.DetectorErrorModel",
7994
) -> CompiledDecoder:
8095
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
8196
dem,
82-
decoder_cls=decoder_cls,
83-
cluster_node_limit=cluster_node_limit,
97+
decoder_cls=self.decoder_cls,
98+
cluster_node_limit=self.cluster_node_limit,
8499
)
85100
return MwpfCompiledDecoder(
86101
solver,
@@ -99,13 +114,14 @@ def decode_via_files(
99114
dets_b8_in_path: pathlib.Path,
100115
obs_predictions_b8_out_path: pathlib.Path,
101116
tmp_dir: pathlib.Path,
102-
decoder_cls: Any = None,
103117
) -> None:
104118
import mwpf
105119

106120
error_model = stim.DetectorErrorModel.from_file(dem_path)
107121
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
108-
error_model, decoder_cls=decoder_cls
122+
error_model,
123+
decoder_cls=self.decoder_cls,
124+
cluster_node_limit=self.cluster_node_limit,
109125
)
110126
num_det_bytes = math.ceil(num_dets / 8)
111127
with open(dets_b8_in_path, "rb") as dets_in_f:
@@ -136,44 +152,8 @@ def decode_via_files(
136152

137153

138154
class HyperUFDecoder(MwpfDecoder):
139-
def compile_decoder_for_dem(
140-
self, *, dem: "stim.DetectorErrorModel"
141-
) -> CompiledDecoder:
142-
try:
143-
import mwpf
144-
except ImportError as ex:
145-
raise mwpf_import_error() from ex
146-
147-
return super().compile_decoder_for_dem(
148-
dem=dem, decoder_cls=mwpf.SolverSerialUnionFind
149-
)
150-
151-
def decode_via_files(
152-
self,
153-
*,
154-
num_shots: int,
155-
num_dets: int,
156-
num_obs: int,
157-
dem_path: pathlib.Path,
158-
dets_b8_in_path: pathlib.Path,
159-
obs_predictions_b8_out_path: pathlib.Path,
160-
tmp_dir: pathlib.Path,
161-
) -> None:
162-
try:
163-
import mwpf
164-
except ImportError as ex:
165-
raise mwpf_import_error() from ex
166-
167-
return super().decode_via_files(
168-
num_shots=num_shots,
169-
num_dets=num_dets,
170-
num_obs=num_obs,
171-
dem_path=dem_path,
172-
dets_b8_in_path=dets_b8_in_path,
173-
obs_predictions_b8_out_path=obs_predictions_b8_out_path,
174-
tmp_dir=tmp_dir,
175-
decoder_cls=mwpf.SolverSerialUnionFind,
176-
)
155+
def __init__(self):
156+
super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)
177157

178158

179159
def iter_flatten_model(
@@ -193,16 +173,16 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
193173
_helper(instruction.body_copy(), instruction.repeat_count)
194174
elif isinstance(instruction, stim.DemInstruction):
195175
if instruction.type == "error":
196-
dets: List[int] = []
197-
frames: List[int] = []
176+
dets: set[int] = set()
177+
frames: set[int] = set()
198178
t: stim.DemTarget
199179
p = instruction.args_copy()[0]
200180
for t in instruction.targets_copy():
201181
if t.is_relative_detector_id():
202-
dets.append(t.val + det_offset)
182+
dets ^= {t.val + det_offset}
203183
elif t.is_logical_observable_id():
204-
frames.append(t.val)
205-
handle_error(p, dets, frames)
184+
frames ^= {t.val}
185+
handle_error(p, list(dets), list(frames))
206186
elif instruction.type == "shift_detectors":
207187
det_offset += instruction.targets_copy()[0]
208188
a = np.array(instruction.args_copy())
@@ -310,6 +290,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
310290
if decoder_cls is None:
311291
# default to the solver with highest accuracy
312292
decoder_cls = mwpf.SolverSerialJointSingleHair
293+
elif isinstance(decoder_cls, str):
294+
decoder_cls = getattr(mwpf, decoder_cls)
313295
return (
314296
(
315297
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})

0 commit comments

Comments
 (0)