You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Introduces python/neura_ops.py as the Python frontend interface for Neura hardware primitives.
Programmers can explicitly call custom ops (e.g., torch.ops.neura.gather()) to annotate hardware-specific memory access patterns that standard PyTorch ops cannot express.
Custom ops behave identically to standard PyTorch operations at runtime, but are preserved as-is through torch_mlir tracing so that the Neura compiler backend can recognize and lower them to hardware primitives.
Example: neura::gather in hash encoding
hash_encode kernel uses indirect addressing (level_embeddings[indices]) for hash-table lookups. After torch_mlir tracing, this becomes aten.index.Tensor, which loses the gather semantics needed by CGRA.
With the custom op, programmers write torch.ops.neura.gather(table, indices) instead. torch_mlir preserves it as torch.operator "neura.gather", giving the compiler a clear signal.
Refactored trilinear_interpolation to batch all 8 corner indices into a single gather per level (from 8 calls down to 1).
Changes
python/neura_ops.py (new): Registers neura::gather with CPU and Meta implementations, compatible with PyTorch 2.1+.
neura.gather appears exactly 2 times in output MLIR (one per level).
I don't see MLIR output?
MLIR output is generated by running compile_hash_encode.py
PyTorch-side correctness verified (semantically equivalent to the original implementation).
The correctness is based on "semantic", instead of execution, right?
Current correctness verification is semantic-level. Maybe we could register neura.gather in the Neura interpreter, which will enable execution-level correctness verification at the MLIR layer?
neura.gather appears exactly 2 times in output MLIR (one per level).
I don't see MLIR output?
MLIR output is generated by running compile_hash_encode.py
PyTorch-side correctness verified (semantically equivalent to the original implementation).
The correctness is based on "semantic", instead of execution, right?
Current correctness verification is semantic-level. Maybe we could register neura.gather in the Neura interpreter, which will enable execution-level correctness verification at the MLIR layer?
Then can we have a lit test to run the .py and check its output?
neura.gather in the Neura interpreter
We can do this later (another PR once your design/infra is almost done).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
python/neura_ops.pyas the Python frontend interface for Neura hardware primitives.torch.ops.neura.gather()) to annotate hardware-specific memory access patterns that standard PyTorch ops cannot express.Example:
neura::gatherin hash encodinghash_encodekernel uses indirect addressing (level_embeddings[indices]) for hash-table lookups. After torch_mlir tracing, this becomesaten.index.Tensor, which loses the gather semantics needed by CGRA.torch.ops.neura.gather(table, indices)instead. torch_mlir preserves it astorch.operator "neura.gather", giving the compiler a clear signal.trilinear_interpolationto batch all 8 corner indices into a single gather per level (from 8 calls down to 1).Changes
python/neura_ops.py(new): Registersneura::gatherwith CPU and Meta implementations, compatible with PyTorch 2.1+.nerf_kernels.py(new): Importsneura_ops, replaces fancy indexing withtorch.ops.neura.gather(), two-phase trilinear interpolation.compile_hash_encode.py(new): Compilation script that exports clean ~930-line Torch Dialect MLIR withneura.gatherpreserved.Verification
neura.gatherappears exactly 2 times in output MLIR (one per level).aten.index.Tensorfully eliminated.