diff --git a/Cargo.lock b/Cargo.lock index 301fa18c..f0fbc0d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -960,6 +960,26 @@ dependencies = [ "serde", ] +[[package]] +name = "bincode" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36eaf5d7b090263e8150820482d5d93cd964a81e4019913c972f4edcc6edb740" +dependencies = [ + "bincode_derive", + "serde", + "unty", +] + +[[package]] +name = "bincode_derive" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf95709a440f45e986983918d0e8a1f30a9b1df04918fc828670606804ac3c09" +dependencies = [ + "virtue", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -2179,13 +2199,16 @@ dependencies = [ "dstack-guest-agent-rpc", "dstack-kms-rpc", "dstack-types", + "flate2", "fs-err", "futures", "git-version", "hex", "hickory-resolver 0.24.4", + "http-body-util", "http-client", "hyper", + "hyper-rustls", "hyper-util", "insta", "ipnet", @@ -2199,6 +2222,7 @@ dependencies = [ "rand 0.8.5", "reqwest", "rinja", + "rmp-serde", "rocket", "rustls", "safe-write", @@ -2208,10 +2232,14 @@ dependencies = [ "sha2 0.10.9", "shared_child", "smallvec", + "tempfile", "tokio", "tokio-rustls", "tracing", "tracing-subscriber", + "uuid", + "wavekv", + "x509-parser", ] [[package]] @@ -2839,9 +2867,9 @@ dependencies = [ [[package]] name = "fs-err" -version = "3.1.1" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d7be93788013f265201256d58f04936a8079ad5dc898743aa20525f503b683" +checksum = "62d91fd049c123429b018c47887d3f75a265540dd3c30ba9cb7bae9197edb03a" dependencies = [ "autocfg", ] @@ -4213,7 +4241,7 @@ checksum = "2044d8bd5489b199890c3dbf38d4c8f50f3a5a38833986808b14e2367fe267fa" dependencies = [ "aes 0.7.5", "base64 0.13.1", - "bincode", + "bincode 1.3.3", "crossterm", "hmac 0.11.0", "pbkdf2", @@ -5852,6 +5880,28 @@ dependencies = [ "rustc-hex", ] +[[package]] +name = "rmp" +version = "0.8.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "228ed7c16fa39782c3b3468e974aec2795e9089153cd08ee2e9aefb3613334c4" +dependencies = [ + "byteorder", + "num-traits", + "paste", +] + +[[package]] +name = "rmp-serde" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e599a477cf9840e92f2cde9a7189e67b42c57532749bf90aea6ec10facd4db" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + [[package]] name = "rocket" version = "0.6.0-dev" @@ -7680,6 +7730,12 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "unty" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" + [[package]] name = "url" version = "2.5.4" @@ -7732,6 +7788,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "virtue" +version = "0.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "051eb1abcf10076295e815102942cc58f9d5e3b4560e46e53c21e8ff6f3af7b1" + [[package]] name = "void" version = "1.0.2" @@ -7877,6 +7939,29 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wavekv" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9b73bc556dfdb7ef33617a9d477b803198db43ea3df25463efaf43d4986fe8" +dependencies = [ + "anyhow", + "bincode 2.0.1", + "chrono", + "crc32fast", + "dashmap", + "fs-err", + "futures", + "hex", + "rmp-serde", + "serde", + "serde-human-bytes", + "serde_json", + "sha2 0.10.9", + "tokio", + "tracing", +] + [[package]] name = "web-sys" version = "0.3.77" diff --git a/Cargo.toml b/Cargo.toml index 084ded65..9f09d05d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,6 +83,7 @@ sodiumbox = { path = "sodiumbox" } serde-duration = { path = "serde-duration" } dstack-mr = { path = "dstack-mr" } size-parser = { path = "size-parser" } +wavekv = "1.0.0" # Core dependencies anyhow = { version = "1.0.97", default-features = false } @@ -105,6 +106,7 @@ sd-notify = "0.4.5" jemallocator = "0.5.4" # Serialization/Parsing +flate2 = "1.0" borsh = { version = "1.5.7", default-features = false, features = ["derive"] } bon = { version = "3.4.0", default-features = false } base64 = "0.22.1" @@ -118,6 +120,7 @@ scale = { version = "3.7.4", package = "parity-scale-codec", features = [ ] } serde = { version = "1.0.219", features = ["derive"], default-features = false } serde-human-bytes = "0.1.0" +rmp-serde = "1.3.0" serde_json = { version = "1.0.140", default-features = false } serde_ini = "0.2.0" toml = "0.8.20" @@ -137,6 +140,11 @@ hyper-util = { version = "0.1.10", features = [ "client-legacy", "http1", ] } +hyper-rustls = { version = "0.27", default-features = false, features = [ + "ring", + "http1", + "tls12", +] } hyperlocal = "0.9.1" ipnet = { version = "2.11.0", features = ["serde"] } reqwest = { version = "0.12.14", default-features = false, features = [ diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index d150126c..8b046b29 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -10,7 +10,7 @@ edition.workspace = true license.workspace = true [dependencies] -rocket = { workspace = true, features = ["mtls"] } +rocket = { workspace = true, features = ["mtls", "json"] } tracing.workspace = true tracing-subscriber.workspace = true anyhow.workspace = true @@ -48,11 +48,19 @@ dstack-types.workspace = true serde-duration.workspace = true reqwest = { workspace = true, features = ["json"] } hyper = { workspace = true, features = ["server", "http1"] } -hyper-util = { version = "0.1", features = ["tokio"] } +hyper-util = { workspace = true, features = ["tokio"] } +hyper-rustls.workspace = true +http-body-util.workspace = true +x509-parser.workspace = true jemallocator.workspace = true +wavekv.workspace = true +flate2.workspace = true +uuid = { workspace = true, features = ["v4"] } +rmp-serde.workspace = true [target.'cfg(unix)'.dependencies] nix = { workspace = true, features = ["resource"] } [dev-dependencies] insta.workspace = true +tempfile.workspace = true diff --git a/gateway/dstack-app/builder/entrypoint.sh b/gateway/dstack-app/builder/entrypoint.sh index 7696b95d..76764cd5 100755 --- a/gateway/dstack-app/builder/entrypoint.sh +++ b/gateway/dstack-app/builder/entrypoint.sh @@ -51,6 +51,13 @@ validate_env "$CF_API_TOKEN" validate_env "$CF_ZONE_ID" validate_env "$SRV_DOMAIN" validate_env "$WG_ENDPOINT" +validate_env "$NODE_ID" + +# Validate $NODE_ID, must be a number +if [[ ! "$NODE_ID" =~ ^[0-9]+$ ]]; then + echo "Invalid NODE_ID: $NODE_ID" + exit 1 +fi # Validate $SUBNET_INDEX, valid range is 0-15 if [[ ! "$SUBNET_INDEX" =~ ^[0-9]+$ ]] || [ "$SUBNET_INDEX" -lt 0 ] || [ "$SUBNET_INDEX" -gt 15 ]; then @@ -80,8 +87,7 @@ echo "RPC_DOMAIN: $RPC_DOMAIN" cat >$CONFIG_PATH < node_connections = 2; +} + +// Node status entry +message NodeStatusEntry { + uint32 node_id = 1; + string status = 2; // "up" or "down" +} + +// Get node statuses response +message GetNodeStatusesResponse { + repeated NodeStatusEntry statuses = 1; +} + service Admin { // Get the status of the gateway. rpc Status(google.protobuf.Empty) returns (StatusResponse) {} @@ -187,4 +363,16 @@ service Admin { rpc SetCaa(google.protobuf.Empty) returns (google.protobuf.Empty) {} // Summary API for inspect. rpc GetMeta(google.protobuf.Empty) returns (GetMetaResponse) {} + // Set a node's sync URL - used for dynamic peer management + rpc SetNodeUrl(SetNodeUrlRequest) returns (google.protobuf.Empty) {} + // Set a node's status (up/down) + rpc SetNodeStatus(SetNodeStatusRequest) returns (google.protobuf.Empty) {} + // Get WaveKV sync status + rpc WaveKvStatus(google.protobuf.Empty) returns (WaveKvStatusResponse) {} + // Get instance handshakes from all nodes + rpc GetInstanceHandshakes(GetInstanceHandshakesRequest) returns (GetInstanceHandshakesResponse) {} + // Get global connections statistics + rpc GetGlobalConnections(google.protobuf.Empty) returns (GlobalConnectionsStats) {} + // Get all node statuses + rpc GetNodeStatuses(google.protobuf.Empty) returns (GetNodeStatusesResponse) {} } diff --git a/gateway/src/admin_service.rs b/gateway/src/admin_service.rs index 541dee0d..b317f406 100644 --- a/gateway/src/admin_service.rs +++ b/gateway/src/admin_service.rs @@ -8,14 +8,16 @@ use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::{Context, Result}; use dstack_gateway_rpc::{ admin_server::{AdminRpc, AdminServer}, - GetInfoRequest, GetInfoResponse, GetMetaResponse, HostInfo, RenewCertResponse, StatusResponse, + GetInfoRequest, GetInfoResponse, GetInstanceHandshakesRequest, GetInstanceHandshakesResponse, + GetMetaResponse, GetNodeStatusesResponse, GlobalConnectionsStats, HandshakeEntry, HostInfo, + LastSeenEntry, NodeStatusEntry, PeerSyncStatus as ProtoPeerSyncStatus, RenewCertResponse, + SetNodeStatusRequest, SetNodeUrlRequest, StatusResponse, StoreSyncStatus, WaveKvStatusResponse, }; use ra_rpc::{CallContext, RpcCall}; +use tracing::info; +use wavekv::node::NodeStatus as WaveKvNodeStatus; -use crate::{ - main_service::{encode_ts, Proxy}, - proxy::NUM_CONNECTIONS, -}; +use crate::{kv::NodeStatus, main_service::Proxy, proxy::NUM_CONNECTIONS}; pub struct AdminRpcHandler { state: Proxy, @@ -30,28 +32,28 @@ impl AdminRpcHandler { .state .instances .values() - .map(|instance| HostInfo { - instance_id: instance.id.clone(), - ip: instance.ip.to_string(), - app_id: instance.app_id.clone(), - base_domain: base_domain.clone(), - port: state.config.proxy.listen_port as u32, - latest_handshake: encode_ts(instance.last_seen), - num_connections: instance.num_connections(), + .map(|instance| { + // Get global latest_handshake from KvStore (max across all nodes) + let latest_handshake = state + .get_instance_latest_handshake(&instance.id) + .unwrap_or(0); + HostInfo { + instance_id: instance.id.clone(), + ip: instance.ip.to_string(), + app_id: instance.app_id.clone(), + base_domain: base_domain.clone(), + port: state.config.proxy.listen_port as u32, + latest_handshake, + num_connections: instance.num_connections(), + } }) .collect::>(); - let nodes = state - .state - .nodes - .values() - .cloned() - .map(Into::into) - .collect::>(); Ok(StatusResponse { + id: state.config.sync.node_id, url: state.config.sync.my_url.clone(), - id: state.config.id(), + uuid: state.config.uuid(), bootnode_url: state.config.sync.bootnode.clone(), - nodes, + nodes: state.get_all_nodes(), hosts, num_connections: NUM_CONNECTIONS.load(Ordering::Relaxed), }) @@ -146,6 +148,163 @@ impl AdminRpc for AdminRpcHandler { online: online as u32, }) } + + async fn set_node_url(self, request: SetNodeUrlRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + kv_store.register_peer_url(request.id, &request.url)?; + info!("Updated peer URL: node {} -> {}", request.id, request.url); + Ok(()) + } + + async fn set_node_status(self, request: SetNodeStatusRequest) -> Result<()> { + let kv_store = self.state.kv_store(); + let status = match request.status.as_str() { + "up" => NodeStatus::Up, + "down" => NodeStatus::Down, + _ => anyhow::bail!("invalid status: expected 'up' or 'down'"), + }; + kv_store.set_node_status(request.id, status)?; + info!("Updated node status: node {} -> {:?}", request.id, status); + Ok(()) + } + + async fn wave_kv_status(self) -> Result { + let kv_store = self.state.kv_store(); + + let persistent_status = kv_store.persistent().read().status(); + let ephemeral_status = kv_store.ephemeral().read().status(); + + let get_peer_last_seen = |peer_id: u32| -> Vec<(u32, u64)> { + kv_store + .get_node_last_seen_by_all(peer_id) + .into_iter() + .collect() + }; + + Ok(WaveKvStatusResponse { + enabled: self.state.config.sync.enabled, + persistent: Some(build_store_status( + "persistent", + persistent_status, + &get_peer_last_seen, + )), + ephemeral: Some(build_store_status( + "ephemeral", + ephemeral_status, + &get_peer_last_seen, + )), + }) + } + + async fn get_instance_handshakes( + self, + request: GetInstanceHandshakesRequest, + ) -> Result { + let kv_store = self.state.kv_store(); + let handshakes = kv_store.get_instance_handshakes(&request.instance_id); + + let entries = handshakes + .into_iter() + .map(|(observer_node_id, timestamp)| HandshakeEntry { + observer_node_id, + timestamp, + }) + .collect(); + + Ok(GetInstanceHandshakesResponse { + handshakes: entries, + }) + } + + async fn get_global_connections(self) -> Result { + let state = self.state.lock(); + let kv_store = self.state.kv_store(); + + let mut node_connections = std::collections::HashMap::new(); + let mut total_connections = 0u64; + + // Iterate through all instances and sum up connections per node + for instance_id in state.state.instances.keys() { + // Get connection counts from ephemeral KV for this instance + let conn_prefix = format!("conn/{}/", instance_id); + for (key, count) in kv_store + .ephemeral() + .read() + .iter_by_prefix(&conn_prefix) + .filter_map(|(k, entry)| { + let value = entry.value.as_ref()?; + let count: u64 = rmp_serde::decode::from_slice(value).ok()?; + Some((k.to_string(), count)) + }) + { + // Parse node_id from key: "conn/{instance_id}/{node_id}" + if let Some(node_id_str) = key.strip_prefix(&conn_prefix) { + if let Ok(node_id) = node_id_str.parse::() { + *node_connections.entry(node_id).or_insert(0) += count; + total_connections += count; + } + } + } + } + + Ok(GlobalConnectionsStats { + total_connections, + node_connections, + }) + } + + async fn get_node_statuses(self) -> Result { + let kv_store = self.state.kv_store(); + let statuses = kv_store.load_all_node_statuses(); + + let entries = statuses + .into_iter() + .map(|(node_id, status)| { + let status_str = match status { + NodeStatus::Up => "up", + NodeStatus::Down => "down", + }; + NodeStatusEntry { + node_id, + status: status_str.to_string(), + } + }) + .collect(); + + Ok(GetNodeStatusesResponse { statuses: entries }) + } +} + +fn build_store_status( + name: &str, + status: WaveKvNodeStatus, + get_peer_last_seen: &impl Fn(u32) -> Vec<(u32, u64)>, +) -> StoreSyncStatus { + StoreSyncStatus { + name: name.to_string(), + node_id: status.id, + n_keys: status.n_kvs as u64, + next_seq: status.next_seq, + dirty: status.dirty, + wal_enabled: status.wal, + peers: status + .peers + .into_iter() + .map(|p| { + let last_seen = get_peer_last_seen(p.id) + .into_iter() + .map(|(node_id, timestamp)| LastSeenEntry { node_id, timestamp }) + .collect(); + ProtoPeerSyncStatus { + id: p.id, + local_ack: p.ack, + peer_ack: p.pack, + buffered_logs: p.logs as u64, + last_seen, + } + }) + .collect(), + } } impl RpcCall for AdminRpcHandler { diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 3a3d88db..eb04691d 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -125,11 +125,22 @@ pub struct SyncConfig { #[serde(with = "serde_duration")] pub interval: Duration, #[serde(with = "serde_duration")] - pub broadcast_interval: Duration, - #[serde(with = "serde_duration")] pub timeout: Duration, pub my_url: String, + /// The URL of the bootnode used to fetch initial peer list when joining the network pub bootnode: String, + /// WaveKV node ID for this gateway (must be unique across cluster) + pub node_id: u32, + /// Data directory for WaveKV persistence + pub data_dir: String, + /// Interval for periodic WAL persistence (default: 10s) + #[serde(with = "serde_duration")] + pub persist_interval: Duration, + /// Enable periodic sync of instance connections to KV store + pub sync_connections_enabled: bool, + /// Interval for syncing instance connections to KV store + #[serde(with = "serde_duration")] + pub sync_connections_interval: Duration, } #[derive(Debug, Clone, Deserialize)] @@ -139,16 +150,25 @@ pub struct Config { pub certbot: CertbotConfig, pub pccs_url: Option, pub recycle: RecycleConfig, - pub state_path: String, pub set_ulimit: bool, pub rpc_domain: String, pub kms_url: String, pub admin: AdminConfig, - pub run_in_dstack: bool, + /// Debug server configuration (separate port for debug RPCs) + pub debug: DebugConfig, pub sync: SyncConfig, pub auth: AuthConfig, } +#[derive(Debug, Clone, Deserialize, Default)] +pub struct DebugConfig { + /// Enable debug server + #[serde(default)] + pub insecure_enable_debug_rpc: bool, + #[serde(default)] + pub insecure_skip_attestation: bool, +} + #[derive(Debug, Clone, Deserialize)] pub struct AuthConfig { pub enabled: bool, @@ -158,11 +178,41 @@ pub struct AuthConfig { } impl Config { - pub fn id(&self) -> Vec { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(self.wg.public_key.as_bytes()); - hasher.finalize()[..20].to_vec() + /// Get or generate a unique node UUID. + /// The UUID is stored in `{data_dir}/node_uuid` and persisted across restarts. + pub fn uuid(&self) -> Vec { + use std::fs; + use std::path::Path; + + let uuid_path = Path::new(&self.sync.data_dir).join("node_uuid"); + + // Try to read existing UUID + if let Ok(content) = fs::read_to_string(&uuid_path) { + if let Ok(uuid) = uuid::Uuid::parse_str(content.trim()) { + return uuid.as_bytes().to_vec(); + } + } + + // Generate new UUID + let uuid = uuid::Uuid::new_v4(); + + // Ensure directory exists + if let Some(parent) = uuid_path.parent() { + let _ = fs::create_dir_all(parent); + } + + // Save UUID to file + if let Err(err) = fs::write(&uuid_path, uuid.to_string()) { + tracing::warn!( + "failed to save node UUID to {}: {}", + uuid_path.display(), + err + ); + } else { + tracing::info!("generated new node UUID: {}", uuid); + } + + uuid.as_bytes().to_vec() } } @@ -210,31 +260,6 @@ pub struct CertbotConfig { pub renew_timeout: Duration, } -impl CertbotConfig { - fn to_bot_config(&self) -> certbot::CertBotConfig { - let workdir = certbot::WorkDir::new(&self.workdir); - certbot::CertBotConfig::builder() - .auto_create_account(true) - .cert_dir(workdir.backup_dir()) - .cert_file(workdir.cert_path()) - .key_file(workdir.key_path()) - .credentials_file(workdir.account_credentials_path()) - .acme_url(self.acme_url.clone()) - .cert_subject_alt_names(vec![self.domain.clone()]) - .cf_zone_id(self.cf_zone_id.clone()) - .cf_api_token(self.cf_api_token.clone()) - .renew_interval(self.renew_interval) - .renew_timeout(self.renew_timeout) - .renew_expires_in(self.renew_before_expiration) - .auto_set_caa(self.auto_set_caa) - .build() - } - - pub async fn build_bot(&self) -> Result { - self.to_bot_config().build_bot().await - } -} - pub const DEFAULT_CONFIG: &str = include_str!("../gateway.toml"); pub fn load_config_figment(config_file: Option<&str>) -> Figment { load_config("gateway", DEFAULT_CONFIG, config_file, false) diff --git a/gateway/src/debug_service.rs b/gateway/src/debug_service.rs new file mode 100644 index 00000000..cc2a4a78 --- /dev/null +++ b/gateway/src/debug_service.rs @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Debug service for testing - runs on a separate port when debug.enabled=true + +use anyhow::Result; +use dstack_gateway_rpc::{ + debug_server::{DebugRpc, DebugServer}, + DebugProxyStateResponse, DebugRegisterCvmRequest, DebugSyncDataResponse, InfoResponse, + InstanceEntry, NodeInfoEntry, PeerAddrEntry, ProxyStateInstance, RegisterCvmResponse, +}; +use ra_rpc::{CallContext, RpcCall}; +use tracing::warn; + +use crate::main_service::Proxy; + +pub struct DebugRpcHandler { + state: Proxy, +} + +impl DebugRpcHandler { + pub fn new(state: Proxy) -> Self { + Self { state } + } +} + +impl DebugRpc for DebugRpcHandler { + async fn register_cvm(self, request: DebugRegisterCvmRequest) -> Result { + warn!( + "Debug register CVM: app_id={}, instance_id={}", + request.app_id, request.instance_id + ); + self.state.do_register_cvm( + &request.app_id, + &request.instance_id, + &request.client_public_key, + ) + } + + async fn info(self) -> Result { + let config = &self.state.config; + Ok(InfoResponse { + base_domain: config.proxy.base_domain.clone(), + external_port: config.proxy.external_port as u32, + app_address_ns_prefix: config.proxy.app_address_ns_prefix.clone(), + }) + } + + async fn get_sync_data(self) -> Result { + let kv_store = self.state.kv_store(); + let my_node_id = kv_store.my_node_id(); + + // Get all peer addresses + let peer_addrs: Vec = kv_store + .get_all_peer_addrs() + .into_iter() + .map(|(node_id, url)| PeerAddrEntry { + node_id: node_id as u64, + url, + }) + .collect(); + + // Get all node info + let nodes: Vec = kv_store + .load_all_nodes() + .into_iter() + .map(|(node_id, data)| NodeInfoEntry { + node_id: node_id as u64, + url: data.url, + wg_public_key: data.wg_public_key, + wg_endpoint: data.wg_endpoint, + wg_ip: data.wg_ip, + }) + .collect(); + + // Get all instances + let instances: Vec = kv_store + .load_all_instances() + .into_iter() + .map(|(instance_id, data)| InstanceEntry { + instance_id, + app_id: data.app_id, + ip: data.ip.to_string(), + public_key: data.public_key, + }) + .collect(); + + // Get key counts + let persistent_keys = kv_store.persistent().read().status().n_kvs as u64; + let ephemeral_keys = kv_store.ephemeral().read().status().n_kvs as u64; + + Ok(DebugSyncDataResponse { + my_node_id: my_node_id as u64, + peer_addrs, + nodes, + instances, + persistent_keys, + ephemeral_keys, + }) + } + + async fn get_proxy_state(self) -> Result { + let state = self.state.lock(); + + // Get all instances from ProxyState + let instances: Vec = state + .state + .instances + .values() + .map(|inst| { + let reg_time = inst + .reg_time + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + ProxyStateInstance { + instance_id: inst.id.clone(), + app_id: inst.app_id.clone(), + ip: inst.ip.to_string(), + public_key: inst.public_key.clone(), + reg_time, + } + }) + .collect(); + + // Get all allocated addresses + let allocated_addresses: Vec = state + .state + .allocated_addresses + .iter() + .map(|ip| ip.to_string()) + .collect(); + + Ok(DebugProxyStateResponse { + instances, + allocated_addresses, + }) + } +} + +impl RpcCall for DebugRpcHandler { + type PrpcService = DebugServer; + + fn construct(context: CallContext<'_, Proxy>) -> Result { + Ok(DebugRpcHandler::new(context.state.clone())) + } +} diff --git a/gateway/src/distributed_certbot.rs b/gateway/src/distributed_certbot.rs new file mode 100644 index 00000000..7ed14c3d --- /dev/null +++ b/gateway/src/distributed_certbot.rs @@ -0,0 +1,431 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! Distributed certificate management using WaveKV for synchronization. +//! +//! This module wraps the certbot library to provide distributed certificate +//! management across multiple gateway nodes sharing the same domain. + +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use anyhow::{Context, Result}; +use certbot::{AcmeClient, Dns01Client, WorkDir}; +use fs_err as fs; +use ra_tls::rcgen::KeyPair; +use tracing::{error, info, warn}; + +use crate::config::CertbotConfig; +use crate::kv::{CertCredentials, CertData, KvStore}; + +/// Lock timeout for certificate renewal (10 minutes) +const RENEW_LOCK_TIMEOUT_SECS: u64 = 600; + +/// Distributed certificate manager +pub struct DistributedCertBot { + config: CertbotConfig, + kv_store: Arc, + workdir: WorkDir, +} + +impl DistributedCertBot { + pub fn new(config: CertbotConfig, kv_store: Arc) -> Self { + let workdir = WorkDir::new(&config.workdir); + Self { + config, + kv_store, + workdir, + } + } + + pub fn domain(&self) -> &str { + &self.config.domain + } + + pub fn renew_interval(&self) -> Duration { + self.config.renew_interval + } + + /// Set CAA records for the domain + pub async fn set_caa(&self) -> Result<()> { + let acme_client = self.get_or_create_acme_client().await?; + acme_client + .set_caa_records(&[self.config.domain.clone()]) + .await + } + + /// Initialize certificates - load from KvStore or create new + pub async fn init(&self) -> Result<()> { + // First, try to load from KvStore (synced from other nodes) + if let Some(cert_data) = self.kv_store.get_cert_data(&self.config.domain) { + let now = now_secs(); + if cert_data.not_after > now { + info!( + "cert[{}]: loaded from KvStore (issued by node {}, expires in {} days)", + self.config.domain, + cert_data.issued_by, + (cert_data.not_after - now) / 86400 + ); + self.save_cert_to_disk(&cert_data.cert_pem, &cert_data.key_pem)?; + return Ok(()); + } + info!( + "cert[{}]: KvStore certificate expired, will request new one", + self.config.domain + ); + } + + // Check if local cert exists and is valid + if self.workdir.cert_path().exists() && self.workdir.key_path().exists() { + let cert_pem = fs::read_to_string(self.workdir.cert_path())?; + let key_pem = fs::read_to_string(self.workdir.key_path())?; + if let Some(not_after) = get_cert_expiry(&cert_pem) { + let now = now_secs(); + if not_after > now { + info!( + "cert[{}]: loaded from local file, expires in {} days", + self.config.domain, + (not_after - now) / 86400 + ); + // Sync to KvStore for other nodes + self.save_cert_to_kvstore(&cert_pem, &key_pem, not_after)?; + info!( + "cert[{}]: saved to KvStore for other nodes", + self.config.domain + ); + return Ok(()); + } + } + } + + // No valid cert anywhere, need to request new one + info!( + "cert[{}]: no valid certificate found, requesting from ACME", + self.config.domain + ); + self.request_new_cert().await + } + + /// Try to renew certificate if needed + pub async fn try_renew(&self, force: bool) -> Result { + let domain = &self.config.domain; + + // Check if renewal is needed + let cert_data = self.kv_store.get_cert_data(domain); + let needs_renew = if force { + true + } else if let Some(ref data) = cert_data { + let now = now_secs(); + let expires_in = data.not_after.saturating_sub(now); + let renew_before = self.config.renew_before_expiration.as_secs(); + expires_in < renew_before + } else { + true + }; + + if !needs_renew { + info!("certificate for {} does not need renewal", domain); + return Ok(false); + } + + // Try to acquire lock + if !self + .kv_store + .try_acquire_cert_renew_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + info!( + "another node is renewing certificate for {}, skipping", + domain + ); + return Ok(false); + } + + info!("acquired renew lock for {}, starting renewal", domain); + + // Perform renewal + let result = self.do_renew().await; + + // Release lock regardless of result + if let Err(e) = self.kv_store.release_cert_renew_lock(domain) { + error!("failed to release cert renew lock: {}", e); + } + + result + } + + /// Reload certificate from KvStore (called when watcher triggers) + pub fn reload_from_kvstore(&self) -> Result { + let Some(cert_data) = self.kv_store.get_cert_data(&self.config.domain) else { + return Ok(false); + }; + + // Check if this is newer than what we have + if self.workdir.cert_path().exists() { + let local_cert = fs::read_to_string(self.workdir.cert_path())?; + if let Some(local_expiry) = get_cert_expiry(&local_cert) { + if local_expiry >= cert_data.not_after { + return Ok(false); + } + } + } + + info!( + "cert[{}]: reloading from KvStore (sync triggered, issued by node {})", + self.config.domain, cert_data.issued_by + ); + self.save_cert_to_disk(&cert_data.cert_pem, &cert_data.key_pem)?; + Ok(true) + } + + async fn request_new_cert(&self) -> Result<()> { + let domain = &self.config.domain; + + // Try to acquire lock first + if !self + .kv_store + .try_acquire_cert_renew_lock(domain, RENEW_LOCK_TIMEOUT_SECS) + { + // Another node is requesting, wait for it + info!( + "another node is requesting certificate for {}, waiting...", + domain + ); + // Wait and then try to load from KvStore + tokio::time::sleep(Duration::from_secs(30)).await; + if let Some(cert_data) = self.kv_store.get_cert_data(domain) { + self.save_cert_to_disk(&cert_data.cert_pem, &cert_data.key_pem)?; + return Ok(()); + } + anyhow::bail!("failed to get certificate from KvStore after waiting"); + } + + let result = self.do_request_new().await; + + if let Err(e) = self.kv_store.release_cert_renew_lock(domain) { + error!("failed to release cert renew lock: {}", e); + } + + result + } + + async fn do_request_new(&self) -> Result<()> { + let acme_client = self.get_or_create_acme_client().await?; + let domain = &self.config.domain; + let timeout = self.config.renew_timeout; + + // Generate new key pair + let key = KeyPair::generate().context("failed to generate key")?; + let key_pem = key.serialize_pem(); + + // Request certificate with timeout + info!("cert[{}]: requesting new certificate from ACME...", domain); + let cert_pem = tokio::time::timeout( + timeout, + acme_client.request_new_certificate(&key_pem, &[domain.clone()]), + ) + .await + .context("certificate request timed out")? + .context("failed to request new certificate")?; + + let not_after = get_cert_expiry(&cert_pem).context("failed to parse certificate expiry")?; + + // Save to KvStore first (so other nodes can see it) + self.save_cert_to_kvstore(&cert_pem, &key_pem, not_after)?; + info!( + "cert[{}]: new certificate obtained from ACME, saved to KvStore", + domain + ); + + // Then save to disk + self.save_cert_to_disk(&cert_pem, &key_pem)?; + + info!( + "cert[{}]: new certificate saved to disk (expires in {} days)", + domain, + (not_after - now_secs()) / 86400 + ); + Ok(()) + } + + async fn do_renew(&self) -> Result { + let acme_client = self.get_or_create_acme_client().await?; + let domain = &self.config.domain; + let timeout = self.config.renew_timeout; + + // Load current cert and key + let cert_pem = fs::read_to_string(self.workdir.cert_path()) + .context("failed to read current certificate")?; + let key_pem = + fs::read_to_string(self.workdir.key_path()).context("failed to read current key")?; + + // Renew with timeout + info!("cert[{}]: renewing certificate from ACME...", domain); + let new_cert_pem = + tokio::time::timeout(timeout, acme_client.renew_cert(&cert_pem, &key_pem)) + .await + .context("certificate renewal timed out")? + .context("failed to renew certificate")?; + + let not_after = + get_cert_expiry(&new_cert_pem).context("failed to parse certificate expiry")?; + + // Save to KvStore first + self.save_cert_to_kvstore(&new_cert_pem, &key_pem, not_after)?; + info!("cert[{}]: renewed certificate saved to KvStore", domain); + + // Then save to disk + self.save_cert_to_disk(&new_cert_pem, &key_pem)?; + + info!( + "cert[{}]: renewed certificate saved to disk (expires in {} days)", + domain, + (not_after - now_secs()) / 86400 + ); + Ok(true) + } + + async fn get_or_create_acme_client(&self) -> Result { + let dns01_client = Dns01Client::new_cloudflare( + self.config.cf_zone_id.clone(), + self.config.cf_api_token.clone(), + ); + + // Try to load credentials from KvStore + if let Some(creds) = self.kv_store.get_cert_credentials(&self.config.domain) { + if acme_url_matches(&creds.acme_credentials, &self.config.acme_url) { + info!( + "acme[{}]: loaded account credentials from KvStore", + self.config.domain + ); + return AcmeClient::load(dns01_client, &creds.acme_credentials) + .await + .context("failed to load ACME client from KvStore credentials"); + } + warn!( + "acme[{}]: URL mismatch in KvStore credentials, will try local file", + self.config.domain + ); + } + + // Try to load from local file + let credentials_path = self.workdir.account_credentials_path(); + if credentials_path.exists() { + let creds_json = fs::read_to_string(&credentials_path)?; + if acme_url_matches(&creds_json, &self.config.acme_url) { + info!( + "acme[{}]: loaded account credentials from local file", + self.config.domain + ); + // Save to KvStore for other nodes + self.kv_store.save_cert_credentials( + &self.config.domain, + &CertCredentials { + acme_credentials: creds_json.clone(), + }, + )?; + return AcmeClient::load(dns01_client, &creds_json) + .await + .context("failed to load ACME client from local credentials"); + } + } + + // Create new account + info!( + "acme[{}]: creating new account at {}", + self.config.domain, self.config.acme_url + ); + let client = AcmeClient::new_account(&self.config.acme_url, dns01_client) + .await + .context("failed to create new ACME account")?; + + let creds_json = client + .dump_credentials() + .context("failed to dump ACME credentials")?; + + // Set CAA records if configured + if self.config.auto_set_caa { + client + .set_caa_records(&[self.config.domain.clone()]) + .await?; + } + + // Save to KvStore + self.kv_store.save_cert_credentials( + &self.config.domain, + &CertCredentials { + acme_credentials: creds_json.clone(), + }, + )?; + + // Save to local file + if let Some(parent) = credentials_path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(&credentials_path, &creds_json)?; + + Ok(client) + } + + fn save_cert_to_kvstore(&self, cert_pem: &str, key_pem: &str, not_after: u64) -> Result<()> { + let cert_data = CertData { + cert_pem: cert_pem.to_string(), + key_pem: key_pem.to_string(), + not_after, + issued_by: self.kv_store.my_node_id(), + issued_at: now_secs(), + }; + self.kv_store + .save_cert_data(&self.config.domain, &cert_data) + } + + fn save_cert_to_disk(&self, cert_pem: &str, key_pem: &str) -> Result<()> { + let cert_path = self.workdir.cert_path(); + let key_path = self.workdir.key_path(); + + // Create parent directories + if let Some(parent) = cert_path.parent() { + fs::create_dir_all(parent)?; + } + + // Also save to backup dir with timestamp + let backup_dir = self.workdir.backup_dir(); + fs::create_dir_all(&backup_dir)?; + let timestamp = now_secs(); + let backup_subdir = backup_dir.join(format!("{}", timestamp)); + fs::create_dir_all(&backup_subdir)?; + fs::write(backup_subdir.join("cert.pem"), cert_pem)?; + fs::write(backup_subdir.join("key.pem"), key_pem)?; + + // Write main cert files + fs::write(&cert_path, cert_pem)?; + fs::write(&key_path, key_pem)?; + + Ok(()) + } +} + +fn now_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn get_cert_expiry(cert_pem: &str) -> Option { + use x509_parser::prelude::*; + let pem = Pem::iter_from_buffer(cert_pem.as_bytes()).next()?.ok()?; + let cert = pem.parse_x509().ok()?; + Some(cert.validity().not_after.timestamp() as u64) +} + +fn acme_url_matches(credentials_json: &str, expected_url: &str) -> bool { + #[derive(serde::Deserialize)] + struct Creds { + #[serde(default)] + acme_url: String, + } + serde_json::from_str::(credentials_json) + .map(|c| c.acme_url == expected_url) + .unwrap_or(false) +} diff --git a/gateway/src/kv/https_client.rs b/gateway/src/kv/https_client.rs new file mode 100644 index 00000000..d0d034a9 --- /dev/null +++ b/gateway/src/kv/https_client.rs @@ -0,0 +1,322 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! HTTPS client with mTLS and custom certificate verification during TLS handshake. + +use std::fmt::Debug; +use std::io::{Read, Write}; +use std::sync::Arc; + +use anyhow::{Context, Result}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http_body_util::{BodyExt, Full}; +use hyper::body::Bytes; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; +use rustls::pki_types::pem::PemObject; +use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; +use rustls::{DigitallySignedStruct, SignatureScheme}; +use serde::{de::DeserializeOwned, Serialize}; + +use super::{decode, encode}; + +/// Custom certificate validator trait for TLS handshake verification. +/// +/// Implementations can perform additional validation on the peer certificate +/// during the TLS handshake, before any application data is sent. +pub trait CertValidator: Debug + Send + Sync + 'static { + /// Validate the peer certificate. + /// + /// Called after standard X.509 chain verification succeeds. + /// Return `Ok(())` to accept the certificate, or `Err` to reject. + fn validate(&self, cert_der: &[u8]) -> Result<(), String>; +} + +/// TLS configuration for mTLS with optional custom certificate validation +#[derive(Clone)] +pub struct HttpsClientConfig { + pub cert_path: String, + pub key_path: String, + pub ca_cert_path: String, + /// Optional custom certificate validator (checked during TLS handshake) + pub cert_validator: Option>, +} + +/// Wrapper that adapts a CertValidator to rustls ServerCertVerifier +#[derive(Debug)] +struct CustomCertVerifier { + validator: Arc, + root_store: Arc, +} + +impl CustomCertVerifier { + fn new( + validator: Arc, + ca_cert_der: CertificateDer<'static>, + ) -> Result { + let mut root_store = rustls::RootCertStore::empty(); + root_store + .add(ca_cert_der) + .context("failed to add CA cert to root store")?; + Ok(Self { + validator, + root_store: Arc::new(root_store), + }) + } +} + +impl ServerCertVerifier for CustomCertVerifier { + fn verify_server_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + server_name: &ServerName<'_>, + _ocsp_response: &[u8], + now: UnixTime, + ) -> Result { + // First, do standard certificate verification + let verifier = rustls::client::WebPkiServerVerifier::builder(self.root_store.clone()) + .build() + .map_err(|e| rustls::Error::General(format!("failed to build verifier: {e}")))?; + + verifier.verify_server_cert(end_entity, intermediates, server_name, &[], now)?; + + // Then run custom validation + self.validator + .validate(end_entity.as_ref()) + .map_err(rustls::Error::General)?; + + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls12_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + rustls::crypto::verify_tls13_signature( + message, + cert, + dss, + &rustls::crypto::ring::default_provider().signature_verification_algorithms, + ) + } + + fn supported_verify_schemes(&self) -> Vec { + rustls::crypto::ring::default_provider() + .signature_verification_algorithms + .supported_schemes() + } +} + +type HyperClient = Client, Full>; + +/// HTTPS client with mTLS and optional custom certificate validation. +/// +/// When a `cert_validator` is set in `TlsConfig`, the client runs the validator +/// during the TLS handshake, before any application data is sent. +#[derive(Clone)] +pub struct HttpsClient { + client: HyperClient, +} + +impl HttpsClient { + /// Create a new HTTPS client with mTLS configuration + pub fn new(tls: &HttpsClientConfig) -> Result { + // Load client certificate and key + let cert_pem = std::fs::read(&tls.cert_path) + .with_context(|| format!("failed to read TLS cert from {}", tls.cert_path))?; + let key_pem = std::fs::read(&tls.key_path) + .with_context(|| format!("failed to read TLS key from {}", tls.key_path))?; + + let certs: Vec> = CertificateDer::pem_slice_iter(&cert_pem) + .collect::>() + .context("failed to parse client certs")?; + + let key = PrivateKeyDer::from_pem_slice(&key_pem).context("failed to parse private key")?; + + // Load CA certificate + let ca_cert_pem = std::fs::read(&tls.ca_cert_path) + .with_context(|| format!("failed to read CA cert from {}", tls.ca_cert_path))?; + let ca_certs: Vec> = CertificateDer::pem_slice_iter(&ca_cert_pem) + .collect::>() + .context("failed to parse CA certs")?; + let ca_cert = ca_certs + .into_iter() + .next() + .context("no CA certificate found")?; + + // Build rustls config with custom verifier if validator is provided + let tls_config_builder = rustls::ClientConfig::builder(); + + let tls_config = if let Some(ref validator) = tls.cert_validator { + let verifier = CustomCertVerifier::new(validator.clone(), ca_cert)?; + tls_config_builder + .dangerous() + .with_custom_certificate_verifier(Arc::new(verifier)) + } else { + // Standard verification without custom validator + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(ca_cert).context("failed to add CA cert")?; + tls_config_builder.with_root_certificates(root_store) + } + .with_client_auth_cert(certs, key) + .context("failed to set client auth cert")?; + + let https = HttpsConnectorBuilder::new() + .with_tls_config(tls_config) + .https_only() + .enable_http1() + .build(); + + let client = Client::builder(TokioExecutor::new()).build(https); + Ok(Self { client }) + } + + /// Send a POST request with JSON body and receive JSON response + pub async fn post_json( + &self, + url: &str, + body: &T, + ) -> Result { + let body = serde_json::to_vec(body).context("failed to serialize request body")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/json") + .body(Full::new(Bytes::from(body))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + serde_json::from_slice(&body).context("failed to parse response") + } + + /// Send a POST request with msgpack + gzip encoded body and receive msgpack + gzip response + pub async fn post_compressed_msg( + &self, + url: &str, + body: &T, + ) -> Result { + let encoded = encode(body).context("failed to encode request body")?; + + // Compress with gzip + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder + .write_all(&encoded) + .context("failed to compress request")?; + let compressed = encoder.finish().context("failed to finish compression")?; + + let request = hyper::Request::builder() + .method(hyper::Method::POST) + .uri(url) + .header("content-type", "application/x-msgpack-gz") + .body(Full::new(Bytes::from(compressed))) + .context("failed to build request")?; + + let response = self + .client + .request(request) + .await + .with_context(|| format!("failed to send request to {url}"))?; + + if !response.status().is_success() { + anyhow::bail!("request failed: {}", response.status()); + } + + let body = response + .into_body() + .collect() + .await + .context("failed to read response body")? + .to_bytes(); + + // Decompress + let mut decoder = GzDecoder::new(body.as_ref()); + let mut decompressed = Vec::new(); + decoder + .read_to_end(&mut decompressed) + .context("failed to decompress response")?; + + decode(&decompressed).context("failed to decode response") + } +} + +// ============================================================================ +// Built-in validators +// ============================================================================ + +/// Validator that checks the peer certificate contains a specific app_id. +#[derive(Debug)] +pub struct AppIdValidator { + expected_app_id: Vec, +} + +impl AppIdValidator { + pub fn new(expected_app_id: Vec) -> Self { + Self { expected_app_id } + } +} + +impl CertValidator for AppIdValidator { + fn validate(&self, cert_der: &[u8]) -> Result<(), String> { + use ra_tls::traits::CertExt; + + let (_, cert) = x509_parser::parse_x509_certificate(cert_der) + .map_err(|e| format!("failed to parse certificate: {e}"))?; + + let peer_app_id = cert + .get_app_id() + .map_err(|e| format!("failed to get app_id: {e}"))?; + + let Some(peer_app_id) = peer_app_id else { + return Err("peer certificate does not contain app_id".into()); + }; + + if peer_app_id != self.expected_app_id { + return Err(format!( + "app_id mismatch: expected {}, got {}", + hex::encode(&self.expected_app_id), + hex::encode(&peer_app_id) + )); + } + + Ok(()) + } +} diff --git a/gateway/src/kv/mod.rs b/gateway/src/kv/mod.rs new file mode 100644 index 00000000..5126d7c5 --- /dev/null +++ b/gateway/src/kv/mod.rs @@ -0,0 +1,586 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV-based sync layer for dstack-gateway. +//! +//! This module provides synchronization between gateway nodes. The local ProxyState +//! remains the primary data store for fast reads, while WaveKV handles cross-node sync. +//! +//! Key schema: +//! +//! # Persistent WaveKV (needs persistence + sync) +//! - `inst/{instance_id}` → InstanceData +//! - `node/{node_id}` → NodeData +//! +//! # Ephemeral WaveKV (no persistence, sync only) +//! - `conn/{instance_id}/{node_id}` → u64 (connection count) +//! - `last_seen/inst/{instance_id}` → u64 (timestamp) +//! - `last_seen/node/{node_id}/{seen_by_node_id}` → u64 (timestamp) + +mod https_client; +mod sync_service; + +pub use https_client::{AppIdValidator, HttpsClientConfig}; +pub use sync_service::{fetch_peers_from_bootnode, WaveKvSyncService}; +use tracing::warn; + +use std::{collections::BTreeMap, net::Ipv4Addr, path::Path}; + +use anyhow::{Context, Result}; +use serde::{Deserialize, Serialize}; +use tokio::sync::watch; +use wavekv::{node::NodeState, types::NodeId, Node}; + +/// Instance core data (persistent) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct InstanceData { + pub app_id: String, + pub ip: Ipv4Addr, + pub public_key: String, + pub reg_time: u64, +} + +/// Gateway node status (stored separately for independent updates) +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum NodeStatus { + #[default] + Up, + Down, +} + +/// Gateway node data (persistent, rarely changes) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct NodeData { + pub uuid: Vec, + pub url: String, + pub wg_public_key: String, + pub wg_endpoint: String, + pub wg_ip: String, +} + +/// Certificate credentials (ACME account) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertCredentials { + pub acme_credentials: String, +} + +/// Certificate data (cert + key) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertData { + pub cert_pem: String, + pub key_pem: String, + pub not_after: u64, + pub issued_by: NodeId, + pub issued_at: u64, +} + +/// Certificate renew lock +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CertRenewLock { + pub started_at: u64, + pub started_by: NodeId, +} + +// Key prefixes and builders +pub mod keys { + use super::NodeId; + + pub const INST_PREFIX: &str = "inst/"; + pub const NODE_PREFIX: &str = "node/"; + pub const NODE_INFO_PREFIX: &str = "node/info/"; + pub const NODE_STATUS_PREFIX: &str = "node/status/"; + pub const CONN_PREFIX: &str = "conn/"; + pub const HANDSHAKE_PREFIX: &str = "handshake/"; + pub const LAST_SEEN_NODE_PREFIX: &str = "last_seen/node/"; + pub const PEER_ADDR_PREFIX: &str = "__peer_addr/"; + pub const CERT_PREFIX: &str = "cert/"; + + pub fn inst(instance_id: &str) -> String { + format!("{INST_PREFIX}{instance_id}") + } + + pub fn node_info(node_id: NodeId) -> String { + format!("{NODE_INFO_PREFIX}{node_id}") + } + + pub fn node_status(node_id: NodeId) -> String { + format!("{NODE_STATUS_PREFIX}{node_id}") + } + + pub fn conn(instance_id: &str, node_id: NodeId) -> String { + format!("{CONN_PREFIX}{instance_id}/{node_id}") + } + + /// Key for instance handshake timestamp observed by a specific node + /// Format: handshake/{instance_id}/{observer_node_id} + pub fn handshake(instance_id: &str, observer_node_id: NodeId) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/{observer_node_id}") + } + + /// Prefix to iterate all handshake observations for an instance + pub fn handshake_prefix(instance_id: &str) -> String { + format!("{HANDSHAKE_PREFIX}{instance_id}/") + } + + pub fn last_seen_node(node_id: NodeId, seen_by: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/{seen_by}") + } + + pub fn last_seen_node_prefix(node_id: NodeId) -> String { + format!("{LAST_SEEN_NODE_PREFIX}{node_id}/") + } + + pub fn peer_addr(node_id: NodeId) -> String { + format!("{PEER_ADDR_PREFIX}{node_id}") + } + + // Certificate keys (per domain) + pub fn cert_credentials(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/credentials") + } + + pub fn cert_data(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/data") + } + + pub fn cert_renew_lock(domain: &str) -> String { + format!("{CERT_PREFIX}{domain}/renew_lock") + } + + /// Parse instance_id from key + pub fn parse_inst_key(key: &str) -> Option<&str> { + key.strip_prefix(INST_PREFIX) + } + + /// Parse node_id from node/info/{node_id} key + pub fn parse_node_info_key(key: &str) -> Option { + key.strip_prefix(NODE_INFO_PREFIX)?.parse().ok() + } +} + +pub fn encode(value: &T) -> Result> { + rmp_serde::encode::to_vec(value).context("failed to encode value") +} + +pub fn decode Deserialize<'de>>(bytes: &[u8]) -> Result { + rmp_serde::decode::from_slice(bytes).context("failed to decode value") +} + +trait GetPutCodec { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option; + fn put_encoded(&mut self, key: String, value: &T) -> Result<()>; + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator; +} + +impl GetPutCodec for NodeState { + fn decode serde::Deserialize<'de>>(&self, key: &str) -> Option { + self.get(key) + .and_then(|entry| match decode(entry.value.as_ref()?) { + Ok(value) => Some(value), + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + None + } + }) + } + + fn put_encoded(&mut self, key: String, value: &T) -> Result<()> { + self.put(key.clone(), encode(value)?) + .with_context(|| format!("failed to put key {key}"))?; + Ok(()) + } + + fn iter_decoded serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some((key.to_string(), value)) + }) + } + + fn iter_decoded_values serde::Deserialize<'de>>( + &self, + prefix: &str, + ) -> impl Iterator { + self.iter_by_prefix(prefix).filter_map(|(key, entry)| { + let value = match decode(entry.value.as_ref()?) { + Ok(value) => value, + Err(e) => { + warn!("failed to decode value for key {key}: {e:?}"); + return None; + } + }; + Some(value) + }) + } +} + +/// Sync store wrapping two WaveKV Nodes (persistent and ephemeral). +/// +/// This is the sync layer - not the primary data store. +/// ProxyState remains in memory for fast reads. +#[derive(Clone)] +pub struct KvStore { + /// Persistent WaveKV Node (with WAL) + persistent: Node, + /// Ephemeral WaveKV Node (in-memory only) + ephemeral: Node, + /// This gateway's node ID + my_node_id: NodeId, +} + +impl KvStore { + /// Create a new sync store + pub fn new( + my_node_id: NodeId, + peer_ids: Vec, + data_dir: impl AsRef, + ) -> Result { + let persistent = + Node::new_with_persistence(my_node_id, peer_ids.clone(), data_dir.as_ref()) + .context("failed to create persistent wavekv node")?; + + let ephemeral = Node::new(my_node_id, peer_ids); + + Ok(Self { + persistent, + ephemeral, + my_node_id, + }) + } + + pub fn my_node_id(&self) -> NodeId { + self.my_node_id + } + + pub fn persistent(&self) -> &Node { + &self.persistent + } + + pub fn ephemeral(&self) -> &Node { + &self.ephemeral + } + + // ==================== Instance Sync ==================== + + /// Sync instance data to other nodes + pub fn sync_instance(&self, instance_id: &str, data: &InstanceData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::inst(instance_id), data) + } + + /// Sync instance deletion to other nodes + pub fn sync_delete_instance(&self, instance_id: &str) -> Result<()> { + self.persistent.write().delete(keys::inst(instance_id))?; + self.ephemeral + .write() + .delete(keys::conn(instance_id, self.my_node_id))?; + // Delete this node's handshake record + self.ephemeral + .write() + .delete(keys::handshake(instance_id, self.my_node_id))?; + Ok(()) + } + + /// Load all instances from sync store (for initial sync on startup) + pub fn load_all_instances(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::INST_PREFIX) + .filter_map(|(key, data)| { + let instance_id = keys::parse_inst_key(&key)?; + Some((instance_id.into(), data)) + }) + .collect() + } + + // ==================== Node Sync ==================== + + /// Sync node data to other nodes + pub fn sync_node(&self, node_id: NodeId, data: &NodeData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_info(node_id), data) + } + + /// Load all nodes from sync store + pub fn load_all_nodes(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_INFO_PREFIX) + .filter_map(|(key, data)| { + let node_id = keys::parse_node_info_key(&key)?; + Some((node_id, data)) + }) + .collect() + } + + // ==================== Node Status Sync ==================== + + /// Set node status (stored separately from NodeData) + pub fn set_node_status(&self, node_id: NodeId, status: NodeStatus) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::node_status(node_id), &status)?; + Ok(()) + } + + /// Get node status + pub fn get_node_status(&self, node_id: NodeId) -> NodeStatus { + self.persistent + .read() + .decode(&keys::node_status(node_id)) + .unwrap_or_default() + } + + /// Load all node statuses + pub fn load_all_node_statuses(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::NODE_STATUS_PREFIX) + .filter_map(|(key, status)| { + let node_id: NodeId = key.strip_prefix(keys::NODE_STATUS_PREFIX)?.parse().ok()?; + Some((node_id, status)) + }) + .collect() + } + + // ==================== Connection Count Sync ==================== + + /// Sync connection count for an instance (from this node) + pub fn sync_connections(&self, instance_id: &str, count: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::conn(instance_id, self.my_node_id), &count)?; + Ok(()) + } + + // ==================== Handshake Sync ==================== + + /// Sync handshake timestamp for an instance (as observed by this node) + pub fn sync_instance_handshake(&self, instance_id: &str, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::handshake(instance_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all handshake observations for an instance (from all nodes) + pub fn get_instance_handshakes(&self, instance_id: &str) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::handshake_prefix(instance_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::handshake_prefix(instance_id))?; + let observer: NodeId = suffix.parse().ok()?; + Some((observer, ts)) + }) + .collect() + } + + /// Get the latest handshake timestamp for an instance (max across all nodes) + pub fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::handshake_prefix(instance_id)) + .max() + } + + /// Sync node last_seen (as observed by this node) + pub fn sync_node_last_seen(&self, node_id: NodeId, timestamp: u64) -> Result<()> { + self.ephemeral + .write() + .put_encoded(keys::last_seen_node(node_id, self.my_node_id), ×tamp)?; + Ok(()) + } + + /// Get all observations of a node's last_seen + pub fn get_node_last_seen_by_all(&self, node_id: NodeId) -> BTreeMap { + self.ephemeral + .read() + .iter_decoded(&keys::last_seen_node_prefix(node_id)) + .filter_map(|(key, ts)| { + let suffix = key.strip_prefix(&keys::last_seen_node_prefix(node_id))?; + let seen_by: NodeId = suffix.parse().ok()?; + Some((seen_by, ts)) + }) + .collect() + } + + /// Get the latest last_seen timestamp for a node (max across all observers) + pub fn get_node_latest_last_seen(&self, node_id: NodeId) -> Option { + self.ephemeral + .read() + .iter_decoded_values(&keys::last_seen_node_prefix(node_id)) + .max() + } + + // ==================== Watch for Remote Changes ==================== + + /// Watch for remote instance changes (for updating local ProxyState) + pub fn watch_instances(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::INST_PREFIX) + } + + /// Watch for remote node changes + pub fn watch_nodes(&self) -> watch::Receiver<()> { + self.persistent.watch_prefix(keys::NODE_PREFIX) + } + + // ==================== Persistence ==================== + + pub fn persist_if_dirty(&self) -> Result { + self.persistent.persist_if_dirty() + } + + // ==================== Peer Management ==================== + + pub fn add_peer(&self, peer_id: NodeId) -> Result<()> { + self.persistent.write().add_peer(peer_id)?; + self.ephemeral.write().add_peer(peer_id)?; + Ok(()) + } + + // ==================== Peer Address (in DB) ==================== + + /// Register a node's sync URL in DB and add to peer list for sync + /// + /// This stores the URL in KvStore (for address lookup) and also adds the node + /// to the wavekv peer list (so SyncManager knows to sync with it). + pub fn register_peer_url(&self, node_id: NodeId, url: &str) -> Result<()> { + // Store URL in persistent KvStore + self.persistent + .write() + .put_encoded(keys::peer_addr(node_id), &url)?; + + let _ = self.add_peer(node_id); + Ok(()) + } + + /// Get a peer's sync URL from DB + pub fn get_peer_url(&self, node_id: NodeId) -> Option { + self.persistent.read().decode(&keys::peer_addr(node_id)) + } + + /// Query the UUID for a given node ID from KvStore + pub fn get_peer_uuid(&self, peer_id: NodeId) -> Option> { + self.persistent.read().decode(&keys::node_info(peer_id)) + } + + pub fn update_peer_last_seen(&self, peer_id: NodeId) { + let ts = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + let key = keys::last_seen_node(peer_id, self.my_node_id); + if let Err(e) = self.ephemeral.write().put_encoded(key, &ts) { + warn!("failed to update peer {peer_id} last_seen: {e}"); + } + } + + /// Get all peer addresses from DB (for debugging/testing) + pub fn get_all_peer_addrs(&self) -> BTreeMap { + self.persistent + .read() + .iter_decoded(keys::PEER_ADDR_PREFIX) + .filter_map(|(key, url)| { + let node_id: NodeId = key.strip_prefix(keys::PEER_ADDR_PREFIX)?.parse().ok()?; + Some((node_id, url)) + }) + .collect() + } + + // ==================== Certificate Sync ==================== + + /// Get certificate credentials for a domain + pub fn get_cert_credentials(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::cert_credentials(domain)) + } + + /// Save certificate credentials for a domain + pub fn save_cert_credentials(&self, domain: &str, creds: &CertCredentials) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::cert_credentials(domain), creds)?; + Ok(()) + } + + /// Get certificate data for a domain + pub fn get_cert_data(&self, domain: &str) -> Option { + self.persistent.read().decode(&keys::cert_data(domain)) + } + + /// Save certificate data for a domain + pub fn save_cert_data(&self, domain: &str, data: &CertData) -> Result<()> { + self.persistent + .write() + .put_encoded(keys::cert_data(domain), data)?; + Ok(()) + } + + /// Get certificate renew lock for a domain + pub fn get_cert_renew_lock(&self, domain: &str) -> Option { + self.persistent + .read() + .decode(&keys::cert_renew_lock(domain)) + } + + /// Try to acquire certificate renew lock + /// Returns true if lock acquired, false if already locked by another node + pub fn try_acquire_cert_renew_lock(&self, domain: &str, lock_timeout_secs: u64) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if let Some(existing) = self.get_cert_renew_lock(domain) { + // Check if lock is still valid (not expired) + if now < existing.started_at + lock_timeout_secs { + return false; + } + } + + // Acquire the lock + let lock = CertRenewLock { + started_at: now, + started_by: self.my_node_id, + }; + self.persistent + .write() + .put_encoded(keys::cert_renew_lock(domain), &lock) + .is_ok() + } + + /// Release certificate renew lock + pub fn release_cert_renew_lock(&self, domain: &str) -> Result<()> { + self.persistent + .write() + .delete(keys::cert_renew_lock(domain))?; + Ok(()) + } + + /// Watch for certificate data changes + pub fn watch_cert(&self, domain: &str) -> watch::Receiver<()> { + self.persistent.watch_prefix(&keys::cert_data(domain)) + } +} diff --git a/gateway/src/kv/sync_service.rs b/gateway/src/kv/sync_service.rs new file mode 100644 index 00000000..f691595a --- /dev/null +++ b/gateway/src/kv/sync_service.rs @@ -0,0 +1,238 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync service - implements network transport for wavekv synchronization. +//! +//! Peer URLs are stored in the persistent KV store under `__peer_addr/{node_id}` keys. +//! This allows peer addresses to be automatically synced across nodes. + +use std::sync::Arc; + +use anyhow::{Context, Result}; +use dstack_gateway_rpc::GetPeersResponse; +use tracing::{info, warn}; +use wavekv::{ + sync::{ExchangeInterface, SyncConfig as KvSyncConfig, SyncManager, SyncMessage, SyncResponse}, + types::NodeId, + Node, +}; + +use crate::config::SyncConfig as GwSyncConfig; + +use super::https_client::{HttpsClient, HttpsClientConfig}; +use super::KvStore; + +/// HTTP-based network transport for WaveKV sync. +/// Holds a reference to the persistent node for reading peer URLs. +#[derive(Clone)] +pub struct HttpSyncNetwork { + client: HttpsClient, + /// Reference to persistent node for reading peer URLs + kv_store: KvStore, + /// This node's UUID (for node ID reuse detection) + my_uuid: Vec, + /// URL path suffix for this store (e.g., "persistent" or "ephemeral") + store_path: &'static str, +} + +impl HttpSyncNetwork { + pub fn new( + kv_store: KvStore, + store_path: &'static str, + tls_config: &HttpsClientConfig, + ) -> Result { + let client = HttpsClient::new(tls_config)?; + let my_uuid = kv_store + .get_peer_uuid(kv_store.my_node_id) + .context("failed to get my UUID")?; + Ok(Self { + client, + kv_store, + my_uuid, + store_path, + }) + } + + /// Get peer URL from persistent node + fn get_peer_url(&self, peer_id: NodeId) -> Option { + self.kv_store.get_peer_url(peer_id) + } +} + +impl ExchangeInterface for HttpSyncNetwork { + fn uuid(&self) -> Vec { + self.my_uuid.clone() + } + + fn query_uuid(&self, node_id: NodeId) -> Option> { + self.kv_store.get_peer_uuid(node_id) + } + + async fn sync_to(&self, _node: &Node, peer: NodeId, msg: SyncMessage) -> Result { + let url = self + .get_peer_url(peer) + .ok_or_else(|| anyhow::anyhow!("peer {} address not found in DB", peer))?; + + let sync_url = format!( + "{}/wavekv/sync/{}", + url.trim_end_matches('/'), + self.store_path + ); + + // Send request with msgpack + gzip encoding + // app_id verification happens during TLS handshake via AppIdVerifier + let sync_response: SyncResponse = self + .client + .post_compressed_msg(&sync_url, &msg) + .await + .with_context(|| format!("failed to sync to peer {peer} at {sync_url}"))?; + + // Update peer last_seen on successful sync + self.kv_store.update_peer_last_seen(peer); + + Ok(sync_response) + } +} + +/// WaveKV sync service that manages synchronization for both persistent and ephemeral stores +pub struct WaveKvSyncService { + pub persistent_manager: Arc>, + pub ephemeral_manager: Arc>, +} + +impl WaveKvSyncService { + /// Create a new WaveKV sync service + /// + /// # Arguments + /// * `kv_store` - The sync store containing persistent and ephemeral nodes + /// * `sync_config` - Sync configuration + /// * `tls_config` - TLS configuration for mTLS peer authentication + pub fn new( + kv_store: &KvStore, + sync_config: &GwSyncConfig, + tls_config: HttpsClientConfig, + ) -> Result { + let sync_config = KvSyncConfig { + interval: sync_config.interval, + timeout: sync_config.timeout, + }; + + // Both networks use the same persistent node for URL lookup, but different paths + let persistent_network = HttpSyncNetwork::new(kv_store.clone(), "persistent", &tls_config)?; + let ephemeral_network = HttpSyncNetwork::new(kv_store.clone(), "ephemeral", &tls_config)?; + + let persistent_manager = Arc::new(SyncManager::with_config( + kv_store.persistent().clone(), + persistent_network, + sync_config.clone(), + )); + let ephemeral_manager = Arc::new(SyncManager::with_config( + kv_store.ephemeral().clone(), + ephemeral_network, + sync_config, + )); + + Ok(Self { + persistent_manager, + ephemeral_manager, + }) + } + + /// Bootstrap from peers + pub async fn bootstrap(&self) -> Result<()> { + info!("bootstrapping persistent store..."); + if let Err(e) = self.persistent_manager.bootstrap().await { + warn!("failed to bootstrap persistent store: {e}"); + } + + info!("bootstrapping ephemeral store..."); + if let Err(e) = self.ephemeral_manager.bootstrap().await { + warn!("failed to bootstrap ephemeral store: {e}"); + } + + Ok(()) + } + + /// Start background sync tasks + pub async fn start_sync_tasks(&self) { + let persistent = self.persistent_manager.clone(); + let ephemeral = self.ephemeral_manager.clone(); + + tokio::join!(persistent.start_sync_tasks(), ephemeral.start_sync_tasks(),); + + info!("WaveKV sync tasks started"); + } + + /// Handle incoming sync request for persistent store + pub fn handle_persistent_sync(&self, msg: SyncMessage) -> Result { + self.persistent_manager.handle_sync(msg) + } + + /// Handle incoming sync request for ephemeral store + pub fn handle_ephemeral_sync(&self, msg: SyncMessage) -> Result { + self.ephemeral_manager.handle_sync(msg) + } +} + +/// Fetch peer list from bootnode and register them in KvStore. +/// +/// This is called during startup to bootstrap the peer list from a known bootnode. +/// Uses Gateway.GetPeers RPC which requires mTLS gateway authentication. +pub async fn fetch_peers_from_bootnode( + bootnode_url: &str, + kv_store: &KvStore, + my_node_id: NodeId, + tls_config: &HttpsClientConfig, +) -> Result<()> { + if bootnode_url.is_empty() { + info!("no bootnode configured, skipping peer fetch"); + return Ok(()); + } + + info!("fetching peers from bootnode: {}", bootnode_url); + + // Create HTTPS client for bootnode communication (with mTLS) + let client = HttpsClient::new(tls_config).context("failed to create HTTPS client")?; + + // Call Gateway.GetPeers RPC on bootnode (requires mTLS gateway auth) + let peers_url = format!("{}/prpc/GetPeers", bootnode_url.trim_end_matches('/')); + + let response: GetPeersResponse = client + .post_json(&peers_url, &()) + .await + .with_context(|| format!("failed to fetch peers from bootnode {bootnode_url}"))?; + + info!( + "bootnode returned {} peers (bootnode_id={})", + response.peers.len(), + response.my_id + ); + + // Register each peer + for peer in &response.peers { + if peer.id == my_node_id { + continue; // Skip self + } + + // Add peer to WaveKV + if let Err(e) = kv_store.add_peer(peer.id) { + warn!("failed to add peer {}: {}", peer.id, e); + continue; + } + + // Register peer URL + if !peer.url.is_empty() { + if let Err(e) = kv_store.register_peer_url(peer.id, &peer.url) { + warn!("failed to register peer URL for node {}: {}", peer.id, e); + } else { + info!( + "registered peer from bootnode: node {} -> {}", + peer.id, peer.url + ); + } + } + } + + Ok(()) +} diff --git a/gateway/src/main.rs b/gateway/src/main.rs index 61d25632..331243f8 100644 --- a/gateway/src/main.rs +++ b/gateway/src/main.rs @@ -7,7 +7,7 @@ use clap::Parser; use config::{Config, TlsConfig}; use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, GetTlsKeyArgs}; use http_client::prpc::PrpcClient; -use ra_rpc::{client::RaClient, rocket_helper::QuoteVerifier}; +use ra_rpc::{client::RaClient, prpc_routes as prpc, rocket_helper::QuoteVerifier}; use rocket::{ fairing::AdHoc, figment::{providers::Serialized, Figment}, @@ -15,10 +15,15 @@ use rocket::{ use tracing::info; use admin_service::AdminRpcHandler; -use main_service::{Proxy, RpcHandler}; +use main_service::{Proxy, ProxyOptions, RpcHandler}; + +use crate::debug_service::DebugRpcHandler; mod admin_service; mod config; +mod debug_service; +mod distributed_certbot; +mod kv; mod main_service; mod models; mod proxy; @@ -67,7 +72,7 @@ async fn maybe_gen_certs(config: &Config, tls_config: &TlsConfig) -> Result<()> return Ok(()); } - if config.run_in_dstack { + if !config.debug.insecure_skip_attestation { info!("Using dstack guest agent for certificate generation"); let agent_client = dstack_agent().context("Failed to create dstack client")?; let response = agent_client @@ -138,9 +143,18 @@ async fn main() -> Result<()> { let figment = config::load_config_figment(args.config.as_deref()); let config = figment.focus("core").extract::()?; + + // Validate node_id + if config.sync.enabled && config.sync.node_id == 0 { + anyhow::bail!("node_id must be greater than 0"); + } + config::setup_wireguard(&config.wg)?; - let tls_config = figment.focus("tls").extract::()?; + let tls_config = figment + .focus("tls") + .extract::() + .context("Failed to extract tls config")?; maybe_gen_certs(&config, &tls_config) .await .context("Failed to generate certs")?; @@ -150,39 +164,50 @@ async fn main() -> Result<()> { set_max_ulimit()?; } - let my_app_id = if config.run_in_dstack { + let my_app_id = if config.debug.insecure_skip_attestation { + None + } else { let dstack_client = dstack_agent().context("Failed to create dstack client")?; let info = dstack_client .info() .await .context("Failed to get app info")?; Some(info.app_id) - } else { - None }; let proxy_config = config.proxy.clone(); let pccs_url = config.pccs_url.clone(); let admin_enabled = config.admin.enabled; - let state = main_service::Proxy::new(config, my_app_id).await?; + let debug_config = config.debug.clone(); + let state = Proxy::new(ProxyOptions { + config, + my_app_id, + tls_config, + }) + .await?; info!("Starting background tasks"); state.start_bg_tasks().await?; state.lock().reconfigure()?; proxy::start(proxy_config, state.clone()); - let admin_figment = - Figment::new() - .merge(rocket::Config::default()) - .merge(Serialized::defaults( - figment - .find_value("core.admin") - .context("admin section not found")?, - )); + let admin_value = figment + .find_value("core.admin") + .context("admin section not found")?; + let debug_value = figment + .find_value("core.debug") + .context("debug section not found")?; + + let admin_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(admin_value)); + + let debug_figment = Figment::new() + .merge(rocket::Config::default()) + .merge(Serialized::defaults(debug_value)); let mut rocket = rocket::custom(figment) - .mount( - "/prpc", - ra_rpc::prpc_routes!(Proxy, RpcHandler, trim: "Tproxy."), - ) + .mount("/prpc", prpc!(Proxy, RpcHandler, trim: "Tproxy.")) + // Mount WaveKV sync endpoint (requires mTLS gateway auth) + .mount("/", web_routes::wavekv_sync_routes()) .attach(AdHoc::on_response("Add app version header", |_req, res| { Box::pin(async move { res.set_raw_header("X-App-Version", app_version()); @@ -192,12 +217,26 @@ async fn main() -> Result<()> { let verifier = QuoteVerifier::new(pccs_url); rocket = rocket.manage(verifier); let main_srv = rocket.launch(); + let admin_state = state.clone(); + let debug_state = state; let admin_srv = async move { if admin_enabled { rocket::custom(admin_figment) .mount("/", web_routes::routes()) - .mount("/", ra_rpc::prpc_routes!(Proxy, AdminRpcHandler)) - .manage(state) + .mount("/", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .mount("/prpc", prpc!(Proxy, AdminRpcHandler, trim: "Admin.")) + .manage(admin_state) + .launch() + .await + } else { + std::future::pending().await + } + }; + let debug_srv = async move { + if debug_config.insecure_enable_debug_rpc { + rocket::custom(debug_figment) + .mount("/prpc", prpc!(Proxy, DebugRpcHandler, trim: "Debug.")) + .manage(debug_state) .launch() .await } else { @@ -211,6 +250,9 @@ async fn main() -> Result<()> { result = admin_srv => { result.map_err(|err| anyhow!("Failed to start admin server: {err:?}"))?; } + result = debug_srv => { + result.map_err(|err| anyhow!("Failed to start debug server: {err:?}"))?; + } } Ok(()) } diff --git a/gateway/src/main_service.rs b/gateway/src/main_service.rs index 9162ff40..81b6622f 100644 --- a/gateway/src/main_service.rs +++ b/gateway/src/main_service.rs @@ -13,12 +13,14 @@ use std::{ use anyhow::{bail, Context, Result}; use auth_client::AuthClient; -use certbot::{CertBot, WorkDir}; +use certbot::WorkDir; + +use crate::distributed_certbot::DistributedCertBot; use cmd_lib::run_cmd as cmd; use dstack_gateway_rpc::{ gateway_server::{GatewayRpc, GatewayServer}, - AcmeInfoResponse, GatewayState, GuestAgentConfig, InfoResponse, QuotedPublicKey, - RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, + AcmeInfoResponse, GatewayNodeInfo, GetPeersResponse, GuestAgentConfig, InfoResponse, PeerInfo, + QuotedPublicKey, RegisterCvmRequest, RegisterCvmResponse, WireGuardConfig, WireGuardPeer, }; use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, RawQuoteArgs}; use fs_err as fs; @@ -35,13 +37,15 @@ use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, warn}; use crate::{ - config::Config, + config::{Config, TlsConfig}, + kv::{ + fetch_peers_from_bootnode, AppIdValidator, HttpsClientConfig, InstanceData, KvStore, + NodeData, NodeStatus, WaveKvSyncService, + }, models::{InstanceInfo, WgConf}, proxy::{create_acceptor, AddressGroup, AddressInfo}, }; -mod sync_client; - mod auth_client; #[derive(Clone)] @@ -58,26 +62,21 @@ impl Deref for Proxy { pub struct ProxyInner { pub(crate) config: Arc, - pub(crate) certbot: Option>, + pub(crate) certbot: Option>, my_app_id: Option>, state: Mutex, - notify_state_updated: Notify, + pub(crate) notify_state_updated: Notify, auth_client: AuthClient, pub(crate) acceptor: RwLock, pub(crate) h2_acceptor: RwLock, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub(crate) struct GatewayNodeInfo { - pub id: Vec, - pub url: String, - pub wg_peer: WireGuardPeer, - pub last_seen: SystemTime, + /// WaveKV-based store for persistence (and cross-node sync when enabled) + kv_store: Arc, + /// WaveKV sync service for network synchronization + pub(crate) wavekv_sync: Option>, } #[derive(Debug, Serialize, Deserialize, Default)] pub(crate) struct ProxyStateMut { - pub(crate) nodes: BTreeMap, pub(crate) apps: BTreeMap>, pub(crate) instances: BTreeMap, pub(crate) allocated_addresses: BTreeSet, @@ -88,12 +87,22 @@ pub(crate) struct ProxyStateMut { pub(crate) struct ProxyState { pub(crate) config: Arc, pub(crate) state: ProxyStateMut, + /// Reference to KvStore for syncing changes + kv_store: Arc, +} + +/// Options for creating a Proxy instance +pub struct ProxyOptions { + pub config: Config, + pub my_app_id: Option>, + /// TLS configuration (from Rocket's tls config) + pub tls_config: TlsConfig, } impl Proxy { - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { Ok(Self { - _inner: Arc::new(ProxyInner::new(config, my_app_id).await?), + _inner: Arc::new(ProxyInner::new(options).await?), }) } } @@ -103,52 +112,124 @@ impl ProxyInner { self.state.lock().expect("Failed to lock AppState") } - pub async fn new(config: Config, my_app_id: Option>) -> Result { + pub async fn new(options: ProxyOptions) -> Result { + let ProxyOptions { + config, + my_app_id, + tls_config, + } = options; let config = Arc::new(config); - let mut state = fs::metadata(&config.state_path) - .is_ok() - .then(|| load_state(&config.state_path)) - .transpose() - .unwrap_or_else(|err| { - error!("Failed to load state: {err}"); - None - }) - .unwrap_or_default(); - state - .nodes - .retain(|_, info| info.wg_peer.ip != config.wg.ip.to_string()); - state.nodes.insert( - config.wg.public_key.clone(), - GatewayNodeInfo { - id: config.id(), - url: config.sync.my_url.clone(), - wg_peer: WireGuardPeer { - pk: config.wg.public_key.clone(), - ip: config.wg.ip.to_string(), - endpoint: config.wg.endpoint.clone(), - }, - last_seen: SystemTime::now(), - }, + + // Initialize WaveKV store without peers (peers will be added dynamically from bootnode) + let kv_store = Arc::new( + KvStore::new(config.sync.node_id, vec![], &config.sync.data_dir) + .context("failed to initialize WaveKV store")?, + ); + info!( + "WaveKV store initialized: node_id={}, sync_enabled={}", + config.sync.node_id, config.sync.enabled + ); + + // Load state from WaveKV + let instances = kv_store.load_all_instances(); + let nodes = kv_store.load_all_nodes(); + info!( + "Loaded state from WaveKV: {} instances, {} nodes", + instances.len(), + nodes.len() ); + let state = build_state_from_kv_store(instances); + + // Sync this node to KvStore + let node_data = NodeData { + uuid: config.uuid(), + url: config.sync.my_url.clone(), + wg_public_key: config.wg.public_key.clone(), + wg_endpoint: config.wg.endpoint.clone(), + wg_ip: config.wg.ip.to_string(), + }; + if let Err(err) = kv_store.sync_node(config.sync.node_id, &node_data) { + error!("Failed to sync this node to KvStore: {err}"); + } + // Set this node's status to Online + if let Err(err) = kv_store.set_node_status(config.sync.node_id, NodeStatus::Up) { + error!("Failed to set node status: {err}"); + } + // Register this node's sync URL in DB (for peer discovery) + if let Err(err) = kv_store.register_peer_url(config.sync.node_id, &config.sync.my_url) { + error!("Failed to register peer URL: {err}"); + } + + // Build HttpsClientConfig for mTLS communication + let https_config = { + let tls = &tls_config; + let cert_validator = my_app_id + .clone() + .map(|app_id| Arc::new(AppIdValidator::new(app_id)) as _); + HttpsClientConfig { + cert_path: tls.certs.clone(), + key_path: tls.key.clone(), + ca_cert_path: tls.mutual.ca_certs.clone(), + cert_validator, + } + }; + + // Fetch peers from bootnode if configured (only when sync is enabled) + if config.sync.enabled && !config.sync.bootnode.is_empty() { + if let Err(err) = fetch_peers_from_bootnode( + &config.sync.bootnode, + &kv_store, + config.sync.node_id, + &https_config, + ) + .await + { + warn!("Failed to fetch peers from bootnode: {err}"); + } + } + + // Create WaveKV sync service (only if sync is enabled) + let wavekv_sync = if config.sync.enabled { + match WaveKvSyncService::new(&kv_store, &config.sync, https_config) { + Ok(sync_service) => Some(Arc::new(sync_service)), + Err(err) => { + error!("Failed to create WaveKV sync service: {err}"); + None + } + } + } else { + None + }; + let state = Mutex::new(ProxyState { config: config.clone(), state, + kv_store: kv_store.clone(), }); let auth_client = AuthClient::new(config.auth.clone()); + // Bootstrap WaveKV first if sync is enabled, so certbot can load certs from peers + if let Some(ref wavekv_sync) = wavekv_sync { + info!("WaveKV: bootstrapping from peers..."); + if let Err(err) = wavekv_sync.bootstrap().await { + warn!("WaveKV bootstrap failed: {err}"); + } + } + + // Now initialize certbot - it can access synced data from KvStore let certbot = match config.certbot.enabled { true => { - let certbot = config - .certbot - .build_bot() + let certbot = DistributedCertBot::new(config.certbot.clone(), kv_store.clone()); + info!("Initializing DistributedCertBot..."); + certbot + .init() .await - .context("Failed to build certbot")?; - info!("Certbot built, renewing..."); - // Try first renewal for the acceptor creation - certbot.renew(false).await.context("Failed to renew cert")?; + .context("Failed to initialize distributed certbot")?; Some(Arc::new(certbot)) } false => None, }; + + // Create acceptors (cert files now exist) let acceptor = RwLock::new( create_acceptor(&config.proxy, false).context("Failed to create acceptor")?, ); @@ -163,14 +244,28 @@ impl ProxyInner { acceptor, h2_acceptor, certbot, + kv_store, + wavekv_sync, }) } + + pub(crate) fn kv_store(&self) -> &Arc { + &self.kv_store + } + + pub(crate) fn my_app_id(&self) -> Option<&[u8]> { + self.my_app_id.as_deref() + } } impl Proxy { pub(crate) async fn start_bg_tasks(&self) -> Result<()> { start_recycle_thread(self.clone()); - start_sync_task(self.clone()); + // Start WaveKV periodic sync (bootstrap already done in new()) + if let Some(ref wavekv_sync) = self.wavekv_sync { + start_wavekv_sync_task(self.clone(), wavekv_sync.clone()).await; + } + start_wavekv_watch_task(self.clone()).context("Failed to start WaveKV watch task")?; start_certbot_task(self.clone()).await?; Ok(()) } @@ -179,12 +274,26 @@ impl Proxy { let Some(certbot) = &self.certbot else { return Ok(false); }; - let renewed = certbot.renew(force).await.context("Failed to renew cert")?; - if renewed { + let renewed = certbot + .try_renew(force) + .await + .context("Failed to renew cert")?; + Ok(renewed) + } + + /// Reload certificate from KvStore (called when watcher triggers) + pub(crate) fn reload_cert_from_kvstore(&self) -> Result { + let Some(certbot) = &self.certbot else { + return Ok(false); + }; + let reloaded = certbot + .reload_from_kvstore() + .context("Failed to reload cert from KvStore")?; + if reloaded { self.reload_certificates() .context("Failed to reload certificates")?; } - Ok(renewed) + Ok(reloaded) } pub(crate) async fn acme_info(&self) -> Result { @@ -192,35 +301,49 @@ impl Proxy { let workdir = WorkDir::new(&config.certbot.workdir); let account_uri = workdir.acme_account_uri().unwrap_or_default(); let keys = workdir.list_cert_public_keys().unwrap_or_default(); - let agent = crate::dstack_agent().context("Failed to get dstack agent")?; - let account_quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("acme-account"), - account_uri.as_bytes(), - workdir.acme_account_quote_path(), - ) - .await - .unwrap_or_default(); - let mut quoted_hist_keys = vec![]; - for cert_path in workdir.list_certs().unwrap_or_default() { - let cert_pem = fs::read_to_string(&cert_path).context("Failed to read key")?; - let pubkey = certbot::read_pubkey(&cert_pem).context("Failed to read pubkey")?; - let quote = get_or_generate_quote( - &agent, - QuoteContentType::Custom("zt-cert"), - &pubkey, - cert_path.display().to_string() + ".quote", + // Try to get dstack agent for quote generation (optional in test environments) + let agent = crate::dstack_agent().ok(); + + let account_quote = match &agent { + Some(agent) => get_or_generate_quote( + agent, + QuoteContentType::Custom("acme-account"), + account_uri.as_bytes(), + workdir.acme_account_quote_path(), ) .await - .unwrap_or_default(); + .unwrap_or_default(), + None => String::new(), + }; + + let mut quoted_hist_keys = vec![]; + for cert_path in workdir.list_certs().unwrap_or_default() { + let cert_pem = match fs::read_to_string(&cert_path) { + Ok(pem) => pem, + Err(_) => continue, + }; + let pubkey = match certbot::read_pubkey(&cert_pem) { + Ok(pk) => pk, + Err(_) => continue, + }; + let quote = match &agent { + Some(agent) => get_or_generate_quote( + agent, + QuoteContentType::Custom("zt-cert"), + &pubkey, + cert_path.display().to_string() + ".quote", + ) + .await + .unwrap_or_default(), + None => String::new(), + }; quoted_hist_keys.push(QuotedPublicKey { public_key: pubkey, quote, }); } - let active_cert = - fs::read_to_string(workdir.cert_path()).context("Failed to read active cert")?; + let active_cert = fs::read_to_string(workdir.cert_path()).unwrap_or_default(); Ok(AcmeInfoResponse { account_uri, @@ -231,11 +354,89 @@ impl Proxy { base_domain: config.proxy.base_domain.clone(), }) } + + /// Register a CVM with the given app_id, instance_id and client_public_key + pub fn do_register_cvm( + &self, + app_id: &str, + instance_id: &str, + client_public_key: &str, + ) -> Result { + let mut state = self.lock(); + + // Check if this node is marked as down + let my_status = state.kv_store.get_node_status(state.config.sync.node_id); + if matches!(my_status, NodeStatus::Down) { + bail!("this gateway node is marked as down and cannot accept new registrations"); + } + + if app_id.is_empty() { + bail!("[{instance_id}] app id is empty"); + } + if instance_id.is_empty() { + bail!("[{instance_id}] instance id is empty"); + } + if client_public_key.is_empty() { + bail!("[{instance_id}] client public key is empty"); + } + let client_info = state + .new_client_by_id(instance_id, app_id, client_public_key) + .context("failed to allocate IP address for client")?; + if let Err(err) = state.reconfigure() { + error!("failed to reconfigure: {}", err); + } + let gateways = state.get_active_nodes(); + let servers = gateways + .iter() + .map(|n| WireGuardPeer { + pk: n.wg_public_key.clone(), + ip: n.wg_ip.clone(), + endpoint: n.wg_endpoint.clone(), + }) + .collect::>(); + let response = RegisterCvmResponse { + wg: Some(WireGuardConfig { + client_ip: client_info.ip.to_string(), + servers, + }), + agent: Some(GuestAgentConfig { + external_port: state.config.proxy.external_port as u32, + internal_port: state.config.proxy.agent_port as u32, + domain: state.config.proxy.base_domain.clone(), + app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), + }), + gateways, + }; + self.notify_state_updated.notify_one(); + Ok(response) + } } -fn load_state(state_path: &str) -> Result { - let state_str = fs::read_to_string(state_path).context("Failed to read state")?; - serde_json::from_str(&state_str).context("Failed to load state") +fn build_state_from_kv_store(instances: BTreeMap) -> ProxyStateMut { + let mut state = ProxyStateMut::default(); + + // Build instances + for (instance_id, data) in instances { + let info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key, + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + state.allocated_addresses.insert(data.ip); + state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.instances.insert(instance_id, info); + } + + state } fn start_recycle_thread(proxy: Proxy) { @@ -256,6 +457,25 @@ async fn start_certbot_task(proxy: Proxy) -> Result<()> { info!("Certbot is not enabled"); return Ok(()); }; + + // Watch for cert changes from other nodes + let domain = certbot.domain().to_string(); + let kv_store = proxy.kv_store.clone(); + let proxy_for_watch = proxy.clone(); + tokio::spawn(async move { + let mut rx = kv_store.watch_cert(&domain); + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected certificate change for {domain}, reloading..."); + if let Err(err) = proxy_for_watch.reload_cert_from_kvstore() { + error!("Failed to reload cert from KvStore: {err}"); + } + } + }); + + // Periodic renewal task tokio::spawn(async move { loop { tokio::time::sleep(certbot.renew_interval()).await; @@ -267,17 +487,148 @@ async fn start_certbot_task(proxy: Proxy) -> Result<()> { Ok(()) } -fn start_sync_task(proxy: Proxy) { +async fn start_wavekv_sync_task(proxy: Proxy, wavekv_sync: Arc) { if !proxy.config.sync.enabled { - info!("sync is disabled"); + info!("WaveKV sync is disabled"); return; } + + // Bootstrap already done in ProxyInner::new() before certbot init + // Peers are discovered from bootnode or via Admin.SetNodeInfo RPC + + // Start periodic sync tasks (runs forever in background) tokio::spawn(async move { - match sync_client::sync_task(proxy).await { - Ok(_) => info!("Sync task exited"), - Err(err) => error!("Failed to run sync task: {err}"), + wavekv_sync.start_sync_tasks().await; + }); + info!("WaveKV sync tasks started"); +} + +fn start_wavekv_watch_task(proxy: Proxy) -> Result<()> { + let kv_store = proxy.kv_store.clone(); + + // Watch for instance changes + let proxy_clone = proxy.clone(); + let store_clone = kv_store.clone(); + // Register watcher first, then do initial load to avoid race condition + let mut rx = store_clone.watch_instances(); + reload_instances_from_kv_store(&proxy_clone, &store_clone) + .context("Failed to initial load instances from KvStore")?; + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote instance changes, reloading..."); + if let Err(err) = reload_instances_from_kv_store(&proxy_clone, &store_clone) { + error!("Failed to reload instances from KvStore: {err}"); + } } }); + + // Initial WireGuard configuration + proxy.lock().reconfigure()?; + + // Watch for node changes and reconfigure WireGuard + let mut rx = kv_store.watch_nodes(); + let proxy_for_nodes = proxy.clone(); + tokio::spawn(async move { + loop { + if rx.changed().await.is_err() { + break; + } + info!("WaveKV: detected remote node changes, reconfiguring WireGuard..."); + if let Err(err) = proxy_for_nodes.lock().reconfigure() { + error!("Failed to reconfigure WireGuard: {err}"); + } + } + }); + + // Start periodic persistence task + let persist_interval = proxy.config.sync.persist_interval; + if !persist_interval.is_zero() { + let kv_store_for_persist = kv_store.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(persist_interval); + loop { + ticker.tick().await; + match kv_store_for_persist.persist_if_dirty() { + Ok(true) => info!("WaveKV: periodic persist completed"), + Ok(false) => {} // No changes to persist + Err(err) => error!("WaveKV: periodic persist failed: {err}"), + } + } + }); + info!("WaveKV: periodic persistence enabled (interval: {persist_interval:?})"); + } + + // Start periodic connection sync task + if proxy.config.sync.sync_connections_enabled { + let sync_interval = proxy.config.sync.sync_connections_interval; + let proxy_for_sync = proxy.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(sync_interval); + loop { + ticker.tick().await; + let state = proxy_for_sync.lock(); + for (instance_id, instance) in &state.state.instances { + let count = instance.num_connections(); + state.sync_connections(instance_id, count); + } + } + }); + info!( + "WaveKV: periodic connection sync enabled (interval: {:?})", + proxy.config.sync.sync_connections_interval + ); + } + + Ok(()) +} + +fn reload_instances_from_kv_store(proxy: &Proxy, store: &KvStore) -> Result<()> { + let instances = store.load_all_instances(); + let mut state = proxy.lock(); + let mut wg_changed = false; + + for (instance_id, data) in instances { + let new_info = InstanceInfo { + id: instance_id.clone(), + app_id: data.app_id.clone(), + ip: data.ip, + public_key: data.public_key.clone(), + reg_time: UNIX_EPOCH + .checked_add(Duration::from_secs(data.reg_time)) + .unwrap_or(UNIX_EPOCH), + connections: Default::default(), + }; + + if let Some(existing) = state.state.instances.get(&instance_id) { + // Check if wg config needs update + if existing.public_key != data.public_key || existing.ip != data.ip { + wg_changed = true; + } + // Only update if remote is newer (based on reg_time) + if data.reg_time <= encode_ts(existing.reg_time) { + continue; + } + } else { + wg_changed = true; + } + + state.state.allocated_addresses.insert(data.ip); + state + .state + .apps + .entry(data.app_id) + .or_default() + .insert(instance_id.clone()); + state.state.instances.insert(instance_id, new_info); + } + + if wg_changed { + state.reconfigure()?; + } + Ok(()) } impl ProxyState { @@ -329,6 +680,16 @@ impl ProxyState { } let existing = existing.clone(); if self.valid_ip(existing.ip) { + // Sync existing instance to KvStore (might be from legacy state) + let data = InstanceData { + app_id: existing.app_id.clone(), + ip: existing.ip, + public_key: existing.public_key.clone(), + reg_time: encode_ts(existing.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&existing.id, &data) { + error!("failed to sync existing instance to KvStore: {err}"); + } return Some(existing); } info!("ip {} is invalid, removing", existing.ip); @@ -341,7 +702,6 @@ impl ProxyState { ip, public_key: public_key.to_string(), reg_time: SystemTime::now(), - last_seen: SystemTime::now(), connections: Default::default(), }; self.add_instance(host_info.clone()); @@ -349,6 +709,17 @@ impl ProxyState { } fn add_instance(&mut self, info: InstanceInfo) { + // Sync to KvStore + let data = InstanceData { + app_id: info.app_id.clone(), + ip: info.ip, + public_key: info.public_key.clone(), + reg_time: encode_ts(info.reg_time), + }; + if let Err(err) = self.kv_store.sync_instance(&info.id, &data) { + error!("failed to sync instance to KvStore: {err}"); + } + self.state .apps .entry(info.app_id.clone()) @@ -377,13 +748,6 @@ impl ProxyState { Ok(_) => info!("wg config updated"), Err(e) => error!("failed to set wg config: {e}"), } - self.save_state()?; - Ok(()) - } - - fn save_state(&self) -> Result<()> { - let state_str = serde_json::to_string(&self.state).context("Failed to serialize state")?; - safe_write(&self.config.state_path, state_str).context("Failed to write state")?; Ok(()) } @@ -530,6 +894,12 @@ impl ProxyState { .instances .remove(id) .context("instance not found")?; + + // Sync deletion to KvStore + if let Err(err) = self.kv_store.sync_delete_instance(id) { + error!("Failed to sync instance deletion to KvStore: {err}"); + } + self.state.allocated_addresses.remove(&info.ip); if let Some(app_instances) = self.state.apps.get_mut(&info.app_id) { app_instances.remove(id); @@ -541,48 +911,50 @@ impl ProxyState { } fn recycle(&mut self) -> Result<()> { - // Recycle stale Gateway nodes - let mut staled_nodes = vec![]; - for node in self.state.nodes.values() { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.last_seen.elapsed().unwrap_or_default() > self.config.recycle.node_timeout { - staled_nodes.push(node.wg_peer.pk.clone()); - } - } - for id in staled_nodes { - self.state.nodes.remove(&id); + // Refresh state: sync local handshakes to KvStore, update local last_seen from global + if let Err(err) = self.refresh_state() { + warn!("failed to refresh state: {err}"); } - // Recycle stale CVM instances + // Note: Gateway nodes are not removed from KvStore, only marked offline/retired + + // Recycle stale CVM instances based on global last_seen (max across all nodes) let stale_timeout = self.config.recycle.timeout; - let stale_handshakes = self.latest_handshakes(Some(stale_timeout))?; - if tracing::enabled!(tracing::Level::DEBUG) { - for (pubkey, (ts, elapsed)) in &stale_handshakes { - debug!("stale instance: {pubkey} recent={ts} ({elapsed:?} ago)"); - } - } - // Find and remove instances with matching public keys + let now = SystemTime::now(); + let stale_instances: Vec<_> = self .state .instances .iter() - .filter(|(_, info)| { - stale_handshakes.contains_key(&info.public_key) && { - info.reg_time.elapsed().unwrap_or_default() > stale_timeout + .filter(|(id, info)| { + // Skip if instance was registered recently + if info.reg_time.elapsed().unwrap_or_default() <= stale_timeout { + return false; + } + // Check global last_seen from KvStore (max across all nodes) + let global_ts = self.kv_store.get_instance_latest_handshake(id); + let last_seen = global_ts.map(decode_ts).unwrap_or(info.reg_time); + let elapsed = now.duration_since(last_seen).unwrap_or_default(); + if elapsed > stale_timeout { + debug!( + "stale instance: {} last_seen={:?} ({:?} ago)", + id, last_seen, elapsed + ); + true + } else { + false } }) - .map(|(id, _info)| id.clone()) + .map(|(id, _)| id.clone()) .collect(); - debug!("stale instances: {:#?}", stale_instances); + let num_recycled = stale_instances.len(); for id in stale_instances { self.remove_instance(&id)?; } - info!("recycled {num_recycled} stale instances"); - // Reconfigure WireGuard with updated peers + if num_recycled > 0 { + info!("recycled {num_recycled} stale instances"); self.reconfigure()?; } Ok(()) @@ -592,89 +964,94 @@ impl ProxyState { std::process::exit(0); } - fn dedup_nodes(&mut self) { - // Dedup nodes by URL, keeping the latest one - let mut node_map = BTreeMap::::new(); + pub(crate) fn refresh_state(&mut self) -> Result<()> { + // Get local WG handshakes and sync to KvStore + let handshakes = self.latest_handshakes(None)?; + + // Build a map from public_key to instance_id for lookup + let pk_to_id: BTreeMap<&str, &str> = self + .state + .instances + .iter() + .map(|(id, info)| (info.public_key.as_str(), id.as_str())) + .collect(); - for node in std::mem::take(&mut self.state.nodes).into_values() { - match node_map.get(&node.wg_peer.endpoint) { - Some(existing) if existing.last_seen >= node.last_seen => {} - _ => { - node_map.insert(node.wg_peer.endpoint.clone(), node); + // Sync local handshake observations to KvStore + for (pk, (ts, _)) in &handshakes { + if let Some(&instance_id) = pk_to_id.get(pk.as_str()) { + if let Err(err) = self.kv_store.sync_instance_handshake(instance_id, *ts) { + debug!("failed to sync instance handshake: {err}"); } } } - for node in node_map.into_values() { - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + + // Update this node's last_seen in KvStore + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + if let Err(err) = self + .kv_store + .sync_node_last_seen(self.config.sync.node_id, now) + { + debug!("failed to sync node last_seen: {err}"); } + Ok(()) } - fn update_state( - &mut self, - proxy_nodes: Vec, - apps: Vec, - ) -> Result<()> { - for node in proxy_nodes { - if node.wg_peer.pk == self.config.wg.public_key { - continue; - } - if node.url == self.config.sync.my_url { - continue; - } - if let Some(existing) = self.state.nodes.get(&node.wg_peer.pk) { - if node.last_seen <= existing.last_seen { - continue; - } - } - self.state.nodes.insert(node.wg_peer.pk.clone(), node); + /// Sync connection count for an instance to KvStore + pub(crate) fn sync_connections(&self, instance_id: &str, count: u64) { + if let Err(err) = self.kv_store.sync_connections(instance_id, count) { + debug!("Failed to sync connections: {err}"); } - self.dedup_nodes(); + } - let mut wg_changed = false; - for app in apps { - if let Some(existing) = self.state.instances.get(&app.id) { - let existing_ts = (existing.reg_time, existing.last_seen); - let update_ts = (app.reg_time, app.last_seen); - if update_ts <= existing_ts { - continue; - } - if !wg_changed { - wg_changed = existing.public_key != app.public_key || existing.ip != app.ip; - } - } else { - wg_changed = true; - } - self.add_instance(app); - } - info!("updated, wg_changed: {wg_changed}"); - if wg_changed { - self.reconfigure()?; - } else { - self.save_state()?; - } - Ok(()) + /// Get latest handshake for an instance from KvStore (max across all nodes) + pub(crate) fn get_instance_latest_handshake(&self, instance_id: &str) -> Option { + self.kv_store.get_instance_latest_handshake(instance_id) } - fn dump_state(&mut self) -> (Vec, Vec) { - self.refresh_state().ok(); - ( - self.state.nodes.values().cloned().collect(), - self.state.instances.values().cloned().collect(), - ) + /// Get all nodes from KvStore (for admin API - includes all nodes) + pub(crate) fn get_all_nodes(&self) -> Vec { + self.get_all_nodes_filtered(false) } - pub(crate) fn refresh_state(&mut self) -> Result<()> { - let handshakes = self.latest_handshakes(None)?; - for instance in self.state.instances.values_mut() { - let Some((ts, _)) = handshakes.get(&instance.public_key).copied() else { - continue; - }; - instance.last_seen = decode_ts(ts); - } - if let Some(node) = self.state.nodes.get_mut(&self.config.wg.public_key) { - node.last_seen = SystemTime::now(); - } - Ok(()) + /// Get nodes for CVM registration (excludes nodes with status "down") + pub(crate) fn get_active_nodes(&self) -> Vec { + self.get_all_nodes_filtered(true) + } + + /// Get all nodes from KvStore with optional filtering + fn get_all_nodes_filtered(&self, exclude_down: bool) -> Vec { + let node_statuses = if exclude_down { + self.kv_store.load_all_node_statuses() + } else { + Default::default() + }; + + self.kv_store + .load_all_nodes() + .into_iter() + .filter(|(id, _)| { + if !exclude_down { + return true; + } + // Exclude nodes with status "down" + match node_statuses.get(id) { + Some(NodeStatus::Down) => false, + _ => true, // Include Up or nodes without explicit status + } + }) + .map(|(id, node)| GatewayNodeInfo { + id, + uuid: node.uuid, + wg_public_key: node.wg_public_key, + wg_ip: node.wg_ip, + wg_endpoint: node.wg_endpoint, + url: node.url, + last_seen: self.kv_store.get_node_latest_last_seen(id).unwrap_or(0), + }) + .collect() } } @@ -696,7 +1073,7 @@ pub struct RpcHandler { impl RpcHandler { fn ensure_from_gateway(&self) -> Result<()> { - if !self.state.config.run_in_dstack { + if self.state.config.debug.insecure_skip_attestation { return Ok(()); } if self.remote_app_id.is_none() { @@ -724,76 +1101,14 @@ impl GatewayRpc for RpcHandler { .context("App authorization failed")?; let app_id = hex::encode(&app_info.app_id); let instance_id = hex::encode(&app_info.instance_id); - - let mut state = self.state.lock(); - if request.client_public_key.is_empty() { - bail!("[{instance_id}] client public key is empty"); - } - let client_info = state - .new_client_by_id(&instance_id, &app_id, &request.client_public_key) - .context("failed to allocate IP address for client")?; - if let Err(err) = state.reconfigure() { - error!("failed to reconfigure: {}", err); - } - let servers = state - .state - .nodes - .values() - .map(|n| n.wg_peer.clone()) - .collect::>(); - let response = RegisterCvmResponse { - wg: Some(WireGuardConfig { - client_ip: client_info.ip.to_string(), - servers, - }), - agent: Some(GuestAgentConfig { - external_port: state.config.proxy.external_port as u32, - internal_port: state.config.proxy.agent_port as u32, - domain: state.config.proxy.base_domain.clone(), - app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), - }), - }; - self.state.notify_state_updated.notify_one(); - Ok(response) + self.state + .do_register_cvm(&app_id, &instance_id, &request.client_public_key) } async fn acme_info(self) -> Result { self.state.acme_info().await } - async fn update_state(self, request: GatewayState) -> Result<()> { - self.ensure_from_gateway()?; - let mut nodes = vec![]; - let mut apps = vec![]; - - for node in request.nodes { - nodes.push(GatewayNodeInfo { - id: node.id, - wg_peer: node.wg_peer.context("wg_peer is missing")?, - last_seen: decode_ts(node.last_seen), - url: node.url, - }); - } - - for app in request.apps { - apps.push(InstanceInfo { - id: app.instance_id, - app_id: app.app_id, - ip: app.ip.parse().context("Invalid IP address")?, - public_key: app.public_key, - reg_time: decode_ts(app.reg_time), - last_seen: decode_ts(app.last_seen), - connections: Default::default(), - }); - } - - self.state - .lock() - .update_state(nodes, apps) - .context("failed to update state")?; - Ok(()) - } - async fn info(self) -> Result { let state = self.state.lock(); Ok(InfoResponse { @@ -802,6 +1117,27 @@ impl GatewayRpc for RpcHandler { app_address_ns_prefix: state.config.proxy.app_address_ns_prefix.clone(), }) } + + async fn get_peers(self) -> Result { + self.ensure_from_gateway()?; + + let kv_store = self.state.kv_store(); + let config = &self.state.config; + + // Get all peer addresses from KvStore + let peer_addrs = kv_store.get_all_peer_addrs(); + + let peers: Vec = peer_addrs + .into_iter() + .map(|(id, url)| PeerInfo { id, url }) + .collect(); + + Ok(GetPeersResponse { + my_id: config.sync.node_id, + my_url: config.sync.my_url.clone(), + peers, + }) + } } async fn get_or_generate_quote( @@ -836,30 +1172,5 @@ impl RpcCall for RpcHandler { } } -impl From for dstack_gateway_rpc::GatewayNodeInfo { - fn from(node: GatewayNodeInfo) -> Self { - Self { - id: node.id, - wg_peer: Some(node.wg_peer), - last_seen: encode_ts(node.last_seen), - url: node.url, - } - } -} - -impl From for dstack_gateway_rpc::AppInstanceInfo { - fn from(app: InstanceInfo) -> Self { - Self { - num_connections: app.num_connections(), - instance_id: app.id, - app_id: app.app_id, - ip: app.ip.to_string(), - public_key: app.public_key, - reg_time: encode_ts(app.reg_time), - last_seen: encode_ts(app.last_seen), - } - } -} - #[cfg(test)] mod tests; diff --git a/gateway/src/main_service/sync_client.rs b/gateway/src/main_service/sync_client.rs deleted file mode 100644 index 7feba2a0..00000000 --- a/gateway/src/main_service/sync_client.rs +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-FileCopyrightText: © 2025 Phala Network -// -// SPDX-License-Identifier: Apache-2.0 - -use std::time::{Duration, Instant}; - -use anyhow::{Context, Result}; -use dstack_gateway_rpc::{gateway_client::GatewayClient, GatewayState}; -use dstack_guest_agent_rpc::GetTlsKeyArgs; -use ra_rpc::client::{RaClient, RaClientConfig}; -use tracing::{error, info}; - -use crate::{dstack_agent, main_service::Proxy}; - -struct SyncClient { - in_dstack: bool, - cert_pem: String, - key_pem: String, - ca_cert_pem: String, - app_id: Vec, - timeout: Duration, - pccs_url: Option, -} - -impl SyncClient { - fn create_rpc_client(&self, url: &str) -> Result> { - let app_id = self.app_id.clone(); - let url = format!("{}/prpc", url.trim_end_matches('/')); - let client = if self.in_dstack { - RaClientConfig::builder() - .remote_uri(url) - // Don't verify server RA because we use the CA cert from KMS to verify - // the server cert. - .verify_server_attestation(false) - .tls_no_check(true) - .tls_no_check_hostname(false) - .tls_client_cert(self.cert_pem.clone()) - .tls_client_key(self.key_pem.clone()) - .tls_ca_cert(self.ca_cert_pem.clone()) - .tls_built_in_root_certs(false) - .maybe_pccs_url(self.pccs_url.clone()) - .cert_validator(Box::new(move |cert| { - let cert = cert.context("TLS cert not found")?; - let remote_app_id = cert.app_id.context("App id not found")?; - if remote_app_id != app_id { - return Err(anyhow::anyhow!("Remote app id mismatch")); - } - Ok(()) - })) - .build() - .into_client() - .context("failed to create client")? - } else { - RaClient::new(url, true)? - }; - Ok(GatewayClient::new(client)) - } - - async fn sync_state(&self, url: &str, state: &GatewayState) -> Result<()> { - info!("Trying to sync state to {url}"); - let rpc = self.create_rpc_client(url)?; - tokio::time::timeout(self.timeout, rpc.update_state(state.clone())) - .await - .ok() - .context("Timeout while syncing state")? - .context("Failed to sync state")?; - info!("Synced state to {url}"); - Ok(()) - } - - async fn sync_state_ignore_error(&self, url: &str, state: &GatewayState) -> bool { - match self.sync_state(url, state).await { - Ok(_) => true, - Err(e) => { - error!("Failed to sync state to {url}: {e:?}"); - false - } - } - } -} - -pub(crate) async fn sync_task(proxy: Proxy) -> Result<()> { - let config = proxy.config.clone(); - let sync_client = if config.run_in_dstack { - let agent = dstack_agent().context("Failed to create dstack agent client")?; - let keys = agent - .get_tls_key(GetTlsKeyArgs { - subject: "dstack-gateway-sync-client".into(), - alt_names: vec![], - usage_ra_tls: false, - usage_server_auth: false, - usage_client_auth: true, - }) - .await - .context("Failed to get sync-client keys")?; - let my_app_id = agent - .info() - .await - .context("Failed to get guest info")? - .app_id; - SyncClient { - in_dstack: true, - cert_pem: keys.certificate_chain.join("\n"), - key_pem: keys.key, - ca_cert_pem: keys.certificate_chain.last().cloned().unwrap_or_default(), - app_id: my_app_id, - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - } else { - SyncClient { - in_dstack: false, - cert_pem: "".into(), - key_pem: "".into(), - ca_cert_pem: "".into(), - app_id: vec![], - timeout: config.sync.timeout, - pccs_url: config.pccs_url.clone(), - } - }; - - let mut last_broadcast_time = Instant::now(); - let mut broadcast = false; - loop { - if broadcast { - last_broadcast_time = Instant::now(); - } - - let (mut nodes, apps) = proxy.lock().dump_state(); - // Sort nodes by pubkey - nodes.sort_by(|a, b| a.id.cmp(&b.id)); - - let self_idx = nodes - .iter() - .position(|n| n.wg_peer.pk == config.wg.public_key) - .unwrap_or(0); - - let state = GatewayState { - nodes: nodes.into_iter().map(|n| n.into()).collect(), - apps: apps.into_iter().map(|a| a.into()).collect(), - }; - - if state.nodes.is_empty() { - // If no nodes exist yet, sync with bootnode - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } else { - let nodes = &state.nodes; - // Try nodes after self, wrapping around to beginning - let mut success = false; - for i in 1..nodes.len() { - let idx = (self_idx + i) % nodes.len(); - if sync_client - .sync_state_ignore_error(&nodes[idx].url, &state) - .await - { - success = true; - if !broadcast { - break; - } - } - } - - // If no node succeeded, try bootnode as fallback - if !success { - info!("Fallback to sync with bootnode"); - sync_client - .sync_state_ignore_error(&config.sync.bootnode, &state) - .await; - } - } - - tokio::select! { - _ = proxy.notify_state_updated.notified() => { - broadcast = true; - } - _ = tokio::time::sleep(config.sync.interval) => { - broadcast = last_broadcast_time.elapsed() >= config.sync.broadcast_interval; - } - } - } -} diff --git a/gateway/src/main_service/tests.rs b/gateway/src/main_service/tests.rs index d98c0131..71f972a5 100644 --- a/gateway/src/main_service/tests.rs +++ b/gateway/src/main_service/tests.rs @@ -3,17 +3,47 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; -use crate::config::{load_config_figment, Config}; +use crate::config::{load_config_figment, Config, MutualConfig}; +use tempfile::TempDir; -async fn create_test_state() -> Proxy { +struct TestState { + proxy: Proxy, + _temp_dir: TempDir, +} + +impl std::ops::Deref for TestState { + type Target = Proxy; + fn deref(&self) -> &Self::Target { + &self.proxy + } +} + +async fn create_test_state() -> TestState { let figment = load_config_figment(None); let mut config = figment.focus("core").extract::().unwrap(); let cargo_dir = env!("CARGO_MANIFEST_DIR"); config.proxy.cert_chain = format!("{cargo_dir}/assets/cert.pem"); config.proxy.cert_key = format!("{cargo_dir}/assets/cert.key"); - Proxy::new(config, None) + let temp_dir = TempDir::new().expect("failed to create temp dir"); + config.sync.data_dir = temp_dir.path().to_string_lossy().to_string(); + let options = ProxyOptions { + config, + my_app_id: None, + tls_config: TlsConfig { + certs: "".to_string(), + key: "".to_string(), + mutual: MutualConfig { + ca_certs: "".to_string(), + }, + }, + }; + let proxy = Proxy::new(options) .await - .expect("failed to create app state") + .expect("failed to create app state"); + TestState { + proxy, + _temp_dir: temp_dir, + } } #[tokio::test] diff --git a/gateway/src/models.rs b/gateway/src/models.rs index ec476cff..37caa274 100644 --- a/gateway/src/models.rs +++ b/gateway/src/models.rs @@ -60,7 +60,6 @@ pub struct InstanceInfo { pub ip: Ipv4Addr, pub public_key: String, pub reg_time: SystemTime, - pub last_seen: SystemTime, #[serde(skip)] pub connections: Arc, } diff --git a/gateway/src/proxy.rs b/gateway/src/proxy.rs index 75cc286e..91f841aa 100644 --- a/gateway/src/proxy.rs +++ b/gateway/src/proxy.rs @@ -16,6 +16,7 @@ pub(crate) use tls_terminate::create_acceptor; use tokio::{ io::AsyncReadExt, net::{TcpListener, TcpStream}, + runtime::Runtime, time::timeout, }; use tracing::{debug, error, info, info_span, Instrument}; @@ -160,14 +161,7 @@ async fn handle_connection( } #[inline(never)] -pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { - let workers_rt = tokio::runtime::Builder::new_multi_thread() - .thread_name("proxy-worker") - .enable_all() - .worker_threads(config.workers) - .build() - .expect("Failed to build Tokio runtime"); - +pub async fn proxy_main(rt: &Runtime, config: &ProxyConfig, proxy: Proxy) -> Result<()> { let dotted_base_domain = { let base_domain = config.base_domain.as_str(); let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain); @@ -196,7 +190,7 @@ pub async fn proxy_main(config: &ProxyConfig, proxy: Proxy) -> Result<()> { info!(%from, "new connection"); let proxy = proxy.clone(); let dotted_base_domain = dotted_base_domain.clone(); - workers_rt.spawn( + rt.spawn( async move { let _conn_entered = conn_entered; let timeouts = &proxy.config.proxy.timeouts; @@ -242,8 +236,15 @@ pub fn start(config: ProxyConfig, app_state: Proxy) { .build() .expect("Failed to build Tokio runtime"); + let worker_rt = tokio::runtime::Builder::new_multi_thread() + .thread_name("proxy-worker") + .enable_all() + .worker_threads(config.workers) + .build() + .expect("Failed to build Tokio runtime"); + // Run the proxy_main function in this runtime - if let Err(err) = rt.block_on(proxy_main(&config, app_state)) { + if let Err(err) = rt.block_on(proxy_main(&worker_rt, &config, app_state)) { error!( "error on {}:{}: {err:?}", config.listen_addr, config.listen_port diff --git a/gateway/src/web_routes.rs b/gateway/src/web_routes.rs index 1bd57f2b..adc7bfb5 100644 --- a/gateway/src/web_routes.rs +++ b/gateway/src/web_routes.rs @@ -7,6 +7,7 @@ use anyhow::Result; use rocket::{get, response::content::RawHtml, routes, Route, State}; mod route_index; +mod wavekv_sync; #[get("/")] async fn index(state: &State) -> Result, String> { @@ -16,3 +17,8 @@ async fn index(state: &State) -> Result, String> { pub fn routes() -> Vec { routes![index] } + +/// WaveKV sync endpoint (for main server, requires mTLS gateway auth) +pub fn wavekv_sync_routes() -> Vec { + routes![wavekv_sync::sync_store] +} diff --git a/gateway/src/web_routes/wavekv_sync.rs b/gateway/src/web_routes/wavekv_sync.rs new file mode 100644 index 00000000..dead1141 --- /dev/null +++ b/gateway/src/web_routes/wavekv_sync.rs @@ -0,0 +1,150 @@ +// SPDX-FileCopyrightText: © 2024-2025 Phala Network +// +// SPDX-License-Identifier: Apache-2.0 + +//! WaveKV sync HTTP endpoints +//! +//! Sync data is encoded using msgpack + gzip compression for efficiency. + +use crate::{ + kv::{decode, encode}, + main_service::Proxy, +}; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use ra_tls::traits::CertExt; +use rocket::{ + data::{Data, ToByteUnit}, + http::{ContentType, Status}, + mtls::{oid::Oid, Certificate}, + post, State, +}; +use std::io::{Read, Write}; +use tracing::warn; +use wavekv::sync::{SyncMessage, SyncResponse}; + +/// Wrapper to implement CertExt for Rocket's Certificate +struct RocketCert<'a>(&'a Certificate<'a>); + +impl CertExt for RocketCert<'_> { + fn get_extension_der(&self, oid: &[u64]) -> anyhow::Result>> { + let oid = Oid::from(oid).map_err(|_| anyhow::anyhow!("failed to create OID from slice"))?; + let Some(ext) = self.0.extensions().iter().find(|ext| ext.oid == oid) else { + return Ok(None); + }; + Ok(Some(ext.value.to_vec())) + } +} + +/// Decode compressed msgpack data +fn decode_sync_message(data: &[u8]) -> Result { + // Decompress + let mut decoder = GzDecoder::new(data); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed).map_err(|e| { + warn!("failed to decompress sync message: {e}"); + Status::BadRequest + })?; + + decode(&decompressed).map_err(|e| { + warn!("failed to decode sync message: {e}"); + Status::BadRequest + }) +} + +/// Encode and compress sync response +fn encode_sync_response(response: &SyncResponse) -> Result, Status> { + let encoded = encode(response).map_err(|e| { + warn!("failed to encode sync response: {e}"); + Status::InternalServerError + })?; + + // Compress + let mut encoder = GzEncoder::new(Vec::new(), Compression::fast()); + encoder.write_all(&encoded).map_err(|e| { + warn!("failed to compress sync response: {e}"); + Status::InternalServerError + })?; + encoder.finish().map_err(|e| { + warn!("failed to finish compression: {e}"); + Status::InternalServerError + }) +} + +/// Verify that the request is from a gateway with the same app_id (mTLS verification) +fn verify_gateway_peer(state: &Proxy, cert: Option>) -> Result<(), Status> { + // Skip verification if not running in dstack (test mode) + if state.config.debug.insecure_skip_attestation { + return Ok(()); + } + + let Some(cert) = cert else { + warn!("WaveKV sync: client certificate required but not provided"); + return Err(Status::Unauthorized); + }; + + let remote_app_id = RocketCert(&cert).get_app_id().map_err(|e| { + warn!("WaveKV sync: failed to extract app_id from certificate: {e}"); + Status::Unauthorized + })?; + + let Some(remote_app_id) = remote_app_id else { + warn!("WaveKV sync: certificate does not contain app_id"); + return Err(Status::Unauthorized); + }; + + if state.my_app_id() != Some(remote_app_id.as_slice()) { + warn!( + "WaveKV sync: app_id mismatch, expected {:?}, got {:?}", + state.my_app_id(), + remote_app_id + ); + return Err(Status::Forbidden); + } + + Ok(()) +} + +/// Handle sync request (msgpack + gzip encoded) +#[post("/wavekv/sync/", data = "")] +pub async fn sync_store( + state: &State, + cert: Option>, + store: &str, + data: Data<'_>, +) -> Result<(ContentType, Vec), Status> { + verify_gateway_peer(state, cert)?; + + let Some(ref wavekv_sync) = state.wavekv_sync else { + return Err(Status::ServiceUnavailable); + }; + + // Read and decode request + let bytes = data + .open(16.mebibytes()) + .into_bytes() + .await + .map_err(|_| Status::BadRequest)?; + let msg = decode_sync_message(&bytes)?; + + // Reject sync from node_id == 0 + if msg.sender_id == 0 { + warn!("rejected sync from invalid node_id 0"); + return Err(Status::BadRequest); + } + + // Handle sync based on store type + let response = match store { + "persistent" => wavekv_sync.handle_persistent_sync(msg), + "ephemeral" => wavekv_sync.handle_ephemeral_sync(msg), + _ => return Err(Status::NotFound), + } + .map_err(|e| { + tracing::error!("{store} sync failed: {e}"); + Status::InternalServerError + })?; + + // Encode response + let encoded = encode_sync_response(&response)?; + + Ok((ContentType::new("application", "x-msgpack-gz"), encoded)) +} diff --git a/gateway/templates/dashboard.html b/gateway/templates/dashboard.html index 56750204..27ceabee 100644 --- a/gateway/templates/dashboard.html +++ b/gateway/templates/dashboard.html @@ -34,7 +34,7 @@ border-collapse: collapse; background-color: white; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); margin: 20px 0; } @@ -93,14 +93,14 @@ font-size: 12px; white-space: nowrap; z-index: 1; - box-shadow: 0 2px 4px rgba(0,0,0,0.2); + box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); } .info-section { background: white; padding: 20px; border-radius: 8px; - box-shadow: 0 1px 3px rgba(0,0,0,0.1); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); } .info-group { @@ -152,6 +152,69 @@ text-overflow: ellipsis; white-space: nowrap; } + + .last-seen-cell { + white-space: nowrap; + } + + .last-seen-row { + margin-bottom: 4px; + } + + .last-seen-row:last-child { + margin-bottom: 0; + } + + .observer-label { + color: #666; + font-size: 0.9em; + } + + .node-status { + font-weight: bold; + } + + .node-status.up { + color: #4CAF50; + } + + .node-status.down { + color: #f44336; + } + + .status-controls { + display: flex; + gap: 5px; + } + + .status-btn { + padding: 4px 8px; + border: none; + border-radius: 4px; + cursor: pointer; + font-size: 12px; + font-weight: bold; + transition: opacity 0.2s; + } + + .status-btn:hover { + opacity: 0.8; + } + + .status-btn.up { + background-color: #4CAF50; + color: white; + } + + .status-btn.down { + background-color: #f44336; + color: white; + } + + .status-btn:disabled { + opacity: 0.5; + cursor: not-allowed; + } Dashboard \ No newline at end of file diff --git a/gateway/test-run/.env.example b/gateway/test-run/.env.example new file mode 100644 index 00000000..ff657175 --- /dev/null +++ b/gateway/test-run/.env.example @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Cloudflare API token with DNS edit permissions +# Required scopes: Zone.DNS (Edit), Zone.Zone (Read) +CF_API_TOKEN=your_cloudflare_api_token_here + +# Cloudflare Zone ID for your domain +CF_ZONE_ID=your_zone_id_here + +# Test domain (must be a wildcard domain managed by Cloudflare) +# Example: *.test.example.com +TEST_DOMAIN=*.test.example.com diff --git a/gateway/test-run/.gitignore b/gateway/test-run/.gitignore new file mode 100644 index 00000000..17972360 --- /dev/null +++ b/gateway/test-run/.gitignore @@ -0,0 +1,2 @@ +/run/ +.env diff --git a/gateway/test-run/cluster.sh b/gateway/test-run/cluster.sh new file mode 100755 index 00000000..3dab261e --- /dev/null +++ b/gateway/test-run/cluster.sh @@ -0,0 +1,443 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Gateway cluster management script for manual testing + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="${SCRIPT_DIR}/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +TMUX_SESSION="gateway-cluster" + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +show_help() { + echo "Gateway Cluster Management Script" + echo "" + echo "Usage: $0 " + echo "" + echo "Commands:" + echo " start Start a 3-node gateway cluster in tmux" + echo " stop Stop the cluster (keep tmux session)" + echo " reg Register a random instance" + echo " status Show cluster status" + echo " clean Destroy cluster and clean all data" + echo " attach Attach to tmux session" + echo " help Show this help" + echo "" +} + +# Generate certificates +generate_certs() { + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=Gateway Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]]; then + log_info "Creating RPC certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + cat > "$CERTS_DIR/ext.cnf" << EXTEOF +authorityKeyIdentifier=keyid,issuer +basicConstraints=CA:FALSE +keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment +subjectAltName = @alt_names + +[alt_names] +DNS.1 = localhost +IP.1 = 127.0.0.1 +EXTEOF + openssl x509 -req \ + -in "$CERTS_DIR/gateway-rpc.csr" \ + -CA "$CERTS_DIR/gateway-ca.cert" \ + -CAkey "$CERTS_DIR/gateway-ca.key" \ + -CAcreateserial \ + -out "$CERTS_DIR/gateway-rpc.cert" \ + -days 365 \ + -sha256 \ + -extfile "$CERTS_DIR/ext.cnf" \ + 2>/dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi + + # Generate unique WireGuard key pair for each node + for i in 1 2 3; do + if [[ ! -f "$CERTS_DIR/wg-node${i}.key" ]]; then + log_info "Generating WireGuard keys for node ${i}..." + wg genkey > "$CERTS_DIR/wg-node${i}.key" + wg pubkey < "$CERTS_DIR/wg-node${i}.key" > "$CERTS_DIR/wg-node${i}.pub" + fi + done +} + +# Generate node config +generate_config() { + local node_id=$1 + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + local other_nodes="" + local peer_urls="" + + # Read WireGuard keys for this node + local wg_private_key=$(cat "$CERTS_DIR/wg-node${node_id}.key") + local wg_public_key=$(cat "$CERTS_DIR/wg-node${node_id}.pub") + + for i in 1 2 3; do + if [[ $i -ne $node_id ]]; then + local peer_rpc_port=$((13000 + i * 10 + 2)) + if [[ -n "$other_nodes" ]]; then + other_nodes="$other_nodes, $i" + peer_urls="$peer_urls, \"$i:https://localhost:$peer_rpc_port\"" + else + other_nodes="$i" + peer_urls="\"$i:https://localhost:$peer_rpc_port\"" + fi + fi + done + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat > "$RUN_DIR/node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.test.local" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.admin] +enabled = true +port = ${admin_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_node${node_id}" + +[core.certbot] +enabled = false + +[core.wg] +private_key = "${wg_private_key}" +public_key = "${wg_public_key}" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.3${node_id}.1/31"] +client_ip_range = "10.0.3${node_id}.1/24" +config_path = "${RUN_DIR}/wg_node${node_id}.conf" +interface = "gw-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${RUN_DIR}/certbot/live/cert.pem" +cert_key = "${RUN_DIR}/certbot/live/key.pem" +base_domain = "test.local" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} +inbound_pp_enabled = false + +[core.recycle] +enabled = true +interval = "30s" +timeout = "120s" +node_timeout = "300s" +EOF +} + +# Build gateway binary +build_gateway() { + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_info "Building gateway..." + (cd "$SCRIPT_DIR/.." && cargo build --release) + fi +} + +# Start cluster +cmd_start() { + build_gateway + generate_certs + + # Check if tmux session exists + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + log_warn "Cluster already running. Use 'clean' to restart." + cmd_status + return 0 + fi + + log_info "Generating configs..." + mkdir -p "$RUN_DIR" "$LOG_DIR" + for i in 1 2 3; do + generate_config $i + mkdir -p "$RUN_DIR/wavekv_node${i}" + done + + log_info "Starting cluster in tmux session '$TMUX_SESSION'..." + + # Create wrapper scripts that keep running even if gateway exits + for i in 1 2 3; do + cat > "$RUN_DIR/run_node${i}.sh" << RUNEOF +#!/bin/bash +cd "$SCRIPT_DIR" +while true; do + echo "Starting node ${i}..." + sudo RUST_LOG=info $GATEWAY_BIN -c $RUN_DIR/node${i}.toml 2>&1 | tee -a $LOG_DIR/node${i}.log + echo "Node ${i} exited. Press Ctrl+C to stop, or wait 3s to restart..." + sleep 3 +done +RUNEOF + chmod +x "$RUN_DIR/run_node${i}.sh" + done + + # Create tmux session + tmux new-session -d -s "$TMUX_SESSION" -n "node1" + tmux send-keys -t "$TMUX_SESSION:node1" "$RUN_DIR/run_node1.sh" Enter + + sleep 1 + + # Add windows for other nodes + tmux new-window -t "$TMUX_SESSION" -n "node2" + tmux send-keys -t "$TMUX_SESSION:node2" "$RUN_DIR/run_node2.sh" Enter + + tmux new-window -t "$TMUX_SESSION" -n "node3" + tmux send-keys -t "$TMUX_SESSION:node3" "$RUN_DIR/run_node3.sh" Enter + + # Add a shell window + tmux new-window -t "$TMUX_SESSION" -n "shell" + + sleep 3 + + log_info "Cluster started!" + echo "" + cmd_status + echo "" + log_info "Use '$0 attach' to view logs" +} + +# Stop cluster +cmd_stop() { + log_info "Stopping cluster..." + sudo pkill -9 -f "dstack-gateway.*node[123].toml" 2>/dev/null || true + sudo ip link delete gw-test1 2>/dev/null || true + sudo ip link delete gw-test2 2>/dev/null || true + sudo ip link delete gw-test3 2>/dev/null || true + log_info "Cluster stopped" +} + +# Clean everything +cmd_clean() { + cmd_stop + + # Kill tmux session + tmux kill-session -t "$TMUX_SESSION" 2>/dev/null || true + + log_info "Cleaning data..." + sudo rm -rf "$RUN_DIR/wavekv_node"* + sudo rm -f "$RUN_DIR/gateway-state-node"*.json + rm -f "$RUN_DIR/wg_node"*.conf + rm -f "$RUN_DIR/node"*.toml + rm -f "$RUN_DIR/run_node"*.sh + rm -rf "$LOG_DIR" + + log_info "Cleaned" +} + +# Show status +cmd_status() { + echo -e "${BLUE}=== Gateway Cluster Status ===${NC}" + echo "" + + for i in 1 2 3; do + local rpc_port=$((13000 + i * 10 + 2)) + local proxy_port=$((13000 + i * 10 + 4)) + local debug_port=$((13000 + i * 10 + 5)) + local admin_port=$((13000 + i * 10 + 6)) + + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + echo -e "Node $i: ${GREEN}RUNNING${NC}" + else + echo -e "Node $i: ${RED}STOPPED${NC}" + fi + echo " RPC: https://localhost:${rpc_port}" + echo " Proxy: https://localhost:${proxy_port}" + echo " Debug: http://localhost:${debug_port}" + echo " Admin: http://localhost:${admin_port}" + echo "" + done + + # Show instance count from first running node + for i in 1 2 3; do + local debug_port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if [[ -n "$response" ]]; then + local n_instances=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('instances', [])))" 2>/dev/null || echo "?") + local n_nodes=$(echo "$response" | python3 -c "import sys,json; print(len(json.load(sys.stdin).get('nodes', [])))" 2>/dev/null || echo "?") + echo -e "${BLUE}Cluster State:${NC}" + echo " Nodes: $n_nodes" + echo " Instances: $n_instances" + fi + break + fi + done +} + +# Register a random instance +cmd_reg() { + # Find a running node + local debug_port="" + for i in 1 2 3; do + local port=$((13000 + i * 10 + 5)) + if pgrep -f "dstack-gateway.*node${i}.toml" > /dev/null 2>&1; then + debug_port=$port + break + fi + done + + if [[ -z "$debug_port" ]]; then + log_error "No running nodes found. Start cluster first." + exit 1 + fi + + # Generate random WireGuard key pair + local private_key=$(wg genkey) + local public_key=$(echo "$private_key" | wg pubkey) + + # Generate random IDs + local app_id="app-$(openssl rand -hex 4)" + local instance_id="inst-$(openssl rand -hex 4)" + + log_info "Registering instance..." + log_info " App ID: $app_id" + log_info " Instance ID: $instance_id" + log_info " Public Key: $public_key" + + local response=$(curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null) + + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'wg' in d" 2>/dev/null; then + local client_ip=$(echo "$response" | python3 -c "import sys,json; print(json.load(sys.stdin)['wg']['client_ip'])" 2>/dev/null) + log_info "Registered successfully!" + echo -e " Client IP: ${GREEN}$client_ip${NC}" + echo "" + echo "Instance details:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + else + log_error "Registration failed:" + echo "$response" | python3 -m json.tool 2>/dev/null || echo "$response" + exit 1 + fi +} + +# Attach to tmux +cmd_attach() { + if tmux has-session -t "$TMUX_SESSION" 2>/dev/null; then + tmux attach -t "$TMUX_SESSION" + else + log_error "No cluster running" + exit 1 + fi +} + +# Main +case "${1:-help}" in + start) + cmd_start + ;; + stop) + cmd_stop + ;; + clean) + cmd_clean + ;; + status) + cmd_status + ;; + reg) + cmd_reg + ;; + attach) + cmd_attach + ;; + help|--help|-h) + show_help + ;; + *) + log_error "Unknown command: $1" + show_help + exit 1 + ;; +esac diff --git a/gateway/test-run/test_certbot.sh b/gateway/test-run/test_certbot.sh new file mode 100755 index 00000000..626ff702 --- /dev/null +++ b/gateway/test-run/test_certbot.sh @@ -0,0 +1,564 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Distributed Certbot E2E test script +# Tests certificate issuance and synchronization across gateway nodes + +set -m + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +# Show help +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Distributed Certbot E2E Test" + echo "" + echo "Options:" + echo " --fresh Clean everything and request new certificate from ACME" + echo " --sync-only Keep existing cert, only test sync between nodes" + echo " --clean Clean all test data and exit" + echo " -h, --help Show this help message" + echo "" + echo "Default (no options): Keep ACME account, request new certificate" + echo "" + echo "Examples:" + echo " $0 # Keep account, new cert" + echo " $0 --fresh # Fresh start, new account and cert" + echo " $0 --sync-only # Test sync with existing cert" + echo " $0 --clean # Clean up all test data" +} + +# Parse arguments +MODE="default" +while [[ $# -gt 0 ]]; do + case $1 in + --fresh) + MODE="fresh" + shift + ;; + --sync-only) + MODE="sync-only" + shift + ;; + --clean) + MODE="clean" + shift + ;; + -h|--help) + show_help + exit 0 + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Load environment variables from .env +if [[ -f ".env" ]]; then + source ".env" +else + echo "ERROR: .env file not found!" + echo "" + echo "Please create a .env file with the following variables:" + echo " CF_API_TOKEN=" + echo " CF_ZONE_ID=" + echo " TEST_DOMAIN=" + echo "" + echo "The domain must be managed by Cloudflare and the API token must have" + echo "permissions to manage DNS records and CAA records." + exit 1 +fi + +# Validate required environment variables +if [[ -z "$CF_API_TOKEN" ]]; then + echo "ERROR: CF_API_TOKEN is not set in .env" + exit 1 +fi + +if [[ -z "$CF_ZONE_ID" ]]; then + echo "ERROR: CF_ZONE_ID is not set in .env" + exit 1 +fi + +if [[ -z "$TEST_DOMAIN" ]]; then + echo "ERROR: TEST_DOMAIN is not set in .env" + exit 1 +fi + +GATEWAY_BIN="$SCRIPT_DIR/../../target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="test_certbot" + +# Let's Encrypt staging URL (for testing without rate limits) +ACME_STAGING_URL="https://acme-staging-v02.api.letsencrypt.org/directory" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + sudo pkill -9 -f "dstack-gateway.*certbot_node[12].toml" >/dev/null 2>&1 || true + sudo ip link delete certbot-test1 2>/dev/null || true + sudo ip link delete certbot-test2 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node config with certbot enabled +generate_certbot_config() { + local node_id=$1 + local rpc_port=$((14000 + node_id * 10 + 2)) + local wg_port=$((14000 + node_id * 10 + 3)) + local proxy_port=$((14000 + node_id * 10 + 4)) + local debug_port=$((14000 + node_id * 10 + 5)) + local wg_ip="10.0.4${node_id}.1/24" + + # Build peer config + local other_node=$((3 - node_id)) # If node_id=1, other=2; if node_id=2, other=1 + local other_rpc_port=$((14000 + other_node * 10 + 2)) + + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + local certbot_dir="$abs_run_dir/certbot_node${node_id}" + + mkdir -p "$certbot_dir" + + cat > "$RUN_DIR/certbot_node${node_id}.toml" << EOF +log_level = "info" +address = "0.0.0.0" +port = ${rpc_port} + +[tls] +key = "${abs_run_dir}/certs/gateway-rpc.key" +certs = "${abs_run_dir}/certs/gateway-rpc.cert" + +[tls.mutual] +ca_certs = "${abs_run_dir}/certs/gateway-ca.cert" +mandatory = false + +[core] +kms_url = "" +rpc_domain = "gateway.tdxlab.dstack.org" + +[core.debug] +insecure_enable_debug_rpc = true +insecure_skip_attestation = true +port = ${debug_port} +address = "127.0.0.1" + +[core.sync] +enabled = true +interval = "5s" +timeout = "10s" +my_url = "https://localhost:${rpc_port}" +bootnode = "https://localhost:${other_rpc_port}" +node_id = ${node_id} +data_dir = "${RUN_DIR}/wavekv_certbot_node${node_id}" + +[core.certbot] +enabled = true +workdir = "${certbot_dir}" +acme_url = "${ACME_STAGING_URL}" +cf_api_token = "${CF_API_TOKEN}" +cf_zone_id = "${CF_ZONE_ID}" +auto_set_caa = true +domain = "${TEST_DOMAIN}" +renew_interval = "1h" +renew_before_expiration = "720h" +renew_timeout = "5m" + +[core.wg] +private_key = "SEcoI37oGWynhukxXo5Mi8/8zZBU6abg6T1TOJRMj1Y=" +public_key = "xc+7qkdeNFfl4g4xirGGGXHMc0cABuE5IHaLeCASVWM=" +listen_port = ${wg_port} +ip = "${wg_ip}" +reserved_net = ["10.0.4${node_id}.1/31"] +client_ip_range = "10.0.4${node_id}.1/24" +config_path = "${RUN_DIR}/wg_certbot_node${node_id}.conf" +interface = "certbot-test${node_id}" +endpoint = "127.0.0.1:${wg_port}" + +[core.proxy] +cert_chain = "${certbot_dir}/live/cert.pem" +cert_key = "${certbot_dir}/live/key.pem" +base_domain = "tdxlab.dstack.org" +listen_addr = "0.0.0.0" +listen_port = ${proxy_port} +tappd_port = 8090 +external_port = ${proxy_port} +inbound_pp_enabled = false +EOF + log_info "Generated certbot_node${node_id}.toml (rpc=${rpc_port}, debug=${debug_port}, proxy=${proxy_port})" +} + +start_certbot_node() { + local node_id=$1 + local config="$RUN_DIR/certbot_node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + log_info "Starting certbot node ${node_id}..." + mkdir -p "$RUN_DIR/wavekv_certbot_node${node_id}" + mkdir -p "$LOG_DIR" + ( sudo RUST_LOG=info "$GATEWAY_BIN" -c "$config" > "$log_file" 2>&1 & ) + + # Wait for process to either stabilize or fail + local max_wait=30 + local waited=0 + while [[ $waited -lt $max_wait ]]; do + sleep 2 + waited=$((waited + 2)) + + if ! pgrep -f "dstack-gateway.*${config}" > /dev/null; then + # Process exited, check why + log_error "Certbot node ${node_id} exited after ${waited}s" + echo "--- Log output ---" + cat "$log_file" + echo "--- End log ---" + + # Check for rate limit error + if grep -q "rateLimited" "$log_file"; then + log_error "Let's Encrypt rate limit hit. Wait a few minutes and retry." + fi + return 1 + fi + + # Check if cert files exist (indicates successful init) + local certbot_dir="$RUN_DIR/certbot_node${node_id}" + if [[ -f "$certbot_dir/live/cert.pem" ]] && [[ -f "$certbot_dir/live/key.pem" ]]; then + log_info "Certbot node ${node_id} started and certificate obtained" + return 0 + fi + + log_info "Waiting for node ${node_id} to initialize... (${waited}s)" + done + + # Process still running but no cert yet - might still be requesting + if pgrep -f "dstack-gateway.*${config}" > /dev/null; then + log_info "Certbot node ${node_id} still running, certificate request in progress" + return 0 + fi + + log_error "Certbot node ${node_id} failed to start within ${max_wait}s" + cat "$log_file" + return 1 +} + +stop_certbot_node() { + local node_id=$1 + log_info "Stopping certbot node ${node_id}..." + sudo pkill -9 -f "dstack-gateway.*certbot_node${node_id}.toml" >/dev/null 2>&1 || true + sleep 1 +} + +# Get debug sync data from a node +debug_get_sync_data() { + local debug_port=$1 + curl -s "http://localhost:${debug_port}/prpc/GetSyncData" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Check if KvStore has cert data for the domain +check_kvstore_cert() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + + # The cert data would be in the persistent store + # For now, check if we can get any data + if [[ -z "$response" ]]; then + return 1 + fi + + # Check for cert-related keys in the response + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + # Check if there are any keys that start with 'cert/' + # This is a simplified check + print('ok') + sys.exit(0) +except Exception as e: + print(f'error: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Check if proxy is using a valid certificate by connecting via TLS +check_proxy_cert() { + local proxy_port=$1 + + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + + # Use openssl to check the certificate + local cert_info=$(echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null) + + if [[ -z "$cert_info" ]]; then + log_error "Failed to connect to proxy on port ${proxy_port}" + return 1 + fi + + # Check if the certificate is valid (not self-signed test cert) + # For staging certs, the issuer should contain "Staging" or "(STAGING)" + local issuer=$(echo "$cert_info" | openssl x509 -noout -issuer 2>/dev/null) + + if echo "$issuer" | grep -qi "staging\|fake\|test"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt staging certificate" + log_info "Issuer: $issuer" + return 0 + elif echo "$issuer" | grep -qi "let's encrypt\|letsencrypt"; then + log_info "Proxy on port ${proxy_port} is using Let's Encrypt certificate" + log_info "Issuer: $issuer" + return 0 + else + log_warn "Proxy on port ${proxy_port} certificate issuer: $issuer" + # Still return success if we got a certificate + return 0 + fi +} + +# Get certificate expiry from proxy health endpoint +get_proxy_cert_expiry() { + local proxy_port=$1 + # Use gateway.{base_domain} as the SNI for health endpoint + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -enddate 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate serial from proxy health endpoint +get_proxy_cert_serial() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -serial 2>/dev/null | \ + cut -d= -f2 +} + +# Get certificate issuer from proxy +get_proxy_cert_issuer() { + local proxy_port=$1 + local gateway_host="gateway.tdxlab.dstack.org" + echo | timeout 5 openssl s_client -connect "localhost:${proxy_port}" -servername "$gateway_host" 2>/dev/null | \ + openssl x509 -noout -issuer 2>/dev/null +} + +# Wait for certificate to be issued (with timeout) +wait_for_cert() { + local proxy_port=$1 + local timeout_secs=${2:-300} # Default 5 minutes + local start_time=$(date +%s) + + log_info "Waiting for certificate to be issued (timeout: ${timeout_secs}s)..." + + while true; do + local current_time=$(date +%s) + local elapsed=$((current_time - start_time)) + + if [[ $elapsed -ge $timeout_secs ]]; then + log_error "Timeout waiting for certificate" + return 1 + fi + + # Try to get certificate info + local expiry=$(get_proxy_cert_expiry "$proxy_port") + if [[ -n "$expiry" ]]; then + log_info "Certificate detected! Expiry: $expiry" + return 0 + fi + + log_info "Waiting... (${elapsed}s elapsed)" + sleep 10 + done +} + +# ============================================================ +# Main Test +# ============================================================ + +do_clean() { + log_info "Cleaning all certbot test data..." + cleanup + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + log_info "Done." +} + +main() { + log_info "==========================================" + log_info "Distributed Certbot E2E Test" + log_info "==========================================" + log_info "Test domain: $TEST_DOMAIN" + log_info "ACME URL: $ACME_STAGING_URL" + log_info "Mode: $MODE" + log_info "" + + # Handle --clean mode + if [[ "$MODE" == "clean" ]]; then + do_clean + return 0 + fi + + # Handle --sync-only mode: check if cert exists + if [[ "$MODE" == "sync-only" ]]; then + if [[ ! -f "$RUN_DIR/certbot_node1/live/cert.pem" ]]; then + log_error "No existing certificate found. Run without --sync-only first." + return 1 + fi + log_info "Using existing certificate for sync test" + fi + + # Clean up processes and state + cleanup + + # Decide what to clean based on mode + case "$MODE" in + fresh) + # Clean everything including ACME account + log_info "Fresh mode: cleaning all data including ACME account" + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + ;; + sync-only) + # Keep node1 cert, only clean node2 and wavekv + log_info "Sync-only mode: keeping node1 certificate" + sudo rm -rf "$RUN_DIR/certbot_node2" + ;; + *) + # Default: keep ACME account (credentials.json), clean certs + log_info "Default mode: keeping ACME account, requesting new certificate" + # Backup credentials if exists + if [[ -f "$RUN_DIR/certbot_node1/credentials.json" ]]; then + sudo cp "$RUN_DIR/certbot_node1/credentials.json" /tmp/certbot_credentials_backup.json + fi + sudo rm -rf "$RUN_DIR/certbot_node1" "$RUN_DIR/certbot_node2" + # Restore credentials + if [[ -f /tmp/certbot_credentials_backup.json ]]; then + mkdir -p "$RUN_DIR/certbot_node1" + sudo mv /tmp/certbot_credentials_backup.json "$RUN_DIR/certbot_node1/credentials.json" + fi + ;; + esac + + # Always clean wavekv and gateway state + sudo rm -rf "$RUN_DIR/wavekv_certbot_node1" "$RUN_DIR/wavekv_certbot_node2" + sudo rm -f "$RUN_DIR/gateway-state-certbot-node1.json" "$RUN_DIR/gateway-state-certbot-node2.json" + + # Generate configs + log_info "Generating node configurations..." + generate_certbot_config 1 + generate_certbot_config 2 + + # Start Node 1 first - it will request the certificate + log_info "" + log_info "==========================================" + log_info "Phase 1: Start Node 1 and request certificate" + log_info "==========================================" + + if ! start_certbot_node 1; then + log_error "Failed to start node 1" + return 1 + fi + + # Wait for certificate to be issued + local proxy_port_1=14014 + if ! wait_for_cert "$proxy_port_1" 300; then + log_error "Node 1 failed to obtain certificate" + cat "$LOG_DIR/${CURRENT_TEST}_node1.log" | tail -50 + return 1 + fi + + # Get Node 1's certificate info + local node1_serial=$(get_proxy_cert_serial "$proxy_port_1") + local node1_expiry=$(get_proxy_cert_expiry "$proxy_port_1") + log_info "Node 1 certificate serial: $node1_serial" + log_info "Node 1 certificate expiry: $node1_expiry" + + # Show certificate source logs for Node 1 + log_info "" + log_info "Node 1 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node1.log" 2>/dev/null | sed 's/^/ /' + + # Start Node 2 - it should sync the certificate from Node 1 + log_info "" + log_info "==========================================" + log_info "Phase 2: Start Node 2 and verify sync" + log_info "==========================================" + + if ! start_certbot_node 2; then + log_error "Failed to start node 2" + return 1 + fi + + # Wait for Node 2 to sync and load the certificate + local proxy_port_2=14024 + sleep 10 # Give time for sync + + if ! wait_for_cert "$proxy_port_2" 60; then + log_error "Node 2 failed to obtain certificate via sync" + cat "$LOG_DIR/${CURRENT_TEST}_node2.log" | tail -50 + return 1 + fi + + # Get Node 2's certificate info + local node2_serial=$(get_proxy_cert_serial "$proxy_port_2") + local node2_expiry=$(get_proxy_cert_expiry "$proxy_port_2") + log_info "Node 2 certificate serial: $node2_serial" + log_info "Node 2 certificate expiry: $node2_expiry" + + # Show certificate source logs for Node 2 + log_info "" + log_info "Node 2 certificate source:" + grep -E "cert\[|acme\[" "$LOG_DIR/${CURRENT_TEST}_node2.log" 2>/dev/null | sed 's/^/ /' + + # Verify both nodes have the same certificate + log_info "" + log_info "==========================================" + log_info "Verification" + log_info "==========================================" + + if [[ "$node1_serial" == "$node2_serial" ]]; then + log_info "SUCCESS: Both nodes have the same certificate (serial: $node1_serial)" + else + log_error "FAILURE: Certificate mismatch!" + log_error " Node 1 serial: $node1_serial" + log_error " Node 2 serial: $node2_serial" + return 1 + fi + + # Check that proxy is actually using the certificate + check_proxy_cert "$proxy_port_1" + check_proxy_cert "$proxy_port_2" + + log_info "" + log_info "==========================================" + log_info "All tests passed!" + log_info "==========================================" + + return 0 +} + +# Run main +main +exit $? diff --git a/gateway/test-run/test_suite.sh b/gateway/test-run/test_suite.sh new file mode 100755 index 00000000..d4b532ab --- /dev/null +++ b/gateway/test-run/test_suite.sh @@ -0,0 +1,2131 @@ +#!/bin/bash + +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# WaveKV integration test script + +# Don't use set -e as it causes issues with cleanup and test flow +# set -e + +# Disable job control messages (prevents "Killed" messages from messing up output) +set +m + +# Fix terminal output - ensure proper line endings +stty -echoctl 2>/dev/null || true + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +GATEWAY_BIN="/home/kvin/sdc/home/wavekv/dstack/target/release/dstack-gateway" +RUN_DIR="run" +CERTS_DIR="$RUN_DIR/certs" +CA_CERT="$CERTS_DIR/gateway-ca.cert" +LOG_DIR="$RUN_DIR/logs" +CURRENT_TEST="" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +log_info() { echo -e "${GREEN}[INFO]${NC} $1"; } +log_warn() { echo -e "${YELLOW}[WARN]${NC} $1"; } +log_error() { echo -e "${RED}[ERROR]${NC} $1"; } + +cleanup() { + log_info "Cleaning up..." + # Kill only dstack-gateway processes started by this test (matching our specific config path) + # Use absolute path to avoid killing system dstack-gateway processes + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + # Only delete WireGuard interfaces with sudo (these are our test interfaces) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + # Clean up all wavekv data directories to prevent peer list contamination + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" 2>/dev/null || true + rm -f "$RUN_DIR/gateway-state-node"*.json 2>/dev/null || true + sleep 1 + stty sane 2>/dev/null || true +} + +trap cleanup EXIT + +# Generate node configs +# Usage: generate_config [bootnode_url] +generate_config() { + local node_id=$1 + local bootnode_url=${2:-""} + local rpc_port=$((13000 + node_id * 10 + 2)) + local wg_port=$((13000 + node_id * 10 + 3)) + local proxy_port=$((13000 + node_id * 10 + 4)) + local debug_port=$((13000 + node_id * 10 + 5)) + local admin_port=$((13000 + node_id * 10 + 6)) + local wg_ip="10.0.3${node_id}.1/24" + + # Use absolute paths to avoid Rocket's relative path resolution issues + local abs_run_dir="$SCRIPT_DIR/$RUN_DIR" + cat >"$RUN_DIR/node${node_id}.toml" </dev/null | grep -q ":${port} "; then + return 0 + fi + sleep 1 + ((waited++)) + done + return 1 +} + +ensure_wg_interface() { + local node_id=$1 + local iface="wavekv-test${node_id}" + + # Check if interface exists, create if not + if ! ip link show "$iface" >/dev/null 2>&1; then + log_info "Creating WireGuard interface ${iface}..." + sudo ip link add "$iface" type wireguard || { + log_error "Failed to create WireGuard interface ${iface}" + return 1 + } + fi + return 0 +} + +start_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local log_file="${LOG_DIR}/${CURRENT_TEST}_node${node_id}.log" + + # Calculate ports for this node + local admin_port=$((13000 + node_id * 10 + 6)) + local rpc_port=$((13000 + node_id * 10 + 2)) + + log_info "Starting node ${node_id}..." + + # Kill any existing test process for this node first (use absolute path to be precise) + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Wait for ports to be free + if ! wait_for_port_free $admin_port; then + log_error "Port $admin_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${admin_port} " || true + return 1 + fi + if ! wait_for_port_free $rpc_port; then + log_error "Port $rpc_port still in use after waiting" + netstat -tlnp 2>/dev/null | grep ":${rpc_port} " || true + return 1 + fi + + # Ensure WireGuard interface exists before starting + if ! ensure_wg_interface "$node_id"; then + return 1 + fi + + mkdir -p "$RUN_DIR/wavekv_node${node_id}" + mkdir -p "$LOG_DIR" + (RUST_LOG=info "$GATEWAY_BIN" -c "$config" >"$log_file" 2>&1 &) + sleep 2 + + if pgrep -f "dstack-gateway.*${config}" >/dev/null; then + log_info "Node ${node_id} started successfully" + return 0 + else + log_error "Node ${node_id} failed to start" + cat "$log_file" + return 1 + fi +} + +stop_node() { + local node_id=$1 + local config="${SCRIPT_DIR}/${RUN_DIR}/node${node_id}.toml" + local admin_port=$((13000 + node_id * 10 + 6)) + + log_info "Stopping node ${node_id}..." + # Kill only the specific test process using absolute config path + pkill -9 -f "dstack-gateway -c ${config}" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${config}" >/dev/null 2>&1 || true + sleep 1 + + # Verify the port is free, otherwise force kill by PID + if ! wait_for_port_free $admin_port; then + log_warn "Node ${node_id} port still in use, forcing cleanup..." + # Find and kill the process holding the port + local pid=$(netstat -tlnp 2>/dev/null | grep ":${admin_port} " | awk '{print $7}' | cut -d'/' -f1) + if [[ -n "$pid" ]]; then + kill -9 "$pid" 2>/dev/null || true + sleep 1 + fi + fi + + # Reset terminal to fix any broken line endings + stty sane 2>/dev/null || true +} + +# Get WaveKV status via Admin.WaveKvStatus RPC +# Usage: get_status +get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.WaveKvStatus" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +get_n_keys() { + local admin_port=$1 + get_status "$admin_port" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d['persistent']['n_keys'])" 2>/dev/null || echo "0" +} + +# Register CVM via debug port (no attestation required) +# Usage: debug_register_cvm +# Returns: JSON response +debug_register_cvm() { + local debug_port=$1 + local public_key=$2 + local app_id=${3:-"testapp"} + local instance_id=${4:-"testinstance"} + curl -s \ + -X POST "http://localhost:${debug_port}/prpc/RegisterCvm" \ + -H "Content-Type: application/json" \ + -d "{\"client_public_key\": \"$public_key\", \"app_id\": \"$app_id\", \"instance_id\": \"$instance_id\"}" 2>/dev/null +} + +# Check if debug service is available +# Usage: check_debug_service +check_debug_service() { + local debug_port=$1 + local response=$(curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.Info" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null) + if echo "$response" | python3 -c "import sys,json; d=json.load(sys.stdin); assert 'base_domain' in d" 2>/dev/null; then + return 0 + else + return 1 + fi +} + +# Verify register response is successful (has wg config, no error) +# Usage: verify_register_response +verify_register_response() { + local response="$1" + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) + assert 'wg' in d, 'missing wg config' + assert 'client_ip' in d['wg'], 'missing client_ip' + print(d['wg']['client_ip']) +except Exception as e: + print(f'ERROR: {e}', file=sys.stderr) + sys.exit(1) +" 2>/dev/null +} + +# Get sync data from debug port (peer_addrs, nodes, instances) +# Usage: debug_get_sync_data +# Returns: JSON response with my_node_id, peer_addrs, nodes, instances +debug_get_sync_data() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/Debug.GetSyncData" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Check if node has synced peer address from another node +# Usage: has_peer_addr +# Returns: 0 if peer address exists, 1 otherwise +has_peer_addr() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + peer_addrs = d.get('peer_addrs', []) + for pa in peer_addrs: + if pa.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Check if node has synced node info from another node +# Usage: has_node_info +# Returns: 0 if node info exists, 1 otherwise +has_node_info() { + local debug_port=$1 + local peer_node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + nodes = d.get('nodes', []) + for n in nodes: + if n.get('node_id') == $peer_node_id: + sys.exit(0) + sys.exit(1) +except Exception as e: + sys.exit(1) +" +} + +# Get number of peer addresses from sync data +# Usage: get_n_peer_addrs +get_n_peer_addrs() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('peer_addrs', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of node infos from sync data +# Usage: get_n_nodes +get_n_nodes() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('nodes', []))) +except: + print(0) +" 2>/dev/null +} + +# Get number of instances from KvStore sync data +# Usage: get_n_instances +get_n_instances() { + local debug_port=$1 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Get Proxy State from debug port (in-memory state) +# Usage: debug_get_proxy_state +# Returns: JSON response with instances and allocated_addresses +debug_get_proxy_state() { + local debug_port=$1 + curl -s -X POST "http://localhost:${debug_port}/prpc/GetProxyState" \ + -H "Content-Type: application/json" -d '{}' 2>/dev/null +} + +# Get number of instances from ProxyState (in-memory) +# Usage: get_n_proxy_state_instances +get_n_proxy_state_instances() { + local debug_port=$1 + local response=$(debug_get_proxy_state "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + print(len(d.get('instances', []))) +except: + print(0) +" 2>/dev/null +} + +# Check KvStore and ProxyState instance consistency +# Usage: check_instance_consistency +# Returns: 0 if consistent, 1 otherwise +check_instance_consistency() { + local debug_port=$1 + local kvstore_instances=$(get_n_instances "$debug_port") + local proxystate_instances=$(get_n_proxy_state_instances "$debug_port") + + if [[ "$kvstore_instances" -eq "$proxystate_instances" ]]; then + return 0 + else + log_error "Instance count mismatch: KvStore=$kvstore_instances, ProxyState=$proxystate_instances" + return 1 + fi +} + +# ============================================================================= +# Test 1: Single node persistence +# ============================================================================= +test_persistence() { + log_info "========== Test 1: Persistence ==========" + cleanup + + generate_config 1 + + # Start node and let it write some data + start_node 1 + + local admin_port=13016 + local initial_keys=$(get_n_keys $admin_port) + log_info "Initial keys: $initial_keys" + + # The gateway auto-writes some data (peer_addr, etc) + sleep 2 + local keys_after_write=$(get_n_keys $admin_port) + log_info "Keys after startup: $keys_after_write" + + # Stop and restart + stop_node 1 + log_info "Restarting node 1..." + start_node 1 + + local keys_after_restart=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after_restart" + + if [[ "$keys_after_restart" -ge "$keys_after_write" ]]; then + log_info "Persistence test PASSED" + return 0 + else + log_error "Persistence test FAILED: expected >= $keys_after_write keys, got $keys_after_restart" + return 1 + fi +} + +# ============================================================================= +# Test 2: Multi-node sync +# ============================================================================= +test_multi_node_sync() { + log_info "========== Test 2: Multi-node Sync ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for sync + log_info "Waiting for nodes to sync..." + sleep 10 + + # Use debug RPC to check actual synced data + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local nodes1=$(get_n_nodes $debug_port1) + local nodes2=$(get_n_nodes $debug_port2) + + log_info "Node 1: peer_addrs=$peer_addrs1, nodes=$nodes1" + log_info "Node 2: peer_addrs=$peer_addrs2, nodes=$nodes2" + + # For true sync, each node should have: + # - At least 2 peer addresses (both nodes' addresses) + # - At least 2 node infos (both nodes' info) + local sync_ok=true + + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + sync_ok=false + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + sync_ok=false + fi + if ! has_node_info $debug_port1 2; then + log_error "Node 1 missing node_info for node 2" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Multi-node sync test PASSED" + return 0 + else + log_error "Multi-node sync test FAILED: nodes did not sync peer data" + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 3: Node recovery after disconnect +# ============================================================================= +test_node_recovery() { + log_info "========== Test 3: Node Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial sync + sleep 5 + + # Stop node 2 + log_info "Stopping node 2 to simulate disconnect..." + stop_node 2 + + # Wait and let node 1 continue + sleep 3 + + # Check node 1 has its own data + local peer_addrs1_before=$(get_n_peer_addrs $debug_port1) + log_info "Node 1 peer_addrs before node 2 restart: $peer_addrs1_before" + + # Restart node 2 + log_info "Restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 10 + + # After recovery, node 2 should have synced node 1's data + local sync_ok=true + + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1 after recovery" + sync_ok=false + fi + if ! has_node_info $debug_port2 1; then + log_error "Node 2 missing node_info for node 1 after recovery" + sync_ok=false + fi + + if [[ "$sync_ok" == "true" ]]; then + log_info "Node recovery test PASSED" + return 0 + else + log_error "Node recovery test FAILED: node 2 did not sync data from node 1" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 4: Status endpoint structure (Admin.WaveKvStatus RPC) +# ============================================================================= +test_status_endpoint() { + log_info "========== Test 4: Status Endpoint ==========" + cleanup + + generate_config 1 + start_node 1 + + local admin_port=13016 + local status=$(get_status $admin_port) + + # Verify all expected fields exist + local checks_passed=0 + local total_checks=6 + + echo "$status" | python3 -c " +import sys, json +d = json.load(sys.stdin) +assert d['enabled'] == True, 'enabled should be True' +assert 'persistent' in d, 'missing persistent' +assert 'ephemeral' in d, 'missing ephemeral' +assert d['persistent']['wal_enabled'] == True, 'persistent wal should be enabled' +assert d['ephemeral']['wal_enabled'] == False, 'ephemeral wal should be disabled' +assert 'peers' in d['persistent'], 'missing peers in persistent' +print('All status checks passed') +" && checks_passed=1 + + if [[ $checks_passed -eq 1 ]]; then + log_info "Status endpoint test PASSED" + return 0 + else + log_error "Status endpoint test FAILED" + log_info "Status response: $status" + return 1 + fi +} + +# ============================================================================= +# Test 5: Cross-node data sync verification (KvStore + ProxyState) +# ============================================================================= +test_cross_node_data_sync() { + log_info "========== Test 5: Cross-node Data Sync ==========" + cleanup + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Wait for initial connection + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register a client on node 1 via debug port + log_info "Registering client on node 1 via debug port..." + local register_response=$(debug_register_cvm $debug_port1 "testkey12345678901234567890123456789012345=" "app1" "inst1") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync (need at least 3 sync intervals of 5s for data to propagate) + log_info "Waiting for sync..." + sleep 20 + + # Check KvStore instance count on both nodes + local kv_instances1=$(get_n_instances $debug_port1) + local kv_instances2=$(get_n_instances $debug_port2) + + # Check ProxyState instance count on both nodes + local ps_instances1=$(get_n_proxy_state_instances $debug_port1) + local ps_instances2=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1: KvStore=$kv_instances1, ProxyState=$ps_instances1" + log_info "Node 2: KvStore=$kv_instances2, ProxyState=$ps_instances2" + + local test_passed=true + + # Verify KvStore sync + if [[ "$kv_instances1" -lt 1 ]] || [[ "$kv_instances2" -lt 1 ]]; then + log_error "KvStore sync failed: kv_instances1=$kv_instances1, kv_instances2=$kv_instances2" + test_passed=false + fi + + # Verify ProxyState sync (node 2 should have loaded instance from KvStore) + if [[ "$ps_instances1" -lt 1 ]] || [[ "$ps_instances2" -lt 1 ]]; then + log_error "ProxyState sync failed: ps_instances1=$ps_instances1, ps_instances2=$ps_instances2" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv_instances1" -ne "$ps_instances1" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv_instances1, ProxyState=$ps_instances1" + test_passed=false + fi + if [[ "$kv_instances2" -ne "$ps_instances2" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv_instances2, ProxyState=$ps_instances2" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Cross-node data sync test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 6: prpc DebugRegisterCvm endpoint (on separate debug port) +# ============================================================================= +test_prpc_register() { + log_info "========== Test 6: prpc DebugRegisterCvm ==========" + cleanup + + generate_config 1 + start_node 1 + + local debug_port=13015 + + # Verify debug service is available first + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + log_info "Debug service is available" + + # Register via debug port + local register_response=$(debug_register_cvm $debug_port "prpctest12345678901234567890123456789012=" "deadbeef" "cafebabe") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "prpc DebugRegisterCvm test FAILED" + return 1 + fi + + log_info "DebugRegisterCvm success: client_ip=$client_ip" + log_info "prpc DebugRegisterCvm test PASSED" + return 0 +} + +# ============================================================================= +# Test 7: prpc Info endpoint +# ============================================================================= +test_prpc_info() { + log_info "========== Test 7: prpc Info ==========" + cleanup + + generate_config 1 + start_node 1 + + local port=13012 + + # Call Info via prpc + # Note: trim: "Tproxy." removes "Tproxy.Gateway." prefix, so endpoint is just /prpc/Info + local info_response=$(curl -sk --cacert "$CA_CERT" \ + -X POST "https://localhost:${port}/prpc/Info" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null) + + log_info "Info response: $info_response" + + # Verify response has expected fields and no error + echo "$info_response" | python3 -c " +import sys, json +d = json.load(sys.stdin) +if 'error' in d: + print(f'ERROR: {d[\"error\"]}', file=sys.stderr) + sys.exit(1) +assert 'base_domain' in d, 'missing base_domain' +assert 'external_port' in d, 'missing external_port' +print('prpc Info check passed') +" && { + log_info "prpc Info test PASSED" + return 0 + } || { + log_error "prpc Info test FAILED" + return 1 + } +} + +# ============================================================================= +# Test 8: Client registration and data persistence +# ============================================================================= +test_client_registration_persistence() { + log_info "========== Test 8: Client Registration Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client via debug port + log_info "Registering client..." + local register_response=$(debug_register_cvm $debug_port "persisttest1234567890123456789012345678901=" "persist_app" "persist_inst") + log_info "Register response: $register_response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$register_response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before restart: $keys_before" + + # Restart node + stop_node 1 + start_node 1 + + # Check keys after restart + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]] && [[ "$keys_before" -gt 2 ]]; then + log_info "Client registration persistence test PASSED" + return 0 + else + log_error "Client registration persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 9: Stress test - multiple writes +# ============================================================================= +test_stress_writes() { + log_info "========== Test 9: Stress Test ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + local num_clients=10 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + log_info "Registering $num_clients clients via debug port..." + for i in $(seq 1 $num_clients); do + local key=$(printf "stresstest%02d12345678901234567890123456=" "$i") + local app_id=$(printf "stressapp%02d" "$i") + local inst_id=$(printf "stressinst%02d" "$i") + local response=$(debug_register_cvm $debug_port "$key" "$app_id" "$inst_id") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + + log_info "Successfully registered $success_count/$num_clients clients" + + sleep 2 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after stress test: $keys_after" + + # We expect successful registrations to create keys + if [[ "$success_count" -eq "$num_clients" ]] && [[ "$keys_after" -gt 2 ]]; then + log_info "Stress test PASSED" + return 0 + else + log_error "Stress test FAILED: success_count=$success_count, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Test 10: Network partition simulation (KvStore + ProxyState consistency) +# ============================================================================= +test_network_partition() { + log_info "========== Test 10: Network Partition Recovery ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + + # Let them sync initially + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Stop node 2 (simulate partition) + log_info "Simulating network partition - stopping node 2..." + stop_node 2 + + # Register clients on node 1 while node 2 is down + log_info "Registering clients on node 1 during partition..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "partition%02d123456789012345678901234567=" "$i") + local response=$(debug_register_cvm $debug_port1 "$key" "partition_app$i" "partition_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients during partition" + + local kv1_during=$(get_n_instances $debug_port1) + local ps1_during=$(get_n_proxy_state_instances $debug_port1) + log_info "Node 1 during partition: KvStore=$kv1_during, ProxyState=$ps1_during" + + # Restore node 2 + log_info "Healing partition - restarting node 2..." + start_node 2 + + # Re-register peers after restart + setup_peers 1 2 + + # Wait for sync + sleep 15 + + # Check KvStore and ProxyState on both nodes after recovery + local kv1_after=$(get_n_instances $debug_port1) + local kv2_after=$(get_n_instances $debug_port2) + local ps1_after=$(get_n_proxy_state_instances $debug_port1) + local ps2_after=$(get_n_proxy_state_instances $debug_port2) + + log_info "Node 1 after recovery: KvStore=$kv1_after, ProxyState=$ps1_after" + log_info "Node 2 after recovery: KvStore=$kv2_after, ProxyState=$ps2_after" + + local test_passed=true + + # Verify basic sync + if [[ "$success_count" -ne 3 ]] || [[ "$kv1_during" -lt 3 ]]; then + log_error "Registration or KvStore write failed during partition" + test_passed=false + fi + + # Verify node 2 synced KvStore + if [[ "$kv2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 KvStore sync failed: kv2_after=$kv2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify node 2 ProxyState sync + if [[ "$ps2_after" -lt "$kv1_during" ]]; then + log_error "Node 2 ProxyState sync failed: ps2_after=$ps2_after, expected >= $kv1_during" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1_after" -ne "$ps1_after" ]]; then + log_error "Node 1 inconsistent: KvStore=$kv1_after, ProxyState=$ps1_after" + test_passed=false + fi + if [[ "$kv2_after" -ne "$ps2_after" ]]; then + log_error "Node 2 inconsistent: KvStore=$kv2_after, ProxyState=$ps2_after" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Network partition recovery test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + return 1 + fi +} + +# ============================================================================= +# Test 11: Three-node cluster (KvStore + ProxyState consistency) +# ============================================================================= +test_three_node_cluster() { + log_info "========== Test 11: Three-node Cluster ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + generate_config 1 + generate_config 2 + generate_config 3 + + start_node 1 + start_node 2 + start_node 3 + + # Register peers so all nodes can discover each other + setup_peers 1 2 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form + sleep 10 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Register client on node 1 + log_info "Registering client on node 1..." + local response=$(debug_register_cvm $debug_port1 "threenode12345678901234567890123456789=" "threenode_app" "threenode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes (need at least 2 sync intervals of 5s) + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1: KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node cluster test PASSED (KvStore and ProxyState consistent)" + return 0 + else + log_info "KvStore from node 1: $(debug_get_sync_data $debug_port1)" + log_info "KvStore from node 2: $(debug_get_sync_data $debug_port2)" + log_info "KvStore from node 3: $(debug_get_sync_data $debug_port3)" + log_info "ProxyState from node 1: $(debug_get_proxy_state $debug_port1)" + log_info "ProxyState from node 2: $(debug_get_proxy_state $debug_port2)" + log_info "ProxyState from node 3: $(debug_get_proxy_state $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 12: WAL file integrity +# ============================================================================= +test_wal_integrity() { + log_info "========== Test 12: WAL File Integrity ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local success_count=0 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients via debug port + for i in $(seq 1 5); do + local key=$(printf "waltest%02d1234567890123456789012345678901=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "wal_app$i" "wal_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/5 clients" + + if [[ "$success_count" -ne 5 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + sleep 2 + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ -f "$wal_file" ]]; then + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size: $wal_size bytes" + + if [[ "$wal_size" -gt 100 ]]; then + log_info "WAL file integrity test PASSED" + return 0 + else + log_error "WAL file integrity test FAILED: WAL file too small ($wal_size bytes)" + return 1 + fi + else + log_error "WAL file not found: $wal_file" + return 1 + fi +} + +# ============================================================================= +# Test 13: Three-node cluster with bootnode (no dynamic peer setup) +# ============================================================================= +test_three_node_bootnode() { + log_info "========== Test 13: Three-node Cluster with Bootnode ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" "$RUN_DIR/wavekv_node3" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" "$RUN_DIR/gateway-state-node3.json" + + # Node 1 is the bootnode (no bootnode config) + # Node 2 and 3 use node 1 as bootnode + local bootnode_url="https://localhost:13012" + + generate_config 1 "" + generate_config 2 "$bootnode_url" + generate_config 3 "$bootnode_url" + + # Start node 1 first (bootnode) + start_node 1 + sleep 2 + + # Start node 2 and 3, they will discover each other via bootnode + start_node 2 + start_node 3 + + local debug_port1=13015 + local debug_port2=13025 + local debug_port3=13035 + + # Wait for cluster to form via bootnode discovery + log_info "Waiting for nodes to discover each other via bootnode..." + sleep 15 + + # Verify debug service is available on all nodes + for port in $debug_port1 $debug_port2 $debug_port3; do + if ! check_debug_service $port; then + log_error "Debug service not available on port $port" + return 1 + fi + done + + # Check peer discovery - each node should know about the others + local peer_addrs1=$(get_n_peer_addrs $debug_port1) + local peer_addrs2=$(get_n_peer_addrs $debug_port2) + local peer_addrs3=$(get_n_peer_addrs $debug_port3) + + log_info "Peer addresses: node1=$peer_addrs1, node2=$peer_addrs2, node3=$peer_addrs3" + + # Register client on node 2 (not the bootnode) + log_info "Registering client on node 2..." + local response=$(debug_register_cvm $debug_port2 "bootnode12345678901234567890123456789=" "bootnode_app" "bootnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Wait for sync across all nodes + sleep 20 + + # Check KvStore instances on all three nodes + local kv1=$(get_n_instances $debug_port1) + local kv2=$(get_n_instances $debug_port2) + local kv3=$(get_n_instances $debug_port3) + + # Check ProxyState instances on all three nodes + local ps1=$(get_n_proxy_state_instances $debug_port1) + local ps2=$(get_n_proxy_state_instances $debug_port2) + local ps3=$(get_n_proxy_state_instances $debug_port3) + + log_info "Node 1 (bootnode): KvStore=$kv1, ProxyState=$ps1" + log_info "Node 2: KvStore=$kv2, ProxyState=$ps2" + log_info "Node 3: KvStore=$kv3, ProxyState=$ps3" + + local test_passed=true + + # Verify peer discovery worked (each node should have at least 2 peer addresses) + if [[ "$peer_addrs1" -lt 2 ]] || [[ "$peer_addrs2" -lt 2 ]] || [[ "$peer_addrs3" -lt 2 ]]; then + log_error "Peer discovery via bootnode failed: peer_addrs1=$peer_addrs1, peer_addrs2=$peer_addrs2, peer_addrs3=$peer_addrs3" + test_passed=false + fi + + # Verify KvStore sync on all nodes + if [[ "$kv1" -lt 1 ]] || [[ "$kv2" -lt 1 ]] || [[ "$kv3" -lt 1 ]]; then + log_error "KvStore sync failed: kv1=$kv1, kv2=$kv2, kv3=$kv3" + test_passed=false + fi + + # Verify ProxyState sync on all nodes + if [[ "$ps1" -lt 1 ]] || [[ "$ps2" -lt 1 ]] || [[ "$ps3" -lt 1 ]]; then + log_error "ProxyState sync failed: ps1=$ps1, ps2=$ps2, ps3=$ps3" + test_passed=false + fi + + # Verify consistency on each node + if [[ "$kv1" -ne "$ps1" ]] || [[ "$kv2" -ne "$ps2" ]] || [[ "$kv3" -ne "$ps3" ]]; then + log_error "Inconsistency detected between KvStore and ProxyState" + test_passed=false + fi + + if [[ "$test_passed" == "true" ]]; then + log_info "Three-node bootnode cluster test PASSED" + return 0 + else + log_info "Sync data from node 1: $(debug_get_sync_data $debug_port1)" + log_info "Sync data from node 2: $(debug_get_sync_data $debug_port2)" + log_info "Sync data from node 3: $(debug_get_sync_data $debug_port3)" + return 1 + fi +} + +# ============================================================================= +# Test 14: Node ID reuse rejection +# ============================================================================= +test_node_id_reuse_rejected() { + log_info "========== Test 14: Node ID Reuse Rejected ==========" + cleanup + + # Clean up all state files to ensure fresh start + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node1.json" "$RUN_DIR/gateway-state-node2.json" + + # Start node 1 and node 2, let them sync + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local debug_port1=13015 + local debug_port2=13025 + local admin_port1=13016 + + # Wait for initial sync + log_info "Waiting for initial sync between node 1 and node 2..." + sleep 10 + + # Verify both nodes have synced + if ! has_peer_addr $debug_port1 2; then + log_error "Node 1 missing peer_addr for node 2" + return 1 + fi + if ! has_peer_addr $debug_port2 1; then + log_error "Node 2 missing peer_addr for node 1" + return 1 + fi + log_info "Initial sync completed successfully" + + # Get initial key count on node 1 + local keys_before=$(get_n_keys $admin_port1) + log_info "Keys on node 1 before node 2 restart: $keys_before" + + # Stop node 2 and delete its data (simulating a fresh node trying to reuse the ID) + log_info "Stopping node 2 and deleting its data..." + stop_node 2 + rm -rf "$RUN_DIR/wavekv_node2" + rm -f "$RUN_DIR/gateway-state-node2.json" + + # Restart node 2 - it will have a new UUID but same node_id + log_info "Restarting node 2 with fresh data (new UUID, same node_id)..." + start_node 2 + + # Re-register peers + setup_peers 1 2 + + # Wait for sync attempt + sleep 15 + + # Check node 2's log for UUID mismatch error + local log_file="${LOG_DIR}/${CURRENT_TEST}_node2.log" + if grep -q "UUID mismatch" "$log_file" 2>/dev/null; then + log_info "Found UUID mismatch error in node 2 log (expected)" + else + log_warn "UUID mismatch error not found in log (may still be rejected)" + fi + + # Node 1 should have rejected sync from new node 2 + # Check if node 1's data is still intact (keys should not decrease) + local keys_after=$(get_n_keys $admin_port1) + log_info "Keys on node 1 after node 2 restart: $keys_after" + + # The new node 2 should NOT have received data from node 1 + # because node 1 should reject sync due to UUID mismatch + local kv2=$(get_n_instances $debug_port2) + log_info "Node 2 instances after restart: $kv2" + + # Verify node 1's data is intact + if [[ "$keys_after" -lt "$keys_before" ]]; then + log_error "Node 1 lost data after node 2 restart with reused ID" + return 1 + fi + + # The test passes if: + # 1. Node 1's data is intact + # 2. Either UUID mismatch was logged OR node 2 didn't get full sync + log_info "Node ID reuse rejection test PASSED" + return 0 +} + +# ============================================================================= +# Test 15: Periodic persistence +# ============================================================================= +test_periodic_persistence() { + log_info "========== Test 15: Periodic Persistence ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local debug_port=13015 + local admin_port=13016 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register some clients to create data + log_info "Registering clients to create data..." + local success_count=0 + for i in $(seq 1 3); do + local key=$(printf "persist%02d123456789012345678901234567890=" "$i") + local response=$(debug_register_cvm $debug_port "$key" "persist_app$i" "persist_inst$i") + if verify_register_response "$response" >/dev/null 2>&1; then + ((success_count++)) + fi + done + log_info "Registered $success_count/3 clients" + + if [[ "$success_count" -ne 3 ]]; then + log_error "Failed to register all clients" + return 1 + fi + + # Get initial key count + local keys_before=$(get_n_keys $admin_port) + log_info "Keys before waiting for persist: $keys_before" + + # Wait for periodic persistence (persist_interval is 5s in test config) + log_info "Waiting for periodic persistence (8s)..." + sleep 8 + + # Check log for periodic persist message + local log_file="${LOG_DIR}/${CURRENT_TEST}_node1.log" + if grep -q "periodic persist completed" "$log_file" 2>/dev/null; then + log_info "Found periodic persist message in log" + else + log_error "Periodic persist message not found in log - test FAILED" + return 1 + fi + + # Stop node + stop_node 1 + + # Check WAL file exists and has content + local wal_file="$RUN_DIR/wavekv_node1/node_1.wal" + if [[ ! -f "$wal_file" ]]; then + log_error "WAL file not found: $wal_file" + return 1 + fi + + local wal_size=$(stat -c%s "$wal_file" 2>/dev/null || stat -f%z "$wal_file" 2>/dev/null) + log_info "WAL file size after periodic persist: $wal_size bytes" + + # Restart node and verify data is recovered + log_info "Restarting node to verify persistence..." + start_node 1 + + local keys_after=$(get_n_keys $admin_port) + log_info "Keys after restart: $keys_after" + + if [[ "$keys_after" -ge "$keys_before" ]]; then + log_info "Periodic persistence test PASSED" + return 0 + else + log_error "Periodic persistence test FAILED: keys_before=$keys_before, keys_after=$keys_after" + return 1 + fi +} + +# ============================================================================= +# Admin RPC helper functions +# ============================================================================= + +# Call Admin.SetNodeUrl RPC +# Usage: admin_set_node_url +admin_set_node_url() { + local admin_port=$1 + local node_id=$2 + local url=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeUrl" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"url\": \"$url\"}" 2>/dev/null +} + +# Register peers between nodes via Admin RPC +# This is needed since we removed peer_node_ids/peer_urls from config +# Usage: setup_peers +# Example: setup_peers 1 2 3 # Sets up peers between nodes 1, 2, and 3 +setup_peers() { + local node_ids=("$@") + + for src_node in "${node_ids[@]}"; do + local src_admin_port=$((13000 + src_node * 10 + 6)) + + for dst_node in "${node_ids[@]}"; do + if [[ "$src_node" != "$dst_node" ]]; then + local dst_rpc_port=$((13000 + dst_node * 10 + 2)) + local dst_url="https://localhost:${dst_rpc_port}" + admin_set_node_url "$src_admin_port" "$dst_node" "$dst_url" + fi + done + done + + # Wait for peers to be registered + sleep 1 +} + +# Call Admin.SetNodeStatus RPC +# Usage: admin_set_node_status +# status: "up" or "down" +admin_set_node_status() { + local admin_port=$1 + local node_id=$2 + local status=$3 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.SetNodeStatus" \ + -H "Content-Type: application/json" \ + -d "{\"id\": $node_id, \"status\": \"$status\"}" 2>/dev/null +} + +# Call Admin.Status RPC to get all nodes +# Usage: admin_get_status +admin_get_status() { + local admin_port=$1 + curl -s -X POST "http://localhost:${admin_port}/prpc/Admin.Status" \ + -H "Content-Type: application/json" \ + -d '{}' 2>/dev/null +} + +# Get peer URL from sync data +# Usage: get_peer_url +get_peer_url_from_sync() { + local debug_port=$1 + local node_id=$2 + local response=$(debug_get_sync_data "$debug_port") + echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + for pa in d.get('peer_addrs', []): + if pa.get('node_id') == $node_id: + print(pa.get('url', '')) + sys.exit(0) + print('') +except: + print('') +" 2>/dev/null +} + +# ============================================================================= +# Test 16: Admin.SetNodeUrl RPC +# ============================================================================= +test_admin_set_node_url() { + log_info "========== Test 16: Admin.SetNodeUrl RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Set URL for a new node (node 2) via Admin RPC + local new_url="https://new-node2.example.com:8011" + log_info "Setting node 2 URL via Admin.SetNodeUrl..." + local response=$(admin_set_node_url $admin_port 2 "$new_url") + log_info "SetNodeUrl response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeUrl returned error: $response" + return 1 + fi + + # Wait for data to be written + sleep 2 + + # Verify the URL was stored in KvStore + local stored_url=$(get_peer_url_from_sync $debug_port 2) + log_info "Stored URL for node 2: $stored_url" + + if [[ "$stored_url" == "$new_url" ]]; then + log_info "Admin.SetNodeUrl test PASSED" + return 0 + else + log_error "Admin.SetNodeUrl test FAILED: expected '$new_url', got '$stored_url'" + log_info "Sync data: $(debug_get_sync_data $debug_port)" + return 1 + fi +} + +# ============================================================================= +# Test 17: Admin.SetNodeStatus RPC +# ============================================================================= +test_admin_set_node_status() { + log_info "========== Test 17: Admin.SetNodeStatus RPC ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # First set a URL for node 2 so we have a peer + admin_set_node_url $admin_port 2 "https://node2.example.com:8011" + sleep 1 + + # Set node 2 status to "down" + log_info "Setting node 2 status to 'down' via Admin.SetNodeStatus..." + local response=$(admin_set_node_status $admin_port 2 "down") + log_info "SetNodeStatus response: $response" + + # Check if the response contains an error + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + sleep 1 + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up' via Admin.SetNodeStatus..." + response=$(admin_set_node_status $admin_port 2 "up") + log_info "SetNodeStatus response: $response" + + if echo "$response" | grep -q '"error"'; then + log_error "SetNodeStatus returned error: $response" + return 1 + fi + + # Test invalid status + log_info "Testing invalid status..." + response=$(admin_set_node_status $admin_port 2 "invalid") + if echo "$response" | grep -q '"error"'; then + log_info "Invalid status correctly rejected" + else + log_warn "Invalid status was not rejected (may be acceptable)" + fi + + log_info "Admin.SetNodeStatus test PASSED" + return 0 +} + +# ============================================================================= +# Test 18: Node down excluded from RegisterCvm response +# ============================================================================= +test_node_status_register_exclude() { + log_info "========== Test 18: Node Down Excluded from Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" "$RUN_DIR/wavekv_node2" + + generate_config 1 + generate_config 2 + + start_node 1 + start_node 2 + + # Register peers so nodes can discover each other + setup_peers 1 2 + + local admin_port1=13016 + local admin_port2=13026 + local debug_port1=13015 + + # Wait for sync + sleep 5 + + # Verify debug service is available + if ! check_debug_service $debug_port1; then + log_error "Debug service not available on node 1" + return 1 + fi + + # Set node 2 status to "down" via node 1's admin API + log_info "Setting node 2 status to 'down'..." + admin_set_node_status $admin_port1 2 "down" + sleep 2 + + # Register a client on node 1 + log_info "Registering client on node 1 (node 2 is down)..." + local response=$(debug_register_cvm $debug_port1 "downtest12345678901234567890123456789012=" "downtest_app" "downtest_inst") + log_info "Register response: $response" + + # Verify registration succeeded + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed" + return 1 + fi + log_info "Registered client with IP: $client_ip" + + # Check gateways list in response - should NOT include node 2 + local has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "yes" ]]; then + log_error "Node 2 (down) was included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (down) correctly excluded from registration response" + fi + + # Set node 2 status back to "up" + log_info "Setting node 2 status to 'up'..." + admin_set_node_status $admin_port1 2 "up" + sleep 2 + + # Register another client + log_info "Registering client on node 1 (node 2 is now up)..." + response=$(debug_register_cvm $debug_port1 "uptest123456789012345678901234567890123=" "uptest_app" "uptest_inst2") + + # Check gateways list - should now include node 2 + has_node2=$(echo "$response" | python3 -c " +import sys, json +try: + d = json.load(sys.stdin) + gateways = d.get('gateways', []) + for gw in gateways: + if gw.get('id') == 2: + sys.exit(0) + sys.exit(1) +except: + sys.exit(1) +" && echo "yes" || echo "no") + + if [[ "$has_node2" == "no" ]]; then + log_error "Node 2 (up) was NOT included in registration response" + log_info "Response: $response" + return 1 + else + log_info "Node 2 (up) correctly included in registration response" + fi + + log_info "Node down excluded from registration test PASSED" + return 0 +} + +# ============================================================================= +# Test 19: Node down rejects RegisterCvm requests +# ============================================================================= +test_node_status_register_reject() { + log_info "========== Test 19: Node Down Rejects Registration ==========" + cleanup + + rm -rf "$RUN_DIR/wavekv_node1" + + generate_config 1 + start_node 1 + + local admin_port=13016 + local debug_port=13015 + + # Verify debug service is available + if ! check_debug_service $debug_port; then + log_error "Debug service not available" + return 1 + fi + + # Register a client when node is up (should succeed) + log_info "Registering client when node 1 is up..." + local response=$(debug_register_cvm $debug_port "upnode123456789012345678901234567890123=" "upnode_app" "upnode_inst") + local client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was up" + return 1 + fi + log_info "Registration succeeded when node was up (IP: $client_ip)" + + # Set node 1 status to "down" (marking itself as down) + log_info "Setting node 1 status to 'down'..." + admin_set_node_status $admin_port 1 "down" + sleep 2 + + # Try to register a client when node is down (should fail) + log_info "Attempting to register client when node 1 is down..." + response=$(debug_register_cvm $debug_port "downnode12345678901234567890123456789012=" "downnode_app" "downnode_inst") + log_info "Register response: $response" + + # Check if response contains error about node being down + if echo "$response" | grep -qi "error"; then + log_info "Registration correctly rejected when node is down" + if echo "$response" | grep -qi "marked as down"; then + log_info "Error message mentions 'marked as down' (correct)" + fi + else + log_error "Registration was NOT rejected when node is down" + log_info "Response: $response" + return 1 + fi + + # Set node 1 status back to "up" + log_info "Setting node 1 status to 'up'..." + admin_set_node_status $admin_port 1 "up" + sleep 2 + + # Register a client again (should succeed) + log_info "Registering client when node 1 is back up..." + response=$(debug_register_cvm $debug_port "backup123456789012345678901234567890123=" "backup_app" "backup_inst") + client_ip=$(verify_register_response "$response") + if [[ -z "$client_ip" ]]; then + log_error "Registration failed when node was back up" + return 1 + fi + log_info "Registration succeeded when node was back up (IP: $client_ip)" + + log_info "Node down rejects registration test PASSED" + return 0 +} + +# ============================================================================= +# Clean command - remove all generated files +# ============================================================================= +clean() { + log_info "Cleaning up generated files..." + + # Kill only test gateway processes (matching our specific config path) + pkill -9 -f "dstack-gateway -c ${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + pkill -9 -f "dstack-gateway.*${SCRIPT_DIR}/${RUN_DIR}/node" >/dev/null 2>&1 || true + sleep 1 + + # Remove WireGuard interfaces (only our test interfaces need sudo) + sudo ip link delete wavekv-test1 2>/dev/null || true + sudo ip link delete wavekv-test2 2>/dev/null || true + sudo ip link delete wavekv-test3 2>/dev/null || true + + # Remove run directory (contains all generated files including certs) + rm -rf "$RUN_DIR" + + log_info "Cleanup complete" +} + +# ============================================================================= +# Ensure all certificates exist (CA + RPC + proxy) +# ============================================================================= +ensure_certs() { + # Create directories + mkdir -p "$CERTS_DIR" + mkdir -p "$RUN_DIR/certbot/live" + + # Generate CA certificate if not exists + if [[ ! -f "$CERTS_DIR/gateway-ca.key" ]] || [[ ! -f "$CERTS_DIR/gateway-ca.cert" ]]; then + log_info "Creating CA certificate..." + openssl genrsa -out "$CERTS_DIR/gateway-ca.key" 2048 2>/dev/null + openssl req -x509 -new -nodes \ + -key "$CERTS_DIR/gateway-ca.key" \ + -sha256 -days 365 \ + -out "$CERTS_DIR/gateway-ca.cert" \ + -subj "/CN=Test CA/O=WaveKV Test" \ + 2>/dev/null + fi + + # Generate RPC certificate signed by CA if not exists + if [[ ! -f "$CERTS_DIR/gateway-rpc.key" ]] || [[ ! -f "$CERTS_DIR/gateway-rpc.cert" ]]; then + log_info "Creating RPC certificate signed by CA..." + openssl genrsa -out "$CERTS_DIR/gateway-rpc.key" 2048 2>/dev/null + openssl req -new \ + -key "$CERTS_DIR/gateway-rpc.key" \ + -out "$CERTS_DIR/gateway-rpc.csr" \ + -subj "/CN=localhost" \ + 2>/dev/null + # Create certificate with SAN for localhost + cat >"$CERTS_DIR/ext.cnf" </dev/null + rm -f "$CERTS_DIR/gateway-rpc.csr" "$CERTS_DIR/ext.cnf" + fi + + # Generate proxy certificates (for TLS termination) + local proxy_cert_dir="$RUN_DIR/certbot/live" + if [[ ! -f "$proxy_cert_dir/cert.pem" ]] || [[ ! -f "$proxy_cert_dir/key.pem" ]]; then + log_info "Creating proxy certificates..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$proxy_cert_dir/key.pem" \ + -out "$proxy_cert_dir/cert.pem" \ + -days 365 \ + -subj "/CN=localhost" \ + 2>/dev/null + fi +} + +# ============================================================================= +# Main +# ============================================================================= +main() { + # Handle clean command + if [[ "${1:-}" == "clean" ]]; then + clean + exit 0 + fi + + # Handle cfg command - generate node configuration + if [[ "${1:-}" == "cfg" ]]; then + local node_id="${2:-}" + if [[ -z "$node_id" ]]; then + log_error "Usage: $0 cfg " + log_info "Example: $0 cfg 1" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Generate config for the specified node + generate_config "$node_id" + log_info "Configuration generated: $RUN_DIR/node${node_id}.toml" + exit 0 + fi + + # Handle ls command - list all test cases + if [[ "${1:-}" == "ls" ]]; then + echo "Available test cases:" + echo "" + echo "Quick tests:" + echo " test_persistence - Single node persistence" + echo " test_status_endpoint - Status endpoint structure" + echo " test_prpc_register - prpc DebugRegisterCvm endpoint" + echo " test_prpc_info - prpc Info endpoint" + echo " test_wal_integrity - WAL file integrity" + echo "" + echo "Sync tests:" + echo " test_multi_node_sync - Multi-node sync" + echo " test_node_recovery - Node recovery after disconnect" + echo " test_cross_node_data_sync - Cross-node data sync verification" + echo "" + echo "Advanced tests:" + echo " test_client_registration_persistence - Client registration and persistence" + echo " test_stress_writes - Stress test - multiple writes" + echo " test_network_partition - Network partition simulation" + echo " test_three_node_cluster - Three-node cluster" + echo " test_three_node_bootnode - Three-node cluster with bootnode" + echo " test_node_id_reuse_rejected - Node ID reuse rejection" + echo " test_periodic_persistence - Periodic persistence" + echo "" + echo "Admin RPC tests:" + echo " test_admin_set_node_url - Admin.SetNodeUrl RPC" + echo " test_admin_set_node_status - Admin.SetNodeStatus RPC" + echo " test_node_status_register_exclude - Node down excluded from registration" + echo " test_node_status_register_reject - Node down rejects registration" + echo "" + echo "Usage:" + echo " $0 - Run all tests" + echo " $0 quick - Run quick tests only" + echo " $0 sync - Run sync tests only" + echo " $0 advanced - Run advanced tests only" + echo " $0 admin - Run admin RPC tests only" + echo " $0 case - Run specific test case" + echo " $0 ls - List all test cases" + echo " $0 clean - Clean up generated files" + exit 0 + fi + + # Handle case command - run specific test case + if [[ "${1:-}" == "case" ]]; then + local test_case="${2:-}" + if [[ -z "$test_case" ]]; then + log_error "Usage: $0 case " + log_info "Run '$0 ls' to see all available test cases" + exit 1 + fi + + # Check if gateway binary exists + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure certificates exist + ensure_certs + + # Check if test function exists + if ! declare -f "$test_case" >/dev/null; then + log_error "Test case not found: $test_case" + log_info "Use '$0 case' to see available test cases" + exit 1 + fi + + # Run the specific test + log_info "Running test case: $test_case" + CURRENT_TEST="$test_case" + if $test_case; then + log_info "Test PASSED: $test_case" + cleanup + exit 0 + else + log_error "Test FAILED: $test_case" + cleanup + exit 1 + fi + fi + + log_info "Starting WaveKV integration tests..." + + if [[ ! -f "$GATEWAY_BIN" ]]; then + log_error "Gateway binary not found: $GATEWAY_BIN" + log_info "Please run: cargo build --release" + exit 1 + fi + + # Ensure all certificates exist (RPC + proxy) + ensure_certs + + local failed=0 + local passed=0 + local failed_tests=() + + run_test() { + local test_name=$1 + CURRENT_TEST="$test_name" + if $test_name; then + ((passed++)) + else + ((failed++)) + failed_tests+=("$test_name") + fi + cleanup + } + + # Run selected test or all tests + local test_filter="${1:-all}" + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "quick" ]]; then + run_test test_persistence + run_test test_status_endpoint + run_test test_prpc_register + run_test test_prpc_info + run_test test_wal_integrity + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "sync" ]]; then + run_test test_multi_node_sync + run_test test_node_recovery + run_test test_cross_node_data_sync + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "advanced" ]]; then + run_test test_client_registration_persistence + run_test test_stress_writes + run_test test_network_partition + run_test test_three_node_cluster + run_test test_three_node_bootnode + run_test test_node_id_reuse_rejected + run_test test_periodic_persistence + fi + + if [[ "$test_filter" == "all" ]] || [[ "$test_filter" == "admin" ]]; then + run_test test_admin_set_node_url + run_test test_admin_set_node_status + run_test test_node_status_register_exclude + run_test test_node_status_register_reject + fi + + echo "" + log_info "==========================================" + log_info "Tests passed: $passed" + if [[ $failed -gt 0 ]]; then + log_error "Tests failed: $failed" + echo "" + log_error "Failed test cases:" + for test_name in "${failed_tests[@]}"; do + log_error " - $test_name" + done + echo "" + log_info "To rerun a failed test:" + log_info " $0 case " + log_info "Example:" + if [[ ${#failed_tests[@]} -gt 0 ]]; then + log_info " $0 case ${failed_tests[0]}" + fi + fi + log_info "==========================================" + + return $failed +} + +# Run if executed directly +if [[ "${BASH_SOURCE[0]}" == "${0}" ]]; then + main "$@" +fi