|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -from typing import List, Union |
| 5 | +import base64 |
| 6 | +from typing import List, Union, cast |
6 | 7 | from typing_extensions import Literal |
7 | 8 |
|
8 | 9 | from ..types import CreateEmbeddingResponse, embedding_create_params |
9 | 10 | from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven |
10 | | -from .._utils import maybe_transform |
| 11 | +from .._utils import is_given, maybe_transform |
| 12 | +from .._extras import numpy as np |
| 13 | +from .._extras import has_numpy |
11 | 14 | from .._resource import SyncAPIResource, AsyncAPIResource |
12 | 15 | from .._base_client import make_request_options |
13 | 16 |
|
@@ -61,23 +64,40 @@ def create( |
61 | 64 |
|
62 | 65 | timeout: Override the client-level default timeout for this request, in seconds |
63 | 66 | """ |
64 | | - return self._post( |
| 67 | + params = { |
| 68 | + "input": input, |
| 69 | + "model": model, |
| 70 | + "user": user, |
| 71 | + "encoding_format": encoding_format, |
| 72 | + } |
| 73 | + if not is_given(encoding_format) and has_numpy(): |
| 74 | + params["encoding_format"] = "base64" |
| 75 | + |
| 76 | + response = self._post( |
65 | 77 | "/embeddings", |
66 | | - body=maybe_transform( |
67 | | - { |
68 | | - "input": input, |
69 | | - "model": model, |
70 | | - "encoding_format": encoding_format, |
71 | | - "user": user, |
72 | | - }, |
73 | | - embedding_create_params.EmbeddingCreateParams, |
74 | | - ), |
| 78 | + body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams), |
75 | 79 | options=make_request_options( |
76 | 80 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout |
77 | 81 | ), |
78 | 82 | cast_to=CreateEmbeddingResponse, |
79 | 83 | ) |
80 | 84 |
|
| 85 | + if is_given(encoding_format): |
| 86 | + # don't modify the response object if a user explicitly asked for a format |
| 87 | + return response |
| 88 | + |
| 89 | + for embedding in response.data: |
| 90 | + data = cast(object, embedding.embedding) |
| 91 | + if not isinstance(data, str): |
| 92 | + # numpy is not installed / base64 optimisation isn't enabled for this model yet |
| 93 | + continue |
| 94 | + |
| 95 | + embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call] |
| 96 | + base64.b64decode(data), dtype="float32" |
| 97 | + ).tolist() |
| 98 | + |
| 99 | + return response |
| 100 | + |
81 | 101 |
|
82 | 102 | class AsyncEmbeddings(AsyncAPIResource): |
83 | 103 | async def create( |
@@ -126,19 +146,36 @@ async def create( |
126 | 146 |
|
127 | 147 | timeout: Override the client-level default timeout for this request, in seconds |
128 | 148 | """ |
129 | | - return await self._post( |
| 149 | + params = { |
| 150 | + "input": input, |
| 151 | + "model": model, |
| 152 | + "user": user, |
| 153 | + "encoding_format": encoding_format, |
| 154 | + } |
| 155 | + if not is_given(encoding_format) and has_numpy(): |
| 156 | + params["encoding_format"] = "base64" |
| 157 | + |
| 158 | + response = await self._post( |
130 | 159 | "/embeddings", |
131 | | - body=maybe_transform( |
132 | | - { |
133 | | - "input": input, |
134 | | - "model": model, |
135 | | - "encoding_format": encoding_format, |
136 | | - "user": user, |
137 | | - }, |
138 | | - embedding_create_params.EmbeddingCreateParams, |
139 | | - ), |
| 160 | + body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams), |
140 | 161 | options=make_request_options( |
141 | 162 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout |
142 | 163 | ), |
143 | 164 | cast_to=CreateEmbeddingResponse, |
144 | 165 | ) |
| 166 | + |
| 167 | + if is_given(encoding_format): |
| 168 | + # don't modify the response object if a user explicitly asked for a format |
| 169 | + return response |
| 170 | + |
| 171 | + for embedding in response.data: |
| 172 | + data = cast(object, embedding.embedding) |
| 173 | + if not isinstance(data, str): |
| 174 | + # numpy is not installed / base64 optimisation isn't enabled for this model yet |
| 175 | + continue |
| 176 | + |
| 177 | + embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call] |
| 178 | + base64.b64decode(data), dtype="float32" |
| 179 | + ).tolist() |
| 180 | + |
| 181 | + return response |
0 commit comments