Skip to content

Commit 7f22cb7

Browse files
authored
fix: allowed_cuda_versions was incorrectly typed as string (#418)
This corrects some of the changes from #375
1 parent 1a0976f commit 7f22cb7

File tree

2 files changed

+22
-6
lines changed

2 files changed

+22
-6
lines changed

runpod/api/ctl_commands.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,13 +316,28 @@ def create_endpoint(
316316
workers_min: int = 0,
317317
workers_max: int = 3,
318318
flashboot=False,
319-
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
319+
allowed_cuda_versions: list = None,
320320
gpu_count: int = 1,
321321
):
322322
"""
323323
Create an endpoint
324324
325-
:param allowed_cuda_versions: Comma-separated string of allowed CUDA versions (e.g., "12.4,12.5").
325+
:param name: the name of the endpoint
326+
:param template_id: the id of the template to use for the endpoint
327+
:param gpu_ids: the ids of the GPUs to use for the endpoint
328+
:param network_volume_id: the id of the network volume to use for the endpoint
329+
:param locations: the locations to use for the endpoint
330+
:param idle_timeout: the idle timeout for the endpoint
331+
:param scaler_type: the scaler type for the endpoint
332+
:param scaler_value: the scaler value for the endpoint
333+
:param workers_min: the minimum number of workers for the endpoint
334+
:param workers_max: the maximum number of workers for the endpoint
335+
:param allowed_cuda_versions: Comma-separated list of allowed CUDA versions (e.g., ["12.4", "12.5"]).
336+
:param gpu_count: the number of GPUs to use for the endpoint
337+
338+
:example:
339+
340+
>>> endpoint_id = runpod.create_endpoint("test", "template_id")
326341
"""
327342
raw_response = run_graphql_query(
328343
endpoint_mutations.generate_endpoint_mutation(

runpod/api/mutations/endpoints.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def generate_endpoint_mutation(
1515
workers_min: int = 0,
1616
workers_max: int = 3,
1717
flashboot=False,
18-
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
18+
allowed_cuda_versions: list = None,
1919
gpu_count: int = None,
2020
):
2121
"""Generate a string for a GraphQL mutation to create a new endpoint."""
@@ -46,9 +46,10 @@ def generate_endpoint_mutation(
4646
input_fields.append(f"workersMin: {workers_min}")
4747
input_fields.append(f"workersMax: {workers_max}")
4848

49-
if allowed_cuda_versions is not None:
50-
input_fields.append(f'allowedCudaVersions: "{allowed_cuda_versions}"')
51-
49+
if allowed_cuda_versions:
50+
cuda_versions = ", ".join(f'"{v}"' for v in allowed_cuda_versions)
51+
input_fields.append(f"allowedCudaVersions: [{cuda_versions}]")
52+
5253
if gpu_count is not None:
5354
input_fields.append(f"gpuCount: {gpu_count}")
5455

0 commit comments

Comments
 (0)