Skip to content

Commit 13d30f8

Browse files
authored
docs: update the fine-tuning examples
1 parent d922bf8 commit 13d30f8

14 files changed

+381
-217
lines changed

README.md

+98-19
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
![structure](./docs/source/_static/structure.png)
3737

3838
**Open-retrievals** unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable.
39-
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning, and LLM.
40-
- Reranking fine-tuned with Cross Encoder, ColBERT, and LLM.
41-
- Easily build enhanced modular RAG, integrated with Transformers, Langchain, and LlamaIndex.
39+
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning and LLM.
40+
- Reranking fine-tuned with Cross-Encoder, ColBERT and LLM.
41+
- Easily build enhanced modular RAG, integrated with Transformers, Langchain and LlamaIndex.
4242

4343
| Experiment | Model | Original | Finetuned | Demo |
4444
|-------------------------------|------------------------|----------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
@@ -48,7 +48,7 @@
4848
| **rerank** colbert | bge-m3 | 0.657 | **0.695** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
4949
| **rerank** LLM (LoRA) | bge-reranker-v2-gemma | 0.637 | **0.706** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing) |
5050

51-
* The metrics is MAP in 10% eval [t2-reranking data](https://huggingface.co/datasets/C-MTEB/T2Reranking).
51+
* The eval metrics is MAP in 10% [t2-reranking data](https://huggingface.co/datasets/C-MTEB/T2Reranking).
5252
* Read [more examples](./examples)
5353

5454

@@ -76,7 +76,7 @@ python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
7676

7777
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-WBMisdWLeHUKlzJ2DrREXY_kSV8vjP3?usp=sharing)
7878

79-
<details><summary> Embeddings from pretrained weights </summary>
79+
<details><summary> Embedding from pretrained weights </summary>
8080

8181
```python
8282
from retrievals import AutoModelForEmbedding
@@ -89,7 +89,7 @@ sentences = [
8989
]
9090
model_name_or_path = 'intfloat/e5-base-v2'
9191
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
92-
embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
92+
embeddings = model.encode(sentences, normalize_embeddings=True)
9393
scores = (embeddings[:2] @ embeddings[2:].T) * 100
9494
print(scores.tolist())
9595
```
@@ -103,7 +103,7 @@ from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
103103
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
104104
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
105105
index_path = './database/faiss/faiss.index'
106-
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
106+
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method='mean')
107107
model.build_index(sentences, index_path=index_path)
108108

109109
query_embed = model.encode("He plays guitar.")
@@ -216,7 +216,7 @@ epochs: int = 3
216216
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
217217
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
218218
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
219-
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
219+
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
220220
model = model.set_train_type('pairwise')
221221

222222
optimizer = AdamW(model.parameters(), lr=5e-5)
@@ -252,14 +252,22 @@ import torch.nn as nn
252252
from datasets import load_dataset
253253
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, TrainingArguments
254254
from retrievals import AutoModelForEmbedding, RetrievalTrainer, PairCollator, TripletCollator
255-
from retrievals.losses import ArcFaceAdaptiveMarginLoss, InfoNCE, SimCSE, TripletLoss
255+
from retrievals.losses import InfoNCE, SimCSE, TripletLoss
256+
257+
def add_instructions(example):
258+
example['query'] = query_instruction + example['query']
259+
example['positive'] = document_instruction + example['positive']
260+
return example
256261

257262
model_name_or_path: str = "Qwen/Qwen2-1.5B-Instruct"
258263
batch_size: int = 8
259264
epochs: int = 3
265+
query_instruction = "Retrieve relevant passages that answer the query\nQuery: "
266+
document_instruction = "Document: "
260267

261268
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
262269
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'positive'})
270+
train_dataset = train_dataset.map(add_instructions)
263271
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
264272
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="last", use_lora=True)
265273
model = model.set_train_type('pairwise', loss_fn=InfoNCE(nn.CrossEntropyLoss(label_smoothing=0.05)))
@@ -272,6 +280,7 @@ training_arguments = TrainingArguments(
272280
num_train_epochs=epochs,
273281
per_device_train_batch_size=batch_size,
274282
remove_unused_columns=False,
283+
logging_steps=100,
275284
)
276285
trainer = RetrievalTrainer(
277286
model=model,
@@ -291,25 +300,32 @@ trainer.train()
291300
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
292301
from retrievals import RerankCollator, AutoModelForRanking, RerankTrainer, RerankTrainDataset
293302

294-
model_name_or_path: str = "microsoft/deberta-v3-base"
303+
model_name_or_path: str = "BAAI/bge-reranker-base"
295304
max_length: int = 128
296305
learning_rate: float = 3e-5
297306
batch_size: int = 4
298307
epochs: int = 3
308+
output_dir: str = "./checkpoints"
299309

300-
train_dataset = RerankTrainDataset('./t2rank.json', positive_key='pos', negative_key='neg')
310+
train_dataset = RerankTrainDataset("C-MTEB/T2Reranking", positive_key="positive", negative_key="negative", dataset_split='dev')
301311
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
302312
model = AutoModelForRanking.from_pretrained(model_name_or_path)
303313
optimizer = AdamW(model.parameters(), lr=learning_rate)
304314
num_train_steps = int(len(train_dataset) / batch_size * epochs)
305-
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
315+
scheduler = get_cosine_schedule_with_warmup(
316+
optimizer,
317+
num_warmup_steps=0.05 * num_train_steps,
318+
num_training_steps=num_train_steps,
319+
)
306320

307321
training_args = TrainingArguments(
308322
learning_rate=learning_rate,
309323
per_device_train_batch_size=batch_size,
310324
num_train_epochs=epochs,
311-
output_dir='./checkpoints',
325+
output_dir=output_dir,
312326
remove_unused_columns=False,
327+
logging_steps=100,
328+
report_to="none",
313329
)
314330
trainer = RerankTrainer(
315331
model=model,
@@ -348,9 +364,7 @@ epochs: int = 3
348364
colbert_dim: int = 1024
349365
output_dir: str = './checkpoints'
350366

351-
train_dataset = RetrievalTrainDataset(
352-
'C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev'
353-
)
367+
train_dataset = RetrievalTrainDataset('C-MTEB/T2Reranking', positive_key='positive', negative_key='negative', dataset_split='dev')
354368
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
355369
data_collator = ColBertCollator(
356370
tokenizer,
@@ -367,9 +381,7 @@ model = ColBERT.from_pretrained(
367381

368382
optimizer = AdamW(model.parameters(), lr=learning_rate)
369383
num_train_steps = int(len(train_dataset) / batch_size * epochs)
370-
scheduler = get_cosine_schedule_with_warmup(
371-
optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps
372-
)
384+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
373385

374386
training_args = TrainingArguments(
375387
learning_rate=learning_rate,
@@ -394,7 +406,74 @@ trainer.train()
394406
<details><summary> Fine-tune LLM reranking </summary>
395407

396408
```python
409+
from transformers import (
410+
AdamW,
411+
AutoTokenizer,
412+
TrainingArguments,
413+
get_cosine_schedule_with_warmup,
414+
)
397415

416+
from retrievals import (
417+
LLMRanker,
418+
LLMRerankCollator,
419+
RerankTrainer,
420+
RetrievalTrainDataset,
421+
)
422+
from retrievals.losses import TokenLoss
423+
424+
model_name_or_path: str = "Qwen/Qwen2-1.5B-Instruct"
425+
max_length: int = 512
426+
learning_rate: float = 3e-5
427+
batch_size: int = 8
428+
epochs: int = 3
429+
task_prompt: str = (
430+
"""Given a query A and a passage B, determine whether the passage contains an answer to the query"""
431+
"""by providing a prediction of either 'Yes' or 'No'."""
432+
)
433+
434+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
435+
train_dataset = RetrievalTrainDataset(
436+
data_name_or_path='C-MTEB/T2Reranking',
437+
positive_key='positive',
438+
negative_key='negative',
439+
query_instruction='A: ',
440+
document_instruction='B: ',
441+
dataset_split='dev',
442+
)
443+
data_collator = LLMRerankCollator(tokenizer=tokenizer, max_length=max_length, prompt=task_prompt, add_target_token='Yes')
444+
token_index = tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
445+
model = LLMRanker.from_pretrained(
446+
model_name_or_path,
447+
causal_lm=True,
448+
use_fp16=True,
449+
loss_fn=TokenLoss(token_index=token_index),
450+
use_lora=True,
451+
)
452+
453+
optimizer = AdamW(model.parameters(), lr=learning_rate)
454+
num_train_steps = int(len(train_dataset) / batch_size * epochs)
455+
scheduler = get_cosine_schedule_with_warmup(
456+
optimizer,
457+
num_warmup_steps=0.05 * num_train_steps,
458+
num_training_steps=num_train_steps,
459+
)
460+
461+
training_args = TrainingArguments(
462+
learning_rate=learning_rate,
463+
per_device_train_batch_size=batch_size,
464+
num_train_epochs=epochs,
465+
output_dir="./checkpoints",
466+
remove_unused_columns=False,
467+
)
468+
trainer = RerankTrainer(
469+
model=model,
470+
args=training_args,
471+
train_dataset=train_dataset,
472+
data_collator=data_collator,
473+
)
474+
trainer.optimizer = optimizer
475+
trainer.scheduler = scheduler
476+
trainer.train()
398477
```
399478
</details>
400479

README_ja-JP.md

+18-8
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ sentences = [
7979
]
8080
model_name_or_path = 'intfloat/e5-base-v2'
8181
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
82-
embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
82+
embeddings = model.encode(sentences, normalize_embeddings=True)
8383
scores = (embeddings[:2] @ embeddings[2:].T) * 100
8484
print(scores.tolist())
8585
```
@@ -91,7 +91,7 @@ from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
9191
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
9292
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
9393
index_path = './database/faiss/faiss.index'
94-
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
94+
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method='mean')
9595
model.build_index(sentences, index_path=index_path)
9696

9797
query_embed = model.encode("He plays guitar.")
@@ -199,8 +199,9 @@ epochs: int = 3
199199
train_dataset = load_dataset('shibing624/nli_zh', 'STS-B')['train']
200200
train_dataset = train_dataset.rename_columns({'sentence1': 'query', 'sentence2': 'document'})
201201
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
202-
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="cls")
203-
# model = model.set_train_type('pointwise') # 'pointwise', 'pairwise', 'listwise'
202+
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
203+
model = model.set_train_type('pairwise')
204+
204205
optimizer = AdamW(model.parameters(), lr=5e-5)
205206
num_train_steps = int(len(train_dataset) / batch_size * epochs)
206207
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
@@ -240,25 +241,34 @@ model = AutoModelForEmbedding.from_pretrained(
240241
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
241242
from retrievals import RerankCollator, AutoModelForRanking, RerankTrainer, RerankTrainDataset
242243

243-
model_name_or_path: str = "microsoft/deberta-v3-base"
244+
model_name_or_path: str = "BAAI/bge-reranker-base"
244245
max_length: int = 128
245246
learning_rate: float = 3e-5
246247
batch_size: int = 4
247248
epochs: int = 3
249+
output_dir: str = "./checkpoints"
248250

249-
train_dataset = RerankTrainDataset('./t2rank.json', positive_key='pos', negative_key='neg')
251+
train_dataset = RerankTrainDataset(
252+
"C-MTEB/T2Reranking", positive_key="positive", negative_key="negative", dataset_split='dev'
253+
)
250254
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
251255
model = AutoModelForRanking.from_pretrained(model_name_or_path)
252256
optimizer = AdamW(model.parameters(), lr=learning_rate)
253257
num_train_steps = int(len(train_dataset) / batch_size * epochs)
254-
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
258+
scheduler = get_cosine_schedule_with_warmup(
259+
optimizer,
260+
num_warmup_steps=0.05 * num_train_steps,
261+
num_training_steps=num_train_steps,
262+
)
255263

256264
training_args = TrainingArguments(
257265
learning_rate=learning_rate,
258266
per_device_train_batch_size=batch_size,
259267
num_train_epochs=epochs,
260-
output_dir='./checkpoints',
268+
output_dir=output_dir,
261269
remove_unused_columns=False,
270+
logging_steps=100,
271+
report_to="none",
262272
)
263273
trainer = RerankTrainer(
264274
model=model,

0 commit comments

Comments
 (0)