Skip to content
Open

Docs #214

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions endpoints/Kobold/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from common.auth import check_api_key
from common.model import check_model_container
from common.utils import unwrap
from endpoints.core.types.tags import Tags
from endpoints.core.utils.model import get_current_model
from endpoints.Kobold.types.generation import (
AbortRequest,
Expand All @@ -25,14 +26,14 @@


api_name = "KoboldAI"
router = APIRouter(prefix="/api")
router = APIRouter(prefix="/api", tags=[Tags.Kobold])
urls = {
"Generation": "http://{host}:{port}/api/v1/generate",
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
}

kai_router = APIRouter()
extra_kai_router = APIRouter()
kai_router = APIRouter(tags=[Tags.Kobold])
extra_kai_router = APIRouter(tags=[Tags.Kobold])


def setup():
Expand All @@ -46,8 +47,10 @@ def setup():
@kai_router.post(
"/generate",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
"""Generate a response to a prompt."""
response = await get_generation(data, request)

return response
Expand All @@ -56,8 +59,10 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
@extra_kai_router.post(
"/generate/stream",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
"""Stream the chat response to a prompt."""
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)

return response
Expand All @@ -66,8 +71,10 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe
@extra_kai_router.post(
"/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def abort_generate(data: AbortRequest) -> AbortResponse:
"""Aborts a generation from the cache."""
response = await abort_generation(data.genkey)

return response
Expand All @@ -76,19 +83,24 @@ async def abort_generate(data: AbortRequest) -> AbortResponse:
@extra_kai_router.get(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
@extra_kai_router.post(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
"""Fetches the status of a generation from the cache."""
response = await generation_status(data.genkey)

return response


@kai_router.get(
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
"/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def current_model() -> CurrentModelResponse:
"""Fetches the current model and who owns it."""
Expand All @@ -100,8 +112,10 @@ async def current_model() -> CurrentModelResponse:
@extra_kai_router.post(
"/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
"""Get the number of tokens in a given prompt."""
raw_tokens = model.container.encode_tokens(data.prompt)
tokens = unwrap(raw_tokens, [])
return TokenCountResponse(value=len(tokens), ids=tokens)
Expand All @@ -110,14 +124,17 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
@kai_router.get(
"/config/max_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
@kai_router.get(
"/config/max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
@extra_kai_router.get(
"/true_max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
tags=[Tags.Kobold],
)
async def get_max_length() -> MaxLengthResponse:
"""Fetches the max length of the model."""
Expand All @@ -126,35 +143,35 @@ async def get_max_length() -> MaxLengthResponse:
return {"value": max_length}


@kai_router.get("/info/version")
@kai_router.get("/info/version", tags=[Tags.Kobold])
async def get_version():
"""Impersonate KAI United."""

return {"result": "1.2.5"}


@extra_kai_router.get("/version")
@extra_kai_router.get("/version", tags=[Tags.Kobold])
async def get_extra_version():
"""Impersonate Koboldcpp."""

return {"result": "KoboldCpp", "version": "1.71"}


@kai_router.get("/config/soft_prompts_list")
@kai_router.get("/config/soft_prompts_list", tags=[Tags.Kobold])
async def get_available_softprompts():
"""Used for KAI compliance."""

return {"values": []}


@kai_router.get("/config/soft_prompt")
@kai_router.get("/config/soft_prompt", tags=[Tags.Kobold])
async def get_current_softprompt():
"""Used for KAI compliance."""

return {"value": ""}


@kai_router.put("/config/soft_prompt")
@kai_router.put("/config/soft_prompt", tags=[Tags.Kobold])
async def set_current_softprompt():
"""Used for KAI compliance."""

Expand Down
14 changes: 9 additions & 5 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
from endpoints.core.types.tags import Tags


api_name = "OAI"
router = APIRouter()
router = APIRouter(tags=[Tags.OpenAI])
urls = {
"Completions": "http://{host}:{port}/v1/completions",
"Chat completions": "http://{host}:{port}/v1/chat/completions",
Expand All @@ -41,8 +42,7 @@ def setup():

# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key)],
"/v1/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI]
)
async def completion_request(
request: Request, data: CompletionRequest
Expand Down Expand Up @@ -96,8 +96,7 @@ async def completion_request(

# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key)],
"/v1/chat/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI]
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
Expand Down Expand Up @@ -156,8 +155,13 @@ async def chat_completion_request(
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
tags=[Tags.OpenAI],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
"""Generate Text embeddings for a given text input.

Requires Infinity embed to be installed and an embedding model to be loaded.
"""
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
Expand Down
Loading
Loading