@@ -3498,3 +3498,101 @@ func.func @torch.symbolic_int$canonicalize(%arg0: !torch.vtensor<[?],f32>, %arg1
3498
3498
torch.bind_symbolic_shape %3 , [%0 ], affine_map <()[s0 ] -> (s0 )> : !torch.vtensor <[?],f32 >
3499
3499
return %3 : !torch.vtensor <[?],f32 >
3500
3500
}
3501
+
3502
+ // -----
3503
+
3504
+ // CHECK-LABEL: func.func @ttorch.aten.ones$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3505
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3506
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3507
+ // CHECK: }
3508
+ func.func @ttorch.aten.ones$float_fold () -> !torch.vtensor <[2 ,3 ,4 ],f32 > {
3509
+ %int2 = torch.constant.int 2
3510
+ %int3 = torch.constant.int 3
3511
+ %int4 = torch.constant.int 4
3512
+ %none = torch.constant.none
3513
+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3514
+ %1 = torch.aten.ones %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],f32 >
3515
+ return %1 : !torch.vtensor <[2 ,3 ,4 ],f32 >
3516
+ }
3517
+
3518
+ // -----
3519
+
3520
+ // CHECK-LABEL: func.func @ttorch.aten.ones$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3521
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<1> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3522
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3523
+ // CHECK: }
3524
+ func.func @ttorch.aten.ones$int_fold () -> !torch.vtensor <[2 ,3 ,4 ],si64 > {
3525
+ %int2 = torch.constant.int 2
3526
+ %int3 = torch.constant.int 3
3527
+ %int4 = torch.constant.int 4
3528
+ %none = torch.constant.none
3529
+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3530
+ %1 = torch.aten.ones %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],si64 >
3531
+ return %1 : !torch.vtensor <[2 ,3 ,4 ],si64 >
3532
+ }
3533
+
3534
+ // -----
3535
+
3536
+ // CHECK-LABEL: func.func @test_aten_zeros$float_fold() -> !torch.vtensor<[2,3,4],f32> {
3537
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<2x3x4xf32>) : !torch.vtensor<[2,3,4],f32>
3538
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],f32>
3539
+ // CHECK: }
3540
+ func.func @test_aten_zeros$float_fold () -> !torch.vtensor <[2 ,3 ,4 ],f32 > {
3541
+ %int2 = torch.constant.int 2
3542
+ %int3 = torch.constant.int 3
3543
+ %int4 = torch.constant.int 4
3544
+ %none = torch.constant.none
3545
+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3546
+ %1 = torch.aten.zeros %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],f32 >
3547
+ return %1 : !torch.vtensor <[2 ,3 ,4 ],f32 >
3548
+ }
3549
+
3550
+ // -----
3551
+
3552
+ // CHECK-LABEL: func.func @test_aten_zeros$int_fold() -> !torch.vtensor<[2,3,4],si64> {
3553
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x3x4xsi64>) : !torch.vtensor<[2,3,4],si64>
3554
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,3,4],si64>
3555
+ // CHECK: }
3556
+ func.func @test_aten_zeros$int_fold () -> !torch.vtensor <[2 ,3 ,4 ],si64 > {
3557
+ %int2 = torch.constant.int 2
3558
+ %int3 = torch.constant.int 3
3559
+ %int4 = torch.constant.int 4
3560
+ %none = torch.constant.none
3561
+ %0 = torch.prim.ListConstruct %int2 , %int3 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3562
+ %1 = torch.aten.zeros %0 , %none , %none , %none , %none : !torch.list <int >, !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,3 ,4 ],si64 >
3563
+ return %1 : !torch.vtensor <[2 ,3 ,4 ],si64 >
3564
+ }
3565
+
3566
+ // -----
3567
+
3568
+ // CHECK-LABEL: func.func @torch.aten.full$float_fold() -> !torch.vtensor<[2,1,4],f32> {
3569
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0xFF800000> : tensor<2x1x4xf32>) : !torch.vtensor<[2,1,4],f32>
3570
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],f32>
3571
+ // CHECK: }
3572
+ func.func @torch.aten.full$float_fold () -> !torch.vtensor <[2 ,1 ,4 ],f32 > {
3573
+ %float -Inf = torch.constant.float 0xFFF0000000000000
3574
+ %int2 = torch.constant.int 2
3575
+ %int1 = torch.constant.int 1
3576
+ %int4 = torch.constant.int 4
3577
+ %none = torch.constant.none
3578
+ %0 = torch.prim.ListConstruct %int2 , %int1 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3579
+ %1 = torch.aten.full %0 , %float -Inf , %none , %none , %none , %none : !torch.list <int >, !torch.float , !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,1 ,4 ],f32 >
3580
+ return %1 : !torch.vtensor <[2 ,1 ,4 ],f32 >
3581
+ }
3582
+
3583
+ // -----
3584
+
3585
+ // CHECK-LABEL: func.func @torch.aten.full$int_fold() -> !torch.vtensor<[2,1,4],si64> {
3586
+ // CHECK: %[[VAL_0:.*]] = torch.vtensor.literal(dense<0> : tensor<2x1x4xsi64>) : !torch.vtensor<[2,1,4],si64>
3587
+ // CHECK: return %[[VAL_0]] : !torch.vtensor<[2,1,4],si64>
3588
+ // CHECK: }
3589
+ func.func @torch.aten.full$int_fold () -> !torch.vtensor <[2 ,1 ,4 ],si64 > {
3590
+ %int -Inf = torch.constant.int 0
3591
+ %int2 = torch.constant.int 2
3592
+ %int1 = torch.constant.int 1
3593
+ %int4 = torch.constant.int 4
3594
+ %none = torch.constant.none
3595
+ %0 = torch.prim.ListConstruct %int2 , %int1 , %int4 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
3596
+ %1 = torch.aten.full %0 , %int -Inf , %none , %none , %none , %none : !torch.list <int >, !torch.int , !torch.none , !torch.none , !torch.none , !torch.none -> !torch.vtensor <[2 ,1 ,4 ],si64 >
3597
+ return %1 : !torch.vtensor <[2 ,1 ,4 ],si64 >
3598
+ }
0 commit comments