@@ -19,42 +19,50 @@ class MatMulTileConstraint(TileConstraint):
1919
2020 @staticmethod
2121 def addGeometricalConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
22-
23- # Get to-be-tiled tensor's buffers
2422 bufferA = ctxt .lookup (name = parseDict ['A' ])
2523 bufferB = ctxt .lookup (name = parseDict ['B' ])
26- outputBuffer = ctxt .lookup (name = parseDict ['data_out' ])
24+ bufferOut = ctxt .lookup (name = parseDict ['data_out' ])
2725
2826 # Add I/O dimensions to the model as variables
29- for _buffer in [bufferA , bufferB , outputBuffer ]:
30- tilerModel .addTensorDimToModel (ctxt , _buffer .name )
31-
32- tensorsShapeLen = len (bufferA .shape )
33-
34- AFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name ,
35- dimIdx = (tensorsShapeLen - 2 ) + int (parseDict ['transA' ]))
36- ASecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name ,
37- dimIdx = (tensorsShapeLen - 1 ) - int (parseDict ['transA' ]))
38- BFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name ,
39- dimIdx = (tensorsShapeLen - 2 ) + int (parseDict ['transB' ]))
40- BSecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name ,
41- dimIdx = (tensorsShapeLen - 1 ) - int (parseDict ['transB' ]))
42- outputFirstDimVar = tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = (tensorsShapeLen - 2 ))
43- outputSecondDimVar = tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = (tensorsShapeLen - 1 ))
44-
45- # Map output dims to inputs dims
46- for idx in range (tensorsShapeLen - 2 ):
47- tilerModel .addConstraint (
48- tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = idx ) == tilerModel .getTensorDimVar (
49- tensorName = bufferA .name , dimIdx = idx ))
50- tilerModel .addConstraint (
51- tilerModel .getTensorDimVar (tensorName = outputBuffer .name , dimIdx = idx ) == tilerModel .getTensorDimVar (
52- tensorName = bufferB .name , dimIdx = idx ))
27+ for buff in [bufferA , bufferB , bufferOut ]:
28+ tilerModel .addTensorDimToModel (ctxt , buff .name )
29+
30+ rankA = len (bufferA .shape )
31+ if not parseDict ['transA' ]:
32+ firstDimIdxA , secondDimIdxA = rankA - 2 , rankA - 1
33+ else :
34+ firstDimIdxA , secondDimIdxA = rankA - 1 , rankA - 2
35+ AFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name , dimIdx = firstDimIdxA )
36+ ASecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name , dimIdx = secondDimIdxA )
37+
38+ rankB = len (bufferB .shape )
39+ if not parseDict ['transB' ]:
40+ firstDimIdxB , secondDimIdxB = rankB - 2 , rankB - 1
41+ else :
42+ firstDimIdxB , secondDimIdxB = rankB - 1 , rankB - 2
43+ BFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name , dimIdx = firstDimIdxB )
44+ BSecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name , dimIdx = secondDimIdxB )
45+
46+ rankOut = len (bufferOut .shape )
47+ outputFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferOut .name , dimIdx = rankOut - 2 )
48+ outputSecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferOut .name , dimIdx = rankOut - 1 )
49+
50+ # Map batch dims between A and output
51+ batchDimsA = rankA - 2
52+ for dimIdx in range (batchDimsA ):
53+ varA = tilerModel .getTensorDimVar (tensorName = bufferA .name , dimIdx = dimIdx )
54+ varOut = tilerModel .getTensorDimVar (tensorName = bufferOut .name , dimIdx = (rankOut - rankA ) + dimIdx )
55+ tilerModel .addConstraint (varOut == varA )
56+
57+ # Map batch dims between B and output
58+ batchDimsB = rankB - 2
59+ for dimIdx in range (batchDimsB ):
60+ varB = tilerModel .getTensorDimVar (tensorName = bufferB .name , dimIdx = dimIdx )
61+ varOut = tilerModel .getTensorDimVar (tensorName = bufferOut .name , dimIdx = (rankOut - rankB ) + dimIdx )
62+ tilerModel .addConstraint (varOut == varB )
5363
5464 tilerModel .addConstraint (outputFirstDimVar == AFirstDimVar )
5565 tilerModel .addConstraint (outputSecondDimVar == BSecondDimVar )
56-
57- # Add GEMM Geometrical constraints
5866 tilerModel .addConstraint (ASecondDimVar == BFirstDimVar )
5967
6068 return tilerModel
@@ -65,14 +73,19 @@ def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkCo
6573 bufferA = ctxt .lookup (name = parseDict ['A' ])
6674 bufferB = ctxt .lookup (name = parseDict ['B' ])
6775
68- tensorsShapeLen = len (bufferA .shape )
69-
70- ASecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name ,
71- dimIdx = (tensorsShapeLen - 1 ) - parseDict ['transA' ])
72- BFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name ,
73- dimIdx = (tensorsShapeLen - 2 ) + parseDict ['transB' ])
74- BSecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name ,
75- dimIdx = (tensorsShapeLen - 1 ) - parseDict ['transB' ])
76+ rankA = len (bufferA .shape )
77+ if not parseDict ['transA' ]:
78+ _ , secondDimIdxA = rankA - 2 , rankA - 1
79+ else :
80+ _ , secondDimIdxA = rankA - 1 , rankA - 2
81+ ASecondDimVar = tilerModel .getTensorDimVar (tensorName = bufferA .name , dimIdx = secondDimIdxA )
82+
83+ rankB = len (bufferB .shape )
84+ if not parseDict ['transB' ]:
85+ firstDimIdxB , _ = rankB - 2 , rankB - 1
86+ else :
87+ firstDimIdxB , _ = rankB - 1 , rankB - 2
88+ BFirstDimVar = tilerModel .getTensorDimVar (tensorName = bufferB .name , dimIdx = firstDimIdxB )
7689
7790 # VIC: We don't want to deal with intermediate results between kernel calls
7891 tilerModel .addConstraint (ASecondDimVar == parseDict ['N' ])
0 commit comments