Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,14 @@
gradient dialect and the `lower-gradients` compilation stage.
[(#2241)](https://github.com/PennyLaneAI/catalyst/pull/2241)

* Added support for PPRs to the :func:`~.passes.merge_rotations` pass to merge PPRs with
equivalent angles, and cancelling of PPRs with opposite angles, or angles
that sum to identity. Also supports conditions on PPRs, merging when conditions are
identical and not merging otherwise.
[(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224)
* Added support for PPRs and arbitrary angle PPRs to the :func:`~.passes.merge_rotations` pass.
This pass now merges PPRs with equivalent angles, and cancels PPRs with opposite angles, or
angles that sum to identity when the angles are known. The pass also supports conditions on PPRs,
merging when conditions are identical and not merging otherwise.
[(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224)
[(#2245)](https://github.com/PennyLaneAI/catalyst/pull/2245)
[(#2254)](https://github.com/PennyLaneAI/catalyst/pull/2254)
[(#2258)](https://github.com/PennyLaneAI/catalyst/pull/2258)


* Refactor QEC tablegen files to separate QEC operations into a new `QECOp.td` file
Expand Down
27 changes: 27 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,33 @@ def circuit():
assert 'qec.ppr ["X", "Y", "Z"](2)' in ir_opt


@pytest.mark.usefixtures("use_capture")
def test_merge_rotation_arbitrary_angle_ppr():
"""Test that the merge_rotation pass correctly merges arbtirary angle PPRs."""

my_pipeline = [("pipe", ["quantum-compilation-stage"])]

@qml.qjit(pipelines=my_pipeline, target="mlir")
def test_merge_rotation_ppr_workflow():
@qml.transforms.merge_rotations
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x, y):
qml.PauliRot(x, pauli_word="ZY", wires=[0, 1])
qml.PauliRot(y, pauli_word="ZY", wires=[0, 1])

return circuit(2.6, 0.3)

ir = test_merge_rotation_ppr_workflow.mlir
ir_opt = test_merge_rotation_ppr_workflow.mlir_opt

assert 'transform.apply_registered_pass "merge-rotations"' in ir
assert "qec.ppr.arbitrary" in ir
assert "arith.addf" not in ir

assert "arith.addf" in ir_opt
assert 'qec.ppr.arbitrary ["Z", "Y"]' in ir_opt


def test_clifford_to_ppm():

pipe = [("pipe", ["quantum-compilation-stage"])]
Expand Down
124 changes: 102 additions & 22 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ using namespace mlir;
using namespace catalyst::quantum;
using namespace catalyst::qec;

static const mlir::StringSet<> fixedRotationsAndPhaseShiftsSet = {
static const StringSet<> fixedRotationsAndPhaseShiftsSet = {
"RX", "RY", "RZ", "PhaseShift", "CRX", "CRY", "CRZ", "ControlledPhaseShift"};
static const mlir::StringSet<> arbitraryRotationsSet = {"Rot", "CRot"};
static const StringSet<> arbitraryRotationsSet = {"Rot", "CRot"};

namespace {

// convertOpParamsToValues: helper function for extracting CustomOp parameters as mlir::Values
SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, PatternRewriter &rewriter)
// convertOpParamsToValues: helper function for extracting CustomOp parameters as Values
SmallVector<Value> convertOpParamsToValues(CustomOp &op, PatternRewriter &rewriter)
{
SmallVector<mlir::Value> values;
SmallVector<Value> values;
auto params = op.getParams();
for (auto param : params) {
values.push_back(param);
Expand All @@ -53,7 +53,7 @@ SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, PatternRewriter &
// getStaticValuesOrNothing: helper function for extracting Rot or CRot parameters as:
// - doubles, in case they are constant
// - std::nullopt, otherwise
std::array<std::optional<double>, 3> getStaticValuesOrNothing(const SmallVector<mlir::Value> values)
std::array<std::optional<double>, 3> getStaticValuesOrNothing(const SmallVector<Value> values)
{
assert(values.size() == 3 && "found Rot or CRot operation should have exactly 3 parameters");
auto staticValues = std::array<std::optional<double>, 3>{};
Expand Down Expand Up @@ -88,15 +88,14 @@ struct MergeRotationsRewritePattern : public OpRewritePattern<OpType> {
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

// Extract parameters of the op and its parent,
// promoting the parameters to mlir::Values if necessary
// promoting the parameters to Values if necessary
auto parentParams = convertOpParamsToValues(parentOp, rewriter);
auto params = convertOpParamsToValues(op, rewriter);

auto loc = op.getLoc();
SmallVector<mlir::Value> sumParams;
SmallVector<Value> sumParams;
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
mlir::Value sumParam =
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
Value sumParam = rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
sumParams.push_back(sumParam);
}
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
Expand All @@ -122,25 +121,25 @@ struct MergeRotationsRewritePattern : public OpRewritePattern<OpType> {
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();

// Extract parameters of the op and its parent,
// promoting the parameters to mlir::Values if necessary
// promoting the parameters to Values if necessary
auto parentParams = convertOpParamsToValues(parentOp, rewriter);
auto params = convertOpParamsToValues(op, rewriter);

// Parent params are ϕ1, θ1, and ω1
// Params are ϕ2, θ2, and ω2
mlir::Value phi1 = parentParams[0];
mlir::Value theta1 = parentParams[1];
mlir::Value omega1 = parentParams[2];
mlir::Value phi2 = params[0];
mlir::Value theta2 = params[1];
mlir::Value omega2 = params[2];
Value phi1 = parentParams[0];
Value theta1 = parentParams[1];
Value omega1 = parentParams[2];
Value phi2 = params[0];
Value theta2 = params[1];
Value omega2 = params[2];

auto [phi1Opt, theta1Opt, omega1Opt] = getStaticValuesOrNothing(parentParams);
auto [phi2Opt, theta2Opt, omega2Opt] = getStaticValuesOrNothing(params);

mlir::Value phiF;
mlir::Value thetaF;
mlir::Value omegaF;
Value phiF;
Value thetaF;
Value omegaF;

// TODO: should we use an epsilon for comparing doubles here?
bool omega1IsZero = omega1Opt.has_value() && omega1Opt.value() == 0.0;
Expand Down Expand Up @@ -289,7 +288,7 @@ struct MergeRotationsRewritePattern : public OpRewritePattern<OpType> {
omegaF = rewriter.create<arith::SubFOp>(loc, alphaF, betaF);
}

auto sumParams = SmallVector<mlir::Value>{phiF, thetaF, omegaF};
auto sumParams = SmallVector<Value>{phiF, thetaF, omegaF};
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
parentInQubits, op.getGateName(), false,
parentInCtrlQubits, parentInCtrlValues);
Expand Down Expand Up @@ -396,6 +395,86 @@ struct MergePPRRewritePattern : public OpRewritePattern<PPRotationOp> {
}
};

struct MergePPRArbitraryRewritePattern : public OpRewritePattern<PPRotationArbitraryOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(PPRotationArbitraryOp op,
PatternRewriter &rewriter) const override
{
ValueRange opInQubits = op.getInQubits();

Operation *definingOp = opInQubits[0].getDefiningOp();
if (!definingOp) {
return failure();
}

auto parentOp = dyn_cast<PPRotationArbitraryOp>(definingOp);
if (!parentOp) {
return failure();
}

// verify that parentOp is parent of all qubits
for (mlir::Value qubit : opInQubits) {
if (qubit.getDefiningOp() != parentOp) {
return failure();
}
}
ValueRange parentOpOutQubits = parentOp.getOutQubits();
if (parentOpOutQubits.size() != opInQubits.size()) {
return failure();
}

// When two rotations have permuted Pauli strings, we can still merge them, we just need to
// correctly re-map the inputs. This map stores the index of a qubit in parentOp's out
// qubits at the index it appears in op's in qubits.
SmallVector<unsigned> inverse_permutation;
for (auto qubit : opInQubits) {
inverse_permutation.push_back(cast<OpResult>(qubit).getResultNumber());
}

// check Pauli + qubit pairings
ArrayAttr opPauliProduct = op.getPauliProduct();
ArrayAttr parentOpPauliProduct = parentOp.getPauliProduct();
for (size_t i = 0; i < opInQubits.size(); i++) {
if (opPauliProduct[i] != parentOpPauliProduct[inverse_permutation[i]]) {
return failure();
}
}

// check same conditionals
mlir::Value opCondition = op.getCondition();
if (opCondition != parentOp.getCondition()) {
return failure();
}

Location loc = op.getLoc();

mlir::Value opRotation = op.getArbitraryAngle();
mlir::Value parentOpRotation = parentOp.getArbitraryAngle();
auto newAngleOp =
rewriter.create<arith::AddFOp>(loc, opRotation, parentOpRotation).getResult();

// We need to construct the Pauli string + inQubits for new op. The simplest way to ensure
// that permuted PPRs can merge correctly is to maintain output qubits order and permute
// input qubits
ValueRange parentOpInQubits = parentOp.getInQubits();
SmallVector<mlir::Value> newInQubits;
for (size_t i = 0; i < parentOpInQubits.size(); i++) {
newInQubits.push_back(parentOpInQubits[inverse_permutation[i]]);
}

auto mergeOp = rewriter.create<PPRotationArbitraryOp>(loc, parentOpOutQubits.getTypes(),
opPauliProduct, newAngleOp,
newInQubits, opCondition);

// replace and erase old ops
rewriter.replaceOp(op, mergeOp);
rewriter.eraseOp(parentOp);

return success();
}
};

struct MergeMultiRZRewritePattern : public OpRewritePattern<MultiRZOp> {
using OpRewritePattern<MultiRZOp>::OpRewritePattern;

Expand All @@ -421,7 +500,7 @@ struct MergeMultiRZRewritePattern : public OpRewritePattern<MultiRZOp> {
auto parentTheta = parentOp.getTheta();
auto theta = op.getTheta();

mlir::Value sumParam = rewriter.create<arith::AddFOp>(loc, parentTheta, theta).getResult();
Value sumParam = rewriter.create<arith::AddFOp>(loc, parentTheta, theta).getResult();

auto mergeOp = rewriter.create<MultiRZOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParam,
parentInQubits, nullptr, parentInCtrlQubits,
Expand All @@ -442,6 +521,7 @@ void populateMergeRotationsPatterns(RewritePatternSet &patterns)
patterns.add<MergeRotationsRewritePattern<CustomOp, CustomOp>>(patterns.getContext(), 1);
patterns.add<MergeMultiRZRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRArbitraryRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Quantum/Transforms/VerifyParentGateAnalysis.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
#include "Catalyst/IR/CatalystDialect.h"
#include "Quantum/IR/QuantumOps.h"

using namespace llvm;
using namespace mlir;
using namespace catalyst;

Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/merge_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase<MergeRotationsPass> {
&getContext());
catalyst::qec::PPRotationOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
catalyst::qec::PPRotationArbitraryOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}
Expand Down
Loading