-
Notifications
You must be signed in to change notification settings - Fork 757
CPU Overhead Optimizations #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
93ee022
06338bc
50de9cd
5fee841
4c79ac7
62b88e1
99494d7
b157f85
2a7b627
b61a6a8
938651e
88dfdbd
1526eea
5809dcc
b3bd748
30fecf2
1b0d497
138b7bf
eec1e86
8169d9c
6fefaf2
a5feaf9
285dbff
afb2f23
3919cb8
fd36424
4668133
5a00652
739bbad
e8042c1
1d323d7
da7fbf5
e2c7435
f4e2492
beada36
06a72a2
5d21db2
1dfd6fe
8c8dd20
c1acd62
7f35b0b
ca177ae
1538fd9
710b581
8a57a75
7e4f093
8604b69
de44954
cc50745
f2e9a5d
0d75c3e
3d9f673
53e8e4e
c746abd
9c922f5
88b782a
5562cbe
14adf1a
1e28aa8
c651d65
e07b5b3
3f2da29
06ac237
24a8f3d
853ddd5
3db390d
aaf5347
0bf040f
b7d9693
15165b7
369afeb
cb73444
c7bb5ce
f934261
1843f02
89d8d82
63509e6
a77195a
4e92a46
73e4d1d
8282aca
e52a12d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |
| ret.Atype = A.data.dtype; | ||
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | ||
| if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { | ||
| // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. | ||
| if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { | ||
| ret.A = A.columnwise_data.dptr; | ||
|
|
@@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |
| } else { | ||
| NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); | ||
| } | ||
| } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { | ||
| } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { | ||
| // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed | ||
| // data with the mirrored transpose-flag if we don't have row-wise data. | ||
| NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), | ||
|
|
@@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |
| ret.Btype = B.data.dtype; | ||
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | ||
|
vthumbe1503 marked this conversation as resolved.
Outdated
vthumbe1503 marked this conversation as resolved.
Outdated
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [P0] Variable |
||
| if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { | ||
| // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. | ||
| if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { | ||
| ret.B = B.columnwise_data.dptr; | ||
|
|
@@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |
| } else { | ||
| NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); | ||
| } | ||
| } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { | ||
| } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { | ||
| // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed | ||
| // data with the mirrored transpose-flag if we don't have row-wise data. | ||
| NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { | |
| } | ||
|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making If Consider: static int num_devices = transformer_engine::cuda::num_devices();This is initialized once, but A safer approach might be: static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
cache.resize(num_devices, -1);
flags.resize(num_devices);
});Or simply document that the device count must not change during the application's lifetime.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Potential out-of-bounds access with static num_devices Making This will cause out-of-bounds access on lines 968 and 975: std::call_once(flags[device_id], ...); // OOB if device_id >= num_devices
return cache[device_id]; // OOB if device_id >= num_devicesImpact: Undefined behavior, potential crashes, memory corruption. Fix: Query device count each time, or add bounds checking: int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;
int device_id = transformer_engine::cuda::current_device();
{
std::lock_guard<std::mutex> lock(resize_mutex);
if (device_id >= cache.size()) {
cache.resize(device_id + 1, -1);
flags.resize(device_id + 1);
}
}
std::call_once(flags[device_id], [&]() { ... });
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential race condition in static variable initialization Making Scenario:
However, this is likely safe in practice because:
But for correctness, consider wrapping in the existing static int num_devices = []() {
return transformer_engine::cuda::num_devices();
}();Or initialize it within the |
||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); | ||
|
vthumbe1503 marked this conversation as resolved.
Outdated
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical bug: Potential out-of-bounds access when Making
Fix: Either:
|
||
| int device_id = transformer_engine::cuda::current_device(); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.