Skip to content

Commit

Permalink
Always construct memref value in storage op
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 committed Dec 11, 2024
1 parent 85d1d7d commit 1f4d8ac
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
3 changes: 0 additions & 3 deletions tripy/tests/frontend/trace/ops/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,20 @@ def test_from_memref(self, device):
module = np if device == "cpu" else cp
data = memref.create_memref_view(module.ones((2, 2), dtype=module.float32))
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.has_memref is True
assert storage.dtype == tp.float32
assert storage.shape == (2, 2)
assert storage.device.kind == device

def test_from_list(self):
data = [[1.0, 2.0], [3.0, 4.0]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data)
assert storage.has_memref is False
assert storage.dtype == tp.float32
assert storage.shape == (2, 2)
assert storage.device.kind == "gpu"

def test_empty_list(self):
data = [[]]
storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data, dtype=tp.float16)
assert storage.has_memref is True
assert storage.dtype == tp.float16
assert storage.shape == (1, 0)
assert storage.device.kind == "gpu"
Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def test_str(self):
str(trace)
== dedent(
"""
a = storage(data=[0], shape=(1,), dtype=int32, device=gpu:0)
b = storage(data=[1], shape=(1,), dtype=int32, device=gpu:0)
a = storage(shape=(1,), dtype=int32, device=gpu:0)
b = storage(shape=(1,), dtype=int32, device=gpu:0)
c = a + b
outputs:
c: [shape=([-1]), dtype=(int32), loc=(gpu:0)]
Expand Down
1 change: 1 addition & 0 deletions tripy/tripy/backend/mlir/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import mlir_tensorrt.runtime.api as runtime

from tripy.backend.mlir import utils as mlir_utils
from tripy.common import datatype
from tripy.common import device as tp_device
from tripy.utils import raise_error

Expand Down
10 changes: 9 additions & 1 deletion tripy/tripy/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

import array
from typing import Any, List, Sequence
from typing import Any, List, Sequence, Tuple

import tripy.common.datatype
from tripy.common.exception import raise_error
Expand Down Expand Up @@ -68,6 +68,14 @@ def convert_list_to_array(values: List[Any], dtype: str) -> bytes:

return array.array(TYPE_TO_FORMAT[dtype], values)

def get_array_supported_types() -> Tuple["tripy.common.datatype"]:
return (
tripy.common.datatype.bool,
tripy.common.datatype.int32,
tripy.common.datatype.int64,
tripy.common.datatype.float32,
)


def is_empty(data: Sequence) -> bool:
return isinstance(data, Sequence) and all(map(is_empty, data))
5 changes: 3 additions & 2 deletions tripy/tripy/frontend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tripy import export, utils
from tripy.backend.mlir import memref
from tripy.common import datatype
from tripy.common import utils as common_utils
from tripy.common.exception import raise_error, str_from_stack_info
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
from tripy.frontend.trace.ops import Storage
Expand Down Expand Up @@ -150,7 +151,7 @@ def raw_init(
data = memref.create_memref_view(data)
Storage.build_internal([], [instance.trace_tensor], data)
else:
Storage.build_internal([], [instance.trace_tensor], data, dtype, device)
Storage.build_internal([], [instance.trace_tensor], data, None, device)
# TODO(#155): Remove this hack:
instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device)

Expand Down Expand Up @@ -201,7 +202,7 @@ def device(self):
return self.trace_tensor.device

def eval(self) -> runtime.MemRefValue:
if isinstance(self.trace_tensor.producer, Storage) and self.trace_tensor.producer.has_memref:
if isinstance(self.trace_tensor.producer, Storage):
# Exit early if the tensor has already been evaluated.
# This happens before the imports below so we don't incur extra overhead.
return self.trace_tensor.producer.data
Expand Down
19 changes: 8 additions & 11 deletions tripy/tripy/frontend/trace/ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
@dataclass(repr=False)
class Storage(BaseTraceOp):

data: Union[runtime.MemRefValue, Sequence[numbers.Number]]
data: runtime.MemRefValue
shape: Sequence[int]
dtype: type
device: tp_device
Expand All @@ -56,30 +56,27 @@ def __init__(
self.device = tp_device.create_directly(
"gpu" if data.address_space == runtime.PointerType.device else "cpu", 0
)
self.has_memref = True
elif common_utils.is_empty(data):
# special case: empty tensor
self.dtype = utils.default(dtype, datatype.float32)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(shape=self.shape, dtype=self.dtype)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))
self.has_memref = True
else:
# If the input was a sequence, we need to copy it so that we don't take changes made
# to the list after the Storage op was constructed.
self.data = copy.copy(data)
self.dtype = dtype if dtype else common_utils.get_element_type(data)
self.shape = tuple(utils.get_shape(data))
self.data = memref.create_memref(
shape=self.shape,
dtype=self.dtype,
array=common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype),
)
self.device = utils.default(device, tp_device.create_directly("gpu", 0))
self.has_memref = False

self.outputs[0].shape = list(self.shape)

def str_skip_fields(self) -> Set[str]:
# skip data if i) it is a MemRefValue or ii) its volume exceeds threshold
if not isinstance(self.data, Sequence) or utils.should_omit_constant_in_str(self.shape):
return {"data"}
return set()
# skip data since it is always a memref value
return {"data"}

def __eq__(self, other) -> bool:
return self.data == other.data if isinstance(other, Storage) else False
Expand Down

0 comments on commit 1f4d8ac

Please sign in to comment.