11#include " fused_lora_operation.h"
22
3+ #include < atb/utils.h>
4+
5+ #include < algorithm>
36#include < cstdint>
47#include < unordered_set>
5- #include < algorithm>
68
7- #include " aclnnop/aclnn_mul.h"
89#include " aclnnop/aclnn_grouped_matmul_v4.h"
10+ #include " aclnnop/aclnn_mul.h"
911#include " aclnnop/aclnn_permute.h"
1012#include " ops/operation_creator.h"
1113#include " third_party/acl/inc/acl/acl_base.h"
1214#include " utils/common.h"
1315#include " utils/log.h"
1416#include " utils/scalar.h"
1517
16- #include < atb/utils.h>
17-
1818namespace dicp {
1919
2020const int NUM1 = 1 ;
@@ -72,10 +72,10 @@ int CustomFusedLoraOperation::CreateAclTensors(const atb::VariantPack& variantPa
7272
7373 const size_t inTensorCount = variantPack.inTensors .size ();
7474 const size_t outTensorCount = variantPack.outTensors .size ();
75-
75+
7676 aclInTensors_.resize (inTensorCount);
7777 aclOutTensors_.resize (outTensorCount);
78-
78+
7979 for (size_t i = 0 ; i < inTensorCount; ++i) {
8080 aclInTensors_[i] = CreateTensor (variantPack.inTensors .at (i));
8181 }
@@ -104,7 +104,7 @@ void CustomFusedLoraOperation::ClearInternal() {
104104 aclWeightATranspose_.clear ();
105105 weightA_.clear ();
106106 weightB_.clear ();
107- weightATranspose_.clear ();
107+ weightATranspose_.clear ();
108108
109109 aclScalingInput_.clear ();
110110 scalingInput_.clear ();
@@ -183,7 +183,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
183183 const int64_t loraBDim = variantPack.inTensors .at (2 ).desc .shape .dims [1 ];
184184
185185 ClearInternal ();
186-
186+
187187 // Pre-allocate vectors to avoid reallocations
188188 weightA_.reserve (adapterIdsVec.size ());
189189 weightATranspose_.reserve (adapterIdsVec.size ());
@@ -198,7 +198,6 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
198198 aclScalingWeight_.reserve (adapterIdsVec.size ());
199199 aclScalingInput_.reserve (adapterIdsVec.size ());
200200
201-
202201 bool singleInfer = adapterIdsVec.size () == 1 ;
203202 int32_t totalRanks = 0 ;
204203
@@ -284,26 +283,22 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
284283 } else {
285284 permuteDims = {1 , 0 };
286285 }
287- aclIntArray * permuteDimsArray = aclCreateIntArray (permuteDims.data (), permuteDims.size ());
286+ aclIntArray* permuteDimsArray = aclCreateIntArray (permuteDims.data (), permuteDims.size ());
288287 for (const auto & [adapterId, weightATransposeIndex] : weightATransposeIdMap_) {
289288 aclWeightAPermuteExecutor_[adapterId] = nullptr ;
290289 aclWeightAPermuteWorkspace_[adapterId] = 0 ;
291290
292291 auto & weightA = aclWeightA_[weightATransposeIndex];
293292 auto & weightATranspose = aclWeightATranspose_[weightATransposeIndex];
294293
295-
296- int ret = aclnnPermuteGetWorkspaceSize (weightA.tensor ,
297- permuteDimsArray,
298- weightATranspose.tensor ,
299- &aclWeightAPermuteWorkspace_[adapterId],
300- &aclWeightAPermuteExecutor_[adapterId]);
294+ int ret = aclnnPermuteGetWorkspaceSize (
295+ weightA.tensor , permuteDimsArray, weightATranspose.tensor , &aclWeightAPermuteWorkspace_[adapterId], &aclWeightAPermuteExecutor_[adapterId]);
301296 DICP_LOG (INFO) << opName_ << " aclnnPermuteGetWorkspaceSize size[" << adapterId << " ]: " << aclWeightAPermuteWorkspace_[adapterId] << " , ret: " << ret;
302297 }
303298
304299 // Setup grouped matrix multiplication
305300 DICP_LOG (INFO) << opName_ << " Setting up grouped matrix multiplication" ;
306-
301+
307302 // Create input tensor list
308303 std::vector<aclTensor*> xTmp;
309304 if (singleInfer) {
@@ -317,7 +312,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
317312 slicedInput.desc .format = aclInTensors_.at (0 ).atbTensor .desc .format ;
318313 slicedInput.desc .shape .dimNum = aclInTensors_.at (0 ).atbTensor .desc .shape .dimNum ;
319314 slicedInput.desc .shape .dims [0 ] = seqLensVec[i];
320- slicedInput.desc .shape .dims [1 ] = aclInTensors_.at (0 ).atbTensor .desc .shape .dims [1 ];
315+ slicedInput.desc .shape .dims [1 ] = aclInTensors_.at (0 ).atbTensor .desc .shape .dims [1 ];
321316 slicedInput.dataSize = atb::Utils::GetTensorSize (slicedInput.desc );
322317
323318 auto offset = CalculateWeightOffset (seqLensVec, i, slicedInput.dataSize / seqLensVec[i]);
@@ -338,14 +333,14 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
338333 std::vector<aclTensor*> weightTmpB;
339334 weightTmpA.reserve (aclWeightATranspose_.size ());
340335 weightTmpB.reserve (aclWeightB_.size ());
341-
336+
342337 for (const auto & weight : aclWeightATranspose_) {
343338 weightTmpA.push_back (weight.tensor );
344339 }
345340 for (const auto & weight : aclWeightB_) {
346341 weightTmpB.push_back (weight.tensor );
347342 }
348-
343+
349344 aclTensorList* weightTensorListA = aclCreateTensorList (weightTmpA.data (), weightTmpA.size ());
350345 aclTensorList* weightTensorListB = aclCreateTensorList (weightTmpB.data (), weightTmpB.size ());
351346
@@ -363,7 +358,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
363358 loraASliceOutput.desc .format = aclOutTensors_.at (1 ).atbTensor .desc .format ;
364359 loraASliceOutput.desc .shape .dimNum = aclOutTensors_.at (1 ).atbTensor .desc .shape .dimNum ;
365360 loraASliceOutput.desc .shape .dims [0 ] = aclOutTensors_.at (1 ).atbTensor .desc .shape .dims [0 ];
366- loraASliceOutput.desc .shape .dims [1 ] = totalRanks / adapterIdsVec.size ();
361+ loraASliceOutput.desc .shape .dims [1 ] = totalRanks / adapterIdsVec.size ();
367362 loraASliceOutput.dataSize = atb::Utils::GetTensorSize (loraASliceOutput.desc );
368363 loraASliceOutput.deviceData = aclOutTensors_.at (1 ).atbTensor .deviceData ;
369364 auto aclnnLoraASliceOutput = CreateTensor (loraASliceOutput);
@@ -378,7 +373,7 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
378373 slicedOutput.desc .format = aclOutTensors_.at (1 ).atbTensor .desc .format ;
379374 slicedOutput.desc .shape .dimNum = aclOutTensors_.at (1 ).atbTensor .desc .shape .dimNum ;
380375 slicedOutput.desc .shape .dims [0 ] = seqLensVec[i];
381- slicedOutput.desc .shape .dims [1 ] = ranksVec[adapterIdsVec[i]];
376+ slicedOutput.desc .shape .dims [1 ] = ranksVec[adapterIdsVec[i]];
382377 slicedOutput.dataSize = atb::Utils::GetTensorSize (slicedOutput.desc );
383378
384379 auto offset = CalculateWeightOffset (seqLensVec, i, slicedOutput.dataSize / seqLensVec[i]);
@@ -398,59 +393,59 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
398393 DICP_LOG (ERROR) << opName_ << " Failed to create output tensor lists" ;
399394 return -1 ;
400395 }
401-
396+
402397 // Setup LoRA A grouped matrix multiplication
403- int ret = aclnnGroupedMatmulV4GetWorkspaceSize (xTensorList, // x
404- weightTensorListA, // weight
405- nullptr , // biasOptional
406- nullptr , // scaleOptional
407- nullptr , // offsetOptional
408- nullptr , // antiquantScaleOptional
409- nullptr , // antiquantOffsetOptional
410- nullptr , // perTokenScaleOptional
411- singleInfer ? aclInTensors_.at (5 ).tensor : nullptr , // groupListOptional
412- nullptr , // activationInputOptional
413- nullptr , // activationQuantScaleOptional
414- nullptr , // activationQuantOffsetOptional
415- singleInfer ? 2 : 0 , // splitItem
416- singleInfer ? 0 : -1 , // groupType
417- 1 , // groupListType
418- 0 , // actType
419- loraAOutTensorList, // out
420- nullptr , // activationFeatureOutOptional
421- nullptr , // dynQuantScaleOutOptional
398+ int ret = aclnnGroupedMatmulV4GetWorkspaceSize (xTensorList, // x
399+ weightTensorListA, // weight
400+ nullptr , // biasOptional
401+ nullptr , // scaleOptional
402+ nullptr , // offsetOptional
403+ nullptr , // antiquantScaleOptional
404+ nullptr , // antiquantOffsetOptional
405+ nullptr , // perTokenScaleOptional
406+ singleInfer ? aclInTensors_.at (5 ).tensor : nullptr , // groupListOptional
407+ nullptr , // activationInputOptional
408+ nullptr , // activationQuantScaleOptional
409+ nullptr , // activationQuantOffsetOptional
410+ singleInfer ? 2 : 0 , // splitItem
411+ singleInfer ? 0 : -1 , // groupType
412+ 1 , // groupListType
413+ 0 , // actType
414+ loraAOutTensorList, // out
415+ nullptr , // activationFeatureOutOptional
416+ nullptr , // dynQuantScaleOutOptional
422417 &loraAGroupedGemmWorkspace_,
423418 &aclLoraAGroupedGemmExecutor_);
424419 DICP_LOG (INFO) << opName_ << " LoRA A grouped matmul workspace size: " << loraAGroupedGemmWorkspace_ << " , ret: " << ret;
425420
426421 // Setup LoRA B grouped matrix multiplication
427- ret = aclnnGroupedMatmulV4GetWorkspaceSize (loraAOutTensorList, // x
428- weightTensorListB, // weight
429- nullptr , // biasOptional
430- nullptr , // scaleOptional
431- nullptr , // offsetOptional
432- nullptr , // antiquantScaleOptional
433- nullptr , // antiquantOffsetOptional
434- nullptr , // perTokenScaleOptional
435- aclInTensors_.at (5 ).tensor , // groupListOptional
436- nullptr , // activationInputOptional
437- nullptr , // activationQuantScaleOptional
438- nullptr , // activationQuantOffsetOptional
439- 2 , // splitItem
440- 0 , // groupType
441- 1 , // groupListType
442- 0 , // actType
443- loraBOutTensorList, // out
444- nullptr , // activationFeatureOutOptional
445- nullptr , // dynQuantScaleOutOptional
422+ ret = aclnnGroupedMatmulV4GetWorkspaceSize (loraAOutTensorList, // x
423+ weightTensorListB, // weight
424+ nullptr , // biasOptional
425+ nullptr , // scaleOptional
426+ nullptr , // offsetOptional
427+ nullptr , // antiquantScaleOptional
428+ nullptr , // antiquantOffsetOptional
429+ nullptr , // perTokenScaleOptional
430+ aclInTensors_.at (5 ).tensor , // groupListOptional
431+ nullptr , // activationInputOptional
432+ nullptr , // activationQuantScaleOptional
433+ nullptr , // activationQuantOffsetOptional
434+ 2 , // splitItem
435+ 0 , // groupType
436+ 1 , // groupListType
437+ 0 , // actType
438+ loraBOutTensorList, // out
439+ nullptr , // activationFeatureOutOptional
440+ nullptr , // dynQuantScaleOutOptional
446441 &loraBGroupedGemmWorkspace_,
447442 &aclLoraBGroupedGemmExecutor_);
448443 DICP_LOG (INFO) << opName_ << " LoRA B grouped matmul workspace size: " << loraBGroupedGemmWorkspace_ << " , ret: " << ret;
449-
444+
450445 // Setup scaling operations
451446 aclScalingWorkspace_.resize (adapterIdsVec.size ());
452447 aclScalingExecutor_.resize (adapterIdsVec.size ());
453-
448+
454449 for (size_t i = 0 ; i < adapterIdsVec.size (); ++i) {
455450 const int32_t adapterId = adapterIdsVec[i];
456451 const auto & inputAtbTensor = aclOutTensors_.at (0 ).atbTensor ;
@@ -494,10 +489,8 @@ int CustomFusedLoraOperation::Setup(const atb::VariantPack& variantPack, uint64_
494489 aclnnScalingWeight.CreateTensor (opName_);
495490 aclScalingWeight_.push_back (aclnnScalingWeight);
496491
497- ret = aclnnInplaceMulGetWorkspaceSize (aclScalingInput_.back ().tensor ,
498- aclScalingWeight_.back ().tensor ,
499- &aclScalingWorkspace_[i],
500- &aclScalingExecutor_[i]);
492+ ret =
493+ aclnnInplaceMulGetWorkspaceSize (aclScalingInput_.back ().tensor , aclScalingWeight_.back ().tensor , &aclScalingWorkspace_[i], &aclScalingExecutor_[i]);
501494 DICP_LOG (INFO) << opName_ << " Scaling workspace size[" << i << " ]: " << aclScalingWorkspace_[i] << " , ret: " << ret;
502495 }
503496
0 commit comments