From 2e23faa991a44981289597034dd8b936cd310168 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Mon, 7 Jul 2025 03:15:14 +0000 Subject: [PATCH 1/5] [ascend] support lora --- dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py | 21 + .../AtbGraph/codegen/atb_infer_param.py | 13 + .../dicp/vendor/AtbGraph/codegen/atb_op.py | 35 +- .../AtbGraph/codegen/runtime/CMakeLists.txt | 1 - .../vendor/AtbGraph/codegen/runtime/model.cpp | 5 + .../runtime/ops/aclnn_ops/cat_operation.cpp | 2 +- .../ops/aclnn_ops/inplace_add_operation.cpp | 71 +++ .../ops/aclnn_ops/inplace_add_operation.h | 24 + .../aclnn_ops/split_with_size_operation.cpp | 61 +- .../ops/aclnn_ops/split_with_size_operation.h | 4 +- .../ops/custom_ops/fused_lora_operation.cpp | 565 ++++++++++++++++++ .../ops/custom_ops/fused_lora_operation.h | 68 +++ .../codegen/runtime/utils/operation_util.h | 2 + .../graph/dicp/vendor/AtbGraph/conversion.py | 87 ++- dlinfer/ops/llm.py | 69 +++ dlinfer/vendor/ascend/torch_npu_ops.py | 55 ++ 16 files changed, 1068 insertions(+), 15 deletions(-) create mode 100644 dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.cpp create mode 100644 dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.h create mode 100644 dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp create mode 100644 dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py b/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py index 8619eeac..0de70b3b 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py @@ -825,3 +825,24 @@ def __init__(self): def infer_result(self, x, dim, keep_dim, dtype, ascend_dtype): return x.sum(dim, keep_dim=keep_dim, dtype=dtype) + + +class CustomFusedLora(Operator): + def __init__(self): + super().__init__("CustomFusedLora") + + def infer_result( + self, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype + ): + M, K = x.shape + N = lora_b.size(1) + output = torch.empty((M, N), dtype=x.dtype, device=x.device) + return output, output + + +class AclNnInplaceAdd(Operator): + def __init__(self): + super().__init__("AclNnInplaceAdd") + + def infer_result(self, a, b, dtype): + return a + b diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py index 81d61ac8..a612be2c 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_infer_param.py @@ -729,6 +729,19 @@ class AclNnReduceSumParam: dtype: str = "FLOAT" +@dataclass +class CustomFusedLoraParam: + name: str = "" + dtype: str = "FLOAT" + + +@dataclass +class AclNnInplaceAddParam: + name: str = "" + alpha: float = 1.0 + dtype: str = "FLOAT" + + def custom_asdict_factory(data): def convert_value(obj): if isinstance(obj, IntEnum): diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py index b45e3961..4cb10bc0 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py @@ -570,7 +570,7 @@ def SplitWithSize(name, x, sizes, dim): op = Operation(name, "AclNnSplitWithSizeOperation") param = infer_param.SplitParam() param.splitDim = dim - param.splitSizes = sizes + param.splitSizes = [str(s) for s in sizes] op.set_param(param) op.set_input([x]) for idx, _ in enumerate(sizes): @@ -1246,3 +1246,36 @@ def AclNnReduceSum(name, x, dim, keep_dim, dtype, ascend_dtype): op.set_param(param) op.set_output([name]) return op + + def CustomFusedLora( + name, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype + ): + op = Operation(name, "CustomFusedLoraOperation") + # TODO: add param + param = infer_param.CustomFusedLoraParam() + param.name = name + param.dtype = get_ascend_dtype(dtype) + seq_lens_cpu = seq_lens + op.set_input( + [x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, seq_lens_cpu] + ) + op.set_param(param) + op.set_output([f"{name}__0", f"{name}__1", f"{name}__2"]) + + op.has_host_inputs = True + op.host_inputs.append(ranks) + op.host_inputs.append(adapter_ids) + op.host_inputs.append(seq_lens_cpu) + return op + + def AclNnInplaceAdd(name, a, b, dtype): + op = Operation(name, "AclNnInplaceAddOperation") + param = infer_param.AclNnInplaceAddParam() + param.name = name + param.dtype = get_ascend_dtype(dtype) + op.set_input([a, b]) + op.set_param(param) + op.set_output([name]) + op.has_inplace_output = True + op.add_inplace_output(0, 0) + return op diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt index c5f7b3c5..aa18de95 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt @@ -19,7 +19,6 @@ set(COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter -Wno-attributes - -D_GLIBCXX_USE_CXX11_ABI=0 ) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2") diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp index 1fd2a6c8..1956217e 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp @@ -313,6 +313,7 @@ atb::Status Model::ExecuteNode(int nodeId) { if (st != 0) { DICP_LOG(ERROR) << "execute node[" << nodeId << "] fail, error code: " << st; } + DICP_LOG(INFO) << modelId_ << "execute node[" << nodeId << "] end"; return st; } @@ -424,6 +425,10 @@ void Model::CreateGraphOperation(const nlohmann::json& paramJson, Node& node) { graph_param.internalTensorNum = internalNames.size(); graph_param.nodes.resize(nodeSize); + if (paramJson.contains("name")) { + graph_param.name = paramJson["name"].get(); + } + // graph local tensor ids std::unordered_map graph_tensor_ids; int tensorCount = 0; diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/cat_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/cat_operation.cpp index ed73e87d..c4fc512a 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/cat_operation.cpp +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/cat_operation.cpp @@ -39,7 +39,7 @@ atb::Status AclNnCatOperation::InferShape(const atb::SVector& i outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; int64_t concatDimSize = 0; - int64_t dim = this->concatDim > 0 ? this->concatDim : inTensorDescs.at(0).shape.dimNum + this->concatDim; + int64_t dim = this->concatDim >= 0 ? this->concatDim : inTensorDescs.at(0).shape.dimNum + this->concatDim; for (size_t i = 0; i < inTensorDescs.size(); ++i) { concatDimSize += inTensorDescs.at(i).shape.dims[dim]; } diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.cpp new file mode 100644 index 00000000..ea64f898 --- /dev/null +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.cpp @@ -0,0 +1,71 @@ +#include "inplace_add_operation.h" + +#include + +#include "aclnnop/aclnn_add.h" +#include "utils/log.h" +#include "utils/misc.h" + +namespace dicp { + +const int NUM1 = 1; +const int NUM2 = 2; + +AclNnInplaceAddOperation::AclNnInplaceAddOperation(const std::string& name, float aplpha, const std::string& dtype) : AclNnOperation(name) { + alpha_ = DICPScalar(aplpha, dtype); + aclAlpha_ = aclCreateScalar(alpha_.getValuePtr(), alpha_.getDataType()); +} + +AclNnInplaceAddOperation::~AclNnInplaceAddOperation() { + if (aclAlpha_ != nullptr) { + aclDestroyScalar(aclAlpha_); + } +} + +atb::Status AclNnInplaceAddOperation::InferShape(const atb::SVector& inTensorDescs, atb::SVector& outTensorDescs) const { + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).shape.dimNum = inTensorDescs.at(0).shape.dimNum; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + for (size_t i = 0; i < outTensorDescs.at(0).shape.dimNum; ++i) { + outTensorDescs.at(0).shape.dims[i] = inTensorDescs.at(0).shape.dims[i]; + } + return 0; +} + +uint32_t AclNnInplaceAddOperation::GetInputNum() const { return NUM2; } + +uint32_t AclNnInplaceAddOperation::GetOutputNum() const { return NUM1; } + +int AclNnInplaceAddOperation::SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) { + int ret = aclnnInplaceAddGetWorkspaceSize(aclInTensors_.at(0).tensor, aclInTensors_.at(1).tensor, aclAlpha_, &workspaceSize, &aclExecutor_); + DICP_LOG(INFO) << opName_ << " aclnnInplaceAddGetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize << ", aclExecutor:" << aclExecutor_; + return ret; +} + +int AclNnInplaceAddOperation::CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) { + int ret = aclnnInplaceAdd(workspace, workspaceSize, aclExecutor, stream); + DICP_LOG(INFO) << opName_ << " aclnnInplaceAdd end, ret:" << ret; + return ret; +} + +atb::Operation* AclNnInplaceAddOperationCreate(const nlohmann::json& paramJson) { + std::string opName; + float aplpha; + std::string dtype; + if (paramJson.contains("name")) { + opName = paramJson["name"].get(); + } + if (paramJson.contains("aplpha")) { + aplpha = paramJson["aplpha"].get(); + } + if (paramJson.contains("dtype")) { + dtype = paramJson["dtype"].get(); + } + DICP_LOG(INFO) << "AclNnInplaceAddOperation: name: " << opName; + atb::Operation* op = new AclNnInplaceAddOperation(opName, aplpha, dtype); + return op; +} + +REGISTER_OPERATION(AclNnInplaceAddOperation, AclNnInplaceAddOperationCreate); + +} // namespace dicp diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.h new file mode 100644 index 00000000..55056413 --- /dev/null +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/inplace_add_operation.h @@ -0,0 +1,24 @@ +#pragma once + +#include "acl_nn_operation.h" +#include "utils/scalar.h" + +namespace dicp { +class AclNnInplaceAddOperation : public AclNnOperation { +public: + explicit AclNnInplaceAddOperation(const std::string& name, float aplpha, const std::string& dtype); + ~AclNnInplaceAddOperation() override; + atb::Status InferShape(const atb::SVector& inTensorDescs, atb::SVector& outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + DICPScalar alpha_; + aclScalar* aclAlpha_ = nullptr; + + std::string dtype_; + int SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) override; + int CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) override; +}; + +} // namespace dicp diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.cpp index 8fa66d21..ec81064d 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.cpp +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.cpp @@ -12,14 +12,29 @@ #include "acl/acl.h" #include "aclnnop/aclnn_split_with_size.h" +#include "utils/common.h" +#include "utils/global_dict.h" #include "utils/log.h" namespace dicp { const int NUM1 = 1; -AclNnSplitWithSizeOperation::AclNnSplitWithSizeOperation(const std::string& name, int64_t splitDim, std::vector splitSizes) - : AclNnOperation(name), splitDim_(splitDim), splitSizes_(std::move(splitSizes)) {} +AclNnSplitWithSizeOperation::AclNnSplitWithSizeOperation(const std::string& name, int64_t splitDim, std::vector splitSizes) + : AclNnOperation(name), splitDim_(splitDim) { + splitSizes_.resize(splitSizes.size()); + for (size_t i = 0; i < splitSizes_.size(); ++i) { + bool isDynamic = !std::isdigit(splitSizes[i][0]); + if (isDynamic) { + dynamicSplitSizesMap_[i] = splitSizes[i]; + } else { + splitSizes_[i] = std::stol(splitSizes[i]); + } + } + if (dynamicSplitSizesMap_.size() == 0) { + aclSplitSizes_ = aclCreateIntArray(splitSizes_.data(), splitSizes_.size()); + } +} AclNnSplitWithSizeOperation::~AclNnSplitWithSizeOperation() {} @@ -32,8 +47,9 @@ atb::Status AclNnSplitWithSizeOperation::InferShape(const atb::SVector= 0 ? splitDim_ : inputDimNum + splitDim_; + + auto& globalDict = GetGlobalDictData(); for (size_t i = 0; i < splitSizes_.size(); ++i) { auto& outputTensorDesc = outTensorDescs.at(i); outputTensorDesc.format = inputFormat; @@ -41,8 +57,21 @@ atb::Status AclNnSplitWithSizeOperation::InferShape(const atb::SVectorsecond); + if (it != globalDict.end()) { + targetDimValue = static_cast(it->second); + } else { + DICP_LOG(ERROR) << "Cannot find key " << dynamicSize->second << " in global_dict"; + } + } else { + targetDimValue = splitSizes_[i]; + } + for (size_t j = 0; j < inputDimNum; ++j) { - outputDims[j] = (j != splitDim) ? inputDims[j] : splitSizes_[i]; + outputDims[j] = (j != splitDim) ? inputDims[j] : targetDimValue; } } @@ -62,8 +91,24 @@ int AclNnSplitWithSizeOperation::SetAclNnWorkspaceExecutor(uint64_t& workspaceSi tmp[i] = aclOutTensors_.at(i).tensor; } aclTensorList* tensorList = aclCreateTensorList(tmp.data(), tmp.size()); - aclIntArray* sizes = aclCreateIntArray(splitSizes_.data(), splitSizes_.size()); - int ret = aclnnSplitWithSizeGetWorkspaceSize(aclInTensors_.at(0).tensor, sizes, splitDim_, tensorList, &workspaceSize, &aclExecutor_); + + if (dynamicSplitSizesMap_.size() > 0) { + auto& globalDict = GetGlobalDictData(); + for (auto& [key, value] : dynamicSplitSizesMap_) { + auto it = globalDict.find(value); + if (it != globalDict.end()) { + splitSizes_[key] = static_cast(it->second); + } else { + DICP_LOG(ERROR) << "Cannot find key " << value << " in global dict"; + } + } + if (aclSplitSizes_ != nullptr) { + aclDestroyIntArray(aclSplitSizes_); + aclSplitSizes_ = nullptr; + } + aclSplitSizes_ = aclCreateIntArray(splitSizes_.data(), splitSizes_.size()); + } + int ret = aclnnSplitWithSizeGetWorkspaceSize(aclInTensors_.at(0).tensor, aclSplitSizes_, splitDim_, tensorList, &workspaceSize, &aclExecutor_); DICP_LOG(INFO) << opName_ << " aclnnSplitWithSizeGetWorkspaceSize end, ret:" << ret << ", workspaceSize:" << workspaceSize << ", aclExecutor:" << aclExecutor_; @@ -80,7 +125,7 @@ int AclNnSplitWithSizeOperation::CallAclExecute(uint8_t* workspace, uint64_t wor atb::Operation* AclNnSplitWithSizeOperationCreate(const nlohmann::json& paramJson) { std::string opName; int64_t splitDim; - std::vector splitSizes; + std::vector splitSizes; if (paramJson.contains("name")) { opName = paramJson["name"].get(); } @@ -88,7 +133,7 @@ atb::Operation* AclNnSplitWithSizeOperationCreate(const nlohmann::json& paramJso splitDim = paramJson["splitDim"].get(); } if (paramJson.contains("splitSizes")) { - splitSizes = paramJson["splitSizes"].get>(); + splitSizes = paramJson["splitSizes"].get>(); } DICP_LOG(INFO) << "AclNnSplitWithSizeOperation: name: " << opName; atb::Operation* op = new AclNnSplitWithSizeOperation(opName, splitDim, splitSizes); diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.h index e9749dc0..95b8dc2f 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.h +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/aclnn_ops/split_with_size_operation.h @@ -6,7 +6,7 @@ namespace dicp { class AclNnSplitWithSizeOperation : public AclNnOperation { public: - explicit AclNnSplitWithSizeOperation(const std::string& name, int64_t splitDim, std::vector splitSizes); + explicit AclNnSplitWithSizeOperation(const std::string& name, int64_t splitDim, std::vector splitSizes); ~AclNnSplitWithSizeOperation() override; atb::Status InferShape(const atb::SVector& inTensorDescs, atb::SVector& outTensorDescs) const override; uint32_t GetInputNum() const override; @@ -15,6 +15,8 @@ class AclNnSplitWithSizeOperation : public AclNnOperation { private: int64_t splitDim_; std::vector splitSizes_; + std::unordered_map dynamicSplitSizesMap_; + aclIntArray* aclSplitSizes_ = nullptr; int SetAclNnWorkspaceExecutor(uint64_t& workspaceSize) override; int CallAclExecute(uint8_t* workspace, uint64_t workspaceSize, aclOpExecutor* aclExecutor, aclrtStream stream) override; }; diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp new file mode 100644 index 00000000..48ea6350 --- /dev/null +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp @@ -0,0 +1,565 @@ +#include "fused_lora_operation.h" + +#include +#include +#include + +#include "aclnnop/aclnn_mul.h" +#include "aclnnop/aclnn_grouped_matmul_v4.h" +#include "aclnnop/aclnn_permute.h" +#include "ops/operation_creator.h" +#include "third_party/acl/inc/acl/acl_base.h" +#include "utils/common.h" +#include "utils/log.h" +#include "utils/scalar.h" + +#include + +namespace dicp { + +const int NUM1 = 1; +const int NUM2 = 2; +const int NUM3 = 3; +const int NUM8 = 8; + +CustomFusedLoraOperation::CustomFusedLoraOperation(const std::string& name, const std::string& dtype) : opName_(name), dtype_(dtype) {} + +CustomFusedLoraOperation::~CustomFusedLoraOperation() {} + +std::string CustomFusedLoraOperation::GetName() const { return opName_; } + +atb::Status CustomFusedLoraOperation::InferShape(const atb::SVector& inTensorDescs, atb::SVector& outTensorDescs) const { + DICP_LOG(INFO) << opName_ << " infer shape start"; + const auto totalLen = inTensorDescs.at(0).shape.dims[0]; + const auto totalRanks = inTensorDescs.at(1).shape.dims[0]; + const auto loraBDim = inTensorDescs.at(2).shape.dims[1]; + + // Main output tensor + outTensorDescs.at(0).shape.dimNum = 2; + outTensorDescs.at(0).shape.dims[0] = totalLen; + outTensorDescs.at(0).shape.dims[1] = loraBDim; + outTensorDescs.at(0).format = inTensorDescs.at(0).format; + outTensorDescs.at(0).dtype = inTensorDescs.at(0).dtype; + + // Internal gemm(x, lora_a) output + outTensorDescs.at(1).shape.dimNum = 2; + outTensorDescs.at(1).shape.dims[0] = totalLen; + outTensorDescs.at(1).shape.dims[1] = totalRanks * totalLen; // assuem totalRank is the max rank + outTensorDescs.at(1).format = inTensorDescs.at(0).format; + outTensorDescs.at(1).dtype = inTensorDescs.at(0).dtype; + + // Internal lora_a transpose output + outTensorDescs.at(2).shape.dimNum = 2; + outTensorDescs.at(2).shape.dims[0] = inTensorDescs.at(1).shape.dims[0]; + outTensorDescs.at(2).shape.dims[1] = inTensorDescs.at(1).shape.dims[1]; + outTensorDescs.at(2).format = inTensorDescs.at(1).format; + outTensorDescs.at(2).dtype = inTensorDescs.at(1).dtype; + return 0; +} + +uint32_t CustomFusedLoraOperation::GetInputNum() const { return NUM8; } + +uint32_t CustomFusedLoraOperation::GetOutputNum() const { return NUM3; } + +AclNnTensor CustomFusedLoraOperation::CreateTensor(const atb::Tensor& atbTensor) { + AclNnTensor aclNnTensor; + aclNnTensor.atbTensor = atbTensor; + return aclNnTensor; +} + +int CustomFusedLoraOperation::CreateAclTensors(const atb::VariantPack& variantPack) { + DICP_LOG(INFO) << opName_ << " CreateAclTensor start"; + + const size_t inTensorCount = variantPack.inTensors.size(); + const size_t outTensorCount = variantPack.outTensors.size(); + + aclInTensors_.resize(inTensorCount); + aclOutTensors_.resize(outTensorCount); + + for (size_t i = 0; i < inTensorCount; ++i) { + aclInTensors_[i] = CreateTensor(variantPack.inTensors.at(i)); + } + + for (size_t i = 0; i < outTensorCount; ++i) { + aclOutTensors_[i] = CreateTensor(variantPack.outTensors.at(i)); + } + + DICP_LOG(INFO) << opName_ << " CreateAclTensor end"; + return 0; +} + +void CustomFusedLoraOperation::ClearAclScalrs() { + for (auto* scalar : aclScalingScalar_) { + if (scalar != nullptr) { + aclDestroyScalar(scalar); + } + } + aclScalingScalar_.clear(); +} + +void CustomFusedLoraOperation::ClearInternal() { + ClearAclScalrs(); + aclWeightA_.clear(); + aclWeightB_.clear(); + aclWeightATranspose_.clear(); + weightA_.clear(); + weightB_.clear(); + weightATranspose_.clear(); + + aclScalingInput_.clear(); + scalingInput_.clear(); + aclScalingWeight_.clear(); + scalingWeight_.clear(); + + aclScalingWorkspace_.clear(); + aclScalingExecutor_.clear(); +} + +// Helper function to create weight tensor +atb::Tensor CustomFusedLoraOperation::CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset) { + atb::Tensor weightTensor; + weightTensor.desc.dtype = baseTensor.desc.dtype; + weightTensor.desc.format = baseTensor.desc.format; + weightTensor.desc.shape.dimNum = baseTensor.desc.shape.dimNum; + weightTensor.desc.shape.dims[0] = rank; + weightTensor.desc.shape.dims[1] = dim; + weightTensor.dataSize = atb::Utils::GetTensorSize(weightTensor.desc); + weightTensor.deviceData = static_cast(baseTensor.deviceData) + offset; + return weightTensor; +} + +// Helper function to calculate offset for weight tensors +uint64_t CustomFusedLoraOperation::CalculateWeightOffset(const std::vector& ranksVec, size_t adapterId, uint64_t tensorSizePerRank) { + uint64_t offset = 0; + for (size_t j = 0; j < adapterId; ++j) { + offset += tensorSizePerRank * static_cast(ranksVec[j]); + } + return offset; +} + +int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_t& workspaceSize, atb::Context* context) { + DICP_LOG(INFO) << opName_ << " Setup start"; + + DICP_CHECK_RET(CreateAclTensors(variantPack)); + + // Create tensors for non-host tensors + const std::unordered_set hostTensorIndices{3, 4, 6}; + for (size_t i = 0; i < aclInTensors_.size(); ++i) { + if (hostTensorIndices.find(i) == hostTensorIndices.end()) { + aclInTensors_.at(i).CreateTensor(opName_); + } + } + + for (auto& outTensor : aclOutTensors_) { + outTensor.CreateTensor(opName_); + } + + // Input tensor mapping: + // 0: x (input tensor) + // 1: lora_a (LoRA A weights) + // 2: lora_b (LoRA B weights) + // 3: scaling + // 4: ranks (host tensor) + // 5: seq_lens + // 6: adapter_ids (host tensor) + // 7 seq_lens_cpu (host tensor) + + // Extract host data + const int32_t* ranksPtr = static_cast(variantPack.inTensors.at(4).hostData); + const int32_t* adapterIdsPtr = static_cast(variantPack.inTensors.at(6).hostData); + const int32_t* seqLensPtr = static_cast(variantPack.inTensors.at(7).hostData); + + const size_t ranksCount = variantPack.inTensors.at(4).desc.shape.dims[0]; + const size_t adapterIdsCount = variantPack.inTensors.at(6).desc.shape.dims[0]; + const size_t seqLensCount = variantPack.inTensors.at(7).desc.shape.dims[0]; + + std::vector ranksVec(ranksPtr, ranksPtr + ranksCount); + std::vector adapterIdsVec(adapterIdsPtr, adapterIdsPtr + adapterIdsCount); + std::vector seqLensVec(seqLensPtr, seqLensPtr + seqLensCount); + + ranksVec[0] = 1; + + const int64_t loraADim = variantPack.inTensors.at(1).desc.shape.dims[1]; + const int64_t loraBDim = variantPack.inTensors.at(2).desc.shape.dims[1]; + + ClearInternal(); + + // Pre-allocate vectors to avoid reallocations + weightA_.reserve(adapterIdsVec.size()); + weightATranspose_.reserve(adapterIdsVec.size()); + weightB_.reserve(adapterIdsVec.size()); + + aclWeightA_.reserve(adapterIdsVec.size()); + aclWeightB_.reserve(adapterIdsVec.size()); + aclWeightATranspose_.reserve(adapterIdsVec.size()); + + scalingWeight_.reserve(adapterIdsVec.size()); + scalingInput_.reserve(adapterIdsVec.size()); + aclScalingWeight_.reserve(adapterIdsVec.size()); + aclScalingInput_.reserve(adapterIdsVec.size()); + + + bool singleInfer = adapterIdsVec.size() == 1; + int32_t totalRanks = 0; + + // Create weight tensors for each adapter + for (size_t i = 0; i < adapterIdsVec.size(); ++i) { + const int32_t adapterId = adapterIdsVec[i]; + const int32_t rank = ranksVec[adapterId]; + totalRanks += rank; + + // Create LoRA A weight tensor + atb::Tensor weightA; + weightA.desc.dtype = variantPack.inTensors.at(1).desc.dtype; + weightA.desc.format = variantPack.inTensors.at(1).desc.format; + if (singleInfer) { + weightA.desc.shape.dimNum = 3; + weightA.desc.shape.dims[0] = 1; + weightA.desc.shape.dims[1] = rank; + weightA.desc.shape.dims[2] = loraADim; + + } else { + weightA.desc.shape.dimNum = 2; + weightA.desc.shape.dims[0] = rank; + weightA.desc.shape.dims[1] = loraADim; + } + const uint64_t weightASize = atb::Utils::GetTensorSize(weightA.desc); + const uint64_t loraASizePerRank = weightASize / rank; + const uint64_t offsetA = CalculateWeightOffset(ranksVec, adapterId, loraASizePerRank); + weightA.deviceData = static_cast(variantPack.inTensors.at(1).deviceData) + offsetA; + + auto aclnnWeightA = CreateTensor(weightA); + aclnnWeightA.CreateTensor(opName_); + aclWeightA_.push_back(aclnnWeightA); + + atb::Tensor weightATranspose; + weightATranspose.desc.dtype = variantPack.inTensors.at(1).desc.dtype; + weightATranspose.desc.format = variantPack.inTensors.at(1).desc.format; + if (singleInfer) { + weightATranspose.desc.shape.dimNum = 3; + weightATranspose.desc.shape.dims[0] = 1; + weightATranspose.desc.shape.dims[1] = loraADim; + weightATranspose.desc.shape.dims[2] = rank; + } else { + weightATranspose.desc.shape.dimNum = 2; + weightATranspose.desc.shape.dims[0] = loraADim; + weightATranspose.desc.shape.dims[1] = rank; + } + weightATranspose.deviceData = static_cast(variantPack.outTensors.at(2).deviceData) + offsetA; + + auto aclnnWeightATranspose = CreateTensor(weightATranspose); + aclnnWeightATranspose.CreateTensor(opName_); + aclWeightATranspose_.push_back(aclnnWeightATranspose); + + weightATransposeIdMap_[adapterId] = aclWeightATranspose_.size() - 1; + + // Create LoRA B weight tensor + atb::Tensor weightB; + weightB.desc.dtype = variantPack.inTensors.at(2).desc.dtype; + weightB.desc.format = variantPack.inTensors.at(2).desc.format; + if (singleInfer) { + weightB.desc.shape.dimNum = 3; + weightB.desc.shape.dims[0] = 1; + weightB.desc.shape.dims[1] = rank; + weightB.desc.shape.dims[2] = loraBDim; + } else { + weightB.desc.shape.dimNum = 2; + weightB.desc.shape.dims[0] = rank; + weightB.desc.shape.dims[1] = loraBDim; + } + const uint64_t weightBSize = atb::Utils::GetTensorSize(weightB.desc); + const uint64_t loraBSizePerRank = weightBSize / rank; + const uint64_t offsetB = CalculateWeightOffset(ranksVec, adapterId, loraBSizePerRank); + weightB.deviceData = static_cast(variantPack.inTensors.at(2).deviceData) + offsetB; + + auto aclnnWeightB = CreateTensor(weightB); + aclnnWeightB.CreateTensor(opName_); + aclWeightB_.push_back(aclnnWeightB); + } + + // transpose weight A + std::vector permuteDims; + if (singleInfer) { + permuteDims = {0, 2, 1}; + } else { + permuteDims = {1, 0}; + } + aclIntArray *permuteDimsArray = aclCreateIntArray(permuteDims.data(), permuteDims.size()); + for (const auto& [adapterId, weightATransposeIndex] : weightATransposeIdMap_) { + aclWeightAPermuteExecutor_[adapterId] = nullptr; + aclWeightAPermuteWorkspace_[adapterId] = 0; + + auto& weightA = aclWeightA_[weightATransposeIndex]; + auto& weightATranspose = aclWeightATranspose_[weightATransposeIndex]; + + + int ret = aclnnPermuteGetWorkspaceSize(weightA.tensor, + permuteDimsArray, + weightATranspose.tensor, + &aclWeightAPermuteWorkspace_[adapterId], + &aclWeightAPermuteExecutor_[adapterId]); + DICP_LOG(INFO) << opName_ << " aclnnPermuteGetWorkspaceSize size[" << adapterId << "]: " << aclWeightAPermuteWorkspace_[adapterId] << ", ret: " << ret; + } + + // Setup grouped matrix multiplication + DICP_LOG(INFO) << opName_ << " Setting up grouped matrix multiplication"; + + // Create input tensor list + std::vector xTmp; + if (singleInfer) { + xTmp = {aclInTensors_.at(0).tensor}; + } else { + xTmp.reserve(seqLensVec.size()); + // split input by seq len + for (size_t i = 0; i < seqLensVec.size(); ++i) { + atb::Tensor slicedInput; + slicedInput.desc.dtype = aclInTensors_.at(0).atbTensor.desc.dtype; + slicedInput.desc.format = aclInTensors_.at(0).atbTensor.desc.format; + slicedInput.desc.shape.dimNum = aclInTensors_.at(0).atbTensor.desc.shape.dimNum; + slicedInput.desc.shape.dims[0] = seqLensVec[i]; + slicedInput.desc.shape.dims[1] = aclInTensors_.at(0).atbTensor.desc.shape.dims[1]; + slicedInput.dataSize = atb::Utils::GetTensorSize(slicedInput.desc); + + auto offset = CalculateWeightOffset(seqLensVec, i, slicedInput.dataSize / seqLensVec[i]); + slicedInput.deviceData = static_cast(aclInTensors_.at(0).atbTensor.deviceData) + offset; + auto aclnnSlicedInput = CreateTensor(slicedInput); + aclnnSlicedInput.CreateTensor(opName_); + xTmp.push_back(aclnnSlicedInput.tensor); + } + } + aclTensorList* xTensorList = aclCreateTensorList(xTmp.data(), xTmp.size()); + if (!xTensorList) { + DICP_LOG(ERROR) << opName_ << " Failed to create x tensor list"; + return -1; + } + + // Create weight tensor lists + std::vector weightTmpA; + std::vector weightTmpB; + weightTmpA.reserve(aclWeightATranspose_.size()); + weightTmpB.reserve(aclWeightB_.size()); + + for (const auto& weight : aclWeightATranspose_) { + weightTmpA.push_back(weight.tensor); + } + for (const auto& weight : aclWeightB_) { + weightTmpB.push_back(weight.tensor); + } + + aclTensorList* weightTensorListA = aclCreateTensorList(weightTmpA.data(), weightTmpA.size()); + aclTensorList* weightTensorListB = aclCreateTensorList(weightTmpB.data(), weightTmpB.size()); + + if (!weightTensorListA || !weightTensorListB) { + DICP_LOG(ERROR) << opName_ << " Failed to create weight tensor lists"; + return -1; + } + + // Create output tensor lists + // slice aclOutTensors_.at(1) + std::vector loraATmp; + if (singleInfer) { + atb::Tensor loraASliceOutput; + loraASliceOutput.desc.dtype = aclOutTensors_.at(1).atbTensor.desc.dtype; + loraASliceOutput.desc.format = aclOutTensors_.at(1).atbTensor.desc.format; + loraASliceOutput.desc.shape.dimNum = aclOutTensors_.at(1).atbTensor.desc.shape.dimNum; + loraASliceOutput.desc.shape.dims[0] = aclOutTensors_.at(1).atbTensor.desc.shape.dims[0]; + loraASliceOutput.desc.shape.dims[1] = totalRanks / adapterIdsVec.size(); + loraASliceOutput.dataSize = atb::Utils::GetTensorSize(loraASliceOutput.desc); + loraASliceOutput.deviceData = aclOutTensors_.at(1).atbTensor.deviceData; + auto aclnnLoraASliceOutput = CreateTensor(loraASliceOutput); + aclnnLoraASliceOutput.CreateTensor(opName_); + loraATmp = {aclnnLoraASliceOutput.tensor}; + } else { + loraATmp.reserve(seqLensVec.size()); + // split input by seq len + for (size_t i = 0; i < seqLensVec.size(); ++i) { + atb::Tensor slicedOutput; + slicedOutput.desc.dtype = aclOutTensors_.at(1).atbTensor.desc.dtype; + slicedOutput.desc.format = aclOutTensors_.at(1).atbTensor.desc.format; + slicedOutput.desc.shape.dimNum = aclOutTensors_.at(1).atbTensor.desc.shape.dimNum; + slicedOutput.desc.shape.dims[0] = seqLensVec[i]; + slicedOutput.desc.shape.dims[1] = ranksVec[adapterIdsVec[i]]; + slicedOutput.dataSize = atb::Utils::GetTensorSize(slicedOutput.desc); + + auto offset = CalculateWeightOffset(seqLensVec, i, slicedOutput.dataSize / seqLensVec[i]); + slicedOutput.deviceData = static_cast(aclOutTensors_.at(1).atbTensor.deviceData) + offset; + auto aclnnSlicedOutput = CreateTensor(slicedOutput); + aclnnSlicedOutput.CreateTensor(opName_); + loraATmp.push_back(aclnnSlicedOutput.tensor); + } + } + + std::vector loraBTmp{aclOutTensors_.at(0).tensor}; + + aclTensorList* loraAOutTensorList = aclCreateTensorList(loraATmp.data(), loraATmp.size()); + aclTensorList* loraBOutTensorList = aclCreateTensorList(loraBTmp.data(), loraBTmp.size()); + + if (!loraAOutTensorList || !loraBOutTensorList) { + DICP_LOG(ERROR) << opName_ << " Failed to create output tensor lists"; + return -1; + } + + // Setup LoRA A grouped matrix multiplication + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(xTensorList, // x + weightTensorListA, // weight + nullptr, // biasOptional + nullptr, // scaleOptional + nullptr, // offsetOptional + nullptr, // antiquantScaleOptional + nullptr, // antiquantOffsetOptional + nullptr, // perTokenScaleOptional + singleInfer ? aclInTensors_.at(5).tensor : nullptr, // groupListOptional + nullptr, // activationInputOptional + nullptr, // activationQuantScaleOptional + nullptr, // activationQuantOffsetOptional + singleInfer ? 2 : 0, // splitItem + singleInfer ? 0 : -1, // groupType + 1, // groupListType + 0, // actType + loraAOutTensorList, // out + nullptr, // activationFeatureOutOptional + nullptr, // dynQuantScaleOutOptional + &loraAGroupedGemmWorkspace_, + &aclLoraAGroupedGemmExecutor_); + DICP_LOG(INFO) << opName_ << " LoRA A grouped matmul workspace size: " << loraAGroupedGemmWorkspace_ << ", ret: " << ret; + + // Setup LoRA B grouped matrix multiplication + ret = aclnnGroupedMatmulV4GetWorkspaceSize(loraAOutTensorList, // x + weightTensorListB, // weight + nullptr, // biasOptional + nullptr, // scaleOptional + nullptr, // offsetOptional + nullptr, // antiquantScaleOptional + nullptr, // antiquantOffsetOptional + nullptr, // perTokenScaleOptional + aclInTensors_.at(5).tensor, // groupListOptional + nullptr, // activationInputOptional + nullptr, // activationQuantScaleOptional + nullptr, // activationQuantOffsetOptional + 2, // splitItem + 0, // groupType + 1, // groupListType + 0, // actType + loraBOutTensorList, // out + nullptr, // activationFeatureOutOptional + nullptr, // dynQuantScaleOutOptional + &loraBGroupedGemmWorkspace_, + &aclLoraBGroupedGemmExecutor_); + DICP_LOG(INFO) << opName_ << " LoRA B grouped matmul workspace size: " << loraBGroupedGemmWorkspace_ << ", ret: " << ret; + + // Setup scaling operations + aclScalingWorkspace_.resize(adapterIdsVec.size()); + aclScalingExecutor_.resize(adapterIdsVec.size()); + + for (size_t i = 0; i < adapterIdsVec.size(); ++i) { + const int32_t adapterId = adapterIdsVec[i]; + const auto& inputAtbTensor = aclOutTensors_.at(0).atbTensor; + const auto& scalingAtbTensor = aclInTensors_.at(3).atbTensor; + + // Create slice tensor for scaling + atb::Tensor input; + input.desc.dtype = inputAtbTensor.desc.dtype; + input.desc.format = inputAtbTensor.desc.format; + input.desc.shape.dimNum = inputAtbTensor.desc.shape.dimNum; + input.desc.shape.dims[0] = seqLensVec[i]; + input.desc.shape.dims[1] = loraBDim; + input.dataSize = atb::Utils::GetTensorSize(input.desc); + + uint64_t offset = 0; + for (size_t j = 0; j < i; ++j) { + offset += loraBDim * static_cast(seqLensVec[j]); + } + input.deviceData = static_cast(inputAtbTensor.deviceData) + offset; + + scalingInput_.push_back(input); + + // create slice tensor for scaling weight + atb::Tensor weight; + weight.desc.dtype = scalingAtbTensor.desc.dtype; + weight.desc.format = scalingAtbTensor.desc.format; + weight.desc.shape.dimNum = 1; + weight.desc.shape.dims[0] = 1; + weight.dataSize = atb::Utils::GetTensorSize(weight.desc); + + const uint64_t weight_offset = weight.dataSize * adapterId; + weight.deviceData = static_cast(scalingAtbTensor.deviceData) + weight_offset; + + scalingWeight_.push_back(weight); + + auto aclnnScalingInput = CreateTensor(input); + aclnnScalingInput.CreateTensor(opName_); + aclScalingInput_.push_back(aclnnScalingInput); + + auto aclnnScalingWeight = CreateTensor(weight); + aclnnScalingWeight.CreateTensor(opName_); + aclScalingWeight_.push_back(aclnnScalingWeight); + + ret = aclnnInplaceMulGetWorkspaceSize(aclScalingInput_.back().tensor, + aclScalingWeight_.back().tensor, + &aclScalingWorkspace_[i], + &aclScalingExecutor_[i]); + DICP_LOG(INFO) << opName_ << " Scaling workspace size[" << i << "]: " << aclScalingWorkspace_[i] << ", ret: " << ret; + } + + // Calculate total workspace size + const uint64_t scalingMaxValue = aclScalingWorkspace_.empty() ? 0 : *std::max_element(aclScalingWorkspace_.begin(), aclScalingWorkspace_.end()); + workspaceSize = std::max({loraAGroupedGemmWorkspace_, loraBGroupedGemmWorkspace_, scalingMaxValue}); + + DICP_LOG(INFO) << opName_ << " Setup completed, total workspace size: " << workspaceSize; + return ret; +} + +int CustomFusedLoraOperation::Execute(const atb::VariantPack& variantPack, uint8_t* workspace, uint64_t workspaceSize, atb::Context* context) { + DICP_LOG(INFO) << opName_ << " execute start"; + if (!context) { + DICP_LOG(ERROR) << opName_ << " execute fail, context param is null"; + return atb::ERROR_INVALID_PARAM; + } + + aclrtStream stream = context->GetExecuteStream(); + if (!stream) { + DICP_LOG(ERROR) << opName_ << " execute fail, execute stream in context is null"; + return atb::ERROR_INVALID_PARAM; + } + + // transpose weightA + for (const auto& [adapterId, weightATransposeIndex] : weightATransposeIdMap_) { + int ret = aclnnPermute(workspace, aclWeightAPermuteWorkspace_[adapterId], aclWeightAPermuteExecutor_[adapterId], stream); + DICP_LOG(INFO) << opName_ << " aclnnPermute completed, ret: " << ret; + } + + // Execute LoRA A grouped matrix multiplication + int ret = aclnnGroupedMatmulV4(workspace, loraAGroupedGemmWorkspace_, aclLoraAGroupedGemmExecutor_, stream); + DICP_LOG(INFO) << opName_ << " LoRA A grouped matmul completed, ret: " << ret; + + // Execute LoRA B grouped matrix multiplication + ret = aclnnGroupedMatmulV4(workspace, loraBGroupedGemmWorkspace_, aclLoraBGroupedGemmExecutor_, stream); + DICP_LOG(INFO) << opName_ << " LoRA B grouped matmul completed, ret: " << ret; + + // Execute scaling operations + for (size_t i = 0; i < aclScalingExecutor_.size(); ++i) { + ret = aclnnInplaceMul(workspace, aclScalingWorkspace_[i], aclScalingExecutor_[i], stream); + DICP_LOG(INFO) << opName_ << " Scaling operation[" << i << "] completed, ret: " << ret; + } + + DICP_LOG(INFO) << opName_ << " execute end"; + return 0; +} + +atb::Operation* CustomFusedLoraOperationCreate(const nlohmann::json& paramJson) { + std::string opName; + std::string dtype; + if (paramJson.contains("name")) { + opName = paramJson["name"].get(); + } + if (paramJson.contains("dtype")) { + dtype = paramJson["dtype"].get(); + } + DICP_LOG(INFO) << "CustomFusedLoraOperationCreate: name: " << opName << " dtype:" << dtype; + atb::Operation* op = new CustomFusedLoraOperation(opName, dtype); + return op; +} + +REGISTER_OPERATION(CustomFusedLoraOperation, CustomFusedLoraOperationCreate); + +} // namespace dicp diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h new file mode 100644 index 00000000..03fc0d9a --- /dev/null +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h @@ -0,0 +1,68 @@ +#pragma once + +#include "ops/aclnn_ops/acl_nn_operation.h" +#include +#include + +namespace dicp { + +class CustomFusedLoraOperation : public atb::Operation { +public: + explicit CustomFusedLoraOperation(const std::string& name, const std::string& dtype); + ~CustomFusedLoraOperation() override; + + std::string GetName() const override; + atb::Status Setup(const atb::VariantPack& variantPack, uint64_t& workspaceSize, atb::Context* context) override; + atb::Status Execute(const atb::VariantPack& variantPack, uint8_t* workspace, uint64_t workspaceSize, atb::Context* context) override; + atb::Status InferShape(const atb::SVector& inTensorDescs, atb::SVector& outTensorDescs) const override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +private: + atb::SVector aclInTensors_; + atb::SVector aclOutTensors_; + + AclNnTensor CreateTensor(const atb::Tensor& atbTensor); + int CreateAclTensors(const atb::VariantPack& variantPack); + void ClearAclScalrs(); + void ClearInternal(); + + // Helper functions for weight tensor creation and offset calculation + atb::Tensor CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset); + uint64_t CalculateWeightOffset(const std::vector& ranksVec, size_t adapterId, uint64_t tensorSizePerRank); + +private: + std::string opName_; + std::string dtype_; + std::vector aclScalingScalar_; + + std::vector weightA_; + std::vector weightB_; + std::vector weightATranspose_; + + std::vector aclWeightA_; + std::vector aclWeightB_; + std::vector aclWeightATranspose_; + + // adapterId, weightA_index + std::unordered_map weightATransposeIdMap_; + std::unordered_map aclWeightAPermuteExecutor_; + std::unordered_map aclWeightAPermuteWorkspace_; + + uint64_t loraAGroupedGemmWorkspace_ = 0; + uint64_t loraBGroupedGemmWorkspace_ = 0; + + aclOpExecutor* aclLoraAGroupedGemmExecutor_ = nullptr; + aclOpExecutor* aclLoraBGroupedGemmExecutor_ = nullptr; + + std::vector scalingWeight_; + std::vector scalingInput_; + std::vector aclScalingInput_; + std::vector aclScalingWeight_; + + std::vector aclScalingWorkspace_; + std::vector aclScalingExecutor_; + +}; + +} // namespace dicp diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/utils/operation_util.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/utils/operation_util.h index b2e3612d..8ba10925 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/utils/operation_util.h +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/utils/operation_util.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include namespace dicp { #define CREATE_OPERATION(param, operation) \ diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py index 3acf085b..e1551b1c 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/conversion.py @@ -812,11 +812,11 @@ def aten_new_empty(self, x, size, pin_memory=False): def aten_slice_scatter(self, x, data, dim=0, start=None, end=None, step=1): return self.get_proxy(atb_op.SliceScatter, (x, data, dim, start, end, step)) - @register_conversion(torch.ops.dlinfer.dynamic_quant.default) + @register_conversion("torch.ops.dlinfer.dynamic_quant.default") def dlinfer_dynamic_quant(self, x, quant_dtype, quant_granularity): return self.get_proxy(atb_op.AclNnDynamicQuant, (x, quant_dtype)) - @register_conversion(torch.ops.dlinfer.linear_w8a8.default) + @register_conversion("torch.ops.dlinfer.linear_w8a8.default") def dlinfer_linear_w8a8( self, x, y, rms_scale, linear_scale, out_type, quant_dtype, bias ): @@ -861,10 +861,91 @@ def aten_topk(self, x, k, dim=-1, largest=True, sorted=True): assert largest == True return self.get_proxy(atb_op.Sort, (x, k)) - @register_conversion(torch.ops.dlinfer.transdata.default) + @register_conversion("torch.ops.dlinfer.transdata.default") def dlinfer_transdata(self, x, transdata_type): return self.get_proxy(atb_op.Transdata, (x, transdata_type)) + @register_conversion("torch.ops.dlinfer.fused_lora.default") + def dlinfer_fused_lora( + self, + x, + lora_a, + lora_b, + scaling, + rank_start, + ranks, + seq_start, + seq_lens, + adapter_ids, + max_rank, + max_seqlen, + slice_start, + slice_stop, + slice_step, + output, + ): + assert slice_step is None, "slice_step is not supported yet" + total_len = x.node.meta["val"].shape[0] + out_dim = lora_b.node.meta["val"].shape[1] + + output_need_view = False + if output is not None: + assert ( + output.node.meta["val"].shape[-1] == out_dim + ), "lora output slice is not supported yet." + if len(output.node.meta["val"].shape) == 3: + output = self.get_proxy( + atb_op.View, (output, list(map(str, [total_len, out_dim]))) + ) + output_need_view = True + + # concat 0 rank + lora_a_dim = lora_a.node.meta["val"].shape[1] + rank0_zero_a = self.get_proxy( + atb_op.Zeros, + ( + list(map(str, [1, lora_a_dim])), + lora_b.node.meta["val"].dtype, + [1, lora_a_dim], + ), + ) + rank0_zero_b = self.get_proxy( + atb_op.Zeros, + ( + list(map(str, [1, out_dim])), + lora_b.node.meta["val"].dtype, + [1, out_dim], + ), + ) + concated_lora_a = self.get_proxy(atb_op.Concat, ([rank0_zero_a, lora_a], 0)) + concated_lora_b = self.get_proxy(atb_op.Concat, ([rank0_zero_b, lora_b], 0)) + + casted_scaling = self.get_proxy( + atb_op.Cast, (scaling, lora_b.node.meta["val"].dtype) + ) + + fused_lora = self.get_proxy( + atb_op.CustomFusedLora, + ( + x, + concated_lora_a, + concated_lora_b, + casted_scaling, + ranks, + seq_lens, + adapter_ids, + lora_b.node.meta["val"].dtype, + ), + ) + out = self.get_proxy(atb_op.GetItem, (fused_lora, 0)) + if output is not None: + out = self.get_proxy(atb_op.Add, (output, out)) + if output_need_view: + out = self.get_proxy( + atb_op.View, (out, list(map(str, [1, total_len, out_dim]))) + ) + return out + class ViewSymIntTransformer(torch.fx.Transformer): def call_function(self, target, args, kwargs): diff --git a/dlinfer/ops/llm.py b/dlinfer/ops/llm.py index 904df29d..6e069abd 100644 --- a/dlinfer/ops/llm.py +++ b/dlinfer/ops/llm.py @@ -27,6 +27,7 @@ "rms_norm_w8a8", "add_rms_norm_w8a8", "transdata", + "fused_lora", ] @@ -837,3 +838,71 @@ def transdata( Tensor : A tensor in target format. """ return vendor_ops_registry["transdata"](hidden_states, transdata_type) + + +def fused_lora_impl_abstract_func( + input: Tensor, + lora_a: Tensor, + lora_b: Tensor, + scaling: Tensor, + rank_start: Tensor, + ranks: Tensor, + seq_start: Tensor, + seq_lens: Tensor, + adapter_ids: Tensor, + max_rank: int, + max_seqlen: int, + slice_start: int, + slice_stop: int, + slice_step: Optional[int], + output: Optional[Tensor], +) -> Tensor: + M, K = input.shape + N = lora_b.size(1) + if output is None: + output = input.new_empty((M, N), dtype=input.dtype, device=input.device) + return output + + +@register_custom_op( + "dlinfer::fused_lora", + impl_abstract_func=fused_lora_impl_abstract_func, + default_value={"output": None}, +) +def fused_lora( + input: Tensor, + lora_a: Tensor, + lora_b: Tensor, + scaling: Tensor, + rank_start: Tensor, + ranks: Tensor, + seq_start: Tensor, + seq_lens: Tensor, + adapter_ids: Tensor, + max_rank: int, + max_seqlen: int, + slice_start: int, + slice_stop: int, + slice_step: Optional[int], + output: Optional[Tensor], +) -> Tensor: + """ + Fused lora. + """ + return vendor_ops_registry["fused_lora"]( + input, + lora_a, + lora_b, + scaling, + rank_start, + ranks, + seq_start, + seq_lens, + adapter_ids, + max_rank, + max_seqlen, + slice_start, + slice_stop, + slice_step, + output, + ) diff --git a/dlinfer/vendor/ascend/torch_npu_ops.py b/dlinfer/vendor/ascend/torch_npu_ops.py index b9035d86..bf30cb24 100644 --- a/dlinfer/vendor/ascend/torch_npu_ops.py +++ b/dlinfer/vendor/ascend/torch_npu_ops.py @@ -20,6 +20,7 @@ "weight_quant_matmul", "fused_moe", "linear", + "fused_lora", ] @@ -547,3 +548,57 @@ def transdata( transdata_type: int, ): raise NotImplementedError("transdata in eager mode is not implemented yet!") + + +@register_ops(vendor_ops_registry) +def fused_lora( + input: Tensor, + lora_a: Tensor, + lora_b: Tensor, + scaling: Tensor, + rank_start: Tensor, + ranks: Tensor, + seq_start: Tensor, + seq_lens: Tensor, + adapter_ids: Tensor, + max_rank: int, + max_seqlen: int, + slice_start: int, + slice_stop: int, + slice_step: Optional[int], + output: Optional[Tensor], +) -> Tensor: + total_len = input.size(0) + out_dim = lora_b.size(1) + + if output is not None: + output = output.contiguous() + base_slice = slice(slice_start, slice_stop, slice_step) + output = output[..., base_slice] + output = output.flatten(0, -2) + else: + output = torch.zeros(total_len, out_dim, dtype=input.dtype, device=input.device) + + num_seqs = seq_lens.size(0) + for i in range(num_seqs): + adapter_id = adapter_ids[i].item() + rank = ranks[adapter_id].item() + offset = rank_start[adapter_id].item() + scale = scaling[adapter_id].item() + + start_idx = seq_start[i].item() + seq_len = seq_lens[i].item() + + if rank == 0: + continue + + seq_input = input[start_idx : start_idx + seq_len] + A = lora_a[offset : offset + rank] + B = lora_b[offset : offset + rank] + + intermediate = torch.matmul(seq_input, A.t()) + lora_out = torch.matmul(intermediate, B) + lora_out = lora_out * scale + + output[start_idx : start_idx + seq_len] += lora_out + return output From 2a4cdec3264cde58c2d9643ebc6fd9771d98ef52 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Wed, 2 Jul 2025 03:33:12 +0000 Subject: [PATCH 2/5] [ascend] update torch version --- cmake/ascend.cmake | 7 +++--- .../graph/dicp/dynamo_bridge/torch_version.py | 7 ++++++ dlinfer/vendor/ascend/CMakeLists.txt | 12 ++++++---- .../ascend/csrc/torch_npu_symbol_fix.cpp | 24 ++++++++++++++----- requirements/ascend/torch.txt | 8 ++++--- 5 files changed, 42 insertions(+), 16 deletions(-) diff --git a/cmake/ascend.cmake b/cmake/ascend.cmake index bd6363a3..aecbbd11 100644 --- a/cmake/ascend.cmake +++ b/cmake/ascend.cmake @@ -17,10 +17,11 @@ execute_process( ) execute_process( - COMMAND python -c "import torch; from packaging import version; \ + COMMAND python -c "import torch; from packaging import version; \ torch_version = version.parse(torch.__version__).base_version; \ - print('1' if version.parse(torch_version) > version.parse('2.3.1') else '0', end='')" - OUTPUT_VARIABLE Torch_npu_VERSION_HIGHER_THAN_231 + print(torch_version, end='')" + OUTPUT_VARIABLE TORCH_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE ) find_package(Torch REQUIRED) diff --git a/dlinfer/graph/dicp/dynamo_bridge/torch_version.py b/dlinfer/graph/dicp/dynamo_bridge/torch_version.py index c2ce1d12..cb71c313 100644 --- a/dlinfer/graph/dicp/dynamo_bridge/torch_version.py +++ b/dlinfer/graph/dicp/dynamo_bridge/torch_version.py @@ -8,6 +8,9 @@ is_torch_220 = False is_torch_231 = False is_torch_251 = False +is_torch_260 = False +is_torch_271 = False + if torch_version.startswith("2.0"): is_torch_200 = True @@ -19,6 +22,10 @@ is_torch_231 = True elif torch_version.startswith("2.5.1"): is_torch_251 = True +elif torch_version.startswith("2.6.0"): + is_torch_260 = True +elif torch_version.startswith("2.7.1"): + is_torch_271 = True else: raise ValueError(f"unsupported dicp torch version: {torch.__version__}") diff --git a/dlinfer/vendor/ascend/CMakeLists.txt b/dlinfer/vendor/ascend/CMakeLists.txt index e99b7d1e..d77ccdd8 100644 --- a/dlinfer/vendor/ascend/CMakeLists.txt +++ b/dlinfer/vendor/ascend/CMakeLists.txt @@ -9,12 +9,9 @@ set(CSRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/moe_gating_topk_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/op_api_common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/csrc/torch_npu_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/torch_npu_symbol_fix.cpp ) -if("${Torch_npu_VERSION_HIGHER_THAN_231}" STREQUAL "1") - list(APPEND CSRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/csrc/torch_npu_symbol_fix.cpp) -endif() - add_library( ${PROJECT_NAME} SHARED ${CSRC_FILES} @@ -30,6 +27,13 @@ target_compile_definitions( GLIBCXX_USE_CXX11_ABI=${_GLIBCXX_USE_CXX11_ABI} ) +if(TORCH_VERSION STREQUAL "2.7.1") + target_compile_definitions( + ${PROJECT_NAME} PUBLIC + TORCH_VERSION_2_7_1 + ) +endif() + target_include_directories( ${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/dlinfer/vendor/ascend/csrc/torch_npu_symbol_fix.cpp b/dlinfer/vendor/ascend/csrc/torch_npu_symbol_fix.cpp index 4ad0f8a1..98b57b08 100644 --- a/dlinfer/vendor/ascend/csrc/torch_npu_symbol_fix.cpp +++ b/dlinfer/vendor/ascend/csrc/torch_npu_symbol_fix.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -12,7 +13,6 @@ namespace acl { // These functions are reimplemented to handle the missing symbol issue in // torch-npu >= 2.3.1. If these functions are called, it indicates an environment // setup issue and the program should terminate - aclError AclrtPeekAtLastError(aclrtLastErrLevel flag) { throw std::runtime_error( "Dlinfer AclrtPeekAtLastError should not be called. " @@ -21,24 +21,36 @@ aclError AclrtPeekAtLastError(aclrtLastErrLevel flag) { } } // namespace acl -bool checkUceErrAndRepair() { +void record_mem_hbm_ecc_error() { + throw std::runtime_error( + "Dlinfer record_mem_hbm_ecc_error should not be called. " + "Please check your environment setup."); +} + +#if !defined(TORCH_VERSION_2_7_1) +bool checkUceErrAndRepair(bool check_error, std::string& err_msg) { throw std::runtime_error( "Dlinfer checkUceErrAndRepair should not be called. " "Please check your environment setup."); return false; } +#endif -void checkUceErrAndRepair(bool tf, std::string& str) { +#if defined(TORCH_VERSION_2_7_1) +namespace option { +bool OptionsManager::ShouldPrintLessError() { throw std::runtime_error( - "Dlinfer checkUceErrAndRepair should not be called. " + "Dlinfer record_mem_hbm_ecc_error should not be called. " "Please check your environment setup."); - return; } +} // namespace option -void record_mem_hbm_ecc_error() { +std::string handleDeviceError(int errorCode) { throw std::runtime_error( "Dlinfer record_mem_hbm_ecc_error should not be called. " "Please check your environment setup."); + return ""; } +#endif } // namespace c10_npu diff --git a/requirements/ascend/torch.txt b/requirements/ascend/torch.txt index dce6e168..c6ed7a11 100644 --- a/requirements/ascend/torch.txt +++ b/requirements/ascend/torch.txt @@ -1,5 +1,7 @@ -torch==2.3.1 -torchvision==0.18.1 -torch-npu==2.3.1 +# Supported torch versions: 2.3.1, 2.5.1, 2.6.0, 2.7.1 +# Please install one of the supported versions manually +torch>=2.3.1,<2.8.0 +torch-npu>=2.3.1,<2.8.0 +torchvision>=0.18.1,<0.23.0 numpy<2.0.0 pyyaml From 77c7e1f76fd8377f3be48189afa26aec45b1ac99 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Mon, 7 Jul 2025 03:33:32 +0000 Subject: [PATCH 3/5] add compile option when cxx_abi_0 --- cmake/FindATB.cmake | 4 ++++ .../dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/cmake/FindATB.cmake b/cmake/FindATB.cmake index a82c7fcd..497fc30e 100644 --- a/cmake/FindATB.cmake +++ b/cmake/FindATB.cmake @@ -8,6 +8,10 @@ else() CACHE STRING "atb toolkit default home") endif() +# Extract cxx_abi version from ATB_HOME_PATH (last path component) +get_filename_component(ATB_CXX_ABI_VERSION ${ATB_HOME_PATH} NAME) +message(STATUS "ATB_CXX_ABI_VERSION: ${ATB_CXX_ABI_VERSION}") + # Include directories. find_path(ATB_INCLUDE_DIRS NAMES atb/atb_infer.h diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt index aa18de95..e9a453d2 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/CMakeLists.txt @@ -21,6 +21,11 @@ set(COMPILE_OPTIONS -Wno-attributes ) +# Add CXX11 ABI flag based on ATB_CXX_ABI_VERSION +if(ATB_CXX_ABI_VERSION STREQUAL "cxx_abi_0") + list(APPEND COMPILE_OPTIONS -D_GLIBCXX_USE_CXX11_ABI=0) +endif() + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O2") add_library(dicp_model SHARED ${SOURCES}) From f64cd1904fe097e79639b12b2d599e468d62d926 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Mon, 7 Jul 2025 03:35:31 +0000 Subject: [PATCH 4/5] code format --- .../ops/custom_ops/fused_lora_operation.cpp | 127 +++++++++--------- .../ops/custom_ops/fused_lora_operation.h | 8 +- 2 files changed, 64 insertions(+), 71 deletions(-) diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp index 48ea6350..1c3bd8a2 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp @@ -1,11 +1,13 @@ #include "fused_lora_operation.h" +#include + +#include #include #include -#include -#include "aclnnop/aclnn_mul.h" #include "aclnnop/aclnn_grouped_matmul_v4.h" +#include "aclnnop/aclnn_mul.h" #include "aclnnop/aclnn_permute.h" #include "ops/operation_creator.h" #include "third_party/acl/inc/acl/acl_base.h" @@ -13,8 +15,6 @@ #include "utils/log.h" #include "utils/scalar.h" -#include - namespace dicp { const int NUM1 = 1; @@ -72,10 +72,10 @@ int CustomFusedLoraOperation::CreateAclTensors(const atb::VariantPack& variantPa const size_t inTensorCount = variantPack.inTensors.size(); const size_t outTensorCount = variantPack.outTensors.size(); - + aclInTensors_.resize(inTensorCount); aclOutTensors_.resize(outTensorCount); - + for (size_t i = 0; i < inTensorCount; ++i) { aclInTensors_[i] = CreateTensor(variantPack.inTensors.at(i)); } @@ -104,7 +104,7 @@ void CustomFusedLoraOperation::ClearInternal() { aclWeightATranspose_.clear(); weightA_.clear(); weightB_.clear(); - weightATranspose_.clear(); + weightATranspose_.clear(); aclScalingInput_.clear(); scalingInput_.clear(); @@ -183,7 +183,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ const int64_t loraBDim = variantPack.inTensors.at(2).desc.shape.dims[1]; ClearInternal(); - + // Pre-allocate vectors to avoid reallocations weightA_.reserve(adapterIdsVec.size()); weightATranspose_.reserve(adapterIdsVec.size()); @@ -198,7 +198,6 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ aclScalingWeight_.reserve(adapterIdsVec.size()); aclScalingInput_.reserve(adapterIdsVec.size()); - bool singleInfer = adapterIdsVec.size() == 1; int32_t totalRanks = 0; @@ -284,7 +283,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ } else { permuteDims = {1, 0}; } - aclIntArray *permuteDimsArray = aclCreateIntArray(permuteDims.data(), permuteDims.size()); + aclIntArray* permuteDimsArray = aclCreateIntArray(permuteDims.data(), permuteDims.size()); for (const auto& [adapterId, weightATransposeIndex] : weightATransposeIdMap_) { aclWeightAPermuteExecutor_[adapterId] = nullptr; aclWeightAPermuteWorkspace_[adapterId] = 0; @@ -292,18 +291,14 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ auto& weightA = aclWeightA_[weightATransposeIndex]; auto& weightATranspose = aclWeightATranspose_[weightATransposeIndex]; - - int ret = aclnnPermuteGetWorkspaceSize(weightA.tensor, - permuteDimsArray, - weightATranspose.tensor, - &aclWeightAPermuteWorkspace_[adapterId], - &aclWeightAPermuteExecutor_[adapterId]); + int ret = aclnnPermuteGetWorkspaceSize( + weightA.tensor, permuteDimsArray, weightATranspose.tensor, &aclWeightAPermuteWorkspace_[adapterId], &aclWeightAPermuteExecutor_[adapterId]); DICP_LOG(INFO) << opName_ << " aclnnPermuteGetWorkspaceSize size[" << adapterId << "]: " << aclWeightAPermuteWorkspace_[adapterId] << ", ret: " << ret; } // Setup grouped matrix multiplication DICP_LOG(INFO) << opName_ << " Setting up grouped matrix multiplication"; - + // Create input tensor list std::vector xTmp; if (singleInfer) { @@ -317,7 +312,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ slicedInput.desc.format = aclInTensors_.at(0).atbTensor.desc.format; slicedInput.desc.shape.dimNum = aclInTensors_.at(0).atbTensor.desc.shape.dimNum; slicedInput.desc.shape.dims[0] = seqLensVec[i]; - slicedInput.desc.shape.dims[1] = aclInTensors_.at(0).atbTensor.desc.shape.dims[1]; + slicedInput.desc.shape.dims[1] = aclInTensors_.at(0).atbTensor.desc.shape.dims[1]; slicedInput.dataSize = atb::Utils::GetTensorSize(slicedInput.desc); auto offset = CalculateWeightOffset(seqLensVec, i, slicedInput.dataSize / seqLensVec[i]); @@ -338,14 +333,14 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ std::vector weightTmpB; weightTmpA.reserve(aclWeightATranspose_.size()); weightTmpB.reserve(aclWeightB_.size()); - + for (const auto& weight : aclWeightATranspose_) { weightTmpA.push_back(weight.tensor); } for (const auto& weight : aclWeightB_) { weightTmpB.push_back(weight.tensor); } - + aclTensorList* weightTensorListA = aclCreateTensorList(weightTmpA.data(), weightTmpA.size()); aclTensorList* weightTensorListB = aclCreateTensorList(weightTmpB.data(), weightTmpB.size()); @@ -363,7 +358,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ loraASliceOutput.desc.format = aclOutTensors_.at(1).atbTensor.desc.format; loraASliceOutput.desc.shape.dimNum = aclOutTensors_.at(1).atbTensor.desc.shape.dimNum; loraASliceOutput.desc.shape.dims[0] = aclOutTensors_.at(1).atbTensor.desc.shape.dims[0]; - loraASliceOutput.desc.shape.dims[1] = totalRanks / adapterIdsVec.size(); + loraASliceOutput.desc.shape.dims[1] = totalRanks / adapterIdsVec.size(); loraASliceOutput.dataSize = atb::Utils::GetTensorSize(loraASliceOutput.desc); loraASliceOutput.deviceData = aclOutTensors_.at(1).atbTensor.deviceData; auto aclnnLoraASliceOutput = CreateTensor(loraASliceOutput); @@ -378,7 +373,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ slicedOutput.desc.format = aclOutTensors_.at(1).atbTensor.desc.format; slicedOutput.desc.shape.dimNum = aclOutTensors_.at(1).atbTensor.desc.shape.dimNum; slicedOutput.desc.shape.dims[0] = seqLensVec[i]; - slicedOutput.desc.shape.dims[1] = ranksVec[adapterIdsVec[i]]; + slicedOutput.desc.shape.dims[1] = ranksVec[adapterIdsVec[i]]; slicedOutput.dataSize = atb::Utils::GetTensorSize(slicedOutput.desc); auto offset = CalculateWeightOffset(seqLensVec, i, slicedOutput.dataSize / seqLensVec[i]); @@ -398,59 +393,59 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ DICP_LOG(ERROR) << opName_ << " Failed to create output tensor lists"; return -1; } - + // Setup LoRA A grouped matrix multiplication - int ret = aclnnGroupedMatmulV4GetWorkspaceSize(xTensorList, // x - weightTensorListA, // weight - nullptr, // biasOptional - nullptr, // scaleOptional - nullptr, // offsetOptional - nullptr, // antiquantScaleOptional - nullptr, // antiquantOffsetOptional - nullptr, // perTokenScaleOptional - singleInfer ? aclInTensors_.at(5).tensor : nullptr, // groupListOptional - nullptr, // activationInputOptional - nullptr, // activationQuantScaleOptional - nullptr, // activationQuantOffsetOptional - singleInfer ? 2 : 0, // splitItem - singleInfer ? 0 : -1, // groupType - 1, // groupListType - 0, // actType - loraAOutTensorList, // out - nullptr, // activationFeatureOutOptional - nullptr, // dynQuantScaleOutOptional + int ret = aclnnGroupedMatmulV4GetWorkspaceSize(xTensorList, // x + weightTensorListA, // weight + nullptr, // biasOptional + nullptr, // scaleOptional + nullptr, // offsetOptional + nullptr, // antiquantScaleOptional + nullptr, // antiquantOffsetOptional + nullptr, // perTokenScaleOptional + singleInfer ? aclInTensors_.at(5).tensor : nullptr, // groupListOptional + nullptr, // activationInputOptional + nullptr, // activationQuantScaleOptional + nullptr, // activationQuantOffsetOptional + singleInfer ? 2 : 0, // splitItem + singleInfer ? 0 : -1, // groupType + 1, // groupListType + 0, // actType + loraAOutTensorList, // out + nullptr, // activationFeatureOutOptional + nullptr, // dynQuantScaleOutOptional &loraAGroupedGemmWorkspace_, &aclLoraAGroupedGemmExecutor_); DICP_LOG(INFO) << opName_ << " LoRA A grouped matmul workspace size: " << loraAGroupedGemmWorkspace_ << ", ret: " << ret; // Setup LoRA B grouped matrix multiplication - ret = aclnnGroupedMatmulV4GetWorkspaceSize(loraAOutTensorList, // x - weightTensorListB, // weight - nullptr, // biasOptional - nullptr, // scaleOptional - nullptr, // offsetOptional - nullptr, // antiquantScaleOptional - nullptr, // antiquantOffsetOptional - nullptr, // perTokenScaleOptional - aclInTensors_.at(5).tensor, // groupListOptional - nullptr, // activationInputOptional - nullptr, // activationQuantScaleOptional - nullptr, // activationQuantOffsetOptional - 2, // splitItem - 0, // groupType - 1, // groupListType - 0, // actType - loraBOutTensorList, // out - nullptr, // activationFeatureOutOptional - nullptr, // dynQuantScaleOutOptional + ret = aclnnGroupedMatmulV4GetWorkspaceSize(loraAOutTensorList, // x + weightTensorListB, // weight + nullptr, // biasOptional + nullptr, // scaleOptional + nullptr, // offsetOptional + nullptr, // antiquantScaleOptional + nullptr, // antiquantOffsetOptional + nullptr, // perTokenScaleOptional + aclInTensors_.at(5).tensor, // groupListOptional + nullptr, // activationInputOptional + nullptr, // activationQuantScaleOptional + nullptr, // activationQuantOffsetOptional + 2, // splitItem + 0, // groupType + 1, // groupListType + 0, // actType + loraBOutTensorList, // out + nullptr, // activationFeatureOutOptional + nullptr, // dynQuantScaleOutOptional &loraBGroupedGemmWorkspace_, &aclLoraBGroupedGemmExecutor_); DICP_LOG(INFO) << opName_ << " LoRA B grouped matmul workspace size: " << loraBGroupedGemmWorkspace_ << ", ret: " << ret; - + // Setup scaling operations aclScalingWorkspace_.resize(adapterIdsVec.size()); aclScalingExecutor_.resize(adapterIdsVec.size()); - + for (size_t i = 0; i < adapterIdsVec.size(); ++i) { const int32_t adapterId = adapterIdsVec[i]; const auto& inputAtbTensor = aclOutTensors_.at(0).atbTensor; @@ -494,10 +489,8 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ aclnnScalingWeight.CreateTensor(opName_); aclScalingWeight_.push_back(aclnnScalingWeight); - ret = aclnnInplaceMulGetWorkspaceSize(aclScalingInput_.back().tensor, - aclScalingWeight_.back().tensor, - &aclScalingWorkspace_[i], - &aclScalingExecutor_[i]); + ret = + aclnnInplaceMulGetWorkspaceSize(aclScalingInput_.back().tensor, aclScalingWeight_.back().tensor, &aclScalingWorkspace_[i], &aclScalingExecutor_[i]); DICP_LOG(INFO) << opName_ << " Scaling workspace size[" << i << "]: " << aclScalingWorkspace_[i] << ", ret: " << ret; } diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h index 03fc0d9a..1534e4fc 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h @@ -1,8 +1,9 @@ #pragma once -#include "ops/aclnn_ops/acl_nn_operation.h" -#include #include +#include + +#include "ops/aclnn_ops/acl_nn_operation.h" namespace dicp { @@ -26,7 +27,7 @@ class CustomFusedLoraOperation : public atb::Operation { int CreateAclTensors(const atb::VariantPack& variantPack); void ClearAclScalrs(); void ClearInternal(); - + // Helper functions for weight tensor creation and offset calculation atb::Tensor CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset); uint64_t CalculateWeightOffset(const std::vector& ranksVec, size_t adapterId, uint64_t tensorSizePerRank); @@ -62,7 +63,6 @@ class CustomFusedLoraOperation : public atb::Operation { std::vector aclScalingWorkspace_; std::vector aclScalingExecutor_; - }; } // namespace dicp From e264503fc70cbb44b7d0542b102aa5dc3db54623 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Tue, 8 Jul 2025 03:14:13 +0000 Subject: [PATCH 5/5] udpate codes --- dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py | 8 ++++++- .../dicp/vendor/AtbGraph/codegen/atb_op.py | 1 - .../ops/custom_ops/fused_lora_operation.cpp | 22 ------------------- .../ops/custom_ops/fused_lora_operation.h | 6 ----- 4 files changed, 7 insertions(+), 30 deletions(-) diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py b/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py index 0de70b3b..677be300 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/atb_op.py @@ -835,9 +835,15 @@ def infer_result( self, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype ): M, K = x.shape + ranks = lora_a.size(0) N = lora_b.size(1) output = torch.empty((M, N), dtype=x.dtype, device=x.device) - return output, output + # assuem totalRank is the max rank + internal_output_x_lora_a = torch.empty( + (M, ranks * M), dtype=x.dtype, device=x.device + ) + internal_lora_a_transpose = torch.empty_like(lora_a) + return output, internal_output_x_lora_a, internal_lora_a_transpose class AclNnInplaceAdd(Operator): diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py index 4cb10bc0..ea1084be 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/atb_op.py @@ -1251,7 +1251,6 @@ def CustomFusedLora( name, x, lora_a, lora_b, scaling, ranks, seq_lens, adapter_ids, dtype ): op = Operation(name, "CustomFusedLoraOperation") - # TODO: add param param = infer_param.CustomFusedLoraParam() param.name = name param.dtype = get_ascend_dtype(dtype) diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp index 1c3bd8a2..50ee8a27 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.cpp @@ -102,9 +102,6 @@ void CustomFusedLoraOperation::ClearInternal() { aclWeightA_.clear(); aclWeightB_.clear(); aclWeightATranspose_.clear(); - weightA_.clear(); - weightB_.clear(); - weightATranspose_.clear(); aclScalingInput_.clear(); scalingInput_.clear(); @@ -115,19 +112,6 @@ void CustomFusedLoraOperation::ClearInternal() { aclScalingExecutor_.clear(); } -// Helper function to create weight tensor -atb::Tensor CustomFusedLoraOperation::CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset) { - atb::Tensor weightTensor; - weightTensor.desc.dtype = baseTensor.desc.dtype; - weightTensor.desc.format = baseTensor.desc.format; - weightTensor.desc.shape.dimNum = baseTensor.desc.shape.dimNum; - weightTensor.desc.shape.dims[0] = rank; - weightTensor.desc.shape.dims[1] = dim; - weightTensor.dataSize = atb::Utils::GetTensorSize(weightTensor.desc); - weightTensor.deviceData = static_cast(baseTensor.deviceData) + offset; - return weightTensor; -} - // Helper function to calculate offset for weight tensors uint64_t CustomFusedLoraOperation::CalculateWeightOffset(const std::vector& ranksVec, size_t adapterId, uint64_t tensorSizePerRank) { uint64_t offset = 0; @@ -183,12 +167,6 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_ const int64_t loraBDim = variantPack.inTensors.at(2).desc.shape.dims[1]; ClearInternal(); - - // Pre-allocate vectors to avoid reallocations - weightA_.reserve(adapterIdsVec.size()); - weightATranspose_.reserve(adapterIdsVec.size()); - weightB_.reserve(adapterIdsVec.size()); - aclWeightA_.reserve(adapterIdsVec.size()); aclWeightB_.reserve(adapterIdsVec.size()); aclWeightATranspose_.reserve(adapterIdsVec.size()); diff --git a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h index 1534e4fc..0bfed49c 100644 --- a/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h +++ b/dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/ops/custom_ops/fused_lora_operation.h @@ -28,8 +28,6 @@ class CustomFusedLoraOperation : public atb::Operation { void ClearAclScalrs(); void ClearInternal(); - // Helper functions for weight tensor creation and offset calculation - atb::Tensor CreateWeightTensor(const atb::Tensor& baseTensor, int64_t rank, int64_t dim, uint64_t offset); uint64_t CalculateWeightOffset(const std::vector& ranksVec, size_t adapterId, uint64_t tensorSizePerRank); private: @@ -37,10 +35,6 @@ class CustomFusedLoraOperation : public atb::Operation { std::string dtype_; std::vector aclScalingScalar_; - std::vector weightA_; - std::vector weightB_; - std::vector weightATranspose_; - std::vector aclWeightA_; std::vector aclWeightB_; std::vector aclWeightATranspose_;