Skip to content

Commit 4738bb3

Browse files
authored
feat: build hard negative for retrieval
1 parent 13d30f8 commit 4738bb3

20 files changed

+296
-148
lines changed

README.md

+2-13
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
![structure](./docs/source/_static/structure.png)
3737

38-
**Open-retrievals** unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable.
38+
**Open-retrievals** unify text embedding, retrieval, reranking and RAG. It's easy, flexible and scalable to fine-tune the model.
3939
- Embedding fine-tuned through point-wise, pairwise, listwise, contrastive learning and LLM.
4040
- Reranking fine-tuned with Cross-Encoder, ColBERT and LLM.
4141
- Easily build enhanced modular RAG, integrated with Transformers, Langchain and LlamaIndex.
@@ -54,23 +54,12 @@
5454

5555
## Installation
5656

57-
**Prerequisites**
58-
```shell
59-
pip install transformers
60-
pip install faiss-cpu # if necessary while faiss retrieval
61-
pip install peft # if necessary while LoRA training
62-
```
63-
6457
**With pip**
6558
```shell
59+
pip install transformers
6660
pip install open-retrievals
6761
```
6862

69-
**With source code**
70-
```shell
71-
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
72-
```
73-
7463

7564
## Quick-start
7665

README_zh-CN.md

+7-18
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434

3535
![structure](./docs/source/_static/structure.png)
3636

37-
**Open-Retrievals** 支持统一调用或微调文本向量、检索、重排等模型,使信息检索、RAG应用更加便捷
38-
- 支持全套向量微调,对比学习、大模型、point-wise、pairwise、listwise
39-
- 支持全套重排微调,cross-encoder、ColBERT、LLM
37+
**Open-Retrievals** 统一调用和微调文本向量、检索、重排模型,使信息检索、RAG应用更加便捷
38+
- 支持文本向量微调,对比学习、大模型、point-wise、pairwise、listwise
39+
- 支持重排微调,cross-encoder、ColBERT、LLM
4040
- 支持定制化、模块化RAG,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型
4141

4242
| 实验 | 模型 | 原分数 | 微调分数 | Demo代码 |
@@ -54,23 +54,12 @@
5454

5555
## 安装
5656

57-
**基础**
58-
```shell
59-
pip install transformers
60-
pip install faiss # 如有必要,检索
61-
pip install peft # 如有必要,LoRA训练
62-
```
63-
6457
**pip安装**
6558
```shell
59+
pip install transformers
6660
pip install open-retrievals
6761
```
6862

69-
**源码安装**
70-
```shell
71-
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
72-
```
73-
7463

7564
## 快速入门
7665

@@ -306,7 +295,7 @@ trainer.train()
306295

307296
</details>
308297

309-
<details><summary> 微调Cross-encoder重排模型 </summary>
298+
<details><summary> 微调Cross-encoder重排 </summary>
310299

311300
```python
312301
import os
@@ -356,7 +345,7 @@ trainer.train()
356345

357346
</details>
358347

359-
<details><summary> 微调ColBERT重排模型 </summary>
348+
<details><summary> 微调ColBERT重排 </summary>
360349

361350
```python
362351
import os
@@ -422,7 +411,7 @@ trainer.train()
422411

423412
</details>
424413

425-
<details><summary> 微调大模型重排模型 </summary>
414+
<details><summary> 微调大模型重排 </summary>
426415

427416
```python
428417
import os

docs/source/embed.rst

+25-7
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ Prepare data
7676
Pair wise
7777
~~~~~~~~~~~~~
7878

79+
If the positive and negative examples have some noise in label, the directly point-wise cross-entropy maybe not the best. The pair wise just compare relatively, or the hinge loss with margin could be better.
80+
7981
.. image:: https://colab.research.google.com/assets/colab-badge.svg
8082
:target: https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing
8183
:alt: Open In Colab
@@ -188,9 +190,9 @@ Pair wise
188190
Point wise
189191
~~~~~~~~~~~~~~~~~~
190192

191-
If the positive and negative examples have some noise in label, the directly point-wise cross-entropy maybe not the best. The pair wise just compare relatively, or the hinge loss with margin could be better.
193+
We can use point-wise train, similar to use `tfidf` in information retrieval.
192194

193-
arcface
195+
**arcface**
194196

195197
- layer wise learning rate
196198
- batch size is important
@@ -202,7 +204,6 @@ List wise
202204
~~~~~~~~~~~~~~~~~~
203205

204206

205-
206207
3. Training skills to enhance the performance
207208
----------------------------------------------
208209

@@ -225,14 +226,31 @@ tuning the important parameters:
225226

226227

227228
Hard negative mining
228-
~~~~~~~~~~~~~~~~~~~~~~~~
229-
offline hard mining
229+
~~~~~~~~~~~~~~~~~~~~~~~~~
230+
231+
- offline hard mining or online hard mining
232+
233+
If we only have query and positive, we can use it to generate more negative samples to enhance the retrieval performance.
230234

231-
online hard mining
235+
The data format of `input_file` to generate hard negative is `(query, positive)` or `(query, positive, negative)`
236+
The format of `candidate_pool` of corpus is jsonl of `{text}`
237+
238+
239+
.. code-block:: shell
240+
241+
python -m retrievals.pipelines.build_hn \
242+
--model_name_or_path BAAI/bge-base-en-v1.5 \
243+
--input_file /t2_ranking.jsonl \
244+
--output_file /t2_ranking_hn.jsonl \
245+
--positive_key positive \
246+
--negative_key negative \
247+
--range_for_sampling 2-200 \
248+
--negative_number 15 \
232249
233250
234251
Matryoshka Representation Learning
235-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
252+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
253+
236254

237255

238256
Contrastive loss

docs/source/index.rst

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ Now you are ready, proceed with
3131
# install with support of evaluation
3232
pip install open-retrievals[eval]
3333
34+
Or install from source code
35+
36+
.. code-block:: shell
37+
38+
python -m pip install -U git+https://github.com/LongxingTan/open-retrievals.git
39+
3440
3541
Examples
3642
------------------

docs/source/rag.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Integrated with Langchain
3939
rerank_model_name_or_path = "BAAI/bge-reranker-base"
4040
llm_model_name_or_path = "microsoft/Phi-3-mini-128k-instruct"
4141
42-
embeddings = LangchainEmbedding(model_name_or_path=embed_model_name_or_path)
42+
embeddings = LangchainEmbedding(model_name_or_path=embed_model_name_or_path, model_kwargs={'pooling_method': 'mean'})
4343
vectordb = Vectorstore(
4444
persist_directory=persist_directory,
4545
embedding_function=embeddings,

docs/source/retrieval.rst

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Retrieval
66
1. Pipeline
77
----------------------------
88

9+
The retrieval method could solve the **search** or **extreme multiclass classification** problem.
10+
911
generate data -> train -> eval
1012

1113
pretrained encoding -> build hard negative -> train -> eval -> indexing -> retrieval

examples/0_embedding/README.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,13 @@
3232
```
3333

3434

35-
3635
Train directly using shell script, refer to the [document](https://open-retrievals.readthedocs.io/en/master/embed.html)
3736

38-
## Encoder embedding
37+
## Transformer encoder embedding
3938

4039
Refer to [the fine-tuning code](./train_pairwise.py) to train the model like
4140

4241

43-
4442
## LLM embedding
4543

4644
Refer to [the fine-tuning code](./train_llm.py), to train the model like

examples/0_embedding/train_llm.py

-3
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,6 @@ def __getitem__(self, item):
147147
query = self.dataset[item]["query"] + self.tokenizer.eos_token
148148
pos = self.dataset[item]["pos"][0] + self.tokenizer.eos_token
149149
neg = self.dataset[item]["neg"][0] + self.tokenizer.eos_token
150-
# pos = random.choice(self.dataset[item]["pos"])
151-
# neg = random.choice(self.dataset[item]["neg"])
152-
153150
res = {"query": query, "pos": pos, "neg": neg}
154151
return res
155152

examples/scifact/evaluate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from datasets import load_dataset
66

7-
from retrievals.metrics import get_mrr, get_ndcg, get_recall
7+
from retrievals.metrics import get_fbeta, get_mrr, get_ndcg
88

99

1010
def transfer_index_to_id(save_path):
@@ -56,7 +56,7 @@ def transfer_index_to_id(save_path):
5656
qid2ranking[qid].append(pid)
5757

5858
results = get_mrr(qid2positives, qid2ranking, cutoff_rank=10)
59-
results.update(get_recall(qid2positives, qid2ranking, cutoff_ranks=[10]))
59+
results.update(get_fbeta(qid2positives, qid2ranking, cutoff_ranks=[10]))
6060
results.update(get_ndcg(qid2positives, qid2ranking, cutoff_rank=10))
6161

6262
print(json.dumps(results, indent=4))

src/retrievals/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from .fbeta import get_recall
1+
from .fbeta import get_fbeta
2+
from .hit_rate import get_hit_rate
23
from .map import get_map
34
from .mrr import get_mrr
45
from .ndcg import get_ndcg

src/retrievals/metrics/fbeta.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Dict, List
22

33

4-
def get_recall(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_ranks: List[int] = [10]):
4+
def get_fbeta(
5+
qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_ranks: List[int] = [10], beta: int = 2
6+
):
57
qid2recall = {cutoff_rank: {} for cutoff_rank in cutoff_ranks}
68
num_samples = len(qid2ranking.keys())
79

src/retrievals/metrics/hit_rate.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from typing import Dict, List
2+
3+
4+
def get_hit_rate(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_rank: int = 10):
5+
"""
6+
qid2positive (order doesn't matter): {qid: [pos1_doc_id, pos2_doc_id]}
7+
qid2ranking (order does matter): {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
8+
"""
9+
10+
def hit_rate(positives_ids: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
11+
"""
12+
Calculate hit rate at the specified cutoff
13+
"""
14+
hits = 0
15+
16+
for doc_id in ranked_doc_ids[:cutoff]:
17+
if doc_id in positives_ids:
18+
hits += 1
19+
20+
return hits / cutoff if cutoff > 0 else 0.0
21+
22+
qid2hr = dict()
23+
24+
for qid in qid2positive:
25+
positives_ids = qid2positive[qid]
26+
ranked_doc_ids = qid2ranking[qid]
27+
qid2hr[qid] = hit_rate(positives_ids, ranked_doc_ids, cutoff_rank)
28+
29+
return {f"hit_rate@{cutoff_rank}": sum(qid2hr.values()) / len(qid2hr) if qid2hr else 0.0}

src/retrievals/metrics/map.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,29 @@
33

44
def get_map(qid2positive: Dict[str, List[str]], qid2ranking: Dict[str, List[str]], cutoff_rank: int = 10):
55
"""
6-
qid2positive: {qid: [pos1_doc_id, pos2_doc_id]}
7-
qid2ranking: {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
6+
qid2positive (order doesn't matter): {qid: [pos1_doc_id, pos2_doc_id]}
7+
qid2ranking (order does matter): {qid: [rank1_doc_id, rank2_doc_id, rank3_doc_id]}
88
"""
99

10-
def average_precision(positives: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
10+
def average_precision(positives_ids: List[str], ranked_doc_ids: List[str], cutoff: int) -> float:
1111
"""
12-
for each cut_off, calculate its precision
12+
Average of precision for each cut_off
1313
"""
1414
hits = 0
1515
sum_precisions = 0.0
1616

1717
for rank, doc_id in enumerate(ranked_doc_ids[:cutoff], start=1):
18-
if doc_id in positives:
18+
if doc_id in positives_ids:
1919
hits += 1
2020
sum_precisions += hits / rank
2121

22-
return sum_precisions / len(positives) if positives else 0.0
22+
return sum_precisions / min(len(positives_ids), cutoff) if positives_ids else 0.0
2323

2424
qid2map = dict()
2525

2626
for qid in qid2positive:
27-
positives = qid2positive[qid]
27+
positives_ids = qid2positive[qid]
2828
ranked_doc_ids = qid2ranking[qid]
29-
qid2map[qid] = average_precision(positives, ranked_doc_ids, cutoff_rank)
29+
qid2map[qid] = average_precision(positives_ids, ranked_doc_ids, cutoff_rank)
3030

3131
return {f"map@{cutoff_rank}": sum(qid2map.values()) / len(qid2ranking.keys())}

0 commit comments

Comments
 (0)