Skip to content

Commit

Permalink
Add overlooked overload information on torchlib functions (#919)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #920
* __->__ #919

Resolves most reported missing overloads from
#865
  • Loading branch information
BowenBao authored Jul 25, 2023
1 parent 241260f commit f24797f
Showing 1 changed file with 25 additions and 25 deletions.
50 changes: 25 additions & 25 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def aten_acosh(self: TFloat) -> TFloat:
return op.Acosh(self)


@torch_op("aten::add")
@torch_op(("aten::add", "aten::add.Tensor"))
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
Expand Down Expand Up @@ -1235,7 +1235,7 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
return op.SplitToSequence(self, list_split, axis=dim)


@torch_op("aten::clamp", trace_only=True)
@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
"""clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
clamped = self
Expand Down Expand Up @@ -2184,7 +2184,7 @@ def aten_dist(self: TensorType, other: TensorType, p: float = 2.0) -> TensorType
raise NotImplementedError()


@torch_op("aten::div")
@torch_op(("aten::div", "aten::div.Tensor"))
def aten_div(self: TFloat, other: TFloat) -> TFloat:
"""div.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2299,7 +2299,7 @@ def aten_embedding_sparse_backward(
raise NotImplementedError()


@torch_op("aten::empty")
@torch_op(("aten::empty", "aten::empty.memory_format"))
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor

Expand Down Expand Up @@ -2353,7 +2353,7 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op("aten::eq")
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar"))
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2563,7 +2563,7 @@ def aten_feature_dropout(input: TensorType, p: float, train: bool) -> TensorType
raise NotImplementedError()


@torch_op("aten::fill")
@torch_op(("aten::fill", "aten::fill.Tensor"))
def aten_fill(self: TTensor, value: TTensor) -> TTensor:
"""fill.Tensor(Tensor self, Tensor value) -> Tensor"""

Expand Down Expand Up @@ -2748,7 +2748,7 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ge")
@torch_op(("aten::ge", "aten::ge.Tensor", "aten::ge.Scalar"))
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -2905,7 +2905,7 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op("aten::gt")
@torch_op(("aten::gt", "aten::gt.Scalar"))
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3595,7 +3595,7 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::le")
@torch_op(("aten::le", "aten::le.Tensor"))
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3884,7 +3884,7 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op("aten::lt")
@torch_op(("aten::lt", "aten::lt.Scalar"))
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -3957,7 +3957,7 @@ def aten_margin_ranking_loss(
raise NotImplementedError()


@torch_op("aten::masked_fill")
@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"))
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
"""masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"""
# NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types.
Expand Down Expand Up @@ -4462,15 +4462,15 @@ def aten_msort(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
def aten_mul(self: TReal, other: TReal) -> TReal:
"""mul.Tensor(Tensor self, Tensor other) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
other = op.CastLike(other, self)
return op.Mul(self, other)


@torch_op("aten::mul")
@torch_op(("aten::mul", "aten::mul.Tensor"))
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
"""ONNX Mul doesn't support Boolean, so use And as an equivalent operator."""

Expand Down Expand Up @@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op("aten::ne")
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor"))
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5223,7 +5223,7 @@ def aten_positive(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::pow")
@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"))
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""

Expand Down Expand Up @@ -5756,7 +5756,7 @@ def aten_rsqrt(self: TFloatOrBFloat16) -> TFloatOrBFloat16:
return op.Reciprocal(op.Sqrt(self))


@torch_op("aten::rsub")
@torch_op(("aten::rsub", "aten::rsub.Scalar"))
def aten_rsub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
# FIXME(titaiwang): get rid of this when we have type_promotion
Expand Down Expand Up @@ -5785,7 +5785,7 @@ def aten_scatter_add(
return op.ScatterElements(self, index, src, axis=dim, reduction="add")


@torch_op("aten::scatter_reduce", trace_only=True)
@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True)
def aten_scatter_reduce(
self: TReal,
dim: int, # we have to use int here because ScatterElements() will use this attribute
Expand Down Expand Up @@ -5855,7 +5855,7 @@ def aten_segment_reduce(
raise NotImplementedError()


@torch_op("aten::select")
@torch_op(("aten::select", "aten::select.int"))
def aten_select(self: TTensor, dim: int, index: int) -> TTensor:
"""select(Tensor self, int dim, int index) -> Tensor"""

Expand Down Expand Up @@ -5935,7 +5935,7 @@ def aten_sinh(self: TFloat) -> TFloat:
return op.Sinh(self)


@torch_op("aten::slice", trace_only=True)
@torch_op(("aten::slice", "aten::slice.Tensor"), trace_only=True)
def aten_slice(
self: TTensor,
dim: int = 0,
Expand Down Expand Up @@ -6081,7 +6081,7 @@ def aten_sparse_mask(self: TensorType, mask: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::split")
@torch_op(("aten::split", "aten::split.Tensor"))
def aten_split(self: TTensor, split_size: INT64, dim: int = 0) -> TTensor:
"""split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -6309,7 +6309,7 @@ def aten_stft(
return result


@torch_op("aten::sub")
@torch_op(("aten::sub", "aten::sub.Tensor"))
def aten_sub(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
"""sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
alpha = op.CastLike(alpha, other)
Expand All @@ -6324,7 +6324,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> Te
raise NotImplementedError()


@torch_op("aten::sum", trace_only=True)
@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
Expand Down Expand Up @@ -6634,7 +6634,7 @@ def aten_trace_backward(grad: TensorType, sizes: INT64) -> TensorType:
raise NotImplementedError()


@torch_op("aten::transpose", trace_only=True)
@torch_op(("aten::transpose", "aten::transpose.int"), trace_only=True)
def aten_transpose(self, dim0: int, dim1: int):
"""transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"""

Expand Down Expand Up @@ -6729,7 +6729,7 @@ def aten_type_as(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::unbind")
@torch_op(("aten::unbind", "aten::unbind.int"))
def aten_unbind(self: TTensor, dim: int = 0) -> Sequence[TTensor]:
"""unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]"""

Expand Down Expand Up @@ -7082,7 +7082,7 @@ def aten_vstack(tensors: Sequence[TTensor]) -> TTensor:
return op.ConcatFromSequence(tensors, axis=0)


@torch_op("aten::where")
@torch_op(("aten::where", "aten::where.self"))
def aten_where(condition: BOOL, self: TTensor, other: TTensor) -> TTensor:
"""where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""

Expand Down

0 comments on commit f24797f

Please sign in to comment.