From 50d93c04addcadcbfdcc8f4b7c8f6b55df37cb44 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 29 Jun 2025 12:29:14 +0800 Subject: [PATCH 01/36] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9Eveth=E5=92=8Cbr?= =?UTF-8?q?idge=E7=BB=93=E6=9E=84=E4=BD=93=EF=BC=8C=E5=B0=9A=E6=9C=AA?= =?UTF-8?q?=E8=AF=A6=E7=BB=86=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 101 +++++++ kernel/src/driver/net/mod.rs | 2 + kernel/src/driver/net/veth.rs | 492 ++++++++++++++++++++++++++++++++ kernel/src/net/net_core.rs | 2 +- user/apps/test-veth/.gitignore | 3 + user/apps/test-veth/Cargo.toml | 11 + user/apps/test-veth/Makefile | 56 ++++ user/apps/test-veth/src/main.rs | 217 ++++++++++++++ 8 files changed, 883 insertions(+), 1 deletion(-) create mode 100644 kernel/src/driver/net/bridge.rs create mode 100644 kernel/src/driver/net/veth.rs create mode 100644 user/apps/test-veth/.gitignore create mode 100644 user/apps/test-veth/Cargo.toml create mode 100644 user/apps/test-veth/Makefile create mode 100644 user/apps/test-veth/src/main.rs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs new file mode 100644 index 000000000..19bcf1913 --- /dev/null +++ b/kernel/src/driver/net/bridge.rs @@ -0,0 +1,101 @@ +use crate::{libs::rwlock::RwLock, time::Instant}; +use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; +use smoltcp::wire::EthernetAddress; + +const MAC_ENTRY_TIMEOUT: u64 = 60_000; // 60秒 + +struct MacEntry { + port: String, + last_seen: Instant, +} + +pub struct BridgePort { + name: String, + bridge_enable: Arc, +} + +pub struct Bridge { + name: String, + ports: RwLock>>, + mac_table: RwLock>, +} + +impl Bridge { + pub fn new(name: &str) -> Self { + Self { + name: name.into(), + ports: RwLock::new(BTreeMap::new()), + mac_table: RwLock::new(BTreeMap::new()), + } + } + + pub fn add_port(&self, port: Arc) { + let port_name = port.name(); + let port_obj = Arc::new(BridgePort { + name: port_name.clone(), + bridge_enable: port.clone(), + }); + + self.ports + .write_irqsave() + .insert(port_name.clone(), port_obj); + } + + pub fn handle_frame( + &self, + src_mac: EthernetAddress, + dst_mac: EthernetAddress, + frame: Vec, + ingress: &str, + ) { + // MAC 学习 + self.mac_table.write_irqsave().insert( + src_mac, + MacEntry { + port: ingress.into(), + last_seen: Instant::now(), + }, + ); + + let ports = self.ports.read(); + if dst_mac == EthernetAddress::BROADCAST { + // 广播 + for (name, port) in ports.iter() { + if name != ingress { + Self::transmit_to(port, &frame); + } + } + } else { + // 单播 + if let Some(out_port) = self.mac_table.read().get(&dst_mac) { + if let Some(port) = ports.get(out_port.port.as_str()) { + Self::transmit_to(port, &frame); + } + } else { + // 未知单播 → 广播 + for (name, port) in ports.iter() { + if name != ingress { + Self::transmit_to(port, &frame); + } + } + } + } + } + + fn transmit_to(port: &BridgePort, frame: &[u8]) { + port.bridge_enable.bridge_transmit(frame); + } + + pub fn sweep_mac_table(&self) { + let now = Instant::now(); + self.mac_table.write_irqsave().retain(|_mac, entry| { + now.duration_since(entry.last_seen).unwrap().total_millis() < MAC_ENTRY_TIMEOUT + }); + } +} + +pub trait BridgeEnableDevice { + fn name(&self) -> String; + fn bridge_transmit(&self, frame: &[u8]); + // fn bridge_receive(&self, frame: &[u8]) ; +} diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 98187cd86..b9f355080 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -10,6 +10,7 @@ use crate::{ use smoltcp; use system_error::SystemError; +pub mod bridge; pub mod class; mod dma; pub mod e1000e; @@ -17,6 +18,7 @@ pub mod irq_handle; pub mod kthread; pub mod loopback; pub mod sysfs; +pub mod veth; pub mod virtio_net; bitflags! { diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs new file mode 100644 index 000000000..ad8e82320 --- /dev/null +++ b/kernel/src/driver/net/veth.rs @@ -0,0 +1,492 @@ +use crate::arch::rand::rand; +use crate::driver::base::class::Class; +use crate::driver::base::device::bus::Bus; +use crate::driver::base::device::driver::Driver; +use crate::driver::base::device::{Device, DeviceCommonData, DeviceType, IdTable}; +use crate::driver::base::kobject::{ + KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, +}; +use crate::driver::base::kset::KSet; +use crate::filesystem::kernfs::KernFSInode; +use crate::init::initcall::INITCALL_DEVICE; +use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; +use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::net::{generate_iface_id, NET_DEVICES}; +use alloc::collections::VecDeque; +use alloc::fmt::Debug; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::cell::UnsafeCell; +use core::ops::{Deref, DerefMut}; +use smoltcp::phy::DeviceCapabilities; +use smoltcp::phy::{self, TxToken}; +use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +use super::bridge::BridgeEnableDevice; +use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate}; + +use super::{Iface, IfaceCommon}; + +// const DEVICE_NAME: &str = "veth"; + +pub struct Veth { + rx_queue: VecDeque>, + peer: Option>>, +} + +impl Veth { + pub fn new() -> Self { + Veth { + rx_queue: VecDeque::new(), + peer: None, + } + } + + pub fn set_peer(&mut self, peer: Arc>) { + self.peer = Some(peer); + } + + pub fn send_to_peer(&self, data: Vec) { + if let Some(peer) = &self.peer { + peer.lock().rx_queue.push_back(data); + } + } + + pub fn recv(&mut self) -> Option> { + self.rx_queue.pop_front() + } +} + +#[derive(Clone)] +pub struct VethDriver { + pub inner: Arc>, +} + +impl VethDriver { + pub fn new_pair() -> (Self, Self) { + let dev1 = Arc::new(SpinLock::new(Veth::new())); + let dev2 = Arc::new(SpinLock::new(Veth::new())); + + dev1.lock().set_peer(dev2.clone()); + dev2.lock().set_peer(dev1.clone()); + + (VethDriver { inner: dev1 }, VethDriver { inner: dev2 }) + } +} + +pub struct VethTxToken { + driver: VethDriver, +} + +impl phy::TxToken for VethTxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buf = vec![0; len]; + let result = f(&mut buf); + self.driver.inner.lock().send_to_peer(buf); + result + } +} + +pub struct VethRxToken { + buffer: Vec, +} + +impl phy::RxToken for VethRxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(&self.buffer) + } +} + +#[derive(Debug)] +struct VethDriverWarpper(UnsafeCell); +unsafe impl Send for VethDriverWarpper {} +unsafe impl Sync for VethDriverWarpper {} + +impl Deref for VethDriverWarpper { + type Target = VethDriver; + fn deref(&self) -> &Self::Target { + unsafe { &*self.0.get() } + } +} + +impl DerefMut for VethDriverWarpper { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0.get() } + } +} + +impl VethDriverWarpper { + #[allow(clippy::mut_from_ref)] + #[allow(clippy::mut_from_ref)] + fn force_get_mut(&self) -> &mut VethDriver { + unsafe { &mut *self.0.get() } + } +} + +impl phy::Device for VethDriver { + type RxToken<'a> = VethRxToken; + type TxToken<'a> = VethTxToken; + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1500; + caps.medium = smoltcp::phy::Medium::Ethernet; + caps + } + + fn receive( + &mut self, + _timestamp: smoltcp::time::Instant, + ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut guard = self.inner.lock(); + guard.recv().map(|buf| { + ( + VethRxToken { buffer: buf }, + VethTxToken { + driver: self.clone(), + }, + ) + }) + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + Some(VethTxToken { + driver: self.clone(), + }) + } +} + +#[cast_to([sync] Iface)] +#[cast_to([sync] Device)] +#[derive(Debug)] +pub struct VethInterface { + name: String, + driver: VethDriverWarpper, + common: IfaceCommon, + inner: SpinLock, + locked_kobj_state: LockedKObjectState, +} + +#[derive(Debug)] +pub struct InnerVethInterface { + netdevice_common: NetDeviceCommonData, + device_common: DeviceCommonData, + kobj_common: KObjectCommonData, +} + +impl VethInterface { + pub fn new(name: &str, driver: VethDriver) -> Arc { + let iface_id = generate_iface_id(); + let mac = [ + 0x02, + 0x00, + 0x00, + 0x00, + (iface_id >> 8) as u8, + iface_id as u8, + ]; + let hw_addr = HardwareAddress::Ethernet(EthernetAddress(mac)); + let mut iface_config = smoltcp::iface::Config::new(hw_addr); + iface_config.random_seed = rand() as u64; + let mut iface = smoltcp::iface::Interface::new( + iface_config, + &mut driver.clone(), + crate::time::Instant::now().into(), + ); + iface.set_any_ip(true); + + let device = Arc::new(VethInterface { + name: name.to_string(), + driver: VethDriverWarpper(UnsafeCell::new(driver)), + common: IfaceCommon::new(iface_id, true, iface), + inner: SpinLock::new(InnerVethInterface { + netdevice_common: NetDeviceCommonData::default(), + device_common: DeviceCommonData::default(), + kobj_common: KObjectCommonData::default(), + }), + locked_kobj_state: LockedKObjectState::default(), + }); + + device.set_net_state(NetDeivceState::__LINK_STATE_START); + device.set_operstate(Operstate::IF_OPER_UP); + NET_DEVICES + .write_irqsave() + .insert(device.nic_id(), device.clone()); + log::debug!( + "VethInterface created, devices: {:?}", + NET_DEVICES.read().keys() + ); + register_netdevice(device.clone()).expect("register veth device failed"); + + device + } + + fn inner(&self) -> SpinLockGuard { + self.inner.lock() + } + + pub fn update_ip_addrs(&self, addr: IpAddress, cidr: u8) { + let iface = &mut self.common.smol_iface.lock(); + let cidr = IpCidr::new(addr, cidr); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs.push(cidr).expect("Push ipCidr failed: full"); + }); + + // 默认路由 + iface.routes_mut().update(|routes_map| { + routes_map + .push(smoltcp::iface::Route { + cidr, + via_router: addr, + preferred_until: None, + expires_at: None, + }) + .expect("Add default ipv4 route failed: full"); + }); + } +} + +impl KObject for VethInterface { + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } + fn set_inode(&self, inode: Option>) { + self.inner().kobj_common.kern_inode = inode; + } + fn inode(&self) -> Option> { + self.inner().kobj_common.kern_inode.clone() + } + fn parent(&self) -> Option> { + self.inner().kobj_common.parent.clone() + } + fn set_parent(&self, parent: Option>) { + self.inner().kobj_common.parent = parent; + } + fn kset(&self) -> Option> { + self.inner().kobj_common.kset.clone() + } + fn set_kset(&self, kset: Option>) { + self.inner().kobj_common.kset = kset; + } + fn kobj_type(&self) -> Option<&'static dyn KObjType> { + self.inner().kobj_common.kobj_type + } + fn name(&self) -> String { + self.name.clone() + } + fn set_name(&self, _name: String) {} + fn kobj_state(&self) -> RwLockReadGuard { + self.locked_kobj_state.read() + } + fn kobj_state_mut(&self) -> RwLockWriteGuard { + self.locked_kobj_state.write() + } + fn set_kobj_state(&self, state: KObjectState) { + *self.locked_kobj_state.write() = state; + } + fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) { + self.inner().kobj_common.kobj_type = ktype; + } +} + +impl Device for VethInterface { + fn dev_type(&self) -> DeviceType { + DeviceType::Net + } + fn id_table(&self) -> IdTable { + IdTable::new(self.name.clone(), None) + } + fn bus(&self) -> Option> { + self.inner().device_common.bus.clone() + } + fn set_bus(&self, bus: Option>) { + self.inner().device_common.bus = bus; + } + fn class(&self) -> Option> { + let mut guard = self.inner(); + let r = guard.device_common.class.clone()?.upgrade(); + if r.is_none() { + guard.device_common.class = None; + } + r + } + fn set_class(&self, class: Option>) { + self.inner().device_common.class = class; + } + fn driver(&self) -> Option> { + let r = self.inner().device_common.driver.clone()?.upgrade(); + if r.is_none() { + self.inner().device_common.driver = None; + } + r + } + fn set_driver(&self, driver: Option>) { + self.inner().device_common.driver = driver; + } + fn is_dead(&self) -> bool { + false + } + fn can_match(&self) -> bool { + self.inner().device_common.can_match + } + fn set_can_match(&self, can_match: bool) { + self.inner().device_common.can_match = can_match; + } + fn state_synced(&self) -> bool { + true + } + fn dev_parent(&self) -> Option> { + self.inner().device_common.get_parent_weak_or_clear() + } + fn set_dev_parent(&self, parent: Option>) { + self.inner().device_common.parent = parent; + } +} + +impl Iface for VethInterface { + fn common(&self) -> &IfaceCommon { + &self.common + } + fn iface_name(&self) -> String { + self.name.clone() + } + fn mac(&self) -> EthernetAddress { + if let HardwareAddress::Ethernet(mac) = self.common.smol_iface.lock().hardware_addr() { + mac + } else { + EthernetAddress([0, 0, 0, 0, 0, 0]) + } + } + fn poll(&self) { + self.common.poll(self.driver.force_get_mut()); + } + fn addr_assign_type(&self) -> u8 { + self.inner().netdevice_common.addr_assign_type + } + fn net_device_type(&self) -> u16 { + self.inner().netdevice_common.net_device_type = 1; + self.inner().netdevice_common.net_device_type + } + fn net_state(&self) -> NetDeivceState { + self.inner().netdevice_common.state + } + fn set_net_state(&self, state: NetDeivceState) { + self.inner().netdevice_common.state |= state; + } + fn operstate(&self) -> Operstate { + self.inner().netdevice_common.operstate + } + fn set_operstate(&self, state: Operstate) { + self.inner().netdevice_common.operstate = state; + } +} + +impl BridgeEnableDevice for VethInterface { + fn bridge_transmit(&self, frame: &[u8]) { + let driver = self.driver.force_get_mut(); + let token = VethTxToken { + driver: driver.clone(), + }; + token.consume(frame.len(), |buf| { + buf.copy_from_slice(frame); + }); + } + fn name(&self) -> String { + self.name.clone() + } + // fn bridge_receive(&self, frame: &[u8]) { + // let driver = self.driver.force_get_mut(); + // let token = VethRxToken { + // buffer: frame.to_vec(), + // }; + // } +} + +pub fn veth_probe() { + let (drv0, drv1) = VethDriver::new_pair(); + VethInterface::new("veth0", drv0); + VethInterface::new("veth1", drv1); +} + +#[unified_init(INITCALL_DEVICE)] +pub fn veth_init() -> Result<(), SystemError> { + veth_probe(); + log::info!("Veth pair initialized."); + Ok(()) +} + +// use smoltcp::time::Instant; + +// pub fn test_veth_loop() { +// let mut dev_map = NET_DEVICES.write_irqsave(); + +// let veth0 = dev_map +// .iter() +// .find(|(_, dev)| dev.iface_name() == "veth0") +// .map(|(_, dev)| dev.clone()) +// .and_then(|dev| dev.as_any_ref().downcast_ref::()) +// .expect("veth0 not found"); + +// let veth1 = dev_map +// .iter() +// .find(|(_, dev)| dev.iface_name() == "veth1") +// .map(|(_, dev)| dev.clone()) +// .and_then(|dev| dev.as_any_ref().downcast_ref::()) +// .expect("veth1 not found"); + +// // veth0 → veth1 +// { +// let frame: &[u8] = &[ +// 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Destination MAC +// 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, // Source MAC +// 0x08, 0x00, // Ethertype: IPv4 +// 0xde, 0xad, 0xbe, 0xef, // Payload +// ]; +// let tx = veth0.bridge_transmit(frame); + +// veth1.poll(); // 触发接收 + +// veth1 +// .driver +// .inner +// .lock() +// .rx_queue +// .pop_front() +// .map(|buf| { +// log::info!("[veth1 recv]: {:?}", &buf); +// }) +// .unwrap_or_else(|| { +// log::warn!("[veth1 recv]: nothing received"); +// }); +// } + + // // veth1 → veth0 + // { + // let tx = veth1.bridge_transmit(Instant::now()).unwrap(); + // let _ = tx.consume(32, |buf| { + // buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x77]); + // buf[6..12].copy_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06]); + // buf[12..14].copy_from_slice(&(0x0806u16.to_be_bytes())); // ethertype: ARP + // buf[14..].copy_from_slice(&[0xca, 0xfe, 0xba, 0xbe]); + // }); + + // veth0.poll(); + + // if let Some((rx, _)) = veth0.driver.force_get_mut().receive(Instant::now()) { + // rx.consume(|buf| { + // log::info!("[veth0 recv]: {:?}", &buf); + // }); + // } else { + // log::warn!("[veth0 recv]: nothing received"); + // } + // } +// } diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 88e49653e..82649d15d 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -43,7 +43,7 @@ fn dhcp_query() -> Result<(), SystemError> { sockets().remove(dhcp_handle); }); - const DHCP_TRY_ROUND: u8 = 100; + const DHCP_TRY_ROUND: u8 = 0; for i in 0..DHCP_TRY_ROUND { log::debug!("DHCP try round: {}", i); net_face.poll(); diff --git a/user/apps/test-veth/.gitignore b/user/apps/test-veth/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/test-veth/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/test-veth/Cargo.toml b/user/apps/test-veth/Cargo.toml new file mode 100644 index 000000000..d60adbf3c --- /dev/null +++ b/user/apps/test-veth/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "test-veth" +version = "0.1.0" +edition = "2021" +description = "测试veth pair" +authors = [ "sparkzky " ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6", "medium-ip"]} diff --git a/user/apps/test-veth/Makefile b/user/apps/test-veth/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/test-veth/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs new file mode 100644 index 000000000..74af43647 --- /dev/null +++ b/user/apps/test-veth/src/main.rs @@ -0,0 +1,217 @@ +// src/main.rs +use smoltcp::phy::{Device, DeviceCapabilities, RxToken, TxToken}; +use smoltcp::time::Instant; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +// 模拟 veth pair 中的一个端点 +pub struct VethInner { + queue: VecDeque>, + peer: Option>>, +} + +impl VethInner { + pub fn new() -> Self { + Self { + queue: VecDeque::new(), + peer: None, + } + } + + pub fn set_peer(&mut self, peer: Arc>) { + self.peer = Some(peer); + } + + pub fn send_to_peer(&self, buf: Vec) { + if let Some(peer) = &self.peer { + peer.lock().unwrap().queue.push_back(buf); + } + } + + pub fn recv(&mut self) -> Option> { + self.queue.pop_front() + } +} + +#[derive(Clone)] +pub struct VethDriver { + inner: Arc>, +} + +impl VethDriver { + pub fn new_pair() -> (Self, Self) { + let a = Arc::new(Mutex::new(VethInner::new())); + let b = Arc::new(Mutex::new(VethInner::new())); + a.lock().unwrap().set_peer(b.clone()); + b.lock().unwrap().set_peer(a.clone()); + (Self { inner: a }, Self { inner: b }) + } +} + +pub struct VethTxToken { + driver: VethDriver, +} + +impl TxToken for VethTxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = vec![0u8; len]; + let result = f(&mut buffer); + self.driver.inner.lock().unwrap().send_to_peer(buffer); + result + } +} + +pub struct VethRxToken { + buffer: Vec, +} + +impl RxToken for VethRxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(&self.buffer) + } +} + +impl Device for VethDriver { + type RxToken<'a> = VethRxToken; + type TxToken<'a> = VethTxToken; + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1500; + caps.medium = smoltcp::phy::Medium::Ethernet; + caps + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut inner = self.inner.lock().unwrap(); + if let Some(buf) = inner.recv() { + Some(( + VethRxToken { buffer: buf }, + VethTxToken { + driver: self.clone(), + }, + )) + } else { + None + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option> { + Some(VethTxToken { + driver: self.clone(), + }) + } +} + +// fn main() { +// let (mut veth0, mut veth1) = VethDriver::new_pair(); + +// // veth0 发,veth1 收 +// println!("--- veth0 → veth1 ---"); +// if let Some(tx) = veth0.transmit(Instant::from_millis(0)) { +// tx.consume(32, |buf| { +// buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); +// buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); +// buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); +// buf[14..].copy_from_slice(b"hello veth1! "); +// }); +// } + +// if let Some((rx, _tx)) = veth1.receive(Instant::from_millis(0)) { +// rx.consume(|buf| { +// println!("veth1 received: {:02x?}", buf); +// }); +// } else { +// println!("veth1 received nothing"); +// } + +// // veth1 发,veth0 收 +// println!("--- veth1 → veth0 ---"); +// if let Some(tx) = veth1.transmit(Instant::from_millis(0)) { +// tx.consume(28, |buf| { +// buf[..6].copy_from_slice(&[0xde, 0xad, 0xbe, 0xef, 0xde, 0xad]); +// buf[6..12].copy_from_slice(&[0xca, 0xfe, 0xba, 0xbe, 0xca, 0xfe]); +// buf[12..14].copy_from_slice(&(0x0806u16.to_be_bytes())); +// buf[14..].copy_from_slice(b"yo veth0! "); +// }); +// } + +// if let Some((rx, _tx)) = veth0.receive(Instant::from_millis(0)) { +// rx.consume(|buf| { +// println!("veth0 received: {:02x?}", buf); +// }); +// } else { +// println!("veth0 received nothing"); +// } +// } + +fn main() { + let (mut veth0, veth1) = VethDriver::new_pair(); + let (veth3, mut veth4) = VethDriver::new_pair(); + + let mut bridge = BridgeDevice::new(); + bridge.add_port(veth1.clone()); + bridge.add_port(veth3.clone()); + + // veth0 → bridge → veth1 & veth3(→ veth4) + println!("--- veth0 → bridge (→ veth1, veth3) ---"); + if let Some(tx) = veth0.transmit(Instant::from_millis(0)) { + tx.consume(32, |buf| { + buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dst MAC + buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); // src MAC + buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); // Ethertype + buf[14..].copy_from_slice(b"hello bridge world"); // payload + + bridge.handle_frame(&veth0, &buf); + }); + } + + if let Some((rx, _tx)) = veth1.clone().receive(Instant::from_millis(0)) { + rx.consume(|buf| { + println!("veth1 received: {:02x?}", buf); + }); + } else { + println!("veth1 received nothing"); + } + + if let Some((rx, _tx)) = veth4.receive(Instant::from_millis(0)) { + rx.consume(|buf| { + println!("veth4 received: {:02x?}", buf); + }); + } else { + println!("veth4 received nothing"); + } +} + +// 网桥设备:只做广播转发(无 MAC 学习) +pub struct BridgeDevice { + pub ports: Vec, +} + +impl BridgeDevice { + pub fn new() -> Self { + BridgeDevice { ports: Vec::new() } + } + + pub fn add_port(&mut self, port: VethDriver) { + self.ports.push(port); + } + + pub fn remove_port(&mut self, port: &VethDriver) { + self.ports.retain(|p| !Arc::ptr_eq(&p.inner, &port.inner)); + } + + pub fn handle_frame(&mut self, src_if: &VethDriver, frame: &[u8]) { + for port in &self.ports { + if !Arc::ptr_eq(&port.inner, &src_if.inner) { + port.inner.lock().unwrap().send_to_peer(frame.to_vec()); + } + } + } +} From 4450c3ca0fdb1ff20572fa8df31f64dbbc2e6b2a Mon Sep 17 00:00:00 2001 From: sparkzky Date: Wed, 2 Jul 2025 15:47:37 +0800 Subject: [PATCH 02/36] =?UTF-8?q?feat(net):=20=E5=AE=8C=E5=96=84=E4=B8=80?= =?UTF-8?q?=E4=B8=8B=E5=B7=B2=E6=9C=89=E7=9A=84bridge=E4=BB=A5=E5=8F=8Avet?= =?UTF-8?q?h=E8=AE=BE=E5=A4=87,=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E8=B0=83=E8=AF=95=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 180 +++++++--- kernel/src/driver/net/veth.rs | 138 +++----- kernel/src/net/net_core.rs | 2 +- kernel/src/net/socket/inet/datagram/mod.rs | 23 ++ user/apps/test-veth/src/main.rs | 393 +++++++++++---------- user/dadk/config/test_veth_0_1_0.toml | 36 ++ 6 files changed, 455 insertions(+), 317 deletions(-) create mode 100644 user/dadk/config/test_veth_0_1_0.toml diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index 19bcf1913..26c00eb03 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -1,101 +1,187 @@ -use crate::{libs::rwlock::RwLock, time::Instant}; -use alloc::{collections::BTreeMap, string::String, sync::Arc, vec::Vec}; +use crate::{ + libs::{rwlock::RwLock, spinlock::SpinLock}, + time::Instant, +}; +use alloc::{ + string::String, + sync::{Arc, Weak}, + vec::Vec, +}; +use hashbrown::HashMap; use smoltcp::wire::EthernetAddress; const MAC_ENTRY_TIMEOUT: u64 = 60_000; // 60秒 struct MacEntry { - port: String, + port: Arc, + pub(self) record: RwLock, + // 存活时间(动态学习的老化) +} + +impl MacEntry { + pub fn new(port: Arc) -> Self { + MacEntry { + port, + record: RwLock::new(MacEntryRecord { + last_seen: Instant::now(), + }), + } + } + + /// 更新最后一次被看到的时间为现在 + pub(self) fn update_last_seen(&self) { + self.record.write_irqsave().last_seen = Instant::now(); + } +} + +struct MacEntryRecord { last_seen: Instant, } +/// 代表一个加入bridge的网络接口 pub struct BridgePort { name: String, bridge_enable: Arc, + bridge: Weak, + // 当前接口状态?forwarding, learning, blocking? + // mac mtu信息 } +impl BridgePort { + fn new(device: Arc) -> Self { + BridgePort { + name: device.name(), + bridge_enable: device, + bridge: Weak::new(), + } + } +} + +pub struct LockedBridgePort(pub SpinLock); +unsafe impl Send for LockedBridgePort {} +unsafe impl Sync for LockedBridgePort {} + pub struct Bridge { name: String, - ports: RwLock>>, - mac_table: RwLock>, + ports: RwLock>>, + // FDB(Forwarding Database) + mac_table: RwLock>, + // 配置参数,比如aging timeout, max age, hello time, forward delay } +unsafe impl Send for Bridge {} +unsafe impl Sync for Bridge {} + impl Bridge { pub fn new(name: &str) -> Self { Self { name: name.into(), - ports: RwLock::new(BTreeMap::new()), - mac_table: RwLock::new(BTreeMap::new()), + ports: RwLock::new(Vec::new()), + mac_table: RwLock::new(HashMap::new()), } } - pub fn add_port(&self, port: Arc) { - let port_name = port.name(); - let port_obj = Arc::new(BridgePort { - name: port_name.clone(), - bridge_enable: port.clone(), - }); - - self.ports - .write_irqsave() - .insert(port_name.clone(), port_obj); + pub fn add_port(&self, port: Arc) { + self.ports.write_irqsave().push(port); } pub fn handle_frame( &self, - src_mac: EthernetAddress, + ingress_port: Arc, + frame: &[u8], dst_mac: EthernetAddress, - frame: Vec, - ingress: &str, + src_mac: EthernetAddress, ) { - // MAC 学习 - self.mac_table.write_irqsave().insert( - src_mac, - MacEntry { - port: ingress.into(), - last_seen: Instant::now(), - }, - ); + let guard = self.mac_table.write_irqsave(); + if let Some(entry) = guard.get(&src_mac) { + entry.update_last_seen(); + } else { + // MAC 学习 + self.mac_table + .write_irqsave() + .insert(src_mac, MacEntry::new(ingress_port.clone())); + } - let ports = self.ports.read(); - if dst_mac == EthernetAddress::BROADCAST { + if dst_mac.is_broadcast() { // 广播 - for (name, port) in ports.iter() { - if name != ingress { - Self::transmit_to(port, &frame); - } - } + self.flood(&ingress_port, frame); } else { // 单播 - if let Some(out_port) = self.mac_table.read().get(&dst_mac) { - if let Some(port) = ports.get(out_port.port.as_str()) { - Self::transmit_to(port, &frame); + if let Some(entry) = self.mac_table.read().get(&dst_mac) { + let target_port = &entry.port; + // 避免发回自己 + if !Arc::ptr_eq(target_port, &ingress_port) { + target_port + .0 + .lock() + .bridge_enable + .receive_from_bridge(frame); } } else { // 未知单播 → 广播 - for (name, port) in ports.iter() { - if name != ingress { - Self::transmit_to(port, &frame); - } - } + self.flood(&ingress_port, frame); + } + } + } + + fn flood(&self, except_port: &Arc, frame: &[u8]) { + for port in self.ports.read().iter() { + if !Arc::ptr_eq(port, except_port) { + port.0.lock().bridge_enable.receive_from_bridge(frame); } } } fn transmit_to(port: &BridgePort, frame: &[u8]) { - port.bridge_enable.bridge_transmit(frame); + port.bridge_enable.receive_from_bridge(frame); } pub fn sweep_mac_table(&self) { let now = Instant::now(); self.mac_table.write_irqsave().retain(|_mac, entry| { - now.duration_since(entry.last_seen).unwrap().total_millis() < MAC_ENTRY_TIMEOUT + now.duration_since(entry.record.read().last_seen) + .unwrap() + .total_millis() + < MAC_ENTRY_TIMEOUT }); } } +#[derive(Clone)] +pub struct BridgeDriver { + pub inner: Arc, +} + +impl BridgeDriver { + pub fn new(name: &str) -> Self { + BridgeDriver { + inner: Arc::new(Bridge::new(name)), + } + } + + pub fn add_port(&self, port: Arc) { + let bridge_port = Arc::new(LockedBridgePort(SpinLock::new(BridgePort::new(port)))); + self.inner.add_port(bridge_port.clone()); + let mut guard = bridge_port.0.lock(); + guard.bridge = Arc::downgrade(&self.inner); + } + + pub fn handle_frame(&self, ingress_port: Arc, frame: &[u8]) { + if frame.len() < 14 { + return; // 非法以太网帧 + } + + let dst_mac = EthernetAddress::from_bytes(&frame[0..6]); + let src_mac = EthernetAddress::from_bytes(&frame[6..12]); + + self.inner + .handle_frame(ingress_port, frame, dst_mac, src_mac); + } +} + +/// 可供桥接设备应该实现的 trait pub trait BridgeEnableDevice { fn name(&self) -> String; - fn bridge_transmit(&self, frame: &[u8]); - // fn bridge_receive(&self, frame: &[u8]) ; + fn receive_from_bridge(&self, frame: &[u8]); + fn mac_addr(&self) -> EthernetAddress; } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index ad8e82320..9ef644671 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -33,13 +33,15 @@ use super::{Iface, IfaceCommon}; // const DEVICE_NAME: &str = "veth"; pub struct Veth { + name: String, rx_queue: VecDeque>, peer: Option>>, } impl Veth { - pub fn new() -> Self { + pub fn new(name: String) -> Self { Veth { + name, rx_queue: VecDeque::new(), peer: None, } @@ -50,14 +52,30 @@ impl Veth { } pub fn send_to_peer(&self, data: Vec) { + // log::info!("{} sending", self.name); if let Some(peer) = &self.peer { - peer.lock().rx_queue.push_back(data); + let mut peer = peer.lock(); + peer.rx_queue.push_back(data); + // log::info!( + // "{} sending data to peer {}, peer current rx_queue: {:?}", + // self.name, + // peer.name(), + // peer.rx_queue + // ); } } - pub fn recv(&mut self) -> Option> { + pub fn recv_from_peer(&mut self) -> Option> { + // log::info!( + // "{} Receiving data from peer, current rx_queue: {:?}", + // self.name, + // self.rx_queue + // ); self.rx_queue.pop_front() } + pub fn name(&self) -> &str { + &self.name + } } #[derive(Clone)] @@ -66,9 +84,9 @@ pub struct VethDriver { } impl VethDriver { - pub fn new_pair() -> (Self, Self) { - let dev1 = Arc::new(SpinLock::new(Veth::new())); - let dev2 = Arc::new(SpinLock::new(Veth::new())); + pub fn new_pair(name1: &str, name2: &str) -> (Self, Self) { + let dev1 = Arc::new(SpinLock::new(Veth::new(name1.to_string()))); + let dev2 = Arc::new(SpinLock::new(Veth::new(name2.to_string()))); dev1.lock().set_peer(dev2.clone()); dev2.lock().set_peer(dev1.clone()); @@ -148,7 +166,8 @@ impl phy::Device for VethDriver { _timestamp: smoltcp::time::Instant, ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { let mut guard = self.inner.lock(); - guard.recv().map(|buf| { + guard.recv_from_peer().map(|buf| { + // log::info!("VethDriver received data: {:?}", buf); ( VethRxToken { buffer: buf }, VethTxToken { @@ -221,10 +240,14 @@ impl VethInterface { NET_DEVICES .write_irqsave() .insert(device.nic_id(), device.clone()); - log::debug!( - "VethInterface created, devices: {:?}", - NET_DEVICES.read().keys() - ); + // log::debug!( + // "VethInterface created, devices: {:?}", + // NET_DEVICES + // .read() + // .values() + // .map(|d| d.iface_name()) + // .collect::>() + // ); register_netdevice(device.clone()).expect("register veth device failed"); device @@ -234,9 +257,8 @@ impl VethInterface { self.inner.lock() } - pub fn update_ip_addrs(&self, addr: IpAddress, cidr: u8) { + pub fn update_ip_addrs(&self, addr: IpAddress, cidr: IpCidr) { let iface = &mut self.common.smol_iface.lock(); - let cidr = IpCidr::new(addr, cidr); iface.update_ip_addrs(|ip_addrs| { ip_addrs.push(cidr).expect("Push ipCidr failed: full"); }); @@ -252,6 +274,8 @@ impl VethInterface { }) .expect("Add default ipv4 route failed: full"); }); + + log::info!("VethInterface {} updated IP address: {}", self.name, addr); } } @@ -391,7 +415,7 @@ impl Iface for VethInterface { } impl BridgeEnableDevice for VethInterface { - fn bridge_transmit(&self, frame: &[u8]) { + fn receive_from_bridge(&self, frame: &[u8]) { let driver = self.driver.force_get_mut(); let token = VethTxToken { driver: driver.clone(), @@ -403,6 +427,9 @@ impl BridgeEnableDevice for VethInterface { fn name(&self) -> String { self.name.clone() } + fn mac_addr(&self) -> EthernetAddress { + self.mac() + } // fn bridge_receive(&self, frame: &[u8]) { // let driver = self.driver.force_get_mut(); // let token = VethRxToken { @@ -412,9 +439,19 @@ impl BridgeEnableDevice for VethInterface { } pub fn veth_probe() { - let (drv0, drv1) = VethDriver::new_pair(); - VethInterface::new("veth0", drv0); - VethInterface::new("veth1", drv1); + let name1 = "veth0"; + let name2 = "veth1"; + let (drv0, drv1) = VethDriver::new_pair(name1, name2); + let iface1 = VethInterface::new(name1, drv0); + let iface2 = VethInterface::new(name2, drv1); + + let addr1 = IpAddress::v4(10, 0, 0, 1); + let cidr1 = IpCidr::new(addr1, 24); + iface1.update_ip_addrs(addr1, cidr1); + + let addr2 = IpAddress::v4(10, 0, 0, 2); + let cidr2 = IpCidr::new(addr2, 24); + iface2.update_ip_addrs(addr2, cidr2); } #[unified_init(INITCALL_DEVICE)] @@ -423,70 +460,3 @@ pub fn veth_init() -> Result<(), SystemError> { log::info!("Veth pair initialized."); Ok(()) } - -// use smoltcp::time::Instant; - -// pub fn test_veth_loop() { -// let mut dev_map = NET_DEVICES.write_irqsave(); - -// let veth0 = dev_map -// .iter() -// .find(|(_, dev)| dev.iface_name() == "veth0") -// .map(|(_, dev)| dev.clone()) -// .and_then(|dev| dev.as_any_ref().downcast_ref::()) -// .expect("veth0 not found"); - -// let veth1 = dev_map -// .iter() -// .find(|(_, dev)| dev.iface_name() == "veth1") -// .map(|(_, dev)| dev.clone()) -// .and_then(|dev| dev.as_any_ref().downcast_ref::()) -// .expect("veth1 not found"); - -// // veth0 → veth1 -// { -// let frame: &[u8] = &[ -// 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, // Destination MAC -// 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, // Source MAC -// 0x08, 0x00, // Ethertype: IPv4 -// 0xde, 0xad, 0xbe, 0xef, // Payload -// ]; -// let tx = veth0.bridge_transmit(frame); - -// veth1.poll(); // 触发接收 - -// veth1 -// .driver -// .inner -// .lock() -// .rx_queue -// .pop_front() -// .map(|buf| { -// log::info!("[veth1 recv]: {:?}", &buf); -// }) -// .unwrap_or_else(|| { -// log::warn!("[veth1 recv]: nothing received"); -// }); -// } - - // // veth1 → veth0 - // { - // let tx = veth1.bridge_transmit(Instant::now()).unwrap(); - // let _ = tx.consume(32, |buf| { - // buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x77]); - // buf[6..12].copy_from_slice(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06]); - // buf[12..14].copy_from_slice(&(0x0806u16.to_be_bytes())); // ethertype: ARP - // buf[14..].copy_from_slice(&[0xca, 0xfe, 0xba, 0xbe]); - // }); - - // veth0.poll(); - - // if let Some((rx, _)) = veth0.driver.force_get_mut().receive(Instant::now()) { - // rx.consume(|buf| { - // log::info!("[veth0 recv]: {:?}", &buf); - // }); - // } else { - // log::warn!("[veth0 recv]: nothing received"); - // } - // } -// } diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 82649d15d..f4603b6d5 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -43,7 +43,7 @@ fn dhcp_query() -> Result<(), SystemError> { sockets().remove(dhcp_handle); }); - const DHCP_TRY_ROUND: u8 = 0; + const DHCP_TRY_ROUND: u8 = 10; for i in 0..DHCP_TRY_ROUND { log::debug!("DHCP try round: {}", i); net_face.poll(); diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 337042e75..5756d70fe 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -137,6 +137,29 @@ impl UdpSocket { }; return result; } + + pub fn event(&self) -> EPollEventType { + // log::info!("UdpSocket::event"); + let mut event = EPollEventType::empty(); + match self.inner.read().as_ref().unwrap() { + UdpInner::Unbound(_) => { + event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); + } + UdpInner::Bound(bound) => { + let (can_recv, can_send) = + bound.with_socket(|socket| (socket.can_recv(), socket.can_send())); + + if can_recv { + event.insert(EP::EPOLLIN | EP::EPOLLRDNORM); + } + + if can_send { + event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); + } + } + } + return event; + } } impl Socket for UdpSocket { diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs index 74af43647..f78f70751 100644 --- a/user/apps/test-veth/src/main.rs +++ b/user/apps/test-veth/src/main.rs @@ -1,129 +1,136 @@ -// src/main.rs -use smoltcp::phy::{Device, DeviceCapabilities, RxToken, TxToken}; -use smoltcp::time::Instant; -use std::collections::VecDeque; -use std::sync::{Arc, Mutex}; - -// 模拟 veth pair 中的一个端点 -pub struct VethInner { - queue: VecDeque>, - peer: Option>>, -} +// // src/main.rs +// use smoltcp::phy::{Device, DeviceCapabilities, RxToken, TxToken}; +// use smoltcp::time::Instant; +// use std::collections::VecDeque; +// use std::sync::{Arc, Mutex}; + +// // 模拟 veth pair 中的一个端点 +// pub struct VethInner { +// queue: VecDeque>, +// peer: Option>>, +// } -impl VethInner { - pub fn new() -> Self { - Self { - queue: VecDeque::new(), - peer: None, - } - } - - pub fn set_peer(&mut self, peer: Arc>) { - self.peer = Some(peer); - } - - pub fn send_to_peer(&self, buf: Vec) { - if let Some(peer) = &self.peer { - peer.lock().unwrap().queue.push_back(buf); - } - } - - pub fn recv(&mut self) -> Option> { - self.queue.pop_front() - } -} +// impl VethInner { +// pub fn new() -> Self { +// Self { +// queue: VecDeque::new(), +// peer: None, +// } +// } -#[derive(Clone)] -pub struct VethDriver { - inner: Arc>, -} +// pub fn set_peer(&mut self, peer: Arc>) { +// self.peer = Some(peer); +// } -impl VethDriver { - pub fn new_pair() -> (Self, Self) { - let a = Arc::new(Mutex::new(VethInner::new())); - let b = Arc::new(Mutex::new(VethInner::new())); - a.lock().unwrap().set_peer(b.clone()); - b.lock().unwrap().set_peer(a.clone()); - (Self { inner: a }, Self { inner: b }) - } -} +// pub fn send_to_peer(&self, buf: Vec) { +// if let Some(peer) = &self.peer { +// peer.lock().unwrap().queue.push_back(buf); +// } +// } -pub struct VethTxToken { - driver: VethDriver, -} +// pub fn recv(&mut self) -> Option> { +// self.queue.pop_front() +// } +// } -impl TxToken for VethTxToken { - fn consume(self, len: usize, f: F) -> R - where - F: FnOnce(&mut [u8]) -> R, - { - let mut buffer = vec![0u8; len]; - let result = f(&mut buffer); - self.driver.inner.lock().unwrap().send_to_peer(buffer); - result - } -} +// #[derive(Clone)] +// pub struct VethDriver { +// inner: Arc>, +// } -pub struct VethRxToken { - buffer: Vec, -} +// impl VethDriver { +// pub fn new_pair() -> (Self, Self) { +// let a = Arc::new(Mutex::new(VethInner::new())); +// let b = Arc::new(Mutex::new(VethInner::new())); +// a.lock().unwrap().set_peer(b.clone()); +// b.lock().unwrap().set_peer(a.clone()); +// (Self { inner: a }, Self { inner: b }) +// } +// } -impl RxToken for VethRxToken { - fn consume(self, f: F) -> R - where - F: FnOnce(&[u8]) -> R, - { - f(&self.buffer) - } -} +// pub struct VethTxToken { +// driver: VethDriver, +// } -impl Device for VethDriver { - type RxToken<'a> = VethRxToken; - type TxToken<'a> = VethTxToken; - - fn capabilities(&self) -> DeviceCapabilities { - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1500; - caps.medium = smoltcp::phy::Medium::Ethernet; - caps - } - - fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let mut inner = self.inner.lock().unwrap(); - if let Some(buf) = inner.recv() { - Some(( - VethRxToken { buffer: buf }, - VethTxToken { - driver: self.clone(), - }, - )) - } else { - None - } - } - - fn transmit(&mut self, _timestamp: Instant) -> Option> { - Some(VethTxToken { - driver: self.clone(), - }) - } -} +// impl TxToken for VethTxToken { +// fn consume(self, len: usize, f: F) -> R +// where +// F: FnOnce(&mut [u8]) -> R, +// { +// let mut buffer = vec![0u8; len]; +// let result = f(&mut buffer); +// self.driver.inner.lock().unwrap().send_to_peer(buffer); +// result +// } +// } + +// pub struct VethRxToken { +// buffer: Vec, +// } + +// impl RxToken for VethRxToken { +// fn consume(self, f: F) -> R +// where +// F: FnOnce(&[u8]) -> R, +// { +// f(&self.buffer) +// } +// } + +// impl Device for VethDriver { +// type RxToken<'a> = VethRxToken; +// type TxToken<'a> = VethTxToken; + +// fn capabilities(&self) -> DeviceCapabilities { +// let mut caps = DeviceCapabilities::default(); +// caps.max_transmission_unit = 1500; +// caps.medium = smoltcp::phy::Medium::Ethernet; +// caps +// } + +// fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { +// let mut inner = self.inner.lock().unwrap(); +// if let Some(buf) = inner.recv() { +// Some(( +// VethRxToken { buffer: buf }, +// VethTxToken { +// driver: self.clone(), +// }, +// )) +// } else { +// None +// } +// } + +// fn transmit(&mut self, _timestamp: Instant) -> Option> { +// Some(VethTxToken { +// driver: self.clone(), +// }) +// } +// } // fn main() { -// let (mut veth0, mut veth1) = VethDriver::new_pair(); +// let (mut veth0, veth1) = VethDriver::new_pair(); +// let (veth3, mut veth4) = VethDriver::new_pair(); + +// let mut bridge = BridgeDevice::new(); +// bridge.add_port(veth1.clone()); +// bridge.add_port(veth3.clone()); -// // veth0 发,veth1 收 -// println!("--- veth0 → veth1 ---"); +// // veth0 → bridge → veth1 & veth3(→ veth4) +// println!("--- veth0 → bridge (→ veth1, veth3) ---"); // if let Some(tx) = veth0.transmit(Instant::from_millis(0)) { // tx.consume(32, |buf| { -// buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); -// buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); -// buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); -// buf[14..].copy_from_slice(b"hello veth1! "); +// buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dst MAC +// buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); // src MAC +// buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); // Ethertype +// buf[14..].copy_from_slice(b"hello bridge world"); // payload + +// bridge.handle_frame(&veth0, &buf); // }); // } -// if let Some((rx, _tx)) = veth1.receive(Instant::from_millis(0)) { +// if let Some((rx, _tx)) = veth1.clone().receive(Instant::from_millis(0)) { // rx.consume(|buf| { // println!("veth1 received: {:02x?}", buf); // }); @@ -131,87 +138,103 @@ impl Device for VethDriver { // println!("veth1 received nothing"); // } -// // veth1 发,veth0 收 -// println!("--- veth1 → veth0 ---"); -// if let Some(tx) = veth1.transmit(Instant::from_millis(0)) { -// tx.consume(28, |buf| { -// buf[..6].copy_from_slice(&[0xde, 0xad, 0xbe, 0xef, 0xde, 0xad]); -// buf[6..12].copy_from_slice(&[0xca, 0xfe, 0xba, 0xbe, 0xca, 0xfe]); -// buf[12..14].copy_from_slice(&(0x0806u16.to_be_bytes())); -// buf[14..].copy_from_slice(b"yo veth0! "); -// }); -// } - -// if let Some((rx, _tx)) = veth0.receive(Instant::from_millis(0)) { +// if let Some((rx, _tx)) = veth4.receive(Instant::from_millis(0)) { // rx.consume(|buf| { -// println!("veth0 received: {:02x?}", buf); +// println!("veth4 received: {:02x?}", buf); // }); // } else { -// println!("veth0 received nothing"); +// println!("veth4 received nothing"); // } // } -fn main() { - let (mut veth0, veth1) = VethDriver::new_pair(); - let (veth3, mut veth4) = VethDriver::new_pair(); - - let mut bridge = BridgeDevice::new(); - bridge.add_port(veth1.clone()); - bridge.add_port(veth3.clone()); - - // veth0 → bridge → veth1 & veth3(→ veth4) - println!("--- veth0 → bridge (→ veth1, veth3) ---"); - if let Some(tx) = veth0.transmit(Instant::from_millis(0)) { - tx.consume(32, |buf| { - buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dst MAC - buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); // src MAC - buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); // Ethertype - buf[14..].copy_from_slice(b"hello bridge world"); // payload - - bridge.handle_frame(&veth0, &buf); - }); - } - - if let Some((rx, _tx)) = veth1.clone().receive(Instant::from_millis(0)) { - rx.consume(|buf| { - println!("veth1 received: {:02x?}", buf); - }); - } else { - println!("veth1 received nothing"); - } - - if let Some((rx, _tx)) = veth4.receive(Instant::from_millis(0)) { - rx.consume(|buf| { - println!("veth4 received: {:02x?}", buf); - }); - } else { - println!("veth4 received nothing"); - } -} +// // 网桥设备:只做广播转发(无 MAC 学习) +// pub struct BridgeDevice { +// pub ports: Vec, +// } -// 网桥设备:只做广播转发(无 MAC 学习) -pub struct BridgeDevice { - pub ports: Vec, -} +// impl BridgeDevice { +// pub fn new() -> Self { +// BridgeDevice { ports: Vec::new() } +// } + +// pub fn add_port(&mut self, port: VethDriver) { +// self.ports.push(port); +// } + +// pub fn remove_port(&mut self, port: &VethDriver) { +// self.ports.retain(|p| !Arc::ptr_eq(&p.inner, &port.inner)); +// } + +// pub fn handle_frame(&mut self, src_if: &VethDriver, frame: &[u8]) { +// for port in &self.ports { +// if !Arc::ptr_eq(&port.inner, &src_if.inner) { +// port.inner.lock().unwrap().send_to_peer(frame.to_vec()); +// } +// } +// } +// } + +use std::net::UdpSocket; +use std::str; +use std::thread; +use std::time::Duration; + +fn main() -> std::io::Result<()> { + // 启动 server 线程 + let server_thread = thread::spawn(|| { + let socket = UdpSocket::bind("10.0.0.2:34254") + .expect("Failed to bind to veth1 (10.0.0.2:34254)"); + println!("[server] Listening on 10.0.0.2:34254"); + + let mut buf = [0; 1024]; + let (amt, src) = socket + .recv_from(&mut buf) + .expect("[server] Failed to receive"); + + let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + + println!("[server] Received from {}: {}", src, received_msg); + + socket + .send_to(received_msg.as_bytes(), src) + .expect("[server] Failed to send back"); + println!("[server] Echoed back the message"); + }); + + // 确保 server 已启动(可根据情况适当 sleep) + thread::sleep(Duration::from_millis(200)); + + // 启动 client + let client_thread = thread::spawn(|| { + let socket = UdpSocket::bind("10.0.0.1:0").expect("Failed to bind to veth0 (10.0.0.1)"); + socket + .connect("10.0.0.2:34254") + .expect("Failed to connect to 10.0.0.2:34254"); + + let msg = "Hello from veth0!"; + socket + .send(msg.as_bytes()) + .expect("[client] Failed to send"); + + println!("[client] Sent: {}", msg); + + let mut buf = [0; 1024]; + let (amt, _src) = socket + .recv_from(&mut buf) + .expect("[client] Failed to receive"); + + let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + + println!("[client] Received echo: {}", received_msg); + + assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); + }); + + // 等待两个线程结束 + server_thread.join().unwrap(); + client_thread.join().unwrap(); + + println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); -impl BridgeDevice { - pub fn new() -> Self { - BridgeDevice { ports: Vec::new() } - } - - pub fn add_port(&mut self, port: VethDriver) { - self.ports.push(port); - } - - pub fn remove_port(&mut self, port: &VethDriver) { - self.ports.retain(|p| !Arc::ptr_eq(&p.inner, &port.inner)); - } - - pub fn handle_frame(&mut self, src_if: &VethDriver, frame: &[u8]) { - for port in &self.ports { - if !Arc::ptr_eq(&port.inner, &src_if.inner) { - port.inner.lock().unwrap().send_to_peer(frame.to_vec()); - } - } - } + Ok(()) } diff --git a/user/dadk/config/test_veth_0_1_0.toml b/user/dadk/config/test_veth_0_1_0.toml new file mode 100644 index 000000000..f8571b15f --- /dev/null +++ b/user/dadk/config/test_veth_0_1_0.toml @@ -0,0 +1,36 @@ +# 用户程序名称 +name = "test-veth" +# 版本号 +version = "0.1.0" +# 用户程序描述信息 +description = "test for veth interface" +# (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 +build-once = false +# (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 +install-once = false +# 目标架构 +# 可选值:"x86_64", "aarch64", "riscv64" +target-arch = ["x86_64"] +# 任务源 +[task-source] +# 构建类型 +# 可选值:"build-from_source", "install-from-prebuilt" +type = "build-from-source" +# 构建来源 +# "build_from_source" 可选值:"git", "local", "archive" +# "install_from_prebuilt" 可选值:"local", "archive" +source = "local" +# 路径或URL +source-path = "user/apps/test-veth" +# 构建相关信息 +[build] +# (可选)构建命令 +build-command = "make install" +# 安装相关信息 +[install] +# (可选)安装到DragonOS的路径 +in-dragonos-path = "/" +# 清除相关信息 +[clean] +# (可选)清除命令 +clean-command = "make clean" From cbc2f78b3a1745a39cecfe0306b82664ccf79e64 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 6 Jul 2025 22:11:20 +0800 Subject: [PATCH 03/36] =?UTF-8?q?feat(net):=20=E5=AE=8C=E5=96=84veth?= =?UTF-8?q?=E7=BD=91=E5=8D=A1=E9=A9=B1=E5=8A=A8,=E8=83=BD=E9=80=9A?= =?UTF-8?q?=E8=BF=87=E6=B5=8B=E4=BE=8B;=E7=AE=80=E5=8D=95=E4=BF=AE?= =?UTF-8?q?=E6=94=B9vridge=E8=AE=BE=E5=A4=87,=E5=B0=9A=E6=9C=AA=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 109 +++++----- kernel/src/driver/net/mod.rs | 19 +- kernel/src/driver/net/veth.rs | 231 ++++++++++++++------- kernel/src/driver/net/virtio_net.rs | 18 +- kernel/src/net/socket/inet/datagram/mod.rs | 25 ++- user/apps/test-veth/src/main.rs | 4 +- 6 files changed, 262 insertions(+), 144 deletions(-) diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index 26c00eb03..da93a0a32 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -1,25 +1,22 @@ use crate::{ + driver::net::Iface, libs::{rwlock::RwLock, spinlock::SpinLock}, time::Instant, }; -use alloc::{ - string::String, - sync::{Arc, Weak}, - vec::Vec, -}; +use alloc::{collections::BTreeMap, string::String, sync::Arc}; use hashbrown::HashMap; use smoltcp::wire::EthernetAddress; const MAC_ENTRY_TIMEOUT: u64 = 60_000; // 60秒 struct MacEntry { - port: Arc, + port: Arc, pub(self) record: RwLock, // 存活时间(动态学习的老化) } impl MacEntry { - pub fn new(port: Arc) -> Self { + pub fn new(port: Arc) -> Self { MacEntry { port, record: RwLock::new(MacEntryRecord { @@ -39,95 +36,91 @@ struct MacEntryRecord { } /// 代表一个加入bridge的网络接口 +#[derive(Clone)] pub struct BridgePort { - name: String, bridge_enable: Arc, - bridge: Weak, + bridge: BridgeDriver, // 当前接口状态?forwarding, learning, blocking? // mac mtu信息 } impl BridgePort { - fn new(device: Arc) -> Self { + fn new(device: Arc, bridge: BridgeDriver) -> Self { BridgePort { - name: device.name(), bridge_enable: device, - bridge: Weak::new(), + bridge, } } -} -pub struct LockedBridgePort(pub SpinLock); -unsafe impl Send for LockedBridgePort {} -unsafe impl Sync for LockedBridgePort {} + fn mac(&self) -> EthernetAddress { + self.bridge_enable.mac() + } +} pub struct Bridge { name: String, - ports: RwLock>>, + // 端口列表,key为MAC地址 + ports: BTreeMap>, // FDB(Forwarding Database) - mac_table: RwLock>, + mac_table: HashMap, // 配置参数,比如aging timeout, max age, hello time, forward delay } -unsafe impl Send for Bridge {} -unsafe impl Sync for Bridge {} - impl Bridge { pub fn new(name: &str) -> Self { Self { name: name.into(), - ports: RwLock::new(Vec::new()), - mac_table: RwLock::new(HashMap::new()), + ports: BTreeMap::new(), + mac_table: HashMap::new(), } } - pub fn add_port(&self, port: Arc) { - self.ports.write_irqsave().push(port); + pub fn add_port(&mut self, port: Arc) { + self.ports.insert(port.mac(), port); + } + + pub fn insert_macentry(&mut self, src_mac: EthernetAddress, port: Arc) { + self.mac_table.insert(src_mac, MacEntry::new(port)); } pub fn handle_frame( - &self, - ingress_port: Arc, + &mut self, + ingress_port: Arc, frame: &[u8], dst_mac: EthernetAddress, src_mac: EthernetAddress, ) { - let guard = self.mac_table.write_irqsave(); - if let Some(entry) = guard.get(&src_mac) { + if let Some(entry) = self.mac_table.get(&src_mac) { entry.update_last_seen(); } else { // MAC 学习 - self.mac_table - .write_irqsave() - .insert(src_mac, MacEntry::new(ingress_port.clone())); + self.insert_macentry(src_mac, ingress_port.clone()); } if dst_mac.is_broadcast() { // 广播 - self.flood(&ingress_port, frame); + self.flood(ingress_port.mac(), frame); } else { // 单播 - if let Some(entry) = self.mac_table.read().get(&dst_mac) { + if let Some(entry) = self.mac_table.get(&dst_mac) { let target_port = &entry.port; // 避免发回自己 if !Arc::ptr_eq(target_port, &ingress_port) { - target_port - .0 - .lock() - .bridge_enable - .receive_from_bridge(frame); + Bridge::transmit_to(target_port, frame); } } else { // 未知单播 → 广播 - self.flood(&ingress_port, frame); + self.flood(ingress_port.mac(), frame); } } + + self.sweep_mac_table(); } - fn flood(&self, except_port: &Arc, frame: &[u8]) { - for port in self.ports.read().iter() { - if !Arc::ptr_eq(port, except_port) { - port.0.lock().bridge_enable.receive_from_bridge(frame); + fn flood(&self, except_mac: EthernetAddress, frame: &[u8]) { + for (mac, port) in self.ports.iter() { + if mac != &except_mac { + Bridge::transmit_to(port, frame); } } } @@ -136,9 +129,9 @@ impl Bridge { port.bridge_enable.receive_from_bridge(frame); } - pub fn sweep_mac_table(&self) { + pub fn sweep_mac_table(&mut self) { let now = Instant::now(); - self.mac_table.write_irqsave().retain(|_mac, entry| { + self.mac_table.retain(|_mac, entry| { now.duration_since(entry.record.read().last_seen) .unwrap() .total_millis() @@ -149,39 +142,43 @@ impl Bridge { #[derive(Clone)] pub struct BridgeDriver { - pub inner: Arc, + pub inner: Arc>, } impl BridgeDriver { pub fn new(name: &str) -> Self { BridgeDriver { - inner: Arc::new(Bridge::new(name)), + inner: Arc::new(SpinLock::new(Bridge::new(name))), } } pub fn add_port(&self, port: Arc) { - let bridge_port = Arc::new(LockedBridgePort(SpinLock::new(BridgePort::new(port)))); - self.inner.add_port(bridge_port.clone()); - let mut guard = bridge_port.0.lock(); - guard.bridge = Arc::downgrade(&self.inner); + let port = BridgePort::new(port, self.clone()); + + let bridge_port = Arc::new(port); + self.inner.lock().add_port(bridge_port.clone()); } - pub fn handle_frame(&self, ingress_port: Arc, frame: &[u8]) { + pub fn handle_frame(&self, ingress_port: Arc, frame: &[u8]) { if frame.len() < 14 { return; // 非法以太网帧 } let dst_mac = EthernetAddress::from_bytes(&frame[0..6]); let src_mac = EthernetAddress::from_bytes(&frame[6..12]); + //todo Frame::new_unchecked self.inner + .lock() .handle_frame(ingress_port, frame, dst_mac, src_mac); } } /// 可供桥接设备应该实现的 trait -pub trait BridgeEnableDevice { - fn name(&self) -> String; +pub trait BridgeEnableDevice: Iface { fn receive_from_bridge(&self, frame: &[u8]); - fn mac_addr(&self) -> EthernetAddress; + fn transmit_to_bridge(&self, frame: &[u8]) { + // 默认实现,子类可以覆盖 + self.receive_from_bridge(frame); + } } diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index b9f355080..0950f0324 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -75,14 +75,15 @@ pub trait Iface: crate::driver::base::device::Device { self.common().iface_id } - /// # `poll` - /// 用于轮询接口的状态。 + fn poll(&self); + + /// # `poll_blocking` + /// 用于在阻塞模式下轮询网卡 /// ## 参数 - /// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集 + /// - `can_recv_fn` :一个函数指针,用于判断是否可以接收数据 /// ## 返回值 - /// - 成功返回 `Ok(())` - /// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞 - fn poll(&self); + /// - 该函数不返回任何值,但会在满足条件时阻塞当前线程,直到可以接收数据。 + fn poll_blocking(&self, _can_recv_fn: &dyn Fn() -> bool) {} /// # `update_ip_addrs` /// 用于更新接口的 IP 地址 @@ -229,6 +230,12 @@ impl IfaceCommon { // drop sockets here to avoid deadlock drop(interface); drop(sockets); + // log::info!( + // "polling iface {}, has_events: {}, poll_at: {:?}", + // self.iface_id, + // has_events, + // poll_at + // ); use core::sync::atomic::Ordering; if let Some(instant) = poll_at { diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 9ef644671..f496a8a7c 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -1,3 +1,6 @@ +use super::bridge::BridgeEnableDevice; +use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate}; +use super::{Iface, IfaceCommon}; use crate::arch::rand::rand; use crate::driver::base::class::Class; use crate::driver::base::device::bus::Bus; @@ -11,7 +14,10 @@ use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::libs::wait_queue::WaitQueue; use crate::net::{generate_iface_id, NET_DEVICES}; +use crate::process::ProcessState; +use crate::sched::SchedMode; use alloc::collections::VecDeque; use alloc::fmt::Debug; use alloc::string::{String, ToString}; @@ -25,17 +31,11 @@ use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; use system_error::SystemError; use unified_init::macros::unified_init; -use super::bridge::BridgeEnableDevice; -use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate}; - -use super::{Iface, IfaceCommon}; - -// const DEVICE_NAME: &str = "veth"; - pub struct Veth { name: String, rx_queue: VecDeque>, - peer: Option>>, + /// 对端的 `VethInterface`,在完成数据发送的时候会使用到 + peer: Weak, } impl Veth { @@ -43,36 +43,29 @@ impl Veth { Veth { name, rx_queue: VecDeque::new(), - peer: None, + peer: Weak::new(), } } - pub fn set_peer(&mut self, peer: Arc>) { - self.peer = Some(peer); + pub fn set_peer_iface(&mut self, peer: &Arc) { + self.peer = Arc::downgrade(peer); } pub fn send_to_peer(&self, data: Vec) { - // log::info!("{} sending", self.name); - if let Some(peer) = &self.peer { - let mut peer = peer.lock(); - peer.rx_queue.push_back(data); - // log::info!( - // "{} sending data to peer {}, peer current rx_queue: {:?}", - // self.name, - // peer.name(), - // peer.rx_queue - // ); + if let Some(peer) = self.peer.upgrade() { + let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); + peer_veth.rx_queue.push_back(data.clone()); + drop(peer_veth); + + // 唤醒对端正在等待的进程 + peer.wait_queue.wakeup(Some(ProcessState::Blocked(true))); } } pub fn recv_from_peer(&mut self) -> Option> { - // log::info!( - // "{} Receiving data from peer, current rx_queue: {:?}", - // self.name, - // self.rx_queue - // ); self.rx_queue.pop_front() } + pub fn name(&self) -> &str { &self.name } @@ -84,14 +77,26 @@ pub struct VethDriver { } impl VethDriver { - pub fn new_pair(name1: &str, name2: &str) -> (Self, Self) { + /// # `new_pair` + /// 创建一对虚拟以太网设备(veth pair),用于网络测试 + /// ## 参数 + /// - `name1`: 第一个设备的名称 + /// - `name2`: 第二个设备的名称 + /// ## 返回值 + /// 返回一个元组,包含两个 `VethDriver` 实例,分别对应 + /// 第一个和第二个虚拟以太网设备。 + pub fn new_pair(name1: &str, name2: &str) -> (VethDriver, VethDriver) { let dev1 = Arc::new(SpinLock::new(Veth::new(name1.to_string()))); let dev2 = Arc::new(SpinLock::new(Veth::new(name2.to_string()))); - dev1.lock().set_peer(dev2.clone()); - dev2.lock().set_peer(dev1.clone()); + let driver1 = VethDriver { inner: dev1 }; + let driver2 = VethDriver { inner: dev2 }; + + (driver1, driver2) + } - (VethDriver { inner: dev1 }, VethDriver { inner: dev2 }) + pub fn name(&self) -> String { + self.inner.lock_irqsave().name().to_string() } } @@ -106,7 +111,7 @@ impl phy::TxToken for VethTxToken { { let mut buf = vec![0; len]; let result = f(&mut buf); - self.driver.inner.lock().send_to_peer(buf); + self.driver.inner.lock_irqsave().send_to_peer(buf); result } } @@ -165,7 +170,7 @@ impl phy::Device for VethDriver { &mut self, _timestamp: smoltcp::time::Instant, ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let mut guard = self.inner.lock(); + let mut guard = self.inner.lock_irqsave(); guard.recv_from_peer().map(|buf| { // log::info!("VethDriver received data: {:?}", buf); ( @@ -191,20 +196,34 @@ pub struct VethInterface { name: String, driver: VethDriverWarpper, common: IfaceCommon, - inner: SpinLock, + inner: SpinLock, locked_kobj_state: LockedKObjectState, + wait_queue: WaitQueue, } #[derive(Debug)] -pub struct InnerVethInterface { +pub struct VethCommonData { netdevice_common: NetDeviceCommonData, device_common: DeviceCommonData, kobj_common: KObjectCommonData, + peer_veth: Weak, } impl VethInterface { - pub fn new(name: &str, driver: VethDriver) -> Arc { + pub fn has_data(&self) -> bool { + let driver = self.driver.force_get_mut(); + let inner = driver.inner.lock_irqsave(); + !inner.rx_queue.is_empty() + } + + #[allow(unused)] + pub fn peer_veth(&self) -> Arc { + self.inner.lock_irqsave().peer_veth.upgrade().unwrap() + } + + pub fn new(driver: VethDriver) -> Arc { let iface_id = generate_iface_id(); + let name = driver.name(); let mac = [ 0x02, 0x00, @@ -224,41 +243,48 @@ impl VethInterface { iface.set_any_ip(true); let device = Arc::new(VethInterface { - name: name.to_string(), + name, driver: VethDriverWarpper(UnsafeCell::new(driver)), common: IfaceCommon::new(iface_id, true, iface), - inner: SpinLock::new(InnerVethInterface { + inner: SpinLock::new(VethCommonData { netdevice_common: NetDeviceCommonData::default(), device_common: DeviceCommonData::default(), kobj_common: KObjectCommonData::default(), + peer_veth: Weak::new(), }), locked_kobj_state: LockedKObjectState::default(), + wait_queue: WaitQueue::default(), }); + device + } - device.set_net_state(NetDeivceState::__LINK_STATE_START); - device.set_operstate(Operstate::IF_OPER_UP); - NET_DEVICES - .write_irqsave() - .insert(device.nic_id(), device.clone()); - // log::debug!( - // "VethInterface created, devices: {:?}", - // NET_DEVICES - // .read() - // .values() - // .map(|d| d.iface_name()) - // .collect::>() - // ); - register_netdevice(device.clone()).expect("register veth device failed"); + pub fn set_peer_iface(&self, peer: &Arc) { + let mut inner = self.inner.lock_irqsave(); + inner.peer_veth = Arc::downgrade(peer); + self.driver.inner.lock_irqsave().set_peer_iface(peer); + } - device + pub fn new_pair(name1: &str, name2: &str) -> (Arc, Arc) { + let (driver1, driver2) = VethDriver::new_pair(name1, name2); + let iface1 = VethInterface::new(driver1); + let iface2 = VethInterface::new(driver2); + + iface1.set_peer_iface(&iface2); + iface2.set_peer_iface(&iface1); + + // log::info!( + // "is connected: {}", + // iface1.driver.inner.lock_irqsave().peer.upgrade().is_some() + // ); + (iface1, iface2) } - fn inner(&self) -> SpinLockGuard { - self.inner.lock() + fn inner(&self) -> SpinLockGuard { + self.inner.lock_irqsave() } pub fn update_ip_addrs(&self, addr: IpAddress, cidr: IpCidr) { - let iface = &mut self.common.smol_iface.lock(); + let iface = &mut self.common.smol_iface.lock_irqsave(); iface.update_ip_addrs(|ip_addrs| { ip_addrs.push(cidr).expect("Push ipCidr failed: full"); }); @@ -275,7 +301,7 @@ impl VethInterface { .expect("Add default ipv4 route failed: full"); }); - log::info!("VethInterface {} updated IP address: {}", self.name, addr); + // log::info!("VethInterface {} updated IP address: {}", self.name, addr); } } @@ -283,40 +309,52 @@ impl KObject for VethInterface { fn as_any_ref(&self) -> &dyn core::any::Any { self } + fn set_inode(&self, inode: Option>) { self.inner().kobj_common.kern_inode = inode; } + fn inode(&self) -> Option> { self.inner().kobj_common.kern_inode.clone() } + fn parent(&self) -> Option> { self.inner().kobj_common.parent.clone() } + fn set_parent(&self, parent: Option>) { self.inner().kobj_common.parent = parent; } + fn kset(&self) -> Option> { self.inner().kobj_common.kset.clone() } + fn set_kset(&self, kset: Option>) { self.inner().kobj_common.kset = kset; } + fn kobj_type(&self) -> Option<&'static dyn KObjType> { self.inner().kobj_common.kobj_type } + fn name(&self) -> String { self.name.clone() } + fn set_name(&self, _name: String) {} fn kobj_state(&self) -> RwLockReadGuard { self.locked_kobj_state.read() } + fn kobj_state_mut(&self) -> RwLockWriteGuard { self.locked_kobj_state.write() } + fn set_kobj_state(&self, state: KObjectState) { *self.locked_kobj_state.write() = state; } + fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) { self.inner().kobj_common.kobj_type = ktype; } @@ -326,15 +364,19 @@ impl Device for VethInterface { fn dev_type(&self) -> DeviceType { DeviceType::Net } + fn id_table(&self) -> IdTable { IdTable::new(self.name.clone(), None) } + fn bus(&self) -> Option> { self.inner().device_common.bus.clone() } + fn set_bus(&self, bus: Option>) { self.inner().device_common.bus = bus; } + fn class(&self) -> Option> { let mut guard = self.inner(); let r = guard.device_common.class.clone()?.upgrade(); @@ -343,9 +385,11 @@ impl Device for VethInterface { } r } + fn set_class(&self, class: Option>) { self.inner().device_common.class = class; } + fn driver(&self) -> Option> { let r = self.inner().device_common.driver.clone()?.upgrade(); if r.is_none() { @@ -353,24 +397,31 @@ impl Device for VethInterface { } r } + fn set_driver(&self, driver: Option>) { self.inner().device_common.driver = driver; } + fn is_dead(&self) -> bool { false } + fn can_match(&self) -> bool { self.inner().device_common.can_match } + fn set_can_match(&self, can_match: bool) { self.inner().device_common.can_match = can_match; } + fn state_synced(&self) -> bool { true } + fn dev_parent(&self) -> Option> { self.inner().device_common.get_parent_weak_or_clear() } + fn set_dev_parent(&self, parent: Option>) { self.inner().device_common.parent = parent; } @@ -380,35 +431,75 @@ impl Iface for VethInterface { fn common(&self) -> &IfaceCommon { &self.common } + fn iface_name(&self) -> String { self.name.clone() } + fn mac(&self) -> EthernetAddress { - if let HardwareAddress::Ethernet(mac) = self.common.smol_iface.lock().hardware_addr() { + if let HardwareAddress::Ethernet(mac) = + self.common.smol_iface.lock_irqsave().hardware_addr() + { mac } else { EthernetAddress([0, 0, 0, 0, 0, 0]) } } + + fn poll_blocking(&self, can_stop_fn: &dyn Fn() -> bool) { + // log::info!("VethInterface {} polling block", self.name); + + loop { + // 检查是否有数据可用 + self.common.poll(self.driver.force_get_mut()); + + let has_data = self.has_data(); + + // 外部 socket 是否可以接收数据,如果是的话就可以退出loop了 + let can_stop = can_stop_fn(); + + if can_stop { + break; + } + + // 没有数据可用时,进入等待队列 + // 如果有数据可用,则直接跳出循环 + if !has_data { + let _ = wq_wait_event_interruptible!( + self.wait_queue, + self.has_data() || can_stop_fn(), + {} + ); + } + } + } + fn poll(&self) { + // log::info!("VethInterface {} polling normal", self.name); self.common.poll(self.driver.force_get_mut()); } + fn addr_assign_type(&self) -> u8 { self.inner().netdevice_common.addr_assign_type } + fn net_device_type(&self) -> u16 { self.inner().netdevice_common.net_device_type = 1; self.inner().netdevice_common.net_device_type } + fn net_state(&self) -> NetDeivceState { self.inner().netdevice_common.state } + fn set_net_state(&self, state: NetDeivceState) { self.inner().netdevice_common.state |= state; } + fn operstate(&self) -> Operstate { self.inner().netdevice_common.operstate } + fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } @@ -424,26 +515,12 @@ impl BridgeEnableDevice for VethInterface { buf.copy_from_slice(frame); }); } - fn name(&self) -> String { - self.name.clone() - } - fn mac_addr(&self) -> EthernetAddress { - self.mac() - } - // fn bridge_receive(&self, frame: &[u8]) { - // let driver = self.driver.force_get_mut(); - // let token = VethRxToken { - // buffer: frame.to_vec(), - // }; - // } } pub fn veth_probe() { let name1 = "veth0"; let name2 = "veth1"; - let (drv0, drv1) = VethDriver::new_pair(name1, name2); - let iface1 = VethInterface::new(name1, drv0); - let iface2 = VethInterface::new(name2, drv1); + let (iface1, iface2) = VethInterface::new_pair(name1, name2); let addr1 = IpAddress::v4(10, 0, 0, 1); let cidr1 = IpCidr::new(addr1, 24); @@ -452,6 +529,16 @@ pub fn veth_probe() { let addr2 = IpAddress::v4(10, 0, 0, 2); let cidr2 = IpCidr::new(addr2, 24); iface2.update_ip_addrs(addr2, cidr2); + + let turn_on = |a: &Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + turn_on(&iface1); + turn_on(&iface2); } #[unified_init(INITCALL_DEVICE)] diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 07ccce7d4..fc95d90c0 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -11,7 +11,11 @@ use alloc::{ vec::Vec, }; use log::{debug, error}; -use smoltcp::{iface, phy, wire}; +use smoltcp::{ + iface, + phy::{self, TxToken}, + wire, +}; use unified_init::macros::unified_init; use virtio_drivers::device::net::VirtIONet; @@ -29,7 +33,7 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - net::register_netdevice, + net::{bridge::BridgeEnableDevice, register_netdevice}, virtio::{ irq::virtio_irq_manager, sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager}, @@ -743,6 +747,16 @@ impl KObject for VirtioInterface { } } +impl BridgeEnableDevice for VirtioInterface { + fn receive_from_bridge(&self, frame: &[u8]) { + let token = VirtioNetToken::new(self.device_inner.force_get_mut().clone(), None); + + token.consume(frame.len(), |buf| { + buf.copy_from_slice(frame); + }); + } +} + #[unified_init(INITCALL_POSTCORE)] fn virtio_net_driver_init() -> Result<(), SystemError> { let driver = VirtIONetDriver::new(); diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 5756d70fe..a33fe5347 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -160,6 +160,22 @@ impl UdpSocket { } return event; } + + /// 这个方法会阻塞当前线程,直到有数据可读 + /// 通过 poll_blocking 来等待数据的到来 + pub(self) fn wait_for_recv(&self) { + use crate::sched::SchedMode; + let guard = self.inner.read(); + let inner = guard.as_ref(); + if let UdpInner::Bound(bound) = inner.unwrap() { + let rem = bound.inner().iface().clone(); + drop(guard); + let self_ref = self.self_ref.upgrade().unwrap().clone(); + let can_recv = move || self_ref.can_recv(); + rem.poll_blocking(&can_recv); + } + let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + } } impl Socket for UdpSocket { @@ -225,15 +241,13 @@ impl Socket for UdpSocket { } fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { - use crate::sched::SchedMode; - return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { self.try_recv(buffer) } else { loop { match self.try_recv(buffer) { Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; + self.wait_for_recv(); } result => break result, } @@ -248,7 +262,6 @@ impl Socket for UdpSocket { flags: PMSG, address: Option, ) -> Result<(usize, Endpoint), SystemError> { - use crate::sched::SchedMode; // could block io if let Some(endpoint) = address { self.connect(endpoint)?; @@ -260,8 +273,8 @@ impl Socket for UdpSocket { loop { match self.try_recv(buffer) { Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; - log::debug!("UdpSocket::recv_from: wake up"); + self.wait_for_recv(); + log::info!("UdpSocket::recv_from: wake up"); } result => break result, } diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs index f78f70751..7c444536e 100644 --- a/user/apps/test-veth/src/main.rs +++ b/user/apps/test-veth/src/main.rs @@ -182,8 +182,8 @@ use std::time::Duration; fn main() -> std::io::Result<()> { // 启动 server 线程 let server_thread = thread::spawn(|| { - let socket = UdpSocket::bind("10.0.0.2:34254") - .expect("Failed to bind to veth1 (10.0.0.2:34254)"); + let socket = + UdpSocket::bind("10.0.0.2:34254").expect("Failed to bind to veth1 (10.0.0.2:34254)"); println!("[server] Listening on 10.0.0.2:34254"); let mut buf = [0; 1024]; From 50f56e86e50679a87afe90f866774e756b3fa613 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 6 Jul 2025 22:12:57 +0800 Subject: [PATCH 04/36] =?UTF-8?q?feat(routing):=20=E7=AE=80=E5=8D=95?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=B7=AF=E7=94=B1=E5=AD=90=E7=B3=BB=E7=BB=9F?= =?UTF-8?q?,=E5=B0=9A=E6=9C=AA=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/mod.rs | 1 + kernel/src/net/routing/mod.rs | 139 ++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 kernel/src/net/routing/mod.rs diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 805761cc6..7150cecbe 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -10,6 +10,7 @@ use crate::{driver::net::Iface, libs::rwlock::RwLock}; pub mod net_core; pub mod posix; +pub mod routing; pub mod socket; pub mod syscall; diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs new file mode 100644 index 000000000..c0b948165 --- /dev/null +++ b/kernel/src/net/routing/mod.rs @@ -0,0 +1,139 @@ +use crate::time::Instant; +use alloc::collections::BTreeMap; +use alloc::vec::Vec; +use smoltcp::wire::{IpAddress, IpCidr}; + +#[derive(Debug, Clone)] +pub struct NextHop { + // 出口接口编号 + pub if_index: u32, + pub via_router: IpAddress, +} + +#[derive(Debug, Clone)] +pub struct RouteEntry { + pub cidr: IpCidr, + pub next_hops: Vec, + + // None 表示永久有效 + pub prefer_until: Option, + pub expired_at: Option, +} + +#[derive(Debug)] +pub struct RouteTable { + pub table_id: u32, + pub entries: BTreeMap, +} + +impl RouteTable { + pub fn new(table_id: u32) -> Self { + RouteTable { + table_id, + entries: BTreeMap::new(), + } + } + + pub fn add_route(&mut self, cidr: IpCidr, entry: RouteEntry) { + self.entries.insert(cidr, entry); + } + + pub fn del_route(&mut self, cidr: &IpCidr) { + self.entries.remove(cidr); + } + + pub fn lookup(&self, ip: &IpAddress, now: Instant) -> Option<&NextHop> { + self.entries + .iter() + .filter(|(cidr, entry)| { + cidr.contains_addr(ip) && entry.expired_at.map_or(true, |t| now <= t) + }) + .max_by_key(|(cidr, _entry)| cidr.prefix_len()) + .and_then(|(_cidr, entry)| entry.next_hops.first()) + } +} + +pub struct RoutingSubsystem { + pub route_tables: Vec, + pub rules: Vec, +} + +impl RoutingSubsystem { + pub fn new() -> Self { + RoutingSubsystem { + route_tables: Vec::new(), + rules: Vec::new(), + } + } + + pub fn get_table_mut(&mut self, table_id: u32) -> Option<&mut RouteTable> { + self.route_tables + .iter_mut() + .find(|t| t.table_id == table_id) + } + + pub fn add_route_table(&mut self, table: RouteTable) { + self.route_tables.push(table); + } + + pub fn add_routing_rule(&mut self, rule: RoutingRule) { + self.rules.push(rule); + } + + pub fn lookup_route(&self, packet: &PacketMeta) -> Option<&NextHop> { + if let Some(rule) = self + .rules + .iter() + .filter(|r| r.matches(packet)) + .min_by_key(|r| r.priority) + { + return self + .route_tables + .iter() + .find(|t| t.table_id == rule.table_id) + .and_then(|t| t.lookup(&packet.dst_ip, Instant::now())); + } + None + } +} + +#[derive(Debug, Clone)] +pub struct RoutingRule { + pub from: Option, + pub tos: Option, + pub fwmark: Option, + pub table_id: u32, + // 匹配优先级,数字越小优先匹配 + pub priority: u32, +} + +pub struct PacketMeta { + pub src_ip: IpAddress, + pub dst_ip: IpAddress, + pub tos: u8, + pub fwmark: u32, +} + +impl RoutingRule { + pub fn matches(&self, packet: &PacketMeta) -> bool { + if let Some(ref from) = self.from { + if !from.contains_addr(&packet.src_ip) { + return false; + } + } + + if let Some(tos) = self.tos { + if packet.tos != tos { + return false; + } + } + + if let Some(fwmark) = self.fwmark { + if packet.fwmark != fwmark { + return false; + } + } + + true + } +} From ceb6a12c54905c7a1c04a8a349f03923c6e2dca4 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Wed, 9 Jul 2025 17:26:09 +0800 Subject: [PATCH 05/36] =?UTF-8?q?feat(veth):=20=E5=A2=9E=E5=8A=A0veth?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E5=AF=B9=E7=AB=AF=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/veth.rs | 48 +++++++++++---- kernel/src/driver/net/virtio_net.rs | 18 +----- kernel/src/net/socket/inet/datagram/mod.rs | 4 +- user/apps/test-veth/src/main.rs | 69 ++++++++++++++++++++++ 4 files changed, 109 insertions(+), 30 deletions(-) diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index f496a8a7c..b6dea69ba 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -5,7 +5,7 @@ use crate::arch::rand::rand; use crate::driver::base::class::Class; use crate::driver::base::device::bus::Bus; use crate::driver::base::device::driver::Driver; -use crate::driver::base::device::{Device, DeviceCommonData, DeviceType, IdTable}; +use crate::driver::base::device::{self, DeviceCommonData, DeviceType, IdTable}; use crate::driver::base::kobject::{ KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, }; @@ -26,7 +26,7 @@ use alloc::vec::Vec; use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; use smoltcp::phy::DeviceCapabilities; -use smoltcp::phy::{self, TxToken}; +use smoltcp::phy::{self, RxToken, TxToken}; use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; use system_error::SystemError; use unified_init::macros::unified_init; @@ -58,7 +58,7 @@ impl Veth { drop(peer_veth); // 唤醒对端正在等待的进程 - peer.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + peer.wake_up(); } } @@ -120,7 +120,7 @@ pub struct VethRxToken { buffer: Vec, } -impl phy::RxToken for VethRxToken { +impl RxToken for VethRxToken { fn consume(self, f: F) -> R where F: FnOnce(&[u8]) -> R, @@ -190,7 +190,7 @@ impl phy::Device for VethDriver { } #[cast_to([sync] Iface)] -#[cast_to([sync] Device)] +#[cast_to([sync] device::Device)] #[derive(Debug)] pub struct VethInterface { name: String, @@ -303,6 +303,24 @@ impl VethInterface { // log::info!("VethInterface {} updated IP address: {}", self.name, addr); } + + pub fn add_default_route_to_peer(&self, peer_ip: IpAddress) { + let iface = &mut self.common.smol_iface.lock_irqsave(); + iface.routes_mut().update(|routes_map| { + routes_map + .push(smoltcp::iface::Route { + cidr: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + via_router: peer_ip, + preferred_until: None, + expires_at: None, + }) + .expect("Add default route to peer failed"); + }); + } + + pub fn wake_up(&self) { + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + } } impl KObject for VethInterface { @@ -360,7 +378,7 @@ impl KObject for VethInterface { } } -impl Device for VethInterface { +impl device::Device for VethInterface { fn dev_type(&self) -> DeviceType { DeviceType::Net } @@ -418,11 +436,11 @@ impl Device for VethInterface { true } - fn dev_parent(&self) -> Option> { + fn dev_parent(&self) -> Option> { self.inner().device_common.get_parent_weak_or_clear() } - fn set_dev_parent(&self, parent: Option>) { + fn set_dev_parent(&self, parent: Option>) { self.inner().device_common.parent = parent; } } @@ -517,9 +535,9 @@ impl BridgeEnableDevice for VethInterface { } } -pub fn veth_probe() { - let name1 = "veth0"; - let name2 = "veth1"; +pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { + // let name1 = "veth0"; + // let name2 = "veth1"; let (iface1, iface2) = VethInterface::new_pair(name1, name2); let addr1 = IpAddress::v4(10, 0, 0, 1); @@ -530,6 +548,10 @@ pub fn veth_probe() { let cidr2 = IpCidr::new(addr2, 24); iface2.update_ip_addrs(addr2, cidr2); + // 添加默认路由 + iface1.add_default_route_to_peer(addr2); + iface2.add_default_route_to_peer(addr1); + let turn_on = |a: &Arc| { a.set_net_state(NetDeivceState::__LINK_STATE_START); a.set_operstate(Operstate::IF_OPER_UP); @@ -539,11 +561,13 @@ pub fn veth_probe() { turn_on(&iface1); turn_on(&iface2); + + (iface1, iface2) } #[unified_init(INITCALL_DEVICE)] pub fn veth_init() -> Result<(), SystemError> { - veth_probe(); + veth_probe("veth0", "veth1"); log::info!("Veth pair initialized."); Ok(()) } diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index fc95d90c0..07ccce7d4 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -11,11 +11,7 @@ use alloc::{ vec::Vec, }; use log::{debug, error}; -use smoltcp::{ - iface, - phy::{self, TxToken}, - wire, -}; +use smoltcp::{iface, phy, wire}; use unified_init::macros::unified_init; use virtio_drivers::device::net::VirtIONet; @@ -33,7 +29,7 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - net::{bridge::BridgeEnableDevice, register_netdevice}, + net::register_netdevice, virtio::{ irq::virtio_irq_manager, sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager}, @@ -747,16 +743,6 @@ impl KObject for VirtioInterface { } } -impl BridgeEnableDevice for VirtioInterface { - fn receive_from_bridge(&self, frame: &[u8]) { - let token = VirtioNetToken::new(self.device_inner.force_get_mut().clone(), None); - - token.consume(frame.len(), |buf| { - buf.copy_from_slice(frame); - }); - } -} - #[unified_init(INITCALL_POSTCORE)] fn virtio_net_driver_init() -> Result<(), SystemError> { let driver = VirtIONetDriver::new(); diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index a33fe5347..bf0744521 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -164,7 +164,7 @@ impl UdpSocket { /// 这个方法会阻塞当前线程,直到有数据可读 /// 通过 poll_blocking 来等待数据的到来 pub(self) fn wait_for_recv(&self) { - use crate::sched::SchedMode; + // use crate::sched::SchedMode; let guard = self.inner.read(); let inner = guard.as_ref(); if let UdpInner::Bound(bound) = inner.unwrap() { @@ -174,7 +174,7 @@ impl UdpSocket { let can_recv = move || self_ref.can_recv(); rem.poll_blocking(&can_recv); } - let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + // let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); } } diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs index 7c444536e..c2a0b3e12 100644 --- a/user/apps/test-veth/src/main.rs +++ b/user/apps/test-veth/src/main.rs @@ -238,3 +238,72 @@ fn main() -> std::io::Result<()> { Ok(()) } + + +//bridge + + +// use std::net::UdpSocket; +// use std::str; +// use std::thread; +// use std::time::Duration; + +// fn main() -> std::io::Result<()> { +// // 启动 server 线程 +// let server_thread = thread::spawn(|| { +// let socket = +// UdpSocket::bind("200.0.0.2:34254").expect("Failed to bind to veth_d (200.0.0.2:34254)"); +// println!("[server] Listening on 200.0.0.2:34254"); + +// let mut buf = [0; 1024]; +// let (amt, src) = socket +// .recv_from(&mut buf) +// .expect("[server] Failed to receive"); + +// let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + +// println!("[server] Received from {}: {}", src, received_msg); + +// socket +// .send_to(received_msg.as_bytes(), src) +// .expect("[server] Failed to send back"); +// println!("[server] Echoed back the message"); +// }); + +// // 确保 server 已启动(可根据情况适当 sleep) +// thread::sleep(Duration::from_millis(200)); + +// // 启动 client +// let client_thread = thread::spawn(|| { +// let socket = UdpSocket::bind("100.0.0.1:0").expect("Failed to bind to veth_a (100.0.0.1)"); +// socket +// .connect("200.0.0.2:34254") +// .expect("Failed to connect to 200.0.0.2:34254"); + +// let msg = "Hello from veth1!"; +// socket +// .send(msg.as_bytes()) +// .expect("[client] Failed to send"); + +// println!("[client] Sent: {}", msg); + +// let mut buf = [0; 1024]; +// let (amt, _src) = socket +// .recv_from(&mut buf) +// .expect("[client] Failed to receive"); + +// let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + +// println!("[client] Received echo: {}", received_msg); + +// assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); +// }); + +// // 等待两个线程结束 +// server_thread.join().unwrap(); +// client_thread.join().unwrap(); + +// println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); + +// Ok(()) +// } From bae4463b531134bb9c828165ad29f57cc45d584a Mon Sep 17 00:00:00 2001 From: sparkzky Date: Wed, 9 Jul 2025 21:43:43 +0800 Subject: [PATCH 06/36] =?UTF-8?q?feat(socket):=20=E6=81=A2=E5=A4=8Dudp=20s?= =?UTF-8?q?ocket=E4=B8=AD=E7=9A=84wait=5Fqueue=E7=AD=89=E5=BE=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/socket/inet/datagram/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index bf0744521..c7a425167 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -174,7 +174,7 @@ impl UdpSocket { let can_recv = move || self_ref.can_recv(); rem.poll_blocking(&can_recv); } - // let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); } } From df05c6b80402bfeb7b5fd9a29b2e8d638bf1aefe Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 11 Jul 2025 12:57:23 +0800 Subject: [PATCH 07/36] =?UTF-8?q?feat(net):=20=E8=A1=A5=E5=85=85bridge?= =?UTF-8?q?=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 358 +++++++++++++++++---- kernel/src/driver/net/veth.rs | 152 ++++++--- kernel/src/net/socket/inet/datagram/mod.rs | 2 +- user/apps/test-veth/src/main.rs | 2 - 4 files changed, 403 insertions(+), 111 deletions(-) diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index da93a0a32..2c4fd7dbb 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -1,24 +1,41 @@ use crate::{ - driver::net::Iface, - libs::{rwlock::RwLock, spinlock::SpinLock}, + driver::net::{register_netdevice, veth::VethInterface, Iface, NetDeivceState, Operstate}, + init::initcall::INITCALL_DEVICE, + libs::{rwlock::RwLock, spinlock::SpinLock, wait_queue::WaitQueue}, + net::NET_DEVICES, + process::{ + kthread::{KernelThreadClosure, KernelThreadMechanism}, + ProcessState, + }, time::Instant, }; +use alloc::boxed::Box; +use alloc::collections::VecDeque; +use alloc::sync::Weak; +use alloc::vec::Vec; use alloc::{collections::BTreeMap, string::String, sync::Arc}; +use core::{panic, sync::atomic::AtomicUsize}; use hashbrown::HashMap; -use smoltcp::wire::EthernetAddress; +use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr}; +use system_error::SystemError; +use unified_init::macros::unified_init; -const MAC_ENTRY_TIMEOUT: u64 = 60_000; // 60秒 +/// MAC地址表老化时间 +const MAC_ENTRY_TIMEOUT: u64 = 300_000; // 5分钟 +pub type BridgePortId = usize; + +#[derive(Debug)] struct MacEntry { - port: Arc, + port_id: BridgePortId, pub(self) record: RwLock, // 存活时间(动态学习的老化) } impl MacEntry { - pub fn new(port: Arc) -> Self { + pub fn new(port: BridgePortId) -> Self { MacEntry { - port, + port_id: port, record: RwLock::new(MacEntryRecord { last_seen: Instant::now(), }), @@ -31,39 +48,55 @@ impl MacEntry { } } +#[derive(Debug)] struct MacEntryRecord { last_seen: Instant, } /// 代表一个加入bridge的网络接口 -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct BridgePort { - bridge_enable: Arc, - bridge: BridgeDriver, + pub id: BridgePortId, + pub(super) bridge_enable: Arc, + //先把这里直接改成driver,去除weak,忽略循环依赖 + pub(super) bridge_driver: Weak, // 当前接口状态?forwarding, learning, blocking? // mac mtu信息 } impl BridgePort { - fn new(device: Arc, bridge: BridgeDriver) -> Self { + fn new( + id: BridgePortId, + device: Arc, + bridge: &Arc, + ) -> Self { BridgePort { + id, bridge_enable: device, - bridge, + bridge_driver: Arc::downgrade(bridge), } } - fn mac(&self) -> EthernetAddress { - self.bridge_enable.mac() - } + // fn mac(&self) -> EthernetAddress { + // self.bridge_enable.mac() + // } } +type ReceivedFrame = (BridgePortId, Vec); + +#[derive(Debug)] pub struct Bridge { name: String, // 端口列表,key为MAC地址 - ports: BTreeMap>, + ports: BTreeMap, // FDB(Forwarding Database) mac_table: HashMap, // 配置参数,比如aging timeout, max age, hello time, forward delay + // bridge_mac: EthernetAddress, + next_port_id: AtomicUsize, + wait_queue: Arc, + + rx_buf: VecDeque, } impl Bridge { @@ -72,113 +105,310 @@ impl Bridge { name: name.into(), ports: BTreeMap::new(), mac_table: HashMap::new(), + next_port_id: AtomicUsize::new(0), + wait_queue: Arc::new(WaitQueue::default()), + rx_buf: VecDeque::new(), } } - pub fn add_port(&mut self, port: Arc) { - self.ports.insert(port.mac(), port); + fn next_port_id(&self) -> BridgePortId { + self.next_port_id + .fetch_add(1, core::sync::atomic::Ordering::Relaxed) } - pub fn insert_macentry(&mut self, src_mac: EthernetAddress, port: Arc) { - self.mac_table.insert(src_mac, MacEntry::new(port)); + pub fn add_port(&mut self, id: BridgePortId, port: BridgePort) { + self.ports.insert(id, port); } - pub fn handle_frame( - &mut self, - ingress_port: Arc, - frame: &[u8], - dst_mac: EthernetAddress, - src_mac: EthernetAddress, - ) { + pub fn remove_port(&mut self, port_id: BridgePortId) { + self.ports.remove(&port_id); + // 清理MAC地址表中与该端口相关的条目 + self.mac_table + .retain(|_mac, entry| entry.port_id != port_id); + } + + fn insert_or_update_mac_entry(&mut self, src_mac: EthernetAddress, port_id: BridgePortId) { if let Some(entry) = self.mac_table.get(&src_mac) { entry.update_last_seen(); + // 如果 MAC 地址学习到了不同的端口,需要更新 + if entry.port_id != port_id { + // log::info!("Bridge {}: MAC {} moved from port {} to port {}", self.name, src_mac, entry.port_id, port_id); + self.mac_table.insert(src_mac, MacEntry::new(port_id)); + } } else { - // MAC 学习 - self.insert_macentry(src_mac, ingress_port.clone()); + // log::info!("Bridge {}: Learned MAC {} on port {}", self.name, src_mac, port_id); + self.mac_table.insert(src_mac, MacEntry::new(port_id)); + } + } + + pub fn handle_frame(&mut self, ingress_port_id: BridgePortId, frame: &[u8]) { + if frame.len() < 14 { + // 使用 smoltcp 提供的最小长度 + // log::warn!("Bridge {}: Received malformed Ethernet frame (too short).", self.name); + return; } + let ether_frame = match EthernetFrame::new_checked(frame) { + Ok(f) => f, + Err(_) => { + // log::warn!("Bridge {}: Received malformed Ethernet frame.", self.name); + return; + } + }; + + let dst_mac = ether_frame.dst_addr(); + let src_mac = ether_frame.src_addr(); + + self.insert_or_update_mac_entry(src_mac, ingress_port_id); + if dst_mac.is_broadcast() { - // 广播 - self.flood(ingress_port.mac(), frame); + // 广播 这里有可能是arp请求 + self.flood(None, frame); } else { // 单播 if let Some(entry) = self.mac_table.get(&dst_mac) { - let target_port = &entry.port; + let target_port = entry.port_id; // 避免发回自己 - if !Arc::ptr_eq(target_port, &ingress_port) { - Bridge::transmit_to(target_port, frame); - } + // if target_port != ingress_port_id { + self.transmit_to_port(target_port, frame); + // } } else { // 未知单播 → 广播 - self.flood(ingress_port.mac(), frame); + log::info!("unknown unicast, flooding frame"); + self.flood(Some(ingress_port_id), frame); } } self.sweep_mac_table(); } - fn flood(&self, except_mac: EthernetAddress, frame: &[u8]) { - for (mac, port) in self.ports.iter() { - if mac != &except_mac { - Bridge::transmit_to(port, frame); + fn flood(&self, except_port_id: Option, frame: &[u8]) { + match except_port_id { + Some(except_id) => { + for (&port_id, bridge_port) in &self.ports { + if port_id != except_id { + self.transmit_to_device(bridge_port, frame); + } + } } + None => { + for bridge_port in self.ports.values() { + self.transmit_to_device(bridge_port, frame); + } + } + } + } + + fn transmit_to_port(&self, target_port_id: BridgePortId, frame: &[u8]) { + if let Some(device_arc) = self.ports.get(&target_port_id) { + self.transmit_to_device(device_arc, frame); + } else { + // log::warn!("Bridge {}: Attempted to transmit to non-existent port ID {}", self.name, target_port_id); } } - fn transmit_to(port: &BridgePort, frame: &[u8]) { - port.bridge_enable.receive_from_bridge(frame); + fn transmit_to_device(&self, device: &BridgePort, frame: &[u8]) { + device.bridge_enable.receive_from_bridge(frame); } pub fn sweep_mac_table(&mut self) { let now = Instant::now(); self.mac_table.retain(|_mac, entry| { now.duration_since(entry.record.read().last_seen) - .unwrap() + .unwrap_or_default() .total_millis() < MAC_ENTRY_TIMEOUT }); } + + // pub fn poll_blocking(&mut self) { + // use crate::sched::SchedMode; + // loop { + // let opt = self.rx_buf.pop_front(); + // if let Some((port_id, frame)) = opt { + // self.handle_frame(port_id, &frame); + // } else { + // log::info!("Bridge is going to sleep"); + // let _ = wq_wait_event_interruptible!(self.wait_queue, !self.rx_buf.is_empty(), {}); + // } + // } + // } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct BridgeDriver { pub inner: Arc>, + wait_queue: Arc, } impl BridgeDriver { pub fn new(name: &str) -> Self { - BridgeDriver { - inner: Arc::new(SpinLock::new(Bridge::new(name))), - } + let inner = Arc::new(SpinLock::new(Bridge::new(name))); + let wait_queue = inner.lock().wait_queue.clone(); + + let driver = BridgeDriver { inner, wait_queue }; + + // let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { + // driver_clone.poll_blocking(); + // 0 + // }); + // let closure = KernelThreadClosure::EmptyClosure((closure, ())); + // let name = format!("bridge_{}", name); + // let _pcb = KernelThreadMechanism::create_and_run(closure, name) + // .ok_or("") + // .expect("create bridge_poll thread failed"); + + driver } - pub fn add_port(&self, port: Arc) { - let port = BridgePort::new(port, self.clone()); + pub fn add_port(&self, port: BridgePort) { + log::info!("Adding port with id: {}", port.id); - let bridge_port = Arc::new(port); - self.inner.lock().add_port(bridge_port.clone()); + self.inner.lock().add_port(port.id, port); } - pub fn handle_frame(&self, ingress_port: Arc, frame: &[u8]) { - if frame.len() < 14 { - return; // 非法以太网帧 + pub fn remove_port(&self, port_id: BridgePortId) { + self.inner.lock().remove_port(port_id); + } + + fn poll_blocking(&self) { + use crate::sched::SchedMode; + + loop { + let mut inner = self.inner.lock_irqsave(); + + let opt = inner.rx_buf.pop_front(); + if let Some((port_id, frame)) = opt { + inner.handle_frame(port_id, &frame); + } else { + drop(inner); + log::info!("Bridge is going to sleep"); + let _ = wq_wait_event_interruptible!( + self.wait_queue, + !self.inner.lock().rx_buf.is_empty(), + {} + ); + } } + // inner.poll_blocking(); + } + + pub fn enqueue_frame(&self, port_id: BridgePortId, frame: &Vec) { + { + let mut bridge = self.inner.lock(); + log::info!("Enqueuing frame on port {}: {:?}", port_id, frame); + log::warn!("{:?}", frame); + bridge.rx_buf.push_back((port_id, frame.clone())); + } + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + } +} + +pub struct BridgeIface { + pub driver: BridgeDriver, + self_ref: Weak, +} + +impl BridgeIface { + pub fn new(driver: BridgeDriver) -> Arc { + let name = driver.inner.lock().name.clone(); + + let iface = Arc::new_cyclic(|me| BridgeIface { + driver, + self_ref: me.clone(), + }); + let iface_clone = iface.clone(); + + // 创建一个线程来处理桥接设备的轮询 + let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { + iface_clone.poll_blocking(); + 0 + }); + let closure = KernelThreadClosure::EmptyClosure((closure, ())); + let name = format!("bridge_{}", name); + let _pcb = KernelThreadMechanism::create_and_run(closure, name) + .ok_or("") + .expect("create bridge_poll thread failed"); + + iface + } + + pub fn add_port(&self, port_device: Arc) { + let id = self.driver.inner.lock().next_port_id(); + let port = BridgePort::new(id, port_device.clone(), &self.self_ref.upgrade().unwrap()); + + port_device.set_common_bridge_data(port.clone()); + + self.driver.add_port(port); + } - let dst_mac = EthernetAddress::from_bytes(&frame[0..6]); - let src_mac = EthernetAddress::from_bytes(&frame[6..12]); - //todo Frame::new_unchecked + pub fn remove_port(&self, port_id: BridgePortId) { + self.driver.remove_port(port_id); + } - self.inner - .lock() - .handle_frame(ingress_port, frame, dst_mac, src_mac); + pub fn poll_blocking(&self) { + self.driver.poll_blocking(); } } /// 可供桥接设备应该实现的 trait pub trait BridgeEnableDevice: Iface { fn receive_from_bridge(&self, frame: &[u8]); - fn transmit_to_bridge(&self, frame: &[u8]) { - // 默认实现,子类可以覆盖 - self.receive_from_bridge(frame); - } + // fn inner_driver(&self) -> Arc; + fn set_common_bridge_data(&self, _port: BridgePort) {} +} + +fn bridge_probe() { + let (iface1, iface2) = VethInterface::new_pair("veth_a", "veth_b"); + let (iface3, iface4) = VethInterface::new_pair("veth_c", "veth_d"); + + let addr1 = IpAddress::v4(100, 0, 0, 1); + let cidr1 = IpCidr::new(addr1, 24); + let addr2 = IpAddress::v4(100, 0, 0, 2); + let cidr2 = IpCidr::new(addr1, 24); + + let addr3 = IpAddress::v4(200, 0, 0, 1); + let cidr3 = IpCidr::new(addr3, 24); + let addr4 = IpAddress::v4(200, 0, 0, 2); + let cidr4 = IpCidr::new(addr4, 24); + + iface1.update_ip_addrs(cidr1); + iface2.update_ip_addrs(cidr2); + iface3.update_ip_addrs(cidr3); + iface4.update_ip_addrs(cidr4); + + iface1.add_default_route_to_peer(addr2); + iface2.add_default_route_to_peer(addr1); + iface3.add_default_route_to_peer(addr4); + iface4.add_default_route_to_peer(addr3); + + // iface1.add_direct_route(cidr4, addr2); + + let turn_on = |a: &Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + turn_on(&iface1); + turn_on(&iface2); + turn_on(&iface3); + turn_on(&iface4); + + let bridge = BridgeDriver::new("bridge0"); + let iface = BridgeIface::new(bridge); + + // BRIDGE_DEVICES.write_irqsave().push(bridge.clone()); + log::info!("Bridge device created"); + + iface.add_port(iface3); + iface.add_port(iface2); +} + +#[unified_init(INITCALL_DEVICE)] +pub fn bridge_init() -> Result<(), SystemError> { + bridge_probe(); + log::info!("bridge initialized."); + Ok(()) } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index b6dea69ba..9119f6c7d 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -10,6 +10,7 @@ use crate::driver::base::kobject::{ KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, }; use crate::driver::base::kset::KSet; +use crate::driver::net::bridge::BridgePort; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; @@ -26,7 +27,7 @@ use alloc::vec::Vec; use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; use smoltcp::phy::DeviceCapabilities; -use smoltcp::phy::{self, RxToken, TxToken}; +use smoltcp::phy::{self, RxToken}; use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; use system_error::SystemError; use unified_init::macros::unified_init; @@ -51,18 +52,39 @@ impl Veth { self.peer = Arc::downgrade(peer); } - pub fn send_to_peer(&self, data: Vec) { + pub fn send_to_peer(&self, data: &Vec) { if let Some(peer) = self.peer.upgrade() { - let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); - peer_veth.rx_queue.push_back(data.clone()); - drop(peer_veth); + // log::info!("Veth {} trying to send", self.name); - // 唤醒对端正在等待的进程 - peer.wake_up(); + if let Some(bridge_data) = peer.inner.lock().bridge_port_data.as_ref() { + // log::info!("Veth {} sending data to bridge", self.name); + Self::to_bridge(bridge_data, data); + return; + } + + Self::to_peer(&peer, data); } } + pub(self) fn to_peer(peer: &Arc, data: &[u8]) { + let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); + peer_veth.rx_queue.push_back(data.to_vec()); + // log::info!("DATA RECEIVED: {:?}", peer_veth.rx_queue); + drop(peer_veth); + + // 唤醒对端正在等待的进程 + peer.wake_up(); + } + + fn to_bridge(bridge_data: &BridgePort, data: &Vec) { + if let Some(bridge_driver) = bridge_data.bridge_driver.upgrade() { + // log::info!("Veth {} sending data to bridge", self.name); + bridge_driver.driver.enqueue_frame(bridge_data.id, data); + }; + } + pub fn recv_from_peer(&mut self) -> Option> { + // log::info!("Veth {} trying to receive", self.name); self.rx_queue.pop_front() } @@ -111,7 +133,7 @@ impl phy::TxToken for VethTxToken { { let mut buf = vec![0; len]; let result = f(&mut buf); - self.driver.inner.lock_irqsave().send_to_peer(buf); + self.driver.inner.lock_irqsave().send_to_peer(&buf); result } } @@ -207,6 +229,20 @@ pub struct VethCommonData { device_common: DeviceCommonData, kobj_common: KObjectCommonData, peer_veth: Weak, + + bridge_port_data: Option, +} + +impl Default for VethCommonData { + fn default() -> Self { + VethCommonData { + netdevice_common: NetDeviceCommonData::default(), + device_common: DeviceCommonData::default(), + kobj_common: KObjectCommonData::default(), + peer_veth: Weak::new(), + bridge_port_data: None, + } + } } impl VethInterface { @@ -246,15 +282,12 @@ impl VethInterface { name, driver: VethDriverWarpper(UnsafeCell::new(driver)), common: IfaceCommon::new(iface_id, true, iface), - inner: SpinLock::new(VethCommonData { - netdevice_common: NetDeviceCommonData::default(), - device_common: DeviceCommonData::default(), - kobj_common: KObjectCommonData::default(), - peer_veth: Weak::new(), - }), + inner: SpinLock::new(VethCommonData::default()), locked_kobj_state: LockedKObjectState::default(), wait_queue: WaitQueue::default(), }); + + // log::info!("VethInterface {} created with ID {}", device.name, iface_id); device } @@ -272,10 +305,6 @@ impl VethInterface { iface1.set_peer_iface(&iface2); iface2.set_peer_iface(&iface1); - // log::info!( - // "is connected: {}", - // iface1.driver.inner.lock_irqsave().peer.upgrade().is_some() - // ); (iface1, iface2) } @@ -283,29 +312,37 @@ impl VethInterface { self.inner.lock_irqsave() } - pub fn update_ip_addrs(&self, addr: IpAddress, cidr: IpCidr) { + /// # `update_ip_addrs` + /// 更新虚拟以太网设备的 IP 地址 + /// ## 参数 + /// - `cidr`: 要添加的 IP 地址和子网掩码 + /// ## 描述 + /// 该方法会将指定的 IP 地址添加到虚拟以太网设备的 IP 地址列表中。 + /// 如果添加失败(例如列表已满),则会触发 panic。 + pub fn update_ip_addrs(&self, cidr: IpCidr) { let iface = &mut self.common.smol_iface.lock_irqsave(); iface.update_ip_addrs(|ip_addrs| { ip_addrs.push(cidr).expect("Push ipCidr failed: full"); }); - // 默认路由 - iface.routes_mut().update(|routes_map| { - routes_map - .push(smoltcp::iface::Route { - cidr, - via_router: addr, - preferred_until: None, - expires_at: None, - }) - .expect("Add default ipv4 route failed: full"); - }); - // log::info!("VethInterface {} updated IP address: {}", self.name, addr); } + /// # `add_default_route_to_peer` + /// 添加默认路由到对端虚拟以太网设备 + /// ## 参数 + /// - `peer_ip`: 对端设备的 IP 地址 + /// ## 描述 + /// 该方法会在当前虚拟以太网设备的路由表中 + /// 添加一条默认路由, + /// 指向对端虚拟以太网设备的 IP 地址。 + /// 如果添加失败,则会触发 panic。 + /// pub fn add_default_route_to_peer(&self, peer_ip: IpAddress) { let iface = &mut self.common.smol_iface.lock_irqsave(); + // iface.update_ip_addrs(|ip_addrs| { + // ip_addrs.push(self_cidr).expect("Push ipCidr failed: full"); + // }); iface.routes_mut().update(|routes_map| { routes_map .push(smoltcp::iface::Route { @@ -318,6 +355,20 @@ impl VethInterface { }); } + // pub fn add_direct_route(&self, cidr: IpCidr, via_router: IpAddress) { + // let iface = &mut self.common.smol_iface.lock_irqsave(); + // iface.routes_mut().update(|routes_map| { + // routes_map + // .push(smoltcp::iface::Route { + // cidr, + // via_router, + // preferred_until: None, + // expires_at: None, + // }) + // .expect("Add direct route failed"); + // }); + // } + pub fn wake_up(&self) { self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); } @@ -465,7 +516,7 @@ impl Iface for VethInterface { } fn poll_blocking(&self, can_stop_fn: &dyn Fn() -> bool) { - // log::info!("VethInterface {} polling block", self.name); + log::info!("VethInterface {} polling block", self.name); loop { // 检查是否有数据可用 @@ -482,6 +533,7 @@ impl Iface for VethInterface { // 没有数据可用时,进入等待队列 // 如果有数据可用,则直接跳出循环 + log::info!("VethInterface {} waiting for data", self.name); if !has_data { let _ = wq_wait_event_interruptible!( self.wait_queue, @@ -493,7 +545,7 @@ impl Iface for VethInterface { } fn poll(&self) { - // log::info!("VethInterface {} polling normal", self.name); + log::info!("VethInterface {} polling normal", self.name); self.common.poll(self.driver.force_get_mut()); } @@ -525,28 +577,40 @@ impl Iface for VethInterface { impl BridgeEnableDevice for VethInterface { fn receive_from_bridge(&self, frame: &[u8]) { - let driver = self.driver.force_get_mut(); - let token = VethTxToken { - driver: driver.clone(), - }; - token.consume(frame.len(), |buf| { - buf.copy_from_slice(frame); - }); + log::info!("VethInterface {} received from bridge", self.name); + + let inner = self.inner.lock_irqsave(); + + if let Some(_data) = inner.bridge_port_data.as_ref() { + log::info!("VethInterface {} sending data to peer", self.name); + + // Veth::to_peer(&peer, frame); + self.driver + .inner + .lock_irqsave() + .rx_queue + .push_back(frame.to_vec()); + self.poll(); + } + } + + fn set_common_bridge_data(&self, port: BridgePort) { + // log::info!("Now set bridge port data for {}", self.name); + let mut inner = self.inner.lock_irqsave(); + inner.bridge_port_data = Some(port); } } pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { - // let name1 = "veth0"; - // let name2 = "veth1"; let (iface1, iface2) = VethInterface::new_pair(name1, name2); let addr1 = IpAddress::v4(10, 0, 0, 1); let cidr1 = IpCidr::new(addr1, 24); - iface1.update_ip_addrs(addr1, cidr1); + iface1.update_ip_addrs(cidr1); let addr2 = IpAddress::v4(10, 0, 0, 2); let cidr2 = IpCidr::new(addr2, 24); - iface2.update_ip_addrs(addr2, cidr2); + iface2.update_ip_addrs(cidr2); // 添加默认路由 iface1.add_default_route_to_peer(addr2); diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index c7a425167..a33fe5347 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -164,7 +164,7 @@ impl UdpSocket { /// 这个方法会阻塞当前线程,直到有数据可读 /// 通过 poll_blocking 来等待数据的到来 pub(self) fn wait_for_recv(&self) { - // use crate::sched::SchedMode; + use crate::sched::SchedMode; let guard = self.inner.read(); let inner = guard.as_ref(); if let UdpInner::Bound(bound) = inner.unwrap() { diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs index c2a0b3e12..d7f1ff3dd 100644 --- a/user/apps/test-veth/src/main.rs +++ b/user/apps/test-veth/src/main.rs @@ -239,10 +239,8 @@ fn main() -> std::io::Result<()> { Ok(()) } - //bridge - // use std::net::UdpSocket; // use std::str; // use std::thread; From 63c22329e0393958be35ae28e1fca9c85421f2a9 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sat, 12 Jul 2025 21:54:17 +0800 Subject: [PATCH 08/36] =?UTF-8?q?feat(bridge):=20=E6=9B=B4=E6=94=B9?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=A8=8B=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 12 +-- kernel/src/driver/net/veth.rs | 24 +++--- user/apps/test-veth/src/main.rs | 146 ++++++++++++++++---------------- 3 files changed, 93 insertions(+), 89 deletions(-) diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index 2c4fd7dbb..395e9ed67 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -14,7 +14,7 @@ use alloc::collections::VecDeque; use alloc::sync::Weak; use alloc::vec::Vec; use alloc::{collections::BTreeMap, string::String, sync::Arc}; -use core::{panic, sync::atomic::AtomicUsize}; +use core::sync::atomic::AtomicUsize; use hashbrown::HashMap; use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr}; use system_error::SystemError; @@ -163,7 +163,7 @@ impl Bridge { if dst_mac.is_broadcast() { // 广播 这里有可能是arp请求 - self.flood(None, frame); + self.flood(Some(ingress_port_id), frame); } else { // 单播 if let Some(entry) = self.mac_table.get(&dst_mac) { @@ -362,14 +362,14 @@ fn bridge_probe() { let (iface1, iface2) = VethInterface::new_pair("veth_a", "veth_b"); let (iface3, iface4) = VethInterface::new_pair("veth_c", "veth_d"); - let addr1 = IpAddress::v4(100, 0, 0, 1); + let addr1 = IpAddress::v4(200, 0, 0, 1); let cidr1 = IpCidr::new(addr1, 24); - let addr2 = IpAddress::v4(100, 0, 0, 2); + let addr2 = IpAddress::v4(200, 0, 0, 2); let cidr2 = IpCidr::new(addr1, 24); - let addr3 = IpAddress::v4(200, 0, 0, 1); + let addr3 = IpAddress::v4(200, 0, 0, 3); let cidr3 = IpCidr::new(addr3, 24); - let addr4 = IpAddress::v4(200, 0, 0, 2); + let addr4 = IpAddress::v4(200, 0, 0, 4); let cidr4 = IpCidr::new(addr4, 24); iface1.update_ip_addrs(cidr1); diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 9119f6c7d..287303d10 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -69,7 +69,8 @@ impl Veth { pub(self) fn to_peer(peer: &Arc, data: &[u8]) { let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); peer_veth.rx_queue.push_back(data.to_vec()); - // log::info!("DATA RECEIVED: {:?}", peer_veth.rx_queue); + log::info!("Veth {} received data from peer", peer.name); + log::info!("DATA RECEIVED: {:?}", peer_veth.rx_queue); drop(peer_veth); // 唤醒对端正在等待的进程 @@ -578,20 +579,23 @@ impl Iface for VethInterface { impl BridgeEnableDevice for VethInterface { fn receive_from_bridge(&self, frame: &[u8]) { log::info!("VethInterface {} received from bridge", self.name); + let peer = self.peer_veth(); - let inner = self.inner.lock_irqsave(); + // let inner = self.inner.lock_irqsave(); - if let Some(_data) = inner.bridge_port_data.as_ref() { + if let Some(_data) = self.inner.lock_irqsave().bridge_port_data.as_ref() { log::info!("VethInterface {} sending data to peer", self.name); - // Veth::to_peer(&peer, frame); - self.driver - .inner - .lock_irqsave() - .rx_queue - .push_back(frame.to_vec()); - self.poll(); + // let peer = self.peer_veth(); + Veth::to_peer(&peer, frame); + // self.driver + // .inner + // .lock_irqsave() + // .rx_queue + // .push_back(frame.to_vec()); + // peer.poll(); } + log::info!("returning"); } fn set_common_bridge_data(&self, port: BridgePort) { diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth/src/main.rs index d7f1ff3dd..5a2d9e629 100644 --- a/user/apps/test-veth/src/main.rs +++ b/user/apps/test-veth/src/main.rs @@ -174,73 +174,6 @@ // } // } -use std::net::UdpSocket; -use std::str; -use std::thread; -use std::time::Duration; - -fn main() -> std::io::Result<()> { - // 启动 server 线程 - let server_thread = thread::spawn(|| { - let socket = - UdpSocket::bind("10.0.0.2:34254").expect("Failed to bind to veth1 (10.0.0.2:34254)"); - println!("[server] Listening on 10.0.0.2:34254"); - - let mut buf = [0; 1024]; - let (amt, src) = socket - .recv_from(&mut buf) - .expect("[server] Failed to receive"); - - let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - - println!("[server] Received from {}: {}", src, received_msg); - - socket - .send_to(received_msg.as_bytes(), src) - .expect("[server] Failed to send back"); - println!("[server] Echoed back the message"); - }); - - // 确保 server 已启动(可根据情况适当 sleep) - thread::sleep(Duration::from_millis(200)); - - // 启动 client - let client_thread = thread::spawn(|| { - let socket = UdpSocket::bind("10.0.0.1:0").expect("Failed to bind to veth0 (10.0.0.1)"); - socket - .connect("10.0.0.2:34254") - .expect("Failed to connect to 10.0.0.2:34254"); - - let msg = "Hello from veth0!"; - socket - .send(msg.as_bytes()) - .expect("[client] Failed to send"); - - println!("[client] Sent: {}", msg); - - let mut buf = [0; 1024]; - let (amt, _src) = socket - .recv_from(&mut buf) - .expect("[client] Failed to receive"); - - let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - - println!("[client] Received echo: {}", received_msg); - - assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); - }); - - // 等待两个线程结束 - server_thread.join().unwrap(); - client_thread.join().unwrap(); - - println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); - - Ok(()) -} - -//bridge - // use std::net::UdpSocket; // use std::str; // use std::thread; @@ -250,8 +183,8 @@ fn main() -> std::io::Result<()> { // // 启动 server 线程 // let server_thread = thread::spawn(|| { // let socket = -// UdpSocket::bind("200.0.0.2:34254").expect("Failed to bind to veth_d (200.0.0.2:34254)"); -// println!("[server] Listening on 200.0.0.2:34254"); +// UdpSocket::bind("10.0.0.2:34254").expect("Failed to bind to veth1 (10.0.0.2:34254)"); +// println!("[server] Listening on 10.0.0.2:34254"); // let mut buf = [0; 1024]; // let (amt, src) = socket @@ -273,12 +206,12 @@ fn main() -> std::io::Result<()> { // // 启动 client // let client_thread = thread::spawn(|| { -// let socket = UdpSocket::bind("100.0.0.1:0").expect("Failed to bind to veth_a (100.0.0.1)"); +// let socket = UdpSocket::bind("10.0.0.1:0").expect("Failed to bind to veth0 (10.0.0.1)"); // socket -// .connect("200.0.0.2:34254") -// .expect("Failed to connect to 200.0.0.2:34254"); +// .connect("10.0.0.2:34254") +// .expect("Failed to connect to 10.0.0.2:34254"); -// let msg = "Hello from veth1!"; +// let msg = "Hello from veth0!"; // socket // .send(msg.as_bytes()) // .expect("[client] Failed to send"); @@ -305,3 +238,70 @@ fn main() -> std::io::Result<()> { // Ok(()) // } + +//bridge + +use std::net::UdpSocket; +use std::str; +use std::thread; +use std::time::Duration; + +fn main() -> std::io::Result<()> { + // 启动 server 线程 + let server_thread = thread::spawn(|| { + let socket = + UdpSocket::bind("200.0.0.4:34254").expect("Failed to bind to veth_d (200.0.0.4:34254)"); + println!("[server] Listening on 200.0.0.4:34254"); + + let mut buf = [0; 1024]; + let (amt, src) = socket + .recv_from(&mut buf) + .expect("[server] Failed to receive"); + + let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + + println!("[server] Received from {}: {}", src, received_msg); + + socket + .send_to(received_msg.as_bytes(), src) + .expect("[server] Failed to send back"); + println!("[server] Echoed back the message"); + }); + + // 确保 server 已启动(可根据情况适当 sleep) + thread::sleep(Duration::from_millis(200)); + + // 启动 client + let client_thread = thread::spawn(|| { + let socket = UdpSocket::bind("200.0.0.1:0").expect("Failed to bind to veth_a (200.0.0.1)"); + socket + .connect("200.0.0.4:34254") + .expect("Failed to connect to 200.0.0.4:34254"); + + let msg = "Hello from veth1!"; + socket + .send(msg.as_bytes()) + .expect("[client] Failed to send"); + + println!("[client] Sent: {}", msg); + + let mut buf = [0; 1024]; + let (amt, _src) = socket + .recv_from(&mut buf) + .expect("[client] Failed to receive"); + + let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); + + println!("[client] Received echo: {}", received_msg); + + assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); + }); + + // 等待两个线程结束 + server_thread.join().unwrap(); + client_thread.join().unwrap(); + + println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); + + Ok(()) +} From a1400d2f4de1733742519a61dfa8a4ad6fd635bb Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 13 Jul 2025 11:42:51 +0800 Subject: [PATCH 09/36] =?UTF-8?q?feat:=20=E9=87=8D=E5=91=BD=E5=90=8D?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=A8=8B=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- user/apps/{test-veth => test-veth-bridge}/.gitignore | 0 user/apps/{test-veth => test-veth-bridge}/Cargo.toml | 0 user/apps/{test-veth => test-veth-bridge}/Makefile | 0 user/apps/{test-veth => test-veth-bridge}/src/main.rs | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename user/apps/{test-veth => test-veth-bridge}/.gitignore (100%) rename user/apps/{test-veth => test-veth-bridge}/Cargo.toml (100%) rename user/apps/{test-veth => test-veth-bridge}/Makefile (100%) rename user/apps/{test-veth => test-veth-bridge}/src/main.rs (100%) diff --git a/user/apps/test-veth/.gitignore b/user/apps/test-veth-bridge/.gitignore similarity index 100% rename from user/apps/test-veth/.gitignore rename to user/apps/test-veth-bridge/.gitignore diff --git a/user/apps/test-veth/Cargo.toml b/user/apps/test-veth-bridge/Cargo.toml similarity index 100% rename from user/apps/test-veth/Cargo.toml rename to user/apps/test-veth-bridge/Cargo.toml diff --git a/user/apps/test-veth/Makefile b/user/apps/test-veth-bridge/Makefile similarity index 100% rename from user/apps/test-veth/Makefile rename to user/apps/test-veth-bridge/Makefile diff --git a/user/apps/test-veth/src/main.rs b/user/apps/test-veth-bridge/src/main.rs similarity index 100% rename from user/apps/test-veth/src/main.rs rename to user/apps/test-veth-bridge/src/main.rs From 09f2a59911f6b85e4a09b8679f3d56106ed513fa Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 18 Jul 2025 21:00:31 +0800 Subject: [PATCH 10/36] =?UTF-8?q?feat:=20=E6=9B=B4=E6=94=B9veth&beidge?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=A8=8B=E5=BA=8F=E7=9A=84toml?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- user/apps/test-veth-bridge/Cargo.toml | 2 +- .../{test_veth_0_1_0.toml => test_veth_bridge_0_1_0.toml} | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) rename user/dadk/config/{test_veth_0_1_0.toml => test_veth_bridge_0_1_0.toml} (90%) diff --git a/user/apps/test-veth-bridge/Cargo.toml b/user/apps/test-veth-bridge/Cargo.toml index d60adbf3c..77dbe7a0c 100644 --- a/user/apps/test-veth-bridge/Cargo.toml +++ b/user/apps/test-veth-bridge/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "test-veth" +name = "test-veth-bridge" version = "0.1.0" edition = "2021" description = "测试veth pair" diff --git a/user/dadk/config/test_veth_0_1_0.toml b/user/dadk/config/test_veth_bridge_0_1_0.toml similarity index 90% rename from user/dadk/config/test_veth_0_1_0.toml rename to user/dadk/config/test_veth_bridge_0_1_0.toml index f8571b15f..fc8d58738 100644 --- a/user/dadk/config/test_veth_0_1_0.toml +++ b/user/dadk/config/test_veth_bridge_0_1_0.toml @@ -1,9 +1,9 @@ # 用户程序名称 -name = "test-veth" +name = "test-veth-bridge" # 版本号 version = "0.1.0" # 用户程序描述信息 -description = "test for veth interface" +description = "test for veth and bridge" # (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 build-once = false # (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 @@ -21,7 +21,7 @@ type = "build-from-source" # "install_from_prebuilt" 可选值:"local", "archive" source = "local" # 路径或URL -source-path = "user/apps/test-veth" +source-path = "user/apps/test-veth-bridge" # 构建相关信息 [build] # (可选)构建命令 From 3190b105ce2152afcc9a662079e5ee594ab635ad Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 3 Aug 2025 16:49:59 +0800 Subject: [PATCH 11/36] =?UTF-8?q?feat:=20=E6=9A=82=E6=97=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0route=5Fiface=E4=BB=A5=E5=8F=8Aroute=5Ftable?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 1 + kernel/src/driver/net/mod.rs | 1 + kernel/src/driver/net/route_iface.rs | 464 ++++++++++++++++++++++++ kernel/src/net/routing/routing_table.rs | 61 ++++ user/apps/test-router/.gitignore | 3 + user/apps/test-router/Cargo.toml | 12 + user/apps/test-router/Makefile | 56 +++ user/apps/test-router/src/main.rs | 3 + 8 files changed, 601 insertions(+) create mode 100644 kernel/src/driver/net/route_iface.rs create mode 100644 kernel/src/net/routing/routing_table.rs create mode 100644 user/apps/test-router/.gitignore create mode 100644 user/apps/test-router/Cargo.toml create mode 100644 user/apps/test-router/Makefile create mode 100644 user/apps/test-router/src/main.rs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index 395e9ed67..aa2681351 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -342,6 +342,7 @@ impl BridgeIface { self.driver.add_port(port); } + #[allow(unused)] pub fn remove_port(&self, port_id: BridgePortId) { self.driver.remove_port(port_id); } diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 0950f0324..b2906de88 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -17,6 +17,7 @@ pub mod e1000e; pub mod irq_handle; pub mod kthread; pub mod loopback; +pub mod route_iface; pub mod sysfs; pub mod veth; pub mod virtio_net; diff --git a/kernel/src/driver/net/route_iface.rs b/kernel/src/driver/net/route_iface.rs new file mode 100644 index 000000000..a27d0749e --- /dev/null +++ b/kernel/src/driver/net/route_iface.rs @@ -0,0 +1,464 @@ +use super::{Iface, IfaceCommon}; +use super::{NetDeivceState, NetDeviceCommonData, Operstate}; +use crate::arch::rand::rand; +use crate::driver::base::class::Class; +use crate::driver::base::device::bus::Bus; +use crate::driver::base::device::driver::Driver; +use crate::driver::base::device::{Device, DeviceCommonData, DeviceType, IdTable}; +use crate::driver::base::kobject::{ + KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, +}; +use crate::driver::base::kset::KSet; +use crate::filesystem::kernfs::KernFSInode; +use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; +use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::net::generate_iface_id; +use crate::net::routing::router::Router; +use crate::time::Instant; +use alloc::collections::VecDeque; +use alloc::fmt::Debug; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::cell::UnsafeCell; +use core::ops::{Deref, DerefMut}; +use smoltcp::phy::DeviceCapabilities; +use smoltcp::wire::{EthernetAddress, HardwareAddress}; +use smoltcp::{ + phy::{self}, + wire::{IpAddress, IpCidr}, +}; + +pub struct RouteRxToken { + driver: RouteDriver, + buffer: Vec, +} + +impl phy::RxToken for RouteRxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(self.buffer.as_slice()) + } +} + +pub struct RouteTxToken { + driver: RouteDriver, +} + +impl phy::TxToken for RouteTxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buffer = vec![0; len]; + let result = f(buffer.as_mut_slice()); + let mut device = self.driver.inner.lock(); + device.route_transmit(buffer); + result + } +} + +pub struct Route { + name: String, + queue: VecDeque>, +} + +impl Route { + pub fn new(name: &str) -> Self { + let queue = VecDeque::new(); + Route { + name: name.to_string(), + queue, + } + } + + pub fn route_receive(&mut self) -> Vec { + let buffer = self.queue.pop_front(); + match buffer { + Some(buffer) => { + return buffer; + } + None => { + return Vec::new(); + } + } + } + + pub fn route_transmit(&mut self, buffer: Vec) { + self.queue.push_back(buffer) + } +} + +#[derive(Debug)] +struct RouteDriverWapper(UnsafeCell); +unsafe impl Send for RouteDriverWapper {} +unsafe impl Sync for RouteDriverWapper {} + +impl Deref for RouteDriverWapper { + type Target = RouteDriver; + fn deref(&self) -> &Self::Target { + unsafe { &*self.0.get() } + } +} +impl DerefMut for RouteDriverWapper { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0.get() } + } +} + +impl RouteDriverWapper { + #[allow(clippy::mut_from_ref)] + #[allow(clippy::mut_from_ref)] + fn force_get_mut(&self) -> &mut RouteDriver { + unsafe { &mut *self.0.get() } + } +} + +pub struct RouteDriver { + pub inner: Arc>, + pub router: Weak, +} + +impl RouteDriver { + pub fn new(name: &str) -> Self { + let inner = Arc::new(SpinLock::new(Route::new(name))); + RouteDriver { + inner, + router: Weak::default(), + } + } + + pub fn name(&self) -> String { + self.inner.lock().name.clone() + } + + pub fn attach_router(&mut self, router: Arc) { + self.router = Arc::downgrade(&router); + } +} + +impl Clone for RouteDriver { + fn clone(&self) -> Self { + RouteDriver { + inner: self.inner.clone(), + router: self.router.clone(), + } + } +} + +impl phy::Device for RouteDriver { + type RxToken<'a> + = RouteRxToken + where + Self: 'a; + type TxToken<'a> + = RouteTxToken + where + Self: 'a; + + fn capabilities(&self) -> phy::DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1500; + caps.medium = smoltcp::phy::Medium::Ethernet; + caps + } + + fn receive( + &mut self, + _timestamp: smoltcp::time::Instant, + ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let buffer = self.inner.lock().route_receive(); + + if let Some(router) = self.router.upgrade() { + router.recv_from_iface(buffer); + return None; + } + + if buffer.is_empty() { + return Option::None; + } + let rx = RouteRxToken { + driver: self.clone(), + buffer, + }; + let tx = RouteTxToken { + driver: self.clone(), + }; + return Option::Some((rx, tx)); + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + Some(RouteTxToken { + driver: self.clone(), + }) + } +} + +#[cast_to([sync] Iface)] +#[cast_to([sync] crate::driver::base::device::Device)] +#[derive(Debug)] +pub struct RouteInterface { + name: String, + driver: RouteDriverWapper, + common: IfaceCommon, + inner: SpinLock, + locked_kobj_state: LockedKObjectState, +} + +#[derive(Debug)] +pub struct InnerRouteInterface { + netdevice_common: NetDeviceCommonData, + device_common: DeviceCommonData, + kobj_common: KObjectCommonData, + + router: Weak, +} + +impl Default for InnerRouteInterface { + fn default() -> Self { + InnerRouteInterface { + netdevice_common: NetDeviceCommonData::default(), + device_common: DeviceCommonData::default(), + kobj_common: KObjectCommonData::default(), + router: Weak::default(), + } + } +} + +impl RouteInterface { + pub fn new(driver: RouteDriver) -> Arc { + let iface_id = generate_iface_id(); + + let mac = [ + 0x03, + 0x00, + 0x00, + 0x00, + (iface_id >> 8) as u8, + iface_id as u8, + ]; + + let hw_addr = HardwareAddress::Ethernet(EthernetAddress(mac)); + let mut iface_config = smoltcp::iface::Config::new(hw_addr); + iface_config.random_seed = rand() as u64; + let mut iface = smoltcp::iface::Interface::new( + iface_config, + &mut driver.clone(), + crate::time::Instant::now().into(), + ); + iface.set_any_ip(true); + + Arc::new(RouteInterface { + name: driver.name(), + driver: RouteDriverWapper(UnsafeCell::new(driver)), + common: IfaceCommon::new(iface_id, false, iface), + inner: SpinLock::new(InnerRouteInterface::default()), + locked_kobj_state: LockedKObjectState::default(), + }) + } + + pub fn attach_router(&self, router: Arc) { + self.inner().router = Arc::downgrade(&router); + self.driver.force_get_mut().attach_router(router); + } + + pub fn update_ip_addrs(&self, cidr: IpCidr) { + let iface = &mut self.common.smol_iface.lock_irqsave(); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs.push(cidr).expect("Push ipCidr failed: full"); + }); + } + + pub fn add_default_route_to_peer(&self, ip: IpAddress) { + let iface = &mut self.common.smol_iface.lock_irqsave(); + + iface.routes_mut().update(|routes_map| { + routes_map + .push(smoltcp::iface::Route { + cidr: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + via_router: ip, + preferred_until: None, + expires_at: None, + }) + .expect("Add default route to peer failed"); + }); + } + + fn inner(&self) -> SpinLockGuard { + return self.inner.lock(); + } + + // pub fn send()!!! +} + +impl KObject for RouteInterface { + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } + + fn set_inode(&self, inode: Option>) { + self.inner().kobj_common.kern_inode = inode; + } + + fn inode(&self) -> Option> { + self.inner().kobj_common.kern_inode.clone() + } + + fn parent(&self) -> Option> { + self.inner().kobj_common.parent.clone() + } + + fn set_parent(&self, parent: Option>) { + self.inner().kobj_common.parent = parent; + } + + fn kset(&self) -> Option> { + self.inner().kobj_common.kset.clone() + } + + fn set_kset(&self, kset: Option>) { + self.inner().kobj_common.kset = kset; + } + + fn kobj_type(&self) -> Option<&'static dyn KObjType> { + self.inner().kobj_common.kobj_type + } + + fn name(&self) -> String { + self.name.clone() + } + + fn set_name(&self, _name: String) { + // do nothing + } + + fn kobj_state(&self) -> RwLockReadGuard { + self.locked_kobj_state.read() + } + + fn kobj_state_mut(&self) -> RwLockWriteGuard { + self.locked_kobj_state.write() + } + + fn set_kobj_state(&self, state: KObjectState) { + *self.locked_kobj_state.write() = state; + } + + fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) { + self.inner().kobj_common.kobj_type = ktype; + } +} + +impl Device for RouteInterface { + fn dev_type(&self) -> DeviceType { + DeviceType::Net + } + + fn id_table(&self) -> IdTable { + IdTable::new(self.name.clone(), None) + } + + fn bus(&self) -> Option> { + self.inner().device_common.bus.clone() + } + + fn set_bus(&self, bus: Option>) { + self.inner().device_common.bus = bus; + } + + fn class(&self) -> Option> { + let mut guard = self.inner(); + let r = guard.device_common.class.clone()?.upgrade(); + if r.is_none() { + guard.device_common.class = None; + } + + return r; + } + + fn set_class(&self, class: Option>) { + self.inner().device_common.class = class; + } + + fn driver(&self) -> Option> { + let r = self.inner().device_common.driver.clone()?.upgrade(); + if r.is_none() { + self.inner().device_common.driver = None; + } + + return r; + } + + fn set_driver(&self, driver: Option>) { + self.inner().device_common.driver = driver; + } + + fn is_dead(&self) -> bool { + false + } + + fn can_match(&self) -> bool { + self.inner().device_common.can_match + } + + fn set_can_match(&self, can_match: bool) { + self.inner().device_common.can_match = can_match; + } + + fn state_synced(&self) -> bool { + true + } + + fn dev_parent(&self) -> Option> { + self.inner().device_common.get_parent_weak_or_clear() + } + + fn set_dev_parent(&self, parent: Option>) { + self.inner().device_common.parent = parent; + } +} + +impl Iface for RouteInterface { + fn common(&self) -> &IfaceCommon { + &self.common + } + + fn iface_name(&self) -> String { + self.name.clone() + } + + fn mac(&self) -> smoltcp::wire::EthernetAddress { + let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + smoltcp::wire::EthernetAddress(mac) + } + + fn poll(&self) { + self.common.poll(self.driver.force_get_mut()) + } + + fn addr_assign_type(&self) -> u8 { + return self.inner().netdevice_common.addr_assign_type; + } + + fn net_device_type(&self) -> u16 { + return self.inner().netdevice_common.net_device_type; + } + + fn net_state(&self) -> NetDeivceState { + return self.inner().netdevice_common.state; + } + + fn set_net_state(&self, state: NetDeivceState) { + self.inner().netdevice_common.state |= state; + } + + fn operstate(&self) -> Operstate { + return self.inner().netdevice_common.operstate; + } + + fn set_operstate(&self, state: Operstate) { + self.inner().netdevice_common.operstate = state; + } +} diff --git a/kernel/src/net/routing/routing_table.rs b/kernel/src/net/routing/routing_table.rs new file mode 100644 index 000000000..5b6d4f584 --- /dev/null +++ b/kernel/src/net/routing/routing_table.rs @@ -0,0 +1,61 @@ +use core::sync::atomic::AtomicU32; + +use crate::time::Instant; +use alloc::vec::Vec; +use smoltcp::wire::{IpAddress, IpCidr}; + +static DEFAULT_TABLE_ID: AtomicU32 = AtomicU32::new(0); + +fn generate_table_id() -> u32 { + DEFAULT_TABLE_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed) +} + +#[derive(Debug, Clone)] +pub struct NextHop { + // 出口接口编号 + pub if_index: usize, + pub via_router: IpAddress, +} + +#[derive(Debug, Clone)] +pub struct RouteEntry { + pub destination: IpCidr, + pub next_hop: NextHop, + + // None 表示永久有效 + pub prefer_until: Option, + pub expired_at: Option, + + /// 度量值,暂时未用到 + pub metric: u32, +} + +#[derive(Debug, Default)] +pub struct RouteTable { + pub table_id: u32, + // pub entries: BTreeMap, + entries: Vec, +} + +impl RouteTable { + pub fn new() -> Self { + RouteTable { + table_id: generate_table_id(), + entries: Vec::new(), + } + } + + pub fn add_route(&mut self, entry: RouteEntry) { + self.entries.push(entry); + self.entries + .sort_by(|a, b| b.destination.prefix_len().cmp(&a.destination.prefix_len())); + } + + /// 根据目的IP地址查找最佳匹配的路由条目(最长前缀匹配)。 + pub fn lookup_route(&self, dest_ip: IpAddress) -> Option<&RouteEntry> { + self.entries + .iter() + .filter(|entry| entry.destination.contains_addr(&dest_ip)) + .max_by_key(|entry| entry.destination.prefix_len()) // 最长前缀匹配 + } +} diff --git a/user/apps/test-router/.gitignore b/user/apps/test-router/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/test-router/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/test-router/Cargo.toml b/user/apps/test-router/Cargo.toml new file mode 100644 index 000000000..79599a1ca --- /dev/null +++ b/user/apps/test-router/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-router" +version = "0.1.0" +edition = "2021" +description = "测试路由功能" +authors = [ "sparkzky " ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +smoltcp = { version = "0.9", features = ["medium-ethernet"] } +etherparse = "0.13" \ No newline at end of file diff --git a/user/apps/test-router/Makefile b/user/apps/test-router/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/test-router/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test-router/src/main.rs b/user/apps/test-router/src/main.rs new file mode 100644 index 000000000..aec1f5653 --- /dev/null +++ b/user/apps/test-router/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world DragonOS-musl"); +} From c2dbd1549012b38a0efa9a9e0ae102ef358fe4f3 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 3 Aug 2025 16:51:01 +0800 Subject: [PATCH 12/36] feat: draft router Signed-off-by: sparkzky --- kernel/src/driver/net/route_iface.rs | 162 +++++++++++-- kernel/src/net/routing/mod.rs | 298 +++++++++++++----------- kernel/src/net/routing/router.rs | 189 +++++++++++++++ kernel/src/net/routing/routing_table.rs | 43 +++- 4 files changed, 520 insertions(+), 172 deletions(-) create mode 100644 kernel/src/net/routing/router.rs diff --git a/kernel/src/driver/net/route_iface.rs b/kernel/src/driver/net/route_iface.rs index a27d0749e..1f400b541 100644 --- a/kernel/src/driver/net/route_iface.rs +++ b/kernel/src/driver/net/route_iface.rs @@ -14,13 +14,13 @@ use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; use crate::net::generate_iface_id; use crate::net::routing::router::Router; -use crate::time::Instant; use alloc::collections::VecDeque; use alloc::fmt::Debug; use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; use alloc::vec::Vec; use core::cell::UnsafeCell; +use core::net::Ipv4Addr; use core::ops::{Deref, DerefMut}; use smoltcp::phy::DeviceCapabilities; use smoltcp::wire::{EthernetAddress, HardwareAddress}; @@ -29,6 +29,15 @@ use smoltcp::{ wire::{IpAddress, IpCidr}, }; +/// 路由动作 +#[derive(Debug, PartialEq)] +pub enum RoutingAction { + DeliverToLocal, // 交给本地协议栈处理 + Forwarded, // 已转发给其他接口 + Drop, // 丢弃 + Ignore, // 忽略 +} + pub struct RouteRxToken { driver: RouteDriver, buffer: Vec, @@ -39,7 +48,46 @@ impl phy::RxToken for RouteRxToken { where F: FnOnce(&[u8]) -> R, { - f(self.buffer.as_slice()) + //? 如果buffer是L3包(来自Router注入),直接交给协议栈 + if self.is_l3_packet() { + return f(self.buffer.as_slice()); + } + + // 如果是完整的以太网帧,先让Router分析 + let routing_action = if let Some(router) = self.driver.router.upgrade() { + router.handle_received_frame(&self.driver.name(), &self.buffer) + } else { + RoutingAction::DeliverToLocal + }; + + match routing_action { + RoutingAction::DeliverToLocal => f(self.buffer.as_slice()), + _ => f(&[]), + } + } +} + +impl RouteRxToken { + pub fn is_l3_packet(&self) -> bool { + if self.buffer.len() < 20 { + return false; + } + + // 检查IPv4包特征 + let first_byte = self.buffer[0]; + let version = (first_byte >> 4) & 0x0F; + let ihl = (first_byte & 0x0F) as usize * 4; + + if version != 4 || ihl < 20 || self.buffer.len() < ihl { + return false; + } + + if self.buffer.len() >= 4 { + let total_length = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize; + return self.buffer.len() == total_length; + } + + false } } @@ -62,32 +110,45 @@ impl phy::TxToken for RouteTxToken { pub struct Route { name: String, - queue: VecDeque>, + rx_queue: VecDeque>, + tx_queue: VecDeque>, + l3_inject_queue: VecDeque>, } impl Route { pub fn new(name: &str) -> Self { - let queue = VecDeque::new(); Route { name: name.to_string(), - queue, + rx_queue: VecDeque::new(), + tx_queue: VecDeque::new(), + l3_inject_queue: VecDeque::new(), } } + pub fn inject_ether(&mut self, data: Vec) { + self.rx_queue.push_back(data); + } + pub fn route_receive(&mut self) -> Vec { - let buffer = self.queue.pop_front(); - match buffer { - Some(buffer) => { - return buffer; - } - None => { - return Vec::new(); - } + // 优先处理L3注入的包 + if let Some(l3_packet) = self.l3_inject_queue.pop_front() { + return l3_packet; } + + // 然后处理硬件接收的帧 + self.rx_queue.pop_front().unwrap_or_else(|| Vec::new()) } pub fn route_transmit(&mut self, buffer: Vec) { - self.queue.push_back(buffer) + self.tx_queue.push_back(buffer); + } + + pub fn pop_tx_frame(&mut self) -> Option> { + self.tx_queue.pop_front() + } + + pub fn inject_l3_packet(&mut self, ip_packet: Vec) { + self.l3_inject_queue.push_back(ip_packet); } } @@ -171,13 +232,13 @@ impl phy::Device for RouteDriver { ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { let buffer = self.inner.lock().route_receive(); - if let Some(router) = self.router.upgrade() { - router.recv_from_iface(buffer); - return None; - } + // if let Some(router) = self.router.upgrade() { + // router.recv_from_iface(buffer); + // return None; + // } if buffer.is_empty() { - return Option::None; + return None; } let rx = RouteRxToken { driver: self.clone(), @@ -186,7 +247,7 @@ impl phy::Device for RouteDriver { let tx = RouteTxToken { driver: self.clone(), }; - return Option::Some((rx, tx)); + Some((rx, tx)) } fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { @@ -214,6 +275,8 @@ pub struct InnerRouteInterface { kobj_common: KObjectCommonData, router: Weak, + // 本接口的子网信息 + subnet: Option, } impl Default for InnerRouteInterface { @@ -223,6 +286,7 @@ impl Default for InnerRouteInterface { device_common: DeviceCommonData::default(), kobj_common: KObjectCommonData::default(), router: Weak::default(), + subnet: None, } } } @@ -259,6 +323,51 @@ impl RouteInterface { }) } + // fn send_via_raw_socket(&self, ip_packet: Vec, dst_ip: Ipv4Addr) { + // let raw_rx_buffer = + // RawSocketBuffer::new(vec![raw::PacketMetadata::EMPTY; 4], vec![0; 1024]); + // let raw_tx_buffer = + // RawSocketBuffer::new(vec![raw::PacketMetadata::EMPTY; 4], vec![0; 1024]); + // let raw_socket = raw::Socket::new( + // smoltcp::wire::IpVersion::Ipv4, + // smoltcp::wire::IpProtocol::Unknown(0), // 接受所有协议 + // raw_rx_buffer, + // raw_tx_buffer, + // ); + // } + + pub fn inject_l3_packet_for_sending(&self, ip_packet: Vec) { + let mut device = self.driver.inner.lock(); + device.inject_l3_packet(ip_packet); + } + + pub fn receive_frame_from_hardware(&self, frame: Vec) { + let mut device = self.driver.inner.lock(); + device.inject_ether(frame); + } + + pub fn get_outgoing_frame(&self) -> Option> { + let mut device = self.driver.inner.lock(); + device.pop_tx_frame() + } + + pub fn set_subnet(&self, cidr: IpCidr) { + let mut inner = self.inner(); + inner.subnet = Some(cidr); + } + + pub fn subnet(&self) -> Option { + self.inner().subnet + } + + pub fn is_in_subnet(&self, ip: Ipv4Addr) -> bool { + if let Some(subnet) = self.subnet() { + subnet.contains_addr(&IpAddress::Ipv4(ip)) + } else { + false + } + } + pub fn attach_router(&self, router: Arc) { self.inner().router = Arc::downgrade(&router); self.driver.force_get_mut().attach_router(router); @@ -271,7 +380,7 @@ impl RouteInterface { }); } - pub fn add_default_route_to_peer(&self, ip: IpAddress) { + pub fn add_default_route(&self, ip: IpAddress) { let iface = &mut self.common.smol_iface.lock_irqsave(); iface.routes_mut().update(|routes_map| { @@ -290,6 +399,17 @@ impl RouteInterface { return self.inner.lock(); } + pub fn is_self_ip(&self, dst_ip: Ipv4Addr) -> bool { + let iface = self.common.smol_iface.lock(); + iface.ip_addrs().iter().any(|cidr| { + if let IpAddress::Ipv4(ip) = cidr.address() { + ip == dst_ip + } else { + false + } + }) + } + // pub fn send()!!! } diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs index c0b948165..e7dec74ba 100644 --- a/kernel/src/net/routing/mod.rs +++ b/kernel/src/net/routing/mod.rs @@ -1,139 +1,159 @@ -use crate::time::Instant; -use alloc::collections::BTreeMap; -use alloc::vec::Vec; -use smoltcp::wire::{IpAddress, IpCidr}; - -#[derive(Debug, Clone)] -pub struct NextHop { - // 出口接口编号 - pub if_index: u32, - pub via_router: IpAddress, -} - -#[derive(Debug, Clone)] -pub struct RouteEntry { - pub cidr: IpCidr, - pub next_hops: Vec, - - // None 表示永久有效 - pub prefer_until: Option, - pub expired_at: Option, -} - -#[derive(Debug)] -pub struct RouteTable { - pub table_id: u32, - pub entries: BTreeMap, -} - -impl RouteTable { - pub fn new(table_id: u32) -> Self { - RouteTable { - table_id, - entries: BTreeMap::new(), - } - } - - pub fn add_route(&mut self, cidr: IpCidr, entry: RouteEntry) { - self.entries.insert(cidr, entry); - } - - pub fn del_route(&mut self, cidr: &IpCidr) { - self.entries.remove(cidr); - } - - pub fn lookup(&self, ip: &IpAddress, now: Instant) -> Option<&NextHop> { - self.entries - .iter() - .filter(|(cidr, entry)| { - cidr.contains_addr(ip) && entry.expired_at.map_or(true, |t| now <= t) - }) - .max_by_key(|(cidr, _entry)| cidr.prefix_len()) - .and_then(|(_cidr, entry)| entry.next_hops.first()) - } -} - -pub struct RoutingSubsystem { - pub route_tables: Vec, - pub rules: Vec, -} - -impl RoutingSubsystem { - pub fn new() -> Self { - RoutingSubsystem { - route_tables: Vec::new(), - rules: Vec::new(), - } - } - - pub fn get_table_mut(&mut self, table_id: u32) -> Option<&mut RouteTable> { - self.route_tables - .iter_mut() - .find(|t| t.table_id == table_id) - } - - pub fn add_route_table(&mut self, table: RouteTable) { - self.route_tables.push(table); - } - - pub fn add_routing_rule(&mut self, rule: RoutingRule) { - self.rules.push(rule); - } - - pub fn lookup_route(&self, packet: &PacketMeta) -> Option<&NextHop> { - if let Some(rule) = self - .rules - .iter() - .filter(|r| r.matches(packet)) - .min_by_key(|r| r.priority) - { - return self - .route_tables - .iter() - .find(|t| t.table_id == rule.table_id) - .and_then(|t| t.lookup(&packet.dst_ip, Instant::now())); - } - None - } -} - -#[derive(Debug, Clone)] -pub struct RoutingRule { - pub from: Option, - pub tos: Option, - pub fwmark: Option, - pub table_id: u32, - // 匹配优先级,数字越小优先匹配 - pub priority: u32, -} - -pub struct PacketMeta { - pub src_ip: IpAddress, - pub dst_ip: IpAddress, - pub tos: u8, - pub fwmark: u32, -} - -impl RoutingRule { - pub fn matches(&self, packet: &PacketMeta) -> bool { - if let Some(ref from) = self.from { - if !from.contains_addr(&packet.src_ip) { - return false; - } - } - - if let Some(tos) = self.tos { - if packet.tos != tos { - return false; - } - } - - if let Some(fwmark) = self.fwmark { - if packet.fwmark != fwmark { - return false; - } - } - - true - } -} +pub mod router; +mod routing_table; + +// #[derive(Debug)] +// pub struct RouteTable { +// pub table_id: u32, +// pub entries: BTreeMap, +// } + +// impl RouteTable { +// pub fn new(table_id: u32) -> Self { +// RouteTable { +// table_id, +// entries: BTreeMap::new(), +// } +// } + +// pub fn add_route(&mut self, cidr: IpCidr, entry: RouteEntry) { +// self.entries.insert(cidr, entry); +// } + +// pub fn del_route(&mut self, cidr: &IpCidr) { +// self.entries.remove(cidr); +// } + +// pub fn lookup(&self, ip: &IpAddress, now: Instant) -> Option<&NextHop> { +// self.entries +// .iter() +// .filter(|(cidr, entry)| { +// cidr.contains_addr(ip) && entry.expired_at.map_or(true, |t| now <= t) +// }) +// .max_by_key(|(cidr, _entry)| cidr.prefix_len()) +// .and_then(|(_cidr, entry)| entry.next_hops.first()) +// } +// } + +// pub struct RoutingSubsystem { +// pub route_tables: Vec, +// pub rules: Vec, +// } + +// impl RoutingSubsystem { +// pub fn new() -> Self { +// RoutingSubsystem { +// route_tables: Vec::new(), +// rules: Vec::new(), +// } +// } + +// pub fn get_table_mut(&mut self, table_id: u32) -> Option<&mut RouteTable> { +// self.route_tables +// .iter_mut() +// .find(|t| t.table_id == table_id) +// } + +// pub fn add_route_table(&mut self, table: RouteTable) { +// self.route_tables.push(table); +// } + +// pub fn add_routing_rule(&mut self, rule: RoutingRule) { +// self.rules.push(rule); +// } + +// pub fn lookup_route(&self, packet: &PacketMeta) -> Option<&NextHop> { +// if let Some(rule) = self +// .rules +// .iter() +// .filter(|r| r.matches(packet)) +// .min_by_key(|r| r.priority) +// { +// return self +// .route_tables +// .iter() +// .find(|t| t.table_id == rule.table_id) +// .and_then(|t| t.lookup(&packet.dst_ip, Instant::now())); +// } +// None +// } +// } + +// #[derive(Debug, Clone)] +// pub struct RoutingRule { +// pub from: Option, +// pub tos: Option, +// pub fwmark: Option, +// pub table_id: u32, +// // 匹配优先级,数字越小优先匹配 +// pub priority: u32, +// } + +// pub struct PacketMeta { +// pub src_ip: IpAddress, +// pub dst_ip: IpAddress, +// pub tos: u8, +// pub fwmark: u32, +// } + +// impl RoutingRule { +// pub fn matches(&self, packet: &PacketMeta) -> bool { +// if let Some(ref from) = self.from { +// if !from.contains_addr(&packet.src_ip) { +// return false; +// } +// } + +// if let Some(tos) = self.tos { +// if packet.tos != tos { +// return false; +// } +// } + +// if let Some(fwmark) = self.fwmark { +// if packet.fwmark != fwmark { +// return false; +// } +// } + +// true +// } +// } + +//?test +// pub fn router_probe(name1: &str, name2: &str) -> (Arc, Arc) { +// let (iface1, iface2) = VethInterface::new_pair(name1, name2); + +// let addr1 = IpAddress::v4(10, 0, 0, 1); +// let cidr1 = IpCidr::new(addr1, 24); +// iface1.update_ip_addrs(cidr1); + +// let addr2 = IpAddress::v4(10, 0, 0, 2); +// let cidr2 = IpCidr::new(addr2, 24); +// iface2.update_ip_addrs(cidr2); + +// // 添加默认路由 +// iface1.add_default_route_to_peer(addr2); +// iface2.add_default_route_to_peer(addr1); + +// let turn_on = |a: &Arc| { +// a.set_net_state(NetDeivceState::__LINK_STATE_START); +// a.set_operstate(Operstate::IF_OPER_UP); +// NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); +// register_netdevice(a.clone()).expect("register veth device failed"); +// }; + +// turn_on(&iface1); +// turn_on(&iface2); + +// (iface1, iface2) +// } + +// #[unified_init(INITCALL_DEVICE)] +// pub fn veth_init() -> Result<(), SystemError> { +// router_probe("veth0", "veth1"); +// log::info!("Veth pair initialized."); +// Ok(()) +// } + + + diff --git a/kernel/src/net/routing/router.rs b/kernel/src/net/routing/router.rs new file mode 100644 index 000000000..13da296d3 --- /dev/null +++ b/kernel/src/net/routing/router.rs @@ -0,0 +1,189 @@ +use core::net::Ipv4Addr; +use crate::driver::base::kobject::KObject; +use crate::driver::net::route_iface::RouteInterface; +use crate::driver::net::route_iface::RoutingAction; +use crate::driver::net::Iface; +use crate::libs::spinlock::SpinLock; +use crate::libs::wait_queue::WaitQueue; +use crate::net::routing::routing_table::RouteTable; +use alloc::collections::VecDeque; +use alloc::string::String; +use alloc::string::ToString; +use alloc::sync::Arc; +use alloc::sync::Weak; +use alloc::vec::Vec; +use hashbrown::HashMap; +use smoltcp::wire::EthernetFrame; +use smoltcp::wire::EthernetProtocol; +use smoltcp::wire::Ipv4Packet; + +const ROUTER_NAME: &str = "router"; + +pub struct Router { + name: String, + route_table: RouteTable, + pub interfaces: HashMap>, + self_ref: Weak, + rx_buffer: SpinLock>>, + wait_queue: WaitQueue, +} + +impl Router { + pub fn new() -> Arc { + Arc::new_cyclic(|me| Router { + name: ROUTER_NAME.to_string(), + route_table: RouteTable::new(), + interfaces: HashMap::new(), + self_ref: me.clone(), + rx_buffer: SpinLock::new(VecDeque::new()), + wait_queue: WaitQueue::default(), + }) + } + + pub fn add_interface(&mut self, iface: Arc) { + iface.attach_router(self.self_ref.upgrade().unwrap()); + self.interfaces.insert(iface.name(), iface); + } + + pub fn recv_from_iface(&self, data: Vec) { + let mut buffer = self.rx_buffer.lock(); + buffer.push_back(data); + } + + fn is_local_destination(&self, dst_ip: Ipv4Addr) -> bool { + for interface in self.interfaces.values() { + if interface.is_self_ip(dst_ip) { + return true; + } + } + false + } + + fn route_l3_packet(&self, from_interface: &str, ip_packet: &[u8]) -> RoutingAction { + let packet = match Ipv4Packet::new_checked(ip_packet) { + Ok(packet) => packet, + Err(_) => { + log::error!("Invalid IPv4 packet received"); + return RoutingAction::Drop; + } + }; + + let dst_ip = packet.dst_addr(); + + if packet.hop_limit() <= 1 { + log::warn!("Packet dropped due to TTL <= 1"); + return RoutingAction::Drop; + } + + if self.is_local_destination(dst_ip) { + return RoutingAction::DeliverToLocal; + } + + if let Some(route) = self.route_table.lookup_route(dst_ip) { + // 防止环路:不能从同一接口转发回去 + if route.interface.name() == from_interface { + return RoutingAction::Drop; + } + + // 转发到目标接口 + self.forward_l3_packet_to_interface(&route.interface.name(), ip_packet.to_vec()); + + RoutingAction::Forwarded + } else { + RoutingAction::Drop + } + } + + pub fn handle_received_frame(&self, interface_name: &str, frame: &[u8]) -> RoutingAction { + let eth_frame = match EthernetFrame::new_checked(frame) { + Ok(frame) => frame, + Err(_) => return RoutingAction::Drop, + }; + + let interface = match self.interfaces.get(interface_name) { + Some(iface) => iface, + None => return RoutingAction::Drop, + }; + + let mac = interface.mac(); + if eth_frame.dst_addr() != mac && !eth_frame.dst_addr().is_broadcast() { + return RoutingAction::Ignore; + } + + match eth_frame.ethertype() { + EthernetProtocol::Ipv4 => { + // IPv4包,进行路由处理 + self.route_l3_packet(interface_name, eth_frame.payload()) + } + EthernetProtocol::Arp => { + // ARP包交给本地处理 + RoutingAction::DeliverToLocal + } + _ => { + // 其他协议,暂时忽略 + RoutingAction::Ignore + } + } + } + + fn forward_l3_packet_to_interface(&self, target_interface: &str, mut ip_packet: Vec) { + if let Some(interface) = self.interfaces.get(target_interface) { + // 减少TTL + if ip_packet.len() >= 20 { + let mut packet = Ipv4Packet::new_unchecked(&mut ip_packet); + let new_ttl = packet.hop_limit().saturating_sub(1); + packet.set_hop_limit(new_ttl); + + // 重新计算校验和 + packet.fill_checksum(); + } + + // 将L3包注入到目标接口,让smoltcp处理路由和发送 + interface.inject_l3_packet_for_sending(ip_packet); + } + } + + pub fn poll_blocking(&self) { + use crate::sched::SchedMode; + + loop { + let mut inner = self.rx_buffer.lock_irqsave(); + + let opt = inner.pop_front(); + if let Some(frame) = opt { + // let mut frame = smoltcp::wire::EthernetFrame::new_unchecked(frame); + // log::info!("Router received frame: {:?}", frame); + + // drop(inner); + + // let mut ip_packet_bytes = frame.payload_mut(); + // let mut ipv4_packet = Ipv4Packet::new_unchecked(&mut ip_packet_bytes); + + // // 1. 递减 TTL + // let original_ttl = ipv4_packet.hop_limit(); + // if original_ttl <= 1 { + // // TTL 耗尽,数据包应该被丢弃,并可能发送 ICMP Time Exceeded 消息 + // println!("TTL reached 0, dropping packet."); + // return; + // } + // ipv4_packet.set_hop_limit(original_ttl - 1); + + // ipv4_packet.fill_checksum(); + + // let dest_ip = ipv4_packet.dst_addr(); + + // if let Some(entry) = self.route_table.lookup_route(IpAddress::Ipv4(dest_ip)) { + // //todo! + // } + } else { + drop(inner); + log::info!("Router is going to sleep"); + let _ = wq_wait_event_interruptible!( + self.wait_queue, + !self.rx_buffer.lock().is_empty(), + {} + ); + } + } + } +} diff --git a/kernel/src/net/routing/routing_table.rs b/kernel/src/net/routing/routing_table.rs index 5b6d4f584..954291c6a 100644 --- a/kernel/src/net/routing/routing_table.rs +++ b/kernel/src/net/routing/routing_table.rs @@ -1,6 +1,7 @@ -use core::sync::atomic::AtomicU32; +use alloc::sync::Arc; +use core::{net::Ipv4Addr, sync::atomic::AtomicU32}; -use crate::time::Instant; +use crate::{driver::net::route_iface::RouteInterface, time::Instant}; use alloc::vec::Vec; use smoltcp::wire::{IpAddress, IpCidr}; @@ -10,17 +11,11 @@ fn generate_table_id() -> u32 { DEFAULT_TABLE_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed) } -#[derive(Debug, Clone)] -pub struct NextHop { - // 出口接口编号 - pub if_index: usize, - pub via_router: IpAddress, -} - #[derive(Debug, Clone)] pub struct RouteEntry { pub destination: IpCidr, - pub next_hop: NextHop, + pub next_hop: Option, + pub interface: Arc, // None 表示永久有效 pub prefer_until: Option, @@ -52,10 +47,34 @@ impl RouteTable { } /// 根据目的IP地址查找最佳匹配的路由条目(最长前缀匹配)。 - pub fn lookup_route(&self, dest_ip: IpAddress) -> Option<&RouteEntry> { + pub fn lookup_route(&self, dest_ip: Ipv4Addr) -> Option<&RouteEntry> { self.entries .iter() - .filter(|entry| entry.destination.contains_addr(&dest_ip)) + .filter(|entry| entry.destination.contains_addr(&IpAddress::Ipv4(dest_ip))) .max_by_key(|entry| entry.destination.prefix_len()) // 最长前缀匹配 } + + pub fn remove_route(&mut self, cidr: &IpCidr) { + self.entries.retain(|entry| entry.destination != *cidr); + } + + pub fn lookup(&self, dest_ip: &IpAddress) -> Option<(Arc, Option)> { + let mut best_match: Option<(&RouteEntry, u8)> = None; + + for entry in &self.entries { + if entry.destination.contains_addr(dest_ip) { + let current_prefix_len = entry.destination.prefix_len(); + if let Some((_, prev_prefix_len)) = best_match { + // If a previous match exists, check if the current one is more specific + if current_prefix_len > prev_prefix_len { + best_match = Some((entry, current_prefix_len)); + } + } else { + // First match found + best_match = Some((entry, current_prefix_len)); + } + } + } + best_match.map(|(entry, _)| (entry.interface.clone(), entry.next_hop)) + } } From 36104bc0cb200c86b51d9ab04d6e5e5e41a9fb1e Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 8 Aug 2025 19:06:31 +0800 Subject: [PATCH 13/36] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E7=AE=80?= =?UTF-8?q?=E5=8D=95=E7=9A=84=E8=B7=AF=E7=94=B1=E5=8A=9F=E8=83=BD,?= =?UTF-8?q?=E6=9C=AA=E8=AF=A6=E7=BB=86=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 1 - kernel/src/driver/net/route_iface.rs | 721 +++++++----------------- kernel/src/driver/net/veth.rs | 58 +- kernel/src/net/mod.rs | 1 - kernel/src/net/routing/mod.rs | 159 ------ kernel/src/net/routing/router.rs | 189 ------- kernel/src/net/routing/routing_table.rs | 80 --- 7 files changed, 252 insertions(+), 957 deletions(-) delete mode 100644 kernel/src/net/routing/mod.rs delete mode 100644 kernel/src/net/routing/router.rs delete mode 100644 kernel/src/net/routing/routing_table.rs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index aa2681351..f905aff54 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -58,7 +58,6 @@ struct MacEntryRecord { pub struct BridgePort { pub id: BridgePortId, pub(super) bridge_enable: Arc, - //先把这里直接改成driver,去除weak,忽略循环依赖 pub(super) bridge_driver: Weak, // 当前接口状态?forwarding, learning, blocking? // mac mtu信息 diff --git a/kernel/src/driver/net/route_iface.rs b/kernel/src/driver/net/route_iface.rs index 1f400b541..6f9929ccd 100644 --- a/kernel/src/driver/net/route_iface.rs +++ b/kernel/src/driver/net/route_iface.rs @@ -1,584 +1,255 @@ -use super::{Iface, IfaceCommon}; -use super::{NetDeivceState, NetDeviceCommonData, Operstate}; -use crate::arch::rand::rand; -use crate::driver::base::class::Class; -use crate::driver::base::device::bus::Bus; -use crate::driver::base::device::driver::Driver; -use crate::driver::base::device::{Device, DeviceCommonData, DeviceType, IdTable}; -use crate::driver::base::kobject::{ - KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, -}; -use crate::driver::base::kset::KSet; -use crate::filesystem::kernfs::KernFSInode; -use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; -use crate::libs::spinlock::{SpinLock, SpinLockGuard}; -use crate::net::generate_iface_id; -use crate::net::routing::router::Router; -use alloc::collections::VecDeque; -use alloc::fmt::Debug; +use crate::driver::net::Iface; +use crate::libs::rwlock::RwLock; +use alloc::collections::BTreeMap; use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; use alloc::vec::Vec; -use core::cell::UnsafeCell; -use core::net::Ipv4Addr; -use core::ops::{Deref, DerefMut}; -use smoltcp::phy::DeviceCapabilities; -use smoltcp::wire::{EthernetAddress, HardwareAddress}; -use smoltcp::{ - phy::{self}, - wire::{IpAddress, IpCidr}, -}; - -/// 路由动作 -#[derive(Debug, PartialEq)] -pub enum RoutingAction { - DeliverToLocal, // 交给本地协议栈处理 - Forwarded, // 已转发给其他接口 - Drop, // 丢弃 - Ignore, // 忽略 -} - -pub struct RouteRxToken { - driver: RouteDriver, - buffer: Vec, -} - -impl phy::RxToken for RouteRxToken { - fn consume(self, f: F) -> R - where - F: FnOnce(&[u8]) -> R, - { - //? 如果buffer是L3包(来自Router注入),直接交给协议栈 - if self.is_l3_packet() { - return f(self.buffer.as_slice()); - } - - // 如果是完整的以太网帧,先让Router分析 - let routing_action = if let Some(router) = self.driver.router.upgrade() { - router.handle_received_frame(&self.driver.name(), &self.buffer) - } else { - RoutingAction::DeliverToLocal - }; - - match routing_action { - RoutingAction::DeliverToLocal => f(self.buffer.as_slice()), - _ => f(&[]), +use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; + +#[derive(Debug, Clone)] +pub struct RouteEntry { + /// 目标网络 + pub destination: IpCidr, + /// 下一跳地址(如果是直连网络则为None) + pub next_hop: Option, + /// 出接口 + pub interface: Weak, + /// 路由优先级(数值越小优先级越高) + pub metric: u32, + /// 路由类型 + pub route_type: RouteType, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RouteType { + /// 直连路由 + Connected, + /// 静态路由 + Static, + /// 默认路由 + Default, +} + +impl RouteEntry { + pub fn new_connected(destination: IpCidr, interface: Arc) -> Self { + RouteEntry { + destination, + next_hop: None, + interface: Arc::downgrade(&interface), + metric: 0, + route_type: RouteType::Connected, } } -} -impl RouteRxToken { - pub fn is_l3_packet(&self) -> bool { - if self.buffer.len() < 20 { - return false; - } - - // 检查IPv4包特征 - let first_byte = self.buffer[0]; - let version = (first_byte >> 4) & 0x0F; - let ihl = (first_byte & 0x0F) as usize * 4; - - if version != 4 || ihl < 20 || self.buffer.len() < ihl { - return false; + pub fn new_static( + destination: IpCidr, + next_hop: IpAddress, + interface: Arc, + metric: u32, + ) -> Self { + RouteEntry { + destination, + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric, + route_type: RouteType::Static, } + } - if self.buffer.len() >= 4 { - let total_length = u16::from_be_bytes([self.buffer[2], self.buffer[3]]) as usize; - return self.buffer.len() == total_length; + pub fn new_default(next_hop: IpAddress, interface: Arc) -> Self { + RouteEntry { + destination: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric: 100, + route_type: RouteType::Default, } - - false } } -pub struct RouteTxToken { - driver: RouteDriver, -} - -impl phy::TxToken for RouteTxToken { - fn consume(self, len: usize, f: F) -> R - where - F: FnOnce(&mut [u8]) -> R, - { - let mut buffer = vec![0; len]; - let result = f(buffer.as_mut_slice()); - let mut device = self.driver.inner.lock(); - device.route_transmit(buffer); - result - } +/// 路由决策结果 +#[derive(Debug)] +pub struct RouteDecision { + /// 出接口 + pub interface: Arc, + /// 下一跳地址(先写在这里 + pub next_hop: IpAddress, } -pub struct Route { +#[derive(Debug)] +pub struct Router { name: String, - rx_queue: VecDeque>, - tx_queue: VecDeque>, - l3_inject_queue: VecDeque>, + /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec,并且应该在这上面加锁(maybe rwlock?) and 指针反而可以不加锁,在这个路由表这里加就行 + route_table: RwLock>, } -impl Route { - pub fn new(name: &str) -> Self { - Route { - name: name.to_string(), - rx_queue: VecDeque::new(), - tx_queue: VecDeque::new(), - l3_inject_queue: VecDeque::new(), +impl Router { + pub fn new(name: String) -> Self { + Self { + name, + route_table: RwLock::new(Vec::new()), } } - pub fn inject_ether(&mut self, data: Vec) { - self.rx_queue.push_back(data); - } - - pub fn route_receive(&mut self) -> Vec { - // 优先处理L3注入的包 - if let Some(l3_packet) = self.l3_inject_queue.pop_front() { - return l3_packet; + pub fn add_route(&mut self, route: RouteEntry) { + let mut guard = self.route_table.write(); + let pos = guard + .iter() + .position(|r| r.metric > route.metric) + .unwrap_or(guard.len()); + + guard.insert(pos, route); + log::info!("Router {}: Added route to routing table", self.name); + } + + pub fn remove_route(&mut self, destination: IpCidr) { + self.route_table + .write() + .retain(|route| route.destination != destination); + } + + pub fn lookup_route(&self, dest_ip: IpAddress) -> Option { + let guard = self.route_table.read(); + // 按最长前缀匹配原则查找路由 + let best = guard + .iter() + .filter(|route| { + route.interface.strong_count() > 0 && route.destination.contains_addr(&dest_ip) + }) + .max_by_key(|route| route.destination.prefix_len()); + + if let Some(entry) = best { + if let Some(interface) = entry.interface.upgrade() { + let next_hop = entry.next_hop.unwrap_or(dest_ip); + return Some(RouteDecision { + interface, + next_hop, + }); + } } - // 然后处理硬件接收的帧 - self.rx_queue.pop_front().unwrap_or_else(|| Vec::new()) - } - - pub fn route_transmit(&mut self, buffer: Vec) { - self.tx_queue.push_back(buffer); - } - - pub fn pop_tx_frame(&mut self) -> Option> { - self.tx_queue.pop_front() + None } - pub fn inject_l3_packet(&mut self, ip_packet: Vec) { - self.l3_inject_queue.push_back(ip_packet); - } -} - -#[derive(Debug)] -struct RouteDriverWapper(UnsafeCell); -unsafe impl Send for RouteDriverWapper {} -unsafe impl Sync for RouteDriverWapper {} - -impl Deref for RouteDriverWapper { - type Target = RouteDriver; - fn deref(&self) -> &Self::Target { - unsafe { &*self.0.get() } - } -} -impl DerefMut for RouteDriverWapper { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *self.0.get() } + /// 清理无效的路由表项(接口已经不存在的) + pub fn cleanup_routes(&mut self) { + self.route_table + .write() + .retain(|route| route.interface.strong_count() > 0); } } -impl RouteDriverWapper { - #[allow(clippy::mut_from_ref)] - #[allow(clippy::mut_from_ref)] - fn force_get_mut(&self) -> &mut RouteDriver { - unsafe { &mut *self.0.get() } - } +lazy_static! { + pub static ref GLOBAL_ROUTER: Arc = Arc::new(Router::new("global_router".to_string())); } -pub struct RouteDriver { - pub inner: Arc>, - pub router: Weak, +pub fn global_router() -> Arc { + GLOBAL_ROUTER.clone() } -impl RouteDriver { - pub fn new(name: &str) -> Self { - let inner = Arc::new(SpinLock::new(Route::new(name))); - RouteDriver { - inner, - router: Weak::default(), +/// 可供路由设备应该实现的 trait +pub trait RouterEnableDevice: Iface { + //todo 这里可以直接传一个IpPacket进来?如果目前只有ipv4的话 + fn handle_routable_packet(&self, packet: &[u8]) { + if packet.len() < 14 { + return; } - } - - pub fn name(&self) -> String { - self.inner.lock().name.clone() - } - pub fn attach_router(&mut self, router: Arc) { - self.router = Arc::downgrade(&router); - } -} + let ether_frame = match EthernetFrame::new_checked(packet) { + Ok(f) => f, + Err(_) => return, + }; -impl Clone for RouteDriver { - fn clone(&self) -> Self { - RouteDriver { - inner: self.inner.clone(), - router: self.router.clone(), + // 只处理IP包(IPv4) + if ether_frame.ethertype() != smoltcp::wire::EthernetProtocol::Ipv4 { + return; } - } -} - -impl phy::Device for RouteDriver { - type RxToken<'a> - = RouteRxToken - where - Self: 'a; - type TxToken<'a> - = RouteTxToken - where - Self: 'a; - - fn capabilities(&self) -> phy::DeviceCapabilities { - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1500; - caps.medium = smoltcp::phy::Medium::Ethernet; - caps - } - - fn receive( - &mut self, - _timestamp: smoltcp::time::Instant, - ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let buffer = self.inner.lock().route_receive(); - - // if let Some(router) = self.router.upgrade() { - // router.recv_from_iface(buffer); - // return None; - // } - if buffer.is_empty() { - return None; - } - let rx = RouteRxToken { - driver: self.clone(), - buffer, - }; - let tx = RouteTxToken { - driver: self.clone(), + let ipv4_packet = match Ipv4Packet::new_checked(ether_frame.payload()) { + Ok(p) => p, + Err(_) => return, }; - Some((rx, tx)) - } - fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { - Some(RouteTxToken { - driver: self.clone(), - }) - } -} + let dst_ip = ipv4_packet.dst_addr(); -#[cast_to([sync] Iface)] -#[cast_to([sync] crate::driver::base::device::Device)] -#[derive(Debug)] -pub struct RouteInterface { - name: String, - driver: RouteDriverWapper, - common: IfaceCommon, - inner: SpinLock, - locked_kobj_state: LockedKObjectState, -} - -#[derive(Debug)] -pub struct InnerRouteInterface { - netdevice_common: NetDeviceCommonData, - device_common: DeviceCommonData, - kobj_common: KObjectCommonData, - - router: Weak, - // 本接口的子网信息 - subnet: Option, -} - -impl Default for InnerRouteInterface { - fn default() -> Self { - InnerRouteInterface { - netdevice_common: NetDeviceCommonData::default(), - device_common: DeviceCommonData::default(), - kobj_common: KObjectCommonData::default(), - router: Weak::default(), - subnet: None, + // 检查TTL + if ipv4_packet.hop_limit() <= 1 { + log::warn!("TTL exceeded for packet to {}", dst_ip); + return; } - } -} -impl RouteInterface { - pub fn new(driver: RouteDriver) -> Arc { - let iface_id = generate_iface_id(); - - let mac = [ - 0x03, - 0x00, - 0x00, - 0x00, - (iface_id >> 8) as u8, - iface_id as u8, - ]; - - let hw_addr = HardwareAddress::Ethernet(EthernetAddress(mac)); - let mut iface_config = smoltcp::iface::Config::new(hw_addr); - iface_config.random_seed = rand() as u64; - let mut iface = smoltcp::iface::Interface::new( - iface_config, - &mut driver.clone(), - crate::time::Instant::now().into(), - ); - iface.set_any_ip(true); - - Arc::new(RouteInterface { - name: driver.name(), - driver: RouteDriverWapper(UnsafeCell::new(driver)), - common: IfaceCommon::new(iface_id, false, iface), - inner: SpinLock::new(InnerRouteInterface::default()), - locked_kobj_state: LockedKObjectState::default(), - }) - } - - // fn send_via_raw_socket(&self, ip_packet: Vec, dst_ip: Ipv4Addr) { - // let raw_rx_buffer = - // RawSocketBuffer::new(vec![raw::PacketMetadata::EMPTY; 4], vec![0; 1024]); - // let raw_tx_buffer = - // RawSocketBuffer::new(vec![raw::PacketMetadata::EMPTY; 4], vec![0; 1024]); - // let raw_socket = raw::Socket::new( - // smoltcp::wire::IpVersion::Ipv4, - // smoltcp::wire::IpProtocol::Unknown(0), // 接受所有协议 - // raw_rx_buffer, - // raw_tx_buffer, - // ); - // } - - pub fn inject_l3_packet_for_sending(&self, ip_packet: Vec) { - let mut device = self.driver.inner.lock(); - device.inject_l3_packet(ip_packet); - } - - pub fn receive_frame_from_hardware(&self, frame: Vec) { - let mut device = self.driver.inner.lock(); - device.inject_ether(frame); - } - - pub fn get_outgoing_frame(&self) -> Option> { - let mut device = self.driver.inner.lock(); - device.pop_tx_frame() - } - - pub fn set_subnet(&self, cidr: IpCidr) { - let mut inner = self.inner(); - inner.subnet = Some(cidr); - } - - pub fn subnet(&self) -> Option { - self.inner().subnet - } - - pub fn is_in_subnet(&self, ip: Ipv4Addr) -> bool { - if let Some(subnet) = self.subnet() { - subnet.contains_addr(&IpAddress::Ipv4(ip)) - } else { - false + // 检查是否是发给自己的包(目标IP是否是自己的IP) + if self.is_my_ip(dst_ip.into()) { + // 交给本地协议栈处理 + log::info!("Packet destined for local interface {}", self.iface_name()); + //todo + return; } - } - - pub fn attach_router(&self, router: Arc) { - self.inner().router = Arc::downgrade(&router); - self.driver.force_get_mut().attach_router(router); - } - pub fn update_ip_addrs(&self, cidr: IpCidr) { - let iface = &mut self.common.smol_iface.lock_irqsave(); - iface.update_ip_addrs(|ip_addrs| { - ip_addrs.push(cidr).expect("Push ipCidr failed: full"); - }); - } + // 查询全局路由表//todo 加入namespace之后在这里改成每个设备所属命名空间的Router即可 + let router = global_router(); - pub fn add_default_route(&self, ip: IpAddress) { - let iface = &mut self.common.smol_iface.lock_irqsave(); - - iface.routes_mut().update(|routes_map| { - routes_map - .push(smoltcp::iface::Route { - cidr: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), - via_router: ip, - preferred_until: None, - expires_at: None, - }) - .expect("Add default route to peer failed"); - }); - } - - fn inner(&self) -> SpinLockGuard { - return self.inner.lock(); - } - - pub fn is_self_ip(&self, dst_ip: Ipv4Addr) -> bool { - let iface = self.common.smol_iface.lock(); - iface.ip_addrs().iter().any(|cidr| { - if let IpAddress::Ipv4(ip) = cidr.address() { - ip == dst_ip - } else { - false + let decision = match router.lookup_route(dst_ip.into()) { + Some(d) => d, + None => { + log::warn!("No route to {}", dst_ip); + return; } - }) - } - - // pub fn send()!!! -} - -impl KObject for RouteInterface { - fn as_any_ref(&self) -> &dyn core::any::Any { - self - } - - fn set_inode(&self, inode: Option>) { - self.inner().kobj_common.kern_inode = inode; - } - - fn inode(&self) -> Option> { - self.inner().kobj_common.kern_inode.clone() - } - - fn parent(&self) -> Option> { - self.inner().kobj_common.parent.clone() - } - - fn set_parent(&self, parent: Option>) { - self.inner().kobj_common.parent = parent; - } - - fn kset(&self) -> Option> { - self.inner().kobj_common.kset.clone() - } - - fn set_kset(&self, kset: Option>) { - self.inner().kobj_common.kset = kset; - } - - fn kobj_type(&self) -> Option<&'static dyn KObjType> { - self.inner().kobj_common.kobj_type - } - - fn name(&self) -> String { - self.name.clone() - } - - fn set_name(&self, _name: String) { - // do nothing - } - - fn kobj_state(&self) -> RwLockReadGuard { - self.locked_kobj_state.read() - } - - fn kobj_state_mut(&self) -> RwLockWriteGuard { - self.locked_kobj_state.write() - } - - fn set_kobj_state(&self, state: KObjectState) { - *self.locked_kobj_state.write() = state; - } - - fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) { - self.inner().kobj_common.kobj_type = ktype; - } -} - -impl Device for RouteInterface { - fn dev_type(&self) -> DeviceType { - DeviceType::Net - } - - fn id_table(&self) -> IdTable { - IdTable::new(self.name.clone(), None) - } - - fn bus(&self) -> Option> { - self.inner().device_common.bus.clone() - } - - fn set_bus(&self, bus: Option>) { - self.inner().device_common.bus = bus; - } - - fn class(&self) -> Option> { - let mut guard = self.inner(); - let r = guard.device_common.class.clone()?.upgrade(); - if r.is_none() { - guard.device_common.class = None; - } - - return r; - } + }; - fn set_class(&self, class: Option>) { - self.inner().device_common.class = class; - } + drop(router); - fn driver(&self) -> Option> { - let r = self.inner().device_common.driver.clone()?.upgrade(); - if r.is_none() { - self.inner().device_common.driver = None; + // 检查是否是从同一个接口进来又要从同一个接口出去(避免回路) + if self.iface_name() == decision.interface.iface_name() { + log::warn!("Avoiding routing loop for packet to {}", dst_ip); + return; } - return r; - } - - fn set_driver(&self, driver: Option>) { - self.inner().device_common.driver = driver; - } - - fn is_dead(&self) -> bool { - false - } - - fn can_match(&self) -> bool { - self.inner().device_common.can_match - } + // 创建修改后的IP包(递减TTL) + let modified_ip_packet = ether_frame.payload().to_vec(); + // if modified_ip_packet.len() >= 9 { + // modified_ip_packet[8] = modified_ip_packet[8].saturating_sub(1); + // //todo 这里应该重新计算IP校验和,为了简化先跳过 + // } - fn set_can_match(&self, can_match: bool) { - self.inner().device_common.can_match = can_match; - } + // 交给出接口进行发送 + decision + .interface + .route_and_send(decision.next_hop, &modified_ip_packet); - fn state_synced(&self) -> bool { - true + log::info!( + "Routed packet from {} to {} via interface {}", + self.iface_name(), + dst_ip, + decision.interface.iface_name() + ); } - fn dev_parent(&self) -> Option> { - self.inner().device_common.get_parent_weak_or_clear() - } + /// 路由器决定通过此接口发送包时调用此方法 + /// 同Linux的ndo_start_xmit() + /// + /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( + fn route_and_send(&self, next_hop: IpAddress, ip_packet: &[u8]); - fn set_dev_parent(&self, parent: Option>) { - self.inner().device_common.parent = parent; - } + /// 检查IP地址是否是当前接口的IP + fn is_my_ip(&self, ip: IpAddress) -> bool; } -impl Iface for RouteInterface { - fn common(&self) -> &IfaceCommon { - &self.common - } - - fn iface_name(&self) -> String { - self.name.clone() - } - - fn mac(&self) -> smoltcp::wire::EthernetAddress { - let mac = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; - smoltcp::wire::EthernetAddress(mac) - } - - fn poll(&self) { - self.common.poll(self.driver.force_get_mut()) - } - - fn addr_assign_type(&self) -> u8 { - return self.inner().netdevice_common.addr_assign_type; - } - - fn net_device_type(&self) -> u16 { - return self.inner().netdevice_common.net_device_type; - } - - fn net_state(&self) -> NetDeivceState { - return self.inner().netdevice_common.state; - } - - fn set_net_state(&self, state: NetDeivceState) { - self.inner().netdevice_common.state |= state; - } - - fn operstate(&self) -> Operstate { - return self.inner().netdevice_common.operstate; - } +/// # 每一个`RouterEnableDevice`应该有的公共数据,包含 +/// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) +/// - 当前接口的路由器 (//todo:引入命名空间之后在这里指向当前所属命名空间的Router) +#[derive(Debug)] +pub struct RouterEnableDeviceCommon { + pub arp_table: RwLock>, + pub router: Weak, +} - fn set_operstate(&self, state: Operstate) { - self.inner().netdevice_common.operstate = state; +impl Default for RouterEnableDeviceCommon { + fn default() -> Self { + let router = global_router(); + Self { + arp_table: RwLock::new(BTreeMap::new()), + router: Arc::downgrade(&router), + } } } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 287303d10..eb7ef0f60 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -11,6 +11,7 @@ use crate::driver::base::kobject::{ }; use crate::driver::base::kset::KSet; use crate::driver::net::bridge::BridgePort; +use crate::driver::net::route_iface::{RouterEnableDevice, RouterEnableDeviceCommon}; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; @@ -28,7 +29,7 @@ use core::cell::UnsafeCell; use core::ops::{Deref, DerefMut}; use smoltcp::phy::DeviceCapabilities; use smoltcp::phy::{self, RxToken}; -use smoltcp::wire::{EthernetAddress, HardwareAddress, IpAddress, IpCidr}; +use smoltcp::wire::{EthernetAddress, EthernetFrame, HardwareAddress, IpAddress, IpCidr}; use system_error::SystemError; use unified_init::macros::unified_init; @@ -37,6 +38,7 @@ pub struct Veth { rx_queue: VecDeque>, /// 对端的 `VethInterface`,在完成数据发送的时候会使用到 peer: Weak, + self_iface_ref: Weak, } impl Veth { @@ -45,6 +47,7 @@ impl Veth { name, rx_queue: VecDeque::new(), peer: Weak::new(), + self_iface_ref: Weak::new(), } } @@ -62,6 +65,9 @@ impl Veth { return; } + // 如果是路由设备,则将数据发送到路由器 + self.to_router(data); + Self::to_peer(&peer, data); } } @@ -84,6 +90,13 @@ impl Veth { }; } + fn to_router(&self, data: &[u8]) { + if let Some(self_iface) = self.self_iface_ref.upgrade() { + let frame = EthernetFrame::new_checked(data).unwrap(); + self_iface.handle_routable_packet(frame.payload()); + } + } + pub fn recv_from_peer(&mut self) -> Option> { // log::info!("Veth {} trying to receive", self.name); self.rx_queue.pop_front() @@ -231,7 +244,10 @@ pub struct VethCommonData { kobj_common: KObjectCommonData, peer_veth: Weak, + //TODO 这里其实不用整个port,反而会导致循环引用,可以只留一个BridgeIface bridge_port_data: Option, + + router_common_data: RouterEnableDeviceCommon, } impl Default for VethCommonData { @@ -242,6 +258,7 @@ impl Default for VethCommonData { kobj_common: KObjectCommonData::default(), peer_veth: Weak::new(), bridge_port_data: None, + router_common_data: RouterEnableDeviceCommon::default(), } } } @@ -281,13 +298,15 @@ impl VethInterface { let device = Arc::new(VethInterface { name, - driver: VethDriverWarpper(UnsafeCell::new(driver)), + driver: VethDriverWarpper(UnsafeCell::new(driver.clone())), common: IfaceCommon::new(iface_id, true, iface), inner: SpinLock::new(VethCommonData::default()), locked_kobj_state: LockedKObjectState::default(), wait_queue: WaitQueue::default(), }); + driver.inner.lock().self_iface_ref = Arc::downgrade(&device); + // log::info!("VethInterface {} created with ID {}", device.name, iface_id); device } @@ -605,6 +624,41 @@ impl BridgeEnableDevice for VethInterface { } } +impl RouterEnableDevice for VethInterface { + fn route_and_send(&self, next_hop: IpAddress, ip_packet: &[u8]) { + log::info!( + "VethInterface {} routing packet to {}", + self.iface_name(), + next_hop + ); + + // 构造以太网帧 + let dst_mac = self.peer_veth().mac(); + let src_mac = self.mac(); + + // 以太网类型为 IPv4 + let ethertype = [0x08, 0x00]; + + let mut frame = Vec::with_capacity(14 + ip_packet.len()); + frame.extend_from_slice(&dst_mac.0); + frame.extend_from_slice(&src_mac.0); + frame.extend_from_slice(ðertype); + frame.extend_from_slice(ip_packet); + + // 发送到对端 + self.driver + .force_get_mut() + .inner + .lock_irqsave() + .send_to_peer(&frame); + } + + fn is_my_ip(&self, ip: IpAddress) -> bool { + let iface = self.common.smol_iface.lock_irqsave(); + iface.ip_addrs().iter().any(|cidr| cidr.contains_addr(&ip)) + } +} + pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { let (iface1, iface2) = VethInterface::new_pair(name1, name2); diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 7150cecbe..805761cc6 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -10,7 +10,6 @@ use crate::{driver::net::Iface, libs::rwlock::RwLock}; pub mod net_core; pub mod posix; -pub mod routing; pub mod socket; pub mod syscall; diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs deleted file mode 100644 index e7dec74ba..000000000 --- a/kernel/src/net/routing/mod.rs +++ /dev/null @@ -1,159 +0,0 @@ -pub mod router; -mod routing_table; - -// #[derive(Debug)] -// pub struct RouteTable { -// pub table_id: u32, -// pub entries: BTreeMap, -// } - -// impl RouteTable { -// pub fn new(table_id: u32) -> Self { -// RouteTable { -// table_id, -// entries: BTreeMap::new(), -// } -// } - -// pub fn add_route(&mut self, cidr: IpCidr, entry: RouteEntry) { -// self.entries.insert(cidr, entry); -// } - -// pub fn del_route(&mut self, cidr: &IpCidr) { -// self.entries.remove(cidr); -// } - -// pub fn lookup(&self, ip: &IpAddress, now: Instant) -> Option<&NextHop> { -// self.entries -// .iter() -// .filter(|(cidr, entry)| { -// cidr.contains_addr(ip) && entry.expired_at.map_or(true, |t| now <= t) -// }) -// .max_by_key(|(cidr, _entry)| cidr.prefix_len()) -// .and_then(|(_cidr, entry)| entry.next_hops.first()) -// } -// } - -// pub struct RoutingSubsystem { -// pub route_tables: Vec, -// pub rules: Vec, -// } - -// impl RoutingSubsystem { -// pub fn new() -> Self { -// RoutingSubsystem { -// route_tables: Vec::new(), -// rules: Vec::new(), -// } -// } - -// pub fn get_table_mut(&mut self, table_id: u32) -> Option<&mut RouteTable> { -// self.route_tables -// .iter_mut() -// .find(|t| t.table_id == table_id) -// } - -// pub fn add_route_table(&mut self, table: RouteTable) { -// self.route_tables.push(table); -// } - -// pub fn add_routing_rule(&mut self, rule: RoutingRule) { -// self.rules.push(rule); -// } - -// pub fn lookup_route(&self, packet: &PacketMeta) -> Option<&NextHop> { -// if let Some(rule) = self -// .rules -// .iter() -// .filter(|r| r.matches(packet)) -// .min_by_key(|r| r.priority) -// { -// return self -// .route_tables -// .iter() -// .find(|t| t.table_id == rule.table_id) -// .and_then(|t| t.lookup(&packet.dst_ip, Instant::now())); -// } -// None -// } -// } - -// #[derive(Debug, Clone)] -// pub struct RoutingRule { -// pub from: Option, -// pub tos: Option, -// pub fwmark: Option, -// pub table_id: u32, -// // 匹配优先级,数字越小优先匹配 -// pub priority: u32, -// } - -// pub struct PacketMeta { -// pub src_ip: IpAddress, -// pub dst_ip: IpAddress, -// pub tos: u8, -// pub fwmark: u32, -// } - -// impl RoutingRule { -// pub fn matches(&self, packet: &PacketMeta) -> bool { -// if let Some(ref from) = self.from { -// if !from.contains_addr(&packet.src_ip) { -// return false; -// } -// } - -// if let Some(tos) = self.tos { -// if packet.tos != tos { -// return false; -// } -// } - -// if let Some(fwmark) = self.fwmark { -// if packet.fwmark != fwmark { -// return false; -// } -// } - -// true -// } -// } - -//?test -// pub fn router_probe(name1: &str, name2: &str) -> (Arc, Arc) { -// let (iface1, iface2) = VethInterface::new_pair(name1, name2); - -// let addr1 = IpAddress::v4(10, 0, 0, 1); -// let cidr1 = IpCidr::new(addr1, 24); -// iface1.update_ip_addrs(cidr1); - -// let addr2 = IpAddress::v4(10, 0, 0, 2); -// let cidr2 = IpCidr::new(addr2, 24); -// iface2.update_ip_addrs(cidr2); - -// // 添加默认路由 -// iface1.add_default_route_to_peer(addr2); -// iface2.add_default_route_to_peer(addr1); - -// let turn_on = |a: &Arc| { -// a.set_net_state(NetDeivceState::__LINK_STATE_START); -// a.set_operstate(Operstate::IF_OPER_UP); -// NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); -// register_netdevice(a.clone()).expect("register veth device failed"); -// }; - -// turn_on(&iface1); -// turn_on(&iface2); - -// (iface1, iface2) -// } - -// #[unified_init(INITCALL_DEVICE)] -// pub fn veth_init() -> Result<(), SystemError> { -// router_probe("veth0", "veth1"); -// log::info!("Veth pair initialized."); -// Ok(()) -// } - - - diff --git a/kernel/src/net/routing/router.rs b/kernel/src/net/routing/router.rs deleted file mode 100644 index 13da296d3..000000000 --- a/kernel/src/net/routing/router.rs +++ /dev/null @@ -1,189 +0,0 @@ -use core::net::Ipv4Addr; -use crate::driver::base::kobject::KObject; -use crate::driver::net::route_iface::RouteInterface; -use crate::driver::net::route_iface::RoutingAction; -use crate::driver::net::Iface; -use crate::libs::spinlock::SpinLock; -use crate::libs::wait_queue::WaitQueue; -use crate::net::routing::routing_table::RouteTable; -use alloc::collections::VecDeque; -use alloc::string::String; -use alloc::string::ToString; -use alloc::sync::Arc; -use alloc::sync::Weak; -use alloc::vec::Vec; -use hashbrown::HashMap; -use smoltcp::wire::EthernetFrame; -use smoltcp::wire::EthernetProtocol; -use smoltcp::wire::Ipv4Packet; - -const ROUTER_NAME: &str = "router"; - -pub struct Router { - name: String, - route_table: RouteTable, - pub interfaces: HashMap>, - self_ref: Weak, - rx_buffer: SpinLock>>, - wait_queue: WaitQueue, -} - -impl Router { - pub fn new() -> Arc { - Arc::new_cyclic(|me| Router { - name: ROUTER_NAME.to_string(), - route_table: RouteTable::new(), - interfaces: HashMap::new(), - self_ref: me.clone(), - rx_buffer: SpinLock::new(VecDeque::new()), - wait_queue: WaitQueue::default(), - }) - } - - pub fn add_interface(&mut self, iface: Arc) { - iface.attach_router(self.self_ref.upgrade().unwrap()); - self.interfaces.insert(iface.name(), iface); - } - - pub fn recv_from_iface(&self, data: Vec) { - let mut buffer = self.rx_buffer.lock(); - buffer.push_back(data); - } - - fn is_local_destination(&self, dst_ip: Ipv4Addr) -> bool { - for interface in self.interfaces.values() { - if interface.is_self_ip(dst_ip) { - return true; - } - } - false - } - - fn route_l3_packet(&self, from_interface: &str, ip_packet: &[u8]) -> RoutingAction { - let packet = match Ipv4Packet::new_checked(ip_packet) { - Ok(packet) => packet, - Err(_) => { - log::error!("Invalid IPv4 packet received"); - return RoutingAction::Drop; - } - }; - - let dst_ip = packet.dst_addr(); - - if packet.hop_limit() <= 1 { - log::warn!("Packet dropped due to TTL <= 1"); - return RoutingAction::Drop; - } - - if self.is_local_destination(dst_ip) { - return RoutingAction::DeliverToLocal; - } - - if let Some(route) = self.route_table.lookup_route(dst_ip) { - // 防止环路:不能从同一接口转发回去 - if route.interface.name() == from_interface { - return RoutingAction::Drop; - } - - // 转发到目标接口 - self.forward_l3_packet_to_interface(&route.interface.name(), ip_packet.to_vec()); - - RoutingAction::Forwarded - } else { - RoutingAction::Drop - } - } - - pub fn handle_received_frame(&self, interface_name: &str, frame: &[u8]) -> RoutingAction { - let eth_frame = match EthernetFrame::new_checked(frame) { - Ok(frame) => frame, - Err(_) => return RoutingAction::Drop, - }; - - let interface = match self.interfaces.get(interface_name) { - Some(iface) => iface, - None => return RoutingAction::Drop, - }; - - let mac = interface.mac(); - if eth_frame.dst_addr() != mac && !eth_frame.dst_addr().is_broadcast() { - return RoutingAction::Ignore; - } - - match eth_frame.ethertype() { - EthernetProtocol::Ipv4 => { - // IPv4包,进行路由处理 - self.route_l3_packet(interface_name, eth_frame.payload()) - } - EthernetProtocol::Arp => { - // ARP包交给本地处理 - RoutingAction::DeliverToLocal - } - _ => { - // 其他协议,暂时忽略 - RoutingAction::Ignore - } - } - } - - fn forward_l3_packet_to_interface(&self, target_interface: &str, mut ip_packet: Vec) { - if let Some(interface) = self.interfaces.get(target_interface) { - // 减少TTL - if ip_packet.len() >= 20 { - let mut packet = Ipv4Packet::new_unchecked(&mut ip_packet); - let new_ttl = packet.hop_limit().saturating_sub(1); - packet.set_hop_limit(new_ttl); - - // 重新计算校验和 - packet.fill_checksum(); - } - - // 将L3包注入到目标接口,让smoltcp处理路由和发送 - interface.inject_l3_packet_for_sending(ip_packet); - } - } - - pub fn poll_blocking(&self) { - use crate::sched::SchedMode; - - loop { - let mut inner = self.rx_buffer.lock_irqsave(); - - let opt = inner.pop_front(); - if let Some(frame) = opt { - // let mut frame = smoltcp::wire::EthernetFrame::new_unchecked(frame); - // log::info!("Router received frame: {:?}", frame); - - // drop(inner); - - // let mut ip_packet_bytes = frame.payload_mut(); - // let mut ipv4_packet = Ipv4Packet::new_unchecked(&mut ip_packet_bytes); - - // // 1. 递减 TTL - // let original_ttl = ipv4_packet.hop_limit(); - // if original_ttl <= 1 { - // // TTL 耗尽,数据包应该被丢弃,并可能发送 ICMP Time Exceeded 消息 - // println!("TTL reached 0, dropping packet."); - // return; - // } - // ipv4_packet.set_hop_limit(original_ttl - 1); - - // ipv4_packet.fill_checksum(); - - // let dest_ip = ipv4_packet.dst_addr(); - - // if let Some(entry) = self.route_table.lookup_route(IpAddress::Ipv4(dest_ip)) { - // //todo! - // } - } else { - drop(inner); - log::info!("Router is going to sleep"); - let _ = wq_wait_event_interruptible!( - self.wait_queue, - !self.rx_buffer.lock().is_empty(), - {} - ); - } - } - } -} diff --git a/kernel/src/net/routing/routing_table.rs b/kernel/src/net/routing/routing_table.rs deleted file mode 100644 index 954291c6a..000000000 --- a/kernel/src/net/routing/routing_table.rs +++ /dev/null @@ -1,80 +0,0 @@ -use alloc::sync::Arc; -use core::{net::Ipv4Addr, sync::atomic::AtomicU32}; - -use crate::{driver::net::route_iface::RouteInterface, time::Instant}; -use alloc::vec::Vec; -use smoltcp::wire::{IpAddress, IpCidr}; - -static DEFAULT_TABLE_ID: AtomicU32 = AtomicU32::new(0); - -fn generate_table_id() -> u32 { - DEFAULT_TABLE_ID.fetch_add(1, core::sync::atomic::Ordering::Relaxed) -} - -#[derive(Debug, Clone)] -pub struct RouteEntry { - pub destination: IpCidr, - pub next_hop: Option, - pub interface: Arc, - - // None 表示永久有效 - pub prefer_until: Option, - pub expired_at: Option, - - /// 度量值,暂时未用到 - pub metric: u32, -} - -#[derive(Debug, Default)] -pub struct RouteTable { - pub table_id: u32, - // pub entries: BTreeMap, - entries: Vec, -} - -impl RouteTable { - pub fn new() -> Self { - RouteTable { - table_id: generate_table_id(), - entries: Vec::new(), - } - } - - pub fn add_route(&mut self, entry: RouteEntry) { - self.entries.push(entry); - self.entries - .sort_by(|a, b| b.destination.prefix_len().cmp(&a.destination.prefix_len())); - } - - /// 根据目的IP地址查找最佳匹配的路由条目(最长前缀匹配)。 - pub fn lookup_route(&self, dest_ip: Ipv4Addr) -> Option<&RouteEntry> { - self.entries - .iter() - .filter(|entry| entry.destination.contains_addr(&IpAddress::Ipv4(dest_ip))) - .max_by_key(|entry| entry.destination.prefix_len()) // 最长前缀匹配 - } - - pub fn remove_route(&mut self, cidr: &IpCidr) { - self.entries.retain(|entry| entry.destination != *cidr); - } - - pub fn lookup(&self, dest_ip: &IpAddress) -> Option<(Arc, Option)> { - let mut best_match: Option<(&RouteEntry, u8)> = None; - - for entry in &self.entries { - if entry.destination.contains_addr(dest_ip) { - let current_prefix_len = entry.destination.prefix_len(); - if let Some((_, prev_prefix_len)) = best_match { - // If a previous match exists, check if the current one is more specific - if current_prefix_len > prev_prefix_len { - best_match = Some((entry, current_prefix_len)); - } - } else { - // First match found - best_match = Some((entry, current_prefix_len)); - } - } - } - best_match.map(|(entry, _)| (entry.interface.clone(), entry.next_hop)) - } -} From 3e2ff8c20c8135462f381a5ce8ad541432f581eb Mon Sep 17 00:00:00 2001 From: sparkzky Date: Mon, 11 Aug 2025 21:11:18 +0800 Subject: [PATCH 14/36] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0netlink?= =?UTF-8?q?=E6=A1=86=E6=9E=B6,=E5=86=85=E6=A0=B8=E7=9B=B8=E5=BA=94?= =?UTF-8?q?=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=E4=BB=A5=E5=8F=8A?= =?UTF-8?q?=E8=AF=BB=E5=8F=96=E5=86=99=E5=85=A5=E7=94=A8=E6=88=B7=E7=A9=BA?= =?UTF-8?q?=E9=97=B4=E5=B0=9A=E6=9C=AA=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/filesystem/epoll/mod.rs | 6 + kernel/src/net/socket/endpoint.rs | 8 +- kernel/src/net/socket/mod.rs | 1 + kernel/src/net/socket/netlink/addr/mod.rs | 55 ++++ .../src/net/socket/netlink/addr/multicast.rs | 64 ++++ kernel/src/net/socket/netlink/common/bound.rs | 53 +++ kernel/src/net/socket/netlink/common/mod.rs | 171 ++++++++++ .../src/net/socket/netlink/common/unbound.rs | 86 +++++ .../net/socket/netlink/message/attr/mod.rs | 178 ++++++++++ .../net/socket/netlink/message/attr/noattr.rs | 38 +++ kernel/src/net/socket/netlink/message/mod.rs | 62 ++++ .../socket/netlink/message/segment/common.rs | 84 +++++ .../socket/netlink/message/segment/header.rs | 73 +++++ .../net/socket/netlink/message/segment/mod.rs | 75 +++++ kernel/src/net/socket/netlink/mod.rs | 44 +++ kernel/src/net/socket/netlink/receiver.rs | 51 +++ kernel/src/net/socket/netlink/route/bound.rs | 105 ++++++ .../net/socket/netlink/route/kernel/mod.rs | 27 ++ .../socket/netlink/route/message/attr/addr.rs | 125 +++++++ .../socket/netlink/route/message/attr/mod.rs | 5 + .../netlink/route/message/attr/route.rs | 95 ++++++ .../net/socket/netlink/route/message/mod.rs | 6 + .../netlink/route/message/segment/addr.rs | 120 +++++++ .../netlink/route/message/segment/mod.rs | 97 ++++++ .../netlink/route/message/segment/route.rs | 196 +++++++++++ kernel/src/net/socket/netlink/route/mod.rs | 7 + kernel/src/net/socket/netlink/table/mod.rs | 308 ++++++++++++++++++ .../src/net/socket/netlink/table/multicast.rs | 30 ++ .../src/net/socket/utils/datagram_common.rs | 178 ++++++++++ .../src/net/socket/{utils.rs => utils/mod.rs} | 20 +- 30 files changed, 2354 insertions(+), 14 deletions(-) create mode 100644 kernel/src/net/socket/netlink/addr/mod.rs create mode 100644 kernel/src/net/socket/netlink/addr/multicast.rs create mode 100644 kernel/src/net/socket/netlink/common/bound.rs create mode 100644 kernel/src/net/socket/netlink/common/mod.rs create mode 100644 kernel/src/net/socket/netlink/common/unbound.rs create mode 100644 kernel/src/net/socket/netlink/message/attr/mod.rs create mode 100644 kernel/src/net/socket/netlink/message/attr/noattr.rs create mode 100644 kernel/src/net/socket/netlink/message/mod.rs create mode 100644 kernel/src/net/socket/netlink/message/segment/common.rs create mode 100644 kernel/src/net/socket/netlink/message/segment/header.rs create mode 100644 kernel/src/net/socket/netlink/message/segment/mod.rs create mode 100644 kernel/src/net/socket/netlink/mod.rs create mode 100644 kernel/src/net/socket/netlink/receiver.rs create mode 100644 kernel/src/net/socket/netlink/route/bound.rs create mode 100644 kernel/src/net/socket/netlink/route/kernel/mod.rs create mode 100644 kernel/src/net/socket/netlink/route/message/attr/addr.rs create mode 100644 kernel/src/net/socket/netlink/route/message/attr/mod.rs create mode 100644 kernel/src/net/socket/netlink/route/message/attr/route.rs create mode 100644 kernel/src/net/socket/netlink/route/message/mod.rs create mode 100644 kernel/src/net/socket/netlink/route/message/segment/addr.rs create mode 100644 kernel/src/net/socket/netlink/route/message/segment/mod.rs create mode 100644 kernel/src/net/socket/netlink/route/message/segment/route.rs create mode 100644 kernel/src/net/socket/netlink/route/mod.rs create mode 100644 kernel/src/net/socket/netlink/table/mod.rs create mode 100644 kernel/src/net/socket/netlink/table/multicast.rs create mode 100644 kernel/src/net/socket/utils/datagram_common.rs rename kernel/src/net/socket/{utils.rs => utils/mod.rs} (60%) diff --git a/kernel/src/filesystem/epoll/mod.rs b/kernel/src/filesystem/epoll/mod.rs index c581c8475..2ee4c5b89 100644 --- a/kernel/src/filesystem/epoll/mod.rs +++ b/kernel/src/filesystem/epoll/mod.rs @@ -195,3 +195,9 @@ bitflags! { const EPOLL_LISTEN_CAN_ACCEPT = Self::EPOLLIN.bits | Self::EPOLLRDNORM.bits; } } + +impl EPollEventType { + pub fn filter(&self, events: &EPollEventType) -> bool { + self.intersects(*events) + } +} diff --git a/kernel/src/net/socket/endpoint.rs b/kernel/src/net/socket/endpoint.rs index b592008e1..a04151ee9 100644 --- a/kernel/src/net/socket/endpoint.rs +++ b/kernel/src/net/socket/endpoint.rs @@ -1,6 +1,9 @@ use crate::{ mm::{verify_area, VirtAddr}, - net::{posix::SockAddr, socket::unix::UnixEndpoint}, + net::{ + posix::SockAddr, + socket::{netlink::addr::NetlinkSocketAddr, unix::UnixEndpoint}, + }, }; pub use smoltcp::wire::IpEndpoint; @@ -11,7 +14,10 @@ pub enum Endpoint { LinkLayer(LinkLayerEndpoint), /// 网络层端点 Ip(IpEndpoint), + /// Unix域套接字端点 Unix(UnixEndpoint), + /// Netlink端点 + Netlink(NetlinkSocketAddr), } /// @brief 链路层端点 diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index 4489c4062..af8eef431 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -4,6 +4,7 @@ pub mod endpoint; mod family; pub mod inet; mod inode; +pub mod netlink; mod posix; pub mod unix; mod utils; diff --git a/kernel/src/net/socket/netlink/addr/mod.rs b/kernel/src/net/socket/netlink/addr/mod.rs new file mode 100644 index 000000000..a4cc14e1b --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/mod.rs @@ -0,0 +1,55 @@ +use crate::net::socket::{endpoint::Endpoint, netlink::addr::multicast::GroupIdSet}; +use system_error::SystemError; + +pub(super) mod multicast; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NetlinkSocketAddr { + port: u32, + groups: GroupIdSet, +} + +impl NetlinkSocketAddr { + pub fn new(port_num: u32, groups: GroupIdSet) -> Self { + Self { + port: port_num, + groups, + } + } + + pub fn new_unspecified() -> Self { + Self { + port: 0, + groups: GroupIdSet::new_empty(), + } + } + + pub const fn port(&self) -> u32 { + self.port + } + + pub fn groups(&self) -> GroupIdSet { + self.groups + } + + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } +} + +impl TryFrom for NetlinkSocketAddr { + type Error = SystemError; + + fn try_from(value: Endpoint) -> Result { + match value { + Endpoint::Netlink(addr) => Ok(addr), + _ => Err(SystemError::EAFNOSUPPORT), + } + } +} + +impl From for Endpoint { + fn from(value: NetlinkSocketAddr) -> Self { + Endpoint::Netlink(value) + } +} diff --git a/kernel/src/net/socket/netlink/addr/multicast.rs b/kernel/src/net/socket/netlink/addr/multicast.rs new file mode 100644 index 000000000..237afc957 --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/multicast.rs @@ -0,0 +1,64 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GroupIdSet(u32); + +impl GroupIdSet { + pub const fn new_empty() -> Self { + Self(0) + } + + pub const fn new(groups: u32) -> Self { + Self(groups) + } + + pub const fn ids_iter(&self) -> GroupIdIter { + GroupIdIter::new(self) + } + + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.0 |= groups.0; + } + + pub fn drop_groups(&mut self, groups: GroupIdSet) { + self.0 &= !groups.0; + } + + pub fn set_groups(&mut self, new_groups: u32) { + self.0 = new_groups; + } + + pub fn clear(&mut self) { + self.0 = 0; + } + + pub fn is_empty(&self) -> bool { + self.0 == 0 + } + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +pub struct GroupIdIter { + groups: u32, +} + +impl GroupIdIter { + const fn new(groups: &GroupIdSet) -> Self { + Self { groups: groups.0 } + } +} + +impl Iterator for GroupIdIter { + type Item = u32; + + fn next(&mut self) -> Option { + if self.groups > 0 { + let group_id = self.groups.trailing_zeros(); + self.groups &= self.groups - 1; + return Some(group_id); + } + + None + } +} diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs new file mode 100644 index 000000000..0c86e6d37 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -0,0 +1,53 @@ +use crate::net::socket::netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + receiver::MessageQueue, + table::BoundHandle, +}; +use alloc::fmt::Debug; +use system_error::SystemError; + +#[derive(Debug)] +pub struct BoundNetlink { + pub(in crate::net::socket::netlink) handle: BoundHandle, + pub(in crate::net::socket::netlink) remote_addr: NetlinkSocketAddr, + pub(in crate::net::socket::netlink) receive_queue: MessageQueue, +} + +impl BoundNetlink { + pub(super) fn new(handle: BoundHandle, message_queue: MessageQueue) -> Self { + Self { + handle, + remote_addr: NetlinkSocketAddr::new_unspecified(), + receive_queue: message_queue, + } + } + + pub fn bind_common(&mut self, endpoint: &NetlinkSocketAddr) -> Result<(), SystemError> { + if endpoint.port() != self.handle.port() { + return Err(SystemError::EINVAL); + } + let groups = endpoint.groups(); + self.handle.bind_groups(groups); + + Ok(()) + } + + // pub fn check_io_events_common(&self) -> EPollEventType { + // let mut events = EPollEventType::EPOLLOUT; + + // let receive_queue = self.receive_queue.0.lock(); + // if !receive_queue.is_empty() { + // events |= EPollEventType::EPOLLIN; + // } + + // events + // } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.handle.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.handle.drop_groups(groups); + } +} diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs new file mode 100644 index 000000000..6ae6d2111 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -0,0 +1,171 @@ +use crate::{ + libs::{rwlock::RwLock, wait_queue::WaitQueue}, + net::socket::{ + endpoint::Endpoint, + netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + common::{bound::BoundNetlink, unbound::UnboundNetlink}, + table::SupportedNetlinkProtocol, + }, + utils::datagram_common::{select_remote_and_bind, Bound, Inner}, + Socket, + }, +}; +use alloc::sync::Arc; +use core::sync::atomic::AtomicBool; +use system_error::SystemError; + +pub(super) mod bound; +mod unbound; + +#[derive(Debug)] +pub struct NetlinkSocket { + inner: RwLock, BoundNetlink>>, + + is_nonblocking: AtomicBool, + wait_queue: Arc, +} + +impl NetlinkSocket

+where + BoundNetlink: Bound, +{ + pub fn new(is_nonblocking: bool) -> Arc { + let unbound = UnboundNetlink::new(); + Arc::new(Self { + inner: RwLock::new(Inner::Unbound(unbound)), + is_nonblocking: AtomicBool::new(is_nonblocking), + wait_queue: Arc::new(WaitQueue::default()), + }) + } + + fn try_send( + &self, + buf: &[u8], + to: Option, + flags: crate::net::socket::PMSG, + ) -> Result { + let send_bytes = select_remote_and_bind( + &self.inner, + to, + || { + self.inner.write().bind_ephemeral( + &NetlinkSocketAddr::new_unspecified(), + self.wait_queue.clone(), + ) + }, + |bound, remote| bound.try_send(buf, &remote, flags), + )?; + // todo pollee invalidate?? + + Ok(send_bytes) + } + + fn try_recv( + &self, + buf: &mut [u8], + flags: crate::net::socket::PMSG, + ) -> Result<(usize, Endpoint), SystemError> { + let (recv_bytes, endpoint) = self + .inner + .read() + .try_recv(buf, flags) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + // todo self.pollee.invalidate(); + + Ok((recv_bytes, endpoint)) + } +} + +impl Socket for NetlinkSocket

+where + BoundNetlink: Bound, +{ + fn connect( + &self, + endpoint: crate::net::socket::endpoint::Endpoint, + ) -> Result<(), system_error::SystemError> { + let endpoint = endpoint.try_into()?; + + self.inner + .write() + .connect(&endpoint, self.wait_queue.clone()) + } + + fn bind( + &self, + endpoint: crate::net::socket::endpoint::Endpoint, + ) -> Result<(), system_error::SystemError> { + let endpoint = endpoint.try_into()?; + + self.inner.write().bind(&endpoint, self.wait_queue.clone()) + } + + fn send_to( + &self, + buffer: &[u8], + flags: crate::net::socket::PMSG, + address: crate::net::socket::endpoint::Endpoint, + ) -> Result { + let endpoint = address.try_into()?; + + self.try_send(buffer, Some(endpoint), flags) + } + + fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + fn recv_from( + &self, + buffer: &mut [u8], + flags: crate::net::socket::PMSG, + address: Option, + ) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> { + //todo 处理一下阻塞的逻辑 + + self.try_recv(buffer, flags) + } + + fn poll(&self) -> usize { + todo!() + } + + fn send_buffer_size(&self) -> usize { + log::warn!("send_buffer_size is implemented to 0"); + 0 + } + + fn recv_buffer_size(&self) -> usize { + log::warn!("recv_buffer_size is implemented to 0"); + 0 + } +} + +impl NetlinkSocket

{ + pub fn is_nonblocking(&self) -> bool { + self.is_nonblocking + .load(core::sync::atomic::Ordering::Relaxed) + } + + pub fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking + .store(nonblocking, core::sync::atomic::Ordering::Relaxed); + } +} + +impl Inner, BoundNetlink> { + fn add_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Bound(bound) => bound.add_groups(groups), + Inner::Unbound(unbound) => unbound.add_groups(groups), + } + } + + fn drop_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Unbound(unbound) => unbound.drop_groups(groups), + Inner::Bound(bound) => bound.drop_groups(groups), + } + } +} diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs new file mode 100644 index 000000000..80f975657 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -0,0 +1,86 @@ +use crate::{ + libs::wait_queue::WaitQueue, + net::socket::{ + netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + common::bound::BoundNetlink, + receiver::{MessageQueue, MessageReceiver}, + table::SupportedNetlinkProtocol, + }, + utils::datagram_common, + }, +}; +use alloc::sync::Arc; +use core::marker::PhantomData; +use system_error::SystemError; + +#[derive(Debug)] +pub struct UnboundNetlink { + groups: GroupIdSet, + phantom: PhantomData>, +} + +impl UnboundNetlink

{ + pub(super) fn new() -> Self { + Self { + groups: GroupIdSet::new_empty(), + phantom: PhantomData, + } + } + + pub(super) fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(0, self.groups) + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.groups.drop_groups(groups); + } +} + +impl datagram_common::Unbound for UnboundNetlink

{ + type Endpoint = NetlinkSocketAddr; + type Bound = BoundNetlink; + + fn bind( + &mut self, + endpoint: &NetlinkSocketAddr, + wait_queue: Arc, + ) -> Result, SystemError> { + let message_queue = MessageQueue::::new(); + let bound_handle = { + let endpoint = { + let mut endpoint = *endpoint; + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); +

::bind(&endpoint, receiver)? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue)) + } + + fn bind_ephemeral( + &mut self, + _remote_endpoint: &Self::Endpoint, + wait_queue: Arc, + ) -> Result, SystemError> { + let message_queue = MessageQueue::::new(); + + let bound_handle = { + let endpoint = { + let mut endpoint = NetlinkSocketAddr::new_unspecified(); + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); +

::bind(&endpoint, receiver)? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue)) + } +} diff --git a/kernel/src/net/socket/netlink/message/attr/mod.rs b/kernel/src/net/socket/netlink/message/attr/mod.rs new file mode 100644 index 000000000..c743ae142 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/attr/mod.rs @@ -0,0 +1,178 @@ +pub mod noattr; + +use crate::net::socket::netlink::message::NLMSG_ALIGN; +use alloc::vec::Vec; +use system_error::SystemError; + +const IS_NESTED_MASK: u16 = 1u16 << 15; +const IS_NET_BYTEORDER_MASK: u16 = 1u16 << 14; +const ATTRIBUTE_TYPE_MASK: u16 = !(IS_NESTED_MASK | IS_NET_BYTEORDER_MASK); + +/// Netlink Attribute Header +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct CAttrHeader { + len: u16, + type_: u16, +} + +impl CAttrHeader { + fn from_payload_len(type_: u16, payload_len: usize) -> Self { + let total_len = payload_len + size_of::(); + // debug_assert!(total_len <= u16::MAX as usize); + + Self { + len: total_len as u16, + type_, + } + } + + pub fn type_(&self) -> u16 { + self.type_ & ATTRIBUTE_TYPE_MASK + } + + pub fn payload_len(&self) -> usize { + self.len as usize - size_of::() + } + + pub fn total_len(&self) -> usize { + self.len as usize + } + + pub fn total_len_with_padding(&self) -> usize { + (self.len as usize).checked_add(NLMSG_ALIGN - 1).unwrap() & !(NLMSG_ALIGN - 1) + } + + pub fn padding_len(&self) -> usize { + self.total_len_with_padding() - self.total_len() + } +} + +/// Netlink Attribute +pub trait Attribute: core::fmt::Debug + Send + Sync { + fn type_(&self) -> u16; + + fn payload_as_bytes(&self) -> &[u8]; + + fn total_len_with_padding(&self) -> usize { + const DUMMY_TYPE: u16 = 0; + + CAttrHeader::from_payload_len(DUMMY_TYPE, self.payload_as_bytes().len()) + .total_len_with_padding() + } + + fn read_from_buf(header: &CAttrHeader, payload_buf: &[u8]) -> Result, SystemError> + where + Self: Sized; + + fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + // let payload_bytes = self.payload_as_bytes(); + // let header = CAttrHeader { + // len: (core::mem::size_of::() + payload_bytes.len()) as u16, + // type_: self.type_(), + // }; + + // let total_len = header.len as usize; + // let padded_len = align_to(total_len, NLMSG_ALIGN); + // let padding_len = padded_len - total_len; + + // // 确保 buf 足够大 + // if buf.len() < offset + padded_len { + // buf.resize(offset + padded_len, 0); + // } + + // // 写入头部 + // let header_bytes = unsafe { + // core::slice::from_raw_parts( + // &header as *const CAttrHeader as *const u8, + // core::mem::size_of::(), + // ) + // }; + // buf[offset..offset + header_bytes.len()].copy_from_slice(header_bytes); + + // // 写入负载 + // buf[offset + header_bytes.len()..offset + header_bytes.len() + payload_bytes.len()] + // .copy_from_slice(payload_bytes); + + // // 填充部分已经在 resize 时置零,无需额外处理 + + // Ok(()) + + let payload_bytes = self.payload_as_bytes(); + let header = CAttrHeader { + len: (size_of::() + payload_bytes.len()) as u16, + type_: self.type_(), + }; + + // 写入头部 + let header_bytes = unsafe { + core::slice::from_raw_parts( + &header as *const CAttrHeader as *const u8, + size_of::(), + ) + }; + buf.extend_from_slice(header_bytes); + + // 写入负载 + buf.extend_from_slice(payload_bytes); + + // 添加对齐填充 + let total_len = header.len as usize; + let padded_len = align_to(total_len, NLMSG_ALIGN); + let padding_len = padded_len - total_len; + if padding_len > 0 { + buf.extend_from_slice(&vec![0u8; padding_len]); + } + + Ok(()) + } + + fn read_all_from_buf(buf: &[u8], mut offset: usize) -> Result, SystemError> + where + Self: Sized, + { + let mut attrs = Vec::new(); + + while offset < buf.len() { + // 检查是否有足够的字节读取属性头部 + if buf.len() - offset < size_of::() { + return Err(SystemError::EINVAL); + } + + // 读取属性头部 + let attr_header_bytes = &buf[offset..offset + size_of::()]; + let attr_header = unsafe { *(attr_header_bytes.as_ptr() as *const CAttrHeader) }; + + // 验证属性长度 + if attr_header.len < size_of::() as u16 { + return Err(SystemError::EINVAL); + } + + let attr_total_len = attr_header.len as usize; + if buf.len() - offset < attr_total_len { + return Err(SystemError::EINVAL); + } + + // 读取属性负载 + let payload_start = offset + size_of::(); + let payload_len = attr_total_len - size_of::(); + let payload_buf = &buf[payload_start..payload_start + payload_len]; + + // 解析属性 + if let Some(attr) = Self::read_from_buf(&attr_header, payload_buf)? { + attrs.push(attr); + } + + // 移动到下一个属性(考虑对齐) + let padded_len = align_to(attr_total_len, NLMSG_ALIGN); + offset += padded_len; + } + + Ok(attrs) + } +} + +// 辅助函数 +fn align_to(value: usize, align: usize) -> usize { + (value + align - 1) & !(align - 1) +} diff --git a/kernel/src/net/socket/netlink/message/attr/noattr.rs b/kernel/src/net/socket/netlink/message/attr/noattr.rs new file mode 100644 index 000000000..530f240d1 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/attr/noattr.rs @@ -0,0 +1,38 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use alloc::vec::Vec; + +#[derive(Debug)] +pub enum NoAttr {} + +impl Attribute for NoAttr { + fn type_(&self) -> u16 { + match *self {} + } + + fn payload_as_bytes(&self) -> &[u8] { + match *self {} + } + + fn read_from_buf( + header: &super::CAttrHeader, + payload_buf: &[u8], + ) -> Result, system_error::SystemError> + where + Self: Sized, + { + let payload_len = header.payload_len(); + //todo reader.skip_some(payload_len); + + Ok(None) + } + + fn read_all_from_buf( + buf: &[u8], + mut offset: usize, + ) -> Result, system_error::SystemError> + where + Self: Sized, + { + Ok(Vec::new()) + } +} diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs new file mode 100644 index 000000000..dd929e6a7 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -0,0 +1,62 @@ +use crate::net::socket::netlink::message::segment::header::CMsgSegHdr; +use alloc::vec::Vec; +use system_error::SystemError; + +pub mod attr; +pub mod segment; + +#[derive(Debug)] +pub struct Message { + segments: Vec, +} + +impl Message { + pub fn new(segments: Vec) -> Self { + Self { segments } + } + + pub fn segments(&self) -> &[T] { + &self.segments + } + + pub fn segments_mut(&mut self) -> &mut [T] { + &mut self.segments + } + + pub fn read_from(reader: &[u8]) -> Result { + let segments = { + let segment = T::read_from(reader)?; + vec![segment] + }; + + Ok(Self { segments }) + } + + pub fn write_to(&self, writer: &mut [u8]) -> Result { + let total_written: usize = self + .segments + .iter() + .map(|segment| segment.write_to(writer)) + .collect::, SystemError>>()? + .iter() + .sum(); + + Ok(total_written) + } + + pub fn total_len(&self) -> usize { + self.segments + .iter() + .map(|segment| segment.header().len as usize) + .sum() + } +} + +pub trait ProtocolSegment: Sized + alloc::fmt::Debug { + fn header(&self) -> &CMsgSegHdr; + fn header_mut(&mut self) -> &mut CMsgSegHdr; + fn read_from(reader: &[u8]) -> Result; + fn write_to(&self, writer: &mut [u8]) -> Result; +} + +pub(super) const NLMSG_ALIGN: usize = 4; diff --git a/kernel/src/net/socket/netlink/message/segment/common.rs b/kernel/src/net/socket/netlink/message/segment/common.rs new file mode 100644 index 000000000..a358e0c98 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/common.rs @@ -0,0 +1,84 @@ +use alloc::vec::Vec; +use system_error::SystemError; + +use crate::net::socket::netlink::message::{ + attr::Attribute, + segment::{header::CMsgSegHdr, SegmentBody}, +}; + +#[derive(Debug)] +pub struct SegmentCommon { + header: CMsgSegHdr, + body: Body, + attrs: Vec, +} + +impl SegmentCommon { + pub const HEADER_LEN: usize = size_of::(); + + pub fn header(&self) -> &CMsgSegHdr { + &self.header + } + + pub fn header_mut(&mut self) -> &mut CMsgSegHdr { + &mut self.header + } + + pub fn body(&self) -> &Body { + &self.body + } + + pub fn attrs(&self) -> &Vec { + &self.attrs + } +} + +impl SegmentCommon { + pub const BODY_LEN: usize = size_of::(); + + pub fn new(header: CMsgSegHdr, body: Body, attrs: Vec) -> Self { + let mut res = Self { + header, + body, + attrs, + }; + res.header.len = res.total_len() as u32; + res + } + + pub fn read_from_buf(header: CMsgSegHdr, buf: &[u8]) -> Result { + let (body, remain_len) = Body::read_from_buf(&header, buf)?; + let attrs = Attr::read_all_from_buf(buf, buf.len() - remain_len)?; + + Ok(Self { + header, + body, + attrs, + }) + } + + pub fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + if buf.len() < self.header.len as usize { + return Err(SystemError::EINVAL); + } + + self.body.write_to_buf(buf)?; + for attr in self.attrs.iter() { + attr.write_to_buf(buf)?; + } + Ok(()) + } + + pub fn total_len(&self) -> usize { + Self::HEADER_LEN + Self::BODY_LEN + self.attrs_len() + } +} + +impl SegmentCommon { + pub fn attrs_len(&self) -> usize { + self.attrs + .iter() + .map(|attr| attr.total_len_with_padding()) + .sum() + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/header.rs b/kernel/src/net/socket/netlink/message/segment/header.rs new file mode 100644 index 000000000..03b33e631 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/header.rs @@ -0,0 +1,73 @@ +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CMsgSegHdr { + /// Length of the message, including the header + pub len: u32, + /// Type of message content + pub type_: u16, + /// Additional flags + pub flags: u16, + /// Sequence number + pub seq: u32, + /// Sending process port ID + pub pid: u32, +} + +bitflags! { + pub struct SegHdrCommonFlags: u16 { + /// Indicates a request message + const REQUEST = 0x01; + /// Multipart message, terminated by NLMSG_DONE + const MULTI = 0x02; + /// Reply with an acknowledgment, with zero or an error code + const ACK = 0x04; + /// Echo this request + const ECHO = 0x08; + /// Dump was inconsistent due to sequence change + const DUMP_INTR = 0x10; + /// Dump was filtered as requested + const DUMP_FILTERED = 0x20; + } +} + +bitflags! { + pub struct GetRequestFlags: u16 { + /// Specify the tree root + const ROOT = 0x100; + /// Return all matching results + const MATCH = 0x200; + /// Atomic get request + const ATOMIC = 0x400; + /// Combination flag for root and match + const DUMP = Self::ROOT.bits | Self::MATCH.bits; + } +} + +bitflags! { + pub struct NewRequestFlags: u16 { + /// Override existing entries + const REPLACE = 0x100; + /// Do not modify if it exists + const EXCL = 0x200; + /// Create if it does not exist + const CREATE = 0x400; + /// Add to the end of the list + const APPEND = 0x800; + } +} + +bitflags! { + pub struct DeleteRequestFlags: u16 { + /// Do not delete recursively + const NONREC = 0x100; + /// Delete multiple objects + const BULK = 0x200; + } +} + +bitflags! { + pub struct AckFlags: u16 { + const CAPPED = 0x100; + const ACK_TLVS = 0x100; + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/mod.rs b/kernel/src/net/socket/netlink/message/segment/mod.rs new file mode 100644 index 000000000..803250db5 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/mod.rs @@ -0,0 +1,75 @@ +use system_error::SystemError; + +use crate::net::socket::netlink::message::{segment::header::CMsgSegHdr, NLMSG_ALIGN}; + +pub mod common; +pub mod header; + +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum CSegmentType { + // Standard netlink message types + NOOP = 1, + ERROR = 2, + DONE = 3, + OVERRUN = 4, + + // protocol-level types + NEWLINK = 16, + DELLINK = 17, + GETLINK = 18, + SETLINK = 19, + + NEWADDR = 20, + DELADDR = 21, + GETADDR = 22, + + NEWROUTE = 24, + DELROUTE = 25, + GETROUTE = 26, + // TODO 补充 +} + +impl TryFrom for CSegmentType { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 1 => Ok(CSegmentType::NOOP), + 2 => Ok(CSegmentType::ERROR), + 3 => Ok(CSegmentType::DONE), + 4 => Ok(CSegmentType::OVERRUN), + 16 => Ok(CSegmentType::NEWLINK), + 17 => Ok(CSegmentType::DELLINK), + 18 => Ok(CSegmentType::GETLINK), + 19 => Ok(CSegmentType::SETLINK), + 20 => Ok(CSegmentType::NEWADDR), + 21 => Ok(CSegmentType::DELADDR), + 22 => Ok(CSegmentType::GETADDR), + 24 => Ok(CSegmentType::NEWROUTE), + 25 => Ok(CSegmentType::DELROUTE), + 26 => Ok(CSegmentType::GETROUTE), + _ => Err(SystemError::EINVAL), + } + } +} + +pub trait SegmentBody: Sized + Clone + Copy { + type CType; + + fn read_from_buf(header: &CMsgSegHdr, buf: &[u8]) -> Result<(Self, usize), SystemError> + where + Self: Sized, + { + todo!() + } + + fn write_to_buf(&self, buf: &mut [u8]) -> Result<(), SystemError> { + todo!() + } + + fn padding_len() -> usize { + let payload_len = size_of::(); + payload_len.checked_add(NLMSG_ALIGN - 1).unwrap() & (!(NLMSG_ALIGN - 1) - payload_len) + } +} diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs new file mode 100644 index 000000000..7d3ddb10d --- /dev/null +++ b/kernel/src/net/socket/netlink/mod.rs @@ -0,0 +1,44 @@ +use crate::net::socket::{ + family, + netlink::{route::NetlinkRouteSocket, table::StandardNetlinkProtocol}, + SocketInode, +}; +use alloc::sync::Arc; +use system_error::SystemError; + +pub mod addr; +mod common; +pub mod message; +mod receiver; +mod route; +mod table; + +pub struct Netlink; + +impl family::Family for Netlink { + fn socket( + stype: super::PSOCK, + protocol: u32, + ) -> Result, SystemError> { + match stype { + super::PSOCK::Raw | super::PSOCK::Datagram => create_netlink_socket(protocol), + _ => { + log::warn!("unsupported socket type for Netlink"); + Err(SystemError::EPROTONOSUPPORT) + } + } + } +} + +fn create_netlink_socket(protocol: u32) -> Result, SystemError> { + let nl_protocol = StandardNetlinkProtocol::try_from(protocol); + let inode = match nl_protocol { + Ok(StandardNetlinkProtocol::ROUTE) => NetlinkRouteSocket::new(false), + _ => { + log::warn!("unsupported Netlink protocol: {}", protocol); + return Err(SystemError::EPROTONOSUPPORT); + } + }; + + Ok(SocketInode::new(inode)) +} diff --git a/kernel/src/net/socket/netlink/receiver.rs b/kernel/src/net/socket/netlink/receiver.rs new file mode 100644 index 000000000..044dec098 --- /dev/null +++ b/kernel/src/net/socket/netlink/receiver.rs @@ -0,0 +1,51 @@ +use crate::libs::spinlock::SpinLock; +use crate::libs::wait_queue::WaitQueue; +use crate::process::ProcessState; +use alloc::collections::VecDeque; +use alloc::sync::Arc; +use system_error::SystemError; + +/// Netlink Socket 的消息队列 +#[derive(Debug)] +pub struct MessageQueue(pub Arc>>); + +impl Clone for MessageQueue { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl MessageQueue { + pub fn new() -> Self { + Self(Arc::new(SpinLock::new(VecDeque::new()))) + } + + fn enqueue(&self, message: Message) -> Result<(), SystemError> { + // FIXME: 确保消息队列不会超过最大长度 + self.0.lock().push_back(message); + Ok(()) + } +} + +/// Netlink Socket 的消息接收器,记录在全局的 Netlink Socket 表中,负责将消息压入对应的消息队列,并唤醒等待的线程 +#[derive(Debug)] +pub struct MessageReceiver { + message_queue: MessageQueue, + wait_queue: Arc, +} + +impl MessageReceiver { + pub fn new(message_queue: MessageQueue, wait_queue: Arc) -> Self { + Self { + message_queue, + wait_queue, + } + } + + pub fn enqueue_message(&self, message: Message) -> Result<(), SystemError> { + self.message_queue.enqueue(message)?; + // 唤醒等待队列中的线程 + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + Ok(()) + } +} diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs new file mode 100644 index 000000000..a5c220b5d --- /dev/null +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -0,0 +1,105 @@ +use crate::net::socket::{ + netlink::{ + addr::NetlinkSocketAddr, + common::bound::BoundNetlink, + message::ProtocolSegment, + route::{kernel::netlink_route_kernel, message::RouteNlMessage}, + }, + utils::datagram_common, + PMSG, +}; +use system_error::SystemError; + +impl datagram_common::Bound for BoundNetlink { + type Endpoint = NetlinkSocketAddr; + + fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<(), SystemError> { + self.bind_common(endpoint) + } + + fn local_endpoint(&self) -> Self::Endpoint { + self.handle.addr() + } + + fn remote_endpoint(&self) -> Option { + Some(self.remote_addr) + } + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) { + self.remote_addr = *endpoint; + } + + fn try_send( + &self, + buf: &[u8], + to: &Self::Endpoint, + _flags: crate::net::socket::PMSG, + ) -> Result { + if *to != NetlinkSocketAddr::new_unspecified() { + return Err(SystemError::ENOTCONN); + } + + if *to != NetlinkSocketAddr::new_unspecified() { + return Err(SystemError::ECONNREFUSED); + } + + let sum_lens = buf.len(); + + let mut nlmsg = match RouteNlMessage::read_from(buf) { + Ok(msg) => msg, + Err(e) if e == SystemError::EFAULT => { + // 说明这个时候 buf 不是一个完整的 netlink 消息 + return Err(e); + } + Err(e) => { + // 传播错误,静默处理 + log::warn!( + "netlink_send: failed to read netlink message from buffer: {:?}", + e + ); + return Ok(sum_lens); + } + }; + + let local_port = self.handle.port(); + + for segment in nlmsg.segments_mut() { + let header = segment.header_mut(); + if header.pid == 0 { + header.pid = local_port; + } + } + + netlink_route_kernel().request(&nlmsg, local_port); + + Ok(sum_lens) + } + + fn try_recv( + &self, + writer: &mut [u8], + flags: crate::net::socket::PMSG, + ) -> Result<(usize, Self::Endpoint), SystemError> { + let mut receive_queue = self.receive_queue.0.lock(); + + let Some(res) = receive_queue.front() else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + }; + + let len = { + let max = writer.len(); + res.total_len().min(max) + }; + + let _copied = res.write_to(writer)?; + + if !flags.contains(PMSG::PEEK) { + receive_queue.pop_front(); + } + + // todo 目前这个信息只能来自内核 + let remote = NetlinkSocketAddr::new_unspecified(); + + Ok((len, remote)) + } +} diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs new file mode 100644 index 000000000..fbe6bdbc5 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -0,0 +1,27 @@ +//! # Netlink route kernel module +//! 内核对于 Netlink 路由的处理模块 + +use crate::net::socket::netlink::route::message::RouteNlMessage; +use core::marker::PhantomData; + +pub(super) struct NetlinkRouteKernelSocket { + _private: PhantomData<()>, +} + +impl NetlinkRouteKernelSocket { + const fn new() -> Self { + NetlinkRouteKernelSocket { + _private: PhantomData, + } + } + + pub(super) fn request(&self, request: &RouteNlMessage, dst_port: u32) {} +} + +/// 负责处理 Netlink 路由相关的内核模块 +/// todo net namespace 实现之后应该是每一个 namespace 都有一个独立的 NetlinkRouteKernelSocket +static NETLINK_ROUTE_KERNEL: NetlinkRouteKernelSocket = NetlinkRouteKernelSocket::new(); + +pub(super) fn netlink_route_kernel() -> &'static NetlinkRouteKernelSocket { + &NETLINK_ROUTE_KERNEL +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/addr.rs b/kernel/src/net/socket/netlink/route/message/attr/addr.rs new file mode 100644 index 000000000..0381d98e7 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/addr.rs @@ -0,0 +1,125 @@ +use crate::net::socket::netlink::{message::attr::Attribute, route::message::attr::IFNAME_SIZE}; +use alloc::ffi::CString; +use system_error::SystemError; + +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +enum AddrAttrClass { + UNSPEC = 0, + ADDRESS = 1, + LOCAL = 2, + LABEL = 3, + BROADCAST = 4, + ANYCAST = 5, + CACHEINFO = 6, + MULTICAST = 7, + FLAGS = 8, + RT_PRIORITY = 9, + TARGET_NETNSID = 10, +} + +impl TryFrom for AddrAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 0 => Ok(AddrAttrClass::UNSPEC), + 1 => Ok(AddrAttrClass::ADDRESS), + 2 => Ok(AddrAttrClass::LOCAL), + 3 => Ok(AddrAttrClass::LABEL), + 4 => Ok(AddrAttrClass::BROADCAST), + 5 => Ok(AddrAttrClass::ANYCAST), + 6 => Ok(AddrAttrClass::CACHEINFO), + 7 => Ok(AddrAttrClass::MULTICAST), + 8 => Ok(AddrAttrClass::FLAGS), + 9 => Ok(AddrAttrClass::RT_PRIORITY), + 10 => Ok(AddrAttrClass::TARGET_NETNSID), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug)] +pub enum AddrAttr { + Address([u8; 4]), + Local([u8; 4]), + Label(CString), +} + +impl AddrAttr { + fn class(&self) -> AddrAttrClass { + match self { + AddrAttr::Address(_) => AddrAttrClass::ADDRESS, + AddrAttr::Local(_) => AddrAttrClass::LOCAL, + AddrAttr::Label(_) => AddrAttrClass::LABEL, + } + } +} + +impl Attribute for AddrAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + AddrAttr::Address(addr) => addr.as_ref(), + AddrAttr::Local(addr) => addr.as_ref(), + AddrAttr::Label(label) => label.to_bytes_with_nul(), + } + } + + fn read_from_buf( + header: &crate::net::socket::netlink::message::attr::CAttrHeader, + payload_buf: &[u8], + ) -> Result, SystemError> + where + Self: Sized, + { + let payload_len = header.payload_len(); + + let Ok(addr_class) = AddrAttrClass::try_from(header.type_()) else { + //todo 或许这里我应该返回偏移值 + //reader.skip_some(payload_len); + return Ok(None); + }; + + // 拷贝payload_buf到本地变量,避免生命周期问题 + let buf = &payload_buf[..payload_len.min(payload_buf.len())]; + + let res = match (addr_class, buf.len()) { + (AddrAttrClass::ADDRESS, 4) => { + let mut arr = [0u8; 4]; + arr.copy_from_slice(&buf[0..4]); + AddrAttr::Address(arr) + } + (AddrAttrClass::LOCAL, 4) => { + let mut arr = [0u8; 4]; + arr.copy_from_slice(&buf[0..4]); + AddrAttr::Local(arr) + } + (AddrAttrClass::LABEL, 1..=IFNAME_SIZE) => { + // 查找第一个0字节作为结尾,否则用全部 + let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + let cstr = CString::new(&buf[..nul_pos]).map_err(|_| SystemError::EINVAL)?; + AddrAttr::Label(cstr) + } + (AddrAttrClass::ADDRESS | AddrAttrClass::LOCAL | AddrAttrClass::LABEL, _) => { + log::warn!( + "address attribute `{:?}` contains invalid payload", + addr_class + ); + return Err(SystemError::EINVAL); + } + (_, _) => { + log::warn!("address attribute `{:?}` is not supported", addr_class); + // reader.skip_some(payload_len); + return Ok(None); + } + }; + + Ok(Some(res)) + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/mod.rs b/kernel/src/net/socket/netlink/route/message/attr/mod.rs new file mode 100644 index 000000000..d6de48bfb --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/mod.rs @@ -0,0 +1,5 @@ +pub mod addr; +pub mod route; + +/// 网卡名字长度 +const IFNAME_SIZE: usize = 16; diff --git a/kernel/src/net/socket/netlink/route/message/attr/route.rs b/kernel/src/net/socket/netlink/route/message/attr/route.rs new file mode 100644 index 000000000..3ec62534b --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/route.rs @@ -0,0 +1,95 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use system_error::SystemError; + +/// 路由相关属性 +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +enum RouteAttrClass { + UNSPEC = 0, + DST = 1, // 目标地址 + SRC = 2, // 源地址 + IIF = 3, // 输入接口 + OIF = 4, // 输出接口 + GATEWAY = 5, // 网关地址 + PRIORITY = 6, // 路由优先级 + PREFSRC = 7, // 首选源地址 + METRICS = 8, // 路由度量 + MULTIPATH = 9, // 多路径信息 + TABLE = 15, // 路由表ID +} + +impl TryFrom for RouteAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 0 => Ok(RouteAttrClass::UNSPEC), + 1 => Ok(RouteAttrClass::DST), + 2 => Ok(RouteAttrClass::SRC), + 3 => Ok(RouteAttrClass::IIF), + 4 => Ok(RouteAttrClass::OIF), + 5 => Ok(RouteAttrClass::GATEWAY), + 6 => Ok(RouteAttrClass::PRIORITY), + 7 => Ok(RouteAttrClass::PREFSRC), + 8 => Ok(RouteAttrClass::METRICS), + 9 => Ok(RouteAttrClass::MULTIPATH), + 15 => Ok(RouteAttrClass::TABLE), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug)] +pub enum RouteAttr { + Dst([u8; 4]), // 目标地址 (IPv4) + Src([u8; 4]), // 源地址 (IPv4) + Gateway([u8; 4]), // 网关地址 (IPv4) + Oif(u32), // 输出接口索引 + Iif(u32), // 输入接口索引 + Priority(u32), // 路由优先级 + Prefsrc([u8; 4]), // 首选源地址 (IPv4) + Table(u32), // 路由表ID +} + +impl RouteAttr { + fn class(&self) -> RouteAttrClass { + match self { + RouteAttr::Dst(_) => RouteAttrClass::DST, + RouteAttr::Src(_) => RouteAttrClass::SRC, + RouteAttr::Gateway(_) => RouteAttrClass::GATEWAY, + RouteAttr::Oif(_) => RouteAttrClass::OIF, + RouteAttr::Iif(_) => RouteAttrClass::IIF, + RouteAttr::Priority(_) => RouteAttrClass::PRIORITY, + RouteAttr::Prefsrc(_) => RouteAttrClass::PREFSRC, + RouteAttr::Table(_) => RouteAttrClass::TABLE, + } + } +} + +impl Attribute for RouteAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + fn payload_as_bytes(&self) -> &[u8] { + // match self { + // RouteAttr::Dst(addr) + // | RouteAttr::Src(addr) + // | RouteAttr::Gateway(addr) + // | RouteAttr::Prefsrc(addr) => addr, + // RouteAttr::Oif(idx) | RouteAttr::Iif(idx) => idx.as_bytes(), + // RouteAttr::Priority(pri) => pri.as_bytes(), + // RouteAttr::Table(table) => table.as_bytes(), + // } + todo!() + } + + fn read_from_buf( + header: &crate::net::socket::netlink::message::attr::CAttrHeader, + payload_buf: &[u8], + ) -> Result, SystemError> + where + Self: Sized, + { + todo!() + } +} diff --git a/kernel/src/net/socket/netlink/route/message/mod.rs b/kernel/src/net/socket/netlink/route/message/mod.rs new file mode 100644 index 000000000..5f74d88ae --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/mod.rs @@ -0,0 +1,6 @@ +mod attr; +mod segment; + +use crate::net::socket::netlink::{message::Message, route::message::segment::RouteNlSegment}; + +pub(in crate::net::socket::netlink) type RouteNlMessage = Message; diff --git a/kernel/src/net/socket/netlink/route/message/segment/addr.rs b/kernel/src/net/socket/netlink/route/message/segment/addr.rs new file mode 100644 index 000000000..b09adaae5 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/addr.rs @@ -0,0 +1,120 @@ +use core::num::NonZeroU32; + +use system_error::SystemError; + +use crate::net::socket::netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::addr::AddrAttr, +}; + +pub type AddrSegment = SegmentCommon; + +impl SegmentBody for AddrSegmentBody { + type CType = CIfaddrMsg; +} + +/// `ifaddrmsg` in Linux. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CIfaddrMsg { + pub family: u8, + /// The prefix length + pub prefix_len: u8, + /// Flags + pub flags: u8, + /// Address scope + pub scope: u8, + /// Link index + pub index: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct AddrSegmentBody { + pub family: i32, + pub prefix_len: u8, + pub flags: AddrMessageFlags, + pub scope: RtScope, + pub index: Option, +} + +impl TryFrom for AddrSegmentBody { + type Error = SystemError; + + fn try_from(value: CIfaddrMsg) -> Result { + // TODO: If the attribute IFA_FLAGS exists, the flags in header should be ignored. + let flags = AddrMessageFlags::from_bits_truncate(value.flags as u32); + let scope = RtScope::try_from(value.scope as i32)?; + let index = NonZeroU32::new(value.index); + + Ok(Self { + family: value.family as i32, + prefix_len: value.prefix_len, + flags, + scope, + index, + }) + } +} + +impl From for CIfaddrMsg { + fn from(value: AddrSegmentBody) -> Self { + let index = if let Some(index) = value.index { + index.get() + } else { + 0 + }; + CIfaddrMsg { + family: value.family as u8, + prefix_len: value.prefix_len, + flags: value.flags.bits() as u8, + scope: value.scope as _, + index, + } + } +} + +bitflags! { + /// Flags in [`CIfaddrMsg`]. + pub struct AddrMessageFlags: u32 { + const SECONDARY = 0x01; + const NODAD = 0x02; + const OPTIMISTIC = 0x04; + const DADFAILED = 0x08; + const HOMEADDRESS = 0x10; + const DEPRECATED = 0x20; + const TENTATIVE = 0x40; + const PERMANENT = 0x80; + const MANAGETEMPADDR = 0x100; + const NOPREFIXROUTE = 0x200; + const MCAUTOJOIN = 0x400; + const STABLE_PRIVACY = 0x800; + } +} + +/// `rt_scope_t` in Linux. +#[repr(u8)] +#[derive(Debug, Clone, Copy)] +#[expect(clippy::upper_case_acronyms)] +pub enum RtScope { + UNIVERSE = 0, + // User defined values + SITE = 200, + LINK = 253, + HOST = 254, + NOWHERE = 255, +} + +impl TryFrom for RtScope { + type Error = SystemError; + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(RtScope::UNIVERSE), + 200 => Ok(RtScope::SITE), + 253 => Ok(RtScope::LINK), + 254 => Ok(RtScope::HOST), + 255 => Ok(RtScope::NOWHERE), + _ => Err(SystemError::EINVAL), + } + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs new file mode 100644 index 000000000..dfeeed391 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -0,0 +1,97 @@ +pub mod addr; +pub mod route; + +use crate::net::socket::netlink::{ + message::{ + segment::{header::CMsgSegHdr, CSegmentType}, + ProtocolSegment, + }, + route::message::segment::{addr::AddrSegment, route::RouteSegment}, +}; +use alloc::vec::Vec; +use system_error::SystemError; + +#[derive(Debug)] +pub enum RouteNlSegment { + // NewLink(LinkSegment), + // GetLink(LinkSegment), + NewAddr(AddrSegment), + GetAddr(AddrSegment), + // Done(DoneSegment), + // Error(ErrorSegment), + NewRoute(RouteSegment), + DelRoute(RouteSegment), + GetRoute(RouteSegment), +} + +impl ProtocolSegment for RouteNlSegment { + fn header(&self) -> &crate::net::socket::netlink::message::segment::header::CMsgSegHdr { + match self { + RouteNlSegment::NewRoute(route_segment) + | RouteNlSegment::DelRoute(route_segment) + | RouteNlSegment::GetRoute(route_segment) => route_segment.header(), + RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { + addr_segment.header() + } + } + } + + fn header_mut( + &mut self, + ) -> &mut crate::net::socket::netlink::message::segment::header::CMsgSegHdr { + match self { + RouteNlSegment::NewRoute(route_segment) + | RouteNlSegment::DelRoute(route_segment) + | RouteNlSegment::GetRoute(route_segment) => route_segment.header_mut(), + RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { + addr_segment.header_mut() + } + } + } + + fn read_from(buf: &[u8]) -> Result { + if buf.len() < size_of::() { + log::warn!("the buffer is too small to read a netlink segment header"); + return Err(SystemError::EINVAL); + } + + let header = unsafe { *(buf.as_ptr() as *const CMsgSegHdr) }; + + let segment = match CSegmentType::try_from(header.type_)? { + CSegmentType::GETADDR => { + RouteNlSegment::GetAddr(AddrSegment::read_from_buf(header, buf)?) + } + CSegmentType::GETROUTE => { + RouteNlSegment::GetRoute(RouteSegment::read_from_buf(header, buf)?) + } + _ => return Err(SystemError::EINVAL), + }; + + Ok(segment) + } + + fn write_to(&self, buf: &mut [u8]) -> Result { + let mut kernel_buf: Vec = vec![]; + match self { + RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(&mut kernel_buf)?, + RouteNlSegment::NewRoute(route_segment) => { + route_segment.write_to_buf(&mut kernel_buf)? + } + _ => { + log::warn!("write_to is not implemented for this segment type"); + return Err(SystemError::ENOSYS); + } + } + + let actual_len = kernel_buf.len().min(buf.len()); + let copied = if !kernel_buf.is_empty() { + buf[..actual_len].copy_from_slice(&kernel_buf[..actual_len]); + actual_len + } else { + // 如果没有数据需要写入,返回0 + 0 + }; + + Ok(copied) + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/route.rs b/kernel/src/net/socket/netlink/route/message/segment/route.rs new file mode 100644 index 000000000..3904631df --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/route.rs @@ -0,0 +1,196 @@ +use system_error::SystemError; + +use crate::net::socket::{ + netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::route::RouteAttr, + }, + AddressFamily, +}; + +pub type RouteSegment = SegmentCommon; + +impl SegmentBody for RouteSegmentBody { + type CType = CRtMsg; +} + +/// `rtmsg` in Linux +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CRtMsg { + /// 地址族 (AF_INET/AF_INET6) + pub family: u8, + /// 目标地址前缀长度 + pub dst_len: u8, + /// 源地址前缀长度 + pub src_len: u8, + /// 服务类型/DSCP + pub tos: u8, + /// 路由表ID + pub table: u8, + /// 路由协议 + pub protocol: u8, + /// 路由作用域 + pub scope: u8, + /// 路由类型 + pub type_: u8, + /// 路由标志 + pub flags: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct RouteSegmentBody { + pub family: AddressFamily, + pub dst_len: u8, + pub src_len: u8, + pub tos: u8, + pub table: RouteTable, + pub protocol: RouteProtocol, + pub scope: RouteScope, + pub type_: RouteType, + pub flags: RouteFlags, +} + +// 定义路由相关的枚举类型 +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteTable { + Unspec = 0, + Compat = 252, + Default = 253, + Main = 254, + Local = 255, +} + +impl TryFrom for RouteTable { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteTable::Unspec), + 252 => Ok(RouteTable::Compat), + 253 => Ok(RouteTable::Default), + 254 => Ok(RouteTable::Main), + 255 => Ok(RouteTable::Local), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteProtocol { + Unspec = 0, + Redirect = 1, + Kernel = 2, + Boot = 3, + Static = 4, + // 添加更多协议... +} + +impl TryFrom for RouteProtocol { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteProtocol::Unspec), + 1 => Ok(RouteProtocol::Redirect), + 2 => Ok(RouteProtocol::Kernel), + 3 => Ok(RouteProtocol::Boot), + 4 => Ok(RouteProtocol::Static), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteScope { + Universe = 0, + Site = 200, + Link = 253, + Host = 254, + Nowhere = 255, +} + +impl TryFrom for RouteScope { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteScope::Universe), + 200 => Ok(RouteScope::Site), + 253 => Ok(RouteScope::Link), + 254 => Ok(RouteScope::Host), + 255 => Ok(RouteScope::Nowhere), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteType { + Unspec = 0, + Unicast = 1, + Local = 2, + Broadcast = 3, + Anycast = 4, + Multicast = 5, + Blackhole = 6, + Unreachable = 7, + Prohibit = 8, +} + +impl TryFrom for RouteType { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteType::Unspec), + 1 => Ok(RouteType::Unicast), + 2 => Ok(RouteType::Local), + 3 => Ok(RouteType::Broadcast), + 4 => Ok(RouteType::Anycast), + 5 => Ok(RouteType::Multicast), + 6 => Ok(RouteType::Blackhole), + 7 => Ok(RouteType::Unreachable), + 8 => Ok(RouteType::Prohibit), + _ => Err(SystemError::EINVAL), + } + } +} + +bitflags::bitflags! { + pub struct RouteFlags: u32 { + const NOTIFY = 0x100; + const CLONED = 0x200; + const EQUALIZE = 0x400; + const PREFIX = 0x800; + } +} + +impl TryFrom for RouteSegmentBody { + type Error = SystemError; + + fn try_from(value: CRtMsg) -> Result { + let family = AddressFamily::try_from(value.family as u16)?; + let table = RouteTable::try_from(value.table)?; + let protocol = RouteProtocol::try_from(value.protocol)?; + let scope = RouteScope::try_from(value.scope)?; + let type_ = RouteType::try_from(value.type_)?; + let flags = RouteFlags::from_bits_truncate(value.flags); + + Ok(Self { + family, + dst_len: value.dst_len, + src_len: value.src_len, + tos: value.tos, + table, + protocol, + scope, + type_, + flags, + }) + } +} diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs new file mode 100644 index 000000000..20bc0eb46 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -0,0 +1,7 @@ +use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkRouteProtocol}; + +pub mod bound; +mod kernel; +pub mod message; + +pub type NetlinkRouteSocket = NetlinkSocket; diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs new file mode 100644 index 000000000..6df514b3e --- /dev/null +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -0,0 +1,308 @@ +mod multicast; + +use crate::net::socket::netlink::addr::multicast::GroupIdSet; +use crate::net::socket::netlink::route::message::RouteNlMessage; +use crate::net::socket::netlink::table::multicast::MulticastMessage; +use crate::process::ProcessManager; +use crate::{libs::rand, net::socket::netlink::addr::NetlinkSocketAddr}; +use crate::{ + libs::{once::Once, rwlock::RwLock}, + net::socket::netlink::{receiver::MessageReceiver, table::multicast::MulticastGroup}, +}; +use alloc::boxed::Box; +use alloc::collections::BTreeMap; +use alloc::fmt::Debug; +use system_error::SystemError; + +static mut NETLINK_SOCKET_TABLE: Option = None; + +const MAX_ALLOWED_PROTOCOL_ID: u32 = 32; +const MAX_GROUPS: u32 = 32; + +struct NetlinkSocketTable { + route: RwLock>, +} + +impl NetlinkSocketTable { + pub fn new() -> Self { + Self { + route: RwLock::new(ProtocolSocketTable::new()), + } + } +} + +#[derive(Debug)] +pub struct ProtocolSocketTable { + unicast_sockets: BTreeMap>, + multicast_groups: Box<[MulticastGroup]>, +} + +impl ProtocolSocketTable { + fn new() -> Self { + let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect(); + Self { + unicast_sockets: BTreeMap::new(), + multicast_groups, + } + } + + fn bind( + &mut self, + socket_table: &'static RwLock>, + addr: &NetlinkSocketAddr, + receiver: MessageReceiver, + ) -> Result, SystemError> { + let port = if addr.port() != 0 { + addr.port() + } else { + let mut random_port = ProcessManager::current_pid().data() as u32; + while random_port == 0 || self.unicast_sockets.contains_key(&random_port) { + random_port = rand::soft_rand() as u32; + } + random_port + }; + + if self.unicast_sockets.contains_key(&port) { + return Err(SystemError::EADDRINUSE); + } + + self.unicast_sockets.insert(port, receiver); + + for group_id in addr.groups().ids_iter() { + let group = &mut self.multicast_groups[group_id as usize]; + group.add_member(port); + } + + Ok(BoundHandle::new(socket_table, port, addr.groups())) + } + + fn unicast(&self, dst_port: u32, message: Message) -> Result<(), SystemError> { + let Some(receiver) = self.unicast_sockets.get(&dst_port) else { + return Ok(()); + }; + receiver.enqueue_message(message) + } + + fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<(), SystemError> + where + Message: MulticastMessage, + { + for group_id in dst_groups.ids_iter() { + let Some(group) = self.multicast_groups.get(group_id as usize) else { + continue; + }; + for member in group.members() { + if let Some(receiver) = self.unicast_sockets.get(member) { + receiver.enqueue_message(message.clone())?; + } + } + } + Ok(()) + } +} + +#[derive(Debug)] +pub struct BoundHandle { + socket_table: &'static RwLock>, + port: u32, + groups: GroupIdSet, +} + +impl BoundHandle { + fn new( + socket_table: &'static RwLock>, + port: u32, + groups: GroupIdSet, + ) -> Self { + Self { + socket_table, + port, + groups, + } + } + + pub(super) const fn port(&self) -> u32 { + self.port + } + + pub(super) fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(self.port, self.groups) + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + self.groups.drop_groups(groups); + } + + pub(super) fn bind_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups = groups; + } +} + +impl Drop for BoundHandle { + fn drop(&mut self) { + let mut protocol_sockets = self.socket_table.write(); + + protocol_sockets.unicast_sockets.remove(&self.port); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + } +} + +pub trait SupportedNetlinkProtocol: Debug { + type Message: 'static + Send + Debug; + + fn socket_table() -> &'static RwLock>; + + fn bind( + addr: &NetlinkSocketAddr, + receiver: MessageReceiver, + ) -> Result, SystemError> { + let mut socket_table = Self::socket_table().write(); + socket_table.bind(Self::socket_table(), addr, receiver) + } + + fn unicast(dst_port: u32, message: Self::Message) -> Result<(), SystemError> { + let socket_table = Self::socket_table().read(); + socket_table.unicast(dst_port, message) + } + + fn multicast(dst_groups: GroupIdSet, message: Self::Message) -> Result<(), SystemError> + where + Self::Message: MulticastMessage, + { + let socket_table = Self::socket_table().read(); + socket_table.multicast(dst_groups, message) + } +} + +#[derive(Debug)] +pub struct NetlinkRouteProtocol; + +impl SupportedNetlinkProtocol for NetlinkRouteProtocol { + type Message = RouteNlMessage; + + fn socket_table() -> &'static RwLock> { + unsafe { &NETLINK_SOCKET_TABLE.as_ref().unwrap().route } + } +} + +pub fn init() { + let once = Once::new(); + once.call_once(|| unsafe { + NETLINK_SOCKET_TABLE = Some(NetlinkSocketTable::new()); + }); +} + +pub fn is_valid_protocol(protocol: u32) -> bool { + protocol < MAX_ALLOWED_PROTOCOL_ID +} + +#[expect(non_camel_case_types)] +#[repr(u32)] +#[derive(Debug, Clone, Copy)] +pub enum StandardNetlinkProtocol { + /// Routing/device hook + ROUTE = 0, + /// Unused number + UNUSED = 1, + /// Reserved for user mode socket protocols + USERSOCK = 2, + /// Unused number, formerly ip_queue + FIREWALL = 3, + /// Socket monitoring + SOCK_DIAG = 4, + /// Netfilter/iptables ULOG + NFLOG = 5, + /// IPsec + XFRM = 6, + /// SELinux event notifications + SELINUX = 7, + /// Open-iSCSI + ISCSI = 8, + /// Auditing + AUDIT = 9, + FIB_LOOKUP = 10, + CONNECTOR = 11, + /// Netfilter subsystem + NETFILTER = 12, + IP6_FW = 13, + /// DECnet routing messages + DNRTMSG = 14, + /// Kernel messages to userspace + KOBJECT_UEVENT = 15, + GENERIC = 16, + /// Leave room for NETLINK_DM (DM Events) + /// SCSI Transports + SCSITRANSPORT = 18, + ECRYPTFS = 19, + RDMA = 20, + /// Crypto layer + CRYPTO = 21, + /// SMC monitoring + SMC = 22, +} + +impl TryFrom for StandardNetlinkProtocol { + type Error = (); + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(StandardNetlinkProtocol::ROUTE), + 1 => Ok(StandardNetlinkProtocol::UNUSED), + 2 => Ok(StandardNetlinkProtocol::USERSOCK), + 3 => Ok(StandardNetlinkProtocol::FIREWALL), + 4 => Ok(StandardNetlinkProtocol::SOCK_DIAG), + 5 => Ok(StandardNetlinkProtocol::NFLOG), + 6 => Ok(StandardNetlinkProtocol::XFRM), + 7 => Ok(StandardNetlinkProtocol::SELINUX), + 8 => Ok(StandardNetlinkProtocol::ISCSI), + 9 => Ok(StandardNetlinkProtocol::AUDIT), + 10 => Ok(StandardNetlinkProtocol::FIB_LOOKUP), + 11 => Ok(StandardNetlinkProtocol::CONNECTOR), + 12 => Ok(StandardNetlinkProtocol::NETFILTER), + 13 => Ok(StandardNetlinkProtocol::IP6_FW), + 14 => Ok(StandardNetlinkProtocol::DNRTMSG), + 15 => Ok(StandardNetlinkProtocol::KOBJECT_UEVENT), + 16 => Ok(StandardNetlinkProtocol::GENERIC), + 18 => Ok(StandardNetlinkProtocol::SCSITRANSPORT), + 19 => Ok(StandardNetlinkProtocol::ECRYPTFS), + 20 => Ok(StandardNetlinkProtocol::RDMA), + 21 => Ok(StandardNetlinkProtocol::CRYPTO), + 22 => Ok(StandardNetlinkProtocol::SMC), + _ => Err(()), + } + } +} diff --git a/kernel/src/net/socket/netlink/table/multicast.rs b/kernel/src/net/socket/netlink/table/multicast.rs new file mode 100644 index 000000000..c4a7ab333 --- /dev/null +++ b/kernel/src/net/socket/netlink/table/multicast.rs @@ -0,0 +1,30 @@ +use alloc::collections::BTreeSet; + +#[derive(Debug)] +pub struct MulticastGroup { + // portnumber: u32, + members: BTreeSet, +} + +impl MulticastGroup { + pub const fn new() -> Self { + Self { + members: BTreeSet::new(), + } + } + + pub fn add_member(&mut self, port_num: u32) { + self.members.insert(port_num); + } + + pub fn remove_member(&mut self, port_num: u32) { + self.members.remove(&port_num); + } + + pub fn members(&self) -> &BTreeSet { + &self.members + } +} + +/// Uevent that can be sent to multicast groups. +pub trait MulticastMessage: Clone {} diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs new file mode 100644 index 000000000..a86b73c87 --- /dev/null +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -0,0 +1,178 @@ +use crate::{ + libs::{rwlock::RwLock, wait_queue::WaitQueue}, + net::socket::PMSG, +}; +use alloc::sync::Arc; +use core::panic; +use system_error::SystemError; + +pub trait Unbound { + type Endpoint; + type Bound; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + wait_queue: Arc, + ) -> Result; + + fn bind_ephemeral( + &mut self, + endpoint: &Self::Endpoint, + wait_queue: Arc, + ) -> Result; +} + +pub trait Bound { + type Endpoint: Clone; + + fn bind(&mut self, _endpoint: &Self::Endpoint) -> Result<(), SystemError> { + Err(SystemError::EINVAL) + } + + fn local_endpoint(&self) -> Self::Endpoint; + + fn remote_endpoint(&self) -> Option; + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint); + + fn try_recv( + &self, + writer: &mut [u8], + flags: PMSG, + ) -> Result<(usize, Self::Endpoint), SystemError>; + + fn try_send(&self, buf: &[u8], to: &Self::Endpoint, flags: PMSG) -> Result; +} + +#[derive(Debug)] +pub enum Inner { + Unbound(UnboundSocket), + Bound(BoundSocket), +} + +impl Inner +where + UnboundSocket: Unbound, + BoundSocket: Bound, +{ + pub fn bind( + &mut self, + endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + ) -> Result<(), SystemError> { + let unbound = match self { + Inner::Bound(bound) => return bound.bind(endpoint), + Inner::Unbound(unbound) => unbound, + }; + + let bound = unbound.bind(endpoint, wait_queue)?; + *self = Inner::Bound(bound); + + Ok(()) + } + + pub fn bind_ephemeral( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + ) -> Result<(), SystemError> { + let unbound_datagram = match self { + Inner::Unbound(unbound) => unbound, + Inner::Bound(_) => return Ok(()), + }; + + let bound = unbound_datagram.bind_ephemeral(remote_endpoint, wait_queue)?; + *self = Inner::Bound(bound); + + Ok(()) + } + + pub fn connect( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + ) -> Result<(), SystemError> { + self.bind_ephemeral(remote_endpoint, wait_queue)?; + + let bound = match self { + Inner::Unbound(_) => { + unreachable!( + "`bind_to_ephemeral_endpoint` succeeds so the socket cannot be unbound" + ); + } + Inner::Bound(bound_datagram) => bound_datagram, + }; + bound.set_remote_endpoint(remote_endpoint); + + Ok(()) + } + + pub fn addr(&self) -> Option { + match self { + Inner::Unbound(_) => None, + Inner::Bound(bound) => bound.remote_endpoint(), + } + } + + pub fn peer_addr(&self) -> Option { + match self { + Inner::Unbound(_) => None, + Inner::Bound(bound) => bound.remote_endpoint(), + } + } + + pub fn try_recv( + &self, + writer: &mut [u8], + flags: PMSG, + ) -> Result<(usize, UnboundSocket::Endpoint), SystemError> { + match self { + Inner::Unbound(_) => Err(SystemError::EAGAIN_OR_EWOULDBLOCK), + Inner::Bound(bound) => bound.try_recv(writer, flags), + } + } + + // try_send 在下面:) +} + +pub fn select_remote_and_bind( + inner_lock: &RwLock>, + remote: Option, + bind_ephemeral: B, + op: F, +) -> Result +where + UnboundSocket: Unbound, + BoundSocket: Bound, + B: FnOnce() -> Result<(), SystemError>, + F: FnOnce(&BoundSocket, UnboundSocket::Endpoint) -> Result, +{ + let mut inner = inner_lock.read(); + + // 这里用 loop 只是为了用 break :) + #[expect(clippy::never_loop)] + let bound = loop { + if let Inner::Bound(bound) = &*inner { + break bound; + } + + drop(inner); + bind_ephemeral()?; + + inner = inner_lock.read(); + + if let Inner::Bound(bound_datagram) = &*inner { + break bound_datagram; + } + + panic!(""); + }; + + let remote_endpoint = match remote { + Some(r) => r.clone(), + None => bound.remote_endpoint().ok_or(SystemError::EDESTADDRREQ)?, + }; + + op(bound, remote_endpoint) +} diff --git a/kernel/src/net/socket/utils.rs b/kernel/src/net/socket/utils/mod.rs similarity index 60% rename from kernel/src/net/socket/utils.rs rename to kernel/src/net/socket/utils/mod.rs index f3adbb3f5..96bf443b4 100644 --- a/kernel/src/net/socket/utils.rs +++ b/kernel/src/net/socket/utils/mod.rs @@ -1,6 +1,9 @@ use crate::net::socket::{ self, inet::syscall::create_inet_socket, unix::create_unix_socket, Socket, }; +pub(super) mod datagram_common; + +use crate::net::socket; use alloc::sync::Arc; use system_error::SystemError; @@ -14,19 +17,10 @@ pub fn create_socket( // log::info!("Creating socket: {:?}, {:?}, {:?}", family, socket_type, protocol); type AF = socket::AddressFamily; let inode = match family { - AF::INet => create_inet_socket( - smoltcp::wire::IpVersion::Ipv4, - socket_type, - smoltcp::wire::IpProtocol::from(protocol as u8), - is_nonblock, - )?, - AF::INet6 => create_inet_socket( - smoltcp::wire::IpVersion::Ipv6, - socket_type, - smoltcp::wire::IpProtocol::from(protocol as u8), - is_nonblock, - )?, - AF::Unix => create_unix_socket(socket_type, is_nonblock)?, + AF::INet => socket::inet::Inet::socket(socket_type, protocol)?, + // AF::INet6 => socket::inet::Inet6::socket(socket_type, protocol)?, + AF::Unix => socket::unix::Unix::socket(socket_type, protocol)?, + AF::Netlink => socket::netlink::Netlink::socket(socket_type, protocol)?, _ => { log::warn!("unsupport address family"); return Err(SystemError::EAFNOSUPPORT); From 7a8faa28123695439e7b67c2991ad6b963aa2e8d Mon Sep 17 00:00:00 2001 From: sparkzky Date: Mon, 18 Aug 2025 12:05:07 +0800 Subject: [PATCH 15/36] =?UTF-8?q?feat(netlink):=20=E5=AE=8C=E5=96=84netlin?= =?UTF-8?q?k=E7=9A=84=E8=AF=BB=E5=86=99=E9=83=A8=E5=88=86,=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0addr=E7=9A=84=E5=86=85=E6=A0=B8=E5=A4=84=E7=90=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/mod.rs | 13 +++ .../net/socket/netlink/message/attr/mod.rs | 83 ++++++++---------- .../net/socket/netlink/message/segment/ack.rs | 84 +++++++++++++++++++ .../socket/netlink/message/segment/common.rs | 13 +-- .../net/socket/netlink/message/segment/mod.rs | 59 +++++++++++-- .../net/socket/netlink/route/kernel/addr.rs | 83 ++++++++++++++++++ .../net/socket/netlink/route/kernel/mod.rs | 39 ++++++++- .../net/socket/netlink/route/kernel/utils.rs | 39 +++++++++ .../socket/netlink/route/message/attr/addr.rs | 1 + .../net/socket/netlink/route/message/mod.rs | 4 +- .../netlink/route/message/segment/mod.rs | 15 +++- .../netlink/route/message/segment/route.rs | 16 ++++ 12 files changed, 378 insertions(+), 71 deletions(-) create mode 100644 kernel/src/net/socket/netlink/message/segment/ack.rs create mode 100644 kernel/src/net/socket/netlink/route/kernel/addr.rs create mode 100644 kernel/src/net/socket/netlink/route/kernel/utils.rs diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index b2906de88..ffa35eb2c 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -1,5 +1,6 @@ use alloc::{fmt, vec::Vec}; use alloc::{string::String, sync::Arc}; +use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; use crate::{ @@ -305,4 +306,16 @@ impl IfaceCommon { pub fn is_default_iface(&self) -> bool { self.default_iface } + + pub fn ipv4_addr(&self) -> Option { + self.smol_iface.lock().ipv4_addr() + } + + pub fn prefix_len(&self) -> Option { + self.smol_iface + .lock() + .ip_addrs() + .first() + .map(|ip_addr| ip_addr.prefix_len()) + } } diff --git a/kernel/src/net/socket/netlink/message/attr/mod.rs b/kernel/src/net/socket/netlink/message/attr/mod.rs index c743ae142..8b10afe4f 100644 --- a/kernel/src/net/socket/netlink/message/attr/mod.rs +++ b/kernel/src/net/socket/netlink/message/attr/mod.rs @@ -65,44 +65,13 @@ pub trait Attribute: core::fmt::Debug + Send + Sync { where Self: Sized; - fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { - // let payload_bytes = self.payload_as_bytes(); - // let header = CAttrHeader { - // len: (core::mem::size_of::() + payload_bytes.len()) as u16, - // type_: self.type_(), - // }; - - // let total_len = header.len as usize; - // let padded_len = align_to(total_len, NLMSG_ALIGN); - // let padding_len = padded_len - total_len; - - // // 确保 buf 足够大 - // if buf.len() < offset + padded_len { - // buf.resize(offset + padded_len, 0); - // } - - // // 写入头部 - // let header_bytes = unsafe { - // core::slice::from_raw_parts( - // &header as *const CAttrHeader as *const u8, - // core::mem::size_of::(), - // ) - // }; - // buf[offset..offset + header_bytes.len()].copy_from_slice(header_bytes); - - // // 写入负载 - // buf[offset + header_bytes.len()..offset + header_bytes.len() + payload_bytes.len()] - // .copy_from_slice(payload_bytes); - - // // 填充部分已经在 resize 时置零,无需额外处理 - - // Ok(()) - + fn write_to_buf(&self, buf: &mut Vec) -> Result { + let type_: u16 = self.type_(); let payload_bytes = self.payload_as_bytes(); - let header = CAttrHeader { - len: (size_of::() + payload_bytes.len()) as u16, - type_: self.type_(), - }; + let header = CAttrHeader::from_payload_len(type_, payload_bytes.len()); + let total_len = header.total_len_with_padding(); + + // let mut current_offset = offset; // 写入头部 let header_bytes = unsafe { @@ -111,29 +80,37 @@ pub trait Attribute: core::fmt::Debug + Send + Sync { size_of::(), ) }; + // buf[current_offset..current_offset + header_bytes.len()].copy_from_slice(header_bytes); buf.extend_from_slice(header_bytes); + // current_offset += header_bytes.len(); // 写入负载 + // buf[current_offset..current_offset + payload_bytes.len()].copy_from_slice(payload_bytes); buf.extend_from_slice(payload_bytes); + // current_offset += payload_bytes.len(); // 添加对齐填充 - let total_len = header.len as usize; - let padded_len = align_to(total_len, NLMSG_ALIGN); - let padding_len = padded_len - total_len; + let padding_len = header.padding_len(); if padding_len > 0 { - buf.extend_from_slice(&vec![0u8; padding_len]); + // buf[current_offset..current_offset + padding_len].fill(0); + buf.extend(vec![0u8; padding_len]); } - Ok(()) + Ok(total_len) } - fn read_all_from_buf(buf: &[u8], mut offset: usize) -> Result, SystemError> + fn read_all_from_buf(buf: &[u8], mut total_len: usize) -> Result, SystemError> where Self: Sized, { let mut attrs = Vec::new(); + let mut offset = 0; + + while total_len > 0 { + if total_len < size_of::() { + return Err(SystemError::EINVAL); + } - while offset < buf.len() { // 检查是否有足够的字节读取属性头部 if buf.len() - offset < size_of::() { return Err(SystemError::EINVAL); @@ -144,18 +121,21 @@ pub trait Attribute: core::fmt::Debug + Send + Sync { let attr_header = unsafe { *(attr_header_bytes.as_ptr() as *const CAttrHeader) }; // 验证属性长度 - if attr_header.len < size_of::() as u16 { + if attr_header.total_len() < size_of::() { return Err(SystemError::EINVAL); } - let attr_total_len = attr_header.len as usize; - if buf.len() - offset < attr_total_len { + total_len = total_len + .checked_sub(attr_header.total_len()) + .ok_or(SystemError::EINVAL)?; + + if buf.len() - offset < attr_header.total_len() { return Err(SystemError::EINVAL); } // 读取属性负载 let payload_start = offset + size_of::(); - let payload_len = attr_total_len - size_of::(); + let payload_len = attr_header.payload_len(); let payload_buf = &buf[payload_start..payload_start + payload_len]; // 解析属性 @@ -164,8 +144,11 @@ pub trait Attribute: core::fmt::Debug + Send + Sync { } // 移动到下一个属性(考虑对齐) - let padded_len = align_to(attr_total_len, NLMSG_ALIGN); - offset += padded_len; + let attr_total_with_padding = attr_header.total_len_with_padding(); + offset += attr_total_with_padding; + + let padding_len = total_len.min(attr_header.padding_len()); + total_len -= padding_len; } Ok(attrs) diff --git a/kernel/src/net/socket/netlink/message/segment/ack.rs b/kernel/src/net/socket/netlink/message/segment/ack.rs new file mode 100644 index 000000000..fa2dcaade --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/ack.rs @@ -0,0 +1,84 @@ +use crate::net::socket::netlink::message::{ + attr::noattr::NoAttr, + segment::{ + common::SegmentCommon, + header::{CMsgSegHdr, SegHdrCommonFlags}, + CSegmentType, SegmentBody, + }, +}; +use alloc::vec::Vec; +use system_error::SystemError; + +pub type DoneSegment = SegmentCommon; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct DoneSegmentBody { + error_code: i32, +} + +impl SegmentBody for DoneSegmentBody { + type CType = DoneSegmentBody; +} + +impl DoneSegment { + pub fn new_from_request(request_header: &CMsgSegHdr, error: Option) -> Self { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::DONE as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = if let Some(err) = error { + err.to_posix_errno() + } else { + 0 + }; + DoneSegmentBody { error_code } + }; + + Self::new(header, body, Vec::new()) + } +} + +pub type ErrorSegment = SegmentCommon; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct ErrorSegmentBody { + error_code: i32, + request_header: CMsgSegHdr, +} + +impl SegmentBody for ErrorSegmentBody { + type CType = ErrorSegmentBody; +} + +impl ErrorSegment { + pub fn new_from_request(request_header: &CMsgSegHdr, error: Option) -> Self { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::ERROR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = if let Some(err) = error { + err.to_posix_errno() + } else { + 0 + }; + ErrorSegmentBody { + error_code, + request_header: *request_header, + } + }; + + Self::new(header, body, Vec::new()) + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/common.rs b/kernel/src/net/socket/netlink/message/segment/common.rs index a358e0c98..b095fbfc5 100644 --- a/kernel/src/net/socket/netlink/message/segment/common.rs +++ b/kernel/src/net/socket/netlink/message/segment/common.rs @@ -1,10 +1,9 @@ -use alloc::vec::Vec; -use system_error::SystemError; - use crate::net::socket::netlink::message::{ attr::Attribute, segment::{header::CMsgSegHdr, SegmentBody}, }; +use alloc::vec::Vec; +use system_error::SystemError; #[derive(Debug)] pub struct SegmentCommon { @@ -47,8 +46,10 @@ impl SegmentCommon { } pub fn read_from_buf(header: CMsgSegHdr, buf: &[u8]) -> Result { - let (body, remain_len) = Body::read_from_buf(&header, buf)?; - let attrs = Attr::read_all_from_buf(buf, buf.len() - remain_len)?; + let (body, remain_len, padded_len) = Body::read_from_buf(&header, buf)?; + + let attrs_buf = &buf[padded_len..]; + let attrs = Attr::read_all_from_buf(attrs_buf, remain_len)?; Ok(Self { header, @@ -64,8 +65,10 @@ impl SegmentCommon { self.body.write_to_buf(buf)?; for attr in self.attrs.iter() { + // let cur_buf = &mut buf[offset..]; attr.write_to_buf(buf)?; } + Ok(()) } diff --git a/kernel/src/net/socket/netlink/message/segment/mod.rs b/kernel/src/net/socket/netlink/message/segment/mod.rs index 803250db5..f4fc4ce7f 100644 --- a/kernel/src/net/socket/netlink/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/message/segment/mod.rs @@ -1,7 +1,8 @@ -use system_error::SystemError; - use crate::net::socket::netlink::message::{segment::header::CMsgSegHdr, NLMSG_ALIGN}; +use alloc::vec::Vec; +use system_error::SystemError; +pub mod ack; pub mod common; pub mod header; @@ -55,21 +56,61 @@ impl TryFrom for CSegmentType { } pub trait SegmentBody: Sized + Clone + Copy { - type CType; + type CType: Copy + TryInto + From; - fn read_from_buf(header: &CMsgSegHdr, buf: &[u8]) -> Result<(Self, usize), SystemError> + fn read_from_buf(header: &CMsgSegHdr, buf: &[u8]) -> Result<(Self, usize, usize), SystemError> where Self: Sized, { - todo!() + let total_len = (header.len as usize) + .checked_sub(size_of::()) + .ok_or(SystemError::EINVAL)?; + + if buf.len() < total_len { + return Err(SystemError::EINVAL); + } + + let c_type_bytes = &buf[..size_of::()]; + let c_type = unsafe { *(c_type_bytes.as_ptr() as *const Self::CType) }; + + let total_len_with_padding = Self::total_len_with_padding(); + + let Ok(body) = c_type.try_into() else { + return Err(SystemError::EINVAL); + }; + + let remaining_len = total_len.saturating_sub(total_len_with_padding); + + Ok((body, remaining_len, total_len_with_padding)) } - fn write_to_buf(&self, buf: &mut [u8]) -> Result<(), SystemError> { - todo!() + fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + let c_type = Self::CType::from(*self); + + let body_bytes = unsafe { + core::slice::from_raw_parts( + &c_type as *const Self::CType as *const u8, + size_of::(), + ) + }; + buf.extend_from_slice(body_bytes); + + // let total_len_with_padding = Self::total_len_with_padding(); + let padding_len = Self::padding_len(); + + if padding_len > 0 { + buf.extend(vec![0u8; padding_len]); + } + + Ok(()) } - fn padding_len() -> usize { + fn total_len_with_padding() -> usize { let payload_len = size_of::(); - payload_len.checked_add(NLMSG_ALIGN - 1).unwrap() & (!(NLMSG_ALIGN - 1) - payload_len) + (payload_len.checked_add(NLMSG_ALIGN - 1).unwrap() & !(NLMSG_ALIGN - 1)) - payload_len + } + + fn padding_len() -> usize { + Self::total_len_with_padding() - size_of::() } } diff --git a/kernel/src/net/socket/netlink/route/kernel/addr.rs b/kernel/src/net/socket/netlink/route/kernel/addr.rs new file mode 100644 index 000000000..974148bba --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/addr.rs @@ -0,0 +1,83 @@ +use crate::{ + driver::net::Iface, + net::{ + socket::{ + netlink::{ + message::segment::{ + header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, + CSegmentType, + }, + route::{ + kernel::utils::finish_response, + message::{ + attr::addr::AddrAttr, + segment::{ + addr::{AddrMessageFlags, AddrSegment, AddrSegmentBody, RtScope}, + RouteNlSegment, + }, + }, + }, + }, + AddressFamily, + }, + NET_DEVICES, + }, +}; +use alloc::ffi::CString; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZeroU32; +use system_error::SystemError; + +pub(super) fn do_get_addr( + request_segment: &AddrSegment, +) -> Result, SystemError> { + let dump_all = { + let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); + flags.contains(GetRequestFlags::DUMP) + }; + + if !dump_all { + log::error!("GetAddr request without DUMP flag is not supported yet"); + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + + let mut responce: Vec = NET_DEVICES + .read() + .iter() + .filter_map(|(_, iface)| iface_to_new_addr(request_segment.header(), iface)) + .map(RouteNlSegment::NewAddr) + .collect(); + + finish_response(request_segment.header(), dump_all, &mut responce); + + Ok(responce) +} + +fn iface_to_new_addr(request_header: &CMsgSegHdr, iface: &Arc) -> Option { + let ipv4_addr = iface.common().ipv4_addr()?; + + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::NEWADDR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let addr_message = AddrSegmentBody { + family: AddressFamily::INet as _, + prefix_len: iface.common().prefix_len().unwrap(), + flags: AddrMessageFlags::PERMANENT, + scope: RtScope::HOST, + index: NonZeroU32::new(iface.nic_id() as u32), + }; + + let attrs = vec![ + AddrAttr::Address(ipv4_addr.octets()), + AddrAttr::Label(CString::new(iface.iface_name()).unwrap()), + AddrAttr::Local(ipv4_addr.octets()), + ]; + + Some(AddrSegment::new(header, addr_message, attrs)) +} diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs index fbe6bdbc5..687053a22 100644 --- a/kernel/src/net/socket/netlink/route/kernel/mod.rs +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -1,9 +1,19 @@ //! # Netlink route kernel module //! 内核对于 Netlink 路由的处理模块 -use crate::net::socket::netlink::route::message::RouteNlMessage; +use crate::net::socket::netlink::{ + message::{ + segment::{ack::ErrorSegment, CSegmentType}, + ProtocolSegment, + }, + route::message::{segment::RouteNlSegment, RouteNlMessage}, + table::{NetlinkRouteProtocol, SupportedNetlinkProtocol}, +}; use core::marker::PhantomData; +mod addr; +mod utils; + pub(super) struct NetlinkRouteKernelSocket { _private: PhantomData<()>, } @@ -15,7 +25,32 @@ impl NetlinkRouteKernelSocket { } } - pub(super) fn request(&self, request: &RouteNlMessage, dst_port: u32) {} + pub(super) fn request(&self, request: &RouteNlMessage, dst_port: u32) { + for segment in request.segments() { + let header = segment.header(); + + let seg_type = CSegmentType::try_from(header.type_).unwrap(); + let responce = match segment { + RouteNlSegment::GetAddr(request) => addr::do_get_addr(request), + RouteNlSegment::GetRoute(_new_route) => todo!(), + _ => { + log::warn!("Unsupported route request segment type: {:?}", seg_type); + todo!() + } + }; + + let responce = match responce { + Ok(segments) => RouteNlMessage::new(segments), + Err(error) => { + //todo 处理 `NetlinkMessageCommonFlags::ACK` + let err_segment = ErrorSegment::new_from_request(header, Some(error)); + RouteNlMessage::new(vec![RouteNlSegment::Error(err_segment)]) + } + }; + + NetlinkRouteProtocol::unicast(dst_port, responce).unwrap(); + } + } } /// 负责处理 Netlink 路由相关的内核模块 diff --git a/kernel/src/net/socket/netlink/route/kernel/utils.rs b/kernel/src/net/socket/netlink/route/kernel/utils.rs new file mode 100644 index 000000000..837f446a1 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/utils.rs @@ -0,0 +1,39 @@ +use crate::net::socket::netlink::{ + message::{ + segment::{ + ack::DoneSegment, + header::{CMsgSegHdr, SegHdrCommonFlags}, + }, + ProtocolSegment, + }, + route::message::segment::RouteNlSegment, +}; +use alloc::vec::Vec; + +pub fn finish_response( + request_header: &CMsgSegHdr, + dump_all: bool, + response_segments: &mut Vec, +) { + if !dump_all { + assert_eq!(response_segments.len(), 1); + return; + } + + append_done_segment(request_header, response_segments); + add_multi_flag(response_segments); +} + +fn append_done_segment(header: &CMsgSegHdr, response_segments: &mut Vec) { + let done_segment = DoneSegment::new_from_request(header, None); + response_segments.push(RouteNlSegment::Done(done_segment)); +} + +fn add_multi_flag(responce_segment: &mut [RouteNlSegment]) { + for segment in responce_segment.iter_mut() { + let header = segment.header_mut(); + let mut flags = SegHdrCommonFlags::from_bits_truncate(header.flags); + flags |= SegHdrCommonFlags::MULTI; + header.flags = flags.bits(); + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/addr.rs b/kernel/src/net/socket/netlink/route/message/attr/addr.rs index 0381d98e7..999f4ef2d 100644 --- a/kernel/src/net/socket/netlink/route/message/attr/addr.rs +++ b/kernel/src/net/socket/netlink/route/message/attr/addr.rs @@ -80,6 +80,7 @@ impl Attribute for AddrAttr { { let payload_len = header.payload_len(); + // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. let Ok(addr_class) = AddrAttrClass::try_from(header.type_()) else { //todo 或许这里我应该返回偏移值 //reader.skip_some(payload_len); diff --git a/kernel/src/net/socket/netlink/route/message/mod.rs b/kernel/src/net/socket/netlink/route/message/mod.rs index 5f74d88ae..dc801f071 100644 --- a/kernel/src/net/socket/netlink/route/message/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/mod.rs @@ -1,5 +1,5 @@ -mod attr; -mod segment; +pub(super) mod attr; +pub(super) mod segment; use crate::net::socket::netlink::{message::Message, route::message::segment::RouteNlSegment}; diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs index dfeeed391..5524f51a2 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -3,7 +3,11 @@ pub mod route; use crate::net::socket::netlink::{ message::{ - segment::{header::CMsgSegHdr, CSegmentType}, + segment::{ + ack::{DoneSegment, ErrorSegment}, + header::CMsgSegHdr, + CSegmentType, + }, ProtocolSegment, }, route::message::segment::{addr::AddrSegment, route::RouteSegment}, @@ -17,8 +21,8 @@ pub enum RouteNlSegment { // GetLink(LinkSegment), NewAddr(AddrSegment), GetAddr(AddrSegment), - // Done(DoneSegment), - // Error(ErrorSegment), + Done(DoneSegment), + Error(ErrorSegment), NewRoute(RouteSegment), DelRoute(RouteSegment), GetRoute(RouteSegment), @@ -33,6 +37,8 @@ impl ProtocolSegment for RouteNlSegment { RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { addr_segment.header() } + RouteNlSegment::Done(done_segment) => done_segment.header(), + RouteNlSegment::Error(error_segment) => error_segment.header(), } } @@ -46,6 +52,8 @@ impl ProtocolSegment for RouteNlSegment { RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { addr_segment.header_mut() } + RouteNlSegment::Done(done_segment) => done_segment.header_mut(), + RouteNlSegment::Error(error_segment) => error_segment.header_mut(), } } @@ -71,6 +79,7 @@ impl ProtocolSegment for RouteNlSegment { } fn write_to(&self, buf: &mut [u8]) -> Result { + // 这里没有直接写入buf,而是用 Vec 来构建内核缓冲区 let mut kernel_buf: Vec = vec![]; match self { RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(&mut kernel_buf)?, diff --git a/kernel/src/net/socket/netlink/route/message/segment/route.rs b/kernel/src/net/socket/netlink/route/message/segment/route.rs index 3904631df..995081329 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/route.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/route.rs @@ -194,3 +194,19 @@ impl TryFrom for RouteSegmentBody { }) } } + +impl From for CRtMsg { + fn from(body: RouteSegmentBody) -> Self { + CRtMsg { + family: body.family as u8, + dst_len: body.dst_len, + src_len: body.src_len, + tos: body.tos, + table: body.table as u8, + protocol: body.protocol as u8, + scope: body.scope as u8, + type_: body.type_ as u8, + flags: body.flags.bits(), + } + } +} From bc35dddccb150721c6bb529d87cdbea7d2533bf6 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Mon, 18 Aug 2025 20:35:30 +0800 Subject: [PATCH 16/36] =?UTF-8?q?feat:=20=E7=A7=BB=E5=8A=A8routing?= =?UTF-8?q?=E7=9A=84=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/mod.rs | 1 - kernel/src/driver/net/veth.rs | 2 +- kernel/src/net/mod.rs | 1 + .../net/route_iface.rs => net/routing.rs} | 19 ++++++++++++++----- 4 files changed, 16 insertions(+), 7 deletions(-) rename kernel/src/{driver/net/route_iface.rs => net/routing.rs} (95%) diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index ffa35eb2c..53d38855d 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -18,7 +18,6 @@ pub mod e1000e; pub mod irq_handle; pub mod kthread; pub mod loopback; -pub mod route_iface; pub mod sysfs; pub mod veth; pub mod virtio_net; diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index eb7ef0f60..a04693c6d 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -11,12 +11,12 @@ use crate::driver::base::kobject::{ }; use crate::driver::base::kset::KSet; use crate::driver::net::bridge::BridgePort; -use crate::driver::net::route_iface::{RouterEnableDevice, RouterEnableDeviceCommon}; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; use crate::libs::wait_queue::WaitQueue; +use crate::net::routing::{RouterEnableDevice, RouterEnableDeviceCommon}; use crate::net::{generate_iface_id, NET_DEVICES}; use crate::process::ProcessState; use crate::sched::SchedMode; diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 805761cc6..7150cecbe 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -10,6 +10,7 @@ use crate::{driver::net::Iface, libs::rwlock::RwLock}; pub mod net_core; pub mod posix; +pub mod routing; pub mod socket; pub mod syscall; diff --git a/kernel/src/driver/net/route_iface.rs b/kernel/src/net/routing.rs similarity index 95% rename from kernel/src/driver/net/route_iface.rs rename to kernel/src/net/routing.rs index 6f9929ccd..9e91a07d8 100644 --- a/kernel/src/driver/net/route_iface.rs +++ b/kernel/src/net/routing.rs @@ -67,6 +67,11 @@ impl RouteEntry { } } +#[derive(Debug, Default)] +pub struct RouteTable { + pub entries: Vec, +} + /// 路由决策结果 #[derive(Debug)] pub struct RouteDecision { @@ -80,31 +85,33 @@ pub struct RouteDecision { pub struct Router { name: String, /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec,并且应该在这上面加锁(maybe rwlock?) and 指针反而可以不加锁,在这个路由表这里加就行 - route_table: RwLock>, + route_table: RwLock, } impl Router { pub fn new(name: String) -> Self { Self { name, - route_table: RwLock::new(Vec::new()), + route_table: RwLock::new(RouteTable::default()), } } pub fn add_route(&mut self, route: RouteEntry) { let mut guard = self.route_table.write(); - let pos = guard + let entries = &mut guard.entries; + let pos = entries .iter() .position(|r| r.metric > route.metric) - .unwrap_or(guard.len()); + .unwrap_or(entries.len()); - guard.insert(pos, route); + entries.insert(pos, route); log::info!("Router {}: Added route to routing table", self.name); } pub fn remove_route(&mut self, destination: IpCidr) { self.route_table .write() + .entries .retain(|route| route.destination != destination); } @@ -112,6 +119,7 @@ impl Router { let guard = self.route_table.read(); // 按最长前缀匹配原则查找路由 let best = guard + .entries .iter() .filter(|route| { route.interface.strong_count() > 0 && route.destination.contains_addr(&dest_ip) @@ -135,6 +143,7 @@ impl Router { pub fn cleanup_routes(&mut self) { self.route_table .write() + .entries .retain(|route| route.interface.strong_count() > 0); } } From d766c8b2082172d1d4e2141cb07402d25fe4c8b8 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Tue, 19 Aug 2025 14:28:30 +0800 Subject: [PATCH 17/36] =?UTF-8?q?feat:=20=E8=A1=A5=E5=85=85netlink?= =?UTF-8?q?=E7=9A=84=E9=98=BB=E5=A1=9E=E7=AD=89=E5=BE=85=E9=80=BB=E8=BE=91?= =?UTF-8?q?&&fmt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/loopback.rs | 26 +++++++------- kernel/src/driver/net/mod.rs | 2 +- kernel/src/driver/net/veth.rs | 6 ++-- kernel/src/net/socket/inet/common/port.rs | 10 +++--- kernel/src/net/socket/netlink/common/bound.rs | 27 ++++++++------- kernel/src/net/socket/netlink/common/mod.rs | 34 ++++++++++++++++--- .../src/net/socket/netlink/common/unbound.rs | 5 +++ kernel/src/net/socket/netlink/route/bound.rs | 23 ++++++++----- .../src/net/socket/utils/datagram_common.rs | 12 +++++++ 9 files changed, 100 insertions(+), 45 deletions(-) diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index 69358c454..ba9167a74 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -88,6 +88,7 @@ impl phy::TxToken for LoopbackTxToken { } } +#[derive(Default)] /// ## Loopback设备 /// 成员是一个队列,用来存放接受到的数据包。 /// 当使用lo发送数据包时,不会把数据包传到link层,而是直接发送到该队列,实现环回。 @@ -97,12 +98,6 @@ pub struct Loopback { } impl Loopback { - /// ## Loopback创建函数 - /// 创建lo设备 - pub fn new() -> Self { - let queue = VecDeque::new(); - Loopback { queue } - } /// ## Loopback处理接受到的数据包函数 /// Loopback接受到数据后会调用这个函数来弹出接收的数据,返回给协议栈 /// @@ -172,10 +167,10 @@ pub struct LoopbackDriver { pub inner: Arc>, } -impl LoopbackDriver { +impl Default for LoopbackDriver { /// ## LoopbackDriver创建函数 - pub fn new() -> Self { - let inner = Arc::new(SpinLock::new(Loopback::new())); + fn default() -> Self { + let inner = Arc::new(SpinLock::new(Loopback::default())); LoopbackDriver { inner } } } @@ -515,16 +510,21 @@ impl Iface for LoopbackInterface { } } +pub fn generate_loopback_iface_default() -> Arc { + let iface = LoopbackInterface::new(LoopbackDriver::default()); + // 标识网络设备已经启动 + iface.set_net_state(NetDeivceState::__LINK_STATE_START); + iface +} + pub fn loopback_probe() { loopback_driver_init(); } + /// # lo网卡设备初始化函数 /// 创建驱动和iface,初始化一个lo网卡,添加到全局NET_DEVICES中 pub fn loopback_driver_init() { - let driver = LoopbackDriver::new(); - let iface = LoopbackInterface::new(driver); - // 标识网络设备已经启动 - iface.set_net_state(NetDeivceState::__LINK_STATE_START); + let iface = generate_loopback_iface_default(); NET_DEVICES .write_irqsave() diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 53d38855d..de8c3520e 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -196,7 +196,7 @@ impl IfaceCommon { smol_iface: SpinLock::new(iface), sockets: SpinLock::new(smoltcp::iface::SocketSet::new(Vec::new())), bounds: RwLock::new(Vec::new()), - port_manager: PortManager::new(), + port_manager: PortManager::default(), poll_at_ms: core::sync::atomic::AtomicU64::new(0), default_iface, } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index a04693c6d..5ff684d27 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -328,7 +328,7 @@ impl VethInterface { (iface1, iface2) } - fn inner(&self) -> SpinLockGuard { + fn inner(&self) -> SpinLockGuard<'_, VethCommonData> { self.inner.lock_irqsave() } @@ -432,11 +432,11 @@ impl KObject for VethInterface { } fn set_name(&self, _name: String) {} - fn kobj_state(&self) -> RwLockReadGuard { + fn kobj_state(&self) -> RwLockReadGuard<'_, KObjectState> { self.locked_kobj_state.read() } - fn kobj_state_mut(&self) -> RwLockWriteGuard { + fn kobj_state_mut(&self) -> RwLockWriteGuard<'_, KObjectState> { self.locked_kobj_state.write() } diff --git a/kernel/src/net/socket/inet/common/port.rs b/kernel/src/net/socket/inet/common/port.rs index e54e2a097..fb7e0ea97 100644 --- a/kernel/src/net/socket/inet/common/port.rs +++ b/kernel/src/net/socket/inet/common/port.rs @@ -19,14 +19,16 @@ pub struct PortManager { udp_port_table: SpinLock>, } -impl PortManager { - pub fn new() -> Self { - return Self { +impl Default for PortManager { + fn default() -> Self { + Self { tcp_port_table: SpinLock::new(HashMap::new()), udp_port_table: SpinLock::new(HashMap::new()), - }; + } } +} +impl PortManager { /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE pub fn get_ephemeral_port(&self, socket_type: Types) -> Result { // TODO: selects non-conflict high port diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs index 0c86e6d37..4255e2e64 100644 --- a/kernel/src/net/socket/netlink/common/bound.rs +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -1,7 +1,10 @@ -use crate::net::socket::netlink::{ - addr::{multicast::GroupIdSet, NetlinkSocketAddr}, - receiver::MessageQueue, - table::BoundHandle, +use crate::{ + filesystem::epoll::EPollEventType, + net::socket::netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + receiver::MessageQueue, + table::BoundHandle, + }, }; use alloc::fmt::Debug; use system_error::SystemError; @@ -32,16 +35,16 @@ impl BoundNetlink { Ok(()) } - // pub fn check_io_events_common(&self) -> EPollEventType { - // let mut events = EPollEventType::EPOLLOUT; + pub fn check_io_events_common(&self) -> EPollEventType { + let mut events = EPollEventType::EPOLLOUT; - // let receive_queue = self.receive_queue.0.lock(); - // if !receive_queue.is_empty() { - // events |= EPollEventType::EPOLLIN; - // } + let receive_queue = self.receive_queue.0.lock(); + if !receive_queue.is_empty() { + events |= EPollEventType::EPOLLIN; + } - // events - // } + events + } pub(super) fn add_groups(&mut self, groups: GroupIdSet) { self.handle.add_groups(groups); diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 6ae6d2111..519020754 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -1,4 +1,5 @@ use crate::{ + filesystem::epoll::EPollEventType, libs::{rwlock::RwLock, wait_queue::WaitQueue}, net::socket::{ endpoint::Endpoint, @@ -8,7 +9,7 @@ use crate::{ table::SupportedNetlinkProtocol, }, utils::datagram_common::{select_remote_and_bind, Bound, Inner}, - Socket, + Socket, PMSG, }, }; use alloc::sync::Arc; @@ -75,6 +76,15 @@ where Ok((recv_bytes, endpoint)) } + + /// 判断当前的netlink是否可以接收数据 + /// 目前netlink只是负责接收内核消息,所以暂时不用判断是否可以发送数据 + pub fn can_recv(&self) -> bool { + self.inner + .read() + .check_io_events() + .contains(EPollEventType::EPOLLIN) + } } impl Socket for NetlinkSocket

@@ -122,13 +132,29 @@ where flags: crate::net::socket::PMSG, address: Option, ) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> { - //todo 处理一下阻塞的逻辑 + use crate::sched::SchedMode; + + if let Some(addr) = address { + self.connect(addr)?; + } - self.try_recv(buffer, flags) + return if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) { + self.try_recv(buffer, flags) + } else { + loop { + match self.try_recv(buffer, flags) { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + } + result => break result, + } + } + }; + // self.try_recv(buffer, flags) } fn poll(&self) -> usize { - todo!() + self.inner.read().check_io_events().bits() as usize } fn send_buffer_size(&self) -> usize { diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs index 80f975657..1ff7e2bcb 100644 --- a/kernel/src/net/socket/netlink/common/unbound.rs +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -1,4 +1,5 @@ use crate::{ + filesystem::epoll::EPollEventType, libs::wait_queue::WaitQueue, net::socket::{ netlink::{ @@ -83,4 +84,8 @@ impl datagram_common::Unbound for UnboundNetlink

Ok(BoundNetlink::new(bound_handle, message_queue)) } + + fn check_io_events(&self) -> EPollEventType { + EPollEventType::EPOLLOUT + } } diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index a5c220b5d..825ccd784 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -1,12 +1,15 @@ -use crate::net::socket::{ - netlink::{ - addr::NetlinkSocketAddr, - common::bound::BoundNetlink, - message::ProtocolSegment, - route::{kernel::netlink_route_kernel, message::RouteNlMessage}, +use crate::{ + filesystem::epoll::EPollEventType, + net::socket::{ + netlink::{ + addr::NetlinkSocketAddr, + common::bound::BoundNetlink, + message::ProtocolSegment, + route::{kernel::netlink_route_kernel, message::RouteNlMessage}, + }, + utils::datagram_common, + PMSG, }, - utils::datagram_common, - PMSG, }; use system_error::SystemError; @@ -102,4 +105,8 @@ impl datagram_common::Bound for BoundNetlink { Ok((len, remote)) } + + fn check_io_events(&self) -> EPollEventType { + self.check_io_events_common() + } } diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index a86b73c87..1b4eb4384 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -1,3 +1,4 @@ +use crate::filesystem::epoll::EPollEventType; use crate::{ libs::{rwlock::RwLock, wait_queue::WaitQueue}, net::socket::PMSG, @@ -21,6 +22,8 @@ pub trait Unbound { endpoint: &Self::Endpoint, wait_queue: Arc, ) -> Result; + + fn check_io_events(&self) -> EPollEventType; } pub trait Bound { @@ -43,6 +46,8 @@ pub trait Bound { ) -> Result<(usize, Self::Endpoint), SystemError>; fn try_send(&self, buf: &[u8], to: &Self::Endpoint, flags: PMSG) -> Result; + + fn check_io_events(&self) -> EPollEventType; } #[derive(Debug)] @@ -108,6 +113,13 @@ where Ok(()) } + pub fn check_io_events(&self) -> EPollEventType { + match self { + Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(), + Inner::Bound(bound_datagram) => bound_datagram.check_io_events(), + } + } + pub fn addr(&self) -> Option { match self { Inner::Unbound(_) => None, From 0add58685b8f30429c79dab13afe0c230c4aef79 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Tue, 19 Aug 2025 22:18:47 +0800 Subject: [PATCH 18/36] =?UTF-8?q?feat(netns):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=BD=91=E7=BB=9C=E5=91=BD=E5=90=8D=E7=A9=BA=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 7 +- kernel/src/driver/net/e1000e/e1000e_driver.rs | 16 +- kernel/src/driver/net/irq_handle.rs | 14 +- kernel/src/driver/net/kthread.rs | 47 ---- kernel/src/driver/net/loopback.rs | 14 +- kernel/src/driver/net/mod.rs | 23 +- kernel/src/driver/net/veth.rs | 7 +- kernel/src/driver/net/virtio_net.rs | 38 ++- kernel/src/net/net_core.rs | 5 +- kernel/src/net/socket/inet/common/mod.rs | 75 ++++-- kernel/src/net/socket/inet/datagram/inner.rs | 19 +- kernel/src/net/socket/inet/datagram/mod.rs | 22 +- kernel/src/net/socket/inet/stream/inner.rs | 13 +- kernel/src/net/socket/inet/stream/mod.rs | 31 ++- .../src/net/socket/utils/datagram_common.rs | 2 + kernel/src/process/namespace/mod.rs | 1 + kernel/src/process/namespace/net_namespace.rs | 238 ++++++++++++++++++ kernel/src/process/namespace/nsproxy.rs | 16 +- 18 files changed, 461 insertions(+), 127 deletions(-) delete mode 100644 kernel/src/driver/net/kthread.rs create mode 100644 kernel/src/process/namespace/net_namespace.rs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index f905aff54..dcbd5c8a3 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -2,9 +2,9 @@ use crate::{ driver::net::{register_netdevice, veth::VethInterface, Iface, NetDeivceState, Operstate}, init::initcall::INITCALL_DEVICE, libs::{rwlock::RwLock, spinlock::SpinLock, wait_queue::WaitQueue}, - net::NET_DEVICES, process::{ kthread::{KernelThreadClosure, KernelThreadMechanism}, + namespace::net_namespace::INIT_NET_NAMESPACE, ProcessState, }, time::Instant, @@ -387,7 +387,10 @@ fn bridge_probe() { let turn_on = |a: &Arc| { a.set_net_state(NetDeivceState::__LINK_STATE_START); a.set_operstate(Operstate::IF_OPER_UP); - NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + INIT_NET_NAMESPACE.add_device(a.clone()); + a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); }; diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index 98275a6c2..08211d5d4 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -16,7 +16,8 @@ use crate::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, - net::{generate_iface_id, NET_DEVICES}, + net::generate_iface_id, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::Instant, }; use alloc::{ @@ -200,7 +201,7 @@ impl E1000EInterface { let iface = smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); - let result = Arc::new(E1000EInterface { + let iface = Arc::new(E1000EInterface { driver: E1000EDriverWrapper(UnsafeCell::new(driver)), common: IfaceCommon::new(iface_id, false, iface), name: format!("eth{}", iface_id), @@ -212,7 +213,7 @@ impl E1000EInterface { locked_kobj_state: LockedKObjectState::default(), }); - return result; + iface } pub fn inner(&self) -> SpinLockGuard<'_, InnerE1000EInterface> { @@ -400,9 +401,12 @@ pub fn e1000e_driver_init(device: E1000EDevice) { iface.set_net_state(NetDeivceState::__LINK_STATE_START); // 将网卡的接口信息注册到全局的网卡接口信息表中 - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); + // NET_DEVICES + // .write_irqsave() + // .insert(iface.nic_id(), iface.clone()); + INIT_NET_NAMESPACE.add_device(iface.clone()); + iface.common.set_net_namespace(INIT_NET_NAMESPACE.clone()); + info!("e1000e driver init successfully!\tMAC: [{}]", mac); register_netdevice(iface.clone()).expect("register lo device failed"); diff --git a/kernel/src/driver/net/irq_handle.rs b/kernel/src/driver/net/irq_handle.rs index 5a6cd0db0..3fd223f12 100644 --- a/kernel/src/driver/net/irq_handle.rs +++ b/kernel/src/driver/net/irq_handle.rs @@ -1,10 +1,13 @@ use alloc::sync::Arc; use system_error::SystemError; -use crate::exception::{ - irqdata::IrqHandlerData, - irqdesc::{IrqHandler, IrqReturn}, - IrqNumber, +use crate::{ + exception::{ + irqdata::IrqHandlerData, + irqdesc::{IrqHandler, IrqReturn}, + IrqNumber, + }, + process::namespace::net_namespace::INIT_NET_NAMESPACE, }; /// 默认的网卡中断处理函数 @@ -18,7 +21,8 @@ impl IrqHandler for DefaultNetIrqHandler { _static_data: Option<&dyn IrqHandlerData>, _dynamic_data: Option>, ) -> Result { - super::kthread::wakeup_poll_thread(); + // 这里先暂时唤醒 INIT 网络命名空间的轮询线程 + INIT_NET_NAMESPACE.wakeup_poll_thread(); Ok(IrqReturn::Handled) } } diff --git a/kernel/src/driver/net/kthread.rs b/kernel/src/driver/net/kthread.rs deleted file mode 100644 index 3c634499b..000000000 --- a/kernel/src/driver/net/kthread.rs +++ /dev/null @@ -1,47 +0,0 @@ -use alloc::borrow::ToOwned; -use alloc::sync::Arc; -use unified_init::macros::unified_init; - -use crate::arch::CurrentIrqArch; -use crate::exception::InterruptArch; -use crate::init::initcall::INITCALL_SUBSYS; -use crate::net::NET_DEVICES; -use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; -use crate::process::{ProcessControlBlock, ProcessManager}; -use crate::sched::{schedule, SchedMode}; - -static mut NET_POLL_THREAD: Option> = None; - -#[unified_init(INITCALL_SUBSYS)] -pub fn net_poll_init() -> Result<(), system_error::SystemError> { - let closure = KernelThreadClosure::StaticEmptyClosure((&(net_poll_thread as fn() -> i32), ())); - let pcb = KernelThreadMechanism::create_and_run(closure, "net_poll".to_owned()) - .ok_or("") - .expect("create net_poll thread failed"); - log::info!("net_poll thread created"); - unsafe { - NET_POLL_THREAD = Some(pcb); - } - return Ok(()); -} - -fn net_poll_thread() -> i32 { - log::info!("net_poll thread started"); - loop { - for (_, iface) in NET_DEVICES.read_irqsave().iter() { - iface.poll(); - } - let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; - ProcessManager::mark_sleep(true).expect("clocksource_watchdog_kthread:mark sleep failed"); - drop(irq_guard); - schedule(SchedMode::SM_NONE); - } -} - -/// 拉起线程 -pub(super) fn wakeup_poll_thread() { - if unsafe { NET_POLL_THREAD.is_none() } { - return; - } - let _ = ProcessManager::wakeup(unsafe { NET_POLL_THREAD.as_ref().unwrap() }); -} diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index ba9167a74..aafb9947e 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -11,7 +11,8 @@ use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; -use crate::net::{generate_iface_id, NET_DEVICES}; +use crate::net::generate_iface_id; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; use crate::time::Instant; use alloc::collections::VecDeque; use alloc::fmt::Debug; @@ -514,6 +515,9 @@ pub fn generate_loopback_iface_default() -> Arc { let iface = LoopbackInterface::new(LoopbackDriver::default()); // 标识网络设备已经启动 iface.set_net_state(NetDeivceState::__LINK_STATE_START); + + register_netdevice(iface.clone()).expect("register lo device failed"); + iface } @@ -526,11 +530,9 @@ pub fn loopback_probe() { pub fn loopback_driver_init() { let iface = generate_loopback_iface_default(); - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); - - register_netdevice(iface.clone()).expect("register lo device failed"); + INIT_NET_NAMESPACE.add_device(iface.clone()); + INIT_NET_NAMESPACE.set_loopback_iface(iface.clone()); + iface.common.set_net_namespace(INIT_NET_NAMESPACE.clone()); } /// ## lo网卡设备的注册函数 diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index de8c3520e..2a9abc6e6 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -1,8 +1,10 @@ +use alloc::sync::Weak; use alloc::{fmt, vec::Vec}; use alloc::{string::String, sync::Arc}; use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; +use crate::process::namespace::net_namespace::NetNamespace; use crate::{ libs::{rwlock::RwLock, spinlock::SpinLock}, net::socket::inet::{common::PortManager, InetSocket}, @@ -16,7 +18,6 @@ pub mod class; mod dma; pub mod e1000e; pub mod irq_handle; -pub mod kthread; pub mod loopback; pub mod sysfs; pub mod veth; @@ -126,6 +127,14 @@ pub trait Iface: crate::driver::base::device::Device { fn operstate(&self) -> Operstate; fn set_operstate(&self, state: Operstate); + + fn net_namespace(&self) -> Option> { + self.common().net_namespace() + } + + fn set_net_namespace(&self, ns: Arc) { + self.common().set_net_namespace(ns); + } } /// 网络设备的公共数据 @@ -178,6 +187,8 @@ pub struct IfaceCommon { /// 默认网卡标识 /// TODO: 此字段设置目的是解决对bind unspecified地址的分包问题,需要在inet实现多网卡监听或路由子系统实现后移除 default_iface: bool, + /// 网络命名空间 + net_namespace: RwLock>, } impl fmt::Debug for IfaceCommon { @@ -199,6 +210,7 @@ impl IfaceCommon { port_manager: PortManager::default(), poll_at_ms: core::sync::atomic::AtomicU64::new(0), default_iface, + net_namespace: RwLock::new(Weak::new()), } } @@ -317,4 +329,13 @@ impl IfaceCommon { .first() .map(|ip_addr| ip_addr.prefix_len()) } + + pub fn net_namespace(&self) -> Option> { + self.net_namespace.read().upgrade() + } + + pub fn set_net_namespace(&self, ns: Arc) { + let mut guard = self.net_namespace.write(); + *guard = Arc::downgrade(&ns); + } } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 5ff684d27..cb4c8bf07 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -16,8 +16,9 @@ use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; use crate::libs::wait_queue::WaitQueue; +use crate::net::generate_iface_id; use crate::net::routing::{RouterEnableDevice, RouterEnableDeviceCommon}; -use crate::net::{generate_iface_id, NET_DEVICES}; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; use crate::process::ProcessState; use crate::sched::SchedMode; use alloc::collections::VecDeque; @@ -677,7 +678,9 @@ pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc| { a.set_net_state(NetDeivceState::__LINK_STATE_START); a.set_operstate(Operstate::IF_OPER_UP); - NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + INIT_NET_NAMESPACE.add_device(a.clone()); + a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); register_netdevice(a.clone()).expect("register veth device failed"); }; diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 07ccce7d4..bd5c4b063 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -46,7 +46,8 @@ use crate::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, - net::{generate_iface_id, NET_DEVICES}, + net::generate_iface_id, + process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, time::Instant, }; use system_error::SystemError; @@ -68,6 +69,9 @@ pub struct VirtIONetDevice { dev_id: Arc, inner: SpinLock, locked_kobj_state: LockedKObjectState, + + // 这里放netns是为了在中断到来的时候可以遍历poll当前命名空间下的网卡 + netns: Arc, } impl Debug for VirtIONetDevice { @@ -96,7 +100,11 @@ impl Debug for InnerVirtIONetDevice { } impl VirtIONetDevice { - pub fn new(transport: VirtIOTransport, dev_id: Arc) -> Option> { + pub fn new( + transport: VirtIOTransport, + dev_id: Arc, + netns: Arc, + ) -> Option> { // 设置中断 if let Err(err) = transport.setup_irq(dev_id.clone()) { error!("VirtIONetDevice '{dev_id:?}' setup_irq failed: {:?}", err); @@ -125,6 +133,7 @@ impl VirtIONetDevice { device_common: DeviceCommonData::default(), }), locked_kobj_state: LockedKObjectState::default(), + netns, }); // dev.set_driver(Some(Arc::downgrade(&virtio_net_driver()) as Weak)); @@ -269,8 +278,7 @@ impl Device for VirtIONetDevice { impl VirtIODevice for VirtIONetDevice { fn handle_irq(&self, _irq: IrqNumber) -> Result { - // log::debug!("try to wakeup"); - super::kthread::wakeup_poll_thread(); + self.netns.wakeup_poll_thread(); return Ok(IrqReturn::Handled); } @@ -407,7 +415,7 @@ impl VirtioInterface { let iface = iface::Interface::new(iface_config, &mut device_inner, Instant::now().into()); - let result = Arc::new(VirtioInterface { + let iface = Arc::new(VirtioInterface { device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)), locked_kobj_state: LockedKObjectState::default(), iface_name: format!("eth{}", iface_id), @@ -419,7 +427,7 @@ impl VirtioInterface { }), }); - return result; + iface } fn inner(&self) -> SpinLockGuard<'_, InnerVirtIOInterface> { @@ -436,7 +444,10 @@ impl VirtioInterface { impl Drop for VirtioInterface { fn drop(&mut self) { // 从全局的网卡接口信息表中删除这个网卡的接口信息 - NET_DEVICES.write_irqsave().remove(&self.nic_id()); + // NET_DEVICES.write_irqsave().remove(&self.nic_id()); + if let Some(ns) = self.net_namespace() { + ns.remove_device(&self.nic_id()); + } } } @@ -623,7 +634,7 @@ pub fn virtio_net( dev_id: Arc, dev_parent: Option>, ) { - let virtio_net_deivce = VirtIONetDevice::new(transport, dev_id); + let virtio_net_deivce = VirtIONetDevice::new(transport, dev_id, INIT_NET_NAMESPACE.clone()); if let Some(virtio_net_deivce) = virtio_net_deivce { debug!("VirtIONetDevice '{:?}' created", virtio_net_deivce.dev_id); if let Some(dev_parent) = dev_parent { @@ -820,9 +831,14 @@ impl VirtIODriver for VirtIONetDriver { register_netdevice(iface.clone() as Arc)?; // 将网卡的接口信息注册到全局的网卡接口信息表中 - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); + // NET_DEVICES + // .write_irqsave() + // .insert(iface.nic_id(), iface.clone()); + INIT_NET_NAMESPACE.add_device(iface.clone()); + iface + .iface_common + .set_net_namespace(INIT_NET_NAMESPACE.clone()); + INIT_NET_NAMESPACE.set_default_iface(iface.clone()); virtio_irq_manager() .register_device(device.clone()) diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index f4603b6d5..5a939b2eb 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -3,7 +3,7 @@ use system_error::SystemError; use crate::{ driver::net::Operstate, - net::NET_DEVICES, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::{sleep::nanosleep, PosixTimeSpec}, }; @@ -12,7 +12,8 @@ pub fn net_init() -> Result<(), SystemError> { } fn dhcp_query() -> Result<(), SystemError> { - let binding = NET_DEVICES.write_irqsave(); + // let binding = NET_DEVICES.write_irqsave(); + let binding = INIT_NET_NAMESPACE.device_list_write(); let net_face = binding .iter() diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs index 8b60c0718..41e846c0c 100644 --- a/kernel/src/net/socket/inet/common/mod.rs +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -1,4 +1,4 @@ -use crate::net::{Iface, NET_DEVICES}; +use crate::{net::Iface, process::namespace::net_namespace::NetNamespace}; use alloc::sync::Arc; pub mod port; @@ -24,6 +24,7 @@ pub enum Types { pub struct BoundInner { handle: smoltcp::iface::SocketHandle, iface: Arc, + netns: Arc, // inner: Vec<(smoltcp::iface::SocketHandle, Arc)> // address: smoltcp::wire::IpAddress, } @@ -35,30 +36,42 @@ impl BoundInner { socket: T, // socket_type: Types, address: &smoltcp::wire::IpAddress, + netns: Arc, ) -> Result where T: smoltcp::socket::AnySocket<'static>, { if address.is_unspecified() { + let Some(iface) = netns.default_iface() else { + return Err(SystemError::ENODEV); + }; // 强绑VirtualIO - let iface = NET_DEVICES - .read_irqsave() - .iter() - .find_map(|(_, v)| { - if v.common().is_default_iface() { - Some(v.clone()) - } else { - None - } - }) - .expect("No default interface"); + // let iface = NET_DEVICES + // .read_irqsave() + // .iter() + // .find_map(|(_, v)| { + // if v.common().is_default_iface() { + // Some(v.clone()) + // } else { + // None + // } + // }) + // .expect("No default interface"); let handle = iface.sockets().lock().add(socket); - return Ok(Self { handle, iface }); + return Ok(Self { + handle, + iface, + netns, + }); } else { - let iface = get_iface_to_bind(address).ok_or(SystemError::ENODEV)?; + let iface = get_iface_to_bind(address, netns.clone()).ok_or(SystemError::ENODEV)?; let handle = iface.sockets().lock().add(socket); - return Ok(Self { handle, iface }); + return Ok(Self { + handle, + iface, + netns, + }); } } @@ -66,15 +79,23 @@ impl BoundInner { socket: T, // socket_type: Types, remote: smoltcp::wire::IpAddress, + netns: Arc, ) -> Result<(Self, smoltcp::wire::IpAddress), SystemError> where T: smoltcp::socket::AnySocket<'static>, { - let (iface, address) = get_ephemeral_iface(&remote); + let (iface, address) = get_ephemeral_iface(&remote, netns.clone()); // let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?; let handle = iface.sockets().lock().add(socket); // let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port); - Ok((Self { handle, iface }, address)) + Ok(( + Self { + handle, + iface, + netns, + }, + address, + )) } pub fn port_manager(&self) -> &PortManager { @@ -99,14 +120,21 @@ impl BoundInner { pub fn release(&self) { self.iface.sockets().lock().remove(self.handle); } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } #[inline] -pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option> { +pub fn get_iface_to_bind( + ip_addr: &smoltcp::wire::IpAddress, + netns: Arc, +) -> Option> { // log::debug!("get_iface_to_bind: {:?}", ip_addr); // if ip_addr.is_unspecified() - crate::net::NET_DEVICES - .read_irqsave() + netns + .device_list() .iter() .find(|(_, iface)| { let guard = iface.smol_iface().lock(); @@ -121,11 +149,12 @@ pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option, ) -> (Arc, smoltcp::wire::IpAddress) { - get_iface_to_bind(remote_ip_addr) + get_iface_to_bind(remote_ip_addr, netns.clone()) .map(|iface| (iface, *remote_ip_addr)) .or({ - let ifaces = NET_DEVICES.read_irqsave(); + let ifaces = netns.device_list(); ifaces.iter().find_map(|(_, iface)| { iface .smol_iface() @@ -137,7 +166,7 @@ fn get_ephemeral_iface( }) }) .or({ - NET_DEVICES.read_irqsave().values().next().map(|iface| { + netns.device_list().values().next().map(|iface| { ( iface.clone(), iface.smol_iface().lock().ip_addrs()[0].address(), diff --git a/kernel/src/net/socket/inet/datagram/inner.rs b/kernel/src/net/socket/inet/datagram/inner.rs index de1c3bd6f..0d2e79a56 100644 --- a/kernel/src/net/socket/inet/datagram/inner.rs +++ b/kernel/src/net/socket/inet/datagram/inner.rs @@ -1,9 +1,12 @@ +use alloc::sync::Arc; + use smoltcp; use system_error::SystemError; use crate::{ libs::spinlock::SpinLock, net::socket::inet::common::{BoundInner, Types as InetTypes}, + process::namespace::net_namespace::NetNamespace, }; pub type SmolUdpSocket = smoltcp::socket::udp::Socket<'static>; @@ -32,8 +35,12 @@ impl UnboundUdp { return Self { socket }; } - pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result { - let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?; + pub fn bind( + self, + local_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, + ) -> Result { + let inner = BoundInner::bind(self.socket, &local_endpoint.addr, netns)?; let bind_addr = local_endpoint.addr; let bind_port = if local_endpoint.port == 0 { inner.port_manager().bind_ephemeral_port(InetTypes::Udp)? @@ -65,9 +72,13 @@ impl UnboundUdp { }) } - pub fn bind_ephemeral(self, remote: smoltcp::wire::IpAddress) -> Result { + pub fn bind_ephemeral( + self, + remote: smoltcp::wire::IpAddress, + netns: Arc, + ) -> Result { // let (addr, port) = (remote.addr, remote.port); - let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote)?; + let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote, netns)?; let bound_port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?; let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port); Ok(BoundUdp { diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index a33fe5347..a857fcdcc 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -6,6 +6,8 @@ use crate::filesystem::epoll::EPollEventType; use crate::libs::wait_queue::WaitQueue; use crate::net::socket::common::EPollItems; use crate::net::socket::{Socket, PMSG}; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::ProcessManager; use crate::{libs::rwlock::RwLock, net::socket::endpoint::Endpoint}; use alloc::sync::{Arc, Weak}; use core::sync::atomic::AtomicBool; @@ -24,18 +26,21 @@ pub struct UdpSocket { nonblock: AtomicBool, wait_queue: WaitQueue, self_ref: Weak, + netns: Arc, epoll_items: EPollItems, } impl UdpSocket { pub fn new(nonblock: bool) -> Arc { - return Arc::new_cyclic(|me| Self { + let netns = ProcessManager::current_netns(); + Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(UdpInner::Unbound(UnboundUdp::new()))), nonblock: AtomicBool::new(nonblock), wait_queue: WaitQueue::default(), self_ref: me.clone(), + netns, epoll_items: EPollItems::default(), - }); + }) } pub fn is_nonblock(&self) -> bool { @@ -45,7 +50,7 @@ impl UdpSocket { pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> { let mut inner = self.inner.write(); if let Some(UdpInner::Unbound(unbound)) = inner.take() { - let bound = unbound.bind(local_endpoint)?; + let bound = unbound.bind(local_endpoint, self.netns())?; bound .inner() @@ -62,7 +67,7 @@ impl UdpSocket { let mut inner_guard = self.inner.write(); let bound = match inner_guard.take().expect("Udp inner is None") { UdpInner::Bound(inner) => inner, - UdpInner::Unbound(inner) => inner.bind_ephemeral(remote)?, + UdpInner::Unbound(inner) => inner.bind_ephemeral(remote, self.netns())?, }; inner_guard.replace(UdpInner::Bound(bound)); return Ok(()); @@ -119,9 +124,8 @@ impl UdpSocket { let mut inner_guard = self.inner.write(); let inner = match inner_guard.take().expect("Udp Inner is None") { UdpInner::Bound(bound) => bound, - UdpInner::Unbound(unbound) => { - unbound.bind_ephemeral(to.ok_or(SystemError::EADDRNOTAVAIL)?.addr)? - } + UdpInner::Unbound(unbound) => unbound + .bind_ephemeral(to.ok_or(SystemError::EADDRNOTAVAIL)?.addr, self.netns())?, }; // size = inner.try_send(buf, to)?; inner_guard.replace(UdpInner::Bound(inner)); @@ -176,6 +180,10 @@ impl UdpSocket { } let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } impl Socket for UdpSocket { diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index 7557b7e26..6d42c8de9 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -1,8 +1,10 @@ +use alloc::sync::Arc; use core::sync::atomic::AtomicUsize; use crate::filesystem::epoll::EPollEventType; use crate::libs::rwlock::RwLock; use crate::net::socket::{self, inet::Types}; +use crate::process::namespace::net_namespace::NetNamespace; use alloc::boxed::Box; use alloc::vec::Vec; use smoltcp; @@ -57,10 +59,11 @@ impl Init { pub(super) fn bind( self, local_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result { match self { Init::Unbound((socket, _)) => { - let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?; + let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr, netns)?; bound .port_manager() .bind_port(Types::Tcp, local_endpoint.port)?; @@ -77,11 +80,12 @@ impl Init { pub(super) fn bind_to_ephemeral( self, remote_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result<(socket::inet::BoundInner, smoltcp::wire::IpEndpoint), (Self, SystemError)> { match self { Init::Unbound((socket, ver)) => { let (bound, address) = - socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr) + socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr, netns) .map_err(|err| (Self::new(ver), err))?; let bound_port = bound .port_manager() @@ -97,9 +101,10 @@ impl Init { pub(super) fn connect( self, remote_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result { let (inner, local) = match self { - Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint)?, + Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint, netns)?, Init::Bound(inner) => inner, }; if local.addr.is_unspecified() { @@ -146,6 +151,7 @@ impl Init { .unwrap_or(&smoltcp::wire::IpAddress::from( smoltcp::wire::Ipv4Address::UNSPECIFIED, )), + inner.netns(), )?; inners.push(new_listen); } @@ -321,6 +327,7 @@ impl Listening { .unwrap_or(&smoltcp::wire::IpAddress::from( smoltcp::wire::Ipv4Address::UNSPECIFIED, )), + connected.netns(), )?; // swap the connected socket with the new_listen socket diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index ba0c3cdd5..6df588d8a 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -4,8 +4,13 @@ use system_error::SystemError; use crate::libs::rwlock::RwLock; use crate::libs::wait_queue::WaitQueue; +use crate::net::socket::common::shutdown::{ShutdownBit, ShutdownTemp}; use crate::net::socket::common::EPollItems; -use crate::net::socket::{common::ShutdownBit, endpoint::Endpoint, Socket, PMSG, PSOL}; +use crate::net::socket::endpoint::Endpoint; +use crate::net::socket::{Socket, SocketInode, PMSG, PSOL}; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::ProcessManager; + use crate::sched::SchedMode; use smoltcp; @@ -27,11 +32,13 @@ pub struct TcpSocket { wait_queue: WaitQueue, self_ref: Weak, pollee: AtomicUsize, + netns: Arc, epoll_items: EPollItems, } impl TcpSocket { - pub fn new(nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { + pub fn new(_nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { + let netns = ProcessManager::current_netns(); Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(inner::Inner::Init(inner::Init::new(ver)))), // shutdown: Shutdown::new(), @@ -39,11 +46,16 @@ impl TcpSocket { wait_queue: WaitQueue::default(), self_ref: me.clone(), pollee: AtomicUsize::new(0_usize), + netns, epoll_items: EPollItems::default(), }) } - pub fn new_established(inner: inner::Established, nonblock: bool) -> Arc { + pub fn new_established( + inner: inner::Established, + nonblock: bool, + netns: Arc, + ) -> Arc { Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(inner::Inner::Established(inner))), // shutdown: Shutdown::new(), @@ -51,6 +63,7 @@ impl TcpSocket { wait_queue: WaitQueue::default(), self_ref: me.clone(), pollee: AtomicUsize::new((EP::EPOLLIN.bits() | EP::EPOLLOUT.bits()) as usize), + netns, epoll_items: EPollItems::default(), }) } @@ -63,7 +76,7 @@ impl TcpSocket { let mut writer = self.inner.write(); match writer.take().expect("Tcp inner::Inner is None") { inner::Inner::Init(inner) => { - let bound = inner.bind(local_endpoint)?; + let bound = inner.bind(local_endpoint, self.netns())?; if let inner::Init::Bound((ref bound, _)) = bound { bound .iface() @@ -112,7 +125,7 @@ impl TcpSocket { { inner::Inner::Listening(listening) => listening.accept().map(|(stream, remote)| { ( - TcpSocket::new_established(stream, self.is_nonblock()), + TcpSocket::new_established(stream, self.is_nonblock(), self.netns()), remote, ) }), @@ -129,7 +142,7 @@ impl TcpSocket { let inner = writer.take().expect("Tcp inner::Inner is None"); let (init, result) = match inner { inner::Inner::Init(init) => { - let conn_result = init.connect(remote_endpoint); + let conn_result = init.connect(remote_endpoint, self.netns()); match conn_result { Ok(connecting) => ( inner::Inner::Connecting(connecting), @@ -254,13 +267,17 @@ impl TcpSocket { #[inline] fn incoming(&self) -> bool { - EP::from_bits_truncate(self.do_poll() as u32).contains(EP::EPOLLIN) + EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) } #[inline] fn do_poll(&self) -> usize { self.pollee.load(core::sync::atomic::Ordering::SeqCst) } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } impl Socket for TcpSocket { diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index 1b4eb4384..3ebb5e3f0 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -7,6 +7,8 @@ use alloc::sync::Arc; use core::panic; use system_error::SystemError; +//todo netlink和udp的操作相同,目前只是为netlink实现了下面的trait,后续为 UdpSocket实现下面的trait,提高复用性 + pub trait Unbound { type Endpoint; type Bound; diff --git a/kernel/src/process/namespace/mod.rs b/kernel/src/process/namespace/mod.rs index 0a0a78b7c..4427c7f78 100644 --- a/kernel/src/process/namespace/mod.rs +++ b/kernel/src/process/namespace/mod.rs @@ -1,4 +1,5 @@ pub mod mnt; +pub mod net_namespace; pub mod nsproxy; pub mod pid_namespace; pub mod unshare; diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs new file mode 100644 index 000000000..d0ddebbc0 --- /dev/null +++ b/kernel/src/process/namespace/net_namespace.rs @@ -0,0 +1,238 @@ +use crate::arch::CurrentIrqArch; +use crate::driver::net::loopback::{generate_loopback_iface_default, LoopbackInterface}; +use crate::exception::InterruptArch; +use crate::init::initcall::INITCALL_SUBSYS; +use crate::libs::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crate::net::routing::Router; +use crate::process::fork::CloneFlags; +use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; +use crate::process::namespace::{NamespaceOps, NamespaceType}; +use crate::process::{ProcessControlBlock, ProcessManager}; +use crate::sched::{schedule, SchedMode}; +use crate::{ + driver::net::Iface, + libs::spinlock::SpinLock, + process::namespace::{nsproxy::NsCommon, user_namespace::UserNamespace}, +}; +use alloc::boxed::Box; +use alloc::collections::BTreeMap; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +lazy_static! { + /// # 所有网络设备,进程,socket的初始网络命名空间 + pub static ref INIT_NET_NAMESPACE: Arc = NetNamespace::new_root(); +} + +#[unified_init(INITCALL_SUBSYS)] +pub fn root_net_namespace_thread_init() -> Result<(), SystemError> { + // 创建root网络命名空间的轮询线程 + let pcb = + NetNamespace::create_polling_thread(INIT_NET_NAMESPACE.clone(), "root_netns".to_string()); + INIT_NET_NAMESPACE.set_poll_thread(pcb); + Ok(()) +} + +#[derive(Debug)] +pub struct NetNamespace { + ns_common: NsCommon, + self_ref: Weak, + _user_ns: Arc, + inner: RwLock, + /// # 负责当前网络命名空间网卡轮询的线程 + net_poll_thread: SpinLock>>, + /// # 当前网络命名空间下所有网络接口的列表 + /// 这个列表在中断上下文会使用到,因此需要irqsave + /// 没有放在InnerNetNamespace里面,独立出来,方便管理 + device_list: RwLock>>, +} + +#[derive(Debug)] +pub struct InnerNetNamespace { + router: Arc, + /// 当前网络命名空间的loopback网卡 + loopback_iface: Option>, + default_iface: Option>, +} + +impl InnerNetNamespace { + pub fn router(&self) -> &Arc { + &self.router + } + + pub fn loopback_iface(&self) -> Option> { + self.loopback_iface.clone() + } +} + +impl NetNamespace { + pub fn new_root() -> Arc { + let inner = InnerNetNamespace { + router: Arc::new(Router::new("root_netns_router".to_string())), + loopback_iface: None, + default_iface: None, + }; + + let netns = Arc::new_cyclic(|self_ref| Self { + ns_common: NsCommon::new(0, NamespaceType::Net), + self_ref: self_ref.clone(), + _user_ns: super::user_namespace::INIT_USER_NAMESPACE.clone(), + inner: RwLock::new(inner), + net_poll_thread: SpinLock::new(None), + device_list: RwLock::new(BTreeMap::new()), + }); + + // Self::create_polling_thread(netns.clone(), "netns_root".to_string()); + log::info!("Initialized root net namespace"); + netns + } + + pub fn new_empty(user_ns: Arc) -> Result, SystemError> { + // 这里获取当前进程的pid,只是为了给后面创建的路由以及线程做唯一标识,没有其他意义 + let pid = ProcessManager::current_pid().0; + let loopback = generate_loopback_iface_default(); + + let inner = InnerNetNamespace { + router: Arc::new(Router::new(format!("netns_router_{}", pid))), + loopback_iface: Some(loopback), + default_iface: None, + }; + + let netns = Arc::new_cyclic(|self_ref| Self { + ns_common: NsCommon::new(0, NamespaceType::Net), + self_ref: self_ref.clone(), + _user_ns: user_ns, + inner: RwLock::new(inner), + net_poll_thread: SpinLock::new(None), + device_list: RwLock::new(BTreeMap::new()), + }); + Self::create_polling_thread(netns.clone(), format!("netns_{}", pid)); + + Ok(netns) + } + + pub(super) fn copy_net_ns( + &self, + clone_flags: &CloneFlags, + user_ns: Arc, + ) -> Result, SystemError> { + if !clone_flags.contains(CloneFlags::CLONE_NEWNET) { + return Ok(self.self_ref.upgrade().unwrap()); + } + + Self::new_empty(user_ns) + } + + pub fn device_list_write(&self) -> RwLockWriteGuard<'_, BTreeMap>> { + self.device_list.write_irqsave() + } + + pub fn device_list(&self) -> RwLockReadGuard<'_, BTreeMap>> { + self.device_list.read_irqsave() + } + + pub fn inner(&self) -> RwLockReadGuard<'_, InnerNetNamespace> { + self.inner.read() + } + + pub fn inner_mut(&self) -> RwLockWriteGuard<'_, InnerNetNamespace> { + self.inner.write() + } + + pub fn set_loopback_iface(&self, loopback: Arc) { + self.inner_mut().loopback_iface = Some(loopback); + } + + pub fn loopback_iface(&self) -> Option> { + self.inner().loopback_iface() + } + + pub fn set_default_iface(&self, iface: Arc) { + self.inner_mut().default_iface = Some(iface); + } + + pub fn default_iface(&self) -> Option> { + self.inner().default_iface.clone() + } + + pub fn add_device(&self, device: Arc) { + device.set_net_namespace(self.self_ref.upgrade().unwrap()); + + self.device_list + .write_irqsave() + .insert(device.nic_id(), device); + + log::info!( + "Network device added to namespace count: {:?}", + self.device_list.read_irqsave().len() + ); + } + + pub fn remove_device(&self, nic_id: &usize) { + self.device_list.write_irqsave().remove(nic_id); + } + + /// # 拉起网络命名空间的轮询线程 + pub fn wakeup_poll_thread(&self) { + if self.net_poll_thread.lock().is_none() { + return; + } + // log::info!("wakeup net_poll thread for namespace"); + let _ = ProcessManager::wakeup(self.net_poll_thread.lock().as_ref().unwrap()); + } + + /// # 网络命名空间的轮询线程 + /// 该线程会轮询当前命名空间下的所有网络接口 + /// 并调用它们的poll方法 + /// 注意: 此方法仅可在初始化当前net namespace时创建进程使用 + fn polling(&self) { + log::info!("net_poll thread started for namespace"); + loop { + for (_, iface) in self.device_list.read_irqsave().iter() { + iface.poll(); + } + let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; + ProcessManager::mark_sleep(true) + .expect("clocksource_watchdog_kthread:mark sleep failed"); + drop(irq_guard); + schedule(SchedMode::SM_NONE); + } + } + + fn create_polling_thread(netns: Arc, name: String) -> Arc { + let pcb = { + let closure: Box i32 + Send + Sync> = Box::new(move || { + netns.polling(); + 0 + }); + KernelThreadClosure::EmptyClosure((closure, ())) + }; + + let pcb = KernelThreadMechanism::create_and_run(pcb, name) + .ok_or("") + .expect("create net_poll thread for net namespace failed"); + log::info!("net_poll thread created for namespace"); + pcb + } + + /// # 设置网络命名空间的轮询线程 + /// 这个方法仅可在初始化网络命名空间时调用 + fn set_poll_thread(&self, pcb: Arc) { + let mut lock = self.net_poll_thread.lock(); + *lock = Some(pcb); + } +} + +impl NamespaceOps for NetNamespace { + fn ns_common(&self) -> &NsCommon { + &self.ns_common + } +} + +impl ProcessManager { + pub fn current_netns() -> Arc { + Self::current_pcb().nsproxy.read().net_ns.clone() + } +} diff --git a/kernel/src/process/namespace/nsproxy.rs b/kernel/src/process/namespace/nsproxy.rs index e9b26f5eb..dfc70ed20 100644 --- a/kernel/src/process/namespace/nsproxy.rs +++ b/kernel/src/process/namespace/nsproxy.rs @@ -5,6 +5,7 @@ use crate::process::{ fork::CloneFlags, namespace::{ mnt::{root_mnt_namespace, MntNamespace}, + net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, uts_namespace::{UtsNamespace, INIT_UTS_NAMESPACE}, }, ProcessControlBlock, ProcessManager, @@ -28,8 +29,11 @@ pub struct NsProxy { /// mount namespace(挂载命名空间) pub mnt_ns: Arc, pub uts_ns: Arc, + /// 网络命名空间 + pub net_ns: Arc, + // 注意,user_ns 存储在cred,不存储在nsproxy + // 其他namespace(为未来扩展预留) - // pub net_ns: Option>, // pub ipc_ns: Option>, // pub cgroup_ns: Option>, // pub time_ns: Option>, @@ -46,10 +50,12 @@ impl NsProxy { pub fn new_root() -> Arc { let root_pid_ns = super::pid_namespace::INIT_PID_NAMESPACE.clone(); let root_mnt_ns = root_mnt_namespace(); + let root_net_ns = INIT_NET_NAMESPACE.clone(); let root_uts_ns = INIT_UTS_NAMESPACE.clone(); Arc::new(Self { pid_ns_for_children: root_pid_ns, mnt_ns: root_mnt_ns, + net_ns: root_net_ns, uts_ns: root_uts_ns, }) } @@ -64,10 +70,16 @@ impl NsProxy { &self.mnt_ns } + /// 获取 net namespace + pub fn net_namespace(&self) -> &Arc { + &self.net_ns + } + pub fn clone_inner(&self) -> Self { Self { pid_ns_for_children: self.pid_ns_for_children.clone(), mnt_ns: self.mnt_ns.clone(), + net_ns: self.net_ns.clone(), uts_ns: self.uts_ns.clone(), } } @@ -147,11 +159,13 @@ pub(super) fn create_new_namespaces( .copy_pid_ns(clone_flags, user_ns.clone())?; let mnt_ns = nsproxy.mnt_ns.copy_mnt_ns(clone_flags, user_ns.clone())?; + let net_ns = nsproxy.net_ns.copy_net_ns(clone_flags, user_ns.clone())?; let uts_ns = nsproxy.uts_ns.copy_uts_ns(clone_flags, user_ns.clone())?; let result = NsProxy { pid_ns_for_children, mnt_ns, + net_ns, uts_ns, }; From 41b5c29a0a5622e13c95873bd91b7b469cc8f6e9 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Wed, 20 Aug 2025 12:35:41 +0800 Subject: [PATCH 19/36] =?UTF-8?q?feat(netns):=20=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E8=B7=AF=E7=94=B1,=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E5=BD=93=E5=89=8Dnetns=E4=B8=8B=E7=9A=84=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/routing.rs | 26 +++++++++---------- kernel/src/process/namespace/net_namespace.rs | 4 +++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/kernel/src/net/routing.rs b/kernel/src/net/routing.rs index 9e91a07d8..9421eb17d 100644 --- a/kernel/src/net/routing.rs +++ b/kernel/src/net/routing.rs @@ -1,7 +1,8 @@ use crate::driver::net::Iface; use crate::libs::rwlock::RwLock; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; use alloc::collections::BTreeMap; -use alloc::string::{String, ToString}; +use alloc::string::String; use alloc::sync::{Arc, Weak}; use alloc::vec::Vec; use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; @@ -148,12 +149,8 @@ impl Router { } } -lazy_static! { - pub static ref GLOBAL_ROUTER: Arc = Arc::new(Router::new("global_router".to_string())); -} - -pub fn global_router() -> Arc { - GLOBAL_ROUTER.clone() +pub fn init_netns_router() -> Arc { + INIT_NET_NAMESPACE.router().clone() } /// 可供路由设备应该实现的 trait @@ -195,8 +192,8 @@ pub trait RouterEnableDevice: Iface { return; } - // 查询全局路由表//todo 加入namespace之后在这里改成每个设备所属命名空间的Router即可 - let router = global_router(); + // 查询当前网络命名空间下的路由表 + let router = self.netns_router(); let decision = match router.lookup_route(dst_ip.into()) { Some(d) => d, @@ -238,27 +235,30 @@ pub trait RouterEnableDevice: Iface { /// 同Linux的ndo_start_xmit() /// /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( + /// 后续需要将arp查询的逻辑从smoltcp中抽离出来 fn route_and_send(&self, next_hop: IpAddress, ip_packet: &[u8]); /// 检查IP地址是否是当前接口的IP fn is_my_ip(&self, ip: IpAddress) -> bool; + + fn netns_router(&self) -> Arc { + self.net_namespace() + .map_or_else(|| init_netns_router(), |ns| ns.router()) + } } /// # 每一个`RouterEnableDevice`应该有的公共数据,包含 /// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) -/// - 当前接口的路由器 (//todo:引入命名空间之后在这里指向当前所属命名空间的Router) #[derive(Debug)] pub struct RouterEnableDeviceCommon { + /// 当前接口的邻居缓存 pub arp_table: RwLock>, - pub router: Weak, } impl Default for RouterEnableDeviceCommon { fn default() -> Self { - let router = global_router(); Self { arp_table: RwLock::new(BTreeMap::new()), - router: Arc::downgrade(&router), } } } diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index d0ddebbc0..a0460af1b 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -157,6 +157,10 @@ impl NetNamespace { self.inner().default_iface.clone() } + pub fn router(&self) -> Arc { + self.inner().router.clone() + } + pub fn add_device(&self, device: Arc) { device.set_net_namespace(self.self_ref.upgrade().unwrap()); From 7489eedb7cc4cb9eb04e686dd1d49a0eda216c01 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Wed, 20 Aug 2025 21:35:45 +0800 Subject: [PATCH 20/36] =?UTF-8?q?feat(netlink):=20=E5=B0=86netlink=20socke?= =?UTF-8?q?t=E7=A7=BB=E5=85=A5netns=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/mod.rs | 10 +- kernel/src/net/routing.rs | 2 +- kernel/src/net/socket/netlink/common/bound.rs | 14 ++- kernel/src/net/socket/netlink/common/mod.rs | 14 ++- .../src/net/socket/netlink/common/unbound.rs | 13 ++- .../net/socket/netlink/message/attr/mod.rs | 2 +- .../net/socket/netlink/message/attr/noattr.rs | 8 +- kernel/src/net/socket/netlink/message/mod.rs | 17 +++- kernel/src/net/socket/netlink/mod.rs | 26 +++-- kernel/src/net/socket/netlink/receiver.rs | 2 +- kernel/src/net/socket/netlink/route/bound.rs | 17 +++- .../net/socket/netlink/route/kernel/addr.rs | 37 ++++--- .../net/socket/netlink/route/kernel/mod.rs | 49 ++++++--- .../netlink/route/message/attr/route.rs | 4 +- .../netlink/route/message/segment/mod.rs | 4 + kernel/src/net/socket/netlink/route/mod.rs | 8 +- kernel/src/net/socket/netlink/table/mod.rs | 99 +++++++++++++------ .../src/net/socket/utils/datagram_common.rs | 14 ++- kernel/src/process/namespace/net_namespace.rs | 33 ++++++- 19 files changed, 261 insertions(+), 112 deletions(-) diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 7150cecbe..9b68003af 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -4,9 +4,7 @@ //! 敬请注意。 use core::sync::atomic::AtomicUsize; -use alloc::{collections::BTreeMap, sync::Arc}; - -use crate::{driver::net::Iface, libs::rwlock::RwLock}; +use crate::driver::net::Iface; pub mod net_core; pub mod posix; @@ -14,12 +12,6 @@ pub mod routing; pub mod socket; pub mod syscall; -lazy_static! { - /// # 所有网络接口的列表 - /// 这个列表在中断上下文会使用到,因此需要irqsave - pub static ref NET_DEVICES: RwLock>> = RwLock::new(BTreeMap::new()); -} - /// 生成网络接口的id (全局自增) pub fn generate_iface_id() -> usize { static IFACE_ID: AtomicUsize = AtomicUsize::new(0); diff --git a/kernel/src/net/routing.rs b/kernel/src/net/routing.rs index 9421eb17d..61bd3f41a 100644 --- a/kernel/src/net/routing.rs +++ b/kernel/src/net/routing.rs @@ -243,7 +243,7 @@ pub trait RouterEnableDevice: Iface { fn netns_router(&self) -> Arc { self.net_namespace() - .map_or_else(|| init_netns_router(), |ns| ns.router()) + .map_or_else(init_netns_router, |ns| ns.router()) } } diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs index 4255e2e64..b0a3c0e91 100644 --- a/kernel/src/net/socket/netlink/common/bound.rs +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -5,8 +5,10 @@ use crate::{ receiver::MessageQueue, table::BoundHandle, }, + process::namespace::net_namespace::NetNamespace, }; use alloc::fmt::Debug; +use alloc::sync::Arc; use system_error::SystemError; #[derive(Debug)] @@ -14,14 +16,20 @@ pub struct BoundNetlink { pub(in crate::net::socket::netlink) handle: BoundHandle, pub(in crate::net::socket::netlink) remote_addr: NetlinkSocketAddr, pub(in crate::net::socket::netlink) receive_queue: MessageQueue, + pub(in crate::net::socket::netlink) netns: Arc, } impl BoundNetlink { - pub(super) fn new(handle: BoundHandle, message_queue: MessageQueue) -> Self { + pub(super) fn new( + handle: BoundHandle, + message_queue: MessageQueue, + netns: Arc, + ) -> Self { Self { handle, remote_addr: NetlinkSocketAddr::new_unspecified(), receive_queue: message_queue, + netns, } } @@ -53,4 +61,8 @@ impl BoundNetlink { pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { self.handle.drop_groups(groups); } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 519020754..0359a798d 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -11,6 +11,7 @@ use crate::{ utils::datagram_common::{select_remote_and_bind, Bound, Inner}, Socket, PMSG, }, + process::{namespace::net_namespace::NetNamespace, ProcessManager}, }; use alloc::sync::Arc; use core::sync::atomic::AtomicBool; @@ -25,6 +26,7 @@ pub struct NetlinkSocket { is_nonblocking: AtomicBool, wait_queue: Arc, + netns: Arc, } impl NetlinkSocket

@@ -37,6 +39,7 @@ where inner: RwLock::new(Inner::Unbound(unbound)), is_nonblocking: AtomicBool::new(is_nonblocking), wait_queue: Arc::new(WaitQueue::default()), + netns: ProcessManager::current_netns(), }) } @@ -53,6 +56,7 @@ where self.inner.write().bind_ephemeral( &NetlinkSocketAddr::new_unspecified(), self.wait_queue.clone(), + self.netns(), ) }, |bound, remote| bound.try_send(buf, &remote, flags), @@ -85,6 +89,10 @@ where .check_io_events() .contains(EPollEventType::EPOLLIN) } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } impl Socket for NetlinkSocket

@@ -99,7 +107,7 @@ where self.inner .write() - .connect(&endpoint, self.wait_queue.clone()) + .connect(&endpoint, self.wait_queue.clone(), self.netns()) } fn bind( @@ -108,7 +116,9 @@ where ) -> Result<(), system_error::SystemError> { let endpoint = endpoint.try_into()?; - self.inner.write().bind(&endpoint, self.wait_queue.clone()) + self.inner + .write() + .bind(&endpoint, self.wait_queue.clone(), self.netns()) } fn send_to( diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs index 1ff7e2bcb..6ca92c4e1 100644 --- a/kernel/src/net/socket/netlink/common/unbound.rs +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -10,6 +10,7 @@ use crate::{ }, utils::datagram_common, }, + process::namespace::net_namespace::NetNamespace, }; use alloc::sync::Arc; use core::marker::PhantomData; @@ -48,8 +49,9 @@ impl datagram_common::Unbound for UnboundNetlink

fn bind( &mut self, - endpoint: &NetlinkSocketAddr, + endpoint: &Self::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result, SystemError> { let message_queue = MessageQueue::::new(); let bound_handle = { @@ -59,16 +61,17 @@ impl datagram_common::Unbound for UnboundNetlink

endpoint }; let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); -

::bind(&endpoint, receiver)? +

::bind(&endpoint, receiver, netns.clone())? }; - Ok(BoundNetlink::new(bound_handle, message_queue)) + Ok(BoundNetlink::new(bound_handle, message_queue, netns)) } fn bind_ephemeral( &mut self, _remote_endpoint: &Self::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result, SystemError> { let message_queue = MessageQueue::::new(); @@ -79,10 +82,10 @@ impl datagram_common::Unbound for UnboundNetlink

endpoint }; let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); -

::bind(&endpoint, receiver)? +

::bind(&endpoint, receiver, netns.clone())? }; - Ok(BoundNetlink::new(bound_handle, message_queue)) + Ok(BoundNetlink::new(bound_handle, message_queue, netns)) } fn check_io_events(&self) -> EPollEventType { diff --git a/kernel/src/net/socket/netlink/message/attr/mod.rs b/kernel/src/net/socket/netlink/message/attr/mod.rs index 8b10afe4f..a7ea25ad4 100644 --- a/kernel/src/net/socket/netlink/message/attr/mod.rs +++ b/kernel/src/net/socket/netlink/message/attr/mod.rs @@ -1,4 +1,4 @@ -pub mod noattr; +pub(super) mod noattr; use crate::net::socket::netlink::message::NLMSG_ALIGN; use alloc::vec::Vec; diff --git a/kernel/src/net/socket/netlink/message/attr/noattr.rs b/kernel/src/net/socket/netlink/message/attr/noattr.rs index 530f240d1..59f3d27a7 100644 --- a/kernel/src/net/socket/netlink/message/attr/noattr.rs +++ b/kernel/src/net/socket/netlink/message/attr/noattr.rs @@ -15,20 +15,20 @@ impl Attribute for NoAttr { fn read_from_buf( header: &super::CAttrHeader, - payload_buf: &[u8], + _payload_buf: &[u8], ) -> Result, system_error::SystemError> where Self: Sized, { - let payload_len = header.payload_len(); + let _payload_len = header.payload_len(); //todo reader.skip_some(payload_len); Ok(None) } fn read_all_from_buf( - buf: &[u8], - mut offset: usize, + _buf: &[u8], + _offset: usize, ) -> Result, system_error::SystemError> where Self: Sized, diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs index dd929e6a7..88fb71772 100644 --- a/kernel/src/net/socket/netlink/message/mod.rs +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -1,9 +1,11 @@ -use crate::net::socket::netlink::message::segment::header::CMsgSegHdr; +use crate::net::socket::netlink::{ + message::segment::header::CMsgSegHdr, table::StandardNetlinkProtocol, +}; use alloc::vec::Vec; use system_error::SystemError; -pub mod attr; -pub mod segment; +pub(super) mod attr; +pub(super) mod segment; #[derive(Debug)] pub struct Message { @@ -50,6 +52,14 @@ impl Message { .map(|segment| segment.header().len as usize) .sum() } + + pub fn protocol(&self) -> StandardNetlinkProtocol { + self.segments + .first() + .map_or(StandardNetlinkProtocol::UNUSED, |segment| { + segment.protocol() + }) + } } pub trait ProtocolSegment: Sized + alloc::fmt::Debug { @@ -57,6 +67,7 @@ pub trait ProtocolSegment: Sized + alloc::fmt::Debug { fn header_mut(&mut self) -> &mut CMsgSegHdr; fn read_from(reader: &[u8]) -> Result; fn write_to(&self, writer: &mut [u8]) -> Result; + fn protocol(&self) -> StandardNetlinkProtocol; } pub(super) const NLMSG_ALIGN: usize = 4; diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs index 7d3ddb10d..7a9df50e5 100644 --- a/kernel/src/net/socket/netlink/mod.rs +++ b/kernel/src/net/socket/netlink/mod.rs @@ -1,6 +1,9 @@ use crate::net::socket::{ family, - netlink::{route::NetlinkRouteSocket, table::StandardNetlinkProtocol}, + netlink::{ + route::NetlinkRouteSocket, + table::{is_valid_protocol, StandardNetlinkProtocol}, + }, SocketInode, }; use alloc::sync::Arc; @@ -8,10 +11,10 @@ use system_error::SystemError; pub mod addr; mod common; -pub mod message; +mod message; mod receiver; mod route; -mod table; +pub mod table; pub struct Netlink; @@ -34,9 +37,20 @@ fn create_netlink_socket(protocol: u32) -> Result, SystemError> let nl_protocol = StandardNetlinkProtocol::try_from(protocol); let inode = match nl_protocol { Ok(StandardNetlinkProtocol::ROUTE) => NetlinkRouteSocket::new(false), - _ => { - log::warn!("unsupported Netlink protocol: {}", protocol); - return Err(SystemError::EPROTONOSUPPORT); + Ok(_) => { + log::warn!( + "standard netlink families {} is not supported yet", + protocol + ); + return Err(SystemError::EAFNOSUPPORT); + } + Err(_) => { + if is_valid_protocol(protocol) { + log::error!("user-provided netlink family is not supported"); + return Err(SystemError::EPROTONOSUPPORT); + } + log::error!("invalid netlink protocol: {}", protocol); + return Err(SystemError::EAFNOSUPPORT); } }; diff --git a/kernel/src/net/socket/netlink/receiver.rs b/kernel/src/net/socket/netlink/receiver.rs index 044dec098..a38db9f5b 100644 --- a/kernel/src/net/socket/netlink/receiver.rs +++ b/kernel/src/net/socket/netlink/receiver.rs @@ -27,7 +27,7 @@ impl MessageQueue { } } -/// Netlink Socket 的消息接收器,记录在全局的 Netlink Socket 表中,负责将消息压入对应的消息队列,并唤醒等待的线程 +/// Netlink Socket 的消息接收器,记录在当前网络命名空间的 Netlink Socket 表中,负责将消息压入对应的消息队列,并唤醒等待的线程 #[derive(Debug)] pub struct MessageReceiver { message_queue: MessageQueue, diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index 825ccd784..1d6b301fd 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -5,7 +5,7 @@ use crate::{ addr::NetlinkSocketAddr, common::bound::BoundNetlink, message::ProtocolSegment, - route::{kernel::netlink_route_kernel, message::RouteNlMessage}, + route::{kernel::NetlinkRouteKernelSocket, message::RouteNlMessage}, }, utils::datagram_common, PMSG, @@ -73,7 +73,20 @@ impl datagram_common::Bound for BoundNetlink { } } - netlink_route_kernel().request(&nlmsg, local_port); + let Some(route_kernel) = self + .netns + .get_netlink_socket_by_protocol(nlmsg.protocol().into()) + else { + log::warn!("No route kernel socket available in net namespace"); + return Ok(sum_lens); + }; + + let route_kernel_socket = route_kernel + .as_any_ref() + .downcast_ref::() + .unwrap(); + + route_kernel_socket.request(&nlmsg, local_port, self.netns()); Ok(sum_lens) } diff --git a/kernel/src/net/socket/netlink/route/kernel/addr.rs b/kernel/src/net/socket/netlink/route/kernel/addr.rs index 974148bba..c35a75218 100644 --- a/kernel/src/net/socket/netlink/route/kernel/addr.rs +++ b/kernel/src/net/socket/netlink/route/kernel/addr.rs @@ -1,27 +1,25 @@ use crate::{ driver::net::Iface, - net::{ - socket::{ - netlink::{ - message::segment::{ - header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, - CSegmentType, - }, - route::{ - kernel::utils::finish_response, - message::{ - attr::addr::AddrAttr, - segment::{ - addr::{AddrMessageFlags, AddrSegment, AddrSegmentBody, RtScope}, - RouteNlSegment, - }, + net::socket::{ + netlink::{ + message::segment::{ + header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, + CSegmentType, + }, + route::{ + kernel::utils::finish_response, + message::{ + attr::addr::AddrAttr, + segment::{ + addr::{AddrMessageFlags, AddrSegment, AddrSegmentBody, RtScope}, + RouteNlSegment, }, }, }, - AddressFamily, }, - NET_DEVICES, + AddressFamily, }, + process::namespace::net_namespace::NetNamespace, }; use alloc::ffi::CString; use alloc::sync::Arc; @@ -31,6 +29,7 @@ use system_error::SystemError; pub(super) fn do_get_addr( request_segment: &AddrSegment, + netns: Arc, ) -> Result, SystemError> { let dump_all = { let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); @@ -42,8 +41,8 @@ pub(super) fn do_get_addr( return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); } - let mut responce: Vec = NET_DEVICES - .read() + let mut responce: Vec = netns + .device_list() .iter() .filter_map(|(_, iface)| iface_to_new_addr(request_segment.header(), iface)) .map(RouteNlSegment::NewAddr) diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs index 687053a22..7f33342e6 100644 --- a/kernel/src/net/socket/netlink/route/kernel/mod.rs +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -1,37 +1,52 @@ //! # Netlink route kernel module //! 内核对于 Netlink 路由的处理模块 -use crate::net::socket::netlink::{ - message::{ - segment::{ack::ErrorSegment, CSegmentType}, - ProtocolSegment, +use crate::{ + net::socket::netlink::{ + message::{ + segment::{ack::ErrorSegment, CSegmentType}, + ProtocolSegment, + }, + route::message::{segment::RouteNlSegment, RouteNlMessage}, + table::{ + NetlinkKernelSocket, NetlinkRouteProtocol, StandardNetlinkProtocol, + SupportedNetlinkProtocol, + }, }, - route::message::{segment::RouteNlSegment, RouteNlMessage}, - table::{NetlinkRouteProtocol, SupportedNetlinkProtocol}, + process::namespace::net_namespace::NetNamespace, }; +use alloc::sync::Arc; use core::marker::PhantomData; mod addr; mod utils; -pub(super) struct NetlinkRouteKernelSocket { +/// 负责处理 Netlink 路由相关的内核模块 +/// 每个 net namespace 都有一个独立的 NetlinkRouteKernelSocket +#[derive(Debug)] +pub struct NetlinkRouteKernelSocket { _private: PhantomData<()>, } impl NetlinkRouteKernelSocket { - const fn new() -> Self { + pub const fn new() -> Self { NetlinkRouteKernelSocket { _private: PhantomData, } } - pub(super) fn request(&self, request: &RouteNlMessage, dst_port: u32) { + pub(super) fn request( + &self, + request: &RouteNlMessage, + dst_port: u32, + netns: Arc, + ) { for segment in request.segments() { let header = segment.header(); let seg_type = CSegmentType::try_from(header.type_).unwrap(); let responce = match segment { - RouteNlSegment::GetAddr(request) => addr::do_get_addr(request), + RouteNlSegment::GetAddr(request) => addr::do_get_addr(request, netns.clone()), RouteNlSegment::GetRoute(_new_route) => todo!(), _ => { log::warn!("Unsupported route request segment type: {:?}", seg_type); @@ -48,15 +63,17 @@ impl NetlinkRouteKernelSocket { } }; - NetlinkRouteProtocol::unicast(dst_port, responce).unwrap(); + NetlinkRouteProtocol::unicast(dst_port, responce, netns.clone()).unwrap(); } } } -/// 负责处理 Netlink 路由相关的内核模块 -/// todo net namespace 实现之后应该是每一个 namespace 都有一个独立的 NetlinkRouteKernelSocket -static NETLINK_ROUTE_KERNEL: NetlinkRouteKernelSocket = NetlinkRouteKernelSocket::new(); +impl NetlinkKernelSocket for NetlinkRouteKernelSocket { + fn protocol(&self) -> StandardNetlinkProtocol { + StandardNetlinkProtocol::ROUTE + } -pub(super) fn netlink_route_kernel() -> &'static NetlinkRouteKernelSocket { - &NETLINK_ROUTE_KERNEL + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } } diff --git a/kernel/src/net/socket/netlink/route/message/attr/route.rs b/kernel/src/net/socket/netlink/route/message/attr/route.rs index 3ec62534b..b393be0a5 100644 --- a/kernel/src/net/socket/netlink/route/message/attr/route.rs +++ b/kernel/src/net/socket/netlink/route/message/attr/route.rs @@ -84,8 +84,8 @@ impl Attribute for RouteAttr { } fn read_from_buf( - header: &crate::net::socket::netlink::message::attr::CAttrHeader, - payload_buf: &[u8], + _header: &crate::net::socket::netlink::message::attr::CAttrHeader, + _payload_buf: &[u8], ) -> Result, SystemError> where Self: Sized, diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs index 5524f51a2..5dc6aeee8 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -103,4 +103,8 @@ impl ProtocolSegment for RouteNlSegment { Ok(copied) } + + fn protocol(&self) -> crate::net::socket::netlink::table::StandardNetlinkProtocol { + crate::net::socket::netlink::table::StandardNetlinkProtocol::ROUTE + } } diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs index 20bc0eb46..0eafb4dff 100644 --- a/kernel/src/net/socket/netlink/route/mod.rs +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -1,7 +1,7 @@ use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkRouteProtocol}; -pub mod bound; -mod kernel; -pub mod message; +pub(super) mod bound; +pub(super) mod kernel; +pub(super) mod message; -pub type NetlinkRouteSocket = NetlinkSocket; +pub(super) type NetlinkRouteSocket = NetlinkSocket; diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs index 6df514b3e..c5bce4ae7 100644 --- a/kernel/src/net/socket/netlink/table/mod.rs +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -1,36 +1,48 @@ mod multicast; use crate::net::socket::netlink::addr::multicast::GroupIdSet; +use crate::net::socket::netlink::route::kernel::NetlinkRouteKernelSocket; use crate::net::socket::netlink::route::message::RouteNlMessage; use crate::net::socket::netlink::table::multicast::MulticastMessage; +use crate::process::namespace::net_namespace::NetNamespace; use crate::process::ProcessManager; use crate::{libs::rand, net::socket::netlink::addr::NetlinkSocketAddr}; use crate::{ - libs::{once::Once, rwlock::RwLock}, + libs::rwlock::RwLock, net::socket::netlink::{receiver::MessageReceiver, table::multicast::MulticastGroup}, }; use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::fmt::Debug; +use alloc::sync::Arc; +use core::any::Any; +use hashbrown::HashMap; use system_error::SystemError; -static mut NETLINK_SOCKET_TABLE: Option = None; - -const MAX_ALLOWED_PROTOCOL_ID: u32 = 32; +pub const MAX_ALLOWED_PROTOCOL_ID: u32 = 32; const MAX_GROUPS: u32 = 32; -struct NetlinkSocketTable { - route: RwLock>, +#[derive(Debug)] +pub struct NetlinkSocketTable { + route: Arc>>, + // 在这里继续补充其他协议下的 socket table + // 比如 uevent: Arc>>, } -impl NetlinkSocketTable { - pub fn new() -> Self { +impl Default for NetlinkSocketTable { + fn default() -> Self { Self { - route: RwLock::new(ProtocolSocketTable::new()), + route: Arc::new(RwLock::new(ProtocolSocketTable::new())), } } } +impl NetlinkSocketTable { + pub fn route(&self) -> Arc>> { + self.route.clone() + } +} + #[derive(Debug)] pub struct ProtocolSocketTable { unicast_sockets: BTreeMap>, @@ -48,7 +60,7 @@ impl ProtocolSocketTable { fn bind( &mut self, - socket_table: &'static RwLock>, + socket_table: Arc>>, addr: &NetlinkSocketAddr, receiver: MessageReceiver, ) -> Result, SystemError> { @@ -103,14 +115,14 @@ impl ProtocolSocketTable { #[derive(Debug)] pub struct BoundHandle { - socket_table: &'static RwLock>, + socket_table: Arc>>, port: u32, groups: GroupIdSet, } impl BoundHandle { fn new( - socket_table: &'static RwLock>, + socket_table: Arc>>, port: u32, groups: GroupIdSet, ) -> Self { @@ -184,27 +196,37 @@ impl Drop for BoundHandle { pub trait SupportedNetlinkProtocol: Debug { type Message: 'static + Send + Debug; - fn socket_table() -> &'static RwLock>; + fn socket_table(netns: Arc) -> Arc>>; fn bind( addr: &NetlinkSocketAddr, receiver: MessageReceiver, + netns: Arc, ) -> Result, SystemError> { - let mut socket_table = Self::socket_table().write(); - socket_table.bind(Self::socket_table(), addr, receiver) + let socket_table = Self::socket_table(netns); + let mut socket_table_guard = socket_table.write(); + socket_table_guard.bind(socket_table.clone(), addr, receiver) } - fn unicast(dst_port: u32, message: Self::Message) -> Result<(), SystemError> { - let socket_table = Self::socket_table().read(); - socket_table.unicast(dst_port, message) + fn unicast( + dst_port: u32, + message: Self::Message, + netns: Arc, + ) -> Result<(), SystemError> { + Self::socket_table(netns).read().unicast(dst_port, message) } - fn multicast(dst_groups: GroupIdSet, message: Self::Message) -> Result<(), SystemError> + fn multicast( + dst_groups: GroupIdSet, + message: Self::Message, + netns: Arc, + ) -> Result<(), SystemError> where Self::Message: MulticastMessage, { - let socket_table = Self::socket_table().read(); - socket_table.multicast(dst_groups, message) + Self::socket_table(netns) + .read() + .multicast(dst_groups, message) } } @@ -214,22 +236,33 @@ pub struct NetlinkRouteProtocol; impl SupportedNetlinkProtocol for NetlinkRouteProtocol { type Message = RouteNlMessage; - fn socket_table() -> &'static RwLock> { - unsafe { &NETLINK_SOCKET_TABLE.as_ref().unwrap().route } + fn socket_table(netns: Arc) -> Arc>> { + netns.netlink_socket_table().route() } } -pub fn init() { - let once = Once::new(); - once.call_once(|| unsafe { - NETLINK_SOCKET_TABLE = Some(NetlinkSocketTable::new()); - }); -} - pub fn is_valid_protocol(protocol: u32) -> bool { protocol < MAX_ALLOWED_PROTOCOL_ID } +pub trait NetlinkKernelSocket: Debug + Send + Sync { + fn protocol(&self) -> StandardNetlinkProtocol; + + /// 用于实现动态转换 + fn as_any_ref(&self) -> &dyn Any; +} + +/// 为一个网络命名空间生成支持的 Netlink 内核套接字 +pub fn generate_supported_netlink_kernel_sockets() -> HashMap> { + let mut sockets: HashMap> = + HashMap::with_capacity(MAX_ALLOWED_PROTOCOL_ID as usize); + let route_socket = Arc::new(NetlinkRouteKernelSocket::new()); + sockets.insert(route_socket.protocol().into(), route_socket); + + // Add other supported netlink kernel sockets here + sockets +} + #[expect(non_camel_case_types)] #[repr(u32)] #[derive(Debug, Clone, Copy)] @@ -275,6 +308,12 @@ pub enum StandardNetlinkProtocol { SMC = 22, } +impl From for u32 { + fn from(value: StandardNetlinkProtocol) -> Self { + value as u32 + } +} + impl TryFrom for StandardNetlinkProtocol { type Error = (); diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index 3ebb5e3f0..c28d8b169 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -1,4 +1,5 @@ use crate::filesystem::epoll::EPollEventType; +use crate::process::namespace::net_namespace::NetNamespace; use crate::{ libs::{rwlock::RwLock, wait_queue::WaitQueue}, net::socket::PMSG, @@ -17,12 +18,14 @@ pub trait Unbound { &mut self, endpoint: &Self::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result; fn bind_ephemeral( &mut self, endpoint: &Self::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result; fn check_io_events(&self) -> EPollEventType; @@ -67,13 +70,14 @@ where &mut self, endpoint: &UnboundSocket::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result<(), SystemError> { let unbound = match self { Inner::Bound(bound) => return bound.bind(endpoint), Inner::Unbound(unbound) => unbound, }; - let bound = unbound.bind(endpoint, wait_queue)?; + let bound = unbound.bind(endpoint, wait_queue, netns)?; *self = Inner::Bound(bound); Ok(()) @@ -83,13 +87,14 @@ where &mut self, remote_endpoint: &UnboundSocket::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result<(), SystemError> { let unbound_datagram = match self { Inner::Unbound(unbound) => unbound, Inner::Bound(_) => return Ok(()), }; - let bound = unbound_datagram.bind_ephemeral(remote_endpoint, wait_queue)?; + let bound = unbound_datagram.bind_ephemeral(remote_endpoint, wait_queue, netns)?; *self = Inner::Bound(bound); Ok(()) @@ -99,8 +104,9 @@ where &mut self, remote_endpoint: &UnboundSocket::Endpoint, wait_queue: Arc, + netns: Arc, ) -> Result<(), SystemError> { - self.bind_ephemeral(remote_endpoint, wait_queue)?; + self.bind_ephemeral(remote_endpoint, wait_queue, netns)?; let bound = match self { Inner::Unbound(_) => { @@ -125,7 +131,7 @@ where pub fn addr(&self) -> Option { match self { Inner::Unbound(_) => None, - Inner::Bound(bound) => bound.remote_endpoint(), + Inner::Bound(bound) => Some(bound.local_endpoint()), } } diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index a0460af1b..67ccf701b 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -4,6 +4,9 @@ use crate::exception::InterruptArch; use crate::init::initcall::INITCALL_SUBSYS; use crate::libs::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use crate::net::routing::Router; +use crate::net::socket::netlink::table::{ + generate_supported_netlink_kernel_sockets, NetlinkKernelSocket, NetlinkSocketTable, +}; use crate::process::fork::CloneFlags; use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; use crate::process::namespace::{NamespaceOps, NamespaceType}; @@ -18,6 +21,7 @@ use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; +use hashbrown::HashMap; use system_error::SystemError; use unified_init::macros::unified_init; @@ -47,6 +51,14 @@ pub struct NetNamespace { /// 这个列表在中断上下文会使用到,因此需要irqsave /// 没有放在InnerNetNamespace里面,独立出来,方便管理 device_list: RwLock>>, + + // -- Netlink -- + /// # 当前网络命名空间下的 Netlink 套接字表 + /// 负责绑定netlink套接字的接收队列,以便发送接收消息 + netlink_socket_table: NetlinkSocketTable, + /// # 当前网络命名空间下的 Netlink 内核套接字 + /// 负责接收并处理 Netlink 消息 + netlink_kernel_socket: RwLock>>, } #[derive(Debug)] @@ -54,6 +66,8 @@ pub struct InnerNetNamespace { router: Arc, /// 当前网络命名空间的loopback网卡 loopback_iface: Option>, + /// 当前网络命名空间的默认网卡 + /// 这个网卡会在没有指定网卡的情况下使用 default_iface: Option>, } @@ -78,10 +92,12 @@ impl NetNamespace { let netns = Arc::new_cyclic(|self_ref| Self { ns_common: NsCommon::new(0, NamespaceType::Net), self_ref: self_ref.clone(), - _user_ns: super::user_namespace::INIT_USER_NAMESPACE.clone(), + _user_ns: crate::process::namespace::user_namespace::INIT_USER_NAMESPACE.clone(), inner: RwLock::new(inner), net_poll_thread: SpinLock::new(None), device_list: RwLock::new(BTreeMap::new()), + netlink_socket_table: NetlinkSocketTable::default(), + netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); // Self::create_polling_thread(netns.clone(), "netns_root".to_string()); @@ -91,7 +107,7 @@ impl NetNamespace { pub fn new_empty(user_ns: Arc) -> Result, SystemError> { // 这里获取当前进程的pid,只是为了给后面创建的路由以及线程做唯一标识,没有其他意义 - let pid = ProcessManager::current_pid().0; + let pid = ProcessManager::current_pid().data(); let loopback = generate_loopback_iface_default(); let inner = InnerNetNamespace { @@ -107,6 +123,8 @@ impl NetNamespace { inner: RwLock::new(inner), net_poll_thread: SpinLock::new(None), device_list: RwLock::new(BTreeMap::new()), + netlink_socket_table: NetlinkSocketTable::default(), + netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); Self::create_polling_thread(netns.clone(), format!("netns_{}", pid)); @@ -161,6 +179,17 @@ impl NetNamespace { self.inner().router.clone() } + pub fn netlink_socket_table(&self) -> &NetlinkSocketTable { + &self.netlink_socket_table + } + + pub fn get_netlink_socket_by_protocol( + &self, + protocol: u32, + ) -> Option> { + self.netlink_kernel_socket.read().get(&protocol).cloned() + } + pub fn add_device(&self, device: Arc) { device.set_net_namespace(self.self_ref.upgrade().unwrap()); From e82702725c47a40b50a557436eae47c15d3379d0 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Thu, 21 Aug 2025 15:18:10 +0800 Subject: [PATCH 21/36] =?UTF-8?q?feat:=20=E5=AE=8C=E6=88=90netlink=20addr?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E7=9A=84=E6=94=AF=E6=8C=81,=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95=E7=A8=8B=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 2 +- kernel/src/driver/net/veth.rs | 2 +- kernel/src/net/posix.rs | 17 + kernel/src/net/socket/netlink/addr/mod.rs | 2 +- kernel/src/net/socket/netlink/common/mod.rs | 12 +- kernel/src/net/socket/netlink/message/mod.rs | 22 +- .../socket/netlink/message/segment/common.rs | 34 +- .../net/socket/netlink/message/segment/mod.rs | 12 +- kernel/src/net/socket/netlink/route/bound.rs | 6 +- .../socket/netlink/route/message/attr/addr.rs | 1 - .../netlink/route/message/segment/mod.rs | 30 +- .../src/net/socket/utils/datagram_common.rs | 7 +- kernel/src/process/namespace/net_namespace.rs | 2 +- kernel/src/process/namespace/unshare.rs | 2 +- user/apps/c_unitest/test_netlink.c | 418 ++++++++++++++++++ 15 files changed, 520 insertions(+), 49 deletions(-) create mode 100644 user/apps/c_unitest/test_netlink.c diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index dcbd5c8a3..10707e191 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -365,7 +365,7 @@ fn bridge_probe() { let addr1 = IpAddress::v4(200, 0, 0, 1); let cidr1 = IpCidr::new(addr1, 24); let addr2 = IpAddress::v4(200, 0, 0, 2); - let cidr2 = IpCidr::new(addr1, 24); + let cidr2 = IpCidr::new(addr2, 24); let addr3 = IpAddress::v4(200, 0, 0, 3); let cidr3 = IpCidr::new(addr3, 24); diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index cb4c8bf07..0cd213dfd 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -566,7 +566,7 @@ impl Iface for VethInterface { } fn poll(&self) { - log::info!("VethInterface {} polling normal", self.name); + // log::info!("VethInterface {} polling normal", self.name); self.common.poll(self.driver.force_get_mut()); } diff --git a/kernel/src/net/posix.rs b/kernel/src/net/posix.rs index 3bb72142e..ec2a4f3e6 100644 --- a/kernel/src/net/posix.rs +++ b/kernel/src/net/posix.rs @@ -41,6 +41,8 @@ use alloc::string::ToString; use core::ffi::CStr; use system_error::SystemError; +use crate::net::socket::netlink::addr::{GroupIdSet, NetlinkSocketAddr}; + // 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 #[repr(C)] #[derive(Debug, Clone, Copy)] @@ -252,6 +254,21 @@ impl SockAddr { return Ok(Endpoint::Unix(UnixEndpoint::File(path.to_string()))); } + AddressFamily::Netlink => { + if len < addr.len()? { + log::error!("len < addr.len() for Netlink"); + return Err(SystemError::EINVAL); + } + + let addr_nl: SockAddrNl = addr.addr_nl; + let nl_pid = addr_nl.nl_pid; + let nl_groups = addr_nl.nl_groups; + + Ok(Endpoint::Netlink(NetlinkSocketAddr::new( + nl_pid, + GroupIdSet::new(nl_groups), + ))) + } _ => { log::warn!("not support address family {:?}", addr.family); return Err(SystemError::EINVAL); diff --git a/kernel/src/net/socket/netlink/addr/mod.rs b/kernel/src/net/socket/netlink/addr/mod.rs index a4cc14e1b..73e08de41 100644 --- a/kernel/src/net/socket/netlink/addr/mod.rs +++ b/kernel/src/net/socket/netlink/addr/mod.rs @@ -1,7 +1,7 @@ use crate::net::socket::{endpoint::Endpoint, netlink::addr::multicast::GroupIdSet}; use system_error::SystemError; -pub(super) mod multicast; +pub mod multicast; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct NetlinkSocketAddr { diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 0359a798d..9e9d4c301 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -142,6 +142,7 @@ where flags: crate::net::socket::PMSG, address: Option, ) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> { + // log::info!("NetlinkSocket recv_from called"); use crate::sched::SchedMode; if let Some(addr) = address { @@ -168,14 +169,21 @@ where } fn send_buffer_size(&self) -> usize { - log::warn!("send_buffer_size is implemented to 0"); + // log::warn!("send_buffer_size is implemented to 0"); + // netlink sockets typically do not have a send buffer size like stream sockets. 0 } fn recv_buffer_size(&self) -> usize { - log::warn!("recv_buffer_size is implemented to 0"); + // log::warn!("recv_buffer_size is implemented to 0"); + // netlink sockets typically do not have a recv buffer size like stream sockets. 0 } + + fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + let (len, _) = self.recv_from(buffer, flags, None)?; + Ok(len) + } } impl NetlinkSocket

{ diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs index 88fb71772..1bedbd21c 100644 --- a/kernel/src/net/socket/netlink/message/mod.rs +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -35,14 +35,22 @@ impl Message { } pub fn write_to(&self, writer: &mut [u8]) -> Result { - let total_written: usize = self - .segments - .iter() - .map(|segment| segment.write_to(writer)) - .collect::, SystemError>>()? - .iter() - .sum(); + // log::info!("Message write_to"); + let mut total_written: usize = 0; + + for segment in self.segments() { + if total_written >= writer.len() { + log::warn!("Netlink write buffer is full. Some segments may be dropped."); + break; + } + + let remaining_buf = &mut writer[total_written..]; + let written = segment.write_to(remaining_buf)?; + + total_written += written; + } + // log::info!("Total written bytes: {}", total_written); Ok(total_written) } diff --git a/kernel/src/net/socket/netlink/message/segment/common.rs b/kernel/src/net/socket/netlink/message/segment/common.rs index b095fbfc5..9f6c92c17 100644 --- a/kernel/src/net/socket/netlink/message/segment/common.rs +++ b/kernel/src/net/socket/netlink/message/segment/common.rs @@ -46,6 +46,7 @@ impl SegmentCommon { } pub fn read_from_buf(header: CMsgSegHdr, buf: &[u8]) -> Result { + // log::info!("SegmentCommon try to read from buffer"); let (body, remain_len, padded_len) = Body::read_from_buf(&header, buf)?; let attrs_buf = &buf[padded_len..]; @@ -58,18 +59,41 @@ impl SegmentCommon { }) } - pub fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + pub fn write_to_buf(&self, buf: &mut [u8]) -> Result { if buf.len() < self.header.len as usize { return Err(SystemError::EINVAL); } - self.body.write_to_buf(buf)?; + // Write header to the beginning of buf + let header_bytes = unsafe { + core::slice::from_raw_parts( + (&self.header as *const CMsgSegHdr) as *const u8, + Self::HEADER_LEN, + ) + }; + buf[..Self::HEADER_LEN].copy_from_slice(header_bytes); + + // 这里创建一个内核缓冲区,用来写入body和attribute,方便进行写入 + let mut kernel_buf: Vec = vec![]; + + self.body.write_to_buf(&mut kernel_buf)?; for attr in self.attrs.iter() { - // let cur_buf = &mut buf[offset..]; - attr.write_to_buf(buf)?; + attr.write_to_buf(&mut kernel_buf)?; } - Ok(()) + let actual_len = kernel_buf.len().min(buf.len()); + let payload_copied = if !kernel_buf.is_empty() { + buf[Self::HEADER_LEN..Self::HEADER_LEN + actual_len] + .copy_from_slice(&kernel_buf[..actual_len]); + // log::info!("buffer: {:?}", &buf[..actual_len]); + actual_len + } else { + // 如果没有数据需要写入,返回0 + // log::info!("No data to write to buffer"); + 0 + }; + + Ok(payload_copied + Self::HEADER_LEN) } pub fn total_len(&self) -> usize { diff --git a/kernel/src/net/socket/netlink/message/segment/mod.rs b/kernel/src/net/socket/netlink/message/segment/mod.rs index f4fc4ce7f..6d08d20c7 100644 --- a/kernel/src/net/socket/netlink/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/message/segment/mod.rs @@ -1,4 +1,6 @@ +use crate::libs::align::align_up; use crate::net::socket::netlink::message::{segment::header::CMsgSegHdr, NLMSG_ALIGN}; +use alloc::fmt::Debug; use alloc::vec::Vec; use system_error::SystemError; @@ -56,12 +58,13 @@ impl TryFrom for CSegmentType { } pub trait SegmentBody: Sized + Clone + Copy { - type CType: Copy + TryInto + From; + type CType: Copy + TryInto + From + Debug; fn read_from_buf(header: &CMsgSegHdr, buf: &[u8]) -> Result<(Self, usize, usize), SystemError> where Self: Sized, { + // log::info!("header: {:?}", header); let total_len = (header.len as usize) .checked_sub(size_of::()) .ok_or(SystemError::EINVAL)?; @@ -72,6 +75,7 @@ pub trait SegmentBody: Sized + Clone + Copy { let c_type_bytes = &buf[..size_of::()]; let c_type = unsafe { *(c_type_bytes.as_ptr() as *const Self::CType) }; + // log::info!("c_type: {:?}", c_type); let total_len_with_padding = Self::total_len_with_padding(); @@ -85,6 +89,7 @@ pub trait SegmentBody: Sized + Clone + Copy { } fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + // log::info!("SegmentBody write_to_buf"); let c_type = Self::CType::from(*self); let body_bytes = unsafe { @@ -107,10 +112,11 @@ pub trait SegmentBody: Sized + Clone + Copy { fn total_len_with_padding() -> usize { let payload_len = size_of::(); - (payload_len.checked_add(NLMSG_ALIGN - 1).unwrap() & !(NLMSG_ALIGN - 1)) - payload_len + align_up(payload_len, NLMSG_ALIGN) } fn padding_len() -> usize { - Self::total_len_with_padding() - size_of::() + let payload_len = size_of::(); + Self::total_len_with_padding() - payload_len } } diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs index 1d6b301fd..953a4a0af 100644 --- a/kernel/src/net/socket/netlink/route/bound.rs +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -42,10 +42,6 @@ impl datagram_common::Bound for BoundNetlink { return Err(SystemError::ENOTCONN); } - if *to != NetlinkSocketAddr::new_unspecified() { - return Err(SystemError::ECONNREFUSED); - } - let sum_lens = buf.len(); let mut nlmsg = match RouteNlMessage::read_from(buf) { @@ -75,7 +71,7 @@ impl datagram_common::Bound for BoundNetlink { let Some(route_kernel) = self .netns - .get_netlink_socket_by_protocol(nlmsg.protocol().into()) + .get_netlink_kernel_socket_by_protocol(nlmsg.protocol().into()) else { log::warn!("No route kernel socket available in net namespace"); return Ok(sum_lens); diff --git a/kernel/src/net/socket/netlink/route/message/attr/addr.rs b/kernel/src/net/socket/netlink/route/message/attr/addr.rs index 999f4ef2d..18d3c9c44 100644 --- a/kernel/src/net/socket/netlink/route/message/attr/addr.rs +++ b/kernel/src/net/socket/netlink/route/message/attr/addr.rs @@ -82,7 +82,6 @@ impl Attribute for AddrAttr { // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. let Ok(addr_class) = AddrAttrClass::try_from(header.type_()) else { - //todo 或许这里我应该返回偏移值 //reader.skip_some(payload_len); return Ok(None); }; diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs index 5dc6aeee8..9c5f64e7c 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -58,19 +58,21 @@ impl ProtocolSegment for RouteNlSegment { } fn read_from(buf: &[u8]) -> Result { - if buf.len() < size_of::() { + let header_size = size_of::(); + if buf.len() < header_size { log::warn!("the buffer is too small to read a netlink segment header"); return Err(SystemError::EINVAL); } let header = unsafe { *(buf.as_ptr() as *const CMsgSegHdr) }; + let payload_buf = &buf[header_size..]; let segment = match CSegmentType::try_from(header.type_)? { CSegmentType::GETADDR => { - RouteNlSegment::GetAddr(AddrSegment::read_from_buf(header, buf)?) + RouteNlSegment::GetAddr(AddrSegment::read_from_buf(header, payload_buf)?) } CSegmentType::GETROUTE => { - RouteNlSegment::GetRoute(RouteSegment::read_from_buf(header, buf)?) + RouteNlSegment::GetRoute(RouteSegment::read_from_buf(header, payload_buf)?) } _ => return Err(SystemError::EINVAL), }; @@ -79,26 +81,16 @@ impl ProtocolSegment for RouteNlSegment { } fn write_to(&self, buf: &mut [u8]) -> Result { - // 这里没有直接写入buf,而是用 Vec 来构建内核缓冲区 - let mut kernel_buf: Vec = vec![]; - match self { - RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(&mut kernel_buf)?, - RouteNlSegment::NewRoute(route_segment) => { - route_segment.write_to_buf(&mut kernel_buf)? - } + // log::info!("RouteNlSegment write_to"); + let copied = match self { + RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(buf)?, + RouteNlSegment::NewRoute(route_segment) => route_segment.write_to_buf(buf)?, + RouteNlSegment::Done(done_segment) => done_segment.write_to_buf(buf)?, + RouteNlSegment::Error(error_segment) => error_segment.write_to_buf(buf)?, _ => { log::warn!("write_to is not implemented for this segment type"); return Err(SystemError::ENOSYS); } - } - - let actual_len = kernel_buf.len().min(buf.len()); - let copied = if !kernel_buf.is_empty() { - buf[..actual_len].copy_from_slice(&kernel_buf[..actual_len]); - actual_len - } else { - // 如果没有数据需要写入,返回0 - 0 }; Ok(copied) diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index c28d8b169..852af4aa2 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -4,6 +4,7 @@ use crate::{ libs::{rwlock::RwLock, wait_queue::WaitQueue}, net::socket::PMSG, }; +use alloc::fmt::Debug; use alloc::sync::Arc; use core::panic; use system_error::SystemError; @@ -11,7 +12,7 @@ use system_error::SystemError; //todo netlink和udp的操作相同,目前只是为netlink实现了下面的trait,后续为 UdpSocket实现下面的trait,提高复用性 pub trait Unbound { - type Endpoint; + type Endpoint: Debug; type Bound; fn bind( @@ -32,7 +33,7 @@ pub trait Unbound { } pub trait Bound { - type Endpoint: Clone; + type Endpoint: Clone + Debug; fn bind(&mut self, _endpoint: &Self::Endpoint) -> Result<(), SystemError> { Err(SystemError::EINVAL) @@ -80,6 +81,8 @@ where let bound = unbound.bind(endpoint, wait_queue, netns)?; *self = Inner::Bound(bound); + // log::info!("Socket bound to endpoint: {:?}", endpoint); + Ok(()) } diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index 67ccf701b..497a28204 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -183,7 +183,7 @@ impl NetNamespace { &self.netlink_socket_table } - pub fn get_netlink_socket_by_protocol( + pub fn get_netlink_kernel_socket_by_protocol( &self, protocol: u32, ) -> Option> { diff --git a/kernel/src/process/namespace/unshare.rs b/kernel/src/process/namespace/unshare.rs index d6167595c..b0ae76847 100644 --- a/kernel/src/process/namespace/unshare.rs +++ b/kernel/src/process/namespace/unshare.rs @@ -21,7 +21,7 @@ pub fn ksys_unshare(flags: CloneFlags) -> Result<(), SystemError> { switch_task_namespaces(¤t_pcb, new_nsproxy)?; } // TODO: 处理其他命名空间的 unshare 操作 - // CLONE_NEWNS, CLONE_FS, CLONE_FILES, CLONE_SIGHAND, CLONE_VM, CLONE_THREAD, CLONE_SYSVSEM, + // CLONE_FS, CLONE_FILES, CLONE_SIGHAND, CLONE_VM, CLONE_THREAD, CLONE_SYSVSEM, // CLONE_NEWUTS, CLONE_NEWIPC, CLONE_NEWUSER, CLONE_NEWNET, CLONE_NEWCGROUP, CLONE_NEWTIME Ok(()) diff --git a/user/apps/c_unitest/test_netlink.c b/user/apps/c_unitest/test_netlink.c new file mode 100644 index 000000000..075b795a3 --- /dev/null +++ b/user/apps/c_unitest/test_netlink.c @@ -0,0 +1,418 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// 定义一个足够大的缓冲区来接收Netlink消息 +#define NL_BUFSIZE 8192 + +// 结构体,用于将请求消息封装起来 +struct nl_req_t { + struct nlmsghdr nlh; + struct ifaddrmsg ifa; +}; + +void parse_rtattr(struct rtattr *tb[], int max, struct rtattr *rta, int len) { + memset(tb, 0, sizeof(struct rtattr *) * (max + 1)); + while (RTA_OK(rta, len)) { + if (rta->rta_type <= max) { + tb[rta->rta_type] = rta; + } + rta = RTA_NEXT(rta, len); + } +} + +int run_netlink_test() { + int sock_fd; + struct sockaddr_nl sa_nl; + struct nl_req_t req; + + // struct iovec iov; + // struct msghdr msg; + // struct sockaddr_nl src_addr; // 用于接收发送方的地址 + + char buf[NL_BUFSIZE]; + ssize_t len; + struct nlmsghdr *nlh; + + // 创建Netlink套接字 + sock_fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if (sock_fd < 0) { + perror("socket creation failed"); + return EXIT_FAILURE; + } + + // 设置Netlink地址 + memset(&sa_nl, 0, sizeof(sa_nl)); + sa_nl.nl_family = AF_NETLINK; + sa_nl.nl_pid = getpid(); // 使用进程ID作为地址 + + // 绑定Netlink套接字 + if (bind(sock_fd, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) < 0) { + perror("socket bind failed"); + close(sock_fd); + return EXIT_FAILURE; + } + + // 构建RTM_GETADDR请求消息 + memset(&req, 0, sizeof(req)); + req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); + // 这是关键:设置DUMP标志以获取所有地址 + req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.nlh.nlmsg_type = RTM_GETADDR; + req.nlh.nlmsg_seq = 1; // 序列号,用于匹配请求和响应 + req.nlh.nlmsg_pid = getpid(); + req.ifa.ifa_family = + AF_INET; // 只请求 IPv4 地址 (也可以用 AF_UNSPEC 获取所有) + + // 内核的目标地址 + struct sockaddr_nl dest_addr; + memset(&dest_addr, 0, sizeof(dest_addr)); + dest_addr.nl_family = AF_NETLINK; + dest_addr.nl_pid = 0; // 0 for the kernel + dest_addr.nl_groups = 0; // Unicast + + // 发送请求到内核 + if (sendto(sock_fd, + &req, + req.nlh.nlmsg_len, + 0, + (struct sockaddr *)&dest_addr, + sizeof(dest_addr)) < 0) { + perror("send failed"); + close(sock_fd); + return EXIT_FAILURE; + } + + printf("Sent RTM_GETADDR request with DUMP flag.\n\n"); + + // // 准备 recvmsg 所需的结构体 + // // iovec 指向我们的主数据缓冲区 + // iov.iov_base = buf; + // iov.iov_len = sizeof(buf); + + // // msghdr 将所有部分组合在一起 + // msg.msg_name = &src_addr; // 填充发送方地址 + // msg.msg_namelen = sizeof(src_addr); + // msg.msg_iov = &iov; // 指向数据缓冲区 + // msg.msg_iovlen = 1; + // msg.msg_control = NULL; // 我们暂时不处理控制消息 + // msg.msg_controllen = 0; + + int received_messages = 0; + + // 循环接收响应 + while ((len = recv(sock_fd, buf, sizeof(buf), 0)) > 0) { + // ssize_t len = recvmsg(sock_fd, &msg, 0); + + // if (len < 0) { + // perror("recvmsg failed"); + // break; + // } + + // if (len == 0) { + // printf("EOF on netlink socket\n"); + // break; + // } + + // if (msg.msg_flags & MSG_TRUNC) { + // fprintf( + // stderr, + // "Warning: Message was truncated. Buffer may be too small.\n"); + // } + + + // 使用 NLMSG_OK 遍历缓冲区中可能存在的多条消息 + for (nlh = (struct nlmsghdr *)buf; NLMSG_OK(nlh, len); + nlh = NLMSG_NEXT(nlh, len)) { + + // 如果是 DUMP 结束的标志,则退出循环 + if (nlh->nlmsg_type == NLMSG_DONE) { + printf("--- End of DUMP ---\n"); + if (!received_messages) { + printf("(Received an empty list as expected)\n"); + } + close(sock_fd); + return EXIT_SUCCESS; + } + + // 如果是错误消息 + if (nlh->nlmsg_type == NLMSG_ERROR) { + struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh); + fprintf(stderr, + "Netlink error received: %s\n", + strerror(-err->error)); + close(sock_fd); + return EXIT_FAILURE; + } + + // 只处理我们期望的 RTM_NEWADDR 消息 + if (nlh->nlmsg_type != RTM_NEWADDR) { + printf("Received unexpected message type: %d\n", + nlh->nlmsg_type); + continue; + } + + // printf("Received message from PID: %u\n", src_addr.nl_pid); + + // 表明我们至少接收到了一条信息 + received_messages = 1; + + struct ifaddrmsg *ifa = (struct ifaddrmsg *)NLMSG_DATA(nlh); + struct rtattr *rta_tb[IFA_MAX + 1]; + + // 解析消息中的路由属性 + int rta_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*ifa)); + parse_rtattr(rta_tb, IFA_MAX, IFA_RTA(ifa), rta_len); + + printf("Interface Index: %d, PrefixLen: %d, Scope: %d\n", + ifa->ifa_index, + ifa->ifa_prefixlen, + ifa->ifa_scope); + + char ip_addr_str[INET6_ADDRSTRLEN]; + + // 打印 IFA_LABEL (对应你的 AddrAttr::Label) + if (rta_tb[IFA_LABEL]) { + printf("\tLabel: %s\n", (char *)RTA_DATA(rta_tb[IFA_LABEL])); + } + + // 打印 IFA_ADDRESS (对应你的 AddrAttr::Address) + if (rta_tb[IFA_ADDRESS]) { + inet_ntop(ifa->ifa_family, + RTA_DATA(rta_tb[IFA_ADDRESS]), + ip_addr_str, + sizeof(ip_addr_str)); + printf("\tAddress: %s\n", ip_addr_str); + } + + // 打印 IFA_LOCAL (对应你的 AddrAttr::Local) + if (rta_tb[IFA_LOCAL]) { + inet_ntop(ifa->ifa_family, + RTA_DATA(rta_tb[IFA_LOCAL]), + ip_addr_str, + sizeof(ip_addr_str)); + printf("\tLocal: %s\n", ip_addr_str); + } + printf("----------------------------------------\n"); + } + } + + if (len < 0) { + perror("recv failed"); + } + + close(sock_fd); + return EXIT_SUCCESS; +} + +int main(int argc, char *argv[]) { + + printf("=========== STAGE 1: Testing in Default Network Namespace " + "===========\n"); + if (run_netlink_test() != 0) { + fprintf(stderr, "Test failed in the default namespace.\n"); + return EXIT_FAILURE; + } + + printf("\n\n=========== STAGE 2: Creating and Testing in a New Network " + "Namespace ===========\n"); + + // ** 关键步骤:创建新的网络命名空间 ** + // 这个调用会将当前进程移入一个新的、隔离的网络栈中 + if (unshare(CLONE_NEWNET) == -1) { + perror("unshare(CLONE_NEWNET) failed"); + fprintf(stderr, + "This test requires root privileges (e.g., 'sudo " + "./your_program').\n"); + return EXIT_FAILURE; + } + printf("Successfully created and entered a new network namespace.\n"); + + // 在新的命名空间中再次运行同样的测试 + if (run_netlink_test() != 0) { + fprintf(stderr, "Test failed in the new namespace.\n"); + return EXIT_FAILURE; + } + + printf("\nAll tests completed successfully.\n"); + return EXIT_SUCCESS; + + // int sock_fd; + // struct sockaddr_nl sa_nl; + // struct nl_req_t req; + + // // struct iovec iov; + // // struct msghdr msg; + // // struct sockaddr_nl src_addr; // 用于接收发送方的地址 + + // char buf[NL_BUFSIZE]; + // ssize_t len; + // struct nlmsghdr *nlh; + + // // 创建Netlink套接字 + // sock_fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + // if (sock_fd < 0) { + // perror("socket creation failed"); + // return EXIT_FAILURE; + // } + + // // 设置Netlink地址 + // memset(&sa_nl, 0, sizeof(sa_nl)); + // sa_nl.nl_family = AF_NETLINK; + // sa_nl.nl_pid = getpid(); // 使用进程ID作为地址 + + // // 绑定Netlink套接字 + // if (bind(sock_fd, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) < 0) { + // perror("socket bind failed"); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // // 构建RTM_GETADDR请求消息 + // memset(&req, 0, sizeof(req)); + // req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); + // // 这是关键:设置DUMP标志以获取所有地址 + // req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + // req.nlh.nlmsg_type = RTM_GETADDR; + // req.nlh.nlmsg_seq = 1; // 序列号,用于匹配请求和响应 + // req.nlh.nlmsg_pid = getpid(); + // req.ifa.ifa_family = + // AF_INET; // 只请求 IPv4 地址 (也可以用 AF_UNSPEC 获取所有) + + // // 内核的目标地址 + // struct sockaddr_nl dest_addr; + // memset(&dest_addr, 0, sizeof(dest_addr)); + // dest_addr.nl_family = AF_NETLINK; + // dest_addr.nl_pid = 0; // 0 for the kernel + // dest_addr.nl_groups = 0; // Unicast + + // // 发送请求到内核 + // if (sendto(sock_fd, + // &req, + // req.nlh.nlmsg_len, + // 0, + // (struct sockaddr *)&dest_addr, + // sizeof(dest_addr)) < 0) { + // perror("send failed"); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // printf("Sent RTM_GETADDR request with DUMP flag.\n\n"); + + // // // 准备 recvmsg 所需的结构体 + // // // iovec 指向我们的主数据缓冲区 + // // iov.iov_base = buf; + // // iov.iov_len = sizeof(buf); + + // // // msghdr 将所有部分组合在一起 + // // msg.msg_name = &src_addr; // 填充发送方地址 + // // msg.msg_namelen = sizeof(src_addr); + // // msg.msg_iov = &iov; // 指向数据缓冲区 + // // msg.msg_iovlen = 1; + // // msg.msg_control = NULL; // 我们暂时不处理控制消息 + // // msg.msg_controllen = 0; + + // // 循环接收响应 + // while ((len = recv(sock_fd, buf, sizeof(buf), 0)) > 0) { + // // ssize_t len = recvmsg(sock_fd, &msg, 0); + + // // if (len < 0) { + // // perror("recvmsg failed"); + // // break; + // // } + + // // if (len == 0) { + // // printf("EOF on netlink socket\n"); + // // break; + // // } + + // // if (msg.msg_flags & MSG_TRUNC) { + // // fprintf( + // // stderr, + // // "Warning: Message was truncated. Buffer may be too small.\n"); + // // } + + + // // 使用 NLMSG_OK 遍历缓冲区中可能存在的多条消息 + // for (nlh = (struct nlmsghdr *)buf; NLMSG_OK(nlh, len); + // nlh = NLMSG_NEXT(nlh, len)) { + + // // 如果是 DUMP 结束的标志,则退出循环 + // if (nlh->nlmsg_type == NLMSG_DONE) { + // printf("--- End of DUMP ---\n"); + // close(sock_fd); + // return EXIT_SUCCESS; + // } + + // // 如果是错误消息 + // if (nlh->nlmsg_type == NLMSG_ERROR) { + // struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh); + // fprintf(stderr, + // "Netlink error received: %s\n", + // strerror(-err->error)); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // // 只处理我们期望的 RTM_NEWADDR 消息 + // if (nlh->nlmsg_type != RTM_NEWADDR) { + // printf("Received unexpected message type: %d\n", + // nlh->nlmsg_type); + // continue; + // } + + // // printf("Received message from PID: %u\n", src_addr.nl_pid); + + // struct ifaddrmsg *ifa = (struct ifaddrmsg *)NLMSG_DATA(nlh); + // struct rtattr *rta_tb[IFA_MAX + 1]; + + // // 解析消息中的路由属性 + // int rta_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*ifa)); + // parse_rtattr(rta_tb, IFA_MAX, IFA_RTA(ifa), rta_len); + + // printf("Interface Index: %d, PrefixLen: %d, Scope: %d\n", + // ifa->ifa_index, + // ifa->ifa_prefixlen, + // ifa->ifa_scope); + + // char ip_addr_str[INET6_ADDRSTRLEN]; + + // // 打印 IFA_LABEL (对应你的 AddrAttr::Label) + // if (rta_tb[IFA_LABEL]) { + // printf("\tLabel: %s\n", (char *)RTA_DATA(rta_tb[IFA_LABEL])); + // } + + // // 打印 IFA_ADDRESS (对应你的 AddrAttr::Address) + // if (rta_tb[IFA_ADDRESS]) { + // inet_ntop(ifa->ifa_family, + // RTA_DATA(rta_tb[IFA_ADDRESS]), + // ip_addr_str, + // sizeof(ip_addr_str)); + // printf("\tAddress: %s\n", ip_addr_str); + // } + + // // 打印 IFA_LOCAL (对应你的 AddrAttr::Local) + // if (rta_tb[IFA_LOCAL]) { + // inet_ntop(ifa->ifa_family, + // RTA_DATA(rta_tb[IFA_LOCAL]), + // ip_addr_str, + // sizeof(ip_addr_str)); + // printf("\tLocal: %s\n", ip_addr_str); + // } + // printf("----------------------------------------\n"); + // } + // } + + + // close(sock_fd); + // return EXIT_SUCCESS; +} \ No newline at end of file From b41404174c5e19744769b1aefcf08d538d6a0277 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Thu, 21 Aug 2025 15:31:12 +0800 Subject: [PATCH 22/36] =?UTF-8?q?feat(netlink):=20=E6=B6=88=E9=99=A4?= =?UTF-8?q?=E4=B8=80=E4=BA=9Bwarning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/socket/netlink/common/mod.rs | 1 + kernel/src/net/socket/netlink/common/unbound.rs | 4 ++++ kernel/src/net/socket/netlink/message/attr/mod.rs | 9 ++------- .../src/net/socket/netlink/route/message/segment/mod.rs | 1 - kernel/src/net/socket/utils/datagram_common.rs | 4 +++- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 9e9d4c301..8017592d2 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -198,6 +198,7 @@ impl NetlinkSocket

{ } } +// 多播消息的时候会用到,比如uevent impl Inner, BoundNetlink> { fn add_groups(&mut self, groups: GroupIdSet) { match self { diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs index 6ca92c4e1..05d59c07f 100644 --- a/kernel/src/net/socket/netlink/common/unbound.rs +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -91,4 +91,8 @@ impl datagram_common::Unbound for UnboundNetlink

fn check_io_events(&self) -> EPollEventType { EPollEventType::EPOLLOUT } + + fn local_endpoint(&self) -> Option { + Some(self.addr()) + } } diff --git a/kernel/src/net/socket/netlink/message/attr/mod.rs b/kernel/src/net/socket/netlink/message/attr/mod.rs index a7ea25ad4..8e21ef19b 100644 --- a/kernel/src/net/socket/netlink/message/attr/mod.rs +++ b/kernel/src/net/socket/netlink/message/attr/mod.rs @@ -1,6 +1,6 @@ pub(super) mod noattr; -use crate::net::socket::netlink::message::NLMSG_ALIGN; +use crate::{libs::align::align_up, net::socket::netlink::message::NLMSG_ALIGN}; use alloc::vec::Vec; use system_error::SystemError; @@ -40,7 +40,7 @@ impl CAttrHeader { } pub fn total_len_with_padding(&self) -> usize { - (self.len as usize).checked_add(NLMSG_ALIGN - 1).unwrap() & !(NLMSG_ALIGN - 1) + align_up(self.len as usize, NLMSG_ALIGN) } pub fn padding_len(&self) -> usize { @@ -154,8 +154,3 @@ pub trait Attribute: core::fmt::Debug + Send + Sync { Ok(attrs) } } - -// 辅助函数 -fn align_to(value: usize, align: usize) -> usize { - (value + align - 1) & !(align - 1) -} diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs index 9c5f64e7c..ea40e2cf0 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -12,7 +12,6 @@ use crate::net::socket::netlink::{ }, route::message::segment::{addr::AddrSegment, route::RouteSegment}, }; -use alloc::vec::Vec; use system_error::SystemError; #[derive(Debug)] diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs index 852af4aa2..ab55dda23 100644 --- a/kernel/src/net/socket/utils/datagram_common.rs +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -30,6 +30,8 @@ pub trait Unbound { ) -> Result; fn check_io_events(&self) -> EPollEventType; + + fn local_endpoint(&self) -> Option; } pub trait Bound { @@ -133,7 +135,7 @@ where pub fn addr(&self) -> Option { match self { - Inner::Unbound(_) => None, + Inner::Unbound(unbound) => unbound.local_endpoint(), Inner::Bound(bound) => Some(bound.local_endpoint()), } } From 538db913e2a21500d700385b765fa7243185fba7 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Thu, 21 Aug 2025 16:08:34 +0800 Subject: [PATCH 23/36] =?UTF-8?q?fix:=20=E6=96=B0=E5=BB=BAnetns=E6=97=B6?= =?UTF-8?q?=E6=8F=92=E5=85=A5loopback=E7=BD=91=E5=8D=A1=E5=88=B0=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/process/namespace/net_namespace.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index 497a28204..db903c822 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -112,7 +112,7 @@ impl NetNamespace { let inner = InnerNetNamespace { router: Arc::new(Router::new(format!("netns_router_{}", pid))), - loopback_iface: Some(loopback), + loopback_iface: Some(loopback.clone()), default_iface: None, }; @@ -127,6 +127,7 @@ impl NetNamespace { netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); Self::create_polling_thread(netns.clone(), format!("netns_{}", pid)); + netns.add_device(loopback); Ok(netns) } From 9f5813c7001c8ab08e1f6d04943289e7e59d037b Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 22 Aug 2025 11:00:03 +0800 Subject: [PATCH 24/36] =?UTF-8?q?feat:=20=E5=B0=86veth=E5=92=8Cbridge?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=A8=8B=E5=BA=8F=E6=94=B9=E7=94=A8C?= =?UTF-8?q?=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- user/apps/c_unitest/test_veth_bridge.c | 158 ++++++++++ user/apps/test-veth-bridge/.gitignore | 3 - user/apps/test-veth-bridge/Cargo.toml | 11 - user/apps/test-veth-bridge/Makefile | 56 ---- user/apps/test-veth-bridge/src/main.rs | 307 ------------------- user/dadk/config/test_veth_bridge_0_1_0.toml | 36 --- 6 files changed, 158 insertions(+), 413 deletions(-) create mode 100644 user/apps/c_unitest/test_veth_bridge.c delete mode 100644 user/apps/test-veth-bridge/.gitignore delete mode 100644 user/apps/test-veth-bridge/Cargo.toml delete mode 100644 user/apps/test-veth-bridge/Makefile delete mode 100644 user/apps/test-veth-bridge/src/main.rs delete mode 100644 user/dadk/config/test_veth_bridge_0_1_0.toml diff --git a/user/apps/c_unitest/test_veth_bridge.c b/user/apps/c_unitest/test_veth_bridge.c new file mode 100644 index 000000000..6ceed98d5 --- /dev/null +++ b/user/apps/c_unitest/test_veth_bridge.c @@ -0,0 +1,158 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "200.0.0.4" +#define CLIENT_IP "200.0.0.1" +#define PORT 34254 +#define BUFFER_SIZE 1024 + +// 错误处理函数 +void die(const char *message) { + perror(message); + exit(EXIT_FAILURE); +} + +// 服务器线程函数 +void *server_func(void *arg) { + int sockfd; + struct sockaddr_in server_addr, client_addr; + char buffer[BUFFER_SIZE]; + socklen_t client_len = sizeof(client_addr); + + // 1. 创建 UDP socket + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + die("[server] Failed to create socket"); + } + + // 2. 准备服务器地址结构 + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + die("[server] Invalid server IP address"); + } + + // 3. 绑定 socket 到指定地址和端口 + if (bind(sockfd, (const struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { + die("[server] Failed to bind to " SERVER_IP); + } + printf("[server] Listening on %s:%d\n", SERVER_IP, PORT); + + // 4. 接收数据 + ssize_t n = recvfrom(sockfd, buffer, BUFFER_SIZE, 0, (struct sockaddr *)&client_addr, &client_len); + if (n < 0) { + die("[server] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + char client_ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &client_addr.sin_addr, client_ip_str, INET_ADDRSTRLEN); + printf("[server] Received from %s:%d: %s\n", client_ip_str, ntohs(client_addr.sin_port), buffer); + + // 5. 将数据回显给客户端 + if (sendto(sockfd, buffer, n, 0, (const struct sockaddr *)&client_addr, client_len) < 0) { + die("[server] Failed to send back"); + } + printf("[server] Echoed back the message\n"); + + close(sockfd); + printf("server goning to exit\n"); + return NULL; +} + +// 客户端线程函数 +void *client_func(void *arg) { + int sockfd; + struct sockaddr_in client_addr, server_addr; + char buffer[BUFFER_SIZE]; + const char *msg = "Hello from veth1!"; + + // 1. 创建 UDP socket + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + die("[client] Failed to create socket"); + } + + // 2. 准备客户端地址结构(用于绑定) + memset(&client_addr, 0, sizeof(client_addr)); + client_addr.sin_family = AF_INET; + client_addr.sin_port = htons(0); // 端口为0,由操作系统自动选择 + if (inet_pton(AF_INET, CLIENT_IP, &client_addr.sin_addr) <= 0) { + die("[client] Invalid client IP address"); + } + + // 3. 绑定 socket 到客户端地址(可选,但为了匹配 Rust 代码的行为,我们这样做) + if (bind(sockfd, (const struct sockaddr *)&client_addr, sizeof(client_addr)) < 0) { + die("[client] Failed to bind to " CLIENT_IP); + } + + // 4. 准备服务器地址结构(用于连接) + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + die("[client] Invalid server IP address for connect"); + } + + // 5. 连接到服务器(这会使 UDP socket 记住目标地址) + if (connect(sockfd, (const struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { + die("[client] Failed to connect"); + } + + // 6. 发送消息(因为已连接,可以使用 send 而不是 sendto) + if (send(sockfd, msg, strlen(msg), 0) < 0) { + die("[client] Failed to send"); + } + printf("[client] Sent: %s\n", msg); + + // 7. 接收回显(因为已连接,可以使用 recv 而不是 recvfrom) + ssize_t n = recv(sockfd, buffer, BUFFER_SIZE, 0); + if (n < 0) { + die("[client] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + printf("[client] Received echo: %s\n", buffer); + + // 8. 验证消息是否匹配 + assert(strcmp(msg, buffer) == 0 && "[client] Mismatch in echo!"); + + close(sockfd); + printf("client goning to exit\n"); + return NULL; +} + +int main() { + pthread_t server_tid, client_tid; + + // 启动 server 线程 + if (pthread_create(&server_tid, NULL, server_func, NULL) != 0) { + die("Failed to create server thread"); + } + + // 确保 server 已启动 + usleep(200 * 1000); // 200 milliseconds + + // 启动 client 线程 + if (pthread_create(&client_tid, NULL, client_func, NULL) != 0) { + die("Failed to create client thread"); + } + + // 等待两个线程结束 + if (pthread_join(server_tid, NULL) != 0) { + die("Failed to join server thread"); + } + if (pthread_join(client_tid, NULL) != 0) { + die("Failed to join client thread"); + } + + printf("\n✅ Test completed: veth0 <--> veth1 UDP communication success\n"); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/user/apps/test-veth-bridge/.gitignore b/user/apps/test-veth-bridge/.gitignore deleted file mode 100644 index 1ac354611..000000000 --- a/user/apps/test-veth-bridge/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -/target -Cargo.lock -/install/ \ No newline at end of file diff --git a/user/apps/test-veth-bridge/Cargo.toml b/user/apps/test-veth-bridge/Cargo.toml deleted file mode 100644 index 77dbe7a0c..000000000 --- a/user/apps/test-veth-bridge/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[package] -name = "test-veth-bridge" -version = "0.1.0" -edition = "2021" -description = "测试veth pair" -authors = [ "sparkzky " ] - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -smoltcp = { git = "https://git.mirrors.dragonos.org.cn/DragonOS-Community/smoltcp.git", rev = "3e61c909fd540d05575068d16dc4574e196499ed", default-features = false, features = ["log", "alloc", "socket-raw", "socket-udp", "socket-tcp", "socket-icmp", "socket-dhcpv4", "socket-dns", "proto-ipv4", "proto-ipv6", "medium-ip"]} diff --git a/user/apps/test-veth-bridge/Makefile b/user/apps/test-veth-bridge/Makefile deleted file mode 100644 index 7522ea16c..000000000 --- a/user/apps/test-veth-bridge/Makefile +++ /dev/null @@ -1,56 +0,0 @@ -TOOLCHAIN= -RUSTFLAGS= - -ifdef DADK_CURRENT_BUILD_DIR -# 如果是在dadk中编译,那么安装到dadk的安装目录中 - INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) -else -# 如果是在本地编译,那么安装到当前目录下的install目录中 - INSTALL_DIR = ./install -endif - -ifeq ($(ARCH), x86_64) - export RUST_TARGET=x86_64-unknown-linux-musl -else ifeq ($(ARCH), riscv64) - export RUST_TARGET=riscv64gc-unknown-linux-gnu -else -# 默认为x86_86,用于本地编译 - export RUST_TARGET=x86_64-unknown-linux-musl -endif - -run: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) - -build: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) - -clean: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) - -test: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) - -doc: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) - -fmt: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt - -fmt-check: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check - -run-release: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release - -build-release: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release - -clean-release: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release - -test-release: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release - -.PHONY: install -install: - RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test-veth-bridge/src/main.rs b/user/apps/test-veth-bridge/src/main.rs deleted file mode 100644 index 5a2d9e629..000000000 --- a/user/apps/test-veth-bridge/src/main.rs +++ /dev/null @@ -1,307 +0,0 @@ -// // src/main.rs -// use smoltcp::phy::{Device, DeviceCapabilities, RxToken, TxToken}; -// use smoltcp::time::Instant; -// use std::collections::VecDeque; -// use std::sync::{Arc, Mutex}; - -// // 模拟 veth pair 中的一个端点 -// pub struct VethInner { -// queue: VecDeque>, -// peer: Option>>, -// } - -// impl VethInner { -// pub fn new() -> Self { -// Self { -// queue: VecDeque::new(), -// peer: None, -// } -// } - -// pub fn set_peer(&mut self, peer: Arc>) { -// self.peer = Some(peer); -// } - -// pub fn send_to_peer(&self, buf: Vec) { -// if let Some(peer) = &self.peer { -// peer.lock().unwrap().queue.push_back(buf); -// } -// } - -// pub fn recv(&mut self) -> Option> { -// self.queue.pop_front() -// } -// } - -// #[derive(Clone)] -// pub struct VethDriver { -// inner: Arc>, -// } - -// impl VethDriver { -// pub fn new_pair() -> (Self, Self) { -// let a = Arc::new(Mutex::new(VethInner::new())); -// let b = Arc::new(Mutex::new(VethInner::new())); -// a.lock().unwrap().set_peer(b.clone()); -// b.lock().unwrap().set_peer(a.clone()); -// (Self { inner: a }, Self { inner: b }) -// } -// } - -// pub struct VethTxToken { -// driver: VethDriver, -// } - -// impl TxToken for VethTxToken { -// fn consume(self, len: usize, f: F) -> R -// where -// F: FnOnce(&mut [u8]) -> R, -// { -// let mut buffer = vec![0u8; len]; -// let result = f(&mut buffer); -// self.driver.inner.lock().unwrap().send_to_peer(buffer); -// result -// } -// } - -// pub struct VethRxToken { -// buffer: Vec, -// } - -// impl RxToken for VethRxToken { -// fn consume(self, f: F) -> R -// where -// F: FnOnce(&[u8]) -> R, -// { -// f(&self.buffer) -// } -// } - -// impl Device for VethDriver { -// type RxToken<'a> = VethRxToken; -// type TxToken<'a> = VethTxToken; - -// fn capabilities(&self) -> DeviceCapabilities { -// let mut caps = DeviceCapabilities::default(); -// caps.max_transmission_unit = 1500; -// caps.medium = smoltcp::phy::Medium::Ethernet; -// caps -// } - -// fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { -// let mut inner = self.inner.lock().unwrap(); -// if let Some(buf) = inner.recv() { -// Some(( -// VethRxToken { buffer: buf }, -// VethTxToken { -// driver: self.clone(), -// }, -// )) -// } else { -// None -// } -// } - -// fn transmit(&mut self, _timestamp: Instant) -> Option> { -// Some(VethTxToken { -// driver: self.clone(), -// }) -// } -// } - -// fn main() { -// let (mut veth0, veth1) = VethDriver::new_pair(); -// let (veth3, mut veth4) = VethDriver::new_pair(); - -// let mut bridge = BridgeDevice::new(); -// bridge.add_port(veth1.clone()); -// bridge.add_port(veth3.clone()); - -// // veth0 → bridge → veth1 & veth3(→ veth4) -// println!("--- veth0 → bridge (→ veth1, veth3) ---"); -// if let Some(tx) = veth0.transmit(Instant::from_millis(0)) { -// tx.consume(32, |buf| { -// buf[..6].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dst MAC -// buf[6..12].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); // src MAC -// buf[12..14].copy_from_slice(&(0x0800u16.to_be_bytes())); // Ethertype -// buf[14..].copy_from_slice(b"hello bridge world"); // payload - -// bridge.handle_frame(&veth0, &buf); -// }); -// } - -// if let Some((rx, _tx)) = veth1.clone().receive(Instant::from_millis(0)) { -// rx.consume(|buf| { -// println!("veth1 received: {:02x?}", buf); -// }); -// } else { -// println!("veth1 received nothing"); -// } - -// if let Some((rx, _tx)) = veth4.receive(Instant::from_millis(0)) { -// rx.consume(|buf| { -// println!("veth4 received: {:02x?}", buf); -// }); -// } else { -// println!("veth4 received nothing"); -// } -// } - -// // 网桥设备:只做广播转发(无 MAC 学习) -// pub struct BridgeDevice { -// pub ports: Vec, -// } - -// impl BridgeDevice { -// pub fn new() -> Self { -// BridgeDevice { ports: Vec::new() } -// } - -// pub fn add_port(&mut self, port: VethDriver) { -// self.ports.push(port); -// } - -// pub fn remove_port(&mut self, port: &VethDriver) { -// self.ports.retain(|p| !Arc::ptr_eq(&p.inner, &port.inner)); -// } - -// pub fn handle_frame(&mut self, src_if: &VethDriver, frame: &[u8]) { -// for port in &self.ports { -// if !Arc::ptr_eq(&port.inner, &src_if.inner) { -// port.inner.lock().unwrap().send_to_peer(frame.to_vec()); -// } -// } -// } -// } - -// use std::net::UdpSocket; -// use std::str; -// use std::thread; -// use std::time::Duration; - -// fn main() -> std::io::Result<()> { -// // 启动 server 线程 -// let server_thread = thread::spawn(|| { -// let socket = -// UdpSocket::bind("10.0.0.2:34254").expect("Failed to bind to veth1 (10.0.0.2:34254)"); -// println!("[server] Listening on 10.0.0.2:34254"); - -// let mut buf = [0; 1024]; -// let (amt, src) = socket -// .recv_from(&mut buf) -// .expect("[server] Failed to receive"); - -// let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - -// println!("[server] Received from {}: {}", src, received_msg); - -// socket -// .send_to(received_msg.as_bytes(), src) -// .expect("[server] Failed to send back"); -// println!("[server] Echoed back the message"); -// }); - -// // 确保 server 已启动(可根据情况适当 sleep) -// thread::sleep(Duration::from_millis(200)); - -// // 启动 client -// let client_thread = thread::spawn(|| { -// let socket = UdpSocket::bind("10.0.0.1:0").expect("Failed to bind to veth0 (10.0.0.1)"); -// socket -// .connect("10.0.0.2:34254") -// .expect("Failed to connect to 10.0.0.2:34254"); - -// let msg = "Hello from veth0!"; -// socket -// .send(msg.as_bytes()) -// .expect("[client] Failed to send"); - -// println!("[client] Sent: {}", msg); - -// let mut buf = [0; 1024]; -// let (amt, _src) = socket -// .recv_from(&mut buf) -// .expect("[client] Failed to receive"); - -// let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - -// println!("[client] Received echo: {}", received_msg); - -// assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); -// }); - -// // 等待两个线程结束 -// server_thread.join().unwrap(); -// client_thread.join().unwrap(); - -// println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); - -// Ok(()) -// } - -//bridge - -use std::net::UdpSocket; -use std::str; -use std::thread; -use std::time::Duration; - -fn main() -> std::io::Result<()> { - // 启动 server 线程 - let server_thread = thread::spawn(|| { - let socket = - UdpSocket::bind("200.0.0.4:34254").expect("Failed to bind to veth_d (200.0.0.4:34254)"); - println!("[server] Listening on 200.0.0.4:34254"); - - let mut buf = [0; 1024]; - let (amt, src) = socket - .recv_from(&mut buf) - .expect("[server] Failed to receive"); - - let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - - println!("[server] Received from {}: {}", src, received_msg); - - socket - .send_to(received_msg.as_bytes(), src) - .expect("[server] Failed to send back"); - println!("[server] Echoed back the message"); - }); - - // 确保 server 已启动(可根据情况适当 sleep) - thread::sleep(Duration::from_millis(200)); - - // 启动 client - let client_thread = thread::spawn(|| { - let socket = UdpSocket::bind("200.0.0.1:0").expect("Failed to bind to veth_a (200.0.0.1)"); - socket - .connect("200.0.0.4:34254") - .expect("Failed to connect to 200.0.0.4:34254"); - - let msg = "Hello from veth1!"; - socket - .send(msg.as_bytes()) - .expect("[client] Failed to send"); - - println!("[client] Sent: {}", msg); - - let mut buf = [0; 1024]; - let (amt, _src) = socket - .recv_from(&mut buf) - .expect("[client] Failed to receive"); - - let received_msg = str::from_utf8(&buf[..amt]).expect("Invalid UTF-8"); - - println!("[client] Received echo: {}", received_msg); - - assert_eq!(msg, received_msg, "[client] Mismatch in echo!"); - }); - - // 等待两个线程结束 - server_thread.join().unwrap(); - client_thread.join().unwrap(); - - println!("\n✅ Test completed: veth0 <--> veth1 UDP communication success"); - - Ok(()) -} diff --git a/user/dadk/config/test_veth_bridge_0_1_0.toml b/user/dadk/config/test_veth_bridge_0_1_0.toml deleted file mode 100644 index fc8d58738..000000000 --- a/user/dadk/config/test_veth_bridge_0_1_0.toml +++ /dev/null @@ -1,36 +0,0 @@ -# 用户程序名称 -name = "test-veth-bridge" -# 版本号 -version = "0.1.0" -# 用户程序描述信息 -description = "test for veth and bridge" -# (可选)默认: false 是否只构建一次,如果为true,DADK会在构建成功后,将构建结果缓存起来,下次构建时,直接使用缓存的构建结果 -build-once = false -# (可选) 默认: false 是否只安装一次,如果为true,DADK会在安装成功后,不再重复安装 -install-once = false -# 目标架构 -# 可选值:"x86_64", "aarch64", "riscv64" -target-arch = ["x86_64"] -# 任务源 -[task-source] -# 构建类型 -# 可选值:"build-from_source", "install-from-prebuilt" -type = "build-from-source" -# 构建来源 -# "build_from_source" 可选值:"git", "local", "archive" -# "install_from_prebuilt" 可选值:"local", "archive" -source = "local" -# 路径或URL -source-path = "user/apps/test-veth-bridge" -# 构建相关信息 -[build] -# (可选)构建命令 -build-command = "make install" -# 安装相关信息 -[install] -# (可选)安装到DragonOS的路径 -in-dragonos-path = "/" -# 清除相关信息 -[clean] -# (可选)清除命令 -clean-command = "make clean" From de88f0c34b937a559626f906aeb04a4bc9c60b51 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 22 Aug 2025 11:00:43 +0800 Subject: [PATCH 25/36] =?UTF-8?q?feat(gdb):=20=E5=A2=9E=E5=8A=A0gdb=20debu?= =?UTF-8?q?g=E5=8F=AF=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- .vscode/launch.json | 25 +++++++++++++++++++++++++ kernel/Cargo.toml | 1 + tools/run-qemu.sh | 1 + 3 files changed, 27 insertions(+) create mode 100644 .vscode/launch.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..4aaa20b43 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug kernel.elf", + "stopOnEntry": false, + "targetCreateCommands": ["target create ${workspaceFolder}/bin/kernel/kernel.elf"], + "program": "${workspaceFolder}/bin/kernel/kernel.elf", + "processCreateCommands": [ + "gdb-remote 127.0.0.1:1234", + "settings set target.process.follow-fork-mode child", + "continue", + ], + "args": [], + "cwd": "${workspaceFolder}", + "sourceLanguages": ["rust"], + "console": "internalConsole" + } + ] +} \ No newline at end of file diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 30b1d6c05..32d1ef06b 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -135,3 +135,4 @@ debug = true # Controls whether the compiler passes `-g` # The release profile, used for `cargo build --release` [profile.release] debug = false +# debug = true diff --git a/tools/run-qemu.sh b/tools/run-qemu.sh index 81cf1c03b..d13864220 100755 --- a/tools/run-qemu.sh +++ b/tools/run-qemu.sh @@ -99,6 +99,7 @@ QEMU_DRIVE="id=disk,file=${QEMU_DISK_IMAGE},if=none" QEMU_ACCELARATE="" QEMU_ARGUMENT=" -no-reboot " QEMU_DEVICES="" +# QEMU_ARGUMENT+=" -S " if [ -f "${QEMU_EXT4_DISK_IMAGE}" ]; then QEMU_DRIVE+=" -drive id=ext4disk,file=${QEMU_EXT4_DISK_IMAGE},if=none,format=raw" From a723501b34e3fc01413c3083b8504b15f6a9f707 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 22 Aug 2025 12:01:45 +0800 Subject: [PATCH 26/36] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DSockAddrIn?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BD=93=E4=B8=AD=E7=9A=84sin=5Faddr?= =?UTF-8?q?=E5=AD=97=E8=8A=82=E5=BA=8F=E9=97=AE=E9=A2=98=EF=BC=8C=E7=A1=AE?= =?UTF-8?q?=E4=BF=9D=E6=AD=A3=E7=A1=AE=E5=A4=84=E7=90=86IPv4=E5=9C=B0?= =?UTF-8?q?=E5=9D=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- user/apps/c_unitest/test_veth_bridge.c | 106 ++++++++++++++----------- 1 file changed, 58 insertions(+), 48 deletions(-) diff --git a/user/apps/c_unitest/test_veth_bridge.c b/user/apps/c_unitest/test_veth_bridge.c index 6ceed98d5..fa662376a 100644 --- a/user/apps/c_unitest/test_veth_bridge.c +++ b/user/apps/c_unitest/test_veth_bridge.c @@ -1,12 +1,12 @@ +#include +#include +#include +#include #include #include #include -#include -#include #include -#include -#include -#include +#include #define SERVER_IP "200.0.0.4" #define CLIENT_IP "200.0.0.1" @@ -14,7 +14,7 @@ #define BUFFER_SIZE 1024 // 错误处理函数 -void die(const char *message) { +void handle_error_message(const char *message) { perror(message); exit(EXIT_FAILURE); } @@ -26,39 +26,57 @@ void *server_func(void *arg) { char buffer[BUFFER_SIZE]; socklen_t client_len = sizeof(client_addr); - // 1. 创建 UDP socket if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { - die("[server] Failed to create socket"); + handle_error_message("[server] Failed to create socket"); } - // 2. 准备服务器地址结构 memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(PORT); if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { - die("[server] Invalid server IP address"); + handle_error_message("[server] Invalid server IP address"); } - // 3. 绑定 socket 到指定地址和端口 - if (bind(sockfd, (const struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - die("[server] Failed to bind to " SERVER_IP); + if (bind(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[server] Failed to bind to " SERVER_IP); } printf("[server] Listening on %s:%d\n", SERVER_IP, PORT); - // 4. 接收数据 - ssize_t n = recvfrom(sockfd, buffer, BUFFER_SIZE, 0, (struct sockaddr *)&client_addr, &client_len); + ssize_t n = recvfrom(sockfd, + buffer, + BUFFER_SIZE, + 0, + (struct sockaddr *)&client_addr, + &client_len); if (n < 0) { - die("[server] Failed to receive"); + handle_error_message("[server] Failed to receive"); } buffer[n] = '\0'; // 确保字符串正确终止 + // //debug + // unsigned char *ip_bytes = (unsigned char *)&client_addr.sin_addr.s_addr; + // printf("[DEBUG] Raw IP bytes received: %d.%d.%d.%d\n", + // ip_bytes[0], + // ip_bytes[1], + // ip_bytes[2], + // ip_bytes[3]); + char client_ip_str[INET_ADDRSTRLEN]; inet_ntop(AF_INET, &client_addr.sin_addr, client_ip_str, INET_ADDRSTRLEN); - printf("[server] Received from %s:%d: %s\n", client_ip_str, ntohs(client_addr.sin_port), buffer); - - // 5. 将数据回显给客户端 - if (sendto(sockfd, buffer, n, 0, (const struct sockaddr *)&client_addr, client_len) < 0) { - die("[server] Failed to send back"); + printf("[server] Received from %s:%d: %s\n", + client_ip_str, + ntohs(client_addr.sin_port), + buffer); + + if (sendto(sockfd, + buffer, + n, + 0, + (const struct sockaddr *)&client_addr, + client_len) < 0) { + handle_error_message("[server] Failed to send back"); } printf("[server] Echoed back the message\n"); @@ -74,53 +92,49 @@ void *client_func(void *arg) { char buffer[BUFFER_SIZE]; const char *msg = "Hello from veth1!"; - // 1. 创建 UDP socket if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { - die("[client] Failed to create socket"); + handle_error_message("[client] Failed to create socket"); } - // 2. 准备客户端地址结构(用于绑定) memset(&client_addr, 0, sizeof(client_addr)); client_addr.sin_family = AF_INET; client_addr.sin_port = htons(0); // 端口为0,由操作系统自动选择 if (inet_pton(AF_INET, CLIENT_IP, &client_addr.sin_addr) <= 0) { - die("[client] Invalid client IP address"); + handle_error_message("[client] Invalid client IP address"); } - // 3. 绑定 socket 到客户端地址(可选,但为了匹配 Rust 代码的行为,我们这样做) - if (bind(sockfd, (const struct sockaddr *)&client_addr, sizeof(client_addr)) < 0) { - die("[client] Failed to bind to " CLIENT_IP); + if (bind(sockfd, + (const struct sockaddr *)&client_addr, + sizeof(client_addr)) < 0) { + handle_error_message("[client] Failed to bind to " CLIENT_IP); } - - // 4. 准备服务器地址结构(用于连接) + memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(PORT); if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { - die("[client] Invalid server IP address for connect"); + handle_error_message("[client] Invalid server IP address for connect"); } - // 5. 连接到服务器(这会使 UDP socket 记住目标地址) - if (connect(sockfd, (const struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) { - die("[client] Failed to connect"); + if (connect(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[client] Failed to connect"); } - // 6. 发送消息(因为已连接,可以使用 send 而不是 sendto) if (send(sockfd, msg, strlen(msg), 0) < 0) { - die("[client] Failed to send"); + handle_error_message("[client] Failed to send"); } printf("[client] Sent: %s\n", msg); - // 7. 接收回显(因为已连接,可以使用 recv 而不是 recvfrom) ssize_t n = recv(sockfd, buffer, BUFFER_SIZE, 0); if (n < 0) { - die("[client] Failed to receive"); + handle_error_message("[client] Failed to receive"); } buffer[n] = '\0'; // 确保字符串正确终止 printf("[client] Received echo: %s\n", buffer); - // 8. 验证消息是否匹配 assert(strcmp(msg, buffer) == 0 && "[client] Mismatch in echo!"); close(sockfd); @@ -131,28 +145,24 @@ void *client_func(void *arg) { int main() { pthread_t server_tid, client_tid; - // 启动 server 线程 if (pthread_create(&server_tid, NULL, server_func, NULL) != 0) { - die("Failed to create server thread"); + handle_error_message("Failed to create server thread"); } - // 确保 server 已启动 usleep(200 * 1000); // 200 milliseconds - // 启动 client 线程 if (pthread_create(&client_tid, NULL, client_func, NULL) != 0) { - die("Failed to create client thread"); + handle_error_message("Failed to create client thread"); } - // 等待两个线程结束 if (pthread_join(server_tid, NULL) != 0) { - die("Failed to join server thread"); + handle_error_message("Failed to join server thread"); } if (pthread_join(client_tid, NULL) != 0) { - die("Failed to join client thread"); + handle_error_message("Failed to join client thread"); } - printf("\n✅ Test completed: veth0 <--> veth1 UDP communication success\n"); + printf("\nTest completed: veth_a <--> veth_d UDP communication success\n"); return EXIT_SUCCESS; } \ No newline at end of file From 97fc6b9c65fbfd112a20931d969efb007859bd30 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Tue, 26 Aug 2025 16:37:37 +0800 Subject: [PATCH 27/36] =?UTF-8?q?feat:=20=E6=89=8B=E7=B3=8A=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E8=B7=AF=E7=94=B1=E5=8A=9F=E8=83=BD,=E5=90=8E?= =?UTF-8?q?=E7=BB=AD=E9=9C=80=E8=A6=81=E6=9B=B4=E6=94=B9=E4=BA=8B=E4=BB=B6?= =?UTF-8?q?=E9=A9=B1=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/bridge.rs | 51 ++-- kernel/src/driver/net/e1000e/e1000e_driver.rs | 2 +- kernel/src/driver/net/loopback.rs | 2 +- kernel/src/driver/net/mod.rs | 24 +- kernel/src/driver/net/veth.rs | 221 ++++++++++++++---- kernel/src/driver/net/virtio_net.rs | 2 +- kernel/src/net/routing.rs | 175 ++++++++++---- kernel/src/process/namespace/net_namespace.rs | 43 +++- user/apps/c_unitest/test_router.c | 169 ++++++++++++++ user/apps/c_unitest/test_veth_bridge.c | 1 + 10 files changed, 573 insertions(+), 117 deletions(-) create mode 100644 user/apps/c_unitest/test_router.c diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index 10707e191..f3fdef585 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -58,7 +58,7 @@ struct MacEntryRecord { pub struct BridgePort { pub id: BridgePortId, pub(super) bridge_enable: Arc, - pub(super) bridge_driver: Weak, + pub(super) bridge_iface: Weak, // 当前接口状态?forwarding, learning, blocking? // mac mtu信息 } @@ -72,7 +72,7 @@ impl BridgePort { BridgePort { id, bridge_enable: device, - bridge_driver: Arc::downgrade(bridge), + bridge_iface: Arc::downgrade(bridge), } } @@ -274,20 +274,17 @@ impl BridgeDriver { use crate::sched::SchedMode; loop { - let mut inner = self.inner.lock_irqsave(); - - let opt = inner.rx_buf.pop_front(); - if let Some((port_id, frame)) = opt { + let mut inner = self.inner.lock(); + while let Some((port_id, frame)) = inner.rx_buf.pop_front() { inner.handle_frame(port_id, &frame); - } else { - drop(inner); - log::info!("Bridge is going to sleep"); - let _ = wq_wait_event_interruptible!( - self.wait_queue, - !self.inner.lock().rx_buf.is_empty(), - {} - ); } + drop(inner); + // log::info!("Bridge is going to sleep"); + let _ = wq_wait_event_interruptible!( + self.wait_queue, + !self.inner.lock().rx_buf.is_empty(), + {} + ); } // inner.poll_blocking(); } @@ -355,7 +352,27 @@ impl BridgeIface { pub trait BridgeEnableDevice: Iface { fn receive_from_bridge(&self, frame: &[u8]); // fn inner_driver(&self) -> Arc; - fn set_common_bridge_data(&self, _port: BridgePort) {} + fn set_common_bridge_data(&self, _port: BridgePort); + + // fn common_bridge_data(&self) -> Option; + // fn port_id(&self) -> Option { + // let Some(data) = self.common_bridge_data() else { + // return None; + // }; + // Some(data.id) + // } + // fn bridge(&self) -> Weak { + // let Some(data) = self.common_bridge_data() else { + // return Weak::default(); + // }; + // data.bridge_driver + // } +} + +#[derive(Debug, Clone)] +pub struct BridgeCommonData { + pub id: BridgePortId, + pub bridge_iface: Weak, } fn bridge_probe() { @@ -403,15 +420,15 @@ fn bridge_probe() { let iface = BridgeIface::new(bridge); // BRIDGE_DEVICES.write_irqsave().push(bridge.clone()); - log::info!("Bridge device created"); iface.add_port(iface3); iface.add_port(iface2); + log::info!("Bridge device created"); } #[unified_init(INITCALL_DEVICE)] pub fn bridge_init() -> Result<(), SystemError> { bridge_probe(); - log::info!("bridge initialized."); + // log::info!("bridge initialized."); Ok(()) } diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index 08211d5d4..01d258e88 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -305,7 +305,7 @@ impl Iface for E1000EInterface { return self.name.clone(); } - fn poll(&self) { + fn poll(&self) -> bool { self.common.poll(self.driver.force_get_mut()) } diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index aafb9947e..e5f7db59f 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -481,7 +481,7 @@ impl Iface for LoopbackInterface { smoltcp::wire::EthernetAddress(mac) } - fn poll(&self) { + fn poll(&self) -> bool { self.common.poll(self.driver.force_get_mut()) } diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 2a9abc6e6..906f608cf 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -4,6 +4,8 @@ use alloc::{string::String, sync::Arc}; use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; +use crate::libs::rwlock::RwLockReadGuard; +use crate::net::routing::RouterEnableDeviceCommon; use crate::process::namespace::net_namespace::NetNamespace; use crate::{ libs::{rwlock::RwLock, spinlock::SpinLock}, @@ -77,7 +79,12 @@ pub trait Iface: crate::driver::base::device::Device { self.common().iface_id } - fn poll(&self); + /// # `poll` + /// 用于轮询网卡,处理网络事件 + /// ## 返回值 + /// - `true`:表示有网络事件发生 + /// - `false`:表示没有网络事件 + fn poll(&self) -> bool; /// # `poll_blocking` /// 用于在阻塞模式下轮询网卡 @@ -189,6 +196,8 @@ pub struct IfaceCommon { default_iface: bool, /// 网络命名空间 net_namespace: RwLock>, + + router_common_data: RouterEnableDeviceCommon, } impl fmt::Debug for IfaceCommon { @@ -202,6 +211,11 @@ impl fmt::Debug for IfaceCommon { impl IfaceCommon { pub fn new(iface_id: usize, default_iface: bool, iface: smoltcp::iface::Interface) -> Self { + let router_common_data = RouterEnableDeviceCommon::default(); + router_common_data + .ip_addrs + .write() + .extend_from_slice(iface.ip_addrs()); IfaceCommon { iface_id, smol_iface: SpinLock::new(iface), @@ -211,10 +225,11 @@ impl IfaceCommon { poll_at_ms: core::sync::atomic::AtomicU64::new(0), default_iface, net_namespace: RwLock::new(Weak::new()), + router_common_data, } } - pub fn poll(&self, device: &mut D) + pub fn poll(&self, device: &mut D) -> bool where D: smoltcp::phy::Device + ?Sized, { @@ -281,6 +296,7 @@ impl IfaceCommon { // .extract_if(|closing_socket| closing_socket.is_closed()) // .collect::>(); // drop(closed_sockets); + has_events } pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> { @@ -322,6 +338,10 @@ impl IfaceCommon { self.smol_iface.lock().ipv4_addr() } + pub fn ip_addrs(&self) -> RwLockReadGuard<'_, Vec> { + self.router_common_data.ip_addrs.read() + } + pub fn prefix_len(&self) -> Option { self.smol_iface .lock() diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 0cd213dfd..556e8b1c7 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -1,6 +1,6 @@ use super::bridge::BridgeEnableDevice; -use super::{register_netdevice, NetDeivceState, NetDeviceCommonData, Operstate}; use super::{Iface, IfaceCommon}; +use super::{NetDeivceState, NetDeviceCommonData, Operstate}; use crate::arch::rand::rand; use crate::driver::base::class::Class; use crate::driver::base::device::bus::Bus; @@ -10,16 +10,17 @@ use crate::driver::base::kobject::{ KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, }; use crate::driver::base::kset::KSet; -use crate::driver::net::bridge::BridgePort; +use crate::driver::net::bridge::{BridgeCommonData, BridgePort}; +use crate::driver::net::register_netdevice; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; use crate::libs::wait_queue::WaitQueue; use crate::net::generate_iface_id; -use crate::net::routing::{RouterEnableDevice, RouterEnableDeviceCommon}; -use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; -use crate::process::ProcessState; +use crate::net::routing::{RouteEntry, RouterEnableDevice}; +use crate::process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}; +use crate::process::{ProcessManager, ProcessState}; use crate::sched::SchedMode; use alloc::collections::VecDeque; use alloc::fmt::Debug; @@ -60,14 +61,17 @@ impl Veth { if let Some(peer) = self.peer.upgrade() { // log::info!("Veth {} trying to send", self.name); - if let Some(bridge_data) = peer.inner.lock().bridge_port_data.as_ref() { + if let Some(bridge_common_data) = peer.inner.lock().bridge_common_data.as_ref() { // log::info!("Veth {} sending data to bridge", self.name); - Self::to_bridge(bridge_data, data); + Self::to_bridge(bridge_common_data, data); return; } // 如果是路由设备,则将数据发送到路由器 - self.to_router(data); + if self.to_router(data) { + // log::info!("Veth {} sent data to router", self.name); + return; + } Self::to_peer(&peer, data); } @@ -77,24 +81,48 @@ impl Veth { let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); peer_veth.rx_queue.push_back(data.to_vec()); log::info!("Veth {} received data from peer", peer.name); - log::info!("DATA RECEIVED: {:?}", peer_veth.rx_queue); + log::info!("{:?}", peer_veth.rx_queue); drop(peer_veth); // 唤醒对端正在等待的进程 peer.wake_up(); + + if let Some(ns) = peer.net_namespace() { + ns.wakeup_poll_thread(); + } } - fn to_bridge(bridge_data: &BridgePort, data: &Vec) { - if let Some(bridge_driver) = bridge_data.bridge_driver.upgrade() { - // log::info!("Veth {} sending data to bridge", self.name); - bridge_driver.driver.enqueue_frame(bridge_data.id, data); + fn to_bridge(bridge_data: &BridgeCommonData, data: &Vec) { + // log::info!("Veth {} sending data to bridge", self.name); + let Some(bridge) = bridge_data.bridge_iface.upgrade() else { + log::warn!("Bridge has been dropped"); + return; }; + bridge.driver.enqueue_frame(bridge_data.id, data) } - fn to_router(&self, data: &[u8]) { - if let Some(self_iface) = self.self_iface_ref.upgrade() { - let frame = EthernetFrame::new_checked(data).unwrap(); - self_iface.handle_routable_packet(frame.payload()); + /// 经过路由发送,返回是否发送成功 + fn to_router(&self, data: &[u8]) -> bool { + let Some(self_iface) = self.self_iface_ref.upgrade() else { + return false; + }; + + let frame: EthernetFrame<&[u8]> = EthernetFrame::new_checked(data).unwrap(); + log::info!("trying to go to router"); + match self_iface.handle_routable_packet(&frame) { + Ok(_) => { + log::info!("successfully sent to router"); + return true; + } + // 先不管错误,直接告诉外面没有经过路由发送出去 + Err(Some(err)) => { + log::error!("Router error: {:?}", err); + return false; + } + Err(_) => { + log::info!("not routed"); + return false; + } } } @@ -245,10 +273,7 @@ pub struct VethCommonData { kobj_common: KObjectCommonData, peer_veth: Weak, - //TODO 这里其实不用整个port,反而会导致循环引用,可以只留一个BridgeIface - bridge_port_data: Option, - - router_common_data: RouterEnableDeviceCommon, + bridge_common_data: Option, } impl Default for VethCommonData { @@ -258,8 +283,7 @@ impl Default for VethCommonData { device_common: DeviceCommonData::default(), kobj_common: KObjectCommonData::default(), peer_veth: Weak::new(), - bridge_port_data: None, - router_common_data: RouterEnableDeviceCommon::default(), + bridge_common_data: None, } } } @@ -330,7 +354,7 @@ impl VethInterface { } fn inner(&self) -> SpinLockGuard<'_, VethCommonData> { - self.inner.lock_irqsave() + self.inner.lock() } /// # `update_ip_addrs` @@ -346,6 +370,15 @@ impl VethInterface { ip_addrs.push(cidr).expect("Push ipCidr failed: full"); }); + // // 直接更新对端的arp_table + // self.inner.lock().peer_veth.upgrade().map(|peer| { + // peer.common + // .router_common_data + // .arp_table + // .write() + // .insert(cidr.address(), self.mac()) + // }); + // log::info!("VethInterface {} updated IP address: {}", self.name, addr); } @@ -565,9 +598,10 @@ impl Iface for VethInterface { } } - fn poll(&self) { + fn poll(&self) -> bool { // log::info!("VethInterface {} polling normal", self.name); - self.common.poll(self.driver.force_get_mut()); + self.common.poll(self.driver.force_get_mut()) + // self.clear_recv_buffer(); } fn addr_assign_type(&self) -> u8 { @@ -603,7 +637,16 @@ impl BridgeEnableDevice for VethInterface { // let inner = self.inner.lock_irqsave(); - if let Some(_data) = self.inner.lock_irqsave().bridge_port_data.as_ref() { + if self + .inner + .lock_irqsave() + .bridge_common_data + .as_ref() + .unwrap() + .bridge_iface + .upgrade() + .is_some() + { log::info!("VethInterface {} sending data to peer", self.name); // let peer = self.peer_veth(); @@ -621,8 +664,16 @@ impl BridgeEnableDevice for VethInterface { fn set_common_bridge_data(&self, port: BridgePort) { // log::info!("Now set bridge port data for {}", self.name); let mut inner = self.inner.lock_irqsave(); - inner.bridge_port_data = Some(port); + let data = BridgeCommonData { + id: port.id, + bridge_iface: port.bridge_iface.clone(), + }; + inner.bridge_common_data = Some(data); } + + // fn common_bridge_data(&self) -> Option { + // self.inner().bridge_common_data.clone() + // } } impl RouterEnableDevice for VethInterface { @@ -655,44 +706,122 @@ impl RouterEnableDevice for VethInterface { } fn is_my_ip(&self, ip: IpAddress) -> bool { - let iface = self.common.smol_iface.lock_irqsave(); - iface.ip_addrs().iter().any(|cidr| cidr.contains_addr(&ip)) + self.common + .ip_addrs() + .iter() + .any(|cidr| cidr.contains_addr(&ip)) } } -pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { - let (iface1, iface2) = VethInterface::new_pair(name1, name2); +// pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { +// let (iface1, iface2) = VethInterface::new_pair(name1, name2); + +// let addr1 = IpAddress::v4(10, 0, 0, 1); +// let cidr1 = IpCidr::new(addr1, 24); +// iface1.update_ip_addrs(cidr1); + +// let addr2 = IpAddress::v4(10, 0, 0, 2); +// let cidr2 = IpCidr::new(addr2, 24); +// iface2.update_ip_addrs(cidr2); + +// // 添加默认路由 +// iface1.add_default_route_to_peer(addr2); +// iface2.add_default_route_to_peer(addr1); + +// let turn_on = |a: &Arc| { +// a.set_net_state(NetDeivceState::__LINK_STATE_START); +// a.set_operstate(Operstate::IF_OPER_UP); +// // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); +// INIT_NET_NAMESPACE.add_device(a.clone()); +// a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); +// register_netdevice(a.clone()).expect("register veth device failed"); +// }; - let addr1 = IpAddress::v4(10, 0, 0, 1); +// turn_on(&iface1); +// turn_on(&iface2); + +// (iface1, iface2) +// } + +fn veth_route_test() { + let (iface_ns1, iface_host1) = VethInterface::new_pair("veth-ns1", "veth-host1"); + let (iface_ns2, iface_host2) = VethInterface::new_pair("veth-ns2", "veth-host2"); + + let addr1 = IpAddress::v4(192, 168, 1, 1); let cidr1 = IpCidr::new(addr1, 24); - iface1.update_ip_addrs(cidr1); + iface_ns1.update_ip_addrs(cidr1); - let addr2 = IpAddress::v4(10, 0, 0, 2); + let addr2 = IpAddress::v4(192, 168, 1, 254); let cidr2 = IpCidr::new(addr2, 24); - iface2.update_ip_addrs(cidr2); + iface_host1.update_ip_addrs(cidr2); + + let addr3 = IpAddress::v4(192, 168, 2, 254); + let cidr3 = IpCidr::new(addr3, 24); + iface_host2.update_ip_addrs(cidr3); + + let addr4 = IpAddress::v4(192, 168, 2, 1); + let cidr4 = IpCidr::new(addr4, 24); + iface_ns2.update_ip_addrs(cidr4); // 添加默认路由 - iface1.add_default_route_to_peer(addr2); - iface2.add_default_route_to_peer(addr1); + iface_ns1.add_default_route_to_peer(addr2); + iface_host1.add_default_route_to_peer(addr1); + + iface_host2.add_default_route_to_peer(addr4); + iface_ns2.add_default_route_to_peer(addr3); - let turn_on = |a: &Arc| { + let turn_on = |a: &Arc, ns: Arc| { a.set_net_state(NetDeivceState::__LINK_STATE_START); a.set_operstate(Operstate::IF_OPER_UP); // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); - INIT_NET_NAMESPACE.add_device(a.clone()); - a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); + ns.add_device(a.clone()); + a.common().set_net_namespace(ns.clone()); register_netdevice(a.clone()).expect("register veth device failed"); }; - turn_on(&iface1); - turn_on(&iface2); - - (iface1, iface2) + let ns1 = NetNamespace::new_empty(ProcessManager::current_user_ns()).unwrap(); + let ns2 = NetNamespace::new_empty(ProcessManager::current_user_ns()).unwrap(); + + let router_ns1 = ns1.router(); + // 任何发往 192.168.1.0/24 网络的数据包都是本地邻居,可以直接从 veth-ns1 发送。 + let dest = IpCidr::new(IpAddress::v4(192, 168, 1, 0), 24); + let route = RouteEntry::new_connected(dest, iface_ns1.clone()); + router_ns1.add_route(route); + // 任何不匹配其他路由的数据包,都应该通过 veth-ns1 接口发送给下一跳 192.168.1.254。 + let next_hop = IpAddress::v4(192, 168, 1, 254); + let route = RouteEntry::new_default(next_hop, iface_ns1.clone()); + router_ns1.add_route(route); + + let router_ns2 = ns2.router(); + // 任何发往 192.168.2.0/24 网络的数据包都是本地邻居,可以直接从 veth-ns2 发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 2, 0), 24); + let route = RouteEntry::new_connected(dest, iface_ns2.clone()); + router_ns2.add_route(route); + // 任何不匹配其他路由的数据包,都应该通过 veth-ns2 接口发送给下一跳 192.168.2.254 + let next_hop = IpAddress::v4(192, 168, 2, 254); + let route = RouteEntry::new_default(next_hop, iface_ns2.clone()); + router_ns2.add_route(route); + + let host_router = INIT_NET_NAMESPACE.router(); + // 任何发往 192.168.1.0/24 网络的数据包,都应该从 veth-host1 接口直接发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 1, 0), 24); + let route = RouteEntry::new_connected(dest, iface_host1.clone()); + host_router.add_route(route); + // 任何发往 192.168.2.0/24 网络的数据包,都应该从 veth-host2 接口直接发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 2, 0), 24); + let route = RouteEntry::new_connected(dest, iface_host2.clone()); + host_router.add_route(route); + + turn_on(&iface_ns1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_ns2, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_host1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_host2, INIT_NET_NAMESPACE.clone()); } #[unified_init(INITCALL_DEVICE)] pub fn veth_init() -> Result<(), SystemError> { - veth_probe("veth0", "veth1"); + // veth_probe("veth0", "veth1"); + veth_route_test(); log::info!("Veth pair initialized."); Ok(()) } diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index bd5c4b063..1179796fd 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -661,7 +661,7 @@ impl Iface for VirtioInterface { return self.iface_name.clone(); } - fn poll(&self) { + fn poll(&self) -> bool { // log::debug!("VirtioInterface: poll"); self.iface_common.poll(self.device_inner.force_get_mut()) } diff --git a/kernel/src/net/routing.rs b/kernel/src/net/routing.rs index 61bd3f41a..34ee3d58a 100644 --- a/kernel/src/net/routing.rs +++ b/kernel/src/net/routing.rs @@ -1,11 +1,18 @@ use crate::driver::net::Iface; use crate::libs::rwlock::RwLock; +use crate::libs::wait_queue::WaitQueue; +use crate::process::kthread::KernelThreadClosure; +use crate::process::kthread::KernelThreadMechanism; +use crate::process::namespace::net_namespace::NetNamespace; use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; -use alloc::collections::BTreeMap; -use alloc::string::String; +use crate::process::ProcessState; +use alloc::boxed::Box; +use alloc::collections::VecDeque; +use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; use alloc::vec::Vec; -use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; +use smoltcp::wire::{EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; +use system_error::SystemError; #[derive(Debug, Clone)] pub struct RouteEntry { @@ -85,19 +92,55 @@ pub struct RouteDecision { #[derive(Debug)] pub struct Router { name: String, - /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec,并且应该在这上面加锁(maybe rwlock?) and 指针反而可以不加锁,在这个路由表这里加就行 + /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec route_table: RwLock, + pub ns: RwLock>, + + wait_queue: WaitQueue, + rx_frames: RwLock>, } impl Router { - pub fn new(name: String) -> Self { - Self { - name, + pub fn new(name: String) -> Arc { + let router = Arc::new(Self { + name: name.clone(), route_table: RwLock::new(RouteTable::default()), - } + wait_queue: WaitQueue::default(), + rx_frames: RwLock::new(VecDeque::new()), + ns: RwLock::new(Weak::default()), + }); + + let self_clone: Arc = router.clone(); + + // 创建一个线程来处理桥接设备的轮询 + let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { + self_clone.poll_blocking(); + 0 + }); + let closure = KernelThreadClosure::EmptyClosure((closure, ())); + let name = name + "_poll"; + log::info!("Creating router polling thread: {}", name); + let _pcb = KernelThreadMechanism::create_and_run(closure, name) + .ok_or("") + .expect("create router_poll thread failed"); + + // log::info!("Router polling thread created"); + router } - pub fn add_route(&mut self, route: RouteEntry) { + /// 创建一个空的Router实例,主要用于初始化网络命名空间时使用 + /// 注意: 这个Router实例不会启动轮询线程 + pub fn new_empty() -> Arc { + Arc::new(Self { + name: "empty_router".to_string(), + route_table: RwLock::new(RouteTable::default()), + wait_queue: WaitQueue::default(), + rx_frames: RwLock::new(VecDeque::new()), + ns: RwLock::new(Weak::default()), + }) + } + + pub fn add_route(&self, route: RouteEntry) { let mut guard = self.route_table.write(); let entries = &mut guard.entries; let pos = entries @@ -109,7 +152,7 @@ impl Router { log::info!("Router {}: Added route to routing table", self.name); } - pub fn remove_route(&mut self, destination: IpCidr) { + pub fn remove_route(&self, destination: IpCidr) { self.route_table .write() .entries @@ -147,49 +190,92 @@ impl Router { .entries .retain(|route| route.interface.strong_count() > 0); } + + pub fn enqueue_frame(&self, frame: RouterFrame) { + self.rx_frames.write().push_back(frame); + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + } + + fn poll_blocking(&self) { + use crate::sched::SchedMode; + + loop { + self.poll(); + + log::info!("Router is going to sleep"); + let _ = + wq_wait_event_interruptible!(self.wait_queue, !self.rx_frames.read().is_empty(), { + }); + } + } + + fn poll(&self) { + let mut inner = self.rx_frames.write(); + while let Some((decision, frame)) = inner.pop_front() { + decision.interface.route_and_send(decision.next_hop, &frame); + } + log::info!("Router polled all frames"); + } } +/// 获取初始化网络命名空间下的路由表 pub fn init_netns_router() -> Arc { INIT_NET_NAMESPACE.router().clone() } /// 可供路由设备应该实现的 trait pub trait RouterEnableDevice: Iface { - //todo 这里可以直接传一个IpPacket进来?如果目前只有ipv4的话 - fn handle_routable_packet(&self, packet: &[u8]) { - if packet.len() < 14 { - return; - } - - let ether_frame = match EthernetFrame::new_checked(packet) { - Ok(f) => f, - Err(_) => return, - }; - + /// # 网卡处理可路由的包 + /// ## 参数 + /// - `packet`: 需要处理的以太网帧 + /// ## 返回值 + /// - `Ok(())`: 通过路由处理成功 + /// - `Err(None)`: 忽略非IPv4包或没有路由到达的包,告诉外界没有经过处理,应该交由网卡进行默认处理 + /// - `Err(Some(SystemError))`: 处理失败,可能是包格式错误或其他系统错误 + fn handle_routable_packet( + &self, + ether_frame: &EthernetFrame<&[u8]>, + ) -> Result<(), Option> { // 只处理IP包(IPv4) if ether_frame.ethertype() != smoltcp::wire::EthernetProtocol::Ipv4 { - return; + // 忽略非IPv4包 + log::info!( + "Ignoring non-IPv4 packet on interface {}", + self.iface_name() + ); + return Err(None); } - let ipv4_packet = match Ipv4Packet::new_checked(ether_frame.payload()) { + // log::info!( + // "src_mac: {}, dst_mac: {}", + // ether_frame.src_addr(), + // ether_frame.dst_addr() + // ); + + let ipv4_packet: Ipv4Packet<&[u8]> = match Ipv4Packet::new_checked(ether_frame.payload()) { Ok(p) => p, - Err(_) => return, + Err(_) => return Err(Some(SystemError::EINVAL)), }; + // log::info!( + // "src_ip: {}, dst_ip: {}", + // ipv4_packet.src_addr(), + // ipv4_packet.dst_addr() + // ); + let dst_ip = ipv4_packet.dst_addr(); // 检查TTL if ipv4_packet.hop_limit() <= 1 { log::warn!("TTL exceeded for packet to {}", dst_ip); - return; + return Err(Some(SystemError::EINVAL)); } // 检查是否是发给自己的包(目标IP是否是自己的IP) if self.is_my_ip(dst_ip.into()) { // 交给本地协议栈处理 log::info!("Packet destined for local interface {}", self.iface_name()); - //todo - return; + return Err(None); } // 查询当前网络命名空间下的路由表 @@ -199,7 +285,7 @@ pub trait RouterEnableDevice: Iface { Some(d) => d, None => { log::warn!("No route to {}", dst_ip); - return; + return Err(None); } }; @@ -207,8 +293,12 @@ pub trait RouterEnableDevice: Iface { // 检查是否是从同一个接口进来又要从同一个接口出去(避免回路) if self.iface_name() == decision.interface.iface_name() { - log::warn!("Avoiding routing loop for packet to {}", dst_ip); - return; + log::info!( + "Ignoring packet loop from {} to {}", + self.iface_name(), + dst_ip + ); + return Err(None); } // 创建修改后的IP包(递减TTL) @@ -218,17 +308,13 @@ pub trait RouterEnableDevice: Iface { // //todo 这里应该重新计算IP校验和,为了简化先跳过 // } + let frame = (decision, modified_ip_packet); + // 交给出接口进行发送 - decision - .interface - .route_and_send(decision.next_hop, &modified_ip_packet); - - log::info!( - "Routed packet from {} to {} via interface {}", - self.iface_name(), - dst_ip, - decision.interface.iface_name() - ); + self.netns_router().enqueue_frame(frame); + + log::info!("Routed packet from {} to {} ", self.iface_name(), dst_ip,); + Ok(()) } /// 路由器决定通过此接口发送包时调用此方法 @@ -247,18 +333,23 @@ pub trait RouterEnableDevice: Iface { } } +pub type RouterFrame = (RouteDecision, Vec); + /// # 每一个`RouterEnableDevice`应该有的公共数据,包含 /// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) #[derive(Debug)] pub struct RouterEnableDeviceCommon { /// 当前接口的邻居缓存 - pub arp_table: RwLock>, + // pub arp_table: RwLock>, + /// 当前接口的IP地址列表 + pub ip_addrs: RwLock>, } impl Default for RouterEnableDeviceCommon { fn default() -> Self { Self { - arp_table: RwLock::new(BTreeMap::new()), + // arp_table: RwLock::new(BTreeMap::new()), + ip_addrs: RwLock::new(Vec::new()), } } } diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index db903c822..f624c03dc 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -21,6 +21,7 @@ use alloc::boxed::Box; use alloc::collections::BTreeMap; use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; +use core::sync::atomic::AtomicUsize; use hashbrown::HashMap; use system_error::SystemError; use unified_init::macros::unified_init; @@ -30,15 +31,32 @@ lazy_static! { pub static ref INIT_NET_NAMESPACE: Arc = NetNamespace::new_root(); } +/// # 网络命名空间计数器 +/// 用于生成唯一的网络命名空间ID +/// 每次创建新的网络命名空间时,都会增加这个计数器 +pub static mut NETNS_COUNTER: AtomicUsize = AtomicUsize::new(0); + #[unified_init(INITCALL_SUBSYS)] -pub fn root_net_namespace_thread_init() -> Result<(), SystemError> { +pub fn root_net_namespace_init() -> Result<(), SystemError> { // 创建root网络命名空间的轮询线程 let pcb = NetNamespace::create_polling_thread(INIT_NET_NAMESPACE.clone(), "root_netns".to_string()); INIT_NET_NAMESPACE.set_poll_thread(pcb); + + // 创建 router + let router = Router::new("root_netns_router".to_string()); + INIT_NET_NAMESPACE.inner_mut().router = router.clone(); + let mut guard = router.ns.write(); + *guard = INIT_NET_NAMESPACE.self_ref.clone(); + Ok(()) } +/// # 获取下一个网络命名空间计数器的值 +fn get_next_netns_counter() -> usize { + unsafe { NETNS_COUNTER.fetch_add(1, core::sync::atomic::Ordering::SeqCst) } +} + #[derive(Debug)] pub struct NetNamespace { ns_common: NsCommon, @@ -84,7 +102,8 @@ impl InnerNetNamespace { impl NetNamespace { pub fn new_root() -> Arc { let inner = InnerNetNamespace { - router: Arc::new(Router::new("root_netns_router".to_string())), + // 这里没有直接创建 router,而是留到 init 函数中创建 + router: Router::new_empty(), loopback_iface: None, default_iface: None, }; @@ -106,12 +125,11 @@ impl NetNamespace { } pub fn new_empty(user_ns: Arc) -> Result, SystemError> { - // 这里获取当前进程的pid,只是为了给后面创建的路由以及线程做唯一标识,没有其他意义 - let pid = ProcessManager::current_pid().data(); + let counter = get_next_netns_counter(); let loopback = generate_loopback_iface_default(); let inner = InnerNetNamespace { - router: Arc::new(Router::new(format!("netns_router_{}", pid))), + router: Router::new(format!("netns_router_{}", counter)), loopback_iface: Some(loopback.clone()), default_iface: None, }; @@ -126,7 +144,7 @@ impl NetNamespace { netlink_socket_table: NetlinkSocketTable::default(), netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); - Self::create_polling_thread(netns.clone(), format!("netns_{}", pid)); + Self::create_polling_thread(netns.clone(), format!("netns_{}", counter)); netns.add_device(loopback); Ok(netns) @@ -224,13 +242,24 @@ impl NetNamespace { fn polling(&self) { log::info!("net_poll thread started for namespace"); loop { + let mut has_work_done = false; for (_, iface) in self.device_list.read_irqsave().iter() { - iface.poll(); + // 这里检查poll的返回值,如果为 true 的话,则说明发生了网络事件,很有可能再次发生,则会再次调用poll + if iface.poll() { + has_work_done = true; + iface.poll(); + } + } + if has_work_done { + log::info!("fucking continue"); + continue; } + let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; ProcessManager::mark_sleep(true) .expect("clocksource_watchdog_kthread:mark sleep failed"); drop(irq_guard); + // log::info!("net_poll thread going to sleep"); schedule(SchedMode::SM_NONE); } } diff --git a/user/apps/c_unitest/test_router.c b/user/apps/c_unitest/test_router.c new file mode 100644 index 000000000..7ac968b6c --- /dev/null +++ b/user/apps/c_unitest/test_router.c @@ -0,0 +1,169 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "192.168.2.1" +#define CLIENT_IP "192.168.1.1" +#define PORT 34254 +#define BUFFER_SIZE 1024 + +// 错误处理函数 +void handle_error_message(const char *message) { + perror(message); + exit(EXIT_FAILURE); +} + +// 服务器线程函数 +void *server_func(void *arg) { + int sockfd; + struct sockaddr_in server_addr, client_addr; + char buffer[BUFFER_SIZE]; + socklen_t client_len = sizeof(client_addr); + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[server] Failed to create socket"); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[server] Invalid server IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[server] Failed to bind to " SERVER_IP); + } + printf("[server] Listening on %s:%d\n", SERVER_IP, PORT); + + ssize_t n = recvfrom(sockfd, + buffer, + BUFFER_SIZE, + 0, + (struct sockaddr *)&client_addr, + &client_len); + if (n < 0) { + handle_error_message("[server] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + // //debug + // unsigned char *ip_bytes = (unsigned char *)&client_addr.sin_addr.s_addr; + // printf("[DEBUG] Raw IP bytes received: %d.%d.%d.%d\n", + // ip_bytes[0], + // ip_bytes[1], + // ip_bytes[2], + // ip_bytes[3]); + + char client_ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &client_addr.sin_addr, client_ip_str, INET_ADDRSTRLEN); + printf("[server] Received from %s:%d: %s\n", + client_ip_str, + ntohs(client_addr.sin_port), + buffer); + + if (sendto(sockfd, + buffer, + n, + 0, + (const struct sockaddr *)&client_addr, + client_len) < 0) { + handle_error_message("[server] Failed to send back"); + } + sleep(1); + printf("[server] Echoed back the message\n"); + + close(sockfd); + printf("server goning to exit\n"); + return NULL; +} + +// 客户端线程函数 +void *client_func(void *arg) { + int sockfd; + struct sockaddr_in client_addr, server_addr; + char buffer[BUFFER_SIZE]; + const char *msg = "Hello from veth1!"; + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[client] Failed to create socket"); + } + + memset(&client_addr, 0, sizeof(client_addr)); + client_addr.sin_family = AF_INET; + client_addr.sin_port = htons(0); // 端口为0,由操作系统自动选择 + if (inet_pton(AF_INET, CLIENT_IP, &client_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid client IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&client_addr, + sizeof(client_addr)) < 0) { + handle_error_message("[client] Failed to bind to " CLIENT_IP); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid server IP address for connect"); + } + + if (connect(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[client] Failed to connect"); + } + + if (send(sockfd, msg, strlen(msg), 0) < 0) { + handle_error_message("[client] Failed to send"); + } + printf("[client] Sent: %s\n", msg); + + ssize_t n = recv(sockfd, buffer, BUFFER_SIZE, 0); + if (n < 0) { + handle_error_message("[client] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + printf("[client] Received echo: %s\n", buffer); + + assert(strcmp(msg, buffer) == 0 && "[client] Mismatch in echo!"); + + close(sockfd); + printf("client goning to exit\n"); + return NULL; +} + +int main() { + pthread_t server_tid, client_tid; + + if (pthread_create(&server_tid, NULL, server_func, NULL) != 0) { + handle_error_message("Failed to create server thread"); + } + + usleep(200 * 1000); // 200 milliseconds + + if (pthread_create(&client_tid, NULL, client_func, NULL) != 0) { + handle_error_message("Failed to create client thread"); + } + + if (pthread_join(server_tid, NULL) != 0) { + handle_error_message("Failed to join server thread"); + } + if (pthread_join(client_tid, NULL) != 0) { + handle_error_message("Failed to join client thread"); + } + + printf("\nTest completed: veth_a <--> veth_d UDP communication success\n"); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/user/apps/c_unitest/test_veth_bridge.c b/user/apps/c_unitest/test_veth_bridge.c index fa662376a..e0b20720e 100644 --- a/user/apps/c_unitest/test_veth_bridge.c +++ b/user/apps/c_unitest/test_veth_bridge.c @@ -78,6 +78,7 @@ void *server_func(void *arg) { client_len) < 0) { handle_error_message("[server] Failed to send back"); } + // sleep(5); printf("[server] Echoed back the message\n"); close(sockfd); From 31c7341587e5f9c8149ac506f7c1be6b357fd3f2 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Tue, 26 Aug 2025 16:40:24 +0800 Subject: [PATCH 28/36] =?UTF-8?q?feat(netlink):=20=E8=A1=A5=E5=85=85getlin?= =?UTF-8?q?k=E6=96=B9=E6=B3=95=E4=BB=A5=E5=8F=8A=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E4=BD=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/e1000e/e1000e_driver.rs | 26 ++- kernel/src/driver/net/loopback.rs | 23 +- kernel/src/driver/net/mod.rs | 34 ++- kernel/src/driver/net/types.rs | 84 +++++++ kernel/src/driver/net/veth.rs | 23 +- kernel/src/driver/net/virtio_net.rs | 111 ++++++++- .../net/socket/netlink/route/kernel/link.rs | 157 +++++++++++++ .../net/socket/netlink/route/kernel/mod.rs | 2 + .../socket/netlink/route/message/attr/link.rs | 219 ++++++++++++++++++ .../socket/netlink/route/message/attr/mod.rs | 16 ++ .../netlink/route/message/segment/link.rs | 120 ++++++++++ .../netlink/route/message/segment/mod.rs | 17 +- 12 files changed, 821 insertions(+), 11 deletions(-) create mode 100644 kernel/src/driver/net/types.rs create mode 100644 kernel/src/net/socket/netlink/route/kernel/link.rs create mode 100644 kernel/src/net/socket/netlink/route/message/attr/link.rs create mode 100644 kernel/src/net/socket/netlink/route/message/segment/link.rs diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index 01d258e88..c6c90180b 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -9,7 +9,8 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, }, net::{ - register_netdevice, Iface, IfaceCommon, NetDeivceState, NetDeviceCommonData, Operstate, + register_netdevice, types::InterfaceFlags, Iface, IfaceCommon, NetDeivceState, + NetDeviceCommonData, Operstate, }, }, libs::{ @@ -201,9 +202,21 @@ impl E1000EInterface { let iface = smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); + let flags = InterfaceFlags::UP + | InterfaceFlags::BROADCAST + | InterfaceFlags::RUNNING + | InterfaceFlags::MULTICAST + | InterfaceFlags::LOWER_UP; + let iface = Arc::new(E1000EInterface { driver: E1000EDriverWrapper(UnsafeCell::new(driver)), - common: IfaceCommon::new(iface_id, false, iface), + common: IfaceCommon::new( + iface_id, + crate::driver::net::types::InterfaceType::EETHER, + flags, + false, + iface, + ), name: format!("eth{}", iface_id), inner: SpinLock::new(InnerE1000EInterface { netdevice_common: NetDeviceCommonData::default(), @@ -333,6 +346,15 @@ impl Iface for E1000EInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } } impl KObject for E1000EInterface { diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index e5f7db59f..ff41c8971 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -7,6 +7,7 @@ use crate::driver::base::kobject::{ KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, }; use crate::driver::base::kset::KSet; +use crate::driver::net::types::InterfaceFlags; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; @@ -321,9 +322,20 @@ impl LoopbackInterface { .expect("Add default ipv4 route failed: full"); }); + let flags = InterfaceFlags::LOOPBACK + | InterfaceFlags::UP + | InterfaceFlags::RUNNING + | InterfaceFlags::LOWER_UP; + Arc::new(LoopbackInterface { driver: LoopbackDriverWapper(UnsafeCell::new(driver)), - common: IfaceCommon::new(iface_id, false, iface), + common: IfaceCommon::new( + iface_id, + super::types::InterfaceType::LOOPBACK, + flags, + false, + iface, + ), inner: SpinLock::new(InnerLoopbackInterface { netdevice_common: NetDeviceCommonData::default(), device_common: DeviceCommonData::default(), @@ -509,6 +521,15 @@ impl Iface for LoopbackInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } } pub fn generate_loopback_iface_default() -> Arc { diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 906f608cf..1e27aefae 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -4,6 +4,7 @@ use alloc::{string::String, sync::Arc}; use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; +use crate::driver::net::types::{InterfaceFlags, InterfaceType}; use crate::libs::rwlock::RwLockReadGuard; use crate::net::routing::RouterEnableDeviceCommon; use crate::process::namespace::net_namespace::NetNamespace; @@ -22,6 +23,7 @@ pub mod e1000e; pub mod irq_handle; pub mod loopback; pub mod sysfs; +pub mod types; pub mod veth; pub mod virtio_net; @@ -142,6 +144,16 @@ pub trait Iface: crate::driver::base::device::Device { fn set_net_namespace(&self, ns: Arc) { self.common().set_net_namespace(ns); } + + fn flags(&self) -> InterfaceFlags { + self.common().flags() + } + + fn type_(&self) -> InterfaceType { + self.common().type_() + } + + fn mtu(&self) -> usize; } /// 网络设备的公共数据 @@ -182,6 +194,8 @@ fn register_netdevice(dev: Arc) -> Result<(), SystemError> { pub struct IfaceCommon { iface_id: usize, + flags: InterfaceFlags, + type_: InterfaceType, smol_iface: SpinLock, /// 存smoltcp网卡的套接字集 sockets: SpinLock>, @@ -196,7 +210,7 @@ pub struct IfaceCommon { default_iface: bool, /// 网络命名空间 net_namespace: RwLock>, - + // 路由相关数据 router_common_data: RouterEnableDeviceCommon, } @@ -210,7 +224,13 @@ impl fmt::Debug for IfaceCommon { } impl IfaceCommon { - pub fn new(iface_id: usize, default_iface: bool, iface: smoltcp::iface::Interface) -> Self { + pub fn new( + iface_id: usize, + type_: InterfaceType, + flags: InterfaceFlags, + default_iface: bool, + iface: smoltcp::iface::Interface, + ) -> Self { let router_common_data = RouterEnableDeviceCommon::default(); router_common_data .ip_addrs @@ -226,6 +246,8 @@ impl IfaceCommon { default_iface, net_namespace: RwLock::new(Weak::new()), router_common_data, + flags, + type_, } } @@ -358,4 +380,12 @@ impl IfaceCommon { let mut guard = self.net_namespace.write(); *guard = Arc::downgrade(&ns); } + + pub fn flags(&self) -> InterfaceFlags { + self.flags + } + + pub fn type_(&self) -> InterfaceType { + self.type_ + } } diff --git a/kernel/src/driver/net/types.rs b/kernel/src/driver/net/types.rs new file mode 100644 index 000000000..05fbee657 --- /dev/null +++ b/kernel/src/driver/net/types.rs @@ -0,0 +1,84 @@ +use system_error::SystemError; + +/// Interface type. +/// +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum InterfaceType { + // Arp protocol hardware identifiers + /// from KA9Q: NET/ROM pseudo + NETROM = 0, + /// Ethernet 10Mbps + ETHER = 1, + /// Experimental Ethernet + EETHER = 2, + + // Dummy types for non ARP hardware + /// IPIP tunnel + TUNNEL = 768, + /// IP6IP6 tunnel + TUNNEL6 = 769, + /// Frame Relay Access Device + FRAD = 770, + /// SKIP vif + SKIP = 771, + /// Loopback device + LOOPBACK = 772, + /// Localtalk device + LOCALTALK = 773, + // TODO 更多类型 +} + +impl TryFrom for InterfaceType { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + use num_traits::FromPrimitive; + return ::from_u16(value).ok_or(Self::Error::EINVAL); + } +} + +bitflags! { + /// Interface flags. + /// + pub struct InterfaceFlags: u32 { + /// Interface is up + const UP = 1<<0; + /// Broadcast address valid + const BROADCAST = 1<<1; + /// Turn on debugging + const DEBUG = 1<<2; + /// Loopback net + const LOOPBACK = 1<<3; + /// Interface is has p-p link + const POINTOPOINT = 1<<4; + /// Avoid use of trailers + const NOTRAILERS = 1<<5; + /// Interface RFC2863 OPER_UP + const RUNNING = 1<<6; + /// No ARP protocol + const NOARP = 1<<7; + /// Receive all packets + const PROMISC = 1<<8; + /// Receive all multicast packets + const ALLMULTI = 1<<9; + /// Master of a load balancer + const MASTER = 1<<10; + /// Slave of a load balancer + const SLAVE = 1<<11; + /// Supports multicast + const MULTICAST = 1<<12; + /// Can set media type + const PORTSEL = 1<<13; + /// Auto media select active + const AUTOMEDIA = 1<<14; + /// Dialup device with changing addresses + const DYNAMIC = 1<<15; + /// Driver signals L1 up + const LOWER_UP = 1<<16; + /// Driver signals dormant + const DORMANT = 1<<17; + /// Echo sent packets + const ECHO = 1<<18; + } +} diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 556e8b1c7..65631452f 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -12,6 +12,7 @@ use crate::driver::base::kobject::{ use crate::driver::base::kset::KSet; use crate::driver::net::bridge::{BridgeCommonData, BridgePort}; use crate::driver::net::register_netdevice; +use crate::driver::net::types::InterfaceFlags; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; @@ -321,10 +322,22 @@ impl VethInterface { ); iface.set_any_ip(true); + let flags = InterfaceFlags::BROADCAST + | InterfaceFlags::MULTICAST + | InterfaceFlags::UP + | InterfaceFlags::RUNNING + | InterfaceFlags::LOWER_UP; + let device = Arc::new(VethInterface { name, driver: VethDriverWarpper(UnsafeCell::new(driver.clone())), - common: IfaceCommon::new(iface_id, true, iface), + common: IfaceCommon::new( + iface_id, + super::types::InterfaceType::EETHER, + flags, + false, + iface, + ), inner: SpinLock::new(VethCommonData::default()), locked_kobj_state: LockedKObjectState::default(), wait_queue: WaitQueue::default(), @@ -628,6 +641,14 @@ impl Iface for VethInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } } impl BridgeEnableDevice for VethInterface { diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 1179796fd..49f941879 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -29,7 +29,7 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - net::register_netdevice, + net::{register_netdevice, types::InterfaceFlags}, virtio::{ irq::virtio_irq_manager, sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager}, @@ -398,6 +398,93 @@ pub struct VirtioInterface { locked_kobj_state: LockedKObjectState, } +// // 先手糊为virtio实现这些,后面系统要是有了其他类型网卡,这些实现就得实现成一个单独的trait +// impl VirtioInterface { +// /// 消耗token然后主动发送一个 arp 数据包 +// pub fn emit_arp(arp_repr: &ArpRepr, tx_token: VirtioNetToken) { +// let ether_repr = match arp_repr { +// ArpRepr::EthernetIpv4 { +// source_hardware_addr, +// target_hardware_addr, +// .. +// } => EthernetRepr { +// src_addr: *source_hardware_addr, +// dst_addr: *target_hardware_addr, +// ethertype: EthernetProtocol::Arp, +// }, +// _ => return, +// }; + +// tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| { +// let mut frame = EthernetFrame::new_unchecked(buffer); +// ether_repr.emit(&mut frame); + +// let mut pkt = ArpPacket::new_unchecked(frame.payload_mut()); +// arp_repr.emit(&mut pkt); +// }); +// } + +// /// 解析 arp 包并处理 +// pub fn process_arp(&self, arp_repr: &ArpRepr) -> Option { +// match arp_repr { +// ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Reply, +// source_hardware_addr, +// source_protocol_addr, +// .. +// } => { +// if !source_hardware_addr.is_unicast() +// || !self +// .common() +// .smol_iface +// .lock() +// .context() +// .in_same_network(&IpAddress::Ipv4(*source_protocol_addr)) +// { +// return None; +// } + +// self.common().router_common_data.arp_table.write().insert( +// IpAddress::Ipv4(*source_protocol_addr), +// *source_hardware_addr, +// ); + +// None +// } +// ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Request, +// source_hardware_addr, +// source_protocol_addr, +// target_protocol_addr, +// .. +// } => { +// if !source_hardware_addr.is_unicast() || !source_protocol_addr.x_is_unicast() { +// return None; +// } + +// if self +// .common() +// .smol_iface +// .lock() +// .context() +// .ipv4_addr() +// .is_none_or(|addr| addr != *target_protocol_addr) +// { +// return None; +// } +// Some(ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Reply, +// source_hardware_addr: self.mac(), +// source_protocol_addr: *target_protocol_addr, +// target_hardware_addr: *source_hardware_addr, +// target_protocol_addr: *source_protocol_addr, +// }) +// } +// _ => None, +// } +// } +// } + #[derive(Debug)] struct InnerVirtIOInterface { kobj_common: KObjectCommonData, @@ -415,11 +502,23 @@ impl VirtioInterface { let iface = iface::Interface::new(iface_config, &mut device_inner, Instant::now().into()); + let flags = InterfaceFlags::UP + | InterfaceFlags::BROADCAST + | InterfaceFlags::RUNNING + | InterfaceFlags::MULTICAST + | InterfaceFlags::LOWER_UP; + let iface = Arc::new(VirtioInterface { device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)), locked_kobj_state: LockedKObjectState::default(), iface_name: format!("eth{}", iface_id), - iface_common: super::IfaceCommon::new(iface_id, true, iface), + iface_common: super::IfaceCommon::new( + iface_id, + crate::driver::net::types::InterfaceType::EETHER, + flags, + true, + iface, + ), inner: SpinLock::new(InnerVirtIOInterface { kobj_common: KObjectCommonData::default(), device_common: DeviceCommonData::default(), @@ -694,6 +793,14 @@ impl Iface for VirtioInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + self.device_inner + .force_get_mut() + .capabilities() + .max_transmission_unit + } } impl KObject for VirtioInterface { diff --git a/kernel/src/net/socket/netlink/route/kernel/link.rs b/kernel/src/net/socket/netlink/route/kernel/link.rs new file mode 100644 index 000000000..67f7132c3 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/link.rs @@ -0,0 +1,157 @@ +use crate::{ + driver::net::{types::InterfaceType, Iface}, + net::socket::{ + netlink::{ + message::segment::{ + header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, + CSegmentType, + }, + route::{ + kernel::utils::finish_response, + message::{ + attr::link::LinkAttr, + segment::{ + link::{LinkMessageFlags, LinkSegment, LinkSegmentBody}, + RouteNlSegment, + }, + }, + }, + }, + AddressFamily, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::ffi::CString; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZero; +use system_error::SystemError; + +pub(super) fn do_get_link( + request_segment: &LinkSegment, + netns: Arc, +) -> Result, SystemError> { + let filter_by = FilterBy::from_requset(request_segment)?; + + let mut responce: Vec = netns + .device_list() + .iter() + .filter(|(_, iface)| match &filter_by { + FilterBy::Index(index) => *index == iface.nic_id() as u32, + FilterBy::Name(name) => *name == iface.name(), + FilterBy::Dump => true, + }) + .map(|(_, iface)| iface_to_new_link(request_segment.header(), iface)) + .map(RouteNlSegment::NewLink) + .collect(); + + let dump_all = matches!(filter_by, FilterBy::Dump); + + if !dump_all && responce.is_empty() { + log::error!("no such device"); + return Err(SystemError::ENODEV); + } + + finish_response(request_segment.header(), dump_all, &mut responce); + + Ok(responce) +} + +enum FilterBy<'a> { + Index(u32), + Name(&'a str), + Dump, +} + +impl<'a> FilterBy<'a> { + fn from_requset(request_segment: &'a LinkSegment) -> Result { + let dump_all = { + let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); + flags.contains(GetRequestFlags::DUMP) + }; + if dump_all { + validate_dumplink_request(request_segment.body())?; + return Ok(Self::Dump); + } + + validate_getlink_request(request_segment.body())?; + + if let Some(required_index) = request_segment.body().index { + return Ok(Self::Index(required_index.get())); + } + + let required_name = request_segment.attrs().iter().find_map(|attr| { + if let LinkAttr::Name(name) = attr { + Some(name.to_str().ok()?) + } else { + None + } + }); + + if let Some(name) = required_name { + return Ok(Self::Name(name)); + } + + log::error!("either interface name or index should be specified for non-dump mode"); + Err(SystemError::EINVAL) + } +} + +fn validate_getlink_request(body: &LinkSegmentBody) -> Result<(), SystemError> { + if !body.flags.is_empty() + || body.type_ != InterfaceType::NETROM + || body.pad.is_some() + || !body.change.is_empty() + { + log::error!("the flags or the type is not valid"); + return Err(SystemError::EINVAL); + } + + Ok(()) +} + +fn validate_dumplink_request(body: &LinkSegmentBody) -> Result<(), SystemError> { + // . + if !body.flags.is_empty() + || body.type_ != InterfaceType::NETROM + || body.pad.is_some() + || !body.change.is_empty() + { + log::error!("the flags or the type is not valid"); + return Err(SystemError::EINVAL); + } + + // . + if body.index.is_some() { + log::error!("filtering by interface index is not valid for link dumps"); + return Err(SystemError::EINVAL); + } + + Ok(()) +} + +fn iface_to_new_link(request_header: &CMsgSegHdr, iface: &Arc) -> LinkSegment { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::NEWLINK as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let link_message = LinkSegmentBody { + family: AddressFamily::Unspecified, + type_: iface.type_(), + index: NonZero::new(iface.nic_id() as u32), + flags: iface.flags(), + change: LinkMessageFlags::empty(), + pad: None, + }; + + let attrs = vec![ + LinkAttr::Name(CString::new(iface.name()).unwrap()), + LinkAttr::Mtu(iface.mtu() as u32), + ]; + + LinkSegment::new(header, link_message, attrs) +} diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs index 7f33342e6..c06dc048a 100644 --- a/kernel/src/net/socket/netlink/route/kernel/mod.rs +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -19,6 +19,7 @@ use alloc::sync::Arc; use core::marker::PhantomData; mod addr; +mod link; mod utils; /// 负责处理 Netlink 路由相关的内核模块 @@ -47,6 +48,7 @@ impl NetlinkRouteKernelSocket { let seg_type = CSegmentType::try_from(header.type_).unwrap(); let responce = match segment { RouteNlSegment::GetAddr(request) => addr::do_get_addr(request, netns.clone()), + RouteNlSegment::GetLink(request) => link::do_get_link(request, netns.clone()), RouteNlSegment::GetRoute(_new_route) => todo!(), _ => { log::warn!("Unsupported route request segment type: {:?}", seg_type); diff --git a/kernel/src/net/socket/netlink/route/message/attr/link.rs b/kernel/src/net/socket/netlink/route/message/attr/link.rs new file mode 100644 index 000000000..49224a4ce --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/link.rs @@ -0,0 +1,219 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use crate::net::socket::netlink::message::attr::CAttrHeader; +use crate::net::socket::netlink::route::message::attr::convert_one_from_raw_buf; +use crate::net::socket::netlink::route::message::attr::IFNAME_SIZE; +use alloc::ffi::CString; +use system_error::SystemError; + +#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +enum LinkAttrClass { + UNSPEC = 0, + ADDRESS = 1, + BROADCAST = 2, + IFNAME = 3, + MTU = 4, + LINK = 5, + QDISC = 6, + STATS = 7, + COST = 8, + PRIORITY = 9, + MASTER = 10, + /// Wireless Extension event + WIRELESS = 11, + /// Protocol specific information for a link + PROTINFO = 12, + TXQLEN = 13, + MAP = 14, + WEIGHT = 15, + OPERSTATE = 16, + LINKMODE = 17, + LINKINFO = 18, + NET_NS_PID = 19, + IFALIAS = 20, + /// Number of VFs if device is SR-IOV PF + NUM_VF = 21, + VFINFO_LIST = 22, + STATS64 = 23, + VF_PORTS = 24, + PORT_SELF = 25, + AF_SPEC = 26, + /// Group the device belongs to + GROUP = 27, + NET_NS_FD = 28, + /// Extended info mask, VFs, etc. + EXT_MASK = 29, + /// Promiscuity count: > 0 means acts PROMISC + PROMISCUITY = 30, + NUM_TX_QUEUES = 31, + NUM_RX_QUEUES = 32, + CARRIER = 33, + PHYS_PORT_ID = 34, + CARRIER_CHANGES = 35, + PHYS_SWITCH_ID = 36, + LINK_NETNSID = 37, + PHYS_PORT_NAME = 38, + PROTO_DOWN = 39, + GSO_MAX_SEGS = 40, + GSO_MAX_SIZE = 41, + PAD = 42, + XDP = 43, + EVENT = 44, + NEW_NETNSID = 45, + IF_NETNSID = 46, + CARRIER_UP_COUNT = 47, + CARRIER_DOWN_COUNT = 48, + NEW_IFINDEX = 49, + MIN_MTU = 50, + MAX_MTU = 51, + PROP_LIST = 52, + /// Alternative ifname + ALT_IFNAME = 53, + PERM_ADDRESS = 54, + PROTO_DOWN_REASON = 55, + PARENT_DEV_NAME = 56, + PARENT_DEV_BUS_NAME = 57, +} + +impl TryFrom for LinkAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + use num_traits::FromPrimitive; + return ::from_u16(value).ok_or(Self::Error::EINVAL); + } +} + +#[derive(Debug)] +pub enum LinkAttr { + Name(CString), + Mtu(u32), + TxqLen(u32), + LinkMode(u8), + ExtMask(RtExtFilter), +} + +impl LinkAttr { + fn class(&self) -> LinkAttrClass { + match self { + LinkAttr::Name(_) => LinkAttrClass::IFNAME, + LinkAttr::Mtu(_) => LinkAttrClass::MTU, + LinkAttr::TxqLen(_) => LinkAttrClass::TXQLEN, + LinkAttr::LinkMode(_) => LinkAttrClass::LINKMODE, + LinkAttr::ExtMask(_) => LinkAttrClass::EXT_MASK, + } + } +} + +// #[derive(Debug)] +// pub enum LinkInfoAttr{ +// Kind(CString), +// Data(Vec), +// } + +// #[derive(Debug)] +// pub enum LinkInfoDataAttr{ +// VlanId(u16), + +// } + +impl Attribute for LinkAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + LinkAttr::Name(name) => name.as_bytes_with_nul(), + LinkAttr::Mtu(mtu) => unsafe { + core::slice::from_raw_parts(mtu as *const u32 as *const u8, 4) + }, + LinkAttr::TxqLen(txq_len) => unsafe { + core::slice::from_raw_parts(txq_len as *const u32 as *const u8, 4) + }, + LinkAttr::LinkMode(link_mode) => unsafe { + core::slice::from_raw_parts(link_mode as *const u8, 1) + }, + LinkAttr::ExtMask(ext_filter) => { + let bits = ext_filter.bits(); + unsafe { core::slice::from_raw_parts(&bits as *const u32 as *const u8, 4) } + } + } + } + + fn read_from_buf(header: &CAttrHeader, buf: &[u8]) -> Result, SystemError> + where + Self: Sized, + { + let payload_len = header.payload_len(); + + // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. + let Ok(class) = LinkAttrClass::try_from(header.type_()) else { + // reader.skip_some(payload_len); + return Ok(None); + }; + + let res = match (class, payload_len) { + (LinkAttrClass::IFNAME, 1..=IFNAME_SIZE) => { + let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + let cstr = CString::new(&buf[..nul_pos]).map_err(|_| SystemError::EINVAL)?; + Self::Name(cstr) + } + (LinkAttrClass::MTU, 4) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::Mtu(*data) + } + (LinkAttrClass::TXQLEN, 4) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::TxqLen(*data) + } + (LinkAttrClass::LINKMODE, 1) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::LinkMode(*data) + } + (LinkAttrClass::EXT_MASK, 4) => { + const { assert!(size_of::() == 4) }; + Self::ExtMask(*convert_one_from_raw_buf::(buf)?) + } + + ( + LinkAttrClass::IFNAME + | LinkAttrClass::MTU + | LinkAttrClass::TXQLEN + | LinkAttrClass::LINKMODE + | LinkAttrClass::EXT_MASK, + _, + ) => { + log::warn!("link attribute `{:?}` contains invalid payload", class); + return Err(SystemError::EINVAL); + } + + (_, _) => { + log::warn!("link attribute `{:?}` is not supported", class); + // reader.skip_some(payload_len); + return Ok(None); + } + }; + + Ok(Some(res)) + } +} + +bitflags! { + /// New extended info filters for [`NlLinkAttr::ExtMask`]. + /// + /// Reference: . + #[repr(C)] + pub struct RtExtFilter: u32 { + const VF = 1 << 0; + const BRVLAN = 1 << 1; + const BRVLAN_COMPRESSED = 1 << 2; + const SKIP_STATS = 1 << 3; + const MRP = 1 << 4; + const CFM_CONFIG = 1 << 5; + const CFM_STATUS = 1 << 6; + const MST = 1 << 7; + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/mod.rs b/kernel/src/net/socket/netlink/route/message/attr/mod.rs index d6de48bfb..9c082971c 100644 --- a/kernel/src/net/socket/netlink/route/message/attr/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/attr/mod.rs @@ -1,5 +1,21 @@ +use core::slice::from_raw_parts; +use system_error::SystemError; + pub mod addr; +pub mod link; pub mod route; /// 网卡名字长度 const IFNAME_SIZE: usize = 16; + +pub(super) fn convert_one_from_raw_buf(src: &[u8]) -> Result<&T, SystemError> { + log::info!("convert_one_from_raw_buf: src.len() = {}", src.len()); + if core::mem::size_of::() > src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &[u8] = &src[..core::mem::size_of::()]; + + let chunks = unsafe { from_raw_parts(byte_buffer.as_ptr() as *const T, 1) }; + let data = &chunks[0]; + return Ok(data); +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/link.rs b/kernel/src/net/socket/netlink/route/message/segment/link.rs new file mode 100644 index 000000000..ad0aa1dd1 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/link.rs @@ -0,0 +1,120 @@ +use crate::{ + driver::net::types::{InterfaceFlags, InterfaceType}, + net::socket::{ + netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::link::LinkAttr, + }, + AddressFamily, + }, +}; +use core::num::NonZeroU32; +use system_error::SystemError; + +pub type LinkSegment = SegmentCommon; + +impl SegmentBody for LinkSegmentBody { + type CType = CIfinfoMsg; +} + +/// `ifinfomsg` +/// . +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CIfinfoMsg { + /// AF_UNSPEC + pub family: u8, + /// Padding byte + pub pad: u8, + /// Device type + pub type_: u16, + /// Interface index + pub index: u32, + /// Device flags + pub flags: u32, + /// Change mask + pub change: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct LinkSegmentBody { + pub family: AddressFamily, + pub type_: InterfaceType, + pub index: Option, + pub flags: InterfaceFlags, + pub change: LinkMessageFlags, + pub pad: Option, // Must be 0 +} + +impl TryFrom for LinkSegmentBody { + type Error = SystemError; + + fn try_from(value: CIfinfoMsg) -> Result { + let family = AddressFamily::try_from(value.family as u16)?; + let type_ = InterfaceType::try_from(value.type_)?; + let index = NonZeroU32::new(value.index); + let flags = InterfaceFlags::from_bits_truncate(value.flags); + let change = LinkMessageFlags::from_bits_truncate(value.change); + let pad = if value.pad > 0 { Some(value.pad) } else { None }; + + Ok(Self { + family, + type_, + index, + flags, + change, + pad, + }) + } +} + +impl From for CIfinfoMsg { + fn from(value: LinkSegmentBody) -> Self { + CIfinfoMsg { + family: value.family as _, + pad: 0u8, + type_: value.type_ as _, + index: value.index.map(NonZeroU32::get).unwrap_or(0), + flags: value.flags.bits(), + change: value.change.bits(), + } + } +} + +bitflags! { + /// Flags in [`CIfinfoMsg`]. + pub struct LinkMessageFlags: u32 { + // sysfs + const IFF_UP = 1<<0; + // volatile + const IFF_BROADCAST = 1<<1; + // sysfs + const IFF_DEBUG = 1<<2; + // volatile + const IFF_LOOPBACK = 1<<3; + // volatile + const IFF_POINTOPOINT = 1<<4; + // sysfs + const IFF_NOTRAILERS = 1<<5; + // volatile + const IFF_RUNNING = 1<<6; + // sysfs + const IFF_NOARP = 1<<7; + // sysfs + const IFF_PROMISC = 1<<8; + // sysfs + const IFF_ALLMULTI = 1<<9; + // volatile + const IFF_MASTER = 1<<10; + // volatile + const IFF_SLAVE = 1<<11; + // sysfs + const IFF_MULTICAST = 1<<12; + // sysfs + const IFF_PORTSEL = 1<<13; + // sysfs + const IFF_AUTOMEDIA = 1<<14; + // sysfs + const IFF_DYNAMIC = 1<<15; + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs index ea40e2cf0..f72f558a6 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/mod.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -1,4 +1,5 @@ pub mod addr; +pub mod link; pub mod route; use crate::net::socket::netlink::{ @@ -10,14 +11,14 @@ use crate::net::socket::netlink::{ }, ProtocolSegment, }, - route::message::segment::{addr::AddrSegment, route::RouteSegment}, + route::message::segment::{addr::AddrSegment, link::LinkSegment, route::RouteSegment}, }; use system_error::SystemError; #[derive(Debug)] pub enum RouteNlSegment { - // NewLink(LinkSegment), - // GetLink(LinkSegment), + NewLink(LinkSegment), + GetLink(LinkSegment), NewAddr(AddrSegment), GetAddr(AddrSegment), Done(DoneSegment), @@ -36,6 +37,9 @@ impl ProtocolSegment for RouteNlSegment { RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { addr_segment.header() } + RouteNlSegment::NewLink(link_segment) | RouteNlSegment::GetLink(link_segment) => { + link_segment.header() + } RouteNlSegment::Done(done_segment) => done_segment.header(), RouteNlSegment::Error(error_segment) => error_segment.header(), } @@ -51,6 +55,9 @@ impl ProtocolSegment for RouteNlSegment { RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { addr_segment.header_mut() } + RouteNlSegment::NewLink(link_segment) | RouteNlSegment::GetLink(link_segment) => { + link_segment.header_mut() + } RouteNlSegment::Done(done_segment) => done_segment.header_mut(), RouteNlSegment::Error(error_segment) => error_segment.header_mut(), } @@ -73,6 +80,9 @@ impl ProtocolSegment for RouteNlSegment { CSegmentType::GETROUTE => { RouteNlSegment::GetRoute(RouteSegment::read_from_buf(header, payload_buf)?) } + CSegmentType::GETLINK => { + RouteNlSegment::GetLink(LinkSegment::read_from_buf(header, payload_buf)?) + } _ => return Err(SystemError::EINVAL), }; @@ -84,6 +94,7 @@ impl ProtocolSegment for RouteNlSegment { let copied = match self { RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(buf)?, RouteNlSegment::NewRoute(route_segment) => route_segment.write_to_buf(buf)?, + RouteNlSegment::NewLink(link_segment) => link_segment.write_to_buf(buf)?, RouteNlSegment::Done(done_segment) => done_segment.write_to_buf(buf)?, RouteNlSegment::Error(error_segment) => error_segment.write_to_buf(buf)?, _ => { From b95e3b28336fa668b68323d7402f99def3106992 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 29 Aug 2025 10:42:48 +0800 Subject: [PATCH 29/36] Refactor network driver interfaces and introduce NAPI support - Removed the default_iface parameter. - Introduced a new NAPI module to manage network polling and scheduling. - Updated the Iface trait to include a napi_struct method for NAPI support. - Modified Veth network interfaces to integrate with the new NAPI structure. - Refactored the Router implementation to remove unnecessary polling threads and wait queues. - Updated NetNamespace to manage a list of bridge devices. - Cleaned up various unused methods and comments across network-related files. Signed-off-by: sparkzky --- kernel/src/Makefile | 2 +- kernel/src/driver/net/bridge.rs | 230 +++++---------- kernel/src/driver/net/e1000e/e1000e_driver.rs | 1 - kernel/src/driver/net/loopback.rs | 1 - kernel/src/driver/net/mod.rs | 25 +- kernel/src/driver/net/napi.rs | 273 ++++++++++++++++++ kernel/src/driver/net/veth.rs | 182 ++++-------- kernel/src/driver/net/virtio_net.rs | 1 - kernel/src/net/net_core.rs | 2 +- kernel/src/net/routing.rs | 73 +---- kernel/src/net/socket/inet/datagram/mod.rs | 25 +- .../netlink/route/message/segment/link.rs | 1 + kernel/src/process/namespace/net_namespace.rs | 37 ++- 13 files changed, 451 insertions(+), 402 deletions(-) create mode 100644 kernel/src/driver/net/napi.rs diff --git a/kernel/src/Makefile b/kernel/src/Makefile index 6dd238e17..6ae2d4f6a 100644 --- a/kernel/src/Makefile +++ b/kernel/src/Makefile @@ -1,6 +1,6 @@ SUBDIR_ROOTS := . DIRS := . $(shell find $(SUBDIR_ROOTS) -type d) -GARBAGE_PATTERNS := *.o *.s~ *.s *.S~ *.c~ *.h~ kernel +GARBAGE_PATTERNS := *.o *.s~ *.s *.S~ *.c~ *.h~ GARBAGE := $(foreach DIR,$(DIRS),$(addprefix $(DIR)/,$(GARBAGE_PATTERNS))) DIR_LIB=libs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs index f3fdef585..0f888809d 100644 --- a/kernel/src/driver/net/bridge.rs +++ b/kernel/src/driver/net/bridge.rs @@ -1,18 +1,15 @@ use crate::{ - driver::net::{register_netdevice, veth::VethInterface, Iface, NetDeivceState, Operstate}, - init::initcall::INITCALL_DEVICE, - libs::{rwlock::RwLock, spinlock::SpinLock, wait_queue::WaitQueue}, - process::{ - kthread::{KernelThreadClosure, KernelThreadMechanism}, - namespace::net_namespace::INIT_NET_NAMESPACE, - ProcessState, + driver::net::{ + napi::napi_schedule, register_netdevice, veth::VethInterface, Iface, NetDeivceState, + Operstate, }, + init::initcall::INITCALL_DEVICE, + libs::{rwlock::RwLock, spinlock::SpinLock}, + process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, time::Instant, }; -use alloc::boxed::Box; -use alloc::collections::VecDeque; +use alloc::string::ToString; use alloc::sync::Weak; -use alloc::vec::Vec; use alloc::{collections::BTreeMap, string::String, sync::Arc}; use core::sync::atomic::AtomicUsize; use hashbrown::HashMap; @@ -58,7 +55,7 @@ struct MacEntryRecord { pub struct BridgePort { pub id: BridgePortId, pub(super) bridge_enable: Arc, - pub(super) bridge_iface: Weak, + pub(super) bridge_driver_ref: Weak, // 当前接口状态?forwarding, learning, blocking? // mac mtu信息 } @@ -67,21 +64,19 @@ impl BridgePort { fn new( id: BridgePortId, device: Arc, - bridge: &Arc, + bridge: &Arc, ) -> Self { - BridgePort { + let port = BridgePort { id, - bridge_enable: device, - bridge_iface: Arc::downgrade(bridge), - } - } + bridge_enable: device.clone(), + bridge_driver_ref: Arc::downgrade(bridge), + }; - // fn mac(&self) -> EthernetAddress { - // self.bridge_enable.mac() - // } -} + device.set_common_bridge_data(&port); -type ReceivedFrame = (BridgePortId, Vec); + port + } +} #[derive(Debug)] pub struct Bridge { @@ -92,10 +87,6 @@ pub struct Bridge { mac_table: HashMap, // 配置参数,比如aging timeout, max age, hello time, forward delay // bridge_mac: EthernetAddress, - next_port_id: AtomicUsize, - wait_queue: Arc, - - rx_buf: VecDeque, } impl Bridge { @@ -104,17 +95,9 @@ impl Bridge { name: name.into(), ports: BTreeMap::new(), mac_table: HashMap::new(), - next_port_id: AtomicUsize::new(0), - wait_queue: Arc::new(WaitQueue::default()), - rx_buf: VecDeque::new(), } } - fn next_port_id(&self) -> BridgePortId { - self.next_port_id - .fetch_add(1, core::sync::atomic::Ordering::Relaxed) - } - pub fn add_port(&mut self, id: BridgePortId, port: BridgePort) { self.ports.insert(id, port); } @@ -208,6 +191,9 @@ impl Bridge { fn transmit_to_device(&self, device: &BridgePort, frame: &[u8]) { device.bridge_enable.receive_from_bridge(frame); + if let Some(napi) = device.bridge_enable.napi_struct() { + napi_schedule(napi); + } } pub fn sweep_mac_table(&mut self) { @@ -220,147 +206,89 @@ impl Bridge { }); } - // pub fn poll_blocking(&mut self) { - // use crate::sched::SchedMode; - // loop { - // let opt = self.rx_buf.pop_front(); - // if let Some((port_id, frame)) = opt { - // self.handle_frame(port_id, &frame); - // } else { - // log::info!("Bridge is going to sleep"); - // let _ = wq_wait_event_interruptible!(self.wait_queue, !self.rx_buf.is_empty(), {}); - // } - // } - // } + pub fn name(&self) -> &str { + &self.name + } } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct BridgeDriver { - pub inner: Arc>, - wait_queue: Arc, + pub inner: SpinLock, + pub netns: RwLock>, + self_ref: Weak, + next_port_id: AtomicUsize, } impl BridgeDriver { - pub fn new(name: &str) -> Self { - let inner = Arc::new(SpinLock::new(Bridge::new(name))); - let wait_queue = inner.lock().wait_queue.clone(); - - let driver = BridgeDriver { inner, wait_queue }; - - // let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { - // driver_clone.poll_blocking(); - // 0 - // }); - // let closure = KernelThreadClosure::EmptyClosure((closure, ())); - // let name = format!("bridge_{}", name); - // let _pcb = KernelThreadMechanism::create_and_run(closure, name) - // .ok_or("") - // .expect("create bridge_poll thread failed"); - - driver - } - - pub fn add_port(&self, port: BridgePort) { - log::info!("Adding port with id: {}", port.id); - - self.inner.lock().add_port(port.id, port); + pub fn new(name: &str) -> Arc { + Arc::new_cyclic(|self_ref| BridgeDriver { + inner: SpinLock::new(Bridge::new(name)), + netns: RwLock::new(Weak::new()), + self_ref: self_ref.clone(), + next_port_id: AtomicUsize::new(0), + }) } - pub fn remove_port(&self, port_id: BridgePortId) { - self.inner.lock().remove_port(port_id); + fn next_port_id(&self) -> BridgePortId { + self.next_port_id + .fetch_add(1, core::sync::atomic::Ordering::Relaxed) } - fn poll_blocking(&self) { - use crate::sched::SchedMode; - - loop { - let mut inner = self.inner.lock(); - while let Some((port_id, frame)) = inner.rx_buf.pop_front() { - inner.handle_frame(port_id, &frame); + pub fn add_device(&self, device: Arc) { + if let Some(netns) = self.netns() { + if !Arc::ptr_eq( + &netns, + &device.net_namespace().unwrap_or(INIT_NET_NAMESPACE.clone()), + ) { + log::warn!("Port and bridge are in different net namespaces"); + return; } - drop(inner); - // log::info!("Bridge is going to sleep"); - let _ = wq_wait_event_interruptible!( - self.wait_queue, - !self.inner.lock().rx_buf.is_empty(), - {} - ); } - // inner.poll_blocking(); - } + let port = BridgePort::new( + self.next_port_id(), + device.clone(), + &self.self_ref.upgrade().unwrap(), + ); + log::info!("Adding port with id: {}", port.id); - pub fn enqueue_frame(&self, port_id: BridgePortId, frame: &Vec) { - { - let mut bridge = self.inner.lock(); - log::info!("Enqueuing frame on port {}: {:?}", port_id, frame); - log::warn!("{:?}", frame); - bridge.rx_buf.push_back((port_id, frame.clone())); - } - self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + self.inner.lock().add_port(port.id, port); } -} - -pub struct BridgeIface { - pub driver: BridgeDriver, - self_ref: Weak, -} -impl BridgeIface { - pub fn new(driver: BridgeDriver) -> Arc { - let name = driver.inner.lock().name.clone(); - - let iface = Arc::new_cyclic(|me| BridgeIface { - driver, - self_ref: me.clone(), - }); - let iface_clone = iface.clone(); - - // 创建一个线程来处理桥接设备的轮询 - let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { - iface_clone.poll_blocking(); - 0 - }); - let closure = KernelThreadClosure::EmptyClosure((closure, ())); - let name = format!("bridge_{}", name); - let _pcb = KernelThreadMechanism::create_and_run(closure, name) - .ok_or("") - .expect("create bridge_poll thread failed"); - - iface + pub fn remove_device(&self, device: Arc) { + let Some(common_data) = device.common_bridge_data() else { + log::warn!("Device is not part of any bridge"); + return; + }; + self.inner.lock().remove_port(common_data.id); } - pub fn add_port(&self, port_device: Arc) { - let id = self.driver.inner.lock().next_port_id(); - let port = BridgePort::new(id, port_device.clone(), &self.self_ref.upgrade().unwrap()); - - port_device.set_common_bridge_data(port.clone()); + pub fn handle_frame(&self, ingress_port_id: BridgePortId, frame: &[u8]) { + self.inner.lock().handle_frame(ingress_port_id, frame); + } - self.driver.add_port(port); + pub fn name(&self) -> String { + self.inner.lock().name().to_string() } - #[allow(unused)] - pub fn remove_port(&self, port_id: BridgePortId) { - self.driver.remove_port(port_id); + pub fn set_netns(&self, netns: &Arc) { + *self.netns.write() = Arc::downgrade(netns); } - pub fn poll_blocking(&self) { - self.driver.poll_blocking(); + pub fn netns(&self) -> Option> { + self.netns.read().upgrade() } } /// 可供桥接设备应该实现的 trait pub trait BridgeEnableDevice: Iface { + /// 接收来自桥的数据帧 fn receive_from_bridge(&self, frame: &[u8]); - // fn inner_driver(&self) -> Arc; - fn set_common_bridge_data(&self, _port: BridgePort); - // fn common_bridge_data(&self) -> Option; - // fn port_id(&self) -> Option { - // let Some(data) = self.common_bridge_data() else { - // return None; - // }; - // Some(data.id) - // } + /// 设置桥接相关的公共数据 + fn set_common_bridge_data(&self, _port: &BridgePort); + + /// 获取桥接相关的公共数据 + fn common_bridge_data(&self) -> Option; // fn bridge(&self) -> Weak { // let Some(data) = self.common_bridge_data() else { // return Weak::default(); @@ -372,7 +300,7 @@ pub trait BridgeEnableDevice: Iface { #[derive(Debug, Clone)] pub struct BridgeCommonData { pub id: BridgePortId, - pub bridge_iface: Weak, + pub bridge_driver_ref: Weak, } fn bridge_probe() { @@ -417,12 +345,12 @@ fn bridge_probe() { turn_on(&iface4); let bridge = BridgeDriver::new("bridge0"); - let iface = BridgeIface::new(bridge); + bridge.set_netns(&INIT_NET_NAMESPACE); + INIT_NET_NAMESPACE.insert_bridge(bridge.clone()); - // BRIDGE_DEVICES.write_irqsave().push(bridge.clone()); + bridge.add_device(iface3); + bridge.add_device(iface2); - iface.add_port(iface3); - iface.add_port(iface2); log::info!("Bridge device created"); } diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index c6c90180b..8a448e3db 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -214,7 +214,6 @@ impl E1000EInterface { iface_id, crate::driver::net::types::InterfaceType::EETHER, flags, - false, iface, ), name: format!("eth{}", iface_id), diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index ff41c8971..a08901b24 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -333,7 +333,6 @@ impl LoopbackInterface { iface_id, super::types::InterfaceType::LOOPBACK, flags, - false, iface, ), inner: SpinLock::new(InnerLoopbackInterface { diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 1e27aefae..5c2b81dc1 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -22,6 +22,7 @@ mod dma; pub mod e1000e; pub mod irq_handle; pub mod loopback; +pub mod napi; pub mod sysfs; pub mod types; pub mod veth; @@ -88,14 +89,6 @@ pub trait Iface: crate::driver::base::device::Device { /// - `false`:表示没有网络事件 fn poll(&self) -> bool; - /// # `poll_blocking` - /// 用于在阻塞模式下轮询网卡 - /// ## 参数 - /// - `can_recv_fn` :一个函数指针,用于判断是否可以接收数据 - /// ## 返回值 - /// - 该函数不返回任何值,但会在满足条件时阻塞当前线程,直到可以接收数据。 - fn poll_blocking(&self, _can_recv_fn: &dyn Fn() -> bool) {} - /// # `update_ip_addrs` /// 用于更新接口的 IP 地址 /// ## 参数 @@ -154,6 +147,12 @@ pub trait Iface: crate::driver::base::device::Device { } fn mtu(&self) -> usize; + + /// # 获取当前iface的napi结构体 + /// 默认返回None,表示不支持napi + fn napi_struct(&self) -> Option> { + None + } } /// 网络设备的公共数据 @@ -205,9 +204,6 @@ pub struct IfaceCommon { port_manager: PortManager, /// 下次轮询的时间 poll_at_ms: core::sync::atomic::AtomicU64, - /// 默认网卡标识 - /// TODO: 此字段设置目的是解决对bind unspecified地址的分包问题,需要在inet实现多网卡监听或路由子系统实现后移除 - default_iface: bool, /// 网络命名空间 net_namespace: RwLock>, // 路由相关数据 @@ -228,7 +224,6 @@ impl IfaceCommon { iface_id: usize, type_: InterfaceType, flags: InterfaceFlags, - default_iface: bool, iface: smoltcp::iface::Interface, ) -> Self { let router_common_data = RouterEnableDeviceCommon::default(); @@ -243,7 +238,6 @@ impl IfaceCommon { bounds: RwLock::new(Vec::new()), port_manager: PortManager::default(), poll_at_ms: core::sync::atomic::AtomicU64::new(0), - default_iface, net_namespace: RwLock::new(Weak::new()), router_common_data, flags, @@ -351,11 +345,6 @@ impl IfaceCommon { } } - // TODO: 需要在inet实现多网卡监听或路由子系统实现后移除 - pub fn is_default_iface(&self) -> bool { - self.default_iface - } - pub fn ipv4_addr(&self) -> Option { self.smol_iface.lock().ipv4_addr() } diff --git a/kernel/src/driver/net/napi.rs b/kernel/src/driver/net/napi.rs new file mode 100644 index 000000000..6541c89d7 --- /dev/null +++ b/kernel/src/driver/net/napi.rs @@ -0,0 +1,273 @@ +use crate::driver::net::Iface; +use crate::init::initcall::INITCALL_SUBSYS; +use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::libs::wait_queue::WaitQueue; +use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; +use crate::process::ProcessState; +use alloc::boxed::Box; +use alloc::string::ToString; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +lazy_static! { + //todo 按照软中断的做法,这里应该是每个CPU一个列表,但目前只实现单CPU版本 + static ref GLOBAL_NAPI_MANAGER: Arc = + NapiManager::new(); +} + +/// # NAPI 结构体 +/// +/// https://elixir.bootlin.com/linux/v6.13/source/include/linux/netdevice.h#L359 +#[derive(Debug)] +pub struct NapiStruct { + /// NAPI实例状态 + pub state: AtomicU32, + /// NAPI实例权重,表示每次轮询时处理的最大包数 + pub weight: usize, + /// 唯一id + pub napi_id: usize, + /// 指向所属网卡的弱引用 + pub net_device: Weak, +} + +impl NapiStruct { + pub fn new(net_device: Arc, weight: usize) -> Arc { + Arc::new(Self { + state: AtomicU32::new(NapiState::empty().bits()), + weight, + napi_id: net_device.nic_id(), + net_device: Arc::downgrade(&net_device), + }) + } + + pub fn poll(&self) -> bool { + // log::info!("NAPI instance {} polling", self.napi_id); + // 获取网卡的强引用 + if let Some(iface) = self.net_device.upgrade() { + // 这里的weight原意是此次执行可以处理的包,如果超过了这个数就交给专门的内核线程(ksoftirqd)继续处理 + // 但目前我们就是在相当于ksoftirqd里面处理,如果在weight之内发现没数据包被处理了,在直接返回 + // 如果超过weight,返回true,表示还有工作没做完,会在下一次轮询继续处理 + // 因此语义是相同的 + for _ in 0..self.weight { + if !iface.poll() { + return false; + } + } + } else { + log::error!( + "NAPI instance {}: associated net device is gone", + self.napi_id + ); + } + + true + } +} + +bitflags! { + /// # NAPI状态标志 + /// + /// https://elixir.bootlin.com/linux/v6.13/source/include/linux/netdevice.h#L398 + pub struct NapiState:u32{ + /// Poll is scheduled. 这是最核心的状态,表示NAPI实例已被调度, + /// 存在于某个CPU的poll_list中等待处理。 + const SCHED = 1 << 0; + /// Missed a poll. 如果在NAPI实例被调度后但在实际处理前又有新的数据到达, + const MISSED = 1 << 1; + /// Disable pending. NAPI正在被禁用,不应再被调度。 + const DISABLE = 1 << 2; + const NPSVC = 1 << 3; + /// NAPI added to system lists. 表示NAPI实例已注册到设备中。 + const LISTED = 1 << 4; + const NO_BUSY_POLL = 1 << 5; + const IN_BUSY_POLL = 1 << 6; + const PREFER_BUSY_POLL = 1 << 7; + /// The poll is performed inside its own thread. + /// 一个可选的高级功能,表示此NAPI由专用内核线程处理。 + const THREADED = 1 << 8; + const SCHED_THREADED = 1 << 9; + } +} + +#[inline(never)] +#[unified_init(INITCALL_SUBSYS)] +pub fn napi_init() -> Result<(), SystemError> { + // 软中断做法 + // let napi_handler = Arc::new(NapiSoftirq::default()); + // softirq_vectors() + // .register_softirq(SoftirqNumber::NetReceive, napi_handler) + // .expect("Failed to register napi softirq"); + + // 软中断的方式无法唤醒 :( + // 使用一个专门的内核线程来处理NAPI轮询,模拟软中断的行为,相当于ksoftirq :) + + let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { + net_rx_action(); + 0 + }); + let closure = KernelThreadClosure::EmptyClosure((closure, ())); + let name = "napi_handler".to_string(); + let _pcb = KernelThreadMechanism::create_and_run(closure, name) + .ok_or("") + .expect("create napi_handler thread failed"); + + log::info!("napi initialized successfully"); + Ok(()) +} + +fn net_rx_action() { + use crate::sched::SchedMode; + + loop { + // 这里直接将全局的NAPI管理器的napi_list取出,清空全局的列表,避免占用锁时间过长 + let mut inner = GLOBAL_NAPI_MANAGER.inner(); + let mut poll_list = inner.napi_list.clone(); + inner.napi_list.clear(); + drop(inner); + + // log::info!("NAPI softirq processing {} instances", poll_list.len()); + + let size = poll_list.len(); + // 如果此时长度为0,则让当前进程休眠,等待被唤醒 + if size == 0 { + GLOBAL_NAPI_MANAGER + .inner() + .has_pending_signal + .store(false, Ordering::SeqCst); + } + + for _ in 0..size { + let Some(napi) = poll_list.pop() else { + break; + }; + + let has_work_left = napi.poll(); + log::info!("yes"); + + if has_work_left { + poll_list.push(napi); + } else { + napi_complete(napi); + } + } + + // log::info!("napi softirq iteration complete") + + // 在这种情况下,poll_list 中仍然有待处理的 NAPI 实例,压回队列,等待下一次唤醒时处理 + if !poll_list.is_empty() { + GLOBAL_NAPI_MANAGER.inner().napi_list.extend(poll_list); + } + + let _ = wq_wait_event_interruptible!( + GLOBAL_NAPI_MANAGER.wait_queue(), + GLOBAL_NAPI_MANAGER + .inner() + .has_pending_signal + .load(Ordering::SeqCst), + {} + ); + } +} + +/// 标记这个napi任务已经完成 +pub fn napi_complete(napi: Arc) { + napi.state + .fetch_and(!NapiState::SCHED.bits(), Ordering::SeqCst); +} + +/// 标记这个napi任务加入处理队列,已被调度 +pub fn napi_schedule(napi: Arc) { + let current_state = NapiState::from_bits_truncate( + napi.state + .fetch_or(NapiState::SCHED.bits(), Ordering::SeqCst), + ); + + if !current_state.contains(NapiState::SCHED) { + let new_state = current_state.union(NapiState::SCHED); + // log::info!("NAPI instance {} scheduled", napi.napi_id); + napi.state.store(new_state.bits(), Ordering::SeqCst); + } + + let mut inner = GLOBAL_NAPI_MANAGER.inner(); + inner.napi_list.push(napi); + inner.has_pending_signal.store(true, Ordering::SeqCst); + + GLOBAL_NAPI_MANAGER.wakeup(); + + // softirq_vectors().raise_softirq(SoftirqNumber::NetReceive); +} + +pub struct NapiManager { + inner: SpinLock, + wait_queue: WaitQueue, +} + +impl NapiManager { + pub fn new() -> Arc { + let inner = SpinLock::new(NapiManagerInner { + has_pending_signal: AtomicBool::new(false), + napi_list: Vec::new(), + }); + Arc::new(Self { + inner, + wait_queue: WaitQueue::default(), + }) + } + + pub fn inner(&self) -> SpinLockGuard<'_, NapiManagerInner> { + self.inner.lock() + } + + pub fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + pub fn wakeup(&self) { + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + } +} + +pub struct NapiManagerInner { + has_pending_signal: AtomicBool, + napi_list: Vec>, +} + +// 下面的是软中断的做法,无法唤醒,做个记录 + +// #[derive(Debug)] +// pub struct NapiSoftirq { +// running: AtomicBool, +// } + +// impl Default for NapiSoftirq { +// fn default() -> Self { +// Self { +// running: AtomicBool::new(false), +// } +// } +// } + +// impl SoftirqVec for NapiSoftirq { +// fn run(&self) { +// log::info!("NAPI softirq running"); +// if self +// .running +// .compare_exchange( +// false, +// true, +// core::sync::atomic::Ordering::SeqCst, +// core::sync::atomic::Ordering::SeqCst, +// ) +// .is_ok() +// { +// net_rx_action(); +// self.running +// .store(false, core::sync::atomic::Ordering::SeqCst); +// } else { +// log::warn!("NAPI softirq is already running"); +// } +// } +// } diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 65631452f..0b0e2639c 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -11,18 +11,17 @@ use crate::driver::base::kobject::{ }; use crate::driver::base::kset::KSet; use crate::driver::net::bridge::{BridgeCommonData, BridgePort}; +use crate::driver::net::napi::{napi_schedule, NapiStruct}; use crate::driver::net::register_netdevice; use crate::driver::net::types::InterfaceFlags; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; -use crate::libs::wait_queue::WaitQueue; use crate::net::generate_iface_id; use crate::net::routing::{RouteEntry, RouterEnableDevice}; use crate::process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}; -use crate::process::{ProcessManager, ProcessState}; -use crate::sched::SchedMode; +use crate::process::ProcessManager; use alloc::collections::VecDeque; use alloc::fmt::Debug; use alloc::string::{String, ToString}; @@ -58,48 +57,31 @@ impl Veth { self.peer = Arc::downgrade(peer); } - pub fn send_to_peer(&self, data: &Vec) { + pub fn send_to_peer(&self, data: &[u8]) { if let Some(peer) = self.peer.upgrade() { // log::info!("Veth {} trying to send", self.name); - if let Some(bridge_common_data) = peer.inner.lock().bridge_common_data.as_ref() { - // log::info!("Veth {} sending data to bridge", self.name); - Self::to_bridge(bridge_common_data, data); - return; - } - - // 如果是路由设备,则将数据发送到路由器 - if self.to_router(data) { - // log::info!("Veth {} sent data to router", self.name); - return; - } - Self::to_peer(&peer, data); } } pub(self) fn to_peer(peer: &Arc, data: &[u8]) { - let mut peer_veth = peer.driver.force_get_mut().inner.lock_irqsave(); + let mut peer_veth = peer.driver.inner.lock(); peer_veth.rx_queue.push_back(data.to_vec()); log::info!("Veth {} received data from peer", peer.name); log::info!("{:?}", peer_veth.rx_queue); drop(peer_veth); - // 唤醒对端正在等待的进程 - peer.wake_up(); - - if let Some(ns) = peer.net_namespace() { - ns.wakeup_poll_thread(); - } + napi_schedule(peer.napi_struct()); } - fn to_bridge(bridge_data: &BridgeCommonData, data: &Vec) { + fn to_bridge(bridge_data: &BridgeCommonData, data: &[u8]) { // log::info!("Veth {} sending data to bridge", self.name); - let Some(bridge) = bridge_data.bridge_iface.upgrade() else { + let Some(bridge) = bridge_data.bridge_driver_ref.upgrade() else { log::warn!("Bridge has been dropped"); return; }; - bridge.driver.enqueue_frame(bridge_data.id, data) + bridge.handle_frame(bridge_data.id, data); } /// 经过路由发送,返回是否发送成功 @@ -129,7 +111,28 @@ impl Veth { pub fn recv_from_peer(&mut self) -> Option> { // log::info!("Veth {} trying to receive", self.name); - self.rx_queue.pop_front() + let data = self.rx_queue.pop_front()?; + + if let Some(bridge_common_data) = self + .self_iface_ref + .upgrade() + .unwrap() + .inner + .lock() + .bridge_common_data + .as_ref() + { + // log::info!("Veth {} sending data to bridge", self.name); + Self::to_bridge(bridge_common_data, &data); + return None; + } + + // 说明获取的包发给进入路由了,无须返回 + if self.to_router(&data) { + return None; + } + + Some(data) } pub fn name(&self) -> &str { @@ -162,7 +165,7 @@ impl VethDriver { } pub fn name(&self) -> String { - self.inner.lock_irqsave().name().to_string() + self.inner.lock().name().to_string() } } @@ -177,7 +180,7 @@ impl phy::TxToken for VethTxToken { { let mut buf = vec![0; len]; let result = f(&mut buf); - self.driver.inner.lock_irqsave().send_to_peer(&buf); + self.driver.inner.lock().send_to_peer(&buf); result } } @@ -236,7 +239,7 @@ impl phy::Device for VethDriver { &mut self, _timestamp: smoltcp::time::Instant, ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { - let mut guard = self.inner.lock_irqsave(); + let mut guard = self.inner.lock(); guard.recv_from_peer().map(|buf| { // log::info!("VethDriver received data: {:?}", buf); ( @@ -264,7 +267,7 @@ pub struct VethInterface { common: IfaceCommon, inner: SpinLock, locked_kobj_state: LockedKObjectState, - wait_queue: WaitQueue, + napi_struct: SpinLock>>, } #[derive(Debug)] @@ -290,15 +293,8 @@ impl Default for VethCommonData { } impl VethInterface { - pub fn has_data(&self) -> bool { - let driver = self.driver.force_get_mut(); - let inner = driver.inner.lock_irqsave(); - !inner.rx_queue.is_empty() - } - - #[allow(unused)] pub fn peer_veth(&self) -> Arc { - self.inner.lock_irqsave().peer_veth.upgrade().unwrap() + self.inner.lock().peer_veth.upgrade().unwrap() } pub fn new(driver: VethDriver) -> Arc { @@ -331,17 +327,13 @@ impl VethInterface { let device = Arc::new(VethInterface { name, driver: VethDriverWarpper(UnsafeCell::new(driver.clone())), - common: IfaceCommon::new( - iface_id, - super::types::InterfaceType::EETHER, - flags, - false, - iface, - ), + common: IfaceCommon::new(iface_id, super::types::InterfaceType::EETHER, flags, iface), inner: SpinLock::new(VethCommonData::default()), locked_kobj_state: LockedKObjectState::default(), - wait_queue: WaitQueue::default(), + napi_struct: SpinLock::new(None), }); + let napi_struct = NapiStruct::new(device.clone(), 10); + *device.napi_struct.lock() = Some(napi_struct); driver.inner.lock().self_iface_ref = Arc::downgrade(&device); @@ -350,9 +342,9 @@ impl VethInterface { } pub fn set_peer_iface(&self, peer: &Arc) { - let mut inner = self.inner.lock_irqsave(); + let mut inner = self.inner.lock(); inner.peer_veth = Arc::downgrade(peer); - self.driver.inner.lock_irqsave().set_peer_iface(peer); + self.driver.inner.lock().set_peer_iface(peer); } pub fn new_pair(name1: &str, name2: &str) -> (Arc, Arc) { @@ -382,6 +374,7 @@ impl VethInterface { iface.update_ip_addrs(|ip_addrs| { ip_addrs.push(cidr).expect("Push ipCidr failed: full"); }); + self.common.router_common_data.ip_addrs.write().push(cidr); // // 直接更新对端的arp_table // self.inner.lock().peer_veth.upgrade().map(|peer| { @@ -436,8 +429,8 @@ impl VethInterface { // }); // } - pub fn wake_up(&self) { - self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + pub fn napi_struct(&self) -> Arc { + self.napi_struct.lock().as_ref().unwrap().clone() } } @@ -582,35 +575,6 @@ impl Iface for VethInterface { } } - fn poll_blocking(&self, can_stop_fn: &dyn Fn() -> bool) { - log::info!("VethInterface {} polling block", self.name); - - loop { - // 检查是否有数据可用 - self.common.poll(self.driver.force_get_mut()); - - let has_data = self.has_data(); - - // 外部 socket 是否可以接收数据,如果是的话就可以退出loop了 - let can_stop = can_stop_fn(); - - if can_stop { - break; - } - - // 没有数据可用时,进入等待队列 - // 如果有数据可用,则直接跳出循环 - log::info!("VethInterface {} waiting for data", self.name); - if !has_data { - let _ = wq_wait_event_interruptible!( - self.wait_queue, - self.has_data() || can_stop_fn(), - {} - ); - } - } - } - fn poll(&self) -> bool { // log::info!("VethInterface {} polling normal", self.name); self.common.poll(self.driver.force_get_mut()) @@ -649,6 +613,10 @@ impl Iface for VethInterface { .capabilities() .max_transmission_unit } + + fn napi_struct(&self) -> Option> { + Some(self.napi_struct()) + } } impl BridgeEnableDevice for VethInterface { @@ -660,11 +628,11 @@ impl BridgeEnableDevice for VethInterface { if self .inner - .lock_irqsave() + .lock() .bridge_common_data .as_ref() .unwrap() - .bridge_iface + .bridge_driver_ref .upgrade() .is_some() { @@ -682,23 +650,23 @@ impl BridgeEnableDevice for VethInterface { log::info!("returning"); } - fn set_common_bridge_data(&self, port: BridgePort) { + fn set_common_bridge_data(&self, port: &BridgePort) { // log::info!("Now set bridge port data for {}", self.name); - let mut inner = self.inner.lock_irqsave(); + let mut inner = self.inner.lock(); let data = BridgeCommonData { id: port.id, - bridge_iface: port.bridge_iface.clone(), + bridge_driver_ref: port.bridge_driver_ref.clone(), }; inner.bridge_common_data = Some(data); } - // fn common_bridge_data(&self) -> Option { - // self.inner().bridge_common_data.clone() - // } + fn common_bridge_data(&self) -> Option { + self.inner().bridge_common_data.clone() + } } impl RouterEnableDevice for VethInterface { - fn route_and_send(&self, next_hop: IpAddress, ip_packet: &[u8]) { + fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]) { log::info!( "VethInterface {} routing packet to {}", self.iface_name(), @@ -719,11 +687,7 @@ impl RouterEnableDevice for VethInterface { frame.extend_from_slice(ip_packet); // 发送到对端 - self.driver - .force_get_mut() - .inner - .lock_irqsave() - .send_to_peer(&frame); + self.driver.inner.lock().send_to_peer(&frame); } fn is_my_ip(&self, ip: IpAddress) -> bool { @@ -734,36 +698,6 @@ impl RouterEnableDevice for VethInterface { } } -// pub fn veth_probe(name1: &str, name2: &str) -> (Arc, Arc) { -// let (iface1, iface2) = VethInterface::new_pair(name1, name2); - -// let addr1 = IpAddress::v4(10, 0, 0, 1); -// let cidr1 = IpCidr::new(addr1, 24); -// iface1.update_ip_addrs(cidr1); - -// let addr2 = IpAddress::v4(10, 0, 0, 2); -// let cidr2 = IpCidr::new(addr2, 24); -// iface2.update_ip_addrs(cidr2); - -// // 添加默认路由 -// iface1.add_default_route_to_peer(addr2); -// iface2.add_default_route_to_peer(addr1); - -// let turn_on = |a: &Arc| { -// a.set_net_state(NetDeivceState::__LINK_STATE_START); -// a.set_operstate(Operstate::IF_OPER_UP); -// // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); -// INIT_NET_NAMESPACE.add_device(a.clone()); -// a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); -// register_netdevice(a.clone()).expect("register veth device failed"); -// }; - -// turn_on(&iface1); -// turn_on(&iface2); - -// (iface1, iface2) -// } - fn veth_route_test() { let (iface_ns1, iface_host1) = VethInterface::new_pair("veth-ns1", "veth-host1"); let (iface_ns2, iface_host2) = VethInterface::new_pair("veth-ns2", "veth-host2"); diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 49f941879..42b3d0d69 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -516,7 +516,6 @@ impl VirtioInterface { iface_id, crate::driver::net::types::InterfaceType::EETHER, flags, - true, iface, ), inner: SpinLock::new(InnerVirtIOInterface { diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 5a939b2eb..4b99deb34 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -13,7 +13,7 @@ pub fn net_init() -> Result<(), SystemError> { fn dhcp_query() -> Result<(), SystemError> { // let binding = NET_DEVICES.write_irqsave(); - let binding = INIT_NET_NAMESPACE.device_list_write(); + let binding = INIT_NET_NAMESPACE.device_list_mut(); let net_face = binding .iter() diff --git a/kernel/src/net/routing.rs b/kernel/src/net/routing.rs index 34ee3d58a..ab75f48a0 100644 --- a/kernel/src/net/routing.rs +++ b/kernel/src/net/routing.rs @@ -1,13 +1,7 @@ use crate::driver::net::Iface; use crate::libs::rwlock::RwLock; -use crate::libs::wait_queue::WaitQueue; -use crate::process::kthread::KernelThreadClosure; -use crate::process::kthread::KernelThreadMechanism; use crate::process::namespace::net_namespace::NetNamespace; use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; -use crate::process::ProcessState; -use alloc::boxed::Box; -use alloc::collections::VecDeque; use alloc::string::{String, ToString}; use alloc::sync::{Arc, Weak}; use alloc::vec::Vec; @@ -95,37 +89,15 @@ pub struct Router { /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec route_table: RwLock, pub ns: RwLock>, - - wait_queue: WaitQueue, - rx_frames: RwLock>, } impl Router { pub fn new(name: String) -> Arc { - let router = Arc::new(Self { + Arc::new(Self { name: name.clone(), route_table: RwLock::new(RouteTable::default()), - wait_queue: WaitQueue::default(), - rx_frames: RwLock::new(VecDeque::new()), ns: RwLock::new(Weak::default()), - }); - - let self_clone: Arc = router.clone(); - - // 创建一个线程来处理桥接设备的轮询 - let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { - self_clone.poll_blocking(); - 0 - }); - let closure = KernelThreadClosure::EmptyClosure((closure, ())); - let name = name + "_poll"; - log::info!("Creating router polling thread: {}", name); - let _pcb = KernelThreadMechanism::create_and_run(closure, name) - .ok_or("") - .expect("create router_poll thread failed"); - - // log::info!("Router polling thread created"); - router + }) } /// 创建一个空的Router实例,主要用于初始化网络命名空间时使用 @@ -134,8 +106,6 @@ impl Router { Arc::new(Self { name: "empty_router".to_string(), route_table: RwLock::new(RouteTable::default()), - wait_queue: WaitQueue::default(), - rx_frames: RwLock::new(VecDeque::new()), ns: RwLock::new(Weak::default()), }) } @@ -190,32 +160,6 @@ impl Router { .entries .retain(|route| route.interface.strong_count() > 0); } - - pub fn enqueue_frame(&self, frame: RouterFrame) { - self.rx_frames.write().push_back(frame); - self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); - } - - fn poll_blocking(&self) { - use crate::sched::SchedMode; - - loop { - self.poll(); - - log::info!("Router is going to sleep"); - let _ = - wq_wait_event_interruptible!(self.wait_queue, !self.rx_frames.read().is_empty(), { - }); - } - } - - fn poll(&self) { - let mut inner = self.rx_frames.write(); - while let Some((decision, frame)) = inner.pop_front() { - decision.interface.route_and_send(decision.next_hop, &frame); - } - log::info!("Router polled all frames"); - } } /// 获取初始化网络命名空间下的路由表 @@ -308,10 +252,11 @@ pub trait RouterEnableDevice: Iface { // //todo 这里应该重新计算IP校验和,为了简化先跳过 // } - let frame = (decision, modified_ip_packet); - // 交给出接口进行发送 - self.netns_router().enqueue_frame(frame); + let next_hop = &decision.next_hop; + decision + .interface + .route_and_send(next_hop, &modified_ip_packet); log::info!("Routed packet from {} to {} ", self.iface_name(), dst_ip,); Ok(()) @@ -322,7 +267,7 @@ pub trait RouterEnableDevice: Iface { /// /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( /// 后续需要将arp查询的逻辑从smoltcp中抽离出来 - fn route_and_send(&self, next_hop: IpAddress, ip_packet: &[u8]); + fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]); /// 检查IP地址是否是当前接口的IP fn is_my_ip(&self, ip: IpAddress) -> bool; @@ -333,15 +278,13 @@ pub trait RouterEnableDevice: Iface { } } -pub type RouterFrame = (RouteDecision, Vec); - /// # 每一个`RouterEnableDevice`应该有的公共数据,包含 /// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) #[derive(Debug)] pub struct RouterEnableDeviceCommon { /// 当前接口的邻居缓存 // pub arp_table: RwLock>, - /// 当前接口的IP地址列表 + /// 当前接口的IP地址列表(因为如果直接通过smoltcp获取ip的话可能导致死锁,因此则这里维护一份) pub ip_addrs: RwLock>, } diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index a857fcdcc..16dc94503 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -165,22 +165,6 @@ impl UdpSocket { return event; } - /// 这个方法会阻塞当前线程,直到有数据可读 - /// 通过 poll_blocking 来等待数据的到来 - pub(self) fn wait_for_recv(&self) { - use crate::sched::SchedMode; - let guard = self.inner.read(); - let inner = guard.as_ref(); - if let UdpInner::Bound(bound) = inner.unwrap() { - let rem = bound.inner().iface().clone(); - drop(guard); - let self_ref = self.self_ref.upgrade().unwrap().clone(); - let can_recv = move || self_ref.can_recv(); - rem.poll_blocking(&can_recv); - } - let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); - } - pub fn netns(&self) -> Arc { self.netns.clone() } @@ -249,13 +233,15 @@ impl Socket for UdpSocket { } fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + use crate::sched::SchedMode; + return if self.is_nonblock() || flags.contains(PMSG::DONTWAIT) { self.try_recv(buffer) } else { loop { match self.try_recv(buffer) { Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - self.wait_for_recv(); + wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; } result => break result, } @@ -270,6 +256,7 @@ impl Socket for UdpSocket { flags: PMSG, address: Option, ) -> Result<(usize, Endpoint), SystemError> { + use crate::sched::SchedMode; // could block io if let Some(endpoint) = address { self.connect(endpoint)?; @@ -281,8 +268,8 @@ impl Socket for UdpSocket { loop { match self.try_recv(buffer) { Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { - self.wait_for_recv(); - log::info!("UdpSocket::recv_from: wake up"); + wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {})?; + log::debug!("UdpSocket::recv_from: wake up"); } result => break result, } diff --git a/kernel/src/net/socket/netlink/route/message/segment/link.rs b/kernel/src/net/socket/netlink/route/message/segment/link.rs index ad0aa1dd1..c386712d5 100644 --- a/kernel/src/net/socket/netlink/route/message/segment/link.rs +++ b/kernel/src/net/socket/netlink/route/message/segment/link.rs @@ -83,6 +83,7 @@ impl From for CIfinfoMsg { bitflags! { /// Flags in [`CIfinfoMsg`]. + /// . pub struct LinkMessageFlags: u32 { // sysfs const IFF_UP = 1<<0; diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs index f624c03dc..2f21a95f3 100644 --- a/kernel/src/process/namespace/net_namespace.rs +++ b/kernel/src/process/namespace/net_namespace.rs @@ -1,4 +1,5 @@ use crate::arch::CurrentIrqArch; +use crate::driver::net::bridge::BridgeDriver; use crate::driver::net::loopback::{generate_loopback_iface_default, LoopbackInterface}; use crate::exception::InterruptArch; use crate::init::initcall::INITCALL_SUBSYS; @@ -69,6 +70,8 @@ pub struct NetNamespace { /// 这个列表在中断上下文会使用到,因此需要irqsave /// 没有放在InnerNetNamespace里面,独立出来,方便管理 device_list: RwLock>>, + ///当前网络命名空间下的桥接设备列表 + bridge_list: RwLock>>, // -- Netlink -- /// # 当前网络命名空间下的 Netlink 套接字表 @@ -115,6 +118,7 @@ impl NetNamespace { inner: RwLock::new(inner), net_poll_thread: SpinLock::new(None), device_list: RwLock::new(BTreeMap::new()), + bridge_list: RwLock::new(BTreeMap::new()), netlink_socket_table: NetlinkSocketTable::default(), netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); @@ -141,6 +145,7 @@ impl NetNamespace { inner: RwLock::new(inner), net_poll_thread: SpinLock::new(None), device_list: RwLock::new(BTreeMap::new()), + bridge_list: RwLock::new(BTreeMap::new()), netlink_socket_table: NetlinkSocketTable::default(), netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), }); @@ -162,12 +167,12 @@ impl NetNamespace { Self::new_empty(user_ns) } - pub fn device_list_write(&self) -> RwLockWriteGuard<'_, BTreeMap>> { - self.device_list.write_irqsave() + pub fn device_list_mut(&self) -> RwLockWriteGuard<'_, BTreeMap>> { + self.device_list.write() } pub fn device_list(&self) -> RwLockReadGuard<'_, BTreeMap>> { - self.device_list.read_irqsave() + self.device_list.read() } pub fn inner(&self) -> RwLockReadGuard<'_, InnerNetNamespace> { @@ -212,18 +217,20 @@ impl NetNamespace { pub fn add_device(&self, device: Arc) { device.set_net_namespace(self.self_ref.upgrade().unwrap()); - self.device_list - .write_irqsave() - .insert(device.nic_id(), device); + self.device_list_mut().insert(device.nic_id(), device); log::info!( "Network device added to namespace count: {:?}", - self.device_list.read_irqsave().len() + self.device_list().len() ); } pub fn remove_device(&self, nic_id: &usize) { - self.device_list.write_irqsave().remove(nic_id); + self.device_list_mut().remove(nic_id); + } + + pub fn insert_bridge(&self, bridge: Arc) { + self.bridge_list.write().insert(bridge.name(), bridge); } /// # 拉起网络命名空间的轮询线程 @@ -242,22 +249,12 @@ impl NetNamespace { fn polling(&self) { log::info!("net_poll thread started for namespace"); loop { - let mut has_work_done = false; for (_, iface) in self.device_list.read_irqsave().iter() { - // 这里检查poll的返回值,如果为 true 的话,则说明发生了网络事件,很有可能再次发生,则会再次调用poll - if iface.poll() { - has_work_done = true; - iface.poll(); - } - } - if has_work_done { - log::info!("fucking continue"); - continue; + iface.poll(); } let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; - ProcessManager::mark_sleep(true) - .expect("clocksource_watchdog_kthread:mark sleep failed"); + ProcessManager::mark_sleep(true).expect("netns_poll_kthread:mark sleep failed"); drop(irq_guard); // log::info!("net_poll thread going to sleep"); schedule(SchedMode::SM_NONE); From 29423ca2c67319ec20d221811fa457d02636c432 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 29 Aug 2025 19:18:34 +0800 Subject: [PATCH 30/36] =?UTF-8?q?feat:=20=E5=B0=86virtio=E7=BD=91=E5=8D=A1?= =?UTF-8?q?=E7=9A=84=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=E7=A7=BB=E5=8A=A8?= =?UTF-8?q?=E8=BF=9Bksoftirqd=E4=B8=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/mod.rs | 8 +++-- kernel/src/driver/net/napi.rs | 11 ++---- kernel/src/driver/net/veth.rs | 19 ++++------ kernel/src/driver/net/virtio_net.rs | 56 ++++++++++++++++++++++------- 4 files changed, 59 insertions(+), 35 deletions(-) diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 5c2b81dc1..0d49228c1 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -4,6 +4,7 @@ use alloc::{string::String, sync::Arc}; use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; +use crate::driver::net::napi::NapiStruct; use crate::driver::net::types::{InterfaceFlags, InterfaceType}; use crate::libs::rwlock::RwLockReadGuard; use crate::net::routing::RouterEnableDeviceCommon; @@ -151,7 +152,7 @@ pub trait Iface: crate::driver::base::device::Device { /// # 获取当前iface的napi结构体 /// 默认返回None,表示不支持napi fn napi_struct(&self) -> Option> { - None + self.common().napi_struct.read().clone() } } @@ -206,8 +207,10 @@ pub struct IfaceCommon { poll_at_ms: core::sync::atomic::AtomicU64, /// 网络命名空间 net_namespace: RwLock>, - // 路由相关数据 + /// 路由相关数据 router_common_data: RouterEnableDeviceCommon, + /// NAPI 结构体 + napi_struct: RwLock>>, } impl fmt::Debug for IfaceCommon { @@ -242,6 +245,7 @@ impl IfaceCommon { router_common_data, flags, type_, + napi_struct: RwLock::new(None), } } diff --git a/kernel/src/driver/net/napi.rs b/kernel/src/driver/net/napi.rs index 6541c89d7..1321ef930 100644 --- a/kernel/src/driver/net/napi.rs +++ b/kernel/src/driver/net/napi.rs @@ -130,22 +130,17 @@ fn net_rx_action() { // log::info!("NAPI softirq processing {} instances", poll_list.len()); - let size = poll_list.len(); // 如果此时长度为0,则让当前进程休眠,等待被唤醒 - if size == 0 { + if poll_list.is_empty() { GLOBAL_NAPI_MANAGER .inner() .has_pending_signal .store(false, Ordering::SeqCst); } - for _ in 0..size { - let Some(napi) = poll_list.pop() else { - break; - }; - + while let Some(napi) = poll_list.pop() { let has_work_left = napi.poll(); - log::info!("yes"); + // log::info!("yes"); if has_work_left { poll_list.push(napi); diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 0b0e2639c..334f41014 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -72,7 +72,12 @@ impl Veth { log::info!("{:?}", peer_veth.rx_queue); drop(peer_veth); - napi_schedule(peer.napi_struct()); + let Some(napi) = peer.napi_struct() else { + log::error!("Veth {} has no napi_struct", peer.name); + return; + }; + + napi_schedule(napi); } fn to_bridge(bridge_data: &BridgeCommonData, data: &[u8]) { @@ -267,7 +272,6 @@ pub struct VethInterface { common: IfaceCommon, inner: SpinLock, locked_kobj_state: LockedKObjectState, - napi_struct: SpinLock>>, } #[derive(Debug)] @@ -330,10 +334,9 @@ impl VethInterface { common: IfaceCommon::new(iface_id, super::types::InterfaceType::EETHER, flags, iface), inner: SpinLock::new(VethCommonData::default()), locked_kobj_state: LockedKObjectState::default(), - napi_struct: SpinLock::new(None), }); let napi_struct = NapiStruct::new(device.clone(), 10); - *device.napi_struct.lock() = Some(napi_struct); + *device.common.napi_struct.write() = Some(napi_struct); driver.inner.lock().self_iface_ref = Arc::downgrade(&device); @@ -428,10 +431,6 @@ impl VethInterface { // .expect("Add direct route failed"); // }); // } - - pub fn napi_struct(&self) -> Arc { - self.napi_struct.lock().as_ref().unwrap().clone() - } } impl KObject for VethInterface { @@ -613,10 +612,6 @@ impl Iface for VethInterface { .capabilities() .max_transmission_unit } - - fn napi_struct(&self) -> Option> { - Some(self.napi_struct()) - } } impl BridgeEnableDevice for VethInterface { diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 42b3d0d69..d5642596b 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -29,7 +29,11 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - net::{register_netdevice, types::InterfaceFlags}, + net::{ + napi::{napi_schedule, NapiStruct}, + register_netdevice, + types::InterfaceFlags, + }, virtio::{ irq::virtio_irq_manager, sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager}, @@ -43,11 +47,11 @@ use crate::{ filesystem::{kernfs::KernFSInode, sysfs::AttributeGroup}, init::initcall::INITCALL_POSTCORE, libs::{ - rwlock::{RwLockReadGuard, RwLockWriteGuard}, + rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, net::generate_iface_id, - process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::Instant, }; use system_error::SystemError; @@ -70,8 +74,8 @@ pub struct VirtIONetDevice { inner: SpinLock, locked_kobj_state: LockedKObjectState, - // 这里放netns是为了在中断到来的时候可以遍历poll当前命名空间下的网卡 - netns: Arc, + // 指向对应的interface + iface_ref: RwLock>, } impl Debug for VirtIONetDevice { @@ -100,11 +104,7 @@ impl Debug for InnerVirtIONetDevice { } impl VirtIONetDevice { - pub fn new( - transport: VirtIOTransport, - dev_id: Arc, - netns: Arc, - ) -> Option> { + pub fn new(transport: VirtIOTransport, dev_id: Arc) -> Option> { // 设置中断 if let Err(err) = transport.setup_irq(dev_id.clone()) { error!("VirtIONetDevice '{dev_id:?}' setup_irq failed: {:?}", err); @@ -133,7 +133,7 @@ impl VirtIONetDevice { device_common: DeviceCommonData::default(), }), locked_kobj_state: LockedKObjectState::default(), - netns, + iface_ref: RwLock::new(Weak::new()), }); // dev.set_driver(Some(Arc::downgrade(&virtio_net_driver()) as Weak)); @@ -144,6 +144,14 @@ impl VirtIONetDevice { fn inner(&self) -> SpinLockGuard<'_, InnerVirtIONetDevice> { return self.inner.lock(); } + + pub fn set_iface(&self, iface: &Arc) { + *self.iface_ref.write() = Arc::downgrade(iface); + } + + pub fn iface(&self) -> Option> { + self.iface_ref.read().upgrade() + } } impl KObject for VirtIONetDevice { @@ -278,7 +286,22 @@ impl Device for VirtIONetDevice { impl VirtIODevice for VirtIONetDevice { fn handle_irq(&self, _irq: IrqNumber) -> Result { - self.netns.wakeup_poll_thread(); + let Some(iface) = self.iface() else { + error!( + "VirtIONetDevice '{:?}' has no associated iface to handle irq", + self.dev_id.id() + ); + return Ok(IrqReturn::NotHandled); + }; + + let Some(napi) = iface.napi_struct() else { + log::error!("Virtio net device {} has no napi_struct", iface.name()); + return Ok(IrqReturn::NotHandled); + }; + + napi_schedule(napi); + + // self.netns.wakeup_poll_thread(); return Ok(IrqReturn::Handled); } @@ -525,6 +548,10 @@ impl VirtioInterface { }), }); + // 设置napi struct + let napi_struct = NapiStruct::new(iface.clone(), 10); + *iface.common().napi_struct.write() = Some(napi_struct); + iface } @@ -732,7 +759,7 @@ pub fn virtio_net( dev_id: Arc, dev_parent: Option>, ) { - let virtio_net_deivce = VirtIONetDevice::new(transport, dev_id, INIT_NET_NAMESPACE.clone()); + let virtio_net_deivce = VirtIONetDevice::new(transport, dev_id); if let Some(virtio_net_deivce) = virtio_net_deivce { debug!("VirtIONetDevice '{:?}' created", virtio_net_deivce.dev_id); if let Some(dev_parent) = dev_parent { @@ -936,6 +963,9 @@ impl VirtIODriver for VirtIONetDriver { // 在sysfs中注册iface register_netdevice(iface.clone() as Arc)?; + // 将virtio_net_device和iface关联起来 + virtio_net_device.set_iface(&iface); + // 将网卡的接口信息注册到全局的网卡接口信息表中 // NET_DEVICES // .write_irqsave() From 0e90cbe9875c1467b73aa6dfea887e423781f23f Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sat, 30 Aug 2025 18:19:59 +0800 Subject: [PATCH 31/36] =?UTF-8?q?feat(netlink):=20=E6=9A=82=E6=97=B6?= =?UTF-8?q?=E4=B8=BA=E5=A4=9A=E6=92=AD=E6=B6=88=E6=81=AF=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?allow=20unused,=E6=B6=88=E9=99=A4warning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/filesystem/vfs/mod.rs | 10 ++++++++++ kernel/src/net/socket/netlink/common/mod.rs | 3 +++ kernel/src/net/socket/netlink/table/mod.rs | 2 ++ 3 files changed, 15 insertions(+) diff --git a/kernel/src/filesystem/vfs/mod.rs b/kernel/src/filesystem/vfs/mod.rs index b22d8d14e..73f66cb3f 100644 --- a/kernel/src/filesystem/vfs/mod.rs +++ b/kernel/src/filesystem/vfs/mod.rs @@ -657,6 +657,16 @@ pub trait IndexNode: Any + Sync + Send + Debug + CastFromSync { Err(SystemError::ENOSYS) } + /// # 将当前Inode转换为Socket类型 + /// 如果当前Inode不是Socket类型,则返回None + /// + /// # 注意 + /// 这个方法已经为dyn Socket实现, + /// 所以如果可以确定当前`dyn IndexNode`是`dyn Socket`类型,则可以直接调用此方法进行转换 + fn as_socket(&self) -> Option<&dyn Socket> { + None + } + /// @brief 按文件名获取扩展属性 /// /// @param name 属性名称 diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 8017592d2..3f7f412ea 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -192,6 +192,7 @@ impl NetlinkSocket

{ .load(core::sync::atomic::Ordering::Relaxed) } + #[allow(unused)] pub fn set_nonblocking(&self, nonblocking: bool) { self.is_nonblocking .store(nonblocking, core::sync::atomic::Ordering::Relaxed); @@ -200,6 +201,7 @@ impl NetlinkSocket

{ // 多播消息的时候会用到,比如uevent impl Inner, BoundNetlink> { + #[allow(unused)] fn add_groups(&mut self, groups: GroupIdSet) { match self { Inner::Bound(bound) => bound.add_groups(groups), @@ -207,6 +209,7 @@ impl Inner, BoundNetlink unbound.drop_groups(groups), diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs index c5bce4ae7..adc4aab70 100644 --- a/kernel/src/net/socket/netlink/table/mod.rs +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -216,6 +216,8 @@ pub trait SupportedNetlinkProtocol: Debug { Self::socket_table(netns).read().unicast(dst_port, message) } + //todo 多播消息用 + #[allow(unused)] fn multicast( dst_groups: GroupIdSet, message: Self::Message, From af3ede8fdafd54ffea495f1a0b10b9a5396e2060 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sun, 31 Aug 2025 16:13:34 +0800 Subject: [PATCH 32/36] =?UTF-8?q?feat(nat):=20=E5=AE=9E=E7=8E=B0SNAT?= =?UTF-8?q?=E5=92=8CDNAT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/mod.rs | 4 + kernel/src/driver/net/veth.rs | 21 +- kernel/src/net/routing.rs | 298 ------------------ kernel/src/net/routing/mod.rs | 489 ++++++++++++++++++++++++++++++ kernel/src/net/routing/nat.rs | 413 +++++++++++++++++++++++++ user/apps/c_unitest/test_router.c | 5 +- 6 files changed, 928 insertions(+), 302 deletions(-) delete mode 100644 kernel/src/net/routing.rs create mode 100644 kernel/src/net/routing/mod.rs create mode 100644 kernel/src/net/routing/nat.rs diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 0d49228c1..3753a8909 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -154,6 +154,10 @@ pub trait Iface: crate::driver::base::device::Device { fn napi_struct(&self) -> Option> { self.common().napi_struct.read().clone() } + + fn router_common(&self) -> &RouterEnableDeviceCommon { + &self.common().router_common_data + } } /// 网络设备的公共数据 diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 334f41014..773375324 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -19,7 +19,7 @@ use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; use crate::net::generate_iface_id; -use crate::net::routing::{RouteEntry, RouterEnableDevice}; +use crate::net::routing::{DnatRule, RouteEntry, RouterEnableDevice, SnatRule}; use crate::process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}; use crate::process::ProcessManager; use alloc::collections::VecDeque; @@ -709,7 +709,7 @@ fn veth_route_test() { let cidr3 = IpCidr::new(addr3, 24); iface_host2.update_ip_addrs(cidr3); - let addr4 = IpAddress::v4(192, 168, 2, 1); + let addr4 = IpAddress::v4(192, 168, 2, 3); let cidr4 = IpCidr::new(addr4, 24); iface_ns2.update_ip_addrs(cidr4); @@ -766,6 +766,23 @@ fn veth_route_test() { turn_on(&iface_ns2, INIT_NET_NAMESPACE.clone()); turn_on(&iface_host1, INIT_NET_NAMESPACE.clone()); turn_on(&iface_host2, INIT_NET_NAMESPACE.clone()); + + let snat_rules = vec![SnatRule { + // 匹配所有来自 192.168.1.0/24 网络的流量 + source_cidr: "192.168.1.0/24".parse().unwrap(), + // 将源地址转换为 192.168.2.254(hardcode) + nat_ip: IpAddress::v4(192, 168, 2, 254), + }]; + + let dnat_rules = vec![DnatRule { + external_addr: IpAddress::v4(192, 168, 2, 1), + internal_addr: IpAddress::v4(192, 168, 2, 3), + internal_port: None, + external_port: None, + }]; + + host_router.nat_tracker().update_snat_rules(snat_rules); + host_router.nat_tracker().update_dnat_rules(dnat_rules); } #[unified_init(INITCALL_DEVICE)] diff --git a/kernel/src/net/routing.rs b/kernel/src/net/routing.rs deleted file mode 100644 index ab75f48a0..000000000 --- a/kernel/src/net/routing.rs +++ /dev/null @@ -1,298 +0,0 @@ -use crate::driver::net::Iface; -use crate::libs::rwlock::RwLock; -use crate::process::namespace::net_namespace::NetNamespace; -use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; -use alloc::string::{String, ToString}; -use alloc::sync::{Arc, Weak}; -use alloc::vec::Vec; -use smoltcp::wire::{EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; -use system_error::SystemError; - -#[derive(Debug, Clone)] -pub struct RouteEntry { - /// 目标网络 - pub destination: IpCidr, - /// 下一跳地址(如果是直连网络则为None) - pub next_hop: Option, - /// 出接口 - pub interface: Weak, - /// 路由优先级(数值越小优先级越高) - pub metric: u32, - /// 路由类型 - pub route_type: RouteType, -} - -#[derive(Debug, Clone, PartialEq)] -pub enum RouteType { - /// 直连路由 - Connected, - /// 静态路由 - Static, - /// 默认路由 - Default, -} - -impl RouteEntry { - pub fn new_connected(destination: IpCidr, interface: Arc) -> Self { - RouteEntry { - destination, - next_hop: None, - interface: Arc::downgrade(&interface), - metric: 0, - route_type: RouteType::Connected, - } - } - - pub fn new_static( - destination: IpCidr, - next_hop: IpAddress, - interface: Arc, - metric: u32, - ) -> Self { - RouteEntry { - destination, - next_hop: Some(next_hop), - interface: Arc::downgrade(&interface), - metric, - route_type: RouteType::Static, - } - } - - pub fn new_default(next_hop: IpAddress, interface: Arc) -> Self { - RouteEntry { - destination: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), - next_hop: Some(next_hop), - interface: Arc::downgrade(&interface), - metric: 100, - route_type: RouteType::Default, - } - } -} - -#[derive(Debug, Default)] -pub struct RouteTable { - pub entries: Vec, -} - -/// 路由决策结果 -#[derive(Debug)] -pub struct RouteDecision { - /// 出接口 - pub interface: Arc, - /// 下一跳地址(先写在这里 - pub next_hop: IpAddress, -} - -#[derive(Debug)] -pub struct Router { - name: String, - /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec - route_table: RwLock, - pub ns: RwLock>, -} - -impl Router { - pub fn new(name: String) -> Arc { - Arc::new(Self { - name: name.clone(), - route_table: RwLock::new(RouteTable::default()), - ns: RwLock::new(Weak::default()), - }) - } - - /// 创建一个空的Router实例,主要用于初始化网络命名空间时使用 - /// 注意: 这个Router实例不会启动轮询线程 - pub fn new_empty() -> Arc { - Arc::new(Self { - name: "empty_router".to_string(), - route_table: RwLock::new(RouteTable::default()), - ns: RwLock::new(Weak::default()), - }) - } - - pub fn add_route(&self, route: RouteEntry) { - let mut guard = self.route_table.write(); - let entries = &mut guard.entries; - let pos = entries - .iter() - .position(|r| r.metric > route.metric) - .unwrap_or(entries.len()); - - entries.insert(pos, route); - log::info!("Router {}: Added route to routing table", self.name); - } - - pub fn remove_route(&self, destination: IpCidr) { - self.route_table - .write() - .entries - .retain(|route| route.destination != destination); - } - - pub fn lookup_route(&self, dest_ip: IpAddress) -> Option { - let guard = self.route_table.read(); - // 按最长前缀匹配原则查找路由 - let best = guard - .entries - .iter() - .filter(|route| { - route.interface.strong_count() > 0 && route.destination.contains_addr(&dest_ip) - }) - .max_by_key(|route| route.destination.prefix_len()); - - if let Some(entry) = best { - if let Some(interface) = entry.interface.upgrade() { - let next_hop = entry.next_hop.unwrap_or(dest_ip); - return Some(RouteDecision { - interface, - next_hop, - }); - } - } - - None - } - - /// 清理无效的路由表项(接口已经不存在的) - pub fn cleanup_routes(&mut self) { - self.route_table - .write() - .entries - .retain(|route| route.interface.strong_count() > 0); - } -} - -/// 获取初始化网络命名空间下的路由表 -pub fn init_netns_router() -> Arc { - INIT_NET_NAMESPACE.router().clone() -} - -/// 可供路由设备应该实现的 trait -pub trait RouterEnableDevice: Iface { - /// # 网卡处理可路由的包 - /// ## 参数 - /// - `packet`: 需要处理的以太网帧 - /// ## 返回值 - /// - `Ok(())`: 通过路由处理成功 - /// - `Err(None)`: 忽略非IPv4包或没有路由到达的包,告诉外界没有经过处理,应该交由网卡进行默认处理 - /// - `Err(Some(SystemError))`: 处理失败,可能是包格式错误或其他系统错误 - fn handle_routable_packet( - &self, - ether_frame: &EthernetFrame<&[u8]>, - ) -> Result<(), Option> { - // 只处理IP包(IPv4) - if ether_frame.ethertype() != smoltcp::wire::EthernetProtocol::Ipv4 { - // 忽略非IPv4包 - log::info!( - "Ignoring non-IPv4 packet on interface {}", - self.iface_name() - ); - return Err(None); - } - - // log::info!( - // "src_mac: {}, dst_mac: {}", - // ether_frame.src_addr(), - // ether_frame.dst_addr() - // ); - - let ipv4_packet: Ipv4Packet<&[u8]> = match Ipv4Packet::new_checked(ether_frame.payload()) { - Ok(p) => p, - Err(_) => return Err(Some(SystemError::EINVAL)), - }; - - // log::info!( - // "src_ip: {}, dst_ip: {}", - // ipv4_packet.src_addr(), - // ipv4_packet.dst_addr() - // ); - - let dst_ip = ipv4_packet.dst_addr(); - - // 检查TTL - if ipv4_packet.hop_limit() <= 1 { - log::warn!("TTL exceeded for packet to {}", dst_ip); - return Err(Some(SystemError::EINVAL)); - } - - // 检查是否是发给自己的包(目标IP是否是自己的IP) - if self.is_my_ip(dst_ip.into()) { - // 交给本地协议栈处理 - log::info!("Packet destined for local interface {}", self.iface_name()); - return Err(None); - } - - // 查询当前网络命名空间下的路由表 - let router = self.netns_router(); - - let decision = match router.lookup_route(dst_ip.into()) { - Some(d) => d, - None => { - log::warn!("No route to {}", dst_ip); - return Err(None); - } - }; - - drop(router); - - // 检查是否是从同一个接口进来又要从同一个接口出去(避免回路) - if self.iface_name() == decision.interface.iface_name() { - log::info!( - "Ignoring packet loop from {} to {}", - self.iface_name(), - dst_ip - ); - return Err(None); - } - - // 创建修改后的IP包(递减TTL) - let modified_ip_packet = ether_frame.payload().to_vec(); - // if modified_ip_packet.len() >= 9 { - // modified_ip_packet[8] = modified_ip_packet[8].saturating_sub(1); - // //todo 这里应该重新计算IP校验和,为了简化先跳过 - // } - - // 交给出接口进行发送 - let next_hop = &decision.next_hop; - decision - .interface - .route_and_send(next_hop, &modified_ip_packet); - - log::info!("Routed packet from {} to {} ", self.iface_name(), dst_ip,); - Ok(()) - } - - /// 路由器决定通过此接口发送包时调用此方法 - /// 同Linux的ndo_start_xmit() - /// - /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( - /// 后续需要将arp查询的逻辑从smoltcp中抽离出来 - fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]); - - /// 检查IP地址是否是当前接口的IP - fn is_my_ip(&self, ip: IpAddress) -> bool; - - fn netns_router(&self) -> Arc { - self.net_namespace() - .map_or_else(init_netns_router, |ns| ns.router()) - } -} - -/// # 每一个`RouterEnableDevice`应该有的公共数据,包含 -/// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) -#[derive(Debug)] -pub struct RouterEnableDeviceCommon { - /// 当前接口的邻居缓存 - // pub arp_table: RwLock>, - /// 当前接口的IP地址列表(因为如果直接通过smoltcp获取ip的话可能导致死锁,因此则这里维护一份) - pub ip_addrs: RwLock>, -} - -impl Default for RouterEnableDeviceCommon { - fn default() -> Self { - Self { - // arp_table: RwLock::new(BTreeMap::new()), - ip_addrs: RwLock::new(Vec::new()), - } - } -} diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs new file mode 100644 index 000000000..5da0fc9e9 --- /dev/null +++ b/kernel/src/net/routing/mod.rs @@ -0,0 +1,489 @@ +use crate::driver::net::Iface; +use crate::libs::rwlock::RwLock; +use crate::net::routing::nat::ConnTracker; +use crate::net::routing::nat::DnatPolicy; +use crate::net::routing::nat::FiveTuple; +use crate::net::routing::nat::NatPktStatus; +use crate::net::routing::nat::NatPolicy; +use crate::net::routing::nat::SnatPolicy; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::net::Ipv4Addr; +use smoltcp::wire::{EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; +use system_error::SystemError; + +mod nat; + +pub use nat::{DnatRule, SnatRule}; + +#[derive(Debug, Clone)] +pub struct RouteEntry { + /// 目标网络 + pub destination: IpCidr, + /// 下一跳地址(如果是直连网络则为None) + pub next_hop: Option, + /// 出接口 + pub interface: Weak, + /// 路由优先级(数值越小优先级越高) + pub metric: u32, + /// 路由类型 + pub route_type: RouteType, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RouteType { + /// 直连路由 + Connected, + /// 静态路由 + Static, + /// 默认路由 + Default, +} + +impl RouteEntry { + pub fn new_connected(destination: IpCidr, interface: Arc) -> Self { + RouteEntry { + destination, + next_hop: None, + interface: Arc::downgrade(&interface), + metric: 0, + route_type: RouteType::Connected, + } + } + + pub fn new_static( + destination: IpCidr, + next_hop: IpAddress, + interface: Arc, + metric: u32, + ) -> Self { + RouteEntry { + destination, + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric, + route_type: RouteType::Static, + } + } + + pub fn new_default(next_hop: IpAddress, interface: Arc) -> Self { + RouteEntry { + destination: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric: 100, + route_type: RouteType::Default, + } + } +} + +#[derive(Debug, Default)] +pub struct RouteTable { + pub entries: Vec, +} + +/// 路由决策结果 +#[derive(Debug)] +pub struct RouteDecision { + /// 出接口 + pub interface: Arc, + /// 下一跳地址(先写在这里 + pub next_hop: IpAddress, +} + +#[derive(Debug)] +pub struct Router { + name: String, + /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec + route_table: RwLock, + pub(self) nat_tracker: Arc, + pub ns: RwLock>, +} + +impl Router { + pub fn new(name: String) -> Arc { + Arc::new(Self { + name: name.clone(), + route_table: RwLock::new(RouteTable::default()), + nat_tracker: Arc::new(ConnTracker::default()), + ns: RwLock::new(Weak::default()), + }) + } + + /// 创建一个空的Router实例,主要用于初始化网络命名空间时使用 + /// 注意: 这个Router实例不会启动轮询线程 + pub fn new_empty() -> Arc { + Arc::new(Self { + name: "empty_router".to_string(), + route_table: RwLock::new(RouteTable::default()), + ns: RwLock::new(Weak::default()), + nat_tracker: Arc::new(ConnTracker::default()), + }) + } + + pub fn add_route(&self, route: RouteEntry) { + let mut guard = self.route_table.write(); + let entries = &mut guard.entries; + let pos = entries + .iter() + .position(|r| r.metric > route.metric) + .unwrap_or(entries.len()); + + entries.insert(pos, route); + log::info!("Router {}: Added route to routing table", self.name); + } + + pub fn remove_route(&self, destination: IpCidr) { + self.route_table + .write() + .entries + .retain(|route| route.destination != destination); + } + + pub fn lookup_route(&self, dest_ip: IpAddress) -> Option { + let guard = self.route_table.read(); + // 按最长前缀匹配原则查找路由 + let best = guard + .entries + .iter() + .filter(|route| { + route.interface.strong_count() > 0 && route.destination.contains_addr(&dest_ip) + }) + .max_by_key(|route| route.destination.prefix_len()); + + if let Some(entry) = best { + if let Some(interface) = entry.interface.upgrade() { + let next_hop = entry.next_hop.unwrap_or(dest_ip); + return Some(RouteDecision { + interface, + next_hop, + }); + } + } + + None + } + + /// 清理无效的路由表项(接口已经不存在的) + pub fn cleanup_routes(&mut self) { + self.route_table + .write() + .entries + .retain(|route| route.interface.strong_count() > 0); + } + + pub fn nat_tracker(&self) -> Arc { + self.nat_tracker.clone() + } +} + +/// 获取初始化网络命名空间下的路由表 +pub fn init_netns_router() -> Arc { + INIT_NET_NAMESPACE.router().clone() +} + +/// 可供路由设备应该实现的 trait +pub trait RouterEnableDevice: Iface { + /// # 网卡处理可路由的包 + /// ## 参数 + /// - `packet`: 需要处理的以太网帧 + /// ## 返回值 + /// - `Ok(())`: 通过路由处理成功 + /// - `Err(None)`: 忽略非IPv4包或没有路由到达的包,告诉外界没有经过处理,应该交由网卡进行默认处理 + /// - `Err(Some(SystemError))`: 处理失败,可能是包格式错误或其他系统错误 + fn handle_routable_packet( + &self, + ether_frame: &EthernetFrame<&[u8]>, + ) -> Result<(), Option> { + match ether_frame.ethertype() { + smoltcp::wire::EthernetProtocol::Ipv4 => { + // 获取IPv4包的可变引用 + let mut payload_mut = ether_frame.payload().to_vec(); + let mut ipv4_packet_mut = + Ipv4Packet::new_checked(&mut payload_mut).map_err(|e| { + log::warn!("Invalid IPv4 packet: {:?}", e); + Some(SystemError::EINVAL) + })?; + + let maybe_tuple = FiveTuple::extract_from_ipv4_packet( + &Ipv4Packet::new_checked(ether_frame.payload()).unwrap(), + ); + + // === PRE-ROUTING HOOK === + + let pkt_status = self.pre_routing_hook(&maybe_tuple, &mut ipv4_packet_mut); + ipv4_packet_mut.fill_checksum(); + + // === PRE-ROUTING HOOK END === + + let dst_ip = ipv4_packet_mut.dst_addr(); + + // 检查TTL + if ipv4_packet_mut.hop_limit() <= 1 { + log::warn!("TTL exceeded for packet to {}", dst_ip); + return Err(Some(SystemError::EINVAL)); + } + + // 检查是否是发给自己的包(目标IP是否是自己的IP) + if self.is_my_ip(dst_ip.into()) { + // 交给本地协议栈处理 + log::info!("Packet destined for local interface {}", self.iface_name()); + return Err(None); + } + + // 查询当前网络命名空间下的路由表 + let router = self.netns_router(); + + let decision = match router.lookup_route(dst_ip.into()) { + Some(d) => d, + None => { + log::warn!("No route to {}", dst_ip); + return Err(None); + } + }; + + drop(router); + + // === POST-ROUTING HOOK === + + let decision_src_ip = decision.interface.common().ipv4_addr().unwrap(); + self.post_routing_hook( + &maybe_tuple, + &decision_src_ip, + &mut ipv4_packet_mut, + &pkt_status, + ); + ipv4_packet_mut.fill_checksum(); + + // === POST-ROUTING HOOK END === + + // 检查是否是从同一个接口进来又要从同一个接口出去(避免回路) + if self.iface_name() == decision.interface.iface_name() { + log::info!( + "Ignoring packet loop from {} to {}", + self.iface_name(), + dst_ip + ); + return Err(None); + } + + // 创建修改后的IP包(递减TTL) + sub_ttl_ipv4(&mut ipv4_packet_mut); + ipv4_packet_mut.fill_checksum(); + + // 交给出接口进行发送 + let next_hop = &decision.next_hop; + decision + .interface + .route_and_send(next_hop, ipv4_packet_mut.as_ref()); + + log::info!("Routed packet from {} to {} ", self.iface_name(), dst_ip); + Ok(()) + } + smoltcp::wire::EthernetProtocol::Arp => { + // 忽略ARP包 + log::info!( + "Ignoring non-IPv4 packet on interface {}", + self.iface_name() + ); + Err(None) + } + smoltcp::wire::EthernetProtocol::Ipv6 => { + log::warn!("IPv6 is not supported yet, ignoring packet"); + Err(None) + } + _ => { + log::warn!( + "Unknown ethertype {:?}, ignoring packet", + ether_frame.ethertype() + ); + Err(None) + } + } + } + + fn pre_routing_hook( + &self, + tuple: &Option, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) -> NatPktStatus { + let Some(tuple) = tuple else { + return NatPktStatus::Untouched; + }; + + let tracker = self.netns_router().nat_tracker(); + + if let Some((new_dst_ip, new_dst_port)) = tracker.snat.lock().process_return_traffic(tuple) + { + log::info!( + "Reverse SNAT: Translating {}:{} to {}:{}", + tuple.src_addr, + tuple.src_port, + new_dst_ip, + new_dst_port + ); + + SnatPolicy::update_dst( + tuple.src_addr, + new_dst_ip, + new_dst_port, + tuple.protocol, + ipv4_packet_mut, + ); + + let new_tuple = FiveTuple { + dst_addr: new_dst_ip, + dst_port: new_dst_port, + src_addr: tuple.src_addr, + src_port: tuple.src_port, + protocol: tuple.protocol, + }; + + return NatPktStatus::ReverseSnat(new_tuple); + } + + let mut dnat_guard = tracker.dnat.lock(); + if let Some((new_dst_ip, new_dst_port)) = dnat_guard.process_new_connection(tuple) { + log::info!( + "DNAT: Translating {}:{} to {}:{}", + tuple.dst_addr, + tuple.dst_port, + new_dst_ip, + new_dst_port + ); + + DnatPolicy::update_dst( + tuple.src_addr, + new_dst_ip, + new_dst_port, + tuple.protocol, + ipv4_packet_mut, + ); + + let new_tuple = FiveTuple { + dst_addr: new_dst_ip, + dst_port: new_dst_port, + src_addr: tuple.src_addr, + src_port: tuple.src_port, + protocol: tuple.protocol, + }; + + return NatPktStatus::NewDnat(new_tuple); + } + + return NatPktStatus::Untouched; + } + + fn post_routing_hook( + &self, + tuple: &Option, + _decision_src_ip: &Ipv4Addr, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + pkt_status: &NatPktStatus, + ) { + let tuple = match pkt_status { + NatPktStatus::ReverseSnat(t) => t, + NatPktStatus::NewDnat(t) => t, + NatPktStatus::Untouched => { + let Some(tuple) = tuple else { + return; + }; + tuple + } + }; + + let tracker = self.netns_router().nat_tracker(); + + if let Some((new_src_ip, new_src_port)) = tracker.dnat.lock().process_return_traffic(tuple) + { + log::info!( + "Reverse DNAT: Translating src {}:{} -> {}:{}", + tuple.src_addr, + tuple.src_port, + new_src_ip, + new_src_port + ); + + DnatPolicy::update_src( + tuple.dst_addr, + new_src_ip, + new_src_port, + tuple.protocol, + ipv4_packet_mut, + ); + + return; + } + + let mut snat_guard = tracker.snat.lock(); + if let Some((new_src_ip, new_src_port)) = snat_guard.process_new_connection(tuple) { + // log::info!( + // "SNAT: Translating {}:{} -> {}:{}", + // tuple.src_addr, + // tuple.src_port, + // new_src_ip, + // new_src_port + // ); + + //TODO 应该加一个判断snat,可以支持直接改成出口接口的ip + // // 修改源IP地址 + // let new_src_ip: IpAddress = if let IpAddress::Ipv4(new_src_ip) = new_src_ip { + // new_src_ip.into() + // } else { + // (*decision_src_ip).into() + // }; + + SnatPolicy::update_src( + tuple.dst_addr, + new_src_ip, + new_src_port, + tuple.protocol, + ipv4_packet_mut, + ); + + return; + } + } + + /// 路由器决定通过此接口发送包时调用此方法 + /// 同Linux的ndo_start_xmit() + /// + /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( + /// 后续需要将arp查询的逻辑从smoltcp中抽离出来 + fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]); + + /// 检查IP地址是否是当前接口的IP + fn is_my_ip(&self, ip: IpAddress) -> bool; + + fn netns_router(&self) -> Arc { + self.net_namespace() + .map_or_else(init_netns_router, |ns| ns.router()) + } +} + +fn sub_ttl_ipv4(ipv4_packet: &mut Ipv4Packet<&mut Vec>) { + let new_ttl = ipv4_packet.hop_limit().saturating_sub(1); + ipv4_packet.set_hop_limit(new_ttl); +} + +/// # 每一个`RouterEnableDevice`应该有的公共数据,包含 +/// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) +#[derive(Debug)] +pub struct RouterEnableDeviceCommon { + /// 当前接口的邻居缓存 + // pub arp_table: RwLock>, + /// 当前接口的IP地址列表(因为如果直接通过smoltcp获取ip的话可能导致死锁,因此则这里维护一份) + pub ip_addrs: RwLock>, +} + +impl Default for RouterEnableDeviceCommon { + fn default() -> Self { + Self { + // arp_table: RwLock::new(BTreeMap::new()), + ip_addrs: RwLock::new(Vec::new()), + } + } +} diff --git a/kernel/src/net/routing/nat.rs b/kernel/src/net/routing/nat.rs new file mode 100644 index 000000000..1aa898978 --- /dev/null +++ b/kernel/src/net/routing/nat.rs @@ -0,0 +1,413 @@ +use core::marker::PhantomData; + +use crate::libs::spinlock::SpinLock; +use crate::time::Duration; +use crate::time::Instant; +use alloc::fmt::Debug; +use alloc::vec::Vec; +use hashbrown::HashMap; +use smoltcp::wire::{IpAddress, IpCidr, Ipv4Packet}; + +pub(super) trait NatMapping: Debug + Clone + Copy { + fn last_seen(&self) -> Instant; + fn update_last_seen(&mut self, time: Instant); +} + +impl NatMapping for SnatMapping { + fn last_seen(&self) -> Instant { + self.last_seen + } + fn update_last_seen(&mut self, now: Instant) { + self.last_seen = now; + } +} + +impl NatMapping for DnatMapping { + fn last_seen(&self) -> Instant { + self.last_seen + } + fn update_last_seen(&mut self, now: Instant) { + self.last_seen = now; + } +} + +pub(super) trait NatPolicy { + type Rule: Debug + Clone; + type Mapping: NatMapping + Send + Sync; + + fn translate(rule: &Self::Rule, original: &FiveTuple) -> (FiveTuple, Self::Mapping); + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option; + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16); + + fn update_src( + dst_ip: IpAddress, + new_src_ip: IpAddress, + new_src_port: u16, + protocol: Protocol, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) { + // 修改源IP地址 + let IpAddress::Ipv4(new_src_ip_v4) = new_src_ip else { + return; + }; + ipv4_packet_mut.set_src_addr(new_src_ip_v4); + + let payload_mut = ipv4_packet_mut.payload_mut(); + match protocol { + Protocol::Tcp => { + let mut tcp_packet = smoltcp::wire::TcpPacket::new_checked(payload_mut).unwrap(); + tcp_packet.set_src_port(new_src_port); + // 重新计算TCP校验和 + tcp_packet.fill_checksum(&new_src_ip, &dst_ip); + } + Protocol::Udp => { + let mut udp_packet = smoltcp::wire::UdpPacket::new_checked(payload_mut).unwrap(); + udp_packet.set_src_port(new_src_port); + // 重新计算UDP校验和 + udp_packet.fill_checksum(&new_src_ip, &dst_ip); + } + _ => {} + } + } + + fn update_dst( + src_ip: IpAddress, + new_dst_ip: IpAddress, + new_dst_port: u16, + protocol: Protocol, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) { + let IpAddress::Ipv4(new_dst_ip_v4) = new_dst_ip else { + return; + }; + ipv4_packet_mut.set_dst_addr(new_dst_ip_v4); + + let payload_mut = ipv4_packet_mut.payload_mut(); + match protocol { + Protocol::Tcp => { + let mut tcp_packet = smoltcp::wire::TcpPacket::new_checked(payload_mut).unwrap(); + tcp_packet.set_dst_port(new_dst_port); + // 重新计算TCP校验和 + tcp_packet.fill_checksum(&src_ip, &new_dst_ip); + } + Protocol::Udp => { + let mut udp_packet = smoltcp::wire::UdpPacket::new_checked(payload_mut).unwrap(); + udp_packet.set_dst_port(new_dst_port); + // 重新计算UDP校验和 + udp_packet.fill_checksum(&src_ip, &new_dst_ip); + } + _ => {} + } + } +} + +#[derive(Debug)] +pub(super) struct NatTracker { + rules: Vec, + mappings: HashMap, + reverse_mappings: HashMap, + policy_marker: PhantomData

, +} + +impl Default for NatTracker

{ + fn default() -> Self { + Self { + rules: Vec::new(), + mappings: HashMap::new(), + reverse_mappings: HashMap::new(), + policy_marker: PhantomData, + } + } +} + +impl NatTracker

{ + pub fn update_rules(&mut self, rules: Vec) { + self.rules = rules; + } + + pub fn cleanup_expired(&mut self, now: Instant) { + // 收集需要移除的键 + let expired_keys: Vec = self + .mappings + .iter() + .filter(|(_, mapping)| { + now.duration_since(mapping.last_seen()).unwrap() > Duration::from_secs(300) + }) + .map(|(key, _)| *key) + .collect(); + + for key in expired_keys { + self.mappings.remove(&key); + // 注意:反向映射表的清理比较复杂,因为多个主映射可能共享一个反向键(如果实现端口复用) + // 或者需要遍历 reverse_mappings 找到 value 为 key 的条目并删除。 + // 为简单起见,这里只清理主表。更健壮的实现需要双向链接或引用计数。 + log::info!("Cleaned up expired connection for key: {:?}", key); + } + } + + pub fn insert_rule(&mut self, rule: P::Rule) { + self.rules.push(rule); + } + + pub fn process_new_connection(&mut self, tuple: &FiveTuple) -> Option<(IpAddress, u16)> { + // let rules = self.rules.lock(); + let matching_rule = P::find_matching_rule(&self.rules, tuple)?; + + let (translated_tuple, new_mapping) = P::translate(&matching_rule, tuple); + + self.mappings.insert(*tuple, new_mapping); + self.reverse_mappings + .insert(translated_tuple.reverse(), *tuple); + + // log::info!( + // "Created new NAT mapping. Original: {:?}, Translated: {:?}", + // tuple, + // translated_tuple + // ); + + // 返回转换后的地址和端口信息,用于修改数据包 + // 注意:这里需要区分是修改源地址还是目的地址,取决于调用者 (SNAT vs DNAT) + // SNAT返回新的src_ip/port, DNAT返回新的dst_ip/port. + // `translated_tuple` 包含了所有信息,我们返回对应的部分。 + if translated_tuple.src_addr != tuple.src_addr { + Some((translated_tuple.src_addr, translated_tuple.src_port)) + } else { + Some((translated_tuple.dst_addr, translated_tuple.dst_port)) + } + } + + /// 处理返回流量 + pub fn process_return_traffic(&mut self, tuple: &FiveTuple) -> Option<(IpAddress, u16)> { + if let Some(original_key) = self.reverse_mappings.get(tuple) { + if let Some(mapping) = self.mappings.get_mut(original_key) { + mapping.update_last_seen(Instant::now()); + return Some(P::get_translation_for_return_traffic(mapping)); + } + } + None + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SnatPolicy; + +impl NatPolicy for SnatPolicy { + type Rule = SnatRule; + type Mapping = SnatMapping; + + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option { + rules + .iter() + .find(|rule| rule.source_cidr.contains_addr(&tuple.src_addr)) + .cloned() + } + + fn translate(rule: &Self::Rule, original_tuple: &FiveTuple) -> (FiveTuple, Self::Mapping) { + let translated_tuple = FiveTuple { + src_addr: rule.nat_ip, + src_port: original_tuple.src_port, // 简化端口处理 + ..*original_tuple + }; + + let mapping = SnatMapping { + original: *original_tuple, + translated: translated_tuple, + last_seen: Instant::now(), + }; + + (translated_tuple, mapping) + } + + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16) { + // 返回流量需要修改目的地址为原始客户端地址 + (mapping.original.src_addr, mapping.original.src_port) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DnatPolicy; + +impl NatPolicy for DnatPolicy { + type Rule = DnatRule; + type Mapping = DnatMapping; + + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option { + rules + .iter() + .find(|rule| { + if tuple.dst_addr != rule.external_addr { + return false; + } + + match rule.external_port { + Some(port) => tuple.dst_port == port, + None => true, + } + }) + .cloned() + } + + fn translate(rule: &Self::Rule, original: &FiveTuple) -> (FiveTuple, Self::Mapping) { + let new_internal_port = match rule.internal_port { + Some(port) => port, + None => original.dst_port, + }; + + let translated_tuple = FiveTuple { + dst_addr: rule.internal_addr, + dst_port: new_internal_port, + ..*original + }; + + let mapping = DnatMapping { + from_client: *original, + to_server: translated_tuple, + last_seen: Instant::now(), + }; + + (translated_tuple, mapping) + } + + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16) { + // 返回流量需要修改源地址为原始客户端地址 + (mapping.from_client.dst_addr, mapping.from_client.dst_port) + } +} + +#[derive(Debug)] +pub struct ConnTracker { + pub(super) snat: SpinLock>, + pub(super) dnat: SpinLock>, +} + +impl ConnTracker { + pub fn cleanup_expired(&self, now: Instant) { + self.snat.lock().cleanup_expired(now); + self.dnat.lock().cleanup_expired(now); + } + + pub fn update_snat_rules(&self, rules: Vec) { + self.snat.lock().update_rules(rules); + } + + pub fn update_dnat_rules(&self, rules: Vec) { + self.dnat.lock().update_rules(rules); + } +} + +impl Default for ConnTracker { + fn default() -> Self { + Self { + snat: SpinLock::new(NatTracker::::default()), + dnat: SpinLock::new(NatTracker::::default()), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NatPktStatus { + Untouched, + ReverseSnat(FiveTuple), + NewDnat(FiveTuple), +} + +// SNAT 规则:匹配来自某个源地址段的流量,并将其转换为指定的公网IP +#[derive(Debug, Clone)] +pub struct SnatRule { + pub source_cidr: IpCidr, + pub nat_ip: IpAddress, +} + +/// SNAT 映射:记录一个连接的原始元组和转换后的元组 +#[derive(Debug, Clone, Copy)] +pub struct SnatMapping { + pub original: FiveTuple, + pub translated: FiveTuple, + pub last_seen: Instant, +} + +#[derive(Debug, Clone)] +pub struct DnatRule { + pub external_addr: IpAddress, + pub external_port: Option, + pub internal_addr: IpAddress, + pub internal_port: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct DnatMapping { + // The original tuple from the external client's perspective + pub from_client: FiveTuple, + // The tuple after DNAT, as seen by the internal server + pub to_server: FiveTuple, + pub last_seen: Instant, +} + +/// 五元组结构体,用于唯一标识一个网络连接 +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct FiveTuple { + pub src_addr: IpAddress, + pub dst_addr: IpAddress, + pub src_port: u16, + pub dst_port: u16, + pub protocol: Protocol, +} + +impl FiveTuple { + pub fn extract_from_ipv4_packet(packet: &Ipv4Packet<&[u8]>) -> Option { + let src_addr = packet.src_addr().into(); + let dst_addr = packet.dst_addr().into(); + let protocol = Protocol::from_bits_truncate(packet.next_header().into()); + + match protocol { + Protocol::Tcp => { + let tcp_packet = smoltcp::wire::TcpPacket::new_checked(packet.payload()).ok()?; + Some(FiveTuple { + protocol, + src_addr, + src_port: tcp_packet.src_port(), + dst_addr, + dst_port: tcp_packet.dst_port(), + }) + } + Protocol::Udp => { + let udp_packet = smoltcp::wire::UdpPacket::new_checked(packet.payload()).ok()?; + Some(FiveTuple { + protocol, + src_addr, + src_port: udp_packet.src_port(), + dst_addr, + dst_port: udp_packet.dst_port(), + }) + } + _ => None, + } + } + + pub fn reverse(&self) -> Self { + Self { + src_addr: self.dst_addr, + dst_addr: self.src_addr, + src_port: self.dst_port, + dst_port: self.src_port, + protocol: self.protocol, + } + } +} + +bitflags! { + pub struct Protocol: u8 { + const HopByHop = 0x00; + const Icmp = 0x01; + const Igmp = 0x02; + const Tcp = 0x06; + const Udp = 0x11; + const Ipv6Route = 0x2b; + const Ipv6Frag = 0x2c; + const IpSecEsp = 0x32; + const IpSecAh = 0x33; + const Icmpv6 = 0x3a; + const Ipv6NoNxt = 0x3b; + const Ipv6Opts = 0x3c; + } +} diff --git a/user/apps/c_unitest/test_router.c b/user/apps/c_unitest/test_router.c index 7ac968b6c..d49bcc674 100644 --- a/user/apps/c_unitest/test_router.c +++ b/user/apps/c_unitest/test_router.c @@ -8,7 +8,8 @@ #include #include -#define SERVER_IP "192.168.2.1" +#define SERVER_IP "192.168.2.3" +#define FAKE_SERVER_IP "192.168.2.1" #define CLIENT_IP "192.168.1.1" #define PORT 34254 #define BUFFER_SIZE 1024 @@ -113,7 +114,7 @@ void *client_func(void *arg) { memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(PORT); - if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + if (inet_pton(AF_INET, FAKE_SERVER_IP, &server_addr.sin_addr) <= 0) { handle_error_message("[client] Invalid server IP address for connect"); } From 9ae44335dff6be5d87315f77feba257d5d928b9a Mon Sep 17 00:00:00 2001 From: sparkzky Date: Mon, 1 Sep 2025 21:35:16 +0800 Subject: [PATCH 33/36] =?UTF-8?q?feat(epoll):=20=E6=9B=B4=E6=94=B9epoll?= =?UTF-8?q?=E5=94=A4=E9=86=92=E5=88=A4=E6=96=AD=E7=9A=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?,=E6=94=AF=E6=8C=81socket=E5=8A=A0=E5=85=A5epoll?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/driver/net/veth.rs | 77 ++++++- kernel/src/filesystem/epoll/event_poll.rs | 1 + kernel/src/net/routing/mod.rs | 10 +- kernel/src/net/socket/inet/mod.rs | 1 + kernel/src/net/socket/inet/stream/inner.rs | 1 + kernel/src/net/socket/inet/stream/mod.rs | 3 +- user/apps/c_unitest/test_epoll_socket.c | 223 +++++++++++++++++++++ 7 files changed, 301 insertions(+), 15 deletions(-) create mode 100644 user/apps/c_unitest/test_epoll_socket.c diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs index 773375324..95a13718e 100644 --- a/kernel/src/driver/net/veth.rs +++ b/kernel/src/driver/net/veth.rs @@ -68,8 +68,40 @@ impl Veth { pub(self) fn to_peer(peer: &Arc, data: &[u8]) { let mut peer_veth = peer.driver.inner.lock(); peer_veth.rx_queue.push_back(data.to_vec()); - log::info!("Veth {} received data from peer", peer.name); - log::info!("{:?}", peer_veth.rx_queue); + + // { + // let ether = EthernetFrame::new_checked(data).unwrap(); + // if ether.ethertype() == smoltcp::wire::EthernetProtocol::Ipv4 { + // if let Some(ipv4_packet) = + // smoltcp::wire::Ipv4Packet::new_checked(ether.payload()).ok() + // { + // log::info!( + // "Veth {} sending IPv4 packet to peer: {} -> {}", + // peer.name, + // ipv4_packet.src_addr(), + // ipv4_packet.dst_addr() + // ); + // } + // } else if ether.ethertype() == smoltcp::wire::EthernetProtocol::Ipv6 { + // if let Some(ipv6_packet) = + // smoltcp::wire::Ipv6Packet::new_checked(ether.payload()).ok() + // { + // log::info!( + // "Veth {} sending IPv6 packet to peer: {} -> {}", + // peer.name, + // ipv6_packet.src_addr(), + // ipv6_packet.dst_addr() + // ); + // } + // } else { + // log::info!( + // "Veth {} sending non-IP packet to peer: ethertype={:?}", + // peer.name, + // ether.ethertype() + // ); + // } + // } + drop(peer_veth); let Some(napi) = peer.napi_struct() else { @@ -96,20 +128,20 @@ impl Veth { }; let frame: EthernetFrame<&[u8]> = EthernetFrame::new_checked(data).unwrap(); - log::info!("trying to go to router"); + // log::info!("trying to go to router"); match self_iface.handle_routable_packet(&frame) { Ok(_) => { - log::info!("successfully sent to router"); - return true; + // log::info!("successfully sent to router"); + true } // 先不管错误,直接告诉外面没有经过路由发送出去 Err(Some(err)) => { log::error!("Router error: {:?}", err); - return false; + false } Err(_) => { - log::info!("not routed"); - return false; + // log::info!("not routed"); + false } } } @@ -785,9 +817,36 @@ fn veth_route_test() { host_router.nat_tracker().update_dnat_rules(dnat_rules); } +fn veth_epoll_test() { + let (iface1, iface2) = VethInterface::new_pair("veth1", "veth2"); + + let addr1 = IpAddress::v4(111, 111, 11, 1); + let cidr1 = IpCidr::new(addr1, 24); + iface1.update_ip_addrs(cidr1); + + let addr2 = IpAddress::v4(111, 111, 11, 2); + let cidr2 = IpCidr::new(addr2, 24); + iface2.update_ip_addrs(cidr2); + + iface1.add_default_route_to_peer(addr2); + iface2.add_default_route_to_peer(addr1); + + let turn_on = |a: &Arc, ns: Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + ns.add_device(a.clone()); + a.common().set_net_namespace(ns.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + turn_on(&iface1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface2, INIT_NET_NAMESPACE.clone()); +} + #[unified_init(INITCALL_DEVICE)] pub fn veth_init() -> Result<(), SystemError> { - // veth_probe("veth0", "veth1"); + veth_epoll_test(); veth_route_test(); log::info!("Veth pair initialized."); Ok(()) diff --git a/kernel/src/filesystem/epoll/event_poll.rs b/kernel/src/filesystem/epoll/event_poll.rs index d5075242f..01dc11f7f 100644 --- a/kernel/src/filesystem/epoll/event_poll.rs +++ b/kernel/src/filesystem/epoll/event_poll.rs @@ -663,6 +663,7 @@ impl EventPoll { epoll_guard.ep_add_ready(epitem.clone()); if epoll_guard.ep_has_waiter() { + // log::info!("wakeup epoll waiters"); if ep_events.contains(EPollEventType::EPOLLEXCLUSIVE) && !pollflags.contains(EPollEventType::POLLFREE) { diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs index 5da0fc9e9..4845d01e7 100644 --- a/kernel/src/net/routing/mod.rs +++ b/kernel/src/net/routing/mod.rs @@ -230,7 +230,7 @@ pub trait RouterEnableDevice: Iface { // 检查是否是发给自己的包(目标IP是否是自己的IP) if self.is_my_ip(dst_ip.into()) { // 交给本地协议栈处理 - log::info!("Packet destined for local interface {}", self.iface_name()); + // log::info!("Packet destined for local interface {}", self.iface_name()); return Err(None); } @@ -285,10 +285,10 @@ pub trait RouterEnableDevice: Iface { } smoltcp::wire::EthernetProtocol::Arp => { // 忽略ARP包 - log::info!( - "Ignoring non-IPv4 packet on interface {}", - self.iface_name() - ); + // log::info!( + // "Ignoring non-IPv4 packet on interface {}", + // self.iface_name() + // ); Err(None) } smoltcp::wire::EthernetProtocol::Ipv6 => { diff --git a/kernel/src/net/socket/inet/mod.rs b/kernel/src/net/socket/inet/mod.rs index aec4e66f0..8090ba625 100644 --- a/kernel/src/net/socket/inet/mod.rs +++ b/kernel/src/net/socket/inet/mod.rs @@ -40,6 +40,7 @@ pub trait InetSocket: Socket { fn on_iface_events(&self); fn notify(&self) { + // log::info!("InetSocket::notify"); self.on_iface_events(); let _ = EventPoll::wakeup_epoll(self.epoll_items().as_ref(), self.check_io_event()); } diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index 6d42c8de9..779bc78bd 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -338,6 +338,7 @@ impl Listening { } pub fn update_io_events(&self, pollee: &AtomicUsize) { + // log::info!("Listening::update_io_events"); let position = self.inners.iter().position(|inner| { inner.with::(|socket| socket.is_active()) }); diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index 6df588d8a..92a09ee1f 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -37,7 +37,7 @@ pub struct TcpSocket { } impl TcpSocket { - pub fn new(_nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { + pub fn new(nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { let netns = ProcessManager::current_netns(); Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(inner::Inner::Init(inner::Init::new(ver)))), @@ -192,6 +192,7 @@ impl TcpSocket { writer.replace(inner); drop(writer); + // log::info!("TcpSocket::finish_connect: {:?}", result); result } diff --git a/user/apps/c_unitest/test_epoll_socket.c b/user/apps/c_unitest/test_epoll_socket.c new file mode 100644 index 000000000..9d7d02ad1 --- /dev/null +++ b/user/apps/c_unitest/test_epoll_socket.c @@ -0,0 +1,223 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "111.111.11.1" +#define CLIENT_IP "111.111.11.2" +#define PORT 8888 +#define MAX_EVENTS 10 +#define BUFFER_SIZE 1024 + +// 函数声明 +void server_process(); +void client_process(); + +int main() { + pid_t pid = fork(); + + if (pid < 0) { + perror("fork failed"); + exit(EXIT_FAILURE); + } else if (pid == 0) { + // 子进程作为客户端 + // 等待一秒,确保服务器已启动 + sleep(1); + client_process(); + } else { + // 父进程作为服务器 + server_process(); + } + + return 0; +} + + +// 服务器进程逻辑 +void server_process() { + printf("[Server] Starting server process...\n"); + + int listen_sock, conn_sock, epoll_fd; + struct sockaddr_in server_addr, client_addr; + socklen_t client_len = sizeof(client_addr); + struct epoll_event ev, events[MAX_EVENTS]; + char buffer[BUFFER_SIZE]; + int data_processed = 0; // 添加标志位,标记是否已处理过数据 + + if ((listen_sock = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0)) < 0) { + perror("[Server] socket creation failed"); + exit(EXIT_FAILURE); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + perror("[Server] inet_pton failed"); + exit(EXIT_FAILURE); + } + + if (bind(listen_sock, + (struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + perror("[Server] bind failed"); + exit(EXIT_FAILURE); + } + + if (listen(listen_sock, 1) < 0) { + perror("[Server] listen failed"); + exit(EXIT_FAILURE); + } + printf("[Server] Listening on %s:%d\n", SERVER_IP, PORT); + + if ((epoll_fd = epoll_create1(0)) < 0) { + perror("[Server] epoll_create1 failed"); + exit(EXIT_FAILURE); + } + + ev.events = EPOLLIN; + ev.data.fd = listen_sock; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_sock, &ev) < 0) { + perror("[Server] epoll_ctl: listen_sock failed"); + exit(EXIT_FAILURE); + } + printf("Adding listening socket %d to epoll\n", listen_sock); + + while (!data_processed) { + int nfds = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); + if (nfds < 0) { + perror("[Server] epoll_wait failed"); + exit(EXIT_FAILURE); + } + printf("Fuck epoll_wait returned %d\n", nfds); + + for (int n = 0; n < nfds; ++n) { + if (events[n].data.fd == listen_sock) { + printf("trying to accept new connection...\n"); + while (1) { + conn_sock = accept(listen_sock, + (struct sockaddr *)&client_addr, + &client_len); + if (conn_sock < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + printf("All incoming connections have been " + "processed.\n"); + break; + } else { + perror("accept error"); + exit(EXIT_FAILURE); + break; + } + } + ev.events = EPOLLIN | EPOLLET; + ev.data.fd = conn_sock; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn_sock, &ev) < + 0) { + perror("[Server] epoll_ctl: conn_sock failed"); + exit(EXIT_FAILURE); + } + printf("[Server] Accepted connection from %s:%d\n", + inet_ntoa(client_addr.sin_addr), + ntohs(client_addr.sin_port)); + } + } else { + printf("[Server] handling client data...\n"); + int client_fd = events[n].data.fd; + int nread = read(client_fd, buffer, BUFFER_SIZE); + if (nread == -1) { + if (errno != EAGAIN) { + perror("[Server] read error"); + close(client_fd); + } + } else if (nread == 0) { + printf("[Server] Client disconnected.\n"); + close(client_fd); + } else { + buffer[nread] = '\0'; + printf("[Server] Received from client: %s\n", buffer); + write(client_fd, buffer, nread); + printf("[Server] Echoed data back to client. Server will " + "now exit.\n"); + data_processed = 1; // 设置退出标志 + sleep(3); + + close(client_fd); + break; + } + } + } + } + + printf("[Server] Server process completed.\n"); + close(listen_sock); + close(epoll_fd); +} + +// 客户端进程逻辑 +void client_process() { + printf("[Client] Starting client process...\n"); + + int sock = 0; + struct sockaddr_in client_bind_addr, server_addr; + char buffer[BUFFER_SIZE] = {0}; + + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("[Client] socket creation failed"); + exit(EXIT_FAILURE); + } + + memset(&client_bind_addr, 0, sizeof(client_bind_addr)); + client_bind_addr.sin_family = AF_INET; + client_bind_addr.sin_port = htons(7777); + if (inet_pton(AF_INET, CLIENT_IP, &client_bind_addr.sin_addr) <= 0) { + perror("[Client] inet_pton for bind failed"); + exit(EXIT_FAILURE); + } + + if (bind(sock, + (struct sockaddr *)&client_bind_addr, + sizeof(client_bind_addr)) < 0) { + perror("[Client] bind failed"); + exit(EXIT_FAILURE); + } + printf("[Client] Bound to IP %s\n", CLIENT_IP); + + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + perror("[Client] inet_pton for connect failed"); + exit(EXIT_FAILURE); + } + + if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < + 0) { + perror("[Client] connect failed"); + exit(EXIT_FAILURE); + } + printf("[Client] Connected to server %s:%d\n", SERVER_IP, PORT); + + + const char *message = "Hello from client"; + write(sock, message, strlen(message)); + printf("[Client] Sent: %s\n", message); + sleep(1); + + int valread = read(sock, buffer, BUFFER_SIZE); + if (valread > 0) { + buffer[valread] = '\0'; + printf("[Client] Received: %s\n", buffer); + } + + printf("[Client] Client process completed.\n"); + close(sock); +} \ No newline at end of file From 3c67759f9d1ce4b1203a14fa5600206cc4cd2f9b Mon Sep 17 00:00:00 2001 From: sparkzky Date: Mon, 1 Sep 2025 22:14:03 +0800 Subject: [PATCH 34/36] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9test=5Fbind,?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E7=88=86=E5=86=85=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/socket/inet/syscall.rs | 2 +- user/apps/c_unitest/test_bind.c | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/kernel/src/net/socket/inet/syscall.rs b/kernel/src/net/socket/inet/syscall.rs index d363feb44..1b0d168ec 100644 --- a/kernel/src/net/socket/inet/syscall.rs +++ b/kernel/src/net/socket/inet/syscall.rs @@ -25,7 +25,7 @@ pub fn create_inet_socket( }, PSOCK::Stream => match protocol { IpProtocol::HopByHop | IpProtocol::Tcp => { - log::debug!("create tcp socket"); + // log::debug!("create tcp socket"); return Ok(TcpSocket::new(is_nonblock, version)); } _ => { diff --git a/user/apps/c_unitest/test_bind.c b/user/apps/c_unitest/test_bind.c index 420b597e5..6a6f3e1f2 100644 --- a/user/apps/c_unitest/test_bind.c +++ b/user/apps/c_unitest/test_bind.c @@ -139,6 +139,9 @@ void test_all_ports() { } count++; + + close(tcp_fd); + if (count>=100) break; } printf("===TEST 10===\n"); printf("count: %d\n", count); From 22a51be6bcfa6aeae5cbf6412f99228a63250f6d Mon Sep 17 00:00:00 2001 From: sparkzky Date: Sat, 6 Sep 2025 08:46:31 +0000 Subject: [PATCH 35/36] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E4=B8=AA=E8=B7=AF=E7=94=B1todo=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/net/routing/mod.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs index 4845d01e7..205dcbc9b 100644 --- a/kernel/src/net/routing/mod.rs +++ b/kernel/src/net/routing/mod.rs @@ -229,6 +229,8 @@ pub trait RouterEnableDevice: Iface { // 检查是否是发给自己的包(目标IP是否是自己的IP) if self.is_my_ip(dst_ip.into()) { + // todo 按照linux的逻辑,只要包的目标ip在当前网络命名空间里面,就直接进入本地协议栈处理 + // todo 但是我们的操作系统中每个接口都是独立的,并没有统一处理和分发(socket),所有这里必须将包放到对应iface的接收队列里面 // 交给本地协议栈处理 // log::info!("Packet destined for local interface {}", self.iface_name()); return Err(None); @@ -456,6 +458,7 @@ pub trait RouterEnableDevice: Iface { fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]); /// 检查IP地址是否是当前接口的IP + /// todo 这里实现有误,不应该判断是否当前接口的IP,而是应该判断是否是当前网络命名空间的IP,然脏 fn is_my_ip(&self, ip: IpAddress) -> bool; fn netns_router(&self) -> Arc { From e2fa4a587fe1ef377ae1aede376dfc3cabfa3e57 Mon Sep 17 00:00:00 2001 From: sparkzky Date: Fri, 12 Sep 2025 06:58:45 +0000 Subject: [PATCH 36/36] =?UTF-8?q?fix:=20rebase=E4=B8=BB=E7=BA=BF=E4=B9=8B?= =?UTF-8?q?=E5=90=8E=E4=BF=AE=E6=94=B9=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sparkzky --- kernel/src/filesystem/vfs/mod.rs | 24 ++++--- kernel/src/net/posix.rs | 21 +++++-- kernel/src/net/routing/mod.rs | 16 ++--- kernel/src/net/socket/base.rs | 7 --- kernel/src/net/socket/inet/datagram/mod.rs | 23 ------- kernel/src/net/socket/inet/stream/mod.rs | 7 +-- kernel/src/net/socket/inode.rs | 4 ++ kernel/src/net/socket/netlink/common/mod.rs | 53 +++++++++++++++- kernel/src/net/socket/netlink/mod.rs | 70 ++++++++++----------- kernel/src/net/socket/utils/mod.rs | 26 +++++--- kernel/src/net/syscall.rs | 63 +++++++++++++------ kernel/src/process/mod.rs | 25 +++++--- user/apps/c_unitest/test_epoll_socket.c | 1 - 13 files changed, 207 insertions(+), 133 deletions(-) diff --git a/kernel/src/filesystem/vfs/mod.rs b/kernel/src/filesystem/vfs/mod.rs index 73f66cb3f..f6ff006d9 100644 --- a/kernel/src/filesystem/vfs/mod.rs +++ b/kernel/src/filesystem/vfs/mod.rs @@ -26,6 +26,7 @@ use crate::{ spinlock::{SpinLock, SpinLockGuard}, }, mm::{fault::PageFaultMessage, VmFaultReason}, + net::socket::Socket, process::ProcessManager, time::PosixTimeSpec, }; @@ -657,16 +658,6 @@ pub trait IndexNode: Any + Sync + Send + Debug + CastFromSync { Err(SystemError::ENOSYS) } - /// # 将当前Inode转换为Socket类型 - /// 如果当前Inode不是Socket类型,则返回None - /// - /// # 注意 - /// 这个方法已经为dyn Socket实现, - /// 所以如果可以确定当前`dyn IndexNode`是`dyn Socket`类型,则可以直接调用此方法进行转换 - fn as_socket(&self) -> Option<&dyn Socket> { - None - } - /// @brief 按文件名获取扩展属性 /// /// @param name 属性名称 @@ -697,6 +688,19 @@ pub trait IndexNode: Any + Sync + Send + Debug + CastFromSync { ); return Err(SystemError::ENOSYS); } + + /// # 将当前Inode转换为 Socket 引用 + /// + /// # 返回值 + /// - Some(&dyn Socket): 当前Inode是Socket类型,返回其引用 + /// - None: 当前Inode不是Socket类型 + /// + /// # 注意 + /// 这个方法已经为dyn Socket实现, + /// 所以如果可以确定当前`dyn IndexNode`是`dyn Socket`类型,则可以直接调用此方法进行转换 + fn as_socket(&self) -> Option<&dyn Socket> { + None + } } impl DowncastArc for dyn IndexNode { diff --git a/kernel/src/net/posix.rs b/kernel/src/net/posix.rs index ec2a4f3e6..b1892f5b5 100644 --- a/kernel/src/net/posix.rs +++ b/kernel/src/net/posix.rs @@ -36,13 +36,12 @@ impl PosixArgsSocketType { } use super::socket::{endpoint::Endpoint, AddressFamily}; +use crate::net::socket::netlink::addr::{multicast::GroupIdSet, NetlinkSocketAddr}; use crate::net::socket::unix::UnixEndpoint; use alloc::string::ToString; use core::ffi::CStr; use system_error::SystemError; -use crate::net::socket::netlink::addr::{GroupIdSet, NetlinkSocketAddr}; - // 参考资料: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/netinet_in.h.html#tag_13_32 #[repr(C)] #[derive(Debug, Clone, Copy)] @@ -105,8 +104,8 @@ impl From for SockAddr { smoltcp::wire::IpAddress::Ipv4(ipv4_addr) => Self { addr_in: SockAddrIn { sin_family: AddressFamily::INet as u16, - sin_port: value.port, - sin_addr: ipv4_addr.to_bits(), + sin_port: value.port.to_be(), + sin_addr: ipv4_addr.to_bits().to_be(), sin_zero: Default::default(), }, }, @@ -148,12 +147,26 @@ impl From for SockAddr { } } +impl From for SockAddr { + fn from(value: NetlinkSocketAddr) -> Self { + SockAddr { + addr_nl: SockAddrNl { + nl_family: AddressFamily::Netlink, + nl_pad: 0, + nl_pid: value.port(), + nl_groups: value.groups().as_u32(), + }, + } + } +} + impl From for SockAddr { fn from(value: Endpoint) -> Self { match value { Endpoint::LinkLayer(_link_layer_endpoint) => todo!(), Endpoint::Ip(endpoint) => Self::from(endpoint), Endpoint::Unix(unix_endpoint) => Self::from(unix_endpoint), + Endpoint::Netlink(netlink_addr) => Self::from(netlink_addr), } } } diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs index 205dcbc9b..5c46e0e0c 100644 --- a/kernel/src/net/routing/mod.rs +++ b/kernel/src/net/routing/mod.rs @@ -230,7 +230,7 @@ pub trait RouterEnableDevice: Iface { // 检查是否是发给自己的包(目标IP是否是自己的IP) if self.is_my_ip(dst_ip.into()) { // todo 按照linux的逻辑,只要包的目标ip在当前网络命名空间里面,就直接进入本地协议栈处理 - // todo 但是我们的操作系统中每个接口都是独立的,并没有统一处理和分发(socket),所有这里必须将包放到对应iface的接收队列里面 + // todo 但是我们的操作系统中每个接口都是独立的,并没有统一处理和分发(socket),所有这里必须将包放到对应iface的接收队列里面 // 交给本地协议栈处理 // log::info!("Packet destined for local interface {}", self.iface_name()); return Err(None); @@ -422,13 +422,13 @@ pub trait RouterEnableDevice: Iface { let mut snat_guard = tracker.snat.lock(); if let Some((new_src_ip, new_src_port)) = snat_guard.process_new_connection(tuple) { - // log::info!( - // "SNAT: Translating {}:{} -> {}:{}", - // tuple.src_addr, - // tuple.src_port, - // new_src_ip, - // new_src_port - // ); + log::info!( + "SNAT: Translating {}:{} -> {}:{}", + tuple.src_addr, + tuple.src_port, + new_src_ip, + new_src_port + ); //TODO 应该加一个判断snat,可以支持直接改成出口接口的ip // // 修改源IP地址 diff --git a/kernel/src/net/socket/base.rs b/kernel/src/net/socket/base.rs index 105530182..4620ed2ff 100644 --- a/kernel/src/net/socket/base.rs +++ b/kernel/src/net/socket/base.rs @@ -138,11 +138,4 @@ pub trait Socket: PollableInode + IndexNode { fn write(&self, buffer: &[u8]) -> Result { self.send(buffer, PMSG::empty()) } - - fn into_socket(this: Arc) -> Option> - where - Self: Sized, - { - Some(this) - } } diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 16dc94503..45980e977 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -142,29 +142,6 @@ impl UdpSocket { return result; } - pub fn event(&self) -> EPollEventType { - // log::info!("UdpSocket::event"); - let mut event = EPollEventType::empty(); - match self.inner.read().as_ref().unwrap() { - UdpInner::Unbound(_) => { - event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); - } - UdpInner::Bound(bound) => { - let (can_recv, can_send) = - bound.with_socket(|socket| (socket.can_recv(), socket.can_send())); - - if can_recv { - event.insert(EP::EPOLLIN | EP::EPOLLRDNORM); - } - - if can_send { - event.insert(EP::EPOLLOUT | EP::EPOLLWRNORM | EP::EPOLLWRBAND); - } - } - } - return event; - } - pub fn netns(&self) -> Arc { self.netns.clone() } diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index 92a09ee1f..c1466f5a2 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -4,13 +4,10 @@ use system_error::SystemError; use crate::libs::rwlock::RwLock; use crate::libs::wait_queue::WaitQueue; -use crate::net::socket::common::shutdown::{ShutdownBit, ShutdownTemp}; use crate::net::socket::common::EPollItems; -use crate::net::socket::endpoint::Endpoint; -use crate::net::socket::{Socket, SocketInode, PMSG, PSOL}; +use crate::net::socket::{common::ShutdownBit, endpoint::Endpoint, Socket, PMSG, PSOL}; use crate::process::namespace::net_namespace::NetNamespace; use crate::process::ProcessManager; - use crate::sched::SchedMode; use smoltcp; @@ -268,7 +265,7 @@ impl TcpSocket { #[inline] fn incoming(&self) -> bool { - EP::from_bits_truncate(self.poll() as u32).contains(EP::EPOLLIN) + EP::from_bits_truncate(self.do_poll() as u32).contains(EP::EPOLLIN) } #[inline] diff --git a/kernel/src/net/socket/inode.rs b/kernel/src/net/socket/inode.rs index 37802006b..3601ce667 100644 --- a/kernel/src/net/socket/inode.rs +++ b/kernel/src/net/socket/inode.rs @@ -68,6 +68,10 @@ impl IndexNode for T { ModeType::from_bits_truncate(0o755), )) } + + fn as_socket(&self) -> Option<&dyn Socket> { + Some(self) + } } impl PollableInode for T { diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs index 3f7f412ea..ed2e88381 100644 --- a/kernel/src/net/socket/netlink/common/mod.rs +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -90,6 +90,10 @@ where .contains(EPollEventType::EPOLLIN) } + pub fn do_poll(&self) -> usize { + self.inner.read().check_io_events().bits() as usize + } + pub fn netns(&self) -> Arc { self.netns.clone() } @@ -164,8 +168,8 @@ where // self.try_recv(buffer, flags) } - fn poll(&self) -> usize { - self.inner.read().check_io_events().bits() as usize + fn check_io_event(&self) -> crate::filesystem::epoll::EPollEventType { + EPollEventType::from_bits_truncate(self.do_poll() as u32) } fn send_buffer_size(&self) -> usize { @@ -184,6 +188,51 @@ where let (len, _) = self.recv_from(buffer, flags, None)?; Ok(len) } + + fn recv_msg( + &self, + _msg: &mut crate::net::posix::MsgHdr, + _flags: PMSG, + ) -> Result { + todo!("implement recv_msg for netlink socket"); + } + + fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + self.try_send(buffer, None, flags) + } + + fn send_msg( + &self, + _msg: &crate::net::posix::MsgHdr, + _flags: PMSG, + ) -> Result { + todo!("implement send_msg for netlink socket"); + } + + fn do_close(&self) -> Result<(), SystemError> { + //TODO close the socket properly + Ok(()) + } + + fn epoll_items(&self) -> &crate::net::socket::common::EPollItems { + todo!("implement epoll_items for netlink socket"); + } + + fn local_endpoint(&self) -> Result { + if let Some(addr) = self.inner.read().addr() { + Ok(addr.into()) + } else { + Err(SystemError::ENOTCONN) + } + } + + fn remote_endpoint(&self) -> Result { + if let Some(addr) = self.inner.read().peer_addr() { + Ok(addr.into()) + } else { + Err(SystemError::ENOTCONN) + } + } } impl NetlinkSocket

{ diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs index 7a9df50e5..59382aec3 100644 --- a/kernel/src/net/socket/netlink/mod.rs +++ b/kernel/src/net/socket/netlink/mod.rs @@ -1,10 +1,9 @@ use crate::net::socket::{ - family, netlink::{ route::NetlinkRouteSocket, table::{is_valid_protocol, StandardNetlinkProtocol}, }, - SocketInode, + Socket, PSOCK, }; use alloc::sync::Arc; use system_error::SystemError; @@ -16,43 +15,38 @@ mod receiver; mod route; pub mod table; -pub struct Netlink; +pub fn create_netlink_socket( + socket_type: PSOCK, + protocol: u32, + _is_nonblock: bool, +) -> Result, SystemError> { + match socket_type { + super::PSOCK::Raw | super::PSOCK::Datagram => { + let nl_protocol = StandardNetlinkProtocol::try_from(protocol); + let inode = match nl_protocol { + Ok(StandardNetlinkProtocol::ROUTE) => NetlinkRouteSocket::new(false), + Ok(_) => { + log::warn!( + "standard netlink families {} is not supported yet", + protocol + ); + return Err(SystemError::EAFNOSUPPORT); + } + Err(_) => { + if is_valid_protocol(protocol) { + log::error!("user-provided netlink family is not supported"); + return Err(SystemError::EPROTONOSUPPORT); + } + log::error!("invalid netlink protocol: {}", protocol); + return Err(SystemError::EAFNOSUPPORT); + } + }; -impl family::Family for Netlink { - fn socket( - stype: super::PSOCK, - protocol: u32, - ) -> Result, SystemError> { - match stype { - super::PSOCK::Raw | super::PSOCK::Datagram => create_netlink_socket(protocol), - _ => { - log::warn!("unsupported socket type for Netlink"); - Err(SystemError::EPROTONOSUPPORT) - } + Ok(inode) } - } -} - -fn create_netlink_socket(protocol: u32) -> Result, SystemError> { - let nl_protocol = StandardNetlinkProtocol::try_from(protocol); - let inode = match nl_protocol { - Ok(StandardNetlinkProtocol::ROUTE) => NetlinkRouteSocket::new(false), - Ok(_) => { - log::warn!( - "standard netlink families {} is not supported yet", - protocol - ); - return Err(SystemError::EAFNOSUPPORT); + _ => { + log::warn!("unsupported socket type for Netlink"); + Err(SystemError::EPROTONOSUPPORT) } - Err(_) => { - if is_valid_protocol(protocol) { - log::error!("user-provided netlink family is not supported"); - return Err(SystemError::EPROTONOSUPPORT); - } - log::error!("invalid netlink protocol: {}", protocol); - return Err(SystemError::EAFNOSUPPORT); - } - }; - - Ok(SocketInode::new(inode)) + } } diff --git a/kernel/src/net/socket/utils/mod.rs b/kernel/src/net/socket/utils/mod.rs index 96bf443b4..b1825319b 100644 --- a/kernel/src/net/socket/utils/mod.rs +++ b/kernel/src/net/socket/utils/mod.rs @@ -1,9 +1,9 @@ -use crate::net::socket::{ - self, inet::syscall::create_inet_socket, unix::create_unix_socket, Socket, -}; pub(super) mod datagram_common; -use crate::net::socket; +use crate::net::socket::{ + self, inet::syscall::create_inet_socket, netlink::create_netlink_socket, + unix::create_unix_socket, Socket, +}; use alloc::sync::Arc; use system_error::SystemError; @@ -17,10 +17,20 @@ pub fn create_socket( // log::info!("Creating socket: {:?}, {:?}, {:?}", family, socket_type, protocol); type AF = socket::AddressFamily; let inode = match family { - AF::INet => socket::inet::Inet::socket(socket_type, protocol)?, - // AF::INet6 => socket::inet::Inet6::socket(socket_type, protocol)?, - AF::Unix => socket::unix::Unix::socket(socket_type, protocol)?, - AF::Netlink => socket::netlink::Netlink::socket(socket_type, protocol)?, + AF::INet => create_inet_socket( + smoltcp::wire::IpVersion::Ipv4, + socket_type, + smoltcp::wire::IpProtocol::from(protocol as u8), + is_nonblock, + )?, + AF::INet6 => create_inet_socket( + smoltcp::wire::IpVersion::Ipv6, + socket_type, + smoltcp::wire::IpProtocol::from(protocol as u8), + is_nonblock, + )?, + AF::Unix => create_unix_socket(socket_type, is_nonblock)?, + AF::Netlink => create_netlink_socket(socket_type, protocol, is_nonblock)?, _ => { log::warn!("unsupport address family"); return Err(SystemError::EAFNOSUPPORT); diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs index 7e7f0eee9..35dc8054e 100644 --- a/kernel/src/net/syscall.rs +++ b/kernel/src/net/syscall.rs @@ -124,9 +124,13 @@ impl Syscall { optval: &[u8], ) -> Result { let sol = socket::PSOL::try_from(level as u32)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; log::debug!("setsockopt: level = {:?} ", sol); - return socket.set_option(sol, optname, optval).map(|_| 0); + return socket + .as_socket() + .unwrap() + .set_option(sol, optname, optval) + .map(|_| 0); } /// @brief sys_getsockopt系统调用的实际执行函数 @@ -147,7 +151,8 @@ impl Syscall { ) -> Result { // 获取socket let optval = optval as *mut u32; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); use socket::{PSO, PSOL}; @@ -207,8 +212,11 @@ impl Syscall { /// @return 成功返回0,失败返回错误码 pub fn connect(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result { let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; - socket.connect(endpoint)?; + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .connect(endpoint)?; Ok(0) } @@ -228,9 +236,11 @@ impl Syscall { // addrlen // ); let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; - // log::debug!("bind: socket={:?}", socket); - socket.bind(endpoint)?; + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .bind(endpoint)?; Ok(0) } @@ -258,7 +268,8 @@ impl Syscall { let flags = socket::PMSG::from_bits_truncate(flags); - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); if let Some(endpoint) = endpoint { return socket.send_to(buf, flags, endpoint); @@ -283,7 +294,8 @@ impl Syscall { addr: *mut SockAddr, addr_len: *mut u32, ) -> Result { - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); let flags = socket::PMSG::from_bits_truncate(flags); @@ -320,7 +332,8 @@ impl Syscall { let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; let (buf, recv_size) = { - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); let flags = socket::PMSG::from_bits_truncate(flags); @@ -344,8 +357,12 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn listen(fd: usize, backlog: usize) -> Result { - let inode = ProcessManager::current_pcb().get_socket(fd as i32)?; - inode.listen(backlog).map(|_| 0) + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .listen(backlog) + .map(|_| 0) } /// @brief sys_shutdown系统调用的实际执行函数 @@ -355,8 +372,12 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn shutdown(fd: usize, how: usize) -> Result { - let inode = ProcessManager::current_pcb().get_socket(fd as i32)?; - inode.shutdown(ShutdownBit::try_from(how)?).map(|()| 0) + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .shutdown(ShutdownBit::try_from(how)?) + .map(|()| 0) } /// @brief sys_accept系统调用的实际执行函数 @@ -413,7 +434,9 @@ impl Syscall { ) -> Result { let (new_socket, remote_endpoint) = { ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .accept()? }; @@ -459,7 +482,9 @@ impl Syscall { return Err(SystemError::EINVAL); } ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .local_endpoint()? .write_to_user(addr, addrlen)?; return Ok(0); @@ -482,7 +507,9 @@ impl Syscall { } ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .remote_endpoint()? .write_to_user(addr, addrlen)?; diff --git a/kernel/src/process/mod.rs b/kernel/src/process/mod.rs index e0fcacb6d..32892449d 100644 --- a/kernel/src/process/mod.rs +++ b/kernel/src/process/mod.rs @@ -14,7 +14,6 @@ use alloc::{ }; use cred::INIT_CRED; use hashbrown::HashMap; -use intertrait::cast::CastArc; use log::{debug, error, info, warn}; use pid::{alloc_pid, Pid, PidLink, PidType}; use process_group::Pgid; @@ -55,7 +54,6 @@ use crate::{ ucontext::AddressSpace, PhysAddr, VirtAddr, }, - net::socket::Socket, sched::{ DequeueFlag, EnqueueFlag, OnRq, SchedMode, WakeupFlags, __schedule, completion::Completion, cpu_rq, fair::FairSchedEntity, prio::MAX_PRIO, @@ -1069,7 +1067,7 @@ impl ProcessControlBlock { self.task_tgid_vnr().unwrap() == RawPid(1) } - /// 根据文件描述符序号,获取socket对象的Arc指针 + /// 根据文件描述符序号,获取socket相应的IndexNode的Arc指针 /// /// this is a helper function /// @@ -1079,8 +1077,13 @@ impl ProcessControlBlock { /// /// ## 返回值 /// - /// Option(&mut Box) socket对象的可变引用. 如果文件描述符不是socket,那么返回None - pub fn get_socket(&self, fd: i32) -> Result, SystemError> { + /// 如果fd对应的文件是一个socket,那么返回这个socket相应的IndexNode的Arc指针,否则返回错误码 + /// + /// # 注意 + /// 因为底层的Socket中可能包含泛型,经过类型擦除转换成Arc的时候内部的泛型信息会丢失; + /// 因此这里返回Arc,可在外部直接通过 `as_socket()` 转换成 `Option<&dyn Socket>`; + /// 因为内部已经经过检查,因此在外部可以直接 `unwarp` 来获取 `&dyn Socket` + pub fn get_socket_inode(&self, fd: i32) -> Result, SystemError> { let f = ProcessManager::current_pcb() .fd_table() .read() @@ -1093,11 +1096,15 @@ impl ProcessControlBlock { if f.file_type() != FileType::Socket { return Err(SystemError::EBADF); } + + let inode = f.inode(); // log::info!("get_socket: fd {} is a socket", fd); - f.inode().cast::().map_err(|_| { - log::error!("get_socket: fd {} is not a socket", fd); - SystemError::EBADF - }) + if let Some(_sock) = inode.as_socket() { + // log::info!("{:?}", sock); + return Ok(inode); + } + + Err(SystemError::EBADF) } /// 当前进程退出时,让初始进程收养所有子进程 diff --git a/user/apps/c_unitest/test_epoll_socket.c b/user/apps/c_unitest/test_epoll_socket.c index 9d7d02ad1..7b46f8699 100644 --- a/user/apps/c_unitest/test_epoll_socket.c +++ b/user/apps/c_unitest/test_epoll_socket.c @@ -97,7 +97,6 @@ void server_process() { perror("[Server] epoll_wait failed"); exit(EXIT_FAILURE); } - printf("Fuck epoll_wait returned %d\n", nfds); for (int n = 0; n < nfds; ++n) { if (events[n].data.fd == listen_sock) {