Skip to content

Commit 9c5918e

Browse files
authored
Add allowed CUDA versions parameter to endpoint creation (#375)
* Add allowed CUDA versions parameter to endpoint creation * Add gpu_count parameter to endpoint creation functions
1 parent 7912c20 commit 9c5918e

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

runpod/api/ctl_commands.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -302,24 +302,13 @@ def create_endpoint(
302302
workers_min: int = 0,
303303
workers_max: int = 3,
304304
flashboot=False,
305+
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
306+
gpu_count: int = 1,
305307
):
306308
"""
307309
Create an endpoint
308310
309-
:param name: the name of the endpoint
310-
:param template_id: the id of the template to use for the endpoint
311-
:param gpu_ids: the ids of the GPUs to use for the endpoint
312-
:param network_volume_id: the id of the network volume to use for the endpoint
313-
:param locations: the locations to use for the endpoint
314-
:param idle_timeout: the idle timeout for the endpoint
315-
:param scaler_type: the scaler type for the endpoint
316-
:param scaler_value: the scaler value for the endpoint
317-
:param workers_min: the minimum number of workers for the endpoint
318-
:param workers_max: the maximum number of workers for the endpoint
319-
320-
:example:
321-
322-
>>> endpoint_id = runpod.create_endpoint("test", "template_id")
311+
:param allowed_cuda_versions: Comma-separated string of allowed CUDA versions (e.g., "12.4,12.5").
323312
"""
324313
raw_response = run_graphql_query(
325314
endpoint_mutations.generate_endpoint_mutation(
@@ -334,6 +323,8 @@ def create_endpoint(
334323
workers_min,
335324
workers_max,
336325
flashboot,
326+
allowed_cuda_versions,
327+
gpu_count
337328
)
338329
)
339330

runpod/api/mutations/endpoints.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def generate_endpoint_mutation(
1515
workers_min: int = 0,
1616
workers_max: int = 3,
1717
flashboot=False,
18+
allowed_cuda_versions: str = "12.1,12.2,12.3,12.4,12.5",
19+
gpu_count: int = None,
1820
):
1921
"""Generate a string for a GraphQL mutation to create a new endpoint."""
2022
input_fields = []
@@ -44,6 +46,12 @@ def generate_endpoint_mutation(
4446
input_fields.append(f"workersMin: {workers_min}")
4547
input_fields.append(f"workersMax: {workers_max}")
4648

49+
if allowed_cuda_versions is not None:
50+
input_fields.append(f'allowedCudaVersions: "{allowed_cuda_versions}"')
51+
52+
if gpu_count is not None:
53+
input_fields.append(f"gpuCount: {gpu_count}")
54+
4755
# Format the input fields into a string
4856
input_fields_string = ", ".join(input_fields)
4957

@@ -65,11 +73,14 @@ def generate_endpoint_mutation(
6573
scalerValue
6674
workersMin
6775
workersMax
76+
allowedCudaVersions
77+
gpuCount
6878
}}
6979
}}
7080
"""
7181

7282

83+
7384
def update_endpoint_template_mutation(endpoint_id: str, template_id: str):
7485
"""Generate a string for a GraphQL mutation to update an existing endpoint's template."""
7586
input_fields = []

0 commit comments

Comments
 (0)