Skip to content

Commit b155d3f

Browse files
committed
合并wty & hds
1 parent d188d53 commit b155d3f

File tree

2 files changed

+107
-64
lines changed

2 files changed

+107
-64
lines changed

scripts/infer_task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def kvcache(self):
2626

2727
def next(self, out_token):
2828
self._kv_cache.update_tokens(self.tokens, self.pos)
29+
recentWindow = 16
30+
self.pos += min(len(self.tokens), recentWindow)
2931

3032
self.pos += len(self.tokens)
3133
if out_token == None or out_token in self.end_tokens:

src/models/jiuge/jiuge.cpp

Lines changed: 105 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#include <random>
99
#include <thread>
1010
#include <vector>
11-
#include <iostream>
12-
#include <iomanip>
1311

1412
void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
1513
const JiugeWeights *weights,
@@ -21,7 +19,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
2119
infiniopCreateHandle(&handle);
2220
infinirtStream_t stream;
2321
infinirtStreamCreate(&stream);
24-
22+
// 加载权重
2523
std::vector<std::shared_ptr<Tensor>> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out,
2624
w_ffn_norm, w_ffn_gate_up, w_ffn_down;
2725
for (size_t layer = 0; layer < meta->nlayer; layer++) {
@@ -43,13 +41,7 @@ void createDeviceResource(DeviceResource *rsrc, const JiugeMeta *meta,
4341
getFFNDown(meta, weights, layer, idev, ndev));
4442
}
4543

46-
// 配置内存池预分配策略
47-
MemoryPool::PreallocationConfig pool_config;
48-
pool_config.small_pool_size = 32 * 1024 * 1024; // 32MB for small allocations
49-
pool_config.medium_pool_size = 64 * 1024 * 1024; // 64MB for medium allocations
50-
pool_config.large_pool_size = 128 * 1024 * 1024; // 128MB for large allocations
51-
52-
auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024, MemoryPool::DEFAULT_ALIGNMENT, pool_config);
44+
auto memory_pool = std::make_shared<MemoryPool>(128 * 1024 * 1024);
5345

5446
*rsrc = DeviceResource{
5547
device,
@@ -120,36 +112,27 @@ void releaseDeviceResource(DeviceResource &res) {
120112

121113
void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
122114
uint32_t idev, uint32_t ndev,
123-
const uint32_t *tokens, uint32_t ntok,
115+
const uint32_t *tokens, uint32_t ntok, //所有req的ntokens之和
124116
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
125117
struct KVCache **kv_caches,
126118
const float *temperature, const uint32_t *topk, const float *topp,
127119
uint32_t *output) {
128-
// 推理开始前检查内存状态
129-
static int inference_count = 0;
130-
inference_count++;
131-
132-
// 每10次推理检查一次碎片率,必要时进行碎片整理
133-
if (inference_count % 10 == 0 && rsrc.memory_pool->shouldDefragment()) {
134-
rsrc.memory_pool->defragment();
135-
}
136-
137120
auto nlayer = meta.nlayer;
138121
auto nkvh = meta.nkvh / ndev;
139122
auto nh = meta.nh / ndev;
140123
auto ngroup = nh / nkvh;
141124
// auto dctx = meta.dctx;
142125
auto dh = meta.dh;
143-
auto d = meta.d;
144-
auto dt_logits = meta.dt_logits;
126+
auto d = meta.d; //hidden size
127+
auto dt_logits = meta.dt_logits; //data type
145128
auto di = meta.di / ndev;
146129
auto dvoc = meta.dvoc;
147130
auto stream = rsrc.stream;
148131
bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0;
149132

150133
// Allocate buffers
151-
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
152-
auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
134+
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); //hidden_stat
135+
auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); //hidden_stat (rms)
153136
auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool);
154137
auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool);
155138
auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool);
@@ -179,7 +162,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
179162
RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
180163
rsrc.w_in_embd->data(tokens[i] * d),
181164
dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
182-
}
165+
} // ids -> embed hidden
183166

184167
// Prepare operators and workspace
185168
size_t workspace_size = 0, temp_size = 0;
@@ -236,21 +219,33 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
236219
size_t max_qk_size = 0;
237220
size_t max_seq_len = 0;
238221
o_buf->dimSplit(1, {nh, dh});
222+
223+
size_t recentWindow = 16; //sparse attention
239224
for (uint32_t req = 0; req < nreq; req++) {
240225
auto past_len = req_pos[req];
241226
auto seq_len = req_lens[req];
242227
auto total_len = past_len + seq_len;
243228
auto o = o_buf->slice({{0, token_offset, seq_len}});
244229
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
245230
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
246-
// auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
231+
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
247232
// kv cache tensors can share the same descriptor
248233
// [nkvh, dh, total_len]
249234
auto full_kv = kv_caches[req]->k[idev][0]->slice(0, 0, total_len)->permute({1, 2, 0});
250235
auto cache_kv = kv_caches[req]->k[idev][0]->slice(0, past_len, seq_len);
251236

252-
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
253-
cache_kv->desc(), k->desc()));
237+
bool prune = (past_len == 0) && (seq_len > recentWindow);
238+
239+
if (prune) {
240+
auto k_compressed = k->slice({{0, seq_len - recentWindow, recentWindow}});
241+
auto cache_kv_compressed = kv_caches[req]->k[idev][0]->slice(0, past_len, recentWindow);
242+
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
243+
cache_kv_compressed->desc(), k_compressed->desc()));
244+
} else {
245+
RUN_INFINI(infiniopCreateRearrangeDescriptor(rsrc.handle, &desc_kv_rearranges[req],
246+
cache_kv->desc(), k->desc()));
247+
}
248+
254249

255250
// [nkvh, ngroup, seq_len, dh]
256251
q->dimSplit(1, {nkvh, ngroup})->permute({1, 2, 0, 3});
@@ -267,15 +262,30 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
267262
auto qk = TensorDesc::create(dt_logits, {nkvh, ngroup * seq_len, total_len});
268263
max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
269264
max_seq_len = std::max(max_seq_len, size_t(seq_len));
270-
RUN_INFINI(infiniopCreateGemmDescriptor(
271-
rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), full_kv->desc()));
265+
266+
if (prune) {
267+
RUN_INFINI(infiniopCreateGemmDescriptor(
268+
rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), (k->permute({1, 2, 0}))->desc()));
269+
} else {
270+
RUN_INFINI(infiniopCreateGemmDescriptor(
271+
rsrc.handle, &desc_qk_gemms[req], qk->desc(), q_t->desc(), full_kv->desc()));
272+
}
273+
274+
272275
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_qk_gemms[req], &temp_size));
273276
workspace_size = std::max(workspace_size, temp_size);
274277

275278
// [nkvh, total_len, dh]
276279
auto full_v = kv_caches[req]->v[idev][0]->slice(0, 0, total_len)->permute({1, 0, 2});
277-
RUN_INFINI(infiniopCreateGemmDescriptor(
278-
rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), full_v->desc()));
280+
281+
if (prune) {
282+
RUN_INFINI(infiniopCreateGemmDescriptor(
283+
rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), (v->permute({1, 0, 2}))->desc()));
284+
} else {
285+
RUN_INFINI(infiniopCreateGemmDescriptor(
286+
rsrc.handle, &desc_attn_v_gemms[req], q_t->desc(), qk->desc(), full_v->desc()));
287+
}
288+
279289
RUN_INFINI(infiniopGetGemmWorkspaceSize(desc_attn_v_gemms[req], &temp_size));
280290
workspace_size = std::max(workspace_size, temp_size);
281291

@@ -334,7 +344,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
334344
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
335345
workspace_size = std::max(workspace_size, temp_size);
336346
// Allocate workspace
337-
std::shared_ptr<Storage> workspace_storage = Storage::createFromPool(workspace_size, rsrc.memory_pool);
347+
std::shared_ptr<Storage> workspace_storage = Storage::createFromPool(workspace_size, rsrc.memory_pool); //激活值所需最大的内存空间
338348
void *workspace = workspace_storage->memory();
339349

340350
// Compute
@@ -372,35 +382,66 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
372382

373383
size_t token_offset = 0;
374384
for (uint32_t req = 0; req < nreq; req++) {
375-
auto past_len = req_pos[req];
385+
auto past_len = req_pos[req]; //for kv cache
376386
auto seq_len = req_lens[req];
377387
auto o = o_buf->slice({{0, token_offset, seq_len}});
378388
auto q = qkv_buf->slice({{0, token_offset, seq_len}, {1, 0, nh}});
379389
auto k = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh, nkvh}});
380-
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}});
390+
auto v = qkv_buf->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); //同一个req的qkv本身也储存在一起,arg2=起始位置,arg3=大小
391+
//不同req的qkv存在一起,用token_offset来维护
392+
bool prune = (past_len == 0) && (seq_len > recentWindow);
381393
// self attention
382-
// concat
383-
RUN_INFINI(infiniopRearrange(
384-
desc_kv_rearranges[req],
385-
kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
386-
k->data(), stream));
387-
RUN_INFINI(infiniopRearrange(
388-
desc_kv_rearranges[req],
389-
kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
390-
v->data(), stream));
391-
// qk
392-
RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
393-
RUN_INFINI(infiniopGemm(
394-
desc_qk_gemms[req], workspace, workspace_size,
395-
qk_buf->data(), rearrange_q_buf->data(), kv_caches[req]->k[idev][layer]->data(), 1. / sqrt(dh), 0.0, stream));
396-
// softmax
397-
RUN_INFINI(infiniopCausalSoftmax(
398-
desc_qk_softmaxs[req], workspace, workspace_size,
399-
qk_buf->data(), qk_buf->data(), stream));
400-
// attn val
401-
RUN_INFINI(infiniopGemm(
402-
desc_attn_v_gemms[req], workspace, workspace_size,
403-
attn_val_buf->data(), qk_buf->data(), kv_caches[req]->v[idev][layer]->data(), 1.0, 0.0, stream));
394+
if (prune) { // first prefill phase
395+
auto k_compressed = k->slice({{0, seq_len - recentWindow, recentWindow}});
396+
auto v_compressed = v->slice({{0, seq_len - recentWindow, recentWindow}});
397+
//存入kv cache
398+
RUN_INFINI(infiniopRearrange( // concat
399+
desc_kv_rearranges[req],
400+
kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
401+
k_compressed->data(), stream));
402+
403+
RUN_INFINI(infiniopRearrange(
404+
desc_kv_rearranges[req],
405+
kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
406+
v_compressed->data(), stream));
407+
// qk
408+
RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
409+
RUN_INFINI(infiniopGemm(
410+
desc_qk_gemms[req], workspace, workspace_size,
411+
qk_buf->data(), rearrange_q_buf->data(), k->data(), 1. / sqrt(dh), 0.0, stream));
412+
// softmax
413+
RUN_INFINI(infiniopCausalSoftmax(
414+
desc_qk_softmaxs[req], workspace, workspace_size,
415+
qk_buf->data(), qk_buf->data(), stream));
416+
// attn val
417+
RUN_INFINI(infiniopGemm(
418+
desc_attn_v_gemms[req], workspace, workspace_size,
419+
attn_val_buf->data(), qk_buf->data(), v->data(), 1.0, 0.0, stream));
420+
} else { // decode phase
421+
RUN_INFINI(infiniopRearrange(
422+
desc_kv_rearranges[req],
423+
kv_caches[req]->k[idev][layer]->data(past_len * nkvh * dh),
424+
k->data(), stream)); //加进kv cache
425+
426+
RUN_INFINI(infiniopRearrange(
427+
desc_kv_rearranges[req],
428+
kv_caches[req]->v[idev][layer]->data(past_len * nkvh * dh),
429+
v->data(), stream));
430+
// qk
431+
RUN_INFINI(infiniopRearrange(desc_q_rearranges[req], rearrange_q_buf->data(), q->data(), stream));
432+
RUN_INFINI(infiniopGemm(
433+
desc_qk_gemms[req], workspace, workspace_size,
434+
qk_buf->data(), rearrange_q_buf->data(), kv_caches[req]->k[idev][layer]->data(), 1. / sqrt(dh), 0.0, stream));
435+
// softmax
436+
RUN_INFINI(infiniopCausalSoftmax(
437+
desc_qk_softmaxs[req], workspace, workspace_size,
438+
qk_buf->data(), qk_buf->data(), stream));
439+
// attn val
440+
RUN_INFINI(infiniopGemm(
441+
desc_attn_v_gemms[req], workspace, workspace_size,
442+
attn_val_buf->data(), qk_buf->data(), kv_caches[req]->v[idev][layer]->data(), 1.0, 0.0, stream));
443+
}
444+
404445
// rearrange attn val
405446
RUN_INFINI(infiniopRearrange(
406447
desc_attn_v_rearranges[req],
@@ -439,7 +480,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
439480
desc_ffn_down, workspace, workspace_size,
440481
logits_in->data(), gate_buf->data(),
441482
rsrc.w_ffn_down[layer]->data(), 1.0, idev == 0 ? 1.0 : 0.0, stream)); // only rank 0 adds residual
442-
483+
// logits_in->data()即是下一层的输入
443484
// All_reduce if distributed
444485
if (rsrc.comm != nullptr) {
445486
RUN_INFINI(infinicclAllReduce(
@@ -520,7 +561,7 @@ inferBatch(struct JiugeModel *model,
520561
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
521562
struct KVCache **kv_caches,
522563
const float *temperature, const uint32_t *topk, const float *topp,
523-
uint32_t *output) {
564+
uint32_t *output) { //设置推理请求参数并启动多设备的并发推理流程。
524565
model->req.tokens = tokens;
525566
model->req.ntok = ntok;
526567
model->req.req_lens = req_lens;
@@ -532,13 +573,13 @@ inferBatch(struct JiugeModel *model,
532573
model->req.topk = topk;
533574
model->req.topp = topp;
534575

535-
for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
576+
for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { //启动多设备推理
536577
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
537578
model->states[idev].proceed = true;
538579
lock.unlock();
539-
model->states[idev].cv_start.notify_one();
580+
model->states[idev].cv_start.notify_one(); //唤醒一个线程去执行推理任务
540581
}
541-
for (size_t i = model->dev_ids.size(); i > 0; i--) {
582+
for (size_t i = model->dev_ids.size(); i > 0; i--) { //等待推理完成
542583
auto idev = i - 1;
543584
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
544585
model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
@@ -560,7 +601,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, DeviceReso
560601
// Infer Loop
561602
while (true) {
562603
std::unique_lock<std::mutex> lock(state.mtx);
563-
state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
604+
state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); //等待inferBatch函数唤醒
564605
// quit if exit_flag is set
565606
if (state.exit_flag) {
566607
break;
@@ -627,4 +668,4 @@ __C void destroyJiugeModel(struct JiugeModel *model) {
627668
}
628669

629670
delete model;
630-
}
671+
}

0 commit comments

Comments
 (0)