Skip to content

Commit 4f1d217

Browse files
authored
Make stim.Tableau.from_stabilizers faster (#713)
- Store the growing reduction as a circuit instead of as a tableau - Measured 20x faster (140ms -> 6ms) on a 144 qubit case - Measured 100x faster (15s -> 0.13s) on a 432 qubit case
1 parent 4040fd8 commit 4f1d217

File tree

3 files changed

+100
-41
lines changed

3 files changed

+100
-41
lines changed

src/stim/stabilizers/conversions.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ Circuit stabilizer_state_vector_to_circuit(
8383
/// ignore_noise: If the circuit contains noise channels, ignore them instead of raising an exception.
8484
/// ignore_measurement: If the circuit contains measurements, ignore them instead of raising an exception.
8585
/// ignore_reset: If the circuit contains resets, ignore them instead of raising an exception.
86+
/// inverse: The last step of the implementation is to invert the tableau. Setting this argument
87+
/// to true will skip this inversion, saving time but returning the inverse tableau.
8688
///
8789
/// Returns:
8890
/// A tableau encoding the given circuit's Clifford operation.
8991
template <size_t W>
90-
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset);
92+
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse = false);
9193

9294
/// Simulates the given circuit and outputs a state vector.
9395
///

src/stim/stabilizers/conversions.inl

+49-40
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ std::vector<std::vector<std::complex<float>>> tableau_to_unitary(const Tableau<W
149149
}
150150

151151
template <size_t W>
152-
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset) {
152+
Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ignore_measurement, bool ignore_reset, bool inverse) {
153153
Tableau<W> result(circuit.count_qubits());
154154
TableauSimulator<W> sim(std::mt19937_64(0), circuit.count_qubits());
155155

@@ -185,7 +185,10 @@ Tableau<W> circuit_to_tableau(const Circuit &circuit, bool ignore_noise, bool ig
185185
}
186186
});
187187

188-
return sim.inv_state.inverse();
188+
if (!inverse) {
189+
return sim.inv_state.inverse();
190+
}
191+
return sim.inv_state;
189192
}
190193

191194
template <size_t W>
@@ -556,7 +559,7 @@ Tableau<W> stabilizers_to_tableau(
556559
}
557560

558561
for (size_t k1 = 0; k1 < stabilizers.size(); k1++) {
559-
for (size_t k2 = 0; k2 < stabilizers.size(); k2++) {
562+
for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) {
560563
if (!stabilizers[k1].ref().commutes(stabilizers[k2])) {
561564
std::stringstream ss;
562565
ss << "Some of the given stabilizers anticommute.\n";
@@ -568,44 +571,39 @@ Tableau<W> stabilizers_to_tableau(
568571
}
569572
}
570573
}
571-
Tableau<W> inverted(num_qubits);
572-
573-
PauliString<W> cur(num_qubits);
574-
std::vector<size_t> targets;
575-
while (targets.size() < num_qubits) {
576-
targets.push_back(targets.size());
577-
}
578-
auto overwrite_cur_apply_recorded = [&](const PauliString<W> &e) {
579-
PauliStringRef<W> cur_ref = cur.ref();
580-
cur.xs.clear();
581-
cur.zs.clear();
582-
cur.xs.word_range_ref(0, e.xs.num_simd_words) = e.xs;
583-
cur.zs.word_range_ref(0, e.xs.num_simd_words) = e.zs;
584-
cur.sign = e.sign;
585-
inverted.apply_within(cur_ref, targets);
586-
};
574+
Circuit elimination_instructions;
575+
PauliString<W> buf(num_qubits);
587576

588577
size_t used = 0;
589578
for (const auto &e : stabilizers) {
590-
overwrite_cur_apply_recorded(e);
579+
if (e.num_qubits == num_qubits) {
580+
buf = e;
581+
} else {
582+
buf.xs.clear();
583+
buf.zs.clear();
584+
memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded());
585+
memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded());
586+
buf.sign = e.sign;
587+
}
588+
buf.ref().do_circuit(elimination_instructions);
591589

592590
// Find a non-identity term in the Pauli string past the region used by other stabilizers.
593591
size_t pivot;
594592
for (pivot = used; pivot < num_qubits; pivot++) {
595-
if (cur.xs[pivot] || cur.zs[pivot]) {
593+
if (buf.xs[pivot] || buf.zs[pivot]) {
596594
break;
597595
}
598596
}
599597

600598
// Check for incompatible / redundant stabilizers.
601599
if (pivot == num_qubits) {
602-
if (cur.xs.not_zero()) {
600+
if (buf.xs.not_zero()) {
603601
throw std::invalid_argument("Some of the given stabilizers anticommute.");
604602
}
605-
if (cur.sign) {
603+
if (buf.sign) {
606604
throw std::invalid_argument("Some of the given stabilizers contradict each other.");
607605
}
608-
if (!allow_redundant && cur.zs.not_zero()) {
606+
if (!allow_redundant && buf.zs.not_zero()) {
609607
throw std::invalid_argument(
610608
"Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. "
611609
"To allow redundant stabilizers, pass the argument allow_redundant=True.");
@@ -614,32 +612,36 @@ Tableau<W> stabilizers_to_tableau(
614612
}
615613

616614
// Change pivot basis to the Z axis.
617-
if (cur.xs[pivot]) {
618-
std::string name = cur.zs[pivot] ? "H_YZ" : "H_XZ";
619-
inverted.inplace_scatter_append(GATE_DATA.at(name).tableau<W>(), {pivot});
615+
if (buf.xs[pivot]) {
616+
GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H;
617+
GateTarget t = GateTarget::qubit(pivot);
618+
CircuitInstruction instruction{g, {}, &t};
619+
elimination_instructions.safe_append(instruction);
620+
buf.ref().do_instruction(instruction);
620621
}
621622
// Cancel other terms in Pauli string.
622623
for (size_t q = 0; q < num_qubits; q++) {
623-
int p = cur.xs[q] + cur.zs[q] * 2;
624+
int p = buf.xs[q] + buf.zs[q] * 2;
624625
if (p && q != pivot) {
625-
inverted.inplace_scatter_append(
626-
GATE_DATA.at(p == 1 ? "XCX"
627-
: p == 2 ? "XCZ"
628-
: "XCY")
629-
.tableau<W>(),
630-
{pivot, q});
626+
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(q)};
627+
CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets};
628+
elimination_instructions.safe_append(instruction);
629+
buf.ref().do_instruction(instruction);
631630
}
632631
}
633632

634633
// Move pivot to diagonal.
635634
if (pivot != used) {
636-
inverted.inplace_scatter_append(GATE_DATA.at("SWAP").tableau<W>(), {pivot, used});
635+
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(used)};
636+
CircuitInstruction instruction{GateType::SWAP, {}, targets};
637+
elimination_instructions.safe_append(instruction);
637638
}
638639

639640
// Fix sign.
640-
overwrite_cur_apply_recorded(e);
641-
if (cur.sign) {
642-
inverted.inplace_scatter_append(GATE_DATA.at("X").tableau<W>(), {used});
641+
if (buf.sign) {
642+
GateTarget t = GateTarget::qubit(used);
643+
CircuitInstruction instruction{GateType::X, {}, &t};
644+
elimination_instructions.safe_append(instruction);
643645
}
644646

645647
used++;
@@ -653,10 +655,17 @@ Tableau<W> stabilizers_to_tableau(
653655
}
654656
}
655657

658+
if (num_qubits > 0) {
659+
// Force size of resulting tableau to be correct.
660+
GateTarget t = GateTarget::qubit(num_qubits - 1);
661+
elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t});
662+
elimination_instructions.safe_append(CircuitInstruction{GateType::X, {}, &t});
663+
}
664+
656665
if (invert) {
657-
return inverted;
666+
return circuit_to_tableau<W>(elimination_instructions.inverse(), false, false, false, true);
658667
}
659-
return inverted.inverse();
668+
return circuit_to_tableau<W>(elimination_instructions, false, false, false, true);
660669
}
661670

662671
} // namespace stim

src/stim/stabilizers/conversions.perf.cc

+48
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,51 @@ BENCHMARK(independent_to_disjoint_xyz_errors) {
6767
std::cout << "data dependence";
6868
}
6969
}
70+
71+
BENCHMARK(stabilizers_to_tableau) {
72+
std::vector<std::complex<float>> offsets{
73+
{1, 0},
74+
{-1, 0},
75+
{0, 1},
76+
{0, -1},
77+
{3, 6},
78+
{-6, 3},
79+
};
80+
size_t w = 24;
81+
size_t h = 12;
82+
83+
auto normalize = [&](std::complex<float> c) -> std::complex<float> {
84+
return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)};
85+
};
86+
auto q2i = [&](std::complex<float> c) -> size_t {
87+
c = normalize(c);
88+
return (int)c.real() / 2 + c.imag() * (w / 2);
89+
};
90+
91+
std::vector<stim::PauliString<64>> stabilizers;
92+
for (size_t x = 0; x < w; x++) {
93+
for (size_t y = x % 2; y < h; y += 2) {
94+
std::complex<float> s{x % 2 ? -1.0f : +1.0f, 0.0f};
95+
std::complex<float> c{(float)x, (float)y};
96+
stim::PauliString<64> ps(w * h / 2);
97+
for (const auto &offset : offsets) {
98+
size_t i = q2i(c + offset * s);
99+
if (x % 2 == 0) {
100+
ps.xs[i] = 1;
101+
} else {
102+
ps.zs[i] = 1;
103+
}
104+
}
105+
stabilizers.push_back(ps);
106+
}
107+
}
108+
109+
size_t dep = 0;
110+
benchmark_go([&]() {
111+
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
112+
dep += t.xs[0].zs[0];
113+
}).goal_millis(5);
114+
if (dep == 99999999) {
115+
std::cout << "data dependence";
116+
}
117+
}

0 commit comments

Comments
 (0)