Skip to content

Commit 1c376c3

Browse files
Support For Roofline Analysis Of Pallas Kernels
PiperOrigin-RevId: 796850434
1 parent 47d933c commit 1c376c3

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

jaxlib/mosaic/BUILD

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ cc_library(
6060
] + glob([
6161
"dialect/tpu/transforms/*.h",
6262
]),
63-
# compatible with libtpu
63+
compatible_with = [
64+
"//buildenv/target:libtpu",
65+
"//buildenv/target:non_prod",
66+
],
6467
deps = [
6568
":pass_boilerplate",
6669
":serde",
@@ -104,7 +107,10 @@ cc_library(
104107

105108
gentbl_cc_library(
106109
name = "tpu_inc_gen",
107-
# compatible with libtpu
110+
compatible_with = [
111+
"//buildenv/target:libtpu",
112+
"//buildenv/target:non_prod",
113+
],
108114
tbl_outs = {
109115
"dialect/tpu/tpu_ops.h.inc": ["-gen-op-decls"],
110116
"dialect/tpu/tpu_ops.cc.inc": ["-gen-op-defs"],
@@ -139,7 +145,10 @@ td_library(
139145
srcs = [
140146
"dialect/tpu/tpu.td",
141147
],
142-
# compatible with libtpu
148+
compatible_with = [
149+
"//buildenv/target:libtpu",
150+
"//buildenv/target:non_prod",
151+
],
143152
deps = [
144153
"@llvm-project//mlir:BuiltinDialectTdFiles",
145154
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
@@ -263,7 +272,10 @@ filegroup(
263272
cc_library(
264273
name = "pass_boilerplate",
265274
hdrs = ["pass_boilerplate.h"],
266-
# compatible with libtpu
275+
compatible_with = [
276+
"//buildenv/target:libtpu",
277+
"//buildenv/target:non_prod",
278+
],
267279
deps = [
268280
"@llvm-project//mlir:IR",
269281
"@llvm-project//mlir:Pass",
@@ -275,7 +287,10 @@ cc_library(
275287
name = "serde",
276288
srcs = ["serde.cc"],
277289
hdrs = ["serde.h"],
278-
# compatible with libtpu
290+
compatible_with = [
291+
"//buildenv/target:libtpu",
292+
"//buildenv/target:non_prod",
293+
],
279294
deps = [
280295
"@llvm-project//llvm:Support",
281296
"@llvm-project//mlir:DataLayoutInterfaces",

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,4 +1450,11 @@ def LinalgVectorizationPass : Pass<"linalg-vectorization", "::mlir::func::FuncOp
14501450
];
14511451
}
14521452

1453+
def BasicBlockTraceInsertionPass : Pass<"basic-block-trace-insertion", "::mlir::func::FuncOp"> {
1454+
let dependentDialects = [
1455+
"::mlir::tpu::TPUDialect",
1456+
];
1457+
let constructor = "::mlir::tpu::createBasicBlockTraceInsertionPass()";
1458+
}
1459+
14531460
#endif // TPU_ATTRS

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgVectorizationPass(
9898

9999
std::unique_ptr<OperationPass<func::FuncOp>> createDebugAssertInsertionPass();
100100

101+
std::unique_ptr<OperationPass<func::FuncOp>>
102+
createBasicBlockTraceInsertionPass();
103+
101104
#define GEN_PASS_DECL_MOSAICSERDEPASS
102105
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
103106

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
2+
/* Copyright 2024 The JAX Authors.
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
==============================================================================*/
13+
#include <cstdint>
14+
#include <memory>
15+
#include <deque>
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/IR/Block.h"
18+
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/BuiltinAttributes.h"
20+
#include "mlir/IR/Location.h"
21+
#include "mlir/IR/Region.h"
22+
#include "mlir/IR/Value.h"
23+
#include "mlir/Pass/Pass.h"
24+
#include "absl/strings/str_cat.h"
25+
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
26+
namespace mlir::tpu {
27+
#define GEN_PASS_DECL_BASICBLOCKTRACEINSERTIONPASS
28+
#define GEN_PASS_DEF_BASICBLOCKTRACEINSERTIONPASS
29+
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
30+
namespace {
31+
struct BasicBlockTraceInsertionPass
32+
: public impl::BasicBlockTraceInsertionPassBase<
33+
BasicBlockTraceInsertionPass> {
34+
void runOnOperation() override {
35+
func::FuncOp func = getOperation();
36+
std::deque<Region*> queue{&func.getBody()};
37+
int64_t block_counter = 0;
38+
Location loc = UnknownLoc::get(func.getContext());
39+
while (!queue.empty()) {
40+
Region* region = queue.front();
41+
queue.pop_front();
42+
for (auto it = region->begin(); it != region->end(); ++it) {
43+
Block& block = *it;
44+
if (block.empty()) {
45+
continue;
46+
}
47+
OpBuilder::atBlockBegin(&block).create<tpu::TraceStartOp>(
48+
loc, absl::StrCat("__block_", block_counter++), /*level=*/10);
49+
OpBuilder::atBlockTerminator(&block).create<tpu::TraceStopOp>(loc);
50+
for (Operation& op : block.without_terminator()) {
51+
for (Region &region : op.getRegions()) {
52+
if (!region.empty()) {
53+
queue.push_back(&region);
54+
}
55+
}
56+
}
57+
}
58+
}
59+
}
60+
};
61+
} // namespace
62+
std::unique_ptr<OperationPass<func::FuncOp>>
63+
createBasicBlockTraceInsertionPass() {
64+
return std::make_unique<BasicBlockTraceInsertionPass>();
65+
}
66+
} // namespace mlir::tpu

0 commit comments

Comments
 (0)