-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Optimization][Operator] Implement and enable Conv2d-Reshape-Add-ReLU fusion #18240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kimm240
wants to merge
4
commits into
apache:main
Choose a base branch
from
kimm240:conv2d-reshape-add-relu-fusion
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
… fusion This PR introduces an operator fusion for the common `conv2d` followed by `reshape`, `add`, and `relu` sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage. 1. **Performance Improvement:** * **Reduced Kernel Launch Overhead:** Previously, `conv2d`, `reshape`, `add`, and `relu` each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., `dnnl_fused_conv2d_bias_relu`), the overhead from multiple kernel launches is significantly reduced. This is evident from `src/runtime/contrib/dnnl/dnnl.cc:154-158`, where all operations are handled by a single `execute` call. * **Decreased Memory Bandwidth Consumption:** Intermediate results of individual operations (e.g., `conv_out`, `bias_add`) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time. 2. **Increased Efficiency:** * **Leveraging Compiler Optimizations:** By utilizing TVM's `FuseOpsByPattern` and `MergeCompositeFunctions` passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL. * **Simplified IR Module:** Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages. This fusion is achieved through a two-stage transformation within the TVM Relax framework: 1. **Pattern Recognition and Composite Function Creation (`FuseConv2dReshapeAddRelu` Pass):** * The `FuseConv2dReshapeAddRelu` class, registered as a `tvm.transform.module_pass`, transforms the `IRModule`. * The `_conv2d_reshape_add_relu_pattern()` helper function defines the specific sequence: `conv2d` -> `reshape` (applied to bias) -> `add` -> `relu` using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (`data`, `weight`, `bias`, `shape`) using `wildcard()` and identifying operation sequence with `is_op()`. * The `relax.transform.FuseOpsByPattern` pass identifies this pattern in the input `IRModule`. Upon detection, the operation sequence is encapsulated into a new Relax function with `{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}` attributes, marking it as a logical "composite" unit. 2. **Composite Function Merging and Codegen Attribute Assignment (`MergeCompositeFunctions` Pass):** * Following the `FuseConv2dReshapeAddRelu` pass, the `MergeCompositeFunctions` pass is applied via `tvm.ir.transform.Sequential`. * This pass identifies functions marked with the `Composite` attribute and transforms them into external functions bearing the `{"Codegen": "dnnl"}` attribute. This `Codegen` attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL. * Consequently, during graph execution, the fused function with the `Codegen` attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, `dnnl_fused_conv2d_bias_relu` (defined in `src/runtime/contrib/dnnl/dnnl.cc:199-207`). This implementation successfully enables the fusion of the `conv2d + reshape + add + relu` pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM. --- To verify this fusion, you can directly run the specific test case: python tests/python/relax/test_conv2d_reshape_add_relu.py
…Add-ReLU fusion" This reverts commit c014994.
…yTorch Conv2d-Reshape-Add-ReLU fusion This commit extends the existing make_fused_bias_activation_pattern function to handle PyTorch frontend's specific IR generation pattern, following the design principle of enhancing general fusion capabilities rather than creating specialized solutions. When PyTorch models with bias=True are converted to Relax IR, the frontend generates a conv2d -> reshape -> add -> relu sequence that was previously not recognized by the standard fusion patterns. The key enhancement adds an allow_reshape parameter to make_fused_bias_activation_pattern in both python/tvm/relax/dpl/pattern.py and python/tvm/relax/backend/patterns.py, enabling the existing fusion infrastructure to handle a broader range of patterns. This approach ensures that other frontends and use cases can benefit from the same enhancement, rather than creating PyTorch-specific fusion logic. When allow_reshape=True, the pattern matcher recognizes the complete conv2d -> reshape -> add -> relu sequence and fuses it into a single composite function. The original behavior (allow_reshape=False) is preserved as the default, maintaining full backward compatibility while extending the system's capabilities to handle frontend-specific IR patterns. This enhancement demonstrates how extending existing fusion infrastructure can address specific frontend requirements while maintaining generality. The solution reduces memory usage and kernel launch overhead for PyTorch models by eliminating intermediate tensor allocations, while keeping the door open for other frontends to leverage the same pattern matching improvements. Comprehensive tests are added in tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py to validate the fusion behavior with both old and new patterns, ensuring correctness across different convolution types and confirming that fusion only occurs when appropriate conditions are met.
… fusion This commit extends the make_fused_bias_activation_pattern function to support PyTorch frontend's specific IR generation pattern for convolution operations with bias. When PyTorch models with bias=True are converted to Relax IR, the frontend generates a conv2d -> reshape -> add -> relu sequence instead of the simpler conv2d -> add -> relu pattern that existing fusion logic expected. The key changes include: 1. Add allow_reshape parameter to make_fused_bias_activation_pattern in both dpl/pattern.py and backend/patterns.py with default value False to maintain backward compatibility. 2. When allow_reshape=True, the pattern matcher now recognizes and fuses the complete conv2d -> reshape -> add -> relu sequence into a single composite function, eliminating intermediate tensor allocations and kernel launch overhead. 3. The original pattern (allow_reshape=False) only fuses conv2d -> add -> relu, leaving the reshape operation outside the fused function, which results in suboptimal performance for PyTorch-originated models. This enhancement enables more efficient operator fusion for PyTorch models, reducing memory usage and improving execution performance by capturing the complete computation pattern in a single fused kernel. The implementation maintains full backward compatibility while extending support for PyTorch frontend's specific IR generation patterns. Comprehensive tests are added to verify the fusion behavior with both old and new patterns, ensuring correctness across different convolution types (Conv1d, Conv2d, Conv3d) and validating that fusion only occurs when appropriate conditions are met.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This commit extends the make_fused_bias_activation_pattern function to support
PyTorch frontend's specific IR generation pattern for convolution operations
with bias. When PyTorch models with bias=True are converted to Relax IR, the
frontend generates a conv2d -> reshape -> add -> relu sequence instead of the
simpler conv2d -> add -> relu pattern that existing fusion logic expected.
The key changes include:
Add allow_reshape parameter to make_fused_bias_activation_pattern in both
dpl/pattern.py and backend/patterns.py with default value False to maintain
backward compatibility.
When allow_reshape=True, the pattern matcher now recognizes and fuses the
complete conv2d -> reshape -> add -> relu sequence into a single composite
function, eliminating intermediate tensor allocations and kernel launch
overhead.
The original pattern (allow_reshape=False) only fuses conv2d -> add -> relu,
leaving the reshape operation outside the fused function, which results in
suboptimal performance for PyTorch-originated models.
This enhancement enables more efficient operator fusion for PyTorch models,
reducing memory usage and improving execution performance by capturing the
complete computation pattern in a single fused kernel. The implementation
maintains full backward compatibility while extending support for PyTorch
frontend's specific IR generation patterns.
Comprehensive tests are added to verify the fusion behavior with both old and
new patterns, ensuring correctness across different convolution types (Conv1d,
Conv2d, Conv3d) and validating that fusion only occurs when appropriate
conditions are met.