Skip to content

Commit ed55f72

Browse files
authored
refactor(embed): construct the embed model
1 parent 1b72e5e commit ed55f72

22 files changed

+212
-311
lines changed

README.md

+14-14
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ print(response)
195195
import torch.nn as nn
196196
from datasets import load_dataset
197197
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
198-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
198+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
199199
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
200200

201201
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -205,9 +205,9 @@ epochs: int = 3
205205
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
206206
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
207207
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
208-
model = model.set_train_type('pairwise')
208+
train_model = PairwiseModel(model)
209209

210-
optimizer = AdamW(model.parameters(), lr=5e-5)
210+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
211211
num_train_steps = int(len(train_dataset) / batch_size * epochs)
212212
scheduler = get_linear_schedule_with_warmup(
213213
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
@@ -221,7 +221,7 @@ training_arguments = TrainingArguments(
221221
logging_steps=100,
222222
)
223223
trainer = RetrievalTrainer(
224-
model=model,
224+
model=train_model,
225225
args=training_arguments,
226226
train_dataset=train_dataset,
227227
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
@@ -239,26 +239,26 @@ trainer.train()
239239
import torch.nn as nn
240240
from datasets import load_dataset
241241
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
242-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
242+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
243243
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
244244

245245
def add_instructions(example):
246-
example['query'] = query_instruction + example['query']
247-
example['positive'] = document_instruction + example['positive']
246+
example['query'] = query_instruction.format(example['query'])
247+
example['positive'] = document_instruction.format(example['positive'])
248248
return example
249249

250250
model_name_or_path: str = "Qwen/Qwen2-1.5B-Instruct"
251251
batch_size: int = 8
252252
epochs: int = 3
253-
query_instruction = "Retrieve relevant passages that answer the query\nQuery: "
254-
document_instruction = "Document: "
253+
query_instruction = "Retrieve relevant passages that answer the query\nQuery: {}"
254+
document_instruction = "Document: {}"
255255

256256
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
257257
train_dataset = train_dataset.map(add_instructions)
258258
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
259259
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="last", use_lora=True)
260-
model = model.set_train_type('pairwise', loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
261-
optimizer = AdamW(model.parameters(), lr=5e-5)
260+
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
261+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
262262
num_train_steps = int(len(train_dataset) / batch_size * epochs)
263263
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
264264

@@ -270,7 +270,7 @@ training_arguments = TrainingArguments(
270270
logging_steps=100,
271271
)
272272
trainer = RetrievalTrainer(
273-
model=model,
273+
model=train_model,
274274
args=training_arguments,
275275
train_dataset=train_dataset,
276276
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
@@ -423,8 +423,8 @@ train_dataset = RetrievalTrainDataset(
423423
data_name_or_path='C-MTEB/T2Reranking',
424424
positive_key='positive',
425425
negative_key='negative',
426-
query_instruction='A: ',
427-
document_instruction='B: ',
426+
query_instruction='A: {}',
427+
document_instruction='B: {}',
428428
dataset_split='dev',
429429
)
430430
data_collator = LLMRerankCollator(tokenizer=tokenizer, max_length=max_length, prompt=task_prompt, add_target_token='Yes')

README_ja-JP.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ print(response)
190190
import torch.nn as nn
191191
from datasets import load_dataset
192192
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
193-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
193+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
194194
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
195195

196196
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -200,9 +200,9 @@ epochs: int = 3
200200
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
201201
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
202202
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
203-
model = model.set_train_type('pairwise')
203+
train_model = PairwiseModel(model)
204204

205-
optimizer = AdamW(model.parameters(), lr=5e-5)
205+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
206206
num_train_steps = int(len(train_dataset) / batch_size * epochs)
207207
scheduler = get_linear_schedule_with_warmup(
208208
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
@@ -216,7 +216,7 @@ training_arguments = TrainingArguments(
216216
logging_steps=100,
217217
)
218218
trainer = RetrievalTrainer(
219-
model=model,
219+
model=train_model,
220220
args=training_arguments,
221221
train_dataset=train_dataset,
222222
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
@@ -234,7 +234,7 @@ from retrievals import AutoModelForEmbedding
234234
model = AutoModelForEmbedding.from_pretrained(
235235
"mistralai/Mistral-7B-v0.1",
236236
pooling_method='last',
237-
query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: '
237+
query_instruction=f'Instruct: Retrieve semantically similar text\nQuery: {}'
238238
)
239239
```
240240

README_zh-CN.md

+14-14
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ print(response)
201201
import torch.nn as nn
202202
from datasets import load_dataset
203203
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
204-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
204+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
205205
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
206206

207207
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -211,9 +211,9 @@ epochs: int = 3
211211
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
212212
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
213213
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
214-
model = model.set_train_type('pairwise')
214+
train_model = PairwiseModel(model)
215215

216-
optimizer = AdamW(model.parameters(), lr=5e-5)
216+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
217217
num_train_steps = int(len(train_dataset) / batch_size * epochs)
218218
scheduler = get_linear_schedule_with_warmup(
219219
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
@@ -227,7 +227,7 @@ training_arguments = TrainingArguments(
227227
logging_steps=100,
228228
)
229229
trainer = RetrievalTrainer(
230-
model=model,
230+
model=train_model,
231231
args=training_arguments,
232232
train_dataset=train_dataset,
233233
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[32, 128]),
@@ -246,28 +246,28 @@ import os
246246
import torch.nn as nn
247247
from datasets import load_dataset
248248
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
249-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
249+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
250250
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
251251
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
252252

253253
def add_instructions(example):
254-
example['query'] = query_instruction + example['query']
255-
example['positive'] = document_instruction + example['positive']
254+
example['query'] = query_instruction.format(example['query'])
255+
example['positive'] = document_instruction.format(example['positive'])
256256
return example
257257

258258
model_name_or_path: str = "Qwen/Qwen2-1.5B-Instruct"
259259
batch_size: int = 8
260260
epochs: int = 3
261-
query_instruction = "Retrieve relevant passages that answer the query\nQuery: "
262-
document_instruction = "Document: "
261+
query_instruction = "Retrieve relevant passages that answer the query\nQuery: {}"
262+
document_instruction = "Document: {}"
263263

264264
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
265265
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
266266
train_dataset = train_dataset.map(add_instructions)
267267
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
268268
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="last", use_lora=True)
269-
model = model.set_train_type('pairwise', loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
270-
optimizer = AdamW(model.parameters(), lr=5e-5)
269+
train_model = PairwiseModel(model, loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
270+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
271271
num_train_steps = int(len(train_dataset) / batch_size * epochs)
272272
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
273273

@@ -279,7 +279,7 @@ training_arguments = TrainingArguments(
279279
logging_steps=100,
280280
)
281281
trainer = RetrievalTrainer(
282-
model=model,
282+
model=train_model,
283283
args=training_arguments,
284284
train_dataset=train_dataset,
285285
data_collator=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[64, 128]),
@@ -439,8 +439,8 @@ train_dataset = RetrievalTrainDataset(
439439
data_name_or_path='C-MTEB/T2Reranking',
440440
positive_key='positive',
441441
negative_key='negative',
442-
query_instruction='A: ',
443-
document_instruction='B: ',
442+
query_instruction='A: {}',
443+
document_instruction='B: {}',
444444
dataset_split='dev',
445445
)
446446
data_collator = LLMRerankCollator(

docs/source/embed.rst

+10-7
Original file line numberDiff line numberDiff line change
@@ -88,31 +88,34 @@ If the positive and negative examples have some noise in label, the directly poi
8888
import torch.nn as nn
8989
from datasets import load_dataset
9090
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
91-
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator
91+
from retrievals import AutoModelForEmbedding, RetrievalTrainer, RetrievalCollator, PairwiseModel
9292
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
9393
9494
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
95-
batch_size: int = 128
95+
batch_size: int = 32
9696
epochs: int = 3
9797
9898
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
9999
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
100100
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
101+
train_model = PairwiseModel(model)
102+
101103
optimizer = AdamW(model.parameters(), lr=5e-5)
102-
num_train_steps=int(len(train_dataset) / batch_size * epochs)
104+
num_train_steps = int(len(train_dataset) / batch_size * epochs)
103105
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
104106
105107
training_arguments = TrainingArguments(
106108
output_dir='./checkpoints',
107109
num_train_epochs=epochs,
108110
per_device_train_batch_size=batch_size,
109111
remove_unused_columns=False,
112+
logging_steps=100,
110113
)
111114
trainer = RetrievalTrainer(
112-
model=model,
115+
model=train_model,
113116
args=training_arguments,
114117
train_dataset=train_dataset,
115-
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[64, 128]),,
118+
data_collator=RetrievalCollator(tokenizer, keys=['sentence1', 'sentence2'], max_lengths=[64, 128]),
116119
loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05), use_inbatch_negative=True, negatives_cross_device=False),
117120
)
118121
trainer.optimizer = optimizer
@@ -172,8 +175,8 @@ If the positive and negative examples have some noise in label, the directly poi
172175
--positive_key positive \
173176
--negative_key negative \
174177
--use_lora True \
175-
--query_instruction "Retrieve the possible answer for query.\nQuery: " \
176-
--document_instruction 'Document: ' \
178+
--query_instruction "Retrieve the possible answer for query.\nQuery: {}" \
179+
--document_instruction 'Document: {}' \
177180
--learning_rate 2e-4 \
178181
--bf16 \
179182
--num_train_epochs 3 \

examples/0_embedding/train_llm.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,16 @@ def main():
244244
pooling_method=training_args.pooling_method,
245245
lora_config=lora_config,
246246
)
247-
model = model.set_train_type(
248-
"pairwise",
247+
248+
train_model = PairwiseModel(
249+
model,
249250
loss_fn=TripletLoss(
250251
use_inbatch_negative=training_args.use_inbatch_neg,
251252
negatives_cross_device=training_args.negatives_cross_device,
252253
),
253254
)
254255

255-
optimizer = get_optimizer(model, lr=5e-5, weight_decay=1e-3)
256+
optimizer = get_optimizer(train_model, lr=5e-5, weight_decay=1e-3)
256257

257258
lr_scheduler = get_scheduler(
258259
optimizer,
@@ -262,7 +263,7 @@ def main():
262263
)
263264

264265
trainer = RetrievalTrainer(
265-
model=model,
266+
model=train_model,
266267
args=training_args,
267268
train_dataset=train_dataset,
268269
data_collator=RetrievalCollator(

examples/0_embedding/train_pairwise.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
get_linear_schedule_with_warmup,
1010
)
1111

12-
from retrievals import AutoModelForEmbedding, RetrievalCollator, RetrievalTrainer
12+
from retrievals import (
13+
AutoModelForEmbedding,
14+
PairwiseModel,
15+
RetrievalCollator,
16+
RetrievalTrainer,
17+
)
1318
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
1419

1520
model_name_or_path: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
@@ -23,9 +28,9 @@ def train():
2328
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
2429
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
2530
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
26-
model = model.set_train_type('pairwise')
31+
train_model = PairwiseModel(model)
2732

28-
optimizer = AdamW(model.parameters(), lr=5e-5)
33+
optimizer = AdamW(train_model.parameters(), lr=5e-5)
2934
num_train_steps = int(len(train_dataset) / batch_size * epochs)
3035
scheduler = get_linear_schedule_with_warmup(
3136
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
@@ -38,7 +43,7 @@ def train():
3843
remove_unused_columns=False,
3944
)
4045
trainer = RetrievalTrainer(
41-
model=model,
46+
model=train_model,
4247
args=training_args,
4348
train_dataset=train_dataset,
4449
data_collator=RetrievalCollator(tokenizer, keys=['query', 'positive'], max_lengths=[128, 128]),
@@ -48,7 +53,7 @@ def train():
4853
trainer.scheduler = scheduler
4954
trainer.train()
5055

51-
model.save_pretrained(training_args.output_dir)
56+
train_model.save_pretrained(training_args.output_dir)
5257
if trainer.is_world_process_zero():
5358
tokenizer.save_pretrained(training_args.output_dir)
5459

examples/README_zh_CN.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export HF_ENDPOINT=https://hf-mirror.com
3030
- [向量模型pairwise微调](./0_embedding/train_pairwise.py)
3131
- [decoder大模型向量模型pairwise微调](./0_embedding/train_llm.py)
3232
- 设置 `query_instruction`
33-
- "给定一个查询和一个相关文档,检索与查询相关的文档\n查询: "
33+
- "给定一个查询和一个相关文档,检索与查询相关的文档\n查询: {}"
3434
- 使用适当的 `pooling_method`
3535
- `last`
3636
- 由于模型尺寸较大,可能需要减少批次大小

src/retrievals/data/collator.py

-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,6 @@ def __init__(
223223
add_target_token: str = '',
224224
sep_token: str = "\n",
225225
max_length: int = 128,
226-
tokenize_args: Optional[Dict] = None,
227226
pad_to_multiple_of: Optional[int] = 8,
228227
):
229228
self.tokenizer = tokenizer

0 commit comments

Comments
 (0)