Skip to content

Commit

Permalink
Merge pull request #153 from spacemeshos/151-max-retries
Browse files Browse the repository at this point in the history
Add `--max-retries` to the post service
  • Loading branch information
poszu authored Nov 23, 2023
2 parents 2ed5c49 + c1efa57 commit c420c42
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 36 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ jobs:
toolchain: stable

steps:
- uses: arduino/setup-protoc@v2
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
- uses: actions/checkout@v3
with:
submodules: true
Expand All @@ -64,8 +67,8 @@ jobs:
unzip -j OpenCL-SDK-v2023.04.17-Win-x64.zip OpenCL-SDK-v2023.04.17-Win-x64/lib/OpenCL.lib
- uses: Swatinem/rust-cache@v2

- name: Test post crate
run: cargo test --all-features --release
- name: Tests
run: cargo test --all-features --release -p post-rs -p certifier -p service
env:
RUSTFLAGS: ${{ matrix.rustflags }}
# https://github.com/tevador/RandomX/issues/262
Expand Down
2 changes: 1 addition & 1 deletion certifier/tests/test_certify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn test_certificate_post_proof() {
// Spawn the certifier service
let signer = SigningKey::generate(&mut rand::rngs::OsRng);
let app = certifier::certifier::new(cfg, init_cfg, signer);
let server = axum::Server::bind(&"127.0.0.1:0".parse().unwrap()).serve(app.into_make_service());
let server = axum::Server::bind(&([127, 0, 0, 1], 0).into()).serve(app.into_make_service());
let addr = server.local_addr();
tokio::spawn(server);

Expand Down
45 changes: 33 additions & 12 deletions service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub mod spacemesh_v1 {

pub struct ServiceClient<S: PostService> {
endpoint: Endpoint,
reconnect_interval: Duration,
service: S,
}

Expand Down Expand Up @@ -69,7 +68,6 @@ impl<T: PostService + ?Sized> PostService for std::sync::Arc<T> {
impl<S: PostService> ServiceClient<S> {
pub fn new(
address: String,
reconnect_interval: Duration,
tls: Option<(Option<String>, Certificate, Identity)>,
service: S,
) -> eyre::Result<Self> {
Expand All @@ -96,28 +94,37 @@ impl<S: PostService> ServiceClient<S> {
None => endpoint,
};

Ok(Self {
endpoint,
reconnect_interval,
service,
})
Ok(Self { endpoint, service })
}

pub async fn run(mut self) -> eyre::Result<()> {
pub async fn run(
mut self,
max_retries: Option<usize>,
reconnect_interval: Duration,
) -> eyre::Result<()> {
loop {
let mut attempt = 1;
let client = loop {
log::debug!("connecting to the node on {}", self.endpoint.uri());
log::debug!(
"connecting to the node on {} (attempt {})",
self.endpoint.uri(),
attempt
);
match PostServiceClient::connect(self.endpoint.clone()).await {
Ok(client) => break client,
Err(e) => {
log::info!("could not connect to the node: {e}");
sleep(self.reconnect_interval).await;
if let Some(max) = max_retries {
eyre::ensure!(attempt <= max, "max retries ({max}) reached");
}
sleep(reconnect_interval).await;
}
}
attempt += 1;
};
let res = self.register_and_serve(client).await;
log::info!("disconnected: {res:?}");
sleep(self.reconnect_interval).await;
sleep(reconnect_interval).await;
}
}

Expand Down Expand Up @@ -267,6 +274,8 @@ fn convert_metadata(meta: PostMetadata) -> spacemesh_v1::Metadata {

#[cfg(test)]
mod tests {
use std::time::Duration;

use tonic::transport::{Certificate, Identity};

#[test]
Expand All @@ -275,7 +284,6 @@ mod tests {
let client_crt = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
super::ServiceClient::new(
"https://localhost:1234".to_string(),
Default::default(),
Some((
None,
Certificate::from_pem(crt.serialize_pem().unwrap()),
Expand All @@ -288,4 +296,17 @@ mod tests {
)
.unwrap();
}

#[tokio::test]
async fn gives_up_after_max_retries() {
let client = super::ServiceClient::new(
"http://localhost:1234".to_string(),
None,
super::MockPostService::new(),
)
.unwrap();

let res = client.run(Some(2), Duration::from_millis(1)).await;
assert_eq!(res.unwrap_err().to_string(), "max retries (2) reached");
}
}
45 changes: 37 additions & 8 deletions service/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ struct Cli {
/// time to wait before reconnecting to the node
#[arg(long, default_value = "5", value_parser = |secs: &str| secs.parse().map(Duration::from_secs))]
reconnect_interval_s: Duration,
/// Maximum number of retries to connect to the node
/// The default is infinite.
#[arg(long)]
max_retries: Option<usize>,

#[command(flatten, next_help_heading = "POST configuration")]
post_config: PostConfig,
Expand Down Expand Up @@ -219,14 +223,28 @@ async fn main() -> eyre::Result<()> {
None
};

let client = client::ServiceClient::new(args.address, args.reconnect_interval_s, tls, service)?;
let client_handle = tokio::spawn(client.run());
let client = client::ServiceClient::new(args.address, tls, service)?;
let client_handle = tokio::spawn(client.run(args.max_retries, args.reconnect_interval_s));

if let Some(pid) = args.watch_pid {
tokio::task::spawn_blocking(move || watch_pid(pid, Duration::from_secs(1))).await?;
Ok(())
} else {
client_handle.await?
tokio::select! {
Some(err) = watch_pid_if_needed(args.watch_pid) => {
log::info!("PID watcher exited: {err:?}");
return Ok(())

Check warning on line 232 in service/src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

unneeded `return` statement

warning: unneeded `return` statement --> service/src/main.rs:232:13 | 232 | return Ok(()) | ^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_return = note: `#[warn(clippy::needless_return)]` on by default help: remove `return` | 232 | Ok(()) |
}
err = client_handle => {
return err.unwrap();

Check warning on line 235 in service/src/main.rs

View workflow job for this annotation

GitHub Actions / clippy

unneeded `return` statement

warning: unneeded `return` statement --> service/src/main.rs:235:13 | 235 | return err.unwrap(); | ^^^^^^^^^^^^^^^^^^^ | = help: for further information visit https://rust-lang.github.io/rust-clippy/master/index.html#needless_return
}
}
}

async fn watch_pid_if_needed(
pid: Option<Pid>,
) -> Option<std::result::Result<(), tokio::task::JoinError>> {
match pid {
Some(pid) => {
Some(tokio::task::spawn_blocking(move || watch_pid(pid, Duration::from_secs(1))).await)
}
None => None,
}
}

Expand Down Expand Up @@ -254,7 +272,18 @@ fn watch_pid(pid: Pid, interval: Duration) {
mod tests {
use std::process::Command;

use sysinfo::PidExt;
use sysinfo::{Pid, PidExt};

#[tokio::test]
async fn watch_pid_if_needed() {
// Don't watch
assert!(super::watch_pid_if_needed(None).await.is_none());
// Watch
super::watch_pid_if_needed(Some(Pid::from(0)))
.await
.expect("should be some")
.expect("should be OK");
}

#[tokio::test]
async fn watching_pid_zombie() {
Expand Down
8 changes: 1 addition & 7 deletions service/tests/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,7 @@ impl TestServer {
where
S: PostService,
{
ServiceClient::new(
format!("http://{}", self.addr),
std::time::Duration::from_secs(1),
None,
service,
)
.unwrap()
ServiceClient::new(format!("http://{}", self.addr), None, service).unwrap()
}

pub async fn generate_proof(
Expand Down
12 changes: 6 additions & 6 deletions service/tests/test_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use server::{TestNodeRequest, TestServer};
async fn test_registers() {
let mut test_server = TestServer::new().await;
let client = test_server.create_client(Arc::new(MockPostService::new()));
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));

// Check if client registered
test_server.connected.recv().await.unwrap();
Expand All @@ -45,7 +45,7 @@ async fn test_gen_proof_in_progress() {
.returning(|_| Ok(ProofGenState::InProgress));
let service = Arc::new(service);
let client = test_server.create_client(service.clone());
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));

let connected = test_server.connected.recv().await.unwrap();
let response = TestServer::generate_proof(&connected, vec![0xCA; 32]).await;
Expand Down Expand Up @@ -74,7 +74,7 @@ async fn test_gen_proof_failed() {

let service = Arc::new(service);
let client = test_server.create_client(service.clone());
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));

let connected = test_server.connected.recv().await.unwrap();
let response = TestServer::generate_proof(&connected, vec![0xCA; 32]).await;
Expand Down Expand Up @@ -137,7 +137,7 @@ async fn test_gen_proof_finished() {

let service = Arc::new(service);
let client = test_server.create_client(service.clone());
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));

let connected = test_server.connected.recv().await.unwrap();

Expand Down Expand Up @@ -191,7 +191,7 @@ async fn test_broken_request_no_kind() {

let service = Arc::new(service);
let client = test_server.create_client(service.clone());
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));

let connected = test_server.connected.recv().await.unwrap();

Expand Down Expand Up @@ -262,7 +262,7 @@ async fn test_get_metadata(#[case] vrf_difficulty: Option<[u8; 32]>) {
.unwrap();

let client = test_server.create_client(Arc::new(service));
let client_handle = tokio::spawn(client.run());
let client_handle = tokio::spawn(client.run(None, std::time::Duration::from_secs(1)));
let connected = test_server.connected.recv().await.unwrap();

let response = TestServer::request_metadata(&connected).await;
Expand Down

0 comments on commit c420c42

Please sign in to comment.