1
+ mod alibi;
1
2
#[ cfg( feature = "cuda" ) ]
2
3
mod compute_cap;
3
4
#[ cfg( feature = "cuda" ) ]
@@ -9,7 +10,9 @@ mod models;
9
10
use crate :: compute_cap:: { incompatible_compute_cap, COMPILE_COMPUTE_CAP , RUNTIME_COMPUTE_CAP } ;
10
11
#[ cfg( feature = "cuda" ) ]
11
12
use crate :: models:: FlashBertModel ;
12
- use crate :: models:: { BertModel , EmbeddingModel , PositionEmbeddingType , QuantBertModel } ;
13
+ use crate :: models:: {
14
+ BertModel , EmbeddingModel , JinaBertModel , PositionEmbeddingType , QuantBertModel ,
15
+ } ;
13
16
use candle:: { DType , Device } ;
14
17
use candle_nn:: VarBuilder ;
15
18
use models:: Config ;
@@ -47,8 +50,6 @@ impl CandleBackend {
47
50
48
51
let model: Box < dyn EmbeddingModel + Send > = match device {
49
52
Device :: Cpu => {
50
- tracing:: info!( "Starting Bert model on CPU" ) ;
51
-
52
53
if & dtype == "float32" || & dtype == "float16" {
53
54
let dtype = if & dtype == "float32" {
54
55
DType :: F32
@@ -70,14 +71,21 @@ impl CandleBackend {
70
71
}
71
72
. s ( ) ?;
72
73
73
- Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
74
+ if config. position_embedding_type == PositionEmbeddingType :: Alibi {
75
+ tracing:: info!( "Starting JinaBert model on CPU" ) ;
76
+ Box :: new ( JinaBertModel :: load ( vb, & config, pool) . s ( ) ?)
77
+ } else {
78
+ tracing:: info!( "Starting Bert model on CPU" ) ;
79
+ Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
80
+ }
74
81
} else if & dtype == "q6k" {
75
82
let vb = candle_transformers:: quantized_var_builder:: VarBuilder :: from_gguf (
76
83
model_path. join ( "ggml-model-q6k.bin" ) ,
77
84
)
78
85
. map_err ( |err| BackendError :: Start ( err. to_string ( ) ) ) ?;
79
86
tracing:: info!( "vb" ) ;
80
87
88
+ tracing:: info!( "Starting QuantBert model on CPU" ) ;
81
89
Box :: new ( QuantBertModel :: load ( vb, & config, pool) . s ( ) ?)
82
90
} else {
83
91
return Err ( BackendError :: Start ( format ! (
@@ -130,6 +138,9 @@ impl CandleBackend {
130
138
{
131
139
tracing:: info!( "Starting FlashBert model on Cuda" ) ;
132
140
Box :: new ( FlashBertModel :: load ( vb, & config, pool) . s ( ) ?)
141
+ } else if config. position_embedding_type == PositionEmbeddingType :: Alibi {
142
+ tracing:: info!( "Starting JinaBert model on Cuda" ) ;
143
+ Box :: new ( JinaBertModel :: load ( vb, & config, pool) . s ( ) ?)
133
144
} else {
134
145
tracing:: info!( "Starting Bert model on Cuda" ) ;
135
146
Box :: new ( BertModel :: load ( vb, & config, pool) . s ( ) ?)
0 commit comments