Skip to content

Commit 87a05d2

Browse files
clam004Carson LamCarson Lam
authored
Make together embeddings.create() into OpenAI compatible format and allow providing a safety_model to Complete.create() (#63)
* embeddings take str and list of strings this is to make the method compatible with openai * embeddings take str and list of strings this is to make the method compatible with openai * allows returning object embedding this make the embedding API openai compatible so you can call embed.data[0].embedding * allows returning object embedding this make the embedding API openai compatible so you can call embed.data[0].embedding * Added safety_model to Complete.create this is for the meta safety llama as a placeholder so we can use python in the demo * removed the embeddings api from readme not to be announced yet per heejin * black ruff and mypy * Added TogetherAI() class this allows for both the output = together.Complete.create( form of usage and also the client = TogetherAI() embed = client.embeddings.create( form of usage to keep the python library self consistent but also be OpenAI compatible * Added TogetherAI() class this allows for both the output = together.Complete.create( form of usage and also the client = TogetherAI() embed = client.embeddings.create( form of usage to keep the python library self consistent but also be OpenAI compatible * Added TogetherAI() class this allows for both the output = together.Complete.create( form of usage and also the client = TogetherAI() embed = client.embeddings.create( form of usage to keep the python library self consistent but also be OpenAI compatible * changed TogetherAI to Together class - added safety model to commands and changed Output to EmbeddingsOuput black ruff and mypy * changed TogetherAI to Together class - added safety model to commands and changed Output to EmbeddingsOuput black ruff and mypy --------- Co-authored-by: Carson Lam <[email protected]> Co-authored-by: Carson Lam <[email protected]>
1 parent 8ae1902 commit 87a05d2

File tree

7 files changed

+76
-22
lines changed

7 files changed

+76
-22
lines changed

README.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -453,16 +453,6 @@ print(output_text)
453453
Space Robots are a great way to get your kids interested in science. After all, they are the future!
454454
```
455455

456-
## Embeddings API
457-
458-
Embeddings are vector representations of sequences. You can use these vectors for measuring the overall similarity between texts. Embeddings are useful for tasks such as search and retrieval.
459-
460-
```python
461-
resp = together.Embeddings.create("embed this sentence into a single vector", model="togethercomputer/bert-base-uncased")
462-
463-
print(resp['data'][0]['embedding']) # [0.06659205, 0.07896972, 0.007910785 ........]
464-
```
465-
466456
## Colab Tutorial
467457

468458
Follow along in our Colab (Google Colaboratory) Notebook Tutorial [Example Finetuning Project](https://colab.research.google.com/drive/11DwtftycpDSgp3Z1vnV-Cy68zvkGZL4K?usp=sharing).

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.masonry.api"
44

55
[tool.poetry]
66
name = "together"
7-
version = "0.2.8"
7+
version = "0.2.9"
88
authors = [
99
"Together AI <[email protected]>"
1010
]

src/together/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import urllib.parse
4+
from typing import Type
45

56
from .version import VERSION
67

@@ -41,6 +42,27 @@
4142
from .models import Models
4243

4344

45+
class Together:
46+
complete: Type[Complete]
47+
completion: Type[Completion]
48+
embeddings: Type[Embeddings]
49+
files: Type[Files]
50+
finetune: Type[Finetune]
51+
image: Type[Image]
52+
models: Type[Models]
53+
54+
def __init__(
55+
self,
56+
) -> None:
57+
self.complete = Complete
58+
self.completion = Completion
59+
self.embeddings = Embeddings
60+
self.files = Files
61+
self.finetune = Finetune
62+
self.image = Image
63+
self.models = Models
64+
65+
4466
__all__ = [
4567
"api_key",
4668
"api_base",
@@ -63,4 +85,5 @@
6385
"MISSING_API_KEY_MESSAGE",
6486
"BACKOFF_FACTOR",
6587
"min_samples",
88+
"Together",
6689
]

src/together/commands/complete.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def add_parser(subparsers: argparse._SubParsersAction[argparse.ArgumentParser])
9090
action="store_true",
9191
help="temperature for the LM",
9292
)
93+
subparser.add_argument(
94+
"--safety-model",
95+
"-sm",
96+
default=None,
97+
type=str,
98+
help="The name of the safety model to use for moderation.",
99+
)
93100
subparser.set_defaults(func=_run_complete)
94101

95102

@@ -142,6 +149,7 @@ def _run_complete(args: argparse.Namespace) -> None:
142149
top_k=args.top_k,
143150
repetition_penalty=args.repetition_penalty,
144151
logprobs=args.logprobs,
152+
safety_model=args.safety_model,
145153
)
146154
except together.AuthenticationError:
147155
logger.critical(together.MISSING_API_KEY_MESSAGE)
@@ -159,6 +167,7 @@ def _run_complete(args: argparse.Namespace) -> None:
159167
top_p=args.top_p,
160168
top_k=args.top_k,
161169
repetition_penalty=args.repetition_penalty,
170+
safety_model=args.safety_model,
162171
raw=args.raw,
163172
):
164173
if not args.raw:

src/together/commands/embeddings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import argparse
4-
import json
54

65
import together
76
from together import Embeddings
@@ -42,7 +41,7 @@ def _run_complete(args: argparse.Namespace) -> None:
4241
model=args.model,
4342
)
4443

45-
print(json.dumps(response, indent=4))
44+
print([e.embedding for e in response.data])
4645
except together.AuthenticationError:
4746
logger.critical(together.MISSING_API_KEY_MESSAGE)
4847
exit(0)

src/together/complete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def create(
2424
logprobs: Optional[int] = None,
2525
api_key: Optional[str] = None,
2626
cast: bool = False,
27+
safety_model: Optional[str] = None,
2728
) -> Union[Dict[str, Any], TogetherResponse]:
2829
if model == "":
2930
model = together.default_text_model
@@ -38,6 +39,7 @@ def create(
3839
"stop": stop,
3940
"repetition_penalty": repetition_penalty,
4041
"logprobs": logprobs,
42+
"safety_model": safety_model,
4143
}
4244

4345
# send request
@@ -70,6 +72,7 @@ def create_streaming(
7072
raw: Optional[bool] = False,
7173
api_key: Optional[str] = None,
7274
cast: Optional[bool] = False,
75+
safety_model: Optional[str] = None,
7376
) -> Union[Iterator[str], Iterator[TogetherResponse]]:
7477
"""
7578
Prints streaming responses and returns the completed text.
@@ -88,6 +91,7 @@ def create_streaming(
8891
"stop": stop,
8992
"repetition_penalty": repetition_penalty,
9093
"stream_tokens": True,
94+
"safety_model": safety_model,
9195
}
9296

9397
# send request

src/together/embeddings.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Dict, Optional
1+
import concurrent.futures
2+
from typing import Any, Dict, List, Optional, Union
23

34
import together
45
from together.utils import create_post_request, get_logger
@@ -7,29 +8,57 @@
78
logger = get_logger(str(__name__))
89

910

11+
class DataItem:
12+
def __init__(self, embedding: List[float]):
13+
self.embedding = embedding
14+
15+
16+
class EmbeddingsOutput:
17+
def __init__(self, data: List[DataItem]):
18+
self.data = data
19+
20+
1021
class Embeddings:
1122
@classmethod
1223
def create(
13-
self,
14-
input: str,
24+
cls,
25+
input: Union[str, List[str]],
1526
model: Optional[str] = "",
16-
) -> Dict[str, Any]:
27+
) -> EmbeddingsOutput:
1728
if model == "":
1829
model = together.default_embedding_model
1930

20-
parameter_payload = {
21-
"input": input,
22-
"model": model,
23-
}
31+
if isinstance(input, str):
32+
parameter_payload = {
33+
"input": input,
34+
"model": model,
35+
}
36+
37+
response = cls._process_input(parameter_payload)
38+
39+
return EmbeddingsOutput([DataItem(response["data"][0]["embedding"])])
2440

41+
elif isinstance(input, list):
42+
# If input is a list, process each string concurrently
43+
with concurrent.futures.ThreadPoolExecutor() as executor:
44+
parameter_payloads = [{"input": item, "model": model} for item in input]
45+
results = list(executor.map(cls._process_input, parameter_payloads))
46+
47+
return EmbeddingsOutput(
48+
[DataItem(item["data"][0]["embedding"]) for item in results]
49+
)
50+
51+
@classmethod
52+
def _process_input(cls, parameter_payload: Dict[str, Any]) -> Dict[str, Any]:
2553
# send request
2654
response = create_post_request(
2755
url=together.api_base_embeddings, json=parameter_payload
2856
)
2957

58+
# return the json as a DotDict
3059
try:
3160
response_json = dict(response.json())
32-
3361
except Exception as e:
3462
raise together.JSONError(e, http_status=response.status_code)
63+
3564
return response_json

0 commit comments

Comments
 (0)