@@ -195,15 +195,14 @@ print(response)
195
195
import torch.nn as nn
196
196
from datasets import load_dataset
197
197
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
198
- from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
198
+ from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
199
199
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
200
200
201
201
model_name_or_path: str = " sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
202
202
batch_size: int = 32
203
203
epochs: int = 3
204
204
205
205
train_dataset = load_dataset(' shibing624/nli_zh' , ' STS-B' )[' train' ]
206
- train_dataset = train_dataset.rename_columns({' sentence1' : ' query' , ' sentence2' : ' positive' })
207
206
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = False )
208
207
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method = " mean" )
209
208
model = model.set_train_type(' pairwise' )
@@ -225,7 +224,7 @@ trainer = RetrievalTrainer(
225
224
model = model,
226
225
args = training_arguments,
227
226
train_dataset = train_dataset,
228
- data_collator = PairCollator (tokenizer, query_max_length = 32 , document_max_length = 128 ),
227
+ data_collator = RetrievalCollator (tokenizer, keys = [ ' sentence1 ' , ' sentence2 ' ], max_lengths = [ 32 , 128 ] ),
229
228
loss_fn = InfoNCE(nn.CrossEntropyLoss(label_smoothing = 0.05 )),
230
229
)
231
230
trainer.optimizer = optimizer
@@ -240,7 +239,7 @@ trainer.train()
240
239
import torch.nn as nn
241
240
from datasets import load_dataset
242
241
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
243
- from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
242
+ from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
244
243
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
245
244
246
245
def add_instructions (example ):
@@ -255,7 +254,6 @@ query_instruction = "Retrieve relevant passages that answer the query\nQuery: "
255
254
document_instruction = " Document: "
256
255
257
256
train_dataset = load_dataset(' shibing624/nli_zh' , ' STS-B' )[' train' ]
258
- train_dataset = train_dataset.rename_columns({' sentence1' : ' query' , ' sentence2' : ' positive' })
259
257
train_dataset = train_dataset.map(add_instructions)
260
258
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = False )
261
259
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method = " last" , use_lora = True )
@@ -275,7 +273,7 @@ trainer = RetrievalTrainer(
275
273
model = model,
276
274
args = training_arguments,
277
275
train_dataset = train_dataset,
278
- data_collator = PairCollator (tokenizer, query_max_length = 64 , document_max_length = 128 ),
276
+ data_collator = RetrievalCollator (tokenizer, keys = [ ' sentence1 ' , ' sentence2 ' ], max_lengths = [ 32 , 128 ] ),
279
277
)
280
278
trainer.optimizer = optimizer
281
279
trainer.scheduler = scheduler
0 commit comments