Skip to content

Commit 76fc949

Browse files
SDK and CLI for model upload (#366)
* API and CLI for model upload * Cleaning up the response
1 parent 0cf374c commit 76fc949

File tree

4 files changed

+250
-3
lines changed

4 files changed

+250
-3
lines changed

src/together/cli/api/models.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tabulate import tabulate
55

66
from together import Together
7-
from together.types.models import ModelObject
7+
from together.types.models import ModelObject, ModelUploadResponse
88

99

1010
@click.group()
@@ -53,3 +53,81 @@ def list(ctx: click.Context, type: str | None, json: bool) -> None:
5353
click.echo(json_lib.dumps(display_list, indent=2))
5454
else:
5555
click.echo(tabulate(display_list, headers="keys", tablefmt="plain"))
56+
57+
58+
@models.command()
59+
@click.option(
60+
"--model-name",
61+
required=True,
62+
help="The name to give to your uploaded model",
63+
)
64+
@click.option(
65+
"--model-source",
66+
required=True,
67+
help="The source location of the model (Hugging Face repo or S3 path)",
68+
)
69+
@click.option(
70+
"--model-type",
71+
type=click.Choice(["model", "adapter"]),
72+
default="model",
73+
help="Whether the model is a full model or an adapter",
74+
)
75+
@click.option(
76+
"--hf-token",
77+
help="Hugging Face token (if uploading from Hugging Face)",
78+
)
79+
@click.option(
80+
"--description",
81+
help="A description of your model",
82+
)
83+
@click.option(
84+
"--base-model",
85+
help="The base model to use for an adapter if setting it to run against a serverless pool. Only used for model_type 'adapter'.",
86+
)
87+
@click.option(
88+
"--lora-model",
89+
help="The lora pool to use for an adapter if setting it to run against, say, a dedicated pool. Only used for model_type 'adapter'.",
90+
)
91+
@click.option(
92+
"--json",
93+
is_flag=True,
94+
help="Output in JSON format",
95+
)
96+
@click.pass_context
97+
def upload(
98+
ctx: click.Context,
99+
model_name: str,
100+
model_source: str,
101+
model_type: str,
102+
hf_token: str | None,
103+
description: str | None,
104+
base_model: str | None,
105+
lora_model: str | None,
106+
json: bool,
107+
) -> None:
108+
"""Upload a custom model or adapter from Hugging Face or S3"""
109+
client: Together = ctx.obj
110+
111+
response: ModelUploadResponse = client.models.upload(
112+
model_name=model_name,
113+
model_source=model_source,
114+
model_type=model_type,
115+
hf_token=hf_token,
116+
description=description,
117+
base_model=base_model,
118+
lora_model=lora_model,
119+
)
120+
121+
if json:
122+
click.echo(json_lib.dumps(response.model_dump(), indent=2))
123+
else:
124+
click.echo(f"Model upload job created successfully!")
125+
if response.job_id:
126+
click.echo(f"Job ID: {response.job_id}")
127+
if response.model_name:
128+
click.echo(f"Model Name: {response.model_name}")
129+
if response.model_id:
130+
click.echo(f"Model ID: {response.model_id}")
131+
if response.model_source:
132+
click.echo(f"Model Source: {response.model_source}")
133+
click.echo(f"Message: {response.message}")

src/together/resources/models.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from together.together_response import TogetherResponse
77
from together.types import (
88
ModelObject,
9+
ModelUploadRequest,
10+
ModelUploadResponse,
911
TogetherClient,
1012
TogetherRequest,
1113
)
@@ -85,6 +87,64 @@ def list(
8587

8688
return models
8789

90+
def upload(
91+
self,
92+
*,
93+
model_name: str,
94+
model_source: str,
95+
model_type: str = "model",
96+
hf_token: str | None = None,
97+
description: str | None = None,
98+
base_model: str | None = None,
99+
lora_model: str | None = None,
100+
) -> ModelUploadResponse:
101+
"""
102+
Upload a custom model or adapter from Hugging Face or S3.
103+
104+
Args:
105+
model_name (str): The name to give to your uploaded model
106+
model_source (str): The source location of the model (Hugging Face repo or S3 path)
107+
model_type (str, optional): Whether the model is a full model or an adapter. Defaults to "model".
108+
hf_token (str, optional): Hugging Face token (if uploading from Hugging Face)
109+
description (str, optional): A description of your model
110+
base_model (str, optional): The base model to use for an adapter if setting it to run against a serverless pool. Only used for model_type "adapter".
111+
lora_model (str, optional): The lora pool to use for an adapter if setting it to run against, say, a dedicated pool. Only used for model_type "adapter".
112+
113+
Returns:
114+
ModelUploadResponse: Object containing upload job information
115+
"""
116+
requestor = api_requestor.APIRequestor(
117+
client=self._client,
118+
)
119+
120+
data = {
121+
"model_name": model_name,
122+
"model_source": model_source,
123+
"model_type": model_type,
124+
}
125+
126+
if hf_token is not None:
127+
data["hf_token"] = hf_token
128+
if description is not None:
129+
data["description"] = description
130+
if base_model is not None:
131+
data["base_model"] = base_model
132+
if lora_model is not None:
133+
data["lora_model"] = lora_model
134+
135+
response, _, _ = requestor.request(
136+
options=TogetherRequest(
137+
method="POST",
138+
url="models",
139+
params=data,
140+
),
141+
stream=False,
142+
)
143+
144+
assert isinstance(response, TogetherResponse)
145+
146+
return ModelUploadResponse.from_api_response(response.data)
147+
88148

89149
class AsyncModels(ModelsBase):
90150
async def list(
@@ -132,3 +192,61 @@ async def list(
132192
models.sort(key=lambda x: x.id.lower())
133193

134194
return models
195+
196+
async def upload(
197+
self,
198+
*,
199+
model_name: str,
200+
model_source: str,
201+
model_type: str = "model",
202+
hf_token: str | None = None,
203+
description: str | None = None,
204+
base_model: str | None = None,
205+
lora_model: str | None = None,
206+
) -> ModelUploadResponse:
207+
"""
208+
Upload a custom model or adapter from Hugging Face or S3.
209+
210+
Args:
211+
model_name (str): The name to give to your uploaded model
212+
model_source (str): The source location of the model (Hugging Face repo or S3 path)
213+
model_type (str, optional): Whether the model is a full model or an adapter. Defaults to "model".
214+
hf_token (str, optional): Hugging Face token (if uploading from Hugging Face)
215+
description (str, optional): A description of your model
216+
base_model (str, optional): The base model to use for an adapter if setting it to run against a serverless pool. Only used for model_type "adapter".
217+
lora_model (str, optional): The lora pool to use for an adapter if setting it to run against, say, a dedicated pool. Only used for model_type "adapter".
218+
219+
Returns:
220+
ModelUploadResponse: Object containing upload job information
221+
"""
222+
requestor = api_requestor.APIRequestor(
223+
client=self._client,
224+
)
225+
226+
data = {
227+
"model_name": model_name,
228+
"model_source": model_source,
229+
"model_type": model_type,
230+
}
231+
232+
if hf_token is not None:
233+
data["hf_token"] = hf_token
234+
if description is not None:
235+
data["description"] = description
236+
if base_model is not None:
237+
data["base_model"] = base_model
238+
if lora_model is not None:
239+
data["lora_model"] = lora_model
240+
241+
response, _, _ = await requestor.arequest(
242+
options=TogetherRequest(
243+
method="POST",
244+
url="models",
245+
params=data,
246+
),
247+
stream=False,
248+
)
249+
250+
assert isinstance(response, TogetherResponse)
251+
252+
return ModelUploadResponse.from_api_response(response.data)

src/together/types/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
TrainingType,
6060
)
6161
from together.types.images import ImageRequest, ImageResponse
62-
from together.types.models import ModelObject
62+
from together.types.models import ModelObject, ModelUploadRequest, ModelUploadResponse
6363
from together.types.rerank import RerankRequest, RerankResponse
6464
from together.types.batch import BatchJob, BatchJobStatus, BatchEndpoint
6565
from together.types.evaluation import (
@@ -110,6 +110,8 @@
110110
"ImageRequest",
111111
"ImageResponse",
112112
"ModelObject",
113+
"ModelUploadRequest",
114+
"ModelUploadResponse",
113115
"TrainingType",
114116
"FullTrainingType",
115117
"LoRATrainingType",

src/together/types/models.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import Literal
4+
from typing import Any, Dict, Literal, Optional
55

66
from together.types.abstract import BaseModel
77
from together.types.common import ObjectType
@@ -44,3 +44,52 @@ class ModelObject(BaseModel):
4444
license: str | None = None
4545
context_length: int | None = None
4646
pricing: PricingObject
47+
48+
49+
class ModelUploadRequest(BaseModel):
50+
model_name: str
51+
model_source: str
52+
model_type: Literal["model", "adapter"] = "model"
53+
hf_token: Optional[str] = None
54+
description: Optional[str] = None
55+
base_model: Optional[str] = None
56+
lora_model: Optional[str] = None
57+
58+
59+
class ModelUploadResponse(BaseModel):
60+
job_id: Optional[str] = None
61+
model_name: Optional[str] = None
62+
model_id: Optional[str] = None
63+
model_source: Optional[str] = None
64+
message: str
65+
66+
@classmethod
67+
def from_api_response(cls, response_data: Dict[str, Any]) -> "ModelUploadResponse":
68+
"""Create ModelUploadResponse from API response, handling both flat and nested structures"""
69+
# Start with the base response
70+
result: Dict[str, Any] = {"message": response_data.get("message", "")}
71+
72+
# Check if we have nested data
73+
if "data" in response_data and response_data["data"] is not None:
74+
# Use nested data values
75+
nested_data = response_data["data"]
76+
result.update(
77+
{
78+
"job_id": nested_data.get("job_id"),
79+
"model_name": nested_data.get("model_name"),
80+
"model_id": nested_data.get("model_id"),
81+
"model_source": nested_data.get("model_source"),
82+
}
83+
)
84+
else:
85+
# Use top-level values
86+
result.update(
87+
{
88+
"job_id": response_data.get("job_id"),
89+
"model_name": response_data.get("model_name"),
90+
"model_id": response_data.get("model_id"),
91+
"model_source": response_data.get("model_source"),
92+
}
93+
)
94+
95+
return cls(**result)

0 commit comments

Comments
 (0)