Skip to content

Commit

Permalink
feat(backend): support ollama llm; support ollama, jina, cohere embed…
Browse files Browse the repository at this point in the history
…ding (pingcap#247)
  • Loading branch information
wd0517 authored Aug 27, 2024
1 parent 4c65b13 commit a4e56ee
Show file tree
Hide file tree
Showing 21 changed files with 299 additions and 68 deletions.
10 changes: 8 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@ TIDB_PASSWORD=
TIDB_DATABASE=
TIDB_SSL=true

# *** DO NOT CHANGE BELOW CONFIGURATIONS UNLESS YOU KNOW WHAT YOU ARE DOING
DSP_CACHEBOOL=false
# CAUTION: Do not change EMBEDDING_DIMS after initializing the database.
# Changing the embedding dimensions requires recreating the database and tables.
# The default EMBEDDING_DIMS and EMBEDDING_MAX_TOKENS are set for the OpenAI text-embedding-3-small model.
# If using a different embedding model, adjust these values according to the model's specifications.
# For example:
# maidalun1020/bce-embedding-base_v1: EMBEDDING_DIMS=768 EMBEDDING_MAX_TOKENS=512
EMBEDDING_DIMS=1536
EMBEDDING_MAX_TOKENS=8191
76 changes: 76 additions & 0 deletions backend/app/alembic/versions/00534dc350db_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""empty message
Revision ID: 00534dc350db
Revises: 10f36e8a25c4
Create Date: 2024-08-26 12:46:00.203425
"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql

# revision identifiers, used by Alembic.
revision = "00534dc350db"
down_revision = "10f36e8a25c4"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"embedding_models",
"provider",
existing_type=mysql.ENUM("OPENAI"),
type_=sa.String(length=32),
existing_nullable=False,
)
op.alter_column(
"llms",
"provider",
existing_type=mysql.ENUM(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", "BEDROCK"
),
type_=sa.String(length=32),
existing_nullable=False,
)
op.alter_column(
"reranker_models",
"provider",
existing_type=mysql.ENUM("JINA", "COHERE", "BAISHENG"),
type_=sa.String(length=32),
existing_nullable=False,
)
op.execute("UPDATE embedding_models SET provider = lower(provider)")
op.execute("UPDATE llms SET provider = lower(provider)")
op.execute("UPDATE reranker_models SET provider = lower(provider)")
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"reranker_models",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM("JINA", "COHERE", "BAISHENG"),
existing_nullable=False,
)
op.alter_column(
"llms",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM(
"OPENAI", "GEMINI", "ANTHROPIC_VERTEX", "OPENAI_LIKE", "BEDROCK"
),
existing_nullable=False,
)
op.alter_column(
"embedding_models",
"provider",
existing_type=sa.String(length=32),
type_=mysql.ENUM("OPENAI"),
existing_nullable=False,
)
# ### end Alembic commands ###
13 changes: 7 additions & 6 deletions backend/app/alembic/versions/2fc10c21bf88_.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy.dialects import mysql
from app.core.config import settings

# revision identifiers, used by Alembic.
revision = "2fc10c21bf88"
Expand Down Expand Up @@ -98,13 +99,13 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"description_vec",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column(
"meta_vec",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
Expand All @@ -116,14 +117,14 @@ def upgrade():
sa.Column("query", sa.Text(), nullable=True),
sa.Column(
"query_vec",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
sa.Column("value", sa.Text(), nullable=True),
sa.Column(
"value_vec",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
Expand Down Expand Up @@ -289,7 +290,7 @@ def upgrade():
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column(
"embedding",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
Expand Down Expand Up @@ -329,7 +330,7 @@ def upgrade():
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"description_vec",
VectorType(dim=1536),
VectorType(dim=settings.EMBEDDING_DIMS),
nullable=True,
comment="hnsw(distance=cosine)",
),
Expand Down
2 changes: 1 addition & 1 deletion backend/app/api/admin_routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_embedding_model(
credentials=db_embed_model.credentials,
)
embedding = embed_model.get_query_embedding("Hello, world!")
expected_length = settings.EMBEDDOMG_DIMS
expected_length = settings.EMBEDDING_DIMS
if len(embedding) != expected_length:
raise ValueError(
f"Currently we only support {expected_length} dims embedding, got {len(embedding)} dims."
Expand Down
10 changes: 8 additions & 2 deletions backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ def server_host(self) -> str:

COMPLIED_INTENT_ANALYSIS_PROGRAM_PATH: str | None = None

# Currently, we only support 1536 dims for the embedding model
EMBEDDOMG_DIMS: int = 1536
# CAUTION: Do not change EMBEDDING_DIMS after initializing the database.
# Changing the embedding dimensions requires recreating the database and tables.
# The default EMBEDDING_DIMS and EMBEDDING_MAX_TOKENS are set for the OpenAI text-embedding-3-small model.
# If using a different embedding model, adjust these values according to the model's specifications.
# For example:
# maidalun1020/bce-embedding-base_v1: EMBEDDING_DIMS=768 EMBEDDING_MAX_TOKENS=512
EMBEDDING_DIMS: int = 1536
EMBEDDING_MAX_TOKENS: int = 8191

@computed_field # type: ignore[misc]
@property
Expand Down
5 changes: 4 additions & 1 deletion backend/app/models/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tidb_vector.sqlalchemy import VectorType
from llama_index.core.schema import TextNode

from app.core.config import settings
from .base import UpdatableBaseModel, UUIDBaseModel


Expand All @@ -27,7 +28,9 @@ class Chunk(UUIDBaseModel, UpdatableBaseModel, table=True):
text: str = Field(sa_column=Column(Text))
meta: dict | list = Field(default={}, sa_column=Column(JSON))
embedding: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
document_id: int = Field(foreign_key="documents.id", nullable=True)
document: "Document" = SQLRelationship(
Expand Down
4 changes: 2 additions & 2 deletions backend/app/models/embed_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional, Any

from sqlmodel import Field, Column, JSON
from sqlmodel import Field, Column, JSON, String

from .base import UpdatableBaseModel, AESEncryptedColumn
from app.types import EmbeddingProvider


class BaseEmbeddingModel(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: EmbeddingProvider
provider: EmbeddingProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)
Expand Down
14 changes: 11 additions & 3 deletions backend/app/models/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from tidb_vector.sqlalchemy import VectorType

from app.core.config import settings


class EntityType(str, enum.Enum):
original = "original"
Expand All @@ -34,10 +36,14 @@ class EntityBase(SQLModel):
class Entity(EntityBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
description_vec: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
meta_vec: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)

__tablename__ = "entities"
Expand Down Expand Up @@ -70,7 +76,9 @@ class RelationshipBase(SQLModel):
class Relationship(RelationshipBase, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
description_vec: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
source_entity: Entity = SQLModelRelationship(
sa_relationship_kwargs={
Expand Down
4 changes: 2 additions & 2 deletions backend/app/models/llm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional, Any

from sqlmodel import Field, Column, JSON
from sqlmodel import Field, Column, JSON, String

from .base import UpdatableBaseModel, AESEncryptedColumn
from app.types import LLMProvider


class BaseLLM(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: LLMProvider
provider: LLMProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)
Expand Down
4 changes: 2 additions & 2 deletions backend/app/models/reranker_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import Optional, Any

from sqlmodel import Field, Column, JSON
from sqlmodel import Field, Column, JSON, String

from .base import UpdatableBaseModel, AESEncryptedColumn
from app.types import RerankerProvider


class BaseRerankerModel(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: RerankerProvider
provider: RerankerProvider = Field(sa_column=Column(String(32), nullable=False))
model: str = Field(max_length=256)
top_n: int = Field(default=10)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
Expand Down
10 changes: 8 additions & 2 deletions backend/app/models/semantic_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
)
from tidb_vector.sqlalchemy import VectorType

from app.core.config import settings


class SemanticCache(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
query: str = Field(sa_column=Column(Text))
query_vec: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
value: str = Field(sa_column=Column(Text))
value_vec: Any = Field(
sa_column=Column(VectorType(1536), comment="hnsw(distance=cosine)")
sa_column=Column(
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
)
)
meta: List | Dict = Field(default={}, sa_column=Column(JSON))
created_at: datetime = Field(
Expand Down
9 changes: 7 additions & 2 deletions backend/app/rag/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from app.rag.node_parser import MarkdownNodeParser
from app.rag.vector_store.tidb_vector_store import TiDBVectorStore
from app.rag.chat_config import get_default_embedding_model
from app.core.config import settings
from app.models import (
Document as DBDocument,
Chunk as DBChunk,
Expand Down Expand Up @@ -45,9 +46,13 @@ def build_vector_index_from_document(
if db_document.mime_type.lower() == "text/markdown":
# spliter = MarkdownNodeParser()
# TODO: FIX MarkdownNodeParser
spliter = SentenceSplitter()
spliter = SentenceSplitter(
chunk_size=settings.EMBEDDING_MAX_TOKENS,
)
else:
spliter = SentenceSplitter()
spliter = SentenceSplitter(
chunk_size=settings.EMBEDDING_MAX_TOKENS,
)

_transformations = [
spliter,
Expand Down
24 changes: 24 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
from llama_index.llms.openai_like import OpenAILike
from llama_index.llms.gemini import Gemini
from llama_index.llms.bedrock import Bedrock
from llama_index.llms.ollama import Ollama
from llama_index.core.llms.llm import LLM
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.embeddings.jinaai import JinaEmbedding
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.postprocessor.jinaai_rerank import JinaRerank
from llama_index.postprocessor.cohere_rerank import CohereRerank
from sqlmodel import Session, select
Expand Down Expand Up @@ -208,6 +212,10 @@ def get_llm(
if "max_tokens" not in config:
config.update(max_tokens=4096)
return AnthropicVertex(model=model, credentials=google_creds, **config)
case LLMProvider.OLLAMA:
config.setdefault("request_timeout", 60 * 5)
config.setdefault("context_window", 4096)
return Ollama(model=model, **config)
case _:
raise ValueError(f"Got unknown LLM provider: {provider}")

Expand Down Expand Up @@ -241,6 +249,22 @@ def get_embedding_model(
api_key=credentials,
**config,
)
case EmbeddingProvider.JINA:
return JinaEmbedding(
model=model,
api_key=credentials,
**config,
)
case EmbeddingProvider.COHERE:
return CohereEmbedding(
model_name=model,
cohere_api_key=credentials,
)
case EmbeddingProvider.OLLAMA:
return OllamaEmbedding(
model_name=model,
**config,
)
case _:
raise ValueError(f"Got unknown embedding provider: {provider}")

Expand Down
Loading

0 comments on commit a4e56ee

Please sign in to comment.