Skip to content

Commit 2c54589

Browse files
authored
feat(cli): add GPU count requests (#1812)
* feat(gpu)!: add resource requirements BREAKING CHANGE: SandboxSpec.gpu and DriverSandboxSpec.gpu were replaced with resource_requirements.gpu, changing protobuf field 9 from a bool to a message for both public and driver APIs. Signed-off-by: Evan Lezar <elezar@nvidia.com> * refactor(gpu): pass requirements through sandbox create Pass the coupled GPU requirement object through the CLI sandbox_create boundary instead of splitting presence and count into separate arguments. Signed-off-by: Evan Lezar <elezar@nvidia.com> * refactor(gpu): pass requirements to timeout message Pass ResourceRequirements into the provisioning timeout message helper so GPU hints are derived from the same nested request object used to create the sandbox. Signed-off-by: Evan Lezar <elezar@nvidia.com> * refactor(gpu): pass driver requirements through helpers Thread Option<GpuResourceRequirements> through driver validation and rendering helpers instead of splitting GPU presence and count into separate arguments. Signed-off-by: Evan Lezar <elezar@nvidia.com> * fix(gpu): validate exact device requests Require exact driver GPU device lists to be tied to a GPU request, allow a single exact device to use the default countless request, and require explicit matching counts for multi-device lists. Signed-off-by: Evan Lezar <elezar@nvidia.com> --------- Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 8e831f3 commit 2c54589

22 files changed

Lines changed: 2035 additions & 618 deletions

File tree

architecture/compute-runtimes.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,9 @@ through the driver configuration. The Helm chart defaults sandbox agents to
5555
`Unconfined` so runtime/default AppArmor profiles do not block supervisor
5656
network namespace setup on AppArmor-enabled nodes.
5757

58-
GPU requests enter the driver layer through `SandboxSpec.gpu` and
59-
`SandboxSpec.gpu_device`. Docker and Podman map default GPU requests to one
60-
concrete NVIDIA CDI device when individual CDI devices are available, use
61-
`nvidia.com/gpu=all` only for WSL2/all-only compatibility, and pass explicit
62-
driver-native device IDs through.
58+
Resource requirements enter the driver layer through `SandboxSpec.resource_requirements`. This includes a set of GPU requirements, where a user
59+
can request a specific number of GPUs or the driver-specific default behaviour.
60+
For all in-tree drivers, this is equivalent to selecting a single GPU.
6361

6462
VM runtime state paths are derived only from driver-validated sandbox IDs
6563
matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a
@@ -98,7 +96,10 @@ users.
9896
Custom sandbox images must include the agent runtime and any system
9997
dependencies, but they should not need to include the gateway. GPU-capable
10098
images must include the user-space libraries required by the workload. The
101-
runtime still owns GPU device injection.
99+
runtime still owns GPU device injection. GPU requests are explicit, and can be
100+
refined with a driver-native device identifier or requested count; the gateway
101+
validates the request shape and each runtime enforces the GPU allocation modes it
102+
supports.
102103

103104
## Deployment Shape
104105

crates/openshell-cli/src/main.rs

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use openshell_bootstrap::{
1919
use openshell_cli::completers;
2020
use openshell_cli::run;
2121
use openshell_cli::tls::TlsOptions;
22+
use openshell_core::proto::GpuResourceRequirements;
2223

2324
/// Resolved gateway context: name + gateway endpoint.
2425
struct GatewayContext {
@@ -28,6 +29,21 @@ struct GatewayContext {
2829
endpoint: String,
2930
}
3031

32+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33+
enum GpuCliRequest {
34+
DriverDefault,
35+
Count(u32),
36+
}
37+
38+
impl From<GpuCliRequest> for GpuResourceRequirements {
39+
fn from(gpu: GpuCliRequest) -> Self {
40+
match gpu {
41+
GpuCliRequest::Count(count) => Self { count: Some(count) },
42+
GpuCliRequest::DriverDefault => Self { count: None },
43+
}
44+
}
45+
}
46+
3147
/// Resolve the gateway name to a [`GatewayContext`] with the gateway endpoint.
3248
///
3349
/// Resolution priority:
@@ -109,6 +125,21 @@ fn resolve_gateway(
109125
})
110126
}
111127

128+
fn parse_gpu_request(value: &str) -> std::result::Result<GpuCliRequest, String> {
129+
if value.is_empty() {
130+
return Ok(GpuCliRequest::DriverDefault);
131+
}
132+
133+
let count = value
134+
.parse::<u32>()
135+
.map_err(|_| "GPU count must be a positive integer".to_string())?;
136+
if count == 0 {
137+
return Err("GPU count must be greater than 0".to_string());
138+
}
139+
140+
Ok(GpuCliRequest::Count(count))
141+
}
142+
112143
fn resolve_gateway_name(gateway_flag: &Option<String>) -> Option<String> {
113144
gateway_flag
114145
.clone()
@@ -1227,8 +1258,11 @@ enum SandboxCommands {
12271258
editor: Option<CliEditor>,
12281259

12291260
/// Request GPU resources for the sandbox.
1230-
#[arg(long)]
1231-
gpu: bool,
1261+
///
1262+
/// Omit COUNT for the driver's default GPU selection, or pass COUNT
1263+
/// to request a specific number of GPUs.
1264+
#[arg(long, num_args = 0..=1, value_name = "COUNT", default_missing_value = "", value_parser = parse_gpu_request)]
1265+
gpu: Option<GpuCliRequest>,
12321266

12331267
/// CPU limit for the sandbox (for example: 500m, 1, 2.5).
12341268
#[arg(long)]
@@ -2636,6 +2670,7 @@ async fn main() -> Result<()> {
26362670
.map(|s| openshell_core::forward::ForwardSpec::parse(&s))
26372671
.transpose()?;
26382672
let keep = keep || !no_keep || editor.is_some() || forward.is_some();
2673+
let gpu_requirements: Option<GpuResourceRequirements> = gpu.map(Into::into);
26392674

26402675
let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?;
26412676
let endpoint = &ctx.endpoint;
@@ -2648,7 +2683,7 @@ async fn main() -> Result<()> {
26482683
&ctx.name,
26492684
&upload_specs,
26502685
keep,
2651-
gpu,
2686+
gpu_requirements,
26522687
cpu.as_deref(),
26532688
memory.as_deref(),
26542689
driver_config_json.as_deref(),
@@ -3648,6 +3683,27 @@ mod tests {
36483683
});
36493684
}
36503685

3686+
#[test]
3687+
fn gpu_cli_request_option_maps_absent_gpu_to_no_requirements() {
3688+
let gpu: Option<GpuResourceRequirements> = Option::<GpuCliRequest>::None.map(Into::into);
3689+
3690+
assert_eq!(gpu, None);
3691+
}
3692+
3693+
#[test]
3694+
fn gpu_cli_request_driver_default_converts_to_requirements() {
3695+
let gpu = GpuResourceRequirements::from(GpuCliRequest::DriverDefault);
3696+
3697+
assert_eq!(gpu.count, None);
3698+
}
3699+
3700+
#[test]
3701+
fn gpu_cli_request_count_converts_to_requirements() {
3702+
let gpu = GpuResourceRequirements::from(GpuCliRequest::Count(2));
3703+
3704+
assert_eq!(gpu.count, Some(2));
3705+
}
3706+
36513707
#[test]
36523708
fn apply_auth_uses_stored_token() {
36533709
let tmp = tempfile::tempdir().unwrap();
@@ -4529,6 +4585,113 @@ mod tests {
45294585
}
45304586
}
45314587

4588+
#[test]
4589+
fn sandbox_create_gpu_parses_driver_default() {
4590+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu"])
4591+
.expect("sandbox create --gpu should parse");
4592+
4593+
match cli.command {
4594+
Some(Commands::Sandbox {
4595+
command: Some(SandboxCommands::Create { gpu, .. }),
4596+
..
4597+
}) => {
4598+
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
4599+
}
4600+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4601+
}
4602+
}
4603+
4604+
#[test]
4605+
fn sandbox_create_gpu_count_parses_from_gpu_flag() {
4606+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "2"])
4607+
.expect("sandbox create --gpu 2 should parse");
4608+
4609+
match cli.command {
4610+
Some(Commands::Sandbox {
4611+
command: Some(SandboxCommands::Create { gpu, .. }),
4612+
..
4613+
}) => {
4614+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
4615+
}
4616+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4617+
}
4618+
}
4619+
4620+
#[test]
4621+
fn sandbox_create_gpu_driver_default_allows_trailing_command() {
4622+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "--", "claude"])
4623+
.expect("sandbox create --gpu -- claude should parse");
4624+
4625+
match cli.command {
4626+
Some(Commands::Sandbox {
4627+
command: Some(SandboxCommands::Create { gpu, command, .. }),
4628+
..
4629+
}) => {
4630+
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
4631+
assert_eq!(command, vec!["claude".to_string()]);
4632+
}
4633+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4634+
}
4635+
}
4636+
4637+
#[test]
4638+
fn sandbox_create_gpu_count_allows_trailing_command() {
4639+
let cli = Cli::try_parse_from([
4640+
"openshell",
4641+
"sandbox",
4642+
"create",
4643+
"--gpu",
4644+
"2",
4645+
"--",
4646+
"claude",
4647+
])
4648+
.expect("sandbox create --gpu 2 -- claude should parse");
4649+
4650+
match cli.command {
4651+
Some(Commands::Sandbox {
4652+
command: Some(SandboxCommands::Create { gpu, command, .. }),
4653+
..
4654+
}) => {
4655+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
4656+
assert_eq!(command, vec!["claude".to_string()]);
4657+
}
4658+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4659+
}
4660+
}
4661+
4662+
#[test]
4663+
fn sandbox_create_gpu_count_rejects_zero() {
4664+
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "0"]);
4665+
4666+
assert!(result.is_err(), "sandbox create --gpu 0 should be rejected");
4667+
}
4668+
4669+
#[test]
4670+
fn sandbox_create_gpu_count_accepts_equals_syntax() {
4671+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu=2"])
4672+
.expect("sandbox create --gpu=2 should parse");
4673+
4674+
match cli.command {
4675+
Some(Commands::Sandbox {
4676+
command: Some(SandboxCommands::Create { gpu, .. }),
4677+
..
4678+
}) => {
4679+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
4680+
}
4681+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4682+
}
4683+
}
4684+
4685+
#[test]
4686+
fn sandbox_create_gpu_count_rejects_non_integer() {
4687+
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "many"]);
4688+
4689+
assert!(
4690+
result.is_err(),
4691+
"sandbox create --gpu many should be rejected"
4692+
);
4693+
}
4694+
45324695
#[test]
45334696
fn service_expose_accepts_positional_target_port_and_service() {
45344697
let cli = Cli::try_parse_from([

crates/openshell-cli/src/run.rs

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,13 @@ use openshell_core::proto::{
4242
GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest,
4343
GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest,
4444
GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest,
45-
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest,
46-
ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest,
47-
ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest,
48-
ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider,
49-
ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile,
50-
ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest,
45+
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirements,
46+
HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest,
47+
ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest,
48+
ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent,
49+
PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus,
50+
ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic,
51+
ProviderProfileImportItem, RejectDraftChunkRequest, ResourceRequirements,
5152
RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy,
5253
SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest,
5354
SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget,
@@ -123,7 +124,7 @@ fn ready_false_condition_message(
123124

124125
fn provisioning_timeout_message(
125126
timeout_secs: u64,
126-
requested_gpu: bool,
127+
resource_requirements: Option<&ResourceRequirements>,
127128
condition_message: Option<&str>,
128129
) -> String {
129130
let mut message = format!("sandbox provisioning timed out after {timeout_secs}s");
@@ -133,7 +134,7 @@ fn provisioning_timeout_message(
133134
message.push_str(condition_message);
134135
}
135136

136-
if requested_gpu {
137+
if resource_requirements.is_some_and(|requirements| requirements.gpu.is_some()) {
137138
message.push_str(
138139
". Hint: this may be because the available GPU is already in use by another sandbox.",
139140
);
@@ -1753,7 +1754,7 @@ pub async fn sandbox_create(
17531754
gateway_name: &str,
17541755
uploads: &[(String, Option<String>, bool)],
17551756
keep: bool,
1756-
gpu: bool,
1757+
gpu_requirements: Option<GpuResourceRequirements>,
17571758
cpu: Option<&str>,
17581759
memory: Option<&str>,
17591760
driver_config_json: Option<&str>,
@@ -1809,8 +1810,6 @@ pub async fn sandbox_create(
18091810
}
18101811
None => None,
18111812
};
1812-
let requested_gpu = gpu;
1813-
18141813
let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
18151814
let inferred_types: Vec<String> = if providers_v2_enabled {
18161815
Vec::new()
@@ -1842,9 +1841,11 @@ pub async fn sandbox_create(
18421841
None
18431842
};
18441843

1844+
let resource_requirements = gpu_requirements.map(|gpu| ResourceRequirements { gpu: Some(gpu) });
1845+
18451846
let request = CreateSandboxRequest {
18461847
spec: Some(SandboxSpec {
1847-
gpu: requested_gpu,
1848+
resource_requirements,
18481849
environment: environment.clone(),
18491850
policy,
18501851
providers: configured_providers,
@@ -1989,7 +1990,7 @@ pub async fn sandbox_create(
19891990
if remaining.is_zero() {
19901991
let timeout_message = provisioning_timeout_message(
19911992
provision_timeout.as_secs(),
1992-
requested_gpu,
1993+
resource_requirements.as_ref(),
19931994
last_condition_message.as_deref(),
19941995
);
19951996
if let Some(d) = display.as_mut() {
@@ -2008,7 +2009,7 @@ pub async fn sandbox_create(
20082009
// Timeout fired — the stream was idle for too long.
20092010
let timeout_message = provisioning_timeout_message(
20102011
provision_timeout.as_secs(),
2011-
requested_gpu,
2012+
resource_requirements.as_ref(),
20122013
last_condition_message.as_deref(),
20132014
);
20142015
if let Some(d) = display.as_mut() {
@@ -7776,9 +7777,10 @@ mod tests {
77767777
PROGRESS_STEP_STARTING_SANDBOX,
77777778
};
77787779
use openshell_core::proto::{
7779-
Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus,
7780-
ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrant, ProviderProfile,
7781-
ProviderProfileCredential, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta,
7780+
GpuResourceRequirements, Provider, ProviderCredentialRefresh,
7781+
ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy,
7782+
ProviderCredentialTokenGrant, ProviderProfile, ProviderProfileCredential,
7783+
ResourceRequirements, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta,
77827784
};
77837785

77847786
struct EnvVarGuard {
@@ -8482,9 +8484,12 @@ mod tests {
84828484

84838485
#[test]
84848486
fn provisioning_timeout_message_includes_condition_and_gpu_hint() {
8487+
let resource_requirements = ResourceRequirements {
8488+
gpu: Some(GpuResourceRequirements { count: None }),
8489+
};
84858490
let message = provisioning_timeout_message(
84868491
120,
8487-
true,
8492+
Some(&resource_requirements),
84888493
Some("DependenciesNotReady: Pod exists with phase: Pending; Service Exists"),
84898494
);
84908495

@@ -8495,7 +8500,15 @@ mod tests {
84958500

84968501
#[test]
84978502
fn provisioning_timeout_message_omits_gpu_hint_for_non_gpu_requests() {
8498-
let message = provisioning_timeout_message(120, false, None);
8503+
let message = provisioning_timeout_message(120, None, None);
8504+
8505+
assert_eq!(message, "sandbox provisioning timed out after 120s");
8506+
}
8507+
8508+
#[test]
8509+
fn provisioning_timeout_message_omits_gpu_hint_without_gpu_requirements() {
8510+
let resource_requirements = ResourceRequirements { gpu: None };
8511+
let message = provisioning_timeout_message(120, Some(&resource_requirements), None);
84998512

85008513
assert_eq!(message, "sandbox provisioning timed out after 120s");
85018514
}

0 commit comments

Comments
 (0)