diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 9f61520510cc..ec0256c51a1b 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -60,7 +60,10 @@ cc_library( ] + glob([ "dialect/tpu/transforms/*.h", ]), - # compatible with libtpu + compatible_with = [ + "//buildenv/target:libtpu", + "//buildenv/target:non_prod", + ], deps = [ ":pass_boilerplate", ":serde", @@ -104,7 +107,10 @@ cc_library( gentbl_cc_library( name = "tpu_inc_gen", - # compatible with libtpu + compatible_with = [ + "//buildenv/target:libtpu", + "//buildenv/target:non_prod", + ], tbl_outs = { "dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"], "dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"], @@ -139,7 +145,10 @@ td_library( srcs = [ "dialect/tpu/tpu.td", ], - # compatible with libtpu + compatible_with = [ + "//buildenv/target:libtpu", + "//buildenv/target:non_prod", + ], deps = [ "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", @@ -263,7 +272,10 @@ filegroup( cc_library( name = "pass_boilerplate", hdrs = ["pass_boilerplate.h"], - # compatible with libtpu + compatible_with = [ + "//buildenv/target:libtpu", + "//buildenv/target:non_prod", + ], deps = [ "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -275,7 +287,10 @@ cc_library( name = "serde", srcs = ["serde.cc"], hdrs = ["serde.h"], - # compatible with libtpu + compatible_with = [ + "//buildenv/target:libtpu", + "//buildenv/target:non_prod", + ], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 6cca63ed1d95..0ecfce6076b6 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -1450,4 +1450,11 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp ]; } +def BasicBlockTraceInsertionPass : Pass<"basic-block-trace-insertion", "::mlir::func::FuncOp"> { + let dependentDialects = [ + "::mlir::tpu::TPUDialect", + ]; + let constructor = "::mlir::tpu::createBasicBlockTraceInsertionPass()"; +} + #endif // TPU_ATTRS diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 489dadf612f3..aab5493fac87 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -98,6 +98,9 @@ std::unique_ptr> createLinalgVectorizationPass( std::unique_ptr> createDebugAssertInsertionPass(); +std::unique_ptr> +createBasicBlockTraceInsertionPass(); + #define GEN_PASS_DECL_MOSAICSERDEPASS #include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" diff --git a/jaxlib/mosaic/dialect/tpu/transforms/basic_block_trace_insertion.cc b/jaxlib/mosaic/dialect/tpu/transforms/basic_block_trace_insertion.cc new file mode 100644 index 000000000000..5d11775900b2 --- /dev/null +++ b/jaxlib/mosaic/dialect/tpu/transforms/basic_block_trace_insertion.cc @@ -0,0 +1,66 @@ + +/* Copyright 2024 The JAX Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "absl/strings/str_cat.h" +#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h" +namespace mlir::tpu { +#define GEN_PASS_DECL_BASICBLOCKTRACEINSERTIONPASS +#define GEN_PASS_DEF_BASICBLOCKTRACEINSERTIONPASS +#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc" +namespace { +struct BasicBlockTraceInsertionPass + : public impl::BasicBlockTraceInsertionPassBase< + BasicBlockTraceInsertionPass> { + void runOnOperation() override { + func::FuncOp func = getOperation(); + std::deque queue{&func.getBody()}; + int64_t block_counter = 0; + Location loc = UnknownLoc::get(func.getContext()); + while (!queue.empty()) { + Region* region = queue.front(); + queue.pop_front(); + for (auto it = region->begin(); it != region->end(); ++it) { + Block& block = *it; + if (block.empty()) { + continue; + } + OpBuilder::atBlockBegin(&block).create( + loc, absl::StrCat("__block_", block_counter++), /*level=*/10); + OpBuilder::atBlockTerminator(&block).create(loc); + for (Operation& op : block.without_terminator()) { + for (Region ®ion : op.getRegions()) { + if (!region.empty()) { + queue.push_back(®ion); + } + } + } + } + } + } +}; +} // namespace +std::unique_ptr> +createBasicBlockTraceInsertionPass() { + return std::make_unique(); +} +} // namespace mlir::tpu \ No newline at end of file