Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 39 additions & 36 deletions backends/candle/src/models/modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ pub struct ModernBertModel {
pool: Pool,
classifier: Option<Box<dyn ClassificationHead + Send>>,

local_attention: usize,
window_size: usize,
global_inv_freqs: Tensor,
local_inv_freqs: Tensor,
rotary_dim: usize,
Expand Down Expand Up @@ -541,7 +541,7 @@ impl ModernBertModel {
final_norm,
pool,
classifier,
local_attention: config.local_attention,
window_size: config.local_attention / 2,
global_inv_freqs,
local_inv_freqs,
rotary_dim: attention_head_size,
Expand Down Expand Up @@ -577,32 +577,47 @@ impl ModernBertModel {
Ok(extended_attention_mask)
}

fn get_local_attention_mask(&self, attention_mask: &Tensor) -> Result<Tensor> {
let dev = attention_mask.device();
let attention_mask = attention_mask
.to_device(&Device::Cpu)?
.to_dtype(DType::U8)?;
fn get_window_mask(&self, seq_len: usize) -> Result<Tensor> {
let mut inverted_window_mask = vec![0.0_f32; seq_len * seq_len];

let mask_shape = attention_mask.shape();
let (_, _, seq_len, _) = mask_shape.dims4()?;
for i in 0..seq_len {
let start = i.saturating_sub(self.window_size);
let end = (i + self.window_size + 1).min(seq_len);

let rows = Tensor::arange(0, seq_len as i64, attention_mask.device())?.unsqueeze(0)?;
let rows = rows.broadcast_as((seq_len, seq_len))?;
inverted_window_mask[(i * seq_len)..(i * seq_len + start)].fill(1.0);
inverted_window_mask[(i * seq_len + end)..((i + 1) * seq_len)].fill(1.0);
}

let inverted_window_mask =
Tensor::from_slice(&inverted_window_mask, (seq_len, seq_len), &self.device)?;

Ok(inverted_window_mask)
}

fn get_attention_mask(
&self,
attention_mask: Option<&Tensor>,
input_shape: &(usize, usize),
) -> Result<(Tensor, Tensor)> {
let global_attention_mask = self
.get_global_attention_mask(attention_mask, input_shape)?
.to_dtype(self.dtype)?;

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
};

let distance = (&rows - &rows.t()?)?.abs()?;
let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
let global_attention_mask = global_attention_mask.to_dtype(self.dtype)?;

let window_size = (self.local_attention / 2) as i64;
let window_mask = distance
.le(window_size)?
.unsqueeze(0)?
.unsqueeze(0)?
.broadcast_as(mask_shape)?;
let seq_len = global_attention_mask.dim(2)?;
let window_mask = self.get_window_mask(seq_len)?;
let window_mask = (window_mask * min_value)?.to_dtype(self.dtype)?;

let zero_tensor = Tensor::zeros_like(&attention_mask)?;
let local_attention_mask = attention_mask.where_cond(&window_mask, &zero_tensor)?;
let local_attention_mask = local_attention_mask.to_device(dev)?;
let local_attention_mask = global_attention_mask.broadcast_add(&window_mask)?;
Comment thread
kozistr marked this conversation as resolved.

Ok(local_attention_mask)
Ok((global_attention_mask, local_attention_mask))
}

fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
Expand Down Expand Up @@ -675,20 +690,8 @@ impl ModernBertModel {
let mut input_lengths =
Tensor::from_vec(input_lengths, (batch_size, 1), &self.device)?.to_dtype(self.dtype)?;

let global_attention_mask = self
.get_global_attention_mask(attention_mask.as_ref(), &shape)?
.to_dtype(self.dtype)?;
let local_attention_mask = self
.get_local_attention_mask(&global_attention_mask)?
.to_dtype(self.dtype)?;

let min_value = match self.dtype {
DType::F32 => f32::MIN as f64,
_ => -65504.0, // f16 minimum value
};

let global_attention_mask = ((1.0 - global_attention_mask)? * min_value)?;
let local_attention_mask = ((1.0 - local_attention_mask)? * min_value)?;
let (global_attention_mask, local_attention_mask) =
self.get_attention_mask(attention_mask.as_ref(), &shape)?;

let global_rotary_cache =
get_cos_sin(max_length, &self.global_inv_freqs, self.dtype, true)?;
Expand Down