Skip to content

Commit 8fcbac1

Browse files
Document known issues (#29)
This includes a markdown file with documentation for some known issues. The three included here comprise the vast majority of all errors that we encounter when attempting to import novel `nn.Module` instances. Each one also includes a minimal reproducible example. The vast majority of other errors (at least in the tests that I have sampled) have to do with unimplemented ops which can either be dealt with via including decompositions or implementing ops in upstream torch-mlir. Note: I included a small tweak to the importer, adding the ability to convert `None` to the appropriate `!torch.none` in our `TypeSubclassMap` because it 1) obscures the real issue in one of these cases and 2) is probably something we want there anyways. --------- Co-authored-by: brucekimrokcmu <[email protected]>
1 parent 2fd61ed commit 8fcbac1

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

python/shark_turbine/dynamo/importer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import operator
99
import re
10+
from types import NoneType
1011
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
1112

1213
from iree.compiler.ir import (
@@ -490,7 +491,7 @@ def _import_list_argument(self, loc: Location, arg):
490491
operand_type = type(operand)
491492
if not isinstance(operand, arg_type):
492493
raise TypeError(
493-
f"Lists with multiple types are not supported, got: {arg_type}, {operand_type}"
494+
f"Heterogeneous lists are not supported: expected {arg_type}, got {operand_type}"
494495
)
495496

496497
if isinstance(operand, torch.fx.Node):
@@ -588,7 +589,7 @@ def _make_constant_op(
588589

589590
LITERAL_CONVERTER_MAP = TypeSubclassMap()
590591
LITERAL_CONVERTER_MAP.map(
591-
type(None),
592+
NoneType,
592593
lambda arg, gni, cc: Operation.create(
593594
"torch.constant.none", results=[cc.torch_none_type]
594595
).result,
@@ -654,6 +655,7 @@ def _make_constant_op(
654655
float: "!torch.float",
655656
str: "!torch.str",
656657
bool: "!torch.bool",
658+
NoneType: "!torch.none",
657659
}
658660

659661
# AOT-autograd sometimes falsely emit tensor version op with scalar arguments.

python/shark_turbine/known_issues.md

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Known Issues in SHARK-Turbine
2+
3+
## Handling lists of optional types
4+
```py
5+
from torch import nn
6+
class foomod(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
10+
def forward(self, x):
11+
return self.up(x)
12+
```
13+
```
14+
# occuring in importer -> import_list_arguments
15+
compiler_fn raised TypeError: Heterogeneous lists are not supported: expected <class 'NoneType'>, got <class 'torch.fx.node.Node'>
16+
```
17+
An example is attempting to import `nn.Upsample`. This module internally makes a call to `F.interpolate` which eventually
18+
calls `aten.index.Tensor` whose [second argument](https://github.com/llvm/torch-mlir/blob/50f5b658b6dc50f664d78c89c403149b064fb59b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td#L7389C46-L7389C46) is an
19+
optional list of tensors. If indices in a few dimensions are omitted in favor of `None`, we get an error. In reality these values
20+
should have an `AnyTorchOptionalTensorType` type, we need a way to set optional types when importing lists in this scenario.
21+
22+
23+
## Dealing with functional variants of Torch Ops
24+
25+
```py
26+
import torch.nn.functional as F
27+
def forward(self, x):
28+
return F.max_pool2d(8, x)
29+
```
30+
```
31+
# occuring in importer -> import_list_arguments
32+
compiler_fn raised IndexError: list index out of range
33+
```
34+
35+
Currently, we have issues dealing with functional variants of
36+
torch operations that do not define meaningful defaults for their arguments.
37+
Two common operations for which this issue arises are `F.avg_pool2d` and `F.max_pool2d`.
38+
Taking `max_pool2d` as an example, the [functional version](https://pytorch.org/docs/stable/generated/torch.nn.functional.max_pool2d.html) sets `stride=None` by default (which returns an empty list to the importer),
39+
however, the actual intended default setting is to set `stride=kernel_size`. This issue does not occur with the corresponding `nn.Module` wrapper `MaxPool2d` because
40+
it actually [manually sets the intended default value](https://pytorch.org/docs/stable/_modules/torch/nn/modules/pooling.html#_MaxPoolNd). The same issue is at play in `avg_pool2d`.
41+
42+
43+
## Ephemeral Tensor objects from `aten.lift_fresh_copy`
44+
```py
45+
def forward(self, x, y):
46+
x[y == 1] = 2
47+
```
48+
```
49+
# in importer -> import_argument
50+
torch._dynamo.exc.BackendCompilerFailed: compiler_fn raised KeyError: (_tensor_constant0, 0)
51+
```
52+
This error arises due to an odd case in the Fx Graph generation where the
53+
graph module for our code generates a node `_tensor_constant0 = self._tensor_constant0` with no traceable origin within
54+
the graph. This means that our lookup for the appropriate MlirValue in the importer's `_v` table fails. This consistently
55+
occurs when the graph generates an intermediate `aten.lift_fresh_copy` as in the boolean indexing example above.
56+
The same error occurs in the expectedFailure test cases of `list(tensor_data)` and `tensor_data.tolist()`.
57+
58+
There is an existing issue in PyTorch that is tracking this problem in the `aot-eager` backend: https://github.com/pytorch/pytorch/issues/105327.
59+
This issue arises because this particular op is not handled in the PyTorch dispatch logic, and is instead suppresed [here](https://github.com/pytorch/pytorch/blob/ddf36c82b83b2db3be7ce7a85d4aea3507c9d7ef/torch/_dispatch/python.py#L108)

0 commit comments

Comments
 (0)