Skip to content

Commit

Permalink
feat(backend): support config reranker (pingcap#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd0517 authored Aug 8, 2024
1 parent 62de210 commit 9761d4a
Show file tree
Hide file tree
Showing 17 changed files with 413 additions and 32 deletions.
4 changes: 0 additions & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,5 @@ TIDB_USER=
TIDB_PASSWORD=
TIDB_DATABASE=

# Replace with your own Jina AI API key
# You can get one from https://jina.ai/reranker/
JINAAI_API_KEY=

# *** DO NOT CHANGE BELOW CONFIGURATIONS UNLESS YOU KNOW WHAT YOU ARE DOING
DSP_CACHEBOOL=false
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ A conversational search tool based on GraphRAG (Knowledge Graph) that built on t
>
> 1. Set up a [TiDB Serverless cluster](https://docs.pingcap.com/tidbcloud/tidb-cloud-quickstart).
> 2. Install [Docker Compose](https://docs.docker.com/compose/install/).
> 3. Jina AI API key, get one from [Jina AI](https://jina.ai/reranker/).
1. Clone the repository:

Expand All @@ -52,7 +51,6 @@ A conversational search tool based on GraphRAG (Knowledge Graph) that built on t

Replace the following placeholders with your own values:
- `SECRET_KEY`: you can generate a random secret key using `python3 -c "import secrets; print(secrets.token_urlsafe(32))"`
- `JINAAI_API_KEY`: get one from [Jina AI](https://jina.ai/reranker/)
- `TIDB_HOST`, `TIDB_USER`, `TIDB_PASSWORD` and `TIDB_DATABASE`: get them from your [TiDB Serverless cluster](https://tidbcloud.com/)

- Note: TiDB Serverless will provide a default database name called `test`, if you want to use another database name, you need to create a new database in the TiDB Serverless console.
Expand Down
20 changes: 12 additions & 8 deletions backend/app/alembic/versions/bd17a4ebccc5_.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,34 @@
Create Date: 2024-08-08 01:20:42.069228
"""

from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType


# revision identifiers, used by Alembic.
revision = 'bd17a4ebccc5'
down_revision = 'a8c79553c9f6'
revision = "bd17a4ebccc5"
down_revision = "a8c79553c9f6"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('data_sources', sa.Column('deleted_at', sa.DateTime(), nullable=True))
op.drop_index('source_uri', table_name='documents')
op.add_column('relationships', sa.Column('chunk_id', sqlmodel.sql.sqltypes.GUID(), nullable=True))
op.add_column("data_sources", sa.Column("deleted_at", sa.DateTime(), nullable=True))
op.drop_index("source_uri", table_name="documents")
op.add_column(
"relationships",
sa.Column("chunk_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('relationships', 'chunk_id')
op.create_index('source_uri', 'documents', ['source_uri'], unique=True)
op.drop_column('data_sources', 'deleted_at')
op.drop_column("relationships", "chunk_id")
op.create_index("source_uri", "documents", ["source_uri"], unique=True)
op.drop_column("data_sources", "deleted_at")
# ### end Alembic commands ###
66 changes: 66 additions & 0 deletions backend/app/alembic/versions/e32f1e546eec_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""empty message
Revision ID: e32f1e546eec
Revises: bd17a4ebccc5
Create Date: 2024-08-08 03:55:14.042290
"""

from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes
from tidb_vector.sqlalchemy import VectorType
from app.models.base import AESEncryptedColumn


# revision identifiers, used by Alembic.
revision = "e32f1e546eec"
down_revision = "bd17a4ebccc5"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reranker_models",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(length=64), nullable=False),
sa.Column(
"provider",
sa.Enum("JINA", "COHERE", name="rerankerprovider"),
nullable=False,
),
sa.Column(
"model", sqlmodel.sql.sqltypes.AutoString(length=256), nullable=False
),
sa.Column("top_n", sa.Integer(), nullable=False),
sa.Column("config", sa.JSON(), nullable=True),
sa.Column("is_default", sa.Boolean(), nullable=False),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("credentials", AESEncryptedColumn(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_engines", sa.Column("reranker_id", sa.Integer(), nullable=True))
op.create_foreign_key(
None, "chat_engines", "reranker_models", ["reranker_id"], ["id"]
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_engines", "reranker_id")
op.drop_table("reranker_models")
# ### end Alembic commands ###
121 changes: 119 additions & 2 deletions backend/app/api/admin_routes/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,26 @@
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, update
from langfuse import Langfuse
from llama_index.core.schema import NodeWithScore, TextNode

from app.core.config import settings
from app.api.deps import SessionDep, CurrentSuperuserDep
from app.rag.llm_option import admin_llm_options, LLMOption
from app.rag.embed_model_option import admin_embed_model_options, EmbeddingModelOption
from app.rag.chat_config import get_llm, get_embedding_model
from app.rag.reranker_model_option import (
admin_reranker_model_options,
RerankerModelOption,
)
from app.rag.chat_config import get_llm, get_embedding_model, get_reranker_model
from app.models import (
ChatEngine,
LLM,
AdminLLM,
EmbeddingModel,
AdminEmbeddingModel,
RerankerModel,
AdminRerankerModel,
)
from app.site_settings import SiteSetting

router = APIRouter()

Expand Down Expand Up @@ -213,3 +219,114 @@ def test_langfuse(
success = False
error = str(e)
return LangfuseTestResult(success=success, error=error)


@router.get("/admin/reranker-models/options")
def get_reranker_model_options(
user: CurrentSuperuserDep,
) -> List[RerankerModelOption]:
return admin_reranker_model_options


@router.post("/admin/reranker-models/test")
def test_reranker_model(
db_reranker_model: RerankerModel,
user: CurrentSuperuserDep,
) -> LLMTestResult:
try:
reranker = get_reranker_model(
provider=db_reranker_model.provider,
model=db_reranker_model.model,
# for testing purpose, we only rerank 2 nodes
top_n=2,
config=db_reranker_model.config,
credentials=db_reranker_model.credentials,
)
nodes = reranker.postprocess_nodes(
nodes=[
NodeWithScore(
node=TextNode(
text="TiDB is a distributed SQL database.",
),
score=0.8,
),
NodeWithScore(
node=TextNode(
text="TiDB is compatible with MySQL protocol.",
),
score=0.6,
),
NodeWithScore(
node=TextNode(
text="TiFlash is a columnar storage engine.",
),
score=0.4,
),
],
query_str="What is TiDB?",
)
success = True
error = ""
except Exception as e:
success = False
error = str(e)
return LLMTestResult(success=success, error=error)


@router.get("/admin/reranker-models")
def list_reranker_models(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[AdminRerankerModel]:
return paginate(
session,
select(RerankerModel).order_by(RerankerModel.created_at.desc()),
params,
)


@router.post("/admin/reranker-models")
def create_reranker_model(
reranker_model: RerankerModel,
session: SessionDep,
user: CurrentSuperuserDep,
) -> AdminRerankerModel:
session.add(reranker_model)
session.commit()
session.refresh(reranker_model)
return reranker_model


@router.get("/admin/reranker-models/{reranker_model_id}")
def get_reranker_model_detail(
reranker_model_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
) -> AdminRerankerModel:
reranker_model = session.get(RerankerModel, reranker_model_id)
if reranker_model is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Reranker model not found"
)
return reranker_model


@router.delete("/admin/reranker-models/{reranker_model_id}")
def delete_reranker_model(
reranker_id: int,
session: SessionDep,
user: CurrentSuperuserDep,
):
reranker_model = session.get(RerankerModel, reranker_id)
if reranker_model is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Reranker model not found"
)
session.exec(
update(ChatEngine)
.where(ChatEngine.reranker_id == reranker_id)
.values(reranker_id=None)
)
session.delete(reranker_model)
session.commit()
11 changes: 5 additions & 6 deletions backend/app/api/routes/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from app.api.deps import SessionDep
from app.site_settings import SiteSetting
from app.rag.chat import check_rag_required_config
from app.rag.chat import check_rag_required_config, check_rag_optional_config

router = APIRouter()

Expand All @@ -28,6 +28,7 @@ class RequiredConfigStatus(BaseModel):

class OptionalConfigStatus(BaseModel):
langfuse: bool
default_reranker: bool


class SystemConfigStatusResponse(BaseModel):
Expand All @@ -40,17 +41,15 @@ def system_bootstrap_status(session: SessionDep) -> SystemConfigStatusResponse:
has_default_llm, has_default_embedding_model, has_datasource = (
check_rag_required_config(session)
)
langfuse, default_reranker = check_rag_optional_config(session)
return SystemConfigStatusResponse(
required=RequiredConfigStatus(
default_llm=has_default_llm,
default_embedding_model=has_default_embedding_model,
datasource=has_datasource,
),
optional=OptionalConfigStatus(
langfuse=bool(
SiteSetting.langfuse_host
and SiteSetting.langfuse_secret_key
and SiteSetting.langfuse_public_key
)
langfuse=langfuse,
default_reranker=default_reranker,
),
)
1 change: 1 addition & 0 deletions backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .data_source import DataSource, DataSourceType
from .llm import LLM, AdminLLM
from .embed_model import EmbeddingModel, AdminEmbeddingModel
from .reranker_model import RerankerModel, AdminRerankerModel
6 changes: 6 additions & 0 deletions backend/app/models/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class ChatEngine(UpdatableBaseModel, table=True):
"foreign_keys": "ChatEngine.fast_llm_id",
},
)
reranker_id: Optional[int] = Field(foreign_key="reranker_models.id", nullable=True)
reranker: "RerankerModel" = SQLRelationship(
sa_relationship_kwargs={
"foreign_keys": "ChatEngine.reranker_id",
},
)
is_default: bool = Field(default=False)
deleted_at: Optional[datetime] = Field(default=None, sa_column=Column(DateTime))

Expand Down
2 changes: 1 addition & 1 deletion backend/app/models/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class RelationshipBase(SQLModel):
source_entity_id: int = Field(foreign_key="entities.id")
target_entity_id: int = Field(foreign_key="entities.id")
last_modified_at: Optional[datetime] = Field(sa_column=Column(DateTime))
chunk_id: Optional[UUID] = Field(default=None)


class Relationship(RelationshipBase, table=True):
Expand All @@ -82,7 +83,6 @@ class Relationship(RelationshipBase, table=True):
"lazy": "joined",
},
)
chunk_id: UUID = Field(nullable=True)

__tablename__ = "relationships"

Expand Down
26 changes: 26 additions & 0 deletions backend/app/models/reranker_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional, Any

from sqlmodel import Field, Column, JSON

from .base import UpdatableBaseModel, AESEncryptedColumn
from app.types import RerankerProvider


class BaseRerankerModel(UpdatableBaseModel):
name: str = Field(max_length=64)
provider: RerankerProvider
model: str = Field(max_length=256)
top_n: int = Field(default=10)
config: dict | list | None = Field(sa_column=Column(JSON), default={})
is_default: bool = Field(default=False)


class RerankerModel(BaseRerankerModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
credentials: Any = Field(sa_column=Column(AESEncryptedColumn, nullable=True))

__tablename__ = "reranker_models"


class AdminRerankerModel(BaseRerankerModel):
id: int
Loading

0 comments on commit 9761d4a

Please sign in to comment.