Skip to content

Commit 55b0f1b

Browse files
committed
Adjust Trainings#async_create signature
to better align with `Trainings.create` and the way that the arguments are being used. Closes #408
1 parent 461ec70 commit 55b0f1b

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

replicate/training.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,9 @@ def create( # type: ignore
307307

308308
async def async_create(
309309
self,
310-
model: Union[str, Tuple[str, str], "Model"],
311-
version: Union[str, Version],
312-
input: Dict[str, Any],
310+
model: Optional[Union[str, Tuple[str, str], "Model"]] = None,
311+
version: Optional[Union[str, Version]] = None,
312+
input: Optional[Dict[str, Any]] = None,
313313
**params: Unpack["Trainings.CreateTrainingParams"],
314314
) -> Training:
315315
"""
@@ -326,7 +326,15 @@ async def async_create(
326326
The training object.
327327
"""
328328

329-
url = _create_training_url_from_model_and_version(model, version)
329+
url = None
330+
331+
if model and version:
332+
url = _create_training_url_from_model_and_version(model, version)
333+
elif model is None and isinstance(version, str):
334+
url = _create_training_url_from_shorthand(version)
335+
336+
if not url:
337+
raise ValueError("model and version or shorthand version must be specified")
330338

331339
file_encoding_strategy = params.pop("file_encoding_strategy", None)
332340
if input is not None:

0 commit comments

Comments
 (0)