From 09d75a6a04471ddabb62eb0d31dd430fe3bf7bf8 Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 31 Jan 2025 14:21:38 -0800 Subject: [PATCH] Refactors trace operations to be more self-contained, separates frontend/trace tensors more cleanly - Refactors Trace operation so that it reports how many outputs it generates instead of requiring the caller to know. The trace op is now also responsible for creating its own output trace tensors. Additionally, `build`/`build_internal` have been removed, meaning the trace does *not* create frontend tensors anymore. Frontend tensors no longer create trace tensors directly but instead only interface with ops and wrap their outputs as needed. - Consolidates and renames some frontend Tensor constructors to better reflect their purpose. For example, `create_directly` -> `fast_init`. - Temporarily removes the "how to add ops" guide. A new version of this will be written once we have switched to the TRT dialect, which will signficantly affect how ops are added. --- tripy/CONTRIBUTING.md | 3 - .../how-to-add-new-ops.md | 351 ------------------ tripy/notebooks/resnet50.ipynb | 4 +- tripy/nvtripy/backend/api/executable.py | 2 +- tripy/nvtripy/backend/mlir/executor.py | 2 +- tripy/nvtripy/common/device.py | 2 +- tripy/nvtripy/frontend/dimension_size.py | 9 - .../frontend/ops/binary_elementwise.py | 54 +-- tripy/nvtripy/frontend/ops/cast.py | 4 +- tripy/nvtripy/frontend/ops/concatenate.py | 3 +- tripy/nvtripy/frontend/ops/convolution.py | 5 +- tripy/nvtripy/frontend/ops/copy.py | 3 +- tripy/nvtripy/frontend/ops/dequantize.py | 2 +- tripy/nvtripy/frontend/ops/expand.py | 2 +- tripy/nvtripy/frontend/ops/fill.py | 6 +- tripy/nvtripy/frontend/ops/flip.py | 3 +- tripy/nvtripy/frontend/ops/gather.py | 3 +- tripy/nvtripy/frontend/ops/iota.py | 4 +- tripy/nvtripy/frontend/ops/matmul.py | 3 +- tripy/nvtripy/frontend/ops/pad.py | 3 +- tripy/nvtripy/frontend/ops/permute.py | 3 +- tripy/nvtripy/frontend/ops/plugin.py | 11 +- tripy/nvtripy/frontend/ops/pooling.py | 4 +- tripy/nvtripy/frontend/ops/quantize.py | 2 +- tripy/nvtripy/frontend/ops/reduce.py | 5 +- tripy/nvtripy/frontend/ops/reshape.py | 2 +- tripy/nvtripy/frontend/ops/resize.py | 6 +- tripy/nvtripy/frontend/ops/shape.py | 6 +- tripy/nvtripy/frontend/ops/slice.py | 3 +- tripy/nvtripy/frontend/ops/split.py | 5 +- tripy/nvtripy/frontend/ops/squeeze.py | 3 +- .../nvtripy/frontend/ops/unary_elementwise.py | 17 +- tripy/nvtripy/frontend/ops/utils.py | 37 +- tripy/nvtripy/frontend/ops/where.py | 3 +- tripy/nvtripy/frontend/tensor.py | 97 ++--- tripy/nvtripy/trace/ops/base.py | 90 ++--- tripy/nvtripy/trace/ops/fill.py | 2 +- tripy/nvtripy/trace/ops/iota.py | 2 +- tripy/nvtripy/trace/ops/plugin.py | 3 + tripy/nvtripy/trace/ops/split.py | 18 +- tripy/nvtripy/trace/ops/storage.py | 18 +- tripy/nvtripy/trace/tensor.py | 12 +- tripy/nvtripy/utils/function_registry.py | 2 +- tripy/nvtripy/utils/wrappers.py | 2 +- tripy/tests/backend/api/test_compile.py | 4 +- tripy/tests/frontend/ops/test_slice.py | 13 - tripy/tests/frontend/test_tensor.py | 28 +- tripy/tests/integration/test_iota.py | 2 +- tripy/tests/trace/ops/test_storage.py | 16 +- tripy/tests/wrappers/test_interface.py | 16 +- 50 files changed, 262 insertions(+), 638 deletions(-) delete mode 100644 tripy/docs/post0_developer_guides/how-to-add-new-ops.md diff --git a/tripy/CONTRIBUTING.md b/tripy/CONTRIBUTING.md index 9614e3c0b..b590421e2 100644 --- a/tripy/CONTRIBUTING.md +++ b/tripy/CONTRIBUTING.md @@ -74,9 +74,6 @@ We've written developer guides to help you understand the codebase: [architecture](https://nvidia.github.io/TensorRT-Incubator/post0_developer_guides/architecture.html) documentation. -- If you need to add a new operation, refer to - [this guide](https://nvidia.github.io/TensorRT-Incubator/post0_developer_guides/how-to-add-new-ops.html). - ### Tests diff --git a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md b/tripy/docs/post0_developer_guides/how-to-add-new-ops.md deleted file mode 100644 index 6d8a8d5db..000000000 --- a/tripy/docs/post0_developer_guides/how-to-add-new-ops.md +++ /dev/null @@ -1,351 +0,0 @@ -# Adding New Operators - -*You may find it helpful to read the [architecture](project:./architecture.md) documentation* - *before you start reading this guide.* - -Adding new operators to Tripy typically involves making changes in the frontend as well -as in the `FlatIR`. In some cases, the frontend operator can be expressed in terms of existing -`FlatIR` operators, in which case you only need to make changes in the frontend. - -Let's take a look at an example of how you might add an `Iota` operator to Tripy. -So that it doesn't clash with Tripy's actual `Iota` implementation, we'll call it -`Theta` instead. - - - - - - -## Implementation - -### `FlatIR` Operator - -The `FlatIR` operator is usually the most challenging aspect of implementing operators -in Tripy. The good news is that you might not even need to do this if the low-level operators -you need already exist in the `FlatIR`. And if you do, then it'll only get easier after this! - -We'll start by adding a new file under [`nvtripy/flat_ir/ops`](source:/nvtripy/flat_ir/ops/) called -`theta.py`; see the inline comments for explanations of what's happening: - -```py -# doc: no-eval -from dataclasses import dataclass - -from mlir_tensorrt.compiler import ir -from mlir_tensorrt.compiler.dialects import stablehlo - -from nvtripy.flat_ir.ops.base import BaseFlatIROp - - -# Every `FlatIR` operator is implemented as a `dataclass` so that the base -# class can automatically implement several methods by inspecting the child -# class fields at runtime. The `repr=False` is important because the default -# `__repr__` method generated by `dataclass` will be extremely verbose and -# makes interactive debugging more difficult. -@dataclass(repr=False) -class ThetaOp(BaseFlatIROp): - dim: int - - # `to_mlir()` is the trickiest bit. As the name implies, the method is - # meant to lower the `FlatIR` operator into MLIR. To figure out which - # MLIR operators to use, refer to the 'MLIR Python API Guide' - # (linked below). - def to_mlir(self, operands): - out_type = self.outputs[0].to_mlir() - theta_dim = ir.IntegerAttr.get(type=ir.IntegerType.get_signless(64), value=self.dim) - output = stablehlo.DynamicIotaOp(result=out_type, output_shape=operands[0], iota_dimension=theta_dim) - return [output] -``` - -Links: -- [MLIR Python API Guide](project:./mlir-dialect-python-apis.md) - - -### Exposing The Operator - -One of the principles we follow when writing submodules is that other submodules should -not need to reach into the internals of a submodule to retrieve something they need. - -For example, a class which needs to import `ThetaOp` does not need to know where exactly -within the `flat_ir.ops` module the `ThetaOp` lives - it should be able to just import it -from the submodule. - -To make this possible, we need to import the `ThetaOp` into the `flat_ir.ops` submodule. -We can do so by adding the following line into -[`nvtripy/flat_ir/ops/__init__.py`](source:/nvtripy/flat_ir/ops/__init__.py): - - - -```py -# doc: no-eval -from nvtripy.flat_ir.ops.theta import ThetaOp -``` - - - - - - -```py -# doc: no-eval -import nvtripy.flat_ir.ops -nvtripy.flat_ir.ops.ThetaOp = ThetaOp -``` - - - -## `Trace` Operator And The Public API - -Now that we have a `FlatIR` operator, we can implement a `Trace` operator that will use it -along with a public API function. Let's create a new file under -[`nvtripy/trace/ops`](source:/nvtripy/trace/ops/) called `theta.py`. - -### `Trace` Operator - -First, we'll implement the `Trace` operator itself: - -```py -# doc: no-eval -from dataclasses import dataclass -from typing import Tuple - -from nvtripy import utils -from nvtripy.common import datatype, device -from nvtripy.common.exception import raise_error -from nvtripy.trace.ops.base import BaseTraceOp -import nvtripy.trace.ops.utils as op_utils - - -# Just like with `FlatIR` operators, all `Trace` operators are implemented -# as `dataclass`es. As before, we want `repr=False` here. -@dataclass(repr=False) -class Theta(BaseTraceOp): - # Notice that we do *not* need to define a constructor and can rely on - # the default implementation provided by `dataclass`. - dim: int - dtype: datatype.dtype - - # `infer_rank()` populates the rank of the output `TraceTensor`s. - # Here we use one of the predefined policies to set the output rank - # to the same as the shape (i.e. the length) of the shape operand. - infer_rank = op_utils.InferRankPolicies.same_as_shape_of_shape_input() - - # *Optional* `infer_dtypes()` populates the data types of the - # output `TraceTensor`s. The default implementation copies the input - # data types if they are all the same, so you may not need to implement - # this. - def infer_dtypes(self): - self.outputs[0].dtype = self.dtype - - # *Optional* `infer_devices()` populates the devices of the - # output `TraceTensor`s. The default implementation copies the input - # devices if they are all the same, so you may not need to implement - # this either. - def infer_devices(self): - self.outputs[0].device = device("gpu") - - # `to_flat_ir()` translates the `Trace` operator to a subgraph of - # one or more `FlatIR` operators. In our case, it's just a 1:1 - # mapping to the `ThetaOp` we created earlier. - def to_flat_ir(self, inputs, outputs): - # Note that we import the `FlatIR` operator within the function - # call - this is to avoid circular dependencies. - from nvtripy.flat_ir.ops import ThetaOp - import nvtripy.trace.ops.utils as op_utils - - # This code may look a bit confusing; for more details, look at the - # 'FlatIR section in the architecture document' (linked below). - ThetaOp.build(inputs, outputs, dim=self.dim) -``` - -Links: -- [FlatIR section in the architecture document](project:./architecture.md#lowering-to-flatir) - - -### Public API - -Next, we can define the public interface. Since our public interface maps 1:1 with the `Trace` -operator we just implemented and does not require weights, we'll add it in the same file. - -If our API required a composition of multiple `Trace` operators, then we would instead implement -it under [`frontend/ops/`](source:/nvtripy/frontend/ops). - -If it required weights (i.e. inputs that are expected to always be constant), then we would implement -it as a `nvtripy.Module` under [`frontend/module`](source:/nvtripy/frontend/module). - -```py -# doc: no-eval -from nvtripy import export -from nvtripy.utils import wrappers -from nvtripy.types import ShapeLike - -# We can use the `export.public_api()` decorator to automatically export this -# function into the top-level module. This means it will be accessible as -# `nvtripy.theta`. -# -# This decorator also controls how the API is exposed in the documentation - -# the `document_under` option determines where in the documentation hierarchy -# this API will show up. -# -# If we needed to provide any special autodoc options, we could use the -# `autodoc_options` parameter. -@export.public_api(document_under="tensor_operations") - -# We can use the `wrappers.interface` decorator to specify constraints on -# inputs and perform transformations on them, like automatically converting -# compatible arguments (e.g., `TensorLike` or `ShapeLike`s) into tensors. -# We will aim to include most constraints and transformations in this decorator -# so as to avoid layering too many decorators. -@wrappers.interface(convert_to_tensors=True) -def theta(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "nvtripy.Tensor": - # For any public facing interfaces, we have documentation requirements which - # you can read about in the 'Docs README' (linked below). The docstring - # we've implemented here adheres to all of these requirements. Non-compliant - # docstrings will, in most cases, cause test failures; however, you should - # still manually ensure you're writing high-quality docstrings. - # - # The examples in docstrings are run as part of our tests, so you should - # also add assertions to make sure things are functionally correct. In this - # case, we check that the `output` we create in the code example is what we - # expect. - """ - Fills an output tensor with consecutive values starting from zero - along the given dimension. - - Args: - shape: The desired shape. - dim: Dimension along which to perform the theta operation. - This cannot exceed the rank of the specified shape. - dtype: The desired data type. - - Returns: - A tensor of shape ``shape`` and data type ``dtype``. - - .. code-block:: python - :linenos: - - output = tp.theta([3]) - - assert np.array_equal( - cp.from_dlpack(output).get(), np.arange(0, 3, dtype=np.float32) - ) - """ - - # Next we build the trace operator. The `build()` function is also - # responsible for constructing the output frontend Tensors. All of the - # arguments that follow the inputs are forwarded directly to the - # constructor of the `Trace` operator. - return Theta.build([shape], dim, dtype) - -``` - - - - -```py -# doc: no-eval -import nvtripy -nvtripy.theta = theta -``` - - - -Links: -- [Docs README](source:/docs/README.md#docstrings) - - -### Exposing The Operator - -Similarly to the `FlatIR` operator, we need to import `Theta` into the -`trace.ops` submodule. We can do so by adding the following line into -[`nvtripy/trace/ops/__init__.py`](source:/nvtripy/trace/ops/__init__.py): - - - -```py -# doc: no-eval -from nvtripy.trace.ops.theta import Theta, theta -``` - - - - -```py -# doc: no-eval -import nvtripy.trace.ops -nvtripy.trace.ops.Theta = Theta -nvtripy.trace.ops.theta = theta -``` - - -## Testing - -Now that we've implemented our operator, let's write tests for it. The structure of the -[`tests/`](source:/tests/) directory mirrors that of the [`nvtripy/`](source:/nvtripy/) directory -(you can read more about that [here](source:/tests/README.md)). We need to test both the `FlatIR` -and `Trace` operators. - - -### Testing The Trace Operator And Public API - -Since we implemented our `Trace` operator and public API in -[`nvtripy/trace/ops`](source:/nvtripy/trace/ops/), we'll add the test under -[`tests/trace/ops`](source:/tests/trace/ops/). -Create a new file there called `test_theta.py`: - - -```py -# doc: no-eval -import nvtripy as tp -from tests import helper -from nvtripy.trace.ops import Theta - - -class TestTheta: - # This ensures that the public API function creates a frontend `Tensor` - # and populates it with the right `Trace` operator. - def test_op_func(self): - a = tp.theta([2, 3]) - assert isinstance(a, tp.Tensor) - assert isinstance(a.trace_tensor.producer, Theta) - - # You should also include negative tests for anything that is expected to - # fail. In our case, we just have `test_invalid_dim`, - # which ensures that we emit an error if the `dim` parameter is outside - # the allowed range. - def test_invalid_dim(self): - with helper.raises(tp.TripyException, match="iota dimension cannot go beyond the output rank"): - tp.theta([2, 3], dim=3).eval() -``` - - -### Integration Tests - -The code examples in the docstring of the public API serve as good sanity integration tests. -However, you should still add separate integration tests to get better coverage. - -Our docstring covers the 1D case, so let's add an integration test to cover the multidimensional case. -Create a new file called `test_theta.py` under [`tests/integration`](source:/tests/integration/): - -```py -# doc: no-eval -import numpy as np -import cupy as cp - -import nvtripy as tp - - -def test_multi_dimensional(): - output = tp.theta([2, 3], dim=1) - expected = tp.Tensor([[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]], dtype=tp.float32) - - assert tp.equal(output, expected) -``` - -## Done! - -If you've reached this point, you have successfully added a new operation to -Tripy. Congratulations! - - diff --git a/tripy/notebooks/resnet50.ipynb b/tripy/notebooks/resnet50.ipynb index adbe1aa92..beba595cc 100644 --- a/tripy/notebooks/resnet50.ipynb +++ b/tripy/notebooks/resnet50.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "!python3 -m pip install nvtripy -f https://nvidia.github.io/TensorRT-Incubator/packages.html" + "%pip install nvtripy -f https://nvidia.github.io/TensorRT-Incubator/packages.html" ] }, { @@ -39,7 +39,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"datasets==2.18.0\" \"matplotlib>=3.9.1\" \"pillow>=9.4.0\" \"transformers==4.46.2\" \"torch>=2.3.1\"" + "%pip install \"datasets==2.18.0\" \"matplotlib>=3.9.1\" \"pillow>=9.4.0\" \"transformers==4.46.2\" \"torch>=2.3.1\"" ] }, { diff --git a/tripy/nvtripy/backend/api/executable.py b/tripy/nvtripy/backend/api/executable.py index 64d8e3671..aaf817102 100644 --- a/tripy/nvtripy/backend/api/executable.py +++ b/tripy/nvtripy/backend/api/executable.py @@ -229,7 +229,7 @@ def add(a, b): raise - output_tensors = [Tensor.create_directly(output, fetch_stack_info=False) for output in executor_outputs] + output_tensors = [Tensor.fast_init(output) for output in executor_outputs] if len(output_tensors) == 1: output_tensors = output_tensors[0] return output_tensors diff --git a/tripy/nvtripy/backend/mlir/executor.py b/tripy/nvtripy/backend/mlir/executor.py index abaf4d3ec..2a0a5e5ac 100644 --- a/tripy/nvtripy/backend/mlir/executor.py +++ b/tripy/nvtripy/backend/mlir/executor.py @@ -91,7 +91,7 @@ def _get_output_tensor_info(self, outputs_runtime_shape, output_devices): output_device = output_devices[index] if not output_device: - output_device = device.create_directly( + output_device = device.fast_init( "gpu" if memref.address_space == runtime.PointerType.device else "cpu", 0 ) diff --git a/tripy/nvtripy/common/device.py b/tripy/nvtripy/common/device.py index 418105152..22738c8b6 100644 --- a/tripy/nvtripy/common/device.py +++ b/tripy/nvtripy/common/device.py @@ -84,7 +84,7 @@ def __init__(self, device: str) -> None: # Not putting a docstring so it's not exported. Takes a device name and index directly, sets without validation. @staticmethod - def create_directly(kind: str, index: int) -> "tp.device": + def fast_init(kind: str, index: int) -> "tp.device": instance = device.__new__(device) instance.kind = kind instance.index = index diff --git a/tripy/nvtripy/frontend/dimension_size.py b/tripy/nvtripy/frontend/dimension_size.py index 92f02b1de..e857eac11 100644 --- a/tripy/nvtripy/frontend/dimension_size.py +++ b/tripy/nvtripy/frontend/dimension_size.py @@ -36,15 +36,6 @@ def __init__(self, data: int, name: Optional[str] = None) -> None: """ super().__init__(data=data, dtype=int32, name=name) - # Internal use only, leave undocumented so it's not exported. - # Creates a DimensionSize with data without checking (so None is permitted, which we do not want in the public constructor) - # and no overhead from the dispatch system. - @staticmethod - def create_directly(data: Optional[int], name: Optional[str] = None) -> "nvtripy.DimensionSize": - instance = DimensionSize.__new__(DimensionSize) - Tensor.raw_init(instance, data=data, dtype=int32, name=name) - return instance - def __int__(self) -> int: return self.tolist() diff --git a/tripy/nvtripy/frontend/ops/binary_elementwise.py b/tripy/nvtripy/frontend/ops/binary_elementwise.py index 558bb71e3..76d0aa835 100644 --- a/tripy/nvtripy/frontend/ops/binary_elementwise.py +++ b/tripy/nvtripy/frontend/ops/binary_elementwise.py @@ -17,11 +17,11 @@ from nvtripy import export -from nvtripy.types import TensorLike -from nvtripy.utils import wrappers - +from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.binary_elementwise import BinaryElementwise, Comparison +from nvtripy.types import TensorLike +from nvtripy.utils import wrappers @register_tensor_method("__add__") @@ -53,7 +53,7 @@ def __add__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([3, 5])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.SUM) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.SUM) @register_tensor_method("__sub__") @@ -83,7 +83,7 @@ def __sub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 1])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.SUB) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.SUB) @register_tensor_method("__rsub__") @@ -113,7 +113,7 @@ def __rsub__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([0, -1])) """ - return BinaryElementwise.build([other, self], BinaryElementwise.Kind.SUB) + return op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.SUB) @register_tensor_method("__pow__") @@ -143,7 +143,7 @@ def __pow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 8])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.POW) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.POW) @register_tensor_method("__rpow__") @@ -173,7 +173,7 @@ def __rpow__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([4.0, 8.0])) """ - return BinaryElementwise.build([other, self], BinaryElementwise.Kind.POW) + return op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.POW) @register_tensor_method("__mul__") @@ -205,7 +205,7 @@ def __mul__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 6.0])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.MUL) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.MUL) @register_tensor_method("__truediv__") @@ -235,7 +235,7 @@ def __truediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 2.0])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.DIV) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.DIV) @register_tensor_method("__rtruediv__") @@ -265,7 +265,7 @@ def __rtruediv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([3.0, 2.0])) """ - return BinaryElementwise.build([other, self], BinaryElementwise.Kind.DIV) + return op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.DIV) @register_tensor_method("__floordiv__") @@ -298,9 +298,11 @@ def __floordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": from nvtripy.common.datatype import int32 from nvtripy.frontend.ops.cast import cast - return cast(cast(BinaryElementwise.build([self, other], BinaryElementwise.Kind.DIV), int32), self.dtype) + return cast( + cast(op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.DIV), int32), self.dtype + ) # Use the below code when https://github.com/NVIDIA/TensorRT-Incubator/issues/208 is fixed - # return BinaryElementwise.build([self, other], BinaryElementwise.Kind.FLOOR_DIV) + # return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.FLOOR_DIV) @register_tensor_method("__rfloordiv__") @@ -333,9 +335,11 @@ def __rfloordiv__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor" from nvtripy.common.datatype import int32 from nvtripy.frontend.ops.cast import cast - return cast(cast(BinaryElementwise.build([other, self], BinaryElementwise.Kind.DIV), int32), self.dtype) + return cast( + cast(op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.DIV), int32), self.dtype + ) # Use the below code when https://github.com/NVIDIA/TensorRT-Incubator/issues/208 is fixed - # return BinaryElementwise.build([other, self], BinaryElementwise.Kind.FLOOR_DIV) + # return op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.FLOOR_DIV) @register_tensor_method("__mod__") @@ -365,7 +369,7 @@ def __mod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1.0, 2.0])) """ - return BinaryElementwise.build([self, other], BinaryElementwise.Kind.MOD) + return op_utils.create_op(BinaryElementwise, [self, other], BinaryElementwise.Kind.MOD) @register_tensor_method("__rmod__") @@ -394,7 +398,7 @@ def __rmod__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 2.0])) """ - return BinaryElementwise.build([other, self], BinaryElementwise.Kind.MOD) + return op_utils.create_op(BinaryElementwise, [other, self], BinaryElementwise.Kind.MOD) @export.public_api(document_under="operations/functions") @@ -423,7 +427,7 @@ def maximum(lhs: "nvtripy.Tensor", rhs: "nvtripy.Tensor") -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([2.0, 6.0])) """ - return BinaryElementwise.build([lhs, rhs], BinaryElementwise.Kind.MAXIMUM) + return op_utils.create_op(BinaryElementwise, [lhs, rhs], BinaryElementwise.Kind.MAXIMUM) @export.public_api(document_under="operations/functions") @@ -452,7 +456,7 @@ def minimum(lhs: "nvtripy.Tensor", rhs: "nvtripy.Tensor") -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1.0, 3.0])) """ - return BinaryElementwise.build([lhs, rhs], BinaryElementwise.Kind.MINIMUM) + return op_utils.create_op(BinaryElementwise, [lhs, rhs], BinaryElementwise.Kind.MINIMUM) @register_tensor_method("__lt__") @@ -485,7 +489,7 @@ def __lt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.LESS) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.LESS) @register_tensor_method("__le__") @@ -518,7 +522,7 @@ def __le__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.LESS_EQUAL) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.LESS_EQUAL) @register_tensor_method("__eq__") @@ -551,7 +555,7 @@ def __eq__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.EQUAL) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.EQUAL) @register_tensor_method("__ne__") @@ -584,7 +588,7 @@ def __ne__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.NOT_EQUAL) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.NOT_EQUAL) @register_tensor_method("__ge__") @@ -617,7 +621,7 @@ def __ge__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.GREATER_EQUAL) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.GREATER_EQUAL) @register_tensor_method("__gt__") @@ -650,4 +654,4 @@ def __gt__(self: "nvtripy.Tensor", other: TensorLike) -> "nvtripy.Tensor": assert output.tolist() == [True, False] """ - return Comparison.build([self, other], Comparison.Kind.GREATER) + return op_utils.create_op(Comparison, [self, other], Comparison.Kind.GREATER) diff --git a/tripy/nvtripy/frontend/ops/cast.py b/tripy/nvtripy/frontend/ops/cast.py index 19e56ffae..8968d624c 100644 --- a/tripy/nvtripy/frontend/ops/cast.py +++ b/tripy/nvtripy/frontend/ops/cast.py @@ -85,6 +85,6 @@ def cast(input: "nvtripy.Tensor", dtype: "nvtripy.dtype") -> "nvtripy.Tensor": if op_utils.is_quantized_dtype(dtype) and dtype != int8: if input.dtype != float32: - input = Cast.build([input], float32) + input = op_utils.create_op(Cast, [input], float32) return quantize(input, 1.0, dtype) - return Cast.build([input], dtype) + return op_utils.create_op(Cast, [input], dtype) diff --git a/tripy/nvtripy/frontend/ops/concatenate.py b/tripy/nvtripy/frontend/ops/concatenate.py index eea4e3f27..6d647c367 100644 --- a/tripy/nvtripy/frontend/ops/concatenate.py +++ b/tripy/nvtripy/frontend/ops/concatenate.py @@ -19,6 +19,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.concatenate import Concatenate from nvtripy.utils import wrappers @@ -57,4 +58,4 @@ def concatenate(tensors: Sequence["nvtripy.Tensor"], dim: int) -> "nvtripy.Tenso if len(tensors) == 1: return tensors[0] - return Concatenate.build(list(tensors), dim) + return op_utils.create_op(Concatenate, list(tensors), dim) diff --git a/tripy/nvtripy/frontend/ops/convolution.py b/tripy/nvtripy/frontend/ops/convolution.py index 52333551c..124f84c14 100644 --- a/tripy/nvtripy/frontend/ops/convolution.py +++ b/tripy/nvtripy/frontend/ops/convolution.py @@ -17,8 +17,9 @@ from collections.abc import Sequence -from nvtripy.utils import wrappers +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.convolution import Convolution +from nvtripy.utils import wrappers @wrappers.interface( @@ -36,4 +37,4 @@ def convolution( lhs_dilation: Sequence[int], rhs_dilation: Sequence[int], ): - return Convolution.build([input, weight], padding, stride, groups, lhs_dilation, rhs_dilation) + return op_utils.create_op(Convolution, [input, weight], padding, stride, groups, lhs_dilation, rhs_dilation) diff --git a/tripy/nvtripy/frontend/ops/copy.py b/tripy/nvtripy/frontend/ops/copy.py index 0a3eba456..414e87779 100644 --- a/tripy/nvtripy/frontend/ops/copy.py +++ b/tripy/nvtripy/frontend/ops/copy.py @@ -16,6 +16,7 @@ # from nvtripy import export +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.copy import Copy from nvtripy.utils import wrappers @@ -48,4 +49,4 @@ def copy(input: "nvtripy.Tensor", device: "nvtripy.device") -> "nvtripy.Tensor": assert output.trace_tensor.producer.device.kind == "cpu" """ - return Copy.build([input], device) + return op_utils.create_op(Copy, [input], device) diff --git a/tripy/nvtripy/frontend/ops/dequantize.py b/tripy/nvtripy/frontend/ops/dequantize.py index 17f4c059c..47eed4b56 100644 --- a/tripy/nvtripy/frontend/ops/dequantize.py +++ b/tripy/nvtripy/frontend/ops/dequantize.py @@ -107,4 +107,4 @@ def dequantize( op_utils.check_qdq_args(input, scale, dtype, dim, False) # See the note in quantize.py on why we don't just use frontend ops here. - return Dequantize.build([input, scale], dtype, dim) + return op_utils.create_op(Dequantize, [input, scale], dtype, dim) diff --git a/tripy/nvtripy/frontend/ops/expand.py b/tripy/nvtripy/frontend/ops/expand.py index 5d1c5332a..89bdb9edc 100644 --- a/tripy/nvtripy/frontend/ops/expand.py +++ b/tripy/nvtripy/frontend/ops/expand.py @@ -91,4 +91,4 @@ def expand(input: "nvtripy.Tensor", sizes: ShapeLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.broadcast_to(cp.from_dlpack(input).get(), (3, 1, 1))) """ - return Expand.build([input, sizes]) + return op_utils.create_op(Expand, [input, sizes]) diff --git a/tripy/nvtripy/frontend/ops/fill.py b/tripy/nvtripy/frontend/ops/fill.py index eb608c075..b7bd4cb76 100644 --- a/tripy/nvtripy/frontend/ops/fill.py +++ b/tripy/nvtripy/frontend/ops/fill.py @@ -52,7 +52,7 @@ def full(shape: ShapeLike, value: TensorLike, dtype: "nvtripy.dtype" = datatype. assert np.array_equal(cp.from_dlpack(output).get(), np.full([2, 3], 2, dtype=np.float32)) """ - return Fill.build([shape, value], dtype=dtype) + return op_utils.create_op(Fill, [shape, value], dtype=dtype) @export.public_api(document_under="operations/initializers") @@ -85,6 +85,6 @@ def full_like(input: "nvtripy.Tensor", value: TensorLike, dtype: Optional["nvtri assert np.array_equal(cp.from_dlpack(output).get(), np.array([[2, 2], [2, 2]], dtype=np.float32)) """ - return Fill.build( - [op_utils.tensor_from_shape_like(input.shape), value], dtype=utils.utils.default(dtype, input.dtype) + return op_utils.create_op( + Fill, [op_utils.tensor_from_shape_like(input.shape), value], dtype=utils.utils.default(dtype, input.dtype) ) diff --git a/tripy/nvtripy/frontend/ops/flip.py b/tripy/nvtripy/frontend/ops/flip.py index 1a84a8119..5bc3e7df2 100644 --- a/tripy/nvtripy/frontend/ops/flip.py +++ b/tripy/nvtripy/frontend/ops/flip.py @@ -19,6 +19,7 @@ from nvtripy import export, utils from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.flip import Flip from nvtripy.utils import wrappers @@ -84,4 +85,4 @@ def flip(input: "nvtripy.Tensor", dims: Optional[Union[int, Sequence[int]]] = No dims[i] = corrected_dim encountered.add(corrected_dim) - return Flip.build([input], dims=dims) + return op_utils.create_op(Flip, [input], dims=dims) diff --git a/tripy/nvtripy/frontend/ops/gather.py b/tripy/nvtripy/frontend/ops/gather.py index 379041ffc..9103d943f 100644 --- a/tripy/nvtripy/frontend/ops/gather.py +++ b/tripy/nvtripy/frontend/ops/gather.py @@ -17,6 +17,7 @@ from nvtripy import export +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.gather import Gather from nvtripy.utils import wrappers @@ -52,4 +53,4 @@ def gather(input: "nvtripy.Tensor", dim: int, index: "nvtripy.Tensor") -> "nvtri assert np.array_equal(cp.from_dlpack(output).get(), np.take(cp.from_dlpack(data).get(), cp.from_dlpack(indices).get(), axis=1)) """ - return Gather.build([input, index], dim) + return op_utils.create_op(Gather, [input, index], dim) diff --git a/tripy/nvtripy/frontend/ops/iota.py b/tripy/nvtripy/frontend/ops/iota.py index 5a7b7552d..f3569aefd 100644 --- a/tripy/nvtripy/frontend/ops/iota.py +++ b/tripy/nvtripy/frontend/ops/iota.py @@ -31,10 +31,10 @@ def iota_impl(shape: "nvtripy.Tensor", dim: int, dtype: datatype.dtype, output_r # Allocate a float32 tensor and cast the output to dtype. # `tensorrt.linspace` op result #0 must be 0D/1D/2D/3D/4D/5D/6D/7D/8D tensor of 32-bit float or 32-bit signless integer values. if dtype not in (datatype.float32, datatype.int32, datatype.int64): - result = Iota.build([shape], dim, output_rank, datatype.float32) + result = op_utils.create_op(Iota, [shape], dim, output_rank, datatype.float32) return cast(result, dtype) - return Iota.build([shape], dim, output_rank, dtype) + return op_utils.create_op(Iota, [shape], dim, output_rank, dtype) @export.public_api(document_under="operations/initializers") diff --git a/tripy/nvtripy/frontend/ops/matmul.py b/tripy/nvtripy/frontend/ops/matmul.py index 403773935..db7f4121b 100644 --- a/tripy/nvtripy/frontend/ops/matmul.py +++ b/tripy/nvtripy/frontend/ops/matmul.py @@ -16,6 +16,7 @@ # from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.matmul import MatrixMultiplication from nvtripy.utils import wrappers @@ -100,4 +101,4 @@ def get_batch_indices(rank): "rhs": contracting_dim[1], } - return MatrixMultiplication.build([self, other], contracting_dim, batching_dim) + return op_utils.create_op(MatrixMultiplication, [self, other], contracting_dim, batching_dim) diff --git a/tripy/nvtripy/frontend/ops/pad.py b/tripy/nvtripy/frontend/ops/pad.py index c25a1fea1..c935662cb 100644 --- a/tripy/nvtripy/frontend/ops/pad.py +++ b/tripy/nvtripy/frontend/ops/pad.py @@ -73,7 +73,8 @@ def pad( ) padding_low, padding_high = list(zip(*pad)) - return Pad.build( + return op_utils.create_op( + Pad, [ input, op_utils.tensor_from_shape_like(padding_low), diff --git a/tripy/nvtripy/frontend/ops/permute.py b/tripy/nvtripy/frontend/ops/permute.py index 2abcbb9a5..87ea2ecfa 100644 --- a/tripy/nvtripy/frontend/ops/permute.py +++ b/tripy/nvtripy/frontend/ops/permute.py @@ -19,6 +19,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.permute import Permute from nvtripy.utils import wrappers @@ -66,4 +67,4 @@ def permute(input: "nvtripy.Tensor", perm: Sequence[int]) -> "nvtripy.Tensor": ], ) - return Permute.build([input], perm) + return op_utils.create_op(Permute, [input], perm) diff --git a/tripy/nvtripy/frontend/ops/plugin.py b/tripy/nvtripy/frontend/ops/plugin.py index 9d5c81d56..c140a6734 100644 --- a/tripy/nvtripy/frontend/ops/plugin.py +++ b/tripy/nvtripy/frontend/ops/plugin.py @@ -18,6 +18,7 @@ from typing import List, Sequence, Tuple, Union from nvtripy import export +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.plugin import Plugin @@ -65,4 +66,12 @@ def plugin( assert tp.allclose(out,tp.gelu(inp)) """ - return Plugin.build(inputs, name, version, namespace, output_info, kwargs, num_outputs=len(output_info)) + return op_utils.create_op( + Plugin, + inputs, + name, + version, + namespace, + output_info, + kwargs, + ) diff --git a/tripy/nvtripy/frontend/ops/pooling.py b/tripy/nvtripy/frontend/ops/pooling.py index 0894db5b8..ce1b31903 100644 --- a/tripy/nvtripy/frontend/ops/pooling.py +++ b/tripy/nvtripy/frontend/ops/pooling.py @@ -79,7 +79,7 @@ def maxpool( stride = utils.utils.default(stride, [1] * spatial_dims) padding = utils.utils.default(padding, [(0, 0)] * spatial_dims) - return Pooling.build([input], Pooling.Kind.MAX, kernel_dims, stride, padding) + return op_utils.create_op(Pooling, [input], Pooling.Kind.MAX, kernel_dims, stride, padding) @export.public_api(document_under="operations/functions") @@ -138,4 +138,4 @@ def avgpool( stride = utils.utils.default(stride, [1] * spatial_dims) padding = utils.utils.default(padding, [(0, 0)] * spatial_dims) - return Pooling.build([input], Pooling.Kind.AVG, kernel_dims, stride, padding) + return op_utils.create_op(Pooling, [input], Pooling.Kind.AVG, kernel_dims, stride, padding) diff --git a/tripy/nvtripy/frontend/ops/quantize.py b/tripy/nvtripy/frontend/ops/quantize.py index 85670a766..1b679cc84 100644 --- a/tripy/nvtripy/frontend/ops/quantize.py +++ b/tripy/nvtripy/frontend/ops/quantize.py @@ -114,4 +114,4 @@ def quantize( # This is implemented using a special trace op instead of a combination of frontend ops # so that it shows up in the trace and can more easily be pattern matched (by defining our # own trace op, we have finer control over the generated MLIR). - return Quantize.build([input, scale], dtype, dim) + return op_utils.create_op(Quantize, [input, scale], dtype, dim) diff --git a/tripy/nvtripy/frontend/ops/reduce.py b/tripy/nvtripy/frontend/ops/reduce.py index 0e9da651f..be53baf1f 100644 --- a/tripy/nvtripy/frontend/ops/reduce.py +++ b/tripy/nvtripy/frontend/ops/reduce.py @@ -20,6 +20,7 @@ from nvtripy import export from nvtripy.common import datatype +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.reduce import ArgMinMax, Reduce from nvtripy.utils import wrappers from nvtripy.utils.utils import make_list @@ -35,7 +36,7 @@ def _reduce_impl(input: "nvtripy.Tensor", kind: Reduce.Kind, dim: Union[int, Seq from nvtripy.frontend.ops.reshape import reshape from nvtripy.frontend.ops.unsqueeze import unsqueeze - out = Reduce.build([input], adjust_dim(dim, input.rank), kind) + out = op_utils.create_op(Reduce, [input], adjust_dim(dim, input.rank), kind) if keepdim: if dim is None: out = reshape(out, (1,) * input.rank) @@ -316,7 +317,7 @@ def _arg_min_max_impl(tensor: "nvtripy.Tensor", kind: ArgMinMax.Kind, dim: Optio if dim is None: tensor = reshape(tensor, (-1,)) indices = iota_like(tensor, dim if dim else 0, datatype.int32) - out = ArgMinMax.build([tensor, indices], adjust_dim(dim, tensor.rank), kind) + out = op_utils.create_op(ArgMinMax, [tensor, indices], adjust_dim(dim, tensor.rank), kind) if keepdim: if dim is None: out = reshape(out, (1,) * original_rank) diff --git a/tripy/nvtripy/frontend/ops/reshape.py b/tripy/nvtripy/frontend/ops/reshape.py index 802b7c683..df37bd6dd 100644 --- a/tripy/nvtripy/frontend/ops/reshape.py +++ b/tripy/nvtripy/frontend/ops/reshape.py @@ -71,4 +71,4 @@ def reshape(input: "nvtripy.Tensor", shape: ShapeLike) -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.reshape(cp.from_dlpack(input).get(), (1, 6))) """ - return Reshape.build([input, shape], None) + return op_utils.create_op(Reshape, [input, shape], None) diff --git a/tripy/nvtripy/frontend/ops/resize.py b/tripy/nvtripy/frontend/ops/resize.py index 28e36438f..23a0354b8 100644 --- a/tripy/nvtripy/frontend/ops/resize.py +++ b/tripy/nvtripy/frontend/ops/resize.py @@ -73,7 +73,7 @@ def resize( assert torch.allclose(torch.from_dlpack(output).to("cpu"), expected) """ _check_mode(mode, align_corners) - return Resize.build([input, output_shape], mode, scales=None, align_corners=align_corners) + return op_utils.create_op(Resize, [input, output_shape], mode, scales=None, align_corners=align_corners) @export.public_api(document_under="operations/functions") @@ -113,4 +113,6 @@ def resize( """ _check_mode(mode, align_corners) - return Resize.build([input, op_utils.tensor_from_shape_like(input.shape)], mode, scales, align_corners) + return op_utils.create_op( + Resize, [input, op_utils.tensor_from_shape_like(input.shape)], mode, scales, align_corners + ) diff --git a/tripy/nvtripy/frontend/ops/shape.py b/tripy/nvtripy/frontend/ops/shape.py index 55e95c352..271970931 100644 --- a/tripy/nvtripy/frontend/ops/shape.py +++ b/tripy/nvtripy/frontend/ops/shape.py @@ -17,6 +17,7 @@ from nvtripy.common.datatype import DATA_TYPES +from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.shape import GetDimensionSize from nvtripy.types import ShapeLike @@ -48,4 +49,7 @@ def shape(self: "nvtripy.Tensor") -> ShapeLike: if all(dim >= 0 for dim in self.trace_tensor.shape) and not self.trace_tensor.is_compile_tracer: return self.trace_tensor.shape - return [GetDimensionSize.build([self], dim=index, always_cast_to_dimension_size=True) for index in range(self.rank)] + return [ + op_utils.create_op(GetDimensionSize, [self], dim=index, always_cast_to_dimension_size=True) + for index in range(self.rank) + ] diff --git a/tripy/nvtripy/frontend/ops/slice.py b/tripy/nvtripy/frontend/ops/slice.py index 16e0ea1d9..da91339af 100644 --- a/tripy/nvtripy/frontend/ops/slice.py +++ b/tripy/nvtripy/frontend/ops/slice.py @@ -19,6 +19,7 @@ from nvtripy import utils from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.frontend.ops._registry import register_tensor_method from nvtripy.trace.ops.slice import Slice from nvtripy.types import TensorLike @@ -207,4 +208,4 @@ def find_frame_index(arg): if len(candidates) == 1: source_info.column_range = candidates[0] - return Slice.build(inputs=[tensor, *slice_params]) + return op_utils.create_op(Slice, inputs=[tensor, *slice_params]) diff --git a/tripy/nvtripy/frontend/ops/split.py b/tripy/nvtripy/frontend/ops/split.py index 4c21ceaf0..1632f400a 100644 --- a/tripy/nvtripy/frontend/ops/split.py +++ b/tripy/nvtripy/frontend/ops/split.py @@ -19,6 +19,7 @@ from nvtripy import export from nvtripy.common.exception import raise_error +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.split import Split from nvtripy.utils import wrappers @@ -90,7 +91,6 @@ def split( if isinstance(indices_or_sections, int): if indices_or_sections <= 0: raise_error(f"Number of sections argument must be positive, but given {indices_or_sections}") - num_outputs = indices_or_sections else: if not indices_or_sections: raise_error("Split indices must not be empty") @@ -99,6 +99,5 @@ def split( if last and index < last: raise_error(f"Split indices must be given in ascending order, but given {indices_or_sections}") last = index - num_outputs = len(indices_or_sections) + 1 # add 1 because of the last split - return Split.build(inputs=[input], indices_or_sections=indices_or_sections, dim=dim, num_outputs=num_outputs) + return op_utils.create_op(Split, inputs=[input], indices_or_sections=indices_or_sections, dim=dim) diff --git a/tripy/nvtripy/frontend/ops/squeeze.py b/tripy/nvtripy/frontend/ops/squeeze.py index ae7324604..e2746d565 100644 --- a/tripy/nvtripy/frontend/ops/squeeze.py +++ b/tripy/nvtripy/frontend/ops/squeeze.py @@ -15,6 +15,7 @@ from typing import Sequence, Union from nvtripy import export, utils +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.squeeze import Squeeze from nvtripy.utils import wrappers @@ -65,4 +66,4 @@ def squeeze(input: "nvtripy.Tensor", dims: Union[Sequence[int], int]) -> "nvtrip assert np.array_equal(cp.from_dlpack(output).get(), np.squeeze(cp.from_dlpack(input).get(), (0, 2))) """ - return Squeeze.build([input], utils.utils.make_tuple(dims)) + return op_utils.create_op(Squeeze, [input], utils.utils.make_tuple(dims)) diff --git a/tripy/nvtripy/frontend/ops/unary_elementwise.py b/tripy/nvtripy/frontend/ops/unary_elementwise.py index c7be62620..41eb630ba 100644 --- a/tripy/nvtripy/frontend/ops/unary_elementwise.py +++ b/tripy/nvtripy/frontend/ops/unary_elementwise.py @@ -17,6 +17,7 @@ from nvtripy import export +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.unary_elementwise import UnaryElementwise from nvtripy.utils import wrappers @@ -46,7 +47,7 @@ def exp(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.exp(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.EXP) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.EXP) @export.public_api(document_under="operations/functions") @@ -72,7 +73,7 @@ def tanh(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.tanh(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.TANH) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.TANH) @export.public_api(document_under="operations/functions") @@ -98,7 +99,7 @@ def sin(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.sin(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.SINE) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.SINE) @export.public_api(document_under="operations/functions") @@ -124,7 +125,7 @@ def cos(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.cos(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.COSINE) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.COSINE) @export.public_api(document_under="operations/functions") @@ -150,7 +151,7 @@ def rsqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(1.0 / np.sqrt(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.RSQRT) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.RSQRT) @export.public_api(document_under="operations/functions") @@ -176,7 +177,7 @@ def sqrt(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.sqrt(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.SQRT) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.SQRT) @export.public_api(document_under="operations/functions") @@ -202,7 +203,7 @@ def log(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert tp.allclose(output, tp.Tensor(np.log(cp.from_dlpack(input).get()))) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.LOG) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.LOG) @export.public_api(document_under="operations/functions") @@ -228,4 +229,4 @@ def abs(input: "nvtripy.Tensor") -> "nvtripy.Tensor": assert np.array_equal(cp.from_dlpack(output).get(), np.array([1, 2], dtype=np.float32)) """ - return UnaryElementwise.build([input], UnaryElementwise.Kind.ABS) + return op_utils.create_op(UnaryElementwise, [input], UnaryElementwise.Kind.ABS) diff --git a/tripy/nvtripy/frontend/ops/utils.py b/tripy/nvtripy/frontend/ops/utils.py index 30a3f935b..461cd3cae 100644 --- a/tripy/nvtripy/frontend/ops/utils.py +++ b/tripy/nvtripy/frontend/ops/utils.py @@ -17,9 +17,38 @@ import nvtripy.common.datatype as tp_dtype +from nvtripy.common.datatype import int32 from nvtripy.common.exception import raise_error +# Creates a Trace operation from the provided frontend tensors and wraps its outputs in frontend Tensors. +def create_op(OpType, inputs, *args, always_cast_to_dimension_size=False, **kwargs): + from nvtripy.frontend.dimension_size import DimensionSize + from nvtripy.frontend.tensor import Tensor + + # Operations that operate on only DimensionSize inputs will always yield a DimensionSize. + # For any mixed operations, DimensionSize must be casted up to Tensor. + all_inputs_are_dimension_size = all(isinstance(inp, DimensionSize) for inp in inputs) + + def should_cast_to_dimension_size(out): + return always_cast_to_dimension_size or (all_inputs_are_dimension_size and out.dtype == int32 and out.rank == 0) + + STACK_DEPTH_OF_FROM_TRACE_TENSOR = 4 # Stack depth from API function calls + op = OpType([inp.trace_tensor for inp in inputs], *args, **kwargs) + outputs = [ + ( + DimensionSize.from_trace_tensor(out, include_code_index=STACK_DEPTH_OF_FROM_TRACE_TENSOR) + if should_cast_to_dimension_size(out) + else Tensor.from_trace_tensor(out, include_code_index=STACK_DEPTH_OF_FROM_TRACE_TENSOR) + ) + for out in op.outputs + ] + + if len(outputs) == 1: + return outputs[0] + return outputs + + def is_minus_one(arg): # Avoid doing an == with a Tensor return isinstance(arg, int) and arg == -1 @@ -28,12 +57,12 @@ def is_minus_one(arg): def tensor_from_shape_like(arg: "nvtripy.ShapeLike") -> "nvtripy.Tensor": from nvtripy.common.datatype import int32 from nvtripy.frontend.dimension_size import DimensionSize - from nvtripy.frontend.tensor import Tensor from nvtripy.frontend.ops.concatenate import concatenate from nvtripy.frontend.ops.reshape import Reshape + from nvtripy.frontend.tensor import Tensor if not arg: - return Tensor.create_directly([], dtype=int32) + return Tensor([], dtype=int32) concat_tensors = [] @@ -45,7 +74,7 @@ def empty_buffer(): if not int_buffer: return - concat_tensors.append(Tensor.create_directly(int_buffer, dtype=int32)) + concat_tensors.append(Tensor(int_buffer, dtype=int32)) int_buffer.clear() for elem in arg: @@ -53,7 +82,7 @@ def empty_buffer(): empty_buffer() # NOTE: We cannot use the reshape API here since it would lead to an # infinite loop when attempting to convert the shape input to a tensor. - concat_tensors.append(Reshape.build([elem, Tensor.create_directly([1])], 1)) + concat_tensors.append(create_op(Reshape, [elem, Tensor([1])], 1)) else: int_buffer.append(elem) diff --git a/tripy/nvtripy/frontend/ops/where.py b/tripy/nvtripy/frontend/ops/where.py index e0fb5a0a6..c175049ad 100644 --- a/tripy/nvtripy/frontend/ops/where.py +++ b/tripy/nvtripy/frontend/ops/where.py @@ -18,6 +18,7 @@ import numbers from nvtripy import export +from nvtripy.frontend.ops import utils as op_utils from nvtripy.trace.ops.where import Where from nvtripy.utils import wrappers @@ -57,7 +58,7 @@ def where(condition: "nvtripy.Tensor", input: "nvtripy.Tensor", other: "nvtripy. assert np.array_equal(cp.from_dlpack(output).get(), np.array([[1, 0], [1, 1]], dtype=np.float32)) """ - return Where.build([condition, input, other]) + return op_utils.create_op(Where, [condition, input, other]) @export.public_api(document_under="operations/functions") diff --git a/tripy/nvtripy/frontend/tensor.py b/tripy/nvtripy/frontend/tensor.py index 6c2bbf4b2..47d0eabfe 100644 --- a/tripy/nvtripy/frontend/tensor.py +++ b/tripy/nvtripy/frontend/tensor.py @@ -29,12 +29,6 @@ from nvtripy.frontend.ops._registry import TENSOR_METHOD_REGISTRY from nvtripy.logging.logger import logger from nvtripy.trace.ops.storage import Storage -from nvtripy.trace.tensor import TraceTensor -from nvtripy.utils.stack_info import StackInfo - -# We include code for everything above the `BaseTraceOp.build` function, which is called at most -# this many stack frames above the constructor. -STACK_DEPTH_OF_BUILD = 5 class TensorMeta(type): @@ -64,19 +58,11 @@ class Tensor(metaclass=TensorMeta): A tensor is a multi-dimensional array that contains elements of a uniform data type. """ - _COUNT = 0 - # This field communicates to NumPy that it should allow our right-side operator overloads (e.g. __radd__) to take # precedence over its own left-side overloads (e.g. __add__). This will ensure that an expression of the form # ` Tensor` will return a Tensor and not a NumPy array. __array_priority__ = 10000 - @classmethod - def _get_unique_name(cls): - name = f"t{cls._COUNT}" - cls._COUNT += 1 - return name - def __init__( self, data: Any, @@ -102,59 +88,42 @@ def __init__( """ # We use None internally but users should not be permitted to do it assert data is not None, "Data argument to Tensor must not be None" - Tensor.raw_init(self, data, dtype, device, name, fetch_stack_info) + self._stack_info = utils.stack_info.StackInfo([]) - # Left undocumented because this should only be used internally. - # Produces a new instance of a Tensor but avoids calling into the function registry, unlike the normal constructor. - @staticmethod - def create_directly( - data: Any, - dtype: Optional["nvtripy.dtype"] = None, - device: Optional["nvtripy.device"] = None, - name: Optional[str] = None, - fetch_stack_info: bool = True, - ): - instance = Tensor.__new__(Tensor) - Tensor.raw_init(instance, data, dtype, device, name, fetch_stack_info) - return instance - - # No docstring because this should be used only internally. Handles the logic for initializing a new instance. - # We separate this from __init__ because __init__ calls into the registry and rejects None values, which we use internally. - @staticmethod - def raw_init( - instance: Any, - data: Any, - dtype: Optional["nvtripy.dtype"] = None, - device: Optional["nvtripy.device"] = None, - name: Optional[str] = None, - fetch_stack_info: bool = True, - ): - stack_info = StackInfo([]) + storage = Storage(data, device=device if not hasattr(data, "__dlpack__") else None) + self.trace_tensor = storage.outputs[0] + self.trace_tensor.name = utils.utils.default(name, self.trace_tensor.name) if fetch_stack_info: - stack_info = utils.stack_info.get_stack_info(include_code_index=STACK_DEPTH_OF_BUILD) - - name = name if name is not None else Tensor._get_unique_name() - - instance.trace_tensor = TraceTensor(name, stack_info, dtype=None, device=device, producer=None, shape=None) - - # Note: It is important that we are able to call the Tensor constructor with no arguments - # since this is used internally. - if data is None: - return - - Storage.build_internal( - [], [instance.trace_tensor], data, device=device if not hasattr(data, "__dlpack__") else None - ) + # TODO (pranavm): Figure out the right stack depth + self.stack_info = utils.stack_info.get_stack_info(include_code_index=1) # TODO(#155): Remove this hack: - instance.trace_tensor.device = utils.utils.default(device, instance.trace_tensor.device) + self.trace_tensor.device = utils.utils.default(device, self.trace_tensor.device) # Explicit cast if necessary # TODO(#155): Add copy as well when host allocation is fixed - if dtype is not None and dtype != instance.trace_tensor.dtype: + if dtype is not None and dtype != self.trace_tensor.dtype: from nvtripy.frontend.ops.cast import cast - instance.trace_tensor = cast(instance, dtype=dtype).trace_tensor + self.trace_tensor = cast(self, dtype=dtype).trace_tensor + + # Left undocumented because these should only be used internally. + @classmethod + def from_trace_tensor(cls, trace_tensor, include_code_index=2): + instance = cls.__new__(cls) + instance.trace_tensor = trace_tensor + # TODO (pranavm): Figure out what stack depth to use here? + instance.stack_info = utils.stack_info.get_stack_info(include_code_index=include_code_index) + return instance + + # Faster constructor that bypasses things like function registry type checks and fetching stack info. + @staticmethod + def fast_init(data: Any): + instance = Tensor.__new__(Tensor) + storage = Storage(data) + instance.trace_tensor = storage.outputs[0] + instance.stack_info = utils.stack_info.StackInfo([]) + return instance def __getattr__(self, name: str): import nvtripy as tp @@ -173,10 +142,11 @@ def name(self, new_name): @property def stack_info(self): - return self.trace_tensor.stack_info + return self._stack_info @stack_info.setter def stack_info(self, new_stack_info): + self._stack_info = new_stack_info self.trace_tensor.stack_info = new_stack_info @property @@ -231,7 +201,14 @@ def eval(self) -> runtime.MemRefValue: assert len(data) == 1, "Expects only one output from mlir_tensorrt.compiler executor" data = data[0] - Storage.build_internal([], [self.trace_tensor], data) + storage = Storage(data) + # Need to carry forward `is_compile_tracer`: + storage.outputs[0].is_compile_tracer = self.trace_tensor.is_compile_tracer + + # Rebind this tensor, but be sure to preserve stack information: + self.trace_tensor = storage.outputs[0] + self.trace_tensor.stack_info = self.stack_info + # TODO(#155): Remove this hack of overriding the device type. self.trace_tensor.device = output_devices[0] diff --git a/tripy/nvtripy/trace/ops/base.py b/tripy/nvtripy/trace/ops/base.py index 219253c80..20f974060 100644 --- a/tripy/nvtripy/trace/ops/base.py +++ b/tripy/nvtripy/trace/ops/base.py @@ -16,10 +16,21 @@ # import abc -from dataclasses import dataclass -from typing import List, Optional, Set, Union +from dataclasses import dataclass, field +from typing import List, Set from nvtripy import utils +from nvtripy.trace.tensor import TraceTensor + +_COUNT = 0 + + +def _get_unique_name(): + global _COUNT + + name = f"t{_COUNT}" + _COUNT += 1 + return name @dataclass(repr=False) @@ -34,72 +45,25 @@ class BaseTraceOp(abc.ABC): inputs: List["TraceTensor"] """The input tensors of this operation""" - outputs: List["TraceTensor"] + outputs: List["TraceTensor"] = field(init=False) """The output tensors of this operation""" - @classmethod - def build_internal( - cls, inputs: List["TraceTensor"], outputs: List["TraceTensor"], *args, **kwargs - ) -> "BaseTraceOp": - """ - Builds a Trace operation and binds it to the provided input and output trace tensors. - - *args and **kwargs are passed along to the trace operation's constructor. - """ - op = cls(inputs, outputs, *args, **kwargs) - - is_compile_tracer = any(inp.is_compile_tracer for inp in inputs) - for out in op.outputs: - out.producer = op - out.is_compile_tracer |= is_compile_tracer - - op.infer_dtypes() - op.infer_rank() - op.infer_devices() - return op - - @classmethod - def build( - cls, inputs: List["Tensor"], *args, num_outputs=1, always_cast_to_dimension_size=False, **kwargs - ) -> Union["Tensor", List["Tensor"]]: - """ - Builds a trace operation and binds its inputs to the trace tensors corresponding to the - frontend tensors provided in `inputs` and creates `num_outputs` new frontend tensors for the - outputs, whose trace tensors are bound to the outputs of the trace operation. + def __post_init__(self): + is_compile_tracer = any(inp.is_compile_tracer for inp in self.inputs) + self.outputs = [ + TraceTensor(_get_unique_name(), producer=self, is_compile_tracer=is_compile_tracer) + for _ in range(self.get_num_outputs()) + ] - *args and **kwargs are passed along to the trace operation's constructor. + self.infer_dtypes() + self.infer_rank() + self.infer_devices() - `num_outputs=1` is treated as a special case that will return the output tensor directly instead - of returning a list of output tensors. + def get_num_outputs(self) -> int: """ - - from nvtripy.common.datatype import int32 - from nvtripy.frontend.dimension_size import DimensionSize - from nvtripy.frontend.tensor import Tensor - - # NOTE: If you change the stack depth where the tensors are constructed, update STACK_DEPTH_OF_BUILD in - # the Tensor constructor! - outputs = [Tensor.create_directly(None) for _ in range(num_outputs)] - - inp_trace_tensors = [inp.trace_tensor for inp in inputs] - out_trace_tensors = [out.trace_tensor for out in outputs] - cls.build_internal(inp_trace_tensors, out_trace_tensors, *args, **kwargs) - - # Operations that operate on only DimensionSize inputs will always yield a DimensionSize. - # For any mixed operations, DimensionSize must be casted up to Tensor. - all_inputs_are_dimension_size = all(isinstance(inp, DimensionSize) for inp in inputs) - for index, out in enumerate(outputs): - if always_cast_to_dimension_size or ( - all_inputs_are_dimension_size and out.dtype == int32 and out.rank == 0 - ): - dim_size = DimensionSize.create_directly(None) - dim_size.trace_tensor = out.trace_tensor - dim_size.stack_info = out.stack_info - outputs[index] = dim_size - - if num_outputs == 1: - return outputs[0] - return outputs + The number of output produced by this trace operation. + """ + return 1 @abc.abstractmethod def infer_rank(self): diff --git a/tripy/nvtripy/trace/ops/fill.py b/tripy/nvtripy/trace/ops/fill.py index dbdf9a063..8d775cba5 100644 --- a/tripy/nvtripy/trace/ops/fill.py +++ b/tripy/nvtripy/trace/ops/fill.py @@ -34,7 +34,7 @@ def infer_dtypes(self): def infer_devices(self): from nvtripy.common import device - self.outputs[0].device = device.create_directly("gpu", 0) + self.outputs[0].device = device.fast_init("gpu", 0) def to_flat_ir(self, inputs, outputs): from nvtripy.flat_ir.ops import ConvertOp, DynamicBroadcastOp diff --git a/tripy/nvtripy/trace/ops/iota.py b/tripy/nvtripy/trace/ops/iota.py index f7fadcba3..2e500de6b 100644 --- a/tripy/nvtripy/trace/ops/iota.py +++ b/tripy/nvtripy/trace/ops/iota.py @@ -36,7 +36,7 @@ def infer_dtypes(self): def infer_devices(self): from nvtripy.common import device - self.outputs[0].device = device.create_directly("gpu", 0) + self.outputs[0].device = device.fast_init("gpu", 0) def to_flat_ir(self, inputs, outputs): from nvtripy.flat_ir.ops import DynamicIotaOp diff --git a/tripy/nvtripy/trace/ops/plugin.py b/tripy/nvtripy/trace/ops/plugin.py index 673bd529e..ab1494b17 100644 --- a/tripy/nvtripy/trace/ops/plugin.py +++ b/tripy/nvtripy/trace/ops/plugin.py @@ -29,6 +29,9 @@ class Plugin(BaseTraceOp): output_info: List[Tuple[int, "nvtripy.dtype"]] creator_params: Dict[str, Any] + def get_num_outputs(self): + return len(self.output_info) + def infer_dtypes(self): for out, (_, dtype) in zip(self.outputs, self.output_info): out.dtype = dtype diff --git a/tripy/nvtripy/trace/ops/split.py b/tripy/nvtripy/trace/ops/split.py index aea5763a9..0e0cb4628 100644 --- a/tripy/nvtripy/trace/ops/split.py +++ b/tripy/nvtripy/trace/ops/split.py @@ -28,7 +28,7 @@ class Split(BaseTraceOp): indices_or_sections: Union[int, Sequence[int]] dim: int - def num_outputs(self): + def get_num_outputs(self): if isinstance(self.indices_or_sections, int): return self.indices_or_sections else: @@ -36,16 +36,16 @@ def num_outputs(self): return len(self.indices_or_sections) + 1 def infer_rank(self): - for i in range(self.num_outputs()): - self.outputs[i].rank = self.inputs[0].rank + for out in self.outputs: + out.rank = self.inputs[0].rank def infer_devices(self): - for i in range(self.num_outputs()): - self.outputs[i].device = self.inputs[0].device + for out in self.outputs: + out.device = self.inputs[0].device def infer_dtypes(self): - for i in range(self.num_outputs()): - self.outputs[i].dtype = self.inputs[0].dtype + for out in self.outputs: + out.dtype = self.inputs[0].dtype # gets input_tensor[..., :, :, start_idx: end_idx, :, :, ...], with the start and end slice only at the axis dimension def build_slice_of_target_dim(self, input_tensor, input_shape, device, start_idx, end_idx, output_tensor): @@ -108,7 +108,7 @@ def to_flat_ir(self, inputs, outputs): [axis_dim, op_utils.add_constant_tensor_from_list([self.indices_or_sections], device=device)], [section_size_tensor], ) - for i in range(self.num_outputs()): + for i in range(self.get_num_outputs()): with FlatIRTensor.context([f"compute indices of split {i}"]): # i*section_size section_i_start_tensor = FlatIRTensor.build( @@ -168,5 +168,5 @@ def __str__(self) -> str: if field.name not in skip_fields ] - outputs_string = ", ".join([self.outputs[i].name for i in range(self.num_outputs())]) + outputs_string = ", ".join([self.outputs[i].name for i in range(self.get_num_outputs())]) return f"{outputs_string} = {self.__class__.__name__.lower()}({', '.join([inp.name for inp in self.inputs] + args)})" diff --git a/tripy/nvtripy/trace/ops/storage.py b/tripy/nvtripy/trace/ops/storage.py index d15b8b519..5f08be923 100644 --- a/tripy/nvtripy/trace/ops/storage.py +++ b/tripy/nvtripy/trace/ops/storage.py @@ -16,7 +16,7 @@ # from dataclasses import dataclass -from typing import List, Sequence, Set, Any +from typing import Optional, Sequence, Set, Any import mlir_tensorrt.runtime.api as runtime @@ -40,13 +40,9 @@ class Storage(BaseTraceOp): def __init__( self, - inputs: List["Tensor"], - outputs: List["Tensor"], data: Any, - device: tp_device = None, + device: Optional[tp_device] = None, ) -> None: - super().__init__(inputs, outputs) - original_data = data # Handle if data is dlpacked but not memref yet @@ -57,9 +53,7 @@ def __init__( self.data = data self.dtype = mlir_utils.convert_runtime_dtype_to_tripy_dtype(self.data.dtype) self.shape = tuple(data.shape) - self.device = tp_device.create_directly( - "gpu" if data.address_space == runtime.PointerType.device else "cpu", 0 - ) + self.device = tp_device.fast_init("gpu" if data.address_space == runtime.PointerType.device else "cpu", 0) else: if common_utils.is_empty(data): self.dtype = datatype.float32 @@ -73,12 +67,14 @@ def __init__( dtype=self.dtype, array=data_array, ) - self.device = utils.utils.default(device, tp_device.create_directly("gpu", 0)) + self.device = utils.utils.default(device, tp_device.fast_init("gpu", 0)) # Set data_str only for objects that won't be treated as Trace inputs if not utils.utils.should_lift_storage_op_as_input(self.shape): self.data_str = str(original_data) # TODO (#448): Fix floating point str representation + # Parent constructor will run rank/type inference, so we need to run it after setting the fields above. + super().__init__([]) self.outputs[0].shape = list(self.shape) def str_skip_fields(self) -> Set[str]: @@ -98,7 +94,7 @@ def infer_dtypes(self): def infer_devices(self): # TODO(#155): Fix allocation on host - self.outputs[0].device = tp_device.create_directly("gpu", 0) + self.outputs[0].device = tp_device.fast_init("gpu", 0) def to_flat_ir(self, inputs, outputs): from nvtripy.flat_ir.ops import ConstantOp diff --git a/tripy/nvtripy/trace/tensor.py b/tripy/nvtripy/trace/tensor.py index 57ecb5d1f..ef4d8a9be 100644 --- a/tripy/nvtripy/trace/tensor.py +++ b/tripy/nvtripy/trace/tensor.py @@ -15,7 +15,7 @@ # limitations under the License. # -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import List, Optional from nvtripy.utils.stack_info import StackInfo @@ -28,11 +28,11 @@ class TraceTensor: """ name: str - stack_info: StackInfo - dtype: "nvtripy.common.dtype" - device: "nvtripy.common.device" producer: "BaseTraceOp" - shape: List[int] + dtype: "nvtripy.common.dtype" = field(default=None, init=False) + device: "nvtripy.common.device" = field(default=None, init=False) + shape: List[int] = field(default=None, init=False) + stack_info: StackInfo = field(default_factory=lambda: StackInfo([]), init=False) """ Indicates the shape of the tensor. Unknown dimensions are indicated by -1. Generally, the shape will only be known for shape tensors. @@ -42,7 +42,7 @@ class TraceTensor: is_compile_tracer: bool = False # Stack information for the point at which this tensor was evaluated if it was. # This is useful in the compiler to disallow evaluation during tracing. - eval_stack_info: Optional[StackInfo] = None + eval_stack_info: Optional[StackInfo] = field(default=None, init=False) def __str__(self) -> str: return ( diff --git a/tripy/nvtripy/utils/function_registry.py b/tripy/nvtripy/utils/function_registry.py index 0f124edd7..345c232c4 100644 --- a/tripy/nvtripy/utils/function_registry.py +++ b/tripy/nvtripy/utils/function_registry.py @@ -351,7 +351,7 @@ def impl(func): self[key] = func # For classes, we apply the wrapper to all methods. elif inspect.isclass(func): - # Ignore properties and functions not defined in the class (we will use the presence of a docstring as a proxy for that). + # Ignore non-public properties and functions and those not defined in the class (we will use the presence of a docstring as a proxy for that). # It does not suffice to check just that the method is inherited because some decorators like @dataclass add methods # that are not documented or annotated and do not use inheritance to do so. for name, member in inspect.getmembers( diff --git a/tripy/nvtripy/utils/wrappers.py b/tripy/nvtripy/utils/wrappers.py index 4ce4e903e..fdb070651 100644 --- a/tripy/nvtripy/utils/wrappers.py +++ b/tripy/nvtripy/utils/wrappers.py @@ -184,7 +184,7 @@ def add_arg(arg): # Python integers can always be cast to the most restrictive type, which is DimensionSize in Tripy. # DimensionSize can always be cast up to Tensor if needed, but the reverse is not true. # NOTE: We do not use isinstance here because bool is a subclass of int. - arg = DimensionSize.create_directly(arg) if type(arg) is int else Tensor.create_directly(arg) + arg = DimensionSize(arg) if type(arg) is int else Tensor(arg) _add_column_info( arg, diff --git a/tripy/tests/backend/api/test_compile.py b/tripy/tests/backend/api/test_compile.py index 1cf889353..9f87ccd8e 100644 --- a/tripy/tests/backend/api/test_compile.py +++ b/tripy/tests/backend/api/test_compile.py @@ -15,13 +15,11 @@ # limitations under the License. # import cupy as cp +import nvtripy as tp import pytest from tests import helper from tests.backend.api.conftest import * -import nvtripy as tp -from nvtripy.trace.ops.storage import Storage - class TestCompile: # TODO (#246): Verify that it's actually compiling somehow here and below. diff --git a/tripy/tests/frontend/ops/test_slice.py b/tripy/tests/frontend/ops/test_slice.py index 8e7910913..0fb157e0f 100644 --- a/tripy/tests/frontend/ops/test_slice.py +++ b/tripy/tests/frontend/ops/test_slice.py @@ -25,7 +25,6 @@ class TestSlice: def test_slice_of_inline_output(self): a = tp.Tensor([1, 2, 3, 4]) # The start and stop params use clamp bound, but the step parameter doesn't. - # The result is that the stack traces for the slice params are of different lengths. s = (a + a)[3:4:] assert isinstance(s, tp.Tensor) assert isinstance(s.trace_tensor.producer, Slice) @@ -38,18 +37,6 @@ def test_slice_of_inline_output(self): assert any(frame.function == "clamp_bound" for frame in slice_inputs[1].stack_info) assert not any(frame.function == "clamp_bound" for frame in slice_inputs[2].stack_info) - # Consequently, the frame corresponding to the caller is at different depths. - def index_of_caller(trace_input): - for i, frame in enumerate(trace_input.stack_info): - if frame.function == TestSlice.test_slice_of_inline_output.__name__: - return i - return -1 - - caller_idxs = [index_of_caller(inp) for inp in slice_inputs] - assert all(idx != -1 for idx in caller_idxs) - assert caller_idxs[0] == caller_idxs[1] - assert caller_idxs[2] != caller_idxs[1] - def test_incorrect_index_size(self): with helper.raises( tp.TripyException, diff --git a/tripy/tests/frontend/test_tensor.py b/tripy/tests/frontend/test_tensor.py index 821deae27..459c999bc 100644 --- a/tripy/tests/frontend/test_tensor.py +++ b/tripy/tests/frontend/test_tensor.py @@ -108,14 +108,14 @@ def test_dtype_printing(self, dtype): # In this test we only check the two innermost stack frames since beyond that it's all pytest code. @pytest.mark.parametrize( - "build_func,expected_line_number", + "build_func,expected_line_number,expected_func", [ - (lambda: tp.Tensor([1, 1, 1]), sys._getframe().f_lineno), - (lambda: tp.ones((3,)), sys._getframe().f_lineno), + (lambda: tp.Tensor([1, 1, 1]), sys._getframe().f_lineno, tp.Tensor.__init__), + (lambda: tp.ones((3,)), sys._getframe().f_lineno, tp.Tensor.from_trace_tensor), ], ids=["constructor", "op"], ) - def test_stack_info_is_populated(self, build_func, expected_line_number): + def test_stack_info_is_populated(self, build_func, expected_line_number, expected_func): a = build_func() a.stack_info.fetch_source_code() @@ -124,7 +124,7 @@ def test_stack_info_is_populated(self, build_func, expected_line_number): file=inspect.getsourcefile(tp.Tensor), # We don't check line number within tp.Tensor because it's difficult to determine. line=a.stack_info[0].line, - function=tp.Tensor.raw_init.__name__, + function=expected_func.__name__, code=None, _dispatch_target="", column_range=(25, 30) if sys.version_info >= (3, 11) else None, @@ -167,7 +167,6 @@ def test_dlpack_torch(self, kind): assert torch.equal(a_torch, torch.from_dlpack(tp.Tensor(a_torch))) def test_stack_depth_sanity(self): - # Makes sure STACK_DEPTH_OF_BUILD is correct a = tp.ones((2, 3)) a.stack_info.fetch_source_code() @@ -277,7 +276,6 @@ def test_tolist(self, tensor, expected): @pytest.mark.parametrize( "tensor", [ - tp.Tensor([1, 2, 3]), tp.ones((2, 2)), tp.Tensor([1, 2, 3]) + tp.Tensor([4, 5, 6]), # This case should trigger datatype conversions. @@ -291,14 +289,14 @@ def test_tolist(self, tensor, expected): (tp.Tensor([[1], [2], [3]]) + tp.Tensor([[4], [5], [6]]))[0], ], ) - def test_stack_depth_of_build(self, tensor): + def test_stack_depth_of_create_op(self, tensor): tensor.stack_info.fetch_source_code() - # Ensure that we do not include code for any frame until after the caller of `Tensor.build` - build_caller = len(tensor.stack_info) + # Ensure that we do not include code for any frame until after the caller of `op_utils.create_op` + create_op_caller = len(tensor.stack_info) for index, source_info in enumerate(tensor.stack_info): - if source_info.function == "build": - build_caller = index + 1 + if source_info.function == "create_op": + create_op_caller = index + 1 break for index, source_info in enumerate(tensor.stack_info): @@ -307,9 +305,9 @@ def test_stack_depth_of_build(self, tensor): assert source_info.code is not None break - # We should include code starting one frame past the *caller* of `build`, i.e. we - # should not see a call to `build` in the code stack trace we display. - if index > build_caller: + # We should include code starting one frame past the *caller* of `create_op`, i.e. we + # should not see a call to `create_op` in the code stack trace we display. + if index > create_op_caller: assert source_info.code is not None else: assert source_info.code is None diff --git a/tripy/tests/integration/test_iota.py b/tripy/tests/integration/test_iota.py index b59a7ecc2..5044f347e 100644 --- a/tripy/tests/integration/test_iota.py +++ b/tripy/tests/integration/test_iota.py @@ -84,7 +84,7 @@ def test_negative_no_casting(self, dtype): # TODO: update the 'match' error msg when MLIR-TRT fixes dtype constraint a = tp.ones((2, 2)) - out = Iota.build([op_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) + out = op_utils.create_op(Iota, [op_utils.tensor_from_shape_like(a.shape)], dim=0, output_rank=2, dtype=dtype) exception_str = "InternalError: failed to run compilation" with helper.raises( diff --git a/tripy/tests/trace/ops/test_storage.py b/tripy/tests/trace/ops/test_storage.py index c15cd6bcc..f0e496a40 100644 --- a/tripy/tests/trace/ops/test_storage.py +++ b/tripy/tests/trace/ops/test_storage.py @@ -31,7 +31,7 @@ class TestStorage: def test_from_small_memref(self, device): module = np if device == "cpu" else cp data = memref.create_memref_view(module.ones((2, 2), dtype=module.float32)) - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.float32 assert storage.shape == (2, 2) assert storage.device.kind == device @@ -39,7 +39,7 @@ def test_from_small_memref(self, device): def test_from_large_memref(self): data = memref.create_memref_view(cp.ones((2, STORAGE_OP_CACHE_VOLUME_THRESHOLD), dtype=cp.float32)) - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.float32 assert storage.shape == (2, STORAGE_OP_CACHE_VOLUME_THRESHOLD) assert storage.device.kind == "gpu" @@ -51,7 +51,7 @@ def test_from_dlpack_int(self, dtype): tripy_dtype = tp.int64 if dtype == "int64" else tp.int32 data = cp.ones((2, 2), dtype=cp_dtype) - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tripy_dtype assert storage.shape == (2, 2) assert storage.device.kind == "gpu" @@ -63,7 +63,7 @@ def test_from_dlpack_float(self, dtype): tripy_dtype = tp.float16 if dtype == "float16" else tp.float32 data = cp.ones((2, 2), dtype=cp_dtype) - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tripy_dtype assert storage.shape == (2, 2) assert storage.device.kind == "gpu" @@ -72,7 +72,7 @@ def test_from_dlpack_float(self, dtype): def test_from_large_input_shape(self): shape = (1, STORAGE_OP_CACHE_VOLUME_THRESHOLD + 10) data = cp.ones(shape, dtype=cp.float32) - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.float32 assert storage.shape == shape assert storage.device.kind == "gpu" @@ -80,7 +80,7 @@ def test_from_large_input_shape(self): def test_from_list_int(self): data = [[1, 2], [3, 4]] - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.int32 assert storage.shape == (2, 2) assert storage.device.kind == "gpu" @@ -88,7 +88,7 @@ def test_from_list_int(self): def test_from_list_float(self): data = [[1.0, 2.0], [3.0, 4.0]] - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.float32 assert storage.shape == (2, 2) assert storage.device.kind == "gpu" @@ -96,7 +96,7 @@ def test_from_list_float(self): def test_empty_list(self): data = [[]] - storage = Storage([], [TraceTensor("test", None, None, None, None, None)], data) + storage = Storage(data) assert storage.dtype == tp.float32 assert storage.shape == (1, 0) assert storage.device.kind == "gpu" diff --git a/tripy/tests/wrappers/test_interface.py b/tripy/tests/wrappers/test_interface.py index 40456ad52..ebd07f99e 100755 --- a/tripy/tests/wrappers/test_interface.py +++ b/tripy/tests/wrappers/test_interface.py @@ -173,6 +173,7 @@ def cast_to_bool(arg0, arg1): "__rmul__": (lambda self, other: cast_to_bool(self, other) * other), "__rtruediv__": (lambda self, other: self / other), "shape": (lambda self: self.shape), + "__getitem__": (lambda self, index: self[index]), } if func_name in SPECIAL_FUNCS: @@ -253,6 +254,9 @@ def test_raises_on_mismatched_sequence_dtypes(self): sequence_func([tp.ones((2, 2), dtype=tp.float32), tp.ones((2, 2), dtype=tp.int32)]) +STACK_DEPTH_OF_CALLER = 5 + + class TestTensorConversion: def test_no_effect_on_non_tensor_likes(self): @wrappers.interface(convert_to_tensors=True) @@ -273,7 +277,7 @@ def func(a: tp.types.TensorLike): a = func(1.0) assert isinstance(a, tp.Tensor) - assert a.stack_info[4].column_range == (17, 20) + assert a.stack_info[STACK_DEPTH_OF_CALLER].column_range == (17, 20) def test_converts_to_dimension_size(self): # The decorator should convert to DimensionSizes when possible. @@ -317,7 +321,7 @@ def func(a: tp.types.TensorLike): a = func(a=1.0) assert isinstance(a, tp.Tensor) - assert a.stack_info[4].column_range == (17, 22) + assert a.stack_info[STACK_DEPTH_OF_CALLER].column_range == (17, 22) def test_multiple_args(self): @wrappers.interface(convert_to_tensors=True) @@ -327,10 +331,10 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): a, b = func(1.0, 2.0) assert isinstance(a, tp.Tensor) - assert a.stack_info[4].column_range == (20, 23) + assert a.stack_info[STACK_DEPTH_OF_CALLER].column_range == (20, 23) assert isinstance(b, tp.Tensor) - assert b.stack_info[4].column_range == (25, 28) + assert b.stack_info[STACK_DEPTH_OF_CALLER].column_range == (25, 28) def test_args_out_of_order(self): @wrappers.interface(convert_to_tensors=True) @@ -340,11 +344,11 @@ def func(a: tp.types.TensorLike, b: tp.types.TensorLike): a, b = func(b=1.0, a=2.0) assert isinstance(a, tp.Tensor) - assert a.stack_info[4].column_range == (27, 32) + assert a.stack_info[STACK_DEPTH_OF_CALLER].column_range == (27, 32) assert a.tolist() == 2.0 assert isinstance(b, tp.Tensor) - assert b.stack_info[4].column_range == (20, 25) + assert b.stack_info[STACK_DEPTH_OF_CALLER].column_range == (20, 25) assert b.tolist() == 1.0 def test_cast_dtype(self):