diff --git a/python/pyproject.toml b/python/pyproject.toml index 116ba9c9849..800ce0837e2 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ [project.optional-dependencies] runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", - "packaging", "pillow", "psutil", "pydantic", "python-multipart", + "orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart", "torchao", "uvicorn", "uvloop", "zmq", "outlines>=0.0.44", "modelscope"] # xpu is not enabled in public vllm and torch whl, diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 4e70546dfdd..b4727dfd7b6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -25,7 +25,7 @@ from typing import Dict, List from fastapi import HTTPException, Request, UploadFile -from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.responses import ORJSONResponse, StreamingResponse from pydantic import ValidationError try: @@ -101,7 +101,7 @@ def create_error_response( status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, ): error = ErrorResponse(message=message, type=err_type, code=status_code.value) - return JSONResponse(content=error.model_dump(), status_code=error.code) + return ORJSONResponse(content=error.model_dump(), status_code=error.code) def create_streaming_error_response( diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f0b7abbe3e9..8f851c757f1 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -40,7 +40,7 @@ import uvloop from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse +from fastapi.responses import ORJSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.hf_transformers_utils import get_tokenizer @@ -176,12 +176,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): success, message = await tokenizer_manager.update_weights(obj, request) content = {"success": success, "message": message} if success: - return JSONResponse( + return ORJSONResponse( content, status_code=HTTPStatus.OK, ) else: - return JSONResponse( + return ORJSONResponse( content, status_code=HTTPStatus.BAD_REQUEST, ) @@ -211,7 +211,7 @@ async def stream_results(): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return JSONResponse( + return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST ) @@ -226,7 +226,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return JSONResponse( + return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST ) @@ -241,7 +241,7 @@ async def judge_request(obj: RewardReqInput, request: Request): ret = await tokenizer_manager.generate_request(obj, request).__anext__() return ret except ValueError as e: - return JSONResponse( + return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index c5aecb73938..0f5401c8e9e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,7 +35,7 @@ import requests import torch import torch.distributed as dist -from fastapi.responses import JSONResponse +from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from torch import nn from torch.profiler import ProfilerActivity, profile, record_function @@ -566,7 +566,7 @@ async def authentication(request, call_next): if request.url.path.startswith("/health"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + api_key: - return JSONResponse(content={"error": "Unauthorized"}, status_code=401) + return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401) return await call_next(request)