Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bit pack of mwpf and fusion blossom decoders under multiple logical observable #873

Merged
merged 21 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion glue/sample/src/sinter/_decoding/_decoding_fusion_blossom.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def decode_shots_bit_packed(
syndrome = fusion_blossom.SyndromePattern(syndrome_vertices=dets_sparse)
self.solver.solve(syndrome)
prediction = int(np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()]))
predictions[shot] = np.packbits(prediction, bitorder='little')
predictions[shot] = np.packbits(
np.array(list(np.binary_repr(prediction, width=self.num_obs))[::-1],dtype=np.uint8),
bitorder="little",
)
self.solver.clear()
return predictions

Expand Down
88 changes: 35 additions & 53 deletions glue/sample/src/sinter/_decoding/_decoding_mwpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def decode_shots_bit_packed(
bit_packed_detection_event_data: "np.ndarray",
) -> "np.ndarray":
num_shots = bit_packed_detection_event_data.shape[0]
predictions = np.zeros(shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8)
predictions = np.zeros(
shape=(num_shots, (self.num_obs + 7) // 8), dtype=np.uint8
)
import mwpf

for shot in range(num_shots):
Expand All @@ -58,29 +60,42 @@ def decode_shots_bit_packed(
np.bitwise_xor.reduce(self.fault_masks[self.solver.subgraph()])
)
self.solver.clear()
predictions[shot] = np.packbits(prediction, bitorder="little")
predictions[shot] = np.packbits(
np.array(
list(np.binary_repr(prediction, width=self.num_obs))[::-1],
dtype=np.uint8,
),
bitorder="little",
)
return predictions


class MwpfDecoder(Decoder):
"""Use MWPF to predict observables from detection events."""

def compile_decoder_for_dem(
def __init__(
self,
*,
dem: "stim.DetectorErrorModel",
decoder_cls: Any = None, # decoder class used to construct the MWPF decoder.
# in the Rust implementation, all of them inherits from the class of `SolverSerialPlugins`
# but just provide different plugins for optimizing the primal and/or dual solutions.
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
# grows the clusters until the first valid solution appears; some more optimized solvers uses
# one or more plugins to further optimize the solution, which requires longer decoding time.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster,
):
self.decoder_cls = decoder_cls
self.cluster_node_limit = cluster_node_limit
super().__init__()

def compile_decoder_for_dem(
self,
*,
dem: "stim.DetectorErrorModel",
) -> CompiledDecoder:
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
dem,
decoder_cls=decoder_cls,
cluster_node_limit=cluster_node_limit,
decoder_cls=self.decoder_cls,
cluster_node_limit=self.cluster_node_limit,
)
return MwpfCompiledDecoder(
solver,
Expand All @@ -99,13 +114,14 @@ def decode_via_files(
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
decoder_cls: Any = None,
) -> None:
import mwpf

error_model = stim.DetectorErrorModel.from_file(dem_path)
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
error_model, decoder_cls=decoder_cls
error_model,
decoder_cls=self.decoder_cls,
cluster_node_limit=self.cluster_node_limit,
)
num_det_bytes = math.ceil(num_dets / 8)
with open(dets_b8_in_path, "rb") as dets_in_f:
Expand Down Expand Up @@ -136,44 +152,8 @@ def decode_via_files(


class HyperUFDecoder(MwpfDecoder):
def compile_decoder_for_dem(
self, *, dem: "stim.DetectorErrorModel"
) -> CompiledDecoder:
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().compile_decoder_for_dem(
dem=dem, decoder_cls=mwpf.SolverSerialUnionFind
)

def decode_via_files(
self,
*,
num_shots: int,
num_dets: int,
num_obs: int,
dem_path: pathlib.Path,
dets_b8_in_path: pathlib.Path,
obs_predictions_b8_out_path: pathlib.Path,
tmp_dir: pathlib.Path,
) -> None:
try:
import mwpf
except ImportError as ex:
raise mwpf_import_error() from ex

return super().decode_via_files(
num_shots=num_shots,
num_dets=num_dets,
num_obs=num_obs,
dem_path=dem_path,
dets_b8_in_path=dets_b8_in_path,
obs_predictions_b8_out_path=obs_predictions_b8_out_path,
tmp_dir=tmp_dir,
decoder_cls=mwpf.SolverSerialUnionFind,
)
def __init__(self):
super().__init__(decoder_cls="SolverSerialUnionFind", cluster_node_limit=0)


def iter_flatten_model(
Expand All @@ -193,16 +173,16 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
_helper(instruction.body_copy(), instruction.repeat_count)
elif isinstance(instruction, stim.DemInstruction):
if instruction.type == "error":
dets: List[int] = []
frames: List[int] = []
dets: set[int] = set()
frames: set[int] = set()
t: stim.DemTarget
p = instruction.args_copy()[0]
for t in instruction.targets_copy():
if t.is_relative_detector_id():
dets.append(t.val + det_offset)
dets ^= {t.val + det_offset}
elif t.is_logical_observable_id():
frames.append(t.val)
handle_error(p, dets, frames)
frames ^= {t.val}
handle_error(p, list(dets), list(frames))
elif instruction.type == "shift_detectors":
det_offset += instruction.targets_copy()[0]
a = np.array(instruction.args_copy())
Expand Down Expand Up @@ -310,6 +290,8 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
if decoder_cls is None:
# default to the solver with highest accuracy
decoder_cls = mwpf.SolverSerialJointSingleHair
elif isinstance(decoder_cls, str):
decoder_cls = getattr(mwpf, decoder_cls)
return (
(
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
Expand Down
Loading