Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1450: Image Classification support in adala #264

Merged
merged 16 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions adala/runtimes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .base import Runtime, AsyncRuntime
from ._openai import OpenAIChatRuntime, OpenAIVisionRuntime, AsyncOpenAIChatRuntime
from ._litellm import LiteLLMChatRuntime, AsyncLiteLLMChatRuntime
from ._openai import OpenAIChatRuntime, AsyncOpenAIChatRuntime, AsyncOpenAIVisionRuntime
from ._litellm import (
LiteLLMChatRuntime,
AsyncLiteLLMChatRuntime,
AsyncLiteLLMVisionRuntime,
)
415 changes: 254 additions & 161 deletions adala/runtimes/_litellm.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from ._litellm import AsyncLiteLLMChatRuntime, LiteLLMChatRuntime, LiteLLMVisionRuntime
from ._litellm import (
AsyncLiteLLMChatRuntime,
LiteLLMChatRuntime,
AsyncLiteLLMVisionRuntime,
)


# litellm already reads the OPENAI_API_KEY env var, which was the reason for this class
OpenAIChatRuntime = LiteLLMChatRuntime
AsyncOpenAIChatRuntime = AsyncLiteLLMChatRuntime
OpenAIVisionRuntime = LiteLLMVisionRuntime
AsyncOpenAIVisionRuntime = AsyncLiteLLMVisionRuntime
74 changes: 56 additions & 18 deletions adala/skills/collection/label_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
import pandas as pd
from typing import Type, Iterator, Optional
from functools import cached_property
from collections import defaultdict
from adala.skills._base import TransformSkill
from adala.runtimes import AsyncLiteLLMVisionRuntime
from adala.runtimes._litellm import MessageChunkType
from pydantic import BaseModel, Field, model_validator

from adala.runtimes import Runtime, AsyncRuntime
from adala.utils.internal_data import InternalDataFrame

from label_studio_sdk.label_interface import LabelInterface
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk.label_interface.control_tags import ControlTag, ObjectTag
from label_studio_sdk._extensions.label_studio_tools.core.utils.json_schema import (
json_schema_to_pydantic,
)
Expand All @@ -35,39 +38,60 @@ class LabelStudioSkill(TransformSkill):

# TODO: implement postprocessing to verify Taxonomy

@cached_property
def label_interface(self) -> LabelInterface:
return LabelInterface(self.label_config)

@cached_property
def ner_tags(self) -> Iterator[ControlTag]:
# check if the input config has NER tag (<Labels> + <Text>), and return its `from_name` and `to_name`
interface = LabelInterface(self.label_config)
for tag in interface.controls:
# NOTE: don't need to check object tag because at this point, unusable control tags should have been stripped out of the label config
if tag.tag.lower() == "labels":
control_tag_names = self.allowed_control_tags or list(
self.label_interface._controls.keys()
)
for tag_name in control_tag_names:
tag = self.label_interface.get_control(tag_name)
if tag.tag.lower() in {"labels", "hypertextlabels"}:
yield tag

@cached_property
def image_tags(self) -> Iterator[ObjectTag]:
# check if any image tags are used as input variables
object_tag_names = self.allowed_object_tags or list(
self.label_interface._objects.keys()
)
for tag_name in object_tag_names:
tag = self.label_interface.get_object(tag_name)
if tag.tag.lower() == "image":
yield tag

@model_validator(mode="after")
def validate_response_model(self):

interface = LabelInterface(self.label_config)
logger.debug(f"Read labeling config {self.label_config}")

if self.allowed_control_tags or self.allowed_object_tags:
if self.allowed_control_tags:
control_tags = {
tag: interface._controls[tag] for tag in self.allowed_control_tags
tag: self.label_interface._controls[tag]
for tag in self.allowed_control_tags
}
else:
control_tags = interface._controls
control_tags = self.label_interface._controls
if self.allowed_object_tags:
object_tags = {
tag: interface._objects[tag] for tag in self.allowed_object_tags
tag: self.label_interface._objects[tag]
for tag in self.allowed_object_tags
}
else:
object_tags = interface._objects
object_tags = self.label_interface._objects
interface = LabelInterface.create_instance(
tags={**control_tags, **object_tags}
)
logger.debug(
f"Filtered labeling config based on allowed tags {self.allowed_control_tags=} and {self.allowed_object_tags=} to {interface.config}"
)
else:
interface = self.label_interface

# NOTE: filtered label config is used for the response model, but full label config is used for the prompt, so that the model has as much context as possible.
self.field_schema = interface.to_json_schema()
Expand Down Expand Up @@ -100,14 +124,28 @@ async def aapply(
) -> InternalDataFrame:

with json_schema_to_pydantic(self.field_schema) as ResponseModel:
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
)
for ner_tag in self.ner_tags():
# special handling to flag image inputs if they exist
if isinstance(runtime, AsyncLiteLLMVisionRuntime):
input_field_types = defaultdict(lambda: MessageChunkType.TEXT)
for tag in self.image_tags:
input_field_types[tag.name] = MessageChunkType.IMAGE_URL
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
input_field_types=input_field_types,
)
else:
output = await runtime.batch_to_batch(
input,
input_template=self.input_template,
output_template="",
instructions_template=self.instructions,
response_model=ResponseModel,
)
for ner_tag in self.ner_tags:
input_field_name = ner_tag.objects[0].value.lstrip("$")
output_field_name = ner_tag.name
quote_string_field_name = "text"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": "Hey, how''s it going?"}], "model":
"gpt-4o-mini", "max_tokens": 1000, "seed": 47, "temperature": 0.0}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '143'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.47.1
x-stainless-arch:
- x64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.11.5
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA4xSy27bMBC86yu2vPRiFbbs1o9LUPTS9NBDgz7QIBBociWxobgsuYLjBgb6G/29
fklB2bEUNAV6IcCZncHMkvcZgDBabECoRrJqvc1ff9FXTf3+zXfeVl/ryw/x46d3P1q7M/X+6k5M
koK231Dxg+qFotZbZEPuSKuAkjG5zpbz4uVyvVque6IljTbJas/5gvLWOJMX02KRT5f5bHVSN2QU
RrGB6wwA4L4/U06n8U5sYDp5QFqMUdYoNuchABHIJkTIGE1k6VhMBlKRY3R99MvnLWgyroYdWjsB
bqS7hT11z+At7UBuqeN0vYDPjeTfP39FIJeAAK1xGpi03F+MzQNWXZSpoOusPeGHc1pLtQ+0jSf+
jFfGmdiUAWUkl5JFJi969pAB3PRb6R4VFT5Q67lkukWXDGeLo50Y3mJErk4kE0s74PNi8oRbqZGl
sXG0VaGkalAPyuEJZKcNjYhs1PnvME95H3sbV/+P/UAohZ5Rlz6gNupx4WEsYPqp/xo777gPLOI+
MrZlZVyNwQdz/CeVL+caZ8VqNX21Ftkh+wMAAP//AwDs57wINQMAAA==
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8e85a915891d6208-ORD
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 26 Nov 2024 00:11:19 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=f2eAWUmcSjgkraa7rJvzhr53.Kz3y7EZniQwAmrWmHg-1732579879-1.0.1.1-FcTG.L1LC0IYeDrJNsA3S_9CqAeK8RVmE9li1oKj8OrrEOFELgjJ.wfKOQqQi8SWUsocl.oe2kGwriII9BVQ5Q;
path=/; expires=Tue, 26-Nov-24 00:41:19 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=5M11WH7821NNRxCf3t86tF5_JSGA0RXiNMeAxl1Pa4A-1732579879834-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- heartex
openai-processing-ms:
- '488'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149998994'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_c89ae189bd037c2fdf4605f19a3115f5
status:
code: 200
message: OK
- request:
body: '{"messages": [{"role": "user", "content": "\n Given
the title of a museum painting:\nIt''s definitely not the Mona Lisa\n and the
image of the painting:\nhttps://upload.wikimedia.org/wikipedia/commons/thumb/e/ec/Mona_Lisa%2C_by_Leonardo_da_Vinci%2C_from_C2RMF_retouched.jpg/687px-Mona_Lisa%2C_by_Leonardo_da_Vinci%2C_from_C2RMF_retouched.jpg\n,\n classify
the painting as either \"Mona Lisa\" or \"Not Mona Lisa\".\n They
may or may not agree with each other. If the title and image disagree, believe
the image.\n "}], "model": "gpt-4o-mini", "max_tokens": 1000,
"seed": 47, "temperature": 0.0, "tool_choice": {"type": "function", "function":
{"name": "MyModel"}}, "tools": [{"type": "function", "function": {"name": "MyModel",
"description": "Correctly extracted `MyModel` with all the required parameters
with correct types", "parameters": {"properties": {"classification": {"description":
"Choices for image", "enum": ["Mona Lisa", "Not Mona Lisa"], "title": "Classification",
"type": "string"}}, "required": ["classification"], "type": "object"}}}]}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '1124'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- AsyncOpenAI/Python 1.47.1
x-stainless-arch:
- x64
x-stainless-async:
- async:asyncio
x-stainless-lang:
- python
x-stainless-os:
- Linux
x-stainless-package-version:
- 1.47.1
x-stainless-raw-response:
- 'true'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.11.5
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: !!binary |
H4sIAAAAAAAAA4xTUW/TMBB+z6+w7rlBSelIyNsmgRhqERpioFEUuc4l9ebYnu1IlKr/HdnpkrQU
iTxY1n33fXf3+bKPCAFeQUGAbaljrRbx9ffqC79/Xr3p3n+6edfJ5cOH34+39c0i+5YgzDxDbR6R
uRfWK6ZaLdBxJXuYGaQOvWqavZ5fZW/zPAlAqyoUntZoFy9U3HLJ43kyX8RJFqf5kb1VnKGFgvyI
CCFkH07fp6zwFxQkaIVIi9bSBqEYkggBo4SPALWWW0elg9kIMiUdSt+67ISYAE4pUTIqxFi4//aT
+2gWFaLcpJpf7+rlXf7ViXZ7e/f8RD9/5GZSr5fe6dBQ3Uk2mDTBh3hxVowQkLQN3NVuFbybnSdQ
03QtSufbhv0amPBz15xRL7mGYg0rJSlZckvXcIAT/iG6dP85scVg3Vkqjn4d44fhAYRqtFEbe+Yn
1Fxyuy0NUhvmAuuU7mv7OqECdCdvB9qoVrvSqSeUXnCepr0ejPs1otkRc8pRMSXlswtyZYWO8vC2
wzoxyrZYjdRxrWhXcTUBosnQfzdzSbsfnMvmf+RHgDHUDqtSG6w4Ox14TDPo/75/pQ0mh4bB7qzD
tqy5bNBow8PuQ63LJEuuNnWesQSiQ/QHAAD//wMAq681QQkEAAA=
headers:
CF-Cache-Status:
- DYNAMIC
CF-RAY:
- 8e85a91a5a7622f3-ORD
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 26 Nov 2024 00:11:20 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=mR.lQGByVqO3YXPOJhOAYfQSCaSh.GGUAiqvmTKYeF4-1732579880-1.0.1.1-kjoNgd4tNmz.8ile246dtkSjbL3C9pTtBxM35zH_sQENgFJuN91lWEVTAYebM_Au.qq8D_Sr1S1_DegpYxCo7A;
path=/; expires=Tue, 26-Nov-24 00:41:20 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=vJrrlSUKQKX62ERSv.300oGbNMFud1yC5ztTRDPBooA-1732579880316-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
openai-organization:
- heartex
openai-processing-ms:
- '189'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149998865'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_5d437fcbab69225ff907cf1da14e1bb7
status:
code: 200
message: OK
version: 1
Loading
Loading