Skip to content

Commit bcbf148

Browse files
authored
refactor(collator): combine collator for pair and triplet
1 parent 7498bb8 commit bcbf148

File tree

16 files changed

+112
-198
lines changed

16 files changed

+112
-198
lines changed

README.md

+4-6
Original file line numberDiff line numberDiff line change
@@ -195,15 +195,14 @@ print(response)
195195
import torch.nn as nn
196196
from datasets import load_dataset
197197
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
198-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
198+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
199199
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
200200

201201
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
202202
batch_size: int = 32
203203
epochs: int = 3
204204

205205
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
206-
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
207206
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
208207
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
209208
model = model.set_train_type('pairwise')
@@ -225,7 +224,7 @@ trainer = RetrievalTrainer(
225224
model=model,
226225
args=training_arguments,
227226
train_dataset=train_dataset,
228-
data_collator=PairCollator(tokenizer, query_max_length=32, document_max_length=128),
227+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
229228
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
230229
)
231230
trainer.optimizer = optimizer
@@ -240,7 +239,7 @@ trainer.train()
240239
import torch.nn as nn
241240
from datasets import load_dataset
242241
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
243-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
242+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
244243
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
245244

246245
def add_instructions(example):
@@ -255,7 +254,6 @@ query_instruction = "Retrieve relevant passages that answer the query\nQuery: "
255254
document_instruction = "Document: "
256255

257256
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
258-
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
259257
train_dataset = train_dataset.map(add_instructions)
260258
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
261259
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="last", use_lora=True)
@@ -275,7 +273,7 @@ trainer = RetrievalTrainer(
275273
model=model,
276274
args=training_arguments,
277275
train_dataset=train_dataset,
278-
data_collator=PairCollator(tokenizer, query_max_length=64, document_max_length=128),
276+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
279277
)
280278
trainer.optimizer = optimizer
281279
trainer.scheduler = scheduler

README_ja-JP.md

+7-5
Original file line numberDiff line numberDiff line change
@@ -190,34 +190,36 @@ print(response)
190190
import torch.nn as nn
191191
from datasets import load_dataset
192192
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
193-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
193+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
194194
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
195195

196196
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
197-
batch_size: int = 128
197+
batch_size: int = 32
198198
epochs: int = 3
199199

200200
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
201-
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'document'})
202201
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
203202
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
204203
model = model.set_train_type('pairwise')
205204

206205
optimizer = AdamW(model.parameters(), lr=5e-5)
207206
num_train_steps = int(len(train_dataset) / batch_size * epochs)
208-
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
207+
scheduler = get_linear_schedule_with_warmup(
208+
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
209+
)
209210

210211
training_arguments = TrainingArguments(
211212
output_dir='./checkpoints',
212213
num_train_epochs=epochs,
213214
per_device_train_batch_size=batch_size,
214215
remove_unused_columns=False,
216+
logging_steps=100,
215217
)
216218
trainer = RetrievalTrainer(
217219
model=model,
218220
args=training_arguments,
219221
train_dataset=train_dataset,
220-
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
222+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
221223
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
222224
)
223225
trainer.optimizer = optimizer

README_zh-CN.md

+4-7
Original file line numberDiff line numberDiff line change
@@ -198,20 +198,17 @@ print(response)
198198
<details><summary> 微调向量模型 </summary>
199199

200200
```python
201-
import os
202201
import torch.nn as nn
203202
from datasets import load_dataset
204203
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
205-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
204+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
206205
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
207-
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
208206

209207
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
210208
batch_size: int = 32
211209
epochs: int = 3
212210

213211
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
214-
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
215212
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
216213
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
217214
model = model.set_train_type('pairwise')
@@ -233,7 +230,7 @@ trainer = RetrievalTrainer(
233230
model=model,
234231
args=training_arguments,
235232
train_dataset=train_dataset,
236-
data_collator=PairCollator(tokenizer, query_max_length=32, document_max_length=128),
233+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
237234
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
238235
)
239236
trainer.optimizer = optimizer
@@ -249,7 +246,7 @@ import os
249246
import torch.nn as nn
250247
from datasets import load_dataset
251248
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
252-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
249+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
253250
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
254251
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
255252

@@ -285,7 +282,7 @@ trainer = RetrievalTrainer(
285282
model=model,
286283
args=training_arguments,
287284
train_dataset=train_dataset,
288-
data_collator=PairCollator(tokenizer, query_max_length=64, document_max_length=128),
285+
data_collator=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128]),
289286
)
290287
trainer.optimizer = optimizer
291288
trainer.scheduler = scheduler

docs/source/embed.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,14 @@ If the positive and negative examples have some noise in label, the directly poi
8888
import torch.nn as nn
8989
from datasets import load_dataset
9090
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
91-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
91+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
9292
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
9393
9494
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
9595
batch_size: int = 128
9696
epochs: int = 3
9797
9898
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
99-
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
10099
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
101100
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
102101
optimizer = AdamW(model.parameters(), lr=5e-5)
@@ -113,7 +112,7 @@ If the positive and negative examples have some noise in label, the directly poi
113112
model=model,
114113
args=training_arguments,
115114
train_dataset=train_dataset,
116-
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
115+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[64, 128]),,
117116
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
118117
)
119118
trainer.optimizer = optimizer
@@ -136,6 +135,7 @@ If the positive and negative examples have some noise in label, the directly poi
136135
--model_name_or_path $MODEL_NAME \
137136
--do_train \
138137
--data_name_or_path $TRAIN_DATA \
138+
--query_key query \
139139
--positive_key positive \
140140
--negative_key negative \
141141
--learning_rate 3e-5 \
@@ -167,6 +167,7 @@ If the positive and negative examples have some noise in label, the directly poi
167167
--pooling_method last \
168168
--do_train \
169169
--data_name_or_path $TRAIN_DATA \
170+
--query_key query \
170171
--positive_key positive \
171172
--negative_key negative \
172173
--use_lora True \

docs/source/quick-start.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ To further improve the retrieval performance, we can fine tune the embedding mod
5252
import torch.nn as nn
5353
from datasets import load_dataset
5454
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
55-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
55+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, RetrievalCollator
5656
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
5757
5858
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -77,7 +77,7 @@ To further improve the retrieval performance, we can fine tune the embedding mod
7777
model=model,
7878
args=training_arguments,
7979
train_dataset=train_dataset,
80-
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
80+
data_collator=RetrievalCollator(tokenizer, keys=['query', 'document'], max_lengths=[128, 128]),
8181
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
8282
)
8383
trainer.optimizer = optimizer

examples/0_embedding/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ from Cython.Compiler.Options import embed
3636

3737
Train directly using shell script, refer to the [document](https://open-retrievals.readthedocs.io/en/master/embed.html)
3838

39+
3940
### Transformer encoder embedding
4041

4142
Refer to [the fine-tuning code](./train_pairwise.py) to train the model like

examples/0_embedding/train_llm.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
AutoModelForEmbedding,
2323
AutoModelForRetrieval,
2424
PairwiseModel,
25+
RetrievalCollator,
2526
RetrievalTrainer,
26-
TripletCollator,
2727
)
2828
from retrievals.losses import InfoNCE, TripletLoss, TripletRankingLoss
2929

@@ -259,8 +259,10 @@ def main():
259259
model=model,
260260
args=training_args,
261261
train_dataset=train_dataset,
262-
data_collator=TripletCollator(
263-
tokenizer, query_max_length=data_args.query_max_length, document_max_length=data_args.document_max_length
262+
data_collator=RetrievalCollator(
263+
tokenizer,
264+
keys=['query', 'positive', 'negative'],
265+
max_lengths=[data_args.query_max_length, data_args.document_max_length, data_args.document_max_length],
264266
),
265267
)
266268
trainer.optimizer = optimizer

examples/0_embedding/train_pairwise.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
get_linear_schedule_with_warmup,
1010
)
1111

12-
from retrievals import AutoModelForEmbedding, PairCollator, RetrievalTrainer
12+
from retrievals import AutoModelForEmbedding, RetrievalCollator, RetrievalTrainer
1313
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
1414

1515
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -41,7 +41,7 @@ def train():
4141
model=model,
4242
args=training_args,
4343
train_dataset=train_dataset,
44-
data_collator=PairCollator(tokenizer, query_max_length=128, document_max_length=128),
44+
data_collator=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[128, 128]),
4545
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)),
4646
)
4747
trainer.optimizer = optimizer

examples/eval/eval_retrieval2.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
AutoModelForRanking,
1111
ColBERT,
1212
LLMRanker,
13-
PairCollator,
13+
RetrievalCollator,
1414
RetrievalTrainer,
15-
TripletCollator,
1615
)
1716

1817
logger = logging.getLogger(__name__)

src/retrievals/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
ColBertCollator,
33
EncodeCollator,
44
LLMRerankCollator,
5-
PairCollator,
65
RerankCollator,
7-
TripletCollator,
6+
RetrievalCollator,
87
)
98
from .data.dataset import EncodeDataset, RerankTrainDataset, RetrievalTrainDataset
109
from .models.embedding_auto import AutoModelForEmbedding, ListwiseModel, PairwiseModel

src/retrievals/data/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
ColBertCollator,
33
EncodeCollator,
44
LLMRerankCollator,
5-
PairCollator,
65
RerankCollator,
7-
TripletCollator,
6+
RetrievalCollator,
87
)
98
from .dataset import EncodeDataset, RerankTrainDataset, RetrievalTrainDataset
109
from .sampler import GroupedBatchSampler, GroupSortedBatchSampler

0 commit comments

Comments
 (0)