Skip to content
Merged
4 changes: 4 additions & 0 deletions docs/design/manifest.md
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,10 @@ python scripts/validate_manifest.py
python scripts/validate_manifest.py --check-op SoftmaxFwdOp
```

### Generative-op carve-out

An op whose `ref_api` is `"none"` AND whose `forward()` takes zero positional args is a **generative op**: its output is synthesized from construction-time params alone, with no semantic tensor argument. L1 and C4 skip the `signature.inputs` vs `forward()` positional-name alignment for such entries. The manifest still carries `>= 1` entry under `signature.inputs` (a carrier scalar) to satisfy `test_every_signature_has_inputs_and_outputs`; the carrier's `dtype` is the channel for `same_as(...)` on the output. ALiBi and sinusoidal position encodings are the current cases.

## Exclusions

The manifest does NOT describe: multi-kernel execution ordering, accumulator dtypes, persistent state, tile sizes, or autotuning config.
131 changes: 103 additions & 28 deletions scripts/validate_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def check_l1_signature(
*,
init_params: list[str] | None = None,
manifest_static_dims: dict | None = None,
is_generative: bool = False,
) -> list[str]:
"""Check that forward() params match manifest inputs + params.

Expand All @@ -555,6 +556,14 @@ def check_l1_signature(
init_params: List of parameter names from Op.__init__() (excluding 'self').
When None, treated as empty (only forward is checked).
manifest_static_dims: The signature.static_dims dict from manifest (may be None).
is_generative: When True, skip the forward()-order check against
``manifest_inputs``. Set by ``check_l1`` for entries whose
``ref_api`` is ``"none"`` and whose ``forward()`` takes zero
positional args — i.e. generative ops whose manifest inputs
are carriers (device/dtype scalars) rather than semantic
tensors, kept only because
``tests/test_ops_manifest.py::test_every_signature_has_inputs_and_outputs``
requires ``len(signature.inputs) >= 1``.

Returns:
List of error strings (empty if OK).
Expand All @@ -572,15 +581,19 @@ def check_l1_signature(
if init_params is None:
init_params = []

# 1. forward() order check: manifest inputs + forward-visible params, in order
expected = list(manifest_inputs.keys()) + [
name for name in manifest_params.keys() if name in forward_params
]
if forward_params != expected:
errors.append(
f"[signature] {op_name}: forward() params {forward_params} do not match "
f"manifest order {expected}"
)
# 1. forward() order check: manifest inputs + forward-visible params, in order.
# Generative-op carve-out: when ``is_generative`` is True the manifest
# inputs are carriers, not arguments threaded through ``forward()``, so
# there is no order to enforce.
if not is_generative:
expected = list(manifest_inputs.keys()) + [
name for name in manifest_params.keys() if name in forward_params
]
if forward_params != expected:
errors.append(
f"[signature] {op_name}: forward() params {forward_params} do not match "
f"manifest order {expected}"
)
Comment thread
lcy-seso marked this conversation as resolved.

# 2. Strict subset check: every manifest param must exist in init OR forward
code_params = set(forward_params) | set(init_params)
Expand Down Expand Up @@ -704,6 +717,37 @@ def _get_forward_params(cls) -> list[str] | None:
return None


_POSITIONAL_KINDS = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)


def _forward_positional_params(cls) -> list[str] | None:
"""Get positional parameter names of cls.forward(), excluding 'self'.

Only POSITIONAL_ONLY / POSITIONAL_OR_KEYWORD count. KEYWORD_ONLY
params (those after ``*``) are not part of the positional tuple
that manifest ``signature.inputs`` aligns against. Used by both
the generative-op carve-out in check_l1 and the C4 forward-signature
parity check so the two stay in lockstep.
"""
try:
sig = inspect.signature(cls.forward)
return [
p for p, v in sig.parameters.items()
if p != "self" and v.kind in _POSITIONAL_KINDS
]
except (ValueError, TypeError) as exc:
# Stash exception text so callers that surface diagnostics can
# report ``exc.__class__.__name__: exc`` without changing the
# ``None`` return contract for "not inspectable".
_forward_positional_params._last_error = ( # type: ignore[attr-defined]
f"{exc.__class__.__name__}: {exc}"
)
return None


def _get_init_params(cls) -> list[str]:
"""Get explicit parameter names of cls.__init__(), excluding 'self'.

Expand Down Expand Up @@ -789,10 +833,33 @@ def check_l1(
manifest_static_dims = sig.get("static_dims")
init_params = _get_init_params(result.cls)

# Generative-op detection: ``ref_api: "none"`` + zero **positional**
# forward() args signals a kernel that synthesizes its output from
# construction-time params alone (e.g. ALiBi/Sinusoidal position
# encodings). The manifest still carries >= 1 entry under
# ``signature.inputs`` to satisfy
# ``tests/test_ops_manifest.py::test_every_signature_has_inputs_and_outputs``;
# under this carve-out the forward()-order check is skipped because
# those manifest "inputs" are carriers, not semantic tensor arguments.
# Filter to positional-only kinds so the L1 carve-out matches
# check_c4_forward_signature_parity exactly — keyword-only params
# (those after ``*``) are not aligned positionally either.
# Distinguish ``None`` (introspection failed) from ``[]`` (forward()
# has zero positional args). Only the latter qualifies for the
# generative-op carve-out; an introspection failure must NOT silently
# skip the L1 alignment check.
positional_forward_params = _forward_positional_params(result.cls)
is_generative = (
entry.get("ref_api") == "none"
and positional_forward_params is not None
and len(positional_forward_params) == 0
)

return check_l1_signature(
op_name, manifest_inputs, manifest_params, forward_params,
init_params=init_params,
manifest_static_dims=manifest_static_dims,
is_generative=is_generative,
)


Expand Down Expand Up @@ -3309,29 +3376,37 @@ def check_c4_forward_signature_parity(
return errors
expected = list(manifest_inputs.keys())

try:
py_sig = inspect.signature(cls.forward)
except (ValueError, TypeError) as exc:
positional = _forward_positional_params(cls)
if positional is None:
if warnings is not None:
warnings.append(
f"[forward] {op_name}: inspect.signature(forward) raised "
f"{exc.__class__.__name__}: {exc}"
detail = getattr(
_forward_positional_params, "_last_error", None
)
Comment thread
lcy-seso marked this conversation as resolved.
if detail:
warnings.append(
f"[forward] {op_name}: inspect.signature(forward) "
f"raised {detail}"
)
# Clear so a later call site sees only its own failure.
_forward_positional_params._last_error = None # type: ignore[attr-defined]
else:
warnings.append(
f"[forward] {op_name}: inspect.signature(forward) failed"
)
return errors

# Only POSITIONAL_ONLY / POSITIONAL_OR_KEYWORD count as positional.
# KEYWORD_ONLY params (those after ``*``) are not part of the
# positional tuple manifest ``signature.inputs`` describes.
positional: list[str] = []
for pname, p in py_sig.parameters.items():
if pname == "self":
continue
if p.kind not in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
continue
positional.append(pname)
# Generative-op carve-out: ``ref_api: "none"`` plus zero forward()
# positional args signals a kernel that synthesizes its output from
# construction-time params alone. Its manifest inputs are carriers
# kept to satisfy
# ``tests/test_ops_manifest.py::test_every_signature_has_inputs_and_outputs``;
# there is no positional argument to align against them.
#
# ``positional`` is guaranteed non-None here (introspection failure
# returned early above). Use explicit ``len(...) == 0`` so the carve-
# out cannot be confused with the introspection-failed path.
if entry.get("ref_api") == "none" and len(positional) == 0:
return errors

actual_prefix = positional[: len(expected)]
if actual_prefix != expected:
Expand Down
132 changes: 132 additions & 0 deletions tileops/manifest/elementwise_fused_gated.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# elementwise_fused_gated.yaml -- manifest entries for fused gated activation
# ops: y = activation(gate) * value, with gate/value packed as the two halves
# of a single (M, 2*N) input tensor.
#
# These are TileOPs-private fused kernels with no single ``torch.*`` symbol
# they mirror (the PyTorch composite expression for SiluAndMulFwdOp is
# ``F.silu(a) * b``, etc.), so ``ref_api`` is ``"none"`` per
# .claude/domain-rules/manifest-spec.md. Source of truth for op interfaces
# in this family. Loaded and merged with other family files by
# tileops.manifest at runtime. See docs/design/manifest.md for the schema.

SiluAndMulFwdOp:
# Composite expression: F.silu(gate) * value where x = [gate, value] along
# the last dim, x.shape == (M, 2*N). No single torch.* symbol.
ref_api: "none"
family: elementwise
status: spec-only

signature:
inputs:
x: {dtype: "float16 | bfloat16 | float32", shape: "[M, two_N]"}
outputs:
output: {dtype: "same_as(x)", shape: "[M, N]"}
shape_rules:
- "x.shape[1] == 2 * output.shape[1]"
- "x.shape[0] == output.shape[0]"

workloads:
# SwiGLU FFN intermediate (Llama-3.1-8B, hidden_dim=14336)
- {x_shape: [2048, 28672], dtypes: [float16, bfloat16], label: "llama-3.1-8b-swiglu-prefill"}
- {x_shape: [1, 28672], dtypes: [bfloat16], label: "llama-3.1-8b-swiglu-decode"}

roofline:
vars:
M: "x.shape[0]"
N: "output.shape[1]"
# FLOPs: silu(gate) = gate * sigmoid(gate); sigmoid = neg + exp + add +
# recip = 4. silu adds one mul (= 5). Final mul-by-value = 1. Total 6
# per output element. N output elements per row, M rows.
flops: "6 * M * N"
# Read 2*M*N gate+value, write M*N output.
bytes: "3 * M * N * elem_bytes"

source:
kernel: tileops/kernels/elementwise.py
kernel_map:
silu_and_mul: SiluAndMulFwdKernel
op: tileops/ops/elementwise/fused_gated.py
test: tests/ops/test_fused_gated.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false

GeluAndMulFwdOp:
# Composite expression: F.gelu(gate, approximate='none') * value where
# x = [gate, value] along the last dim, x.shape == (M, 2*N). Exact GELU
# (erf-based). No single torch.* symbol.
ref_api: "none"
family: elementwise
status: spec-only

signature:
inputs:
x: {dtype: "float16 | bfloat16 | float32", shape: "[M, two_N]"}
outputs:
output: {dtype: "same_as(x)", shape: "[M, N]"}
shape_rules:
Comment thread
lcy-seso marked this conversation as resolved.
- "x.shape[1] == 2 * output.shape[1]"
- "x.shape[0] == output.shape[0]"

workloads:
- {x_shape: [2048, 28672], dtypes: [float16, bfloat16], label: "ffn-gelu-prefill"}
- {x_shape: [1, 28672], dtypes: [bfloat16], label: "ffn-gelu-decode"}

roofline:
vars:
M: "x.shape[0]"
N: "output.shape[1]"
# FLOPs: gelu(gate) = gate * 0.5 * (1 + erf(gate / sqrt(2))).
# div(1) + erf(1) + add(1) + mul-by-half(1) + mul-by-gate(1) = 5;
# final mul-by-value = 1. Total 6 per output element.
flops: "6 * M * N"
bytes: "3 * M * N * elem_bytes"

source:
kernel: tileops/kernels/elementwise.py
kernel_map:
gelu_and_mul: GeluAndMulFwdKernel
op: tileops/ops/elementwise/fused_gated.py
test: tests/ops/test_fused_gated.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false

GeluTanhAndMulFwdOp:
# Composite expression: F.gelu(gate, approximate='tanh') * value where
# x = [gate, value] along the last dim, x.shape == (M, 2*N). tanh-based
# GELU approximation. No single torch.* symbol.
ref_api: "none"
family: elementwise
status: spec-only

signature:
inputs:
x: {dtype: "float16 | bfloat16 | float32", shape: "[M, two_N]"}
outputs:
output: {dtype: "same_as(x)", shape: "[M, N]"}
shape_rules:
Comment thread
lcy-seso marked this conversation as resolved.
- "x.shape[1] == 2 * output.shape[1]"
- "x.shape[0] == output.shape[0]"

workloads:
- {x_shape: [2048, 28672], dtypes: [float16, bfloat16], label: "ffn-gelu-tanh-prefill"}
- {x_shape: [1, 28672], dtypes: [bfloat16], label: "ffn-gelu-tanh-decode"}

roofline:
vars:
M: "x.shape[0]"
N: "output.shape[1]"
# FLOPs: gelu_tanh(g) = 0.5 * g * (1 + tanh(sqrt(2/pi) * (g + 0.044715 * g^3))).
# cube(2 muls) + mul-coeff(1) + add(1) + mul-sqrt(1) + tanh(1) + add(1)
# + mul-half(1) + mul-by-gate(1) = 9; final mul-by-value = 1. Total 10
# per output element.
flops: "10 * M * N"
bytes: "3 * M * N * elem_bytes"

source:
kernel: tileops/kernels/elementwise.py
kernel_map:
gelu_tanh_and_mul: GeluTanhAndMulFwdKernel
op: tileops/ops/elementwise/fused_gated.py
test: tests/ops/test_fused_gated.py
bench: benchmarks/ops/bench_binary_elementwise.py
bench_manifest_driven: false
Loading
Loading