Skip to content

Commit 1cdf45f

Browse files
authored
Speed up stim.Tableau.from_stabilizers another 10x (#714)
- 144 qubit case went from 5000 micros to 500 micros - 576 qubit case went from 2500 millis to 200 millis
1 parent 4f1d217 commit 1cdf45f

File tree

2 files changed

+146
-27
lines changed

2 files changed

+146
-27
lines changed

src/stim/stabilizers/conversions.inl

+95-25
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,18 @@ Tableau<W> stabilizers_to_tableau(
558558
num_qubits = std::max(num_qubits, e.num_qubits);
559559
}
560560

561+
simd_bit_table<W> buf_xs(num_qubits, stabilizers.size());
562+
simd_bit_table<W> buf_zs(num_qubits, stabilizers.size());
563+
simd_bits<W> buf_signs(stabilizers.size());
564+
simd_bits<W> buf_workspace(stabilizers.size());
565+
for (size_t k = 0; k < stabilizers.size(); k++) {
566+
memcpy(buf_xs[k].u8, stabilizers[k].xs.u8, stabilizers[k].xs.num_u8_padded());
567+
memcpy(buf_zs[k].u8, stabilizers[k].zs.u8, stabilizers[k].zs.num_u8_padded());
568+
buf_signs[k] = stabilizers[k].sign;
569+
}
570+
buf_xs = buf_xs.transposed();
571+
buf_zs = buf_zs.transposed();
572+
561573
for (size_t k1 = 0; k1 < stabilizers.size(); k1++) {
562574
for (size_t k2 = k1 + 1; k2 < stabilizers.size(); k2++) {
563575
if (!stabilizers[k1].ref().commutes(stabilizers[k2])) {
@@ -572,38 +584,28 @@ Tableau<W> stabilizers_to_tableau(
572584
}
573585
}
574586
Circuit elimination_instructions;
575-
PauliString<W> buf(num_qubits);
576587

577588
size_t used = 0;
578-
for (const auto &e : stabilizers) {
579-
if (e.num_qubits == num_qubits) {
580-
buf = e;
581-
} else {
582-
buf.xs.clear();
583-
buf.zs.clear();
584-
memcpy(buf.xs.u8, e.xs.u8, e.xs.num_u8_padded());
585-
memcpy(buf.zs.u8, e.zs.u8, e.zs.num_u8_padded());
586-
buf.sign = e.sign;
587-
}
588-
buf.ref().do_circuit(elimination_instructions);
589-
589+
for (size_t k = 0; k < stabilizers.size(); k++) {
590590
// Find a non-identity term in the Pauli string past the region used by other stabilizers.
591591
size_t pivot;
592592
for (pivot = used; pivot < num_qubits; pivot++) {
593-
if (buf.xs[pivot] || buf.zs[pivot]) {
593+
if (buf_xs[pivot][k] || buf_zs[pivot][k]) {
594594
break;
595595
}
596596
}
597597

598598
// Check for incompatible / redundant stabilizers.
599599
if (pivot == num_qubits) {
600-
if (buf.xs.not_zero()) {
601-
throw std::invalid_argument("Some of the given stabilizers anticommute.");
600+
for (size_t q = 0; q < num_qubits; q++) {
601+
if (buf_xs[q][k]) {
602+
throw std::invalid_argument("Some of the given stabilizers anticommute.");
603+
}
602604
}
603-
if (buf.sign) {
605+
if (buf_signs[k]) {
604606
throw std::invalid_argument("Some of the given stabilizers contradict each other.");
605607
}
606-
if (!allow_redundant && buf.zs.not_zero()) {
608+
if (!allow_redundant) {
607609
throw std::invalid_argument(
608610
"Didn't specify allow_redundant=True but one of the given stabilizers is a product of the others. "
609611
"To allow redundant stabilizers, pass the argument allow_redundant=True.");
@@ -612,21 +614,86 @@ Tableau<W> stabilizers_to_tableau(
612614
}
613615

614616
// Change pivot basis to the Z axis.
615-
if (buf.xs[pivot]) {
616-
GateType g = buf.zs[pivot] ? GateType::H_YZ : GateType::H;
617+
if (buf_xs[pivot][k]) {
618+
GateType g = buf_zs[pivot][k] ? GateType::H_YZ : GateType::H;
617619
GateTarget t = GateTarget::qubit(pivot);
618620
CircuitInstruction instruction{g, {}, &t};
619621
elimination_instructions.safe_append(instruction);
620-
buf.ref().do_instruction(instruction);
622+
size_t q = pivot;
623+
switch (g) {
624+
case GateType::H_YZ:
625+
buf_xs[q] ^= buf_zs[q];
626+
buf_workspace = buf_zs[q];
627+
buf_workspace.invert_bits();
628+
buf_workspace &= buf_xs[q];
629+
buf_signs ^= buf_workspace;
630+
break;
631+
case GateType::H:
632+
buf_xs[q].swap_with(buf_zs[q]);
633+
buf_workspace = buf_zs[q];
634+
buf_workspace &= buf_xs[q];
635+
buf_signs ^= buf_workspace;
636+
break;
637+
default:
638+
throw std::invalid_argument("Unrecognized gate type.");
639+
}
621640
}
641+
622642
// Cancel other terms in Pauli string.
623643
for (size_t q = 0; q < num_qubits; q++) {
624-
int p = buf.xs[q] + buf.zs[q] * 2;
644+
int p = buf_xs[q][k] + buf_zs[q][k] * 2;
625645
if (p && q != pivot) {
626646
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(q)};
627-
CircuitInstruction instruction{p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY, {}, targets};
647+
GateType g = p == 1 ? GateType::XCX : p == 2 ? GateType::XCZ : GateType::XCY;
648+
CircuitInstruction instruction{g, {}, targets};
628649
elimination_instructions.safe_append(instruction);
629-
buf.ref().do_instruction(instruction);
650+
size_t q1 = targets[0].qubit_value();
651+
size_t q2 = targets[1].qubit_value();
652+
simd_bits_range_ref<W> x1 = buf_xs[q1];
653+
simd_bits_range_ref<W> z1 = buf_zs[q1];
654+
simd_bits_range_ref<W> x2 = buf_xs[q2];
655+
simd_bits_range_ref<W> z2 = buf_zs[q2];
656+
switch (g) {
657+
case GateType::XCX:
658+
buf_workspace = x1;
659+
buf_workspace ^= x2;
660+
buf_workspace &= z1;
661+
buf_workspace &= z2;
662+
buf_signs ^= buf_workspace;
663+
x1 ^= z2;
664+
x2 ^= z1;
665+
break;
666+
case GateType::XCY:
667+
x1 ^= x2;
668+
x1 ^= z2;
669+
x2 ^= z1;
670+
z2 ^= z1;
671+
buf_workspace = x1;
672+
buf_workspace |= x2;
673+
buf_workspace.invert_bits();
674+
buf_workspace &= z1;
675+
buf_workspace &= z2;
676+
buf_signs ^= buf_workspace;
677+
buf_workspace = z2;
678+
buf_workspace.invert_bits();
679+
buf_workspace &= z1;
680+
buf_workspace &= x1;
681+
buf_workspace &= x2;
682+
buf_signs ^= buf_workspace;
683+
break;
684+
case GateType::XCZ:
685+
z2 ^= z1;
686+
x1 ^= x2;
687+
buf_workspace = z2;
688+
buf_workspace ^= x1;
689+
buf_workspace.invert_bits();
690+
buf_workspace &= x2;
691+
buf_workspace &= z1;
692+
buf_signs ^= buf_workspace;
693+
break;
694+
default:
695+
throw std::invalid_argument("Unrecognized gate type.");
696+
}
630697
}
631698
}
632699

@@ -635,13 +702,16 @@ Tableau<W> stabilizers_to_tableau(
635702
std::array<GateTarget, 2> targets{GateTarget::qubit(pivot), GateTarget::qubit(used)};
636703
CircuitInstruction instruction{GateType::SWAP, {}, targets};
637704
elimination_instructions.safe_append(instruction);
705+
buf_xs[pivot].swap_with(buf_xs[used]);
706+
buf_zs[pivot].swap_with(buf_zs[used]);
638707
}
639708

640709
// Fix sign.
641-
if (buf.sign) {
710+
if (buf_signs[k]) {
642711
GateTarget t = GateTarget::qubit(used);
643712
CircuitInstruction instruction{GateType::X, {}, &t};
644713
elimination_instructions.safe_append(instruction);
714+
buf_signs ^= buf_zs[used];
645715
}
646716

647717
used++;

src/stim/stabilizers/conversions.perf.cc

+51-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ BENCHMARK(independent_to_disjoint_xyz_errors) {
6868
}
6969
}
7070

71-
BENCHMARK(stabilizers_to_tableau) {
71+
BENCHMARK(stabilizers_to_tableau_144) {
7272
std::vector<std::complex<float>> offsets{
7373
{1, 0},
7474
{-1, 0},
@@ -110,7 +110,56 @@ BENCHMARK(stabilizers_to_tableau) {
110110
benchmark_go([&]() {
111111
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
112112
dep += t.xs[0].zs[0];
113-
}).goal_millis(5);
113+
}).goal_micros(500);
114+
if (dep == 99999999) {
115+
std::cout << "data dependence";
116+
}
117+
}
118+
119+
120+
BENCHMARK(stabilizers_to_tableau_576) {
121+
std::vector<std::complex<float>> offsets{
122+
{1, 0},
123+
{-1, 0},
124+
{0, 1},
125+
{0, -1},
126+
{3, 6},
127+
{-6, 3},
128+
};
129+
size_t w = 24*4;
130+
size_t h = 12*4;
131+
132+
auto normalize = [&](std::complex<float> c) -> std::complex<float> {
133+
return {fmodf(c.real() + w*10, w), fmodf(c.imag() + h*10, h)};
134+
};
135+
auto q2i = [&](std::complex<float> c) -> size_t {
136+
c = normalize(c);
137+
return (int)c.real() / 2 + c.imag() * (w / 2);
138+
};
139+
140+
std::vector<stim::PauliString<64>> stabilizers;
141+
for (size_t x = 0; x < w; x++) {
142+
for (size_t y = x % 2; y < h; y += 2) {
143+
std::complex<float> s{x % 2 ? -1.0f : +1.0f, 0.0f};
144+
std::complex<float> c{(float)x, (float)y};
145+
stim::PauliString<64> ps(w * h / 2);
146+
for (const auto &offset : offsets) {
147+
size_t i = q2i(c + offset * s);
148+
if (x % 2 == 0) {
149+
ps.xs[i] = 1;
150+
} else {
151+
ps.zs[i] = 1;
152+
}
153+
}
154+
stabilizers.push_back(ps);
155+
}
156+
}
157+
158+
size_t dep = 0;
159+
benchmark_go([&]() {
160+
Tableau<64> t = stabilizers_to_tableau(stabilizers, true, true, false);
161+
dep += t.xs[0].zs[0];
162+
}).goal_millis(200);
114163
if (dep == 99999999) {
115164
std::cout << "data dependence";
116165
}

0 commit comments

Comments
 (0)