diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index d2eb1a831c..d785e6a3c0 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -226,13 +226,35 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } // select a backend for fused attention -NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig *cfg) { using namespace transformer_engine; + NVTE_CHECK(cfg != nullptr, "NVTEFusedAttnConfig pointer must not be null."); + NVTE_CHECK(cfg->struct_size >= sizeof(NVTEFusedAttnConfig), + "NVTEFusedAttnConfig::struct_size is smaller than the library expects; " + "did you forget NVTE_FUSED_ATTN_CONFIG_INIT?"); + + // Bind config fields to the local names that the original implementation + // below was written against, so the body can stay as-is. + const bool is_training = cfg->is_training; + const NVTEDType q_dtype = cfg->q_dtype; + const NVTEDType kv_dtype = cfg->kv_dtype; + const NVTE_QKV_Layout qkv_layout = cfg->qkv_layout; + const NVTE_Bias_Type bias_type = cfg->bias_type; + const NVTE_Mask_Type attn_mask_type = cfg->attn_mask_type; + const NVTE_Softmax_Type softmax_type = cfg->softmax_type; + const float dropout = cfg->dropout; + const size_t num_attn_heads = cfg->num_attn_heads; + const size_t num_gqa_groups = cfg->num_gqa_groups; + const size_t max_seqlen_q = cfg->max_seqlen_q; + const size_t max_seqlen_kv = cfg->max_seqlen_kv; + const size_t head_dim_qk = cfg->head_dim_qk; + const size_t head_dim_v = cfg->head_dim_v; + const int64_t window_size_left = cfg->window_size_left; + const int64_t window_size_right = cfg->window_size_right; + const bool return_max_logit = cfg->return_max_logit; + const bool cuda_graph = cfg->cuda_graph; + const bool deterministic = cfg->deterministic; + NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); @@ -525,38 +547,81 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( return backend; } +// Deprecated: forwards to nvte_get_fused_attn_backend_v2. +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + cfg.qkv_layout = qkv_layout; + cfg.bias_type = bias_type; + cfg.attn_mask_type = attn_mask_type; + cfg.softmax_type = softmax_type; + cfg.dropout = dropout; + cfg.max_seqlen_q = max_seqlen_q; + cfg.max_seqlen_kv = max_seqlen_kv; + cfg.window_size_left = window_size_left; + cfg.window_size_right = window_size_right; + cfg.cuda_graph = cuda_graph; + cfg.q_dtype = q_dtype; + cfg.kv_dtype = kv_dtype; + cfg.num_attn_heads = num_attn_heads; + cfg.num_gqa_groups = num_gqa_groups; + cfg.head_dim_qk = head_dim_qk; + cfg.head_dim_v = head_dim_v; + cfg.is_training = is_training; + cfg.return_max_logit = return_max_logit; + cfg.deterministic = deterministic; + return nvte_get_fused_attn_backend_v2(&cfg); +} + // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_v2(const NVTEFusedAttnFwdParams *params) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_page_table_k = convertNVTETensorCheck(page_table_k); - const Tensor *input_page_table_v = convertNVTETensorCheck(page_table_v); - const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_Bias = convertNVTETensorCheck(Bias); - const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); - Tensor *input_output_S = convertNVTETensorCheck(S); - Tensor *output_O = convertNVTETensorCheck(O); - Tensor *wkspace = convertNVTETensor(workspace); + NVTE_CHECK(params != nullptr, "NVTEFusedAttnFwdParams pointer must not be null."); + NVTE_CHECK(params->struct_size >= sizeof(NVTEFusedAttnFwdParams), + "NVTEFusedAttnFwdParams::struct_size is smaller than the library expects; " + "did you forget NVTE_FUSED_ATTN_FWD_PARAMS_INIT?"); + + // Bind struct fields to the local names that the original implementation + // below was written against. + const NVTE_QKV_Layout qkv_layout = params->qkv_layout; + const NVTE_QKV_Format o_format = params->o_format; + const NVTE_QKV_Format qkv_scale_inv_format = params->qkv_scale_inv_format; + const NVTE_Bias_Type bias_type = params->bias_type; + const NVTE_Mask_Type attn_mask_type = params->attn_mask_type; + const NVTE_Softmax_Type softmax_type = params->softmax_type; + const size_t max_seqlen_q = params->max_seqlen_q; + const size_t max_seqlen_kv = params->max_seqlen_kv; + const float attn_scale = params->attn_scale; + const float dropout = params->dropout; + const int64_t window_size_left = params->window_size_left; + const int64_t window_size_right = params->window_size_right; + const bool bottom_right_diagonal = params->bottom_right_diagonal; + const bool is_training = params->is_training; + const bool return_max_logit = params->return_max_logit; + const bool cuda_graph = params->cuda_graph; + NVTETensorPack *Aux_CTX_Tensors = params->Aux_CTX_Tensors; + cudaStream_t stream = params->stream; + + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(params->cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(params->cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(params->cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(params->cu_seqlens_kv_padded); + const Tensor *input_page_table_k = convertNVTETensorCheck(params->page_table_k); + const Tensor *input_page_table_v = convertNVTETensorCheck(params->page_table_v); + const Tensor *input_rng_state = convertNVTETensorCheck(params->rng_state); + const Tensor *input_Q = convertNVTETensorCheck(params->Q); + const Tensor *input_K = convertNVTETensorCheck(params->K); + const Tensor *input_V = convertNVTETensorCheck(params->V); + const Tensor *input_Bias = convertNVTETensorCheck(params->Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(params->SoftmaxOffset); + Tensor *input_output_S = convertNVTETensorCheck(params->S); + Tensor *output_O = convertNVTETensorCheck(params->O); + Tensor *wkspace = convertNVTETensor(params->workspace); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); @@ -607,11 +672,54 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(output_O->data.dtype); + + size_t bias_batch_size = 0, bias_num_heads = 0, bias_seqlen_q = 0, bias_seqlen_kv = 0; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && + input_Bias->data.dptr != nullptr && input_Bias->data.shape.size() >= 4) { + bias_batch_size = input_Bias->data.shape[0]; + bias_num_heads = input_Bias->data.shape[1]; + bias_seqlen_q = input_Bias->data.shape[2]; + bias_seqlen_kv = input_Bias->data.shape[3]; + } - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, false); + NVTEFusedAttnConfig backend_cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + backend_cfg.qkv_layout = qkv_layout; + backend_cfg.o_format = o_format; + backend_cfg.qkv_scale_inv_format = qkv_scale_inv_format; + backend_cfg.bias_type = bias_type; + backend_cfg.attn_mask_type = attn_mask_type; + backend_cfg.softmax_type = softmax_type; + backend_cfg.attn_scale = attn_scale; + backend_cfg.dropout = dropout; + backend_cfg.max_seqlen_q = max_seqlen_q; + backend_cfg.max_seqlen_kv = max_seqlen_kv; + backend_cfg.window_size_left = window_size_left; + backend_cfg.window_size_right = window_size_right; + backend_cfg.bottom_right_diagonal = bottom_right_diagonal; + backend_cfg.cuda_graph = cuda_graph; + backend_cfg.q_dtype = Q_type; + backend_cfg.kv_dtype = KV_type; + backend_cfg.o_dtype = O_type; + backend_cfg.batch_size = b; + backend_cfg.num_attn_heads = h_q; + backend_cfg.num_gqa_groups = h_kv; + backend_cfg.head_dim_qk = d_qk; + backend_cfg.head_dim_v = d_v; + backend_cfg.num_pages_k = num_pages_k; + backend_cfg.num_pages_v = num_pages_v; + backend_cfg.page_size_k = page_size_k; + backend_cfg.page_size_v = page_size_v; + backend_cfg.max_pages_per_seq_k = max_pages_per_seq_k; + backend_cfg.max_pages_per_seq_v = max_pages_per_seq_v; + backend_cfg.bias_batch_size = bias_batch_size; + backend_cfg.bias_num_heads = bias_num_heads; + backend_cfg.bias_seqlen_q = bias_seqlen_q; + backend_cfg.bias_seqlen_kv = bias_seqlen_kv; + backend_cfg.is_training = is_training; + backend_cfg.return_max_logit = return_max_logit; + backend_cfg.deterministic = false; + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend_v2(&backend_cfg); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( @@ -633,41 +741,107 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } } -// NVTE fused attention BWD with separate Q, K and V -void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + +// Deprecated: forwards to nvte_fused_attn_fwd_v2. +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, - size_t max_seqlen_kv, float attn_scale, float dropout, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, - NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, - bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + NVTE_QKV_Format qkv_scale_inv_format, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) { + NVTEFusedAttnFwdParams params = NVTE_FUSED_ATTN_FWD_PARAMS_INIT; + params.Q = Q; + params.K = K; + params.V = V; + params.Bias = Bias; + params.SoftmaxOffset = SoftmaxOffset; + params.cu_seqlens_q = cu_seqlens_q; + params.cu_seqlens_kv = cu_seqlens_kv; + params.cu_seqlens_q_padded = cu_seqlens_q_padded; + params.cu_seqlens_kv_padded = cu_seqlens_kv_padded; + params.page_table_k = page_table_k; + params.page_table_v = page_table_v; + params.rng_state = rng_state; + params.S = S; + params.O = O; + params.Aux_CTX_Tensors = Aux_CTX_Tensors; + params.max_seqlen_q = max_seqlen_q; + params.max_seqlen_kv = max_seqlen_kv; + params.qkv_layout = qkv_layout; + params.o_format = o_format; + params.qkv_scale_inv_format = qkv_scale_inv_format; + params.bias_type = bias_type; + params.attn_mask_type = attn_mask_type; + params.softmax_type = softmax_type; + params.attn_scale = attn_scale; + params.dropout = dropout; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.bottom_right_diagonal = bottom_right_diagonal; + params.is_training = is_training; + params.return_max_logit = return_max_logit; + params.cuda_graph = cuda_graph; + params.workspace = workspace; + params.stream = stream; + nvte_fused_attn_fwd_v2(¶ms); +} +// NVTE fused attention BWD with separate Q, K and V +void nvte_fused_attn_bwd_v2(const NVTEFusedAttnBwdParams *params) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; - const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); - const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); - const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); - const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); - const Tensor *input_Q = convertNVTETensorCheck(Q); - const Tensor *input_K = convertNVTETensorCheck(K); - const Tensor *input_V = convertNVTETensorCheck(V); - const Tensor *input_O = convertNVTETensorCheck(O); - const Tensor *input_dO = convertNVTETensorCheck(dO); - const Tensor *input_S = convertNVTETensorCheck(S); - Tensor *input_output_dP = convertNVTETensorCheck(dP); - Tensor *output_dQ = convertNVTETensorCheck(dQ); - Tensor *output_dK = convertNVTETensorCheck(dK); - Tensor *output_dV = convertNVTETensorCheck(dV); - Tensor *output_dBias = convertNVTETensorCheck(dBias); - Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); - Tensor *wkspace = convertNVTETensor(workspace); + NVTE_CHECK(params != nullptr, "NVTEFusedAttnBwdParams pointer must not be null."); + NVTE_CHECK(params->struct_size >= sizeof(NVTEFusedAttnBwdParams), + "NVTEFusedAttnBwdParams::struct_size is smaller than the library expects; " + "did you forget NVTE_FUSED_ATTN_BWD_PARAMS_INIT?"); + + // Bind struct fields to the local names that the original implementation + // below was written against. + const NVTE_QKV_Layout qkv_layout = params->qkv_layout; + const NVTE_QKV_Layout dqkv_layout = params->dqkv_layout; + const NVTE_QKV_Format o_format = params->o_format; + const NVTE_QKV_Format do_format = params->do_format; + const NVTE_QKV_Format qkv_scale_inv_format = params->qkv_scale_inv_format; + const NVTE_QKV_Format do_scale_inv_format = params->do_scale_inv_format; + const NVTE_Bias_Type bias_type = params->bias_type; + const NVTE_Mask_Type attn_mask_type = params->attn_mask_type; + const NVTE_Softmax_Type softmax_type = params->softmax_type; + const size_t max_seqlen_q = params->max_seqlen_q; + const size_t max_seqlen_kv = params->max_seqlen_kv; + const float attn_scale = params->attn_scale; + const float dropout = params->dropout; + const int64_t window_size_left = params->window_size_left; + const int64_t window_size_right = params->window_size_right; + const bool bottom_right_diagonal = params->bottom_right_diagonal; + const bool deterministic = params->deterministic; + const bool cuda_graph = params->cuda_graph; + const NVTETensorPack *Aux_CTX_Tensors = params->Aux_CTX_Tensors; + cudaStream_t stream = params->stream; + + const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(params->cu_seqlens_q); + const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(params->cu_seqlens_kv); + const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(params->cu_seqlens_q_padded); + const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(params->cu_seqlens_kv_padded); + const Tensor *input_Q = convertNVTETensorCheck(params->Q); + const Tensor *input_K = convertNVTETensorCheck(params->K); + const Tensor *input_V = convertNVTETensorCheck(params->V); + const Tensor *input_O = convertNVTETensorCheck(params->O); + const Tensor *input_dO = convertNVTETensorCheck(params->dO); + const Tensor *input_S = convertNVTETensorCheck(params->S); + Tensor *input_output_dP = convertNVTETensorCheck(params->dP); + Tensor *output_dQ = convertNVTETensorCheck(params->dQ); + Tensor *output_dK = convertNVTETensorCheck(params->dK); + Tensor *output_dV = convertNVTETensorCheck(params->dV); + Tensor *output_dBias = convertNVTETensorCheck(params->dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(params->dSoftmaxOffset); + Tensor *wkspace = convertNVTETensor(params->workspace); NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); @@ -688,11 +862,55 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(input_O->data.dtype); + const NVTEDType dO_type = static_cast(input_dO->data.dtype); + const NVTEDType dQKV_type = static_cast(output_dQ->data.dtype); + + size_t bias_batch_size = 0, bias_num_heads = 0, bias_seqlen_q = 0, bias_seqlen_kv = 0; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && + output_dBias->data.dptr != nullptr && output_dBias->data.shape.size() >= 4) { + bias_batch_size = output_dBias->data.shape[0]; + bias_num_heads = output_dBias->data.shape[1]; + bias_seqlen_q = output_dBias->data.shape[2]; + bias_seqlen_kv = output_dBias->data.shape[3]; + } - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + NVTEFusedAttnConfig backend_cfg = NVTE_FUSED_ATTN_CONFIG_INIT; + backend_cfg.qkv_layout = qkv_layout; + backend_cfg.o_format = o_format; + backend_cfg.do_format = do_format; + backend_cfg.dqkv_layout = dqkv_layout; + backend_cfg.qkv_scale_inv_format = qkv_scale_inv_format; + backend_cfg.do_scale_inv_format = do_scale_inv_format; + backend_cfg.bias_type = bias_type; + backend_cfg.attn_mask_type = attn_mask_type; + backend_cfg.softmax_type = softmax_type; + backend_cfg.attn_scale = attn_scale; + backend_cfg.dropout = dropout; + backend_cfg.max_seqlen_q = max_seqlen_q; + backend_cfg.max_seqlen_kv = max_seqlen_kv; + backend_cfg.window_size_left = window_size_left; + backend_cfg.window_size_right = window_size_right; + backend_cfg.bottom_right_diagonal = bottom_right_diagonal; + backend_cfg.cuda_graph = cuda_graph; + backend_cfg.q_dtype = Q_type; + backend_cfg.kv_dtype = KV_type; + backend_cfg.o_dtype = O_type; + backend_cfg.do_dtype = dO_type; + backend_cfg.dqkv_dtype = dQKV_type; + backend_cfg.batch_size = b; + backend_cfg.num_attn_heads = h_q; + backend_cfg.num_gqa_groups = h_kv; + backend_cfg.head_dim_qk = d_qk; + backend_cfg.head_dim_v = d_v; + backend_cfg.bias_batch_size = bias_batch_size; + backend_cfg.bias_num_heads = bias_num_heads; + backend_cfg.bias_seqlen_q = bias_seqlen_q; + backend_cfg.bias_seqlen_kv = bias_seqlen_kv; + backend_cfg.is_training = true; + backend_cfg.return_max_logit = false; + backend_cfg.deterministic = deterministic; + NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend_v2(&backend_cfg); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; @@ -738,6 +956,63 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } } +// Deprecated: forwards to nvte_fused_attn_bwd_v2. +void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, + const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, + size_t max_seqlen_kv, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, + NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, + NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { + NVTEFusedAttnBwdParams params = NVTE_FUSED_ATTN_BWD_PARAMS_INIT; + params.Q = Q; + params.K = K; + params.V = V; + params.O = O; + params.dO = dO; + params.S = S; + params.cu_seqlens_q = cu_seqlens_q; + params.cu_seqlens_kv = cu_seqlens_kv; + params.cu_seqlens_q_padded = cu_seqlens_q_padded; + params.cu_seqlens_kv_padded = cu_seqlens_kv_padded; + params.Aux_CTX_Tensors = Aux_CTX_Tensors; + params.dP = dP; + params.dQ = dQ; + params.dK = dK; + params.dV = dV; + params.dBias = dBias; + params.dSoftmaxOffset = dSoftmaxOffset; + params.max_seqlen_q = max_seqlen_q; + params.max_seqlen_kv = max_seqlen_kv; + params.qkv_layout = qkv_layout; + params.dqkv_layout = dqkv_layout; + params.o_format = o_format; + params.do_format = do_format; + params.qkv_scale_inv_format = qkv_scale_inv_format; + params.do_scale_inv_format = do_scale_inv_format; + params.bias_type = bias_type; + params.attn_mask_type = attn_mask_type; + params.softmax_type = softmax_type; + params.attn_scale = attn_scale; + params.dropout = dropout; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.bottom_right_diagonal = bottom_right_diagonal; + params.deterministic = deterministic; + params.cuda_graph = cuda_graph; + params.workspace = workspace; + params.stream = stream; + nvte_fused_attn_bwd_v2(¶ms); +} + uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, cudaStream_t stream) { NVTE_API_CALL(nvte_get_runtime_num_segments); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..7345cd18d6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -114,44 +114,37 @@ void fused_attn_arbitrary_seqlen_fwd_impl( bool generate_stats = true; // Always return stats try { FADescriptor_v1 descriptor{ - b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - bias_sq, - bias_skv, - scaling_factor, - is_training, - dropout_probability, - qkv_layout, - o_format, - NVTE_QKV_Format_NOT_SET, - NVTE_QKV_Layout_NOT_SET, - NVTE_QKV_Format_NOT_SET, - NVTE_QKV_Format_NOT_SET, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - bottom_right_diagonal, - true, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - return_max_logit, + .b = b, + .h = h, + .hg = hg, + .s_q = s_q, + .s_kv = s_kv, + .d_qk = d_qk, + .d_v = d_v, + .num_pages_k = num_pages_k, + .num_pages_v = num_pages_v, + .page_size_k = page_size_k, + .page_size_v = page_size_v, + .max_pages_per_seq_k = max_pages_per_seq_k, + .max_pages_per_seq_v = max_pages_per_seq_v, + .bias_b = bias_b, + .bias_h = bias_h, + .bias_sq = bias_sq, + .bias_skv = bias_skv, + .attnScale = scaling_factor, + .isTraining = is_training, + .dropoutProbability = dropout_probability, + .qkv_layout = qkv_layout, + .o_format = o_format, + .bias_type = bias_type, + .mask_type = mask_type, + .softmax_type = softmax_type, + .window_size_left = window_size_left, + .window_size_right = window_size_right, + .bottom_right_diagonal = bottom_right_diagonal, + .deterministic = true, + .qkv_tensor_type = tensorType, + .return_max_logit = return_max_logit, }; namespace fe = cudnn_frontend; @@ -617,44 +610,32 @@ void fused_attn_arbitrary_seqlen_bwd_impl( try { FADescriptor_v1 descriptor{ - b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - bias_sq, - bias_skv, - scaling_factor, - true, - dropout_probability, - qkv_layout, - o_format, - do_format, - dqkv_layout, - NVTE_QKV_Format_NOT_SET, - NVTE_QKV_Format_NOT_SET, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - bottom_right_diagonal, - deterministic, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - false, + .b = b, + .h = h, + .hg = hg, + .s_q = s_q, + .s_kv = s_kv, + .d_qk = d_qk, + .d_v = d_v, + .bias_b = bias_b, + .bias_h = bias_h, + .bias_sq = bias_sq, + .bias_skv = bias_skv, + .attnScale = scaling_factor, + .isTraining = true, + .dropoutProbability = dropout_probability, + .qkv_layout = qkv_layout, + .o_format = o_format, + .do_format = do_format, + .dqkv_layout = dqkv_layout, + .bias_type = bias_type, + .mask_type = mask_type, + .softmax_type = softmax_type, + .window_size_left = window_size_left, + .window_size_right = window_size_right, + .bottom_right_diagonal = bottom_right_diagonal, + .deterministic = deterministic, + .qkv_tensor_type = tensorType, }; namespace fe = cudnn_frontend; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index eab1ae02e6..11b53c6fc4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -61,44 +61,34 @@ void fused_attn_fp8_fwd_impl( "MXFP8 fused attention requires cuDNN 9.21.0 or later!"); try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - bias_sq, - bias_skv, - scaling_factor, - is_training, - dropout_probability, - qkv_layout, - o_format, - NVTE_QKV_Format_NOT_SET, - NVTE_QKV_Layout_NOT_SET, - qkv_scale_inv_format, - NVTE_QKV_Format_NOT_SET, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - bottom_right_diagonal, - true, - qkv_tensor_type, - o_tensor_type, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - false}; + FADescriptor_v1 descriptor{ + .b = b, + .h = h, + .hg = hg, + .s_q = s_q, + .s_kv = s_kv, + .d_qk = d_qk, + .d_v = d_v, + .bias_b = bias_b, + .bias_h = bias_h, + .bias_sq = bias_sq, + .bias_skv = bias_skv, + .attnScale = scaling_factor, + .isTraining = is_training, + .dropoutProbability = dropout_probability, + .qkv_layout = qkv_layout, + .o_format = o_format, + .qkv_scale_inv_format = qkv_scale_inv_format, + .bias_type = bias_type, + .mask_type = mask_type, + .softmax_type = softmax_type, + .window_size_left = window_size_left, + .window_size_right = window_size_right, + .bottom_right_diagonal = bottom_right_diagonal, + .deterministic = true, + .qkv_tensor_type = qkv_tensor_type, + .o_tensor_type = o_tensor_type, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -501,44 +491,39 @@ void fused_attn_fp8_bwd_impl( o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - bias_sq, - bias_skv, - scaling_factor, - true, - dropout_probability, - qkv_layout, - o_format, - do_format, - dqkv_layout, - qkv_scale_inv_format, - do_scale_inv_format, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - bottom_right_diagonal, - deterministic, - qkv_tensor_type, - o_tensor_type, - do_tensor_type, - dqkv_tensor_type, - false}; + FADescriptor_v1 descriptor{ + .b = b, + .h = h, + .hg = hg, + .s_q = s_q, + .s_kv = s_kv, + .d_qk = d_qk, + .d_v = d_v, + .bias_b = bias_b, + .bias_h = bias_h, + .bias_sq = bias_sq, + .bias_skv = bias_skv, + .attnScale = scaling_factor, + .isTraining = true, + .dropoutProbability = dropout_probability, + .qkv_layout = qkv_layout, + .o_format = o_format, + .do_format = do_format, + .dqkv_layout = dqkv_layout, + .qkv_scale_inv_format = qkv_scale_inv_format, + .do_scale_inv_format = do_scale_inv_format, + .bias_type = bias_type, + .mask_type = mask_type, + .softmax_type = softmax_type, + .window_size_left = window_size_left, + .window_size_right = window_size_right, + .bottom_right_diagonal = bottom_right_diagonal, + .deterministic = deterministic, + .qkv_tensor_type = qkv_tensor_type, + .o_tensor_type = o_tensor_type, + .do_tensor_type = do_tensor_type, + .dqkv_tensor_type = dqkv_tensor_type, + }; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 41656062a4..de1a5a85b9 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -13,6 +13,7 @@ #include #include +#include #include "../common.h" #include "transformer_engine/fused_attn.h" @@ -273,64 +274,63 @@ struct FADescriptor { } }; +// Cache key for cuDNN graph plans. All members must participate in equality / ordering; +// adding a new field that affects graph identity is a single-line change here: +// declare it with a sensible default, then add it once to `as_tuple()`. struct FADescriptor_v1 { - std::int64_t b; - std::int64_t h; - std::int64_t hg; - std::int64_t s_q; - std::int64_t s_kv; - std::int64_t d_qk; - std::int64_t d_v; - std::int64_t num_pages_k; - std::int64_t num_pages_v; - std::int64_t page_size_k; - std::int64_t page_size_v; - std::int64_t max_pages_per_seq_k; - std::int64_t max_pages_per_seq_v; - std::int64_t bias_b; - std::int64_t bias_h; - std::int64_t bias_sq; - std::int64_t bias_skv; - float attnScale; - bool isTraining; - float dropoutProbability; - NVTE_QKV_Layout qkv_layout; - NVTE_QKV_Format o_format; - NVTE_QKV_Format do_format; - NVTE_QKV_Layout dqkv_layout; - NVTE_QKV_Format qkv_scale_inv_format; - NVTE_QKV_Format do_scale_inv_format; - NVTE_Bias_Type bias_type; - NVTE_Mask_Type mask_type; - NVTE_Softmax_Type softmax_type; - std::int64_t window_size_left; - std::int64_t window_size_right; - bool bottom_right_diagonal; - bool deterministic; - cudnn_frontend::DataType_t qkv_tensor_type; - cudnn_frontend::DataType_t o_tensor_type; - cudnn_frontend::DataType_t do_tensor_type; - cudnn_frontend::DataType_t dqkv_tensor_type; - bool return_max_logit; - - bool operator<(const FADescriptor_v1 &rhs) const { + std::int64_t b = 0; + std::int64_t h = 0; + std::int64_t hg = 0; + std::int64_t s_q = 0; + std::int64_t s_kv = 0; + std::int64_t d_qk = 0; + std::int64_t d_v = 0; + std::int64_t num_pages_k = 0; + std::int64_t num_pages_v = 0; + std::int64_t page_size_k = 0; + std::int64_t page_size_v = 0; + std::int64_t max_pages_per_seq_k = 0; + std::int64_t max_pages_per_seq_v = 0; + std::int64_t bias_b = 0; + std::int64_t bias_h = 0; + std::int64_t bias_sq = 0; + std::int64_t bias_skv = 0; + float attnScale = 0.0f; + bool isTraining = false; + float dropoutProbability = 0.0f; + NVTE_QKV_Layout qkv_layout = NVTE_QKV_Layout_NOT_SET; + NVTE_QKV_Format o_format = NVTE_QKV_Format_NOT_SET; + NVTE_QKV_Format do_format = NVTE_QKV_Format_NOT_SET; + NVTE_QKV_Layout dqkv_layout = NVTE_QKV_Layout_NOT_SET; + NVTE_QKV_Format qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET; + NVTE_QKV_Format do_scale_inv_format = NVTE_QKV_Format_NOT_SET; + NVTE_Bias_Type bias_type = NVTE_NO_BIAS; + NVTE_Mask_Type mask_type = NVTE_NO_MASK; + NVTE_Softmax_Type softmax_type = NVTE_VANILLA_SOFTMAX; + std::int64_t window_size_left = 0; + std::int64_t window_size_right = 0; + bool bottom_right_diagonal = false; + bool deterministic = false; + cudnn_frontend::DataType_t qkv_tensor_type = cudnn_frontend::DataType_t::NOT_SET; + cudnn_frontend::DataType_t o_tensor_type = cudnn_frontend::DataType_t::NOT_SET; + cudnn_frontend::DataType_t do_tensor_type = cudnn_frontend::DataType_t::NOT_SET; + cudnn_frontend::DataType_t dqkv_tensor_type = cudnn_frontend::DataType_t::NOT_SET; + bool return_max_logit = false; + + // Single source of truth for which fields participate in equality / ordering. + // When you add a new field above, append it here exactly once. + auto as_tuple() const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq, bias_skv, attnScale, isTraining, dropoutProbability, qkv_layout, o_format, - do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, mask_type, - softmax_type, window_size_left, window_size_right, bottom_right_diagonal, - deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, return_max_logit) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, - rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, - rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, - rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.qkv_layout, - rhs.o_format, rhs.do_format, rhs.dqkv_layout, rhs.qkv_scale_inv_format, - rhs.do_scale_inv_format, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, - rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, - rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.return_max_logit); + do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, qkv_tensor_type, o_tensor_type, + do_tensor_type, dqkv_tensor_type, return_max_logit); } + + bool operator<(const FADescriptor_v1 &rhs) const { return as_tuple() < rhs.as_tuple(); } + bool operator==(const FADescriptor_v1 &rhs) const { return as_tuple() == rhs.as_tuple(); } }; __global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index d9d2786623..abc666d332 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -194,7 +194,123 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \struct NVTEFusedAttnConfig + * \brief Inputs to nvte_get_fused_attn_backend_v2(). + * + * Holds algorithm/policy fields plus tensor-derived shape and dtype metadata + * needed to determine which fused attention backend supports a given + * configuration, and to uniquely identify the cuDNN execution plan that will + * be cached for it. Field values are passed directly so backend support can + * be probed without allocating any tensors. + * + * Distinct from NVTEFusedAttnFwdParams / NVTEFusedAttnBwdParams, which + * additionally bind tensors and execution context for an actual call. + * + * Field naming follows snake_case throughout. Direction-only fields are + * grouped together: callers querying for forward should leave bwd-only fields + * (do_format, dqkv_layout, do_scale_inv_format, do_dtype, dqkv_dtype, + * deterministic) at their defaults, and vice versa. + * + * Versioning rules: + * - struct_size MUST be set to sizeof(NVTEFusedAttnConfig) by the caller + * (use NVTE_FUSED_ATTN_CONFIG_INIT). + * - New fields may only be appended at the end; existing fields are never + * reordered, removed, or resized. The library reads only fields that are + * in range according to struct_size and uses safe defaults otherwise. + */ +typedef struct NVTEFusedAttnConfig { + size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnConfig). */ + uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ + uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ + + /* ---- algorithm / policy ---- */ + NVTE_QKV_Layout qkv_layout; /*!< QKV tensors' layout. */ + NVTE_QKV_Format o_format; /*!< Output O tensor format. */ + NVTE_QKV_Format do_format; /*!< Output-grad dO tensor format (bwd). */ + NVTE_QKV_Layout dqkv_layout; /*!< Gradient dQKV tensor layout (bwd). */ + NVTE_QKV_Format qkv_scale_inv_format; /*!< QKV scale_inv tensor format (FP8). */ + NVTE_QKV_Format do_scale_inv_format; /*!< dO scale_inv tensor format (FP8 bwd). */ + NVTE_Bias_Type bias_type; /*!< Attention bias type. */ + NVTE_Mask_Type attn_mask_type; /*!< Attention mask type. */ + NVTE_Softmax_Type softmax_type; /*!< Attention softmax type. */ + float attn_scale; /*!< Pre-softmax attention scale factor. */ + float dropout; /*!< Dropout probability. */ + size_t max_seqlen_q; /*!< Max sequence length for Q. */ + size_t max_seqlen_kv; /*!< Max sequence length for K, V. */ + int64_t window_size_left; /*!< Sliding window size (left half); -1 = unlimited. */ + int64_t window_size_right; /*!< Sliding window size (right half); -1 = unlimited. */ + bool bottom_right_diagonal; /*!< Whether causal mask aligns to the bottom-right diagonal. */ + bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ + + /* ---- tensor-derived metadata (passed as values; no tensor required) ---- */ + NVTEDType q_dtype; /*!< Data type of Tensor Q. */ + NVTEDType kv_dtype; /*!< Data type of Tensors K, V. */ + NVTEDType o_dtype; /*!< Data type of Tensor O. */ + NVTEDType do_dtype; /*!< Data type of Tensor dO (bwd). */ + NVTEDType dqkv_dtype; /*!< Data type of Tensors dQ, dK, dV (bwd). */ + size_t batch_size; /*!< Batch size. */ + size_t num_attn_heads; /*!< Number of heads in Q. */ + size_t num_gqa_groups; /*!< Number of heads in K, V. */ + size_t head_dim_qk; /*!< Head dimension of Q, K. */ + size_t head_dim_v; /*!< Head dimension of V. */ + + /* paged KV cache shape (set 0 when not using paged attention). */ + size_t num_pages_k; /*!< Total number of K cache pages. */ + size_t num_pages_v; /*!< Total number of V cache pages. */ + size_t page_size_k; /*!< Tokens per K cache page. */ + size_t page_size_v; /*!< Tokens per V cache page. */ + size_t max_pages_per_seq_k; /*!< Max K pages per sequence in the batch. */ + size_t max_pages_per_seq_v; /*!< Max V pages per sequence in the batch. */ + + /* attention-bias broadcast shape (set 0 when not using a bias tensor). */ + size_t bias_batch_size; /*!< Bias broadcast dim for batch. */ + size_t bias_num_heads; /*!< Bias broadcast dim for heads. */ + size_t bias_seqlen_q; /*!< Bias broadcast dim for Q sequence length. */ + size_t bias_seqlen_kv; /*!< Bias broadcast dim for K/V sequence length. */ + + /* ---- direction-affecting behavior flags ---- */ + bool is_training; /*!< Whether the model is in training mode. */ + bool return_max_logit; /*!< Whether to produce Max along with Stats (fwd-only). */ + bool deterministic; /*!< Whether determinism is required (bwd-only). */ + + /* ---- Future fields appended here only. ---- */ +} NVTEFusedAttnConfig; + +/*! \brief Default-initialize an NVTEFusedAttnConfig. + * + * Sets struct_size and the categorical fields (layouts, formats, masks, + * window sizes) to safe NOT_SET / no-op defaults. Numeric and tensor-derived + * fields, paged-KV shape, bias broadcast shape, and direction flags all + * default to zero/false; callers must set the fields relevant to their query. + */ +#define NVTE_FUSED_ATTN_CONFIG_INIT \ + { \ + .struct_size = sizeof(NVTEFusedAttnConfig), \ + .qkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .o_format = NVTE_QKV_Format_NOT_SET, \ + .do_format = NVTE_QKV_Format_NOT_SET, \ + .dqkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .do_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .bias_type = NVTE_NO_BIAS, \ + .attn_mask_type = NVTE_NO_MASK, \ + .softmax_type = NVTE_VANILLA_SOFTMAX, \ + .window_size_left = -1, \ + .window_size_right = -1, \ + } + /*! \brief Get fused attention backend based on input parameters. + * + * \param[in] cfg Algorithm/policy + tensor metadata describing the + * attention configuration to query. + * + * \return Backend able to execute this configuration, or NVTE_No_Backend if none. + */ +NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend_v2(const NVTEFusedAttnConfig *cfg); + +/*! \brief Get fused attention backend based on input parameters (deprecated). + * + * This has been deprecated in favor of nvte_get_fused_attn_backend_v2. * * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. @@ -223,7 +339,124 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); +/*! \struct NVTEFusedAttnFwdParams + * \brief All inputs and configuration for nvte_fused_attn_fwd_v2(). + * + * Bundles the tensor bindings, algorithm configuration, behavior flags, and + * execution context for one forward call. Tensors that do not apply to a given + * call (e.g. Bias when bias_type == NVTE_NO_BIAS) may be left as nullptr. + * + * For semantics of cu_seqlens_q_padded / cu_seqlens_kv_padded with THD layouts, + * see the notes on nvte_fused_attn_fwd(). + * + * Versioning rules: + * - struct_size MUST be set to sizeof(NVTEFusedAttnFwdParams) by the caller + * (use NVTE_FUSED_ATTN_FWD_PARAMS_INIT). + * - New fields may only be appended at the end; existing fields are never + * reordered, removed, or resized. The library reads only fields that are + * in range according to struct_size and uses safe defaults otherwise. + */ +typedef struct NVTEFusedAttnFwdParams { + size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnFwdParams). */ + uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ + uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ + + /* ---- input tensors ---- */ + NVTETensor Q; /*!< The Q tensor. */ + NVTETensor K; /*!< The K tensor. */ + NVTETensor V; /*!< The V tensor. */ + NVTETensor Bias; /*!< The Bias tensor. */ + NVTETensor SoftmaxOffset; /*!< The SoftmaxOffset tensor. */ + NVTETensor cu_seqlens_q; /*!< Cumulative sequence lengths for Q, [batch_size + 1]. */ + NVTETensor cu_seqlens_kv; /*!< Cumulative sequence lengths for K and V, [batch_size + 1]. */ + NVTETensor cu_seqlens_q_padded; /*!< Cumulative sequence offsets for Q, [batch_size + 1]. */ + NVTETensor cu_seqlens_kv_padded; /*!< Cumulative sequence offsets for KV, [batch_size + 1]. */ + NVTETensor page_table_k; /*!< Page table for K cache, [batch_size, max_pages_per_seq_k]. */ + NVTETensor page_table_v; /*!< Page table for V cache, [batch_size, max_pages_per_seq_v]. */ + NVTETensor rng_state; /*!< Seed and offset of CUDA random number generator. */ + + /* ---- output / inout tensors ---- */ + NVTETensor S; /*!< The S tensor (in/out). */ + NVTETensor O; /*!< The output O tensor. */ + NVTETensorPack *Aux_CTX_Tensors; /*!< Auxiliary output tensors when training, + e.g. softmax stats, optional Max, rng_state. */ + + /* ---- sizes ---- */ + size_t max_seqlen_q; /*!< Max sequence length for Q; + it may be >= max(seqlen_q_i) for i = 0, ..., batch_size - 1. */ + size_t max_seqlen_kv; /*!< Max sequence length for K and V; + it may be >= max(seqlen_kv_i) for i = 0, ..., batch_size - 1. */ + + /* ---- algorithm config ---- */ + NVTE_QKV_Layout qkv_layout; /*!< QKV tensors' layout. */ + NVTE_QKV_Format o_format; /*!< Output format. */ + NVTE_QKV_Format qkv_scale_inv_format; /*!< Format of scale-inverse tensors for QKV; + NVTE_QKV_Format_NOT_SET = infer from qkv_layout. */ + NVTE_Bias_Type bias_type; /*!< Bias type. */ + NVTE_Mask_Type attn_mask_type; /*!< Attention mask type. */ + NVTE_Softmax_Type softmax_type; /*!< Attention softmax type. */ + + /* ---- numerics / windowing ---- */ + float attn_scale; /*!< Scaling factor for Q * K.T. */ + float dropout; /*!< Dropout probability. */ + int64_t window_size_left; /*!< Sliding window size (left half); -1 = unlimited. */ + int64_t window_size_right; /*!< Sliding window size (right half); -1 = unlimited. */ + bool bottom_right_diagonal; /*!< Whether to align sliding window and ALiBi diagonal + to the bottom right corner of the softmax matrix. */ + + /* ---- behavior ---- */ + bool is_training; /*!< Whether this is in training mode or inference. */ + bool return_max_logit; /*!< Whether to produce Max along with Stats. */ + bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ + + /* ---- execution ---- */ + NVTETensor workspace; /*!< Workspace tensor. */ + cudaStream_t stream; /*!< CUDA stream used for this operation. */ + + /* ---- Future fields appended here only. ---- */ +} NVTEFusedAttnFwdParams; + +/*! \brief Default-initialize an NVTEFusedAttnFwdParams. + * + * Sets struct_size and all enums/scalars to safe defaults. Tensors and the + * Aux_CTX_Tensors pack default to nullptr; callers must set the fields + * relevant to their call. + */ +#define NVTE_FUSED_ATTN_FWD_PARAMS_INIT \ + { \ + .struct_size = sizeof(NVTEFusedAttnFwdParams), \ + .qkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .o_format = NVTE_QKV_Format_NOT_SET, \ + .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .bias_type = NVTE_NO_BIAS, \ + .attn_mask_type = NVTE_NO_MASK, \ + .softmax_type = NVTE_VANILLA_SOFTMAX, \ + .attn_scale = 1.0f, \ + .window_size_left = -1, \ + .window_size_right = -1, \ + } + /*! \brief Compute dot product attention with separate Q, K and V. + * + * Computes: + * - P = Q * Transpose(K) + Bias + * - S = ScaleMaskSoftmax(P) + * - D = Dropout(S) + * - O = D * Transpose(V) + * + * See the notes on nvte_fused_attn_fwd() for cu_seqlens_q_padded / + * cu_seqlens_kv_padded semantics with THD layouts. + * + * \param[in,out] params All inputs, configuration, and execution context. + * Output tensors and the Aux_CTX_Tensors pack are + * written through pointers the caller has supplied + * in the struct. + */ +void nvte_fused_attn_fwd_v2(const NVTEFusedAttnFwdParams *params); + +/*! \brief Compute dot product attention with separate Q, K and V (deprecated). + * + * This has been deprecated in favor of nvte_fused_attn_fwd_v2. * * Computes: * - P = Q * Transpose(K) + Bias @@ -297,7 +530,126 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream); +/*! \struct NVTEFusedAttnBwdParams + * \brief All inputs and configuration for nvte_fused_attn_bwd_v2(). + * + * Bundles the tensor bindings, algorithm configuration, behavior flags, and + * execution context for one backward call. Auxiliary tensors saved by the + * forward pass (e.g. softmax stats, RNG state, saved Bias / SoftmaxOffset) + * are accessed via Aux_CTX_Tensors and are NOT separately set as fields here. + * + * For semantics of cu_seqlens_q_padded / cu_seqlens_kv_padded with THD layouts, + * see the notes on nvte_fused_attn_bwd(). + * + * Versioning rules: + * - struct_size MUST be set to sizeof(NVTEFusedAttnBwdParams) by the caller + * (use NVTE_FUSED_ATTN_BWD_PARAMS_INIT). + * - New fields may only be appended at the end; existing fields are never + * reordered, removed, or resized. The library reads only fields that are + * in range according to struct_size and uses safe defaults otherwise. + */ +typedef struct NVTEFusedAttnBwdParams { + size_t struct_size; /*!< MUST equal sizeof(NVTEFusedAttnBwdParams). */ + uint32_t reserved0; /*!< Padding for layout stability; set to 0. */ + uint32_t reserved1; /*!< Padding for layout stability; set to 0. */ + + /* ---- input tensors ---- */ + NVTETensor Q; /*!< The Q tensor. */ + NVTETensor K; /*!< The K tensor. */ + NVTETensor V; /*!< The V tensor. */ + NVTETensor O; /*!< The O tensor from forward. */ + NVTETensor dO; /*!< The gradient of the O tensor. */ + NVTETensor S; /*!< The S tensor. */ + NVTETensor cu_seqlens_q; /*!< Cumulative sequence lengths for Q, [batch_size + 1]. */ + NVTETensor cu_seqlens_kv; /*!< Cumulative sequence lengths for K and V, [batch_size + 1]. */ + NVTETensor cu_seqlens_q_padded; /*!< Cumulative sequence offsets for Q, [batch_size + 1]. */ + NVTETensor cu_seqlens_kv_padded; /*!< Cumulative sequence offsets for KV, [batch_size + 1]. */ + const NVTETensorPack *Aux_CTX_Tensors; /*!< Auxiliary tensors from forward context, + e.g. softmax stats, optional Max, rng_state. */ + + /* ---- output / inout tensors ---- */ + NVTETensor dP; /*!< The gradient of the P tensor (in/out). */ + NVTETensor dQ; /*!< The gradient of the Q tensor. */ + NVTETensor dK; /*!< The gradient of the K tensor. */ + NVTETensor dV; /*!< The gradient of the V tensor. */ + NVTETensor dBias; /*!< The gradient of the Bias tensor. */ + NVTETensor dSoftmaxOffset; /*!< The gradient of the SoftmaxOffset tensor. */ + + /* ---- sizes ---- */ + size_t max_seqlen_q; /*!< Max sequence length for Q; + it may be >= max(seqlen_q_i) for i = 0, ..., batch_size - 1. */ + size_t max_seqlen_kv; /*!< Max sequence length for K and V; + it may be >= max(seqlen_kv_i) for i = 0, ..., batch_size - 1. */ + + /* ---- algorithm config ---- */ + NVTE_QKV_Layout qkv_layout; /*!< QKV tensors' layout. */ + NVTE_QKV_Layout dqkv_layout; /*!< QKV gradient tensors' layout. */ + NVTE_QKV_Format o_format; /*!< Output format. */ + NVTE_QKV_Format do_format; /*!< Output gradient's format. */ + NVTE_QKV_Format qkv_scale_inv_format; /*!< Format of scale-inverse tensors for QKV; + NVTE_QKV_Format_NOT_SET = infer from qkv_layout. */ + NVTE_QKV_Format do_scale_inv_format; /*!< Format of scale-inverse tensors for dO; + NVTE_QKV_Format_NOT_SET = infer from output layout. */ + NVTE_Bias_Type bias_type; /*!< Bias type. */ + NVTE_Mask_Type attn_mask_type; /*!< Attention mask type. */ + NVTE_Softmax_Type softmax_type; /*!< Attention softmax type. */ + + /* ---- numerics / windowing ---- */ + float attn_scale; /*!< Scaling factor for Q * K.T. */ + float dropout; /*!< Dropout probability. */ + int64_t window_size_left; /*!< Sliding window size (left half); -1 = unlimited. */ + int64_t window_size_right; /*!< Sliding window size (right half); -1 = unlimited. */ + bool bottom_right_diagonal; /*!< Whether to align sliding window and ALiBi diagonal + to the bottom right corner of the softmax matrix. */ + + /* ---- behavior ---- */ + bool deterministic; /*!< Whether to execute with deterministic behaviours. */ + bool cuda_graph; /*!< Whether CUDA graph capture is enabled. */ + + /* ---- execution ---- */ + NVTETensor workspace; /*!< Workspace tensor. */ + cudaStream_t stream; /*!< CUDA stream used for this operation. */ + + /* ---- Future fields appended here only. ---- */ +} NVTEFusedAttnBwdParams; + +/*! \brief Default-initialize an NVTEFusedAttnBwdParams. + * + * Sets struct_size and all enums/scalars to safe defaults. Tensors and the + * Aux_CTX_Tensors pack default to nullptr; callers must set the fields + * relevant to their call. + */ +#define NVTE_FUSED_ATTN_BWD_PARAMS_INIT \ + { \ + .struct_size = sizeof(NVTEFusedAttnBwdParams), \ + .qkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .dqkv_layout = NVTE_QKV_Layout_NOT_SET, \ + .o_format = NVTE_QKV_Format_NOT_SET, \ + .do_format = NVTE_QKV_Format_NOT_SET, \ + .qkv_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .do_scale_inv_format = NVTE_QKV_Format_NOT_SET, \ + .bias_type = NVTE_NO_BIAS, \ + .attn_mask_type = NVTE_NO_MASK, \ + .softmax_type = NVTE_VANILLA_SOFTMAX, \ + .attn_scale = 1.0f, \ + .window_size_left = -1, \ + .window_size_right = -1, \ + } + /*! \brief Compute the backward of the dot product attention with separate Q, K and V. + * + * See the notes on nvte_fused_attn_bwd() for cu_seqlens_q_padded / + * cu_seqlens_kv_padded semantics with THD layouts. + * + * \param[in,out] params All inputs, configuration, and execution context. + * Output tensors are written through pointers the + * caller has supplied in the struct. + */ +void nvte_fused_attn_bwd_v2(const NVTEFusedAttnBwdParams *params); + +/*! \brief Compute the backward of the dot product attention with separate Q, K and V (deprecated). + * + * This has been deprecated in favor of nvte_fused_attn_bwd_v2. * * Notes: *