Skip to content

Commit fe14bf7

Browse files
[TorchDialicet] Fix bug for AtenOnes/AtenZeros/AtenFullOp fold function (#4320)
Co-authored by: [email protected] Fix bug for AtenOnes/AtenZeros/AtenFullOp fold function. Specify using mlir namespace for Integer/FloatType which was unexpected used in torch namespace.
1 parent 864c0a1 commit fe14bf7

File tree

4 files changed

+113
-18
lines changed

4 files changed

+113
-18
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4964,11 +4964,11 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) {
49644964
return nullptr;
49654965
}
49664966
auto elementType = shapedty.getElementType();
4967-
if (isa<IntegerType>(elementType)) {
4967+
if (isa<mlir::IntegerType>(elementType)) {
49684968
Attribute attribute = IntegerAttr::get(elementType, 1);
49694969
return DenseElementsAttr::get(shapedty, attribute);
49704970
}
4971-
if (isa<FloatType>(elementType)) {
4971+
if (isa<mlir::FloatType>(elementType)) {
49724972
Attribute attribute = FloatAttr::get(elementType, 1.0);
49734973
return DenseElementsAttr::get(shapedty, attribute);
49744974
}
@@ -5008,7 +5008,7 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) {
50085008
Attribute attribute = IntegerAttr::get(elementType, 0);
50095009
return DenseElementsAttr::get(shapedty, attribute);
50105010
}
5011-
if (isa<FloatType>(elementType)) {
5011+
if (isa<mlir::FloatType>(elementType)) {
50125012
Attribute attribute = FloatAttr::get(elementType, 0.0);
50135013
return DenseElementsAttr::get(shapedty, attribute);
50145014
}
@@ -5048,7 +5048,7 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) {
50485048
return DenseElementsAttr::get(shapedty, attribute);
50495049
}
50505050
}
5051-
if (isa<FloatType>(elementType)) {
5051+
if (isa<mlir::FloatType>(elementType)) {
50525052
double value = 0.0;
50535053
if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) {
50545054
Attribute attribute = FloatAttr::get(elementType, value);

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -857,10 +857,9 @@ func.func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vt
857857
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
858858
// CHECK: %[[VAL_2:.*]] = torch.constant.none
859859
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
860-
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
861-
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
862-
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
863-
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
860+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
861+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
862+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
864863
// CHECK: }
865864
func.func @torch.aten.zeros$basic() -> !torch.vtensor<[3,4],f32> {
866865
%int4 = torch.constant.int 4
@@ -925,10 +924,9 @@ func.func @torch.aten.contiguous$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !to
925924
// CHECK: %[[VAL_1:.*]] = torch.constant.int 3
926925
// CHECK: %[[VAL_2:.*]] = torch.constant.none
927926
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_0]] : (!torch.int, !torch.int) -> !torch.list<int>
928-
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<3x4xi32>}> : () -> tensor<3x4xi32>
929-
// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<3x4xi32>) -> tensor<3x4xf32>
930-
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
931-
// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32>
927+
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<3x4xf32>}> : () -> tensor<3x4xf32>
928+
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32>
929+
// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32>
932930
// CHECK: }
933931
func.func @torch.aten.ones$basic() -> !torch.vtensor<[3,4],f32> {
934932
%int4 = torch.constant.int 4

test/Dialect/Torch/canonicalize.mlir

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,3 +3498,101 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
34983498
torch.bind_symbolic_shape %3, [%0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32>
34993499
return %3 : !torch.vtensor<[?],f32>
35003500
}
3501+
3502+
// -----
3503+
3504+
// CHECK-LABEL: func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3505+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3506+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3507+
// CHECK: }
3508+
func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3509+
%int2 = torch.constant.int 2
3510+
%int3 = torch.constant.int 3
3511+
%int4 = torch.constant.int 4
3512+
%none = torch.constant.none
3513+
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3514+
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32>
3515+
return %1 : !torch.vtensor<[2,3,4],f32>
3516+
}
3517+
3518+
// -----
3519+
3520+
// CHECK-LABEL: func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3521+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3522+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3523+
// CHECK: }
3524+
func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3525+
%int2 = torch.constant.int 2
3526+
%int3 = torch.constant.int 3
3527+
%int4 = torch.constant.int 4
3528+
%none = torch.constant.none
3529+
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3530+
%1 = torch.aten.ones %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64>
3531+
return %1 : !torch.vtensor<[2,3,4],si64>
3532+
}
3533+
3534+
// -----
3535+
3536+
// CHECK-LABEL: func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3537+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3538+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3539+
// CHECK: }
3540+
func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3541+
%int2 = torch.constant.int 2
3542+
%int3 = torch.constant.int 3
3543+
%int4 = torch.constant.int 4
3544+
%none = torch.constant.none
3545+
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3546+
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32>
3547+
return %1 : !torch.vtensor<[2,3,4],f32>
3548+
}
3549+
3550+
// -----
3551+
3552+
// CHECK-LABEL: func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3553+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3554+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3555+
// CHECK: }
3556+
func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3557+
%int2 = torch.constant.int 2
3558+
%int3 = torch.constant.int 3
3559+
%int4 = torch.constant.int 4
3560+
%none = torch.constant.none
3561+
%0 = torch.prim.ListConstruct %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3562+
%1 = torch.aten.zeros %0, %none, %none, %none, %none : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64>
3563+
return %1 : !torch.vtensor<[2,3,4],si64>
3564+
}
3565+
3566+
// -----
3567+
3568+
// CHECK-LABEL: func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
3569+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0xFF800000> : tensor<2x1x4xf32>) : !torch.vtensor<[2,1,4],f32>
3570+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],f32>
3571+
// CHECK: }
3572+
func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
3573+
%float-Inf = torch.constant.float 0xFFF0000000000000
3574+
%int2 = torch.constant.int 2
3575+
%int1 = torch.constant.int 1
3576+
%int4 = torch.constant.int 4
3577+
%none = torch.constant.none
3578+
%0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3579+
%1 = torch.aten.full %0, %float-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],f32>
3580+
return %1 : !torch.vtensor<[2,1,4],f32>
3581+
}
3582+
3583+
// -----
3584+
3585+
// CHECK-LABEL: func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
3586+
// CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x1x4xsi64>) : !torch.vtensor<[2,1,4],si64>
3587+
// CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],si64>
3588+
// CHECK: }
3589+
func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
3590+
%int-Inf = torch.constant.int 0
3591+
%int2 = torch.constant.int 2
3592+
%int1 = torch.constant.int 1
3593+
%int4 = torch.constant.int 4
3594+
%none = torch.constant.none
3595+
%0 = torch.prim.ListConstruct %int2, %int1, %int4 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
3596+
%1 = torch.aten.full %0, %int-Inf, %none, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,1,4],si64>
3597+
return %1 : !torch.vtensor<[2,1,4],si64>
3598+
}

test/Dialect/Torch/fuse-quantized-ops.mlir

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,25 +85,24 @@ func.func @matmul_commuting(%arg0: !torch.vtensor<[2,128,32,32],si8>) -> !torch.
8585
// CHECK-LABEL: func.func @mm_pad_commute
8686
func.func @mm_pad_commute(%arg0: !torch.vtensor<[8,8],si8>, %arg1: !torch.vtensor<[11,4],si8>) -> !torch.vtensor<[9,4],f32> {
8787
// CHECK-DAG: %[[cstQuart:.*]] = torch.constant.float 2.500000e-01
88-
// CHECK-DAG: %[[int7:.*]] = torch.constant.int 7
89-
// CHECK-DAG: %[[none:.*]] = torch.constant.none
88+
// CHECK-DAG: %[[padVal:.*]] = torch.vtensor.literal(dense<8.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64>
9089
// CHECK-DAG: %[[qMax:.*]] = torch.constant.float 1.270000e+02
9190
// CHECK-DAG: %[[qMin:.*]] = torch.constant.float -1.280000e+02
92-
// CHECK-DAG: %[[padVal:.*]] = torch.constant.float 8.000000e+00
9391
// CHECK-DAG: %[[str:.*]] = torch.constant.str "constant"
9492
// CHECK-DAG: %[[cstHalf:.*]] = torch.constant.float 5.000000e-01
9593
// CHECK-DAG: %[[int0:.*]] = torch.constant.int 0
9694
// CHECK-DAG: %[[int1:.*]] = torch.constant.int 1
9795
// CHECK-DAG: %[[int2:.*]] = torch.constant.int 2
9896
// CHECK: %[[PadList:.*]] = torch.prim.ListConstruct %[[int1]], %[[int2]], %[[int0]], %[[int1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
99-
// CHECK: %[[EmptyList:.*]] = torch.prim.ListConstruct : () -> !torch.list<int>
100-
// CHECK: %[[Rank0:.*]] = torch.aten.full %[[EmptyList]], %[[padVal]], %[[int7]], %[[none]], %[[none]], %[[none]] : !torch.list<int>, !torch.float, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f64>
101-
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[Rank0]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
97+
// CHECK: %[[Clamp:.*]] = torch.aten.clamp %[[padVal]], %[[qMin]], %[[qMax]] : !torch.vtensor<[],f64>, !torch.float, !torch.float -> !torch.vtensor<[],f64>
10298
// CHECK: %[[Item:.*]] = torch.aten.item %[[Clamp]] : !torch.vtensor<[],f64> -> !torch.float
10399
// CHECK: %[[NewPad:.*]] = torch.aten.pad %arg0, %[[PadList]], %[[str]], %[[Item]] : !torch.vtensor<[8,8],si8>, !torch.list<int>, !torch.str, !torch.float -> !torch.vtensor<[9,11],si8>
104100
// CHECK: %[[NewMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[NewPad]], %[[cstHalf]], %[[int1]] : !torch.vtensor<[9,11],si8>, !torch.float, !torch.int -> !torch.vtensor<[9,11],!torch.qint8>
105101
// CHECK: %[[OtherMPTQT:.*]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[cstHalf]], %[[int0]] : !torch.vtensor<[11,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[11,4],!torch.qint8>
106102
// CHECK: %[[MM:.*]] = torch.aten.mm %[[NewMPTQT]], %[[OtherMPTQT]] : !torch.vtensor<[9,11],!torch.qint8>, !torch.vtensor<[11,4],!torch.qint8> -> !torch.vtensor<[9,4],!torch.qint32>
103+
// CHECK: %[[IR:.*]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],si32>
104+
// CHECK: %[[QOUT:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[IR]], %[[cstQuart]], %[[int0]] : !torch.vtensor<[9,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[9,4],!torch.qint32>
105+
// CHECK: %[[OUT:.*]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[9,4],!torch.qint32> -> !torch.vtensor<[9,4],f32>
107106
%scale = torch.constant.float 0.5
108107
%false = torch.constant.bool false
109108
%zero = torch.constant.int 0

0 commit comments

Comments
 (0)