Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: advanced filters for feedbacks and chats admin api #525

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
28 changes: 10 additions & 18 deletions backend/app/api/admin_routes/feedback.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import APIRouter, Depends
from typing import Annotated

from fastapi import APIRouter, Depends, Query
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select

from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import Feedback, AdminFeedbackPublic
from app.models import AdminFeedbackPublic, FeedbackFilters
from app.repositories import feedback_repo

router = APIRouter()

Expand All @@ -13,20 +14,11 @@
def list_feedbacks(
session: SessionDep,
user: CurrentSuperuserDep,
filters: Annotated[FeedbackFilters, Query()],
params: Params = Depends(),
) -> Page[AdminFeedbackPublic]:
return paginate(
session,
select(Feedback).order_by(Feedback.created_at.desc()),
params,
transformer=lambda items: [
AdminFeedbackPublic(
**item.model_dump(),
chat_title=item.chat.title,
chat_origin=item.chat.origin,
chat_message_content=item.chat_message.content,
user_email=item.user.email if item.user else None,
)
for item in items
],
return feedback_repo.paginate(
session=session,
filters=filters,
params=params,
)
4 changes: 4 additions & 0 deletions backend/app/api/admin_routes/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class EmbeddingModelDescriptor(EmbeddingModelItem):

class UserDescriptor(BaseModel):
id: UUID
email: str


class KnowledgeBaseDescriptor(BaseModel):
Expand All @@ -37,6 +38,9 @@ class ChatEngineDescriptor(BaseModel):
name: str
is_default: bool

class ChatOriginDescriptor(BaseModel):
id: UUID
origin: str

class ChatEngineBasedRetrieveRequest(BaseModel):
query: str
Expand Down
43 changes: 42 additions & 1 deletion backend/app/api/admin_routes/stats.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
from datetime import date
from pydantic import BaseModel
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate

from sqlmodel import select

from app.api.deps import CurrentSuperuserDep, SessionDep
from app.repositories import chat_repo
from app.models import Feedback

from app.repositories import chat_repo
from app.api.admin_routes.models import (
ChatOriginDescriptor
)


router = APIRouter()

Expand Down Expand Up @@ -31,3 +42,33 @@ def chat_origin_trend(
) -> ChatStats:
stats = chat_repo.chat_trend_by_origin(session, start_date, end_date)
return ChatStats(start_date=start_date, end_date=end_date, values=stats)

@router.get("/admin/stats/chats/origins")
sszgwdk marked this conversation as resolved.
Show resolved Hide resolved
def list_chat_origins(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> list[ChatOriginDescriptor]:
chat_origins = []
# chats = session.exec(select(Chat.origin, Chat.id).order_by(Chat.created_at.desc()))
for chat in chat_repo.list_chat_origins(session):
chat_origins.append(
ChatOriginDescriptor(
id=chat.id,
origin=chat.origin,
)
)
return chat_origins


@router.get("/admin/stats/feedbacks/origins")
def list_feedback_origins(
session: SessionDep,
user: CurrentSuperuserDep,
params: Params = Depends(),
) -> Page[str]:
return paginate(
session,
select(Feedback.origin).distinct().order_by(Feedback.origin.asc()),
params,
)
37 changes: 37 additions & 0 deletions backend/app/api/admin_routes/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from fastapi import APIRouter, Depends
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate
from sqlmodel import select, col

from app.api.deps import SessionDep, CurrentSuperuserDep
from app.models import User

from app.api.admin_routes.models import (
UserDescriptor,
)

router = APIRouter()


@router.get("/admin/users/search")
def search_users(
session: SessionDep,
user: CurrentSuperuserDep,
search: str | None = None,
params: Params = Depends(),
) -> Page[UserDescriptor]:
query = select(User).order_by(User.id)
if search:
query = query.where(col(User.email).contains(search))
return paginate(
session,
query,
params,
transformer=lambda items: [
UserDescriptor(
id=item.id,
email=item.email,
)
for item in items
],
)
2 changes: 2 additions & 0 deletions backend/app/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
stats as admin_stats,
semantic_cache as admin_semantic_cache,
langfuse as admin_langfuse,
user as admin_user,
)
from app.api.admin_routes.evaluation import (
evaluation_task as admin_evaluation_task,
Expand Down Expand Up @@ -79,6 +80,7 @@
api_router.include_router(
admin_evaluation_dataset.router, tags=["admin/evaluation/dataset"]
)
api_router.include_router(admin_user.router, tags=["admin/user"])

api_router.include_router(
fastapi_users.get_auth_router(auth_backend), prefix="/auth", tags=["auth"]
Expand Down
9 changes: 5 additions & 4 deletions backend/app/api/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import logging
from uuid import UUID
from typing import List, Optional
from typing import List, Optional, Annotated
from http import HTTPStatus

from pydantic import (
BaseModel,
field_validator,
)
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Query
from fastapi.responses import StreamingResponse
from fastapi_pagination import Params, Page

from app.api.deps import SessionDep, OptionalUserDep, CurrentUserDep
from app.repositories import chat_repo
from app.models import Chat, ChatUpdate
from app.models import Chat, ChatUpdate, ChatFilters
from app.rag.chat import (
ChatService,
ChatEvent,
Expand Down Expand Up @@ -145,10 +145,11 @@ def list_chats(
request: Request,
session: SessionDep,
user: OptionalUserDep,
filters: Annotated[ChatFilters, Query()],
params: Params = Depends(),
) -> Page[Chat]:
browser_id = request.state.browser_id
return chat_repo.paginate(session, user, browser_id, params)
return chat_repo.paginate(session, user, browser_id, filters, params)


@router.get("/chats/{chat_id}")
Expand Down
3 changes: 2 additions & 1 deletion backend/app/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
Feedback,
FeedbackType,
AdminFeedbackPublic,
FeedbackFilters,
)
from .semantic_cache import SemanticCache
from .staff_action_log import StaffActionLog
from .chat_engine import ChatEngine, ChatEngineUpdate
from .chat import Chat, ChatUpdate, ChatVisibility
from .chat import Chat, ChatUpdate, ChatVisibility, ChatFilters
from .chat_message import ChatMessage
from .document import Document, DocIndexTaskStatus
from .chunk import Chunk, KgIndexStatus
Expand Down
9 changes: 9 additions & 0 deletions backend/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,12 @@ class Chat(UUIDBaseModel, UpdatableBaseModel, table=True):
class ChatUpdate(BaseModel):
title: Optional[str] = None
visibility: Optional[ChatVisibility] = None

class ChatFilters(BaseModel):
created_at_start: Optional[datetime] = None
created_at_end: Optional[datetime] = None
updated_at_start: Optional[datetime] = None
updated_at_end: Optional[datetime] = None
chat_origin: Optional[str] = None
# user_id: Optional[UUID] = None # no use now
engine_id: Optional[int] = None
10 changes: 10 additions & 0 deletions backend/app/models/feedback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import enum
from uuid import UUID
from typing import Optional
from pydantic import BaseModel
from datetime import datetime

from sqlmodel import (
Field,
Expand Down Expand Up @@ -63,3 +65,11 @@ class AdminFeedbackPublic(BaseFeedback):
chat_message_content: str
user_id: Optional[UUID]
user_email: Optional[str]

class FeedbackFilters(BaseModel):
created_at_start: Optional[datetime] = None
created_at_end: Optional[datetime] = None
feedback_origin: Optional[str] = None
chat_id: Optional[UUID] = None
feedback_type: Optional[FeedbackType] = None
user_id: Optional[UUID] = None
1 change: 1 addition & 0 deletions backend/app/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .chunk import chunk_repo
from .data_source import data_source_repo
from .knowledge_base import knowledge_base_repo
from .feedback import feedback_repo
30 changes: 27 additions & 3 deletions backend/app/repositories/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from datetime import datetime, UTC, date, timedelta
from collections import defaultdict

from sqlmodel import select, Session, or_, func, case, desc
from sqlmodel import select, Session, or_, func, case, desc, col
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate

from app.models import Chat, User, ChatMessage, ChatUpdate
from app.models import Chat, User, ChatMessage, ChatUpdate, ChatFilters
from app.repositories.base_repo import BaseRepo


Expand All @@ -20,6 +20,7 @@ def paginate(
session: Session,
user: User | None,
browser_id: str | None,
filters: ChatFilters,
params: Params | None = Params(),
) -> Page[Chat]:
query = select(Chat).where(Chat.deleted_at == None)
Expand All @@ -30,6 +31,23 @@ def paginate(
)
else:
query = query.where(Chat.browser_id == browser_id, Chat.user_id == None)

# filters
if filters.created_at_start:
query = query.where(Chat.created_at >= filters.created_at_start)
if filters.created_at_end:
query = query.where(Chat.created_at <= filters.created_at_end)
if filters.updated_at_start:
query = query.where(Chat.updated_at >= filters.updated_at_start)
if filters.updated_at_end:
query = query.where(Chat.updated_at <= filters.updated_at_end)
if filters.chat_origin:
query = query.where(col(Chat.origin).contains(filters.chat_origin))
# if filters.user_id:
# query = query.where(Chat.user_id == filters.user_id)
if filters.engine_id:
query = query.where(Chat.engine_id == filters.engine_id)

query = query.order_by(Chat.created_at.desc())
return paginate(session, query, params)

Expand Down Expand Up @@ -202,6 +220,12 @@ def chat_trend_by_origin(

stats.sort(key=lambda x: x["date"])
return stats


def list_chat_origins(self, session: Session):
return session.exec(
select(Chat.origin, Chat.id)
sszgwdk marked this conversation as resolved.
Show resolved Hide resolved
.distinct()
.order_by(Chat.created_at.desc())
)

chat_repo = ChatRepo()
49 changes: 49 additions & 0 deletions backend/app/repositories/feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from sqlmodel import select, Session, col
from fastapi_pagination import Params, Page
from fastapi_pagination.ext.sqlmodel import paginate

from app.models import Feedback, AdminFeedbackPublic, FeedbackFilters
from app.repositories.base_repo import BaseRepo

class FeedbackRepo(BaseRepo):
model_cls = Feedback

def paginate(
self,
session: Session,
filters: FeedbackFilters,
params: Params | None = Params(),
) -> Page[AdminFeedbackPublic]:
# build the select statement via conditions
stmt = select(Feedback)
if filters.created_at_start:
stmt = stmt.where(Feedback.created_at >= filters.created_at_start)
if filters.created_at_end:
stmt = stmt.where(Feedback.created_at <= filters.created_at_end)
if filters.feedback_origin:
stmt = stmt.where(col(Feedback.origin).contains(filters.feedback_origin))
if filters.chat_id:
stmt = stmt.where(Feedback.chat_id == filters.chat_id)
if filters.feedback_type:
stmt = stmt.where(Feedback.feedback_type == filters.feedback_type)
if filters.user_id:
stmt = stmt.where(Feedback.user_id == filters.user_id)

stmt = stmt.order_by(Feedback.created_at.desc())
return paginate(
session,
stmt,
params,
transformer=lambda items: [
AdminFeedbackPublic(
**item.model_dump(),
chat_title=item.chat.title,
chat_origin=item.chat.origin,
chat_message_content=item.chat_message.content,
user_email=item.user.email if item.user else None,
)
for item in items
],
)

feedback_repo = FeedbackRepo()
Loading