Skip to content

Commit

Permalink
add orjson for jsonresponse (#1688)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil authored Oct 17, 2024
1 parent ecb8bad commit b0facb3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [

[project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular",
"packaging", "pillow", "psutil", "pydantic", "python-multipart",
"orjson", "packaging", "pillow", "psutil", "pydantic", "python-multipart",
"torchao", "uvicorn", "uvloop", "zmq",
"outlines>=0.0.44", "modelscope"]
# xpu is not enabled in public vllm and torch whl,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Dict, List

from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError

try:
Expand Down Expand Up @@ -101,7 +101,7 @@ def create_error_response(
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code)
return ORJSONResponse(content=error.model_dump(), status_code=error.code)


def create_streaming_error_response(
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import uvloop
from fastapi import FastAPI, File, Form, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from fastapi.responses import ORJSONResponse, Response, StreamingResponse

from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer
Expand Down Expand Up @@ -176,12 +176,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
success, message = await tokenizer_manager.update_weights(obj, request)
content = {"success": success, "message": message}
if success:
return JSONResponse(
return ORJSONResponse(
content,
status_code=HTTPStatus.OK,
)
else:
return JSONResponse(
return ORJSONResponse(
content,
status_code=HTTPStatus.BAD_REQUEST,
)
Expand Down Expand Up @@ -211,7 +211,7 @@ async def stream_results():
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)

Expand All @@ -226,7 +226,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)

Expand All @@ -241,7 +241,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return JSONResponse(
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import requests
import torch
import torch.distributed as dist
from fastapi.responses import JSONResponse
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from torch import nn
from torch.profiler import ProfilerActivity, profile, record_function
Expand Down Expand Up @@ -566,7 +566,7 @@ async def authentication(request, call_next):
if request.url.path.startswith("/health"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + api_key:
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
return await call_next(request)


Expand Down

0 comments on commit b0facb3

Please sign in to comment.