Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ struct QuantizationConfig {
bool nvfp4_2d_quantization = false;
bool stochastic_rounding = false;
bool use_fast_math = false;
NVTETensor tile_scheduler_workspace = nullptr;
Comment thread
zhongbozhu marked this conversation as resolved.
Outdated

static constexpr size_t attr_sizes[] = {
sizeof(bool), // force_pow_2_scales
Expand All @@ -404,7 +405,8 @@ struct QuantizationConfig {
sizeof(NVTETensor), // rng_seed and offset
sizeof(bool), // nvfp4_2d_quantization
sizeof(bool), // stochastic_rounding
sizeof(bool) // use_fast_math
sizeof(bool), // use_fast_math
sizeof(NVTETensor) // tile_scheduler_workspace
};
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1125,8 +1125,9 @@ template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableR
void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_size, TA const *A,
TB const *B, TQA *QA, TSFA *SFA,
MultiAmaxHadamardCastFusionArgs &args,
const size_t *rng_state, uint32_t sm_count,
cudaStream_t stream, int k_tile_size = 1024) {
const size_t *rng_state, uint32_t *tile_scheduler_workspace,
uint32_t sm_count, cudaStream_t stream,
int k_tile_size = 1024) {
using namespace cute;
static int constexpr SFVecSize = 16;
static int constexpr RhtTensorSize = 16;
Expand Down Expand Up @@ -1295,10 +1296,9 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

// Allocate workspace and set to zero
void *tile_scheduler_workspace = nullptr;
NVTE_CHECK_CUDA(cudaMallocAsync(&tile_scheduler_workspace, sizeof(uint32_t), stream));
NVTE_CHECK_CUDA(cudaMemsetAsync(tile_scheduler_workspace, 0, sizeof(uint32_t), stream));
// Set workspace and set to zero
NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast<void *>(tile_scheduler_workspace), 0,
sizeof(uint32_t), stream));

// Launch kernel
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
Expand All @@ -1308,8 +1308,6 @@ void group_row_col_rht_gemm_ntt_w_sfc(int packed_sequence_length, int hidden_siz
tile_scheduler_workspace, mma, rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed.");

NVTE_CHECK_CUDA(cudaFreeAsync(tile_scheduler_workspace, stream));
}

} // namespace
Expand Down Expand Up @@ -1399,6 +1397,17 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}

uint32_t *tile_scheduler_workspace = nullptr;
NVTE_CHECK(quant_config.tile_scheduler_workspace != nullptr,
"Tile scheduler workspace must be provided.");
Tensor &tile_scheduler_workspace_tensor =
*convertNVTETensorCheck(quant_config.tile_scheduler_workspace);
NVTE_CHECK(tile_scheduler_workspace_tensor.dtype() == DType::kInt32 &&
tile_scheduler_workspace_tensor.data.shape == std::vector<size_t>{1},
"Tile scheduler workspace must be a tensor with shape [1] and dtype int32.");
tile_scheduler_workspace =
reinterpret_cast<uint32_t *>(tile_scheduler_workspace_tensor.data.dptr);

// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
Expand Down Expand Up @@ -1461,7 +1470,9 @@ void group_hadamard_transform_cast_fusion(const Tensor &input_, std::vector<Tens
/*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr),
/*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr),
/*args=*/kernel_args,
/*rng_state=*/rng_state, /*sm_count=*/sm_count,
/*rng_state=*/rng_state,
/*tile_scheduler_workspace=*/tile_scheduler_workspace,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer a more generic workspace name to be honest. Proper handling of this would also require having some function that would return size of the required workspace.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from API level, it's called quant_workspace now

/*sm_count=*/sm_count,
/*stream=*/stream, /*k_tile_size=*/k_tile_size);
} else {
NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ enum NVTEQuantizationConfigAttribute {
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
/*! Tile scheduler workspace (NVTETensor with 1 uint32_t element) */
kNVTEQuantizationConfigTileSchedulerWorkspace = 8,
kNVTEQuantizationConfigNumAttributes
};

Expand Down Expand Up @@ -1009,6 +1011,12 @@ class QuantizationConfigWrapper {
&use_fast_math, sizeof(bool));
}

/*! \brief Set tile scheduler workspace */
void set_tile_scheduler_workspace(NVTETensor tile_scheduler_workspace) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigTileSchedulerWorkspace,
&tile_scheduler_workspace, sizeof(NVTETensor));
}

private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(buf, &config_.use_fast_math, attr_size);
break;
case kNVTEQuantizationConfigTileSchedulerWorkspace:
std::memcpy(buf, &config_.tile_scheduler_workspace, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
Expand Down Expand Up @@ -949,6 +952,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigUseFastMath:
std::memcpy(&config_.use_fast_math, buf, attr_size);
break;
case kNVTEQuantizationConfigTileSchedulerWorkspace:
std::memcpy(&config_.tile_scheduler_workspace, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,13 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input,
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer.rht_matrix);

if (all_aligned_token_dim) {
// allocate a tile scheduler workspace
auto tile_scheduler_workspace_torch =
at::empty({1}, at::device(at::kCUDA).dtype(torch::kInt32));
auto nvte_tile_scheduler_workspace =
makeTransformerEngineTensor(tile_scheduler_workspace_torch);
// assign the workspace tensor
quant_config_list[0].set_tile_scheduler_workspace(nvte_tile_scheduler_workspace.data());
// call the fully-fused grouped kernel for rowwise quantization & colwise RHT quantization transpose
nvte_group_hadamard_transform_cast_fusion(
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
Expand Down
Loading