From 93203ef46ab8fef35a2c0a7d9e469150f337bc41 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:37:38 +0100 Subject: [PATCH 1/3] Add tags to API endpoints --- endpoints/Kobold/router.py | 24 +++++++++--- endpoints/OAI/router.py | 8 ++-- endpoints/core/router.py | 74 ++++++++++++++++++++++++++++-------- endpoints/core/types/tags.py | 13 +++++++ 4 files changed, 94 insertions(+), 25 deletions(-) create mode 100644 endpoints/core/types/tags.py diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 310a3809..c8e45bb4 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -6,6 +6,7 @@ from common.auth import check_api_key from common.model import check_model_container from common.utils import unwrap +from endpoints.core.types.tags import Tags from endpoints.core.utils.model import get_current_model from endpoints.Kobold.types.generation import ( AbortRequest, @@ -46,6 +47,7 @@ def setup(): @kai_router.post( "/generate", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: response = await get_generation(data, request) @@ -56,6 +58,7 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: @extra_kai_router.post( "/generate/stream", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: response = EventSourceResponse(stream_generation(data, request), ping=maxsize) @@ -66,6 +69,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe @extra_kai_router.post( "/abort", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def abort_generate(data: AbortRequest) -> AbortResponse: response = await abort_generation(data.genkey) @@ -76,10 +80,12 @@ async def abort_generate(data: AbortRequest) -> AbortResponse: @extra_kai_router.get( "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @extra_kai_router.post( "/generate/check", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: response = await generation_status(data.genkey) @@ -88,7 +94,9 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: @kai_router.get( - "/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] + "/model", + dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def current_model() -> CurrentModelResponse: """Fetches the current model and who owns it.""" @@ -100,6 +108,7 @@ async def current_model() -> CurrentModelResponse: @extra_kai_router.post( "/tokencount", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: raw_tokens = model.container.encode_tokens(data.prompt) @@ -110,14 +119,17 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: @kai_router.get( "/config/max_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @kai_router.get( "/config/max_context_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) @extra_kai_router.get( "/true_max_context_length", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Kobold], ) async def get_max_length() -> MaxLengthResponse: """Fetches the max length of the model.""" @@ -126,35 +138,35 @@ async def get_max_length() -> MaxLengthResponse: return {"value": max_length} -@kai_router.get("/info/version") +@kai_router.get("/info/version", tags=[Tags.Kobold]) async def get_version(): """Impersonate KAI United.""" return {"result": "1.2.5"} -@extra_kai_router.get("/version") +@extra_kai_router.get("/version", tags=[Tags.Kobold]) async def get_extra_version(): """Impersonate Koboldcpp.""" return {"result": "KoboldCpp", "version": "1.71"} -@kai_router.get("/config/soft_prompts_list") +@kai_router.get("/config/soft_prompts_list", tags=[Tags.Kobold]) async def get_available_softprompts(): """Used for KAI compliance.""" return {"values": []} -@kai_router.get("/config/soft_prompt") +@kai_router.get("/config/soft_prompt", tags=[Tags.Kobold]) async def get_current_softprompt(): """Used for KAI compliance.""" return {"value": ""} -@kai_router.put("/config/soft_prompt") +@kai_router.put("/config/soft_prompt", tags=[Tags.Kobold]) async def set_current_softprompt(): """Used for KAI compliance.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b6a44c98..094aab62 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -25,6 +25,7 @@ stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings +from endpoints.core.types.tags import Tags api_name = "OAI" @@ -41,8 +42,7 @@ def setup(): # Completions endpoint @router.post( - "/v1/completions", - dependencies=[Depends(check_api_key)], + "/v1/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI] ) async def completion_request( request: Request, data: CompletionRequest @@ -96,8 +96,7 @@ async def completion_request( # Chat completions endpoint @router.post( - "/v1/chat/completions", - dependencies=[Depends(check_api_key)], + "/v1/chat/completions", dependencies=[Depends(check_api_key)], tags=[Tags.OpenAI] ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -156,6 +155,7 @@ async def chat_completion_request( @router.post( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + tags=[Tags.OpenAI], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: embeddings_task = asyncio.create_task(get_embeddings(data, request)) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 2c60cd77..deb17661 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -28,6 +28,7 @@ SamplerOverrideListResponse, SamplerOverrideSwitchRequest, ) +from endpoints.core.types.tags import Tags from endpoints.core.types.template import TemplateList, TemplateSwitchRequest from endpoints.core.types.token import ( TokenDecodeRequest, @@ -48,7 +49,7 @@ # Healthcheck endpoint -@router.get("/health") +@router.get("/health", tags=[Tags.Core]) async def healthcheck(response: Response) -> HealthCheckResponse: """Get the current service health status""" healthy, issues = await HealthManager.is_service_healthy() @@ -62,8 +63,12 @@ async def healthcheck(response: Response) -> HealthCheckResponse: # Model list endpoint -@router.get("/v1/models", dependencies=[Depends(check_api_key)]) -@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/models", + dependencies=[Depends(check_api_key)], + tags=[Tags.OpenAI, Tags.List], +) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)], tags=[Tags.List]) async def list_models(request: Request) -> ModelList: """ Lists all models in the model directory. @@ -91,6 +96,7 @@ async def list_models(request: Request) -> ModelList: @router.get( "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.List], ) async def current_model() -> ModelCard: """Returns the currently loaded model.""" @@ -98,7 +104,11 @@ async def current_model() -> ModelCard: return get_current_model() -@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/draft/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_draft_models(request: Request) -> ModelList: """ Lists all draft models in the model directory. @@ -118,7 +128,9 @@ async def list_draft_models(request: Request) -> ModelList: # Load model endpoint -@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +@router.post( + "/v1/model/load", dependencies=[Depends(check_admin_key)], tags=[Tags.Admin] +) async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: """Loads a model into the model container. This returns an SSE stream.""" @@ -163,13 +175,14 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse: @router.post( "/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_model(): """Unloads the currently loaded model.""" await model.unload_model(skip_wait=True) -@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +@router.post("/v1/download", dependencies=[Depends(check_admin_key)], tags=[Tags.Admin]) async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse: """Downloads a model from HuggingFace.""" @@ -191,8 +204,8 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes # Lora list endpoint -@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/loras", dependencies=[Depends(check_api_key)], tags=[Tags.List]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)], tags=[Tags.List]) async def list_all_loras(request: Request) -> LoraList: """ Lists all LoRAs in the lora directory. @@ -213,6 +226,7 @@ async def list_all_loras(request: Request) -> LoraList: @router.get( "/v1/lora", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.List], ) async def active_loras() -> LoraList: """Returns the currently loaded loras.""" @@ -224,6 +238,7 @@ async def active_loras() -> LoraList: @router.post( "/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: """Loads a LoRA into the model container.""" @@ -259,6 +274,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse: @router.post( "/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_loras(): """Unloads the currently loaded loras.""" @@ -266,7 +282,11 @@ async def unload_loras(): await model.unload_loras() -@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/model/embedding/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_embedding_models(request: Request) -> ModelList: """ Lists all embedding models in the model directory. @@ -288,6 +308,7 @@ async def list_embedding_models(request: Request) -> ModelList: @router.get( "/v1/model/embedding", dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], + tags=[Tags.List], ) async def get_embedding_model() -> ModelCard: """Returns the currently loaded embedding model.""" @@ -296,7 +317,11 @@ async def get_embedding_model() -> ModelCard: return models.data[0] -@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +@router.post( + "/v1/model/embedding/load", + dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], +) async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: @@ -343,6 +368,7 @@ async def load_embedding_model( @router.post( "/v1/model/embedding/unload", dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], + tags=[Tags.Admin], ) async def unload_embedding_model(): """Unloads the current embedding model.""" @@ -354,6 +380,7 @@ async def unload_embedding_model(): @router.post( "/v1/token/encode", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Tokenisation], ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" @@ -384,6 +411,7 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: @router.post( "/v1/token/decode", dependencies=[Depends(check_api_key), Depends(check_model_container)], + tags=[Tags.Tokenisation], ) async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: """Decodes tokens into a string.""" @@ -394,7 +422,9 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: return response -@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/auth/permission", dependencies=[Depends(check_api_key)], tags=[Tags.Auth] +) async def key_permission(request: Request) -> AuthPermissionResponse: """ Gets the access level/permission of a provided key in headers. @@ -414,8 +444,10 @@ async def key_permission(request: Request) -> AuthPermissionResponse: raise HTTPException(400, error_message) from exc -@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/templates", dependencies=[Depends(check_api_key)], tags=[Tags.List]) +@router.get( + "/v1/template/list", dependencies=[Depends(check_api_key)], tags=[Tags.List] +) async def list_templates(request: Request) -> TemplateList: """ Get a list of all templates. @@ -437,6 +469,7 @@ async def list_templates(request: Request) -> TemplateList: @router.post( "/v1/template/switch", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def switch_template(data: TemplateSwitchRequest): """Switch the currently loaded template.""" @@ -464,6 +497,7 @@ async def switch_template(data: TemplateSwitchRequest): @router.post( "/v1/template/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], + tags=[Tags.Admin], ) async def unload_template(): """Unloads the currently selected template""" @@ -472,8 +506,16 @@ async def unload_template(): # Sampler override endpoints -@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +@router.get( + "/v1/sampling/overrides", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) +@router.get( + "/v1/sampling/override/list", + dependencies=[Depends(check_api_key)], + tags=[Tags.List], +) async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: """ List all currently applied sampler overrides. @@ -494,6 +536,7 @@ async def list_sampler_overrides(request: Request) -> SamplerOverrideListRespons @router.post( "/v1/sampling/override/switch", dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], ) async def switch_sampler_override(data: SamplerOverrideSwitchRequest): """Switch the currently loaded override preset""" @@ -523,6 +566,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): @router.post( "/v1/sampling/override/unload", dependencies=[Depends(check_admin_key)], + tags=[Tags.Admin], ) async def unload_sampler_override(): """Unloads the currently selected override preset""" diff --git a/endpoints/core/types/tags.py b/endpoints/core/types/tags.py new file mode 100644 index 00000000..0d785885 --- /dev/null +++ b/endpoints/core/types/tags.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class Tags(str, Enum): + """openapi endpoint groups""" + + OpenAI = "OpenAI" + Kobold = "Kobold" + Admin = "Admin" + List = "List" + Tokenisation = "Tokenisation" + Core = "Core" + Auth = "Auth" From 3e1ef555dc69422fd608ac43999b22a471a7eaad Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 29 Sep 2024 21:05:29 +0100 Subject: [PATCH 2/3] add docstrings to all endpoints --- endpoints/Kobold/router.py | 11 ++++++++--- endpoints/OAI/router.py | 6 +++++- endpoints/core/router.py | 2 ++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index c8e45bb4..9d4c5f58 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -26,14 +26,14 @@ api_name = "KoboldAI" -router = APIRouter(prefix="/api") +router = APIRouter(prefix="/api", tags=[Tags.Kobold]) urls = { "Generation": "http://{host}:{port}/api/v1/generate", "Streaming": "http://{host}:{port}/api/extra/generate/stream", } -kai_router = APIRouter() -extra_kai_router = APIRouter() +kai_router = APIRouter(tags=[Tags.Kobold]) +extra_kai_router = APIRouter(tags=[Tags.Kobold]) def setup(): @@ -50,6 +50,7 @@ def setup(): tags=[Tags.Kobold], ) async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: + """Generate a response to a prompt.""" response = await get_generation(data, request) return response @@ -61,6 +62,7 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: tags=[Tags.Kobold], ) async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: + """Stream the chat response to a prompt.""" response = EventSourceResponse(stream_generation(data, request), ping=maxsize) return response @@ -72,6 +74,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe tags=[Tags.Kobold], ) async def abort_generate(data: AbortRequest) -> AbortResponse: + """Aborts a generation from the cache.""" response = await abort_generation(data.genkey) return response @@ -88,6 +91,7 @@ async def abort_generate(data: AbortRequest) -> AbortResponse: tags=[Tags.Kobold], ) async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: + """Fetches the status of a generation from the cache.""" response = await generation_status(data.genkey) return response @@ -111,6 +115,7 @@ async def current_model() -> CurrentModelResponse: tags=[Tags.Kobold], ) async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: + """Get the number of tokens in a given prompt.""" raw_tokens = model.container.encode_tokens(data.prompt) tokens = unwrap(raw_tokens, []) return TokenCountResponse(value=len(tokens), ids=tokens) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 094aab62..2143c173 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -29,7 +29,7 @@ api_name = "OAI" -router = APIRouter() +router = APIRouter(tags=[Tags.OpenAI]) urls = { "Completions": "http://{host}:{port}/v1/completions", "Chat completions": "http://{host}:{port}/v1/chat/completions", @@ -158,6 +158,10 @@ async def chat_completion_request( tags=[Tags.OpenAI], ) async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + """Generate Text embeddings for a given text input. + + Requires Infinity embed to be installed and an embedding model to be loaded. + """ embeddings_task = asyncio.create_task(get_embeddings(data, request)) response = await run_with_request_disconnect( request, diff --git a/endpoints/core/router.py b/endpoints/core/router.py index deb17661..715ba74a 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -325,6 +325,8 @@ async def get_embedding_model() -> ModelCard: async def load_embedding_model( request: Request, data: EmbeddingModelLoadRequest ) -> ModelLoadResponse: + """Loads an embedding model.""" + # Verify request parameters if not data.name: error_message = handle_request_error( From 677d14fee8a50f3dc9114b19755065aef16d895c Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 29 Sep 2024 21:22:00 +0100 Subject: [PATCH 3/3] add description for model fields --- endpoints/core/types/auth.py | 4 +-- endpoints/core/types/download.py | 49 +++++++++++++++++++++++++------- endpoints/core/types/template.py | 2 +- endpoints/core/types/token.py | 22 +++++++------- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/endpoints/core/types/auth.py b/endpoints/core/types/auth.py index b8f3aa2e..070ac88e 100644 --- a/endpoints/core/types/auth.py +++ b/endpoints/core/types/auth.py @@ -1,7 +1,7 @@ """Types for auth requests.""" -from pydantic import BaseModel +from pydantic import BaseModel, Field class AuthPermissionResponse(BaseModel): - permission: str + permission: str = Field(description="The permission level of the API key") diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index cf49501f..a5d73702 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import List, Optional +from typing import List, Literal, Optional def _generate_include_list(): @@ -9,18 +9,45 @@ def _generate_include_list(): class DownloadRequest(BaseModel): """Parameters for a HuggingFace repo download.""" - repo_id: str - repo_type: str = "model" - folder_name: Optional[str] = None - revision: Optional[str] = None - token: Optional[str] = None - include: List[str] = Field(default_factory=_generate_include_list) - exclude: List[str] = Field(default_factory=list) - chunk_limit: Optional[int] = None - timeout: Optional[int] = None + repo_id: str = Field( + description="The repo ID to download from", + examples=[ + "royallab/TinyLlama-1.1B-2T-exl2", + "royallab/LLaMA2-13B-TiefighterLR-exl2", + "turboderp/Llama-3.1-8B-Instruct-exl2", + ], + ) + repo_type: Literal["model", "lora"] = Field("model", description="The model type") + folder_name: Optional[str] = Field( + default=None, + description="The folder name to save the repo to " + + "(this is used to load the model)", + ) + revision: Optional[str] = Field( + default=None, description="The revision to download from" + ) + token: Optional[str] = Field( + default=None, + description="The HuggingFace API token to use, " + + "required for private/gated repos", + ) + include: List[str] = Field( + default_factory=_generate_include_list, + description="A list of file patterns to include in the download", + ) + exclude: List[str] = Field( + default_factory=list, + description="A list of file patterns to exclude from the download", + ) + chunk_limit: Optional[int] = Field( + None, description="The maximum chunk size to download in bytes" + ) + timeout: Optional[int] = Field( + None, description="The timeout for the download in seconds" + ) class DownloadResponse(BaseModel): """Response for a download request.""" - download_path: str + download_path: str = Field(description="The path to the downloaded repo") diff --git a/endpoints/core/types/template.py b/endpoints/core/types/template.py index d72d6210..c0932912 100644 --- a/endpoints/core/types/template.py +++ b/endpoints/core/types/template.py @@ -12,4 +12,4 @@ class TemplateList(BaseModel): class TemplateSwitchRequest(BaseModel): """Request to switch a template.""" - name: str + name: str = Field(description="The name of the template to switch to") diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 945adbf5..87f3f022 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,15 +1,17 @@ """Tokenization types""" -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing import Dict, List, Union class CommonTokenRequest(BaseModel): """Represents a common tokenization request.""" - add_bos_token: bool = True - encode_special_tokens: bool = True - decode_special_tokens: bool = True + add_bos_token: bool = Field( + True, description="Add the BOS (beginning of sequence) token" + ) + encode_special_tokens: bool = Field(True, description="Encode special tokens") + decode_special_tokens: bool = Field(True, description="Decode special tokens") def get_params(self): """Get the parameters for tokenization.""" @@ -23,29 +25,29 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict[str, str]]] + text: Union[str, List[Dict[str, str]]] = Field(description="The string to encode") class TokenEncodeResponse(BaseModel): """Represents a tokenization response.""" - tokens: List[int] - length: int + tokens: List[int] = Field(description="The tokens") + length: int = Field(description="The length of the tokens") class TokenDecodeRequest(CommonTokenRequest): """ " Represents a detokenization request.""" - tokens: List[int] + tokens: List[int] = Field(description="The string to encode") class TokenDecodeResponse(BaseModel): """Represents a detokenization response.""" - text: str + text: str = Field(description="The decoded text") class TokenCountResponse(BaseModel): """Represents a token count response.""" - length: int + length: int = Field(description="The length of the text")