From 00d5e8e87debd482c036b40aca00ce7290174132 Mon Sep 17 00:00:00 2001 From: ruben Date: Sat, 25 Jan 2025 23:09:57 +0100 Subject: [PATCH] progress on new types --- h3/src/client/builder.rs | 6 +- h3/src/client/connection.rs | 3 + h3/src/connection.rs | 17 ++++- h3/src/error2/codes.rs | 6 ++ h3/src/error2/error.rs | 119 ++++++++++++++++++++++++++------ h3/src/error2/internal_error.rs | 12 ++-- h3/src/error2/mod.rs | 1 + h3/src/error2/traits.rs | 61 +++++++++++++--- h3/src/quic.rs | 71 +++++++++++++++++++ h3/src/server/builder.rs | 6 +- h3/src/shared_state.rs | 18 +++-- 11 files changed, 266 insertions(+), 54 deletions(-) diff --git a/h3/src/client/builder.rs b/h3/src/client/builder.rs index 1dfcd06e..024ef26d 100644 --- a/h3/src/client/builder.rs +++ b/h3/src/client/builder.rs @@ -14,6 +14,7 @@ use crate::{ connection::{ConnectionInner, SharedStateRef}, error::Error, quic::{self}, + shared_state::SharedState2, }; use super::connection::{Connection, SendRequest}; @@ -107,14 +108,15 @@ impl Builder { B: Buf, { let open = quic.opener(); - let conn_state = SharedStateRef::default(); + let conn_state = Arc::new(SharedState2::default()); let conn_waker = Some(future::poll_fn(|cx| Poll::Ready(cx.waker().clone())).await); let inner = ConnectionInner::new(quic, conn_state.clone(), self.config).await?; let send_request = SendRequest { open, - conn_state, + conn_state: todo!(), + conn_state2: conn_state, conn_waker, max_field_section_size: self.config.settings.max_field_section_size, sender_count: Arc::new(AtomicUsize::new(1)), diff --git a/h3/src/client/connection.rs b/h3/src/client/connection.rs index 807763e0..ab49eb9b 100644 --- a/h3/src/client/connection.rs +++ b/h3/src/client/connection.rs @@ -21,6 +21,7 @@ use crate::{ proto::{frame::Frame, headers::Header, push::PushId}, qpack, quic::{self, StreamId}, + shared_state::SharedState2, stream::{self, BufRecvStream}, }; @@ -110,6 +111,7 @@ where { pub(super) open: T, pub(super) conn_state: SharedStateRef, + pub(super) conn_state2: Arc, pub(super) max_field_section_size: u64, // maximum size for a header we receive // counts instances of SendRequest to close the connection when the last is dropped. pub(super) sender_count: Arc, @@ -219,6 +221,7 @@ where .fetch_add(1, std::sync::atomic::Ordering::Release); Self { + conn_state2: self.conn_state2.clone(), open: self.open.clone(), conn_state: self.conn_state.clone(), max_field_section_size: self.max_field_section_size, diff --git a/h3/src/connection.rs b/h3/src/connection.rs index e44e15dc..39d905a0 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -26,6 +26,7 @@ use crate::{ }, qpack, quic::{self, SendStream}, + shared_state::{ConnectionState2, SharedState2}, stream::{self, AcceptRecvStream, AcceptedRecvStream, BufRecvStream, UniStreamHeader}, webtransport::SessionId, }; @@ -106,6 +107,7 @@ where C: quic::Connection, B: Buf, { + pub(super) shared2: Arc, pub(super) shared: SharedStateRef, /// TODO: breaking encapsulation just to see if we can get this to work, will fix before merging pub conn: C, @@ -149,6 +151,16 @@ where pub(crate) error_sender: UnboundedSender<(Code, &'static str)>, } +impl ConnectionState2 for ConnectionInner +where + C: quic::Connection, + B: Buf, +{ + fn shared_state(&self) -> &SharedState2 { + &self.shared2 + } +} + enum GreaseStatus where S: SendStream, @@ -240,7 +252,7 @@ where /// Initiates the connection and opens a control stream #[cfg_attr(feature = "tracing", instrument(skip_all, level = "trace"))] - pub async fn new(mut conn: C, shared: SharedStateRef, config: Config) -> Result { + pub async fn new(mut conn: C, shared2: Arc, config: Config) -> Result { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 //# Endpoints SHOULD create the HTTP control stream as well as the //# unidirectional streams required by mandatory extensions (such as the @@ -261,7 +273,8 @@ where //# sender MUST NOT close the control stream, and the receiver MUST NOT //# request that the sender close the control stream. let mut conn_inner = Self { - shared, + shared2, + shared: todo!(), conn, control_send: control_send.map_err(Error::transport_err)?, control_recv: None, diff --git a/h3/src/error2/codes.rs b/h3/src/error2/codes.rs index 974f12cf..10657730 100644 --- a/h3/src/error2/codes.rs +++ b/h3/src/error2/codes.rs @@ -138,3 +138,9 @@ impl From for u64 { code.code } } + +impl From for NewCode { + fn from(code: u64) -> NewCode { + NewCode { code } + } +} diff --git a/h3/src/error2/error.rs b/h3/src/error2/error.rs index 43d3d953..c1bf7f59 100644 --- a/h3/src/error2/error.rs +++ b/h3/src/error2/error.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -use crate::quic; +use crate::quic::{self, ConnectionErrorIncoming, StreamErrorIncoming}; -use super::codes::NewCode; +use super::{codes::NewCode, internal_error::{InternalConnectionError, InternalRequestStreamError}}; /// This enum represents wether the error occurred on the local or remote side of the connection #[derive(Debug, Clone)] @@ -19,7 +19,7 @@ pub enum ConnectionError { /// Error returned by the quic layer /// I might be an quic error or the remote h3 connection closed the connection with an error #[non_exhaustive] - Remote(Arc), + Remote(ConnectionErrorIncoming), /// Timeout occurred #[non_exhaustive] Timeout, @@ -42,6 +42,24 @@ pub enum LocalError { Closing, } +impl From<&InternalConnectionError> for LocalError { + fn from(err: &InternalConnectionError) -> Self { + LocalError::Application { + code: err.code, + reason: err.message, + } + } +} + +impl From<&InternalRequestStreamError> for LocalError { + fn from(err: &InternalRequestStreamError) -> Self { + LocalError::Application { + code: err.code, + reason: err.message, + } + } +} + /// This enum represents a stream error #[derive(Debug, Clone)] #[non_exhaustive] @@ -54,26 +72,44 @@ pub enum StreamError { /// The error reason reason: &'static str, }, + /// Stream was Reset by the peer + RemoteReset { + /// Reset code received from the peer + code: NewCode, + }, /// The error occurred on the connection #[non_exhaustive] ConnectionError(ConnectionError), } +impl From for StreamError { + fn from(err: StreamErrorIncoming) -> Self { + match err { + StreamErrorIncoming::ConnectionErrorIncoming { connection_error } => { + StreamError::ConnectionError(connection_error.into()) + } + StreamErrorIncoming::StreamReset { error_code } => StreamError::RemoteReset { + code: error_code.into(), + }, + } + } +} + +impl From for ConnectionError { + fn from(value: ConnectionErrorIncoming) -> Self { + return match value { + ConnectionErrorIncoming::Timeout => ConnectionError::Timeout, + error => ConnectionError::Remote(error), + }; + } +} + /// This enum represents a stream error #[derive(Debug, Clone)] #[non_exhaustive] pub enum ServerStreamError { - /// The error occurred on the stream - #[non_exhaustive] - StreamError { - /// The error code - code: NewCode, - /// The error reason - reason: &'static str, - }, - /// The error occurred on the connection - #[non_exhaustive] - ConnectionError(ConnectionError), + /// A Stream error occurred + General(StreamError), #[non_exhaustive] /// The received header block is too big /// The Request has been answered with a 431 Request Header Fields Too Large @@ -85,13 +121,48 @@ pub enum ServerStreamError { }, } +#[derive(Debug, Clone)] +#[non_exhaustive] +/// This enum represents a stream error +/// +/// This type will be returned by the client in case of an error on the stream methods +pub enum ClientStreamError { + /// A Stream error occurred + General(StreamError), + #[non_exhaustive] + /// The request cannot be sent because the header block larger then permitted by the server + HeaderTooBig { + /// The actual size of the header block + actual_size: u64, + /// The maximum size of the header block + max_size: u64, + }, +} + +impl std::fmt::Display for ClientStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ClientStreamError::General(stream_error) => { + write!(f, "Stream error: {}", stream_error) + } + ClientStreamError::HeaderTooBig { + actual_size, + max_size, + } => write!( + f, + "Request cannot be sent because the header is lager than permitted by the server: permitted size: {}, max size: {}", + actual_size, max_size + ), + } + } +} + impl std::fmt::Display for ServerStreamError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ServerStreamError::StreamError { code, reason } => { - write!(f, "Stream error: {:?} - {}", code, reason) + ServerStreamError::General(stream_error) => { + write!(f, "Stream error: {}", stream_error) } - ServerStreamError::ConnectionError(err) => write!(f, "Connection error: {}", err), ServerStreamError::HeaderTooBig { actual_size, max_size, @@ -108,12 +179,13 @@ impl std::error::Error for ServerStreamError {} impl From for ServerStreamError { fn from(err: StreamError) -> Self { - match err { - StreamError::StreamError { code, reason } => { - ServerStreamError::StreamError { code, reason } - } - StreamError::ConnectionError(err) => ServerStreamError::ConnectionError(err), - } + ServerStreamError::General(err) + } +} + +impl From for ClientStreamError { + fn from(err: StreamError) -> Self { + ClientStreamError::General(err) } } @@ -136,6 +208,7 @@ impl std::fmt::Display for StreamError { write!(f, "Stream error: {:?} - {}", code, reason) } StreamError::ConnectionError(err) => write!(f, "Connection error: {}", err), + StreamError::RemoteReset { code } => write!(f, "Remote reset: {:?}", code), } } } diff --git a/h3/src/error2/internal_error.rs b/h3/src/error2/internal_error.rs index 52d0edd4..2d6a6715 100644 --- a/h3/src/error2/internal_error.rs +++ b/h3/src/error2/internal_error.rs @@ -1,5 +1,7 @@ use crate::error::Code; +use super::codes::NewCode; + /// This enum determines if the error affects the connection or the stream /// /// The end goal is to make all stream errors from the spec potential connection errors if users of h3 decide to treat them as such @@ -20,11 +22,11 @@ pub enum ErrorScope { #[derive(Debug, Clone, Hash)] pub struct InternalRequestStreamError { /// The error scope - scope: ErrorScope, + pub(crate) scope: ErrorScope, /// The error code - code: Code, + pub(crate) code: NewCode, /// The error message - message: &'static str, + pub(crate) message: &'static str, } /// This error type represents an internal error type, which is used @@ -36,7 +38,7 @@ pub struct InternalRequestStreamError { #[derive(Debug, Clone, Hash)] pub struct InternalConnectionError { /// The error code - code: Code, + pub(super) code: NewCode, /// The error message - message: &'static str, + pub(super) message: &'static str, } diff --git a/h3/src/error2/mod.rs b/h3/src/error2/mod.rs index 317cebac..348aeb7b 100644 --- a/h3/src/error2/mod.rs +++ b/h3/src/error2/mod.rs @@ -18,3 +18,4 @@ pub(crate) mod internal_error; mod error; pub use error::{ConnectionError, LocalError, ServerStreamError, StreamError}; +pub use codes::NewCode; diff --git a/h3/src/error2/traits.rs b/h3/src/error2/traits.rs index 2aae6f28..169dc3a1 100644 --- a/h3/src/error2/traits.rs +++ b/h3/src/error2/traits.rs @@ -1,31 +1,72 @@ //! Defines error traits +use tokio::sync::oneshot::error; + use crate::{config::Config, shared_state::ConnectionState2}; -use super::{codes::NewCode, internal_error::InternalRequestStreamError, ConnectionError}; +use super::{ + codes::NewCode, + internal_error::{ErrorScope, InternalConnectionError, InternalRequestStreamError}, + ConnectionError, StreamError, +}; /// This trait is implemented for all types which can close the connection pub(crate) trait CloseConnection: ConnectionState2 { /// Close the connection fn handle_connection_error( &mut self, - internal_error: InternalRequestStreamError, + internal_error: InternalConnectionError, ) -> ConnectionError { - //self.maybe_conn_error(error) - todo!() + return if let Some(error) = self.get_conn_error() { + error + } else { + let error = ConnectionError::Local { + error: (&internal_error).into(), + }; + self.set_conn_error(error.clone()); + self.close_connection(&internal_error.code, internal_error.message); + error + }; } - fn close_connection>(code: &NewCode, reason: T) -> (); + fn close_connection>(&mut self, code: &NewCode, reason: T) -> (); } -pub(crate) trait CloseStream: CloseConnection { +pub(crate) trait CloseStream: ConnectionState2 { fn handle_stream_error( &mut self, internal_error: InternalRequestStreamError, - config: &Config, - ) -> ConnectionError { - todo!() + ) -> StreamError { + return if let Some(error) = self.get_conn_error() { + // If the connection is already in an error state, return the error + StreamError::ConnectionError(error) + } else { + match internal_error.scope { + ErrorScope::Connection => { + // If the error affects the connection, close the connection + let conn_error = ConnectionError::Local { + error: (&internal_error).into(), + }; + + self.set_conn_error(conn_error.clone()); + let error = StreamError::ConnectionError(conn_error); + self.close_connection(&internal_error.code, internal_error.message); + error + } + ErrorScope::Stream => { + // If the error affects the stream, close the stream + self.close_stream(&internal_error.code, internal_error.message); + + let error = StreamError::StreamError { + code: internal_error.code, + reason: internal_error.message, + }; + error + } + } + }; } - fn close_stream() -> (); + fn close_connection>(&mut self, code: &NewCode, reason: T) -> (); + fn close_stream>(&mut self, code: &NewCode, reason: T) -> (); } diff --git a/h3/src/quic.rs b/h3/src/quic.rs index 39e7217a..530ae73a 100644 --- a/h3/src/quic.rs +++ b/h3/src/quic.rs @@ -3,6 +3,7 @@ //! This module includes traits and types meant to allow being generic over any //! QUIC implementation. +use std::fmt::Display; use std::task::{self, Poll}; use bytes::Buf; @@ -10,6 +11,76 @@ use bytes::Buf; pub use crate::proto::stream::{InvalidStreamId, StreamId}; pub use crate::stream::WriteBuf; + +/// Error type to communicate that the quic connection was closed +/// +/// This is used by to implement the quic abstraction traits +#[derive(Debug, Clone)] +pub enum ConnectionErrorIncoming { + /// Error from the http3 layer + ApplicationClose { + /// http3 error code + error_code: u64, + }, + /// Quic connection timeout + Timeout, + /// Quic error + ConnectionClosed { + /// quic error code + error_code: u64, + }, +} + +/// Error type to communicate that the stream was closed +/// +/// This is used by to implement the quic abstraction traits +#[derive(Debug, Clone)] +pub enum StreamErrorIncoming { + /// Stream is closed because the whole connection is closed + ConnectionErrorIncoming { + /// Connection error + connection_error: ConnectionErrorIncoming, + }, + /// Stream was closed by the peer + StreamReset { + /// Error code sent by the peer + error_code: u64, + }, +} + +impl std::error::Error for StreamErrorIncoming {} + +impl Display for StreamErrorIncoming { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // display enum with fields + match self { + StreamErrorIncoming::ConnectionErrorIncoming { connection_error } => { + write!(f, "ConnectionError: {}", connection_error) + } + StreamErrorIncoming::StreamReset { error_code } => { + write!(f, "StreamClosed: {}", error_code) + } + } + } +} + +impl Display for ConnectionErrorIncoming { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // display enum with fields + match self { + ConnectionErrorIncoming::ApplicationClose { error_code } => { + write!(f, "ApplicationClose: {}", error_code) + } + ConnectionErrorIncoming::Timeout => write!(f, "Timeout"), + ConnectionErrorIncoming::ConnectionClosed { error_code } => { + write!(f, "ConnectionClosed: {}", error_code) + } + } + } +} + +impl std::error::Error for ConnectionErrorIncoming {} + // Unresolved questions: // // - Should the `poll_` methods be `Pin<&mut Self>`? diff --git a/h3/src/server/builder.rs b/h3/src/server/builder.rs index 1c7cd264..ae4970c8 100644 --- a/h3/src/server/builder.rs +++ b/h3/src/server/builder.rs @@ -21,7 +21,7 @@ //! } //! ``` -use std::{collections::HashSet, result::Result}; +use std::{collections::HashSet, result::Result, sync::Arc}; use bytes::Buf; @@ -32,6 +32,7 @@ use crate::{ connection::{ConnectionInner, SharedStateRef}, error::Error, quic::{self}, + shared_state::SharedState2, }; use super::connection::Connection; @@ -129,7 +130,8 @@ impl Builder { { let (sender, receiver) = mpsc::unbounded_channel(); Ok(Connection { - inner: ConnectionInner::new(conn, SharedStateRef::default(), self.config).await?, + inner: ConnectionInner::new(conn, Arc::new(SharedState2::default()), self.config) + .await?, max_field_section_size: self.config.settings.max_field_section_size, request_end_send: sender, request_end_recv: receiver, diff --git a/h3/src/shared_state.rs b/h3/src/shared_state.rs index 02885353..faec4ac0 100644 --- a/h3/src/shared_state.rs +++ b/h3/src/shared_state.rs @@ -4,7 +4,7 @@ use std::sync::{atomic::AtomicBool, Arc, OnceLock}; use crate::{config::Settings, error2::ConnectionError}; -#[derive(Debug, Clone)] +#[derive(Debug)] /// This struct represents the shared state of the h3 connection and the stream structs pub(crate) struct SharedState2 { /// The settings, sent by the peer @@ -12,7 +12,7 @@ pub(crate) struct SharedState2 { /// The connection error connection_error: OnceLock, /// The connection is closing - closing: Arc, + closing: AtomicBool, } impl Default for SharedState2 { @@ -20,7 +20,7 @@ impl Default for SharedState2 { Self { settings: OnceLock::new(), connection_error: OnceLock::new(), - closing: Arc::new(AtomicBool::new(false)), + closing: AtomicBool::new(false), } } } @@ -29,16 +29,14 @@ impl Default for SharedState2 { pub trait ConnectionState2 { /// Get the shared state fn shared_state(&self) -> &SharedState2; - /// Get the Error - fn maybe_conn_error(&self, error: ConnectionError) -> ConnectionError { - self.shared_state() - .connection_error - .get_or_init(|| error) - .clone() - } + /// Get the connection error fn get_conn_error(&self) -> Option { self.shared_state().connection_error.get().cloned() } + /// Set the connection error + fn set_conn_error(&self, error: ConnectionError) { + self.shared_state().connection_error.set(error); + } /// Get the settings fn settings(&self) -> Option<&Settings> { self.shared_state().settings.get()