diff --git a/tileops/manifest/elementwise_binary.yaml b/tileops/manifest/elementwise_binary.yaml
index 186b6300..7e46b9a8 100644
--- a/tileops/manifest/elementwise_binary.yaml
+++ b/tileops/manifest/elementwise_binary.yaml
@@ -19,13 +19,13 @@ PreluFwdOp:
shape_rules:
# PyTorch prelu: weight is either scalar or per-channel along dim 1
# for inputs with ndim >= 2; 1-D inputs accept scalar weight only.
- - "weight.ndim == 0 or (weight.ndim == 1 and (weight.shape[0] == 1 or (input.ndim >= 2 and weight.shape[0] == input.shape[1])))"
- - "output.shape == input.shape"
+ - "weight.ndim == 0 or (weight.ndim == 1 and (weight.shape[0] == 1 or (input.ndim >= 2 and weight.shape[0] == input.shape[1])))"
+ - "output.shape == input.shape"
workloads:
# PReLU CNN feature map (per-channel weight)
- - {input_shape: [16, 256, 56, 56], weight_shape: [256], dtypes: [float16, bfloat16], label: "cnn-feat-per-channel"}
- - {input_shape: [16, 512, 28, 28], weight_shape: [512], dtypes: [float16, bfloat16], label: "cnn-feat-per-channel-deep"}
+ - {input_shape: [16, 256, 56, 56], weight_shape: [256], dtypes: [float16, bfloat16], label: "cnn-feat-per-channel"}
+ - {input_shape: [16, 512, 28, 28], weight_shape: [512], dtypes: [float16, bfloat16], label: "cnn-feat-per-channel-deep"}
roofline:
vars:
@@ -38,7 +38,7 @@ PreluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/prelu.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -72,12 +72,12 @@ MaskedFillFwdOp:
shape_rules:
# Out-of-place masked_fill returns the bidirectional broadcast of
# input and mask; value is 0-dim.
- - "value.shape == ()"
- - "output.shape == broadcast_shapes(input.shape, mask.shape)"
+ - "value.shape == ()"
+ - "output.shape == broadcast_shapes(input.shape, mask.shape)"
workloads:
- - {input_shape: [4096, 4096], mask_shape: [4096, 4096], value_shape: [], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], mask_shape: [16384, 16384], value_shape: [], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], mask_shape: [4096, 4096], value_shape: [], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], mask_shape: [16384, 16384], value_shape: [], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
func: "tileops.perf.formulas.masked_fill_fwd_roofline"
@@ -86,7 +86,7 @@ MaskedFillFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
masked_fill_tensor_value: MaskedFillTensorValueFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/masked_fill.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -112,11 +112,11 @@ MaskedFillScalarFwdOp:
# Out-of-place masked_fill returns the bidirectional broadcast of
# input and mask; out shape follows that broadcast (verified against
# ``torch.Tensor.masked_fill`` — input may also be expanded up).
- - "output.shape == broadcast_shapes(input.shape, mask.shape)"
+ - "output.shape == broadcast_shapes(input.shape, mask.shape)"
workloads:
- - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
# Func mode: shared with MaskedFillFwdOp (Tensor-value primary). See
@@ -128,7 +128,7 @@ MaskedFillScalarFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
masked_fill: MaskedFillFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/masked_fill.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -148,7 +148,7 @@ AddFwdOp:
alpha: {type: "int | float", default: 1}
shape_rules:
# Output follows PyTorch broadcasting; numel uses the broadcast shape.
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -158,7 +158,7 @@ AddFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -177,7 +177,7 @@ SubFwdOp:
params:
alpha: {type: "int | float", default: 1}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -187,7 +187,7 @@ SubFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -204,7 +204,7 @@ MulFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -214,7 +214,7 @@ MulFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -233,8 +233,8 @@ DivFwdOp:
params:
rounding_mode: {type: "str | None", default: null}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
- - "rounding_mode is None or rounding_mode in ('trunc', 'floor')"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "rounding_mode is None or rounding_mode in ('trunc', 'floor')"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -244,7 +244,7 @@ DivFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -261,7 +261,7 @@ RemainderFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -271,7 +271,7 @@ RemainderFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -288,7 +288,7 @@ PowFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, exponent.shape)"
+ - "output.shape == broadcast_shapes(input.shape, exponent.shape)"
workloads:
- {input_shape: [2048, 4096], exponent_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -298,7 +298,7 @@ PowFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -315,7 +315,7 @@ FloorDivideFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -325,7 +325,7 @@ FloorDivideFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -347,7 +347,7 @@ LerpFwdOp:
params:
weight: {type: float, default: 0.5}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, end.shape)"
+ - "output.shape == broadcast_shapes(input.shape, end.shape)"
workloads:
- {input_shape: [2048, 4096], end_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -357,7 +357,7 @@ LerpFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -374,7 +374,7 @@ MaximumFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -384,7 +384,7 @@ MaximumFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -401,7 +401,7 @@ MinimumFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -411,7 +411,7 @@ MinimumFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
@@ -432,7 +432,7 @@ EqFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -442,7 +442,7 @@ EqFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -459,7 +459,7 @@ NeFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -469,7 +469,7 @@ NeFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -486,7 +486,7 @@ GtFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -496,7 +496,7 @@ GtFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -513,7 +513,7 @@ LtFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -523,7 +523,7 @@ LtFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -540,7 +540,7 @@ GeFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -550,7 +550,7 @@ GeFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -567,7 +567,7 @@ LeFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [float16, bfloat16, float32], label: hidden-state-prefill}
@@ -577,7 +577,7 @@ LeFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/comparison.py
test: tests/ops/test_comparison.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -598,7 +598,7 @@ LogicalAndFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [bool, float16, bfloat16, float32], label: hidden-state-prefill}
@@ -608,7 +608,7 @@ LogicalAndFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/logical.py
test: tests/ops/test_logical.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -625,7 +625,7 @@ LogicalOrFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [bool, float16, bfloat16, float32], label: hidden-state-prefill}
@@ -635,7 +635,7 @@ LogicalOrFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/logical.py
test: tests/ops/test_logical.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -656,7 +656,7 @@ BitwiseAndFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [bool, int32, int64], label: hidden-state-prefill}
@@ -666,7 +666,7 @@ BitwiseAndFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/bitwise.py
test: tests/ops/test_bitwise.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -683,7 +683,7 @@ BitwiseOrFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [bool, int32, int64], label: hidden-state-prefill}
@@ -693,7 +693,7 @@ BitwiseOrFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/bitwise.py
test: tests/ops/test_bitwise.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
@@ -710,7 +710,7 @@ BitwiseXorFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(input.shape, other.shape)"
workloads:
- {input_shape: [2048, 4096], other_shape: [2048, 4096], dtypes: [bool, int32, int64], label: hidden-state-prefill}
@@ -720,7 +720,7 @@ BitwiseXorFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/bitwise.py
test: tests/ops/test_bitwise.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
diff --git a/tileops/manifest/elementwise_multi_input.yaml b/tileops/manifest/elementwise_multi_input.yaml
index a3600f71..c097fa81 100644
--- a/tileops/manifest/elementwise_multi_input.yaml
+++ b/tileops/manifest/elementwise_multi_input.yaml
@@ -20,11 +20,11 @@ WhereFwdOp:
shape_rules:
# PyTorch's torch.where broadcasts condition/input/other together;
# the output shape is the broadcast of all three.
- - "output.shape == broadcast_shapes(condition.shape, input.shape, other.shape)"
+ - "output.shape == broadcast_shapes(condition.shape, input.shape, other.shape)"
workloads:
- - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
# Mixed-dtype op (bool condition + float input/other) — inline mode cannot
@@ -37,7 +37,7 @@ WhereFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
where: WhereFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/where.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -61,11 +61,11 @@ LerpTensorFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, end.shape, weight.shape)"
+ - "output.shape == broadcast_shapes(input.shape, end.shape, weight.shape)"
workloads:
- - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
func: "tileops.perf.formulas.lerp_tensor_fwd_roofline"
@@ -74,7 +74,7 @@ LerpTensorFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
lerp_tensor: LerpTensorFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/arithmetic.py
test: tests/ops/test_binary_arith.py
bench: benchmarks/ops/bench_binary_arith.py
bench_manifest_driven: false
diff --git a/tileops/manifest/elementwise_unary_activation.yaml b/tileops/manifest/elementwise_unary_activation.yaml
index 6ab45460..ece8f304 100644
--- a/tileops/manifest/elementwise_unary_activation.yaml
+++ b/tileops/manifest/elementwise_unary_activation.yaml
@@ -20,13 +20,13 @@ ReluFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# Hidden-state activation (Llama-3.1-8B prefill)
- - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "hidden-state-prefill"}
+ - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "hidden-state-prefill"}
# Decode (single token)
- - {input_shape: [1, 4096], dtypes: [bfloat16], label: "hidden-state-decode"}
+ - {input_shape: [1, 4096], dtypes: [bfloat16], label: "hidden-state-decode"}
roofline:
vars:
@@ -38,7 +38,7 @@ ReluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -56,13 +56,13 @@ GeluFwdOp:
params:
approximate: {type: str, default: "none"}
shape_rules:
- - "approximate in ('none', 'tanh')"
- - "output.shape == input.shape"
+ - "approximate in ('none', 'tanh')"
+ - "output.shape == input.shape"
workloads:
# Llama-3.1-8B FFN intermediate (hidden_dim=14336)
- - {input_shape: [2048, 14336], dtypes: [float16, bfloat16], label: "llama-3.1-8b-ffn-prefill"}
- - {input_shape: [1, 14336], dtypes: [bfloat16], label: "llama-3.1-8b-ffn-decode"}
+ - {input_shape: [2048, 14336], dtypes: [float16, bfloat16], label: "llama-3.1-8b-ffn-prefill"}
+ - {input_shape: [1, 14336], dtypes: [bfloat16], label: "llama-3.1-8b-ffn-decode"}
roofline:
vars:
@@ -73,7 +73,7 @@ GeluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -91,12 +91,12 @@ SiluFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# Llama-3.1-8B SwiGLU FFN intermediate
- - {input_shape: [2048, 14336], dtypes: [float16, bfloat16], label: "llama-3.1-8b-ffn-prefill"}
- - {input_shape: [1, 14336], dtypes: [bfloat16], label: "llama-3.1-8b-ffn-decode"}
+ - {input_shape: [2048, 14336], dtypes: [float16, bfloat16], label: "llama-3.1-8b-ffn-prefill"}
+ - {input_shape: [1, 14336], dtypes: [bfloat16], label: "llama-3.1-8b-ffn-decode"}
roofline:
vars:
@@ -107,7 +107,7 @@ SiluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -125,12 +125,12 @@ HardswishFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# Mobile-style activation map (NHW C, e.g. MobileNetV3 stage)
- - {input_shape: [32, 96, 56, 56], dtypes: [float16, bfloat16], label: "mbv3-stage2"}
- - {input_shape: [32, 240, 28, 28], dtypes: [float16, bfloat16], label: "mbv3-stage3"}
+ - {input_shape: [32, 96, 56, 56], dtypes: [float16, bfloat16], label: "mbv3-stage2"}
+ - {input_shape: [32, 240, 28, 28], dtypes: [float16, bfloat16], label: "mbv3-stage3"}
roofline:
vars:
@@ -141,7 +141,7 @@ HardswishFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -159,12 +159,12 @@ HardsigmoidFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# SE-block gating (B, C, 1, 1)
- - {input_shape: [32, 240, 1, 1], dtypes: [float16, bfloat16], label: "mbv3-se-gate"}
- - {input_shape: [32, 960, 1, 1], dtypes: [float16, bfloat16], label: "mbv3-se-gate-deep"}
+ - {input_shape: [32, 240, 1, 1], dtypes: [float16, bfloat16], label: "mbv3-se-gate"}
+ - {input_shape: [32, 960, 1, 1], dtypes: [float16, bfloat16], label: "mbv3-se-gate-deep"}
roofline:
vars:
@@ -175,7 +175,7 @@ HardsigmoidFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -193,12 +193,12 @@ MishFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# YOLO-style feature map activation
- - {input_shape: [16, 256, 80, 80], dtypes: [float16, bfloat16], label: "yolo-p3"}
- - {input_shape: [16, 512, 40, 40], dtypes: [float16, bfloat16], label: "yolo-p4"}
+ - {input_shape: [16, 256, 80, 80], dtypes: [float16, bfloat16], label: "yolo-p3"}
+ - {input_shape: [16, 512, 40, 40], dtypes: [float16, bfloat16], label: "yolo-p4"}
roofline:
vars:
@@ -209,7 +209,7 @@ MishFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -227,12 +227,12 @@ SeluFwdOp:
params:
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# SNN-style fully connected activation
- - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "snn-fc"}
- - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "snn-fc-wide"}
+ - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "snn-fc"}
+ - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "snn-fc-wide"}
roofline:
vars:
@@ -243,7 +243,7 @@ SeluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -262,12 +262,12 @@ LeakyReluFwdOp:
negative_slope: {type: float, default: 0.01}
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# GAN feature map activation
- - {input_shape: [16, 256, 64, 64], dtypes: [float16, bfloat16], label: "gan-feat"}
- - {input_shape: [16, 512, 32, 32], dtypes: [float16, bfloat16], label: "gan-feat-deep"}
+ - {input_shape: [16, 256, 64, 64], dtypes: [float16, bfloat16], label: "gan-feat"}
+ - {input_shape: [16, 512, 32, 32], dtypes: [float16, bfloat16], label: "gan-feat-deep"}
roofline:
vars:
@@ -278,7 +278,7 @@ LeakyReluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -297,12 +297,12 @@ EluFwdOp:
alpha: {type: float, default: 1.0}
inplace: {type: bool, default: false}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# MLP hidden activation
- - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "mlp-hidden"}
- - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "mlp-hidden-wide"}
+ - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "mlp-hidden"}
+ - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "mlp-hidden-wide"}
roofline:
vars:
@@ -313,7 +313,7 @@ EluFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -333,13 +333,13 @@ HardtanhFwdOp:
max_val: {type: float, default: 1.0}
inplace: {type: bool, default: false}
shape_rules:
- - "min_val <= max_val"
- - "output.shape == input.shape"
+ - "min_val <= max_val"
+ - "output.shape == input.shape"
workloads:
# Quantization-friendly bounded activation
- - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "bounded-hidden"}
- - {input_shape: [16, 256, 56, 56], dtypes: [float16, bfloat16], label: "bounded-conv-feat"}
+ - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "bounded-hidden"}
+ - {input_shape: [16, 256, 56, 56], dtypes: [float16, bfloat16], label: "bounded-conv-feat"}
roofline:
vars:
@@ -350,7 +350,7 @@ HardtanhFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -369,12 +369,12 @@ SoftplusFwdOp:
beta: {type: float, default: 1.0}
threshold: {type: float, default: 20.0}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
# Distribution-modeling MLP activation
- - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "mlp-hidden"}
- - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "mlp-hidden-wide"}
+ - {input_shape: [2048, 4096], dtypes: [float16, bfloat16], label: "mlp-hidden"}
+ - {input_shape: [2048, 8192], dtypes: [float16, bfloat16], label: "mlp-hidden-wide"}
roofline:
vars:
@@ -385,7 +385,7 @@ SoftplusFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -411,11 +411,11 @@ ClampFwdOp:
output: {dtype: "same_as(input)"}
shape_rules:
# PyTorch broadcasts input/min/max together for tensor-bound clamp.
- - "output.shape == broadcast_shapes(input.shape, min.shape, max.shape)"
+ - "output.shape == broadcast_shapes(input.shape, min.shape, max.shape)"
workloads:
- - {input_shape: [4096, 4096], min_shape: [4096, 4096], max_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], min_shape: [16384, 16384], max_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], min_shape: [4096, 4096], max_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], min_shape: [16384, 16384], max_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
# Func mode: post-broadcast N_total uses broadcast_shapes which is
@@ -427,7 +427,7 @@ ClampFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
clamp_tensor: ClampTensorFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/clamp.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -450,11 +450,11 @@ ClampScalarFwdOp:
min: {type: "Number | None", default: null}
max: {type: "Number | None", default: null}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -470,7 +470,7 @@ ClampScalarFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
clamp: ClampFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/clamp.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -490,11 +490,11 @@ ClampMinFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, min.shape)"
+ - "output.shape == broadcast_shapes(input.shape, min.shape)"
workloads:
- - {input_shape: [4096, 4096], min_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], min_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], min_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], min_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
# Func mode: see ClampFwdOp.roofline for rationale (broadcast_shapes
@@ -505,7 +505,7 @@ ClampMinFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
clamp_tensor: ClampTensorFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/clamp.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -525,11 +525,11 @@ ClampMaxFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == broadcast_shapes(input.shape, max.shape)"
+ - "output.shape == broadcast_shapes(input.shape, max.shape)"
workloads:
- - {input_shape: [4096, 4096], max_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], max_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], max_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], max_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
# Func mode: see ClampFwdOp.roofline for rationale (broadcast_shapes
@@ -540,7 +540,7 @@ ClampMaxFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
clamp_tensor: ClampTensorFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/clamp.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
@@ -560,11 +560,11 @@ NanToNumFwdOp:
posinf: {type: "float | None", default: null}
neginf: {type: "float | None", default: null}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {input_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {input_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -575,7 +575,7 @@ NanToNumFwdOp:
source:
kernel: tileops/kernels/elementwise.py
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/nan_to_num.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_independent_elementwise.py
bench_manifest_driven: false
diff --git a/tileops/manifest/elementwise_unary_math.yaml b/tileops/manifest/elementwise_unary_math.yaml
index effcbc56..f328d107 100644
--- a/tileops/manifest/elementwise_unary_math.yaml
+++ b/tileops/manifest/elementwise_unary_math.yaml
@@ -20,11 +20,11 @@ ExpFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -36,7 +36,7 @@ ExpFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
exp: ExpFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -52,11 +52,11 @@ LogFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -68,7 +68,7 @@ LogFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
log: LogFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -84,11 +84,11 @@ SqrtFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -100,7 +100,7 @@ SqrtFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
sqrt: SqrtFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -116,11 +116,11 @@ RsqrtFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -132,7 +132,7 @@ RsqrtFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
rsqrt: RsqrtFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -148,11 +148,11 @@ AbsFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -164,7 +164,7 @@ AbsFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
abs: AbsFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -180,11 +180,11 @@ NegFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -196,7 +196,7 @@ NegFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
neg: NegFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -214,11 +214,11 @@ ReciprocalFwdOp:
# inputs preserve their dtype.
output: {dtype: "promote_int_to_float(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -230,7 +230,7 @@ ReciprocalFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
reciprocal: ReciprocalFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -246,11 +246,11 @@ SignFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -263,7 +263,7 @@ SignFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
sign: SignFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -279,11 +279,11 @@ SinFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -295,7 +295,7 @@ SinFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
sin: SinFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -311,11 +311,11 @@ CosFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -327,7 +327,7 @@ CosFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
cos: CosFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -343,11 +343,11 @@ FloorFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -359,7 +359,7 @@ FloorFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
floor: FloorFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -375,11 +375,11 @@ CeilFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -391,7 +391,7 @@ CeilFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
ceil: CeilFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -409,11 +409,11 @@ RoundFwdOp:
params:
decimals: {type: int, default: 0}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -425,7 +425,7 @@ RoundFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
round: RoundFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -441,11 +441,11 @@ TruncFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -457,7 +457,7 @@ TruncFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
trunc: TruncFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -473,11 +473,11 @@ ErfFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -489,7 +489,7 @@ ErfFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
erf: ErfFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -505,11 +505,11 @@ Log1pFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -522,7 +522,7 @@ Log1pFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
log1p: Log1pFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -538,11 +538,11 @@ Expm1FwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -555,7 +555,7 @@ Expm1FwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
expm1: Expm1FwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/math_unary.py
test: tests/ops/test_unary_math.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -571,11 +571,11 @@ SigmoidFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -588,7 +588,7 @@ SigmoidFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
sigmoid: SigmoidFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -604,11 +604,11 @@ TanhFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -621,7 +621,7 @@ TanhFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
tanh: TanhFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/activations.py
test: tests/ops/test_activation.py
bench: benchmarks/ops/bench_activation.py
bench_manifest_driven: false
@@ -637,11 +637,11 @@ LogicalNotFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [bool, float16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [bool], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [bool, float16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [bool], label: "elementwise-256M"}
roofline:
vars:
@@ -654,7 +654,7 @@ LogicalNotFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
logical_not: LogicalNotFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/logical.py
test: tests/ops/test_logical.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -670,11 +670,11 @@ BitwiseNotFwdOp:
outputs:
output: {dtype: "same_as(input)"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [int32, int64], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [int32], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [int32, int64], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [int32], label: "elementwise-256M"}
roofline:
vars:
@@ -686,7 +686,7 @@ BitwiseNotFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
bitwise_not: BitwiseNotFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/bitwise.py
test: tests/ops/test_bitwise.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -702,11 +702,11 @@ IsnanFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -719,7 +719,7 @@ IsnanFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
isnan: IsnanFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/predicates.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -735,11 +735,11 @@ IsinfFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -752,7 +752,7 @@ IsinfFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
isinf: IsinfFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/predicates.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
@@ -768,11 +768,11 @@ IsfiniteFwdOp:
outputs:
output: {dtype: "bool"}
shape_rules:
- - "output.shape == input.shape"
+ - "output.shape == input.shape"
workloads:
- - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
- - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
+ - {x_shape: [4096, 4096], dtypes: [float16, bfloat16, float32], label: "elementwise-16M"}
+ - {x_shape: [16384, 16384], dtypes: [float16, bfloat16], label: "elementwise-256M"}
roofline:
vars:
@@ -785,7 +785,7 @@ IsfiniteFwdOp:
kernel: tileops/kernels/elementwise.py
kernel_map:
isfinite: IsfiniteFwdKernel
- op: tileops/ops/elementwise.py
+ op: tileops/ops/elementwise/predicates.py
test: tests/ops/test_special_elementwise.py
bench: benchmarks/ops/bench_unary_elementwise.py
bench_manifest_driven: false
diff --git a/tileops/ops/elementwise.py b/tileops/ops/elementwise.py
deleted file mode 100644
index d96107ae..00000000
--- a/tileops/ops/elementwise.py
+++ /dev/null
@@ -1,3415 +0,0 @@
-"""Elementwise op templates and broadcast utility.
-
-Three Op template base classes:
-- UnaryOp: wraps UnaryKernel with reshape/flatten
-- BinaryOp: wraps BinaryKernel with broadcast coalescing
-- FusedGatedOp: wraps FusedGatedKernel with (M, 2N) layout
-
-torch.compile support:
-- All 66 concrete ops are registered via @torch.library.custom_op at module load time
-- Three factory functions (_register_unary_custom_op, _register_binary_custom_op,
- _register_fused_gated_custom_op) register every op; instances are looked up at
- runtime via _OP_REGISTRY keyed by id(instance)
-
-Utility:
-- coalesce_broadcast_dims: reduces N-dim broadcast to minimal effective dims
-"""
-
-import inspect
-import math
-import weakref
-from math import prod
-from typing import Callable, Dict, List, Optional
-
-import torch
-
-from tileops.kernels.elementwise import (
- AbsFwdKernel,
- AddFwdKernel,
- AlibiFwdKernel,
- BitwiseAndFwdKernel,
- BitwiseNotFwdKernel,
- BitwiseOrFwdKernel,
- BitwiseXorFwdKernel,
- CeilFwdKernel,
- ClampFwdKernel,
- ClampTensorFwdKernel,
- CosFwdKernel,
- DivFwdKernel,
- EluFwdKernel,
- EqFwdKernel,
- ErfFwdKernel,
- ExpFwdKernel,
- Expm1FwdKernel,
- FloorDivideFwdKernel,
- FloorFwdKernel,
- GeFwdKernel,
- GeluAndMulFwdKernel,
- GeluFwdKernel,
- GeluTanhAndMulFwdKernel,
- GeluTanhFwdKernel,
- GtFwdKernel,
- HardsigmoidFwdKernel,
- HardswishFwdKernel,
- HardtanhFwdKernel,
- IsfiniteFwdKernel,
- IsinfFwdKernel,
- IsnanFwdKernel,
- LeakyReluFwdKernel,
- LeFwdKernel,
- LerpFwdKernel,
- LerpTensorFwdKernel,
- Log1pFwdKernel,
- LogFwdKernel,
- LogicalAndFwdKernel,
- LogicalNotFwdKernel,
- LogicalOrFwdKernel,
- LtFwdKernel,
- MaskedFillFwdKernel,
- MaskedFillTensorValueFwdKernel,
- MaximumFwdKernel,
- MinimumFwdKernel,
- MishFwdKernel,
- MulFwdKernel,
- NanToNumFwdKernel,
- NeFwdKernel,
- NegFwdKernel,
- PowFwdKernel,
- PreluFwdKernel,
- ReciprocalFwdKernel,
- ReluFwdKernel,
- RemainderFwdKernel,
- RoundFwdKernel,
- RsqrtFwdKernel,
- SeluFwdKernel,
- SigmoidFwdKernel,
- SignFwdKernel,
- SiluAndMulFwdKernel,
- SiluFwdKernel,
- SinFwdKernel,
- SinusoidalFwdKernel,
- SoftplusFwdKernel,
- SqrtFwdKernel,
- SubFwdKernel,
- TanhFwdKernel,
- TruncFwdKernel,
- WhereFwdKernel,
-)
-from tileops.kernels.kernel_base import Kernel
-
-from .op_base import Op
-
-# ---------------------------------------------------------------------------
-# torch.compile registration factories
-#
-# Each factory creates a @torch.library.custom_op + register_fake pair.
-# Instances register themselves in _OP_REGISTRY keyed by integer id.
-# The custom_op receives this key and looks up the instance to call the
-# pre-built tilelang kernel. The key is a plain int so dynamo can trace
-# through forward() without hitting unsupported Python side-effects.
-# ---------------------------------------------------------------------------
-
-_OP_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
-
-_FP8_NONSAT_OUTPUT_DTYPES = {
- torch.float8_e5m2: torch.float16,
-}
-
-def _effective_scalar_kernel_dtype(dtype: torch.dtype) -> torch.dtype:
- """Return the dtype used when scalar literals are materialized in kernels."""
- return _FP8_NONSAT_OUTPUT_DTYPES.get(dtype, dtype)
-
-
-def _validate_scalar_param_repr(
- param_name: str, value: float, dtype: torch.dtype, op_name: str,
-) -> None:
- """Reject scalar params that cannot be represented in the user dtype.
-
- Validation targets the *user-facing* ``dtype`` rather than the
- intermediate ``_effective_scalar_kernel_dtype(dtype)``. For fp8
- dtypes the kernel runs in fp16 to preserve Inf/NaN, but a value that
- only fits in fp16 would surface as ``+/-Inf`` after the final fp8
- post-cast. Validating against the user dtype keeps explicit
- replacements finite end-to-end.
- """
- if not isinstance(value, (int, float)):
- raise TypeError(f"{op_name} expected scalar {param_name} to be int/float, got {type(value)}")
-
- finfo = torch.finfo(dtype)
- value_f64 = float(value)
- if math.isnan(value_f64):
- return
- if math.isinf(value_f64):
- raise ValueError(
- f"{op_name} received {param_name}={value!r}, but {param_name} must be finite and "
- f"representable in dtype {dtype}"
- )
- if not (finfo.min <= value_f64 <= finfo.max):
- raise ValueError(
- f"{op_name} received {param_name}={value!r}, which is not representable in "
- f"dtype {dtype} (valid finite range: "
- f"[{finfo.min}, {finfo.max}])"
- )
-
-
-def _register_unary_custom_op(op_cls, output_dtype_override=None):
- """Register a unary elementwise op for torch.compile.
-
- Args:
- op_cls: The Op subclass to register (must have ``_op_name``).
- output_dtype_override: If set, the output dtype (e.g. torch.bool for predicates).
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_unary_{op_name}", mutates_args=())
- def _wrapped(x: torch.Tensor, instance_key: int) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(x)
-
- @_wrapped.register_fake
- def _(x: torch.Tensor, instance_key: int) -> torch.Tensor:
- out_dtype = output_dtype_override if output_dtype_override is not None else x.dtype
- return torch.empty_like(x, dtype=out_dtype)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_unary_inplace_custom_op(op_cls):
- """Register the ``inplace=True`` companion for a unary activation op.
-
- The kernel writes into a fresh buffer; this wrapper copies the result
- back into ``x`` and returns ``x`` so the caller sees ``y is x`` and
- ``x`` carries the activation output. The custom op is registered with
- ``mutates_args=("x",)`` so ``torch.compile`` traces the mutation
- correctly. Sets ``op_cls._wrapped_inplace`` for ``forward()`` to
- dispatch through.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(
- f"top::elementwise_unary_{op_name}_inplace", mutates_args=("x",),
- )
- def _wrapped_inplace(x: torch.Tensor, instance_key: int) -> None:
- instance = _OP_REGISTRY[instance_key]
- result = instance._eager_forward(x)
- x.copy_(result.reshape(x.shape))
-
- op_cls._wrapped_inplace = _wrapped_inplace
-
-
-def _register_binary_custom_op(op_cls, output_bool: bool = False):
- """Register a binary elementwise op for torch.compile.
-
- Args:
- op_cls: The Op subclass to register.
- output_bool: If True, output dtype is torch.bool (for comparison/logical ops).
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_binary_{op_name}", mutates_args=())
- def _wrapped(
- a: torch.Tensor,
- b: torch.Tensor,
- out_shape: List[int],
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(a, b)
-
- @_wrapped.register_fake
- def _(
- a: torch.Tensor,
- b: torch.Tensor,
- out_shape: List[int],
- instance_key: int,
- ) -> torch.Tensor:
- out_dtype = torch.bool if output_bool else a.dtype
- return a.new_empty(out_shape, dtype=out_dtype)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_prelu_custom_op(op_cls):
- """Register a PReLU-style op (x, weight -> y) for torch.compile."""
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- x: torch.Tensor,
- weight: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(x, weight)
-
- @_wrapped.register_fake
- def _(
- x: torch.Tensor,
- weight: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- return torch.empty_like(x)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_where_custom_op(op_cls):
- """Register a where-style op (cond, x, y -> out) for torch.compile.
-
- The fake function computes the broadcast output shape from
- ``cond`` / ``x`` / ``y`` so that ``torch.compile(fullgraph=True)``
- works for both same-shape and broadcasting inputs.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- cond: torch.Tensor,
- x: torch.Tensor,
- y: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(cond, x, y)
-
- @_wrapped.register_fake
- def _(
- cond: torch.Tensor,
- x: torch.Tensor,
- y: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(cond.shape, x.shape, y.shape)
- return x.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_lerp_tensor_custom_op(op_cls):
- """Register a Tensor-weight lerp op (input, end, weight -> out).
-
- The fake function computes the broadcast output shape from ``input`` /
- ``end`` / ``weight`` so that ``torch.compile(fullgraph=True)`` works
- for both same-shape and broadcasting inputs. Registered under a
- distinct ``_tensor`` namespace to avoid colliding with the scalar
- ``LerpFwdOp`` (which bakes ``weight`` at construction time and uses
- the binary registration path).
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- end: torch.Tensor,
- weight: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(input, end, weight)
-
- @_wrapped.register_fake
- def _(
- input: torch.Tensor, # noqa: A002
- end: torch.Tensor,
- weight: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(input.shape, end.shape, weight.shape)
- return input.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_masked_fill_custom_op(op_cls):
- """Register a masked-fill-style op (x, mask -> y) for torch.compile.
-
- The fake function computes the bidirectional broadcast output shape
- of ``x`` and ``mask`` so ``torch.compile(fullgraph=True)`` works for
- both same-shape and broadcasting inputs.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- x: torch.Tensor,
- mask: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(x, mask)
-
- @_wrapped.register_fake
- def _(
- x: torch.Tensor,
- mask: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(x.shape, mask.shape)
- return x.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_masked_fill_tensor_value_custom_op(op_cls):
- """Register a masked-fill (Tensor value) op (input, mask, value -> out).
-
- The fake function computes the broadcast output shape of ``input`` and
- ``mask`` (``value`` is a 0-dim Tensor). Registered under a distinct
- namespace from the scalar masked_fill variant to avoid collision.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(
- f"top::elementwise_{op_name}_tensor_value", mutates_args=(),
- )
- def _wrapped(
- input: torch.Tensor, # noqa: A002
- mask: torch.Tensor,
- value: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(input, mask, value)
-
- @_wrapped.register_fake
- def _(
- input: torch.Tensor, # noqa: A002
- mask: torch.Tensor,
- value: torch.Tensor,
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(input.shape, mask.shape)
- return input.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_clamp_tensor_custom_op(op_cls):
- """Register a Tensor-bound clamp op (input, min?, max? -> out).
-
- ``min`` and ``max`` are each ``Optional[Tensor]``; the schema is
- inferred by ``torch.library.custom_op`` from the ``Optional[torch.Tensor]``
- annotations, producing ``Tensor? min, Tensor? max`` in the underlying
- custom-op schema. The fake function computes the broadcast output
- shape of all non-``None`` operands so ``torch.compile(fullgraph=True)``
- works for both same-shape and broadcasting inputs. Registered under
- a distinct ``_tensor`` namespace from the scalar-bound clamp variant.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(
- f"top::elementwise_{op_name}_tensor", mutates_args=(),
- )
- def _wrapped(
- input: torch.Tensor, # noqa: A002
- min: Optional[torch.Tensor], # noqa: A002
- max: Optional[torch.Tensor], # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(input, min, max)
-
- @_wrapped.register_fake
- def _(
- input: torch.Tensor, # noqa: A002
- min: Optional[torch.Tensor], # noqa: A002
- max: Optional[torch.Tensor], # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- shapes = [input.shape]
- if min is not None:
- shapes.append(min.shape)
- if max is not None:
- shapes.append(max.shape)
- out_shape = torch.broadcast_shapes(*shapes)
- return input.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_clamp_min_custom_op(op_cls):
- """Register single-bound Tensor lower-clamp (input, min -> out)."""
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- input: torch.Tensor, # noqa: A002
- min: torch.Tensor, # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(input, min)
-
- @_wrapped.register_fake
- def _(
- input: torch.Tensor, # noqa: A002
- min: torch.Tensor, # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(input.shape, min.shape)
- return input.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_clamp_max_custom_op(op_cls):
- """Register single-bound Tensor upper-clamp (input, max -> out)."""
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- input: torch.Tensor, # noqa: A002
- max: torch.Tensor, # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(input, max)
-
- @_wrapped.register_fake
- def _(
- input: torch.Tensor, # noqa: A002
- max: torch.Tensor, # noqa: A002
- instance_key: int,
- ) -> torch.Tensor:
- out_shape = torch.broadcast_shapes(input.shape, max.shape)
- return input.new_empty(out_shape)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_generative_custom_op(op_cls, out_shape_fn):
- """Register a generative op (no tensor input -> out) for torch.compile.
-
- A scalar ``device_carrier`` tensor is passed so that ``register_fake``
- can derive the correct device and dtype from a real tensor reference,
- which is required by the torch.compile tracing infrastructure.
-
- Args:
- op_cls: The Op subclass to register.
- out_shape_fn: Callable(carrier, num_a, num_b) -> Tensor returning
- the output metadata so register_fake can produce the right shape.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
- def _wrapped(
- device_carrier: torch.Tensor,
- num_a: int,
- num_b: int,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward()
-
- @_wrapped.register_fake
- def _(
- device_carrier: torch.Tensor,
- num_a: int,
- num_b: int,
- instance_key: int,
- ) -> torch.Tensor:
- return out_shape_fn(device_carrier, num_a, num_b)
-
- op_cls._wrapped = _wrapped
-
-
-def _register_fused_gated_custom_op(op_cls):
- """Register a fused gated elementwise op for torch.compile.
-
- Args:
- op_cls: The Op subclass to register.
- """
- op_name = op_cls._op_name
-
- @torch.library.custom_op(f"top::elementwise_fused_gated_{op_name}", mutates_args=())
- def _wrapped(
- x: torch.Tensor,
- M: int,
- N: int,
- instance_key: int,
- ) -> torch.Tensor:
- instance = _OP_REGISTRY[instance_key]
- return instance._eager_forward(x)
-
- @_wrapped.register_fake
- def _(
- x: torch.Tensor,
- M: int,
- N: int,
- instance_key: int,
- ) -> torch.Tensor:
- return x.new_empty((M, N), dtype=x.dtype)
-
- op_cls._wrapped = _wrapped
-
-
-__all__ = [
- "coalesce_broadcast_dims",
- "UnaryOp",
- "BinaryOp",
- "FusedGatedOp",
- # Unary
- "ReluFwdOp",
- # Binary arithmetic
- "AddFwdOp",
- "SubFwdOp",
- "MulFwdOp",
- "DivFwdOp",
- "RemainderFwdOp",
- "PowFwdOp",
- "FloorDivideFwdOp",
- "LerpFwdOp",
- "MaximumFwdOp",
- "MinimumFwdOp",
- # Comparison (output bool)
- "EqFwdOp",
- "NeFwdOp",
- "GtFwdOp",
- "LtFwdOp",
- "GeFwdOp",
- "LeFwdOp",
- # Logical (output bool)
- "LogicalAndFwdOp",
- "LogicalOrFwdOp",
- # Bitwise
- "BitwiseAndFwdOp",
- "BitwiseOrFwdOp",
- "BitwiseXorFwdOp",
- # Fused gated
- "SiluAndMulFwdOp",
- "GeluAndMulFwdOp",
- "GeluTanhAndMulFwdOp",
- # --- math (17) ---
- "AbsFwdOp",
- "CeilFwdOp",
- "CosFwdOp",
- "ErfFwdOp",
- "ExpFwdOp",
- "Expm1FwdOp",
- "FloorFwdOp",
- "Log1pFwdOp",
- "LogFwdOp",
- "NegFwdOp",
- "ReciprocalFwdOp",
- "RoundFwdOp",
- "RsqrtFwdOp",
- "SignFwdOp",
- "SinFwdOp",
- "SqrtFwdOp",
- "TruncFwdOp",
- # --- activations (8) ---
- "GeluFwdOp",
- "HardsigmoidFwdOp",
- "HardswishFwdOp",
- "MishFwdOp",
- "SeluFwdOp",
- "SigmoidFwdOp",
- "SiluFwdOp",
- "TanhFwdOp",
- # --- logical (1) ---
- "LogicalNotFwdOp",
- # --- bitwise (1) ---
- "BitwiseNotFwdOp",
- # --- special predicates (3) ---
- "IsfiniteFwdOp",
- "IsinfFwdOp",
- "IsnanFwdOp",
- # --- independent (custom-signature, 11) ---
- "LeakyReluFwdOp",
- "EluFwdOp",
- "HardtanhFwdOp",
- "SoftplusFwdOp",
- "PreluFwdOp",
- "WhereFwdOp",
- "LerpTensorFwdOp",
- "ClampFwdOp",
- "ClampScalarFwdOp",
- "ClampMinFwdOp",
- "ClampMaxFwdOp",
- "MaskedFillFwdOp",
- "MaskedFillScalarFwdOp",
- "NanToNumFwdOp",
- "AlibiFwdOp",
- "SinusoidalFwdOp",
-]
-
-
-def coalesce_broadcast_dims(a_shape, b_shape):
- """Coalesce N-dim broadcast into minimal effective dimensions.
-
- Merges adjacent dimensions that have the same broadcast behaviour
- (both real or both broadcast) to minimise the number of divmod
- operations inside the kernel loop.
-
- Args:
- a_shape: Shape tuple of input a.
- b_shape: Shape tuple of input b.
-
- Returns:
- Tuple of (out_shape, coalesced_shape, a_strides, b_strides) where
- strides use 0 for broadcast dimensions.
- """
- # Normalise scalar (0-dim) inputs to 1-dim with size 1
- if len(a_shape) == 0:
- a_shape = (1,)
- if len(b_shape) == 0:
- b_shape = (1,)
-
- out_shape = torch.broadcast_shapes(a_shape, b_shape)
- ndim = len(out_shape)
- a_pad = (1,) * (ndim - len(a_shape)) + tuple(a_shape)
- b_pad = (1,) * (ndim - len(b_shape)) + tuple(b_shape)
-
- def _make_strides(padded_shape):
- strides = [1] * ndim
- for i in range(ndim - 2, -1, -1):
- strides[i] = strides[i + 1] * padded_shape[i + 1]
- # Only zero strides for genuinely broadcast dims (size-1 expanded to >1)
- return [
- 0 if padded_shape[i] == 1 and out_shape[i] > 1 else strides[i]
- for i in range(ndim)
- ]
-
- a_raw = _make_strides(a_pad)
- b_raw = _make_strides(b_pad)
-
- # Coalesce adjacent dims with compatible broadcast patterns
- groups = [(out_shape[0], a_raw[0], b_raw[0])]
- for i in range(1, ndim):
- prev_out, prev_as, prev_bs = groups[-1]
- a_can = (a_raw[i] == 0 and prev_as == 0) or (
- a_raw[i] != 0 and prev_as == a_raw[i] * out_shape[i]
- )
- b_can = (b_raw[i] == 0 and prev_bs == 0) or (
- b_raw[i] != 0 and prev_bs == b_raw[i] * out_shape[i]
- )
- if a_can and b_can:
- groups[-1] = (prev_out * out_shape[i], a_raw[i], b_raw[i])
- else:
- groups.append((out_shape[i], a_raw[i], b_raw[i]))
-
- # Remove trivial size-1 groups (unless all trivial)
- groups = [g for g in groups if g[0] > 1] or [(1, 0, 0)]
- coalesced_shape = tuple(g[0] for g in groups)
- a_strides = tuple(g[1] for g in groups)
- b_strides = tuple(g[2] for g in groups)
- return out_shape, coalesced_shape, a_strides, b_strides
-
-
-def _apply_fp8_post_cast(result: torch.Tensor, kernel) -> torch.Tensor:
- """Apply fp8 output cast if the kernel requires it.
-
- For e5m2 dtypes the kernel produces fp16 output to preserve Inf/NaN;
- this helper performs the final non-saturating cast via PyTorch.
- """
- fp8_out = getattr(kernel, "_fp8_output_dtype", None)
- if fp8_out is not None:
- return result.to(fp8_out)
- return result
-
-
-_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2)
-
-
-def _is_fp8(dtype: torch.dtype) -> bool:
- """Return True iff ``dtype`` is one of the supported fp8 dtypes."""
- return dtype in _FP8_DTYPES
-
-
-def _fp8_compute_dtype(dtype: torch.dtype) -> torch.dtype:
- """Return the compute dtype used to emulate fp8 elementwise fallbacks.
-
- PyTorch's CUDA backend does not implement ``clamp``/``maximum``/
- ``minimum``/``masked_fill_`` on Float8 tensors (raises NotImplementedError
- on ``clamp_cuda`` / ``max_elementwise_cuda`` / ``min_elementwise_cuda`` /
- ``masked_fill_``). Both e4m3fn (finite range ±448) and e5m2 (finite range
- ±57344) fit in fp16, so we upcast to fp16, run the op, and cast back. The
- final cast preserves Inf/NaN for e5m2 (PyTorch's fp16->e5m2 conversion is
- non-saturating) and saturates for e4m3fn (matching PyTorch's default
- fp16->e4m3fn behaviour).
- """
- if not _is_fp8(dtype):
- raise ValueError(f"_fp8_compute_dtype expects an fp8 dtype, got {dtype}")
- return torch.float16
-
-
-class UnaryOp(Op):
- """Template base class for unary elementwise ops.
-
- Subclass must set ``kernel_cls`` and ``_op_name``.
- Subclass should also set ``_wrapped`` via ``_register_unary_custom_op``
- to enable torch.compile support.
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- strategy: Kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune.
- """
-
- kernel_cls: type
- _op_name: str
- _wrapped = None # Set by _register_unary_custom_op at class definition
- # Per-element FLOP count, matching the manifest's ``roofline.flops``
- # coefficient on ``N``. Subclasses override when the op is more than one
- # arithmetic op per element (e.g. ``sigmoid`` ≈ 4, ``tanh`` ≈ 5). The
- # base class default of 1 covers the common ``flops: "N"`` entries.
- FLOPS_PER_ELEM: int = 1
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- self.N_total = N_total
- self.dtype = dtype
- self.strategy = strategy
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](
- N_total, dtype, strategy=strategy, tune=tune,
- )
- # Use _fp8_output_dtype (the final dtype after Op-layer post-cast)
- # rather than kernel.output_dtype (which is fp16 for e5m2).
- fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
- self.output_dtype = fp8_out or getattr(self.kernel, "output_dtype", dtype)
- # Register in global registry for torch.compile dispatch
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self) -> Dict[str, Kernel]:
- return {self._op_name: self.kernel_cls}
-
- @property
- def total_memory(self) -> float:
- """Read x + write y."""
- return self.N_total * (self.dtype.itemsize + self.output_dtype.itemsize)
-
- def eval_roofline(self) -> tuple[int, int]:
- """Return ``(flops, bytes)`` for this unary elementwise op instance.
-
- Mirrors the elementwise_unary_math manifest roofline:
- ``flops = FLOPS_PER_ELEM * N`` and
- ``bytes = N * input_elem_bytes + N * output_elem_bytes``. Subclasses
- whose manifest entry uses a higher coefficient (e.g. ``sigmoid`` →
- ``4 * N``, ``tanh`` → ``5 * N``) override ``FLOPS_PER_ELEM``. For ops
- whose output dtype matches the input (e.g. ``neg``, ``abs``), bytes
- collapse to ``2 * N * elem_bytes``; for ops with a smaller output
- dtype (e.g. ``isnan`` / ``isinf`` / ``isfinite`` / ``logical_not`` →
- bool), ``self.output_dtype.itemsize`` already captures the difference.
- """
- return self.FLOPS_PER_ELEM * self.N_total, int(self.total_memory)
-
- def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- """Direct kernel call for use inside custom_op implementation."""
- orig_shape = input.shape
- flat = input.contiguous().reshape(-1)
- result = self.kernel(flat).reshape(orig_shape)
- # For e5m2: kernel produces fp16 to preserve Inf/NaN;
- # cast to e5m2 here using PyTorch's non-saturating conversion.
- return _apply_fp8_post_cast(result, self.kernel)
-
- def _validate_input(self, input: torch.Tensor) -> None: # noqa: A002
- """Validate input tensor against the op's dtype / numel contract."""
- if not input.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if input.dtype != self.dtype:
- raise ValueError(
- f"Expected input.dtype {self.dtype}, got {input.dtype}"
- )
- if input.numel() != self.N_total:
- raise ValueError(
- f"Expected {self.N_total} elements, got {input.numel()}"
- )
-
- def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- self._validate_input(input)
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, self._instance_key)
- return self._eager_forward(input)
-
-
-class BinaryOp(Op):
- """Template base class for binary elementwise ops with broadcast.
-
- Subclass must set ``kernel_cls`` and ``_op_name``.
- Subclass should also set ``_wrapped`` via ``_register_binary_custom_op``
- to enable torch.compile support.
-
- Args:
- a_shape: Shape of input a.
- b_shape: Shape of input b.
- dtype: Torch dtype.
- strategy: Kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune.
- """
-
- kernel_cls: type
- _op_name: str
- _wrapped = None # Set by _register_binary_custom_op at class definition
- # Subclasses may set ``_other_name`` to a manifest-aligned parameter
- # name (e.g. ``"exponent"`` for ``PowFwdOp``, ``"end"`` for
- # ``LerpFwdOp``); the L1 signature check sees the renamed parameter
- # via ``__init_subclass__`` rebinding ``forward.__signature__``.
- _other_name: str = "other"
-
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__(**kwargs)
- other_name = cls.__dict__.get("_other_name")
- if other_name is None or other_name == "other":
- return
- base_forward = cls.forward
- try:
- sig = inspect.signature(base_forward)
- except (ValueError, TypeError):
- return
- new_params = [
- p.replace(name=other_name) if p.name == "other" else p
- for p in sig.parameters.values()
- ]
- new_sig = sig.replace(parameters=new_params)
-
- def forward(self, *args, **kwargs):
- if other_name in kwargs:
- kwargs["other"] = kwargs.pop(other_name)
- return base_forward(self, *args, **kwargs)
-
- forward.__signature__ = new_sig
- forward.__name__ = "forward"
- forward.__qualname__ = f"{cls.__qualname__}.forward"
- cls.forward = forward
-
- def __init__(
- self,
- a_shape: tuple,
- b_shape: tuple,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- kernel_supported = self.kernel_cls.SUPPORTED_DTYPES
- if kernel_supported is not None and dtype not in kernel_supported:
- names = ", ".join(str(dt) for dt in kernel_supported)
- raise ValueError(
- f"{self._op_name} does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.dtype = dtype
- self.a_shape = tuple(a_shape)
- self.b_shape = tuple(b_shape)
- self.strategy = strategy
- out_shape, coalesced_shape, a_strides, b_strides = coalesce_broadcast_dims(
- a_shape, b_shape,
- )
- self.out_shape = out_shape
- self._out_shape_list = list(out_shape) # cached for custom_op hot path
- self.N_total = prod(out_shape)
- self.a_numel = prod(a_shape)
- self.b_numel = prod(b_shape)
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](
- self.N_total, dtype, coalesced_shape, a_strides, b_strides,
- self.a_numel, self.b_numel, strategy=strategy, tune=tune,
- )
- # Register in global registry for torch.compile dispatch
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self) -> Dict[str, Kernel]:
- return {self._op_name: self.kernel_cls}
-
- @property
- def total_memory(self) -> float:
- """Read a + read b + write y."""
- in_elem = self.dtype.itemsize
- fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
- out_elem = fp8_out.itemsize if fp8_out is not None else in_elem
- return (self.a_numel + self.b_numel) * in_elem + self.N_total * out_elem
-
- def _eager_forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- other: torch.Tensor,
- ) -> torch.Tensor:
- """Direct kernel call for use inside custom_op implementation."""
- result = self.kernel(
- input.contiguous().view(-1), other.contiguous().view(-1),
- ).reshape(self.out_shape)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- other: torch.Tensor,
- ) -> torch.Tensor:
- a_name = getattr(self, "_input_name", "input")
- b_name = getattr(self, "_other_name", "other")
- if not input.is_cuda or not other.is_cuda:
- raise ValueError("Inputs must be CUDA tensors")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected {a_name}.dtype {self.dtype}, got {input.dtype}")
- if other.dtype != self.dtype:
- raise ValueError(f"Expected {b_name}.dtype {self.dtype}, got {other.dtype}")
- if input.numel() != self.a_numel:
- raise ValueError(
- f"Expected {a_name} to have {self.a_numel} elements, got {input.numel()}"
- )
- if other.numel() != self.b_numel:
- raise ValueError(
- f"Expected {b_name} to have {self.b_numel} elements, got {other.numel()}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, other, self._out_shape_list, self._instance_key)
- return self._eager_forward(input, other)
-
-
-class FusedGatedOp(Op):
- """Template base class for fused gated elementwise ops.
-
- Input: x of shape (M, 2*N). gate = x[:, :N], value = x[:, N:].
- Output: y = activation(gate) * value, shape (M, N).
-
- Subclass must set ``kernel_cls`` and ``_op_name``.
- Subclass should also set ``_wrapped`` via ``_register_fused_gated_custom_op``
- to enable torch.compile support.
-
- Args:
- M: Number of rows.
- N: Half column dim (output width).
- dtype: Torch dtype.
- strategy: Kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune.
- """
-
- kernel_cls: type
- _op_name: str
- _wrapped = None # Set by _register_fused_gated_custom_op at class definition
-
- def __init__(
- self,
- M: int,
- N: int,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- supported = self.kernel_cls.SUPPORTED_DTYPES
- if supported is not None and dtype not in supported:
- names = ", ".join(str(dt) for dt in supported)
- raise ValueError(
- f"{self._op_name} does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.M = M
- self.N = N
- self.dtype = dtype
- self.strategy = strategy
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](
- M, N, dtype, strategy=strategy, tune=tune,
- )
- # Register in global registry for torch.compile dispatch
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self) -> Dict[str, Kernel]:
- return {self._op_name: self.kernel_cls}
-
- @property
- def total_memory(self) -> float:
- """Read x (M*2N) + write y (M*N)."""
- in_elem = self.dtype.itemsize
- fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
- out_elem = fp8_out.itemsize if fp8_out is not None else in_elem
- return self.M * 2 * self.N * in_elem + self.M * self.N * out_elem
-
- def _eager_forward(self, x: torch.Tensor) -> torch.Tensor:
- """Direct kernel call for use inside custom_op implementation."""
- x = x.contiguous()
- result = self.kernel(x)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if not x.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if x.dtype != self.dtype:
- raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}")
- if x.shape != (self.M, 2 * self.N):
- raise ValueError(
- f"Expected shape ({self.M}, {2 * self.N}), got {tuple(x.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(x, self.M, self.N, self._instance_key)
- return self._eager_forward(x)
-
-
-# ---------------------------------------------------------------------------
-# Concrete op subclasses
-# ---------------------------------------------------------------------------
-
-
-class _UnaryActivationMixin:
- """Shared ``forward`` / inplace dispatch for unary activation Ops.
-
- The ten unary activation Ops (six param-free: ReLU, SiLU, HardSwish,
- HardSigmoid, Mish, SELU; four parametric: LeakyReLU, ELU, Hardtanh,
- Softplus) share an identical ``forward`` template:
-
- 1. validate ``input`` against the op's ``dtype`` / ``N_total`` contract,
- 2. when ``self.inplace`` is true, dispatch through ``_wrapped_inplace``
- (registered with ``mutates_args=("x",)`` so ``torch.compile`` traces
- the mutation correctly) and return the original ``input`` so callers
- see ``y is x``,
- 3. otherwise dispatch through the standard ``_wrapped`` custom op or
- fall back to ``_eager_forward``.
-
- Concrete classes provide ``_validate_input`` and ``_eager_forward``
- (both inherited from ``UnaryOp``) plus ``self.inplace`` /
- ``self._instance_key`` state. Leaves that do not expose ``inplace``
- in their signature (e.g. Softplus) simply default ``self.inplace`` to
- ``False`` via ``_finalize_init``.
- """
-
- # Set by ``_register_unary_inplace_custom_op`` for leaves that
- # declare ``inplace`` in their manifest signature. Stays ``None``
- # when the leaf does not support inplace (e.g. Softplus, or a
- # test-only subclass that skipped registration).
- _wrapped_inplace = None
-
- def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- self._validate_input(input)
- if self.inplace:
- wrapped_inplace = type(self)._wrapped_inplace
- if wrapped_inplace is not None:
- wrapped_inplace(input, self._instance_key)
- return input
- # No inplace custom op registered (e.g. test-only subclass);
- # fall back to direct mutation via the eager path.
- result = self._eager_forward(input)
- input.copy_(result.reshape(input.shape))
- return input
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, self._instance_key)
- return self._eager_forward(input)
-
-
-class _ParamFreeActivationOp(_UnaryActivationMixin, UnaryOp):
- """Shared base for the param-free activation Op group.
-
- Centralizes the canonical constructor used by activations whose only
- manifest-declared parameter is ``inplace`` (ReLU, SiLU, HardSwish,
- HardSigmoid, Mish, SELU). Each leaf only declares its op-specific
- class fields (``_op_name``, ``kernel_cls``, ``FLOPS_PER_ELEM``,
- docstring); ``forward``/``_eager_forward`` come from
- ``_UnaryActivationMixin`` / ``UnaryOp``.
- """
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- *,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- inplace: bool = False,
- ):
- super().__init__(
- N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
- )
- self.inplace = inplace
-
-
-class _ParametricActivationOp(_UnaryActivationMixin, UnaryOp):
- """Shared base for the parametric activation Op group.
-
- Used by activations that take one or more scalar construction-time
- parameters (LeakyReLU, ELU, Hardtanh, Softplus). Leaves own their
- ``__init__`` (scalar parameter names and defaults vary per leaf):
- each leaf validates its scalars, populates ``self.`` for
- introspection, instantiates ``self.kernel`` with typed kwargs, and
- registers itself with ``_OP_REGISTRY`` via the
- ``_finalize_init`` helper. ``UnaryOp.__init__`` is intentionally
- bypassed; ``_finalize_init`` performs the equivalent state setup.
-
- Leaves that declare ``inplace`` in the manifest signature accept it
- in ``__init__`` and pass it to ``_finalize_init``. ``forward`` and
- ``_eager_forward`` are inherited from the mixin and ``UnaryOp``.
- """
-
- def _finalize_init(
- self,
- N_total: int,
- dtype: torch.dtype,
- kernel: Kernel,
- *,
- inplace: bool = False,
- ) -> None:
- """Record the leaf-built kernel and wire shared base state.
-
- The leaf has already called ``self.dispatch_kernel(kernel_map)``
- and instantiated its kernel directly with typed kwargs. This
- helper records the kernel on ``self`` and runs the
- ``_OP_REGISTRY`` registration shared by every parametric leaf.
- """
- self.N_total = N_total
- self.dtype = dtype
- self.inplace = inplace
- self.kernel = kernel
- # Mirror ``UnaryOp.__init__``: surface ``output_dtype`` so callers
- # and ``total_memory`` can reason about FP8 post-casts. Parametric
- # activations do not currently declare an FP8 path, so the common
- # branch returns ``self.dtype``; the lookup is kept for parity.
- fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
- self.output_dtype = fp8_out or getattr(self.kernel, "output_dtype", dtype)
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
-
-class ReluFwdOp(_ParamFreeActivationOp):
- """ReLU activation: y = max(x, 0)."""
-
- _op_name = "relu"
- kernel_cls = ReluFwdKernel
- # Manifest: flops = "2 * N" (compare + select per element).
- FLOPS_PER_ELEM = 2
-
-
-class _AlphaScaledBinaryOp(BinaryOp):
- """Shared base for ops that take a scalar ``alpha`` multiplier on ``other``.
-
- PyTorch ``torch.add(input, other, alpha=1)`` and ``torch.sub(input,
- other, alpha=1)`` scale ``other`` by ``alpha`` before the binary op.
- The current kernel only honors the manifest-declared default
- (``alpha == 1``); non-default ``alpha`` values raise
- ``NotImplementedError`` until a kernel-side scalar multiplier lands.
- The leading ``*`` makes ``alpha`` and the existing
- ``strategy`` / ``kernel_map`` / ``tune`` parameters keyword-only;
- only the positional triplet ``(a_shape, b_shape, dtype)`` is shared
- with ``BinaryOp``.
- """
-
- def __init__(
- self,
- a_shape: tuple,
- b_shape: tuple,
- dtype: torch.dtype,
- *,
- alpha: int | float = 1,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- super().__init__(
- a_shape, b_shape, dtype, strategy=strategy,
- kernel_map=kernel_map, tune=tune,
- )
- self.alpha = alpha
-
-
-class AddFwdOp(_AlphaScaledBinaryOp):
- """Element-wise addition with broadcast: y = input + alpha * other.
-
- Conforms to ``torch.add(input, other, *, alpha=1)``. Only ``alpha == 1``
- dispatches to the kernel; non-default ``alpha`` raises
- ``NotImplementedError`` until a kernel-side scalar multiplier lands
- (tracked in a follow-up issue).
- """
-
- _op_name = "add"
- kernel_cls = AddFwdKernel
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002
- other: torch.Tensor,
- ) -> torch.Tensor:
- if self.alpha != 1:
- raise NotImplementedError(
- "AddFwdOp(alpha != 1) is not yet implemented; the current "
- "kernel only honors alpha == 1. A follow-up "
- "issue tracks the kernel work."
- )
- return super().forward(input, other)
-
-
-class SubFwdOp(_AlphaScaledBinaryOp):
- """Element-wise subtraction with broadcast: y = input - alpha * other.
-
- Conforms to ``torch.sub(input, other, *, alpha=1)``. Only ``alpha == 1``
- dispatches to the kernel; non-default ``alpha`` raises
- ``NotImplementedError`` until a kernel-side scalar multiplier lands
- (tracked in a follow-up issue).
- """
-
- _op_name = "sub"
- kernel_cls = SubFwdKernel
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002
- other: torch.Tensor,
- ) -> torch.Tensor:
- if self.alpha != 1:
- raise NotImplementedError(
- "SubFwdOp(alpha != 1) is not yet implemented; the current "
- "kernel only honors alpha == 1. A follow-up "
- "issue tracks the kernel work."
- )
- return super().forward(input, other)
-
-
-class MulFwdOp(BinaryOp):
- """Element-wise multiplication with broadcast: y = input * other."""
-
- _op_name = "mul"
- kernel_cls = MulFwdKernel
-
-
-class DivFwdOp(BinaryOp):
- """Element-wise division with broadcast: y = input / other.
-
- Conforms to ``torch.div(input, other, *, rounding_mode=None)``.
- ``rounding_mode`` accepts ``None`` (true division), ``"trunc"``
- (truncation toward zero), or ``"floor"`` (floor division). Only
- ``rounding_mode is None`` dispatches to the kernel; the trunc /
- floor variants raise ``NotImplementedError`` until a rounded-divide
- kernel lands (tracked in a follow-up issue). The leading ``*``
- makes ``rounding_mode`` and the existing ``strategy`` /
- ``kernel_map`` / ``tune`` parameters keyword-only; only the
- positional triplet ``(a_shape, b_shape, dtype)`` is shared with
- ``BinaryOp``.
- """
-
- _op_name = "div"
- kernel_cls = DivFwdKernel
-
- def __init__(
- self,
- a_shape: tuple,
- b_shape: tuple,
- dtype: torch.dtype,
- *,
- rounding_mode: Optional[str] = None,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if rounding_mode is not None and rounding_mode not in ("trunc", "floor"):
- raise ValueError(
- f"DivFwdOp received rounding_mode={rounding_mode!r}; "
- "manifest allows None, 'trunc', or 'floor'"
- )
- super().__init__(
- a_shape, b_shape, dtype, strategy=strategy,
- kernel_map=kernel_map, tune=tune,
- )
- self.rounding_mode = rounding_mode
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002
- other: torch.Tensor,
- ) -> torch.Tensor:
- if self.rounding_mode is not None:
- raise NotImplementedError(
- f"DivFwdOp(rounding_mode={self.rounding_mode!r}) is not yet "
- "implemented; the current kernel only honors rounding_mode is "
- "None. A follow-up issue tracks the kernel work."
- )
- return super().forward(input, other)
-
-
-class RemainderFwdOp(BinaryOp):
- """Element-wise remainder with broadcast: y = a % b."""
-
- _op_name = "remainder"
- kernel_cls = RemainderFwdKernel
-
-
-class PowFwdOp(BinaryOp):
- """Element-wise power with broadcast: y = input ** exponent.
-
- Conforms to ``torch.pow(input, exponent)``: the second operand carries
- the manifest-declared name ``exponent`` rather than the generic
- ``other`` so the L1 signature check matches the manifest.
- """
-
- _op_name = "pow"
- kernel_cls = PowFwdKernel
- _other_name = "exponent"
-
-
-class FloorDivideFwdOp(BinaryOp):
- """Element-wise floor division with broadcast: y = floor(a / b)."""
-
- _op_name = "floor_divide"
- kernel_cls = FloorDivideFwdKernel
-
-
-class LerpFwdOp(BinaryOp):
- """Element-wise lerp with broadcast: y = a + weight * (b - a).
-
- Unlike ``torch.lerp(a, b, weight)`` where weight is a runtime parameter,
- here weight is a **construction-time constant** baked into the compiled
- kernel. This enables compile-time folding but means a new Op instance is
- needed for each distinct weight value.
-
- Args:
- a_shape: Shape of input a.
- b_shape: Shape of input b.
- dtype: Torch dtype.
- weight: Scalar interpolation weight, fixed at construction (default 0.5).
- strategy: Kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune.
- """
-
- _op_name = "lerp"
- kernel_cls = LerpFwdKernel
- _other_name = "end"
-
- def __init__(
- self,
- a_shape: tuple,
- b_shape: tuple,
- dtype: torch.dtype,
- weight: float = 0.5,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- supported = self.kernel_cls.SUPPORTED_DTYPES
- if supported is not None and dtype not in supported:
- names = ", ".join(str(dt) for dt in supported)
- raise ValueError(
- f"{self._op_name} does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.dtype = dtype
- self.a_shape = tuple(a_shape)
- self.b_shape = tuple(b_shape)
- self.strategy = strategy
- self._weight = weight
- out_shape, coalesced_shape, a_strides, b_strides = coalesce_broadcast_dims(
- a_shape, b_shape,
- )
- self.out_shape = out_shape
- self._out_shape_list = list(out_shape) # cached for custom_op hot path
- self.N_total = prod(out_shape)
- self.a_numel = prod(a_shape)
- self.b_numel = prod(b_shape)
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](
- self.N_total, dtype, coalesced_shape, a_strides, b_strides,
- self.a_numel, self.b_numel, strategy=strategy, tune=tune,
- weight=weight,
- )
- # Register in global registry for torch.compile dispatch
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
-
-class MaximumFwdOp(BinaryOp):
- """Element-wise maximum with broadcast: y = max(a, b)."""
-
- _op_name = "maximum"
- kernel_cls = MaximumFwdKernel
-
-
-class MinimumFwdOp(BinaryOp):
- """Element-wise minimum with broadcast: y = min(a, b)."""
-
- _op_name = "minimum"
- kernel_cls = MinimumFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Comparison op subclasses (output bool)
-# ---------------------------------------------------------------------------
-#
-# Kernels produce int8 (1/0) because TileLang cannot vectorize bool.
-# The Op forward() casts to torch.bool after the kernel call.
-
-
-class _BoolOutputBinaryOp(BinaryOp):
- """Binary op base whose kernel emits int8 (1/0) and whose Op output is bool.
-
- TileLang cannot vectorize bool, so the kernel produces int8. The Op
- casts to ``torch.bool`` after the kernel call. ``register_fake``
- already declares ``torch.bool`` as the output dtype, so the
- ``torch.compile`` path stays consistent.
- """
-
- def _eager_forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- other: torch.Tensor,
- ) -> torch.Tensor:
- result = super()._eager_forward(input, other)
- return result.to(torch.bool)
-
-
-class EqFwdOp(_BoolOutputBinaryOp):
- """Element-wise equality with broadcast: y = (a == b)."""
-
- _op_name = "eq"
- kernel_cls = EqFwdKernel
-
-
-class NeFwdOp(_BoolOutputBinaryOp):
- """Element-wise not-equal with broadcast: y = (a != b)."""
-
- _op_name = "ne"
- kernel_cls = NeFwdKernel
-
-
-class GtFwdOp(_BoolOutputBinaryOp):
- """Element-wise greater-than with broadcast: y = (a > b)."""
-
- _op_name = "gt"
- kernel_cls = GtFwdKernel
-
-
-class LtFwdOp(_BoolOutputBinaryOp):
- """Element-wise less-than with broadcast: y = (a < b)."""
-
- _op_name = "lt"
- kernel_cls = LtFwdKernel
-
-
-class GeFwdOp(_BoolOutputBinaryOp):
- """Element-wise greater-equal with broadcast: y = (a >= b)."""
-
- _op_name = "ge"
- kernel_cls = GeFwdKernel
-
-
-class LeFwdOp(_BoolOutputBinaryOp):
- """Element-wise less-equal with broadcast: y = (a <= b)."""
-
- _op_name = "le"
- kernel_cls = LeFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Logical op subclasses (output bool)
-# ---------------------------------------------------------------------------
-
-
-class LogicalAndFwdOp(_BoolOutputBinaryOp):
- """Element-wise logical AND with broadcast using non-zero truthiness."""
-
- _op_name = "logical_and"
- kernel_cls = LogicalAndFwdKernel
-
-
-class LogicalOrFwdOp(_BoolOutputBinaryOp):
- """Element-wise logical OR with broadcast using non-zero truthiness."""
-
- _op_name = "logical_or"
- kernel_cls = LogicalOrFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Bitwise op subclasses
-# ---------------------------------------------------------------------------
-
-
-class BitwiseAndFwdOp(BinaryOp):
- """Element-wise bitwise AND with broadcast: y = a & b."""
-
- _op_name = "bitwise_and"
- kernel_cls = BitwiseAndFwdKernel
-
-
-class BitwiseOrFwdOp(BinaryOp):
- """Element-wise bitwise OR with broadcast: y = a | b."""
-
- _op_name = "bitwise_or"
- kernel_cls = BitwiseOrFwdKernel
-
-
-class BitwiseXorFwdOp(BinaryOp):
- """Element-wise bitwise XOR with broadcast: y = a ^ b."""
-
- _op_name = "bitwise_xor"
- kernel_cls = BitwiseXorFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Fused gated op subclasses
-# ---------------------------------------------------------------------------
-
-
-class SiluAndMulFwdOp(FusedGatedOp):
- """SiLU-and-Mul: y = silu(gate) * value."""
-
- _op_name = "silu_and_mul"
- kernel_cls = SiluAndMulFwdKernel
-
-
-class GeluAndMulFwdOp(FusedGatedOp):
- """GELU-and-Mul: y = gelu(gate) * value (exact GELU)."""
-
- _op_name = "gelu_and_mul"
- kernel_cls = GeluAndMulFwdKernel
-
-
-class GeluTanhAndMulFwdOp(FusedGatedOp):
- """GELU-Tanh-and-Mul: y = gelu_tanh(gate) * value (tanh approximation)."""
-
- _op_name = "gelu_tanh_and_mul"
- kernel_cls = GeluTanhAndMulFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Unary math ops (17)
-# ---------------------------------------------------------------------------
-
-
-class ExpFwdOp(UnaryOp):
- """Element-wise exp(x)."""
-
- _op_name = "exp"
- kernel_cls = ExpFwdKernel
-
-
-class LogFwdOp(UnaryOp):
- """Element-wise log(x)."""
-
- _op_name = "log"
- kernel_cls = LogFwdKernel
-
-
-class SqrtFwdOp(UnaryOp):
- """Element-wise sqrt(x)."""
-
- _op_name = "sqrt"
- kernel_cls = SqrtFwdKernel
-
-
-class RsqrtFwdOp(UnaryOp):
- """Element-wise 1/sqrt(x)."""
-
- _op_name = "rsqrt"
- kernel_cls = RsqrtFwdKernel
-
-
-_MANIFEST_INT_DTYPES = (
- torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
-)
-
-
-def _int_identity(input: torch.Tensor) -> torch.Tensor:
- return input.clone()
-
-
-class _IntIdentityUnaryOp(UnaryOp):
- """Base for unary ops whose manifest declares integer dtypes but whose
- kernel is float-only.
-
- Several manifest entries (floor / ceil / round / trunc, abs / neg / sign,
- isnan / isinf / isfinite) declare both integer and floating-point input
- dtypes, while the underlying ``*FwdKernel`` classes are float-only
- (``FloatUnaryKernel``). For integer inputs we short-circuit at the op
- layer: skip kernel construction in ``__init__`` and route through
- ``_int_handler`` in ``_eager_forward``.
-
- Subclasses override ``_int_handler`` (default = identity = ``input.clone()``)
- and ``_int_output_dtype`` (default = same as input) to express the
- appropriate integer semantics — e.g. ``torch.abs`` for ``AbsFwdOp``,
- constant-False ``torch.bool`` for ``IsnanFwdOp``.
-
- The short-circuit is restricted to the integer dtypes declared in the
- manifest. Other non-float dtypes (bool, complex) are not in the
- contract and fall through to ``UnaryOp.__init__``, which raises via the
- kernel's dtype check.
- """
-
- _int_handler: Callable[[torch.Tensor], torch.Tensor] = staticmethod(
- _int_identity)
- _int_output_dtype: Optional[torch.dtype] = None
- # Subclasses may extend the fallback dtype set when the manifest
- # signature includes additional non-float dtypes (e.g. torch.bool for
- # the is{nan,inf,finite} predicates).
- _fallback_dtypes: tuple = _MANIFEST_INT_DTYPES
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if dtype in type(self)._fallback_dtypes:
- self.N_total = N_total
- self.dtype = dtype
- self.strategy = strategy
- # The float-only kernel cannot be instantiated for an integer
- # dtype, so the kernel itself stays unconstructed. The kernel_map
- # is still installed through the shared validate-and-install path
- # so a user-supplied override is arch-checked identically to the
- # auto-discovered map on the float path.
- self._install_kernel_map(kernel_map)
- self.kernel = None
- self.output_dtype = (
- type(self)._int_output_dtype
- if type(self)._int_output_dtype is not None
- else dtype
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
- return
- super().__init__(
- N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
- )
-
- def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- if self.kernel is None:
- return type(self)._int_handler(input)
- return super()._eager_forward(input)
-
-
-class AbsFwdOp(_IntIdentityUnaryOp):
- """Element-wise |x|."""
-
- _op_name = "abs"
- kernel_cls = AbsFwdKernel
- _int_handler = staticmethod(torch.abs)
-
-
-class NegFwdOp(_IntIdentityUnaryOp):
- """Element-wise -x."""
-
- _op_name = "neg"
- kernel_cls = NegFwdKernel
- _int_handler = staticmethod(torch.neg)
-
-
-class ReciprocalFwdOp(UnaryOp):
- """Element-wise 1/x.
-
- Mirrors ``torch.reciprocal`` int-input promotion: integral dtypes
- (uint8 / int8 / int16 / int32 / int64) are cast to float32 before the
- float kernel runs, and the op's ``output_dtype`` is float32 in that
- case. Floating inputs (float16 / bfloat16 / float32) follow the
- standard same-dtype path.
- """
-
- _op_name = "reciprocal"
- kernel_cls = ReciprocalFwdKernel
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if dtype in _MANIFEST_INT_DTYPES:
- # Build the kernel against the promoted compute dtype (float32)
- # so the float-only ReciprocalFwdKernel can run, then restore
- # the user-declared dtype on ``self.dtype`` so metadata and
- # ``eval_roofline`` reflect the real I/O contract: integer
- # input bytes + float32 output bytes. ``self.output_dtype``
- # stays float32 (set by the kernel) per the manifest's
- # ``promote_int_to_float`` contract.
- super().__init__(
- N_total, torch.float32, strategy=strategy,
- kernel_map=kernel_map, tune=tune,
- )
- self.dtype = dtype
- else:
- super().__init__(
- N_total, dtype, strategy=strategy,
- kernel_map=kernel_map, tune=tune,
- )
-
- def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- if self.dtype in _MANIFEST_INT_DTYPES:
- self._validate_input(input)
- promoted = input.to(torch.float32)
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(promoted, self._instance_key)
- return self._eager_forward(promoted)
- return super().forward(input)
-
-
-class SignFwdOp(_IntIdentityUnaryOp):
- """Element-wise sign(x): -1, 0, or +1."""
-
- _op_name = "sign"
- kernel_cls = SignFwdKernel
- # Manifest: flops = "2 * N" (two compares + selects per element).
- FLOPS_PER_ELEM = 2
- _int_handler = staticmethod(torch.sign)
-
-
-class SinFwdOp(UnaryOp):
- """Element-wise sin(x)."""
-
- _op_name = "sin"
- kernel_cls = SinFwdKernel
-
-
-class CosFwdOp(UnaryOp):
- """Element-wise cos(x)."""
-
- _op_name = "cos"
- kernel_cls = CosFwdKernel
-
-
-class FloorFwdOp(_IntIdentityUnaryOp):
- """Element-wise floor(x)."""
-
- _op_name = "floor"
- kernel_cls = FloorFwdKernel
-
-
-class CeilFwdOp(_IntIdentityUnaryOp):
- """Element-wise ceil(x)."""
-
- _op_name = "ceil"
- kernel_cls = CeilFwdKernel
-
-
-class RoundFwdOp(_IntIdentityUnaryOp):
- """Element-wise round(x) to ``decimals`` decimal places.
-
- The underlying kernel performs banker's round-to-nearest-integer, matching
- ``torch.round`` for ``decimals=0``. Non-zero ``decimals`` is supported at
- the op layer via the standard decomposition:
- ``round(x, decimals=k) == round(x * 10**k) / 10**k``.
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- strategy: Kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune.
- """
-
- _op_name = "round"
- kernel_cls = RoundFwdKernel
-
- def forward( # noqa: A002
- self, input: torch.Tensor, decimals: int = 0,
- ) -> torch.Tensor:
- if decimals == 0:
- return super().forward(input)
- # Non-zero decimals path still owes the same input contract as the
- # ``decimals=0`` fast path (UnaryOp.forward). Run the shared validator
- # before any fp32 arithmetic so a CPU tensor / wrong dtype / wrong
- # numel cannot silently bypass the checks.
- self._validate_input(input)
- # Integer dtypes are no-ops regardless of decimals (rounding an int
- # produces the same int). Match the float-path identity contract.
- if self.dtype in _MANIFEST_INT_DTYPES:
- return input.clone()
- # Run through fp32 so low-precision inputs (fp16/bf16) cannot overflow
- # when ``torch.round`` internally scales by ``10**decimals`` — e.g.
- # ``100 * 10**4 = 1e6`` exceeds fp16 max (~65504). The single down-cast
- # at the end restores the op's contract dtype. The manifest's
- # ``kernel_map`` continues to describe the round-to-nearest-integer
- # kernel that handles the ``decimals=0`` fast path above.
- return torch.round(input.float(), decimals=decimals).to(self.dtype)
-
-
-class TruncFwdOp(_IntIdentityUnaryOp):
- """Element-wise trunc(x)."""
-
- _op_name = "trunc"
- kernel_cls = TruncFwdKernel
-
-
-class ErfFwdOp(UnaryOp):
- """Element-wise erf(x)."""
-
- _op_name = "erf"
- kernel_cls = ErfFwdKernel
-
-
-class Log1pFwdOp(UnaryOp):
- """Element-wise log(1 + x)."""
-
- _op_name = "log1p"
- kernel_cls = Log1pFwdKernel
- # Manifest: flops = "2 * N" (1 add + 1 log).
- FLOPS_PER_ELEM = 2
-
-
-class Expm1FwdOp(UnaryOp):
- """Element-wise exp(x) - 1."""
-
- _op_name = "expm1"
- kernel_cls = Expm1FwdKernel
- # Manifest: flops = "2 * N" (1 exp + 1 sub).
- FLOPS_PER_ELEM = 2
-
-
-# ---------------------------------------------------------------------------
-# Activation ops (8)
-# ---------------------------------------------------------------------------
-
-
-class _GeluApproximateBase(UnaryOp):
- """Intermediate base that resolves the manifest ``approximate`` field.
-
- Validates the ``approximate`` argument against the manifest's allowed
- values (``'none'`` / ``'tanh'``), records it on ``self.approximate``
- for introspection, and then delegates to ``UnaryOp.__init__``. The
- ``default_kernel_map`` of the leaf op picks the kernel implementation
- from ``self.approximate``.
- """
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- *,
- approximate: str = "none",
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if approximate not in ("none", "tanh"):
- raise ValueError(
- f"{type(self).__name__}: approximate must be 'none' or "
- f"'tanh', got {approximate!r}"
- )
- self.approximate = approximate
- super().__init__(
- N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
- )
-
-
-class GeluFwdOp(_GeluApproximateBase):
- """Element-wise GELU honoring the manifest ``approximate`` contract.
-
- Args:
- N_total: Number of elements (flattened input).
- dtype: Torch dtype.
- approximate: Approximation mode. ``'none'`` (default) routes to
- the erf-based ``GeluFwdKernel``. ``'tanh'`` routes to
- ``GeluTanhFwdKernel`` (the fused tanh approximation
- ``0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))``).
- strategy: Optional kernel strategy override.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "gelu"
- kernel_cls = GeluFwdKernel
- # Manifest: flops = "8 * N" (erf-based: mul + erf + add + mul + mul ≈ 8;
- # tanh approximation is similar order, see manifest comment).
- FLOPS_PER_ELEM = 8
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- @property
- def default_kernel_map(self) -> Dict[str, Kernel]:
- kernel_cls = (
- GeluTanhFwdKernel if self.approximate == "tanh" else GeluFwdKernel
- )
- return {self._op_name: kernel_cls}
-
-
-class SiluFwdOp(_ParamFreeActivationOp):
- """Element-wise SiLU (Swish): y = x * sigmoid(x)."""
-
- _op_name = "silu"
- kernel_cls = SiluFwdKernel
- # Manifest: flops = "4 * N" (sigmoid + multiply).
- FLOPS_PER_ELEM = 4
-
-
-class SigmoidFwdOp(UnaryOp):
- """Element-wise sigmoid(x)."""
-
- _op_name = "sigmoid"
- kernel_cls = SigmoidFwdKernel
- # Manifest: flops = "4 * N" (sigmoid(x) = 1 / (1 + exp(-x)) ≈ 4 ops/elem).
- FLOPS_PER_ELEM = 4
-
-
-class TanhFwdOp(UnaryOp):
- """Element-wise tanh(x)."""
-
- _op_name = "tanh"
- kernel_cls = TanhFwdKernel
- # Manifest: flops = "5 * N" (tanh(x) = 2 * sigmoid(2x) - 1 ≈ 5 ops/elem).
- FLOPS_PER_ELEM = 5
-
-
-class HardswishFwdOp(_ParamFreeActivationOp):
- """Element-wise HardSwish: y = x * clamp(x + 3, 0, 6) / 6."""
-
- _op_name = "hardswish"
- kernel_cls = HardswishFwdKernel
- # Manifest: flops = "7 * N" (add + clamp(2 cmp+2 sel) + mul + div).
- FLOPS_PER_ELEM = 7
-
-
-class HardsigmoidFwdOp(_ParamFreeActivationOp):
- """Element-wise HardSigmoid: y = clamp(x + 3, 0, 6) / 6."""
-
- _op_name = "hardsigmoid"
- kernel_cls = HardsigmoidFwdKernel
- # Manifest: flops = "6 * N" (add + clamp(2 cmp+2 sel) + div).
- FLOPS_PER_ELEM = 6
-
-
-class MishFwdOp(_ParamFreeActivationOp):
- """Element-wise Mish: y = x * tanh(softplus(x))."""
-
- _op_name = "mish"
- kernel_cls = MishFwdKernel
- # Manifest: flops = "7 * N" (softplus + tanh + mul).
- FLOPS_PER_ELEM = 7
-
-
-class SeluFwdOp(_ParamFreeActivationOp):
- """Element-wise SELU activation."""
-
- _op_name = "selu"
- kernel_cls = SeluFwdKernel
- # Manifest: flops = "5 * N" (branch + exp/sub/mul + lambda mul).
- FLOPS_PER_ELEM = 5
-
-
-# ---------------------------------------------------------------------------
-# Logical op (1)
-# ---------------------------------------------------------------------------
-
-
-class LogicalNotFwdOp(UnaryOp):
- """Element-wise logical NOT with bool output."""
-
- _op_name = "logical_not"
- kernel_cls = LogicalNotFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Bitwise op (1)
-# ---------------------------------------------------------------------------
-
-
-class BitwiseNotFwdOp(UnaryOp):
- """Element-wise bitwise NOT (~x) for bool/integer inputs."""
-
- _op_name = "bitwise_not"
- kernel_cls = BitwiseNotFwdKernel
-
-
-# ---------------------------------------------------------------------------
-# Special predicate ops (3)
-# ---------------------------------------------------------------------------
-
-
-def _int_all_false(input: torch.Tensor) -> torch.Tensor:
- return torch.zeros(input.shape, dtype=torch.bool, device=input.device)
-
-
-def _int_all_true(input: torch.Tensor) -> torch.Tensor:
- return torch.ones(input.shape, dtype=torch.bool, device=input.device)
-
-
-_PREDICATE_FALLBACK_DTYPES = _MANIFEST_INT_DTYPES + (torch.bool,)
-
-
-class IsnanFwdOp(_IntIdentityUnaryOp):
- """Element-wise isnan with bool output.
-
- Always False on integer / bool input (no NaN representation in those
- dtypes).
- """
-
- _op_name = "isnan"
- kernel_cls = IsnanFwdKernel
- _int_handler = staticmethod(_int_all_false)
- _int_output_dtype = torch.bool
- _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
-
-
-class IsinfFwdOp(_IntIdentityUnaryOp):
- """Element-wise isinf with bool output.
-
- Always False on integer / bool input (no Inf representation in those
- dtypes).
- """
-
- _op_name = "isinf"
- kernel_cls = IsinfFwdKernel
- _int_handler = staticmethod(_int_all_false)
- _int_output_dtype = torch.bool
- _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
-
-
-class IsfiniteFwdOp(_IntIdentityUnaryOp):
- """Element-wise isfinite with bool output.
-
- Always True on integer / bool input (every value in those dtypes is
- finite).
- """
-
- _op_name = "isfinite"
- kernel_cls = IsfiniteFwdKernel
- _int_handler = staticmethod(_int_all_true)
- _int_output_dtype = torch.bool
- _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
-
-
-# ---------------------------------------------------------------------------
-# Independent (custom-signature) op classes (11)
-# ---------------------------------------------------------------------------
-
-
-class LeakyReluFwdOp(_ParametricActivationOp):
- """Leaky ReLU: y = x if x > 0 else negative_slope * x.
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- negative_slope: Slope for negative inputs (default 0.01).
- inplace: When True, copy the result back into ``input`` and
- return ``input`` (preserving tensor identity). The kernel
- still computes into a fresh buffer; only the user-visible
- tensor is mutated, mirroring ``torch.nn.functional.leaky_relu``.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "leaky_relu"
- _wrapped = None
- # Manifest: flops = "3 * N" (compare + mul + select).
- FLOPS_PER_ELEM = 3
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- negative_slope: float = 0.01,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- inplace: bool = False,
- ):
- _validate_scalar_param_repr("negative_slope", negative_slope, dtype, self._op_name)
- self.negative_slope = negative_slope
- self.dispatch_kernel(kernel_map)
- kernel = self.kernel_map[self._op_name](
- N_total, dtype, negative_slope=negative_slope, tune=tune,
- )
- self._finalize_init(N_total, dtype, kernel, inplace=inplace)
-
- @property
- def default_kernel_map(self):
- return {"leaky_relu": LeakyReluFwdKernel}
-
-
-class EluFwdOp(_ParametricActivationOp):
- """ELU: y = x if x > 0 else alpha * (exp(x) - 1).
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- alpha: Scale for the negative part (default 1.0).
- inplace: When True, copy the result back into ``input`` and
- return ``input`` (preserving tensor identity).
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "elu"
- _wrapped = None
- # Manifest: flops = "5 * N" (compare + (exp + sub + mul) + branch select).
- FLOPS_PER_ELEM = 5
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- alpha: float = 1.0,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- inplace: bool = False,
- ):
- _validate_scalar_param_repr("alpha", alpha, dtype, self._op_name)
- self.alpha = alpha
- self.dispatch_kernel(kernel_map)
- kernel = self.kernel_map[self._op_name](
- N_total, dtype, alpha=alpha, tune=tune,
- )
- self._finalize_init(N_total, dtype, kernel, inplace=inplace)
-
- @property
- def default_kernel_map(self):
- return {"elu": EluFwdKernel}
-
-
-class HardtanhFwdOp(_ParametricActivationOp):
- """Hardtanh: y = clamp(x, min_val, max_val).
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- min_val: Lower bound (default -1.0).
- max_val: Upper bound (default 1.0).
- inplace: When True, copy the result back into ``input`` and
- return ``input`` (preserving tensor identity).
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "hardtanh"
- _wrapped = None
- # Manifest: flops = "4 * N" (2 compares + 2 selects per element).
- FLOPS_PER_ELEM = 4
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- min_val: float = -1.0,
- max_val: float = 1.0,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- inplace: bool = False,
- ):
- _validate_scalar_param_repr("min_val", min_val, dtype, self._op_name)
- _validate_scalar_param_repr("max_val", max_val, dtype, self._op_name)
- self.min_val = min_val
- self.max_val = max_val
- self.dispatch_kernel(kernel_map)
- kernel = self.kernel_map[self._op_name](
- N_total, dtype, min_val=min_val, max_val=max_val, tune=tune,
- )
- self._finalize_init(N_total, dtype, kernel, inplace=inplace)
-
- @property
- def default_kernel_map(self):
- return {"hardtanh": HardtanhFwdKernel}
-
-
-class SoftplusFwdOp(_ParametricActivationOp):
- """Softplus: y = log(1 + exp(x*beta))/beta if x*beta <= threshold else x.
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- beta: Scaling factor (default 1.0).
- threshold: Linear regime threshold (default 20.0).
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "softplus"
- _wrapped = None
- # Manifest: flops = "7 * N" (mul + exp + add + log + div + compare + select).
- FLOPS_PER_ELEM = 7
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- beta: float = 1.0,
- threshold: float = 20.0,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- _validate_scalar_param_repr("beta", beta, dtype, self._op_name)
- _validate_scalar_param_repr("threshold", threshold, dtype, self._op_name)
- self.beta = beta
- self.threshold = threshold
- self.dispatch_kernel(kernel_map)
- kernel = self.kernel_map[self._op_name](
- N_total, dtype, beta=beta, threshold=threshold, tune=tune,
- )
- # Softplus does not expose ``inplace`` to callers; default to False.
- self._finalize_init(N_total, dtype, kernel, inplace=False)
-
- @property
- def default_kernel_map(self):
- return {"softplus": SoftplusFwdKernel}
-
-
-class PreluFwdOp(Op):
- """PReLU: y = x if x > 0 else weight[channel] * x.
-
- Channel dimension follows PyTorch convention: dimension 1 for inputs
- with ndim >= 2, dimension 0 for 1-D inputs.
-
- Args:
- shape: Shape of the input tensor (must have a channel dimension).
- dtype: Torch dtype.
- num_channels: Number of channels (weight length).
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "prelu"
- _wrapped = None
-
- def __init__(
- self,
- shape: tuple,
- dtype: torch.dtype,
- num_channels: int,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- self.shape = shape
- self.dtype = dtype
- self.num_channels = num_channels
- N_total = prod(shape)
- self.N_total = N_total
- # PyTorch PReLU: channel dim is 1 for ndim>=2, else 0
- inner_size = (prod(shape[2:]) if len(shape) > 2 else 1) if len(shape) >= 2 else 1
- self.inner_size = inner_size
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](N_total, num_channels, inner_size, dtype)
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"prelu": PreluFwdKernel}
-
- def _eager_forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- weight: torch.Tensor,
- ) -> torch.Tensor:
- orig_shape = input.shape
- result = self.kernel(
- input.contiguous().reshape(-1), weight.contiguous().reshape(-1),
- ).reshape(orig_shape)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- weight: torch.Tensor,
- ) -> torch.Tensor:
- if not input.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if input.numel() != self.N_total:
- raise ValueError(f"Expected {self.N_total} elements, got {input.numel()}")
- # ``weight`` is part of the manifest contract; validate device,
- # dtype, and length so a malformed weight fails fast at the op
- # boundary instead of corrupting the kernel.
- if not weight.is_cuda:
- raise ValueError("Weight must be a CUDA tensor")
- if weight.dtype != self.dtype:
- raise ValueError(
- f"Expected weight.dtype {self.dtype}, got {weight.dtype}"
- )
- if weight.numel() != self.num_channels:
- raise ValueError(
- f"Expected weight to have {self.num_channels} elements, "
- f"got {weight.numel()}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, weight, self._instance_key)
- return self._eager_forward(input, weight)
-
-
-class WhereFwdOp(Op):
- """Where: out = condition ? input : other (with full PyTorch broadcasting).
-
- Conforms to ``torch.where(condition, input, other)``: ``condition`` is a
- bool tensor and ``input`` / ``other`` may broadcast with each other and
- with ``condition`` to produce the output. The Op layer expands all
- three inputs to the broadcast shape and dispatches the existing flat
- where kernel on ``N_total = product(broadcast_shape)`` elements.
-
- Args:
- condition: Shape of the condition tensor (any shape broadcastable
- with ``input`` / ``other``).
- input: Shape of the value-when-true tensor.
- other: Shape of the value-when-false tensor.
- dtype: Torch dtype for ``input`` / ``other``.
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "where"
- _wrapped = None
-
- # Manifest declares ``input`` / ``other`` dtype as
- # ``float16 | bfloat16 | float32``. fp8 dtypes are not in the contract;
- # reject them at the op-layer signature so the impl matches the manifest.
- _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
-
- def __init__(
- self,
- condition: tuple,
- input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
- other: tuple,
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- if dtype not in self._SUPPORTED_DTYPES:
- names = ", ".join(str(dt) for dt in self._SUPPORTED_DTYPES)
- raise ValueError(
- f"WhereFwdOp does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.condition_shape = tuple(condition)
- self.input_shape = tuple(input)
- self.other_shape = tuple(other)
- self.dtype = dtype
- self.out_shape = tuple(
- torch.broadcast_shapes(self.condition_shape, self.input_shape, self.other_shape)
- )
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](self.N_total, dtype)
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"where": WhereFwdKernel}
-
- @staticmethod
- def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
- """Expand ``t`` to ``target_shape`` and return a contiguous flat view."""
- if tuple(t.shape) != tuple(target_shape):
- t = t.expand(target_shape)
- return t.contiguous().view(-1)
-
- def _eager_forward(
- self, condition: torch.Tensor, input: torch.Tensor, other: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- out_shape = self.out_shape if self.out_shape else (1,)
- cond_b = condition if condition.dtype == torch.bool else condition.bool()
- cond_flat = self._expand_flat(cond_b, out_shape).view(torch.uint8)
- x_flat = self._expand_flat(input, out_shape)
- y_flat = self._expand_flat(other, out_shape)
- result = self.kernel(cond_flat, x_flat, y_flat).view(out_shape if self.out_shape else ())
- return result
-
- def forward(
- self, condition: torch.Tensor, input: torch.Tensor, other: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- if not (condition.is_cuda and input.is_cuda and other.is_cuda):
- raise ValueError("Inputs must be CUDA tensors")
- if condition.dtype != torch.bool:
- raise ValueError(
- f"Expected condition.dtype torch.bool, got {condition.dtype}"
- )
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if other.dtype != self.dtype:
- raise ValueError(f"Expected other.dtype {self.dtype}, got {other.dtype}")
- if tuple(condition.shape) != self.condition_shape:
- raise ValueError(
- f"Expected condition.shape {self.condition_shape}, got {tuple(condition.shape)}"
- )
- if tuple(input.shape) != self.input_shape:
- raise ValueError(
- f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
- )
- if tuple(other.shape) != self.other_shape:
- raise ValueError(
- f"Expected other.shape {self.other_shape}, got {tuple(other.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(condition, input, other, self._instance_key)
- return self._eager_forward(condition, input, other)
-
-
-class LerpTensorFwdOp(Op):
- """Tensor-weight lerp: out = input + weight * (end - input).
-
- Conforms to the Tensor-weight overload of ``torch.lerp`` —
- ``torch.lerp(input, end, weight: Tensor)`` where ``weight`` is a
- Tensor that broadcasts together with ``input`` and ``end`` to the
- output shape. The Op layer expands the three inputs to the broadcast
- shape and dispatches the flat ``LerpTensorFwdKernel`` on
- ``N_total = product(broadcast_shape)`` elements. The scalar-weight
- overload is handled separately by ``LerpFwdOp``.
-
- Args:
- input: Shape of the start tensor.
- end: Shape of the end tensor.
- weight: Shape of the per-element weight tensor.
- dtype: Torch dtype for all three operands.
- """
-
- _op_name = "lerp_tensor"
- _wrapped = None
-
- # Manifest declares all three operands as ``float16 | bfloat16 | float32``;
- # fp8 dtypes are rejected at the op-layer signature so the impl matches
- # the manifest contract (the kernel also rejects fp8 independently).
- _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
-
- def __init__(
- self,
- *,
- input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
- end: tuple,
- weight: tuple,
- dtype: torch.dtype,
- strategy: Optional[str] = None,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if dtype not in self._SUPPORTED_DTYPES:
- names = ", ".join(str(dt) for dt in self._SUPPORTED_DTYPES)
- raise ValueError(
- f"LerpTensorFwdOp does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.input_shape = tuple(input)
- self.end_shape = tuple(end)
- self.weight_shape = tuple(weight)
- self.dtype = dtype
- self.strategy = strategy
- self.out_shape = tuple(
- torch.broadcast_shapes(
- self.input_shape, self.end_shape, self.weight_shape,
- )
- )
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](
- self.N_total, dtype, tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self) -> Dict[str, Kernel]:
- return {"lerp_tensor": LerpTensorFwdKernel}
-
- @staticmethod
- def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
- """Expand ``t`` to ``target_shape`` and return a contiguous flat view."""
- if tuple(t.shape) != tuple(target_shape):
- t = t.expand(target_shape)
- return t.contiguous().view(-1)
-
- def _eager_forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- end: torch.Tensor,
- weight: torch.Tensor,
- ) -> torch.Tensor:
- out_shape = self.out_shape if self.out_shape else (1,)
- a_flat = self._expand_flat(input, out_shape)
- b_flat = self._expand_flat(end, out_shape)
- w_flat = self._expand_flat(weight, out_shape)
- result = self.kernel(a_flat, b_flat, w_flat)
- return result.view(self.out_shape if self.out_shape else ())
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
- end: torch.Tensor,
- weight: torch.Tensor,
- ) -> torch.Tensor:
- if not (input.is_cuda and end.is_cuda and weight.is_cuda):
- raise ValueError("Inputs must be CUDA tensors")
- for name, t, expected in [
- ("input", input, self.input_shape),
- ("end", end, self.end_shape),
- ("weight", weight, self.weight_shape),
- ]:
- if t.dtype != self.dtype:
- raise ValueError(
- f"Expected {name}.dtype {self.dtype}, got {t.dtype}"
- )
- if tuple(t.shape) != expected:
- raise ValueError(
- f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, end, weight, self._instance_key)
- return self._eager_forward(input, end, weight)
-
-
-class _ClampTensorBase(Op):
- """Shared infrastructure for Tensor-bound clamp variants (broadcasting)."""
-
- _wrapped = None
-
- @staticmethod
- def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
- if tuple(t.shape) != tuple(target_shape):
- t = t.expand(target_shape)
- return t.contiguous().view(-1)
-
-
-class ClampFwdOp(_ClampTensorBase):
- """Clamp with Tensor lower and/or upper bounds (broadcasting).
-
- Conforms to ``torch.clamp(input, min, max)`` where ``min`` and ``max``
- are each either a Tensor or ``None``. At least one of the two bounds
- must be a Tensor. All Tensor operands broadcast together. The
- primary spec entry in ``tileops/manifest/`` covers the both-Tensor
- form; the mixed Tensor/``None`` cases are runtime-equivalent to
- ``ClampMinFwdOp`` / ``ClampMaxFwdOp`` and are accepted here so callers
- can mirror PyTorch's ``torch.clamp`` API directly.
-
- Args:
- input: Shape of the input tensor.
- min: Shape of the lower-bound tensor, or ``None`` for no lower bound.
- max: Shape of the upper-bound tensor, or ``None`` for no upper bound.
- dtype: Torch dtype for all operands.
-
- Raises:
- ValueError: If both ``min`` and ``max`` are ``None``.
- """
-
- _op_name = "clamp"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
- min: Optional[tuple] = None, # noqa: A002 — manifest-aligned PyTorch param name
- max: Optional[tuple] = None, # noqa: A002 — manifest-aligned PyTorch param name
- dtype: torch.dtype = torch.float32,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if min is None and max is None:
- raise ValueError(
- "ClampFwdOp requires at least one of `min` or `max` to be a "
- "Tensor shape; both None is not a valid clamp."
- )
- self.input_shape = tuple(input)
- self.min_shape = None if min is None else tuple(min)
- self.max_shape = None if max is None else tuple(max)
- self.dtype = dtype
- broadcast_args = [self.input_shape]
- if self.min_shape is not None:
- broadcast_args.append(self.min_shape)
- if self.max_shape is not None:
- broadcast_args.append(self.max_shape)
- self.out_shape = tuple(torch.broadcast_shapes(*broadcast_args))
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["clamp_tensor"](
- self.N_total, dtype,
- has_min=self.min_shape is not None,
- has_max=self.max_shape is not None,
- tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"clamp_tensor": ClampTensorFwdKernel}
-
- def _eager_forward(
- self,
- input: torch.Tensor, # noqa: A002
- min: Optional[torch.Tensor] = None, # noqa: A002
- max: Optional[torch.Tensor] = None, # noqa: A002
- ) -> torch.Tensor:
- # Broadcast all operands to ``out_shape`` and dispatch the
- # TileLang Tensor-bound clamp kernel. The kernel branches on
- # ``has_min`` / ``has_max`` at build time, so this single Op
- # class also covers the mixed Tensor/None cases.
- out_shape = self.out_shape if self.out_shape else (1,)
- x_flat = self._expand_flat(input, out_shape)
- lo_flat = None if min is None else self._expand_flat(min, out_shape)
- hi_flat = None if max is None else self._expand_flat(max, out_shape)
- result = self.kernel(x_flat, lo_flat, hi_flat)
- return result.view(self.out_shape if self.out_shape else ())
-
- def forward(
- self,
- input: torch.Tensor, # noqa: A002
- min: Optional[torch.Tensor] = None, # noqa: A002
- max: Optional[torch.Tensor] = None, # noqa: A002
- ) -> torch.Tensor:
- # Validate that the runtime None / Tensor pattern matches what
- # __init__ was configured for — the broadcast shape and the
- # presence of each bound is baked in at construction.
- if (min is None) != (self.min_shape is None):
- raise ValueError(
- f"min was {'None' if self.min_shape is None else 'a Tensor shape'} at "
- f"__init__ but {'None' if min is None else 'a Tensor'} at forward()"
- )
- if (max is None) != (self.max_shape is None):
- raise ValueError(
- f"max was {'None' if self.max_shape is None else 'a Tensor shape'} at "
- f"__init__ but {'None' if max is None else 'a Tensor'} at forward()"
- )
- tensors = [("input", input, self.input_shape)]
- if min is not None:
- tensors.append(("min", min, self.min_shape))
- if max is not None:
- tensors.append(("max", max, self.max_shape))
- for _, t, _ in tensors:
- if not t.is_cuda:
- raise ValueError("Inputs must be CUDA tensors")
- for name, t, expected in tensors:
- if t.dtype != self.dtype:
- raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
- if tuple(t.shape) != expected:
- raise ValueError(
- f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, min, max, self._instance_key)
- return self._eager_forward(input, min, max)
-
-
-class ClampMinFwdOp(_ClampTensorBase):
- """Single-bound Tensor lower clamp (``torch.clamp_min``).
-
- Args:
- input: Shape of the input tensor.
- min: Shape of the lower-bound tensor.
- dtype: Torch dtype.
- """
-
- _op_name = "clamp_min"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002
- min: tuple, # noqa: A002
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- self.input_shape = tuple(input)
- self.min_shape = tuple(min)
- self.dtype = dtype
- self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.min_shape))
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["clamp_tensor"](
- self.N_total, dtype, has_min=True, has_max=False, tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"clamp_tensor": ClampTensorFwdKernel}
-
- def _eager_forward(
- self, input: torch.Tensor, min: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- # Broadcast input/min to out_shape and dispatch the TileLang
- # min-only Tensor-bound clamp kernel.
- out_shape = self.out_shape if self.out_shape else (1,)
- x_flat = self._expand_flat(input, out_shape)
- lo_flat = self._expand_flat(min, out_shape)
- result = self.kernel(x_flat, lo_flat, None)
- return result.view(self.out_shape if self.out_shape else ())
-
- def forward(
- self, input: torch.Tensor, min: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- if not (input.is_cuda and min.is_cuda):
- raise ValueError("Inputs must be CUDA tensors")
- for name, t, expected in [
- ("input", input, self.input_shape),
- ("min", min, self.min_shape),
- ]:
- if t.dtype != self.dtype:
- raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
- if tuple(t.shape) != expected:
- raise ValueError(
- f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, min, self._instance_key)
- return self._eager_forward(input, min)
-
-
-class ClampMaxFwdOp(_ClampTensorBase):
- """Single-bound Tensor upper clamp (``torch.clamp_max``).
-
- Args:
- input: Shape of the input tensor.
- max: Shape of the upper-bound tensor.
- dtype: Torch dtype.
- """
-
- _op_name = "clamp_max"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002
- max: tuple, # noqa: A002
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- self.input_shape = tuple(input)
- self.max_shape = tuple(max)
- self.dtype = dtype
- self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.max_shape))
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["clamp_tensor"](
- self.N_total, dtype, has_min=False, has_max=True, tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"clamp_tensor": ClampTensorFwdKernel}
-
- def _eager_forward(
- self, input: torch.Tensor, max: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- # Broadcast input/max to out_shape and dispatch the TileLang
- # max-only Tensor-bound clamp kernel.
- out_shape = self.out_shape if self.out_shape else (1,)
- x_flat = self._expand_flat(input, out_shape)
- hi_flat = self._expand_flat(max, out_shape)
- result = self.kernel(x_flat, None, hi_flat)
- return result.view(self.out_shape if self.out_shape else ())
-
- def forward(
- self, input: torch.Tensor, max: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- if not (input.is_cuda and max.is_cuda):
- raise ValueError("Inputs must be CUDA tensors")
- for name, t, expected in [
- ("input", input, self.input_shape),
- ("max", max, self.max_shape),
- ]:
- if t.dtype != self.dtype:
- raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
- if tuple(t.shape) != expected:
- raise ValueError(
- f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, max, self._instance_key)
- return self._eager_forward(input, max)
-
-
-class ClampScalarFwdOp(Op):
- """Scalar-bound clamp (``torch.clamp(input, min: Number|None, max: Number|None)``).
-
- Args:
- input: Shape of the input tensor.
- min: Lower bound (Number or None).
- max: Upper bound (Number or None).
- dtype: Torch dtype.
- """
-
- _op_name = "clamp"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002
- min: Optional[float] = None, # noqa: A002
- max: Optional[float] = None, # noqa: A002
- dtype: torch.dtype = torch.float32,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- if min is None and max is None:
- raise ValueError(
- "ClampScalarFwdOp requires at least one of `min` or `max` to be a "
- "Number; both None is not a valid clamp."
- )
- if min is not None:
- _validate_scalar_param_repr("min", min, dtype, self._op_name)
- if max is not None:
- _validate_scalar_param_repr("max", max, dtype, self._op_name)
- self.input_shape = tuple(input)
- self.N_total = prod(self.input_shape) if self.input_shape else 1
- self.dtype = dtype
- self.min = min
- self.max = max
- # Backwards-compat aliases for legacy callers.
- self.min_val = min
- self.max_val = max
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["clamp"](
- self.N_total, dtype, min_val=min, max_val=max, tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"clamp": ClampFwdKernel}
-
- def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- orig_shape = input.shape
- result = self.kernel(input.contiguous().reshape(-1)).reshape(orig_shape)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- if not input.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if tuple(input.shape) != self.input_shape:
- raise ValueError(
- f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, self._instance_key)
- return self._eager_forward(input)
-
-
-class MaskedFillFwdOp(Op):
- """MaskedFill with 0-dim Tensor value (``torch.Tensor.masked_fill(mask, value: Tensor)``).
-
- Output shape is the bidirectional broadcast of ``input`` and ``mask``;
- ``value`` must be a 0-dim Tensor. The Op expands ``input`` and ``mask``
- to the broadcast shape and dispatches the existing flat scalar kernel
- using ``value.item()`` as the fill literal — this keeps the
- fast vectorized kernel path while satisfying the manifest's Tensor-value
- contract (the kernel reads ``value`` once at forward time, which is
- consistent with the 0-dim semantics).
-
- Args:
- input: Shape of the input tensor.
- mask: Shape of the mask tensor (bool).
- value: Shape of the value tensor (must be ``()`` per the manifest).
- dtype: Torch dtype for ``input`` / ``value``.
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "masked_fill"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002
- mask: tuple,
- value: tuple,
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- if tuple(value) != ():
- raise ValueError(
- f"MaskedFillFwdOp requires a 0-dim value Tensor; got shape {tuple(value)}"
- )
- self.input_shape = tuple(input)
- self.mask_shape = tuple(mask)
- self.value_shape = tuple(value)
- self.dtype = dtype
- self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.mask_shape))
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["masked_fill_tensor_value"](self.N_total, dtype)
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"masked_fill_tensor_value": MaskedFillTensorValueFwdKernel}
-
- @staticmethod
- def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
- if tuple(t.shape) != tuple(target_shape):
- t = t.expand(target_shape)
- return t.contiguous().view(-1)
-
- def _eager_forward(
- self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- # Broadcast input/mask to out_shape, pack mask as uint8, reshape
- # the 0-dim value to (1,), and dispatch the TileLang kernel.
- out_shape = self.out_shape if self.out_shape else (1,)
- x_flat = self._expand_flat(input, out_shape)
- mask_b = mask if mask.dtype == torch.bool else mask.bool()
- mask_flat = self._expand_flat(mask_b, out_shape).view(torch.uint8)
- value_1d = value.contiguous().view(1)
- result = self.kernel(x_flat, mask_flat, value_1d)
- return result.view(self.out_shape if self.out_shape else ())
-
- def forward(
- self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor, # noqa: A002
- ) -> torch.Tensor:
- if not (input.is_cuda and mask.is_cuda and value.is_cuda):
- raise ValueError("Inputs must be CUDA tensors")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if mask.dtype != torch.bool:
- raise ValueError(f"Expected mask.dtype torch.bool, got {mask.dtype}")
- if value.dtype != self.dtype:
- raise ValueError(f"Expected value.dtype {self.dtype}, got {value.dtype}")
- if tuple(input.shape) != self.input_shape:
- raise ValueError(
- f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
- )
- if tuple(mask.shape) != self.mask_shape:
- raise ValueError(
- f"Expected mask.shape {self.mask_shape}, got {tuple(mask.shape)}"
- )
- if tuple(value.shape) != ():
- raise ValueError(f"Expected value.shape (), got {tuple(value.shape)}")
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, mask, value, self._instance_key)
- return self._eager_forward(input, mask, value)
-
-
-class MaskedFillScalarFwdOp(Op):
- """MaskedFill with Number (scalar) value.
-
- Conforms to ``torch.Tensor.masked_fill(mask, value: Number)``. Output
- shape follows the bidirectional broadcast of ``input`` and ``mask``.
-
- The manifest declares the PyTorch dtype union (``bool | uint8 |
- int8 | int16 | int32 | int64 | float16 | bfloat16 | float32``). The
- current TileLang kernel only supports float dtypes; integer and
- bool dtypes are rejected at construction time with ``ValueError``
- until a real int / bool kernel lands (tracked in a follow-up issue).
-
- Args:
- input: Shape of the input tensor.
- mask: Shape of the mask tensor (bool).
- value: Scalar fill value (bool / int / float). Range-validated
- against ``dtype``.
- dtype: Torch dtype. Must be a kernel-supported floating-point
- dtype.
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "masked_fill"
- _wrapped = None
-
- def __init__(
- self,
- input: tuple, # noqa: A002
- mask: tuple,
- value: bool | int | float = 0,
- dtype: torch.dtype = torch.float32,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- kernel_supported = MaskedFillFwdKernel.SUPPORTED_DTYPES
- if kernel_supported is not None and dtype not in kernel_supported:
- names = ", ".join(str(dt) for dt in kernel_supported)
- raise ValueError(
- f"{self._op_name} does not support dtype {dtype}. "
- f"Supported: [{names}]"
- )
- self.input_shape = tuple(input)
- self.mask_shape = tuple(mask)
- self.dtype = dtype
- self.value = value
- # Backwards-compat alias.
- self.fill_value = value
- self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.mask_shape))
- self.N_total = prod(self.out_shape) if self.out_shape else 1
- # The kernel is always built on the broadcast (output) flat size.
- # When input/mask already match out_shape, this is a no-op expand;
- # otherwise the Op layer broadcasts both before dispatch.
- self._needs_broadcast = (
- self.input_shape != self.out_shape or self.mask_shape != self.out_shape
- )
- _validate_scalar_param_repr("value", value, dtype, self._op_name)
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map["masked_fill"](self.N_total, dtype, value)
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"masked_fill": MaskedFillFwdKernel}
-
- @staticmethod
- def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
- if tuple(t.shape) != tuple(target_shape):
- t = t.expand(target_shape)
- return t.contiguous().view(-1)
-
- def _eager_forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # noqa: A002
- out_shape = self.out_shape if self.out_shape else (1,)
- x_flat = self._expand_flat(input, out_shape)
- mask_b = mask if mask.dtype == torch.bool else mask.bool()
- mask_flat = self._expand_flat(mask_b, out_shape).view(torch.uint8)
- result = self.kernel(x_flat, mask_flat).view(self.out_shape if self.out_shape else ())
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # noqa: A002
- if not input.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if tuple(input.shape) != self.input_shape:
- raise ValueError(
- f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
- )
- if not mask.is_cuda:
- raise ValueError("Mask must be a CUDA tensor")
- if mask.dtype != torch.bool:
- raise ValueError(f"Expected mask.dtype torch.bool, got {mask.dtype}")
- if tuple(mask.shape) != self.mask_shape:
- raise ValueError(
- f"Expected mask.shape {self.mask_shape}, got {tuple(mask.shape)}"
- )
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, mask, self._instance_key)
- return self._eager_forward(input, mask)
-
-
-class NanToNumFwdOp(Op):
- """NanToNum: replace NaN, +Inf, -Inf with specified values.
-
- Args:
- N_total: Total number of elements (flattened).
- dtype: Torch dtype.
- nan: Replacement for NaN (default 0.0).
- posinf: Replacement for +Inf. Manifest default ``None`` resolves
- to the largest finite value representable in the user-facing
- ``dtype`` (matches ``torch.nan_to_num``). Explicit values
- must also be representable in ``dtype`` end-to-end; values
- that fit only in the kernel's intermediate dtype (e.g. fp16
- for fp8_e5m2) are rejected so the post-cast cannot resurface
- them as Inf.
- neginf: Replacement for -Inf. Manifest default ``None`` resolves
- to the smallest (most negative) finite value representable
- in the user-facing ``dtype``.
- kernel_map: Optional kernel dispatch override.
- tune: Whether to autotune the kernel.
- """
-
- _op_name = "nan_to_num"
- _wrapped = None
-
- def __init__(
- self,
- N_total: int,
- dtype: torch.dtype,
- nan: float = 0.0,
- posinf: Optional[float] = None,
- neginf: Optional[float] = None,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- tune: bool = False,
- ):
- # The manifest default ``None`` resolves to the *final*
- # user-facing dtype's max / min, not ``+/-inf``: the kernel runs
- # in ``output_dtype`` (fp16 for e5m2 to preserve Inf/NaN) and
- # _clamp_to_dtype_range targets that intermediate, so forwarding
- # ``+inf`` would resolve to fp16's 65504.0 and then surface as
- # ``+Inf`` after the e5m2 post-cast (e5m2 max is 57344.0).
- # Picking ``torch.finfo(dtype).max`` here keeps the replacement
- # value finite end-to-end and matches ``torch.nan_to_num``
- # semantics (replace Inf with the dtype's max finite value).
- _validate_scalar_param_repr("nan", nan, dtype, self._op_name)
- if posinf is None:
- kernel_posinf = torch.finfo(dtype).max
- else:
- _validate_scalar_param_repr("posinf", posinf, dtype, self._op_name)
- kernel_posinf = posinf
- if neginf is None:
- kernel_neginf = torch.finfo(dtype).min
- else:
- _validate_scalar_param_repr("neginf", neginf, dtype, self._op_name)
- kernel_neginf = neginf
- self.N_total = N_total
- self.dtype = dtype
- self.nan = nan
- self.posinf = posinf
- self.neginf = neginf
- self.dispatch_kernel(kernel_map)
- # Pass replacement values positionally; the kernel constructor's
- # internal parameter naming is encapsulated below the Op layer.
- self.kernel = self.kernel_map["nan_to_num"](
- N_total, dtype, nan, kernel_posinf, kernel_neginf, tune=tune,
- )
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"nan_to_num": NanToNumFwdKernel}
-
- def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- orig_shape = input.shape
- result = self.kernel(input.contiguous().reshape(-1)).reshape(orig_shape)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
- if not input.is_cuda:
- raise ValueError("Input must be a CUDA tensor")
- if input.dtype != self.dtype:
- raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
- if input.numel() != self.N_total:
- raise ValueError(f"Expected {self.N_total} elements, got {input.numel()}")
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(input, self._instance_key)
- return self._eager_forward(input)
-
-
-class AlibiFwdOp(Op):
- """ALiBi position encoding: bias[h, i, j] = -slope_h * |i - j|.
-
- Generates the full (num_heads, seq_len, seq_len) bias tensor.
-
- Args:
- seq_len: Sequence length.
- num_heads: Number of attention heads.
- dtype: Torch dtype.
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "alibi"
- _wrapped = None
-
- def __init__(
- self,
- seq_len: int,
- num_heads: int,
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- self.seq_len = seq_len
- self.num_heads = num_heads
- self.dtype = dtype
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](seq_len, num_heads, dtype)
- # Scalar tensor used as device/dtype carrier for torch.compile tracing
- self._device_carrier = torch.empty((), dtype=dtype, device="cuda")
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"alibi": AlibiFwdKernel}
-
- def _eager_forward(self) -> torch.Tensor:
- out = self.kernel()
- result = out.reshape(self.num_heads, self.seq_len, self.seq_len)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self) -> torch.Tensor:
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(
- self._device_carrier,
- self.num_heads, self.seq_len,
- self._instance_key,
- )
- return self._eager_forward()
-
-
-class SinusoidalFwdOp(Op):
- """Sinusoidal positional encoding from "Attention Is All You Need".
-
- Generates the full (seq_len, d_model) encoding tensor.
-
- Args:
- seq_len: Sequence length.
- d_model: Model dimension.
- dtype: Torch dtype.
- kernel_map: Optional dispatch override mapping kernel keys to
- ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
- """
-
- _op_name = "sinusoidal"
- _wrapped = None
-
- def __init__(
- self,
- seq_len: int,
- d_model: int,
- dtype: torch.dtype,
- *,
- kernel_map: Optional[Dict[str, Kernel]] = None,
- ):
- self.seq_len = seq_len
- self.d_model = d_model
- self.dtype = dtype
- self.dispatch_kernel(kernel_map)
- self.kernel = self.kernel_map[self._op_name](seq_len, d_model, dtype)
- # Scalar tensor used as device/dtype carrier for torch.compile tracing
- self._device_carrier = torch.empty((), dtype=dtype, device="cuda")
- self._instance_key = id(self)
- _OP_REGISTRY[self._instance_key] = self
-
- @property
- def default_kernel_map(self):
- return {"sinusoidal": SinusoidalFwdKernel}
-
- def _eager_forward(self) -> torch.Tensor:
- out = self.kernel()
- result = out.reshape(self.seq_len, self.d_model)
- return _apply_fp8_post_cast(result, self.kernel)
-
- def forward(self) -> torch.Tensor:
- wrapped = type(self)._wrapped
- if wrapped is not None:
- return wrapped(
- self._device_carrier,
- self.seq_len, self.d_model,
- self._instance_key,
- )
- return self._eager_forward()
-
-
-# ---------------------------------------------------------------------------
-# torch.compile registration for all 66 concrete ops
-# ---------------------------------------------------------------------------
-
-# --- Unary ops: float-preserving output (1 + 17 + 8 + 1 = 27 ops) ---
-for _cls in [
- ReluFwdOp,
- # math (17)
- ExpFwdOp, LogFwdOp, SqrtFwdOp, RsqrtFwdOp, AbsFwdOp, NegFwdOp, ReciprocalFwdOp, SignFwdOp,
- SinFwdOp, CosFwdOp, FloorFwdOp, CeilFwdOp, RoundFwdOp, TruncFwdOp, ErfFwdOp, Log1pFwdOp, Expm1FwdOp,
- # activations (8)
- GeluFwdOp, SiluFwdOp, SigmoidFwdOp, TanhFwdOp, HardswishFwdOp, HardsigmoidFwdOp, MishFwdOp, SeluFwdOp,
- # bitwise (1) -- output same dtype as input
- BitwiseNotFwdOp,
-]:
- _register_unary_custom_op(_cls)
-
-# --- Unary ops: bool output (4 ops) ---
-for _cls in [LogicalNotFwdOp, IsnanFwdOp, IsinfFwdOp, IsfiniteFwdOp]:
- _register_unary_custom_op(_cls, output_dtype_override=torch.bool)
-
-# --- Binary ops: same-dtype output (10 + 3 = 13 ops) ---
-for _cls in [
- # arithmetic (10)
- AddFwdOp, SubFwdOp, MulFwdOp, DivFwdOp, RemainderFwdOp, PowFwdOp, FloorDivideFwdOp,
- LerpFwdOp, MaximumFwdOp, MinimumFwdOp,
- # bitwise (3)
- BitwiseAndFwdOp, BitwiseOrFwdOp, BitwiseXorFwdOp,
-]:
- _register_binary_custom_op(_cls)
-
-# --- Binary ops: bool output (comparison 6 + logical 2 = 8 ops) ---
-for _cls in [
- EqFwdOp, NeFwdOp, GtFwdOp, LtFwdOp, GeFwdOp, LeFwdOp,
- LogicalAndFwdOp, LogicalOrFwdOp,
-]:
- _register_binary_custom_op(_cls, output_bool=True)
-
-# --- Fused gated ops (3 ops) ---
-for _cls in [SiluAndMulFwdOp, GeluAndMulFwdOp, GeluTanhAndMulFwdOp]:
- _register_fused_gated_custom_op(_cls)
-
-# --- Independent unary-like ops (6 ops: x -> y with baked params) ---
-# ClampScalarFwdOp is the scalar-bound clamp (single-tensor input + min/max
-# baked into __init__). The Tensor-bound ClampFwdOp / ClampMinFwdOp /
-# ClampMaxFwdOp variants register their own multi-input custom_ops below.
-for _cls in [
- LeakyReluFwdOp, EluFwdOp, HardtanhFwdOp, SoftplusFwdOp, ClampScalarFwdOp,
- NanToNumFwdOp,
-]:
- _register_unary_custom_op(_cls)
-
-# --- Inplace companions for activations declaring ``inplace`` ---
-# Each leaf below has ``inplace`` in its manifest signature. Register a
-# parallel ``_wrapped_inplace`` custom op with ``mutates_args=("x",)``
-# so ``forward(input)`` with ``self.inplace=True`` traces correctly
-# under ``torch.compile``.
-for _cls in [
- ReluFwdOp, SiluFwdOp, HardswishFwdOp, HardsigmoidFwdOp, MishFwdOp,
- SeluFwdOp, LeakyReluFwdOp, EluFwdOp, HardtanhFwdOp,
-]:
- _register_unary_inplace_custom_op(_cls)
-
-# --- PReLU op (1 op: x, weight -> y) ---
-_register_prelu_custom_op(PreluFwdOp)
-
-# --- Tensor-bound clamp variants (3 ops: multi-tensor inputs -> out) ---
-# Registered under distinct custom_op namespaces from ClampScalarFwdOp:
-# ``top::elementwise_clamp_tensor`` (Optional Tensor min/max),
-# ``top::elementwise_clamp_min`` and ``top::elementwise_clamp_max`` for the
-# single-bound variants. register_fake is broadcast-aware so
-# torch.compile(fullgraph=True) traces correctly for both same-shape and
-# broadcasting inputs.
-_register_clamp_tensor_custom_op(ClampFwdOp)
-_register_clamp_min_custom_op(ClampMinFwdOp)
-_register_clamp_max_custom_op(ClampMaxFwdOp)
-
-# --- MaskedFill variants (input, mask[, value] -> out) ---
-# Both register broadcast-aware fake functions so torch.compile(fullgraph=True)
-# works for same-shape and broadcasting inputs. The Tensor-value variant is
-# registered under a distinct ``_tensor_value`` namespace to avoid colliding
-# with the scalar variant's ``top::elementwise_masked_fill``.
-_register_masked_fill_custom_op(MaskedFillScalarFwdOp)
-_register_masked_fill_tensor_value_custom_op(MaskedFillFwdOp)
-
-# --- Where op (1 op: cond, x, y -> out) ---
-# The fake function is broadcast-aware so torch.compile(fullgraph=True)
-# traces correctly for both same-shape and broadcasting inputs.
-_register_where_custom_op(WhereFwdOp)
-
-# --- Tensor-weight lerp (1 op: input, end, weight -> out) ---
-# Registered under ``top::elementwise_lerp_tensor`` to avoid colliding with
-# the scalar ``LerpFwdOp``'s ``top::elementwise_binary_lerp`` namespace. The fake
-# function is broadcast-aware so torch.compile(fullgraph=True) traces
-# correctly for both same-shape and broadcasting inputs.
-_register_lerp_tensor_custom_op(LerpTensorFwdOp)
-
-# --- Generative ops (2 ops: no tensor input -> out) ---
-_register_generative_custom_op(
- AlibiFwdOp,
- out_shape_fn=lambda carrier, num_heads, seq_len: carrier.new_empty(
- (num_heads, seq_len, seq_len),
- ),
-)
-_register_generative_custom_op(
- SinusoidalFwdOp,
- out_shape_fn=lambda carrier, seq_len, d_model: carrier.new_empty(
- (seq_len, d_model),
- ),
-)
-
-# Clean up loop variable
-del _cls
diff --git a/tileops/ops/elementwise/__init__.py b/tileops/ops/elementwise/__init__.py
new file mode 100644
index 00000000..c3becccb
--- /dev/null
+++ b/tileops/ops/elementwise/__init__.py
@@ -0,0 +1,294 @@
+"""Elementwise op package.
+
+Re-exports every public symbol previously provided by the monolithic
+``tileops/ops/elementwise.py`` module so that
+``from tileops.ops.elementwise import `` continues to work.
+
+Concrete ops are organised one cluster per leaf module
+(``arithmetic.py``, ``activations.py``, ``clamp.py``, ...). Umbrella
+template classes (``UnaryOp`` / ``BinaryOp`` / ``FusedGatedOp``) and the
+shared registration / broadcast infrastructure live in ``_base.py``.
+
+Concrete ops register their ``torch.library.custom_op`` wrappers at
+package import time via the registration loops at the bottom of this
+module.
+"""
+
+import torch as _torch
+
+from ._base import (
+ BinaryOp,
+ FusedGatedOp,
+ UnaryOp,
+ _register_binary_custom_op,
+ _register_clamp_max_custom_op,
+ _register_clamp_min_custom_op,
+ _register_clamp_tensor_custom_op,
+ _register_fused_gated_custom_op,
+ _register_generative_custom_op,
+ _register_lerp_tensor_custom_op,
+ _register_masked_fill_custom_op,
+ _register_masked_fill_tensor_value_custom_op,
+ _register_prelu_custom_op,
+ _register_unary_custom_op,
+ _register_unary_inplace_custom_op,
+ _register_where_custom_op,
+ coalesce_broadcast_dims,
+)
+from .activations import (
+ EluFwdOp,
+ GeluFwdOp,
+ HardsigmoidFwdOp,
+ HardswishFwdOp,
+ HardtanhFwdOp,
+ LeakyReluFwdOp,
+ MishFwdOp,
+ ReluFwdOp,
+ SeluFwdOp,
+ SigmoidFwdOp,
+ SiluFwdOp,
+ SoftplusFwdOp,
+ TanhFwdOp,
+)
+from .alibi import AlibiFwdOp
+from .arithmetic import (
+ AddFwdOp,
+ DivFwdOp,
+ FloorDivideFwdOp,
+ LerpFwdOp,
+ LerpTensorFwdOp,
+ MaximumFwdOp,
+ MinimumFwdOp,
+ MulFwdOp,
+ PowFwdOp,
+ RemainderFwdOp,
+ SubFwdOp,
+)
+from .bitwise import (
+ BitwiseAndFwdOp,
+ BitwiseNotFwdOp,
+ BitwiseOrFwdOp,
+ BitwiseXorFwdOp,
+)
+from .clamp import ClampFwdOp, ClampMaxFwdOp, ClampMinFwdOp, ClampScalarFwdOp
+from .comparison import EqFwdOp, GeFwdOp, GtFwdOp, LeFwdOp, LtFwdOp, NeFwdOp
+from .fused_gated import GeluAndMulFwdOp, GeluTanhAndMulFwdOp, SiluAndMulFwdOp
+from .logical import LogicalAndFwdOp, LogicalNotFwdOp, LogicalOrFwdOp
+from .masked_fill import MaskedFillFwdOp, MaskedFillScalarFwdOp
+from .math_unary import (
+ AbsFwdOp,
+ CeilFwdOp,
+ CosFwdOp,
+ ErfFwdOp,
+ ExpFwdOp,
+ Expm1FwdOp,
+ FloorFwdOp,
+ Log1pFwdOp,
+ LogFwdOp,
+ NegFwdOp,
+ ReciprocalFwdOp,
+ RoundFwdOp,
+ RsqrtFwdOp,
+ SignFwdOp,
+ SinFwdOp,
+ SqrtFwdOp,
+ TruncFwdOp,
+)
+from .nan_to_num import NanToNumFwdOp
+from .predicates import IsfiniteFwdOp, IsinfFwdOp, IsnanFwdOp
+from .prelu import PreluFwdOp
+from .sinusoidal import SinusoidalFwdOp
+from .where import WhereFwdOp
+
+__all__ = [
+ "AbsFwdOp",
+ "AddFwdOp",
+ "AlibiFwdOp",
+ "BinaryOp",
+ "BitwiseAndFwdOp",
+ "BitwiseNotFwdOp",
+ "BitwiseOrFwdOp",
+ "BitwiseXorFwdOp",
+ "CeilFwdOp",
+ "ClampFwdOp",
+ "ClampMaxFwdOp",
+ "ClampMinFwdOp",
+ "ClampScalarFwdOp",
+ "CosFwdOp",
+ "DivFwdOp",
+ "EluFwdOp",
+ "EqFwdOp",
+ "ErfFwdOp",
+ "ExpFwdOp",
+ "Expm1FwdOp",
+ "FloorDivideFwdOp",
+ "FloorFwdOp",
+ "FusedGatedOp",
+ "GeFwdOp",
+ "GeluAndMulFwdOp",
+ "GeluFwdOp",
+ "GeluTanhAndMulFwdOp",
+ "GtFwdOp",
+ "HardsigmoidFwdOp",
+ "HardswishFwdOp",
+ "HardtanhFwdOp",
+ "IsfiniteFwdOp",
+ "IsinfFwdOp",
+ "IsnanFwdOp",
+ "LeFwdOp",
+ "LeakyReluFwdOp",
+ "LerpFwdOp",
+ "LerpTensorFwdOp",
+ "Log1pFwdOp",
+ "LogFwdOp",
+ "LogicalAndFwdOp",
+ "LogicalNotFwdOp",
+ "LogicalOrFwdOp",
+ "LtFwdOp",
+ "MaskedFillFwdOp",
+ "MaskedFillScalarFwdOp",
+ "MaximumFwdOp",
+ "MinimumFwdOp",
+ "MishFwdOp",
+ "MulFwdOp",
+ "NanToNumFwdOp",
+ "NeFwdOp",
+ "NegFwdOp",
+ "PowFwdOp",
+ "PreluFwdOp",
+ "ReciprocalFwdOp",
+ "ReluFwdOp",
+ "RemainderFwdOp",
+ "RoundFwdOp",
+ "RsqrtFwdOp",
+ "SeluFwdOp",
+ "SigmoidFwdOp",
+ "SignFwdOp",
+ "SiluAndMulFwdOp",
+ "SiluFwdOp",
+ "SinFwdOp",
+ "SinusoidalFwdOp",
+ "SoftplusFwdOp",
+ "SqrtFwdOp",
+ "SubFwdOp",
+ "TanhFwdOp",
+ "TruncFwdOp",
+ "UnaryOp",
+ "WhereFwdOp",
+ "coalesce_broadcast_dims",
+]
+
+
+# ---------------------------------------------------------------------------
+# torch.compile registration for all 66 concrete ops
+# ---------------------------------------------------------------------------
+
+# --- Unary ops: float-preserving output (1 + 17 + 8 + 1 = 27 ops) ---
+for _cls in [
+ ReluFwdOp,
+ # math (17)
+ ExpFwdOp, LogFwdOp, SqrtFwdOp, RsqrtFwdOp, AbsFwdOp, NegFwdOp, ReciprocalFwdOp, SignFwdOp,
+ SinFwdOp, CosFwdOp, FloorFwdOp, CeilFwdOp, RoundFwdOp, TruncFwdOp, ErfFwdOp, Log1pFwdOp, Expm1FwdOp,
+ # activations (8)
+ GeluFwdOp, SiluFwdOp, SigmoidFwdOp, TanhFwdOp, HardswishFwdOp, HardsigmoidFwdOp, MishFwdOp, SeluFwdOp,
+ # bitwise (1) -- output same dtype as input
+ BitwiseNotFwdOp,
+]:
+ _register_unary_custom_op(_cls)
+
+# --- Unary ops: bool output (4 ops) ---
+for _cls in [LogicalNotFwdOp, IsnanFwdOp, IsinfFwdOp, IsfiniteFwdOp]:
+ _register_unary_custom_op(_cls, output_dtype_override=_torch.bool)
+
+# --- Binary ops: same-dtype output (10 + 3 = 13 ops) ---
+for _cls in [
+ # arithmetic (10)
+ AddFwdOp, SubFwdOp, MulFwdOp, DivFwdOp, RemainderFwdOp, PowFwdOp, FloorDivideFwdOp,
+ LerpFwdOp, MaximumFwdOp, MinimumFwdOp,
+ # bitwise (3)
+ BitwiseAndFwdOp, BitwiseOrFwdOp, BitwiseXorFwdOp,
+]:
+ _register_binary_custom_op(_cls)
+
+# --- Binary ops: bool output (comparison 6 + logical 2 = 8 ops) ---
+for _cls in [
+ EqFwdOp, NeFwdOp, GtFwdOp, LtFwdOp, GeFwdOp, LeFwdOp,
+ LogicalAndFwdOp, LogicalOrFwdOp,
+]:
+ _register_binary_custom_op(_cls, output_bool=True)
+
+# --- Fused gated ops (3 ops) ---
+for _cls in [SiluAndMulFwdOp, GeluAndMulFwdOp, GeluTanhAndMulFwdOp]:
+ _register_fused_gated_custom_op(_cls)
+
+# --- Independent unary-like ops (6 ops: x -> y with baked params) ---
+# ClampScalarFwdOp is the scalar-bound clamp (single-tensor input + min/max
+# baked into __init__). The Tensor-bound ClampFwdOp / ClampMinFwdOp /
+# ClampMaxFwdOp variants register their own multi-input custom_ops below.
+for _cls in [
+ LeakyReluFwdOp, EluFwdOp, HardtanhFwdOp, SoftplusFwdOp, ClampScalarFwdOp,
+ NanToNumFwdOp,
+]:
+ _register_unary_custom_op(_cls)
+
+# --- Inplace companions for activations declaring ``inplace`` ---
+# Each leaf below has ``inplace`` in its manifest signature. Register a
+# parallel ``_wrapped_inplace`` custom op with ``mutates_args=("x",)``
+# so ``forward(input)`` with ``self.inplace=True`` traces correctly
+# under ``torch.compile``.
+for _cls in [
+ ReluFwdOp, SiluFwdOp, HardswishFwdOp, HardsigmoidFwdOp, MishFwdOp,
+ SeluFwdOp, LeakyReluFwdOp, EluFwdOp, HardtanhFwdOp,
+]:
+ _register_unary_inplace_custom_op(_cls)
+
+# --- PReLU op (1 op: x, weight -> y) ---
+_register_prelu_custom_op(PreluFwdOp)
+
+# --- Tensor-bound clamp variants (3 ops: multi-tensor inputs -> out) ---
+# Registered under distinct custom_op namespaces from ClampScalarFwdOp:
+# ``top::elementwise_clamp_tensor`` (Optional Tensor min/max),
+# ``top::elementwise_clamp_min`` and ``top::elementwise_clamp_max`` for the
+# single-bound variants. register_fake is broadcast-aware so
+# torch.compile(fullgraph=True) traces correctly for both same-shape and
+# broadcasting inputs.
+_register_clamp_tensor_custom_op(ClampFwdOp)
+_register_clamp_min_custom_op(ClampMinFwdOp)
+_register_clamp_max_custom_op(ClampMaxFwdOp)
+
+# --- MaskedFill variants (input, mask[, value] -> out) ---
+# Both register broadcast-aware fake functions so torch.compile(fullgraph=True)
+# works for same-shape and broadcasting inputs. The Tensor-value variant is
+# registered under a distinct ``_tensor_value`` namespace to avoid colliding
+# with the scalar variant's ``top::elementwise_masked_fill``.
+_register_masked_fill_custom_op(MaskedFillScalarFwdOp)
+_register_masked_fill_tensor_value_custom_op(MaskedFillFwdOp)
+
+# --- Where op (1 op: cond, x, y -> out) ---
+# The fake function is broadcast-aware so torch.compile(fullgraph=True)
+# traces correctly for both same-shape and broadcasting inputs.
+_register_where_custom_op(WhereFwdOp)
+
+# --- Tensor-weight lerp (1 op: input, end, weight -> out) ---
+# Registered under ``top::elementwise_lerp_tensor`` to avoid colliding with
+# the scalar ``LerpFwdOp``'s ``top::elementwise_binary_lerp`` namespace. The fake
+# function is broadcast-aware so torch.compile(fullgraph=True) traces
+# correctly for both same-shape and broadcasting inputs.
+_register_lerp_tensor_custom_op(LerpTensorFwdOp)
+
+# --- Generative ops (2 ops: no tensor input -> out) ---
+_register_generative_custom_op(
+ AlibiFwdOp,
+ out_shape_fn=lambda carrier, num_heads, seq_len: carrier.new_empty(
+ (num_heads, seq_len, seq_len),
+ ),
+)
+_register_generative_custom_op(
+ SinusoidalFwdOp,
+ out_shape_fn=lambda carrier, seq_len, d_model: carrier.new_empty(
+ (seq_len, d_model),
+ ),
+)
+
+# Clean up loop variable
+del _cls
diff --git a/tileops/ops/elementwise/_base.py b/tileops/ops/elementwise/_base.py
new file mode 100644
index 00000000..445b49eb
--- /dev/null
+++ b/tileops/ops/elementwise/_base.py
@@ -0,0 +1,1197 @@
+"""Elementwise op infrastructure: umbrella bases, helpers, registration factories.
+
+Three umbrella Op base classes:
+- UnaryOp: wraps UnaryKernel with reshape/flatten
+- BinaryOp: wraps BinaryKernel with broadcast coalescing
+- FusedGatedOp: wraps FusedGatedKernel with (M, 2N) layout
+
+torch.compile support:
+- Concrete ops are registered via @torch.library.custom_op at package load time
+- Three factory functions (_register_unary_custom_op, _register_binary_custom_op,
+ _register_fused_gated_custom_op) register every op; instances are looked up at
+ runtime via _OP_REGISTRY keyed by id(instance)
+
+Utility:
+- coalesce_broadcast_dims: reduces N-dim broadcast to minimal effective dims
+"""
+
+import inspect
+import math
+import weakref
+from math import prod
+from typing import Callable, Dict, List, Optional
+
+import torch
+
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+
+# ---------------------------------------------------------------------------
+# torch.compile registration factories
+#
+# Each factory creates a @torch.library.custom_op + register_fake pair.
+# Instances register themselves in _OP_REGISTRY keyed by integer id.
+# The custom_op receives this key and looks up the instance to call the
+# pre-built tilelang kernel. The key is a plain int so dynamo can trace
+# through forward() without hitting unsupported Python side-effects.
+# ---------------------------------------------------------------------------
+
+_OP_REGISTRY: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
+
+_FP8_NONSAT_OUTPUT_DTYPES = {
+ torch.float8_e5m2: torch.float16,
+}
+
+def _effective_scalar_kernel_dtype(dtype: torch.dtype) -> torch.dtype:
+ """Return the dtype used when scalar literals are materialized in kernels."""
+ return _FP8_NONSAT_OUTPUT_DTYPES.get(dtype, dtype)
+
+
+def _validate_scalar_param_repr(
+ param_name: str, value: float, dtype: torch.dtype, op_name: str,
+) -> None:
+ """Reject scalar params that cannot be represented in the user dtype.
+
+ Validation targets the *user-facing* ``dtype`` rather than the
+ intermediate ``_effective_scalar_kernel_dtype(dtype)``. For fp8
+ dtypes the kernel runs in fp16 to preserve Inf/NaN, but a value that
+ only fits in fp16 would surface as ``+/-Inf`` after the final fp8
+ post-cast. Validating against the user dtype keeps explicit
+ replacements finite end-to-end.
+ """
+ if not isinstance(value, (int, float)):
+ raise TypeError(f"{op_name} expected scalar {param_name} to be int/float, got {type(value)}")
+
+ finfo = torch.finfo(dtype)
+ value_f64 = float(value)
+ if math.isnan(value_f64):
+ return
+ if math.isinf(value_f64):
+ raise ValueError(
+ f"{op_name} received {param_name}={value!r}, but {param_name} must be finite and "
+ f"representable in dtype {dtype}"
+ )
+ if not (finfo.min <= value_f64 <= finfo.max):
+ raise ValueError(
+ f"{op_name} received {param_name}={value!r}, which is not representable in "
+ f"dtype {dtype} (valid finite range: "
+ f"[{finfo.min}, {finfo.max}])"
+ )
+
+
+def _register_unary_custom_op(op_cls, output_dtype_override=None):
+ """Register a unary elementwise op for torch.compile.
+
+ Args:
+ op_cls: The Op subclass to register (must have ``_op_name``).
+ output_dtype_override: If set, the output dtype (e.g. torch.bool for predicates).
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_unary_{op_name}", mutates_args=())
+ def _wrapped(x: torch.Tensor, instance_key: int) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(x)
+
+ @_wrapped.register_fake
+ def _(x: torch.Tensor, instance_key: int) -> torch.Tensor:
+ out_dtype = output_dtype_override if output_dtype_override is not None else x.dtype
+ return torch.empty_like(x, dtype=out_dtype)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_unary_inplace_custom_op(op_cls):
+ """Register the ``inplace=True`` companion for a unary activation op.
+
+ The kernel writes into a fresh buffer; this wrapper copies the result
+ back into ``x`` and returns ``x`` so the caller sees ``y is x`` and
+ ``x`` carries the activation output. The custom op is registered with
+ ``mutates_args=("x",)`` so ``torch.compile`` traces the mutation
+ correctly. Sets ``op_cls._wrapped_inplace`` for ``forward()`` to
+ dispatch through.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(
+ f"top::elementwise_unary_{op_name}_inplace", mutates_args=("x",),
+ )
+ def _wrapped_inplace(x: torch.Tensor, instance_key: int) -> None:
+ instance = _OP_REGISTRY[instance_key]
+ result = instance._eager_forward(x)
+ x.copy_(result.reshape(x.shape))
+
+ op_cls._wrapped_inplace = _wrapped_inplace
+
+
+def _register_binary_custom_op(op_cls, output_bool: bool = False):
+ """Register a binary elementwise op for torch.compile.
+
+ Args:
+ op_cls: The Op subclass to register.
+ output_bool: If True, output dtype is torch.bool (for comparison/logical ops).
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_binary_{op_name}", mutates_args=())
+ def _wrapped(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ out_shape: List[int],
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(a, b)
+
+ @_wrapped.register_fake
+ def _(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ out_shape: List[int],
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_dtype = torch.bool if output_bool else a.dtype
+ return a.new_empty(out_shape, dtype=out_dtype)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_prelu_custom_op(op_cls):
+ """Register a PReLU-style op (x, weight -> y) for torch.compile."""
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(x, weight)
+
+ @_wrapped.register_fake
+ def _(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ return torch.empty_like(x)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_where_custom_op(op_cls):
+ """Register a where-style op (cond, x, y -> out) for torch.compile.
+
+ The fake function computes the broadcast output shape from
+ ``cond`` / ``x`` / ``y`` so that ``torch.compile(fullgraph=True)``
+ works for both same-shape and broadcasting inputs.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ cond: torch.Tensor,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(cond, x, y)
+
+ @_wrapped.register_fake
+ def _(
+ cond: torch.Tensor,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(cond.shape, x.shape, y.shape)
+ return x.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_lerp_tensor_custom_op(op_cls):
+ """Register a Tensor-weight lerp op (input, end, weight -> out).
+
+ The fake function computes the broadcast output shape from ``input`` /
+ ``end`` / ``weight`` so that ``torch.compile(fullgraph=True)`` works
+ for both same-shape and broadcasting inputs. Registered under a
+ distinct ``_tensor`` namespace to avoid colliding with the scalar
+ ``LerpFwdOp`` (which bakes ``weight`` at construction time and uses
+ the binary registration path).
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ end: torch.Tensor,
+ weight: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(input, end, weight)
+
+ @_wrapped.register_fake
+ def _(
+ input: torch.Tensor, # noqa: A002
+ end: torch.Tensor,
+ weight: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(input.shape, end.shape, weight.shape)
+ return input.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_masked_fill_custom_op(op_cls):
+ """Register a masked-fill-style op (x, mask -> y) for torch.compile.
+
+ The fake function computes the bidirectional broadcast output shape
+ of ``x`` and ``mask`` so ``torch.compile(fullgraph=True)`` works for
+ both same-shape and broadcasting inputs.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(x, mask)
+
+ @_wrapped.register_fake
+ def _(
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(x.shape, mask.shape)
+ return x.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_masked_fill_tensor_value_custom_op(op_cls):
+ """Register a masked-fill (Tensor value) op (input, mask, value -> out).
+
+ The fake function computes the broadcast output shape of ``input`` and
+ ``mask`` (``value`` is a 0-dim Tensor). Registered under a distinct
+ namespace from the scalar masked_fill variant to avoid collision.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(
+ f"top::elementwise_{op_name}_tensor_value", mutates_args=(),
+ )
+ def _wrapped(
+ input: torch.Tensor, # noqa: A002
+ mask: torch.Tensor,
+ value: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(input, mask, value)
+
+ @_wrapped.register_fake
+ def _(
+ input: torch.Tensor, # noqa: A002
+ mask: torch.Tensor,
+ value: torch.Tensor,
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(input.shape, mask.shape)
+ return input.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_clamp_tensor_custom_op(op_cls):
+ """Register a Tensor-bound clamp op (input, min?, max? -> out).
+
+ ``min`` and ``max`` are each ``Optional[Tensor]``; the schema is
+ inferred by ``torch.library.custom_op`` from the ``Optional[torch.Tensor]``
+ annotations, producing ``Tensor? min, Tensor? max`` in the underlying
+ custom-op schema. The fake function computes the broadcast output
+ shape of all non-``None`` operands so ``torch.compile(fullgraph=True)``
+ works for both same-shape and broadcasting inputs. Registered under
+ a distinct ``_tensor`` namespace from the scalar-bound clamp variant.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(
+ f"top::elementwise_{op_name}_tensor", mutates_args=(),
+ )
+ def _wrapped(
+ input: torch.Tensor, # noqa: A002
+ min: Optional[torch.Tensor], # noqa: A002
+ max: Optional[torch.Tensor], # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(input, min, max)
+
+ @_wrapped.register_fake
+ def _(
+ input: torch.Tensor, # noqa: A002
+ min: Optional[torch.Tensor], # noqa: A002
+ max: Optional[torch.Tensor], # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ shapes = [input.shape]
+ if min is not None:
+ shapes.append(min.shape)
+ if max is not None:
+ shapes.append(max.shape)
+ out_shape = torch.broadcast_shapes(*shapes)
+ return input.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_clamp_min_custom_op(op_cls):
+ """Register single-bound Tensor lower-clamp (input, min -> out)."""
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ input: torch.Tensor, # noqa: A002
+ min: torch.Tensor, # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(input, min)
+
+ @_wrapped.register_fake
+ def _(
+ input: torch.Tensor, # noqa: A002
+ min: torch.Tensor, # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(input.shape, min.shape)
+ return input.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_clamp_max_custom_op(op_cls):
+ """Register single-bound Tensor upper-clamp (input, max -> out)."""
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ input: torch.Tensor, # noqa: A002
+ max: torch.Tensor, # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(input, max)
+
+ @_wrapped.register_fake
+ def _(
+ input: torch.Tensor, # noqa: A002
+ max: torch.Tensor, # noqa: A002
+ instance_key: int,
+ ) -> torch.Tensor:
+ out_shape = torch.broadcast_shapes(input.shape, max.shape)
+ return input.new_empty(out_shape)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_generative_custom_op(op_cls, out_shape_fn):
+ """Register a generative op (no tensor input -> out) for torch.compile.
+
+ A scalar ``device_carrier`` tensor is passed so that ``register_fake``
+ can derive the correct device and dtype from a real tensor reference,
+ which is required by the torch.compile tracing infrastructure.
+
+ Args:
+ op_cls: The Op subclass to register.
+ out_shape_fn: Callable(carrier, num_a, num_b) -> Tensor returning
+ the output metadata so register_fake can produce the right shape.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_{op_name}", mutates_args=())
+ def _wrapped(
+ device_carrier: torch.Tensor,
+ num_a: int,
+ num_b: int,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward()
+
+ @_wrapped.register_fake
+ def _(
+ device_carrier: torch.Tensor,
+ num_a: int,
+ num_b: int,
+ instance_key: int,
+ ) -> torch.Tensor:
+ return out_shape_fn(device_carrier, num_a, num_b)
+
+ op_cls._wrapped = _wrapped
+
+
+def _register_fused_gated_custom_op(op_cls):
+ """Register a fused gated elementwise op for torch.compile.
+
+ Args:
+ op_cls: The Op subclass to register.
+ """
+ op_name = op_cls._op_name
+
+ @torch.library.custom_op(f"top::elementwise_fused_gated_{op_name}", mutates_args=())
+ def _wrapped(
+ x: torch.Tensor,
+ M: int,
+ N: int,
+ instance_key: int,
+ ) -> torch.Tensor:
+ instance = _OP_REGISTRY[instance_key]
+ return instance._eager_forward(x)
+
+ @_wrapped.register_fake
+ def _(
+ x: torch.Tensor,
+ M: int,
+ N: int,
+ instance_key: int,
+ ) -> torch.Tensor:
+ return x.new_empty((M, N), dtype=x.dtype)
+
+ op_cls._wrapped = _wrapped
+
+
+def coalesce_broadcast_dims(a_shape, b_shape):
+ """Coalesce N-dim broadcast into minimal effective dimensions.
+
+ Merges adjacent dimensions that have the same broadcast behaviour
+ (both real or both broadcast) to minimise the number of divmod
+ operations inside the kernel loop.
+
+ Args:
+ a_shape: Shape tuple of input a.
+ b_shape: Shape tuple of input b.
+
+ Returns:
+ Tuple of (out_shape, coalesced_shape, a_strides, b_strides) where
+ strides use 0 for broadcast dimensions.
+ """
+ # Normalise scalar (0-dim) inputs to 1-dim with size 1
+ if len(a_shape) == 0:
+ a_shape = (1,)
+ if len(b_shape) == 0:
+ b_shape = (1,)
+
+ out_shape = torch.broadcast_shapes(a_shape, b_shape)
+ ndim = len(out_shape)
+ a_pad = (1,) * (ndim - len(a_shape)) + tuple(a_shape)
+ b_pad = (1,) * (ndim - len(b_shape)) + tuple(b_shape)
+
+ def _make_strides(padded_shape):
+ strides = [1] * ndim
+ for i in range(ndim - 2, -1, -1):
+ strides[i] = strides[i + 1] * padded_shape[i + 1]
+ # Only zero strides for genuinely broadcast dims (size-1 expanded to >1)
+ return [
+ 0 if padded_shape[i] == 1 and out_shape[i] > 1 else strides[i]
+ for i in range(ndim)
+ ]
+
+ a_raw = _make_strides(a_pad)
+ b_raw = _make_strides(b_pad)
+
+ # Coalesce adjacent dims with compatible broadcast patterns
+ groups = [(out_shape[0], a_raw[0], b_raw[0])]
+ for i in range(1, ndim):
+ prev_out, prev_as, prev_bs = groups[-1]
+ a_can = (a_raw[i] == 0 and prev_as == 0) or (
+ a_raw[i] != 0 and prev_as == a_raw[i] * out_shape[i]
+ )
+ b_can = (b_raw[i] == 0 and prev_bs == 0) or (
+ b_raw[i] != 0 and prev_bs == b_raw[i] * out_shape[i]
+ )
+ if a_can and b_can:
+ groups[-1] = (prev_out * out_shape[i], a_raw[i], b_raw[i])
+ else:
+ groups.append((out_shape[i], a_raw[i], b_raw[i]))
+
+ # Remove trivial size-1 groups (unless all trivial)
+ groups = [g for g in groups if g[0] > 1] or [(1, 0, 0)]
+ coalesced_shape = tuple(g[0] for g in groups)
+ a_strides = tuple(g[1] for g in groups)
+ b_strides = tuple(g[2] for g in groups)
+ return out_shape, coalesced_shape, a_strides, b_strides
+
+
+def _apply_fp8_post_cast(result: torch.Tensor, kernel) -> torch.Tensor:
+ """Apply fp8 output cast if the kernel requires it.
+
+ For e5m2 dtypes the kernel produces fp16 output to preserve Inf/NaN;
+ this helper performs the final non-saturating cast via PyTorch.
+ """
+ fp8_out = getattr(kernel, "_fp8_output_dtype", None)
+ if fp8_out is not None:
+ return result.to(fp8_out)
+ return result
+
+
+_FP8_DTYPES = (torch.float8_e4m3fn, torch.float8_e5m2)
+
+
+def _is_fp8(dtype: torch.dtype) -> bool:
+ """Return True iff ``dtype`` is one of the supported fp8 dtypes."""
+ return dtype in _FP8_DTYPES
+
+
+def _fp8_compute_dtype(dtype: torch.dtype) -> torch.dtype:
+ """Return the compute dtype used to emulate fp8 elementwise fallbacks.
+
+ PyTorch's CUDA backend does not implement ``clamp``/``maximum``/
+ ``minimum``/``masked_fill_`` on Float8 tensors (raises NotImplementedError
+ on ``clamp_cuda`` / ``max_elementwise_cuda`` / ``min_elementwise_cuda`` /
+ ``masked_fill_``). Both e4m3fn (finite range ±448) and e5m2 (finite range
+ ±57344) fit in fp16, so we upcast to fp16, run the op, and cast back. The
+ final cast preserves Inf/NaN for e5m2 (PyTorch's fp16->e5m2 conversion is
+ non-saturating) and saturates for e4m3fn (matching PyTorch's default
+ fp16->e4m3fn behaviour).
+ """
+ if not _is_fp8(dtype):
+ raise ValueError(f"_fp8_compute_dtype expects an fp8 dtype, got {dtype}")
+ return torch.float16
+
+
+class UnaryOp(Op):
+ """Template base class for unary elementwise ops.
+
+ Subclass must set ``kernel_cls`` and ``_op_name``.
+ Subclass should also set ``_wrapped`` via ``_register_unary_custom_op``
+ to enable torch.compile support.
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ strategy: Kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune.
+ """
+
+ kernel_cls: type
+ _op_name: str
+ _wrapped = None # Set by _register_unary_custom_op at class definition
+ # Per-element FLOP count, matching the manifest's ``roofline.flops``
+ # coefficient on ``N``. Subclasses override when the op is more than one
+ # arithmetic op per element (e.g. ``sigmoid`` ≈ 4, ``tanh`` ≈ 5). The
+ # base class default of 1 covers the common ``flops: "N"`` entries.
+ FLOPS_PER_ELEM: int = 1
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ self.N_total = N_total
+ self.dtype = dtype
+ self.strategy = strategy
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](
+ N_total, dtype, strategy=strategy, tune=tune,
+ )
+ # Use _fp8_output_dtype (the final dtype after Op-layer post-cast)
+ # rather than kernel.output_dtype (which is fp16 for e5m2).
+ fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
+ self.output_dtype = fp8_out or getattr(self.kernel, "output_dtype", dtype)
+ # Register in global registry for torch.compile dispatch
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self) -> Dict[str, Kernel]:
+ return {self._op_name: self.kernel_cls}
+
+ @property
+ def total_memory(self) -> float:
+ """Read x + write y."""
+ return self.N_total * (self.dtype.itemsize + self.output_dtype.itemsize)
+
+ def eval_roofline(self) -> tuple[int, int]:
+ """Return ``(flops, bytes)`` for this unary elementwise op instance.
+
+ Mirrors the elementwise_unary_math manifest roofline:
+ ``flops = FLOPS_PER_ELEM * N`` and
+ ``bytes = N * input_elem_bytes + N * output_elem_bytes``. Subclasses
+ whose manifest entry uses a higher coefficient (e.g. ``sigmoid`` →
+ ``4 * N``, ``tanh`` → ``5 * N``) override ``FLOPS_PER_ELEM``. For ops
+ whose output dtype matches the input (e.g. ``neg``, ``abs``), bytes
+ collapse to ``2 * N * elem_bytes``; for ops with a smaller output
+ dtype (e.g. ``isnan`` / ``isinf`` / ``isfinite`` / ``logical_not`` →
+ bool), ``self.output_dtype.itemsize`` already captures the difference.
+ """
+ return self.FLOPS_PER_ELEM * self.N_total, int(self.total_memory)
+
+ def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ """Direct kernel call for use inside custom_op implementation."""
+ orig_shape = input.shape
+ flat = input.contiguous().reshape(-1)
+ result = self.kernel(flat).reshape(orig_shape)
+ # For e5m2: kernel produces fp16 to preserve Inf/NaN;
+ # cast to e5m2 here using PyTorch's non-saturating conversion.
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def _validate_input(self, input: torch.Tensor) -> None: # noqa: A002
+ """Validate input tensor against the op's dtype / numel contract."""
+ if not input.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if input.dtype != self.dtype:
+ raise ValueError(
+ f"Expected input.dtype {self.dtype}, got {input.dtype}"
+ )
+ if input.numel() != self.N_total:
+ raise ValueError(
+ f"Expected {self.N_total} elements, got {input.numel()}"
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ self._validate_input(input)
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, self._instance_key)
+ return self._eager_forward(input)
+
+
+class BinaryOp(Op):
+ """Template base class for binary elementwise ops with broadcast.
+
+ Subclass must set ``kernel_cls`` and ``_op_name``.
+ Subclass should also set ``_wrapped`` via ``_register_binary_custom_op``
+ to enable torch.compile support.
+
+ Args:
+ a_shape: Shape of input a.
+ b_shape: Shape of input b.
+ dtype: Torch dtype.
+ strategy: Kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune.
+ """
+
+ kernel_cls: type
+ _op_name: str
+ _wrapped = None # Set by _register_binary_custom_op at class definition
+ # Subclasses may set ``_other_name`` to a manifest-aligned parameter
+ # name (e.g. ``"exponent"`` for ``PowFwdOp``, ``"end"`` for
+ # ``LerpFwdOp``); the L1 signature check sees the renamed parameter
+ # via ``__init_subclass__`` rebinding ``forward.__signature__``.
+ _other_name: str = "other"
+
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ other_name = cls.__dict__.get("_other_name")
+ if other_name is None or other_name == "other":
+ return
+ base_forward = cls.forward
+ try:
+ sig = inspect.signature(base_forward)
+ except (ValueError, TypeError):
+ return
+ new_params = [
+ p.replace(name=other_name) if p.name == "other" else p
+ for p in sig.parameters.values()
+ ]
+ new_sig = sig.replace(parameters=new_params)
+
+ def forward(self, *args, **kwargs):
+ if other_name in kwargs:
+ kwargs["other"] = kwargs.pop(other_name)
+ return base_forward(self, *args, **kwargs)
+
+ forward.__signature__ = new_sig
+ forward.__name__ = "forward"
+ forward.__qualname__ = f"{cls.__qualname__}.forward"
+ cls.forward = forward
+
+ def __init__(
+ self,
+ a_shape: tuple,
+ b_shape: tuple,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ kernel_supported = self.kernel_cls.SUPPORTED_DTYPES
+ if kernel_supported is not None and dtype not in kernel_supported:
+ names = ", ".join(str(dt) for dt in kernel_supported)
+ raise ValueError(
+ f"{self._op_name} does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.dtype = dtype
+ self.a_shape = tuple(a_shape)
+ self.b_shape = tuple(b_shape)
+ self.strategy = strategy
+ out_shape, coalesced_shape, a_strides, b_strides = coalesce_broadcast_dims(
+ a_shape, b_shape,
+ )
+ self.out_shape = out_shape
+ self._out_shape_list = list(out_shape) # cached for custom_op hot path
+ self.N_total = prod(out_shape)
+ self.a_numel = prod(a_shape)
+ self.b_numel = prod(b_shape)
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](
+ self.N_total, dtype, coalesced_shape, a_strides, b_strides,
+ self.a_numel, self.b_numel, strategy=strategy, tune=tune,
+ )
+ # Register in global registry for torch.compile dispatch
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self) -> Dict[str, Kernel]:
+ return {self._op_name: self.kernel_cls}
+
+ @property
+ def total_memory(self) -> float:
+ """Read a + read b + write y."""
+ in_elem = self.dtype.itemsize
+ fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
+ out_elem = fp8_out.itemsize if fp8_out is not None else in_elem
+ return (self.a_numel + self.b_numel) * in_elem + self.N_total * out_elem
+
+ def _eager_forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ """Direct kernel call for use inside custom_op implementation."""
+ result = self.kernel(
+ input.contiguous().view(-1), other.contiguous().view(-1),
+ ).reshape(self.out_shape)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ a_name = getattr(self, "_input_name", "input")
+ b_name = getattr(self, "_other_name", "other")
+ if not input.is_cuda or not other.is_cuda:
+ raise ValueError("Inputs must be CUDA tensors")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected {a_name}.dtype {self.dtype}, got {input.dtype}")
+ if other.dtype != self.dtype:
+ raise ValueError(f"Expected {b_name}.dtype {self.dtype}, got {other.dtype}")
+ if input.numel() != self.a_numel:
+ raise ValueError(
+ f"Expected {a_name} to have {self.a_numel} elements, got {input.numel()}"
+ )
+ if other.numel() != self.b_numel:
+ raise ValueError(
+ f"Expected {b_name} to have {self.b_numel} elements, got {other.numel()}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, other, self._out_shape_list, self._instance_key)
+ return self._eager_forward(input, other)
+
+
+class FusedGatedOp(Op):
+ """Template base class for fused gated elementwise ops.
+
+ Input: x of shape (M, 2*N). gate = x[:, :N], value = x[:, N:].
+ Output: y = activation(gate) * value, shape (M, N).
+
+ Subclass must set ``kernel_cls`` and ``_op_name``.
+ Subclass should also set ``_wrapped`` via ``_register_fused_gated_custom_op``
+ to enable torch.compile support.
+
+ Args:
+ M: Number of rows.
+ N: Half column dim (output width).
+ dtype: Torch dtype.
+ strategy: Kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune.
+ """
+
+ kernel_cls: type
+ _op_name: str
+ _wrapped = None # Set by _register_fused_gated_custom_op at class definition
+
+ def __init__(
+ self,
+ M: int,
+ N: int,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ supported = self.kernel_cls.SUPPORTED_DTYPES
+ if supported is not None and dtype not in supported:
+ names = ", ".join(str(dt) for dt in supported)
+ raise ValueError(
+ f"{self._op_name} does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.M = M
+ self.N = N
+ self.dtype = dtype
+ self.strategy = strategy
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](
+ M, N, dtype, strategy=strategy, tune=tune,
+ )
+ # Register in global registry for torch.compile dispatch
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self) -> Dict[str, Kernel]:
+ return {self._op_name: self.kernel_cls}
+
+ @property
+ def total_memory(self) -> float:
+ """Read x (M*2N) + write y (M*N)."""
+ in_elem = self.dtype.itemsize
+ fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
+ out_elem = fp8_out.itemsize if fp8_out is not None else in_elem
+ return self.M * 2 * self.N * in_elem + self.M * self.N * out_elem
+
+ def _eager_forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Direct kernel call for use inside custom_op implementation."""
+ x = x.contiguous()
+ result = self.kernel(x)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if not x.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if x.dtype != self.dtype:
+ raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}")
+ if x.shape != (self.M, 2 * self.N):
+ raise ValueError(
+ f"Expected shape ({self.M}, {2 * self.N}), got {tuple(x.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(x, self.M, self.N, self._instance_key)
+ return self._eager_forward(x)
+
+
+# ---------------------------------------------------------------------------
+# Intermediate (private) base classes shared by leaf op modules
+# ---------------------------------------------------------------------------
+
+
+class _UnaryActivationMixin:
+ """Shared ``forward`` / inplace dispatch for unary activation Ops.
+
+ The ten unary activation Ops (six param-free: ReLU, SiLU, HardSwish,
+ HardSigmoid, Mish, SELU; four parametric: LeakyReLU, ELU, Hardtanh,
+ Softplus) share an identical ``forward`` template:
+
+ 1. validate ``input`` against the op's ``dtype`` / ``N_total`` contract,
+ 2. when ``self.inplace`` is true, dispatch through ``_wrapped_inplace``
+ (registered with ``mutates_args=("x",)`` so ``torch.compile`` traces
+ the mutation correctly) and return the original ``input`` so callers
+ see ``y is x``,
+ 3. otherwise dispatch through the standard ``_wrapped`` custom op or
+ fall back to ``_eager_forward``.
+
+ Concrete classes provide ``_validate_input`` and ``_eager_forward``
+ (both inherited from ``UnaryOp``) plus ``self.inplace`` /
+ ``self._instance_key`` state. Leaves that do not expose ``inplace``
+ in their signature (e.g. Softplus) simply default ``self.inplace`` to
+ ``False`` via ``_finalize_init``.
+ """
+
+ # Set by ``_register_unary_inplace_custom_op`` for leaves that
+ # declare ``inplace`` in their manifest signature. Stays ``None``
+ # when the leaf does not support inplace (e.g. Softplus, or a
+ # test-only subclass that skipped registration).
+ _wrapped_inplace = None
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ self._validate_input(input)
+ if self.inplace:
+ wrapped_inplace = type(self)._wrapped_inplace
+ if wrapped_inplace is not None:
+ wrapped_inplace(input, self._instance_key)
+ return input
+ # No inplace custom op registered (e.g. test-only subclass);
+ # fall back to direct mutation via the eager path.
+ result = self._eager_forward(input)
+ input.copy_(result.reshape(input.shape))
+ return input
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, self._instance_key)
+ return self._eager_forward(input)
+
+
+class _ParamFreeActivationOp(_UnaryActivationMixin, UnaryOp):
+ """Shared base for the param-free activation Op group.
+
+ Centralizes the canonical constructor used by activations whose only
+ manifest-declared parameter is ``inplace`` (ReLU, SiLU, HardSwish,
+ HardSigmoid, Mish, SELU). Each leaf only declares its op-specific
+ class fields (``_op_name``, ``kernel_cls``, ``FLOPS_PER_ELEM``,
+ docstring); ``forward``/``_eager_forward`` come from
+ ``_UnaryActivationMixin`` / ``UnaryOp``.
+ """
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ *,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ inplace: bool = False,
+ ):
+ super().__init__(
+ N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
+ )
+ self.inplace = inplace
+
+
+class _ParametricActivationOp(_UnaryActivationMixin, UnaryOp):
+ """Shared base for the parametric activation Op group.
+
+ Used by activations that take one or more scalar construction-time
+ parameters (LeakyReLU, ELU, Hardtanh, Softplus). Leaves own their
+ ``__init__`` (scalar parameter names and defaults vary per leaf):
+ each leaf validates its scalars, populates ``self.`` for
+ introspection, instantiates ``self.kernel`` with typed kwargs, and
+ registers itself with ``_OP_REGISTRY`` via the
+ ``_finalize_init`` helper. ``UnaryOp.__init__`` is intentionally
+ bypassed; ``_finalize_init`` performs the equivalent state setup.
+
+ Leaves that declare ``inplace`` in the manifest signature accept it
+ in ``__init__`` and pass it to ``_finalize_init``. ``forward`` and
+ ``_eager_forward`` are inherited from the mixin and ``UnaryOp``.
+ """
+
+ def _finalize_init(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ kernel: Kernel,
+ *,
+ inplace: bool = False,
+ ) -> None:
+ """Record the leaf-built kernel and wire shared base state.
+
+ The leaf has already called ``self.dispatch_kernel(kernel_map)``
+ and instantiated its kernel directly with typed kwargs. This
+ helper records the kernel on ``self`` and runs the
+ ``_OP_REGISTRY`` registration shared by every parametric leaf.
+ """
+ self.N_total = N_total
+ self.dtype = dtype
+ self.inplace = inplace
+ self.kernel = kernel
+ # Mirror ``UnaryOp.__init__``: surface ``output_dtype`` so callers
+ # and ``total_memory`` can reason about FP8 post-casts. Parametric
+ # activations do not currently declare an FP8 path, so the common
+ # branch returns ``self.dtype``; the lookup is kept for parity.
+ fp8_out = getattr(self.kernel, "_fp8_output_dtype", None)
+ self.output_dtype = fp8_out or getattr(self.kernel, "output_dtype", dtype)
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+
+class _AlphaScaledBinaryOp(BinaryOp):
+ """Shared base for ops that take a scalar ``alpha`` multiplier on ``other``.
+
+ PyTorch ``torch.add(input, other, alpha=1)`` and ``torch.sub(input,
+ other, alpha=1)`` scale ``other`` by ``alpha`` before the binary op.
+ The current kernel only honors the manifest-declared default
+ (``alpha == 1``); non-default ``alpha`` values raise
+ ``NotImplementedError`` until a kernel-side scalar multiplier lands.
+ The leading ``*`` makes ``alpha`` and the existing
+ ``strategy`` / ``kernel_map`` / ``tune`` parameters keyword-only;
+ only the positional triplet ``(a_shape, b_shape, dtype)`` is shared
+ with ``BinaryOp``.
+ """
+
+ def __init__(
+ self,
+ a_shape: tuple,
+ b_shape: tuple,
+ dtype: torch.dtype,
+ *,
+ alpha: int | float = 1,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ super().__init__(
+ a_shape, b_shape, dtype, strategy=strategy,
+ kernel_map=kernel_map, tune=tune,
+ )
+ self.alpha = alpha
+
+
+class _BoolOutputBinaryOp(BinaryOp):
+ """Binary op base whose kernel emits int8 (1/0) and whose Op output is bool.
+
+ TileLang cannot vectorize bool, so the kernel produces int8. The Op
+ casts to ``torch.bool`` after the kernel call. ``register_fake``
+ already declares ``torch.bool`` as the output dtype, so the
+ ``torch.compile`` path stays consistent.
+ """
+
+ def _eager_forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ result = super()._eager_forward(input, other)
+ return result.to(torch.bool)
+
+
+_MANIFEST_INT_DTYPES = (
+ torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
+)
+
+
+def _int_identity(input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ return input.clone()
+
+
+def _int_all_false(input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ return torch.zeros(input.shape, dtype=torch.bool, device=input.device)
+
+
+def _int_all_true(input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ return torch.ones(input.shape, dtype=torch.bool, device=input.device)
+
+
+_PREDICATE_FALLBACK_DTYPES = _MANIFEST_INT_DTYPES + (torch.bool,)
+
+
+class _IntIdentityUnaryOp(UnaryOp):
+ """Base for unary ops whose manifest declares integer dtypes but whose
+ kernel is float-only.
+
+ Several manifest entries (floor / ceil / round / trunc, abs / neg / sign,
+ isnan / isinf / isfinite) declare both integer and floating-point input
+ dtypes, while the underlying ``*FwdKernel`` classes are float-only
+ (``FloatUnaryKernel``). For integer inputs we short-circuit at the op
+ layer: skip kernel construction in ``__init__`` and route through
+ ``_int_handler`` in ``_eager_forward``.
+
+ Subclasses override ``_int_handler`` (default = identity = ``input.clone()``)
+ and ``_int_output_dtype`` (default = same as input) to express the
+ appropriate integer semantics — e.g. ``torch.abs`` for ``AbsFwdOp``,
+ constant-False ``torch.bool`` for ``IsnanFwdOp``.
+
+ The short-circuit is restricted to the integer dtypes declared in the
+ manifest. Other non-float dtypes (bool, complex) are not in the
+ contract and fall through to ``UnaryOp.__init__``, which raises via the
+ kernel's dtype check.
+ """
+
+ _int_handler: Callable[[torch.Tensor], torch.Tensor] = staticmethod(
+ _int_identity)
+ _int_output_dtype: Optional[torch.dtype] = None
+ # Subclasses may extend the fallback dtype set when the manifest
+ # signature includes additional non-float dtypes (e.g. torch.bool for
+ # the is{nan,inf,finite} predicates).
+ _fallback_dtypes: tuple = _MANIFEST_INT_DTYPES
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if dtype in type(self)._fallback_dtypes:
+ self.N_total = N_total
+ self.dtype = dtype
+ self.strategy = strategy
+ # The float-only kernel cannot be instantiated for an integer
+ # dtype, so the kernel itself stays unconstructed. The kernel_map
+ # is still installed through the shared validate-and-install path
+ # so a user-supplied override is arch-checked identically to the
+ # auto-discovered map on the float path.
+ self._install_kernel_map(kernel_map)
+ self.kernel = None
+ self.output_dtype = (
+ type(self)._int_output_dtype
+ if type(self)._int_output_dtype is not None
+ else dtype
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+ return
+ super().__init__(
+ N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
+ )
+
+ def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ if self.kernel is None:
+ return type(self)._int_handler(input)
+ return super()._eager_forward(input)
+
+
+class _GeluApproximateBase(UnaryOp):
+ """Intermediate base that resolves the manifest ``approximate`` field.
+
+ Validates the ``approximate`` argument against the manifest's allowed
+ values (``'none'`` / ``'tanh'``), records it on ``self.approximate``
+ for introspection, and then delegates to ``UnaryOp.__init__``. The
+ ``default_kernel_map`` of the leaf op picks the kernel implementation
+ from ``self.approximate``.
+ """
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ *,
+ approximate: str = "none",
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if approximate not in ("none", "tanh"):
+ raise ValueError(
+ f"{type(self).__name__}: approximate must be 'none' or "
+ f"'tanh', got {approximate!r}"
+ )
+ self.approximate = approximate
+ super().__init__(
+ N_total, dtype, strategy=strategy, kernel_map=kernel_map, tune=tune,
+ )
+
+
+class _ClampTensorBase(Op):
+ """Shared infrastructure for Tensor-bound clamp variants (broadcasting)."""
+
+ _wrapped = None
+
+ @staticmethod
+ def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
+ if tuple(t.shape) != tuple(target_shape):
+ t = t.expand(target_shape)
+ return t.contiguous().view(-1)
diff --git a/tileops/ops/elementwise/activations.py b/tileops/ops/elementwise/activations.py
new file mode 100644
index 00000000..a9c1b5b7
--- /dev/null
+++ b/tileops/ops/elementwise/activations.py
@@ -0,0 +1,307 @@
+"""Activation elementwise ops (ReLU + parametric/param-free families)."""
+
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import (
+ EluFwdKernel,
+ GeluFwdKernel,
+ GeluTanhFwdKernel,
+ HardsigmoidFwdKernel,
+ HardswishFwdKernel,
+ HardtanhFwdKernel,
+ LeakyReluFwdKernel,
+ MishFwdKernel,
+ ReluFwdKernel,
+ SeluFwdKernel,
+ SigmoidFwdKernel,
+ SiluFwdKernel,
+ SoftplusFwdKernel,
+ TanhFwdKernel,
+)
+from tileops.kernels.kernel_base import Kernel
+
+from ._base import (
+ UnaryOp,
+ _GeluApproximateBase,
+ _ParametricActivationOp,
+ _ParamFreeActivationOp,
+ _validate_scalar_param_repr,
+)
+
+
+class ReluFwdOp(_ParamFreeActivationOp):
+ """ReLU activation: y = max(x, 0)."""
+
+ _op_name = "relu"
+ kernel_cls = ReluFwdKernel
+ # Manifest: flops = "2 * N" (compare + select per element).
+ FLOPS_PER_ELEM = 2
+
+
+class GeluFwdOp(_GeluApproximateBase):
+ """Element-wise GELU honoring the manifest ``approximate`` contract.
+
+ Args:
+ N_total: Number of elements (flattened input).
+ dtype: Torch dtype.
+ approximate: Approximation mode. ``'none'`` (default) routes to
+ the erf-based ``GeluFwdKernel``. ``'tanh'`` routes to
+ ``GeluTanhFwdKernel`` (the fused tanh approximation
+ ``0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))``).
+ strategy: Optional kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "gelu"
+ kernel_cls = GeluFwdKernel
+ # Manifest: flops = "8 * N" (erf-based: mul + erf + add + mul + mul ≈ 8;
+ # tanh approximation is similar order, see manifest comment).
+ FLOPS_PER_ELEM = 8
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @property
+ def default_kernel_map(self) -> Dict[str, Kernel]:
+ kernel_cls = (
+ GeluTanhFwdKernel if self.approximate == "tanh" else GeluFwdKernel
+ )
+ return {self._op_name: kernel_cls}
+
+
+class SiluFwdOp(_ParamFreeActivationOp):
+ """Element-wise SiLU (Swish): y = x * sigmoid(x)."""
+
+ _op_name = "silu"
+ kernel_cls = SiluFwdKernel
+ # Manifest: flops = "4 * N" (sigmoid + multiply).
+ FLOPS_PER_ELEM = 4
+
+
+class SigmoidFwdOp(UnaryOp):
+ """Element-wise sigmoid(x)."""
+
+ _op_name = "sigmoid"
+ kernel_cls = SigmoidFwdKernel
+ # Manifest: flops = "4 * N" (sigmoid(x) = 1 / (1 + exp(-x)) ≈ 4 ops/elem).
+ FLOPS_PER_ELEM = 4
+
+
+class TanhFwdOp(UnaryOp):
+ """Element-wise tanh(x)."""
+
+ _op_name = "tanh"
+ kernel_cls = TanhFwdKernel
+ # Manifest: flops = "5 * N" (tanh(x) = 2 * sigmoid(2x) - 1 ≈ 5 ops/elem).
+ FLOPS_PER_ELEM = 5
+
+
+class HardswishFwdOp(_ParamFreeActivationOp):
+ """Element-wise HardSwish: y = x * clamp(x + 3, 0, 6) / 6."""
+
+ _op_name = "hardswish"
+ kernel_cls = HardswishFwdKernel
+ # Manifest: flops = "7 * N" (add + clamp(2 cmp+2 sel) + mul + div).
+ FLOPS_PER_ELEM = 7
+
+
+class HardsigmoidFwdOp(_ParamFreeActivationOp):
+ """Element-wise HardSigmoid: y = clamp(x + 3, 0, 6) / 6."""
+
+ _op_name = "hardsigmoid"
+ kernel_cls = HardsigmoidFwdKernel
+ # Manifest: flops = "6 * N" (add + clamp(2 cmp+2 sel) + div).
+ FLOPS_PER_ELEM = 6
+
+
+class MishFwdOp(_ParamFreeActivationOp):
+ """Element-wise Mish: y = x * tanh(softplus(x))."""
+
+ _op_name = "mish"
+ kernel_cls = MishFwdKernel
+ # Manifest: flops = "7 * N" (softplus + tanh + mul).
+ FLOPS_PER_ELEM = 7
+
+
+class SeluFwdOp(_ParamFreeActivationOp):
+ """Element-wise SELU activation."""
+
+ _op_name = "selu"
+ kernel_cls = SeluFwdKernel
+ # Manifest: flops = "5 * N" (branch + exp/sub/mul + lambda mul).
+ FLOPS_PER_ELEM = 5
+
+
+class LeakyReluFwdOp(_ParametricActivationOp):
+ """Leaky ReLU: y = x if x > 0 else negative_slope * x.
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ negative_slope: Slope for negative inputs (default 0.01).
+ inplace: When True, copy the result back into ``input`` and
+ return ``input`` (preserving tensor identity). The kernel
+ still computes into a fresh buffer; only the user-visible
+ tensor is mutated, mirroring ``torch.nn.functional.leaky_relu``.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "leaky_relu"
+ _wrapped = None
+ # Manifest: flops = "3 * N" (compare + mul + select).
+ FLOPS_PER_ELEM = 3
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ negative_slope: float = 0.01,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ inplace: bool = False,
+ ):
+ _validate_scalar_param_repr("negative_slope", negative_slope, dtype, self._op_name)
+ self.negative_slope = negative_slope
+ self.dispatch_kernel(kernel_map)
+ kernel = self.kernel_map[self._op_name](
+ N_total, dtype, negative_slope=negative_slope, tune=tune,
+ )
+ self._finalize_init(N_total, dtype, kernel, inplace=inplace)
+
+ @property
+ def default_kernel_map(self):
+ return {"leaky_relu": LeakyReluFwdKernel}
+
+
+class EluFwdOp(_ParametricActivationOp):
+ """ELU: y = x if x > 0 else alpha * (exp(x) - 1).
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ alpha: Scale for the negative part (default 1.0).
+ inplace: When True, copy the result back into ``input`` and
+ return ``input`` (preserving tensor identity).
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "elu"
+ _wrapped = None
+ # Manifest: flops = "5 * N" (compare + (exp + sub + mul) + branch select).
+ FLOPS_PER_ELEM = 5
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ alpha: float = 1.0,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ inplace: bool = False,
+ ):
+ _validate_scalar_param_repr("alpha", alpha, dtype, self._op_name)
+ self.alpha = alpha
+ self.dispatch_kernel(kernel_map)
+ kernel = self.kernel_map[self._op_name](
+ N_total, dtype, alpha=alpha, tune=tune,
+ )
+ self._finalize_init(N_total, dtype, kernel, inplace=inplace)
+
+ @property
+ def default_kernel_map(self):
+ return {"elu": EluFwdKernel}
+
+
+class HardtanhFwdOp(_ParametricActivationOp):
+ """Hardtanh: y = clamp(x, min_val, max_val).
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ min_val: Lower bound (default -1.0).
+ max_val: Upper bound (default 1.0).
+ inplace: When True, copy the result back into ``input`` and
+ return ``input`` (preserving tensor identity).
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "hardtanh"
+ _wrapped = None
+ # Manifest: flops = "4 * N" (2 compares + 2 selects per element).
+ FLOPS_PER_ELEM = 4
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ min_val: float = -1.0,
+ max_val: float = 1.0,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ inplace: bool = False,
+ ):
+ _validate_scalar_param_repr("min_val", min_val, dtype, self._op_name)
+ _validate_scalar_param_repr("max_val", max_val, dtype, self._op_name)
+ self.min_val = min_val
+ self.max_val = max_val
+ self.dispatch_kernel(kernel_map)
+ kernel = self.kernel_map[self._op_name](
+ N_total, dtype, min_val=min_val, max_val=max_val, tune=tune,
+ )
+ self._finalize_init(N_total, dtype, kernel, inplace=inplace)
+
+ @property
+ def default_kernel_map(self):
+ return {"hardtanh": HardtanhFwdKernel}
+
+
+class SoftplusFwdOp(_ParametricActivationOp):
+ """Softplus: y = log(1 + exp(x*beta))/beta if x*beta <= threshold else x.
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ beta: Scaling factor (default 1.0).
+ threshold: Linear regime threshold (default 20.0).
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "softplus"
+ _wrapped = None
+ # Manifest: flops = "7 * N" (mul + exp + add + log + div + compare + select).
+ FLOPS_PER_ELEM = 7
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ beta: float = 1.0,
+ threshold: float = 20.0,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ _validate_scalar_param_repr("beta", beta, dtype, self._op_name)
+ _validate_scalar_param_repr("threshold", threshold, dtype, self._op_name)
+ self.beta = beta
+ self.threshold = threshold
+ self.dispatch_kernel(kernel_map)
+ kernel = self.kernel_map[self._op_name](
+ N_total, dtype, beta=beta, threshold=threshold, tune=tune,
+ )
+ # Softplus does not expose ``inplace`` to callers; default to False.
+ self._finalize_init(N_total, dtype, kernel, inplace=False)
+
+ @property
+ def default_kernel_map(self):
+ return {"softplus": SoftplusFwdKernel}
diff --git a/tileops/ops/elementwise/alibi.py b/tileops/ops/elementwise/alibi.py
new file mode 100644
index 00000000..02977586
--- /dev/null
+++ b/tileops/ops/elementwise/alibi.py
@@ -0,0 +1,65 @@
+"""ALiBi position-encoding generative op."""
+
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import AlibiFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY, _apply_fp8_post_cast
+
+
+class AlibiFwdOp(Op):
+ """ALiBi position encoding: bias[h, i, j] = -slope_h * |i - j|.
+
+ Generates the full (num_heads, seq_len, seq_len) bias tensor.
+
+ Args:
+ seq_len: Sequence length.
+ num_heads: Number of attention heads.
+ dtype: Torch dtype.
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "alibi"
+ _wrapped = None
+
+ def __init__(
+ self,
+ seq_len: int,
+ num_heads: int,
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ self.seq_len = seq_len
+ self.num_heads = num_heads
+ self.dtype = dtype
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](seq_len, num_heads, dtype)
+ # Scalar tensor used as device/dtype carrier for torch.compile tracing
+ self._device_carrier = torch.empty((), dtype=dtype, device="cuda")
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"alibi": AlibiFwdKernel}
+
+ def _eager_forward(self) -> torch.Tensor:
+ out = self.kernel()
+ result = out.reshape(self.num_heads, self.seq_len, self.seq_len)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self) -> torch.Tensor:
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(
+ self._device_carrier,
+ self.num_heads, self.seq_len,
+ self._instance_key,
+ )
+ return self._eager_forward()
diff --git a/tileops/ops/elementwise/arithmetic.py b/tileops/ops/elementwise/arithmetic.py
new file mode 100644
index 00000000..02d1c65f
--- /dev/null
+++ b/tileops/ops/elementwise/arithmetic.py
@@ -0,0 +1,358 @@
+"""Binary arithmetic elementwise ops with broadcasting."""
+
+from math import prod
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import (
+ AddFwdKernel,
+ DivFwdKernel,
+ FloorDivideFwdKernel,
+ LerpFwdKernel,
+ LerpTensorFwdKernel,
+ MaximumFwdKernel,
+ MinimumFwdKernel,
+ MulFwdKernel,
+ PowFwdKernel,
+ RemainderFwdKernel,
+ SubFwdKernel,
+)
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import (
+ _OP_REGISTRY,
+ BinaryOp,
+ _AlphaScaledBinaryOp,
+ coalesce_broadcast_dims,
+)
+
+
+class AddFwdOp(_AlphaScaledBinaryOp):
+ """Element-wise addition with broadcast: y = input + alpha * other.
+
+ Conforms to ``torch.add(input, other, *, alpha=1)``. Only ``alpha == 1``
+ dispatches to the kernel; non-default ``alpha`` raises
+ ``NotImplementedError`` until a kernel-side scalar multiplier lands
+ (tracked in a follow-up issue).
+ """
+
+ _op_name = "add"
+ kernel_cls = AddFwdKernel
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.alpha != 1:
+ raise NotImplementedError(
+ "AddFwdOp(alpha != 1) is not yet implemented; the current "
+ "kernel only honors alpha == 1. A follow-up "
+ "issue tracks the kernel work."
+ )
+ return super().forward(input, other)
+
+
+class SubFwdOp(_AlphaScaledBinaryOp):
+ """Element-wise subtraction with broadcast: y = input - alpha * other.
+
+ Conforms to ``torch.sub(input, other, *, alpha=1)``. Only ``alpha == 1``
+ dispatches to the kernel; non-default ``alpha`` raises
+ ``NotImplementedError`` until a kernel-side scalar multiplier lands
+ (tracked in a follow-up issue).
+ """
+
+ _op_name = "sub"
+ kernel_cls = SubFwdKernel
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.alpha != 1:
+ raise NotImplementedError(
+ "SubFwdOp(alpha != 1) is not yet implemented; the current "
+ "kernel only honors alpha == 1. A follow-up "
+ "issue tracks the kernel work."
+ )
+ return super().forward(input, other)
+
+
+class MulFwdOp(BinaryOp):
+ """Element-wise multiplication with broadcast: y = input * other."""
+
+ _op_name = "mul"
+ kernel_cls = MulFwdKernel
+
+
+class DivFwdOp(BinaryOp):
+ """Element-wise division with broadcast: y = input / other.
+
+ Conforms to ``torch.div(input, other, *, rounding_mode=None)``.
+ ``rounding_mode`` accepts ``None`` (true division), ``"trunc"``
+ (truncation toward zero), or ``"floor"`` (floor division). Only
+ ``rounding_mode is None`` dispatches to the kernel; the trunc /
+ floor variants raise ``NotImplementedError`` until a rounded-divide
+ kernel lands (tracked in a follow-up issue). The leading ``*``
+ makes ``rounding_mode`` and the existing ``strategy`` /
+ ``kernel_map`` / ``tune`` parameters keyword-only; only the
+ positional triplet ``(a_shape, b_shape, dtype)`` is shared with
+ ``BinaryOp``.
+ """
+
+ _op_name = "div"
+ kernel_cls = DivFwdKernel
+
+ def __init__(
+ self,
+ a_shape: tuple,
+ b_shape: tuple,
+ dtype: torch.dtype,
+ *,
+ rounding_mode: Optional[str] = None,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if rounding_mode is not None and rounding_mode not in ("trunc", "floor"):
+ raise ValueError(
+ f"DivFwdOp received rounding_mode={rounding_mode!r}; "
+ "manifest allows None, 'trunc', or 'floor'"
+ )
+ super().__init__(
+ a_shape, b_shape, dtype, strategy=strategy,
+ kernel_map=kernel_map, tune=tune,
+ )
+ self.rounding_mode = rounding_mode
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002
+ other: torch.Tensor,
+ ) -> torch.Tensor:
+ if self.rounding_mode is not None:
+ raise NotImplementedError(
+ f"DivFwdOp(rounding_mode={self.rounding_mode!r}) is not yet "
+ "implemented; the current kernel only honors rounding_mode is "
+ "None. A follow-up issue tracks the kernel work."
+ )
+ return super().forward(input, other)
+
+
+class RemainderFwdOp(BinaryOp):
+ """Element-wise remainder with broadcast: y = a % b."""
+
+ _op_name = "remainder"
+ kernel_cls = RemainderFwdKernel
+
+
+class PowFwdOp(BinaryOp):
+ """Element-wise power with broadcast: y = input ** exponent.
+
+ Conforms to ``torch.pow(input, exponent)``: the second operand carries
+ the manifest-declared name ``exponent`` rather than the generic
+ ``other`` so the L1 signature check matches the manifest.
+ """
+
+ _op_name = "pow"
+ kernel_cls = PowFwdKernel
+ _other_name = "exponent"
+
+
+class FloorDivideFwdOp(BinaryOp):
+ """Element-wise floor division with broadcast: y = floor(a / b)."""
+
+ _op_name = "floor_divide"
+ kernel_cls = FloorDivideFwdKernel
+
+
+class LerpFwdOp(BinaryOp):
+ """Element-wise lerp with broadcast: y = a + weight * (b - a).
+
+ Unlike ``torch.lerp(a, b, weight)`` where weight is a runtime parameter,
+ here weight is a **construction-time constant** baked into the compiled
+ kernel. This enables compile-time folding but means a new Op instance is
+ needed for each distinct weight value.
+
+ Args:
+ a_shape: Shape of input a.
+ b_shape: Shape of input b.
+ dtype: Torch dtype.
+ weight: Scalar interpolation weight, fixed at construction (default 0.5).
+ strategy: Kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune.
+ """
+
+ _op_name = "lerp"
+ kernel_cls = LerpFwdKernel
+ _other_name = "end"
+
+ def __init__(
+ self,
+ a_shape: tuple,
+ b_shape: tuple,
+ dtype: torch.dtype,
+ weight: float = 0.5,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ supported = self.kernel_cls.SUPPORTED_DTYPES
+ if supported is not None and dtype not in supported:
+ names = ", ".join(str(dt) for dt in supported)
+ raise ValueError(
+ f"{self._op_name} does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.dtype = dtype
+ self.a_shape = tuple(a_shape)
+ self.b_shape = tuple(b_shape)
+ self.strategy = strategy
+ self._weight = weight
+ out_shape, coalesced_shape, a_strides, b_strides = coalesce_broadcast_dims(
+ a_shape, b_shape,
+ )
+ self.out_shape = out_shape
+ self._out_shape_list = list(out_shape) # cached for custom_op hot path
+ self.N_total = prod(out_shape)
+ self.a_numel = prod(a_shape)
+ self.b_numel = prod(b_shape)
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](
+ self.N_total, dtype, coalesced_shape, a_strides, b_strides,
+ self.a_numel, self.b_numel, strategy=strategy, tune=tune,
+ weight=weight,
+ )
+ # Register in global registry for torch.compile dispatch
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+
+class MaximumFwdOp(BinaryOp):
+ """Element-wise maximum with broadcast: y = max(a, b)."""
+
+ _op_name = "maximum"
+ kernel_cls = MaximumFwdKernel
+
+
+class MinimumFwdOp(BinaryOp):
+ """Element-wise minimum with broadcast: y = min(a, b)."""
+
+ _op_name = "minimum"
+ kernel_cls = MinimumFwdKernel
+
+
+class LerpTensorFwdOp(Op):
+ """Tensor-weight lerp: out = input + weight * (end - input).
+
+ Conforms to the Tensor-weight overload of ``torch.lerp`` —
+ ``torch.lerp(input, end, weight: Tensor)`` where ``weight`` is a
+ Tensor that broadcasts together with ``input`` and ``end`` to the
+ output shape. The Op layer expands the three inputs to the broadcast
+ shape and dispatches the flat ``LerpTensorFwdKernel`` on
+ ``N_total = product(broadcast_shape)`` elements. The scalar-weight
+ overload is handled separately by ``LerpFwdOp``.
+
+ Args:
+ input: Shape of the start tensor.
+ end: Shape of the end tensor.
+ weight: Shape of the per-element weight tensor.
+ dtype: Torch dtype for all three operands.
+ """
+
+ _op_name = "lerp_tensor"
+ _wrapped = None
+
+ # Manifest declares all three operands as ``float16 | bfloat16 | float32``;
+ # fp8 dtypes are rejected at the op-layer signature so the impl matches
+ # the manifest contract (the kernel also rejects fp8 independently).
+ _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
+
+ def __init__(
+ self,
+ *,
+ input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
+ end: tuple,
+ weight: tuple,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if dtype not in self._SUPPORTED_DTYPES:
+ names = ", ".join(str(dt) for dt in self._SUPPORTED_DTYPES)
+ raise ValueError(
+ f"LerpTensorFwdOp does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.input_shape = tuple(input)
+ self.end_shape = tuple(end)
+ self.weight_shape = tuple(weight)
+ self.dtype = dtype
+ self.strategy = strategy
+ self.out_shape = tuple(
+ torch.broadcast_shapes(
+ self.input_shape, self.end_shape, self.weight_shape,
+ )
+ )
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](
+ self.N_total, dtype, tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self) -> Dict[str, Kernel]:
+ return {"lerp_tensor": LerpTensorFwdKernel}
+
+ @staticmethod
+ def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
+ """Expand ``t`` to ``target_shape`` and return a contiguous flat view."""
+ if tuple(t.shape) != tuple(target_shape):
+ t = t.expand(target_shape)
+ return t.contiguous().view(-1)
+
+ def _eager_forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ end: torch.Tensor,
+ weight: torch.Tensor,
+ ) -> torch.Tensor:
+ out_shape = self.out_shape if self.out_shape else (1,)
+ a_flat = self._expand_flat(input, out_shape)
+ b_flat = self._expand_flat(end, out_shape)
+ w_flat = self._expand_flat(weight, out_shape)
+ result = self.kernel(a_flat, b_flat, w_flat)
+ return result.view(self.out_shape if self.out_shape else ())
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ end: torch.Tensor,
+ weight: torch.Tensor,
+ ) -> torch.Tensor:
+ if not (input.is_cuda and end.is_cuda and weight.is_cuda):
+ raise ValueError("Inputs must be CUDA tensors")
+ for name, t, expected in [
+ ("input", input, self.input_shape),
+ ("end", end, self.end_shape),
+ ("weight", weight, self.weight_shape),
+ ]:
+ if t.dtype != self.dtype:
+ raise ValueError(
+ f"Expected {name}.dtype {self.dtype}, got {t.dtype}"
+ )
+ if tuple(t.shape) != expected:
+ raise ValueError(
+ f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, end, weight, self._instance_key)
+ return self._eager_forward(input, end, weight)
diff --git a/tileops/ops/elementwise/bitwise.py b/tileops/ops/elementwise/bitwise.py
new file mode 100644
index 00000000..0a589062
--- /dev/null
+++ b/tileops/ops/elementwise/bitwise.py
@@ -0,0 +1,38 @@
+"""Element-wise bitwise ops."""
+
+from tileops.kernels.elementwise import (
+ BitwiseAndFwdKernel,
+ BitwiseNotFwdKernel,
+ BitwiseOrFwdKernel,
+ BitwiseXorFwdKernel,
+)
+
+from ._base import BinaryOp, UnaryOp
+
+
+class BitwiseAndFwdOp(BinaryOp):
+ """Element-wise bitwise AND with broadcast: y = a & b."""
+
+ _op_name = "bitwise_and"
+ kernel_cls = BitwiseAndFwdKernel
+
+
+class BitwiseOrFwdOp(BinaryOp):
+ """Element-wise bitwise OR with broadcast: y = a | b."""
+
+ _op_name = "bitwise_or"
+ kernel_cls = BitwiseOrFwdKernel
+
+
+class BitwiseXorFwdOp(BinaryOp):
+ """Element-wise bitwise XOR with broadcast: y = a ^ b."""
+
+ _op_name = "bitwise_xor"
+ kernel_cls = BitwiseXorFwdKernel
+
+
+class BitwiseNotFwdOp(UnaryOp):
+ """Element-wise bitwise NOT (~x) for bool/integer inputs."""
+
+ _op_name = "bitwise_not"
+ kernel_cls = BitwiseNotFwdKernel
diff --git a/tileops/ops/elementwise/clamp.py b/tileops/ops/elementwise/clamp.py
new file mode 100644
index 00000000..fb226528
--- /dev/null
+++ b/tileops/ops/elementwise/clamp.py
@@ -0,0 +1,347 @@
+"""Clamp ops (Tensor-bound and scalar-bound variants)."""
+
+from math import prod
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import ClampFwdKernel, ClampTensorFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import (
+ _OP_REGISTRY,
+ _apply_fp8_post_cast,
+ _ClampTensorBase,
+ _validate_scalar_param_repr,
+)
+
+
+class ClampFwdOp(_ClampTensorBase):
+ """Clamp with Tensor lower and/or upper bounds (broadcasting).
+
+ Conforms to ``torch.clamp(input, min, max)`` where ``min`` and ``max``
+ are each either a Tensor or ``None``. At least one of the two bounds
+ must be a Tensor. All Tensor operands broadcast together. The
+ primary spec entry in ``tileops/manifest/`` covers the both-Tensor
+ form; the mixed Tensor/``None`` cases are runtime-equivalent to
+ ``ClampMinFwdOp`` / ``ClampMaxFwdOp`` and are accepted here so callers
+ can mirror PyTorch's ``torch.clamp`` API directly.
+
+ Args:
+ input: Shape of the input tensor.
+ min: Shape of the lower-bound tensor, or ``None`` for no lower bound.
+ max: Shape of the upper-bound tensor, or ``None`` for no upper bound.
+ dtype: Torch dtype for all operands.
+
+ Raises:
+ ValueError: If both ``min`` and ``max`` are ``None``.
+ """
+
+ _op_name = "clamp"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
+ min: Optional[tuple] = None, # noqa: A002 — manifest-aligned PyTorch param name
+ max: Optional[tuple] = None, # noqa: A002 — manifest-aligned PyTorch param name
+ dtype: torch.dtype = torch.float32,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if min is None and max is None:
+ raise ValueError(
+ "ClampFwdOp requires at least one of `min` or `max` to be a "
+ "Tensor shape; both None is not a valid clamp."
+ )
+ self.input_shape = tuple(input)
+ self.min_shape = None if min is None else tuple(min)
+ self.max_shape = None if max is None else tuple(max)
+ self.dtype = dtype
+ broadcast_args = [self.input_shape]
+ if self.min_shape is not None:
+ broadcast_args.append(self.min_shape)
+ if self.max_shape is not None:
+ broadcast_args.append(self.max_shape)
+ self.out_shape = tuple(torch.broadcast_shapes(*broadcast_args))
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["clamp_tensor"](
+ self.N_total, dtype,
+ has_min=self.min_shape is not None,
+ has_max=self.max_shape is not None,
+ tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"clamp_tensor": ClampTensorFwdKernel}
+
+ def _eager_forward(
+ self,
+ input: torch.Tensor, # noqa: A002
+ min: Optional[torch.Tensor] = None, # noqa: A002
+ max: Optional[torch.Tensor] = None, # noqa: A002
+ ) -> torch.Tensor:
+ # Broadcast all operands to ``out_shape`` and dispatch the
+ # TileLang Tensor-bound clamp kernel. The kernel branches on
+ # ``has_min`` / ``has_max`` at build time, so this single Op
+ # class also covers the mixed Tensor/None cases.
+ out_shape = self.out_shape if self.out_shape else (1,)
+ x_flat = self._expand_flat(input, out_shape)
+ lo_flat = None if min is None else self._expand_flat(min, out_shape)
+ hi_flat = None if max is None else self._expand_flat(max, out_shape)
+ result = self.kernel(x_flat, lo_flat, hi_flat)
+ return result.view(self.out_shape if self.out_shape else ())
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002
+ min: Optional[torch.Tensor] = None, # noqa: A002
+ max: Optional[torch.Tensor] = None, # noqa: A002
+ ) -> torch.Tensor:
+ # Validate that the runtime None / Tensor pattern matches what
+ # __init__ was configured for — the broadcast shape and the
+ # presence of each bound is baked in at construction.
+ if (min is None) != (self.min_shape is None):
+ raise ValueError(
+ f"min was {'None' if self.min_shape is None else 'a Tensor shape'} at "
+ f"__init__ but {'None' if min is None else 'a Tensor'} at forward()"
+ )
+ if (max is None) != (self.max_shape is None):
+ raise ValueError(
+ f"max was {'None' if self.max_shape is None else 'a Tensor shape'} at "
+ f"__init__ but {'None' if max is None else 'a Tensor'} at forward()"
+ )
+ tensors = [("input", input, self.input_shape)]
+ if min is not None:
+ tensors.append(("min", min, self.min_shape))
+ if max is not None:
+ tensors.append(("max", max, self.max_shape))
+ for _, t, _ in tensors:
+ if not t.is_cuda:
+ raise ValueError("Inputs must be CUDA tensors")
+ for name, t, expected in tensors:
+ if t.dtype != self.dtype:
+ raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
+ if tuple(t.shape) != expected:
+ raise ValueError(
+ f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, min, max, self._instance_key)
+ return self._eager_forward(input, min, max)
+
+
+class ClampMinFwdOp(_ClampTensorBase):
+ """Single-bound Tensor lower clamp (``torch.clamp_min``).
+
+ Args:
+ input: Shape of the input tensor.
+ min: Shape of the lower-bound tensor.
+ dtype: Torch dtype.
+ """
+
+ _op_name = "clamp_min"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002
+ min: tuple, # noqa: A002
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ self.input_shape = tuple(input)
+ self.min_shape = tuple(min)
+ self.dtype = dtype
+ self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.min_shape))
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["clamp_tensor"](
+ self.N_total, dtype, has_min=True, has_max=False, tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"clamp_tensor": ClampTensorFwdKernel}
+
+ def _eager_forward(
+ self, input: torch.Tensor, min: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ # Broadcast input/min to out_shape and dispatch the TileLang
+ # min-only Tensor-bound clamp kernel.
+ out_shape = self.out_shape if self.out_shape else (1,)
+ x_flat = self._expand_flat(input, out_shape)
+ lo_flat = self._expand_flat(min, out_shape)
+ result = self.kernel(x_flat, lo_flat, None)
+ return result.view(self.out_shape if self.out_shape else ())
+
+ def forward(
+ self, input: torch.Tensor, min: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ if not (input.is_cuda and min.is_cuda):
+ raise ValueError("Inputs must be CUDA tensors")
+ for name, t, expected in [
+ ("input", input, self.input_shape),
+ ("min", min, self.min_shape),
+ ]:
+ if t.dtype != self.dtype:
+ raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
+ if tuple(t.shape) != expected:
+ raise ValueError(
+ f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, min, self._instance_key)
+ return self._eager_forward(input, min)
+
+
+class ClampMaxFwdOp(_ClampTensorBase):
+ """Single-bound Tensor upper clamp (``torch.clamp_max``).
+
+ Args:
+ input: Shape of the input tensor.
+ max: Shape of the upper-bound tensor.
+ dtype: Torch dtype.
+ """
+
+ _op_name = "clamp_max"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002
+ max: tuple, # noqa: A002
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ self.input_shape = tuple(input)
+ self.max_shape = tuple(max)
+ self.dtype = dtype
+ self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.max_shape))
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["clamp_tensor"](
+ self.N_total, dtype, has_min=False, has_max=True, tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"clamp_tensor": ClampTensorFwdKernel}
+
+ def _eager_forward(
+ self, input: torch.Tensor, max: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ # Broadcast input/max to out_shape and dispatch the TileLang
+ # max-only Tensor-bound clamp kernel.
+ out_shape = self.out_shape if self.out_shape else (1,)
+ x_flat = self._expand_flat(input, out_shape)
+ hi_flat = self._expand_flat(max, out_shape)
+ result = self.kernel(x_flat, None, hi_flat)
+ return result.view(self.out_shape if self.out_shape else ())
+
+ def forward(
+ self, input: torch.Tensor, max: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ if not (input.is_cuda and max.is_cuda):
+ raise ValueError("Inputs must be CUDA tensors")
+ for name, t, expected in [
+ ("input", input, self.input_shape),
+ ("max", max, self.max_shape),
+ ]:
+ if t.dtype != self.dtype:
+ raise ValueError(f"Expected {name}.dtype {self.dtype}, got {t.dtype}")
+ if tuple(t.shape) != expected:
+ raise ValueError(
+ f"Expected {name}.shape {expected}, got {tuple(t.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, max, self._instance_key)
+ return self._eager_forward(input, max)
+
+
+class ClampScalarFwdOp(Op):
+ """Scalar-bound clamp (``torch.clamp(input, min: Number|None, max: Number|None)``).
+
+ Args:
+ input: Shape of the input tensor.
+ min: Lower bound (Number or None).
+ max: Upper bound (Number or None).
+ dtype: Torch dtype.
+ """
+
+ _op_name = "clamp"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002
+ min: Optional[float] = None, # noqa: A002
+ max: Optional[float] = None, # noqa: A002
+ dtype: torch.dtype = torch.float32,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if min is None and max is None:
+ raise ValueError(
+ "ClampScalarFwdOp requires at least one of `min` or `max` to be a "
+ "Number; both None is not a valid clamp."
+ )
+ if min is not None:
+ _validate_scalar_param_repr("min", min, dtype, self._op_name)
+ if max is not None:
+ _validate_scalar_param_repr("max", max, dtype, self._op_name)
+ self.input_shape = tuple(input)
+ self.N_total = prod(self.input_shape) if self.input_shape else 1
+ self.dtype = dtype
+ self.min = min
+ self.max = max
+ # Backwards-compat aliases for legacy callers.
+ self.min_val = min
+ self.max_val = max
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["clamp"](
+ self.N_total, dtype, min_val=min, max_val=max, tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"clamp": ClampFwdKernel}
+
+ def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ orig_shape = input.shape
+ result = self.kernel(input.contiguous().reshape(-1)).reshape(orig_shape)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ if not input.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if tuple(input.shape) != self.input_shape:
+ raise ValueError(
+ f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, self._instance_key)
+ return self._eager_forward(input)
diff --git a/tileops/ops/elementwise/comparison.py b/tileops/ops/elementwise/comparison.py
new file mode 100644
index 00000000..a2863b51
--- /dev/null
+++ b/tileops/ops/elementwise/comparison.py
@@ -0,0 +1,58 @@
+"""Element-wise comparison ops (output bool).
+
+Kernels produce int8 (1/0) because TileLang cannot vectorize bool.
+The Op forward() casts to torch.bool after the kernel call.
+"""
+
+from tileops.kernels.elementwise import (
+ EqFwdKernel,
+ GeFwdKernel,
+ GtFwdKernel,
+ LeFwdKernel,
+ LtFwdKernel,
+ NeFwdKernel,
+)
+
+from ._base import _BoolOutputBinaryOp
+
+
+class EqFwdOp(_BoolOutputBinaryOp):
+ """Element-wise equality with broadcast: y = (a == b)."""
+
+ _op_name = "eq"
+ kernel_cls = EqFwdKernel
+
+
+class NeFwdOp(_BoolOutputBinaryOp):
+ """Element-wise not-equal with broadcast: y = (a != b)."""
+
+ _op_name = "ne"
+ kernel_cls = NeFwdKernel
+
+
+class GtFwdOp(_BoolOutputBinaryOp):
+ """Element-wise greater-than with broadcast: y = (a > b)."""
+
+ _op_name = "gt"
+ kernel_cls = GtFwdKernel
+
+
+class LtFwdOp(_BoolOutputBinaryOp):
+ """Element-wise less-than with broadcast: y = (a < b)."""
+
+ _op_name = "lt"
+ kernel_cls = LtFwdKernel
+
+
+class GeFwdOp(_BoolOutputBinaryOp):
+ """Element-wise greater-equal with broadcast: y = (a >= b)."""
+
+ _op_name = "ge"
+ kernel_cls = GeFwdKernel
+
+
+class LeFwdOp(_BoolOutputBinaryOp):
+ """Element-wise less-equal with broadcast: y = (a <= b)."""
+
+ _op_name = "le"
+ kernel_cls = LeFwdKernel
diff --git a/tileops/ops/elementwise/fused_gated.py b/tileops/ops/elementwise/fused_gated.py
new file mode 100644
index 00000000..592a4a3a
--- /dev/null
+++ b/tileops/ops/elementwise/fused_gated.py
@@ -0,0 +1,30 @@
+"""Fused gated elementwise ops: y = activation(gate) * value."""
+
+from tileops.kernels.elementwise import (
+ GeluAndMulFwdKernel,
+ GeluTanhAndMulFwdKernel,
+ SiluAndMulFwdKernel,
+)
+
+from ._base import FusedGatedOp
+
+
+class SiluAndMulFwdOp(FusedGatedOp):
+ """SiLU-and-Mul: y = silu(gate) * value."""
+
+ _op_name = "silu_and_mul"
+ kernel_cls = SiluAndMulFwdKernel
+
+
+class GeluAndMulFwdOp(FusedGatedOp):
+ """GELU-and-Mul: y = gelu(gate) * value (exact GELU)."""
+
+ _op_name = "gelu_and_mul"
+ kernel_cls = GeluAndMulFwdKernel
+
+
+class GeluTanhAndMulFwdOp(FusedGatedOp):
+ """GELU-Tanh-and-Mul: y = gelu_tanh(gate) * value (tanh approximation)."""
+
+ _op_name = "gelu_tanh_and_mul"
+ kernel_cls = GeluTanhAndMulFwdKernel
diff --git a/tileops/ops/elementwise/logical.py b/tileops/ops/elementwise/logical.py
new file mode 100644
index 00000000..0a3f4d4f
--- /dev/null
+++ b/tileops/ops/elementwise/logical.py
@@ -0,0 +1,30 @@
+"""Element-wise logical ops (output bool)."""
+
+from tileops.kernels.elementwise import (
+ LogicalAndFwdKernel,
+ LogicalNotFwdKernel,
+ LogicalOrFwdKernel,
+)
+
+from ._base import UnaryOp, _BoolOutputBinaryOp
+
+
+class LogicalAndFwdOp(_BoolOutputBinaryOp):
+ """Element-wise logical AND with broadcast using non-zero truthiness."""
+
+ _op_name = "logical_and"
+ kernel_cls = LogicalAndFwdKernel
+
+
+class LogicalOrFwdOp(_BoolOutputBinaryOp):
+ """Element-wise logical OR with broadcast using non-zero truthiness."""
+
+ _op_name = "logical_or"
+ kernel_cls = LogicalOrFwdKernel
+
+
+class LogicalNotFwdOp(UnaryOp):
+ """Element-wise logical NOT with bool output."""
+
+ _op_name = "logical_not"
+ kernel_cls = LogicalNotFwdKernel
diff --git a/tileops/ops/elementwise/masked_fill.py b/tileops/ops/elementwise/masked_fill.py
new file mode 100644
index 00000000..8510ecfb
--- /dev/null
+++ b/tileops/ops/elementwise/masked_fill.py
@@ -0,0 +1,215 @@
+"""MaskedFill ops (Tensor-value and scalar-value variants)."""
+
+from math import prod
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import (
+ MaskedFillFwdKernel,
+ MaskedFillTensorValueFwdKernel,
+)
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY, _apply_fp8_post_cast, _validate_scalar_param_repr
+
+
+class MaskedFillFwdOp(Op):
+ """MaskedFill with 0-dim Tensor value (``torch.Tensor.masked_fill(mask, value: Tensor)``).
+
+ Output shape is the bidirectional broadcast of ``input`` and ``mask``;
+ ``value`` must be a 0-dim Tensor. The Op expands ``input`` and ``mask``
+ to the broadcast shape and dispatches the existing flat scalar kernel
+ using ``value.item()`` as the fill literal — this keeps the
+ fast vectorized kernel path while satisfying the manifest's Tensor-value
+ contract (the kernel reads ``value`` once at forward time, which is
+ consistent with the 0-dim semantics).
+
+ Args:
+ input: Shape of the input tensor.
+ mask: Shape of the mask tensor (bool).
+ value: Shape of the value tensor (must be ``()`` per the manifest).
+ dtype: Torch dtype for ``input`` / ``value``.
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "masked_fill"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002
+ mask: tuple,
+ value: tuple,
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ if tuple(value) != ():
+ raise ValueError(
+ f"MaskedFillFwdOp requires a 0-dim value Tensor; got shape {tuple(value)}"
+ )
+ self.input_shape = tuple(input)
+ self.mask_shape = tuple(mask)
+ self.value_shape = tuple(value)
+ self.dtype = dtype
+ self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.mask_shape))
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["masked_fill_tensor_value"](self.N_total, dtype)
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"masked_fill_tensor_value": MaskedFillTensorValueFwdKernel}
+
+ @staticmethod
+ def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
+ if tuple(t.shape) != tuple(target_shape):
+ t = t.expand(target_shape)
+ return t.contiguous().view(-1)
+
+ def _eager_forward(
+ self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ # Broadcast input/mask to out_shape, pack mask as uint8, reshape
+ # the 0-dim value to (1,), and dispatch the TileLang kernel.
+ out_shape = self.out_shape if self.out_shape else (1,)
+ x_flat = self._expand_flat(input, out_shape)
+ mask_b = mask if mask.dtype == torch.bool else mask.bool()
+ mask_flat = self._expand_flat(mask_b, out_shape).view(torch.uint8)
+ value_1d = value.contiguous().view(1)
+ result = self.kernel(x_flat, mask_flat, value_1d)
+ return result.view(self.out_shape if self.out_shape else ())
+
+ def forward(
+ self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ if not (input.is_cuda and mask.is_cuda and value.is_cuda):
+ raise ValueError("Inputs must be CUDA tensors")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if mask.dtype != torch.bool:
+ raise ValueError(f"Expected mask.dtype torch.bool, got {mask.dtype}")
+ if value.dtype != self.dtype:
+ raise ValueError(f"Expected value.dtype {self.dtype}, got {value.dtype}")
+ if tuple(input.shape) != self.input_shape:
+ raise ValueError(
+ f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
+ )
+ if tuple(mask.shape) != self.mask_shape:
+ raise ValueError(
+ f"Expected mask.shape {self.mask_shape}, got {tuple(mask.shape)}"
+ )
+ if tuple(value.shape) != ():
+ raise ValueError(f"Expected value.shape (), got {tuple(value.shape)}")
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, mask, value, self._instance_key)
+ return self._eager_forward(input, mask, value)
+
+
+class MaskedFillScalarFwdOp(Op):
+ """MaskedFill with Number (scalar) value.
+
+ Conforms to ``torch.Tensor.masked_fill(mask, value: Number)``. Output
+ shape follows the bidirectional broadcast of ``input`` and ``mask``.
+
+ The manifest declares the PyTorch dtype union (``bool | uint8 |
+ int8 | int16 | int32 | int64 | float16 | bfloat16 | float32``). The
+ current TileLang kernel only supports float dtypes; integer and
+ bool dtypes are rejected at construction time with ``ValueError``
+ until a real int / bool kernel lands (tracked in a follow-up issue).
+
+ Args:
+ input: Shape of the input tensor.
+ mask: Shape of the mask tensor (bool).
+ value: Scalar fill value (bool / int / float). Range-validated
+ against ``dtype``.
+ dtype: Torch dtype. Must be a kernel-supported floating-point
+ dtype.
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "masked_fill"
+ _wrapped = None
+
+ def __init__(
+ self,
+ input: tuple, # noqa: A002
+ mask: tuple,
+ value: bool | int | float = 0,
+ dtype: torch.dtype = torch.float32,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ kernel_supported = MaskedFillFwdKernel.SUPPORTED_DTYPES
+ if kernel_supported is not None and dtype not in kernel_supported:
+ names = ", ".join(str(dt) for dt in kernel_supported)
+ raise ValueError(
+ f"{self._op_name} does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.input_shape = tuple(input)
+ self.mask_shape = tuple(mask)
+ self.dtype = dtype
+ self.value = value
+ # Backwards-compat alias.
+ self.fill_value = value
+ self.out_shape = tuple(torch.broadcast_shapes(self.input_shape, self.mask_shape))
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ # The kernel is always built on the broadcast (output) flat size.
+ # When input/mask already match out_shape, this is a no-op expand;
+ # otherwise the Op layer broadcasts both before dispatch.
+ self._needs_broadcast = (
+ self.input_shape != self.out_shape or self.mask_shape != self.out_shape
+ )
+ _validate_scalar_param_repr("value", value, dtype, self._op_name)
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map["masked_fill"](self.N_total, dtype, value)
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"masked_fill": MaskedFillFwdKernel}
+
+ @staticmethod
+ def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
+ if tuple(t.shape) != tuple(target_shape):
+ t = t.expand(target_shape)
+ return t.contiguous().view(-1)
+
+ def _eager_forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # noqa: A002
+ out_shape = self.out_shape if self.out_shape else (1,)
+ x_flat = self._expand_flat(input, out_shape)
+ mask_b = mask if mask.dtype == torch.bool else mask.bool()
+ mask_flat = self._expand_flat(mask_b, out_shape).view(torch.uint8)
+ result = self.kernel(x_flat, mask_flat).view(self.out_shape if self.out_shape else ())
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self, input: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # noqa: A002
+ if not input.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if tuple(input.shape) != self.input_shape:
+ raise ValueError(
+ f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
+ )
+ if not mask.is_cuda:
+ raise ValueError("Mask must be a CUDA tensor")
+ if mask.dtype != torch.bool:
+ raise ValueError(f"Expected mask.dtype torch.bool, got {mask.dtype}")
+ if tuple(mask.shape) != self.mask_shape:
+ raise ValueError(
+ f"Expected mask.shape {self.mask_shape}, got {tuple(mask.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, mask, self._instance_key)
+ return self._eager_forward(input, mask)
diff --git a/tileops/ops/elementwise/math_unary.py b/tileops/ops/elementwise/math_unary.py
new file mode 100644
index 00000000..346f80ce
--- /dev/null
+++ b/tileops/ops/elementwise/math_unary.py
@@ -0,0 +1,235 @@
+"""Unary math elementwise ops (exp/log/sqrt/abs/neg/round/etc.)."""
+
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import (
+ AbsFwdKernel,
+ CeilFwdKernel,
+ CosFwdKernel,
+ ErfFwdKernel,
+ ExpFwdKernel,
+ Expm1FwdKernel,
+ FloorFwdKernel,
+ Log1pFwdKernel,
+ LogFwdKernel,
+ NegFwdKernel,
+ ReciprocalFwdKernel,
+ RoundFwdKernel,
+ RsqrtFwdKernel,
+ SignFwdKernel,
+ SinFwdKernel,
+ SqrtFwdKernel,
+ TruncFwdKernel,
+)
+from tileops.kernels.kernel_base import Kernel
+
+from ._base import _MANIFEST_INT_DTYPES, UnaryOp, _IntIdentityUnaryOp
+
+
+class ExpFwdOp(UnaryOp):
+ """Element-wise exp(x)."""
+
+ _op_name = "exp"
+ kernel_cls = ExpFwdKernel
+
+
+class LogFwdOp(UnaryOp):
+ """Element-wise log(x)."""
+
+ _op_name = "log"
+ kernel_cls = LogFwdKernel
+
+
+class SqrtFwdOp(UnaryOp):
+ """Element-wise sqrt(x)."""
+
+ _op_name = "sqrt"
+ kernel_cls = SqrtFwdKernel
+
+
+class RsqrtFwdOp(UnaryOp):
+ """Element-wise 1/sqrt(x)."""
+
+ _op_name = "rsqrt"
+ kernel_cls = RsqrtFwdKernel
+
+
+class AbsFwdOp(_IntIdentityUnaryOp):
+ """Element-wise |x|."""
+
+ _op_name = "abs"
+ kernel_cls = AbsFwdKernel
+ _int_handler = staticmethod(torch.abs)
+
+
+class NegFwdOp(_IntIdentityUnaryOp):
+ """Element-wise -x."""
+
+ _op_name = "neg"
+ kernel_cls = NegFwdKernel
+ _int_handler = staticmethod(torch.neg)
+
+
+class ReciprocalFwdOp(UnaryOp):
+ """Element-wise 1/x.
+
+ Mirrors ``torch.reciprocal`` int-input promotion: integral dtypes
+ (uint8 / int8 / int16 / int32 / int64) are cast to float32 before the
+ float kernel runs, and the op's ``output_dtype`` is float32 in that
+ case. Floating inputs (float16 / bfloat16 / float32) follow the
+ standard same-dtype path.
+ """
+
+ _op_name = "reciprocal"
+ kernel_cls = ReciprocalFwdKernel
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ strategy: Optional[str] = None,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ if dtype in _MANIFEST_INT_DTYPES:
+ # Build the kernel against the promoted compute dtype (float32)
+ # so the float-only ReciprocalFwdKernel can run, then restore
+ # the user-declared dtype on ``self.dtype`` so metadata and
+ # ``eval_roofline`` reflect the real I/O contract: integer
+ # input bytes + float32 output bytes. ``self.output_dtype``
+ # stays float32 (set by the kernel) per the manifest's
+ # ``promote_int_to_float`` contract.
+ super().__init__(
+ N_total, torch.float32, strategy=strategy,
+ kernel_map=kernel_map, tune=tune,
+ )
+ self.dtype = dtype
+ else:
+ super().__init__(
+ N_total, dtype, strategy=strategy,
+ kernel_map=kernel_map, tune=tune,
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ if self.dtype in _MANIFEST_INT_DTYPES:
+ self._validate_input(input)
+ promoted = input.to(torch.float32)
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(promoted, self._instance_key)
+ return self._eager_forward(promoted)
+ return super().forward(input)
+
+
+class SignFwdOp(_IntIdentityUnaryOp):
+ """Element-wise sign(x): -1, 0, or +1."""
+
+ _op_name = "sign"
+ kernel_cls = SignFwdKernel
+ # Manifest: flops = "2 * N" (two compares + selects per element).
+ FLOPS_PER_ELEM = 2
+ _int_handler = staticmethod(torch.sign)
+
+
+class SinFwdOp(UnaryOp):
+ """Element-wise sin(x)."""
+
+ _op_name = "sin"
+ kernel_cls = SinFwdKernel
+
+
+class CosFwdOp(UnaryOp):
+ """Element-wise cos(x)."""
+
+ _op_name = "cos"
+ kernel_cls = CosFwdKernel
+
+
+class FloorFwdOp(_IntIdentityUnaryOp):
+ """Element-wise floor(x)."""
+
+ _op_name = "floor"
+ kernel_cls = FloorFwdKernel
+
+
+class CeilFwdOp(_IntIdentityUnaryOp):
+ """Element-wise ceil(x)."""
+
+ _op_name = "ceil"
+ kernel_cls = CeilFwdKernel
+
+
+class RoundFwdOp(_IntIdentityUnaryOp):
+ """Element-wise round(x) to ``decimals`` decimal places.
+
+ The underlying kernel performs banker's round-to-nearest-integer, matching
+ ``torch.round`` for ``decimals=0``. Non-zero ``decimals`` is supported at
+ the op layer via the standard decomposition:
+ ``round(x, decimals=k) == round(x * 10**k) / 10**k``.
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ strategy: Kernel strategy override.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune.
+ """
+
+ _op_name = "round"
+ kernel_cls = RoundFwdKernel
+
+ def forward( # noqa: A002
+ self, input: torch.Tensor, decimals: int = 0,
+ ) -> torch.Tensor:
+ if decimals == 0:
+ return super().forward(input)
+ # Non-zero decimals path still owes the same input contract as the
+ # ``decimals=0`` fast path (UnaryOp.forward). Run the shared validator
+ # before any fp32 arithmetic so a CPU tensor / wrong dtype / wrong
+ # numel cannot silently bypass the checks.
+ self._validate_input(input)
+ # Integer dtypes are no-ops regardless of decimals (rounding an int
+ # produces the same int). Match the float-path identity contract.
+ if self.dtype in _MANIFEST_INT_DTYPES:
+ return input.clone()
+ # Run through fp32 so low-precision inputs (fp16/bf16) cannot overflow
+ # when ``torch.round`` internally scales by ``10**decimals`` — e.g.
+ # ``100 * 10**4 = 1e6`` exceeds fp16 max (~65504). The single down-cast
+ # at the end restores the op's contract dtype. The manifest's
+ # ``kernel_map`` continues to describe the round-to-nearest-integer
+ # kernel that handles the ``decimals=0`` fast path above.
+ return torch.round(input.float(), decimals=decimals).to(self.dtype)
+
+
+class TruncFwdOp(_IntIdentityUnaryOp):
+ """Element-wise trunc(x)."""
+
+ _op_name = "trunc"
+ kernel_cls = TruncFwdKernel
+
+
+class ErfFwdOp(UnaryOp):
+ """Element-wise erf(x)."""
+
+ _op_name = "erf"
+ kernel_cls = ErfFwdKernel
+
+
+class Log1pFwdOp(UnaryOp):
+ """Element-wise log(1 + x)."""
+
+ _op_name = "log1p"
+ kernel_cls = Log1pFwdKernel
+ # Manifest: flops = "2 * N" (1 add + 1 log).
+ FLOPS_PER_ELEM = 2
+
+
+class Expm1FwdOp(UnaryOp):
+ """Element-wise exp(x) - 1."""
+
+ _op_name = "expm1"
+ kernel_cls = Expm1FwdKernel
+ # Manifest: flops = "2 * N" (1 exp + 1 sub).
+ FLOPS_PER_ELEM = 2
diff --git a/tileops/ops/elementwise/nan_to_num.py b/tileops/ops/elementwise/nan_to_num.py
new file mode 100644
index 00000000..5762ad00
--- /dev/null
+++ b/tileops/ops/elementwise/nan_to_num.py
@@ -0,0 +1,102 @@
+"""NanToNum op: replace NaN, +Inf, -Inf with specified values."""
+
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import NanToNumFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY, _apply_fp8_post_cast, _validate_scalar_param_repr
+
+
+class NanToNumFwdOp(Op):
+ """NanToNum: replace NaN, +Inf, -Inf with specified values.
+
+ Args:
+ N_total: Total number of elements (flattened).
+ dtype: Torch dtype.
+ nan: Replacement for NaN (default 0.0).
+ posinf: Replacement for +Inf. Manifest default ``None`` resolves
+ to the largest finite value representable in the user-facing
+ ``dtype`` (matches ``torch.nan_to_num``). Explicit values
+ must also be representable in ``dtype`` end-to-end; values
+ that fit only in the kernel's intermediate dtype (e.g. fp16
+ for fp8_e5m2) are rejected so the post-cast cannot resurface
+ them as Inf.
+ neginf: Replacement for -Inf. Manifest default ``None`` resolves
+ to the smallest (most negative) finite value representable
+ in the user-facing ``dtype``.
+ kernel_map: Optional kernel dispatch override.
+ tune: Whether to autotune the kernel.
+ """
+
+ _op_name = "nan_to_num"
+ _wrapped = None
+
+ def __init__(
+ self,
+ N_total: int,
+ dtype: torch.dtype,
+ nan: float = 0.0,
+ posinf: Optional[float] = None,
+ neginf: Optional[float] = None,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ tune: bool = False,
+ ):
+ # The manifest default ``None`` resolves to the *final*
+ # user-facing dtype's max / min, not ``+/-inf``: the kernel runs
+ # in ``output_dtype`` (fp16 for e5m2 to preserve Inf/NaN) and
+ # _clamp_to_dtype_range targets that intermediate, so forwarding
+ # ``+inf`` would resolve to fp16's 65504.0 and then surface as
+ # ``+Inf`` after the e5m2 post-cast (e5m2 max is 57344.0).
+ # Picking ``torch.finfo(dtype).max`` here keeps the replacement
+ # value finite end-to-end and matches ``torch.nan_to_num``
+ # semantics (replace Inf with the dtype's max finite value).
+ _validate_scalar_param_repr("nan", nan, dtype, self._op_name)
+ if posinf is None:
+ kernel_posinf = torch.finfo(dtype).max
+ else:
+ _validate_scalar_param_repr("posinf", posinf, dtype, self._op_name)
+ kernel_posinf = posinf
+ if neginf is None:
+ kernel_neginf = torch.finfo(dtype).min
+ else:
+ _validate_scalar_param_repr("neginf", neginf, dtype, self._op_name)
+ kernel_neginf = neginf
+ self.N_total = N_total
+ self.dtype = dtype
+ self.nan = nan
+ self.posinf = posinf
+ self.neginf = neginf
+ self.dispatch_kernel(kernel_map)
+ # Pass replacement values positionally; the kernel constructor's
+ # internal parameter naming is encapsulated below the Op layer.
+ self.kernel = self.kernel_map["nan_to_num"](
+ N_total, dtype, nan, kernel_posinf, kernel_neginf, tune=tune,
+ )
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"nan_to_num": NanToNumFwdKernel}
+
+ def _eager_forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ orig_shape = input.shape
+ result = self.kernel(input.contiguous().reshape(-1)).reshape(orig_shape)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # noqa: A002
+ if not input.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if input.numel() != self.N_total:
+ raise ValueError(f"Expected {self.N_total} elements, got {input.numel()}")
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, self._instance_key)
+ return self._eager_forward(input)
diff --git a/tileops/ops/elementwise/predicates.py b/tileops/ops/elementwise/predicates.py
new file mode 100644
index 00000000..02b625b8
--- /dev/null
+++ b/tileops/ops/elementwise/predicates.py
@@ -0,0 +1,58 @@
+"""Special predicate ops: isnan, isinf, isfinite (output bool)."""
+
+import torch
+
+from tileops.kernels.elementwise import (
+ IsfiniteFwdKernel,
+ IsinfFwdKernel,
+ IsnanFwdKernel,
+)
+
+from ._base import (
+ _PREDICATE_FALLBACK_DTYPES,
+ _int_all_false,
+ _int_all_true,
+ _IntIdentityUnaryOp,
+)
+
+
+class IsnanFwdOp(_IntIdentityUnaryOp):
+ """Element-wise isnan with bool output.
+
+ Always False on integer / bool input (no NaN representation in those
+ dtypes).
+ """
+
+ _op_name = "isnan"
+ kernel_cls = IsnanFwdKernel
+ _int_handler = staticmethod(_int_all_false)
+ _int_output_dtype = torch.bool
+ _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
+
+
+class IsinfFwdOp(_IntIdentityUnaryOp):
+ """Element-wise isinf with bool output.
+
+ Always False on integer / bool input (no Inf representation in those
+ dtypes).
+ """
+
+ _op_name = "isinf"
+ kernel_cls = IsinfFwdKernel
+ _int_handler = staticmethod(_int_all_false)
+ _int_output_dtype = torch.bool
+ _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
+
+
+class IsfiniteFwdOp(_IntIdentityUnaryOp):
+ """Element-wise isfinite with bool output.
+
+ Always True on integer / bool input (every value in those dtypes is
+ finite).
+ """
+
+ _op_name = "isfinite"
+ kernel_cls = IsfiniteFwdKernel
+ _int_handler = staticmethod(_int_all_true)
+ _int_output_dtype = torch.bool
+ _fallback_dtypes = _PREDICATE_FALLBACK_DTYPES
diff --git a/tileops/ops/elementwise/prelu.py b/tileops/ops/elementwise/prelu.py
new file mode 100644
index 00000000..1b89e975
--- /dev/null
+++ b/tileops/ops/elementwise/prelu.py
@@ -0,0 +1,96 @@
+"""PReLU op: y = x if x > 0 else weight[channel] * x."""
+
+from math import prod
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import PreluFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY, _apply_fp8_post_cast
+
+
+class PreluFwdOp(Op):
+ """PReLU: y = x if x > 0 else weight[channel] * x.
+
+ Channel dimension follows PyTorch convention: dimension 1 for inputs
+ with ndim >= 2, dimension 0 for 1-D inputs.
+
+ Args:
+ shape: Shape of the input tensor (must have a channel dimension).
+ dtype: Torch dtype.
+ num_channels: Number of channels (weight length).
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "prelu"
+ _wrapped = None
+
+ def __init__(
+ self,
+ shape: tuple,
+ dtype: torch.dtype,
+ num_channels: int,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ self.shape = shape
+ self.dtype = dtype
+ self.num_channels = num_channels
+ N_total = prod(shape)
+ self.N_total = N_total
+ # PyTorch PReLU: channel dim is 1 for ndim>=2, else 0
+ inner_size = (prod(shape[2:]) if len(shape) > 2 else 1) if len(shape) >= 2 else 1
+ self.inner_size = inner_size
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](N_total, num_channels, inner_size, dtype)
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"prelu": PreluFwdKernel}
+
+ def _eager_forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ weight: torch.Tensor,
+ ) -> torch.Tensor:
+ orig_shape = input.shape
+ result = self.kernel(
+ input.contiguous().reshape(-1), weight.contiguous().reshape(-1),
+ ).reshape(orig_shape)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(
+ self,
+ input: torch.Tensor, # noqa: A002 — manifest-aligned PyTorch param name
+ weight: torch.Tensor,
+ ) -> torch.Tensor:
+ if not input.is_cuda:
+ raise ValueError("Input must be a CUDA tensor")
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if input.numel() != self.N_total:
+ raise ValueError(f"Expected {self.N_total} elements, got {input.numel()}")
+ # ``weight`` is part of the manifest contract; validate device,
+ # dtype, and length so a malformed weight fails fast at the op
+ # boundary instead of corrupting the kernel.
+ if not weight.is_cuda:
+ raise ValueError("Weight must be a CUDA tensor")
+ if weight.dtype != self.dtype:
+ raise ValueError(
+ f"Expected weight.dtype {self.dtype}, got {weight.dtype}"
+ )
+ if weight.numel() != self.num_channels:
+ raise ValueError(
+ f"Expected weight to have {self.num_channels} elements, "
+ f"got {weight.numel()}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(input, weight, self._instance_key)
+ return self._eager_forward(input, weight)
diff --git a/tileops/ops/elementwise/sinusoidal.py b/tileops/ops/elementwise/sinusoidal.py
new file mode 100644
index 00000000..3e3baab9
--- /dev/null
+++ b/tileops/ops/elementwise/sinusoidal.py
@@ -0,0 +1,65 @@
+"""Sinusoidal positional encoding generative op."""
+
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import SinusoidalFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY, _apply_fp8_post_cast
+
+
+class SinusoidalFwdOp(Op):
+ """Sinusoidal positional encoding from "Attention Is All You Need".
+
+ Generates the full (seq_len, d_model) encoding tensor.
+
+ Args:
+ seq_len: Sequence length.
+ d_model: Model dimension.
+ dtype: Torch dtype.
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "sinusoidal"
+ _wrapped = None
+
+ def __init__(
+ self,
+ seq_len: int,
+ d_model: int,
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ self.seq_len = seq_len
+ self.d_model = d_model
+ self.dtype = dtype
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](seq_len, d_model, dtype)
+ # Scalar tensor used as device/dtype carrier for torch.compile tracing
+ self._device_carrier = torch.empty((), dtype=dtype, device="cuda")
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"sinusoidal": SinusoidalFwdKernel}
+
+ def _eager_forward(self) -> torch.Tensor:
+ out = self.kernel()
+ result = out.reshape(self.seq_len, self.d_model)
+ return _apply_fp8_post_cast(result, self.kernel)
+
+ def forward(self) -> torch.Tensor:
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(
+ self._device_carrier,
+ self.seq_len, self.d_model,
+ self._instance_key,
+ )
+ return self._eager_forward()
diff --git a/tileops/ops/elementwise/where.py b/tileops/ops/elementwise/where.py
new file mode 100644
index 00000000..2c58b70a
--- /dev/null
+++ b/tileops/ops/elementwise/where.py
@@ -0,0 +1,120 @@
+"""Where op: out = condition ? input : other (with broadcasting)."""
+
+from math import prod
+from typing import Dict, Optional
+
+import torch
+
+from tileops.kernels.elementwise import WhereFwdKernel
+from tileops.kernels.kernel_base import Kernel
+
+from ..op_base import Op
+from ._base import _OP_REGISTRY
+
+
+class WhereFwdOp(Op):
+ """Where: out = condition ? input : other (with full PyTorch broadcasting).
+
+ Conforms to ``torch.where(condition, input, other)``: ``condition`` is a
+ bool tensor and ``input`` / ``other`` may broadcast with each other and
+ with ``condition`` to produce the output. The Op layer expands all
+ three inputs to the broadcast shape and dispatches the existing flat
+ where kernel on ``N_total = product(broadcast_shape)`` elements.
+
+ Args:
+ condition: Shape of the condition tensor (any shape broadcastable
+ with ``input`` / ``other``).
+ input: Shape of the value-when-true tensor.
+ other: Shape of the value-when-false tensor.
+ dtype: Torch dtype for ``input`` / ``other``.
+ kernel_map: Optional dispatch override mapping kernel keys to
+ ``Kernel`` subclasses. Falls back to ``default_kernel_map``.
+ """
+
+ _op_name = "where"
+ _wrapped = None
+
+ # Manifest declares ``input`` / ``other`` dtype as
+ # ``float16 | bfloat16 | float32``. fp8 dtypes are not in the contract;
+ # reject them at the op-layer signature so the impl matches the manifest.
+ _SUPPORTED_DTYPES = (torch.float16, torch.bfloat16, torch.float32)
+
+ def __init__(
+ self,
+ condition: tuple,
+ input: tuple, # noqa: A002 — manifest-aligned PyTorch param name
+ other: tuple,
+ dtype: torch.dtype,
+ *,
+ kernel_map: Optional[Dict[str, Kernel]] = None,
+ ):
+ if dtype not in self._SUPPORTED_DTYPES:
+ names = ", ".join(str(dt) for dt in self._SUPPORTED_DTYPES)
+ raise ValueError(
+ f"WhereFwdOp does not support dtype {dtype}. "
+ f"Supported: [{names}]"
+ )
+ self.condition_shape = tuple(condition)
+ self.input_shape = tuple(input)
+ self.other_shape = tuple(other)
+ self.dtype = dtype
+ self.out_shape = tuple(
+ torch.broadcast_shapes(self.condition_shape, self.input_shape, self.other_shape)
+ )
+ self.N_total = prod(self.out_shape) if self.out_shape else 1
+ self.dispatch_kernel(kernel_map)
+ self.kernel = self.kernel_map[self._op_name](self.N_total, dtype)
+ self._instance_key = id(self)
+ _OP_REGISTRY[self._instance_key] = self
+
+ @property
+ def default_kernel_map(self):
+ return {"where": WhereFwdKernel}
+
+ @staticmethod
+ def _expand_flat(t: torch.Tensor, target_shape: tuple) -> torch.Tensor:
+ """Expand ``t`` to ``target_shape`` and return a contiguous flat view."""
+ if tuple(t.shape) != tuple(target_shape):
+ t = t.expand(target_shape)
+ return t.contiguous().view(-1)
+
+ def _eager_forward(
+ self, condition: torch.Tensor, input: torch.Tensor, other: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ out_shape = self.out_shape if self.out_shape else (1,)
+ cond_b = condition if condition.dtype == torch.bool else condition.bool()
+ cond_flat = self._expand_flat(cond_b, out_shape).view(torch.uint8)
+ x_flat = self._expand_flat(input, out_shape)
+ y_flat = self._expand_flat(other, out_shape)
+ result = self.kernel(cond_flat, x_flat, y_flat).view(out_shape if self.out_shape else ())
+ return result
+
+ def forward(
+ self, condition: torch.Tensor, input: torch.Tensor, other: torch.Tensor, # noqa: A002
+ ) -> torch.Tensor:
+ if not (condition.is_cuda and input.is_cuda and other.is_cuda):
+ raise ValueError("Inputs must be CUDA tensors")
+ if condition.dtype != torch.bool:
+ raise ValueError(
+ f"Expected condition.dtype torch.bool, got {condition.dtype}"
+ )
+ if input.dtype != self.dtype:
+ raise ValueError(f"Expected input.dtype {self.dtype}, got {input.dtype}")
+ if other.dtype != self.dtype:
+ raise ValueError(f"Expected other.dtype {self.dtype}, got {other.dtype}")
+ if tuple(condition.shape) != self.condition_shape:
+ raise ValueError(
+ f"Expected condition.shape {self.condition_shape}, got {tuple(condition.shape)}"
+ )
+ if tuple(input.shape) != self.input_shape:
+ raise ValueError(
+ f"Expected input.shape {self.input_shape}, got {tuple(input.shape)}"
+ )
+ if tuple(other.shape) != self.other_shape:
+ raise ValueError(
+ f"Expected other.shape {self.other_shape}, got {tuple(other.shape)}"
+ )
+ wrapped = type(self)._wrapped
+ if wrapped is not None:
+ return wrapped(condition, input, other, self._instance_key)
+ return self._eager_forward(condition, input, other)