Skip to content

Commit 61149a9

Browse files
authored
Fix circuit.explain_dem_errors not supporting all gates (#704)
- Fix a nasty miscomputed allocation size om `stim::MonotonicBuffer` - Add `stim::inplace_xor_sort` C++ helper method - Add support for `MXX`, `MYY`, `MZZ`, `HERALDED_ERASE`, `HERALDED_PAULI_CHANNEL_1`, `MPAD` to `stim::ErrorMatcher` - Add a unit test verifying `stim::ErrorMatcher` supports all gates Fixes #697
1 parent f11b4f9 commit 61149a9

9 files changed

+341
-48
lines changed

src/stim/mem/monotonic_buffer.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct MonotonicBuffer {
155155
return;
156156
}
157157

158-
size_t alloc_count = std::max(min_required, cur.size() << 1);
158+
size_t alloc_count = std::max(min_required + tail.size(), cur.size() << 1);
159159
if (cur.ptr_start != nullptr) {
160160
old_areas.push_back(cur);
161161
}

src/stim/mem/monotonic_buffer.test.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TEST(pointer_range, equality) {
3939
ASSERT_NE(r1, r2);
4040
}
4141

42-
TEST(monotonic_buffer, x) {
42+
TEST(monotonic_buffer, append_tail) {
4343
MonotonicBuffer<int> buf;
4444
for (size_t k = 0; k < 100; k++) {
4545
buf.append_tail(k);
@@ -51,3 +51,15 @@ TEST(monotonic_buffer, x) {
5151
ASSERT_EQ(rng[k], k);
5252
}
5353
}
54+
55+
TEST(monotonic_buffer, ensure_available) {
56+
MonotonicBuffer<int> buf;
57+
buf.append_tail(std::vector<int>{1, 2, 3, 4});
58+
buf.append_tail(std::vector<int>{5, 6});
59+
buf.append_tail(std::vector<int>{7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
60+
61+
SpanRef<const int> rng = buf.commit_tail();
62+
std::vector<int> expected{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
63+
SpanRef<const int> v = expected;
64+
ASSERT_EQ(rng, v);
65+
}

src/stim/mem/sparse_xor_vec.h

+17
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,23 @@ inline T *xor_merge_sort(SpanRef<const T> sorted_in1, SpanRef<const T> sorted_in
6060
return out;
6161
}
6262

63+
template <typename T>
64+
inline SpanRef<T> inplace_xor_sort(SpanRef<T> items) {
65+
std::sort(items.begin(), items.end());
66+
size_t new_size = 0;
67+
for (size_t k = 0; k < items.size(); k++) {
68+
if (new_size > 0 && items[k] == items[new_size - 1]) {
69+
new_size--;
70+
} else {
71+
if (k != new_size) {
72+
std::swap(items[new_size], items[k]);
73+
}
74+
new_size++;
75+
}
76+
}
77+
return items.sub(0, new_size);
78+
}
79+
6380
template <typename T>
6481
bool is_subset_of_sorted(SpanRef<const T> subset, SpanRef<const T> superset) {
6582
const T *p_sub = subset.ptr_start;

src/stim/mem/sparse_xor_vec.test.cc

+19
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,22 @@ TEST(sparse_xor_vec, contains) {
171171
ASSERT_FALSE((SparseXorVec<uint32_t>{{}}).contains(0));
172172
ASSERT_FALSE((SparseXorVec<uint32_t>{{1}}).contains(0));
173173
}
174+
175+
TEST(sparse_xor_vec, inplace_xor_sort) {
176+
auto f = [](std::vector<int> v) -> std::vector<int> {
177+
SpanRef<int> s = v;
178+
auto r = inplace_xor_sort(s);
179+
v.resize(r.size());
180+
return v;
181+
};
182+
ASSERT_EQ(f({}), (std::vector<int>({})));
183+
ASSERT_EQ(f({5}), (std::vector<int>({5})));
184+
ASSERT_EQ(f({5, 5}), (std::vector<int>({})));
185+
ASSERT_EQ(f({5, 5, 5}), (std::vector<int>({5})));
186+
ASSERT_EQ(f({5, 5, 5, 5}), (std::vector<int>({})));
187+
ASSERT_EQ(f({5, 4, 5, 5}), (std::vector<int>({4, 5})));
188+
ASSERT_EQ(f({4, 5, 5, 5}), (std::vector<int>({4, 5})));
189+
ASSERT_EQ(f({5, 5, 5, 4}), (std::vector<int>({4, 5})));
190+
ASSERT_EQ(f({4, 5, 5, 4}), (std::vector<int>({})));
191+
ASSERT_EQ(f({3, 5, 5, 4}), (std::vector<int>({3, 4})));
192+
}

src/stim/simulators/error_analyzer.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ void ErrorAnalyzer::undo_DEPOLARIZE2(const CircuitInstruction &dat) {
837837

838838
void ErrorAnalyzer::undo_ELSE_CORRELATED_ERROR(const CircuitInstruction &dat) {
839839
if (accumulate_errors) {
840-
throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR" + dat.str());
840+
throw std::invalid_argument("Failed to analyze ELSE_CORRELATED_ERROR: " + dat.str());
841841
}
842842
}
843843

src/stim/simulators/error_matcher.cc

+144-45
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,7 @@ ErrorMatcher::ErrorMatcher(
5959
}
6060
}
6161

62-
void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
63-
assert(error_analyzer.error_class_probabilities.empty());
64-
error_analyzer.undo_gate(effect);
65-
if (error_analyzer.error_class_probabilities.empty()) {
66-
/// Maybe there were no detectors or observables nearby? Or the noise probability was zero?
67-
return;
68-
}
69-
70-
assert(error_analyzer.error_class_probabilities.size() == 1);
71-
SpanRef<const DemTarget> dem_error_terms = error_analyzer.error_class_probabilities.begin()->first;
62+
void ErrorMatcher::add_dem_error_terms(SpanRef<const DemTarget> dem_error_terms) {
7263
auto entry = output_map.find(dem_error_terms);
7364
if (!dem_error_terms.empty() && (allow_adding_new_dem_errors_to_output_map || entry != output_map.end())) {
7465
// We have a desired match! Record it.
@@ -88,6 +79,19 @@ void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
8879
out[0] = std::move(new_loc);
8980
}
9081
}
82+
}
83+
84+
void ErrorMatcher::err_atom(const CircuitInstruction &effect) {
85+
assert(error_analyzer.error_class_probabilities.empty());
86+
error_analyzer.undo_gate(effect);
87+
if (error_analyzer.error_class_probabilities.empty()) {
88+
/// Maybe there were no detectors or observables nearby? Or the noise probability was zero?
89+
return;
90+
}
91+
92+
assert(error_analyzer.error_class_probabilities.size() == 1);
93+
SpanRef<const DemTarget> dem_error_terms = error_analyzer.error_class_probabilities.begin()->first;
94+
add_dem_error_terms(dem_error_terms);
9195

9296
// Restore the pristine state.
9397
error_analyzer.mono_buf.clear();
@@ -128,6 +132,58 @@ void ErrorMatcher::err_xyz(const CircuitInstruction &op, uint32_t target_flags)
128132
}
129133
}
130134

135+
void ErrorMatcher::err_heralded_pauli_channel_1(const CircuitInstruction &op) {
136+
assert(op.args.size() == 4);
137+
for (size_t k = op.targets.size(); k--;) {
138+
auto q = op.targets[k].qubit_value();
139+
cur_loc.instruction_targets.target_range_start = k;
140+
cur_loc.instruction_targets.target_range_end = k + 1;
141+
142+
cur_loc.flipped_measurement.measurement_record_index = error_analyzer.tracker.num_measurements_in_past - 1;
143+
SpanRef<const DemTarget> herald_symptoms = error_analyzer.tracker.rec_bits[error_analyzer.tracker.num_measurements_in_past - 1].range();
144+
SpanRef<const DemTarget> x_symptoms = error_analyzer.tracker.zs[q].range();
145+
SpanRef<const DemTarget> z_symptoms = error_analyzer.tracker.xs[q].range();
146+
if (op.args[0] != 0) {
147+
add_dem_error_terms(herald_symptoms);
148+
}
149+
if (op.args[1] != 0) {
150+
error_analyzer.mono_buf.append_tail(herald_symptoms);
151+
error_analyzer.mono_buf.append_tail(x_symptoms);
152+
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
153+
resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT, cur_loc.flipped_pauli_product);
154+
add_dem_error_terms(error_analyzer.mono_buf.tail);
155+
cur_loc.flipped_pauli_product.clear();
156+
error_analyzer.mono_buf.discard_tail();
157+
}
158+
if (op.args[2] != 0) {
159+
error_analyzer.mono_buf.append_tail(herald_symptoms);
160+
error_analyzer.mono_buf.append_tail(x_symptoms);
161+
error_analyzer.mono_buf.append_tail(z_symptoms);
162+
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
163+
resolve_paulis_into(&op.targets[k], TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product);
164+
add_dem_error_terms(error_analyzer.mono_buf.tail);
165+
cur_loc.flipped_pauli_product.clear();
166+
error_analyzer.mono_buf.discard_tail();
167+
}
168+
if (op.args[3] != 0) {
169+
error_analyzer.mono_buf.append_tail(herald_symptoms);
170+
error_analyzer.mono_buf.append_tail(z_symptoms);
171+
error_analyzer.mono_buf.tail = inplace_xor_sort(error_analyzer.mono_buf.tail);
172+
resolve_paulis_into(&op.targets[k], TARGET_PAULI_Z_BIT, cur_loc.flipped_pauli_product);
173+
add_dem_error_terms(error_analyzer.mono_buf.tail);
174+
cur_loc.flipped_pauli_product.clear();
175+
error_analyzer.mono_buf.discard_tail();
176+
}
177+
cur_loc.flipped_measurement.measurement_record_index = UINT64_MAX;
178+
179+
assert(error_analyzer.error_class_probabilities.empty());
180+
error_analyzer.tracker.undo_gate(op);
181+
error_analyzer.mono_buf.clear();
182+
error_analyzer.error_class_probabilities.clear();
183+
error_analyzer.flushed_reversed_model.clear();
184+
}
185+
}
186+
131187
void ErrorMatcher::err_pauli_channel_1(const CircuitInstruction &op) {
132188
const auto &a = op.args;
133189
const auto &t = op.targets;
@@ -187,12 +243,17 @@ void ErrorMatcher::err_m(const CircuitInstruction &op, uint32_t obs_mask) {
187243
const auto &t = op.targets;
188244
const auto &a = op.args;
189245

246+
bool q2 = GATE_DATA[op.gate_type].flags & GATE_TARGETS_PAIRS;
190247
size_t end = t.size();
191248
while (end > 0) {
192249
size_t start = end - 1;
193250
while (start > 0 && t[start - 1].is_combiner()) {
194251
start -= std::min(start, size_t{2});
195252
}
253+
if (q2) {
254+
start--;
255+
}
256+
196257

197258
SpanRef<const GateTarget> slice{t.begin() + start, t.begin() + end};
198259

@@ -227,48 +288,86 @@ void ErrorMatcher::rev_process_instruction(const CircuitInstruction &op) {
227288
entry->second.push_back(d);
228289
}
229290
}
291+
return;
230292
} else if (op.gate_type == GateType::SHIFT_COORDS) {
231293
error_analyzer.undo_SHIFT_COORDS(op);
232294
for (size_t k = 0; k < op.args.size(); k++) {
233295
cur_coord_offset[k] -= op.args[k];
234296
}
297+
return;
235298
} else if (!(flags & (GATE_IS_NOISY | GATE_PRODUCES_RESULTS))) {
236299
error_analyzer.undo_gate(op);
237-
} else if (op.gate_type == GateType::E || op.gate_type == GateType::ELSE_CORRELATED_ERROR) {
238-
cur_loc.instruction_targets.target_range_start = 0;
239-
cur_loc.instruction_targets.target_range_end = op.targets.size();
240-
resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product);
241-
err_atom(op);
242-
cur_loc.flipped_pauli_product.clear();
243-
} else if (op.gate_type == GateType::X_ERROR) {
244-
err_xyz(op, TARGET_PAULI_X_BIT);
245-
} else if (op.gate_type == GateType::Y_ERROR) {
246-
err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
247-
} else if (op.gate_type == GateType::Z_ERROR) {
248-
err_xyz(op, TARGET_PAULI_Z_BIT);
249-
} else if (op.gate_type == GateType::PAULI_CHANNEL_1) {
250-
err_pauli_channel_1(op);
251-
} else if (op.gate_type == GateType::DEPOLARIZE1) {
252-
float p = op.args[0];
253-
std::array<double, 3> spread{p, p, p};
254-
err_pauli_channel_1({op.gate_type, spread, op.targets});
255-
} else if (op.gate_type == GateType::PAULI_CHANNEL_2) {
256-
err_pauli_channel_2(op);
257-
} else if (op.gate_type == GateType::DEPOLARIZE2) {
258-
float p = op.args[0];
259-
std::array<double, 15> spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p};
260-
err_pauli_channel_2({op.gate_type, spread, op.targets});
261-
} else if (op.gate_type == GateType::MPP) {
262-
err_m(op, 0);
263-
} else if (op.gate_type == GateType::MX || op.gate_type == GateType::MRX) {
264-
err_m(op, TARGET_PAULI_X_BIT);
265-
} else if (op.gate_type == GateType::MY || op.gate_type == GateType::MRY) {
266-
err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
267-
} else if (op.gate_type == GateType::M || op.gate_type == GateType::MR) {
268-
err_m(op, TARGET_PAULI_Z_BIT);
269-
} else {
270-
throw std::invalid_argument(
271-
"Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name));
300+
return;
301+
}
302+
switch (op.gate_type) {
303+
case GateType::MPAD:
304+
error_analyzer.undo_gate(op);
305+
break;
306+
case GateType::E:
307+
case GateType::ELSE_CORRELATED_ERROR: {
308+
cur_loc.instruction_targets.target_range_start = 0;
309+
cur_loc.instruction_targets.target_range_end = op.targets.size();
310+
resolve_paulis_into(op.targets, 0, cur_loc.flipped_pauli_product);
311+
CircuitInstruction op2 = op;
312+
op2.gate_type = GateType::E;
313+
err_atom(op2);
314+
cur_loc.flipped_pauli_product.clear();
315+
break;
316+
} case GateType::X_ERROR:
317+
err_xyz(op, TARGET_PAULI_X_BIT);
318+
break;
319+
case GateType::Y_ERROR:
320+
err_xyz(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
321+
break;
322+
case GateType::Z_ERROR:
323+
err_xyz(op, TARGET_PAULI_Z_BIT);
324+
break;
325+
case GateType::PAULI_CHANNEL_1:
326+
err_pauli_channel_1(op);
327+
break;
328+
case GateType::HERALDED_PAULI_CHANNEL_1:
329+
err_heralded_pauli_channel_1(op);
330+
break;
331+
case GateType::HERALDED_ERASE: {
332+
float p = op.args[0] / 4;
333+
std::array<double, 4> spread{p, p, p, p};
334+
err_heralded_pauli_channel_1({op.gate_type, spread, op.targets});
335+
break;
336+
} case GateType::DEPOLARIZE1: {
337+
float p = op.args[0];
338+
std::array<double, 3> spread{p, p, p};
339+
err_pauli_channel_1({op.gate_type, spread, op.targets});
340+
break;
341+
} case GateType::PAULI_CHANNEL_2:
342+
err_pauli_channel_2(op);
343+
break;
344+
case GateType::DEPOLARIZE2: {
345+
float p = op.args[0];
346+
std::array<double, 15> spread{p, p, p, p, p, p, p, p, p, p, p, p, p, p, p};
347+
err_pauli_channel_2({op.gate_type, spread, op.targets});
348+
break;
349+
}
350+
case GateType::MPP:
351+
err_m(op, 0);
352+
break;
353+
case GateType::MX:
354+
case GateType::MRX:
355+
case GateType::MXX:
356+
err_m(op, TARGET_PAULI_X_BIT);
357+
break;
358+
case GateType::MY:
359+
case GateType::MRY:
360+
case GateType::MYY:
361+
err_m(op, TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT);
362+
break;
363+
case GateType::M:
364+
case GateType::MR:
365+
case GateType::MZZ:
366+
err_m(op, TARGET_PAULI_Z_BIT);
367+
break;
368+
default:
369+
throw std::invalid_argument(
370+
"Not implemented in ErrorMatcher::rev_process_instruction: " + std::string(GATE_DATA[op.gate_type].name));
272371
}
273372
}
274373

src/stim/simulators/error_matcher.h

+4
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ struct ErrorMatcher {
7878
void err_atom(const CircuitInstruction &effect);
7979
/// Processes operations with X, Y, Z errors on each target.
8080
void err_pauli_channel_1(const CircuitInstruction &op);
81+
/// Processes operations with M, X, Y, Z errors on each target.
82+
void err_heralded_pauli_channel_1(const CircuitInstruction &op);
8183
/// Processes operations with 15 two-qubit Pauli product errors on each target pair.
8284
void err_pauli_channel_2(const CircuitInstruction &op);
8385
/// Processes measurement operations.
@@ -88,6 +90,8 @@ struct ErrorMatcher {
8890
void rev_process_instruction(const CircuitInstruction &op);
8991
/// Processes entire circuits.
9092
void rev_process_circuit(uint64_t reps, const Circuit &block);
93+
94+
void add_dem_error_terms(SpanRef<const DemTarget> dem_error_terms);
9195
};
9296

9397
} // namespace stim

0 commit comments

Comments
 (0)