Skip to content

Commit ea25f18

Browse files
committed
Refactored canonicalization patterns to use CompilerUtils::{isScaType, hasScaType, isConstant} for consistency and correctness; replaced arith ops with DAPHNE dialect ops to preserve rewrite applicability. Fixed type handling in rewrites, preserved transpose flags in MatMulOp, and revised/extended both IR- and script-level tests while removing unrelated changes.
1 parent c6f7563 commit ea25f18

27 files changed

Lines changed: 928 additions & 6 deletions

src/ir/daphneir/Canonicalize.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,213 @@
1919
#include "mlir/Support/LogicalResult.h"
2020
#include <compiler/utils/CompilerUtils.h>
2121

22+
mlir::LogicalResult mlir::daphne::AllAggSumOp::canonicalize(mlir::daphne::AllAggSumOp op,
23+
mlir::PatternRewriter &rewriter) {
24+
mlir::Value input = op.getOperand();
25+
mlir::Location location = op.getLoc();
26+
mlir::Type result_type = op.getResult().getType();
27+
auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext());
28+
29+
// Rule 1: sumAll(ewAdd(X, Y)) to ewAdd(sumAll(X), sumAll(Y))
30+
if (auto addOp = input.getDefiningOp<mlir::daphne::EwAddOp>()) {
31+
// Checking the inputs are matrices
32+
if (!addOp.getLhs().getType().isa<mlir::daphne::MatrixType>() ||
33+
!addOp.getRhs().getType().isa<mlir::daphne::MatrixType>()) {
34+
return mlir::failure();
35+
}
36+
37+
// Individual sums
38+
mlir::Value lSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, addOp.getLhs());
39+
mlir::Value rSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, addOp.getRhs());
40+
mlir::Value scalar_add = rewriter.create<mlir::daphne::EwAddOp>(location, result_type, lSum, rSum);
41+
42+
rewriter.replaceOp(op, scalar_add);
43+
return mlir::success();
44+
} // Rule 2: sumAll(transpose(X)) to sumAll(X)
45+
else if (auto transOp = input.getDefiningOp<mlir::daphne::TransposeOp>()) {
46+
mlir::Value input_tr = transOp.getArg();
47+
48+
// Inputs should be matrices
49+
if (!input_tr.getType().isa<mlir::daphne::MatrixType>()) {
50+
return mlir::failure();
51+
}
52+
53+
mlir::Value simplf_sumOftranspose = rewriter.create<mlir::daphne::AllAggSumOp>(location, result_type, input_tr);
54+
rewriter.replaceOp(op, simplf_sumOftranspose);
55+
return mlir::success();
56+
} // Rule 3: sum(lambda * X) -> lambda * sum(X)
57+
else if (auto lambdaMul = input.getDefiningOp<mlir::daphne::EwMulOp>()) {
58+
mlir::Value left_o = lambdaMul.getLhs();
59+
mlir::Value right_o = lambdaMul.getRhs();
60+
61+
mlir::Value scalarOperand;
62+
mlir::Value matrixOperand;
63+
64+
bool lhsIsSca = CompilerUtils::hasScaType(left_o);
65+
bool rhsIsSca = CompilerUtils::hasScaType(right_o);
66+
67+
// Use .getType() only for matrix detection
68+
bool lhsIsMatrix = left_o.getType().isa<mlir::daphne::MatrixType>();
69+
bool rhsIsMatrix = right_o.getType().isa<mlir::daphne::MatrixType>();
70+
71+
if (lhsIsSca && rhsIsMatrix) {
72+
scalarOperand = left_o;
73+
matrixOperand = right_o;
74+
} else if (rhsIsSca && lhsIsMatrix) {
75+
scalarOperand = right_o;
76+
matrixOperand = left_o;
77+
} else {
78+
return mlir::failure(); // Unsupported combination
79+
}
80+
81+
mlir::Value innerSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, unknownType, matrixOperand);
82+
mlir::Value newMul = rewriter.create<mlir::daphne::EwMulOp>(location, result_type, scalarOperand, innerSum);
83+
rewriter.replaceOp(op, newMul);
84+
85+
} // Rule 4: trace(X @ Y) = sum(diagVector(X @ Y)) -> sum(X * transpose(Y))
86+
else if (auto diagVec = input.getDefiningOp<mlir::daphne::DiagVectorOp>()) {
87+
mlir::Value input_dV = diagVec.getOperand(); // This should be a matrix (result of MatMul)
88+
if (auto matMul = input_dV.getDefiningOp<mlir::daphne::MatMulOp>()) {
89+
mlir::Value lhs = matMul.getLhs();
90+
mlir::Value rhs = matMul.getRhs();
91+
92+
if (!lhs.getType().isa<mlir::daphne::MatrixType>() || !rhs.getType().isa<mlir::daphne::MatrixType>()) {
93+
return mlir::failure();
94+
}
95+
96+
mlir::Value t_rhs = rewriter.create<mlir::daphne::TransposeOp>(location, unknownType, rhs);
97+
mlir::Value ewMul_m = rewriter.create<mlir::daphne::EwMulOp>(location, unknownType, lhs, t_rhs);
98+
mlir::Value simplifiedSum = rewriter.create<mlir::daphne::AllAggSumOp>(location, result_type, ewMul_m);
99+
100+
rewriter.replaceOp(op, simplifiedSum);
101+
return mlir::success();
102+
}
103+
}
104+
105+
return mlir::failure();
106+
}
107+
108+
/**
109+
* @brief Canonicalizes:
110+
1)(X%*%Y)[7,3] → X[7,]%*%Y[,3]
111+
112+
*/
113+
mlir::LogicalResult mlir::daphne::SliceColOp::canonicalize(mlir::daphne::SliceColOp op,
114+
mlir::PatternRewriter &rewriter) {
115+
mlir::Value input = op.getOperand(0);
116+
mlir::Location location = op.getLoc();
117+
mlir::Type result_type = op.getResult().getType();
118+
auto unknownType = mlir::daphne::UnknownType::get(rewriter.getContext());
119+
120+
auto sliceRowOp = input.getDefiningOp<mlir::daphne::SliceRowOp>();
121+
if (!sliceRowOp) {
122+
return mlir::failure();
123+
}
124+
125+
auto matMulOp = sliceRowOp.getOperand(0).getDefiningOp<mlir::daphne::MatMulOp>();
126+
if (!matMulOp) {
127+
return mlir::failure();
128+
}
129+
130+
// matrices
131+
mlir::Value X = matMulOp.getLhs();
132+
mlir::Value Y = matMulOp.getRhs();
133+
134+
// lower-upper bounds for rows
135+
mlir::Value row_l = sliceRowOp.getOperand(1);
136+
mlir::Value row_u = sliceRowOp.getOperand(2);
137+
138+
// lower-upper bounds for columns
139+
mlir::Value col_l = op.getOperand(1);
140+
mlir::Value col_u = op.getOperand(2);
141+
142+
// to check if a matrix is transposed
143+
mlir::Value t_X = matMulOp.getOperand(2);
144+
mlir::Value t_Y = matMulOp.getOperand(3);
145+
146+
bool isTransposedX = CompilerUtils::isConstant<bool>(t_X).second;
147+
bool isTransposedY = CompilerUtils::isConstant<bool>(t_Y).second;
148+
149+
mlir::Value row;
150+
mlir::Value col;
151+
152+
if (isTransposedX && isTransposedY) {
153+
row = rewriter.create<mlir::daphne::SliceColOp>(location, unknownType, X, row_l, row_u);
154+
col = rewriter.create<mlir::daphne::SliceRowOp>(location, unknownType, Y, col_l, col_u);
155+
} else if (!isTransposedX && isTransposedY) {
156+
row = rewriter.create<mlir::daphne::SliceRowOp>(location, unknownType, X, row_l, row_u);
157+
col = rewriter.create<mlir::daphne::SliceRowOp>(location, unknownType, Y, col_l, col_u);
158+
} else if ((isTransposedX && !isTransposedY)) {
159+
row = rewriter.create<mlir::daphne::SliceColOp>(location, unknownType, X, row_l, row_u);
160+
col = rewriter.create<mlir::daphne::SliceColOp>(location, unknownType, Y, col_l, col_u);
161+
} else if (!isTransposedX && !isTransposedY) {
162+
row = rewriter.create<mlir::daphne::SliceRowOp>(location, unknownType, X, row_l, row_u);
163+
col = rewriter.create<mlir::daphne::SliceColOp>(location, unknownType, Y, col_l, col_u);
164+
} else {
165+
return mlir::failure();
166+
}
167+
168+
auto newMatMul = rewriter.create<mlir::daphne::MatMulOp>(location, result_type, row, col, t_X, t_Y);
169+
rewriter.replaceOp(op, newMatMul.getResult());
170+
return mlir::success();
171+
}
172+
173+
/** @brief Canonicalizes:
174+
1)X[a:b, c:d] = Y -> X=Y if dims(X) = dims(Y)
175+
//only for matrices with matching element types
176+
*/
177+
mlir::LogicalResult mlir::daphne::InsertRowOp::canonicalize(mlir::daphne::InsertRowOp op,
178+
mlir::PatternRewriter &rewriter) {
179+
mlir::Location location = op.getLoc();
180+
mlir::Type result_type = op.getResult().getType();
181+
182+
auto insertCol = op.getIns().getDefiningOp<mlir::daphne::InsertColOp>();
183+
if (!insertCol) {
184+
return mlir::failure();
185+
}
186+
187+
auto sliceRow = insertCol.getArg().getDefiningOp<mlir::daphne::SliceRowOp>();
188+
if (!sliceRow) {
189+
return mlir::failure();
190+
}
191+
192+
mlir::Value sliceInput = sliceRow.getSource(); // X
193+
mlir::Value insertColInput = insertCol.getIns(); // Y
194+
if (!sliceInput.getType().isa<mlir::daphne::MatrixType>() ||
195+
!insertColInput.getType().isa<mlir::daphne::MatrixType>()) {
196+
return mlir::failure();
197+
}
198+
199+
auto sliceType = sliceInput.getType().dyn_cast<mlir::daphne::MatrixType>();
200+
auto insertColInputType = insertColInput.getType().dyn_cast<mlir::daphne::MatrixType>();
201+
auto opResultType = op.getResult().getType().dyn_cast<mlir::daphne::MatrixType>();
202+
203+
if (!sliceType || !insertColInputType || !opResultType) {
204+
return mlir::failure();
205+
}
206+
207+
if (sliceType.getElementType() != insertColInputType.getElementType()) {
208+
return mlir::failure();
209+
}
210+
211+
int64_t numRows_X = sliceType.getNumRows();
212+
int64_t numCols_X = sliceType.getNumCols();
213+
int64_t numRows_Y = insertColInputType.getNumRows();
214+
int64_t numCols_Y = insertColInputType.getNumCols();
215+
216+
if (numRows_X == -1 || numCols_X == -1 || numRows_Y == -1 || numCols_Y == -1) {
217+
return mlir::failure();
218+
}
219+
220+
if (numRows_X != numRows_Y || numCols_X != numCols_Y) {
221+
return mlir::failure();
222+
}
223+
224+
auto renamed = rewriter.create<mlir::daphne::RenameOp>(location, result_type, insertColInput);
225+
rewriter.replaceOp(op, renamed.getResult());
226+
return mlir::success();
227+
}
228+
22229
mlir::LogicalResult mlir::daphne::VectorizedPipelineOp::canonicalize(mlir::daphne::VectorizedPipelineOp op,
23230
mlir::PatternRewriter &rewriter) {
24231
// // Find duplicate inputs

src/ir/daphneir/DaphneOps.td

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ def Daphne_MatMulOp : Daphne_Op<"matMul", [
228228
DataTypeMat, ValueTypeFromArgs,
229229
DeclareOpInterfaceMethods<InferShapeOpInterface>,
230230
DeclareOpInterfaceMethods<InferSparsityOpInterface>, CUDASupport, FPGAOPENCLSupport,
231-
CastFirstTwoArgsToResType, NoMemoryEffect
231+
CastFirstTwoArgsToResType, NoMemoryEffect,
232+
Pure
232233
]> {
233234
let arguments = (ins MatrixOf<[NumScalar]>:$lhs, MatrixOf<[NumScalar]>:$rhs, BoolScalar:$transa, BoolScalar:$transb);
234235
let results = (outs MatrixOf<[NumScalar]>:$res);
@@ -498,7 +499,9 @@ class Daphne_AllAggOp<string name, Type scalarType, list<Trait> traits = []>
498499
let results = (outs scalarType:$res);
499500
}
500501

501-
def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>;
502+
def Daphne_AllAggSumOp : Daphne_AllAggOp<"sumAll", NumScalar, [ValueTypeFromFirstArg]>{
503+
let hasCanonicalizeMethod = 1;
504+
}
502505
def Daphne_AllAggMinOp : Daphne_AllAggOp<"minAll", NumScalar, [ValueTypeFromFirstArg]>;
503506
def Daphne_AllAggMaxOp : Daphne_AllAggOp<"maxAll", NumScalar, [ValueTypeFromFirstArg]>;
504507
def Daphne_AllAggMeanOp : Daphne_AllAggOp<"meanAll", NumScalar, [ValueTypeFromArgsFP]>;
@@ -649,7 +652,8 @@ def Daphne_ExtractRowOp : Daphne_Op<"extractRow", [
649652

650653
def Daphne_SliceRowOp : Daphne_Op<"sliceRow", [
651654
TypeFromFirstArg,
652-
DeclareOpInterfaceMethods<InferShapeOpInterface>
655+
DeclareOpInterfaceMethods<InferShapeOpInterface>,
656+
Pure
653657
]> {
654658
let summary = "Copies the specified rows from the argument to the result.";
655659

@@ -704,6 +708,7 @@ def Daphne_SliceColOp : Daphne_Op<"sliceCol", [
704708

705709
let arguments = (ins MatrixOrFrame:$source, SI64:$lowerIncl, SI64:$upperExcl);
706710
let results = (outs MatrixOrFrame:$res);
711+
let hasCanonicalizeMethod = 1;
707712
}
708713

709714
// TODO Create combined InsertOp (see #238).
@@ -714,11 +719,13 @@ def Daphne_InsertRowOp : Daphne_Op<"insertRow", [
714719
]> {
715720
let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$rowLowerIncl, SI64:$rowUpperExcl);
716721
let results = (outs MatrixOrFrame:$res);
722+
let hasCanonicalizeMethod = 1;
717723
}
718724

719725
def Daphne_InsertColOp : Daphne_Op<"insertCol", [
720726
TypeFromFirstArg, // this is debatable
721-
ShapeFromArg
727+
ShapeFromArg,
728+
Pure
722729
]> {
723730
let arguments = (ins MatrixOrFrame:$arg, MatrixOrFrame:$ins, SI64:$colLowerIncl, SI64:$colUpperExcl);
724731
let results = (outs MatrixOrFrame:$res);
@@ -989,7 +996,8 @@ def Daphne_SoftmaxOp : Daphne_Op<"Softmax", [ DataTypeFromFirstArg, ValueTypeFro
989996
// ****************************************************************************
990997

991998
def Daphne_DiagVectorOp : Daphne_Op<"diagVector", [
992-
TypeFromFirstArg, NumRowsFromArg, OneCol
999+
TypeFromFirstArg, NumRowsFromArg, OneCol,
1000+
Pure
9931001
]> {
9941002
let arguments = (ins MatrixOrU:$arg);
9951003
let results = (outs MatrixOrU:$res);

test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ include_directories(${PROJECT_SOURCE_DIR}/thirdparty/catch2) # for "catch.hpp"
1818
set(TEST_SOURCES
1919
run_tests.h
2020
run_tests.cpp
21-
21+
22+
api/cli/expressions/SimplificationTest.cpp
2223
api/cli/algorithms/AlgorithmsTest.cpp
2324
api/cli/algorithms/DecisionTreeRandomForestTest.cpp
2425
api/cli/config/ConfigTest.cpp
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include <api/cli/Utils.h>
2+
3+
#include <tags.h>
4+
5+
#include <catch.hpp>
6+
7+
#include <sstream>
8+
#include <string>
9+
10+
const std::string dirPath = "test/api/cli/expressions/";
11+
12+
#define MAKE_TEST_CASE(name, count) \
13+
TEST_CASE(name, TAG_REWRITE) { \
14+
for (unsigned i = 1; i <= count; i++) { \
15+
DYNAMIC_SECTION(name "_" << i << ".daphne") { compareDaphneToRefSimple(dirPath, name, i); } \
16+
} \
17+
}
18+
19+
MAKE_TEST_CASE("simplf_sumEwadd", 1)
20+
MAKE_TEST_CASE("simplf_sumTranspose", 1)
21+
MAKE_TEST_CASE("simplf_sumMulLambda", 1)
22+
MAKE_TEST_CASE("simplf_sumTrace", 1)
23+
MAKE_TEST_CASE("simplf_mmSlice", 1)
24+
MAKE_TEST_CASE("simplf_dynInsert", 1)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
m1 = [
3+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0,
4+
14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0
5+
];
6+
m1 = reshape(m1, 5, 5);
7+
8+
m2 = fill(0.0, 5, 5);
9+
m2 [0:5, 0:5] = m1;
10+
print(m1);
11+
print(m2);
12+
13+
m3 = fill(0.0, 7, 8);
14+
m3 [0:5, 0:5] = m1;
15+
print(m3);
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
DenseMatrix(5x5, double)
2+
1 2 3 4 5
3+
6 7 8 9 10
4+
11 12 13 14 15
5+
16 17 18 19 20
6+
21 22 23 24 25
7+
DenseMatrix(5x5, double)
8+
1 2 3 4 5
9+
6 7 8 9 10
10+
11 12 13 14 15
11+
16 17 18 19 20
12+
21 22 23 24 25
13+
DenseMatrix(7x8, double)
14+
1 2 3 4 5 0 0 0
15+
6 7 8 9 10 0 0 0
16+
11 12 13 14 15 0 0 0
17+
16 17 18 19 20 0 0 0
18+
21 22 23 24 25 0 0 0
19+
0 0 0 0 0 0 0 0
20+
0 0 0 0 0 0 0 0

0 commit comments

Comments
 (0)