diff --git a/CHANGELOG.md b/CHANGELOG.md index 91cf031f82..3b1dfddcc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added +- ✨ Add a `merge-single-qubit-rotation-gates` pass for merging consecutive rotation gates using quaternions ([#1407]) ([**@J4MMlE**]) - ✨ Add support for multi-controlled gates to ZX package ([#1380]) ([**@keefehuang**]) - ✨ Add Sampler and Estimator Primitives to the QDMI-Qiskit Interface ([#1507]) ([**@marcelwa**]) - ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637]) ([**@denialhaag**]) @@ -396,6 +397,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1413]: https://github.com/munich-quantum-toolkit/core/pull/1413 [#1412]: https://github.com/munich-quantum-toolkit/core/pull/1412 [#1411]: https://github.com/munich-quantum-toolkit/core/pull/1411 +[#1407]: https://github.com/munich-quantum-toolkit/core/pull/1407 [#1406]: https://github.com/munich-quantum-toolkit/core/pull/1406 [#1403]: https://github.com/munich-quantum-toolkit/core/pull/1403 [#1402]: https://github.com/munich-quantum-toolkit/core/pull/1402 @@ -554,6 +556,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [**@Ectras**]: https://github.com/Ectras [**@simon1hofmann**]: https://github.com/simon1hofmann [**@keefehuang**]: https://github.com/keefehuang +[**@J4MMlE**]: https://github.com/J4MMlE diff --git a/mlir/include/mlir/Compiler/CompilerPipeline.h b/mlir/include/mlir/Compiler/CompilerPipeline.h index f43bac99d7..d15585827f 100644 --- a/mlir/include/mlir/Compiler/CompilerPipeline.h +++ b/mlir/include/mlir/Compiler/CompilerPipeline.h @@ -40,6 +40,9 @@ struct QuantumCompilerConfig { /// Print IR after each stage bool printIRAfterAllStages = false; + + /// Disable quaternion-based single-qubit rotation gate merging + bool disableMergeSingleQubitRotationGates = false; }; /** diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 5d37e25612..edca59797e 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -11,6 +11,35 @@ include "mlir/Pass/PassBase.td" +def MergeSingleQubitRotationGates + : Pass<"merge-single-qubit-rotation-gates", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + "::mlir::math::MathDialect"]; + let summary = "Merge rotation gates using quaternion-based fusion"; + let description = [{ + Merges consecutive single-qubit rotation gates acting on the same qubit into a single equivalent U gate, reducing circuit depth and gate count. + + Supported gate types: `rx`, `ry`, `rz`, `p`, `r`, `u2`, `u`. + + The pass greedily collects the longest possible chain of consecutive mergeable gates. + Each gate is converted to a unit quaternion: + + - `rx`, `ry`, `rz`, `p`: single-axis rotations via half-angle formulas. + - `r(theta, phi)`: rotation by `theta` around axis `(cos(phi), sin(phi), 0)`. + - `u2(phi, lambda) = u(pi / 2, phi, lambda)`. + - `u(theta, phi, lambda)`: ZYZ decomposition `rz(phi) * ry(theta) * rz(lambda)`, each factor converted to a quaternion and merged via the Hamilton product. + + The gates are then folded one by one via the Hamilton product into a single quaternion, which is decomposed back into ZYZ Euler angles and emitted as a single `UOp`, representing the same rotation as the chain of single gates. + The global phase of each gate is tracked alongside and combined together. + + The emitted `UOp` is defined by $U = \exp [i (\phi + \lambda) / 2] R_z (\phi) R_y (\theta) R_z (\lambda)$. + Each merge emits a `GPhaseOp` carrying the accumulated input phase of the chain. + Because the synthesized `UOp` introduces an additional intrinsic phase of $(\phi + \lambda) / 2$, the `GPhaseOp` must compensate for it. + This applies even to chains composed entirely of $\mathrm{SU} (2)$ gates (`rx`, `ry`, `rz`, `r`) because the synthesis into a `UOp` still produces the intrinsic phase term. + }]; +} + //===----------------------------------------------------------------------===// // Transpilation Passes //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Compiler/CMakeLists.txt b/mlir/lib/Compiler/CMakeLists.txt index 66afb3515a..1dc45532fa 100644 --- a/mlir/lib/Compiler/CMakeLists.txt +++ b/mlir/lib/Compiler/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_library( MLIRQCToQCO MLIRQCOToQC MLIRQCToQIR + MLIRQCOTransforms MQT::MLIRSupport) mqt_mlir_target_use_project_options(MQTCompilerPipeline) diff --git a/mlir/lib/Compiler/CompilerPipeline.cpp b/mlir/lib/Compiler/CompilerPipeline.cpp index 0bab524e53..c66196a7d2 100644 --- a/mlir/lib/Compiler/CompilerPipeline.cpp +++ b/mlir/lib/Compiler/CompilerPipeline.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/QCOToQC/QCOToQC.h" #include "mlir/Conversion/QCToQCO/QCToQCO.h" #include "mlir/Conversion/QCToQIR/QCToQIR.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" #include "mlir/Support/Passes.h" #include "mlir/Support/PrettyPrinting.h" @@ -136,9 +137,11 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, } } // Stage 5: Optimization passes - // TODO: Add optimization passes - if (failed( - runStage([&](PassManager& pm) { populateQCOCleanupPipeline(pm); }))) { + if (failed(runStage([&](PassManager& pm) { + if (!config_.disableMergeSingleQubitRotationGates) { + pm.addPass(qco::createMergeSingleQubitRotationGates()); + } + }))) { return failure(); } if (record != nullptr && config_.recordIntermediates) { diff --git a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt index 2268584167..fc6ee74b9d 100644 --- a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt @@ -15,6 +15,8 @@ add_mlir_library( PRIVATE MLIRQCODialect MLIRQCOUtils + MLIRArithDialect + MLIRMathDialect DEPENDS MLIRQCOTransformsIncGen) diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MergeSingleQubitRotationGates.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MergeSingleQubitRotationGates.cpp new file mode 100644 index 0000000000..fd4b44f291 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MergeSingleQubitRotationGates.cpp @@ -0,0 +1,683 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_MERGESINGLEQUBITROTATIONGATES +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +namespace { + +/** + * @brief Pattern that merges consecutive rotation gates using quaternion + * multiplication. + */ +struct MergeSingleQubitRotationGatesPattern final + : OpInterfaceRewritePattern { + explicit MergeSingleQubitRotationGatesPattern(MLIRContext* context) + : OpInterfaceRewritePattern(context) {} + + /// Quaternion representation (w + xi + yj + zk) using MLIR Values. + struct Quaternion { + Value w; + Value x; + Value y; + Value z; + }; + + /// Axis of a single-axis rotation gate. + enum class RotationAxis : std::uint8_t { X, Y, Z }; + + /// Cached frequently-used constant Values. + struct Constants { + Value negOne; + Value zero; + Value one; + Value two; + Value eps; + Value pi; + }; + + /// Euler-angle triple for a U gate (theta, phi, lambda). + struct UOpAngles { + Value theta; + Value phi; + Value lambda; + }; + + /// Returns whether an operation is considered mergeable + static bool isMergeable(Operation* op) { + return isa(op); + } + + /// Checks if two gates a and b are mergeable via quaternion-based merging. + [[nodiscard]] static bool areQuaternionMergeable(Operation& a, Operation& b) { + return isMergeable(&a) && isMergeable(&b); + } + + /** + * @brief Returns the rotation axis for an RXOp, RYOp, or RZOp. + * + * @param op The operation to query + * @return The rotation axis, or std::nullopt if the operation is not + * RXOp, RYOp, or RZOp. + */ + static std::optional getRotationAxis(Operation* op) { + return llvm::TypeSwitch>(op) + .Case([](auto) { return RotationAxis::X; }) + .Case([](auto) { return RotationAxis::Y; }) + .Case([](auto) { return RotationAxis::Z; }) + .Default([](auto) { return std::nullopt; }); + } + + /** + * @brief Creates shared f64 arithmetic constants used throughout the pass. + * + * These constants are created once and reused across quaternion construction, + * Hamilton product, and Euler angle extraction to avoid redundant ops in the + * generated IR. + * + * @param loc Source location for the created operations + * @param rewriter Pattern rewriter for creating new operations + * @return A Constants struct with all pre-built constant ops + */ + static Constants createConstants(Location loc, PatternRewriter& rewriter) { + return { + .negOne = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(-1.0)), + .zero = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(0.0)), + .one = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(1.0)), + .two = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(2.0)), + // Tolerance for gimbal-lock detection in quaternion-to-Euler + // conversion. Value from reference implementation: + // https://github.com/evbernardes/quaternion_to_euler/blob/main/euler_from_quat.py + .eps = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(1e-12)), + .pi = arith::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64Type(), APFloat(std::numbers::pi)), + }; + } + + /** + * @brief Normalizes an angle to the range [-PI, PI]. + * + * Uses floor-based modular arithmetic: + * normalize(a) = a - floor((a + π) / 2π) * 2π + * + * @param angle The angle value to normalize + * @param loc Source location for the created operations + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return The normalized angle value + */ + static Value normalizeAngle(Value angle, Location loc, + const Constants& constants, + PatternRewriter& rewriter) { + auto twoPi = + arith::MulFOp::create(rewriter, loc, constants.two, constants.pi); + auto shifted = arith::AddFOp::create(rewriter, loc, angle, constants.pi); + auto divided = arith::DivFOp::create(rewriter, loc, shifted, twoPi); + auto floored = math::FloorOp::create(rewriter, loc, divided); + auto multiple = arith::MulFOp::create(rewriter, loc, floored, twoPi); + return arith::SubFOp::create(rewriter, loc, angle, multiple); + } + + /** + * @brief Converts a single-axis rotation to quaternion representation. + * + * Uses half-angle formulas: + * RX(a) = Q(cos(a/2), sin(a/2), 0, 0) + * RY(a) = Q(cos(a/2), 0, sin(a/2), 0) + * RZ(a) = Q(cos(a/2), 0, 0, sin(a/2)) + * + * @see + * https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + * @param angle The rotation angle + * @param axis The rotation axis (X, Y, or Z) + * @param loc Location in the IR + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the rotation + */ + static Quaternion createAxisQuaternion(Value angle, RotationAxis axis, + Location loc, + const Constants& constants, + PatternRewriter& rewriter) { + auto half = arith::DivFOp::create(rewriter, loc, angle, constants.two); + // cos(angle/2) + auto cos = math::CosOp::create(rewriter, loc, half); + // sin(angle/2) + auto sin = math::SinOp::create(rewriter, loc, half); + + switch (axis) { + case RotationAxis::X: + return {.w = cos, .x = sin, .y = constants.zero, .z = constants.zero}; + case RotationAxis::Y: + return {.w = cos, .x = constants.zero, .y = sin, .z = constants.zero}; + case RotationAxis::Z: + return {.w = cos, .x = constants.zero, .y = constants.zero, .z = sin}; + } + } + + /** + * @brief Converts a ZYZ Euler angle decomposition to quaternion. + * + * U(theta, phi, lambda) uses ZYZ decomposition: RZ(lambda) -> RY(theta) -> + * RZ(phi). + * + * When composing rotations, quaternion multiplication follows matrix + * multiplication order (right-to-left), which is the reverse of the + * application sequence: + * Sequential application: RZ(lambda), then RY(theta), then RZ(phi) + * Quaternion product: qPhi * qTheta * qLambda + * + * @note U is defined as P(phi)*RY(theta)*P(lambda), which equals + * e^{i*(phi+lambda)/2} * RZ(phi)*RY(theta)*RZ(lambda). + * Since quaternions represent SU(2), this pass works with the SU(2) part + * RZ(phi)*RY(theta)*RZ(lambda) and tracks the factored-out global phase + * (phi+lambda)/2 separately via globalPhaseOf. + * + * @param theta The Y-rotation angle + * @param phi The first Z-rotation angle + * @param lambda The second Z-rotation angle + * @param loc Location in the IR + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the ZYZ rotation + */ + static Quaternion quaternionFromZYZ(Value theta, Value phi, Value lambda, + Location loc, const Constants& constants, + PatternRewriter& rewriter) { + auto qTheta = + createAxisQuaternion(theta, RotationAxis::Y, loc, constants, rewriter); + auto qPhi = + createAxisQuaternion(phi, RotationAxis::Z, loc, constants, rewriter); + auto qLambda = + createAxisQuaternion(lambda, RotationAxis::Z, loc, constants, rewriter); + + // qPhi * qTheta * qLambda (multiplication in reverse order!) + auto temp = hamiltonProduct(qPhi, qTheta, loc, rewriter); + return hamiltonProduct(temp, qLambda, loc, rewriter); + } + + /** + * @brief Converts a UOp to quaternion representation. + * + * U(theta, phi, lambda) is decomposed via ZYZ Euler angles. + * + * @note Global phase is discarded; see quaternionFromZYZ for details. + * + * @param op The UOp to convert + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the UOp + */ + static Quaternion quaternionFromUOp(UOp op, const Constants& constants, + PatternRewriter& rewriter) { + return quaternionFromZYZ(op.getParameter(0), op.getParameter(1), + op.getParameter(2), op->getLoc(), constants, + rewriter); + } + + /** + * @brief Converts a U2Op to quaternion representation. + * + * U2(phi, lambda) = U(pi / 2, phi, lambda), using ZYZ decomposition with + * theta = pi / 2. + * + * @note Global phase is discarded; see quaternionFromZYZ for details. + * + * @param op The U2Op to convert + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the U2Op + */ + static Quaternion quaternionFromU2Op(U2Op op, const Constants& constants, + PatternRewriter& rewriter) { + auto loc = op->getLoc(); + auto piHalf = + arith::DivFOp::create(rewriter, loc, constants.pi, constants.two); + return quaternionFromZYZ(piHalf, op.getParameter(0), op.getParameter(1), + loc, constants, rewriter); + } + + /** + * @brief Converts an ROp to quaternion representation. + * + * R(theta, phi) represents a rotation by theta around axis + * (cos(phi), sin(phi), 0) in the XY plane: + * Q(cos(theta / 2), sin(theta / 2) * cos(phi), sin(theta / 2) * sin(phi), 0) + * + * @param op The ROp to convert + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the ROp + */ + static Quaternion quaternionFromROp(ROp op, const Constants& constants, + PatternRewriter& rewriter) { + auto loc = op->getLoc(); + auto theta = op.getParameter(0); + auto phi = op.getParameter(1); + + auto halfTheta = arith::DivFOp::create(rewriter, loc, theta, constants.two); + auto cosHalf = math::CosOp::create(rewriter, loc, halfTheta); + auto sinHalf = math::SinOp::create(rewriter, loc, halfTheta); + auto cosPhi = math::CosOp::create(rewriter, loc, phi); + auto sinPhi = math::SinOp::create(rewriter, loc, phi); + + auto x = arith::MulFOp::create(rewriter, loc, sinHalf, cosPhi); + auto y = arith::MulFOp::create(rewriter, loc, sinHalf, sinPhi); + + return {.w = cosHalf, .x = x, .y = y, .z = constants.zero}; + } + + /** + * @brief Converts a rotation gate to quaternion representation. + * + * @note Global phase is discarded; see quaternionFromZYZ for details. + * + * @param op The rotation gate to convert (RXOp, RYOp, RZOp, POp, ROp, U2Op, + * UOp) + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the rotation gate + */ + static Quaternion quaternionFromRotation(UnitaryOpInterface op, + const Constants& constants, + PatternRewriter& rewriter) { + // Single-axis rotations (RX, RY, RZ, P) share the same conversion pattern + if (auto axis = getRotationAxis(op.getOperation())) { + return createAxisQuaternion(op.getParameter(0), *axis, op->getLoc(), + constants, rewriter); + } + + // Multi-parameter gates each need their own conversion + return llvm::TypeSwitch(op.getOperation()) + .Case( + [&](ROp o) { return quaternionFromROp(o, constants, rewriter); }) + .Case( + [&](U2Op o) { return quaternionFromU2Op(o, constants, rewriter); }) + .Case( + [&](UOp o) { return quaternionFromUOp(o, constants, rewriter); }) + .Default([](auto) -> Quaternion { + llvm_unreachable("Unsupported operation type"); + }); + } + + /** + * @brief Returns the global phase contribution of a rotation gate. + * + * Rotation gates can be factored as U = e^{i * phase} * SU(2), where SU(2) + * is the quaternion-representable part and phase is the global phase. This + * function returns the global phase for each gate type: + * + * - RX, RY, RZ, R -> none (already SU(2), no global phase) + * - P(theta) -> theta / 2 (P = e^{i * theta / 2} * RZ(theta)) + * - U(theta, phi, lambda) -> (phi + lambda) / 2 + * - U2(phi, lambda) -> (phi + lambda) / 2 + * + * @param op The rotation gate to query + * @param constants Pre-created arithmetic constants + * @param loc Source location for created operations + * @param rewriter Pattern rewriter for creating new operations + * @return The global phase as a Value, or std::nullopt for SU(2) gates + */ + static std::optional globalPhaseOf(UnitaryOpInterface op, + const Constants& constants, + Location loc, + PatternRewriter& rewriter) { + return llvm::TypeSwitch>(op.getOperation()) + .Case( + [&](auto) -> std::optional { return std::nullopt; }) + .Case([&](auto) -> std::optional { + return arith::DivFOp::create(rewriter, loc, op.getParameter(0), + constants.two); + }) + .Case([&](auto) -> std::optional { + // phi is at different indexes for UOp and U2Op + auto phiIdx = isa(op.getOperation()) ? 1U : 0U; + auto sum = + arith::AddFOp::create(rewriter, loc, op.getParameter(phiIdx), + op.getParameter(phiIdx + 1)); + return arith::DivFOp::create(rewriter, loc, sum, constants.two); + }) + .Default([](auto) -> std::optional { + llvm_unreachable("Unsupported operation type"); + }); + } + + /** + * @brief Checks if this op is the start of a mergeable chain. + * + * A chain start is a mergeable op whose qubit input does NOT come from + * a chain-compatible predecessor. This ensures the greedy rewriter only + * triggers the rewrite at chain heads, building the maximal chain in one + * shot regardless of worklist order. + * + * @param op The operation to check + * @return True if this op is the start of a chain + */ + static bool isChainStart(UnitaryOpInterface op) { + if (!isMergeable(op.getOperation())) { + return false; + } + auto input = op.getInputQubit(0); + auto* defOp = input.getDefiningOp(); + return defOp == nullptr || + !areQuaternionMergeable(*defOp, *op.getOperation()); + } + + /** + * @brief Collects a chain of consecutive mergeable gates. + * + * Walks forward via single-use SSA edges. Breaks when the next operation is + * not considered as mergeable. + * + * @param start The chain head (must satisfy isChainStart) + * @return The chain of operations in circuit order (first applied to last) + */ + static SmallVector + collectChain(UnitaryOpInterface start) { + SmallVector chain = {start}; + auto current = start; + while (true) { + auto* userOp = *current->getUsers().begin(); + if (!areQuaternionMergeable(*current.getOperation(), *userOp)) { + break; + } + current = chain.emplace_back(cast(userOp)); + } + return chain; + } + + /** + * @brief Computes the Hamilton product of two quaternions (q1 * q2). + * + * For q1 = w1 + x1*i + y1*j + z1*k and q2 = w2 + x2*i + y2*j + z2*k: + * + * q1 * q2 = (w1w2 - x1x2 - y1y2 - z1z2) + * + (w1x2 + x1w2 + y1z2 - z1y2) * i + * + (w1y2 - x1z2 + y1w2 + z1x2) * j + * + (w1z2 + x1y2 - y1x2 + z1w2) * k + * + * @see https://en.wikipedia.org/wiki/Quaternion#Hamilton_product + * @param q1 The first quaternion + * @param q2 The second quaternion + * @param loc Location in the IR + * @param rewriter Pattern rewriter for creating new operations + * @return Product quaternion + */ + static Quaternion hamiltonProduct(Quaternion q1, Quaternion q2, Location loc, + PatternRewriter& rewriter) { + // wRes = w1w2 - x1x2 - y1y2 - z1z2 + auto w1w2 = arith::MulFOp::create(rewriter, loc, q1.w, q2.w); + auto x1x2 = arith::MulFOp::create(rewriter, loc, q1.x, q2.x); + auto y1y2 = arith::MulFOp::create(rewriter, loc, q1.y, q2.y); + auto z1z2 = arith::MulFOp::create(rewriter, loc, q1.z, q2.z); + auto wTemp1 = arith::SubFOp::create(rewriter, loc, w1w2, x1x2); + auto wTemp2 = arith::SubFOp::create(rewriter, loc, wTemp1, y1y2); + auto wRes = arith::SubFOp::create(rewriter, loc, wTemp2, z1z2); + + // xRes = w1x2 + x1w2 + y1z2 - z1y2 + auto w1x2 = arith::MulFOp::create(rewriter, loc, q1.w, q2.x); + auto x1w2 = arith::MulFOp::create(rewriter, loc, q1.x, q2.w); + auto y1z2 = arith::MulFOp::create(rewriter, loc, q1.y, q2.z); + auto z1y2 = arith::MulFOp::create(rewriter, loc, q1.z, q2.y); + auto xTemp1 = arith::AddFOp::create(rewriter, loc, w1x2, x1w2); + auto xTemp2 = arith::AddFOp::create(rewriter, loc, xTemp1, y1z2); + auto xRes = arith::SubFOp::create(rewriter, loc, xTemp2, z1y2); + + // yRes = w1y2 - x1z2 + y1w2 + z1x2 + auto w1y2 = arith::MulFOp::create(rewriter, loc, q1.w, q2.y); + auto x1z2 = arith::MulFOp::create(rewriter, loc, q1.x, q2.z); + auto y1w2 = arith::MulFOp::create(rewriter, loc, q1.y, q2.w); + auto z1x2 = arith::MulFOp::create(rewriter, loc, q1.z, q2.x); + auto yTemp1 = arith::SubFOp::create(rewriter, loc, w1y2, x1z2); + auto yTemp2 = arith::AddFOp::create(rewriter, loc, yTemp1, y1w2); + auto yRes = arith::AddFOp::create(rewriter, loc, yTemp2, z1x2); + + // zRes = w1z2 + x1y2 - y1x2 + z1w2 + auto w1z2 = arith::MulFOp::create(rewriter, loc, q1.w, q2.z); + auto x1y2 = arith::MulFOp::create(rewriter, loc, q1.x, q2.y); + auto y1x2 = arith::MulFOp::create(rewriter, loc, q1.y, q2.x); + auto z1w2 = arith::MulFOp::create(rewriter, loc, q1.z, q2.w); + auto zTemp1 = arith::AddFOp::create(rewriter, loc, w1z2, x1y2); + auto zTemp2 = arith::SubFOp::create(rewriter, loc, zTemp1, y1x2); + auto zRes = arith::AddFOp::create(rewriter, loc, zTemp2, z1w2); + + return {.w = wRes, .x = xRes, .y = yRes, .z = zRes}; + } + + /** + * @brief Extracts ZYZ Euler angles from a unit quaternion. + * + * For unit quaternion q = w + x * i + y * j + z * k, extracts UOp parameters: + * + * - alpha = atan2(z, w) + atan2(-x, y) + * - beta = acos(2 * (w^2 + z^2) - 1) + * - gamma = atan2(z, w) - atan2(-x, y) + * + * Based on Bernardes & Viollet (2022), simplified for unit quaternions and + * proper ZYZ Euler angles (Chapter 3.3): + * https://doi.org/10.1371/journal.pone.0276302 + * + * Reference implementation: + * https://github.com/evbernardes/quaternion_to_euler + * SymPy also implements this paper: + * https://docs.sympy.org/latest/modules/algebras.html#sympy.algebras.Quaternion.to_euler + * + * @note Floating-point errors may accumulate when merging many gates. + * @param q The quaternion to convert + * @param loc Source location for the created operations + * @param constants Pre-created arithmetic constants + * @param rewriter Pattern rewriter for creating new operations + * @return UOpAngles {theta, phi, lambda} suitable for UOp::create + */ + static UOpAngles anglesFromQuaternion(Quaternion q, Location loc, + const Constants& constants, + PatternRewriter& rewriter) { + // Calculate angle beta (for y-rotation) + // beta = acos(2 * (w^2 + z^2) - 1) + // NOTE: the term (2 * (w^2 + z^2) - 1) is clamped to [-1, 1], + // otherwise acos could produce NaN. + auto ww = arith::MulFOp::create(rewriter, loc, q.w, q.w); + auto zz = arith::MulFOp::create(rewriter, loc, q.z, q.z); + auto bTemp1 = arith::AddFOp::create(rewriter, loc, ww, zz); + auto bTemp2 = arith::MulFOp::create(rewriter, loc, constants.two, bTemp1); + auto bTemp3 = arith::SubFOp::create(rewriter, loc, bTemp2, constants.one); + auto clampedLow = + arith::MaximumFOp::create(rewriter, loc, bTemp3, constants.negOne); + auto clamped = + arith::MinimumFOp::create(rewriter, loc, clampedLow, constants.one); + auto beta = math::AcosOp::create(rewriter, loc, clamped); + + // intermediates to check for gimbal lock (|beta| and |beta - PI|) + auto absBeta = math::AbsFOp::create(rewriter, loc, beta); + auto betaMinusPi = arith::SubFOp::create(rewriter, loc, beta, constants.pi); + auto absBetaMinusPi = math::AbsFOp::create(rewriter, loc, betaMinusPi); + + // safe1 = beta not within boundary eps around 0: + // |beta| >= eps + auto safe1 = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, + absBeta, constants.eps); + // safe2 = beta not within boundary eps around PI: |beta - PI| >= eps + auto safe2 = arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE, + absBetaMinusPi, constants.eps); + // is safe (not in gimbal lock) when both hold (safe1 AND safe2) + auto safe = arith::AndIOp::create(rewriter, loc, safe1, safe2); + + // intermediate angles for z-rotations alpha and gamma + // theta+ = atan2(z, w) + // theta- = atan2(-x, y) + auto xMinus = arith::NegFOp::create(rewriter, loc, q.x); + auto thetaPlus = math::Atan2Op::create(rewriter, loc, q.z, q.w); + auto thetaMinus = math::Atan2Op::create(rewriter, loc, xMinus, q.y); + + // intermediate angles for gimbal lock cases + // twoTheta+ = 2 * theta+ + // twoTheta- = 2 * theta- + auto twoThetaPlus = + arith::MulFOp::create(rewriter, loc, constants.two, thetaPlus); + auto twoThetaMinus = + arith::MulFOp::create(rewriter, loc, constants.two, thetaMinus); + + // Safe Case (no gimbal lock): + // alphaSafe = theta+ + theta- + // gammaSafe = theta+ - theta- + auto alphaSafe = + arith::AddFOp::create(rewriter, loc, thetaPlus, thetaMinus); + auto gammaSafe = + arith::SubFOp::create(rewriter, loc, thetaPlus, thetaMinus); + + // Unsafe Case (gimbal lock): + // when beta = 0 then alpha = 2 * (atan2(z, w)) + // when beta = PI then alpha = 2 * (atan2(-x, y)) + // gamma is set to zero in both cases + auto alphaUnsafe = arith::SelectOp::create(rewriter, loc, safe1, + twoThetaMinus, twoThetaPlus); + + // choose correct alpha and gamma whether safe or not + auto alpha = + arith::SelectOp::create(rewriter, loc, safe, alphaSafe, alphaUnsafe); + auto gamma = + arith::SelectOp::create(rewriter, loc, safe, gammaSafe, constants.zero); + + // normalize alpha and gamma to [-PI, PI] since they are sums/differences + // of atan2 results and can exceed that range + auto alphaNorm = normalizeAngle(alpha, loc, constants, rewriter); + auto gammaNorm = normalizeAngle(gamma, loc, constants, rewriter); + + return {.theta = beta.getResult(), .phi = alphaNorm, .lambda = gammaNorm}; + } + + /** + * @brief Matches and merges a chain of consecutive rotation gates. + * + * Detects the full chain of mergeable operations, folds their quaternions + * via Hamilton product, and emits a single UOp. + * + * @param op The operation to match (only chain heads trigger the rewrite) + * @param rewriter Pattern rewriter for applying transformations + * @return success() if operations were merged, failure() otherwise + */ + LogicalResult matchAndRewrite(UnitaryOpInterface op, + PatternRewriter& rewriter) const override { + if (!isChainStart(op)) { + return failure(); + } + + auto chain = collectChain(op); + if (chain.size() < 2) { + return failure(); + } + + // Emit all helper ops at the chain tail so the merged UOp is placed + // adjacent to the last gate it replaces. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(chain.back().getOperation()); + + auto loc = op->getLoc(); + auto constants = createConstants(loc, rewriter); + + // Initialize accumulators from the first operation + auto qAccum = quaternionFromRotation(chain.front(), constants, rewriter); + auto phaseAccum = globalPhaseOf(chain.front(), constants, loc, rewriter); + + // Fold remaining operations via Hamilton product + for (auto chainOp : llvm::drop_begin(chain)) { + auto qi = quaternionFromRotation(chainOp, constants, rewriter); + qAccum = hamiltonProduct(qi, qAccum, loc, rewriter); + + if (auto phase = globalPhaseOf(chainOp, constants, loc, rewriter)) { + phaseAccum = phaseAccum ? Value(arith::AddFOp::create( + rewriter, loc, *phaseAccum, *phase)) + : phase; + } + + // Bypass each tail operation + rewriter.replaceOp(chainOp, chainOp.getInputQubit(0)); + } + + // Extract Euler angles from merged quaternion + auto [theta, phi, lambda] = + anglesFromQuaternion(qAccum, loc, constants, rewriter); + + // Emit global phase correction: + // The synthesized UOp carries an intrinsic phase + // outPhase = (phi + lambda) / 2 that must always be compensated. + // correction = totalInputPhase - outPhase + auto phiPlusLambda = arith::AddFOp::create(rewriter, loc, phi, lambda); + auto outPhase = + arith::DivFOp::create(rewriter, loc, phiPlusLambda, constants.two); + auto inputPhase = phaseAccum.value_or(constants.zero); + auto correction = + arith::SubFOp::create(rewriter, loc, inputPhase, outPhase); + GPhaseOp::create(rewriter, loc, correction.getResult()); + + // Replace the head operation with the merged UOp + rewriter.replaceOpWithNewOp( + chain.front(), chain.front().getInputQubit(0), theta, phi, lambda); + + return success(); + } +}; + +/** + * @brief Pass that merges consecutive rotation gates using quaternion + * multiplication. + */ +struct MergeSingleQubitRotationGates final + : impl::MergeSingleQubitRotationGatesBase { + using impl::MergeSingleQubitRotationGatesBase< + MergeSingleQubitRotationGates>::MergeSingleQubitRotationGatesBase; + +protected: + void runOnOperation() override { + auto op = getOperation(); + auto* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(patterns.getContext()); + + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco diff --git a/mlir/tools/mqt-cc/mqt-cc.cpp b/mlir/tools/mqt-cc/mqt-cc.cpp index 76848b6594..1a49fea64d 100644 --- a/mlir/tools/mqt-cc/mqt-cc.cpp +++ b/mlir/tools/mqt-cc/mqt-cc.cpp @@ -71,6 +71,11 @@ static cl::opt cl::desc("Print IR after each compiler stage"), cl::init(false)); +static cl::opt disableMergeSingleQubitRotationGates( + "disable-merge-single-qubit-rotation-gates", + cl::desc("Disable quaternion-based single-qubit rotation gate merging"), + cl::init(false)); + /** * @brief Load and parse a .qasm file */ @@ -165,6 +170,8 @@ int main(int argc, char** argv) { config.enableTiming = enableTiming; config.enableStatistics = enableStatistics; config.printIRAfterAllStages = printIRAfterAllStages; + config.disableMergeSingleQubitRotationGates = + disableMergeSingleQubitRotationGates; // Run the compilation pipeline CompilationRecord record; diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 9cde9043a7..0549b1f4ab 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -117,9 +117,12 @@ class CompilerPipelineTest } static void runPipeline(const mlir::ModuleOp module, const bool convertToQIR, + const bool disableMergeSingleQubitRotationGates, mlir::CompilationRecord& record) { mlir::QuantumCompilerConfig config; config.convertToQIR = convertToQIR; + config.disableMergeSingleQubitRotationGates = + disableMergeSingleQubitRotationGates; config.recordIntermediates = true; config.printIRAfterAllStages = true; @@ -163,7 +166,7 @@ TEST_P(CompilerPipelineTest, EndToEndPipeline) { EXPECT_TRUE(mlir::verify(*module).succeeded()); mlir::CompilationRecord record; - runPipeline(module.get(), testCase.convertToQIR, record); + runPipeline(module.get(), testCase.convertToQIR, false, record); ASSERT_TRUE(testCase.qcReferenceBuilder); auto qcReference = buildQCReference(testCase.qcReferenceBuilder); @@ -189,6 +192,33 @@ TEST_P(CompilerPipelineTest, EndToEndPipeline) { } } +/** + * @brief Test: Rotation merging pass is invoked during the optimization stage + * + * @details + * The merged U gate parameters are computed via floating-point arithmetic + * that is not bit-identical across platforms, so we cannot use + * verifyAllStages with hardcoded expected values. Instead, we run the + * pipeline once with the pass enabled and compare afterQCOCanon against + * afterOptimization to verify the pass transformed the IR. + * Correctness of the pass is tested in a dedicated test. + */ +TEST_F(CompilerPipelineTest, RotationGateMergingPass) { + auto module = mlir::qc::QCProgramBuilder::build( + context.get(), [&](mlir::qc::QCProgramBuilder& b) { + auto q = b.allocQubit(); + b.rz(1.0, q); + b.rx(1.0, q); + }); + ASSERT_TRUE(module); + + mlir::CompilationRecord record; + runPipeline(module.get(), false, false, record); + + // The outputs must differ, proving the pass ran and transformed the IR + EXPECT_NE(record.afterQCOCanon, record.afterOptimization); +} + INSTANTIATE_TEST_SUITE_P( QuantumComputationPipelineProgramsTest, CompilerPipelineTest, testing::Values( diff --git a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt index 30ddc4dc38..9f9b03449d 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(Mapping) +add_subdirectory(Optimizations) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt index 1ced00dc34..405159075a 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt @@ -11,8 +11,8 @@ add_executable(${target_name} test_mapping.cpp) target_link_libraries( ${target_name} - PRIVATE MLIRParser - GTest::gtest_main + PRIVATE GTest::gtest_main + MLIRParser MLIRQCProgramBuilder MLIRQCOProgramBuilder MLIRQCOUtils diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt new file mode 100644 index 0000000000..73606c2efb --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(target_name mqt-core-mlir-unittest-optimizations) +add_executable(${target_name} test_qco_merge_single_qubit_rotation.cpp) + +target_link_libraries( + ${target_name} + PRIVATE GTest::gtest_main + MLIRQCOProgramBuilder + MLIRQCOTransforms + MLIRQCOUtils + MLIRParser + MLIRIR + MLIRPass + MLIRSupport + LLVMSupport) + +mqt_mlir_configure_unittest_target(${target_name}) + +gtest_discover_tests(${target_name} PROPERTIES LABELS mqt-mlir-unittests DISCOVERY_TIMEOUT 60) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/compute_expected_merge_single_qubit_rotation.py b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/compute_expected_merge_single_qubit_rotation.py new file mode 100644 index 0000000000..16f8818cda --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/compute_expected_merge_single_qubit_rotation.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +"""Compute expected U gate parameters and global phase for the merge-single-qubit-rotation pass. + +Reference script for test_qco_merge_single_qubit_rotation.cpp. +Uses SymPy quaternion algebra to produce ground-truth values. +""" + +import math + +from sympy import N, Quaternion, cos, pi, sin + + +def r_gate(theta: float, phi: float) -> Quaternion: + """Return the SU(2) quaternion for an R(theta, phi) gate.""" + return Quaternion(cos(theta / 2), sin(theta / 2) * cos(phi), sin(theta / 2) * sin(phi), 0) + + +def u2_gate(phi: float, lam: float) -> Quaternion: + """Return the SU(2) quaternion for a U2(phi, lambda) gate.""" + return Quaternion.from_euler([phi, pi / 2, lam], "ZYZ") + + +def normalize_angle(a: float) -> float: + """Normalize angle to [-pi, pi], matching the pass's normalizeAngle. + + Returns: + Angle in the range [-pi, pi]. + """ + two_pi = 2 * math.pi + return a - math.floor((a + math.pi) / two_pi) * two_pi + + +def angles_from_quaternion(w: float, x: float, y: float, z: float) -> tuple[float, float, float]: + """ZYZ Euler angles from quaternion, matching anglesFromQuaternion in the pass. + + Returns: + Tuple (theta, phi, lambda) in ZYZ convention. + """ + eps = 1e-12 + + # Clamp before acos to guard against floating-point drift outside [-1, 1] + arg = 2 * (w * w + z * z) - 1 + arg = max(-1.0, min(1.0, arg)) + beta = math.acos(arg) + + abs_beta = abs(beta) + abs_beta_minus_pi = abs(beta - math.pi) + + safe1 = abs_beta >= eps # not near 0 + safe2 = abs_beta_minus_pi >= eps # not near pi + safe = safe1 and safe2 + + theta_plus = math.atan2(z, w) + theta_minus = math.atan2(-x, y) + + if safe: + alpha = theta_plus + theta_minus + gamma = theta_plus - theta_minus + elif not safe1: + # beta near 0 + alpha = 2 * theta_plus + gamma = 0.0 + else: + # beta near pi + alpha = 2 * theta_minus + gamma = 0.0 + + alpha = normalize_angle(alpha) + gamma = normalize_angle(gamma) + + # U gate convention: theta=beta, phi=alpha, lambda=gamma + return beta, alpha, gamma + + +def global_phase(gate_type: str, *angles: float) -> float: + """Return the global phase contribution of a gate. + + U = e^{i*phase} * SU(2), this returns 'phase'. + + Returns: + Global phase in radians. + + Raises: + ValueError: If gate_type is not a recognized gate. + """ + if gate_type in {"RX", "RY", "RZ", "R"}: + return 0.0 + if gate_type == "P": + return angles[0] / 2 + if gate_type == "U": + # U(theta, phi, lambda): phase = (phi + lambda) / 2 + _theta, phi, lam = angles + return (phi + lam) / 2 + if gate_type == "U2": + # U2(phi, lambda): phase = (phi + lambda) / 2 + phi, lam = angles + return (phi + lam) / 2 + msg = f"Unknown gate type: {gate_type!r}" + raise ValueError(msg) + + +def output_phase(phi: float, lam: float) -> float: + """Return the intrinsic phase of the synthesized U(theta, phi, lambda). + + Returns: + Phase in radians. + """ + return (phi + lam) / 2 + + +def gphase_correction(input_phase: float, phi: float, lam: float) -> float: + """Return the GPhaseOp correction = total_input_phase - output_UOp_phase. + + Returns: + Phase correction in radians. + """ + return input_phase - output_phase(phi, lam) + + +# ---- Helper to compute merge + gphase for a chain ---- +def compute_merge(chain: list[tuple]) -> tuple[float, float, float, float]: + """Merge a chain of gates into a single U gate with global phase. + + chain: list of (gate_type, quaternion, *angles). + + Uses our own Euler extraction that matches the C++ pass exactly: + no quaternion sign normalization, same atan2/acos/clamp logic, + same gimbal-lock handling, same angle normalization. + + Returns: + Tuple (theta, phi, lambda, gphase) all as floats. + """ + _, q0, *a0 = chain[0] + q = q0 + total_input_phase = global_phase(chain[0][0], *a0) + + for entry in chain[1:]: + gt, qi, *ai = entry + q = qi.mul(q) # Hamilton product in circuit order + total_input_phase += global_phase(gt, *ai) + + # Extract Euler angles matching the pass (no sign normalization) + w, x, y, z = float(N(q.a)), float(N(q.b)), float(N(q.c)), float(N(q.d)) + theta, phi, lam = angles_from_quaternion(w, x, y, z) + + corr = gphase_correction(total_input_phase, phi, lam) + + return theta, phi, lam, float(N(corr)) + + +# ---- Build gates ---- +rx = Quaternion.from_euler([1, 0, 0], "xyz") +ry = Quaternion.from_euler([0, 1, 0], "xyz") +rz = Quaternion.from_euler([0, 0, 1], "xyz") +mx = Quaternion.from_euler([-1, 0, 0], "xyz") +my = Quaternion.from_euler([0, -1, 0], "xyz") +mz = Quaternion.from_euler([0, 0, -1], "xyz") +px = Quaternion.from_euler([pi, 0, 0], "xyz") +py = Quaternion.from_euler([0, pi, 0], "xyz") +pz = Quaternion.from_euler([0, 0, pi], "xyz") +smallx = Quaternion.from_euler([0.001, 0, 0], "xyz") +smally = Quaternion.from_euler([0, 0.001, 0], "xyz") + +# P gate has same SU(2) quaternion as RZ +p1 = Quaternion.from_euler([0, 0, 1], "xyz") # P(1) same rotation as RZ(1) + +u1 = Quaternion.from_euler([2, 1, 3], "ZYZ") # U(1,2,3) +u2 = Quaternion.from_euler([5, 4, 6], "ZYZ") # U(4,5,6) + +u2_12 = u2_gate(1, 2) +u2_34 = u2_gate(3, 4) + +r12 = r_gate(1, 2) +r34 = r_gate(3, 4) +r11 = r_gate(1, 1) + +cases = [ + ("RX+RX", [("RX", rx), ("RX", rx)]), + ("RX+RY", [("RX", rx), ("RY", ry)]), + ("RX+RZ", [("RX", rx), ("RZ", rz)]), + ("RY+RX", [("RY", ry), ("RX", rx)]), + ("RY+RY", [("RY", ry), ("RY", ry)]), + ("RY+RZ", [("RY", ry), ("RZ", rz)]), + ("RZ+RX", [("RZ", rz), ("RX", rx)]), + ("RZ+RY", [("RZ", rz), ("RY", ry)]), + ("RZ+RZ", [("RZ", rz), ("RZ", rz)]), + ("U+U", [("U", u1, 1.0, 2.0, 3.0), ("U", u2, 4.0, 5.0, 6.0)]), + ("P+RX", [("P", p1, 1.0), ("RX", rx)]), + ("R+R", [("R", r12, 1.0, 2.0), ("R", r34, 3.0, 4.0)]), + ("U2+U2", [("U2", u2_12, 1.0, 2.0), ("U2", u2_34, 3.0, 4.0)]), + ("RZ+RY+RX pi", [("RZ", pz), ("RY", py), ("RX", px)]), + ("RY+RZ+RZ-+RY-", [("RY", ry), ("RZ", rz), ("RZ", mz), ("RY", my)]), + ("small RX+RY", [("RX", smallx), ("RY", smally)]), + ("RX(pi)+RY(pi)", [("RX", px), ("RY", py)]), + ("R+R same", [("R", r11, 1.0, 1.0), ("R", r11, 1.0, 1.0)]), +] + +if __name__ == "__main__": + for name, chain in cases: + theta, phi, lam, gphase = compute_merge(chain) + print(f"{name}: U({theta}, {phi}, {lam}) gphase={gphase}") diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_merge_single_qubit_rotation.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_merge_single_qubit_rotation.cpp new file mode 100644 index 0000000000..5ab06957f0 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_merge_single_qubit_rotation.cpp @@ -0,0 +1,810 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace mlir; +using namespace mlir::qco; + +/// A constant for the value of \f$\pi\f$. +constexpr double PI = std::numbers::pi; + +class MergeSingleQubitRotationGatesTest : public ::testing::Test { +protected: + MLIRContext context; + QCOProgramBuilder builder; + OwningOpRef module; + + enum class GateType : std::uint8_t { RX, RY, RZ, P, R, U2, U }; + /** + * @brief Struct to easily construct a rotation gate inline. + * opName uses the getOperationName() mnemonic. + */ + struct RotationGate { + GateType type; + llvm::SmallVector angles; + }; + + MergeSingleQubitRotationGatesTest() : builder(&context) {} + + void SetUp() override { + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + builder.initialize(); + } + + /** + * @brief Counts the amount of operations the current module/circuit contains. + */ + template int countOps() { + int count = 0; + module->walk([&](OpTy) { ++count; }); + return count; + } + + /** + * @brief Extract constant floating point value from a mlir::Value + */ + static std::optional toDouble(mlir::Value v) { + if (auto constOp = v.getDefiningOp()) { + if (auto floatAttr = + mlir::dyn_cast(constOp.getValue())) { + return floatAttr.getValueAsDouble(); + } + } + return std::nullopt; + } + + /** + * @brief Find the first occurrence of a u-gate in the current module and get + * the numeric value of its parameters. This assumes that parameters are + * constant and can be extracted. + */ + std::optional> getUGateParams() { + UOp uOp = nullptr; + module->walk([&](UOp op) { + uOp = op; + // stop after finding first UOp + return mlir::WalkResult::interrupt(); + }); + + if (!uOp) { + return std::nullopt; + } + + auto theta = toDouble(uOp.getTheta()); + auto phi = toDouble(uOp.getPhi()); + auto lambda = toDouble(uOp.getLambda()); + + if (!theta || !phi || !lambda) { + return std::nullopt; + } + + return std::make_tuple(*theta, *phi, *lambda); + } + + /** + * @brief Gets the first u-gate of a module and tests whether its angle + * parameters are equal to the expected ones. + */ + void expectUGateParams(double expectedTheta, double expectedPhi, + double expectedLambda, double tolerance = 1e-8) { + auto params = getUGateParams(); + ASSERT_TRUE(params.has_value()); + + auto [theta, phi, lambda] = *params; + EXPECT_NEAR(theta, expectedTheta, tolerance); + EXPECT_NEAR(phi, expectedPhi, tolerance); + EXPECT_NEAR(lambda, expectedLambda, tolerance); + } + + /** + * @brief Find the first occurrence of a gphase op in the current module and + * get the numeric value of its parameter. + */ + std::optional getGPhaseParam() { + GPhaseOp gOp = nullptr; + module->walk([&](GPhaseOp op) { + gOp = op; + return mlir::WalkResult::interrupt(); + }); + + if (!gOp) { + return std::nullopt; + } + + return toDouble(gOp.getParameter(0)); + } + + /** + * @brief Gets the first gphase op of a module and tests whether its angle + * parameter is equal to the expected one. + */ + void expectGPhaseParam(double expected, double tolerance = 1e-8) { + auto param = getGPhaseParam(); + ASSERT_TRUE(param.has_value()); + EXPECT_NEAR(*param, expected, tolerance); + } + + Value buildRotations(llvm::ArrayRef rotations, Value& q) { + auto qubit = q; + + for (const auto& gate : rotations) { + switch (gate.type) { + case GateType::RX: + assert(gate.angles.size() == 1 && "RXOp requires 1 angle parameter"); + qubit = builder.rx(gate.angles[0], qubit); + break; + case GateType::RY: + assert(gate.angles.size() == 1 && "RYOp requires 1 angle parameter"); + qubit = builder.ry(gate.angles[0], qubit); + break; + case GateType::RZ: + assert(gate.angles.size() == 1 && "RZOp requires 1 angle parameter"); + qubit = builder.rz(gate.angles[0], qubit); + break; + case GateType::P: + assert(gate.angles.size() == 1 && "POp requires 1 angle parameter"); + qubit = builder.p(gate.angles[0], qubit); + break; + case GateType::R: + assert(gate.angles.size() == 2 && "ROp requires 2 angle parameters"); + qubit = builder.r(gate.angles[0], gate.angles[1], qubit); + break; + case GateType::U2: + assert(gate.angles.size() == 2 && "U2Op requires 2 angle parameters"); + qubit = builder.u2(gate.angles[0], gate.angles[1], qubit); + break; + case GateType::U: + assert(gate.angles.size() == 3 && "UOp requires 3 angle parameters"); + qubit = + builder.u(gate.angles[0], gate.angles[1], gate.angles[2], qubit); + break; + } + } + + return qubit; + } + + /** + * @brief Takes a list of rotation gates (rx, ry, rz and u) and uses the + * builder api to build a small quantum circuit, where a qubit is fed through + * all rotations in the list. + */ + LogicalResult testGateMerge(llvm::ArrayRef rotations) { + auto q = builder.allocQubitRegister(1); + + buildRotations(rotations, q[0]); + + module = builder.finalize(); + return runMergePass(module.get()); + } + + /** + * @brief Adds the mergeRotationGates Pass to the current context and runs it. + */ + static LogicalResult runMergePass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(qco::createMergeSingleQubitRotationGates()); + return pm.run(module); + } +}; + +} // namespace + +// Note: All expected values are computed using the reference script +// compute_expected_merge_single_qubit_rotation.py in this directory, which uses +// SymPy's quaternion algebra: +// https://docs.sympy.org/latest/modules/algebras.html#module-sympy.algebras.Quaternion + +// ################################################## +// # Two Gate Merging Tests +// ################################################## + +/** + * @brief Test: RX->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRXRXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {1.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RX->RY should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRXRYGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {1.}}, + {.type = GateType::RY, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1.27455578230629, -1.07542903757622, 0.495367289218673); + expectGPhaseParam(0.290030874178775); +} + +/** + * @brief Test: RX->RZ should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRXRZGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {1.}}, + {.type = GateType::RZ, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1., -0.570796326794897, 1.57079632679490); + expectGPhaseParam(-0.5); +} + +/** + * @brief Test: RY->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRYRXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RY, .angles = {1.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1.27455578230629, -0.495367289218673, 1.07542903757622); + expectGPhaseParam(-0.290030874178775); +} + +/** + * @brief Test: RY->RY should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRYRYGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RY, .angles = {1.}}, + {.type = GateType::RY, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RY->RZ should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRYRZGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RY, .angles = {1.}}, + {.type = GateType::RZ, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1., 1., 0.); + expectGPhaseParam(-0.5); +} + +/** + * @brief Test: RZ->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRZRXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RZ, .angles = {1.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1., -1.57079632679490, 2.57079632679490); + expectGPhaseParam(-0.5); +} + +/** + * @brief Test: RZ->RY should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRZRYGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RZ, .angles = {1.}}, + {.type = GateType::RY, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1., 0., 1.); + expectGPhaseParam(-0.5); +} + +/** + * @brief Test: RZ->RZ should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRZRZGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RZ, .angles = {1.}}, + {.type = GateType::RZ, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: U->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeUUGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::U, .angles = {4., 5., 6.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + expectUGateParams(2.03289042623884, 0.663830775701153, 0.849231441867857); + expectGPhaseParam(7.243468891215494); +} + +/** + * @brief Test: U->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeURXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: U->RY should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeURYGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::RY, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: U->RZ should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeURZGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::RZ, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RX->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRXUGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {1.}}, + {.type = GateType::U, .angles = {1., 2., 3.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RY->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRYUGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RY, .angles = {1.}}, + {.type = GateType::U, .angles = {1., 2., 3.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RZ->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRZUGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RZ, .angles = {1.}}, + {.type = GateType::U, .angles = {1., 2., 3.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); +} +/** + * @brief Test: P->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergePRXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::P, .angles = {1.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1., -1.57079632679490, 2.57079632679490); + expectGPhaseParam(1.11022302462516e-16); +} + +/** + * @brief Test: P->RY should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergePRYGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::P, .angles = {1.}}, + {.type = GateType::RY, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: P->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergePUGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::P, .angles = {1.}}, + {.type = GateType::U, .angles = {1., 2., 3.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: R->RX should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRRXGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::R, .angles = {1., 1.}}, + {.type = GateType::RX, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: P->P should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergePPGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::P, .angles = {1.}}, + {.type = GateType::P, .angles = {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: R->R should merge into a single U gate (same multi-parameter + * type always uses quaternion merge) + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeRRGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::R, .angles = {1., 2.}}, + {.type = GateType::R, .angles = {3., 4.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + expectUGateParams(2.07770669385131, 1.36334275733332, 2.85969871348886); + expectGPhaseParam(-2.1115207354110845); +} + +/** + * @brief Test: U2->U should merge into a single U gate + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeU2UGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U2, .angles = {1., 2.}}, + {.type = GateType::U, .angles = {1., 2., 3.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: U2->U2 should merge into a single U gate (same multi-parameter + * type always uses quaternion merge) + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeU2U2Gates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U2, .angles = {1., 2.}}, + {.type = GateType::U2, .angles = {3., 4.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(1.85840734641021, 1.42920367320511, 0.429203673205103); + expectGPhaseParam(4.070796326794897); +} + +// ################################################## +// # Not Merging Tests +// ################################################## + +/** + * @brief Test: single RX should not convert to U + */ +TEST_F(MergeSingleQubitRotationGatesTest, noMergeSingleRXGate) { + ASSERT_TRUE( + testGateMerge({{.type = GateType::RX, .angles = {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: single RY should not convert to U + */ +TEST_F(MergeSingleQubitRotationGatesTest, noMergeSingleRYGate) { + ASSERT_TRUE( + testGateMerge({{.type = GateType::RY, .angles = {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: single RZ should not convert to U + */ +TEST_F(MergeSingleQubitRotationGatesTest, noMergeSingleRZGate) { + ASSERT_TRUE( + testGateMerge({{.type = GateType::RZ, .angles = {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: Gates on different qubits should not merge + */ +TEST_F(MergeSingleQubitRotationGatesTest, dontMergeGatesFromDifferentQubits) { + auto q = builder.allocQubitRegister(2); + + builder.rx(1.0, q[0]); + builder.ry(1.0, q[1]); + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: Non-consecutive gates should not merge + */ +TEST_F(MergeSingleQubitRotationGatesTest, dontMergeNonConsecutiveGates) { + auto q = builder.allocQubitRegister(1); + + auto q1 = builder.rx(1.0, q[0]); + auto q2 = builder.h(q1); + builder.ry(1.0, q2); + + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +// ################################################## +// # Greedy Merging Tests +// ################################################## + +/** + * @brief Test: Many gates should greedily merge into one U + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeManyGates) { + ASSERT_TRUE(testGateMerge({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::RX, .angles = {1.}}, + {.type = GateType::RY, .angles = {2.}}, + {.type = GateType::RZ, .angles = {3.}}, + {.type = GateType::U, .angles = {4., 5., 6.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: Many gates with one unmergeable in between should merge into two + * U with the unmergeable in between. + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeManyWithUnmergeable) { + auto reg = builder.allocQubitRegister(1); + auto q = reg[0]; + q = buildRotations({{.type = GateType::U, .angles = {1., 2., 3.}}, + {.type = GateType::RX, .angles = {1.}}, + {.type = GateType::RY, .angles = {2.}}, + {.type = GateType::RZ, .angles = {3.}}}, + q); + q = builder.h(q); + q = buildRotations({{.type = GateType::RZ, .angles = {4.}}, + {.type = GateType::RY, .angles = {5.}}, + {.type = GateType::RX, .angles = {6.}}, + {.type = GateType::U, .angles = {4., 5., 6.}}}, + q); + + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 2); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 2); +} + +// ################################################## +// # Special Cases Tests +// ################################################## + +/** + * @brief Test: Consecutive gates with another gate in between should merge + */ +TEST_F(MergeSingleQubitRotationGatesTest, mergeConsecutiveWithGateInBetween) { + auto q = builder.allocQubitRegister(2); + + auto q1 = builder.rx(1.0, q[0]); + builder.h(q[1]); + builder.ry(1.0, q1); + + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); +} + +// ################################################## +// # Numerical Correctness +// ################################################## + +/** + * @brief Test: RZ(PI)->RY(PI)->RX(PI) should merge into U(0, 0, 0) + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalRotationIdentity) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RZ, .angles = {PI}}, + {.type = GateType::RY, .angles = {PI}}, + {.type = GateType::RX, .angles = {PI}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(0., 0., 0.); + expectGPhaseParam(0.); +} + +/** + * @brief Test: RY(1)->RZ(1)->RZ(-1)->RY(-1) should merge into U(0, 0, 0) + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalRotationIdentity2) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RY, .angles = {1}}, + {.type = GateType::RZ, .angles = {1}}, + {.type = GateType::RZ, .angles = {-1}}, + {.type = GateType::RY, .angles = {-1}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(0., 0., 0.); + expectGPhaseParam(0.); +} + +/** + * @brief Test: RX(0.001)->RY(0.001) should merge into U(0.00141421344452194, + * -0.785398413397490, 0.785397913397407) + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalSmallAngles) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {0.001}}, + {.type = GateType::RY, .angles = {0.001}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(0.00141421344452194, -0.785398413397490, 0.785397913397407); + expectGPhaseParam(2.50000041668308e-7); +} + +/** + * @brief Test: RX(PI)->RY(PI) should merge into U(0, -PI, 0.) + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalGimbalLock) { + ASSERT_TRUE(testGateMerge({{.type = GateType::RX, .angles = {PI}}, + {.type = GateType::RY, .angles = {PI}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(0., -PI, 0.); + expectGPhaseParam(1.57079632679490); +} + +/** + * @brief Test: R(1,1)->R(1,1) (same axis) should merge into U(2.00000000000000, + * -0.570796326794897, 0.570796326794897) + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalAccuracyRRSameAxis) { + ASSERT_TRUE(testGateMerge({{.type = GateType::R, .angles = {1., 1.}}, + {.type = GateType::R, .angles = {1., 1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + expectUGateParams(2., -0.570796326794897, 0.570796326794897); + expectGPhaseParam(0.0); +} + +/** + * @brief Test: U(0, -2.0360075460227076, 0)->U(0, 4.157656961105587, 0) should + * not produce NaN. These specific numbers would produce NaN if acos parameter + * would not be clamped to [-1, 1] + */ +TEST_F(MergeSingleQubitRotationGatesTest, numericalAcosClampingPreventsNaN) { + ASSERT_TRUE(testGateMerge( + {{.type = GateType::U, .angles = {0, -2.0360075460227076, 0}}, + {.type = GateType::U, .angles = {0, 4.157656961105587, 0}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + + auto params = getUGateParams(); + ASSERT_TRUE(params.has_value()); + + auto [theta, phi, lambda] = *params; + EXPECT_FALSE(std::isnan(theta)); + EXPECT_FALSE(std::isnan(phi)); + EXPECT_FALSE(std::isnan(lambda)); + + auto gphase = getGPhaseParam(); + ASSERT_TRUE(gphase.has_value()); + EXPECT_FALSE(std::isnan(*gphase)); +} diff --git a/pyproject.toml b/pyproject.toml index 586456c41f..762f3aca0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -233,6 +233,7 @@ known-first-party = ["mqt.core"] "test/python/**" = ["T20", "INP001"] "docs/**" = ["T20", "INP001"] "noxfile.py" = ["T20", "TID251", "PLC0415"] +"mlir/unittests/**/*.py" = ["INP001", "T201"] "*.pyi" = ["D418", "E501", "PYI021"] "*.ipynb" = [ "D", # pydocstyle