diff --git a/tripy/tests/frontend/trace/ops/test_storage.py b/tripy/tests/frontend/trace/ops/test_storage.py index 90f8ba1ef..991934482 100644 --- a/tripy/tests/frontend/trace/ops/test_storage.py +++ b/tripy/tests/frontend/trace/ops/test_storage.py @@ -47,8 +47,8 @@ def test_from_list(self): def test_empty_list(self): data = [[]] - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data, dtype=tp.float16) - assert storage.dtype == tp.float16 + storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + assert storage.dtype == tp.float32 assert storage.shape == (1, 0) assert storage.device.kind == "gpu" diff --git a/tripy/tripy/backend/mlir/memref.py b/tripy/tripy/backend/mlir/memref.py index e63132776..9baba2362 100644 --- a/tripy/tripy/backend/mlir/memref.py +++ b/tripy/tripy/backend/mlir/memref.py @@ -21,7 +21,6 @@ import mlir_tensorrt.runtime.api as runtime from tripy.backend.mlir import utils as mlir_utils -from tripy.common import datatype from tripy.common import device as tp_device from tripy.utils import raise_error diff --git a/tripy/tripy/frontend/tensor.py b/tripy/tripy/frontend/tensor.py index 104a3f23a..88d0b488d 100644 --- a/tripy/tripy/frontend/tensor.py +++ b/tripy/tripy/frontend/tensor.py @@ -26,7 +26,6 @@ from tripy import export, utils from tripy.backend.mlir import memref from tripy.common import datatype -from tripy.common import utils as common_utils from tripy.common.exception import raise_error, str_from_stack_info from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY from tripy.frontend.trace.ops import Storage @@ -151,7 +150,7 @@ def raw_init( data = memref.create_memref_view(data) Storage.build_internal([], [instance.trace_tensor], data) else: - Storage.build_internal([], [instance.trace_tensor], data, None, device) + Storage.build_internal([], [instance.trace_tensor], data, device) # TODO(#155): Remove this hack: instance.trace_tensor.device = utils.default(device, instance.trace_tensor.device) diff --git a/tripy/tripy/frontend/trace/ops/storage.py b/tripy/tripy/frontend/trace/ops/storage.py index 9293ece28..9823e89d8 100644 --- a/tripy/tripy/frontend/trace/ops/storage.py +++ b/tripy/tripy/frontend/trace/ops/storage.py @@ -44,7 +44,6 @@ def __init__( inputs: List["Tensor"], outputs: List["Tensor"], data: Union[runtime.MemRefValue, Sequence[numbers.Number]], - dtype: datatype = None, device: tp_device = None, ) -> None: super().__init__(inputs, outputs) @@ -56,19 +55,18 @@ def __init__( self.device = tp_device.create_directly( "gpu" if data.address_space == runtime.PointerType.device else "cpu", 0 ) - elif common_utils.is_empty(data): - # special case: empty tensor - self.dtype = utils.default(dtype, datatype.float32) - self.shape = tuple(utils.get_shape(data)) - self.data = memref.create_memref(shape=self.shape, dtype=self.dtype) - self.device = utils.default(device, tp_device.create_directly("gpu", 0)) else: - self.dtype = dtype if dtype else common_utils.get_element_type(data) + if common_utils.is_empty(data): + self.dtype = datatype.float32 + data_array = None + else: + self.dtype = common_utils.get_element_type(data) + data_array = common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype) self.shape = tuple(utils.get_shape(data)) self.data = memref.create_memref( shape=self.shape, dtype=self.dtype, - array=common_utils.convert_list_to_array(utils.flatten_list(data), dtype=self.dtype), + array=data_array, ) self.device = utils.default(device, tp_device.create_directly("gpu", 0))