Skip to content

Commit

Permalink
Add LiteLLM as a representation model (MaartenGr#2213)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Dec 10, 2024
1 parent 50d9a49 commit 84dbf36
Show file tree
Hide file tree
Showing 27 changed files with 285 additions and 76 deletions.
2 changes: 1 addition & 1 deletion bertopic/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BaseCluster:
```python
from bertopic import BERTopic
from bertopic.dimensionality import BaseCluster
from bertopic.cluster import BaseCluster
empty_cluster_model = BaseCluster()
Expand Down
8 changes: 8 additions & 0 deletions bertopic/representation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
msg = "`pip install openai` \n\n"
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)

# LiteLLM Generator
try:
from bertopic.representation._litellm import LiteLLM
except ModuleNotFoundError:
msg = "`pip install litellm` \n\n"
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)

# LangChain Generator
try:
from bertopic.representation._langchain import LangChain
Expand Down Expand Up @@ -63,6 +70,7 @@
"Cohere",
"OpenAI",
"LangChain",
"LiteLLM",
"LlamaCPP",
"VisualRepresentation",
]
176 changes: 176 additions & 0 deletions bertopic/representation/_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import time
from litellm import completion
import pandas as pd
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import retry_with_exponential_backoff


DEFAULT_PROMPT = """
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
Based on the information above, extract a short topic label in the following format:
topic: <topic label>
"""


class LiteLLM(BaseRepresentation):
"""Using the LiteLLM API to generate topic labels.
For an overview of models see:
https://docs.litellm.ai/docs/providers
Arguments:
model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo".
generator_kwargs: Kwargs passed to `litellm.completion`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
exponential_backoff: Retry requests with a random exponential backoff.
A short sleep is used when a rate limit error is hit,
then the requests is retried. Increase the sleep length
if errors are hit until 10 unsuccesfull requests.
If True, overrides `delay_in_seconds`.
nr_docs: The number of documents to pass to LiteLLM if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to LiteLLM.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
Usage:
To use this, you will need to install the litellm package first:
`pip install litellm`
Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows:
```python
import os
from bertopic.representation import LiteLLM
from bertopic import BERTopic
# set ENV variables
os.environ["OPENAI_API_KEY"] = "your-openai-key"
# Create your representation model
representation_model = LiteLLM(model="gpt-3.5-turbo")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
representation_model = LiteLLM(model="gpt", prompt=prompt)
```
""" # noqa: D301

def __init__(
self,
model: str = "gpt-3.5-turbo",
prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
exponential_backoff: bool = False,
nr_docs: int = 4,
diversity: float = None,
):
self.model = model
self.prompt = prompt if prompt else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.nr_docs = nr_docs
self.diversity = diversity

self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
self.model = generator_kwargs.get("model")
if self.generator_kwargs.get("prompt"):
del self.generator_kwargs["prompt"]

def extract_topics(
self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)

# Generate using a (Large) Language Model
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)

# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
if self.exponential_backoff:
response = chat_completions_with_backoff(**kwargs)
else:
response = completion(**kwargs)
label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "")

updated_topics[topic] = [(label, 1)]

return updated_topics

def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]

# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

return prompt

@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt


def chat_completions_with_backoff(**kwargs):
return retry_with_exponential_backoff(
completion,
)(**kwargs)
7 changes: 6 additions & 1 deletion docs/algorithm/algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ The following models are implemented in `bertopic.representation`:
* `PartOfSpeech`
* `KeyBERTInspired`
* `ZeroShotClassification`
* `TextGeneration`
* `TextGeneration` (HuggingFace)
* `Cohere`
* `OpenAI`
* `LangChain`
* `LiteLLM`
* `LlamaCPP`

!!! tip Models
There are roughly two sets of models. **First** are the non-generative set of models that you can find [here](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html). These include models that focus on enhancing the keywords in the topic representations. **Second** are the generative models that attempt to label or summarize the topics instead. You can find an overview of [implemented LLMs here](https://maartengr.github.io/BERTopic/getting_started/representation/llm).
3 changes: 3 additions & 0 deletions docs/api/backends.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Backends`

::: bertopic.backend
3 changes: 0 additions & 3 deletions docs/api/backends/base.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/cohere.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/openai.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/word_doc.md

This file was deleted.

File renamed without changes.
3 changes: 3 additions & 0 deletions docs/api/cluster.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `BaseCluster`

::: bertopic.cluster._base.BaseCluster
File renamed without changes.
3 changes: 0 additions & 3 deletions docs/api/onlinecv.md

This file was deleted.

3 changes: 3 additions & 0 deletions docs/api/plotting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Plotting`

::: bertopic.plotting
3 changes: 0 additions & 3 deletions docs/api/representation/base.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/cohere.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/generation.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/keybert.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/langchain.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/mmr.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/openai.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/pos.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/zeroshot.md

This file was deleted.

3 changes: 3 additions & 0 deletions docs/api/representations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Representations`

::: bertopic.representation
3 changes: 3 additions & 0 deletions docs/api/vectorizers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Vectorizers`

::: bertopic.vectorizers._online_cv.OnlineCountVectorizer
Loading

0 comments on commit 84dbf36

Please sign in to comment.