-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathazureopenai.py
343 lines (289 loc) · 12.2 KB
/
azureopenai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import os
from typing import Any, Callable, Dict, List, Optional
from pydantic.v1 import PrivateAttr
from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type
from redisvl.utils.vectorize.base import BaseVectorizer
# ignore that openai isn't imported
# mypy: disable-error-code="name-defined"
class AzureOpenAITextVectorizer(BaseVectorizer):
"""The AzureOpenAITextVectorizer class utilizes AzureOpenAI's API to generate
embeddings for text data.
This vectorizer is designed to interact with AzureOpenAI's embeddings API,
requiring an API key, an AzureOpenAI deployment endpoint and API version.
These values can be provided directly in the `api_config` dictionary with
the parameters 'azure_endpoint', 'api_version' and 'api_key' or through the
environment variables 'AZURE_OPENAI_ENDPOINT', 'OPENAI_API_VERSION', and 'AZURE_OPENAI_API_KEY'.
Users must obtain these values from the 'Keys and Endpoints' section in their Azure OpenAI service.
Additionally, the `openai` python client must be installed with `pip install openai>=1.13.0`.
The vectorizer supports both synchronous and asynchronous operations,
allowing for batch processing of texts and flexibility in handling
preprocessing tasks.
.. code-block:: python
# Synchronous embedding of a single text
vectorizer = AzureOpenAITextVectorizer(
model="text-embedding-ada-002",
api_config={
"api_key": "your_api_key", # OR set AZURE_OPENAI_API_KEY in your env
"api_version": "your_api_version", # OR set OPENAI_API_VERSION in your env
"azure_endpoint": "your_azure_endpoint", # OR set AZURE_OPENAI_ENDPOINT in your env
}
)
embedding = vectorizer.embed("Hello, world!")
# Asynchronous batch embedding of multiple texts
embeddings = await vectorizer.aembed_many(
["Hello, world!", "How are you?"],
batch_size=2
)
"""
_client: Any = PrivateAttr()
_aclient: Any = PrivateAttr()
def __init__(
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
):
"""Initialize the AzureOpenAI vectorizer.
Args:
model (str): Deployment to use for embedding. Must be the
'Deployment name' not the 'Model name'. Defaults to
'text-embedding-ada-002'.
api_config (Optional[Dict], optional): Dictionary containing the
API key, API version, Azure endpoint, Azure deployment, and any other API options.
Defaults to None.
Raises:
ImportError: If the openai library is not installed.
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
"""
self._initialize_clients(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))
def _initialize_clients(self, api_config: Optional[Dict]):
"""
Setup the OpenAI clients using the provided API key or an
environment variable.
"""
if api_config is None:
api_config = {}
# Dynamic import of the openai module
try:
from openai import AsyncAzureOpenAI, AzureOpenAI
except ImportError:
raise ImportError(
"AzureOpenAI vectorizer requires the openai library. \
Please install with `pip install openai`"
)
# Fetch the API key, version and endpoint from api_config or environment variable
azure_endpoint = (
api_config.pop("azure_endpoint")
if api_config
else os.getenv("AZURE_OPENAI_ENDPOINT")
)
if not azure_endpoint:
raise ValueError(
"AzureOpenAI API endpoint is required. "
"Provide it in api_config or set the AZURE_OPENAI_ENDPOINT\
environment variable."
)
api_version = (
api_config.pop("api_version")
if api_config
else os.getenv("OPENAI_API_VERSION")
)
if not api_version:
raise ValueError(
"AzureOpenAI API version is required. "
"Provide it in api_config or set the OPENAI_API_VERSION\
environment variable."
)
api_key = (
api_config.pop("api_key")
if api_config
else os.getenv("AZURE_OPENAI_API_KEY")
)
if not api_key:
raise ValueError(
"AzureOpenAI API key is required. "
"Provide it in api_config or set the AZURE_OPENAI_API_KEY\
environment variable."
)
azure_deployment = (
api_config.pop("azure_deployment")
if api_config
else os.getenv("AZURE_OPENAI_DEPLOYMENT")
)
if not azure_deployment:
raise ValueError(
"AzureOpenAI API deployment is required. "
"Provide it in api_config or set the AZURE_OPENAI_DEPLOYMENT\
environment variable."
)
self._client = AzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
**api_config,
)
self._aclient = AsyncAzureOpenAI(
api_key=api_key,
api_version=api_version,
azure_endpoint=azure_endpoint,
**api_config,
)
def _set_model_dims(self, model) -> int:
try:
embedding = (
self._client.embeddings.create(input=["dimension test"], model=model)
.data[0]
.embedding
)
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the AzureOpenAI API: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def embed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 10,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
"""Embed many chunks of texts using the AzureOpenAI API.
Args:
texts (List[str]): List of text chunks to embed.
preprocess (Optional[Callable], optional): Optional preprocessing
callable to perform before vectorization. Defaults to None.
batch_size (int, optional): Batch size of texts to use when creating
embeddings. Defaults to 10.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
Returns:
List[List[float]]: List of embeddings.
Raises:
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
dtype = kwargs.pop("dtype", None)
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._client.embeddings.create(input=batch, model=self.model)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
]
return embeddings
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
def embed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
"""Embed a chunk of text using the AzureOpenAI API.
Args:
text (str): Chunk of text to embed.
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
Returns:
List[float]: Embedding.
Raises:
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
if preprocess:
text = preprocess(text)
dtype = kwargs.pop("dtype", None)
result = self._client.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
async def aembed_many(
self,
texts: List[str],
preprocess: Optional[Callable] = None,
batch_size: int = 1000,
as_buffer: bool = False,
**kwargs,
) -> List[List[float]]:
"""Asynchronously embed many chunks of texts using the AzureOpenAI API.
Args:
texts (List[str]): List of text chunks to embed.
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
batch_size (int, optional): Batch size of texts to use when creating
embeddings. Defaults to 10.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
Returns:
List[List[float]]: List of embeddings.
Raises:
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(texts, list):
raise TypeError("Must pass in a list of str values to embed.")
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
dtype = kwargs.pop("dtype", None)
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = await self._aclient.embeddings.create(
input=batch, model=self.model
)
embeddings += [
self._process_embedding(r.embedding, as_buffer, dtype)
for r in response.data
]
return embeddings
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
retry=retry_if_not_exception_type(TypeError),
)
async def aembed(
self,
text: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
**kwargs,
) -> List[float]:
"""Asynchronously embed a chunk of text using the OpenAI API.
Args:
text (str): Chunk of text to embed.
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
as_buffer (bool, optional): Whether to convert the raw embedding
to a byte string. Defaults to False.
Returns:
List[float]: Embedding.
Raises:
TypeError: If the wrong input type is passed in for the test.
"""
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
if preprocess:
text = preprocess(text)
dtype = kwargs.pop("dtype", None)
result = await self._aclient.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@property
def type(self) -> str:
return "azure_openai"