@@ -201,7 +201,7 @@ print(response)
201
201
import torch.nn as nn
202
202
from datasets import load_dataset
203
203
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
204
- from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
204
+ from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
205
205
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
206
206
207
207
model_name_or_path: str = " sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -211,9 +211,9 @@ epochs: int = 3
211
211
train_dataset = load_dataset(' shibing624/nli_zh' , ' STS-B' )[' train' ]
212
212
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = False )
213
213
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method = " mean" )
214
- model = model.set_train_type( ' pairwise ' )
214
+ train_model = PairwiseModel(model )
215
215
216
- optimizer = AdamW(model .parameters(), lr = 5e-5 )
216
+ optimizer = AdamW(train_model .parameters(), lr = 5e-5 )
217
217
num_train_steps = int (len (train_dataset) / batch_size * epochs)
218
218
scheduler = get_linear_schedule_with_warmup(
219
219
optimizer, num_warmup_steps = 0.05 * num_train_steps, num_training_steps = num_train_steps
@@ -227,7 +227,7 @@ training_arguments = TrainingArguments(
227
227
logging_steps = 100 ,
228
228
)
229
229
trainer = RetrievalTrainer(
230
- model = model ,
230
+ model = train_model ,
231
231
args = training_arguments,
232
232
train_dataset = train_dataset,
233
233
data_collator = RetrievalCollator(tokenizer, keys = [' sentence1' , ' sentence2' ], max_lengths = [32 , 128 ]),
@@ -246,28 +246,28 @@ import os
246
246
import torch.nn as nn
247
247
from datasets import load_dataset
248
248
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
249
- from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
249
+ from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
250
250
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
251
251
os.environ[' HF_ENDPOINT' ] = ' https://hf-mirror.com'
252
252
253
253
def add_instructions (example ):
254
- example[' query' ] = query_instruction + example[' query' ]
255
- example[' positive' ] = document_instruction + example[' positive' ]
254
+ example[' query' ] = query_instruction.format( example[' query' ])
255
+ example[' positive' ] = document_instruction.format( example[' positive' ])
256
256
return example
257
257
258
258
model_name_or_path: str = " Qwen/Qwen2-1.5B-Instruct"
259
259
batch_size: int = 8
260
260
epochs: int = 3
261
- query_instruction = " Retrieve relevant passages that answer the query\n Query: "
262
- document_instruction = " Document: "
261
+ query_instruction = " Retrieve relevant passages that answer the query\n Query: {} "
262
+ document_instruction = " Document: {} "
263
263
264
264
train_dataset = load_dataset(' shibing624/nli_zh' , ' STS-B' )[' train' ]
265
265
train_dataset = train_dataset.rename_columns({' sentence1' : ' query' , ' sentence2' : ' positive' })
266
266
train_dataset = train_dataset.map(add_instructions)
267
267
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = False )
268
268
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method = " last" , use_lora = True )
269
- model = model.set_train_type( ' pairwise ' , loss_fn = InfoNCE(nn.CrossEntropyLoss(label_smoothing = 0.05 )))
270
- optimizer = AdamW(model .parameters(), lr = 5e-5 )
269
+ train_model = PairwiseModel(model , loss_fn = InfoNCE(nn.CrossEntropyLoss(label_smoothing = 0.05 )))
270
+ optimizer = AdamW(train_model .parameters(), lr = 5e-5 )
271
271
num_train_steps = int (len (train_dataset) / batch_size * epochs)
272
272
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0.05 * num_train_steps, num_training_steps = num_train_steps)
273
273
@@ -279,7 +279,7 @@ training_arguments = TrainingArguments(
279
279
logging_steps = 100 ,
280
280
)
281
281
trainer = RetrievalTrainer(
282
- model = model ,
282
+ model = train_model ,
283
283
args = training_arguments,
284
284
train_dataset = train_dataset,
285
285
data_collator = RetrievalCollator(tokenizer, keys = [' query' , ' positive' ], max_lengths = [64 , 128 ]),
@@ -439,8 +439,8 @@ train_dataset = RetrievalTrainDataset(
439
439
data_name_or_path = ' C-MTEB/T2Reranking' ,
440
440
positive_key = ' positive' ,
441
441
negative_key = ' negative' ,
442
- query_instruction = ' A: ' ,
443
- document_instruction = ' B: ' ,
442
+ query_instruction = ' A: {} ' ,
443
+ document_instruction = ' B: {} ' ,
444
444
dataset_split = ' dev' ,
445
445
)
446
446
data_collator = LLMRerankCollator(
0 commit comments