Skip to content

Add MHC (Multi-Head Computation) kernels for Ascend A5#104

Open
furljq wants to merge 1 commit into
hw-native-sys:mainfrom
furljq:mhc-kernels
Open

Add MHC (Multi-Head Computation) kernels for Ascend A5#104
furljq wants to merge 1 commit into
hw-native-sys:mainfrom
furljq:mhc-kernels

Conversation

@furljq
Copy link
Copy Markdown

@furljq furljq commented Apr 29, 2026

15 PTO-ISA kernels implementing the full MHC forward/backward pass from DeepSeek TileKernels, generated via PTO-DSL + ptoas pipeline.

Kernels: expand_to_mhc, head_compute_mix, pre_split_mixes, pre_apply_mix, pre_norm_fn, fn_normw_merge, post,
sinkhorn_normalize (fwd + bwd for each).

Tested on Ascend950PR (dav-3510) with CANN 9.0.

15 PTO-ISA kernels implementing the full MHC forward/backward pass
from DeepSeek TileKernels, generated via PTO-DSL + ptoas pipeline.

Kernels: expand_to_mhc, head_compute_mix, pre_split_mixes,
pre_apply_mix, pre_norm_fn, fn_normw_merge, post,
sinkhorn_normalize (fwd + bwd for each).

Tested on Ascend950PR (dav-3510) with CANN 9.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces 15 manual PTO-ISA kernels for the Multi-Head Computation (MHC) architecture, optimized for Ascend A5 hardware, along with a CMake build system and a host-side test suite. The reviewer provided several actionable suggestions to improve the code quality, including adding error handling for ACL API calls, making hardcoded paths configurable, and optimizing the build process to prevent redundant compilation. Additionally, the feedback included specific code suggestions to use sizeof for memory allocations and to implement a more robust element-wise verification loop for floating-point results instead of using memcmp.

Comment on lines +48 to +52
aclInit(nullptr);
aclrtSetDevice(0);
void *stream = nullptr;
aclrtCreateStream(&stream);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return values of ACL initialization and runtime functions (aclInit, aclrtSetDevice, aclrtCreateStream) are not checked for errors. If any of these calls fail, the program will continue and likely crash when attempting to use the device or stream. It is recommended to add error handling for all ACL API calls.

set(ASCEND_HOME_PATH $ENV{ASCEND_HOME_PATH})
endif()

set(ASCEND_DRIVER_PATH /usr/local/Ascend/driver)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path to the Ascend driver is hardcoded to /usr/local/Ascend/driver. This can cause build failures on environments where the driver is installed in a different location. It is recommended to use an environment variable or a CMake cache variable to allow users to specify the path.

Comment on lines +54 to +57
foreach(KERNEL ${MHC_KERNELS})
add_library(${KERNEL}_kernel SHARED ${KERNEL}.cpp)
target_compile_options(${KERNEL}_kernel PRIVATE ${CMAKE_CCE_COMPILE_OPTIONS} --npu-arch=dav-3510 -DMEMORY_BASE)
endforeach()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The kernels are being compiled into individual shared libraries here, but expand_to_mhc_fwd.cpp is also included and compiled again within the mhc_caller library (via caller.cpp). This redundant compilation increases build time and can lead to symbol conflicts. Consider compiling the kernels once and linking them to the caller or the test executable.

Comment on lines +66 to +68
aclrtMalloc(&d_x, x_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&d_out, out_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(d_x, x_elems * 2, h_x.data(), x_elems * 2, ACL_MEMCPY_HOST_TO_DEVICE);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The size calculation uses a hardcoded factor of 2 to represent the size of a bf16 element. Using sizeof(uint16_t) would be more descriptive and less prone to errors if the data type changes.

Suggested change
aclrtMalloc(&d_x, x_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&d_out, out_elems * 2, ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(d_x, x_elems * 2, h_x.data(), x_elems * 2, ACL_MEMCPY_HOST_TO_DEVICE);
aclrtMalloc(&d_x, x_elems * sizeof(uint16_t), ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMalloc(&d_out, out_elems * sizeof(uint16_t), ACL_MEM_MALLOC_HUGE_FIRST);
aclrtMemcpy(d_x, x_elems * sizeof(uint16_t), h_x.data(), x_elems * sizeof(uint16_t), ACL_MEMCPY_HOST_TO_DEVICE);

aclrtMemcpy(h_out.data(), out_elems * 2, d_out, out_elems * 2, ACL_MEMCPY_DEVICE_TO_HOST);

/* verify */
bool pass = (memcmp(h_out.data(), h_golden.data(), out_elems * 2) == 0);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using memcmp for verifying floating-point results is not robust. While it might pass for simple bit-exact operations like broadcast, it will fail for kernels involving arithmetic where small precision differences can occur. It is better to iterate through the results and compare them individually, which also allows for better error reporting in case of a mismatch.

Suggested change
bool pass = (memcmp(h_out.data(), h_golden.data(), out_elems * 2) == 0);
bool pass = true;
for (size_t i = 0; i < h_out.size(); ++i) {
if (h_out[i] != h_golden[i]) {
pass = false;
printf("Mismatch at index %zu: expected %u, got %u\n", i, h_golden[i], h_out[i]);
break;
}
}

@zhoubot
Copy link
Copy Markdown
Collaborator

zhoubot commented May 8, 2026

Triage review (2026-05-08): this is a large feature PR (21 files, ~16k additions) and needs a dedicated maintainer review rather than a quick merge. I did verify that the GitHub PR patch applies cleanly to current main and does not raise whitespace warnings in my triage check.

Requested before merge:

  • CI/pre-commit should run on the branch; currently no checks are reported;
  • please add exact validation commands/results for the 15 kernels, including which CANN/device version and whether forward and backward outputs were numerically checked;
  • consider splitting if possible: directory/index plumbing first, then generated/ported MHC kernels, then benchmark/validation evidence;
  • ensure any generated PTO-DSL output is intentionally committed and not rebuild-only output.

No merge-blocking conflict found in the triage pass, but the size and hardware dependence make this high-risk without checks and explicit validation artifacts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants