Skip to content

Commit

Permalink
Arm2 (#2414)
Browse files Browse the repository at this point in the history
* Fix arm v7 build / improve api

* Update stubs.py

* Fix unit tests
  • Loading branch information
hlohaus authored Nov 24, 2024
1 parent 4744d0b commit 804a80b
Show file tree
Hide file tree
Showing 12 changed files with 255 additions and 226 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile-slim
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ RUN python -m pip install --upgrade pip \
--global-option=build_ext \
--global-option=-j8 \
pydantic==${PYDANTIC_VERSION} \
&& cat requirements.txt | xargs -n 1 pip install --no-cache-dir \
&& cat requirements-slim.txt | xargs -n 1 pip install --no-cache-dir || true \
# Remove build packages
&& pip uninstall --yes \
Cython \
Expand Down
3 changes: 1 addition & 2 deletions g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ async def create_async_generator(
seed: str = None,
**kwargs
) -> AsyncResult:
if model:
model = cls.get_model(model)
model = cls.get_model(model)
if model in cls.image_models:
if prompt is None:
prompt = messages[-1]["content"]
Expand Down
1 change: 1 addition & 0 deletions g4f/Provider/needs_auth/Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def __init__(self,
self.conversation_id = conversation_id
self.response_id = response_id
self.choice_id = choice_id

async def iter_filter_base64(response_iter: AsyncIterator[bytes]) -> AsyncIterator[bytes]:
search_for = b'[["wrb.fr","XqA3Ic","[\\"'
end_with = b'\\'
Expand Down
115 changes: 70 additions & 45 deletions g4f/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,37 @@
import shutil

import os.path
from fastapi import FastAPI, Response, Request, UploadFile
from fastapi import FastAPI, Response, Request, UploadFile, Depends
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from fastapi.security import APIKeyHeader
from starlette.exceptions import HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from starlette.status import (
HTTP_200_OK,
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_404_NOT_FOUND,
HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN
)
from fastapi.encoders import jsonable_encoder
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import FileResponse
from pydantic import BaseModel
from typing import Union, Optional, List
from pydantic import BaseModel, Field
from typing import Union, Optional, List, Annotated

import g4f
import g4f.debug
from g4f.client import AsyncClient, ChatCompletion, convert_to_provider
from g4f.client import AsyncClient, ChatCompletion, ImagesResponse, convert_to_provider
from g4f.providers.response import BaseConversation
from g4f.client.helper import filter_none
from g4f.image import is_accepted_format, images_dir
from g4f.typing import Messages
from g4f.errors import ProviderNotFoundError
from g4f.cookies import read_cookie_files, get_cookies_dir
from g4f.Provider import ProviderType, ProviderUtils, __providers__
from g4f.gui import get_gui_app

logger = logging.getLogger(__name__)

Expand All @@ -50,6 +59,10 @@ def create_app(g4f_api_key: str = None):
api.register_authorization()
api.register_validation_exception_handler()

if AppConfig.gui:
gui_app = WSGIMiddleware(get_gui_app())
app.mount("/", gui_app)

# Read cookie files if not ignored
if not AppConfig.ignore_cookie_files:
read_cookie_files()
Expand All @@ -61,17 +74,17 @@ def create_app_debug(g4f_api_key: str = None):
return create_app(g4f_api_key)

class ChatCompletionsConfig(BaseModel):
messages: Messages
model: str
provider: Optional[str] = None
messages: Messages = Field(examples=[[{"role": "system", "content": ""}, {"role": "user", "content": ""}]])
model: str = Field(default="")
provider: Optional[str] = Field(examples=[None])
stream: bool = False
temperature: Optional[float] = None
max_tokens: Optional[int] = None
stop: Union[list[str], str, None] = None
api_key: Optional[str] = None
web_search: Optional[bool] = None
proxy: Optional[str] = None
conversation_id: str = None
temperature: Optional[float] = Field(examples=[None])
max_tokens: Optional[int] = Field(examples=[None])
stop: Union[list[str], str, None] = Field(examples=[None])
api_key: Optional[str] = Field(examples=[None])
web_search: Optional[bool] = Field(examples=[None])
proxy: Optional[str] = Field(examples=[None])
conversation_id: Optional[str] = Field(examples=[None])

class ImageGenerationConfig(BaseModel):
prompt: str
Expand Down Expand Up @@ -101,6 +114,9 @@ class ModelResponseModel(BaseModel):
created: int
owned_by: Optional[str]

class ErrorResponseModel(BaseModel):
error: str

class AppConfig:
ignored_providers: Optional[list[str]] = None
g4f_api_key: Optional[str] = None
Expand All @@ -109,6 +125,7 @@ class AppConfig:
provider: str = None
image_provider: str = None
proxy: str = None
gui: bool = False

@classmethod
def set_config(cls, **data):
Expand All @@ -129,6 +146,8 @@ def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
self.conversations: dict[str, dict[str, BaseConversation]] = {}

security = HTTPBearer(auto_error=False)

def register_authorization(self):
@self.app.middleware("http")
async def authorization(request: Request, call_next):
Expand Down Expand Up @@ -192,7 +211,7 @@ async def models() -> list[ModelResponseModel]:
} for model_id, model in model_list.items()]

@self.app.get("/v1/models/{model_name}")
async def model_info(model_name: str):
async def model_info(model_name: str) -> ModelResponseModel:
if model_name in g4f.models.ModelUtils.convert:
model_info = g4f.models.ModelUtils.convert[model_name]
return JSONResponse({
Expand All @@ -201,20 +220,20 @@ async def model_info(model_name: str):
'created': 0,
'owned_by': model_info.base_provider
})
return JSONResponse({"error": "The model does not exist."}, 404)

@self.app.post("/v1/chat/completions")
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
return JSONResponse({"error": "The model does not exist."}, HTTP_404_NOT_FOUND)

@self.app.post("/v1/chat/completions", response_model=ChatCompletion)
async def chat_completions(
config: ChatCompletionsConfig,
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None,
provider: str = None
):
try:
config.provider = provider if config.provider is None else config.provider
if config.provider is None:
config.provider = AppConfig.provider
if config.api_key is None and request is not None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
api_key = auth_header.split(None, 1)[-1]
if api_key and api_key != "Bearer":
config.api_key = api_key
if credentials is not None:
config.api_key = credentials.credentials

conversation = return_conversation = None
if config.conversation_id is not None and config.provider is not None:
Expand Down Expand Up @@ -242,8 +261,7 @@ async def chat_completions(config: ChatCompletionsConfig, request: Request = Non
)

if not config.stream:
response: ChatCompletion = await response
return JSONResponse(response.to_json())
return await response

async def streaming():
try:
Expand All @@ -254,7 +272,7 @@ async def streaming():
self.conversations[config.conversation_id] = {}
self.conversations[config.conversation_id][config.provider] = chunk
else:
yield f"data: {json.dumps(chunk.to_json())}\n\n"
yield f"data: {chunk.json()}\n\n"
except GeneratorExit:
pass
except Exception as e:
Expand All @@ -268,15 +286,15 @@ async def streaming():
logger.exception(e)
return Response(content=format_exception(e, config), status_code=500, media_type="application/json")

@self.app.post("/v1/images/generate")
@self.app.post("/v1/images/generations")
async def generate_image(config: ImageGenerationConfig, request: Request):
if config.api_key is None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
api_key = auth_header.split(None, 1)[-1]
if api_key and api_key != "Bearer":
config.api_key = api_key
@self.app.post("/v1/images/generate", response_model=ImagesResponse)
@self.app.post("/v1/images/generations", response_model=ImagesResponse)
async def generate_image(
request: Request,
config: ImageGenerationConfig,
credentials: Annotated[HTTPAuthorizationCredentials, Depends(Api.security)] = None
):
if credentials is not None:
config.api_key = credentials.credentials
try:
response = await self.client.images.generate(
prompt=config.prompt,
Expand All @@ -291,7 +309,7 @@ async def generate_image(config: ImageGenerationConfig, request: Request):
for image in response.data:
if hasattr(image, "url") and image.url.startswith("/"):
image.url = f"{request.base_url}{image.url.lstrip('/')}"
return JSONResponse(response.to_json())
return response
except Exception as e:
logger.exception(e)
return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
Expand Down Expand Up @@ -342,22 +360,29 @@ def upload_cookies(files: List[UploadFile]):
file.file.close()
return response_data

@self.app.get("/v1/synthesize/{provider}")
@self.app.get("/v1/synthesize/{provider}", responses={
HTTP_200_OK: {"content": {"audio/*": {}}},
HTTP_404_NOT_FOUND: {"model": ErrorResponseModel},
HTTP_422_UNPROCESSABLE_ENTITY: {"model": ErrorResponseModel},
})
async def synthesize(request: Request, provider: str):
try:
provider_handler = convert_to_provider(provider)
except ProviderNotFoundError:
return Response("Provider not found", 404)
return JSONResponse({"error": "Provider not found"}, HTTP_404_NOT_FOUND)
if not hasattr(provider_handler, "synthesize"):
return Response("Provider doesn't support synthesize", 500)
return JSONResponse({"error": "Provider doesn't support synthesize"}, HTTP_404_NOT_FOUND)
if len(request.query_params) == 0:
return Response("Missing query params", 500)
return JSONResponse({"error": "Missing query params"}, HTTP_422_UNPROCESSABLE_ENTITY)
response_data = provider_handler.synthesize({**request.query_params})
content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
return StreamingResponse(response_data, media_type=content_type)

@self.app.get("/images/{filename}")
async def get_image(filename) -> FileResponse:
@self.app.get("/images/{filename}", response_class=FileResponse, responses={
HTTP_200_OK: {"content": {"image/*": {}}},
HTTP_404_NOT_FOUND: {}
})
async def get_image(filename):
target = os.path.join(images_dir, filename)

if not os.path.isfile(target):
Expand Down
4 changes: 3 additions & 1 deletion g4f/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def main():
api_parser = subparsers.add_parser("api")
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
api_parser.add_argument("--gui", "-g", default=False, action="store_true", help="Add gui to the api.")
api_parser.add_argument("--model", default=None, help="Default model for chat completion. (incompatible with --reload and --workers)")
api_parser.add_argument("--provider", choices=[provider.__name__ for provider in Provider.__providers__ if provider.working],
default=None, help="Default provider for chat completion. (incompatible with --reload and --workers)")
Expand Down Expand Up @@ -48,7 +49,8 @@ def run_api_args(args):
provider=args.provider,
image_provider=args.image_provider,
proxy=args.proxy,
model=args.model
model=args.model,
gui=args.gui,
)
g4f.cookies.browsers = [g4f.cookies[browser] for browser in args.cookie_browsers]
run_api(
Expand Down
20 changes: 10 additions & 10 deletions g4f/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def iter_response(
finish_reason = "stop"

if stream:
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))

if finish_reason is not None:
break
Expand All @@ -83,12 +83,12 @@ def iter_response(
finish_reason = "stop" if finish_reason is None else finish_reason

if stream:
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time()))
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()))

# Synchronous iter_append_model_and_provider function
def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
Expand Down Expand Up @@ -137,20 +137,20 @@ async def async_iter_response(
finish_reason = "stop"

if stream:
yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))

if finish_reason is not None:
break

finish_reason = "stop" if finish_reason is None else finish_reason

if stream:
yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
yield ChatCompletionChunk.model_construct(None, finish_reason, completion_id, int(time.time()))
else:
if response_format is not None and "type" in response_format:
if response_format["type"] == "json_object":
content = filter_json(content)
yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
yield ChatCompletion.model_construct(content, finish_reason, completion_id, int(time.time()))
finally:
if hasattr(response, 'aclose'):
await safe_aclose(response)
Expand Down Expand Up @@ -394,13 +394,13 @@ async def process_image_item(image_file: str) -> Image:
if response_format == "b64_json":
with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
image_data = base64.b64encode(file.read()).decode()
return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
return Image(url=image_file, revised_prompt=response.alt)
return Image.model_construct(url=image_file, b64_json=image_data, revised_prompt=response.alt)
return Image.model_construct(url=image_file, revised_prompt=response.alt)
images = await asyncio.gather(*[process_image_item(image) for image in images])
else:
images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
last_provider = get_last_provider(True)
return ImagesResponse(
return ImagesResponse.model_construct(
images,
model=last_provider.get("model") if model is None else model,
provider=last_provider.get("name") if provider is None else provider
Expand Down
Loading

0 comments on commit 804a80b

Please sign in to comment.