@@ -12,12 +12,12 @@ use crate::compute_cap::{
12
12
} ;
13
13
use crate :: models:: {
14
14
BertConfig , BertModel , DistilBertConfig , DistilBertModel , GTEConfig , JinaBertModel ,
15
- JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig ,
15
+ JinaCodeBertModel , MistralConfig , Model , NomicBertModel , NomicConfig , Qwen2Config ,
16
16
} ;
17
17
#[ cfg( feature = "cuda" ) ]
18
18
use crate :: models:: {
19
19
FlashBertModel , FlashDistilBertModel , FlashGTEModel , FlashJinaBertModel ,
20
- FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel ,
20
+ FlashJinaCodeBertModel , FlashMistralModel , FlashNomicBertModel , FlashQwen2Model ,
21
21
} ;
22
22
use anyhow:: Context ;
23
23
use candle:: { DType , Device } ;
@@ -59,6 +59,7 @@ enum Config {
59
59
Mistral ( MistralConfig ) ,
60
60
#[ serde( rename = "new" ) ]
61
61
Gte ( GTEConfig ) ,
62
+ Qwen2 ( Qwen2Config ) ,
62
63
}
63
64
64
65
pub struct CandleBackend {
@@ -221,6 +222,10 @@ impl CandleBackend {
221
222
"GTE is only supported on Cuda devices in fp16 with flash attention enabled"
222
223
. to_string ( ) ,
223
224
) ) ,
225
+ ( Config :: Qwen2 ( _) , Device :: Cpu | Device :: Metal ( _) ) => Err ( BackendError :: Start (
226
+ "Qwen2 is only supported on Cuda devices in fp16 with flash attention enabled"
227
+ . to_string ( ) ,
228
+ ) ) ,
224
229
#[ cfg( feature = "cuda" ) ]
225
230
( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
226
231
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
@@ -342,14 +347,25 @@ impl CandleBackend {
342
347
#[ cfg( feature = "cuda" ) ]
343
348
( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
344
349
if dtype != DType :: F16
345
- || !cfg ! ( feature = "flash-attn" )
346
- || get_runtime_compute_cap ( ) . unwrap ( ) < 80
350
+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
347
351
{
348
- return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
352
+ return Err ( BackendError :: Start ( "GTE is only supported on Cuda devices in fp16 with flash attention enabled" . to_string ( ) ) ) ;
349
353
}
350
354
tracing:: info!( "Starting FlashGTE model on {:?}" , device) ;
351
355
Ok ( Box :: new ( FlashGTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
352
356
}
357
+ #[ cfg( feature = "cuda" ) ]
358
+ ( Config :: Qwen2 ( config) , Device :: Cuda ( _) ) => {
359
+ if dtype != DType :: F16
360
+ || !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
361
+ {
362
+ return Err ( BackendError :: Start ( "Qwen2 is only supported on Cuda devices in fp16 with flash attention v2 enabled" . to_string ( ) ) ) ;
363
+ }
364
+ tracing:: info!( "Starting FlashQwen2 model on {:?}" , device) ;
365
+ Ok ( Box :: new (
366
+ FlashQwen2Model :: load ( vb, & config, model_type) . s ( ) ?,
367
+ ) )
368
+ }
353
369
} ;
354
370
355
371
Ok ( Self {
0 commit comments