-
Notifications
You must be signed in to change notification settings - Fork 60
[T1-3-1] GushanFall #387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[T1-3-1] GushanFall #387
Conversation
| hcMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); | ||
| hcMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), hcMemcpyHostToDevice); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 calculate 里面为什么还需要 malloc 空间呢?理论上来说 calculate 阶段应该都是计算了,空间分配在 workspace 的时候应该就分配好了
| cudaMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); | ||
| cudaMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), cudaMemcpyHostToDevice); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(同理)calculate 阶段 malloc 空间
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那检查一下写入和读取时的偏移是不是都无误吧
| hcMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); | ||
| hcMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), hcMemcpyHostToDevice); | ||
| mask_input = mask_temp; | ||
| } else { | ||
| mask_input = mask; | ||
| } | ||
| } | ||
|
|
||
| size_t T_r = ceil(float(seq_len_q) / B_r); | ||
| size_t T_c = ceil(float(seq_len_kv) / B_c); | ||
|
|
||
| auto hc_stream = reinterpret_cast<hcStream_t>(stream); | ||
|
|
||
| void *out, *l; | ||
| if (_info.dtype == INFINI_DTYPE_F16) { | ||
| hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(half)); | ||
| hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(half)); | ||
| } else if (_info.dtype == INFINI_DTYPE_F32) { | ||
| hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(float)); | ||
| hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(float)); | ||
| } else if (_info.dtype == INFINI_DTYPE_BF16) { | ||
| hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(__hpcc_bfloat16)); | ||
| hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(__hpcc_bfloat16)); | ||
| } else { | ||
| return INFINI_STATUS_BAD_TENSOR_DTYPE; | ||
| } | ||
|
|
||
| CHECK_STATUS(launchForwardKernel( | ||
| out, l, q, k, v, mask_input, | ||
| batch_size, | ||
| nums_head_q, nums_head_kv, | ||
| seq_len_q, seq_len_kv, | ||
| head_dim, group, | ||
| B_r, B_c, T_r, T_c, | ||
| _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, | ||
| _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, | ||
| _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, | ||
| _info.dtype, | ||
| hc_stream)); | ||
|
|
||
| void *grad_k_expanded, *grad_v_expanded; | ||
| if (_info.dtype == INFINI_DTYPE_F16) { | ||
| hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); | ||
| hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); | ||
| } else if (_info.dtype == INFINI_DTYPE_F32) { | ||
| hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); | ||
| hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); | ||
| } else if (_info.dtype == INFINI_DTYPE_BF16) { | ||
| hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__hpcc_bfloat16)); | ||
| hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__hpcc_bfloat16)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(同理)calculate 阶段 malloc 空间
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边是很粗糙的实现方法,grad_k 和 grad_v 是在每块计算中累加进行迭代的,原本我想在 kernel 中就算出最后的值,但是会因为不同线程间相互覆盖导致算不出正确结果,没想出很好的解决办法,所以只好把每个都记录下来,单独用一个核函数来做累加了
| ((1, 10, 2, 4), (1, 10, 2, 4), 0), | ||
| ((4, 10, 8, 4), (4, 10, 2, 4), 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加两个稍大一点(正常模型中会出现规模的) testcases 吧
| ((1, 10, 2, 4), (1, 10, 2, 4), 0), | ||
| ((4, 10, 8, 4), (4, 10, 2, 4), 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加两个稍大一点(正常模型中会出现规模的) testcases 吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from .utils import * | ||
| from .datatypes import * | ||
| from .structs import * | ||
| from .masktypes import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
末尾空一行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改
| class infiniopAttentionMaskType: | ||
| NONE = 0 | ||
| FULL = 1 | ||
| CAUSAL = 2 | ||
|
|
||
|
|
||
| InfiniopAttentionMaskTypeNames = { | ||
| infiniopAttentionMaskType.NONE: "NONE", | ||
| infiniopAttentionMaskType.FULL: "FULL", | ||
| infiniopAttentionMaskType.CAUSAL: "CAUSAL", | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可能不必要 直接放在 flash_attention.py 里就行,参考最新的 rope
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改




















完成
T1-3-1赛题所有算子的实现代码、Pytorch单元测试代码以及GGUF测试代码。算子设计文档详见设计文档PR:InfiniTensor/InfiniCore-Documentation#48
测试情况
NVIDIA 平台
FlashAttention
实现的 FlashAttention-2,并支持GQA、任意Mask、不同的序列长度。
FlashAttentionBackward
只实现了基础计算,且有未能解决的问题。
METAX 平台
FlashAttention
实现的 FlashAttention-2,并支持GQA、任意Mask、不同的序列长度。
FlashAttentionBackward
只实现了基础计算,且有未能解决的问题。
ILUVATAR 平台
FlashAttention
实现的 FlashAttention-2,并支持GQA、任意Mask、不同的序列长度。
FlashAttentionBackward
只实现了基础计算,且有未能解决的问题。