From 3385bbe072754a3d5ab1c163f6328569b7228bae Mon Sep 17 00:00:00 2001 From: Qsqsdac <962114354@qq.com> Date: Thu, 17 Jul 2025 11:37:00 +0800 Subject: [PATCH 1/3] =?UTF-8?q?build:=20rwkv.rs=20=E5=88=9D=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llama.cu/src/model/rwkv.rs | 120 +++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 llama.cu/src/model/rwkv.rs diff --git a/llama.cu/src/model/rwkv.rs b/llama.cu/src/model/rwkv.rs new file mode 100644 index 00000000..71189a75 --- /dev/null +++ b/llama.cu/src/model/rwkv.rs @@ -0,0 +1,120 @@ +use super::GGufModel; +use crate::utils::meta; +use nn::{ + Linear, Normalization, NormType, OutputHead, RWKV, RWKVBlock, Tensor, TimeMix, ChannelMix, +}; + +impl GGufModel<'_> { + /// 构造 RWKV 模型 + pub fn rwkv(&self) -> nn::RWKV> { + let n_layer = meta![self => block_count]; + let hidden_size = meta![self => hidden_size]; + let vocab_size = meta![self => vocab_size]; + let epsilon = meta![self => layer_norm_epsilon; 1e-5]; + + let dt_linear = self.tensors["emb.weight"].dt(); + + let get = |name: &str| self.tensors[name].as_deref(); + + let emb_weight = get("emb.weight"); + let ln_out_weight = get("ln_out.weight"); + let head_weight = get("head.weight"); + + RWKV { + embedding: Linear::new( + dt_linear, + [hidden_size, vocab_size], + emb_weight, + None, + ), + blks: (0..n_layer) + .map(|iblk| { + RWKVBlock::new( + Normalization { + d: hidden_size, + epsilon: epsilon as _, + items: NormType::RmsNorm { + dt: ln_out_weight.dt(), + scale: get(&format!("blocks.{iblk}.ln1.weight")), + }, + }, + TimeMix { + k: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.time_mix_k.weight")), + None, + ), + v: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.time_mix_v.weight")), + None, + ), + r: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.time_mix_r.weight")), + None, + ), + }, + Normalization { + d: hidden_size, + epsilon: epsilon as _, + items: NormType::RmsNorm { + dt: ln_out_weight.dt(), + scale: get(&format!("blocks.{iblk}.ln2.weight")), + }, + }, + ChannelMix { + k: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.channel_mix_k.weight")), + None, + ), + r: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.channel_mix_r.weight")), + None, + ), + v: Linear::new( + dt_linear, + [hidden_size, hidden_size], + get(&format!("blocks.{iblk}.channel_mix_v.weight")), + None, + ), + }, + ) + }) + .collect(), + output_head: Some(OutputHead { + out_norm: Normalization { + d: hidden_size, + epsilon: epsilon as _, + items: NormType::RmsNorm { + dt: ln_out_weight.dt(), + scale: ln_out_weight, + }, + }, + lm_head: Linear::new( + head_weight.dt(), + [vocab_size, hidden_size], + head_weight, + None, + ), + }), + } + } + + /// 构造 RWKV 模型的状态缓存张量 + pub fn rwkv_state_cache(&self) -> Tensor { + let dt = self.tensors["emb.weight"].dt(); + let n_layer = meta![self => block_count]; + let hidden_size = meta![self => hidden_size]; + + // RWKV 状态包含: [n_layer, 2, hidden_size] (time_mix 和 channel_mix 各一个状态) + Tensor::from_dim_slice(dt, [n_layer, 2, hidden_size]) + } +} From 1f5cba4f933c8cd721ac04c2a2ff9205605017a4 Mon Sep 17 00:00:00 2001 From: Qsqsdac <962114354@qq.com> Date: Wed, 23 Jul 2025 21:05:11 +0800 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20=E4=BF=AE=E6=AD=A3=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=9E=B6=E6=9E=84=E7=BB=86=E8=8A=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llama.cu/src/model/rwkv.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/llama.cu/src/model/rwkv.rs b/llama.cu/src/model/rwkv.rs index 71189a75..1a3ecdbc 100644 --- a/llama.cu/src/model/rwkv.rs +++ b/llama.cu/src/model/rwkv.rs @@ -21,12 +21,15 @@ impl GGufModel<'_> { let head_weight = get("head.weight"); RWKV { - embedding: Linear::new( - dt_linear, - [hidden_size, vocab_size], - emb_weight, - None, - ), + embedding: Embedding { + dt: emb_weight.dt(), + d: hidden_size, + wte: Table { + row: vocab_size, + weight: emb_weight, + }, + wpe: None, + }, blks: (0..n_layer) .map(|iblk| { RWKVBlock::new( From 68ebfad0a5439647ecd983a54be234d250b82d3c Mon Sep 17 00:00:00 2001 From: Qsqsdac <962114354@qq.com> Date: Wed, 30 Jul 2025 22:39:55 +0800 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20=E5=8E=BB=E6=8E=89=E6=98=BE=E5=BC=8F?= =?UTF-8?q?=E7=9A=84=E7=8A=B6=E6=80=81=E7=AE=A1=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llama.cu/src/model/rwkv.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/llama.cu/src/model/rwkv.rs b/llama.cu/src/model/rwkv.rs index 1a3ecdbc..14058362 100644 --- a/llama.cu/src/model/rwkv.rs +++ b/llama.cu/src/model/rwkv.rs @@ -110,14 +110,4 @@ impl GGufModel<'_> { }), } } - - /// 构造 RWKV 模型的状态缓存张量 - pub fn rwkv_state_cache(&self) -> Tensor { - let dt = self.tensors["emb.weight"].dt(); - let n_layer = meta![self => block_count]; - let hidden_size = meta![self => hidden_size]; - - // RWKV 状态包含: [n_layer, 2, hidden_size] (time_mix 和 channel_mix 各一个状态) - Tensor::from_dim_slice(dt, [n_layer, 2, hidden_size]) - } }