Skip to content

Commit 52360d7

Browse files
authored
feat: add the rag example @LongxingTan (#15)
1 parent a67c731 commit 52360d7

File tree

7 files changed

+130
-12
lines changed

7 files changed

+130
-12
lines changed

examples/rag_langchain.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1+
import torch
2+
import transformers
3+
from langchain.chains import RetrievalQA
4+
from langchain.llms import HuggingFacePipeline
15
from langchain.retrievers import ContextualCompressionRetriever
26
from langchain.text_splitter import RecursiveCharacterTextSplitter
37
from langchain_community.document_loaders import PyPDFLoader
48
from langchain_community.vectorstores import FAISS
59
from langchain_community.vectorstores.utils import DistanceStrategy
10+
from transformers import AutoModelForCausalLM, AutoTokenizer
11+
from transformers.generation import GenerationConfig
612

7-
from retrievals import AutoModelForEmbedding, RerankModel
8-
from retrievals.tools import LangchainReranker, RagFeature
13+
from retrievals.tools.langchain import LangchainEmbedding, LangchainReranker, RagFeature
914

10-
embed_model = AutoModelForEmbedding(model_name_or_path='')
11-
rerank_model = LangchainReranker(model_name_or_path='', top_n=5, device='cuda')
15+
16+
class CFG:
17+
retrieval_model = 'BAAI/bge-large-zh'
18+
rerank_model = ''
19+
llm_model = 'Qwen/Qwen-7B-Chat'
20+
21+
22+
embed_model = LangchainEmbedding(model_name_or_path=CFG.retrieval_model)
23+
rerank_model = LangchainReranker(model_name_or_path=CFG.rerank_model, top_n=5, device='cuda')
1224

1325

1426
documents = PyPDFLoader("llama.pdf").load()
@@ -19,5 +31,30 @@
1931
search_type="similarity", search_kwargs={"score_threshold": 0.3, "k": 10}
2032
)
2133

22-
compression_retriever = ContextualCompressionRetriever(base_compressor=rerank_model, base_retriever=retriever)
23-
response = compression_retriever.get_relevant_documents("What is Llama 2?")
34+
# compression_retriever = ContextualCompressionRetriever(base_compressor=rerank_model, base_retriever=retriever)
35+
# response = compression_retriever.get_relevant_documents("What is Llama 2?")
36+
37+
38+
tokenizer = AutoTokenizer.from_pretrained(CFG.llm_model, trust_remote_code=True)
39+
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
40+
n_gpus = torch.cuda.device_count()
41+
max_memory = {i: max_memory for i in range(n_gpus)}
42+
model = AutoModelForCausalLM.from_pretrained(
43+
CFG.llm_model, device_map='auto', load_in_4bit=True, max_memory=max_memory, trust_remote_code=True, fp16=True
44+
)
45+
model = model.eval()
46+
model.generation_config = GenerationConfig.from_pretrained(CFG.llm_model, trust_remote_code=True)
47+
48+
query_pipeline = transformers.pipeline(
49+
"text-generation",
50+
model=model,
51+
tokenizer=tokenizer,
52+
torch_dtype=torch.float16,
53+
device_map="auto",
54+
)
55+
56+
llm = HuggingFacePipeline(pipeline=query_pipeline)
57+
58+
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, verbose=True)
59+
60+
qa.run('你看了这篇文章后有何感性?')

src/retrievals/models/embedding_auto.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
2+
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
33

44
import faiss
55
import numpy as np
@@ -60,6 +60,9 @@ class AutoModelForEmbedding(nn.Module):
6060
from the Hugging Face Hub with that name.
6161
"""
6262

63+
encode_kwargs: Dict[str, Any] = dict()
64+
show_progress: bool = False
65+
6366
def __init__(
6467
self,
6568
model_name_or_path: str,
@@ -184,7 +187,17 @@ def forward_from_loader(self, inputs):
184187
return embeddings
185188

186189
def forward_from_text(self, texts):
187-
return self.forward_from_loader(texts)
190+
batch_dict = self.tokenizer(
191+
texts,
192+
max_length=self.max_length,
193+
return_attention_mask=False,
194+
padding=False,
195+
truncation=True,
196+
)
197+
batch_dict["input_ids"] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"]]
198+
batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors="pt")
199+
batch_dict.pop("token_type_ids")
200+
return self.forward_from_loader(batch_dict)
188201

189202
def encode(
190203
self,
@@ -197,7 +210,7 @@ def encode(
197210
device: str = None,
198211
normalize_embeddings: bool = False,
199212
):
200-
if isinstance(inputs, DataLoader):
213+
if isinstance(inputs, (BatchEncoding, Dict)):
201214
return self.encode_from_loader(
202215
loader=inputs,
203216
batch_size=batch_size,
@@ -208,7 +221,7 @@ def encode(
208221
device=device,
209222
normalize_embeddings=normalize_embeddings,
210223
)
211-
elif isinstance(inputs, (str, Iterable)):
224+
elif isinstance(inputs, (str, List, Tuple)):
212225
return self.encode_from_text(
213226
sentences=inputs,
214227
batch_size=batch_size,
@@ -219,6 +232,17 @@ def encode(
219232
device=device,
220233
normalize_embeddings=normalize_embeddings,
221234
)
235+
else:
236+
raise ValueError
237+
238+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
239+
"""Compute doc embeddings using a HuggingFace transformer model."""
240+
embeddings = self.encode(texts, show_progress_bar=self.show_progress, **self.encode_kwargs)
241+
return embeddings.tolist()
242+
243+
def embed_query(self, text: str) -> List[float]:
244+
"""Compute query embeddings using a HuggingFace transformer model."""
245+
return self.embed_documents([text])[0]
222246

223247
def encode_from_loader(
224248
self,

src/retrievals/models/rag.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from pathlib import Path
2+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
3+
4+
from transformers import AutoModel
5+
6+
7+
class RAG(object):
8+
def __init__(self):
9+
pass
10+
11+
@classmethod
12+
def from_pretrained(
13+
cls,
14+
model_name_or_path: Union[str, Path],
15+
n_gpu: int = -1,
16+
verbose: int = 1,
17+
index_root: Optional[str] = None,
18+
):
19+
instance = cls()
20+
instance.model = AutoModel()
21+
return instance
22+
23+
@classmethod
24+
def from_index(cls, index_path: Union[str, Path], n_gpu: int = -1, verbose: int = 1):
25+
instance = cls()
26+
index_path = Path(index_path)
27+
instance.model = AutoModel()
28+
29+
return instance
30+
31+
def add_to_index(self):
32+
return
33+
34+
def encode(self):
35+
return
36+
37+
def index(self):
38+
return
39+
40+
def search(self):
41+
return
42+
43+
44+
class Generator(object):
45+
def __init__(self):
46+
pass

src/retrievals/tools/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from src.retrievals.tools.langchain import LangchainReranker, RagFeature
2-
from src.retrievals.tools.llama_index import LlamaIndexReranker
+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class CorpusProcessor(object):
2+
def __init__(self):
3+
pass

src/retrievals/tools/langchain.py

+7
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@
99
MarkdownTextSplitter,
1010
)
1111
from langchain_core.documents import Document
12+
from langchain_core.embeddings import Embeddings
1213
from langchain_core.pydantic_v1 import Extra, root_validator
1314

15+
from src.retrievals.models.embedding_auto import AutoModelForEmbedding
1416
from src.retrievals.models.rerank import RerankModel
1517

1618

19+
class LangchainEmbedding(AutoModelForEmbedding, Embeddings):
20+
def __init__(self, **kwargs):
21+
super().__init__(**kwargs)
22+
23+
1724
class LangchainReranker(BaseDocumentCompressor):
1825
class Config:
1926
"""Configuration for this pydantic object."""

tests/test_models/test_embedding_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def test_encode_from_text(self):
149149
assert emb.shape == (3, 384)
150150
# assert abs(np.sum(emb) - 32.14627) < 0.001
151151

152+
def test_forward_from_text(self):
153+
pass
154+
152155

153156
class PairwiseModelTest(TestCase, ModelTesterMixin):
154157
def setUp(self) -> None:

0 commit comments

Comments
 (0)