|
1 | 1 | import dataclasses as dc
|
| 2 | +import os |
2 | 3 | import typing as t
|
3 | 4 |
|
4 | 5 | import cohere
|
|
7 | 8 | from superduper.base.query_dataset import QueryDataset
|
8 | 9 | from superduper.components.model import APIBaseModel
|
9 | 10 | from superduper.misc.retry import Retry
|
10 |
| -from superduper.misc.utils import format_prompt, get_key |
| 11 | +from superduper.misc.utils import format_prompt |
11 | 12 |
|
12 | 13 | retry = Retry(exception_types=(CohereAPIError, CohereConnectionError))
|
13 | 14 |
|
@@ -42,30 +43,21 @@ class CohereEmbed(Cohere):
|
42 | 43 |
|
43 | 44 | """
|
44 | 45 |
|
45 |
| - shapes: t.ClassVar[t.Dict] = {'embed-english-v2.0': (4096,)} |
46 |
| - shape: t.Optional[t.Sequence[int]] = None |
47 | 46 | batch_size: int = 100
|
48 |
| - signature: str = 'singleton' |
49 |
| - |
50 |
| - def postinit(self): |
51 |
| - """Post-initialization method.""" |
52 |
| - if self.shape is None: |
53 |
| - self.shape = self.shapes[self.identifier] |
54 |
| - return super().postinit() |
55 | 47 |
|
56 | 48 | @retry
|
57 | 49 | def predict(self, X: str):
|
58 | 50 | """Predict the embedding of a single text.
|
59 | 51 |
|
60 | 52 | :param X: The text to predict the embedding of.
|
61 | 53 | """
|
62 |
| - client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) |
| 54 | + client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs) |
63 | 55 | e = client.embed(texts=[X], model=self.identifier, **self.predict_kwargs)
|
64 | 56 | return e.embeddings[0]
|
65 | 57 |
|
66 | 58 | @retry
|
67 | 59 | def _predict_a_batch(self, texts: t.List[str]):
|
68 |
| - client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) |
| 60 | + client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs) |
69 | 61 | out = client.embed(texts=texts, model=self.identifier, **self.predict_kwargs)
|
70 | 62 | return [r for r in out.embeddings]
|
71 | 63 |
|
@@ -111,7 +103,7 @@ def predict(self, prompt: str, context: t.Optional[t.List[str]] = None):
|
111 | 103 | """
|
112 | 104 | if context is not None:
|
113 | 105 | prompt = format_prompt(prompt, self.prompt, context=context)
|
114 |
| - client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) |
| 106 | + client = cohere.Client(os.environ[KEY_NAME], **self.client_kwargs) |
115 | 107 | resp = client.generate(
|
116 | 108 | prompt=prompt, model=self.identifier, **self.predict_kwargs
|
117 | 109 | )
|
|
0 commit comments