Skip to content

Commit

Permalink
Add multiple images support (#2478)
Browse files Browse the repository at this point in the history
* Add multiple images support

* Add multiple images support in gui

* Support multiple images in legacy client and in the api
Fix some model names in provider model list

* Fix unittests

* Add vision and providers docs
  • Loading branch information
hlohaus authored Dec 13, 2024
1 parent bb9132b commit 335c971
Show file tree
Hide file tree
Showing 26 changed files with 1,008 additions and 326 deletions.
575 changes: 575 additions & 0 deletions docs/providers.md

Large diffs are not rendered by default.

83 changes: 83 additions & 0 deletions docs/vision.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
## Vision Support in Chat Completion

This documentation provides an overview of how to integrate vision support into chat completions using an API and a client. It includes examples to guide you through the process.

### Example with the API

To use vision support in chat completion with the API, follow the example below:

```python
import requests
import json
from g4f.image import to_data_uri
from g4f.requests.raise_for_status import raise_for_status

url = "http://localhost:8080/v1/chat/completions"
body = {
"model": "",
"provider": "Copilot",
"messages": [
{"role": "user", "content": "what are on this image?"}
],
"images": [
["data:image/jpeg;base64,...", "cat.jpeg"]
]
}
response = requests.post(url, json=body, headers={"g4f-api-key": "secret"})
raise_for_status(response)
print(response.json())
```

In this example:
- `url` is the endpoint for the chat completion API.
- `body` contains the model, provider, messages, and images.
- `messages` is a list of message objects with roles and content.
- `images` is a list of image data in Data URI format and optional filenames.
- `response` stores the API response.

### Example with the Client

To use vision support in chat completion with the client, follow the example below:

```python
import g4f
import g4f.Provider

def chat_completion(prompt):
client = g4f.Client(provider=g4f.Provider.Blackbox)
images = [
[open("docs/images/waterfall.jpeg", "rb"), "waterfall.jpeg"],
[open("docs/images/cat.webp", "rb"), "cat.webp"]
]
response = client.chat.completions.create([{"content": prompt, "role": "user"}], "", images=images)
print(response.choices[0].message.content)

prompt = "what are on this images?"
chat_completion(prompt)
```

```
**Image 1**
* A waterfall with a rainbow
* Lush greenery surrounding the waterfall
* A stream flowing from the waterfall
**Image 2**
* A white cat with blue eyes
* A bird perched on a window sill
* Sunlight streaming through the window
```

In this example:
- `client` initializes a new client with the specified provider.
- `images` is a list of image data and optional filenames.
- `response` stores the response from the client.
- The `chat_completion` function prints the chat completion output.

### Notes

- Multiple images can be sent. Each image has two data parts: image data (in Data URI format for the API) and an optional filename.
- The client supports bytes, IO objects, and PIL images as input.
- Ensure you use a provider that supports vision and multiple images.
141 changes: 85 additions & 56 deletions etc/tool/readme_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from g4f import models, ChatCompletion
from g4f.providers.types import BaseRetryProvider, ProviderType
from etc.testing._providers import get_providers
from g4f.providers.base_provider import ProviderModelMixin
from g4f.Provider import __providers__
from g4f.models import _all_models
from g4f import debug

debug.logging = True
Expand Down Expand Up @@ -35,53 +37,76 @@ def test_async_list(providers: list[ProviderType]):
return responses

def print_providers():

providers = get_providers()
providers = [provider for provider in __providers__ if provider.working]
responses = test_async_list(providers)

for type in ("GPT-4", "GPT-3.5", "Other"):
lines = [
lines = []
for type in ("Free", "Auth"):
lines += [
"",
f"### {type}",
f"## {type}",
"",
"| Website | Provider | GPT-3.5 | GPT-4 | Stream | Status | Auth |",
"| ------ | ------- | ------- | ----- | ------ | ------ | ---- |",
]
for is_working in (True, False):
for idx, _provider in enumerate(providers):
if is_working != _provider.working:
continue
do_continue = False
if type == "GPT-4" and _provider.supports_gpt_4:
do_continue = True
elif type == "GPT-3.5" and not _provider.supports_gpt_4 and _provider.supports_gpt_35_turbo:
do_continue = True
elif type == "Other" and not _provider.supports_gpt_4 and not _provider.supports_gpt_35_turbo:
do_continue = True
if not do_continue:
continue
for idx, _provider in enumerate(providers):
do_continue = False
if type == "Auth" and _provider.needs_auth:
do_continue = True
elif type == "Free" and not _provider.needs_auth:
do_continue = True
if not do_continue:
continue

lines.append(
f"### {getattr(_provider, 'label', _provider.__name__)}",
)
provider_name = f"`g4f.Provider.{_provider.__name__}`"
lines.append(f"| Provider | {provider_name} |")
lines.append("| -------- | ---- |")

if _provider.url:
netloc = urlparse(_provider.url).netloc.replace("www.", "")
website = f"[{netloc}]({_provider.url})"

provider_name = f"`g4f.Provider.{_provider.__name__}`"

has_gpt_35 = "✔️" if _provider.supports_gpt_35_turbo else "❌"
has_gpt_4 = "✔️" if _provider.supports_gpt_4 else "❌"
stream = "✔️" if _provider.supports_stream else "❌"
if _provider.working:
else:
website = "❌"

message_history = "✔️" if _provider.supports_message_history else "❌"
system = "✔️" if _provider.supports_system_message else "❌"
stream = "✔️" if _provider.supports_stream else "❌"
if _provider.working:
status = '![Active](https://img.shields.io/badge/Active-brightgreen)'
if responses[idx]:
status = '![Active](https://img.shields.io/badge/Active-brightgreen)'
if responses[idx]:
status = '![Active](https://img.shields.io/badge/Active-brightgreen)'
else:
status = '![Unknown](https://img.shields.io/badge/Unknown-grey)'
else:
status = '![Inactive](https://img.shields.io/badge/Inactive-red)'
auth = "✔️" if _provider.needs_auth else "❌"

lines.append(
f"| {website} | {provider_name} | {has_gpt_35} | {has_gpt_4} | {stream} | {status} | {auth} |"
)
print("\n".join(lines))
status = '![Unknown](https://img.shields.io/badge/Unknown-grey)'
else:
status = '![Inactive](https://img.shields.io/badge/Inactive-red)'
auth = "✔️" if _provider.needs_auth else "❌"

lines.append(f"| **Website** | {website} | \n| **Status** | {status} |")

if issubclass(_provider, ProviderModelMixin):
try:
all_models = _provider.get_models()
models = [model for model in _all_models if model in all_models or model in _provider.model_aliases]
image_models = _provider.image_models
if image_models:
for alias, name in _provider.model_aliases.items():
if alias in _all_models and name in image_models:
image_models.append(alias)
image_models = [model for model in image_models if model in _all_models]
if image_models:
models = [model for model in models if model not in image_models]
if models:
lines.append(f"| **Models** | {', '.join(models)} ({len(all_models)})|")
if image_models:
lines.append(f"| **Image Models (Image Generation)** | {', '.join(image_models)} |")
if hasattr(_provider, "vision_models"):
lines.append(f"| **Vision (Image Upload)** | ✔️ |")
except:
pass

lines.append(f"| **Authentication** | {auth} | \n| **Streaming** | {stream} |")
lines.append(f"| **System message** | {system} | \n| **Message history** | {message_history} |")
return lines

def print_models():
base_provider_names = {
Expand Down Expand Up @@ -123,30 +148,34 @@ def print_models():

lines.append(f"| {name} | {base_provider} | {provider_name} | {website} |")

print("\n".join(lines))
return lines

def print_image_models():
lines = [
"| Label | Provider | Image Model | Vision Model | Website |",
"| ----- | -------- | ----------- | ------------ | ------- |",
]
from g4f.gui.server.api import Api
for image_model in Api.get_image_models():
provider_url = image_model["url"]
for provider in [provider for provider in __providers__ if provider.working and getattr(provider, "image_models", None) or getattr(provider, "vision_models", None)]:
provider_url = provider.url if provider.url else "❌"
netloc = urlparse(provider_url).netloc.replace("www.", "")
website = f"[{netloc}]({provider_url})"
label = image_model["provider"] if image_model["label"] is None else image_model["label"]
if image_model["image_model"] is None:
image_model["image_model"] = "❌"
if image_model["vision_model"] is None:
image_model["vision_model"] = "❌"
lines.append(f'| {label} | `g4f.Provider.{image_model["provider"]}` | {image_model["image_model"]}| {image_model["vision_model"]} | {website} |')
label = getattr(provider, "label", provider.__name__)
if provider.image_models:
image_models = ", ".join([model for model in provider.image_models if model in _all_models])
else:
image_models = "❌"
if hasattr(provider, "vision_models"):
vision_models = "✔️"
else:
vision_models = "❌"
lines.append(f'| {label} | `g4f.Provider.{provider.__name__}` | {image_models}| {vision_models} | {website} |')

print("\n".join(lines))
return lines

if __name__ == "__main__":
#print_providers()
#print("\n", "-" * 50, "\n")
#print_models()
print("\n", "-" * 50, "\n")
print_image_models()
with open("docs/providers.md", "w") as f:
f.write("\n".join(print_providers()))
f.write(f"\n{'-' * 50} \n")
#f.write("\n".join(print_models()))
#f.write(f"\n{'-' * 50} \n")
f.write("\n".join(print_image_models()))
18 changes: 12 additions & 6 deletions etc/unittest/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from g4f.models import __models__
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
from g4f.models import Model
from g4f.errors import MissingRequirementsError, MissingAuthError

class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
cache: dict = {}
Expand All @@ -13,11 +13,17 @@ async def test_provider_has_model(self):
for model, providers in __models__.values():
for provider in providers:
if issubclass(provider, ProviderModelMixin):
if model.name not in provider.model_aliases:
await asyncio.wait_for(self.provider_has_model(provider, model), 10)
if model.name in provider.model_aliases:
model_name = provider.model_aliases[model.name]
else:
model_name = model.name
await asyncio.wait_for(self.provider_has_model(provider, model_name), 10)

async def provider_has_model(self, provider: Type[BaseProvider], model: Model):
async def provider_has_model(self, provider: Type[BaseProvider], model: str):
if provider.__name__ not in self.cache:
self.cache[provider.__name__] = provider.get_models()
try:
self.cache[provider.__name__] = provider.get_models()
except (MissingRequirementsError, MissingAuthError):
return
if self.cache[provider.__name__]:
self.assertIn(model.name, self.cache[provider.__name__], provider.__name__)
self.assertIn(model, self.cache[provider.__name__], provider.__name__)
8 changes: 4 additions & 4 deletions g4f/Provider/Blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json
from pathlib import Path

from ..typing import AsyncResult, Messages, ImageType
from ..typing import AsyncResult, Messages, ImagesType
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
from ..image import ImageResponse, to_data_uri
from ..cookies import get_cookies_dir
Expand Down Expand Up @@ -197,8 +197,7 @@ async def create_async_generator(
prompt: str = None,
proxy: str = None,
web_search: bool = False,
image: ImageType = None,
image_name: str = None,
images: ImagesType = None,
top_p: float = 0.9,
temperature: float = 0.5,
max_tokens: int = 1024,
Expand All @@ -212,13 +211,14 @@ async def create_async_generator(

messages = [{"id": message_id, "content": formatted_message, "role": "user"}]

if image is not None:
if images is not None:
messages[-1]['data'] = {
"imagesData": [
{
"filePath": f"MultipleFiles/{image_name}",
"contents": to_data_uri(image)
}
for image, image_name in images
],
"fileText": "",
"title": ""
Expand Down
12 changes: 7 additions & 5 deletions g4f/Provider/Blackbox2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import asyncio
from aiohttp import ClientSession
from typing import Union, AsyncGenerator
from typing import AsyncGenerator

from ..typing import AsyncResult, Messages
from ..image import ImageResponse
Expand Down Expand Up @@ -37,12 +37,15 @@ async def create_async_generator(
max_retries: int = 3,
delay: int = 1,
**kwargs
) -> AsyncGenerator:
) -> AsyncResult:
if not model:
model = cls.default_model
if model in cls.chat_models:
async for result in cls._generate_text(model, messages, proxy, max_retries, delay):
yield result
elif model in cls.image_models:
async for result in cls._generate_image(model, messages, proxy):
prompt = messages[-1]["content"] if prompt is None else prompt
async for result in cls._generate_image(model, prompt, proxy):
yield result
else:
raise ValueError(f"Unsupported model: {model}")
Expand Down Expand Up @@ -87,14 +90,13 @@ async def _generate_text(
async def _generate_image(
cls,
model: str,
messages: Messages,
prompt: str,
proxy: str = None
) -> AsyncGenerator:
headers = cls._get_headers()
api_endpoint = cls.api_endpoints[model]

async with ClientSession(headers=headers) as session:
prompt = messages[-1]["content"]
data = {
"query": prompt
}
Expand Down
Loading

0 comments on commit 335c971

Please sign in to comment.