diff --git a/mlir/include/RTIO/CMakeLists.txt b/mlir/include/RTIO/CMakeLists.txt index f33061b2d8..9f57627c32 100644 --- a/mlir/include/RTIO/CMakeLists.txt +++ b/mlir/include/RTIO/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/RTIO/Transforms/CMakeLists.txt b/mlir/include/RTIO/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..7c02d3fb47 --- /dev/null +++ b/mlir/include/RTIO/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name RTIO) +add_public_tablegen_target(MLIRRTIOPassIncGen) +add_mlir_doc(Passes RTIOPasses ./ -gen-pass-doc) diff --git a/mlir/include/RTIO/Transforms/Passes.h b/mlir/include/RTIO/Transforms/Passes.h new file mode 100644 index 0000000000..b76c90d256 --- /dev/null +++ b/mlir/include/RTIO/Transforms/Passes.h @@ -0,0 +1,29 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Pass/Pass.h" + +#include "RTIO/IR/RTIODialect.h" + +namespace catalyst { +namespace rtio { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "RTIO/Transforms/Passes.h.inc" + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/include/RTIO/Transforms/Passes.td b/mlir/include/RTIO/Transforms/Passes.td new file mode 100644 index 0000000000..ea8ad1151b --- /dev/null +++ b/mlir/include/RTIO/Transforms/Passes.td @@ -0,0 +1,44 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef RTIO_PASSES +#define RTIO_PASSES + +include "mlir/Pass/PassBase.td" + +def RTIOEventToARTIQPass : Pass<"convert-rtio-event-to-artiq", "mlir::ModuleOp"> { + let summary = "Convert RTIO event-based operations directly to ARTIQ-compatible LLVM IR"; + let description = [{ + Mapping: + - rtio.pulse -> at_mu(operand) DDS config + TTL on + delay + TTL off + now_mu() + - rtio.sync -> maximum the timestamp of the input events + at_mu() + now_mu() + - rtio.channel -> channel ID + - rtio.empty -> just return now_mu() + + And this pass generates LLVM IR that directly calls ARTIQ runtime functions + }]; + + let dependentDialects = [ + "rtio::RTIODialect", + "mlir::LLVM::LLVMDialect", + "mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::scf::SCFDialect", + "mlir::func::FuncDialect" + ]; +} + +#endif // RTIO_PASSES + + diff --git a/mlir/include/RTIO/Transforms/Patterns.h b/mlir/include/RTIO/Transforms/Patterns.h new file mode 100644 index 0000000000..357dafd794 --- /dev/null +++ b/mlir/include/RTIO/Transforms/Patterns.h @@ -0,0 +1,34 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "RTIO/IR/RTIOOps.h" + +namespace catalyst { +namespace rtio { +void populateRTIOToARTIQConversionPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns); +void populateRTIORewritePatterns(mlir::RewritePatternSet &patterns); +void populateRTIOSyncSimplifyPatterns(mlir::RewritePatternSet &patterns); +void populateRTIOPulseDecomposePatterns(mlir::RewritePatternSet &patterns); + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/include/RegisterAllPasses.h b/mlir/include/RegisterAllPasses.h index 496cc2f593..3a8d3e3119 100644 --- a/mlir/include/RegisterAllPasses.h +++ b/mlir/include/RegisterAllPasses.h @@ -21,6 +21,7 @@ #include "Mitigation/Transforms/Passes.h" #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" +#include "RTIO/Transforms/Passes.h" #include "Test/Transforms/Passes.h" #include "hlo-extensions/Transforms/Passes.h" @@ -36,6 +37,7 @@ inline void registerAllPasses() mitigation::registerMitigationPasses(); qec::registerQECPasses(); quantum::registerQuantumPasses(); + rtio::registerRTIOPasses(); test::registerTestPasses(); } diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 3ffd467987..c556a04a76 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -45,6 +45,7 @@ set(LIBS MLIRIon ion-transforms MLIRRTIO + rtio-transforms MLIRCatalystTest ${ENZYME_LIB} ) diff --git a/mlir/lib/RTIO/CMakeLists.txt b/mlir/lib/RTIO/CMakeLists.txt index f33061b2d8..9f57627c32 100644 --- a/mlir/lib/RTIO/CMakeLists.txt +++ b/mlir/lib/RTIO/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp new file mode 100644 index 0000000000..456ce84920 --- /dev/null +++ b/mlir/lib/RTIO/Transforms/ARTIQRuntimeBuilder.hpp @@ -0,0 +1,421 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" + +#include "Catalyst/Utils/EnsureFunctionDeclaration.h" +#include "RTIO/IR/RTIOOps.h" // For ConfigAttr + +namespace catalyst { +namespace rtio { + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ARTIQ Function Names +//===----------------------------------------------------------------------===// + +namespace ARTIQFuncNames { +constexpr StringLiteral setFrequency = "__rtio_set_frequency"; +constexpr StringLiteral secToMu = "__rtio_sec_to_mu"; +constexpr StringLiteral configSpi = "__rtio_config_spi"; +constexpr StringLiteral nowMu = "now_mu"; +constexpr StringLiteral atMu = "at_mu"; +constexpr StringLiteral delayMu = "delay_mu"; +constexpr StringLiteral rtioOutput = "rtio_output"; +constexpr StringLiteral rtioInit = "rtio_init"; +constexpr StringLiteral rtioGetCounter = "rtio_get_counter"; +constexpr StringLiteral kernel = "__kernel__"; +} // namespace ARTIQFuncNames + +//===----------------------------------------------------------------------===// +// ARTIQ Hardware Configuration +//===----------------------------------------------------------------------===// + +namespace ARTIQHardwareConfig { +constexpr double nanosecondPeriod = 1e-9; +constexpr double ftwScaleFactor = 4.294967296; // 2^32 / 1e9 +constexpr double powScaleFactor = 65536.0; // 2^16 +constexpr int32_t maxAmplitude = 0x3FFF; // 14-bit max ASF +constexpr int32_t profile7Instruction = 0x15000000; // 0x0E (Profile 0) + 7 = 0x15 +constexpr int64_t initSlackDelay = 125000; +constexpr int64_t freqSetSlackDelay = 10000; // 1e-5s in mu +constexpr int32_t spiDiv = 2; // SPI divider: ARTIQ standard is div=2 for fast transfers +constexpr int32_t spiLen8 = 8; +constexpr int32_t spiLen32 = 32; +constexpr int32_t spiFlagsKeepCS = 8; // SPI_CS_POLARITY (CS low to listen) +constexpr int32_t spiFlagsReleaseCS = 10; // SPI_CS_POLARITY | SPI_END (CS high to release) +constexpr int64_t ioUpdatePulseWidth = 8; +constexpr int64_t refPeriodMu = 8; // RTIO reference period (Kasli = 8ns @ 125MHz RTIO clock) +constexpr int64_t minTTLPulseMu = 8; // Minimum TTL pulse duration to avoid 0 duration events +} // namespace ARTIQHardwareConfig + +//===----------------------------------------------------------------------===// +// ARTIQ Runtime Builder +//===----------------------------------------------------------------------===// + +/// Helper class for building ARTIQ runtime function calls. +class ARTIQRuntimeBuilder { + public: + ARTIQRuntimeBuilder(OpBuilder &builder, Operation *contextOp) + : builder(builder), contextOp(contextOp), ctx(builder.getContext()), + i32Ty(IntegerType::get(ctx, 32)), i64Ty(IntegerType::get(ctx, 64)), + f64Ty(Float64Type::get(ctx)), voidTy(LLVM::LLVMVoidType::get(ctx)) + { + } + + // Timing management + Value nowMu() + { + auto func = ensureFunc(ARTIQFuncNames::nowMu, LLVM::LLVMFunctionType::get(i64Ty, {})); + auto call = builder.create(getLoc(), func, ValueRange{}); + call.setTailCallKind(LLVM::TailCallKind::Tail); + return call.getResult(); + } + + void atMu(Value time) + { + auto func = ensureFunc(ARTIQFuncNames::atMu, LLVM::LLVMFunctionType::get(voidTy, {i64Ty})); + auto call = builder.create(getLoc(), func, ValueRange{time}); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + + void delayMu(Value duration) + { + auto func = + ensureFunc(ARTIQFuncNames::delayMu, LLVM::LLVMFunctionType::get(voidTy, {i64Ty})); + auto call = builder.create(getLoc(), func, ValueRange{duration}); + call.setCConv(LLVM::CConv::Fast); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + + // RTIO operations + void rtioOutput(Value addr, Value val) + { + auto func = ensureFunc(ARTIQFuncNames::rtioOutput, + LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i32Ty})); + auto call = builder.create(getLoc(), func, ValueRange{addr, val}); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + + void rtioInit() + { + auto func = ensureFunc(ARTIQFuncNames::rtioInit, LLVM::LLVMFunctionType::get(voidTy, {})); + auto call = builder.create(getLoc(), func, ValueRange{}); + call.setCConv(LLVM::CConv::Fast); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + + Value rtioGetCounter() + { + auto func = + ensureFunc(ARTIQFuncNames::rtioGetCounter, LLVM::LLVMFunctionType::get(i64Ty, {})); + auto call = builder.create(getLoc(), func, ValueRange{}); + call.setCConv(LLVM::CConv::Fast); + call.setTailCallKind(LLVM::TailCallKind::Tail); + return call.getResult(); + } + + // Duration conversion + Value secToMu(Value durationSec) + { + ensureSecToMuFunc(); + auto func = getModule().lookupSymbol(ARTIQFuncNames::secToMu); + auto call = builder.create(getLoc(), func, ValueRange{durationSec}); + call.setCConv(LLVM::CConv::Fast); + call.setTailCallKind(LLVM::TailCallKind::Tail); + return call.getResult(); + } + + // SPI configuration + void configSpi(Value baseAddr, Value cs, Value len, Value div, Value flags) + { + ensureConfigSpiFunc(); + auto func = getModule().lookupSymbol(ARTIQFuncNames::configSpi); + auto call = + builder.create(getLoc(), func, ValueRange{baseAddr, cs, len, div, flags}); + call.setCConv(LLVM::CConv::Fast); + call.setTailCallKind(LLVM::TailCallKind::Tail); + } + + // Wait for SPI transmission to complete. + // ARTIQ formula: ref_period_mu * ((length + 1) * div + 1) + void waitForSpi(int32_t len, int32_t div) + { + int64_t duration = + ARTIQHardwareConfig::refPeriodMu * ((static_cast(len) + 1) * div + 1); + delayMu(constI64(duration)); + } + + // Frequency setting (continuous phase mode) + Value setFrequency(Value channelId, Value freqHz, Value phaseTurns, Value amplitude) + { + ensureSetFrequencyFunc(); + auto func = getModule().lookupSymbol(ARTIQFuncNames::setFrequency); + builder.create(getLoc(), func, + ValueRange{channelId, freqHz, phaseTurns, amplitude}); + return nowMu(); + } + + // TTL operations + void ttlOn(Value channelAddr) { rtioOutput(channelAddr, constI32(1)); } + + void ttlOff(Value channelAddr) { rtioOutput(channelAddr, constI32(0)); } + + // Constant creation helpers + Value constI32(int32_t val) + { + return builder.create(getLoc(), builder.getI32IntegerAttr(val)); + } + + Value constI64(int64_t val) + { + return builder.create(getLoc(), builder.getI64IntegerAttr(val)); + } + + Value constF64(double val) + { + return builder.create(getLoc(), builder.getF64FloatAttr(val)); + } + + // Accessors + Type getI32Type() const { return i32Ty; } + Type getI64Type() const { return i64Ty; } + Location getLoc() const { return contextOp->getLoc(); } + ModuleOp getModule() const { return contextOp->getParentOfType(); } + + /// Ensure all ARTIQ helper functions are defined in the module. + /// This should be called before lowering patterns that depend on these functions. + void ensureHelperFunctions() + { + ensureSecToMuFunc(); + ensureConfigSpiFunc(); + ensureSetFrequencyFunc(); + } + + private: + OpBuilder &builder; + Operation *contextOp; + MLIRContext *ctx; + Type i32Ty, i64Ty, f64Ty, voidTy; + + LLVM::LLVMFuncOp ensureFunc(StringRef name, LLVM::LLVMFunctionType funcTy) + { + PatternRewriter rewriter = PatternRewriter(builder.getContext()); + return catalyst::ensureFunctionDeclaration(rewriter, contextOp, name, funcTy); + } + + void ensureSecToMuFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::secToMu)) { + return; + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(i64Ty, {f64Ty}); + auto func = builder.create(getLoc(), ARTIQFuncNames::secToMu, funcTy, + LLVM::Linkage::Internal); + func.setCConv(LLVM::CConv::Fast); + + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + Value durationSec = entry->getArgument(0); + + // duration_mu = round(duration_sec / 1e-9) + Value nsPerMu = constF64(ARTIQHardwareConfig::nanosecondPeriod); + Value durationNs = builder.create(getLoc(), durationSec, nsPerMu); + Value rounded = builder.create(getLoc(), durationNs); + Value result = builder.create(getLoc(), i64Ty, rounded); + builder.create(getLoc(), result); + } + + void ensureConfigSpiFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::configSpi)) { + return; + } + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i32Ty, i32Ty, i32Ty, i32Ty, i32Ty}); + auto func = builder.create(getLoc(), ARTIQFuncNames::configSpi, funcTy, + LLVM::Linkage::Internal); + + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + Value baseAddr = entry->getArgument(0); + Value cs = entry->getArgument(1); + Value len = entry->getArgument(2); + Value div = entry->getArgument(3); + Value flags = entry->getArgument(4); + + // Config register address = Base | 1 + Value configAddr = builder.create(getLoc(), baseAddr, constI32(1)); + + // Pack: (CS << 24) | ((div - 2) << 16) | ((len - 1) << 8) | flags + Value csShifted = builder.create(getLoc(), cs, constI32(24)); + Value divOffset = builder.create(getLoc(), div, constI32(2)); + Value divShifted = builder.create(getLoc(), divOffset, constI32(16)); + Value lenOffset = builder.create(getLoc(), len, constI32(1)); + Value lenShifted = builder.create(getLoc(), lenOffset, constI32(8)); + + Value packed = builder.create(getLoc(), csShifted, divShifted); + packed = builder.create(getLoc(), packed, lenShifted); + packed = builder.create(getLoc(), packed, flags); + + rtioOutput(configAddr, packed); + builder.create(getLoc(), ValueRange{}); + } + + void ensureSetFrequencyFunc() + { + auto module = getModule(); + if (module.lookupSymbol(ARTIQFuncNames::setFrequency)) { + return; + } + + ensureConfigSpiFunc(); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + auto funcTy = LLVM::LLVMFunctionType::get(voidTy, {i32Ty, f64Ty, f64Ty, f64Ty}); + auto func = builder.create(getLoc(), ARTIQFuncNames::setFrequency, funcTy, + LLVM::Linkage::Internal); + + Block *entry = func.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + // channelId here is the DDS channel index (0, 1, 2, 3) for the Urukul + Value channelId = entry->getArgument(0); + Value freqHz = entry->getArgument(1); + Value phaseTurns = entry->getArgument(2); + Value amplitude = entry->getArgument(3); + + // Get hardware configuration from module + auto [spiBaseAddr, csBase, ioUpdateAddr] = getHardwareAddresses(module); + + // CS calculation: csBase is the chip_select for ch0 (typically 4) + // For Urukul: ch0->CS=4, ch1->CS=5, ch2->CS=6, ch3->CS=7 + // So CS = csBase + channelId + Value cs = builder.create(getLoc(), constI32(csBase), channelId); + Value spiBase = constI32(spiBaseAddr); + Value ioUpdate = constI32(ioUpdateAddr); + + // Calculate FTW: round(frequency * (2^32 / sys_clk)) + Value ftwScale = constF64(ARTIQHardwareConfig::ftwScaleFactor); + Value ftwDouble = builder.create(getLoc(), freqHz, ftwScale); + Value ftwRounded = builder.create(getLoc(), ftwDouble); + Value ftw = builder.create(getLoc(), i32Ty, ftwRounded); + + // Calculate POW: round(phaseTurns * 65536) + Value powScale = constF64(ARTIQHardwareConfig::powScaleFactor); + Value powDouble = builder.create(getLoc(), phaseTurns, powScale); + Value powRounded = builder.create(getLoc(), powDouble); + Value pow = builder.create(getLoc(), i32Ty, powRounded); + + // SPI Transfer: Write instruction to profile 7 (0x15) + configSpi(spiBase, cs, constI32(ARTIQHardwareConfig::spiLen8), + constI32(ARTIQHardwareConfig::spiDiv), + constI32(ARTIQHardwareConfig::spiFlagsKeepCS)); + delayMu(constI64(ARTIQHardwareConfig::refPeriodMu)); + rtioOutput(spiBase, constI32(ARTIQHardwareConfig::profile7Instruction)); + // Wait for SPI transmission to complete + waitForSpi(ARTIQHardwareConfig::spiLen8, ARTIQHardwareConfig::spiDiv); + + // SPI Transfer: Write amplitude + phase (high 32 bits) + // Convert amplitude (f64, 0.0~1.0) to ASF (i32, 0~0x3FFF) + configSpi(spiBase, cs, constI32(ARTIQHardwareConfig::spiLen32), + constI32(ARTIQHardwareConfig::spiDiv), + constI32(ARTIQHardwareConfig::spiFlagsKeepCS)); + delayMu(constI64(ARTIQHardwareConfig::refPeriodMu)); + Value asfScale = constF64(static_cast(ARTIQHardwareConfig::maxAmplitude)); + Value asfDouble = builder.create(getLoc(), amplitude, asfScale); + Value asfRounded = builder.create(getLoc(), asfDouble); + Value asf = builder.create(getLoc(), i32Ty, asfRounded); + Value asfShifted = builder.create(getLoc(), asf, constI32(16)); + Value ampPhase = builder.create(getLoc(), asfShifted, pow); + rtioOutput(spiBase, ampPhase); + // Wait for SPI transmission to complete + waitForSpi(ARTIQHardwareConfig::spiLen32, ARTIQHardwareConfig::spiDiv); + + // SPI Transfer: Write FTW (low 32 bits) + configSpi(spiBase, cs, constI32(ARTIQHardwareConfig::spiLen32), + constI32(ARTIQHardwareConfig::spiDiv), + constI32(ARTIQHardwareConfig::spiFlagsReleaseCS)); + delayMu(constI64(ARTIQHardwareConfig::refPeriodMu)); + rtioOutput(spiBase, ftw); + // Wait for SPI transmission to complete + waitForSpi(ARTIQHardwareConfig::spiLen32, ARTIQHardwareConfig::spiDiv); + + // IO Update pulse: Toggle IO update TTL + ttlOn(ioUpdate); + delayMu(constI64(ARTIQHardwareConfig::ioUpdatePulseWidth)); + ttlOff(ioUpdate); + + builder.create(getLoc(), ValueRange{}); + } + + /// Returns hardware addresses: (spiBaseAddr, csBase, ioUpdateAddr) + /// - spiBaseAddr: SPI RTIO address (channel << 8) + /// - csBase: Base chip_select value for ch0 (typically 4 for Urukul) + /// Other channels use csBase + channelIndex (ch1=5, ch2=6, ch3=7) + /// - ioUpdateAddr: IO update TTL RTIO address (channel << 8) + std::tuple getHardwareAddresses(ModuleOp module) + { + auto configAttr = module->getAttrOfType(ConfigAttr::getModuleAttrName()); + assert(configAttr && "rtio.config attribute not found on module"); + + auto getChannel = [&](ArrayRef path) -> int64_t { + Attribute current = configAttr; + for (StringRef key : path) { + if (auto dict = dyn_cast(current)) { + current = dict.get(key); + } + else if (auto cfg = dyn_cast(current)) { + current = cfg.get(key); + } + else { + return 0; + } + } + return cast(current).getInt(); + }; + + int64_t spiChannel = getChannel({"device_db", "spi_urukul0", "arguments", "channel"}); + // chip_select from urukul0_ch0 is the base CS (typically 4) + // ch0->CS=4, ch1->CS=5, ch2->CS=6, ch3->CS=7 + int64_t csBase = getChannel({"device_db", "urukul0_ch0", "arguments", "chip_select"}); + int64_t ioUpdateChannel = + getChannel({"device_db", "ttl_urukul0_io_update", "arguments", "channel"}); + + return {static_cast(spiChannel << 8), static_cast(csBase), + static_cast(ioUpdateChannel << 8)}; + } +}; + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/lib/RTIO/Transforms/CMakeLists.txt b/mlir/lib/RTIO/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..e2ee789bcf --- /dev/null +++ b/mlir/lib/RTIO/Transforms/CMakeLists.txt @@ -0,0 +1,25 @@ +set(LIBRARY_NAME rtio-transforms) + +file(GLOB SRC + RTIOEventToARTIQ.cpp + RTIOEventToARTIQPatterns.cpp +) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + ${dialect_libs} + ${conversion_libs} + MLIRRTIO +) + +set(DEPENDS + MLIRRTIOPassIncGen +) + +add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) +target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) +target_include_directories(${LIBRARY_NAME} PUBLIC + . + ${PROJECT_SOURCE_DIR}/include + ${CMAKE_BINARY_DIR}/include) diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp new file mode 100644 index 0000000000..056d4810a3 --- /dev/null +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQ.cpp @@ -0,0 +1,554 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "RTIO/IR/RTIOOps.h" +#include "RTIO/Transforms/Passes.h" +#include "RTIO/Transforms/Patterns.h" + +#include "ARTIQRuntimeBuilder.hpp" +#include "Utils.hpp" + +using namespace mlir; +using namespace catalyst::rtio; + +namespace catalyst { +namespace rtio { + +#define GEN_PASS_DEF_RTIOEVENTTOARTIQPASS +#include "RTIO/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// Type Aliases +//===----------------------------------------------------------------------===// + +using ScheduleGroupsMap = DenseMap>; +using GroupingPredicate = + std::function; + +//===----------------------------------------------------------------------===// +// Pulse Scheduling +//===----------------------------------------------------------------------===// + +class PulseScheduler { + public: + PulseScheduler(func::FuncOp funcOp, OpBuilder &builder, GroupingPredicate predicate) + : funcOp(funcOp), builder(builder), groupingPredicate(std::move(predicate)) + { + } + + ScheduleGroupsMap schedule() + { + // Collect all pulses + funcOp.walk([&](rtio::RTIOPulseOp pulse) { allPulses.push_back(pulse); }); + + // Build consumer map + for (auto pulse : allPulses) { + if (auto producer = pulse.getWait().getDefiningOp()) { + pulseConsumers[producer].insert(pulse); + } + } + + processFromEmptyOps(); + return std::move(groups); + } + + private: + func::FuncOp funcOp; + OpBuilder &builder; + GroupingPredicate groupingPredicate; + + SmallVector allPulses; + DenseMap> pulseConsumers; + DenseSet processedEvents; + DenseSet processedPulses; + ScheduleGroupsMap groups; + int nextGroupId = 0; + + SmallVector getEventConsumers(Value event) + { + SmallVector consumers; + for (Operation *user : event.getUsers()) { + auto pulse = dyn_cast(user); + if (!pulse || pulse.getWait() != event) { + continue; + } + consumers.push_back(pulse); + } + return consumers; + } + + void processFromEmptyOps() + { + std::deque worklist; + funcOp.walk([&](rtio::RTIOEmptyOp emptyOp) { worklist.push_back(emptyOp.getResult()); }); + + while (!worklist.empty()) { + Value event = worklist.front(); + worklist.pop_front(); + + // check if event has already been processed + // if not, process the event and insert it into the processed events + if (!processedEvents.insert(event).second) { + continue; + } + + SmallVector newEvents = processEvent(event); + llvm::append_range(worklist, newEvents); + } + } + + // return the next events to process + SmallVector processEvent(Value event) + { + SmallVector nextEvents; + auto consumers = getEventConsumers(event); + if (consumers.empty()) { + return nextEvents; + } + + // Group pulses by channel, respecting grouping predicate + DenseMap> channelPulses; + DenseMap channelLastPulse; + DenseMap channelBoundary; + SmallVector boundaryConsumers; + + // Initial + for (auto pulse : consumers) { + if (processedPulses.contains(pulse)) { + continue; + } + + int32_t channel = extractChannelId(pulse.getChannel()); + if (canJoinGroup(pulse, channelPulses)) { + channelPulses[channel].push_back(pulse); + channelLastPulse[channel] = pulse; + } + else { + if (!channelBoundary.count(channel)) { + channelBoundary[channel] = pulse; + } + boundaryConsumers.push_back(pulse); + } + } + + if (channelPulses.empty()) { + return nextEvents; + } + + // Extend chains on each channel + extendChannelChains(channelPulses, channelLastPulse, channelBoundary); + + // Record group + recordGroup(channelPulses); + + // Create sync and update dependencies + return createSyncAndUpdateDeps(channelPulses, channelLastPulse, channelBoundary, + boundaryConsumers); + } + + bool canJoinGroup(rtio::RTIOPulseOp cand, + const DenseMap> &channelPulses) + { + for (auto &[ch, pulses] : channelPulses) { + if (!llvm::all_of(pulses, [&](auto pulse) { return groupingPredicate(pulse, cand); })) { + return false; + } + } + return true; + } + + void extendChannelChains(DenseMap> &channelPulses, + DenseMap &channelLastPulse, + DenseMap &channelBoundary) + { + DenseSet stopped; + + while (stopped.size() < channelPulses.size()) { + for (auto &[channel, pulses] : channelPulses) { + if (stopped.contains(channel)) { + continue; + } + + auto currentPulse = channelLastPulse[channel]; + processedPulses.insert(currentPulse); + + bool foundNext = false; + for (auto user : pulseConsumers[currentPulse]) { + int32_t userChannel = extractChannelId(user.getChannel()); + if (userChannel != channel || processedPulses.contains(user)) { + continue; + } + + if (groupingPredicate(currentPulse, user)) { + channelPulses[channel].push_back(user); + channelLastPulse[channel] = user; + } + else { + channelBoundary[channel] = user; + stopped.insert(channel); + } + foundNext = true; + break; + } + + if (!foundNext) { + stopped.insert(channel); + } + } + } + } + + void recordGroup(const DenseMap> &channelPulses) + { + int groupId = nextGroupId++; + auto &groupOps = groups[groupId]; + for (auto &[_, pulses] : channelPulses) { + for (auto pulse : pulses) { + groupOps.insert(pulse.getOperation()); + } + } + } + + SmallVector + createSyncAndUpdateDeps(const DenseMap> &channelPulses, + DenseMap &channelLastPulse, + DenseMap &channelBoundary, + SmallVector &boundaryConsumers) + { + SmallVector nextEvents; + + if (channelPulses.size() > 1 && !channelBoundary.empty()) { + // Collect events to sync + SmallVector eventsToSync; + for (auto &entry : channelLastPulse) { + rtio::RTIOPulseOp pulse = entry.second; + eventsToSync.push_back(pulse.getEvent()); + } + + auto anyPulse = channelLastPulse.begin()->second; + builder.setInsertionPointAfter(anyPulse); + + auto eventType = rtio::EventType::get(builder.getContext()); + Value syncEvent = + builder.create(anyPulse.getLoc(), eventType, eventsToSync); + + // Update boundaries and consumers + for (auto &[_, pulse] : channelBoundary) { + pulse.setWait(syncEvent); + } + for (auto pulse : boundaryConsumers) { + pulse.setWait(syncEvent); + } + for (auto &entry : channelLastPulse) { + rtio::RTIOPulseOp pulse = entry.second; + for (auto user : pulseConsumers[pulse]) { + auto userChannel = extractChannelId(user.getChannel()); + if (!channelBoundary.count(userChannel) || + channelBoundary[userChannel] != user) { + if (user.getWait() == pulse.getEvent()) { + user.setWait(syncEvent); + } + } + } + } + + nextEvents.push_back(syncEvent); + } + else { + // No sync needed + for (auto &entry : channelBoundary) { + rtio::RTIOPulseOp pulse = entry.second; + nextEvents.push_back(pulse.getWait()); + } + if (!boundaryConsumers.empty() && !channelLastPulse.empty()) { + rtio::RTIOPulseOp firstPulse = channelLastPulse.begin()->second; + Value lastEvent = firstPulse.getEvent(); + for (auto pulse : boundaryConsumers) { + pulse.setWait(lastEvent); + } + nextEvents.push_back(lastEvent); + } + for (auto &entry : channelLastPulse) { + rtio::RTIOPulseOp pulse = entry.second; + for (auto *user : pulse.getEvent().getUsers()) { + if (auto syncOp = dyn_cast(user)) { + nextEvents.push_back(syncOp.getSyncEvent()); + } + } + } + } + + return nextEvents; + } +}; + +//===----------------------------------------------------------------------===// +// Frequency Decomposition +//===----------------------------------------------------------------------===// + +void decomposeFrequencyPulses(ScheduleGroupsMap &pulseGroups) +{ + if (pulseGroups.empty()) { + return; + } + + auto firstOp = pulseGroups.begin()->second.front(); + OpBuilder builder(firstOp->getContext()); + + // Track last frequency per channel (to avoid redundant frequency settings) + DenseMap channelLastFreq; + + // Sort groups by ID for deterministic processing + SmallVector *>> sortedGroups; + for (auto &entry : pulseGroups) { + sortedGroups.push_back({entry.first, &entry.second}); + } + llvm::sort(sortedGroups, [](const auto &a, const auto &b) { return a.first < b.first; }); + + for (auto &[groupId, groupOpsPtr] : sortedGroups) { + auto &groupOps = *groupOpsPtr; + if (groupOps.empty()) { + continue; + } + + // Find root pulses (pulses whose wait isn't produced by another pulse in this group) + DenseMap channelRoots; + for (auto *op : groupOps) { + auto pulse = cast(op); + Value wait = pulse.getWait(); + + bool isRoot = llvm::none_of(groupOps, [&](Operation *other) { + return cast(other).getEvent() == wait; + }); + + if (isRoot) { + Value channel = pulse.getChannel(); + if (!channelRoots.count(channel)) { + channelRoots[channel] = pulse; + } + } + } + + if (channelRoots.empty()) { + continue; + } + + // Filter to channels needing frequency change + DenseMap needsFreqSet; + for (auto &entry : channelRoots) { + Value channel = entry.first; + rtio::RTIOPulseOp pulse = entry.second; + Value freq = pulse.getFrequency(); + auto it = channelLastFreq.find(channel); + if (it == channelLastFreq.end() || it->second != freq) { + needsFreqSet[channel] = pulse; + channelLastFreq[channel] = freq; + } + } + + if (needsFreqSet.empty()) { + continue; + } + + // Collect original wait events + SmallVector originalWaits; + for (auto &entry : channelRoots) { + rtio::RTIOPulseOp pulse = entry.second; + Value wait = pulse.getWait(); + if (!llvm::is_contained(originalWaits, wait)) { + originalWaits.push_back(wait); + } + } + + // Find first root pulse (for insertion point) + rtio::RTIOPulseOp firstRoot = nullptr; + for (auto &entry : channelRoots) { + rtio::RTIOPulseOp pulse = entry.second; + if (!firstRoot || pulse->isBeforeInBlock(firstRoot)) { + firstRoot = pulse; + } + } + + builder.setInsertionPoint(firstRoot); + + // Create sync + Value chainStart = + originalWaits.size() > 1 + ? builder.create( + firstRoot.getLoc(), rtio::EventType::get(builder.getContext()), originalWaits) + : originalWaits[0]; + + // Create frequency setting chain + Value lastFreqEvent = chainStart; + for (auto &entry : needsFreqSet) { + rtio::RTIOPulseOp originalPulse = entry.second; + auto freqPulse = cast(builder.clone(*originalPulse.getOperation())); + freqPulse.setWait(lastFreqEvent); + freqPulse->setAttr("_frequency", builder.getUnitAttr()); + lastFreqEvent = freqPulse.getEvent(); + } + + // Update root pulses to wait on last frequency event + for (auto &entry : channelRoots) { + rtio::RTIOPulseOp pulse = entry.second; + pulse.setWait(lastFreqEvent); + } + } +} +} // namespace + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +struct RTIOEventToARTIQPass : public impl::RTIOEventToARTIQPassBase { + using RTIOEventToARTIQPassBase::RTIOEventToARTIQPassBase; + + void runOnOperation() override + { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + // Schedule pulses into groups + DenseMap pulseGroups; + module.walk([&](func::FuncOp funcOp) { + PulseScheduler scheduler(funcOp, builder, sameChannelSameFrequency); + pulseGroups[funcOp] = scheduler.schedule(); + }); + + // Simplify sync operations + { + RewritePatternSet patterns(&getContext()); + populateRTIOSyncSimplifyPatterns(patterns); + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + module.emitError("Failed during sync simplification"); + return signalPassFailure(); + } + } + + // Decompose frequency pulses + for (auto &[funcOp, groups] : pulseGroups) { + decomposeFrequencyPulses(groups); + } + + // Sort blocks to fix dominance + sortAllBlocks(module); + + // Decompose _frequency pulses into control + slack + { + RewritePatternSet patterns(&getContext()); + populateRTIOPulseDecomposePatterns(patterns); + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + module.emitError("Failed during pulse decomposition"); + return signalPassFailure(); + } + } + + // Setup device initialization + if (failed(setupKernelDevice(module, builder))) { + return signalPassFailure(); + } + + // Lowering to LLVM + if (failed(lowerToLLVM(module))) { + return signalPassFailure(); + } + } + + private: + static bool sameChannelSameFrequency(RTIOPulseOp ref, RTIOPulseOp candidate) + { + if (ref.getChannel() == candidate.getChannel()) { + return ref.getFrequency() == candidate.getFrequency(); + } + return true; + } + + static void sortAllBlocks(ModuleOp module) + { + module.walk([](func::FuncOp funcOp) { + for (auto &block : funcOp.getBody()) { + sortTopologically(&block); + } + }); + } + + LogicalResult setupKernelDevice(ModuleOp module, OpBuilder &builder) + { + auto kernelFunc = module.lookupSymbol(ARTIQFuncNames::kernel); + if (!kernelFunc) { + module.emitError("Cannot find ") << ARTIQFuncNames::kernel << " function"; + return failure(); + } + + OpBuilder::InsertionGuard guard(builder); + + // Ensure helper functions are defined in the module + ARTIQRuntimeBuilder artiq(builder, kernelFunc); + artiq.ensureHelperFunctions(); + + builder.setInsertionPointToStart(&kernelFunc.getBody().front()); + artiq.rtioInit(); + + // Set initial timeline: at_mu(rtio_get_counter() + slack) + Value counter = artiq.rtioGetCounter(); + Value slack = artiq.constI64(ARTIQHardwareConfig::initSlackDelay); + Value initialTime = builder.create(kernelFunc.getLoc(), counter, slack); + artiq.atMu(initialTime); + + return success(); + } + + LogicalResult lowerToLLVM(ModuleOp module) + { + MLIRContext *ctx = &getContext(); + LLVMTypeConverter typeConverter(ctx); + + typeConverter.addConversion( + [](rtio::ChannelType type) { return IntegerType::get(type.getContext(), 32); }); + typeConverter.addConversion( + [](rtio::EventType type) { return IntegerType::get(type.getContext(), 64); }); + + RewritePatternSet patterns(ctx); + populateRTIOToARTIQConversionPatterns(typeConverter, patterns); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + + return applyPartialConversion(module, target, std::move(patterns)); + } +}; + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp new file mode 100644 index 0000000000..755bd20b34 --- /dev/null +++ b/mlir/lib/RTIO/Transforms/RTIOEventToARTIQPatterns.cpp @@ -0,0 +1,306 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "RTIO/IR/RTIOOps.h" +#include "RTIO/Transforms/Patterns.h" + +#include "ARTIQRuntimeBuilder.hpp" +#include "Utils.hpp" + +using namespace mlir; +using namespace catalyst::rtio; + +namespace { + +//===----------------------------------------------------------------------===// +// Conversion Patterns +//===----------------------------------------------------------------------===// + +struct PulseOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RTIOPulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + ARTIQRuntimeBuilder artiq(rewriter, op); + + // Set timeline position + artiq.atMu(adaptor.getWait()); + + if (op->hasAttr("_control")) { + return lowerControlPulse(op, adaptor, rewriter, artiq); + } + else if (op->hasAttr("_slack")) { + return lowerSlackPulse(op, rewriter, artiq); + } + return lowerTTLPulse(op, adaptor, rewriter, artiq); + } + + private: + LogicalResult lowerControlPulse(RTIOPulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + ARTIQRuntimeBuilder &artiq) const + { + ModuleOp mod = op->getParentOfType(); + auto setFreqFunc = mod.lookupSymbol(ARTIQFuncNames::setFrequency); + if (!setFreqFunc) { + return op->emitError("Cannot find ") << ARTIQFuncNames::setFrequency << " function"; + } + + Value amplitude = artiq.constF64(1.0); + rewriter.create(op.getLoc(), setFreqFunc, + ValueRange{adaptor.getChannel(), adaptor.getFrequency(), + adaptor.getPhase(), amplitude}); + + Value newTime = artiq.nowMu(); + rewriter.replaceOp(op, newTime); + return success(); + } + + LogicalResult lowerSlackPulse(RTIOPulseOp op, ConversionPatternRewriter &rewriter, + ARTIQRuntimeBuilder &artiq) const + { + artiq.delayMu(artiq.constI64(ARTIQHardwareConfig::freqSetSlackDelay)); + Value newTime = artiq.nowMu(); + rewriter.replaceOp(op, newTime); + return success(); + } + + LogicalResult lowerTTLPulse(RTIOPulseOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + ARTIQRuntimeBuilder &artiq) const + { + Value channelAddr = computeChannelDeviceAddr(rewriter, op, adaptor.getChannel()); + Value durationMu = artiq.secToMu(adaptor.getDuration()); + + // Enforce minimum pulse duration to avoid 0 duratoin events + Value minDuration = artiq.constI64(ARTIQHardwareConfig::minTTLPulseMu); + durationMu = rewriter.create(op.getLoc(), durationMu, minDuration); + + artiq.ttlOn(channelAddr); + artiq.delayMu(durationMu); + artiq.ttlOff(channelAddr); + + Value newTime = artiq.nowMu(); + rewriter.replaceOp(op, newTime); + return success(); + } +}; + +struct SyncOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RTIOSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + ValueRange events = adaptor.getEvents(); + + if (events.size() == 1) { + rewriter.replaceOp(op, events[0]); + return success(); + } + + // Compute maximum timestamp + Value maxTime = events[0]; + for (size_t i = 1; i < events.size(); ++i) { + maxTime = rewriter.create(op.getLoc(), maxTime, events[i]); + } + + ARTIQRuntimeBuilder artiq(rewriter, op); + artiq.atMu(maxTime); + rewriter.replaceOp(op, artiq.nowMu()); + return success(); + } +}; + +struct EmptyOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RTIOEmptyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + ARTIQRuntimeBuilder artiq(rewriter, op); + rewriter.replaceOp(op, artiq.nowMu()); + return success(); + } +}; + +struct ChannelOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(RTIOChannelOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + int32_t channelId = extractChannelId(op.getChannel()); + Type resultType = getTypeConverter()->convertType(op.getChannel().getType()); + Value result = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(resultType, channelId)); + rewriter.replaceOp(op, result); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Rewrite Patterns +//===----------------------------------------------------------------------===// + +/// Decomposes a pulse with _frequency attribute into control + slack pulses +struct DecomposePulsePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RTIOPulseOp op, PatternRewriter &rewriter) const override + { + if (!op->hasAttr("_frequency")) { + return failure(); + } + + op->removeAttr("_frequency"); + Location loc = op.getLoc(); + + // Create control pulse (frequency setting) + auto controlPulse = cast(rewriter.clone(*op.getOperation())); + controlPulse->setAttr("_control", rewriter.getUnitAttr()); + + // Create slack pulse (timing delay) + auto slackPulse = cast(rewriter.clone(*op.getOperation())); + slackPulse->setAttr("_slack", rewriter.getUnitAttr()); + + // Sync both pulses + auto eventType = EventType::get(rewriter.getContext()); + Value syncEvent = rewriter.create( + loc, eventType, ValueRange{controlPulse.getEvent(), slackPulse.getEvent()}); + + rewriter.replaceOp(op, syncEvent); + return success(); + } +}; + +/// Removes redundant transitive dependencies from sync operations +struct SimplifySyncPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RTIOSyncOp op, PatternRewriter &rewriter) const override + { + auto events = op.getEvents(); + if (events.size() <= 1) { + return failure(); + } + + // Find events that aren't reachable from other events + SmallVector requiredEvents; + for (Value event : events) { + bool isRedundant = llvm::any_of(events, [&](Value other) { + if (event == other) { + return false; + } + return canReach(event, other); + }); + if (!isRedundant) { + requiredEvents.push_back(event); + } + } + + if (requiredEvents.size() == events.size() || requiredEvents.empty()) { + return failure(); + } + + if (requiredEvents.size() == 1) { + rewriter.replaceOp(op, requiredEvents[0]); + } + else { + rewriter.replaceOpWithNewOp(op, op.getType(), requiredEvents); + } + return success(); + } + + private: + // Check if target is reachable from 'from' by traversing event dependencies. + static bool canReach(Value target, Value from) + { + if (target == from) { + return true; + } + + DenseSet visited; + SmallVector queue; + queue.push_back(from); + + while (!queue.empty()) { + Value current = queue.pop_back_val(); + + if (current == target) { + return true; + } + + if (!visited.insert(current).second) { + continue; + } + + Operation *defOp = current.getDefiningOp(); + if (!defOp) { + continue; + } + + if (auto pulse = dyn_cast(defOp)) { + queue.push_back(pulse.getWait()); + } + else if (auto sync = dyn_cast(defOp)) { + for (Value ev : sync.getEvents()) { + queue.push_back(ev); + } + } + } + return false; + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population Functions +//===----------------------------------------------------------------------===// + +namespace catalyst { +namespace rtio { + +void populateRTIOToARTIQConversionPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) +{ + patterns.add( + typeConverter, patterns.getContext()); +} + +void populateRTIORewritePatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext()); +} + +void populateRTIOSyncSimplifyPatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext()); +} + +void populateRTIOPulseDecomposePatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext()); +} + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/lib/RTIO/Transforms/Utils.hpp b/mlir/lib/RTIO/Transforms/Utils.hpp new file mode 100644 index 0000000000..b411f518c7 --- /dev/null +++ b/mlir/lib/RTIO/Transforms/Utils.hpp @@ -0,0 +1,64 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Matchers.h" + +#include "RTIO/IR/RTIOOps.h" + +namespace catalyst { +namespace rtio { + +/// Extract the static channel ID from an RTIO channel type. +inline int32_t extractChannelId(mlir::Value channelValue) +{ + auto type = mlir::cast(channelValue.getType()); + + assert(type.isStatic() && "Only static channel IDs are supported"); + return type.getChannelId().getInt(); +} + +/// Compute the device address for a given channel value. +inline mlir::Value computeChannelDeviceAddr(mlir::OpBuilder &builder, mlir::Operation *op, + mlir::Value channelValue) +{ + mlir::Location loc = op->getLoc(); + mlir::ModuleOp mod = op->getParentOfType(); + auto configAttr = mod->getAttrOfType(ConfigAttr::getModuleAttrName()); + assert(configAttr && "configAttr not found"); + + // Get base channel from config + mlir::Attribute current = configAttr; + for (llvm::StringRef key : {"device_db", "ttl_urukul0_sw0", "arguments", "channel"}) { + if (auto dict = mlir::dyn_cast(current)) { + current = dict.get(key); + } + else if (auto cfg = mlir::dyn_cast(current)) { + current = cfg.get(key); + } + } + int64_t channelBase = mlir::cast(current).getInt(); + + llvm::APInt channelIdAPInt; + assert(mlir::matchPattern(channelValue, mlir::m_ConstantInt(&channelIdAPInt)) && + "only static channels are supported"); + int64_t channelId = channelIdAPInt.getSExtValue(); + int32_t addr = static_cast((channelId + channelBase) << 8); + return builder.create(loc, builder.getI32IntegerAttr(addr)); +} + +} // namespace rtio +} // namespace catalyst diff --git a/mlir/test/RTIO/RTIOEventToARTIQ.mlir b/mlir/test/RTIO/RTIOEventToARTIQ.mlir new file mode 100644 index 0000000000..2fc6b46de0 --- /dev/null +++ b/mlir/test/RTIO/RTIOEventToARTIQ.mlir @@ -0,0 +1,140 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --convert-rtio-event-to-artiq --split-input-file | FileCheck %s + +// CHECK: llvm.func @now_mu() -> i64 +// CHECK: llvm.func @at_mu(i64) +// CHECK: llvm.func @rtio_get_counter() -> i64 +// CHECK: llvm.func @rtio_init() +// CHECK: llvm.func @delay_mu(i64) +// CHECK: llvm.func internal @__rtio_set_frequency(%arg0: i32, %arg1: f64, %arg2: f64, %arg3: f64) +// CHECK: llvm.func @rtio_output(i32, i32) +// CHECK: llvm.func internal @__rtio_config_spi(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32) +// CHECK: llvm.func internal fastcc @__rtio_sec_to_mu(%arg0: f64) -> i64 + +// CHECK-LABEL: func.func @__kernel__() +// CHECK-SAME: attributes {diff_method = "parameter-shift", qnode} +module @circuit attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64", device_db = {core = {arguments = {analyzer_proxy = "core_analyzer", host = "172.31.9.64", ref_period = 1.000000e-09 : f64, satellite_cpu_targets = {"1" = "rv32g"}, target = "cortexa9"}, class = "Core", module = "artiq.coredevice.core", type = "local"}, spi_urukul0 = {arguments = {channel = 17 : i64}, class = "SPIMaster", module = "artiq.coredevice.spi2", type = "local"}, ttl_urukul0_io_update = {arguments = {channel = 18 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, ttl_urukul0_sw0 = {arguments = {channel = 19 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, ttl_urukul0_sw1 = {arguments = {channel = 20 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, ttl_urukul0_sw2 = {arguments = {channel = 21 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, ttl_urukul0_sw3 = {arguments = {channel = 22 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, urukul0_ch0 = {arguments = {chip_select = 4 : i64, cpld_device = "urukul0_cpld", pll_en = 1 : i64, pll_n = 32 : i64, sw_device = "ttl_urukul0_sw0"}, class = "AD9910", module = "artiq.coredevice.ad9910", type = "local"}, urukul0_ch1 = {arguments = {chip_select = 5 : i64, cpld_device = "urukul0_cpld", pll_en = 1 : i64, pll_n = 32 : i64, sw_device = "ttl_urukul0_sw1"}, class = "AD9910", module = "artiq.coredevice.ad9910", type = "local"}, urukul0_ch2 = {arguments = {chip_select = 6 : i64, cpld_device = "urukul0_cpld", pll_en = 1 : i64, pll_n = 32 : i64, sw_device = "ttl_urukul0_sw2"}, class = "AD9910", module = "artiq.coredevice.ad9910", type = "local"}, urukul0_ch3 = {arguments = {chip_select = 7 : i64, cpld_device = "urukul0_cpld", pll_en = 1 : i64, pll_n = 32 : i64, sw_device = "ttl_urukul0_sw3"}, class = "AD9910", module = "artiq.coredevice.ad9910", type = "local"}, urukul0_cpld = {arguments = {clk_div = 0 : i64, clk_sel = 2 : i64, io_update_device = "ttl_urukul0_io_update", refclk = 125000000 : i64, spi_device = "spi_urukul0", sync_device}, class = "CPLD", module = "artiq.coredevice.urukul", type = "local"}}}>} { + memref.global "private" constant @__qubit_map_0 : memref<2xindex> = dense<[0, 1]> + func.func @__kernel__() attributes {diff_method = "parameter-shift", qnode} { + %cst = arith.constant 1.1618250000000001E-6 : f64 + %cst_0 = arith.constant 8.298750e-06 : f64 + %cst_1 = arith.constant 1.6597500000000003E-7 : f64 + %cst_2 = arith.constant 19100000.100724373 : f64 + %cst_3 = arith.constant 20900000.012723904 : f64 + %cst_4 = arith.constant 19000000.026035815 : f64 + %cst_5 = arith.constant 21000000.087412462 : f64 + %cst_6 = arith.constant 19999999.977146666 : f64 + %cst_7 = arith.constant 0.000000e+00 : f64 + + // Test rtio.empty, should initialize RTIO and return a timestamp + // CHECK: llvm.call fastcc tail @rtio_init() + // CHECK: %[[COUNTER:.*]] = llvm.call fastcc tail @rtio_get_counter() + // CHECK: %[[OFFSET:.*]] = arith.constant 125000 : i64 + // CHECK: %[[INIT_TIME:.*]] = arith.addi %[[COUNTER]], %[[OFFSET]] + // CHECK: llvm.call tail @at_mu(%[[INIT_TIME]]) + %0 = rtio.empty : !rtio.event + + // Test rtio.channel, creates channel reference + %1 = rtio.channel : !rtio.channel<"dds", [2 : i64], 2> + %3 = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + + // Test rtio.pulse with wait on empty, should set frequency and generate TTL pulse + // First pulse on channel 2 waiting on empty event + // CHECK: llvm.call tail @now_mu() + // CHECK: llvm.call tail @at_mu + // CHECK: llvm.call @__rtio_set_frequency + // CHECK: llvm.call tail @now_mu() + %2 = rtio.pulse %1 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + + // Test parallel pulses, both wait on same event + // CHECK: llvm.call tail @at_mu + // CHECK: llvm.call @__rtio_set_frequency + %4 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + + // Test sequential pulse on same channel + // CHECK: llvm.call tail @at_mu + // CHECK: llvm.call fastcc tail @__rtio_sec_to_mu + // CHECK: llvm.call tail @rtio_output + // CHECK: llvm.call fastcc tail @delay_mu + // CHECK: llvm.call tail @rtio_output + %5 = rtio.pulse %3 duration(%cst_1) frequency(%cst_6) phase(%cst_7) wait(%4) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + + // Test rtio.sync, synchronizes multiple events using maxsi + // CHECK: arith.maxsi + // CHECK: llvm.call tail @at_mu + %6 = rtio.sync %5, %2 : !rtio.event + + // Test multiple parallel pulses after sync + %7 = rtio.pulse %3 duration(%cst_0) frequency(%cst_5) phase(%cst_7) wait(%6) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %8 = rtio.channel : !rtio.channel<"dds", [2 : i64], 1> + %9 = rtio.pulse %8 duration(%cst_0) frequency(%cst_4) phase(%cst_7) wait(%6) {offset = 1 : i64} : <"dds", [2 : i64], 1> -> !rtio.event + %10 = rtio.pulse %1 duration(%cst_0) frequency(%cst_3) phase(%cst_7) wait(%6) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + %11 = rtio.channel : !rtio.channel<"dds", [2 : i64], 3> + %12 = rtio.pulse %11 duration(%cst_0) frequency(%cst_2) phase(%cst_7) wait(%6) {offset = 1 : i64} : <"dds", [2 : i64], 3> -> !rtio.event + + // Test sync with 4 events + // CHECK: arith.maxsi + // CHECK: arith.maxsi + // CHECK: arith.maxsi + // CHECK: llvm.call tail @at_mu + %13 = rtio.sync %7, %9, %10, %12 : !rtio.event + + // Final pulses after sync + %14 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + %15 = rtio.pulse %1 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%13) {offset = 0 : i64} : <"dds", [2 : i64], 2> -> !rtio.event + %16 = rtio.pulse %3 duration(%cst) frequency(%cst_6) phase(%cst_7) wait(%14) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + + // CHECK: return + return + } +} + +// ----- + +// CHECK-LABEL: func.func @__kernel__() +module @simple_sequential attributes {rtio.config = #rtio.config<{core_addr = "172.31.9.64", device_db = {core = {arguments = {host = "172.31.9.64", ref_period = 1.000000e-09 : f64, target = "cortexa9"}, class = "Core", module = "artiq.coredevice.core", type = "local"}, spi_urukul0 = {arguments = {channel = 17 : i64}, class = "SPIMaster", module = "artiq.coredevice.spi2", type = "local"}, ttl_urukul0_io_update = {arguments = {channel = 18 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, ttl_urukul0_sw0 = {arguments = {channel = 19 : i64}, class = "TTLOut", module = "artiq.coredevice.ttl", type = "local"}, urukul0_ch0 = {arguments = {chip_select = 4 : i64, cpld_device = "urukul0_cpld", pll_en = 1 : i64, pll_n = 32 : i64, sw_device = "ttl_urukul0_sw0"}, class = "AD9910", module = "artiq.coredevice.ad9910", type = "local"}, urukul0_cpld = {arguments = {clk_div = 0 : i64, clk_sel = 2 : i64, io_update_device = "ttl_urukul0_io_update", refclk = 125000000 : i64, spi_device = "spi_urukul0", sync_device}, class = "CPLD", module = "artiq.coredevice.urukul", type = "local"}}}>} { + memref.global "private" constant @__qubit_map_0 : memref<1xindex> = dense<0> + func.func @__kernel__() attributes {diff_method = "parameter-shift", qnode} { + %cst_dur = arith.constant 1.0e-6 : f64 + %cst_freq = arith.constant 20000000.0 : f64 + %cst_phase = arith.constant 0.0 : f64 + + // CHECK: llvm.call fastcc tail @rtio_init() + %0 = rtio.empty : !rtio.event + + %ch0 = rtio.channel : !rtio.channel<"dds", [2 : i64], 0> + + // First pulse, sets frequency and generates TTL + // CHECK: llvm.call @__rtio_set_frequency + // CHECK: llvm.call fastcc tail @__rtio_sec_to_mu + // CHECK: llvm.call tail @rtio_output + // CHECK: llvm.call fastcc tail @delay_mu + // CHECK: llvm.call tail @rtio_output + %1 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%0) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + + // Second pulse, sequential, waits for first + // CHECK: llvm.call tail @at_mu + // CHECK: llvm.call fastcc tail @__rtio_sec_to_mu + // CHECK: llvm.call tail @rtio_output + // CHECK: llvm.call fastcc tail @delay_mu + // CHECK: llvm.call tail @rtio_output + %2 = rtio.pulse %ch0 duration(%cst_dur) frequency(%cst_freq) phase(%cst_phase) wait(%1) {offset = 0 : i64} : <"dds", [2 : i64], 0> -> !rtio.event + + // CHECK: return + return + } +} + diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index f9d08252a1..0bafc425f7 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -26,6 +26,7 @@ set(LIBS MLIRIon ion-transforms MLIRRTIO + rtio-transforms MLIRCatalystTest MLIRCatalystUtils MLIRTestDialect