Skip to content

Commit 997ad6e

Browse files
titaiwangmsjustinchuby
authored andcommitted
Fix Op(unflatten) (#2070)
The op was failing and not traced.
1 parent 946a940 commit 997ad6e

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8432,16 +8432,16 @@ def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
84328432
return op.SplitToSequence(self, split_sizes, axis=dim, keepdims=False)
84338433

84348434

8435-
@torch_op("aten::unflatten.int")
8436-
def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
8435+
@torch_op("aten::unflatten.int", trace_only=True)
8436+
def aten_unflatten(self: TReal, dim: int, sizes: Sequence[INT64]):
84378437
"""unflatten(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)"""
84388438

84398439
self_size = op.Shape(self)
84408440

84418441
# PyTorch accepts negative dim as reversed counting
8442-
self_rank = op.Size(self_size)
8443-
dim = self_rank + dim
8444-
dim = dim % self_rank
8442+
self_rank = len(self.shape)
8443+
if dim < 0:
8444+
dim = self_rank + dim
84458445

84468446
head_start_idx = op.Constant(value_ints=[0])
84478447
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
@@ -8451,8 +8451,16 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):
84518451
tail_end_idx = op.Constant(value_ints=[_INT64_MAX])
84528452
tail_part_rank = op.Slice(self_size, tail_start_idx, tail_end_idx)
84538453

8454-
final_shape = op.Concat(head_part_rank, sizes, tail_part_rank, axis=0)
8454+
sizes = [op.Reshape(size, op.Constant(value_ints=[1])) for size in sizes]
84558455

8456+
# corner case 1: head part is None
8457+
if dim == 0:
8458+
final_shape = op.Concat(*sizes, tail_part_rank, axis=0)
8459+
# corner case 2: tail part is None
8460+
elif dim == self_rank - 1:
8461+
final_shape = op.Concat(head_part_rank, *sizes, axis=0)
8462+
else:
8463+
final_shape = op.Concat(head_part_rank, *sizes, tail_part_rank, axis=0)
84568464
return op.Reshape(self, final_shape)
84578465

84588466

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -429,13 +429,6 @@ def _sum_input_wrangler(
429429
return args, kwargs
430430

431431

432-
def _unflatten_input_wrangler(
433-
args: list[Any], kwargs: dict[str, Any]
434-
) -> tuple[list[Any], dict[str, Any]]:
435-
args[1] = np.array(args[1], dtype=np.int64)
436-
return args, kwargs
437-
438-
439432
def _where_input_wrangler(
440433
args: list[Any], kwargs: dict[str, Any]
441434
) -> tuple[list[Any], dict[str, Any]]:
@@ -1471,14 +1464,9 @@ def _where_input_wrangler(
14711464
TorchLibOpInfo(
14721465
"unflatten",
14731466
core_ops.aten_unflatten,
1474-
input_wrangler=_unflatten_input_wrangler,
1475-
)
1476-
.xfail(
1467+
).xfail(
14771468
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
14781469
reason="fixme: Logic not implemented for size 0 inputs in op.Reshape",
1479-
)
1480-
.xfail(
1481-
reason="fixme: https://github.com/pytorch/pytorch/issues/146336",
14821470
),
14831471
TorchLibOpInfo("unfold", core_ops.aten_unfold),
14841472
TorchLibOpInfo("ops.aten.unfold", core_ops.aten_unfold),

0 commit comments

Comments
 (0)