Skip to content

Commit 74c0a9f

Browse files
committed
[fix] client sdk type and req in aggregator and key server
1 parent 0bc3fd9 commit 74c0a9f

8 files changed

Lines changed: 50 additions & 15 deletions

File tree

crates/key-server/src/aggregator/server.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ fn default_key_server_version_requirement() -> VersionReq {
6464
VersionReq::parse(">=0.5.14").expect("Failed to parse default key server version requirement")
6565
}
6666

67+
fn default_rust_sdk_version_requirement() -> VersionReq {
68+
VersionReq::parse(">=0.0.0").expect("Failed to parse default Rust SDK version requirement")
69+
}
70+
6771
/// Configuration file format for aggregator server.
6872
#[derive(Clone, Deserialize)]
6973
struct AggregatorOptions {
@@ -79,6 +83,10 @@ struct AggregatorOptions {
7983
#[serde(default = "default_ts_sdk_version_requirement")]
8084
ts_sdk_version_requirement: VersionReq,
8185

86+
/// The minimum version of the SDK that is required to use this aggregator.
87+
#[serde(default = "default_rust_sdk_version_requirement")]
88+
rust_sdk_version_requirement: VersionReq,
89+
8290
/// The minimum version of the key server that is required by this aggregator.
8391
#[serde(default = "default_key_server_version_requirement")]
8492
key_server_version_requirement: VersionReq,
@@ -112,16 +120,20 @@ impl AppState {
112120
sdk_type: Option<&HeaderValue>,
113121
) -> Result<(), InternalError> {
114122
let version = Version::parse(version).map_err(|_| InvalidSDKVersion)?;
115-
let sdk_type = ClientSdkType::from_header(sdk_type.and_then(|v| v.to_str().ok()));
123+
let sdk_type = ClientSdkType::from_header(sdk_type.and_then(|v| v.to_str().ok()))?;
116124

117125
match sdk_type {
118126
ClientSdkType::TypeScript => {
119127
if !self.options.ts_sdk_version_requirement.matches(&version) {
120128
return Err(DeprecatedSDKVersion);
121129
}
122130
}
123-
_ => {
124-
// TODO: Add support for other SDK types.
131+
ClientSdkType::Rust => {
132+
if !self.options.rust_sdk_version_requirement.matches(&version) {
133+
return Err(DeprecatedSDKVersion);
134+
}
135+
}
136+
ClientSdkType::Aggregator => {
125137
return Err(InvalidSDKType);
126138
}
127139
}
@@ -570,6 +582,7 @@ mod tests {
570582
node_url: None,
571583
key_server_object_id: Address::from([0u8; 32]),
572584
ts_sdk_version_requirement: VersionReq::parse(">=0.9.0").unwrap(),
585+
rust_sdk_version_requirement: VersionReq::parse(">=0.0.0").unwrap(),
573586
key_server_version_requirement: VersionReq::parse(">=0.5.14").unwrap(),
574587
};
575588
let grpc_client = SuiGrpcClient::new(options.node_url()).unwrap();
@@ -607,6 +620,7 @@ mod tests {
607620
let (request, _, _) = create_test_fetch_key_request(&mut thread_rng());
608621

609622
let mut headers = HeaderMap::new();
623+
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
610624
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.3.0".parse().unwrap()); // Too old
611625
let result = handle_fetch_key(State(state), headers, Json(request)).await;
612626

@@ -638,6 +652,7 @@ mod tests {
638652
let (request, _, _) = create_test_fetch_key_request(&mut thread_rng());
639653

640654
let mut headers = HeaderMap::new();
655+
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
641656
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
642657
let result = handle_fetch_key(State(state), headers, Json(request)).await;
643658

@@ -669,6 +684,7 @@ mod tests {
669684
let (request, _, _) = create_test_fetch_key_request(&mut thread_rng());
670685

671686
let mut headers = HeaderMap::new();
687+
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
672688
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
673689
let result = handle_fetch_key(State(state), headers, Json(request)).await;
674690
let response = result.unwrap().into_response();
@@ -732,6 +748,7 @@ mod tests {
732748

733749
// Call handle_fetch_key and check majority error.
734750
let mut headers = HeaderMap::new();
751+
headers.insert(HEADER_CLIENT_SDK_TYPE, "typescript".parse().unwrap());
735752
headers.insert(HEADER_CLIENT_SDK_VERSION, "0.9.6".parse().unwrap());
736753
let result = handle_fetch_key(State(state), headers, Json(request)).await;
737754
match result {

crates/key-server/src/common.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ use axum::response::Response;
66
use serde::{Deserialize, Serialize};
77
use sui_types::base_types::ObjectID;
88

9+
use crate::errors::InternalError;
10+
911
/// Network configuration.
1012
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
1113
pub enum Network {
@@ -50,6 +52,9 @@ pub const SDK_TYPE_AGGREGATOR: &str = "aggregator";
5052
/// SDK type value for TypeScript clients.
5153
pub const SDK_TYPE_TYPESCRIPT: &str = "typescript";
5254

55+
/// SDK type value for Rust clients.
56+
pub const SDK_TYPE_RUST: &str = "rust";
57+
5358
/// Get the git version.
5459
/// Based on https://github.com/MystenLabs/walrus/blob/7e282a681e6530ae4073210b33cac915fab439fa/crates/walrus-service/src/common/utils.rs#L69
5560
#[macro_export]
@@ -80,16 +85,16 @@ macro_rules! git_version {
8085
pub enum ClientSdkType {
8186
Aggregator,
8287
TypeScript,
83-
Other,
88+
Rust,
8489
}
8590

8691
impl ClientSdkType {
87-
pub fn from_header(header_value: Option<&str>) -> Self {
92+
pub fn from_header(header_value: Option<&str>) -> Result<ClientSdkType, InternalError> {
8893
match header_value {
89-
Some(SDK_TYPE_AGGREGATOR) => ClientSdkType::Aggregator,
90-
Some(SDK_TYPE_TYPESCRIPT) => ClientSdkType::TypeScript,
91-
Some(_) => ClientSdkType::Other,
92-
None => ClientSdkType::TypeScript, // Default to TypeScript for backward compatibility
94+
Some(SDK_TYPE_AGGREGATOR) => Ok(ClientSdkType::Aggregator),
95+
Some(SDK_TYPE_TYPESCRIPT) => Ok(ClientSdkType::TypeScript),
96+
Some(SDK_TYPE_RUST) => Ok(ClientSdkType::Rust),
97+
_ => Err(InternalError::InvalidSDKType),
9398
}
9499
}
95100
}

crates/key-server/src/key_server_options.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ pub struct KeyServerOptions {
125125
#[serde(default = "default_ts_sdk_version_requirement")]
126126
pub ts_sdk_version_requirement: VersionReq,
127127

128+
/// The minimum version of the Rust SDK that is required to use this key server.
129+
#[serde(default = "default_rust_sdk_version_requirement")]
130+
pub rust_sdk_version_requirement: VersionReq,
131+
128132
/// The minimum version of the aggregator that is required to use this key server.
129133
#[serde(default = "default_aggregator_version_requirement")]
130134
pub aggregator_version_requirement: VersionReq,
@@ -184,6 +188,7 @@ impl KeyServerOptions {
184188
node_url: None,
185189
ts_sdk_version_requirement: default_ts_sdk_version_requirement(),
186190
aggregator_version_requirement: default_aggregator_version_requirement(),
191+
rust_sdk_version_requirement: default_rust_sdk_version_requirement(),
187192
server_mode: ServerMode::Open {
188193
key_server_object_id,
189194
},
@@ -203,6 +208,7 @@ impl KeyServerOptions {
203208
node_url: None,
204209
ts_sdk_version_requirement: default_ts_sdk_version_requirement(),
205210
aggregator_version_requirement: default_aggregator_version_requirement(),
211+
rust_sdk_version_requirement: default_rust_sdk_version_requirement(),
206212
server_mode: ServerMode::Open {
207213
key_server_object_id: ObjectID::random(),
208214
},
@@ -336,6 +342,9 @@ fn default_aggregator_version_requirement() -> VersionReq {
336342
VersionReq::parse(">=1000.0.0").expect("Failed to parse default aggregator version requirement")
337343
}
338344

345+
fn default_rust_sdk_version_requirement() -> VersionReq {
346+
VersionReq::parse(">=0.0.0").expect("Failed to parse default Rust SDK version requirement")
347+
}
339348
#[test]
340349
fn test_parse_open_config() {
341350
use std::str::FromStr;

crates/key-server/src/server.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::types::{IbeMasterKey, MasterKeyPOP, Network};
1616
use crate::InternalError::DeprecatedSDKVersion;
1717
use anyhow::{Context, Result};
1818
use axum::extract::{Query, Request};
19-
use axum::http::HeaderMap;
19+
use axum::http::{HeaderMap, HeaderValue};
2020
use axum::middleware::{from_fn_with_state, map_response, Next};
2121
use axum::response::Response;
2222
use axum::routing::{get, post};
@@ -849,14 +849,15 @@ impl MyState {
849849
fn validate_sdk_version(
850850
&self,
851851
version_string: &str,
852-
sdk_type: ClientSdkType,
852+
sdk_type: Option<&HeaderValue>,
853853
) -> Result<(), InternalError> {
854+
let sdk_type = ClientSdkType::from_header(sdk_type.and_then(|t| t.to_str().ok()))?;
854855
let version = Version::parse(version_string).map_err(|_| InvalidSDKVersion)?;
855856

856857
let requirement = match sdk_type {
857858
ClientSdkType::Aggregator => &self.server.options.aggregator_version_requirement,
858859
ClientSdkType::TypeScript => &self.server.options.ts_sdk_version_requirement,
859-
ClientSdkType::Other => return Ok(()),
860+
ClientSdkType::Rust => &self.server.options.rust_sdk_version_requirement,
860861
};
861862

862863
if !requirement.matches(&version) {
@@ -888,8 +889,6 @@ async fn handle_request_headers(
888889
request.headers().get("Client-Target-Api-Version")
889890
);
890891

891-
let sdk_type = ClientSdkType::from_header(sdk_type.and_then(|t| t.to_str().ok()));
892-
893892
version
894893
.ok_or(MissingRequiredHeader(HEADER_CLIENT_SDK_VERSION.to_string()))
895894
.and_then(|v| v.to_str().map_err(|_| InvalidSDKVersion))

crates/key-server/src/tests/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ impl SealTestCluster {
183183
rgp_update_interval: Duration::from_secs(60),
184184
ts_sdk_version_requirement: VersionReq::from_str(">=0.4.6").unwrap(),
185185
aggregator_version_requirement: VersionReq::from_str(">=0.5.15").unwrap(),
186+
rust_sdk_version_requirement: VersionReq::from_str(">=0.0.0").unwrap(),
186187
allowed_staleness,
187188
session_key_ttl_max: from_mins(30),
188189
rpc_config: RpcConfig::default(),

crates/key-server/src/tests/server.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ async fn test_service() {
151151
"http://{addr}/v1/service?service_id={}",
152152
key_server_object_id.as_str()
153153
))
154+
.header(HEADER_CLIENT_SDK_TYPE, "typescript")
154155
.header(HEADER_CLIENT_SDK_VERSION, "0.3.0") // Too old (requires >=0.4.6)
155156
.body(Body::empty())
156157
.unwrap(),
@@ -196,6 +197,7 @@ async fn test_service() {
196197
"http://{addr}/v1/service?service_id={}",
197198
key_server_object_id.as_str()
198199
))
200+
.header(HEADER_CLIENT_SDK_TYPE, "typescript")
199201
.header(HEADER_CLIENT_SDK_VERSION, "0.4.11")
200202
.body(Body::empty())
201203
.unwrap(),
@@ -343,6 +345,7 @@ async fn test_fetch_key() {
343345
Request::builder()
344346
.uri(format!("http://{addr}/v1/fetch_key",))
345347
.method("POST")
348+
.header(HEADER_CLIENT_SDK_TYPE, "typescript")
346349
.header(HEADER_CLIENT_SDK_VERSION, "0.4.11")
347350
.header("Content-Type", "application/json")
348351
.body(Body::from(json!(request).to_string()))

crates/key-server/src/tests/test_utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ pub(crate) async fn create_test_server(
4949
metrics_host_port: 0,
5050
rgp_update_interval: Duration::from_secs(60),
5151
ts_sdk_version_requirement: VersionReq::from_str(">=0.4.6").unwrap(),
52+
rust_sdk_version_requirement: VersionReq::from_str(">=0.0.0").unwrap(),
5253
aggregator_version_requirement: VersionReq::from_str(">=0.5.15").unwrap(),
5354
allowed_staleness: Duration::from_secs(120),
5455
session_key_ttl_max: from_mins(30),

crates/seal-cli/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ async fn main() -> FastCryptoResult<()> {
575575
match client
576576
.post(format!("{}/v1/fetch_key", server.url))
577577
.header("Client-Sdk-Type", "rust")
578-
.header("Client-Sdk-Version", "1.0.0")
578+
.header("Client-Sdk-Version", "0.0.0")
579579
.header("Content-Type", "application/json")
580580
.body(Body::from(
581581
request.to_json_string().expect("should not fail"),

0 commit comments

Comments
 (0)