88#include < random>
99#include < thread>
1010#include < vector>
11- #include < iostream>
12- #include < iomanip>
1311
1412void createDeviceResource (DeviceResource *rsrc, const JiugeMeta *meta,
1513 const JiugeWeights *weights,
@@ -21,7 +19,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
2119 infiniopCreateHandle (&handle);
2220 infinirtStream_t stream;
2321 infinirtStreamCreate (&stream);
24-
22+ // 加载权重
2523 std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
2624 w_ffn_norm, w_ffn_gate_up, w_ffn_down;
2725 for (size_t layer = 0 ; layer < meta->nlayer ; layer++) {
@@ -43,13 +41,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
4341 getFFNDown (meta, weights, layer, idev, ndev));
4442 }
4543
46- // 配置内存池预分配策略
47- MemoryPool::PreallocationConfig pool_config;
48- pool_config.small_pool_size = 32 * 1024 * 1024 ; // 32MB for small allocations
49- pool_config.medium_pool_size = 64 * 1024 * 1024 ; // 64MB for medium allocations
50- pool_config.large_pool_size = 128 * 1024 * 1024 ; // 128MB for large allocations
51-
52- auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024 , MemoryPool::DEFAULT_ALIGNMENT, pool_config);
44+ auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024 );
5345
5446 *rsrc = DeviceResource{
5547 device,
@@ -120,36 +112,27 @@ void releaseDeviceResource(DeviceResource &res) {
120112
121113void inferDeviceBatch (const JiugeMeta &meta, DeviceResource &rsrc,
122114 uint32_t idev, uint32_t ndev,
123- const uint32_t *tokens, uint32_t ntok,
115+ const uint32_t *tokens, uint32_t ntok, // 所有req的ntokens之和
124116 const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
125117 struct KVCache **kv_caches,
126118 const float *temperature, const uint32_t *topk, const float *topp,
127119 uint32_t *output) {
128- // 推理开始前检查内存状态
129- static int inference_count = 0 ;
130- inference_count++;
131-
132- // 每10次推理检查一次碎片率,必要时进行碎片整理
133- if (inference_count % 10 == 0 && rsrc.memory_pool ->shouldDefragment ()) {
134- rsrc.memory_pool ->defragment ();
135- }
136-
137120 auto nlayer = meta.nlayer ;
138121 auto nkvh = meta.nkvh / ndev;
139122 auto nh = meta.nh / ndev;
140123 auto ngroup = nh / nkvh;
141124 // auto dctx = meta.dctx;
142125 auto dh = meta.dh ;
143- auto d = meta.d ;
144- auto dt_logits = meta.dt_logits ;
126+ auto d = meta.d ; // hidden size
127+ auto dt_logits = meta.dt_logits ; // data type
145128 auto di = meta.di / ndev;
146129 auto dvoc = meta.dvoc ;
147130 auto stream = rsrc.stream ;
148131 bool has_qkv_bias = rsrc.b_attn_qkv .size () > 0 ;
149132
150133 // Allocate buffers
151- auto logits_in = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool );
152- auto logits_out = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool );
134+ auto logits_in = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool ); // hidden_stat
135+ auto logits_out = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool ); // hidden_stat (rms)
153136 auto qkv_buf = Tensor::buffer (dt_logits, {ntok, (nh + nkvh * 2 ) * dh}, rsrc.memory_pool );
154137 auto gate_up_buf = Tensor::buffer (dt_logits, {ntok, 2 * di}, rsrc.memory_pool );
155138 auto o_buf = Tensor::buffer (dt_logits, {ntok, nh * dh}, rsrc.memory_pool );
@@ -179,7 +162,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
179162 RUN_INFINI (infinirtMemcpyAsync (logits_in->data (i * d),
180163 rsrc.w_in_embd ->data (tokens[i] * d),
181164 dsize (dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
182- }
165+ } // ids -> embed hidden
183166
184167 // Prepare operators and workspace
185168 size_t workspace_size = 0 , temp_size = 0 ;
@@ -236,21 +219,33 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
236219 size_t max_qk_size = 0 ;
237220 size_t max_seq_len = 0 ;
238221 o_buf->dimSplit (1 , {nh, dh});
222+
223+ size_t recentWindow = 16 ; // sparse attention
239224 for (uint32_t req = 0 ; req < nreq; req++) {
240225 auto past_len = req_pos[req];
241226 auto seq_len = req_lens[req];
242227 auto total_len = past_len + seq_len;
243228 auto o = o_buf->slice ({{0 , token_offset, seq_len}});
244229 auto q = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , 0 , nh}});
245230 auto k = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh, nkvh}});
246- // auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
231+ auto v = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh + nkvh, nkvh}});
247232 // kv cache tensors can share the same descriptor
248233 // [nkvh, dh, total_len]
249234 auto full_kv = kv_caches[req]->k [idev][0 ]->slice (0 , 0 , total_len)->permute ({1 , 2 , 0 });
250235 auto cache_kv = kv_caches[req]->k [idev][0 ]->slice (0 , past_len, seq_len);
251236
252- RUN_INFINI (infiniopCreateRearrangeDescriptor (rsrc.handle , &desc_kv_rearranges[req],
253- cache_kv->desc (), k->desc ()));
237+ bool prune = (past_len == 0 ) && (seq_len > recentWindow);
238+
239+ if (prune) {
240+ auto k_compressed = k->slice ({{0 , seq_len - recentWindow, recentWindow}});
241+ auto cache_kv_compressed = kv_caches[req]->k [idev][0 ]->slice (0 , past_len, recentWindow);
242+ RUN_INFINI (infiniopCreateRearrangeDescriptor (rsrc.handle , &desc_kv_rearranges[req],
243+ cache_kv_compressed->desc (), k_compressed->desc ()));
244+ } else {
245+ RUN_INFINI (infiniopCreateRearrangeDescriptor (rsrc.handle , &desc_kv_rearranges[req],
246+ cache_kv->desc (), k->desc ()));
247+ }
248+
254249
255250 // [nkvh, ngroup, seq_len, dh]
256251 q->dimSplit (1 , {nkvh, ngroup})->permute ({1 , 2 , 0 , 3 });
@@ -267,15 +262,30 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
267262 auto qk = TensorDesc::create (dt_logits, {nkvh, ngroup * seq_len, total_len});
268263 max_qk_size = std::max (max_qk_size, size_t (seq_len * total_len));
269264 max_seq_len = std::max (max_seq_len, size_t (seq_len));
270- RUN_INFINI (infiniopCreateGemmDescriptor (
271- rsrc.handle , &desc_qk_gemms[req], qk->desc (), q_t ->desc (), full_kv->desc ()));
265+
266+ if (prune) {
267+ RUN_INFINI (infiniopCreateGemmDescriptor (
268+ rsrc.handle , &desc_qk_gemms[req], qk->desc (), q_t ->desc (), (k->permute ({1 , 2 , 0 }))->desc ()));
269+ } else {
270+ RUN_INFINI (infiniopCreateGemmDescriptor (
271+ rsrc.handle , &desc_qk_gemms[req], qk->desc (), q_t ->desc (), full_kv->desc ()));
272+ }
273+
274+
272275 RUN_INFINI (infiniopGetGemmWorkspaceSize (desc_qk_gemms[req], &temp_size));
273276 workspace_size = std::max (workspace_size, temp_size);
274277
275278 // [nkvh, total_len, dh]
276279 auto full_v = kv_caches[req]->v [idev][0 ]->slice (0 , 0 , total_len)->permute ({1 , 0 , 2 });
277- RUN_INFINI (infiniopCreateGemmDescriptor (
278- rsrc.handle , &desc_attn_v_gemms[req], q_t ->desc (), qk->desc (), full_v->desc ()));
280+
281+ if (prune) {
282+ RUN_INFINI (infiniopCreateGemmDescriptor (
283+ rsrc.handle , &desc_attn_v_gemms[req], q_t ->desc (), qk->desc (), (v->permute ({1 , 0 , 2 }))->desc ()));
284+ } else {
285+ RUN_INFINI (infiniopCreateGemmDescriptor (
286+ rsrc.handle , &desc_attn_v_gemms[req], q_t ->desc (), qk->desc (), full_v->desc ()));
287+ }
288+
279289 RUN_INFINI (infiniopGetGemmWorkspaceSize (desc_attn_v_gemms[req], &temp_size));
280290 workspace_size = std::max (workspace_size, temp_size);
281291
@@ -334,7 +344,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
334344 RUN_INFINI (infiniopGetRandomSampleWorkspaceSize (desc_sample, &temp_size));
335345 workspace_size = std::max (workspace_size, temp_size);
336346 // Allocate workspace
337- std::shared_ptr<Storage> workspace_storage = Storage::createFromPool (workspace_size, rsrc.memory_pool );
347+ std::shared_ptr<Storage> workspace_storage = Storage::createFromPool (workspace_size, rsrc.memory_pool ); // 激活值所需最大的内存空间
338348 void *workspace = workspace_storage->memory ();
339349
340350 // Compute
@@ -372,35 +382,66 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
372382
373383 size_t token_offset = 0 ;
374384 for (uint32_t req = 0 ; req < nreq; req++) {
375- auto past_len = req_pos[req];
385+ auto past_len = req_pos[req]; // for kv cache
376386 auto seq_len = req_lens[req];
377387 auto o = o_buf->slice ({{0 , token_offset, seq_len}});
378388 auto q = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , 0 , nh}});
379389 auto k = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh, nkvh}});
380- auto v = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh + nkvh, nkvh}});
390+ auto v = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh + nkvh, nkvh}}); // 同一个req的qkv本身也储存在一起,arg2=起始位置,arg3=大小
391+ // 不同req的qkv存在一起,用token_offset来维护
392+ bool prune = (past_len == 0 ) && (seq_len > recentWindow);
381393 // self attention
382- // concat
383- RUN_INFINI (infiniopRearrange (
384- desc_kv_rearranges[req],
385- kv_caches[req]->k [idev][layer]->data (past_len * nkvh * dh),
386- k->data (), stream));
387- RUN_INFINI (infiniopRearrange (
388- desc_kv_rearranges[req],
389- kv_caches[req]->v [idev][layer]->data (past_len * nkvh * dh),
390- v->data (), stream));
391- // qk
392- RUN_INFINI (infiniopRearrange (desc_q_rearranges[req], rearrange_q_buf->data (), q->data (), stream));
393- RUN_INFINI (infiniopGemm (
394- desc_qk_gemms[req], workspace, workspace_size,
395- qk_buf->data (), rearrange_q_buf->data (), kv_caches[req]->k [idev][layer]->data (), 1 . / sqrt (dh), 0.0 , stream));
396- // softmax
397- RUN_INFINI (infiniopCausalSoftmax (
398- desc_qk_softmaxs[req], workspace, workspace_size,
399- qk_buf->data (), qk_buf->data (), stream));
400- // attn val
401- RUN_INFINI (infiniopGemm (
402- desc_attn_v_gemms[req], workspace, workspace_size,
403- attn_val_buf->data (), qk_buf->data (), kv_caches[req]->v [idev][layer]->data (), 1.0 , 0.0 , stream));
394+ if (prune) { // first prefill phase
395+ auto k_compressed = k->slice ({{0 , seq_len - recentWindow, recentWindow}});
396+ auto v_compressed = v->slice ({{0 , seq_len - recentWindow, recentWindow}});
397+ // 存入kv cache
398+ RUN_INFINI (infiniopRearrange ( // concat
399+ desc_kv_rearranges[req],
400+ kv_caches[req]->k [idev][layer]->data (past_len * nkvh * dh),
401+ k_compressed->data (), stream));
402+
403+ RUN_INFINI (infiniopRearrange (
404+ desc_kv_rearranges[req],
405+ kv_caches[req]->v [idev][layer]->data (past_len * nkvh * dh),
406+ v_compressed->data (), stream));
407+ // qk
408+ RUN_INFINI (infiniopRearrange (desc_q_rearranges[req], rearrange_q_buf->data (), q->data (), stream));
409+ RUN_INFINI (infiniopGemm (
410+ desc_qk_gemms[req], workspace, workspace_size,
411+ qk_buf->data (), rearrange_q_buf->data (), k->data (), 1 . / sqrt (dh), 0.0 , stream));
412+ // softmax
413+ RUN_INFINI (infiniopCausalSoftmax (
414+ desc_qk_softmaxs[req], workspace, workspace_size,
415+ qk_buf->data (), qk_buf->data (), stream));
416+ // attn val
417+ RUN_INFINI (infiniopGemm (
418+ desc_attn_v_gemms[req], workspace, workspace_size,
419+ attn_val_buf->data (), qk_buf->data (), v->data (), 1.0 , 0.0 , stream));
420+ } else { // decode phase
421+ RUN_INFINI (infiniopRearrange (
422+ desc_kv_rearranges[req],
423+ kv_caches[req]->k [idev][layer]->data (past_len * nkvh * dh),
424+ k->data (), stream)); // 加进kv cache
425+
426+ RUN_INFINI (infiniopRearrange (
427+ desc_kv_rearranges[req],
428+ kv_caches[req]->v [idev][layer]->data (past_len * nkvh * dh),
429+ v->data (), stream));
430+ // qk
431+ RUN_INFINI (infiniopRearrange (desc_q_rearranges[req], rearrange_q_buf->data (), q->data (), stream));
432+ RUN_INFINI (infiniopGemm (
433+ desc_qk_gemms[req], workspace, workspace_size,
434+ qk_buf->data (), rearrange_q_buf->data (), kv_caches[req]->k [idev][layer]->data (), 1 . / sqrt (dh), 0.0 , stream));
435+ // softmax
436+ RUN_INFINI (infiniopCausalSoftmax (
437+ desc_qk_softmaxs[req], workspace, workspace_size,
438+ qk_buf->data (), qk_buf->data (), stream));
439+ // attn val
440+ RUN_INFINI (infiniopGemm (
441+ desc_attn_v_gemms[req], workspace, workspace_size,
442+ attn_val_buf->data (), qk_buf->data (), kv_caches[req]->v [idev][layer]->data (), 1.0 , 0.0 , stream));
443+ }
444+
404445 // rearrange attn val
405446 RUN_INFINI (infiniopRearrange (
406447 desc_attn_v_rearranges[req],
@@ -439,7 +480,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
439480 desc_ffn_down, workspace, workspace_size,
440481 logits_in->data (), gate_buf->data (),
441482 rsrc.w_ffn_down [layer]->data (), 1.0 , idev == 0 ? 1.0 : 0.0 , stream)); // only rank 0 adds residual
442-
483+ // logits_in->data()即是下一层的输入
443484 // All_reduce if distributed
444485 if (rsrc.comm != nullptr ) {
445486 RUN_INFINI (infinicclAllReduce (
@@ -520,7 +561,7 @@ inferBatch(struct JiugeModel *model,
520561 const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
521562 struct KVCache **kv_caches,
522563 const float *temperature, const uint32_t *topk, const float *topp,
523- uint32_t *output) {
564+ uint32_t *output) { // 设置推理请求参数并启动多设备的并发推理流程。
524565 model->req .tokens = tokens;
525566 model->req .ntok = ntok;
526567 model->req .req_lens = req_lens;
@@ -532,13 +573,13 @@ inferBatch(struct JiugeModel *model,
532573 model->req .topk = topk;
533574 model->req .topp = topp;
534575
535- for (size_t idev = 0 ; idev < model->dev_ids .size (); idev++) {
576+ for (size_t idev = 0 ; idev < model->dev_ids .size (); idev++) { // 启动多设备推理
536577 std::unique_lock<std::mutex> lock (model->states [idev].mtx );
537578 model->states [idev].proceed = true ;
538579 lock.unlock ();
539- model->states [idev].cv_start .notify_one ();
580+ model->states [idev].cv_start .notify_one (); // 唤醒一个线程去执行推理任务
540581 }
541- for (size_t i = model->dev_ids .size (); i > 0 ; i--) {
582+ for (size_t i = model->dev_ids .size (); i > 0 ; i--) { // 等待推理完成
542583 auto idev = i - 1 ;
543584 std::unique_lock<std::mutex> lock (model->states [idev].mtx );
544585 model->states [idev].cv_done .wait (lock, [&] { return !(model->states [idev].proceed ); });
@@ -560,7 +601,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
560601 // Infer Loop
561602 while (true ) {
562603 std::unique_lock<std::mutex> lock (state.mtx );
563- state.cv_start .wait (lock, [&] { return state.proceed || state.exit_flag ; });
604+ state.cv_start .wait (lock, [&] { return state.proceed || state.exit_flag ; }); // 等待inferBatch函数唤醒
564605 // quit if exit_flag is set
565606 if (state.exit_flag ) {
566607 break ;
@@ -627,4 +668,4 @@ __C void destroyJiugeModel(struct JiugeModel *model) {
627668 }
628669
629670 delete model;
630- }
671+ }
0 commit comments