Skip to content

feat(irpc-iroh): make it easy to do a manual connection loop with authentication #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 277 additions & 0 deletions irpc-iroh/examples/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
//! This example demonstrates a few things:
//! * Using irpc with a cloneable server struct instead of with an actor loop
//! * Manually implementing the connection loop
//! * Authenticating peers

use anyhow::Result;
use iroh::{protocol::Router, Endpoint};

use self::storage::{StorageClient, StorageServer};

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
println!("Remote use");
remote().await?;
Ok(())
}

async fn remote() -> Result<()> {
let (server_router, server_addr) = {
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
let server = StorageServer::new("secret".to_string());
let router = Router::builder(endpoint.clone())
.accept(StorageServer::ALPN, server.clone())
.spawn();
let addr = endpoint.node_addr().await?;
(router, addr)
};

// correct authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
api.auth("secret").await?;
api.set("hello".to_string(), "world".to_string()).await?;
api.set("goodbye".to_string(), "world".to_string()).await?;
let value = api.get("hello".to_string()).await?;
println!("value = {:?}", value);
let mut list = api.list().await?;
while let Some(value) = list.recv().await? {
println!("list value = {:?}", value);
}

// invalid authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
assert!(api.auth("bad").await.is_err());
assert!(api.get("hello".to_string()).await.is_err());

// no authentication
let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr);
assert!(api.get("hello".to_string()).await.is_err());

drop(server_router);
Ok(())
}

mod storage {
//! Implementation of our storage service.
//!
//! The only `pub` item is [`StorageApi`], everything else is private.

use std::{
collections::BTreeMap,
sync::{Arc, Mutex},
};

use anyhow::Result;
use iroh::{
endpoint::{Connection, RecvStream, SendStream},
protocol::ProtocolHandler,
Endpoint,
};
use irpc::{
channel::{oneshot, spsc},
Client, Service, WithChannels,
};
// Import the macro
use irpc_derive::rpc_requests;
use irpc_iroh::{read_request, IrohRemoteConnection};
use serde::{Deserialize, Serialize};
use tracing::info;

const ALPN: &[u8] = b"storage-api/0";

/// A simple storage service, just to try it out
#[derive(Debug, Clone, Copy)]
struct StorageService;

impl Service for StorageService {}

#[derive(Debug, Serialize, Deserialize)]
struct Auth {
token: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Get {
key: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct List;

#[derive(Debug, Serialize, Deserialize)]
struct Set {
key: String,
value: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct SetMany;

// Use the macro to generate both the StorageProtocol and StorageMessage enums
// plus implement Channels for each type
#[rpc_requests(StorageService, message = StorageMessage)]
#[derive(Serialize, Deserialize)]
enum StorageProtocol {
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
Auth(Auth),
#[rpc(tx=oneshot::Sender<Option<String>>)]
Get(Get),
#[rpc(tx=oneshot::Sender<()>)]
Set(Set),
#[rpc(tx=oneshot::Sender<u64>, rx=spsc::Receiver<(String, String)>)]
SetMany(SetMany),
#[rpc(tx=spsc::Sender<String>)]
List(List),
}

#[derive(Debug, Clone)]
pub struct StorageServer {
state: Arc<Mutex<BTreeMap<String, String>>>,
auth_token: String,
}

impl ProtocolHandler for StorageServer {
fn accept(&self, conn: Connection) -> n0_future::future::Boxed<Result<()>> {
let this = self.clone();
Box::pin(async move {
let mut authed = false;
while let Some((msg, rx, tx)) = read_request(&conn).await? {
let msg_with_channels = upcast_message(msg, rx, tx);
match msg_with_channels {
StorageMessage::Auth(msg) => {
let WithChannels { inner, tx, .. } = msg;
if authed {
conn.close(1u32.into(), b"invalid message");
break;
} else if inner.token != this.auth_token {
conn.close(1u32.into(), b"permission denied");
break;
} else {
authed = true;
tx.send(Ok(())).await.ok();
}
}
msg_with_channels @ _ => {
if !authed {
conn.close(1u32.into(), b"permission denied");
break;
} else {
this.handle_authenticated(msg_with_channels).await;
}
}
}
}
conn.closed().await;
Ok(())
})
}
}

fn upcast_message(msg: StorageProtocol, rx: RecvStream, tx: SendStream) -> StorageMessage {
match msg {
StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(),
}
}

impl StorageServer {
pub const ALPN: &[u8] = ALPN;

pub fn new(auth_token: String) -> Self {
Self {
state: Default::default(),
auth_token,
}
}

async fn handle_authenticated(&self, msg: StorageMessage) {
match msg {
StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"),
StorageMessage::Get(get) => {
info!("get {:?}", get);
let WithChannels { tx, inner, .. } = get;
let res = self.state.lock().unwrap().get(&inner.key).cloned();
tx.send(res).await.ok();
}
StorageMessage::Set(set) => {
info!("set {:?}", set);
let WithChannels { tx, inner, .. } = set;
self.state.lock().unwrap().insert(inner.key, inner.value);
tx.send(()).await.ok();
}
StorageMessage::SetMany(list) => {
let WithChannels { tx, mut rx, .. } = list;
let mut i = 0;
while let Ok(Some((key, value))) = rx.recv().await {
let mut state = self.state.lock().unwrap();
state.insert(key, value);
i += 1;
}
tx.send(i).await.ok();
}
StorageMessage::List(list) => {
info!("list {:?}", list);
let WithChannels { mut tx, .. } = list;
let values = {
let state = self.state.lock().unwrap();
// TODO: use async lock to not clone here.
let values: Vec<_> = state
.iter()
.map(|(key, value)| format!("{key}={value}"))
.collect();
values
};
for value in values {
if tx.send(value).await.is_err() {
break;
}
}
}
}
}
}

pub struct StorageClient {
inner: Client<StorageMessage, StorageProtocol, StorageService>,
}

impl StorageClient {
pub const ALPN: &[u8] = ALPN;

pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
StorageClient {
inner: Client::boxed(conn),
}
}

pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
self.inner
.rpc(Auth {
token: token.to_string(),
})
.await?
.map_err(|err| anyhow::anyhow!(err))
}

pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
self.inner.rpc(Get { key }).await
}

pub async fn list(&self) -> Result<spsc::Receiver<String>, irpc::Error> {
self.inner.server_streaming(List, 10).await
}

pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
let msg = Set { key, value };
self.inner.rpc(msg).await
}
}
}
62 changes: 38 additions & 24 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,49 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(
handler: Handler<R>,
) -> io::Result<()> {
loop {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause))
if cause.error_code.into_inner() == 0 =>
{
trace!("remote side closed connection {cause:?}");
return Ok(());
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
let Some((msg, rx, tx)) = read_request(&connection).await? else {
return Ok(());
};
let size = recv
.read_varint_u64()
.await?
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: R = postcard::from_bytes(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let rx = recv;
let tx = send;
handler(msg, rx, tx).await?;
}
}

/// Reads a single request from the connection.
///
/// This accepts a bi-directional stream from the connection and reads and parses the request.
///
/// Returns the parsed request and the stream pair if reading and parsing the request succeeded.
/// Returns None if the remote closed the connection with error code `0`.
/// Returns an error for all other failure cases.
pub async fn read_request<R: DeserializeOwned + 'static>(
connection: &Connection,
) -> std::io::Result<Option<(R, RecvStream, SendStream)>> {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause)) if cause.error_code.into_inner() == 0 => {
trace!("remote side closed connection {cause:?}");
return Ok(None);
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
};
let size = recv
.read_varint_u64()
.await?
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: R =
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let rx = recv;
let tx = send;
Ok(Some((msg, rx, tx)))
}

/// Utility function to listen for incoming connections and handle them with the provided handler
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
let mut request_id = 0u64;
Expand Down
Loading