diff --git a/Cargo.lock b/Cargo.lock index 01c8d6f4..8d50b6d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -595,6 +595,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "candle-moe" +version = "0.0.1" +source = "git+https://github.com/kozistr/candle-moe?rev=990ac1f42248dd441c51c9b5bcb73c5b77c03f99#990ac1f42248dd441c51c9b5bcb73c5b77c03f99" +dependencies = [ + "anyhow", + "bindgen_cuda", + "candle-core", + "cudarc", + "half", +] + [[package]] name = "candle-nn" version = "0.8.4" @@ -4502,6 +4514,7 @@ dependencies = [ "candle-flash-attn", "candle-flash-attn-v1", "candle-layer-norm", + "candle-moe", "candle-nn", "candle-rotary", "candle-transformers", diff --git a/Cargo.toml b/Cargo.toml index 1220c9ad..d6feae96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ serde_json = "1.0" thiserror = "1.0" rand = "0.9" serial_test = "2.0.0" -cudarc = { version = "0.13", features =["cuda-12020"], default-features = false } +cudarc = { version = "0.13", features = ["cuda-12020"], default-features = false } intel-mkl-src = { version = "0.8", default-features = false } candle = { version = "0.8", package = "candle-core" } candle-nn = { version = "0.8" } @@ -52,10 +52,11 @@ candle-cublaslt = { version = "0.0.1" } candle-layer-norm = { version = "0.0.1" } candle-rotary = { version = "0.0.1" } candle-flash-attn-v1 = { version = "0.0.1" } +candle-moe = { git = "https://github.com/kozistr/candle-moe", rev = "990ac1f42248dd441c51c9b5bcb73c5b77c03f99" } half = { version = "2.3.1", features = ["num-traits"] } [patch.crates-io] -cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9"} +cudarc = { git = "https://github.com/Narsil/cudarc" , rev = "8b4f18b4bcd5e4b1a9daf40abc3a2e27f83f06e9" } candle = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-core" } candle-nn = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-nn" } candle-transformers = { git = "https://github.com/huggingface/candle", rev = "6381023982251959a2c9bab7378b3013304e192b", package = "candle-transformers" } diff --git a/backends/candle/Cargo.toml b/backends/candle/Cargo.toml index 73d0f417..4de054ee 100644 --- a/backends/candle/Cargo.toml +++ b/backends/candle/Cargo.toml @@ -17,6 +17,7 @@ candle-flash-attn-v1 = { workspace = true, optional = true } candle-cublaslt = { workspace = true, optional = true } candle-layer-norm = { workspace = true, optional = true } candle-rotary = { workspace = true, optional = true } +candle-moe = { workspace = true, optional = true } nohash-hasher = { workspace = true } text-embeddings-backend-core = { path = "../core" } tracing = { workspace = true } @@ -41,6 +42,6 @@ anyhow = { version = "1", features = ["backtrace"] } accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] metal = ["candle/metal", "candle-nn/metal"] mkl = ["dep:intel-mkl-src", "candle/_mkl"] -cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"] +cuda = ["candle/_cuda", "candle-nn/_cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary", "dep:candle-moe"] flash-attn-v1 = ["dep:candle-flash-attn-v1", "cuda"] flash-attn = ["dep:candle-flash-attn", "cuda"] diff --git a/backends/candle/src/models/nomic.rs b/backends/candle/src/models/nomic.rs index 8748db38..4956eede 100644 --- a/backends/candle/src/models/nomic.rs +++ b/backends/candle/src/models/nomic.rs @@ -3,6 +3,8 @@ use crate::layers::{ }; use crate::models::Model; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +#[cfg(feature = "cuda")] +use candle_moe; use candle_nn::{Embedding, VarBuilder}; use candle_transformers::models::deepseek2::{BincountOp, NonZeroOp, TopKLastDimOp, TopKOutput}; use serde::Deserialize; @@ -239,6 +241,55 @@ impl NomicRouter { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedRouter { + layer: Linear, + top_k: usize, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedRouter { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let num_experts = config.num_experts.unwrap(); + let top_k = config.moe_top_k.unwrap(); + + let layer_weight = vb.pp("layer").get((num_experts, config.n_embd), "weight")?; + let layer = Linear::new(layer_weight, None, None); + + Ok(Self { + layer, + top_k, + span: tracing::span!(tracing::Level::TRACE, "router"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result<(Tensor, Tensor)> { + let _enter = self.span.enter(); + + let device = hidden_states.device(); + + let weights = hidden_states.reshape(((), hidden_states.dim(D::Minus1)?))?; + let weights = self.layer.forward(&weights)?.to_dtype(DType::F32)?; + + let (seq_len, _) = weights.shape().dims2()?; + + let topk_weight = Tensor::zeros((seq_len, self.top_k), DType::F32, device)?; + let topk_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + let token_expert_indices = Tensor::zeros((seq_len, self.top_k), DType::U32, device)?; + + candle_moe::apply_topk_softmax_inplace( + &weights, + &topk_weight, + &topk_indices, + &token_expert_indices, + )?; + + Ok((topk_weight, topk_indices)) + } +} + pub struct NomicExpertMLP { w1: Tensor, w2: Tensor, @@ -363,6 +414,95 @@ impl NomicExperts { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedExperts { + gate_weight: Tensor, + up_weight: Tensor, + bias: Tensor, + fused_moe: candle_moe::FusedMoeForward, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedExperts { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let hidden_size = config.n_embd; + let ffn_hidden_size = config.n_inner; + let num_experts = config.num_experts.unwrap(); + let top_k = config.moe_top_k.unwrap(); + let activation = config.activation_function.clone(); + + let gate_weight = vb + .pp("mlp") + .get((num_experts * ffn_hidden_size, hidden_size), "w1")? + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; + let up_weight = vb + .pp("mlp") + .get((num_experts * ffn_hidden_size, hidden_size), "w2")? + .reshape((num_experts, ffn_hidden_size, hidden_size))? + .permute((0, 2, 1))? + .contiguous()?; + + let bias = vb.get((config.n_embd,), "bias")?; + + let moe_act = match activation { + HiddenAct::Silu => candle_moe::Activation::Silu, + HiddenAct::Gelu => candle_moe::Activation::Gelu, + HiddenAct::Relu => candle_moe::Activation::Relu, + _ => candle::bail!("not supported activation type"), + }; + + let fused_moe = candle_moe::FusedMoeForward::new(num_experts, top_k, moe_act); + + Ok(Self { + gate_weight, + up_weight, + bias, + fused_moe, + span: tracing::span!(tracing::Level::TRACE, "experts"), + }) + } + + pub fn forward( + &self, + hidden_states: &Tensor, + top_weights: &Tensor, + top_experts: &Tensor, + ) -> Result { + let _enter = self.span.enter(); + + let dims = hidden_states.dims(); + let ndim = dims.len(); + + let (bs, seq_len, hidden_size) = match ndim { + 3 => (dims[0], dims[1], dims[2]), + 2 => (1, dims[0], dims[1]), + _ => unreachable!(), + }; + + let hidden_states = hidden_states.reshape(((), hidden_size))?; + + let mut out = self.fused_moe.forward( + &hidden_states, + &self.gate_weight, + &self.up_weight, + None, + &top_weights, + &top_experts, + 1_u32, // Nomic MoE + )?; + + if ndim == 3 { + out = out.reshape((bs, seq_len, hidden_size))?; + } + + out.broadcast_add(&self.bias) + } +} + pub struct NomicMoELayer { router: NomicRouter, experts: NomicExperts, @@ -392,8 +532,41 @@ impl NomicMoELayer { } } +#[cfg(feature = "cuda")] +pub struct NomicFusedMoELayer { + router: NomicFusedRouter, + experts: NomicFusedExperts, + + span: tracing::Span, +} + +#[cfg(feature = "cuda")] +impl NomicFusedMoELayer { + pub fn load(vb: VarBuilder, config: &NomicConfig) -> Result { + let router = NomicFusedRouter::load(vb.pp("router"), config)?; + let experts = NomicFusedExperts::load(vb.pp("experts"), config)?; + + Ok(Self { + router, + experts, + span: tracing::span!(tracing::Level::TRACE, "moe"), + }) + } + + pub fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + + let (top_weights, top_experts) = self.router.forward(hidden_states)?; + + self.experts + .forward(hidden_states, &top_weights, &top_experts) + } +} + pub enum NomicMLP { MoE(NomicMoELayer), + #[cfg(feature = "cuda")] + FusedMoE(NomicFusedMoELayer), GatedMLP(NomicBertGatedMLP), Mlp(NomicBertMLP), } @@ -403,7 +576,14 @@ impl NomicMLP { let use_moe = matches!(config.moe_every_n_layers, Some(n) if n > 0 && index % n == 1); if use_moe { - Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + #[cfg(feature = "cuda")] + { + Ok(Self::FusedMoE(NomicFusedMoELayer::load(vb, config)?)) + } + #[cfg(not(feature = "cuda"))] + { + Ok(Self::MoE(NomicMoELayer::load(vb, config)?)) + } } else if config.activation_function == HiddenAct::Gelu { Ok(Self::Mlp(NomicBertMLP::load(vb, config)?)) } else { @@ -414,6 +594,8 @@ impl NomicMLP { pub fn forward(&self, hidden_states: &Tensor) -> Result { match self { Self::MoE(layer) => layer.forward(hidden_states), + #[cfg(feature = "cuda")] + Self::FusedMoE(layer) => layer.forward(hidden_states), Self::GatedMLP(layer) => layer.forward(hidden_states), Self::Mlp(layer) => layer.forward(hidden_states), }