Skip to content

Conversation

kimm240
Copy link

@kimm240 kimm240 commented Aug 27, 2025

This commit extends the make_fused_bias_activation_pattern function to support
PyTorch frontend's specific IR generation pattern for convolution operations
with bias. When PyTorch models with bias=True are converted to Relax IR, the
frontend generates a conv2d -> reshape -> add -> relu sequence instead of the
simpler conv2d -> add -> relu pattern that existing fusion logic expected.

The key changes include:

  1. Add allow_reshape parameter to make_fused_bias_activation_pattern in both
    dpl/pattern.py and backend/patterns.py with default value False to maintain
    backward compatibility.

  2. When allow_reshape=True, the pattern matcher now recognizes and fuses the
    complete conv2d -> reshape -> add -> relu sequence into a single composite
    function, eliminating intermediate tensor allocations and kernel launch
    overhead.

  3. The original pattern (allow_reshape=False) only fuses conv2d -> add -> relu,
    leaving the reshape operation outside the fused function, which results in
    suboptimal performance for PyTorch-originated models.

This enhancement enables more efficient operator fusion for PyTorch models,
reducing memory usage and improving execution performance by capturing the
complete computation pattern in a single fused kernel. The implementation
maintains full backward compatibility while extending support for PyTorch
frontend's specific IR generation patterns.

Comprehensive tests are added to verify the fusion behavior with both old and
new patterns, ensuring correctness across different convolution types (Conv1d,
Conv2d, Conv3d) and validating that fusion only occurs when appropriate
conditions are met.

kim hyun gyu added 4 commits July 29, 2025 19:45
… fusion

This PR introduces an operator fusion for the common `conv2d` followed by `reshape`, `add`, and `relu` sequence, commonly found in deep learning models (e.g., convolution + bias + activation pattern in PyTorch). This optimization significantly improves performance and efficiency by reducing overhead and optimizing memory usage.

1.  **Performance Improvement:**
    * **Reduced Kernel Launch Overhead:** Previously, `conv2d`, `reshape`, `add`, and `relu` each required separate kernel calls. By fusing these four operations into a single, unified DNNL kernel (e.g., `dnnl_fused_conv2d_bias_relu`), the overhead from multiple kernel launches is significantly reduced. This is evident from `src/runtime/contrib/dnnl/dnnl.cc:154-158`, where all operations are handled by a single `execute` call.
    * **Decreased Memory Bandwidth Consumption:** Intermediate results of individual operations (e.g., `conv_out`, `bias_add`) traditionally required costly memory write-backs and reads. Fusion allows these intermediate values to be processed directly in registers or cache, reducing unnecessary memory accesses, and thus decreasing memory bandwidth usage and overall execution time.

2.  **Increased Efficiency:**
    * **Leveraging Compiler Optimizations:** By utilizing TVM's `FuseOpsByPattern` and `MergeCompositeFunctions` passes, this change generates a composite operation optimized for specific backends (like DNNL). This ensures that common patterns from frontends like PyTorch are automatically recognized within the TVM graph and mapped to high-performance fused kernels provided by libraries like DNNL.
    * **Simplified IR Module:** Compilers' Intermediate Representation (IR) becomes less complex as multiple operation nodes are condensed into a single composite node. This simplification enhances efficiency in subsequent optimization and code generation stages.

This fusion is achieved through a two-stage transformation within the TVM Relax framework:

1.  **Pattern Recognition and Composite Function Creation (`FuseConv2dReshapeAddRelu` Pass):**
    * The `FuseConv2dReshapeAddRelu` class, registered as a `tvm.transform.module_pass`, transforms the `IRModule`.
    * The `_conv2d_reshape_add_relu_pattern()` helper function defines the specific sequence: `conv2d` -> `reshape` (applied to bias) -> `add` -> `relu` using TVM's Declarative Pattern Language (DPL). This includes matching input tensors (`data`, `weight`, `bias`, `shape`) using `wildcard()` and identifying operation sequence with `is_op()`.
    * The `relax.transform.FuseOpsByPattern` pass identifies this pattern in the input `IRModule`. Upon detection, the operation sequence is encapsulated into a new Relax function with `{"Composite": "dnnl.conv2d_reshape_add_relu", "Primitive": True}` attributes, marking it as a logical "composite" unit.

2.  **Composite Function Merging and Codegen Attribute Assignment (`MergeCompositeFunctions` Pass):**
    * Following the `FuseConv2dReshapeAddRelu` pass, the `MergeCompositeFunctions` pass is applied via `tvm.ir.transform.Sequential`.
    * This pass identifies functions marked with the `Composite` attribute and transforms them into external functions bearing the `{"Codegen": "dnnl"}` attribute. This `Codegen` attribute indicates that the composite operation should be offloaded to a specific TVM backend, such as DNNL.
    * Consequently, during graph execution, the fused function with the `Codegen` attribute will be mapped and executed by an optimized, single DNNL kernel, for instance, `dnnl_fused_conv2d_bias_relu` (defined in `src/runtime/contrib/dnnl/dnnl.cc:199-207`).

This implementation successfully enables the fusion of the `conv2d + reshape + add + relu` pattern. This ensures that common convolution + bias + activation patterns originating from frontends like PyTorch are now fully optimized and executed as a single, highly efficient DNNL kernel within TVM.

---

To verify this fusion, you can directly run the specific test case:

python tests/python/relax/test_conv2d_reshape_add_relu.py
…yTorch Conv2d-Reshape-Add-ReLU fusion

This commit extends the existing make_fused_bias_activation_pattern function to
handle PyTorch frontend's specific IR generation pattern, following the design
principle of enhancing general fusion capabilities rather than creating
specialized solutions. When PyTorch models with bias=True are converted to
Relax IR, the frontend generates a conv2d -> reshape -> add -> relu sequence
that was previously not recognized by the standard fusion patterns.

The key enhancement adds an allow_reshape parameter to make_fused_bias_activation_pattern
in both python/tvm/relax/dpl/pattern.py and python/tvm/relax/backend/patterns.py,
enabling the existing fusion infrastructure to handle a broader range of patterns.
This approach ensures that other frontends and use cases can benefit from the
same enhancement, rather than creating PyTorch-specific fusion logic.

When allow_reshape=True, the pattern matcher recognizes the complete
conv2d -> reshape -> add -> relu sequence and fuses it into a single composite
function. The original behavior (allow_reshape=False) is preserved as the
default, maintaining full backward compatibility while extending the system's
capabilities to handle frontend-specific IR patterns.

This enhancement demonstrates how extending existing fusion infrastructure can
address specific frontend requirements while maintaining generality. The
solution reduces memory usage and kernel launch overhead for PyTorch models
by eliminating intermediate tensor allocations, while keeping the door open
for other frontends to leverage the same pattern matching improvements.

Comprehensive tests are added in tests/python/relax/test_fuse_pytorch_conv2d_bias_pattern.py
to validate the fusion behavior with both old and new patterns, ensuring
correctness across different convolution types and confirming that fusion
only occurs when appropriate conditions are met.
… fusion

This commit extends the make_fused_bias_activation_pattern function to support
PyTorch frontend's specific IR generation pattern for convolution operations
with bias. When PyTorch models with bias=True are converted to Relax IR, the
frontend generates a conv2d -> reshape -> add -> relu sequence instead of the
simpler conv2d -> add -> relu pattern that existing fusion logic expected.

The key changes include:

1. Add allow_reshape parameter to make_fused_bias_activation_pattern in both
   dpl/pattern.py and backend/patterns.py with default value False to maintain
   backward compatibility.

2. When allow_reshape=True, the pattern matcher now recognizes and fuses the
   complete conv2d -> reshape -> add -> relu sequence into a single composite
   function, eliminating intermediate tensor allocations and kernel launch
   overhead.

3. The original pattern (allow_reshape=False) only fuses conv2d -> add -> relu,
   leaving the reshape operation outside the fused function, which results in
   suboptimal performance for PyTorch-originated models.

This enhancement enables more efficient operator fusion for PyTorch models,
reducing memory usage and improving execution performance by capturing the
complete computation pattern in a single fused kernel. The implementation
maintains full backward compatibility while extending support for PyTorch
frontend's specific IR generation patterns.

Comprehensive tests are added to verify the fusion behavior with both old and
new patterns, ensuring correctness across different convolution types (Conv1d,
Conv2d, Conv3d) and validating that fusion only occurs when appropriate
conditions are met.
@kimm240
Copy link
Author

kimm240 commented Aug 27, 2025

@yongwww
Hello. Sorry for dup PR.
(dup: #18173)
This is because I'm inexperienced with sending PRs.
This PR extend FuseOps to handle operator fusion issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant