diff --git a/Cargo.toml b/Cargo.toml index e11f5c06c..82509112e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ defmt = ["dep:defmt", "heapless/defmt-03"] "socket-dhcpv4" = ["socket", "medium-ethernet", "proto-dhcpv4"] "socket-dns" = ["socket", "proto-dns"] "socket-mdns" = ["socket-dns"] +"socket-eth" = ["socket", "medium-ethernet"] # Enable Cubic TCP congestion control algorithm, and it is used as a default congestion controller. # @@ -101,7 +102,7 @@ default = [ "phy-raw_socket", "phy-tuntap_interface", "proto-ipv4", "proto-dhcpv4", "proto-ipv6", "proto-dns", "proto-ipv4-fragmentation", "proto-sixlowpan-fragmentation", - "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns", "socket-mdns", + "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns", "socket-mdns", "socket-eth", "packetmeta-id", "async", "multicast" ] diff --git a/src/iface/interface/ethernet.rs b/src/iface/interface/ethernet.rs index da2d01229..2a336c0f0 100644 --- a/src/iface/interface/ethernet.rs +++ b/src/iface/interface/ethernet.rs @@ -18,6 +18,13 @@ impl InterfaceInner { return None; } + #[cfg(feature = "socket-eth")] + let _ = self.eth_socket_filter( + sockets, + &EthernetRepr::parse(ð_frame).unwrap(), + eth_frame.payload(), + ); + match eth_frame.ethertype() { #[cfg(feature = "proto-ipv4")] EthernetProtocol::Arp => self.process_arp(self.now, ð_frame), @@ -45,12 +52,7 @@ impl InterfaceInner { } } - pub(super) fn dispatch_ethernet( - &mut self, - tx_token: Tx, - buffer_len: usize, - f: F, - ) -> Result<(), DispatchError> + pub(super) fn dispatch_ethernet(&mut self, tx_token: Tx, buffer_len: usize, f: F) where Tx: TxToken, F: FnOnce(EthernetFrame<&mut [u8]>), @@ -64,8 +66,6 @@ impl InterfaceInner { frame.set_src_addr(src_addr); f(frame); - - Ok(()) }) } } diff --git a/src/iface/interface/mod.rs b/src/iface/interface/mod.rs index 8b6fce4a7..1420cbe3f 100644 --- a/src/iface/interface/mod.rs +++ b/src/iface/interface/mod.rs @@ -743,6 +743,25 @@ impl Interface { Packet::new(ip, IpPayload::Udp(udp, dns)), ) }), + #[cfg(feature = "socket-eth")] + Socket::Eth(socket) => { + socket.dispatch(&mut self.inner, |inner, (eth_repr, payload)| { + let token = device.transmit(inner.now).ok_or_else(|| { + net_debug!("failed to transmit raw ETH: device exhausted"); + EgressError::Exhausted + })?; + inner.dispatch_ethernet(token, payload.len(), |mut frame| { + frame.set_dst_addr(eth_repr.dst_addr); + frame.set_src_addr(eth_repr.src_addr); + frame.set_ethertype(eth_repr.ethertype); + frame.payload_mut().copy_from_slice(payload); + }); + + result = PollResult::SocketStateChanged; + + Ok(()) + }) + } }; match result { @@ -897,6 +916,28 @@ impl InterfaceInner { handled_by_raw_socket } + #[cfg(feature = "socket-eth")] + fn eth_socket_filter( + &mut self, + sockets: &mut SocketSet, + eth_repr: &EthernetRepr, + eth_payload: &[u8], + ) -> bool { + let mut handled_by_eth_socket = false; + + // Pass every IP packet to all raw sockets we have registered. + for eth_socket in sockets + .items_mut() + .filter_map(|i| eth::Socket::downcast_mut(&mut i.socket)) + { + if eth_socket.accepts(eth_repr) { + eth_socket.process(self, eth_repr, eth_payload); + handled_by_eth_socket = true; + } + } + handled_by_eth_socket + } + /// Checks if an address is broadcast, taking into account ipv4 subnet-local /// broadcast addresses. pub(crate) fn is_broadcast(&self, address: &IpAddress) -> bool { @@ -934,7 +975,8 @@ impl InterfaceInner { let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); arp_repr.emit(&mut packet); - }) + }); + Ok(()) } EthernetPacket::Ip(packet) => { self.dispatch_ip(tx_token, PacketMeta::default(), packet, frag) @@ -1067,17 +1109,12 @@ impl InterfaceInner { target_protocol_addr: dst_addr, }; - if let Err(e) = - self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { - frame.set_dst_addr(EthernetAddress::BROADCAST); - frame.set_ethertype(EthernetProtocol::Arp); + self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_ethertype(EthernetProtocol::Arp); - arp_repr.emit(&mut ArpPacket::new_unchecked(frame.payload_mut())) - }) - { - net_debug!("Failed to dispatch ARP request: {:?}", e); - return Err(DispatchError::NeighborPending); - } + arp_repr.emit(&mut ArpPacket::new_unchecked(frame.payload_mut())) + }); } #[cfg(feature = "proto-ipv6")] diff --git a/src/socket/eth.rs b/src/socket/eth.rs new file mode 100644 index 000000000..a50b9c5b8 --- /dev/null +++ b/src/socket/eth.rs @@ -0,0 +1,471 @@ +use core::cmp::min; +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::socket::PollAt; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; + +use crate::storage::Empty; + +use crate::wire::EthernetFrame; +use crate::wire::EthernetRepr; + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, + Truncated, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + RecvError::Truncated => write!(f, "truncated"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} + +/// A Eth packet metadata. +pub type PacketMetadata = crate::storage::PacketMetadata<()>; + +/// A Eth packet ring buffer. +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>; + +pub type Ethertype = u16; + +/// A raw Ethernet socket. +/// +/// A eth socket may be bound to a specific ethertype, and owns +/// transmit and receive packet buffers. +#[derive(Debug)] +pub struct Socket<'a> { + ethertype: Option, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, +} + +impl<'a> Socket<'a> { + /// Create a raw ETH socket bound to the given ethertype, with the given buffers. + pub fn new( + ethertype: Option, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + ) -> Socket<'a> { + Socket { + ethertype, + rx_buffer, + tx_buffer, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), + } + } + + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) + } + + /// Return the ethertype the socket is bound to. + #[inline] + pub fn ethertype(&self) -> Option { + self.ethertype + } + + /// Check whether the transmit buffer is full. + #[inline] + pub fn can_send(&self) -> bool { + !self.tx_buffer.is_full() + } + + /// Check whether the receive buffer is not empty. + #[inline] + pub fn can_recv(&self) -> bool { + !self.rx_buffer.is_empty() + } + + /// Return the maximum number packets the socket can receive. + #[inline] + pub fn packet_recv_capacity(&self) -> usize { + self.rx_buffer.packet_capacity() + } + + /// Return the maximum number packets the socket can transmit. + #[inline] + pub fn packet_send_capacity(&self) -> usize { + self.tx_buffer.packet_capacity() + } + + /// Return the maximum number of bytes inside the recv buffer. + #[inline] + pub fn payload_recv_capacity(&self) -> usize { + self.rx_buffer.payload_capacity() + } + + /// Return the maximum number of bytes inside the transmit buffer. + #[inline] + pub fn payload_send_capacity(&self) -> usize { + self.tx_buffer.payload_capacity() + } + + /// Enqueue a packet to send, and return a pointer to its payload. + /// + /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, + /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity + /// to ever send this packet. + /// + /// If the buffer is filled in a way that does not match the socket's + /// ethertype, the packet will be silently dropped. + pub fn send(&mut self, size: usize) -> Result<&mut [u8], SendError> { + let packet_buf = self + .tx_buffer + .enqueue(size, ()) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "eth:{}: buffer to send {} octets", + self.ethertype.unwrap_or(0), + packet_buf.len() + ); + Ok(packet_buf) + } + + /// Enqueue a packet to be send and pass the buffer to the provided closure. + /// The closure then returns the size of the data written into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with(&mut self, max_size: usize, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> usize, + { + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, (), f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "eth:{}: buffer to send {} octets", + self.ethertype.unwrap_or(0), + size + ); + + Ok(size) + } + + /// Enqueue a packet to send, and fill it from a slice. + /// + /// See also [send](#method.send). + pub fn send_slice(&mut self, data: &[u8]) -> Result<(), SendError> { + self.send(data.len())?.copy_from_slice(data); + Ok(()) + } + + /// Dequeue a packet, and return a pointer to the payload. + /// + /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn recv(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "eth:{}: receive {} buffered octets", + self.ethertype.unwrap_or(0), + packet_buf.len() + ); + Ok(packet_buf) + } + + /// Dequeue a packet, and copy the payload into the given slice. + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// the packet is dropped and a `RecvError::Truncated` error is returned. + /// + /// See also [recv](#method.recv). + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { + let buffer = self.recv()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok(length) + } + + /// Peek at a packet in the receive buffer and return a pointer to the + /// payload without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv](#method.recv). + /// + /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn peek(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "eth:{}: receive {} buffered octets", + self.ethertype.unwrap_or(0), + packet_buf.len() + ); + + Ok(packet_buf) + } + + /// Peek at a packet in the receive buffer, copy the payload into the given slice, + /// and return the amount of octets copied without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + /// + /// **Note**: when the size of the provided buffer is smaller than the size of the payload, + /// no data is copied into the provided buffer and a `RecvError::Truncated` error is returned. + /// + /// See also [peek](#method.peek). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result { + let buffer = self.peek()?; + if data.len() < buffer.len() { + return Err(RecvError::Truncated); + } + + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok(length) + } + + /// Return the amount of octets queued in the transmit buffer. + pub fn send_queue(&self) -> usize { + self.tx_buffer.payload_bytes_count() + } + + /// Return the amount of octets queued in the receive buffer. + pub fn recv_queue(&self) -> usize { + self.rx_buffer.payload_bytes_count() + } + + pub(crate) fn accepts(&self, eth_repr: &EthernetRepr) -> bool { + match self.ethertype { + Some(e) if e == eth_repr.ethertype.into() => true, + Some(_) => false, + None => true, + } + } + + pub(crate) fn process(&mut self, _cx: &mut Context, eth_repr: &EthernetRepr, payload: &[u8]) { + debug_assert!(self.accepts(eth_repr)); + + let header_len = eth_repr.buffer_len(); + let total_len = header_len + payload.len(); + + net_trace!( + "eth:{}: receiving {} octets", + self.ethertype.unwrap_or(0), + total_len + ); + + match self.rx_buffer.enqueue(total_len, ()) { + Ok(buf) => { + let mut frame = EthernetFrame::new_checked(buf).expect("internal ethernet error"); + eth_repr.emit(&mut frame); + frame.payload_mut().copy_from_slice(payload); + } + Err(_) => net_trace!( + "eth:{}: buffer full, dropped incoming packet", + self.ethertype.unwrap_or(0) + ), + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); + } + + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (EthernetRepr, &[u8])) -> Result<(), E>, + { + let ethertype = self.ethertype; + let res = self.tx_buffer.dequeue_with(|&mut (), buffer| { + #[allow(clippy::useless_asref)] + let frame = match EthernetFrame::new_checked(buffer.as_ref()) { + Ok(x) => x, + Err(_) => { + net_trace!("eth: malformed ethernet frame in queue, dropping."); + return Ok(()); + } + }; + let eth_repr = match EthernetRepr::parse(&frame) { + Ok(r) => r, + Err(_) => { + net_trace!("eth: malformed ethernet frame in queue, dropping."); + return Ok(()); + } + }; + net_trace!("eth:{}: sending", ethertype.unwrap_or(0)); + emit(cx, (eth_repr, frame.payload())) + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + if self.tx_buffer.is_empty() { + PollAt::Ingress + } else { + PollAt::Now + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::phy::Medium; + use crate::tests::setup; + use crate::wire::ethernet::EtherType; + use crate::wire::EthernetAddress; + + fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets]) + } + + const ETHER_TYPE: u16 = 0x1234; + + fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new(Some(ETHER_TYPE), rx_buffer, tx_buffer) + } + + #[rustfmt::skip] + pub const PACKET_BYTES: [u8; 18] = [ + 0xaa, 0xbb, 0xcc, 0x12, 0x34, 0x56, + 0xaa, 0xbb, 0xcc, 0x78, 0x90, 0x12, + 0x12, 0x34, + 0xaa, 0x00, 0x00, 0xff, + ]; + pub const PACKET_RECEIVER: [u8; 6] = [0xaa, 0xbb, 0xcc, 0x12, 0x34, 0x56]; + pub const PACKET_SENDER: [u8; 6] = [0xaa, 0xbb, 0xcc, 0x78, 0x90, 0x12]; + pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + #[test] + fn test_send() { + let (mut iface, _, _) = setup(Medium::Ethernet); + let cx = iface.context(); + let mut socket = socket(buffer(1), buffer(1)); + assert!(socket.can_send()); + assert_eq!(socket.send_slice(&PACKET_BYTES[..]), Ok(())); + assert_eq!(socket.send_slice(b""), Err(SendError::BufferFull)); + assert!(!socket.can_send()); + assert_eq!( + socket.dispatch(cx, |_, (eth_repr, eth_payload)| { + assert_eq!(eth_repr.ethertype, EtherType::from(ETHER_TYPE)); + assert_eq!(eth_payload, PACKET_PAYLOAD); + Err(()) + }), + Err(()) + ); + assert!(!socket.can_send()); + assert_eq!( + socket.dispatch(cx, |_, (eth_repr, eth_payload)| { + assert_eq!(eth_repr.ethertype, EtherType::from(ETHER_TYPE)); + assert_eq!(eth_payload, PACKET_PAYLOAD); + Ok::<_, ()>(()) + }), + Ok(()) + ); + assert!(socket.can_send()); + } + + #[test] + fn test_recv() { + let (mut iface, _, _) = setup(Medium::Ethernet); + let cx = iface.context(); + let mut socket = socket(buffer(1), buffer(1)); + + assert!(!socket.can_recv()); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + + let frameinfo = EthernetRepr { + src_addr: EthernetAddress::from_bytes(&PACKET_SENDER), + dst_addr: EthernetAddress::from_bytes(&PACKET_RECEIVER), + ethertype: ETHER_TYPE.into(), + }; + + assert!(socket.accepts(&frameinfo)); + socket.process(cx, &frameinfo, &PACKET_PAYLOAD); + assert!(socket.can_recv()); + + assert!(socket.accepts(&frameinfo)); + socket.process(cx, &frameinfo, &PACKET_PAYLOAD); + + assert_eq!(socket.peek(), Ok(&PACKET_BYTES[..])); + assert_eq!(socket.peek(), Ok(&PACKET_BYTES[..])); + assert_eq!(socket.recv(), Ok(&PACKET_BYTES[..])); + assert!(!socket.can_recv()); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } +} diff --git a/src/socket/mod.rs b/src/socket/mod.rs index 7d48b4234..1d2951e42 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -18,6 +18,8 @@ use crate::time::Instant; pub mod dhcpv4; #[cfg(feature = "socket-dns")] pub mod dns; +#[cfg(feature = "socket-eth")] +pub mod eth; #[cfg(feature = "socket-icmp")] pub mod icmp; #[cfg(feature = "socket-raw")] @@ -69,6 +71,8 @@ pub enum Socket<'a> { Dhcpv4(dhcpv4::Socket<'a>), #[cfg(feature = "socket-dns")] Dns(dns::Socket<'a>), + #[cfg(feature = "socket-eth")] + Eth(eth::Socket<'a>), } impl<'a> Socket<'a> { @@ -86,6 +90,8 @@ impl<'a> Socket<'a> { Socket::Dhcpv4(s) => s.poll_at(cx), #[cfg(feature = "socket-dns")] Socket::Dns(s) => s.poll_at(cx), + #[cfg(feature = "socket-eth")] + Socket::Eth(s) => s.poll_at(cx), } } } @@ -139,3 +145,5 @@ from_socket!(tcp::Socket<'a>, Tcp); from_socket!(dhcpv4::Socket<'a>, Dhcpv4); #[cfg(feature = "socket-dns")] from_socket!(dns::Socket<'a>, Dns); +#[cfg(feature = "socket-eth")] +from_socket!(eth::Socket<'a>, Eth); diff --git a/src/wire/mod.rs b/src/wire/mod.rs index 478f0cffc..8fc301bd3 100644 --- a/src/wire/mod.rs +++ b/src/wire/mod.rs @@ -84,7 +84,7 @@ pub(crate) mod dhcpv4; #[cfg(feature = "proto-dns")] pub(crate) mod dns; #[cfg(feature = "medium-ethernet")] -mod ethernet; +pub(crate) mod ethernet; #[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] mod icmp; #[cfg(feature = "proto-ipv4")]