Skip to content

Commit d8cdfc9

Browse files
committed
Fix plugins
1 parent c4da6ff commit d8cdfc9

File tree

53 files changed

+530
-2282
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+530
-2282
lines changed

.github/workflows/ci_code.yml

-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ jobs:
6464
python -m pip install plugins/mongodb
6565
python -m pip install plugins/openai
6666
python -m pip install plugins/ibis
67-
python -m pip install plugins/sqlalchemy
6867
6968
7069
- name: Lint and type-check

.github/workflows/ci_templates.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ jobs:
136136
run: |
137137
export SUPERDUPER_TEMPLATE=${{ matrix.template }}
138138
export SUPERDUPER_DATA_BACKEND='mongomock://test_db'
139-
cd superduper/templates && ln -s ../../templates/* . && cd ../../
140139
pytest test/integration/template/test_template.py -s
141140
env:
142141
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

plugins/anthropic/superduper_anthropic/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import dataclasses as dc
2+
import os
23
import typing as t
34

45
import anthropic
56
from anthropic import APIConnectionError, APIError, APIStatusError, APITimeoutError
67
from superduper.base.query_dataset import QueryDataset
78
from superduper.components.model import APIBaseModel
89
from superduper.misc.retry import Retry
9-
from superduper.misc.utils import format_prompt, get_key
10+
from superduper.misc.utils import format_prompt
1011

1112
retry = Retry(
1213
exception_types=(APIConnectionError, APIError, APIStatusError, APITimeoutError)
@@ -34,7 +35,7 @@ def init(self, db=None):
3435
:param db: The database to use.
3536
"""
3637
self.client = anthropic.Anthropic(
37-
api_key=get_key(KEY_NAME), **self.client_kwargs
38+
api_key=os.environ[KEY_NAME], **self.client_kwargs
3839
)
3940
super().init(db=db)
4041

plugins/cohere/plugin_test/test_model_cohere.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def test_embed_one():
2121
embed = CohereEmbed(identifier='embed-english-v2.0')
2222
resp = embed.predict('Hello world')
2323

24-
assert len(resp) == embed.shape[0]
2524
assert isinstance(resp, list)
2625
assert all(isinstance(x, float) for x in resp)
2726

@@ -35,7 +34,6 @@ def test_embed_batch():
3534
resp = embed.predict_batches(['Hello', 'world'])
3635

3736
assert len(resp) == 2
38-
assert len(resp[0]) == embed.shape[0]
3937
assert isinstance(resp[0], list)
4038
assert all(isinstance(x, float) for x in resp[0])
4139

plugins/cohere/superduper_cohere/model.py

+5-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import dataclasses as dc
2+
import os
23
import typing as t
34

45
import cohere
@@ -7,7 +8,7 @@
78
from superduper.base.query_dataset import QueryDataset
89
from superduper.components.model import APIBaseModel
910
from superduper.misc.retry import Retry
10-
from superduper.misc.utils import format_prompt, get_key
11+
from superduper.misc.utils import format_prompt
1112

1213
retry = Retry(exception_types=(CohereAPIError, CohereConnectionError))
1314

@@ -42,30 +43,21 @@ class CohereEmbed(Cohere):
4243
4344
"""
4445

45-
shapes: t.ClassVar[t.Dict] = {'embed-english-v2.0': (4096,)}
46-
shape: t.Optional[t.Sequence[int]] = None
4746
batch_size: int = 100
48-
signature: str = 'singleton'
49-
50-
def postinit(self):
51-
"""Post-initialization method."""
52-
if self.shape is None:
53-
self.shape = self.shapes[self.identifier]
54-
return super().postinit()
5547

5648
@retry
5749
def predict(self, X: str):
5850
"""Predict the embedding of a single text.
5951
6052
:param X: The text to predict the embedding of.
6153
"""
62-
client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs)
54+
client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs)
6355
e = client.embed(texts=[X], model=self.identifier, **self.predict_kwargs)
6456
return e.embeddings[0]
6557

6658
@retry
6759
def _predict_a_batch(self, texts: t.List[str]):
68-
client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs)
60+
client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs)
6961
out = client.embed(texts=texts, model=self.identifier, **self.predict_kwargs)
7062
return [r for r in out.embeddings]
7163

@@ -111,7 +103,7 @@ def predict(self, prompt: str, context: t.Optional[t.List[str]] = None):
111103
"""
112104
if context is not None:
113105
prompt = format_prompt(prompt, self.prompt, context=context)
114-
client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs)
106+
client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs)
115107
resp = client.generate(
116108
prompt=prompt, model=self.identifier, **self.predict_kwargs
117109
)

plugins/ibis/plugin_test/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ auto_schema: false
33
force_apply: true
44
json_native: false
55
datatype_presets:
6-
vector: superduper.components.datatype.Array
6+
vector: superduper.base.datatype.Array

plugins/ibis/pyproject.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@ dependencies = [
2929
"ibis-framework[sqlite]>=9.0.1,<10.0.0",
3030
"click",
3131
'pandas',
32-
"sqlalchemy"
32+
"sqlalchemy>=1.4.0",
3333
]
3434

3535
[project.optional-dependencies]
3636
test = [
3737
# Annotation plugin dependencies will be installed in CI
38-
# :CI: plugins/sqlalchemy
3938
]
4039

4140
[project.urls]

plugins/jina/superduper_jina/client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
import os
12
from typing import List, Optional
23

34
import aiohttp
45
import requests
56
from aiohttp import ClientConnectionError, ClientResponseError
67
from requests.exceptions import HTTPError
78
from superduper.misc.retry import Retry
8-
from superduper.misc.utils import get_key
99

1010
JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
1111
KEY_NAME = 'JINA_API_KEY'
@@ -34,7 +34,7 @@ def __init__(
3434
# if the user does not provide the API key,
3535
# check if it is set in the environment variable
3636
if api_key is None:
37-
api_key = get_key(KEY_NAME)
37+
api_key = os.environ[KEY_NAME]
3838

3939
self.model_name = model_name
4040
self._session = requests.Session()

plugins/openai/plugin_test/test_model_openai.py

-72
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
OpenAIAudioTranslation,
1212
OpenAIChatCompletion,
1313
OpenAIEmbedding,
14-
OpenAIImageCreation,
15-
OpenAIImageEdit,
1614
_available_models,
1715
)
1816

@@ -122,76 +120,6 @@ def test_batch_chat():
122120
assert isinstance(resp[0], str)
123121

124122

125-
@vcr.use_cassette()
126-
def test_create_url():
127-
e = OpenAIImageCreation(
128-
identifier='dall-e',
129-
prompt='a close up, studio photographic portrait of a {context}',
130-
response_format='url',
131-
)
132-
resp = e.predict('cat')
133-
134-
# PNG 8-byte signature
135-
assert resp[0:16] == PNG_BYTE_SIGNATURE
136-
137-
138-
@vcr.use_cassette()
139-
def test_create_url_batch():
140-
e = OpenAIImageCreation(
141-
identifier='dall-e',
142-
prompt='a close up, studio photographic portrait of a',
143-
response_format='url',
144-
)
145-
resp = e.predict_batches(['cat', 'dog'])
146-
147-
for img in resp:
148-
# PNG 8-byte signature
149-
assert img[0:16] == PNG_BYTE_SIGNATURE
150-
151-
152-
@vcr.use_cassette()
153-
def test_edit_url():
154-
e = OpenAIImageEdit(
155-
identifier='dall-e',
156-
prompt='A celebration party at the launch of {context}',
157-
response_format='url',
158-
)
159-
with open('test/material/data/rickroll.png', 'rb') as f:
160-
buffer = io.BytesIO(f.read())
161-
resp = e.predict(buffer, context=['superduper'])
162-
buffer.close()
163-
164-
# PNG 8-byte signature
165-
assert resp[0:16] == PNG_BYTE_SIGNATURE
166-
167-
168-
@vcr.use_cassette()
169-
def test_edit_url_batch():
170-
e = OpenAIImageEdit(
171-
identifier='dall-e',
172-
prompt='A celebration party at the launch of superduper',
173-
response_format='url',
174-
)
175-
with open('test/material/data/rickroll.png', 'rb') as f:
176-
buffer_one = io.BytesIO(f.read())
177-
with open('test/material/data/rickroll.png', 'rb') as f:
178-
buffer_two = io.BytesIO(f.read())
179-
180-
resp = e.predict_batches(
181-
[
182-
((buffer_one,), {}),
183-
((buffer_two,), {}),
184-
]
185-
)
186-
187-
buffer_one.close()
188-
buffer_two.close()
189-
190-
for img in resp:
191-
# PNG 8-byte signature
192-
assert img[0:16] == PNG_BYTE_SIGNATURE
193-
194-
195123
@vcr.use_cassette()
196124
def test_transcribe():
197125
with open('test/material/data/test.wav', 'rb') as f:

0 commit comments

Comments
 (0)