-
Notifications
You must be signed in to change notification settings - Fork 35
GEMM reference computation offload #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
ad748da
11e090b
9006224
3ecea7f
cafee59
86fbbac
54de3db
311ddfe
306e432
445e64f
462945f
e32fb3d
7bf8adb
e11e400
325ece6
f2f386c
5945897
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -51,11 +51,248 @@ using TShape = std::vector<size_t>; | |||||||||||||||||||||||||||||||||||||
| } // namespace | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float ref_gelu(float x){ | ||||||||||||||||||||||||||||||||||||||
| __device__ __host__ __forceinline__ float ref_gelu(float x){ | ||||||||||||||||||||||||||||||||||||||
| float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); | ||||||||||||||||||||||||||||||||||||||
| return x * cdf; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, | ||||||||||||||||||||||||||||||||||||||
| typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
| __global__ void compute_ref_kernel( | ||||||||||||||||||||||||||||||||||||||
| const A_Type* __restrict__ a_data, | ||||||||||||||||||||||||||||||||||||||
| const B_Type* __restrict__ b_data, | ||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_scalar, // used when mxfp8 == false | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* __restrict__ b_scale_inv_mxfp8, | ||||||||||||||||||||||||||||||||||||||
| const Bias_Type* __restrict__ bias_data, | ||||||||||||||||||||||||||||||||||||||
| float d_scale, | ||||||||||||||||||||||||||||||||||||||
| size_t m, size_t k, size_t n, | ||||||||||||||||||||||||||||||||||||||
| D_Type* __restrict__ d_data, | ||||||||||||||||||||||||||||||||||||||
| float* __restrict__ d_amax, | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* __restrict__ gelu_data, | ||||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb, | ||||||||||||||||||||||||||||||||||||||
| bool is_fp8_output) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||
| const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const bool in_range = (ii < m) && (jj < n); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float val = 0.0f; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (in_range) { | ||||||||||||||||||||||||||||||||||||||
| for (size_t kk = 0; kk < k; ++kk) { | ||||||||||||||||||||||||||||||||||||||
| const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); | ||||||||||||||||||||||||||||||||||||||
| const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_val = a_scale_inv_scalar; | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_val = b_scale_inv_scalar; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (a_scale_inv_mxfp8) { | ||||||||||||||||||||||||||||||||||||||
| const size_t a_scale_idx = | ||||||||||||||||||||||||||||||||||||||
| transa ? (a_idx / 32) : ((kk / 32) * m + ii); | ||||||||||||||||||||||||||||||||||||||
| const size_t b_scale_idx = | ||||||||||||||||||||||||||||||||||||||
| transb ? ((kk / 32) * n + jj) : (b_idx / 32); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float a_byte = static_cast<float>(a_scale_inv_mxfp8[a_scale_idx]); | ||||||||||||||||||||||||||||||||||||||
| const float b_byte = static_cast<float>(b_scale_inv_mxfp8[b_scale_idx]); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| a_scale_inv_val = exp2f(a_byte - 127.0f); | ||||||||||||||||||||||||||||||||||||||
| b_scale_inv_val = exp2f(b_byte - 127.0f); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float a_val = static_cast<float>(a_data[a_idx]); | ||||||||||||||||||||||||||||||||||||||
| const float b_val = static_cast<float>(b_data[b_idx]); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (bias_data) { | ||||||||||||||||||||||||||||||||||||||
| val += static_cast<float>(bias_data[ii]); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (gelu_data) { | ||||||||||||||||||||||||||||||||||||||
| gelu_data[ii + jj * m] = static_cast<Gelu_Type>(val); | ||||||||||||||||||||||||||||||||||||||
| val = ref_gelu(val); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const float scaled = val * d_scale; | ||||||||||||||||||||||||||||||||||||||
| d_data[ii + jj * m] = static_cast<D_Type>(scaled); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Blockwise reduction for amax | ||||||||||||||||||||||||||||||||||||||
| if (is_fp8_output && d_amax) { | ||||||||||||||||||||||||||||||||||||||
| const int tid = threadIdx.y * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||
| const int nthreads = blockDim.x * blockDim.y; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| extern __shared__ float s_amax[]; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Out-of-range threads contribute 0 | ||||||||||||||||||||||||||||||||||||||
| s_amax[tid] = in_range ? fabsf(val) : 0.0f; | ||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| for (int offset = nthreads / 2; offset > 0; offset /= 2) { | ||||||||||||||||||||||||||||||||||||||
| if (tid < offset) { | ||||||||||||||||||||||||||||||||||||||
| s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| __syncthreads(); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (tid == 0) { | ||||||||||||||||||||||||||||||||||||||
| const float block_max = s_amax[0]; | ||||||||||||||||||||||||||||||||||||||
| atomicMax(d_amax, block_max); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Common implementation used by both tensor-wise and MXFP8 frontends | ||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, | ||||||||||||||||||||||||||||||||||||||
| typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
| static void compute_ref_impl( | ||||||||||||||||||||||||||||||||||||||
| const A_Type* a_data, | ||||||||||||||||||||||||||||||||||||||
| const B_Type* b_data, | ||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_scalar, // used when mxfp8 == false | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true | ||||||||||||||||||||||||||||||||||||||
| const fp8e8m0* b_scale_inv_mxfp8, | ||||||||||||||||||||||||||||||||||||||
| const Bias_Type* bias_data, | ||||||||||||||||||||||||||||||||||||||
| float d_scale, | ||||||||||||||||||||||||||||||||||||||
| size_t m, size_t k, size_t n, | ||||||||||||||||||||||||||||||||||||||
| D_Type* d_data, | ||||||||||||||||||||||||||||||||||||||
| float* d_amax_host, | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* gelu_data, | ||||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| using transformer_engine::DType; | ||||||||||||||||||||||||||||||||||||||
| using ::TypeInfo; | ||||||||||||||||||||||||||||||||||||||
| using ::isFp8Type; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const DType dtype = TypeInfo<D_Type>::dtype; | ||||||||||||||||||||||||||||||||||||||
| const bool is_fp8_output = isFp8Type(dtype); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenA = m * k; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenB = k * n; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenD = m * n; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenBias = m; | ||||||||||||||||||||||||||||||||||||||
| const size_t lenGelu = m * n; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; | ||||||||||||||||||||||||||||||||||||||
|
wangye805 marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||
| const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| A_Type* dA = nullptr; | ||||||||||||||||||||||||||||||||||||||
| B_Type* dB = nullptr; | ||||||||||||||||||||||||||||||||||||||
| Bias_Type* dBias = nullptr; | ||||||||||||||||||||||||||||||||||||||
| D_Type* dD = nullptr; | ||||||||||||||||||||||||||||||||||||||
| Gelu_Type* dGelu = nullptr; | ||||||||||||||||||||||||||||||||||||||
| float* dAmax = nullptr; | ||||||||||||||||||||||||||||||||||||||
| fp8e8m0* dA_scale = nullptr; | ||||||||||||||||||||||||||||||||||||||
| fp8e8m0* dB_scale = nullptr; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Allocations and H2D transfers | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); | ||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can adapt existing test tensor classes ( TransformerEngine/tests/cpp/test_common.cu Line 226 in 669b556
In fact, we can change the api of reference computing by taking directly const tensor& therefore we don't need to re-allocate the input and do one extra copy
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of 3ecea7f? This also merges the mxfp8/non-mxfp8 paths.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for consolidating with existing apis in test_common.cu. In fact, I still see some cudaMalloc and cudaFree, which can be replaced by using existing test tensor class apis.
TransformerEngine/tests/cpp/test_common.cu Lines 321 to 335 in 2bc74c8
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I replaced the remaining raw allocations in the reference path with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Yeah, it indeed saved some cudaMalloc/cudaFrees. How about we put the RefD instantiation inside PerformTest, and pass the Tensor RefD (including its RefAmax D) and RefPreGeluOut to run_reference directly (instead of std::unique_ptr<D_Type[]>& ref_D, float* ref_amax_d, std::unique_ptr<Gelu_Type[]>& ref_pre_gelu_out). Then this can save some ref cpu ptr allocation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think of 325ece6? |
||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (bias_data) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| dBias, bias_data, lenBias * sizeof(Bias_Type), | ||||||||||||||||||||||||||||||||||||||
| cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (gelu_data) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (use_mxfp8) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), | ||||||||||||||||||||||||||||||||||||||
| cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), | ||||||||||||||||||||||||||||||||||||||
| cudaMemcpyHostToDevice)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (is_fp8_output && d_amax_host) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // Kernel launch | ||||||||||||||||||||||||||||||||||||||
| dim3 block(16, 16); | ||||||||||||||||||||||||||||||||||||||
|
alextmagro marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||||
| dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| const int nthreads = block.x * block.y; | ||||||||||||||||||||||||||||||||||||||
| size_t shmem_bytes = nthreads * sizeof(float); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| compute_ref_kernel<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type> | ||||||||||||||||||||||||||||||||||||||
| <<<grid, block, shmem_bytes, 0>>>( | ||||||||||||||||||||||||||||||||||||||
| dA, | ||||||||||||||||||||||||||||||||||||||
| dB, | ||||||||||||||||||||||||||||||||||||||
| a_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| b_scale_inv_scalar, | ||||||||||||||||||||||||||||||||||||||
| dA_scale, | ||||||||||||||||||||||||||||||||||||||
| dB_scale, | ||||||||||||||||||||||||||||||||||||||
| dBias, | ||||||||||||||||||||||||||||||||||||||
| d_scale, | ||||||||||||||||||||||||||||||||||||||
| m, k, n, | ||||||||||||||||||||||||||||||||||||||
| dD, | ||||||||||||||||||||||||||||||||||||||
| dAmax, | ||||||||||||||||||||||||||||||||||||||
| dGelu, | ||||||||||||||||||||||||||||||||||||||
| transa, | ||||||||||||||||||||||||||||||||||||||
| transb, | ||||||||||||||||||||||||||||||||||||||
| is_fp8_output); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaGetLastError()); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||||||||||||||||||||||||||||||||||||||
|
wangye805 marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // D2H copies | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (gelu_data) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), | ||||||||||||||||||||||||||||||||||||||
| cudaMemcpyDeviceToHost)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if (is_fp8_output && d_amax_host) { | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaMemcpy( | ||||||||||||||||||||||||||||||||||||||
| d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); | ||||||||||||||||||||||||||||||||||||||
| } else if (d_amax_host) { | ||||||||||||||||||||||||||||||||||||||
| *d_amax_host = 0.0f; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| // cleanup | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dA)); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dB)); | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dD)); | ||||||||||||||||||||||||||||||||||||||
| if (dBias) | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dBias)); | ||||||||||||||||||||||||||||||||||||||
| if (dGelu) | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dGelu)); | ||||||||||||||||||||||||||||||||||||||
| if (dAmax) | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dAmax)); | ||||||||||||||||||||||||||||||||||||||
| if (dA_scale) | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dA_scale)); | ||||||||||||||||||||||||||||||||||||||
| if (dB_scale) | ||||||||||||||||||||||||||||||||||||||
| NVTE_CHECK_CUDA(cudaFree(dB_scale)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
| void compute_ref( | ||||||||||||||||||||||||||||||||||||||
| const A_Type* a_data, | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -71,36 +308,21 @@ void compute_ref( | |||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb){ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float ref_d_amax = 0; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) | ||||||||||||||||||||||||||||||||||||||
| for(size_t ii = 0; ii < m; ii++){ | ||||||||||||||||||||||||||||||||||||||
| for(size_t jj = 0; jj < n; jj++){ | ||||||||||||||||||||||||||||||||||||||
| float val = 0; | ||||||||||||||||||||||||||||||||||||||
| for(size_t kk = 0; kk < k; kk++){ | ||||||||||||||||||||||||||||||||||||||
| float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m]; | ||||||||||||||||||||||||||||||||||||||
| float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k]; | ||||||||||||||||||||||||||||||||||||||
| val += a_scale_inv*a_val*b_scale_inv*b_val; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if(bias_data){ | ||||||||||||||||||||||||||||||||||||||
| val += (float)bias_data[ii]; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if(ref_gelu_data){ | ||||||||||||||||||||||||||||||||||||||
| ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); | ||||||||||||||||||||||||||||||||||||||
| val = ref_gelu(val); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); | ||||||||||||||||||||||||||||||||||||||
| // update ref_d_amax if in fp8 | ||||||||||||||||||||||||||||||||||||||
| DType dtype = TypeInfo<D_Type>::dtype; | ||||||||||||||||||||||||||||||||||||||
| if(isFp8Type(dtype)){ | ||||||||||||||||||||||||||||||||||||||
| ref_d_amax = std::max(ref_d_amax, std::fabs(val)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if (ref_d_amax_ptr) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| *ref_d_amax_ptr = ref_d_amax; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| compute_ref_impl<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||||||||||||||||||||||||||||||||||||||
| a_data, | ||||||||||||||||||||||||||||||||||||||
| b_data, | ||||||||||||||||||||||||||||||||||||||
| /*a_scale_inv_scalar=*/a_scale_inv, | ||||||||||||||||||||||||||||||||||||||
| /*b_scale_inv_scalar=*/b_scale_inv, | ||||||||||||||||||||||||||||||||||||||
| /*a_scale_inv_mxfp8=*/nullptr, | ||||||||||||||||||||||||||||||||||||||
| /*b_scale_inv_mxfp8=*/nullptr, | ||||||||||||||||||||||||||||||||||||||
| bias_data, | ||||||||||||||||||||||||||||||||||||||
| d_scale, | ||||||||||||||||||||||||||||||||||||||
| m, k, n, | ||||||||||||||||||||||||||||||||||||||
| ref_d_data, | ||||||||||||||||||||||||||||||||||||||
| ref_d_amax_ptr, | ||||||||||||||||||||||||||||||||||||||
| ref_gelu_data, | ||||||||||||||||||||||||||||||||||||||
| transa, | ||||||||||||||||||||||||||||||||||||||
| transb); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename A_Type, typename B_Type, typename Bias_Type, typename Gelu_Type, typename D_Type> | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -118,38 +340,21 @@ void compute_mxfp8_ref( | |||||||||||||||||||||||||||||||||||||
| bool transa, | ||||||||||||||||||||||||||||||||||||||
| bool transb){ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| float ref_d_amax = 0; | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) | ||||||||||||||||||||||||||||||||||||||
| for(size_t ii = 0; ii < m; ii++){ | ||||||||||||||||||||||||||||||||||||||
| for(size_t jj = 0; jj < n; jj++){ | ||||||||||||||||||||||||||||||||||||||
| float val = 0; | ||||||||||||||||||||||||||||||||||||||
| for(size_t kk = 0; kk < k; kk++){ | ||||||||||||||||||||||||||||||||||||||
| size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); | ||||||||||||||||||||||||||||||||||||||
| size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); | ||||||||||||||||||||||||||||||||||||||
| float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); | ||||||||||||||||||||||||||||||||||||||
| float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); | ||||||||||||||||||||||||||||||||||||||
| val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if(bias_data){ | ||||||||||||||||||||||||||||||||||||||
| val += (float)bias_data[ii]; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if(ref_gelu_data){ | ||||||||||||||||||||||||||||||||||||||
| ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); | ||||||||||||||||||||||||||||||||||||||
| val = ref_gelu(val); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); | ||||||||||||||||||||||||||||||||||||||
| // update ref_d_amax if in fp8 | ||||||||||||||||||||||||||||||||||||||
| DType dtype = TypeInfo<D_Type>::dtype; | ||||||||||||||||||||||||||||||||||||||
| if(isFp8Type(dtype)){ | ||||||||||||||||||||||||||||||||||||||
| ref_d_amax = std::max(ref_d_amax, std::fabs(val)); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if (ref_d_amax_ptr) | ||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||
| *ref_d_amax_ptr = ref_d_amax; | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| compute_ref_impl<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>( | ||||||||||||||||||||||||||||||||||||||
| a_data, | ||||||||||||||||||||||||||||||||||||||
| b_data, | ||||||||||||||||||||||||||||||||||||||
| /*a_scale_inv_scalar=*/1.0f, | ||||||||||||||||||||||||||||||||||||||
| /*b_scale_inv_scalar=*/1.0f, | ||||||||||||||||||||||||||||||||||||||
| /*a_scale_inv_mxfp8=*/a_scale_inv_data, | ||||||||||||||||||||||||||||||||||||||
| /*b_scale_inv_mxfp8=*/b_scale_inv_data, | ||||||||||||||||||||||||||||||||||||||
| bias_data, | ||||||||||||||||||||||||||||||||||||||
| d_scale, | ||||||||||||||||||||||||||||||||||||||
| m, k, n, | ||||||||||||||||||||||||||||||||||||||
| ref_d_data, | ||||||||||||||||||||||||||||||||||||||
| ref_d_amax_ptr, | ||||||||||||||||||||||||||||||||||||||
| ref_gelu_data, | ||||||||||||||||||||||||||||||||||||||
| transa, | ||||||||||||||||||||||||||||||||||||||
| transb); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| template <typename Type> | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -371,7 +576,7 @@ void performTest(const TestParams& params) { | |||||||||||||||||||||||||||||||||||||
| pre_gelu_out.to_cpu(); | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| //perform the gemm in CPU | ||||||||||||||||||||||||||||||||||||||
| //perform the reference gemm on GPU | ||||||||||||||||||||||||||||||||||||||
| std::unique_ptr<D_Type[]> ref_D = std::make_unique<D_Type[]>(params.m*params.n); | ||||||||||||||||||||||||||||||||||||||
| std::unique_ptr<Gelu_Type[]> ref_pre_gelu_out; | ||||||||||||||||||||||||||||||||||||||
| if(params.use_gelu){ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.