Skip to content

Commit 97da07c

Browse files
committed
Refactor MatMulTileConstraint
1 parent 6976c52 commit 97da07c

File tree

1 file changed

+50
-37
lines changed

1 file changed

+50
-37
lines changed

Deeploy/Targets/PULPOpen/TileConstraints/MatMulTileConstraint.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,50 @@ class MatMulTileConstraint(TileConstraint):
1919

2020
@staticmethod
2121
def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
22-
23-
# Get to-be-tiled tensor's buffers
2422
bufferA = ctxt.lookup(name = parseDict['A'])
2523
bufferB = ctxt.lookup(name = parseDict['B'])
26-
outputBuffer = ctxt.lookup(name = parseDict['data_out'])
24+
bufferOut = ctxt.lookup(name = parseDict['data_out'])
2725

2826
# Add I/O dimensions to the model as variables
29-
for _buffer in [bufferA, bufferB, outputBuffer]:
30-
tilerModel.addTensorDimToModel(ctxt, _buffer.name)
31-
32-
tensorsShapeLen = len(bufferA.shape)
33-
34-
AFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name,
35-
dimIdx = (tensorsShapeLen - 2) + int(parseDict['transA']))
36-
ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name,
37-
dimIdx = (tensorsShapeLen - 1) - int(parseDict['transA']))
38-
BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
39-
dimIdx = (tensorsShapeLen - 2) + int(parseDict['transB']))
40-
BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
41-
dimIdx = (tensorsShapeLen - 1) - int(parseDict['transB']))
42-
outputFirstDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = (tensorsShapeLen - 2))
43-
outputSecondDimVar = tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = (tensorsShapeLen - 1))
44-
45-
# Map output dims to inputs dims
46-
for idx in range(tensorsShapeLen - 2):
47-
tilerModel.addConstraint(
48-
tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = idx) == tilerModel.getTensorDimVar(
49-
tensorName = bufferA.name, dimIdx = idx))
50-
tilerModel.addConstraint(
51-
tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = idx) == tilerModel.getTensorDimVar(
52-
tensorName = bufferB.name, dimIdx = idx))
27+
for buff in [bufferA, bufferB, bufferOut]:
28+
tilerModel.addTensorDimToModel(ctxt, buff.name)
29+
30+
rankA = len(bufferA.shape)
31+
if not parseDict['transA']:
32+
firstDimIdxA, secondDimIdxA = rankA - 2, rankA - 1
33+
else:
34+
firstDimIdxA, secondDimIdxA = rankA - 1, rankA - 2
35+
AFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = firstDimIdxA)
36+
ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = secondDimIdxA)
37+
38+
rankB = len(bufferB.shape)
39+
if not parseDict['transB']:
40+
firstDimIdxB, secondDimIdxB = rankB - 2, rankB - 1
41+
else:
42+
firstDimIdxB, secondDimIdxB = rankB - 1, rankB - 2
43+
BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = firstDimIdxB)
44+
BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = secondDimIdxB)
45+
46+
rankOut = len(bufferOut.shape)
47+
outputFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferOut.name, dimIdx = rankOut - 2)
48+
outputSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferOut.name, dimIdx = rankOut - 1)
49+
50+
# Map batch dims between A and output
51+
batchDimsA = rankA - 2
52+
for dimIdx in range(batchDimsA):
53+
varA = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = dimIdx)
54+
varOut = tilerModel.getTensorDimVar(tensorName = bufferOut.name, dimIdx = (rankOut - rankA) + dimIdx)
55+
tilerModel.addConstraint(varOut == varA)
56+
57+
# Map batch dims between B and output
58+
batchDimsB = rankB - 2
59+
for dimIdx in range(batchDimsB):
60+
varB = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = dimIdx)
61+
varOut = tilerModel.getTensorDimVar(tensorName = bufferOut.name, dimIdx = (rankOut - rankB) + dimIdx)
62+
tilerModel.addConstraint(varOut == varB)
5363

5464
tilerModel.addConstraint(outputFirstDimVar == AFirstDimVar)
5565
tilerModel.addConstraint(outputSecondDimVar == BSecondDimVar)
56-
57-
# Add GEMM Geometrical constraints
5866
tilerModel.addConstraint(ASecondDimVar == BFirstDimVar)
5967

6068
return tilerModel
@@ -65,14 +73,19 @@ def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkCo
6573
bufferA = ctxt.lookup(name = parseDict['A'])
6674
bufferB = ctxt.lookup(name = parseDict['B'])
6775

68-
tensorsShapeLen = len(bufferA.shape)
69-
70-
ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name,
71-
dimIdx = (tensorsShapeLen - 1) - parseDict['transA'])
72-
BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
73-
dimIdx = (tensorsShapeLen - 2) + parseDict['transB'])
74-
BSecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name,
75-
dimIdx = (tensorsShapeLen - 1) - parseDict['transB'])
76+
rankA = len(bufferA.shape)
77+
if not parseDict['transA']:
78+
_, secondDimIdxA = rankA - 2, rankA - 1
79+
else:
80+
_, secondDimIdxA = rankA - 1, rankA - 2
81+
ASecondDimVar = tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = secondDimIdxA)
82+
83+
rankB = len(bufferB.shape)
84+
if not parseDict['transB']:
85+
firstDimIdxB, _ = rankB - 2, rankB - 1
86+
else:
87+
firstDimIdxB, _ = rankB - 1, rankB - 2
88+
BFirstDimVar = tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = firstDimIdxB)
7689

7790
# VIC: We don't want to deal with intermediate results between kernel calls
7891
tilerModel.addConstraint(ASecondDimVar == parseDict['N'])

0 commit comments

Comments
 (0)