diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c8d70b185..b5c0b3001 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -373,6 +373,39 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + // Dequantization for GGML. + ops.def( + "ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? " + "dtype) -> Tensor"); + ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); + + // mmvq kernel for GGML. + ops.def( + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); + ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); + + // mmq kernel for GGML. + ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); + ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); + + // mmq kernel for GGML (MoE). + ops.def( + "ggml_moe_a8(Tensor X, Tensor W, " + "Tensor sorted_token_ids, Tensor expert_ids, " + "Tensor num_tokens_post_padded, int type, " + "SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); + + // mmvq kernel for GGML (MoE). + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); + + ops.def("ggml_moe_get_block_size(int type) -> int"); + ops.impl("ggml_moe_get_block_size", &ggml_moe_get_block_size); + // ┌---------- Not supported for Metax -----------┐ // Compute FP8 quantized tensor for given scaling factor. // Supports per-tensor, per-channel, per-token, and arbitrary 2D group