Add MHC (Multi-Head Computation) kernels for Ascend A5#104
Conversation
15 PTO-ISA kernels implementing the full MHC forward/backward pass from DeepSeek TileKernels, generated via PTO-DSL + ptoas pipeline. Kernels: expand_to_mhc, head_compute_mix, pre_split_mixes, pre_apply_mix, pre_norm_fn, fn_normw_merge, post, sinkhorn_normalize (fwd + bwd for each). Tested on Ascend950PR (dav-3510) with CANN 9.0. Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
There was a problem hiding this comment.
Code Review
This pull request introduces 15 manual PTO-ISA kernels for the Multi-Head Computation (MHC) architecture, optimized for Ascend A5 hardware, along with a CMake build system and a host-side test suite. The reviewer provided several actionable suggestions to improve the code quality, including adding error handling for ACL API calls, making hardcoded paths configurable, and optimizing the build process to prevent redundant compilation. Additionally, the feedback included specific code suggestions to use sizeof for memory allocations and to implement a more robust element-wise verification loop for floating-point results instead of using memcmp.
| aclInit(nullptr); | ||
| aclrtSetDevice(0); | ||
| void *stream = nullptr; | ||
| aclrtCreateStream(&stream); | ||
|
|
There was a problem hiding this comment.
The return values of ACL initialization and runtime functions (aclInit, aclrtSetDevice, aclrtCreateStream) are not checked for errors. If any of these calls fail, the program will continue and likely crash when attempting to use the device or stream. It is recommended to add error handling for all ACL API calls.
| set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH}) | ||
| endif() | ||
|
|
||
| set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver) |
There was a problem hiding this comment.
| foreach(KERNEL ${MHC_KERNELS}) | ||
| add_library(${KERNEL}_kernel SHARED ${KERNEL}.cpp) | ||
| target_compile_options(${KERNEL}_kernel PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --npu-arch=dav-3510 -DMEMORY_BASE) | ||
| endforeach() |
There was a problem hiding this comment.
The kernels are being compiled into individual shared libraries here, but expand_to_mhc_fwd.cpp is also included and compiled again within the mhc_caller library (via caller.cpp). This redundant compilation increases build time and can lead to symbol conflicts. Consider compiling the kernels once and linking them to the caller or the test executable.
| aclrtMalloc(&d_x, x_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST); | ||
| aclrtMalloc(&d_out, out_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST); | ||
| aclrtMemcpy(d_x, x_elems * 2, h_x.data(), x_elems * 2, ACL_MEMCPY_HOST_TO_DEVICE); |
There was a problem hiding this comment.
The size calculation uses a hardcoded factor of 2 to represent the size of a bf16 element. Using sizeof(uint16_t) would be more descriptive and less prone to errors if the data type changes.
| aclrtMalloc(&d_x, x_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST); | |
| aclrtMalloc(&d_out, out_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST); | |
| aclrtMemcpy(d_x, x_elems * 2, h_x.data(), x_elems * 2, ACL_MEMCPY_HOST_TO_DEVICE); | |
| aclrtMalloc(&d_x, x_elems * sizeof(uint16_t), ACL_MEM_MALLOC_HUGE_FIRST); | |
| aclrtMalloc(&d_out, out_elems * sizeof(uint16_t), ACL_MEM_MALLOC_HUGE_FIRST); | |
| aclrtMemcpy(d_x, x_elems * sizeof(uint16_t), h_x.data(), x_elems * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE); |
| aclrtMemcpy(h_out.data(), out_elems * 2, d_out, out_elems * 2, ACL_MEMCPY_DEVICE_TO_HOST); | ||
|
|
||
| /* verify */ | ||
| bool pass = (memcmp(h_out.data(), h_golden.data(), out_elems * 2) == 0); |
There was a problem hiding this comment.
Using memcmp for verifying floating-point results is not robust. While it might pass for simple bit-exact operations like broadcast, it will fail for kernels involving arithmetic where small precision differences can occur. It is better to iterate through the results and compare them individually, which also allows for better error reporting in case of a mismatch.
| bool pass = (memcmp(h_out.data(), h_golden.data(), out_elems * 2) == 0); | |
| bool pass = true; | |
| for (size_t i = 0; i < h_out.size(); ++i) { | |
| if (h_out[i] != h_golden[i]) { | |
| pass = false; | |
| printf("Mismatch at index %zu: expected %u, got %u\n", i, h_golden[i], h_out[i]); | |
| break; | |
| } | |
| } |
|
Triage review (2026-05-08): this is a large feature PR (21 files, ~16k additions) and needs a dedicated maintainer review rather than a quick merge. I did verify that the GitHub PR patch applies cleanly to current Requested before merge:
No merge-blocking conflict found in the triage pass, but the size and hardware dependence make this high-risk without checks and explicit validation artifacts. |
15 PTO-ISA kernels implementing the full MHC forward/backward pass from DeepSeek TileKernels, generated via PTO-DSL + ptoas pipeline.
Kernels: expand_to_mhc, head_compute_mix, pre_split_mixes, pre_apply_mix, pre_norm_fn, fn_normw_merge, post,
sinkhorn_normalize (fwd + bwd for each).
Tested on Ascend950PR (dav-3510) with CANN 9.0.