@@ -25,3 +25,39 @@ y = (x^2 + δ)^(-1/2) * w * x
25251 Output:
2626
2727- ** Y(heterogeneous) - T** : 输出张量。形状与 ` X ` 相同。
28+
29+ ## Attention
30+
31+ ### Summary
32+
33+ Multi-head Self Attention 的封装形式,用于 transformer 模型。
34+
35+ 支持使用 kv cache,使用条件由输入和属性综合决定。有以下 种情况:
36+
37+ | 序号 | 输入数量 | ` max_seq_len ` | 使用 kv cache | 输出数量 | cache s 维度 | 备注
38+ |:-:|:-:|:-----:|:-------:|:-:|:------------------------:|:-
39+ | 0 | 3 | 0 | none | 1 | - |
40+ | 1 | 3 | S > 0 | init | 3 | ` S ` | ` assert(S >= seq_len) `
41+ | 2 | 4 | 0 | inplace | 3 | ` past_seq_len + seq_len ` | ` past_seq_len ` 必须是常量
42+ | 3 | 4 | S > 0 | inplace | 3 | ` S ` | ` assert(S >= past_seq_len + seq_len) `
43+ | 4 | 6 | 0 | copy | 3 | ` past_seq_len + seq_len ` | ` past_seq_len ` 必须是常量
44+ | 5 | 6 | S > 0 | copy | 3 | ` S ` | ` assert(S >= past_seq_len + seq_len) `
45+
46+ ### Attributes
47+
48+ - ** max_seq_len - INT** (default is ` 0 ` ): 最大序列长度,用于初始化 kv cache。
49+
50+ ### Inputs
51+
52+ - ** query(heterogeneous) - T** : 形状为 ` N x n_head x seq_len x head_dim ` 。
53+ - ** key(heterogeneous) - T** : 形状为 ` N x n_kv_head x seq_len x head_dim ` 。
54+ - ** value(heterogeneous) - T** : 形状为 ` N x n_kv_head x seq_len x head_dim ` 。
55+ - ** past_seq_len(optional) -int64** : 要连接的历史序列长度,必须为标量。不使用 kv cache 时留空。
56+ - ** k_cache(optional, heterogeneous) -T** : k 缓存的初始值,形状为 ` N x n_kv_head x s x head_dim ` ,` s ` 为不小于 ` past_seq_len ` 的任意值。不使用或不重置 kv cache 时留空。
57+ - ** v_cache(optional, heterogeneous) -T** : v 缓存的初始值,形状为 ` N x n_kv_head x s x head_dim ` ,` s ` 为不小于 ` past_seq_len ` 的任意值。不使用或不重置 kv cache 时留空。
58+
59+ ### Outputs
60+
61+ - ** output(heterogeneous) - T** : 形状与 ` query ` 相同。
62+ - ** k_cache(optional, heterogeneous) - T** : 形状为 ` N x n_kv_head x s x head_dim ` 。` s ` 的值根据 ` Summary ` 的描述计算。
63+ - ** v_cache(optional, heterogeneous) - T** : 形状为 ` N x n_kv_head x s x head_dim ` 。` s ` 的值根据 ` Summary ` 的描述计算。
0 commit comments