diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..4aaa20b43 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // 使用 IntelliSense 了解相关属性。 + // 悬停以查看现有属性的描述。 + // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug kernel.elf", + "stopOnEntry": false, + "targetCreateCommands": ["target create ${workspaceFolder}/bin/kernel/kernel.elf"], + "program": "${workspaceFolder}/bin/kernel/kernel.elf", + "processCreateCommands": [ + "gdb-remote 127.0.0.1:1234", + "settings set target.process.follow-fork-mode child", + "continue", + ], + "args": [], + "cwd": "${workspaceFolder}", + "sourceLanguages": ["rust"], + "console": "internalConsole" + } + ] +} \ No newline at end of file diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index 30b1d6c05..32d1ef06b 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -135,3 +135,4 @@ debug = true # Controls whether the compiler passes `-g` # The release profile, used for `cargo build --release` [profile.release] debug = false +# debug = true diff --git a/kernel/src/Makefile b/kernel/src/Makefile index 6dd238e17..6ae2d4f6a 100644 --- a/kernel/src/Makefile +++ b/kernel/src/Makefile @@ -1,6 +1,6 @@ SUBDIR_ROOTS := . DIRS := . $(shell find $(SUBDIR_ROOTS) -type d) -GARBAGE_PATTERNS := *.o *.s~ *.s *.S~ *.c~ *.h~ kernel +GARBAGE_PATTERNS := *.o *.s~ *.s *.S~ *.c~ *.h~ GARBAGE := $(foreach DIR,$(DIRS),$(addprefix $(DIR)/,$(GARBAGE_PATTERNS))) DIR_LIB=libs diff --git a/kernel/src/driver/net/bridge.rs b/kernel/src/driver/net/bridge.rs new file mode 100644 index 000000000..0f888809d --- /dev/null +++ b/kernel/src/driver/net/bridge.rs @@ -0,0 +1,362 @@ +use crate::{ + driver::net::{ + napi::napi_schedule, register_netdevice, veth::VethInterface, Iface, NetDeivceState, + Operstate, + }, + init::initcall::INITCALL_DEVICE, + libs::{rwlock::RwLock, spinlock::SpinLock}, + process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, + time::Instant, +}; +use alloc::string::ToString; +use alloc::sync::Weak; +use alloc::{collections::BTreeMap, string::String, sync::Arc}; +use core::sync::atomic::AtomicUsize; +use hashbrown::HashMap; +use smoltcp::wire::{EthernetAddress, EthernetFrame, IpAddress, IpCidr}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +/// MAC地址表老化时间 +const MAC_ENTRY_TIMEOUT: u64 = 300_000; // 5分钟 + +pub type BridgePortId = usize; + +#[derive(Debug)] +struct MacEntry { + port_id: BridgePortId, + pub(self) record: RwLock, + // 存活时间(动态学习的老化) +} + +impl MacEntry { + pub fn new(port: BridgePortId) -> Self { + MacEntry { + port_id: port, + record: RwLock::new(MacEntryRecord { + last_seen: Instant::now(), + }), + } + } + + /// 更新最后一次被看到的时间为现在 + pub(self) fn update_last_seen(&self) { + self.record.write_irqsave().last_seen = Instant::now(); + } +} + +#[derive(Debug)] +struct MacEntryRecord { + last_seen: Instant, +} + +/// 代表一个加入bridge的网络接口 +#[derive(Debug, Clone)] +pub struct BridgePort { + pub id: BridgePortId, + pub(super) bridge_enable: Arc, + pub(super) bridge_driver_ref: Weak, + // 当前接口状态?forwarding, learning, blocking? + // mac mtu信息 +} + +impl BridgePort { + fn new( + id: BridgePortId, + device: Arc, + bridge: &Arc, + ) -> Self { + let port = BridgePort { + id, + bridge_enable: device.clone(), + bridge_driver_ref: Arc::downgrade(bridge), + }; + + device.set_common_bridge_data(&port); + + port + } +} + +#[derive(Debug)] +pub struct Bridge { + name: String, + // 端口列表,key为MAC地址 + ports: BTreeMap, + // FDB(Forwarding Database) + mac_table: HashMap, + // 配置参数,比如aging timeout, max age, hello time, forward delay + // bridge_mac: EthernetAddress, +} + +impl Bridge { + pub fn new(name: &str) -> Self { + Self { + name: name.into(), + ports: BTreeMap::new(), + mac_table: HashMap::new(), + } + } + + pub fn add_port(&mut self, id: BridgePortId, port: BridgePort) { + self.ports.insert(id, port); + } + + pub fn remove_port(&mut self, port_id: BridgePortId) { + self.ports.remove(&port_id); + // 清理MAC地址表中与该端口相关的条目 + self.mac_table + .retain(|_mac, entry| entry.port_id != port_id); + } + + fn insert_or_update_mac_entry(&mut self, src_mac: EthernetAddress, port_id: BridgePortId) { + if let Some(entry) = self.mac_table.get(&src_mac) { + entry.update_last_seen(); + // 如果 MAC 地址学习到了不同的端口,需要更新 + if entry.port_id != port_id { + // log::info!("Bridge {}: MAC {} moved from port {} to port {}", self.name, src_mac, entry.port_id, port_id); + self.mac_table.insert(src_mac, MacEntry::new(port_id)); + } + } else { + // log::info!("Bridge {}: Learned MAC {} on port {}", self.name, src_mac, port_id); + self.mac_table.insert(src_mac, MacEntry::new(port_id)); + } + } + + pub fn handle_frame(&mut self, ingress_port_id: BridgePortId, frame: &[u8]) { + if frame.len() < 14 { + // 使用 smoltcp 提供的最小长度 + // log::warn!("Bridge {}: Received malformed Ethernet frame (too short).", self.name); + return; + } + + let ether_frame = match EthernetFrame::new_checked(frame) { + Ok(f) => f, + Err(_) => { + // log::warn!("Bridge {}: Received malformed Ethernet frame.", self.name); + return; + } + }; + + let dst_mac = ether_frame.dst_addr(); + let src_mac = ether_frame.src_addr(); + + self.insert_or_update_mac_entry(src_mac, ingress_port_id); + + if dst_mac.is_broadcast() { + // 广播 这里有可能是arp请求 + self.flood(Some(ingress_port_id), frame); + } else { + // 单播 + if let Some(entry) = self.mac_table.get(&dst_mac) { + let target_port = entry.port_id; + // 避免发回自己 + // if target_port != ingress_port_id { + self.transmit_to_port(target_port, frame); + // } + } else { + // 未知单播 → 广播 + log::info!("unknown unicast, flooding frame"); + self.flood(Some(ingress_port_id), frame); + } + } + + self.sweep_mac_table(); + } + + fn flood(&self, except_port_id: Option, frame: &[u8]) { + match except_port_id { + Some(except_id) => { + for (&port_id, bridge_port) in &self.ports { + if port_id != except_id { + self.transmit_to_device(bridge_port, frame); + } + } + } + None => { + for bridge_port in self.ports.values() { + self.transmit_to_device(bridge_port, frame); + } + } + } + } + + fn transmit_to_port(&self, target_port_id: BridgePortId, frame: &[u8]) { + if let Some(device_arc) = self.ports.get(&target_port_id) { + self.transmit_to_device(device_arc, frame); + } else { + // log::warn!("Bridge {}: Attempted to transmit to non-existent port ID {}", self.name, target_port_id); + } + } + + fn transmit_to_device(&self, device: &BridgePort, frame: &[u8]) { + device.bridge_enable.receive_from_bridge(frame); + if let Some(napi) = device.bridge_enable.napi_struct() { + napi_schedule(napi); + } + } + + pub fn sweep_mac_table(&mut self) { + let now = Instant::now(); + self.mac_table.retain(|_mac, entry| { + now.duration_since(entry.record.read().last_seen) + .unwrap_or_default() + .total_millis() + < MAC_ENTRY_TIMEOUT + }); + } + + pub fn name(&self) -> &str { + &self.name + } +} + +#[derive(Debug)] +pub struct BridgeDriver { + pub inner: SpinLock, + pub netns: RwLock>, + self_ref: Weak, + next_port_id: AtomicUsize, +} + +impl BridgeDriver { + pub fn new(name: &str) -> Arc { + Arc::new_cyclic(|self_ref| BridgeDriver { + inner: SpinLock::new(Bridge::new(name)), + netns: RwLock::new(Weak::new()), + self_ref: self_ref.clone(), + next_port_id: AtomicUsize::new(0), + }) + } + + fn next_port_id(&self) -> BridgePortId { + self.next_port_id + .fetch_add(1, core::sync::atomic::Ordering::Relaxed) + } + + pub fn add_device(&self, device: Arc) { + if let Some(netns) = self.netns() { + if !Arc::ptr_eq( + &netns, + &device.net_namespace().unwrap_or(INIT_NET_NAMESPACE.clone()), + ) { + log::warn!("Port and bridge are in different net namespaces"); + return; + } + } + let port = BridgePort::new( + self.next_port_id(), + device.clone(), + &self.self_ref.upgrade().unwrap(), + ); + log::info!("Adding port with id: {}", port.id); + + self.inner.lock().add_port(port.id, port); + } + + pub fn remove_device(&self, device: Arc) { + let Some(common_data) = device.common_bridge_data() else { + log::warn!("Device is not part of any bridge"); + return; + }; + self.inner.lock().remove_port(common_data.id); + } + + pub fn handle_frame(&self, ingress_port_id: BridgePortId, frame: &[u8]) { + self.inner.lock().handle_frame(ingress_port_id, frame); + } + + pub fn name(&self) -> String { + self.inner.lock().name().to_string() + } + + pub fn set_netns(&self, netns: &Arc) { + *self.netns.write() = Arc::downgrade(netns); + } + + pub fn netns(&self) -> Option> { + self.netns.read().upgrade() + } +} + +/// 可供桥接设备应该实现的 trait +pub trait BridgeEnableDevice: Iface { + /// 接收来自桥的数据帧 + fn receive_from_bridge(&self, frame: &[u8]); + + /// 设置桥接相关的公共数据 + fn set_common_bridge_data(&self, _port: &BridgePort); + + /// 获取桥接相关的公共数据 + fn common_bridge_data(&self) -> Option; + // fn bridge(&self) -> Weak { + // let Some(data) = self.common_bridge_data() else { + // return Weak::default(); + // }; + // data.bridge_driver + // } +} + +#[derive(Debug, Clone)] +pub struct BridgeCommonData { + pub id: BridgePortId, + pub bridge_driver_ref: Weak, +} + +fn bridge_probe() { + let (iface1, iface2) = VethInterface::new_pair("veth_a", "veth_b"); + let (iface3, iface4) = VethInterface::new_pair("veth_c", "veth_d"); + + let addr1 = IpAddress::v4(200, 0, 0, 1); + let cidr1 = IpCidr::new(addr1, 24); + let addr2 = IpAddress::v4(200, 0, 0, 2); + let cidr2 = IpCidr::new(addr2, 24); + + let addr3 = IpAddress::v4(200, 0, 0, 3); + let cidr3 = IpCidr::new(addr3, 24); + let addr4 = IpAddress::v4(200, 0, 0, 4); + let cidr4 = IpCidr::new(addr4, 24); + + iface1.update_ip_addrs(cidr1); + iface2.update_ip_addrs(cidr2); + iface3.update_ip_addrs(cidr3); + iface4.update_ip_addrs(cidr4); + + iface1.add_default_route_to_peer(addr2); + iface2.add_default_route_to_peer(addr1); + iface3.add_default_route_to_peer(addr4); + iface4.add_default_route_to_peer(addr3); + + // iface1.add_direct_route(cidr4, addr2); + + let turn_on = |a: &Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + INIT_NET_NAMESPACE.add_device(a.clone()); + a.common().set_net_namespace(INIT_NET_NAMESPACE.clone()); + + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + turn_on(&iface1); + turn_on(&iface2); + turn_on(&iface3); + turn_on(&iface4); + + let bridge = BridgeDriver::new("bridge0"); + bridge.set_netns(&INIT_NET_NAMESPACE); + INIT_NET_NAMESPACE.insert_bridge(bridge.clone()); + + bridge.add_device(iface3); + bridge.add_device(iface2); + + log::info!("Bridge device created"); +} + +#[unified_init(INITCALL_DEVICE)] +pub fn bridge_init() -> Result<(), SystemError> { + bridge_probe(); + // log::info!("bridge initialized."); + Ok(()) +} diff --git a/kernel/src/driver/net/e1000e/e1000e_driver.rs b/kernel/src/driver/net/e1000e/e1000e_driver.rs index 98275a6c2..8a448e3db 100644 --- a/kernel/src/driver/net/e1000e/e1000e_driver.rs +++ b/kernel/src/driver/net/e1000e/e1000e_driver.rs @@ -9,14 +9,16 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, }, net::{ - register_netdevice, Iface, IfaceCommon, NetDeivceState, NetDeviceCommonData, Operstate, + register_netdevice, types::InterfaceFlags, Iface, IfaceCommon, NetDeivceState, + NetDeviceCommonData, Operstate, }, }, libs::{ rwlock::{RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, - net::{generate_iface_id, NET_DEVICES}, + net::generate_iface_id, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::Instant, }; use alloc::{ @@ -200,9 +202,20 @@ impl E1000EInterface { let iface = smoltcp::iface::Interface::new(iface_config, &mut driver, Instant::now().into()); - let result = Arc::new(E1000EInterface { + let flags = InterfaceFlags::UP + | InterfaceFlags::BROADCAST + | InterfaceFlags::RUNNING + | InterfaceFlags::MULTICAST + | InterfaceFlags::LOWER_UP; + + let iface = Arc::new(E1000EInterface { driver: E1000EDriverWrapper(UnsafeCell::new(driver)), - common: IfaceCommon::new(iface_id, false, iface), + common: IfaceCommon::new( + iface_id, + crate::driver::net::types::InterfaceType::EETHER, + flags, + iface, + ), name: format!("eth{}", iface_id), inner: SpinLock::new(InnerE1000EInterface { netdevice_common: NetDeviceCommonData::default(), @@ -212,7 +225,7 @@ impl E1000EInterface { locked_kobj_state: LockedKObjectState::default(), }); - return result; + iface } pub fn inner(&self) -> SpinLockGuard<'_, InnerE1000EInterface> { @@ -304,7 +317,7 @@ impl Iface for E1000EInterface { return self.name.clone(); } - fn poll(&self) { + fn poll(&self) -> bool { self.common.poll(self.driver.force_get_mut()) } @@ -332,6 +345,15 @@ impl Iface for E1000EInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } } impl KObject for E1000EInterface { @@ -400,9 +422,12 @@ pub fn e1000e_driver_init(device: E1000EDevice) { iface.set_net_state(NetDeivceState::__LINK_STATE_START); // 将网卡的接口信息注册到全局的网卡接口信息表中 - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); + // NET_DEVICES + // .write_irqsave() + // .insert(iface.nic_id(), iface.clone()); + INIT_NET_NAMESPACE.add_device(iface.clone()); + iface.common.set_net_namespace(INIT_NET_NAMESPACE.clone()); + info!("e1000e driver init successfully!\tMAC: [{}]", mac); register_netdevice(iface.clone()).expect("register lo device failed"); diff --git a/kernel/src/driver/net/irq_handle.rs b/kernel/src/driver/net/irq_handle.rs index 5a6cd0db0..3fd223f12 100644 --- a/kernel/src/driver/net/irq_handle.rs +++ b/kernel/src/driver/net/irq_handle.rs @@ -1,10 +1,13 @@ use alloc::sync::Arc; use system_error::SystemError; -use crate::exception::{ - irqdata::IrqHandlerData, - irqdesc::{IrqHandler, IrqReturn}, - IrqNumber, +use crate::{ + exception::{ + irqdata::IrqHandlerData, + irqdesc::{IrqHandler, IrqReturn}, + IrqNumber, + }, + process::namespace::net_namespace::INIT_NET_NAMESPACE, }; /// 默认的网卡中断处理函数 @@ -18,7 +21,8 @@ impl IrqHandler for DefaultNetIrqHandler { _static_data: Option<&dyn IrqHandlerData>, _dynamic_data: Option>, ) -> Result { - super::kthread::wakeup_poll_thread(); + // 这里先暂时唤醒 INIT 网络命名空间的轮询线程 + INIT_NET_NAMESPACE.wakeup_poll_thread(); Ok(IrqReturn::Handled) } } diff --git a/kernel/src/driver/net/kthread.rs b/kernel/src/driver/net/kthread.rs deleted file mode 100644 index 3c634499b..000000000 --- a/kernel/src/driver/net/kthread.rs +++ /dev/null @@ -1,47 +0,0 @@ -use alloc::borrow::ToOwned; -use alloc::sync::Arc; -use unified_init::macros::unified_init; - -use crate::arch::CurrentIrqArch; -use crate::exception::InterruptArch; -use crate::init::initcall::INITCALL_SUBSYS; -use crate::net::NET_DEVICES; -use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; -use crate::process::{ProcessControlBlock, ProcessManager}; -use crate::sched::{schedule, SchedMode}; - -static mut NET_POLL_THREAD: Option> = None; - -#[unified_init(INITCALL_SUBSYS)] -pub fn net_poll_init() -> Result<(), system_error::SystemError> { - let closure = KernelThreadClosure::StaticEmptyClosure((&(net_poll_thread as fn() -> i32), ())); - let pcb = KernelThreadMechanism::create_and_run(closure, "net_poll".to_owned()) - .ok_or("") - .expect("create net_poll thread failed"); - log::info!("net_poll thread created"); - unsafe { - NET_POLL_THREAD = Some(pcb); - } - return Ok(()); -} - -fn net_poll_thread() -> i32 { - log::info!("net_poll thread started"); - loop { - for (_, iface) in NET_DEVICES.read_irqsave().iter() { - iface.poll(); - } - let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; - ProcessManager::mark_sleep(true).expect("clocksource_watchdog_kthread:mark sleep failed"); - drop(irq_guard); - schedule(SchedMode::SM_NONE); - } -} - -/// 拉起线程 -pub(super) fn wakeup_poll_thread() { - if unsafe { NET_POLL_THREAD.is_none() } { - return; - } - let _ = ProcessManager::wakeup(unsafe { NET_POLL_THREAD.as_ref().unwrap() }); -} diff --git a/kernel/src/driver/net/loopback.rs b/kernel/src/driver/net/loopback.rs index 69358c454..a08901b24 100644 --- a/kernel/src/driver/net/loopback.rs +++ b/kernel/src/driver/net/loopback.rs @@ -7,11 +7,13 @@ use crate::driver::base::kobject::{ KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, }; use crate::driver::base::kset::KSet; +use crate::driver::net::types::InterfaceFlags; use crate::filesystem::kernfs::KernFSInode; use crate::init::initcall::INITCALL_DEVICE; use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; use crate::libs::spinlock::{SpinLock, SpinLockGuard}; -use crate::net::{generate_iface_id, NET_DEVICES}; +use crate::net::generate_iface_id; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; use crate::time::Instant; use alloc::collections::VecDeque; use alloc::fmt::Debug; @@ -88,6 +90,7 @@ impl phy::TxToken for LoopbackTxToken { } } +#[derive(Default)] /// ## Loopback设备 /// 成员是一个队列,用来存放接受到的数据包。 /// 当使用lo发送数据包时,不会把数据包传到link层,而是直接发送到该队列,实现环回。 @@ -97,12 +100,6 @@ pub struct Loopback { } impl Loopback { - /// ## Loopback创建函数 - /// 创建lo设备 - pub fn new() -> Self { - let queue = VecDeque::new(); - Loopback { queue } - } /// ## Loopback处理接受到的数据包函数 /// Loopback接受到数据后会调用这个函数来弹出接收的数据,返回给协议栈 /// @@ -172,10 +169,10 @@ pub struct LoopbackDriver { pub inner: Arc>, } -impl LoopbackDriver { +impl Default for LoopbackDriver { /// ## LoopbackDriver创建函数 - pub fn new() -> Self { - let inner = Arc::new(SpinLock::new(Loopback::new())); + fn default() -> Self { + let inner = Arc::new(SpinLock::new(Loopback::default())); LoopbackDriver { inner } } } @@ -325,9 +322,19 @@ impl LoopbackInterface { .expect("Add default ipv4 route failed: full"); }); + let flags = InterfaceFlags::LOOPBACK + | InterfaceFlags::UP + | InterfaceFlags::RUNNING + | InterfaceFlags::LOWER_UP; + Arc::new(LoopbackInterface { driver: LoopbackDriverWapper(UnsafeCell::new(driver)), - common: IfaceCommon::new(iface_id, false, iface), + common: IfaceCommon::new( + iface_id, + super::types::InterfaceType::LOOPBACK, + flags, + iface, + ), inner: SpinLock::new(InnerLoopbackInterface { netdevice_common: NetDeviceCommonData::default(), device_common: DeviceCommonData::default(), @@ -485,7 +492,7 @@ impl Iface for LoopbackInterface { smoltcp::wire::EthernetAddress(mac) } - fn poll(&self) { + fn poll(&self) -> bool { self.common.poll(self.driver.force_get_mut()) } @@ -513,24 +520,39 @@ impl Iface for LoopbackInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } +} + +pub fn generate_loopback_iface_default() -> Arc { + let iface = LoopbackInterface::new(LoopbackDriver::default()); + // 标识网络设备已经启动 + iface.set_net_state(NetDeivceState::__LINK_STATE_START); + + register_netdevice(iface.clone()).expect("register lo device failed"); + + iface } pub fn loopback_probe() { loopback_driver_init(); } + /// # lo网卡设备初始化函数 /// 创建驱动和iface,初始化一个lo网卡,添加到全局NET_DEVICES中 pub fn loopback_driver_init() { - let driver = LoopbackDriver::new(); - let iface = LoopbackInterface::new(driver); - // 标识网络设备已经启动 - iface.set_net_state(NetDeivceState::__LINK_STATE_START); + let iface = generate_loopback_iface_default(); - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); - - register_netdevice(iface.clone()).expect("register lo device failed"); + INIT_NET_NAMESPACE.add_device(iface.clone()); + INIT_NET_NAMESPACE.set_loopback_iface(iface.clone()); + iface.common.set_net_namespace(INIT_NET_NAMESPACE.clone()); } /// ## lo网卡设备的注册函数 diff --git a/kernel/src/driver/net/mod.rs b/kernel/src/driver/net/mod.rs index 98187cd86..3753a8909 100644 --- a/kernel/src/driver/net/mod.rs +++ b/kernel/src/driver/net/mod.rs @@ -1,7 +1,14 @@ +use alloc::sync::Weak; use alloc::{fmt, vec::Vec}; use alloc::{string::String, sync::Arc}; +use core::net::Ipv4Addr; use sysfs::netdev_register_kobject; +use crate::driver::net::napi::NapiStruct; +use crate::driver::net::types::{InterfaceFlags, InterfaceType}; +use crate::libs::rwlock::RwLockReadGuard; +use crate::net::routing::RouterEnableDeviceCommon; +use crate::process::namespace::net_namespace::NetNamespace; use crate::{ libs::{rwlock::RwLock, spinlock::SpinLock}, net::socket::inet::{common::PortManager, InetSocket}, @@ -10,13 +17,16 @@ use crate::{ use smoltcp; use system_error::SystemError; +pub mod bridge; pub mod class; mod dma; pub mod e1000e; pub mod irq_handle; -pub mod kthread; pub mod loopback; +pub mod napi; pub mod sysfs; +pub mod types; +pub mod veth; pub mod virtio_net; bitflags! { @@ -74,13 +84,11 @@ pub trait Iface: crate::driver::base::device::Device { } /// # `poll` - /// 用于轮询接口的状态。 - /// ## 参数 - /// - `sockets` :一个可变引用到 `smoltcp::iface::SocketSet`,表示要轮询的套接字集 + /// 用于轮询网卡,处理网络事件 /// ## 返回值 - /// - 成功返回 `Ok(())` - /// - 如果轮询失败,返回 `Err(SystemError::EAGAIN_OR_EWOULDBLOCK)`,表示需要再次尝试或者操作会阻塞 - fn poll(&self); + /// - `true`:表示有网络事件发生 + /// - `false`:表示没有网络事件 + fn poll(&self) -> bool; /// # `update_ip_addrs` /// 用于更新接口的 IP 地址 @@ -122,6 +130,34 @@ pub trait Iface: crate::driver::base::device::Device { fn operstate(&self) -> Operstate; fn set_operstate(&self, state: Operstate); + + fn net_namespace(&self) -> Option> { + self.common().net_namespace() + } + + fn set_net_namespace(&self, ns: Arc) { + self.common().set_net_namespace(ns); + } + + fn flags(&self) -> InterfaceFlags { + self.common().flags() + } + + fn type_(&self) -> InterfaceType { + self.common().type_() + } + + fn mtu(&self) -> usize; + + /// # 获取当前iface的napi结构体 + /// 默认返回None,表示不支持napi + fn napi_struct(&self) -> Option> { + self.common().napi_struct.read().clone() + } + + fn router_common(&self) -> &RouterEnableDeviceCommon { + &self.common().router_common_data + } } /// 网络设备的公共数据 @@ -162,6 +198,8 @@ fn register_netdevice(dev: Arc) -> Result<(), SystemError> { pub struct IfaceCommon { iface_id: usize, + flags: InterfaceFlags, + type_: InterfaceType, smol_iface: SpinLock, /// 存smoltcp网卡的套接字集 sockets: SpinLock>, @@ -171,9 +209,12 @@ pub struct IfaceCommon { port_manager: PortManager, /// 下次轮询的时间 poll_at_ms: core::sync::atomic::AtomicU64, - /// 默认网卡标识 - /// TODO: 此字段设置目的是解决对bind unspecified地址的分包问题,需要在inet实现多网卡监听或路由子系统实现后移除 - default_iface: bool, + /// 网络命名空间 + net_namespace: RwLock>, + /// 路由相关数据 + router_common_data: RouterEnableDeviceCommon, + /// NAPI 结构体 + napi_struct: RwLock>>, } impl fmt::Debug for IfaceCommon { @@ -186,19 +227,33 @@ impl fmt::Debug for IfaceCommon { } impl IfaceCommon { - pub fn new(iface_id: usize, default_iface: bool, iface: smoltcp::iface::Interface) -> Self { + pub fn new( + iface_id: usize, + type_: InterfaceType, + flags: InterfaceFlags, + iface: smoltcp::iface::Interface, + ) -> Self { + let router_common_data = RouterEnableDeviceCommon::default(); + router_common_data + .ip_addrs + .write() + .extend_from_slice(iface.ip_addrs()); IfaceCommon { iface_id, smol_iface: SpinLock::new(iface), sockets: SpinLock::new(smoltcp::iface::SocketSet::new(Vec::new())), bounds: RwLock::new(Vec::new()), - port_manager: PortManager::new(), + port_manager: PortManager::default(), poll_at_ms: core::sync::atomic::AtomicU64::new(0), - default_iface, + net_namespace: RwLock::new(Weak::new()), + router_common_data, + flags, + type_, + napi_struct: RwLock::new(None), } } - pub fn poll(&self, device: &mut D) + pub fn poll(&self, device: &mut D) -> bool where D: smoltcp::phy::Device + ?Sized, { @@ -227,6 +282,12 @@ impl IfaceCommon { // drop sockets here to avoid deadlock drop(interface); drop(sockets); + // log::info!( + // "polling iface {}, has_events: {}, poll_at: {:?}", + // self.iface_id, + // has_events, + // poll_at + // ); use core::sync::atomic::Ordering; if let Some(instant) = poll_at { @@ -259,6 +320,7 @@ impl IfaceCommon { // .extract_if(|closing_socket| closing_socket.is_closed()) // .collect::>(); // drop(closed_sockets); + has_events } pub fn update_ip_addrs(&self, ip_addrs: &[smoltcp::wire::IpCidr]) -> Result<(), SystemError> { @@ -291,8 +353,36 @@ impl IfaceCommon { } } - // TODO: 需要在inet实现多网卡监听或路由子系统实现后移除 - pub fn is_default_iface(&self) -> bool { - self.default_iface + pub fn ipv4_addr(&self) -> Option { + self.smol_iface.lock().ipv4_addr() + } + + pub fn ip_addrs(&self) -> RwLockReadGuard<'_, Vec> { + self.router_common_data.ip_addrs.read() + } + + pub fn prefix_len(&self) -> Option { + self.smol_iface + .lock() + .ip_addrs() + .first() + .map(|ip_addr| ip_addr.prefix_len()) + } + + pub fn net_namespace(&self) -> Option> { + self.net_namespace.read().upgrade() + } + + pub fn set_net_namespace(&self, ns: Arc) { + let mut guard = self.net_namespace.write(); + *guard = Arc::downgrade(&ns); + } + + pub fn flags(&self) -> InterfaceFlags { + self.flags + } + + pub fn type_(&self) -> InterfaceType { + self.type_ } } diff --git a/kernel/src/driver/net/napi.rs b/kernel/src/driver/net/napi.rs new file mode 100644 index 000000000..1321ef930 --- /dev/null +++ b/kernel/src/driver/net/napi.rs @@ -0,0 +1,268 @@ +use crate::driver::net::Iface; +use crate::init::initcall::INITCALL_SUBSYS; +use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::libs::wait_queue::WaitQueue; +use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; +use crate::process::ProcessState; +use alloc::boxed::Box; +use alloc::string::ToString; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +lazy_static! { + //todo 按照软中断的做法,这里应该是每个CPU一个列表,但目前只实现单CPU版本 + static ref GLOBAL_NAPI_MANAGER: Arc = + NapiManager::new(); +} + +/// # NAPI 结构体 +/// +/// https://elixir.bootlin.com/linux/v6.13/source/include/linux/netdevice.h#L359 +#[derive(Debug)] +pub struct NapiStruct { + /// NAPI实例状态 + pub state: AtomicU32, + /// NAPI实例权重,表示每次轮询时处理的最大包数 + pub weight: usize, + /// 唯一id + pub napi_id: usize, + /// 指向所属网卡的弱引用 + pub net_device: Weak, +} + +impl NapiStruct { + pub fn new(net_device: Arc, weight: usize) -> Arc { + Arc::new(Self { + state: AtomicU32::new(NapiState::empty().bits()), + weight, + napi_id: net_device.nic_id(), + net_device: Arc::downgrade(&net_device), + }) + } + + pub fn poll(&self) -> bool { + // log::info!("NAPI instance {} polling", self.napi_id); + // 获取网卡的强引用 + if let Some(iface) = self.net_device.upgrade() { + // 这里的weight原意是此次执行可以处理的包,如果超过了这个数就交给专门的内核线程(ksoftirqd)继续处理 + // 但目前我们就是在相当于ksoftirqd里面处理,如果在weight之内发现没数据包被处理了,在直接返回 + // 如果超过weight,返回true,表示还有工作没做完,会在下一次轮询继续处理 + // 因此语义是相同的 + for _ in 0..self.weight { + if !iface.poll() { + return false; + } + } + } else { + log::error!( + "NAPI instance {}: associated net device is gone", + self.napi_id + ); + } + + true + } +} + +bitflags! { + /// # NAPI状态标志 + /// + /// https://elixir.bootlin.com/linux/v6.13/source/include/linux/netdevice.h#L398 + pub struct NapiState:u32{ + /// Poll is scheduled. 这是最核心的状态,表示NAPI实例已被调度, + /// 存在于某个CPU的poll_list中等待处理。 + const SCHED = 1 << 0; + /// Missed a poll. 如果在NAPI实例被调度后但在实际处理前又有新的数据到达, + const MISSED = 1 << 1; + /// Disable pending. NAPI正在被禁用,不应再被调度。 + const DISABLE = 1 << 2; + const NPSVC = 1 << 3; + /// NAPI added to system lists. 表示NAPI实例已注册到设备中。 + const LISTED = 1 << 4; + const NO_BUSY_POLL = 1 << 5; + const IN_BUSY_POLL = 1 << 6; + const PREFER_BUSY_POLL = 1 << 7; + /// The poll is performed inside its own thread. + /// 一个可选的高级功能,表示此NAPI由专用内核线程处理。 + const THREADED = 1 << 8; + const SCHED_THREADED = 1 << 9; + } +} + +#[inline(never)] +#[unified_init(INITCALL_SUBSYS)] +pub fn napi_init() -> Result<(), SystemError> { + // 软中断做法 + // let napi_handler = Arc::new(NapiSoftirq::default()); + // softirq_vectors() + // .register_softirq(SoftirqNumber::NetReceive, napi_handler) + // .expect("Failed to register napi softirq"); + + // 软中断的方式无法唤醒 :( + // 使用一个专门的内核线程来处理NAPI轮询,模拟软中断的行为,相当于ksoftirq :) + + let closure: Box i32 + Send + Sync + 'static> = Box::new(move || { + net_rx_action(); + 0 + }); + let closure = KernelThreadClosure::EmptyClosure((closure, ())); + let name = "napi_handler".to_string(); + let _pcb = KernelThreadMechanism::create_and_run(closure, name) + .ok_or("") + .expect("create napi_handler thread failed"); + + log::info!("napi initialized successfully"); + Ok(()) +} + +fn net_rx_action() { + use crate::sched::SchedMode; + + loop { + // 这里直接将全局的NAPI管理器的napi_list取出,清空全局的列表,避免占用锁时间过长 + let mut inner = GLOBAL_NAPI_MANAGER.inner(); + let mut poll_list = inner.napi_list.clone(); + inner.napi_list.clear(); + drop(inner); + + // log::info!("NAPI softirq processing {} instances", poll_list.len()); + + // 如果此时长度为0,则让当前进程休眠,等待被唤醒 + if poll_list.is_empty() { + GLOBAL_NAPI_MANAGER + .inner() + .has_pending_signal + .store(false, Ordering::SeqCst); + } + + while let Some(napi) = poll_list.pop() { + let has_work_left = napi.poll(); + // log::info!("yes"); + + if has_work_left { + poll_list.push(napi); + } else { + napi_complete(napi); + } + } + + // log::info!("napi softirq iteration complete") + + // 在这种情况下,poll_list 中仍然有待处理的 NAPI 实例,压回队列,等待下一次唤醒时处理 + if !poll_list.is_empty() { + GLOBAL_NAPI_MANAGER.inner().napi_list.extend(poll_list); + } + + let _ = wq_wait_event_interruptible!( + GLOBAL_NAPI_MANAGER.wait_queue(), + GLOBAL_NAPI_MANAGER + .inner() + .has_pending_signal + .load(Ordering::SeqCst), + {} + ); + } +} + +/// 标记这个napi任务已经完成 +pub fn napi_complete(napi: Arc) { + napi.state + .fetch_and(!NapiState::SCHED.bits(), Ordering::SeqCst); +} + +/// 标记这个napi任务加入处理队列,已被调度 +pub fn napi_schedule(napi: Arc) { + let current_state = NapiState::from_bits_truncate( + napi.state + .fetch_or(NapiState::SCHED.bits(), Ordering::SeqCst), + ); + + if !current_state.contains(NapiState::SCHED) { + let new_state = current_state.union(NapiState::SCHED); + // log::info!("NAPI instance {} scheduled", napi.napi_id); + napi.state.store(new_state.bits(), Ordering::SeqCst); + } + + let mut inner = GLOBAL_NAPI_MANAGER.inner(); + inner.napi_list.push(napi); + inner.has_pending_signal.store(true, Ordering::SeqCst); + + GLOBAL_NAPI_MANAGER.wakeup(); + + // softirq_vectors().raise_softirq(SoftirqNumber::NetReceive); +} + +pub struct NapiManager { + inner: SpinLock, + wait_queue: WaitQueue, +} + +impl NapiManager { + pub fn new() -> Arc { + let inner = SpinLock::new(NapiManagerInner { + has_pending_signal: AtomicBool::new(false), + napi_list: Vec::new(), + }); + Arc::new(Self { + inner, + wait_queue: WaitQueue::default(), + }) + } + + pub fn inner(&self) -> SpinLockGuard<'_, NapiManagerInner> { + self.inner.lock() + } + + pub fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + pub fn wakeup(&self) { + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + } +} + +pub struct NapiManagerInner { + has_pending_signal: AtomicBool, + napi_list: Vec>, +} + +// 下面的是软中断的做法,无法唤醒,做个记录 + +// #[derive(Debug)] +// pub struct NapiSoftirq { +// running: AtomicBool, +// } + +// impl Default for NapiSoftirq { +// fn default() -> Self { +// Self { +// running: AtomicBool::new(false), +// } +// } +// } + +// impl SoftirqVec for NapiSoftirq { +// fn run(&self) { +// log::info!("NAPI softirq running"); +// if self +// .running +// .compare_exchange( +// false, +// true, +// core::sync::atomic::Ordering::SeqCst, +// core::sync::atomic::Ordering::SeqCst, +// ) +// .is_ok() +// { +// net_rx_action(); +// self.running +// .store(false, core::sync::atomic::Ordering::SeqCst); +// } else { +// log::warn!("NAPI softirq is already running"); +// } +// } +// } diff --git a/kernel/src/driver/net/types.rs b/kernel/src/driver/net/types.rs new file mode 100644 index 000000000..05fbee657 --- /dev/null +++ b/kernel/src/driver/net/types.rs @@ -0,0 +1,84 @@ +use system_error::SystemError; + +/// Interface type. +/// +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, FromPrimitive, ToPrimitive)] +pub enum InterfaceType { + // Arp protocol hardware identifiers + /// from KA9Q: NET/ROM pseudo + NETROM = 0, + /// Ethernet 10Mbps + ETHER = 1, + /// Experimental Ethernet + EETHER = 2, + + // Dummy types for non ARP hardware + /// IPIP tunnel + TUNNEL = 768, + /// IP6IP6 tunnel + TUNNEL6 = 769, + /// Frame Relay Access Device + FRAD = 770, + /// SKIP vif + SKIP = 771, + /// Loopback device + LOOPBACK = 772, + /// Localtalk device + LOCALTALK = 773, + // TODO 更多类型 +} + +impl TryFrom for InterfaceType { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + use num_traits::FromPrimitive; + return ::from_u16(value).ok_or(Self::Error::EINVAL); + } +} + +bitflags! { + /// Interface flags. + /// + pub struct InterfaceFlags: u32 { + /// Interface is up + const UP = 1<<0; + /// Broadcast address valid + const BROADCAST = 1<<1; + /// Turn on debugging + const DEBUG = 1<<2; + /// Loopback net + const LOOPBACK = 1<<3; + /// Interface is has p-p link + const POINTOPOINT = 1<<4; + /// Avoid use of trailers + const NOTRAILERS = 1<<5; + /// Interface RFC2863 OPER_UP + const RUNNING = 1<<6; + /// No ARP protocol + const NOARP = 1<<7; + /// Receive all packets + const PROMISC = 1<<8; + /// Receive all multicast packets + const ALLMULTI = 1<<9; + /// Master of a load balancer + const MASTER = 1<<10; + /// Slave of a load balancer + const SLAVE = 1<<11; + /// Supports multicast + const MULTICAST = 1<<12; + /// Can set media type + const PORTSEL = 1<<13; + /// Auto media select active + const AUTOMEDIA = 1<<14; + /// Dialup device with changing addresses + const DYNAMIC = 1<<15; + /// Driver signals L1 up + const LOWER_UP = 1<<16; + /// Driver signals dormant + const DORMANT = 1<<17; + /// Echo sent packets + const ECHO = 1<<18; + } +} diff --git a/kernel/src/driver/net/veth.rs b/kernel/src/driver/net/veth.rs new file mode 100644 index 000000000..95a13718e --- /dev/null +++ b/kernel/src/driver/net/veth.rs @@ -0,0 +1,853 @@ +use super::bridge::BridgeEnableDevice; +use super::{Iface, IfaceCommon}; +use super::{NetDeivceState, NetDeviceCommonData, Operstate}; +use crate::arch::rand::rand; +use crate::driver::base::class::Class; +use crate::driver::base::device::bus::Bus; +use crate::driver::base::device::driver::Driver; +use crate::driver::base::device::{self, DeviceCommonData, DeviceType, IdTable}; +use crate::driver::base::kobject::{ + KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState, +}; +use crate::driver::base::kset::KSet; +use crate::driver::net::bridge::{BridgeCommonData, BridgePort}; +use crate::driver::net::napi::{napi_schedule, NapiStruct}; +use crate::driver::net::register_netdevice; +use crate::driver::net::types::InterfaceFlags; +use crate::filesystem::kernfs::KernFSInode; +use crate::init::initcall::INITCALL_DEVICE; +use crate::libs::rwlock::{RwLockReadGuard, RwLockWriteGuard}; +use crate::libs::spinlock::{SpinLock, SpinLockGuard}; +use crate::net::generate_iface_id; +use crate::net::routing::{DnatRule, RouteEntry, RouterEnableDevice, SnatRule}; +use crate::process::namespace::net_namespace::{NetNamespace, INIT_NET_NAMESPACE}; +use crate::process::ProcessManager; +use alloc::collections::VecDeque; +use alloc::fmt::Debug; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::cell::UnsafeCell; +use core::ops::{Deref, DerefMut}; +use smoltcp::phy::DeviceCapabilities; +use smoltcp::phy::{self, RxToken}; +use smoltcp::wire::{EthernetAddress, EthernetFrame, HardwareAddress, IpAddress, IpCidr}; +use system_error::SystemError; +use unified_init::macros::unified_init; + +pub struct Veth { + name: String, + rx_queue: VecDeque>, + /// 对端的 `VethInterface`,在完成数据发送的时候会使用到 + peer: Weak, + self_iface_ref: Weak, +} + +impl Veth { + pub fn new(name: String) -> Self { + Veth { + name, + rx_queue: VecDeque::new(), + peer: Weak::new(), + self_iface_ref: Weak::new(), + } + } + + pub fn set_peer_iface(&mut self, peer: &Arc) { + self.peer = Arc::downgrade(peer); + } + + pub fn send_to_peer(&self, data: &[u8]) { + if let Some(peer) = self.peer.upgrade() { + // log::info!("Veth {} trying to send", self.name); + + Self::to_peer(&peer, data); + } + } + + pub(self) fn to_peer(peer: &Arc, data: &[u8]) { + let mut peer_veth = peer.driver.inner.lock(); + peer_veth.rx_queue.push_back(data.to_vec()); + + // { + // let ether = EthernetFrame::new_checked(data).unwrap(); + // if ether.ethertype() == smoltcp::wire::EthernetProtocol::Ipv4 { + // if let Some(ipv4_packet) = + // smoltcp::wire::Ipv4Packet::new_checked(ether.payload()).ok() + // { + // log::info!( + // "Veth {} sending IPv4 packet to peer: {} -> {}", + // peer.name, + // ipv4_packet.src_addr(), + // ipv4_packet.dst_addr() + // ); + // } + // } else if ether.ethertype() == smoltcp::wire::EthernetProtocol::Ipv6 { + // if let Some(ipv6_packet) = + // smoltcp::wire::Ipv6Packet::new_checked(ether.payload()).ok() + // { + // log::info!( + // "Veth {} sending IPv6 packet to peer: {} -> {}", + // peer.name, + // ipv6_packet.src_addr(), + // ipv6_packet.dst_addr() + // ); + // } + // } else { + // log::info!( + // "Veth {} sending non-IP packet to peer: ethertype={:?}", + // peer.name, + // ether.ethertype() + // ); + // } + // } + + drop(peer_veth); + + let Some(napi) = peer.napi_struct() else { + log::error!("Veth {} has no napi_struct", peer.name); + return; + }; + + napi_schedule(napi); + } + + fn to_bridge(bridge_data: &BridgeCommonData, data: &[u8]) { + // log::info!("Veth {} sending data to bridge", self.name); + let Some(bridge) = bridge_data.bridge_driver_ref.upgrade() else { + log::warn!("Bridge has been dropped"); + return; + }; + bridge.handle_frame(bridge_data.id, data); + } + + /// 经过路由发送,返回是否发送成功 + fn to_router(&self, data: &[u8]) -> bool { + let Some(self_iface) = self.self_iface_ref.upgrade() else { + return false; + }; + + let frame: EthernetFrame<&[u8]> = EthernetFrame::new_checked(data).unwrap(); + // log::info!("trying to go to router"); + match self_iface.handle_routable_packet(&frame) { + Ok(_) => { + // log::info!("successfully sent to router"); + true + } + // 先不管错误,直接告诉外面没有经过路由发送出去 + Err(Some(err)) => { + log::error!("Router error: {:?}", err); + false + } + Err(_) => { + // log::info!("not routed"); + false + } + } + } + + pub fn recv_from_peer(&mut self) -> Option> { + // log::info!("Veth {} trying to receive", self.name); + let data = self.rx_queue.pop_front()?; + + if let Some(bridge_common_data) = self + .self_iface_ref + .upgrade() + .unwrap() + .inner + .lock() + .bridge_common_data + .as_ref() + { + // log::info!("Veth {} sending data to bridge", self.name); + Self::to_bridge(bridge_common_data, &data); + return None; + } + + // 说明获取的包发给进入路由了,无须返回 + if self.to_router(&data) { + return None; + } + + Some(data) + } + + pub fn name(&self) -> &str { + &self.name + } +} + +#[derive(Clone)] +pub struct VethDriver { + pub inner: Arc>, +} + +impl VethDriver { + /// # `new_pair` + /// 创建一对虚拟以太网设备(veth pair),用于网络测试 + /// ## 参数 + /// - `name1`: 第一个设备的名称 + /// - `name2`: 第二个设备的名称 + /// ## 返回值 + /// 返回一个元组,包含两个 `VethDriver` 实例,分别对应 + /// 第一个和第二个虚拟以太网设备。 + pub fn new_pair(name1: &str, name2: &str) -> (VethDriver, VethDriver) { + let dev1 = Arc::new(SpinLock::new(Veth::new(name1.to_string()))); + let dev2 = Arc::new(SpinLock::new(Veth::new(name2.to_string()))); + + let driver1 = VethDriver { inner: dev1 }; + let driver2 = VethDriver { inner: dev2 }; + + (driver1, driver2) + } + + pub fn name(&self) -> String { + self.inner.lock().name().to_string() + } +} + +pub struct VethTxToken { + driver: VethDriver, +} + +impl phy::TxToken for VethTxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut buf = vec![0; len]; + let result = f(&mut buf); + self.driver.inner.lock().send_to_peer(&buf); + result + } +} + +pub struct VethRxToken { + buffer: Vec, +} + +impl RxToken for VethRxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&[u8]) -> R, + { + f(&self.buffer) + } +} + +#[derive(Debug)] +struct VethDriverWarpper(UnsafeCell); +unsafe impl Send for VethDriverWarpper {} +unsafe impl Sync for VethDriverWarpper {} + +impl Deref for VethDriverWarpper { + type Target = VethDriver; + fn deref(&self) -> &Self::Target { + unsafe { &*self.0.get() } + } +} + +impl DerefMut for VethDriverWarpper { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.0.get() } + } +} + +impl VethDriverWarpper { + #[allow(clippy::mut_from_ref)] + #[allow(clippy::mut_from_ref)] + fn force_get_mut(&self) -> &mut VethDriver { + unsafe { &mut *self.0.get() } + } +} + +impl phy::Device for VethDriver { + type RxToken<'a> = VethRxToken; + type TxToken<'a> = VethTxToken; + + fn capabilities(&self) -> DeviceCapabilities { + let mut caps = DeviceCapabilities::default(); + caps.max_transmission_unit = 1500; + caps.medium = smoltcp::phy::Medium::Ethernet; + caps + } + + fn receive( + &mut self, + _timestamp: smoltcp::time::Instant, + ) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut guard = self.inner.lock(); + guard.recv_from_peer().map(|buf| { + // log::info!("VethDriver received data: {:?}", buf); + ( + VethRxToken { buffer: buf }, + VethTxToken { + driver: self.clone(), + }, + ) + }) + } + + fn transmit(&mut self, _timestamp: smoltcp::time::Instant) -> Option> { + Some(VethTxToken { + driver: self.clone(), + }) + } +} + +#[cast_to([sync] Iface)] +#[cast_to([sync] device::Device)] +#[derive(Debug)] +pub struct VethInterface { + name: String, + driver: VethDriverWarpper, + common: IfaceCommon, + inner: SpinLock, + locked_kobj_state: LockedKObjectState, +} + +#[derive(Debug)] +pub struct VethCommonData { + netdevice_common: NetDeviceCommonData, + device_common: DeviceCommonData, + kobj_common: KObjectCommonData, + peer_veth: Weak, + + bridge_common_data: Option, +} + +impl Default for VethCommonData { + fn default() -> Self { + VethCommonData { + netdevice_common: NetDeviceCommonData::default(), + device_common: DeviceCommonData::default(), + kobj_common: KObjectCommonData::default(), + peer_veth: Weak::new(), + bridge_common_data: None, + } + } +} + +impl VethInterface { + pub fn peer_veth(&self) -> Arc { + self.inner.lock().peer_veth.upgrade().unwrap() + } + + pub fn new(driver: VethDriver) -> Arc { + let iface_id = generate_iface_id(); + let name = driver.name(); + let mac = [ + 0x02, + 0x00, + 0x00, + 0x00, + (iface_id >> 8) as u8, + iface_id as u8, + ]; + let hw_addr = HardwareAddress::Ethernet(EthernetAddress(mac)); + let mut iface_config = smoltcp::iface::Config::new(hw_addr); + iface_config.random_seed = rand() as u64; + let mut iface = smoltcp::iface::Interface::new( + iface_config, + &mut driver.clone(), + crate::time::Instant::now().into(), + ); + iface.set_any_ip(true); + + let flags = InterfaceFlags::BROADCAST + | InterfaceFlags::MULTICAST + | InterfaceFlags::UP + | InterfaceFlags::RUNNING + | InterfaceFlags::LOWER_UP; + + let device = Arc::new(VethInterface { + name, + driver: VethDriverWarpper(UnsafeCell::new(driver.clone())), + common: IfaceCommon::new(iface_id, super::types::InterfaceType::EETHER, flags, iface), + inner: SpinLock::new(VethCommonData::default()), + locked_kobj_state: LockedKObjectState::default(), + }); + let napi_struct = NapiStruct::new(device.clone(), 10); + *device.common.napi_struct.write() = Some(napi_struct); + + driver.inner.lock().self_iface_ref = Arc::downgrade(&device); + + // log::info!("VethInterface {} created with ID {}", device.name, iface_id); + device + } + + pub fn set_peer_iface(&self, peer: &Arc) { + let mut inner = self.inner.lock(); + inner.peer_veth = Arc::downgrade(peer); + self.driver.inner.lock().set_peer_iface(peer); + } + + pub fn new_pair(name1: &str, name2: &str) -> (Arc, Arc) { + let (driver1, driver2) = VethDriver::new_pair(name1, name2); + let iface1 = VethInterface::new(driver1); + let iface2 = VethInterface::new(driver2); + + iface1.set_peer_iface(&iface2); + iface2.set_peer_iface(&iface1); + + (iface1, iface2) + } + + fn inner(&self) -> SpinLockGuard<'_, VethCommonData> { + self.inner.lock() + } + + /// # `update_ip_addrs` + /// 更新虚拟以太网设备的 IP 地址 + /// ## 参数 + /// - `cidr`: 要添加的 IP 地址和子网掩码 + /// ## 描述 + /// 该方法会将指定的 IP 地址添加到虚拟以太网设备的 IP 地址列表中。 + /// 如果添加失败(例如列表已满),则会触发 panic。 + pub fn update_ip_addrs(&self, cidr: IpCidr) { + let iface = &mut self.common.smol_iface.lock_irqsave(); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs.push(cidr).expect("Push ipCidr failed: full"); + }); + self.common.router_common_data.ip_addrs.write().push(cidr); + + // // 直接更新对端的arp_table + // self.inner.lock().peer_veth.upgrade().map(|peer| { + // peer.common + // .router_common_data + // .arp_table + // .write() + // .insert(cidr.address(), self.mac()) + // }); + + // log::info!("VethInterface {} updated IP address: {}", self.name, addr); + } + + /// # `add_default_route_to_peer` + /// 添加默认路由到对端虚拟以太网设备 + /// ## 参数 + /// - `peer_ip`: 对端设备的 IP 地址 + /// ## 描述 + /// 该方法会在当前虚拟以太网设备的路由表中 + /// 添加一条默认路由, + /// 指向对端虚拟以太网设备的 IP 地址。 + /// 如果添加失败,则会触发 panic。 + /// + pub fn add_default_route_to_peer(&self, peer_ip: IpAddress) { + let iface = &mut self.common.smol_iface.lock_irqsave(); + // iface.update_ip_addrs(|ip_addrs| { + // ip_addrs.push(self_cidr).expect("Push ipCidr failed: full"); + // }); + iface.routes_mut().update(|routes_map| { + routes_map + .push(smoltcp::iface::Route { + cidr: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + via_router: peer_ip, + preferred_until: None, + expires_at: None, + }) + .expect("Add default route to peer failed"); + }); + } + + // pub fn add_direct_route(&self, cidr: IpCidr, via_router: IpAddress) { + // let iface = &mut self.common.smol_iface.lock_irqsave(); + // iface.routes_mut().update(|routes_map| { + // routes_map + // .push(smoltcp::iface::Route { + // cidr, + // via_router, + // preferred_until: None, + // expires_at: None, + // }) + // .expect("Add direct route failed"); + // }); + // } +} + +impl KObject for VethInterface { + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } + + fn set_inode(&self, inode: Option>) { + self.inner().kobj_common.kern_inode = inode; + } + + fn inode(&self) -> Option> { + self.inner().kobj_common.kern_inode.clone() + } + + fn parent(&self) -> Option> { + self.inner().kobj_common.parent.clone() + } + + fn set_parent(&self, parent: Option>) { + self.inner().kobj_common.parent = parent; + } + + fn kset(&self) -> Option> { + self.inner().kobj_common.kset.clone() + } + + fn set_kset(&self, kset: Option>) { + self.inner().kobj_common.kset = kset; + } + + fn kobj_type(&self) -> Option<&'static dyn KObjType> { + self.inner().kobj_common.kobj_type + } + + fn name(&self) -> String { + self.name.clone() + } + + fn set_name(&self, _name: String) {} + fn kobj_state(&self) -> RwLockReadGuard<'_, KObjectState> { + self.locked_kobj_state.read() + } + + fn kobj_state_mut(&self) -> RwLockWriteGuard<'_, KObjectState> { + self.locked_kobj_state.write() + } + + fn set_kobj_state(&self, state: KObjectState) { + *self.locked_kobj_state.write() = state; + } + + fn set_kobj_type(&self, ktype: Option<&'static dyn KObjType>) { + self.inner().kobj_common.kobj_type = ktype; + } +} + +impl device::Device for VethInterface { + fn dev_type(&self) -> DeviceType { + DeviceType::Net + } + + fn id_table(&self) -> IdTable { + IdTable::new(self.name.clone(), None) + } + + fn bus(&self) -> Option> { + self.inner().device_common.bus.clone() + } + + fn set_bus(&self, bus: Option>) { + self.inner().device_common.bus = bus; + } + + fn class(&self) -> Option> { + let mut guard = self.inner(); + let r = guard.device_common.class.clone()?.upgrade(); + if r.is_none() { + guard.device_common.class = None; + } + r + } + + fn set_class(&self, class: Option>) { + self.inner().device_common.class = class; + } + + fn driver(&self) -> Option> { + let r = self.inner().device_common.driver.clone()?.upgrade(); + if r.is_none() { + self.inner().device_common.driver = None; + } + r + } + + fn set_driver(&self, driver: Option>) { + self.inner().device_common.driver = driver; + } + + fn is_dead(&self) -> bool { + false + } + + fn can_match(&self) -> bool { + self.inner().device_common.can_match + } + + fn set_can_match(&self, can_match: bool) { + self.inner().device_common.can_match = can_match; + } + + fn state_synced(&self) -> bool { + true + } + + fn dev_parent(&self) -> Option> { + self.inner().device_common.get_parent_weak_or_clear() + } + + fn set_dev_parent(&self, parent: Option>) { + self.inner().device_common.parent = parent; + } +} + +impl Iface for VethInterface { + fn common(&self) -> &IfaceCommon { + &self.common + } + + fn iface_name(&self) -> String { + self.name.clone() + } + + fn mac(&self) -> EthernetAddress { + if let HardwareAddress::Ethernet(mac) = + self.common.smol_iface.lock_irqsave().hardware_addr() + { + mac + } else { + EthernetAddress([0, 0, 0, 0, 0, 0]) + } + } + + fn poll(&self) -> bool { + // log::info!("VethInterface {} polling normal", self.name); + self.common.poll(self.driver.force_get_mut()) + // self.clear_recv_buffer(); + } + + fn addr_assign_type(&self) -> u8 { + self.inner().netdevice_common.addr_assign_type + } + + fn net_device_type(&self) -> u16 { + self.inner().netdevice_common.net_device_type = 1; + self.inner().netdevice_common.net_device_type + } + + fn net_state(&self) -> NetDeivceState { + self.inner().netdevice_common.state + } + + fn set_net_state(&self, state: NetDeivceState) { + self.inner().netdevice_common.state |= state; + } + + fn operstate(&self) -> Operstate { + self.inner().netdevice_common.operstate + } + + fn set_operstate(&self, state: Operstate) { + self.inner().netdevice_common.operstate = state; + } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + self.driver + .force_get_mut() + .capabilities() + .max_transmission_unit + } +} + +impl BridgeEnableDevice for VethInterface { + fn receive_from_bridge(&self, frame: &[u8]) { + log::info!("VethInterface {} received from bridge", self.name); + let peer = self.peer_veth(); + + // let inner = self.inner.lock_irqsave(); + + if self + .inner + .lock() + .bridge_common_data + .as_ref() + .unwrap() + .bridge_driver_ref + .upgrade() + .is_some() + { + log::info!("VethInterface {} sending data to peer", self.name); + + // let peer = self.peer_veth(); + Veth::to_peer(&peer, frame); + // self.driver + // .inner + // .lock_irqsave() + // .rx_queue + // .push_back(frame.to_vec()); + // peer.poll(); + } + log::info!("returning"); + } + + fn set_common_bridge_data(&self, port: &BridgePort) { + // log::info!("Now set bridge port data for {}", self.name); + let mut inner = self.inner.lock(); + let data = BridgeCommonData { + id: port.id, + bridge_driver_ref: port.bridge_driver_ref.clone(), + }; + inner.bridge_common_data = Some(data); + } + + fn common_bridge_data(&self) -> Option { + self.inner().bridge_common_data.clone() + } +} + +impl RouterEnableDevice for VethInterface { + fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]) { + log::info!( + "VethInterface {} routing packet to {}", + self.iface_name(), + next_hop + ); + + // 构造以太网帧 + let dst_mac = self.peer_veth().mac(); + let src_mac = self.mac(); + + // 以太网类型为 IPv4 + let ethertype = [0x08, 0x00]; + + let mut frame = Vec::with_capacity(14 + ip_packet.len()); + frame.extend_from_slice(&dst_mac.0); + frame.extend_from_slice(&src_mac.0); + frame.extend_from_slice(ðertype); + frame.extend_from_slice(ip_packet); + + // 发送到对端 + self.driver.inner.lock().send_to_peer(&frame); + } + + fn is_my_ip(&self, ip: IpAddress) -> bool { + self.common + .ip_addrs() + .iter() + .any(|cidr| cidr.contains_addr(&ip)) + } +} + +fn veth_route_test() { + let (iface_ns1, iface_host1) = VethInterface::new_pair("veth-ns1", "veth-host1"); + let (iface_ns2, iface_host2) = VethInterface::new_pair("veth-ns2", "veth-host2"); + + let addr1 = IpAddress::v4(192, 168, 1, 1); + let cidr1 = IpCidr::new(addr1, 24); + iface_ns1.update_ip_addrs(cidr1); + + let addr2 = IpAddress::v4(192, 168, 1, 254); + let cidr2 = IpCidr::new(addr2, 24); + iface_host1.update_ip_addrs(cidr2); + + let addr3 = IpAddress::v4(192, 168, 2, 254); + let cidr3 = IpCidr::new(addr3, 24); + iface_host2.update_ip_addrs(cidr3); + + let addr4 = IpAddress::v4(192, 168, 2, 3); + let cidr4 = IpCidr::new(addr4, 24); + iface_ns2.update_ip_addrs(cidr4); + + // 添加默认路由 + iface_ns1.add_default_route_to_peer(addr2); + iface_host1.add_default_route_to_peer(addr1); + + iface_host2.add_default_route_to_peer(addr4); + iface_ns2.add_default_route_to_peer(addr3); + + let turn_on = |a: &Arc, ns: Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + ns.add_device(a.clone()); + a.common().set_net_namespace(ns.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + let ns1 = NetNamespace::new_empty(ProcessManager::current_user_ns()).unwrap(); + let ns2 = NetNamespace::new_empty(ProcessManager::current_user_ns()).unwrap(); + + let router_ns1 = ns1.router(); + // 任何发往 192.168.1.0/24 网络的数据包都是本地邻居,可以直接从 veth-ns1 发送。 + let dest = IpCidr::new(IpAddress::v4(192, 168, 1, 0), 24); + let route = RouteEntry::new_connected(dest, iface_ns1.clone()); + router_ns1.add_route(route); + // 任何不匹配其他路由的数据包,都应该通过 veth-ns1 接口发送给下一跳 192.168.1.254。 + let next_hop = IpAddress::v4(192, 168, 1, 254); + let route = RouteEntry::new_default(next_hop, iface_ns1.clone()); + router_ns1.add_route(route); + + let router_ns2 = ns2.router(); + // 任何发往 192.168.2.0/24 网络的数据包都是本地邻居,可以直接从 veth-ns2 发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 2, 0), 24); + let route = RouteEntry::new_connected(dest, iface_ns2.clone()); + router_ns2.add_route(route); + // 任何不匹配其他路由的数据包,都应该通过 veth-ns2 接口发送给下一跳 192.168.2.254 + let next_hop = IpAddress::v4(192, 168, 2, 254); + let route = RouteEntry::new_default(next_hop, iface_ns2.clone()); + router_ns2.add_route(route); + + let host_router = INIT_NET_NAMESPACE.router(); + // 任何发往 192.168.1.0/24 网络的数据包,都应该从 veth-host1 接口直接发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 1, 0), 24); + let route = RouteEntry::new_connected(dest, iface_host1.clone()); + host_router.add_route(route); + // 任何发往 192.168.2.0/24 网络的数据包,都应该从 veth-host2 接口直接发送 + let dest = IpCidr::new(IpAddress::v4(192, 168, 2, 0), 24); + let route = RouteEntry::new_connected(dest, iface_host2.clone()); + host_router.add_route(route); + + turn_on(&iface_ns1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_ns2, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_host1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface_host2, INIT_NET_NAMESPACE.clone()); + + let snat_rules = vec![SnatRule { + // 匹配所有来自 192.168.1.0/24 网络的流量 + source_cidr: "192.168.1.0/24".parse().unwrap(), + // 将源地址转换为 192.168.2.254(hardcode) + nat_ip: IpAddress::v4(192, 168, 2, 254), + }]; + + let dnat_rules = vec![DnatRule { + external_addr: IpAddress::v4(192, 168, 2, 1), + internal_addr: IpAddress::v4(192, 168, 2, 3), + internal_port: None, + external_port: None, + }]; + + host_router.nat_tracker().update_snat_rules(snat_rules); + host_router.nat_tracker().update_dnat_rules(dnat_rules); +} + +fn veth_epoll_test() { + let (iface1, iface2) = VethInterface::new_pair("veth1", "veth2"); + + let addr1 = IpAddress::v4(111, 111, 11, 1); + let cidr1 = IpCidr::new(addr1, 24); + iface1.update_ip_addrs(cidr1); + + let addr2 = IpAddress::v4(111, 111, 11, 2); + let cidr2 = IpCidr::new(addr2, 24); + iface2.update_ip_addrs(cidr2); + + iface1.add_default_route_to_peer(addr2); + iface2.add_default_route_to_peer(addr1); + + let turn_on = |a: &Arc, ns: Arc| { + a.set_net_state(NetDeivceState::__LINK_STATE_START); + a.set_operstate(Operstate::IF_OPER_UP); + // NET_DEVICES.write_irqsave().insert(a.nic_id(), a.clone()); + ns.add_device(a.clone()); + a.common().set_net_namespace(ns.clone()); + register_netdevice(a.clone()).expect("register veth device failed"); + }; + + turn_on(&iface1, INIT_NET_NAMESPACE.clone()); + turn_on(&iface2, INIT_NET_NAMESPACE.clone()); +} + +#[unified_init(INITCALL_DEVICE)] +pub fn veth_init() -> Result<(), SystemError> { + veth_epoll_test(); + veth_route_test(); + log::info!("Veth pair initialized."); + Ok(()) +} diff --git a/kernel/src/driver/net/virtio_net.rs b/kernel/src/driver/net/virtio_net.rs index 07ccce7d4..d5642596b 100644 --- a/kernel/src/driver/net/virtio_net.rs +++ b/kernel/src/driver/net/virtio_net.rs @@ -29,7 +29,11 @@ use crate::{ kobject::{KObjType, KObject, KObjectCommonData, KObjectState, LockedKObjectState}, kset::KSet, }, - net::register_netdevice, + net::{ + napi::{napi_schedule, NapiStruct}, + register_netdevice, + types::InterfaceFlags, + }, virtio::{ irq::virtio_irq_manager, sysfs::{virtio_bus, virtio_device_manager, virtio_driver_manager}, @@ -43,10 +47,11 @@ use crate::{ filesystem::{kernfs::KernFSInode, sysfs::AttributeGroup}, init::initcall::INITCALL_POSTCORE, libs::{ - rwlock::{RwLockReadGuard, RwLockWriteGuard}, + rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}, spinlock::{SpinLock, SpinLockGuard}, }, - net::{generate_iface_id, NET_DEVICES}, + net::generate_iface_id, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::Instant, }; use system_error::SystemError; @@ -68,6 +73,9 @@ pub struct VirtIONetDevice { dev_id: Arc, inner: SpinLock, locked_kobj_state: LockedKObjectState, + + // 指向对应的interface + iface_ref: RwLock>, } impl Debug for VirtIONetDevice { @@ -125,6 +133,7 @@ impl VirtIONetDevice { device_common: DeviceCommonData::default(), }), locked_kobj_state: LockedKObjectState::default(), + iface_ref: RwLock::new(Weak::new()), }); // dev.set_driver(Some(Arc::downgrade(&virtio_net_driver()) as Weak)); @@ -135,6 +144,14 @@ impl VirtIONetDevice { fn inner(&self) -> SpinLockGuard<'_, InnerVirtIONetDevice> { return self.inner.lock(); } + + pub fn set_iface(&self, iface: &Arc) { + *self.iface_ref.write() = Arc::downgrade(iface); + } + + pub fn iface(&self) -> Option> { + self.iface_ref.read().upgrade() + } } impl KObject for VirtIONetDevice { @@ -269,8 +286,22 @@ impl Device for VirtIONetDevice { impl VirtIODevice for VirtIONetDevice { fn handle_irq(&self, _irq: IrqNumber) -> Result { - // log::debug!("try to wakeup"); - super::kthread::wakeup_poll_thread(); + let Some(iface) = self.iface() else { + error!( + "VirtIONetDevice '{:?}' has no associated iface to handle irq", + self.dev_id.id() + ); + return Ok(IrqReturn::NotHandled); + }; + + let Some(napi) = iface.napi_struct() else { + log::error!("Virtio net device {} has no napi_struct", iface.name()); + return Ok(IrqReturn::NotHandled); + }; + + napi_schedule(napi); + + // self.netns.wakeup_poll_thread(); return Ok(IrqReturn::Handled); } @@ -390,6 +421,93 @@ pub struct VirtioInterface { locked_kobj_state: LockedKObjectState, } +// // 先手糊为virtio实现这些,后面系统要是有了其他类型网卡,这些实现就得实现成一个单独的trait +// impl VirtioInterface { +// /// 消耗token然后主动发送一个 arp 数据包 +// pub fn emit_arp(arp_repr: &ArpRepr, tx_token: VirtioNetToken) { +// let ether_repr = match arp_repr { +// ArpRepr::EthernetIpv4 { +// source_hardware_addr, +// target_hardware_addr, +// .. +// } => EthernetRepr { +// src_addr: *source_hardware_addr, +// dst_addr: *target_hardware_addr, +// ethertype: EthernetProtocol::Arp, +// }, +// _ => return, +// }; + +// tx_token.consume(ether_repr.buffer_len() + arp_repr.buffer_len(), |buffer| { +// let mut frame = EthernetFrame::new_unchecked(buffer); +// ether_repr.emit(&mut frame); + +// let mut pkt = ArpPacket::new_unchecked(frame.payload_mut()); +// arp_repr.emit(&mut pkt); +// }); +// } + +// /// 解析 arp 包并处理 +// pub fn process_arp(&self, arp_repr: &ArpRepr) -> Option { +// match arp_repr { +// ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Reply, +// source_hardware_addr, +// source_protocol_addr, +// .. +// } => { +// if !source_hardware_addr.is_unicast() +// || !self +// .common() +// .smol_iface +// .lock() +// .context() +// .in_same_network(&IpAddress::Ipv4(*source_protocol_addr)) +// { +// return None; +// } + +// self.common().router_common_data.arp_table.write().insert( +// IpAddress::Ipv4(*source_protocol_addr), +// *source_hardware_addr, +// ); + +// None +// } +// ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Request, +// source_hardware_addr, +// source_protocol_addr, +// target_protocol_addr, +// .. +// } => { +// if !source_hardware_addr.is_unicast() || !source_protocol_addr.x_is_unicast() { +// return None; +// } + +// if self +// .common() +// .smol_iface +// .lock() +// .context() +// .ipv4_addr() +// .is_none_or(|addr| addr != *target_protocol_addr) +// { +// return None; +// } +// Some(ArpRepr::EthernetIpv4 { +// operation: ArpOperation::Reply, +// source_hardware_addr: self.mac(), +// source_protocol_addr: *target_protocol_addr, +// target_hardware_addr: *source_hardware_addr, +// target_protocol_addr: *source_protocol_addr, +// }) +// } +// _ => None, +// } +// } +// } + #[derive(Debug)] struct InnerVirtIOInterface { kobj_common: KObjectCommonData, @@ -407,11 +525,22 @@ impl VirtioInterface { let iface = iface::Interface::new(iface_config, &mut device_inner, Instant::now().into()); - let result = Arc::new(VirtioInterface { + let flags = InterfaceFlags::UP + | InterfaceFlags::BROADCAST + | InterfaceFlags::RUNNING + | InterfaceFlags::MULTICAST + | InterfaceFlags::LOWER_UP; + + let iface = Arc::new(VirtioInterface { device_inner: VirtIONicDeviceInnerWrapper(UnsafeCell::new(device_inner)), locked_kobj_state: LockedKObjectState::default(), iface_name: format!("eth{}", iface_id), - iface_common: super::IfaceCommon::new(iface_id, true, iface), + iface_common: super::IfaceCommon::new( + iface_id, + crate::driver::net::types::InterfaceType::EETHER, + flags, + iface, + ), inner: SpinLock::new(InnerVirtIOInterface { kobj_common: KObjectCommonData::default(), device_common: DeviceCommonData::default(), @@ -419,7 +548,11 @@ impl VirtioInterface { }), }); - return result; + // 设置napi struct + let napi_struct = NapiStruct::new(iface.clone(), 10); + *iface.common().napi_struct.write() = Some(napi_struct); + + iface } fn inner(&self) -> SpinLockGuard<'_, InnerVirtIOInterface> { @@ -436,7 +569,10 @@ impl VirtioInterface { impl Drop for VirtioInterface { fn drop(&mut self) { // 从全局的网卡接口信息表中删除这个网卡的接口信息 - NET_DEVICES.write_irqsave().remove(&self.nic_id()); + // NET_DEVICES.write_irqsave().remove(&self.nic_id()); + if let Some(ns) = self.net_namespace() { + ns.remove_device(&self.nic_id()); + } } } @@ -650,7 +786,7 @@ impl Iface for VirtioInterface { return self.iface_name.clone(); } - fn poll(&self) { + fn poll(&self) -> bool { // log::debug!("VirtioInterface: poll"); self.iface_common.poll(self.device_inner.force_get_mut()) } @@ -683,6 +819,14 @@ impl Iface for VirtioInterface { fn set_operstate(&self, state: Operstate) { self.inner().netdevice_common.operstate = state; } + + fn mtu(&self) -> usize { + use smoltcp::phy::Device; + self.device_inner + .force_get_mut() + .capabilities() + .max_transmission_unit + } } impl KObject for VirtioInterface { @@ -819,10 +963,18 @@ impl VirtIODriver for VirtIONetDriver { // 在sysfs中注册iface register_netdevice(iface.clone() as Arc)?; + // 将virtio_net_device和iface关联起来 + virtio_net_device.set_iface(&iface); + // 将网卡的接口信息注册到全局的网卡接口信息表中 - NET_DEVICES - .write_irqsave() - .insert(iface.nic_id(), iface.clone()); + // NET_DEVICES + // .write_irqsave() + // .insert(iface.nic_id(), iface.clone()); + INIT_NET_NAMESPACE.add_device(iface.clone()); + iface + .iface_common + .set_net_namespace(INIT_NET_NAMESPACE.clone()); + INIT_NET_NAMESPACE.set_default_iface(iface.clone()); virtio_irq_manager() .register_device(device.clone()) diff --git a/kernel/src/filesystem/epoll/event_poll.rs b/kernel/src/filesystem/epoll/event_poll.rs index d5075242f..01dc11f7f 100644 --- a/kernel/src/filesystem/epoll/event_poll.rs +++ b/kernel/src/filesystem/epoll/event_poll.rs @@ -663,6 +663,7 @@ impl EventPoll { epoll_guard.ep_add_ready(epitem.clone()); if epoll_guard.ep_has_waiter() { + // log::info!("wakeup epoll waiters"); if ep_events.contains(EPollEventType::EPOLLEXCLUSIVE) && !pollflags.contains(EPollEventType::POLLFREE) { diff --git a/kernel/src/filesystem/epoll/mod.rs b/kernel/src/filesystem/epoll/mod.rs index c581c8475..2ee4c5b89 100644 --- a/kernel/src/filesystem/epoll/mod.rs +++ b/kernel/src/filesystem/epoll/mod.rs @@ -195,3 +195,9 @@ bitflags! { const EPOLL_LISTEN_CAN_ACCEPT = Self::EPOLLIN.bits | Self::EPOLLRDNORM.bits; } } + +impl EPollEventType { + pub fn filter(&self, events: &EPollEventType) -> bool { + self.intersects(*events) + } +} diff --git a/kernel/src/filesystem/vfs/mod.rs b/kernel/src/filesystem/vfs/mod.rs index b22d8d14e..f6ff006d9 100644 --- a/kernel/src/filesystem/vfs/mod.rs +++ b/kernel/src/filesystem/vfs/mod.rs @@ -26,6 +26,7 @@ use crate::{ spinlock::{SpinLock, SpinLockGuard}, }, mm::{fault::PageFaultMessage, VmFaultReason}, + net::socket::Socket, process::ProcessManager, time::PosixTimeSpec, }; @@ -687,6 +688,19 @@ pub trait IndexNode: Any + Sync + Send + Debug + CastFromSync { ); return Err(SystemError::ENOSYS); } + + /// # 将当前Inode转换为 Socket 引用 + /// + /// # 返回值 + /// - Some(&dyn Socket): 当前Inode是Socket类型,返回其引用 + /// - None: 当前Inode不是Socket类型 + /// + /// # 注意 + /// 这个方法已经为dyn Socket实现, + /// 所以如果可以确定当前`dyn IndexNode`是`dyn Socket`类型,则可以直接调用此方法进行转换 + fn as_socket(&self) -> Option<&dyn Socket> { + None + } } impl DowncastArc for dyn IndexNode { diff --git a/kernel/src/net/mod.rs b/kernel/src/net/mod.rs index 805761cc6..9b68003af 100644 --- a/kernel/src/net/mod.rs +++ b/kernel/src/net/mod.rs @@ -4,21 +4,14 @@ //! 敬请注意。 use core::sync::atomic::AtomicUsize; -use alloc::{collections::BTreeMap, sync::Arc}; - -use crate::{driver::net::Iface, libs::rwlock::RwLock}; +use crate::driver::net::Iface; pub mod net_core; pub mod posix; +pub mod routing; pub mod socket; pub mod syscall; -lazy_static! { - /// # 所有网络接口的列表 - /// 这个列表在中断上下文会使用到,因此需要irqsave - pub static ref NET_DEVICES: RwLock>> = RwLock::new(BTreeMap::new()); -} - /// 生成网络接口的id (全局自增) pub fn generate_iface_id() -> usize { static IFACE_ID: AtomicUsize = AtomicUsize::new(0); diff --git a/kernel/src/net/net_core.rs b/kernel/src/net/net_core.rs index 88e49653e..4b99deb34 100644 --- a/kernel/src/net/net_core.rs +++ b/kernel/src/net/net_core.rs @@ -3,7 +3,7 @@ use system_error::SystemError; use crate::{ driver::net::Operstate, - net::NET_DEVICES, + process::namespace::net_namespace::INIT_NET_NAMESPACE, time::{sleep::nanosleep, PosixTimeSpec}, }; @@ -12,7 +12,8 @@ pub fn net_init() -> Result<(), SystemError> { } fn dhcp_query() -> Result<(), SystemError> { - let binding = NET_DEVICES.write_irqsave(); + // let binding = NET_DEVICES.write_irqsave(); + let binding = INIT_NET_NAMESPACE.device_list_mut(); let net_face = binding .iter() @@ -43,7 +44,7 @@ fn dhcp_query() -> Result<(), SystemError> { sockets().remove(dhcp_handle); }); - const DHCP_TRY_ROUND: u8 = 100; + const DHCP_TRY_ROUND: u8 = 10; for i in 0..DHCP_TRY_ROUND { log::debug!("DHCP try round: {}", i); net_face.poll(); diff --git a/kernel/src/net/posix.rs b/kernel/src/net/posix.rs index 3bb72142e..b1892f5b5 100644 --- a/kernel/src/net/posix.rs +++ b/kernel/src/net/posix.rs @@ -36,6 +36,7 @@ impl PosixArgsSocketType { } use super::socket::{endpoint::Endpoint, AddressFamily}; +use crate::net::socket::netlink::addr::{multicast::GroupIdSet, NetlinkSocketAddr}; use crate::net::socket::unix::UnixEndpoint; use alloc::string::ToString; use core::ffi::CStr; @@ -103,8 +104,8 @@ impl From for SockAddr { smoltcp::wire::IpAddress::Ipv4(ipv4_addr) => Self { addr_in: SockAddrIn { sin_family: AddressFamily::INet as u16, - sin_port: value.port, - sin_addr: ipv4_addr.to_bits(), + sin_port: value.port.to_be(), + sin_addr: ipv4_addr.to_bits().to_be(), sin_zero: Default::default(), }, }, @@ -146,12 +147,26 @@ impl From for SockAddr { } } +impl From for SockAddr { + fn from(value: NetlinkSocketAddr) -> Self { + SockAddr { + addr_nl: SockAddrNl { + nl_family: AddressFamily::Netlink, + nl_pad: 0, + nl_pid: value.port(), + nl_groups: value.groups().as_u32(), + }, + } + } +} + impl From for SockAddr { fn from(value: Endpoint) -> Self { match value { Endpoint::LinkLayer(_link_layer_endpoint) => todo!(), Endpoint::Ip(endpoint) => Self::from(endpoint), Endpoint::Unix(unix_endpoint) => Self::from(unix_endpoint), + Endpoint::Netlink(netlink_addr) => Self::from(netlink_addr), } } } @@ -252,6 +267,21 @@ impl SockAddr { return Ok(Endpoint::Unix(UnixEndpoint::File(path.to_string()))); } + AddressFamily::Netlink => { + if len < addr.len()? { + log::error!("len < addr.len() for Netlink"); + return Err(SystemError::EINVAL); + } + + let addr_nl: SockAddrNl = addr.addr_nl; + let nl_pid = addr_nl.nl_pid; + let nl_groups = addr_nl.nl_groups; + + Ok(Endpoint::Netlink(NetlinkSocketAddr::new( + nl_pid, + GroupIdSet::new(nl_groups), + ))) + } _ => { log::warn!("not support address family {:?}", addr.family); return Err(SystemError::EINVAL); diff --git a/kernel/src/net/routing/mod.rs b/kernel/src/net/routing/mod.rs new file mode 100644 index 000000000..5c46e0e0c --- /dev/null +++ b/kernel/src/net/routing/mod.rs @@ -0,0 +1,492 @@ +use crate::driver::net::Iface; +use crate::libs::rwlock::RwLock; +use crate::net::routing::nat::ConnTracker; +use crate::net::routing::nat::DnatPolicy; +use crate::net::routing::nat::FiveTuple; +use crate::net::routing::nat::NatPktStatus; +use crate::net::routing::nat::NatPolicy; +use crate::net::routing::nat::SnatPolicy; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::namespace::net_namespace::INIT_NET_NAMESPACE; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use alloc::vec::Vec; +use core::net::Ipv4Addr; +use smoltcp::wire::{EthernetFrame, IpAddress, IpCidr, Ipv4Packet}; +use system_error::SystemError; + +mod nat; + +pub use nat::{DnatRule, SnatRule}; + +#[derive(Debug, Clone)] +pub struct RouteEntry { + /// 目标网络 + pub destination: IpCidr, + /// 下一跳地址(如果是直连网络则为None) + pub next_hop: Option, + /// 出接口 + pub interface: Weak, + /// 路由优先级(数值越小优先级越高) + pub metric: u32, + /// 路由类型 + pub route_type: RouteType, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum RouteType { + /// 直连路由 + Connected, + /// 静态路由 + Static, + /// 默认路由 + Default, +} + +impl RouteEntry { + pub fn new_connected(destination: IpCidr, interface: Arc) -> Self { + RouteEntry { + destination, + next_hop: None, + interface: Arc::downgrade(&interface), + metric: 0, + route_type: RouteType::Connected, + } + } + + pub fn new_static( + destination: IpCidr, + next_hop: IpAddress, + interface: Arc, + metric: u32, + ) -> Self { + RouteEntry { + destination, + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric, + route_type: RouteType::Static, + } + } + + pub fn new_default(next_hop: IpAddress, interface: Arc) -> Self { + RouteEntry { + destination: IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0), + next_hop: Some(next_hop), + interface: Arc::downgrade(&interface), + metric: 100, + route_type: RouteType::Default, + } + } +} + +#[derive(Debug, Default)] +pub struct RouteTable { + pub entries: Vec, +} + +/// 路由决策结果 +#[derive(Debug)] +pub struct RouteDecision { + /// 出接口 + pub interface: Arc, + /// 下一跳地址(先写在这里 + pub next_hop: IpAddress, +} + +#[derive(Debug)] +pub struct Router { + name: String, + /// 路由表 //todo 后面再优化LC-trie,现在先简单用一个Vec + route_table: RwLock, + pub(self) nat_tracker: Arc, + pub ns: RwLock>, +} + +impl Router { + pub fn new(name: String) -> Arc { + Arc::new(Self { + name: name.clone(), + route_table: RwLock::new(RouteTable::default()), + nat_tracker: Arc::new(ConnTracker::default()), + ns: RwLock::new(Weak::default()), + }) + } + + /// 创建一个空的Router实例,主要用于初始化网络命名空间时使用 + /// 注意: 这个Router实例不会启动轮询线程 + pub fn new_empty() -> Arc { + Arc::new(Self { + name: "empty_router".to_string(), + route_table: RwLock::new(RouteTable::default()), + ns: RwLock::new(Weak::default()), + nat_tracker: Arc::new(ConnTracker::default()), + }) + } + + pub fn add_route(&self, route: RouteEntry) { + let mut guard = self.route_table.write(); + let entries = &mut guard.entries; + let pos = entries + .iter() + .position(|r| r.metric > route.metric) + .unwrap_or(entries.len()); + + entries.insert(pos, route); + log::info!("Router {}: Added route to routing table", self.name); + } + + pub fn remove_route(&self, destination: IpCidr) { + self.route_table + .write() + .entries + .retain(|route| route.destination != destination); + } + + pub fn lookup_route(&self, dest_ip: IpAddress) -> Option { + let guard = self.route_table.read(); + // 按最长前缀匹配原则查找路由 + let best = guard + .entries + .iter() + .filter(|route| { + route.interface.strong_count() > 0 && route.destination.contains_addr(&dest_ip) + }) + .max_by_key(|route| route.destination.prefix_len()); + + if let Some(entry) = best { + if let Some(interface) = entry.interface.upgrade() { + let next_hop = entry.next_hop.unwrap_or(dest_ip); + return Some(RouteDecision { + interface, + next_hop, + }); + } + } + + None + } + + /// 清理无效的路由表项(接口已经不存在的) + pub fn cleanup_routes(&mut self) { + self.route_table + .write() + .entries + .retain(|route| route.interface.strong_count() > 0); + } + + pub fn nat_tracker(&self) -> Arc { + self.nat_tracker.clone() + } +} + +/// 获取初始化网络命名空间下的路由表 +pub fn init_netns_router() -> Arc { + INIT_NET_NAMESPACE.router().clone() +} + +/// 可供路由设备应该实现的 trait +pub trait RouterEnableDevice: Iface { + /// # 网卡处理可路由的包 + /// ## 参数 + /// - `packet`: 需要处理的以太网帧 + /// ## 返回值 + /// - `Ok(())`: 通过路由处理成功 + /// - `Err(None)`: 忽略非IPv4包或没有路由到达的包,告诉外界没有经过处理,应该交由网卡进行默认处理 + /// - `Err(Some(SystemError))`: 处理失败,可能是包格式错误或其他系统错误 + fn handle_routable_packet( + &self, + ether_frame: &EthernetFrame<&[u8]>, + ) -> Result<(), Option> { + match ether_frame.ethertype() { + smoltcp::wire::EthernetProtocol::Ipv4 => { + // 获取IPv4包的可变引用 + let mut payload_mut = ether_frame.payload().to_vec(); + let mut ipv4_packet_mut = + Ipv4Packet::new_checked(&mut payload_mut).map_err(|e| { + log::warn!("Invalid IPv4 packet: {:?}", e); + Some(SystemError::EINVAL) + })?; + + let maybe_tuple = FiveTuple::extract_from_ipv4_packet( + &Ipv4Packet::new_checked(ether_frame.payload()).unwrap(), + ); + + // === PRE-ROUTING HOOK === + + let pkt_status = self.pre_routing_hook(&maybe_tuple, &mut ipv4_packet_mut); + ipv4_packet_mut.fill_checksum(); + + // === PRE-ROUTING HOOK END === + + let dst_ip = ipv4_packet_mut.dst_addr(); + + // 检查TTL + if ipv4_packet_mut.hop_limit() <= 1 { + log::warn!("TTL exceeded for packet to {}", dst_ip); + return Err(Some(SystemError::EINVAL)); + } + + // 检查是否是发给自己的包(目标IP是否是自己的IP) + if self.is_my_ip(dst_ip.into()) { + // todo 按照linux的逻辑,只要包的目标ip在当前网络命名空间里面,就直接进入本地协议栈处理 + // todo 但是我们的操作系统中每个接口都是独立的,并没有统一处理和分发(socket),所有这里必须将包放到对应iface的接收队列里面 + // 交给本地协议栈处理 + // log::info!("Packet destined for local interface {}", self.iface_name()); + return Err(None); + } + + // 查询当前网络命名空间下的路由表 + let router = self.netns_router(); + + let decision = match router.lookup_route(dst_ip.into()) { + Some(d) => d, + None => { + log::warn!("No route to {}", dst_ip); + return Err(None); + } + }; + + drop(router); + + // === POST-ROUTING HOOK === + + let decision_src_ip = decision.interface.common().ipv4_addr().unwrap(); + self.post_routing_hook( + &maybe_tuple, + &decision_src_ip, + &mut ipv4_packet_mut, + &pkt_status, + ); + ipv4_packet_mut.fill_checksum(); + + // === POST-ROUTING HOOK END === + + // 检查是否是从同一个接口进来又要从同一个接口出去(避免回路) + if self.iface_name() == decision.interface.iface_name() { + log::info!( + "Ignoring packet loop from {} to {}", + self.iface_name(), + dst_ip + ); + return Err(None); + } + + // 创建修改后的IP包(递减TTL) + sub_ttl_ipv4(&mut ipv4_packet_mut); + ipv4_packet_mut.fill_checksum(); + + // 交给出接口进行发送 + let next_hop = &decision.next_hop; + decision + .interface + .route_and_send(next_hop, ipv4_packet_mut.as_ref()); + + log::info!("Routed packet from {} to {} ", self.iface_name(), dst_ip); + Ok(()) + } + smoltcp::wire::EthernetProtocol::Arp => { + // 忽略ARP包 + // log::info!( + // "Ignoring non-IPv4 packet on interface {}", + // self.iface_name() + // ); + Err(None) + } + smoltcp::wire::EthernetProtocol::Ipv6 => { + log::warn!("IPv6 is not supported yet, ignoring packet"); + Err(None) + } + _ => { + log::warn!( + "Unknown ethertype {:?}, ignoring packet", + ether_frame.ethertype() + ); + Err(None) + } + } + } + + fn pre_routing_hook( + &self, + tuple: &Option, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) -> NatPktStatus { + let Some(tuple) = tuple else { + return NatPktStatus::Untouched; + }; + + let tracker = self.netns_router().nat_tracker(); + + if let Some((new_dst_ip, new_dst_port)) = tracker.snat.lock().process_return_traffic(tuple) + { + log::info!( + "Reverse SNAT: Translating {}:{} to {}:{}", + tuple.src_addr, + tuple.src_port, + new_dst_ip, + new_dst_port + ); + + SnatPolicy::update_dst( + tuple.src_addr, + new_dst_ip, + new_dst_port, + tuple.protocol, + ipv4_packet_mut, + ); + + let new_tuple = FiveTuple { + dst_addr: new_dst_ip, + dst_port: new_dst_port, + src_addr: tuple.src_addr, + src_port: tuple.src_port, + protocol: tuple.protocol, + }; + + return NatPktStatus::ReverseSnat(new_tuple); + } + + let mut dnat_guard = tracker.dnat.lock(); + if let Some((new_dst_ip, new_dst_port)) = dnat_guard.process_new_connection(tuple) { + log::info!( + "DNAT: Translating {}:{} to {}:{}", + tuple.dst_addr, + tuple.dst_port, + new_dst_ip, + new_dst_port + ); + + DnatPolicy::update_dst( + tuple.src_addr, + new_dst_ip, + new_dst_port, + tuple.protocol, + ipv4_packet_mut, + ); + + let new_tuple = FiveTuple { + dst_addr: new_dst_ip, + dst_port: new_dst_port, + src_addr: tuple.src_addr, + src_port: tuple.src_port, + protocol: tuple.protocol, + }; + + return NatPktStatus::NewDnat(new_tuple); + } + + return NatPktStatus::Untouched; + } + + fn post_routing_hook( + &self, + tuple: &Option, + _decision_src_ip: &Ipv4Addr, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + pkt_status: &NatPktStatus, + ) { + let tuple = match pkt_status { + NatPktStatus::ReverseSnat(t) => t, + NatPktStatus::NewDnat(t) => t, + NatPktStatus::Untouched => { + let Some(tuple) = tuple else { + return; + }; + tuple + } + }; + + let tracker = self.netns_router().nat_tracker(); + + if let Some((new_src_ip, new_src_port)) = tracker.dnat.lock().process_return_traffic(tuple) + { + log::info!( + "Reverse DNAT: Translating src {}:{} -> {}:{}", + tuple.src_addr, + tuple.src_port, + new_src_ip, + new_src_port + ); + + DnatPolicy::update_src( + tuple.dst_addr, + new_src_ip, + new_src_port, + tuple.protocol, + ipv4_packet_mut, + ); + + return; + } + + let mut snat_guard = tracker.snat.lock(); + if let Some((new_src_ip, new_src_port)) = snat_guard.process_new_connection(tuple) { + log::info!( + "SNAT: Translating {}:{} -> {}:{}", + tuple.src_addr, + tuple.src_port, + new_src_ip, + new_src_port + ); + + //TODO 应该加一个判断snat,可以支持直接改成出口接口的ip + // // 修改源IP地址 + // let new_src_ip: IpAddress = if let IpAddress::Ipv4(new_src_ip) = new_src_ip { + // new_src_ip.into() + // } else { + // (*decision_src_ip).into() + // }; + + SnatPolicy::update_src( + tuple.dst_addr, + new_src_ip, + new_src_port, + tuple.protocol, + ipv4_packet_mut, + ); + + return; + } + } + + /// 路由器决定通过此接口发送包时调用此方法 + /// 同Linux的ndo_start_xmit() + /// + /// todo 在这里查询arp_table,找到目标IP对应的mac地址然后拼接,如果找不到的话就需要主动发送arp请求去查询mac地址了,手伸不到smoltcp内部:( + /// 后续需要将arp查询的逻辑从smoltcp中抽离出来 + fn route_and_send(&self, next_hop: &IpAddress, ip_packet: &[u8]); + + /// 检查IP地址是否是当前接口的IP + /// todo 这里实现有误,不应该判断是否当前接口的IP,而是应该判断是否是当前网络命名空间的IP,然脏 + fn is_my_ip(&self, ip: IpAddress) -> bool; + + fn netns_router(&self) -> Arc { + self.net_namespace() + .map_or_else(init_netns_router, |ns| ns.router()) + } +} + +fn sub_ttl_ipv4(ipv4_packet: &mut Ipv4Packet<&mut Vec>) { + let new_ttl = ipv4_packet.hop_limit().saturating_sub(1); + ipv4_packet.set_hop_limit(new_ttl); +} + +/// # 每一个`RouterEnableDevice`应该有的公共数据,包含 +/// - 当前接口的arp_table,记录邻居(//todo:将网卡的发送以及处理逻辑从smoltcp中移动出来,目前只是简单为veth实现这个,因为可以直接查到对端的mac地址) +#[derive(Debug)] +pub struct RouterEnableDeviceCommon { + /// 当前接口的邻居缓存 + // pub arp_table: RwLock>, + /// 当前接口的IP地址列表(因为如果直接通过smoltcp获取ip的话可能导致死锁,因此则这里维护一份) + pub ip_addrs: RwLock>, +} + +impl Default for RouterEnableDeviceCommon { + fn default() -> Self { + Self { + // arp_table: RwLock::new(BTreeMap::new()), + ip_addrs: RwLock::new(Vec::new()), + } + } +} diff --git a/kernel/src/net/routing/nat.rs b/kernel/src/net/routing/nat.rs new file mode 100644 index 000000000..1aa898978 --- /dev/null +++ b/kernel/src/net/routing/nat.rs @@ -0,0 +1,413 @@ +use core::marker::PhantomData; + +use crate::libs::spinlock::SpinLock; +use crate::time::Duration; +use crate::time::Instant; +use alloc::fmt::Debug; +use alloc::vec::Vec; +use hashbrown::HashMap; +use smoltcp::wire::{IpAddress, IpCidr, Ipv4Packet}; + +pub(super) trait NatMapping: Debug + Clone + Copy { + fn last_seen(&self) -> Instant; + fn update_last_seen(&mut self, time: Instant); +} + +impl NatMapping for SnatMapping { + fn last_seen(&self) -> Instant { + self.last_seen + } + fn update_last_seen(&mut self, now: Instant) { + self.last_seen = now; + } +} + +impl NatMapping for DnatMapping { + fn last_seen(&self) -> Instant { + self.last_seen + } + fn update_last_seen(&mut self, now: Instant) { + self.last_seen = now; + } +} + +pub(super) trait NatPolicy { + type Rule: Debug + Clone; + type Mapping: NatMapping + Send + Sync; + + fn translate(rule: &Self::Rule, original: &FiveTuple) -> (FiveTuple, Self::Mapping); + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option; + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16); + + fn update_src( + dst_ip: IpAddress, + new_src_ip: IpAddress, + new_src_port: u16, + protocol: Protocol, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) { + // 修改源IP地址 + let IpAddress::Ipv4(new_src_ip_v4) = new_src_ip else { + return; + }; + ipv4_packet_mut.set_src_addr(new_src_ip_v4); + + let payload_mut = ipv4_packet_mut.payload_mut(); + match protocol { + Protocol::Tcp => { + let mut tcp_packet = smoltcp::wire::TcpPacket::new_checked(payload_mut).unwrap(); + tcp_packet.set_src_port(new_src_port); + // 重新计算TCP校验和 + tcp_packet.fill_checksum(&new_src_ip, &dst_ip); + } + Protocol::Udp => { + let mut udp_packet = smoltcp::wire::UdpPacket::new_checked(payload_mut).unwrap(); + udp_packet.set_src_port(new_src_port); + // 重新计算UDP校验和 + udp_packet.fill_checksum(&new_src_ip, &dst_ip); + } + _ => {} + } + } + + fn update_dst( + src_ip: IpAddress, + new_dst_ip: IpAddress, + new_dst_port: u16, + protocol: Protocol, + ipv4_packet_mut: &mut Ipv4Packet<&mut Vec>, + ) { + let IpAddress::Ipv4(new_dst_ip_v4) = new_dst_ip else { + return; + }; + ipv4_packet_mut.set_dst_addr(new_dst_ip_v4); + + let payload_mut = ipv4_packet_mut.payload_mut(); + match protocol { + Protocol::Tcp => { + let mut tcp_packet = smoltcp::wire::TcpPacket::new_checked(payload_mut).unwrap(); + tcp_packet.set_dst_port(new_dst_port); + // 重新计算TCP校验和 + tcp_packet.fill_checksum(&src_ip, &new_dst_ip); + } + Protocol::Udp => { + let mut udp_packet = smoltcp::wire::UdpPacket::new_checked(payload_mut).unwrap(); + udp_packet.set_dst_port(new_dst_port); + // 重新计算UDP校验和 + udp_packet.fill_checksum(&src_ip, &new_dst_ip); + } + _ => {} + } + } +} + +#[derive(Debug)] +pub(super) struct NatTracker { + rules: Vec, + mappings: HashMap, + reverse_mappings: HashMap, + policy_marker: PhantomData

, +} + +impl Default for NatTracker

{ + fn default() -> Self { + Self { + rules: Vec::new(), + mappings: HashMap::new(), + reverse_mappings: HashMap::new(), + policy_marker: PhantomData, + } + } +} + +impl NatTracker

{ + pub fn update_rules(&mut self, rules: Vec) { + self.rules = rules; + } + + pub fn cleanup_expired(&mut self, now: Instant) { + // 收集需要移除的键 + let expired_keys: Vec = self + .mappings + .iter() + .filter(|(_, mapping)| { + now.duration_since(mapping.last_seen()).unwrap() > Duration::from_secs(300) + }) + .map(|(key, _)| *key) + .collect(); + + for key in expired_keys { + self.mappings.remove(&key); + // 注意:反向映射表的清理比较复杂,因为多个主映射可能共享一个反向键(如果实现端口复用) + // 或者需要遍历 reverse_mappings 找到 value 为 key 的条目并删除。 + // 为简单起见,这里只清理主表。更健壮的实现需要双向链接或引用计数。 + log::info!("Cleaned up expired connection for key: {:?}", key); + } + } + + pub fn insert_rule(&mut self, rule: P::Rule) { + self.rules.push(rule); + } + + pub fn process_new_connection(&mut self, tuple: &FiveTuple) -> Option<(IpAddress, u16)> { + // let rules = self.rules.lock(); + let matching_rule = P::find_matching_rule(&self.rules, tuple)?; + + let (translated_tuple, new_mapping) = P::translate(&matching_rule, tuple); + + self.mappings.insert(*tuple, new_mapping); + self.reverse_mappings + .insert(translated_tuple.reverse(), *tuple); + + // log::info!( + // "Created new NAT mapping. Original: {:?}, Translated: {:?}", + // tuple, + // translated_tuple + // ); + + // 返回转换后的地址和端口信息,用于修改数据包 + // 注意:这里需要区分是修改源地址还是目的地址,取决于调用者 (SNAT vs DNAT) + // SNAT返回新的src_ip/port, DNAT返回新的dst_ip/port. + // `translated_tuple` 包含了所有信息,我们返回对应的部分。 + if translated_tuple.src_addr != tuple.src_addr { + Some((translated_tuple.src_addr, translated_tuple.src_port)) + } else { + Some((translated_tuple.dst_addr, translated_tuple.dst_port)) + } + } + + /// 处理返回流量 + pub fn process_return_traffic(&mut self, tuple: &FiveTuple) -> Option<(IpAddress, u16)> { + if let Some(original_key) = self.reverse_mappings.get(tuple) { + if let Some(mapping) = self.mappings.get_mut(original_key) { + mapping.update_last_seen(Instant::now()); + return Some(P::get_translation_for_return_traffic(mapping)); + } + } + None + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SnatPolicy; + +impl NatPolicy for SnatPolicy { + type Rule = SnatRule; + type Mapping = SnatMapping; + + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option { + rules + .iter() + .find(|rule| rule.source_cidr.contains_addr(&tuple.src_addr)) + .cloned() + } + + fn translate(rule: &Self::Rule, original_tuple: &FiveTuple) -> (FiveTuple, Self::Mapping) { + let translated_tuple = FiveTuple { + src_addr: rule.nat_ip, + src_port: original_tuple.src_port, // 简化端口处理 + ..*original_tuple + }; + + let mapping = SnatMapping { + original: *original_tuple, + translated: translated_tuple, + last_seen: Instant::now(), + }; + + (translated_tuple, mapping) + } + + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16) { + // 返回流量需要修改目的地址为原始客户端地址 + (mapping.original.src_addr, mapping.original.src_port) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DnatPolicy; + +impl NatPolicy for DnatPolicy { + type Rule = DnatRule; + type Mapping = DnatMapping; + + fn find_matching_rule(rules: &[Self::Rule], tuple: &FiveTuple) -> Option { + rules + .iter() + .find(|rule| { + if tuple.dst_addr != rule.external_addr { + return false; + } + + match rule.external_port { + Some(port) => tuple.dst_port == port, + None => true, + } + }) + .cloned() + } + + fn translate(rule: &Self::Rule, original: &FiveTuple) -> (FiveTuple, Self::Mapping) { + let new_internal_port = match rule.internal_port { + Some(port) => port, + None => original.dst_port, + }; + + let translated_tuple = FiveTuple { + dst_addr: rule.internal_addr, + dst_port: new_internal_port, + ..*original + }; + + let mapping = DnatMapping { + from_client: *original, + to_server: translated_tuple, + last_seen: Instant::now(), + }; + + (translated_tuple, mapping) + } + + fn get_translation_for_return_traffic(mapping: &Self::Mapping) -> (IpAddress, u16) { + // 返回流量需要修改源地址为原始客户端地址 + (mapping.from_client.dst_addr, mapping.from_client.dst_port) + } +} + +#[derive(Debug)] +pub struct ConnTracker { + pub(super) snat: SpinLock>, + pub(super) dnat: SpinLock>, +} + +impl ConnTracker { + pub fn cleanup_expired(&self, now: Instant) { + self.snat.lock().cleanup_expired(now); + self.dnat.lock().cleanup_expired(now); + } + + pub fn update_snat_rules(&self, rules: Vec) { + self.snat.lock().update_rules(rules); + } + + pub fn update_dnat_rules(&self, rules: Vec) { + self.dnat.lock().update_rules(rules); + } +} + +impl Default for ConnTracker { + fn default() -> Self { + Self { + snat: SpinLock::new(NatTracker::::default()), + dnat: SpinLock::new(NatTracker::::default()), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NatPktStatus { + Untouched, + ReverseSnat(FiveTuple), + NewDnat(FiveTuple), +} + +// SNAT 规则:匹配来自某个源地址段的流量,并将其转换为指定的公网IP +#[derive(Debug, Clone)] +pub struct SnatRule { + pub source_cidr: IpCidr, + pub nat_ip: IpAddress, +} + +/// SNAT 映射:记录一个连接的原始元组和转换后的元组 +#[derive(Debug, Clone, Copy)] +pub struct SnatMapping { + pub original: FiveTuple, + pub translated: FiveTuple, + pub last_seen: Instant, +} + +#[derive(Debug, Clone)] +pub struct DnatRule { + pub external_addr: IpAddress, + pub external_port: Option, + pub internal_addr: IpAddress, + pub internal_port: Option, +} + +#[derive(Debug, Clone, Copy)] +pub struct DnatMapping { + // The original tuple from the external client's perspective + pub from_client: FiveTuple, + // The tuple after DNAT, as seen by the internal server + pub to_server: FiveTuple, + pub last_seen: Instant, +} + +/// 五元组结构体,用于唯一标识一个网络连接 +#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] +pub struct FiveTuple { + pub src_addr: IpAddress, + pub dst_addr: IpAddress, + pub src_port: u16, + pub dst_port: u16, + pub protocol: Protocol, +} + +impl FiveTuple { + pub fn extract_from_ipv4_packet(packet: &Ipv4Packet<&[u8]>) -> Option { + let src_addr = packet.src_addr().into(); + let dst_addr = packet.dst_addr().into(); + let protocol = Protocol::from_bits_truncate(packet.next_header().into()); + + match protocol { + Protocol::Tcp => { + let tcp_packet = smoltcp::wire::TcpPacket::new_checked(packet.payload()).ok()?; + Some(FiveTuple { + protocol, + src_addr, + src_port: tcp_packet.src_port(), + dst_addr, + dst_port: tcp_packet.dst_port(), + }) + } + Protocol::Udp => { + let udp_packet = smoltcp::wire::UdpPacket::new_checked(packet.payload()).ok()?; + Some(FiveTuple { + protocol, + src_addr, + src_port: udp_packet.src_port(), + dst_addr, + dst_port: udp_packet.dst_port(), + }) + } + _ => None, + } + } + + pub fn reverse(&self) -> Self { + Self { + src_addr: self.dst_addr, + dst_addr: self.src_addr, + src_port: self.dst_port, + dst_port: self.src_port, + protocol: self.protocol, + } + } +} + +bitflags! { + pub struct Protocol: u8 { + const HopByHop = 0x00; + const Icmp = 0x01; + const Igmp = 0x02; + const Tcp = 0x06; + const Udp = 0x11; + const Ipv6Route = 0x2b; + const Ipv6Frag = 0x2c; + const IpSecEsp = 0x32; + const IpSecAh = 0x33; + const Icmpv6 = 0x3a; + const Ipv6NoNxt = 0x3b; + const Ipv6Opts = 0x3c; + } +} diff --git a/kernel/src/net/socket/base.rs b/kernel/src/net/socket/base.rs index 105530182..4620ed2ff 100644 --- a/kernel/src/net/socket/base.rs +++ b/kernel/src/net/socket/base.rs @@ -138,11 +138,4 @@ pub trait Socket: PollableInode + IndexNode { fn write(&self, buffer: &[u8]) -> Result { self.send(buffer, PMSG::empty()) } - - fn into_socket(this: Arc) -> Option> - where - Self: Sized, - { - Some(this) - } } diff --git a/kernel/src/net/socket/endpoint.rs b/kernel/src/net/socket/endpoint.rs index b592008e1..a04151ee9 100644 --- a/kernel/src/net/socket/endpoint.rs +++ b/kernel/src/net/socket/endpoint.rs @@ -1,6 +1,9 @@ use crate::{ mm::{verify_area, VirtAddr}, - net::{posix::SockAddr, socket::unix::UnixEndpoint}, + net::{ + posix::SockAddr, + socket::{netlink::addr::NetlinkSocketAddr, unix::UnixEndpoint}, + }, }; pub use smoltcp::wire::IpEndpoint; @@ -11,7 +14,10 @@ pub enum Endpoint { LinkLayer(LinkLayerEndpoint), /// 网络层端点 Ip(IpEndpoint), + /// Unix域套接字端点 Unix(UnixEndpoint), + /// Netlink端点 + Netlink(NetlinkSocketAddr), } /// @brief 链路层端点 diff --git a/kernel/src/net/socket/inet/common/mod.rs b/kernel/src/net/socket/inet/common/mod.rs index 8b60c0718..41e846c0c 100644 --- a/kernel/src/net/socket/inet/common/mod.rs +++ b/kernel/src/net/socket/inet/common/mod.rs @@ -1,4 +1,4 @@ -use crate::net::{Iface, NET_DEVICES}; +use crate::{net::Iface, process::namespace::net_namespace::NetNamespace}; use alloc::sync::Arc; pub mod port; @@ -24,6 +24,7 @@ pub enum Types { pub struct BoundInner { handle: smoltcp::iface::SocketHandle, iface: Arc, + netns: Arc, // inner: Vec<(smoltcp::iface::SocketHandle, Arc)> // address: smoltcp::wire::IpAddress, } @@ -35,30 +36,42 @@ impl BoundInner { socket: T, // socket_type: Types, address: &smoltcp::wire::IpAddress, + netns: Arc, ) -> Result where T: smoltcp::socket::AnySocket<'static>, { if address.is_unspecified() { + let Some(iface) = netns.default_iface() else { + return Err(SystemError::ENODEV); + }; // 强绑VirtualIO - let iface = NET_DEVICES - .read_irqsave() - .iter() - .find_map(|(_, v)| { - if v.common().is_default_iface() { - Some(v.clone()) - } else { - None - } - }) - .expect("No default interface"); + // let iface = NET_DEVICES + // .read_irqsave() + // .iter() + // .find_map(|(_, v)| { + // if v.common().is_default_iface() { + // Some(v.clone()) + // } else { + // None + // } + // }) + // .expect("No default interface"); let handle = iface.sockets().lock().add(socket); - return Ok(Self { handle, iface }); + return Ok(Self { + handle, + iface, + netns, + }); } else { - let iface = get_iface_to_bind(address).ok_or(SystemError::ENODEV)?; + let iface = get_iface_to_bind(address, netns.clone()).ok_or(SystemError::ENODEV)?; let handle = iface.sockets().lock().add(socket); - return Ok(Self { handle, iface }); + return Ok(Self { + handle, + iface, + netns, + }); } } @@ -66,15 +79,23 @@ impl BoundInner { socket: T, // socket_type: Types, remote: smoltcp::wire::IpAddress, + netns: Arc, ) -> Result<(Self, smoltcp::wire::IpAddress), SystemError> where T: smoltcp::socket::AnySocket<'static>, { - let (iface, address) = get_ephemeral_iface(&remote); + let (iface, address) = get_ephemeral_iface(&remote, netns.clone()); // let bound_port = iface.port_manager().bind_ephemeral_port(socket_type)?; let handle = iface.sockets().lock().add(socket); // let endpoint = smoltcp::wire::IpEndpoint::new(local_addr, bound_port); - Ok((Self { handle, iface }, address)) + Ok(( + Self { + handle, + iface, + netns, + }, + address, + )) } pub fn port_manager(&self) -> &PortManager { @@ -99,14 +120,21 @@ impl BoundInner { pub fn release(&self) { self.iface.sockets().lock().remove(self.handle); } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } #[inline] -pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option> { +pub fn get_iface_to_bind( + ip_addr: &smoltcp::wire::IpAddress, + netns: Arc, +) -> Option> { // log::debug!("get_iface_to_bind: {:?}", ip_addr); // if ip_addr.is_unspecified() - crate::net::NET_DEVICES - .read_irqsave() + netns + .device_list() .iter() .find(|(_, iface)| { let guard = iface.smol_iface().lock(); @@ -121,11 +149,12 @@ pub fn get_iface_to_bind(ip_addr: &smoltcp::wire::IpAddress) -> Option, ) -> (Arc, smoltcp::wire::IpAddress) { - get_iface_to_bind(remote_ip_addr) + get_iface_to_bind(remote_ip_addr, netns.clone()) .map(|iface| (iface, *remote_ip_addr)) .or({ - let ifaces = NET_DEVICES.read_irqsave(); + let ifaces = netns.device_list(); ifaces.iter().find_map(|(_, iface)| { iface .smol_iface() @@ -137,7 +166,7 @@ fn get_ephemeral_iface( }) }) .or({ - NET_DEVICES.read_irqsave().values().next().map(|iface| { + netns.device_list().values().next().map(|iface| { ( iface.clone(), iface.smol_iface().lock().ip_addrs()[0].address(), diff --git a/kernel/src/net/socket/inet/common/port.rs b/kernel/src/net/socket/inet/common/port.rs index e54e2a097..fb7e0ea97 100644 --- a/kernel/src/net/socket/inet/common/port.rs +++ b/kernel/src/net/socket/inet/common/port.rs @@ -19,14 +19,16 @@ pub struct PortManager { udp_port_table: SpinLock>, } -impl PortManager { - pub fn new() -> Self { - return Self { +impl Default for PortManager { + fn default() -> Self { + Self { tcp_port_table: SpinLock::new(HashMap::new()), udp_port_table: SpinLock::new(HashMap::new()), - }; + } } +} +impl PortManager { /// @brief 自动分配一个相对应协议中未被使用的PORT,如果动态端口均已被占用,返回错误码 EADDRINUSE pub fn get_ephemeral_port(&self, socket_type: Types) -> Result { // TODO: selects non-conflict high port diff --git a/kernel/src/net/socket/inet/datagram/inner.rs b/kernel/src/net/socket/inet/datagram/inner.rs index de1c3bd6f..0d2e79a56 100644 --- a/kernel/src/net/socket/inet/datagram/inner.rs +++ b/kernel/src/net/socket/inet/datagram/inner.rs @@ -1,9 +1,12 @@ +use alloc::sync::Arc; + use smoltcp; use system_error::SystemError; use crate::{ libs::spinlock::SpinLock, net::socket::inet::common::{BoundInner, Types as InetTypes}, + process::namespace::net_namespace::NetNamespace, }; pub type SmolUdpSocket = smoltcp::socket::udp::Socket<'static>; @@ -32,8 +35,12 @@ impl UnboundUdp { return Self { socket }; } - pub fn bind(self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result { - let inner = BoundInner::bind(self.socket, &local_endpoint.addr)?; + pub fn bind( + self, + local_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, + ) -> Result { + let inner = BoundInner::bind(self.socket, &local_endpoint.addr, netns)?; let bind_addr = local_endpoint.addr; let bind_port = if local_endpoint.port == 0 { inner.port_manager().bind_ephemeral_port(InetTypes::Udp)? @@ -65,9 +72,13 @@ impl UnboundUdp { }) } - pub fn bind_ephemeral(self, remote: smoltcp::wire::IpAddress) -> Result { + pub fn bind_ephemeral( + self, + remote: smoltcp::wire::IpAddress, + netns: Arc, + ) -> Result { // let (addr, port) = (remote.addr, remote.port); - let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote)?; + let (inner, address) = BoundInner::bind_ephemeral(self.socket, remote, netns)?; let bound_port = inner.port_manager().bind_ephemeral_port(InetTypes::Udp)?; let endpoint = smoltcp::wire::IpEndpoint::new(address, bound_port); Ok(BoundUdp { diff --git a/kernel/src/net/socket/inet/datagram/mod.rs b/kernel/src/net/socket/inet/datagram/mod.rs index 337042e75..45980e977 100644 --- a/kernel/src/net/socket/inet/datagram/mod.rs +++ b/kernel/src/net/socket/inet/datagram/mod.rs @@ -6,6 +6,8 @@ use crate::filesystem::epoll::EPollEventType; use crate::libs::wait_queue::WaitQueue; use crate::net::socket::common::EPollItems; use crate::net::socket::{Socket, PMSG}; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::ProcessManager; use crate::{libs::rwlock::RwLock, net::socket::endpoint::Endpoint}; use alloc::sync::{Arc, Weak}; use core::sync::atomic::AtomicBool; @@ -24,18 +26,21 @@ pub struct UdpSocket { nonblock: AtomicBool, wait_queue: WaitQueue, self_ref: Weak, + netns: Arc, epoll_items: EPollItems, } impl UdpSocket { pub fn new(nonblock: bool) -> Arc { - return Arc::new_cyclic(|me| Self { + let netns = ProcessManager::current_netns(); + Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(UdpInner::Unbound(UnboundUdp::new()))), nonblock: AtomicBool::new(nonblock), wait_queue: WaitQueue::default(), self_ref: me.clone(), + netns, epoll_items: EPollItems::default(), - }); + }) } pub fn is_nonblock(&self) -> bool { @@ -45,7 +50,7 @@ impl UdpSocket { pub fn do_bind(&self, local_endpoint: smoltcp::wire::IpEndpoint) -> Result<(), SystemError> { let mut inner = self.inner.write(); if let Some(UdpInner::Unbound(unbound)) = inner.take() { - let bound = unbound.bind(local_endpoint)?; + let bound = unbound.bind(local_endpoint, self.netns())?; bound .inner() @@ -62,7 +67,7 @@ impl UdpSocket { let mut inner_guard = self.inner.write(); let bound = match inner_guard.take().expect("Udp inner is None") { UdpInner::Bound(inner) => inner, - UdpInner::Unbound(inner) => inner.bind_ephemeral(remote)?, + UdpInner::Unbound(inner) => inner.bind_ephemeral(remote, self.netns())?, }; inner_guard.replace(UdpInner::Bound(bound)); return Ok(()); @@ -119,9 +124,8 @@ impl UdpSocket { let mut inner_guard = self.inner.write(); let inner = match inner_guard.take().expect("Udp Inner is None") { UdpInner::Bound(bound) => bound, - UdpInner::Unbound(unbound) => { - unbound.bind_ephemeral(to.ok_or(SystemError::EADDRNOTAVAIL)?.addr)? - } + UdpInner::Unbound(unbound) => unbound + .bind_ephemeral(to.ok_or(SystemError::EADDRNOTAVAIL)?.addr, self.netns())?, }; // size = inner.try_send(buf, to)?; inner_guard.replace(UdpInner::Bound(inner)); @@ -137,6 +141,10 @@ impl UdpSocket { }; return result; } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } impl Socket for UdpSocket { diff --git a/kernel/src/net/socket/inet/mod.rs b/kernel/src/net/socket/inet/mod.rs index aec4e66f0..8090ba625 100644 --- a/kernel/src/net/socket/inet/mod.rs +++ b/kernel/src/net/socket/inet/mod.rs @@ -40,6 +40,7 @@ pub trait InetSocket: Socket { fn on_iface_events(&self); fn notify(&self) { + // log::info!("InetSocket::notify"); self.on_iface_events(); let _ = EventPoll::wakeup_epoll(self.epoll_items().as_ref(), self.check_io_event()); } diff --git a/kernel/src/net/socket/inet/stream/inner.rs b/kernel/src/net/socket/inet/stream/inner.rs index 7557b7e26..779bc78bd 100644 --- a/kernel/src/net/socket/inet/stream/inner.rs +++ b/kernel/src/net/socket/inet/stream/inner.rs @@ -1,8 +1,10 @@ +use alloc::sync::Arc; use core::sync::atomic::AtomicUsize; use crate::filesystem::epoll::EPollEventType; use crate::libs::rwlock::RwLock; use crate::net::socket::{self, inet::Types}; +use crate::process::namespace::net_namespace::NetNamespace; use alloc::boxed::Box; use alloc::vec::Vec; use smoltcp; @@ -57,10 +59,11 @@ impl Init { pub(super) fn bind( self, local_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result { match self { Init::Unbound((socket, _)) => { - let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr)?; + let bound = socket::inet::BoundInner::bind(*socket, &local_endpoint.addr, netns)?; bound .port_manager() .bind_port(Types::Tcp, local_endpoint.port)?; @@ -77,11 +80,12 @@ impl Init { pub(super) fn bind_to_ephemeral( self, remote_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result<(socket::inet::BoundInner, smoltcp::wire::IpEndpoint), (Self, SystemError)> { match self { Init::Unbound((socket, ver)) => { let (bound, address) = - socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr) + socket::inet::BoundInner::bind_ephemeral(*socket, remote_endpoint.addr, netns) .map_err(|err| (Self::new(ver), err))?; let bound_port = bound .port_manager() @@ -97,9 +101,10 @@ impl Init { pub(super) fn connect( self, remote_endpoint: smoltcp::wire::IpEndpoint, + netns: Arc, ) -> Result { let (inner, local) = match self { - Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint)?, + Init::Unbound(_) => self.bind_to_ephemeral(remote_endpoint, netns)?, Init::Bound(inner) => inner, }; if local.addr.is_unspecified() { @@ -146,6 +151,7 @@ impl Init { .unwrap_or(&smoltcp::wire::IpAddress::from( smoltcp::wire::Ipv4Address::UNSPECIFIED, )), + inner.netns(), )?; inners.push(new_listen); } @@ -321,6 +327,7 @@ impl Listening { .unwrap_or(&smoltcp::wire::IpAddress::from( smoltcp::wire::Ipv4Address::UNSPECIFIED, )), + connected.netns(), )?; // swap the connected socket with the new_listen socket @@ -331,6 +338,7 @@ impl Listening { } pub fn update_io_events(&self, pollee: &AtomicUsize) { + // log::info!("Listening::update_io_events"); let position = self.inners.iter().position(|inner| { inner.with::(|socket| socket.is_active()) }); diff --git a/kernel/src/net/socket/inet/stream/mod.rs b/kernel/src/net/socket/inet/stream/mod.rs index ba0c3cdd5..c1466f5a2 100644 --- a/kernel/src/net/socket/inet/stream/mod.rs +++ b/kernel/src/net/socket/inet/stream/mod.rs @@ -6,6 +6,8 @@ use crate::libs::rwlock::RwLock; use crate::libs::wait_queue::WaitQueue; use crate::net::socket::common::EPollItems; use crate::net::socket::{common::ShutdownBit, endpoint::Endpoint, Socket, PMSG, PSOL}; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::ProcessManager; use crate::sched::SchedMode; use smoltcp; @@ -27,11 +29,13 @@ pub struct TcpSocket { wait_queue: WaitQueue, self_ref: Weak, pollee: AtomicUsize, + netns: Arc, epoll_items: EPollItems, } impl TcpSocket { pub fn new(nonblock: bool, ver: smoltcp::wire::IpVersion) -> Arc { + let netns = ProcessManager::current_netns(); Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(inner::Inner::Init(inner::Init::new(ver)))), // shutdown: Shutdown::new(), @@ -39,11 +43,16 @@ impl TcpSocket { wait_queue: WaitQueue::default(), self_ref: me.clone(), pollee: AtomicUsize::new(0_usize), + netns, epoll_items: EPollItems::default(), }) } - pub fn new_established(inner: inner::Established, nonblock: bool) -> Arc { + pub fn new_established( + inner: inner::Established, + nonblock: bool, + netns: Arc, + ) -> Arc { Arc::new_cyclic(|me| Self { inner: RwLock::new(Some(inner::Inner::Established(inner))), // shutdown: Shutdown::new(), @@ -51,6 +60,7 @@ impl TcpSocket { wait_queue: WaitQueue::default(), self_ref: me.clone(), pollee: AtomicUsize::new((EP::EPOLLIN.bits() | EP::EPOLLOUT.bits()) as usize), + netns, epoll_items: EPollItems::default(), }) } @@ -63,7 +73,7 @@ impl TcpSocket { let mut writer = self.inner.write(); match writer.take().expect("Tcp inner::Inner is None") { inner::Inner::Init(inner) => { - let bound = inner.bind(local_endpoint)?; + let bound = inner.bind(local_endpoint, self.netns())?; if let inner::Init::Bound((ref bound, _)) = bound { bound .iface() @@ -112,7 +122,7 @@ impl TcpSocket { { inner::Inner::Listening(listening) => listening.accept().map(|(stream, remote)| { ( - TcpSocket::new_established(stream, self.is_nonblock()), + TcpSocket::new_established(stream, self.is_nonblock(), self.netns()), remote, ) }), @@ -129,7 +139,7 @@ impl TcpSocket { let inner = writer.take().expect("Tcp inner::Inner is None"); let (init, result) = match inner { inner::Inner::Init(init) => { - let conn_result = init.connect(remote_endpoint); + let conn_result = init.connect(remote_endpoint, self.netns()); match conn_result { Ok(connecting) => ( inner::Inner::Connecting(connecting), @@ -179,6 +189,7 @@ impl TcpSocket { writer.replace(inner); drop(writer); + // log::info!("TcpSocket::finish_connect: {:?}", result); result } @@ -261,6 +272,10 @@ impl TcpSocket { fn do_poll(&self) -> usize { self.pollee.load(core::sync::atomic::Ordering::SeqCst) } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } } impl Socket for TcpSocket { diff --git a/kernel/src/net/socket/inet/syscall.rs b/kernel/src/net/socket/inet/syscall.rs index d363feb44..1b0d168ec 100644 --- a/kernel/src/net/socket/inet/syscall.rs +++ b/kernel/src/net/socket/inet/syscall.rs @@ -25,7 +25,7 @@ pub fn create_inet_socket( }, PSOCK::Stream => match protocol { IpProtocol::HopByHop | IpProtocol::Tcp => { - log::debug!("create tcp socket"); + // log::debug!("create tcp socket"); return Ok(TcpSocket::new(is_nonblock, version)); } _ => { diff --git a/kernel/src/net/socket/inode.rs b/kernel/src/net/socket/inode.rs index 37802006b..3601ce667 100644 --- a/kernel/src/net/socket/inode.rs +++ b/kernel/src/net/socket/inode.rs @@ -68,6 +68,10 @@ impl IndexNode for T { ModeType::from_bits_truncate(0o755), )) } + + fn as_socket(&self) -> Option<&dyn Socket> { + Some(self) + } } impl PollableInode for T { diff --git a/kernel/src/net/socket/mod.rs b/kernel/src/net/socket/mod.rs index 4489c4062..af8eef431 100644 --- a/kernel/src/net/socket/mod.rs +++ b/kernel/src/net/socket/mod.rs @@ -4,6 +4,7 @@ pub mod endpoint; mod family; pub mod inet; mod inode; +pub mod netlink; mod posix; pub mod unix; mod utils; diff --git a/kernel/src/net/socket/netlink/addr/mod.rs b/kernel/src/net/socket/netlink/addr/mod.rs new file mode 100644 index 000000000..73e08de41 --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/mod.rs @@ -0,0 +1,55 @@ +use crate::net::socket::{endpoint::Endpoint, netlink::addr::multicast::GroupIdSet}; +use system_error::SystemError; + +pub mod multicast; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct NetlinkSocketAddr { + port: u32, + groups: GroupIdSet, +} + +impl NetlinkSocketAddr { + pub fn new(port_num: u32, groups: GroupIdSet) -> Self { + Self { + port: port_num, + groups, + } + } + + pub fn new_unspecified() -> Self { + Self { + port: 0, + groups: GroupIdSet::new_empty(), + } + } + + pub const fn port(&self) -> u32 { + self.port + } + + pub fn groups(&self) -> GroupIdSet { + self.groups + } + + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } +} + +impl TryFrom for NetlinkSocketAddr { + type Error = SystemError; + + fn try_from(value: Endpoint) -> Result { + match value { + Endpoint::Netlink(addr) => Ok(addr), + _ => Err(SystemError::EAFNOSUPPORT), + } + } +} + +impl From for Endpoint { + fn from(value: NetlinkSocketAddr) -> Self { + Endpoint::Netlink(value) + } +} diff --git a/kernel/src/net/socket/netlink/addr/multicast.rs b/kernel/src/net/socket/netlink/addr/multicast.rs new file mode 100644 index 000000000..237afc957 --- /dev/null +++ b/kernel/src/net/socket/netlink/addr/multicast.rs @@ -0,0 +1,64 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct GroupIdSet(u32); + +impl GroupIdSet { + pub const fn new_empty() -> Self { + Self(0) + } + + pub const fn new(groups: u32) -> Self { + Self(groups) + } + + pub const fn ids_iter(&self) -> GroupIdIter { + GroupIdIter::new(self) + } + + pub fn add_groups(&mut self, groups: GroupIdSet) { + self.0 |= groups.0; + } + + pub fn drop_groups(&mut self, groups: GroupIdSet) { + self.0 &= !groups.0; + } + + pub fn set_groups(&mut self, new_groups: u32) { + self.0 = new_groups; + } + + pub fn clear(&mut self) { + self.0 = 0; + } + + pub fn is_empty(&self) -> bool { + self.0 == 0 + } + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +pub struct GroupIdIter { + groups: u32, +} + +impl GroupIdIter { + const fn new(groups: &GroupIdSet) -> Self { + Self { groups: groups.0 } + } +} + +impl Iterator for GroupIdIter { + type Item = u32; + + fn next(&mut self) -> Option { + if self.groups > 0 { + let group_id = self.groups.trailing_zeros(); + self.groups &= self.groups - 1; + return Some(group_id); + } + + None + } +} diff --git a/kernel/src/net/socket/netlink/common/bound.rs b/kernel/src/net/socket/netlink/common/bound.rs new file mode 100644 index 000000000..b0a3c0e91 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/bound.rs @@ -0,0 +1,68 @@ +use crate::{ + filesystem::epoll::EPollEventType, + net::socket::netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + receiver::MessageQueue, + table::BoundHandle, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::fmt::Debug; +use alloc::sync::Arc; +use system_error::SystemError; + +#[derive(Debug)] +pub struct BoundNetlink { + pub(in crate::net::socket::netlink) handle: BoundHandle, + pub(in crate::net::socket::netlink) remote_addr: NetlinkSocketAddr, + pub(in crate::net::socket::netlink) receive_queue: MessageQueue, + pub(in crate::net::socket::netlink) netns: Arc, +} + +impl BoundNetlink { + pub(super) fn new( + handle: BoundHandle, + message_queue: MessageQueue, + netns: Arc, + ) -> Self { + Self { + handle, + remote_addr: NetlinkSocketAddr::new_unspecified(), + receive_queue: message_queue, + netns, + } + } + + pub fn bind_common(&mut self, endpoint: &NetlinkSocketAddr) -> Result<(), SystemError> { + if endpoint.port() != self.handle.port() { + return Err(SystemError::EINVAL); + } + let groups = endpoint.groups(); + self.handle.bind_groups(groups); + + Ok(()) + } + + pub fn check_io_events_common(&self) -> EPollEventType { + let mut events = EPollEventType::EPOLLOUT; + + let receive_queue = self.receive_queue.0.lock(); + if !receive_queue.is_empty() { + events |= EPollEventType::EPOLLIN; + } + + events + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.handle.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.handle.drop_groups(groups); + } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } +} diff --git a/kernel/src/net/socket/netlink/common/mod.rs b/kernel/src/net/socket/netlink/common/mod.rs new file mode 100644 index 000000000..ed2e88381 --- /dev/null +++ b/kernel/src/net/socket/netlink/common/mod.rs @@ -0,0 +1,268 @@ +use crate::{ + filesystem::epoll::EPollEventType, + libs::{rwlock::RwLock, wait_queue::WaitQueue}, + net::socket::{ + endpoint::Endpoint, + netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + common::{bound::BoundNetlink, unbound::UnboundNetlink}, + table::SupportedNetlinkProtocol, + }, + utils::datagram_common::{select_remote_and_bind, Bound, Inner}, + Socket, PMSG, + }, + process::{namespace::net_namespace::NetNamespace, ProcessManager}, +}; +use alloc::sync::Arc; +use core::sync::atomic::AtomicBool; +use system_error::SystemError; + +pub(super) mod bound; +mod unbound; + +#[derive(Debug)] +pub struct NetlinkSocket { + inner: RwLock, BoundNetlink>>, + + is_nonblocking: AtomicBool, + wait_queue: Arc, + netns: Arc, +} + +impl NetlinkSocket

+where + BoundNetlink: Bound, +{ + pub fn new(is_nonblocking: bool) -> Arc { + let unbound = UnboundNetlink::new(); + Arc::new(Self { + inner: RwLock::new(Inner::Unbound(unbound)), + is_nonblocking: AtomicBool::new(is_nonblocking), + wait_queue: Arc::new(WaitQueue::default()), + netns: ProcessManager::current_netns(), + }) + } + + fn try_send( + &self, + buf: &[u8], + to: Option, + flags: crate::net::socket::PMSG, + ) -> Result { + let send_bytes = select_remote_and_bind( + &self.inner, + to, + || { + self.inner.write().bind_ephemeral( + &NetlinkSocketAddr::new_unspecified(), + self.wait_queue.clone(), + self.netns(), + ) + }, + |bound, remote| bound.try_send(buf, &remote, flags), + )?; + // todo pollee invalidate?? + + Ok(send_bytes) + } + + fn try_recv( + &self, + buf: &mut [u8], + flags: crate::net::socket::PMSG, + ) -> Result<(usize, Endpoint), SystemError> { + let (recv_bytes, endpoint) = self + .inner + .read() + .try_recv(buf, flags) + .map(|(recv_bytes, remote_endpoint)| (recv_bytes, remote_endpoint.into()))?; + // todo self.pollee.invalidate(); + + Ok((recv_bytes, endpoint)) + } + + /// 判断当前的netlink是否可以接收数据 + /// 目前netlink只是负责接收内核消息,所以暂时不用判断是否可以发送数据 + pub fn can_recv(&self) -> bool { + self.inner + .read() + .check_io_events() + .contains(EPollEventType::EPOLLIN) + } + + pub fn do_poll(&self) -> usize { + self.inner.read().check_io_events().bits() as usize + } + + pub fn netns(&self) -> Arc { + self.netns.clone() + } +} + +impl Socket for NetlinkSocket

+where + BoundNetlink: Bound, +{ + fn connect( + &self, + endpoint: crate::net::socket::endpoint::Endpoint, + ) -> Result<(), system_error::SystemError> { + let endpoint = endpoint.try_into()?; + + self.inner + .write() + .connect(&endpoint, self.wait_queue.clone(), self.netns()) + } + + fn bind( + &self, + endpoint: crate::net::socket::endpoint::Endpoint, + ) -> Result<(), system_error::SystemError> { + let endpoint = endpoint.try_into()?; + + self.inner + .write() + .bind(&endpoint, self.wait_queue.clone(), self.netns()) + } + + fn send_to( + &self, + buffer: &[u8], + flags: crate::net::socket::PMSG, + address: crate::net::socket::endpoint::Endpoint, + ) -> Result { + let endpoint = address.try_into()?; + + self.try_send(buffer, Some(endpoint), flags) + } + + fn wait_queue(&self) -> &WaitQueue { + &self.wait_queue + } + + fn recv_from( + &self, + buffer: &mut [u8], + flags: crate::net::socket::PMSG, + address: Option, + ) -> Result<(usize, crate::net::socket::endpoint::Endpoint), system_error::SystemError> { + // log::info!("NetlinkSocket recv_from called"); + use crate::sched::SchedMode; + + if let Some(addr) = address { + self.connect(addr)?; + } + + return if self.is_nonblocking() || flags.contains(PMSG::DONTWAIT) { + self.try_recv(buffer, flags) + } else { + loop { + match self.try_recv(buffer, flags) { + Err(SystemError::EAGAIN_OR_EWOULDBLOCK) => { + let _ = wq_wait_event_interruptible!(self.wait_queue, self.can_recv(), {}); + } + result => break result, + } + } + }; + // self.try_recv(buffer, flags) + } + + fn check_io_event(&self) -> crate::filesystem::epoll::EPollEventType { + EPollEventType::from_bits_truncate(self.do_poll() as u32) + } + + fn send_buffer_size(&self) -> usize { + // log::warn!("send_buffer_size is implemented to 0"); + // netlink sockets typically do not have a send buffer size like stream sockets. + 0 + } + + fn recv_buffer_size(&self) -> usize { + // log::warn!("recv_buffer_size is implemented to 0"); + // netlink sockets typically do not have a recv buffer size like stream sockets. + 0 + } + + fn recv(&self, buffer: &mut [u8], flags: PMSG) -> Result { + let (len, _) = self.recv_from(buffer, flags, None)?; + Ok(len) + } + + fn recv_msg( + &self, + _msg: &mut crate::net::posix::MsgHdr, + _flags: PMSG, + ) -> Result { + todo!("implement recv_msg for netlink socket"); + } + + fn send(&self, buffer: &[u8], flags: PMSG) -> Result { + self.try_send(buffer, None, flags) + } + + fn send_msg( + &self, + _msg: &crate::net::posix::MsgHdr, + _flags: PMSG, + ) -> Result { + todo!("implement send_msg for netlink socket"); + } + + fn do_close(&self) -> Result<(), SystemError> { + //TODO close the socket properly + Ok(()) + } + + fn epoll_items(&self) -> &crate::net::socket::common::EPollItems { + todo!("implement epoll_items for netlink socket"); + } + + fn local_endpoint(&self) -> Result { + if let Some(addr) = self.inner.read().addr() { + Ok(addr.into()) + } else { + Err(SystemError::ENOTCONN) + } + } + + fn remote_endpoint(&self) -> Result { + if let Some(addr) = self.inner.read().peer_addr() { + Ok(addr.into()) + } else { + Err(SystemError::ENOTCONN) + } + } +} + +impl NetlinkSocket

{ + pub fn is_nonblocking(&self) -> bool { + self.is_nonblocking + .load(core::sync::atomic::Ordering::Relaxed) + } + + #[allow(unused)] + pub fn set_nonblocking(&self, nonblocking: bool) { + self.is_nonblocking + .store(nonblocking, core::sync::atomic::Ordering::Relaxed); + } +} + +// 多播消息的时候会用到,比如uevent +impl Inner, BoundNetlink> { + #[allow(unused)] + fn add_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Bound(bound) => bound.add_groups(groups), + Inner::Unbound(unbound) => unbound.add_groups(groups), + } + } + + #[allow(unused)] + fn drop_groups(&mut self, groups: GroupIdSet) { + match self { + Inner::Unbound(unbound) => unbound.drop_groups(groups), + Inner::Bound(bound) => bound.drop_groups(groups), + } + } +} diff --git a/kernel/src/net/socket/netlink/common/unbound.rs b/kernel/src/net/socket/netlink/common/unbound.rs new file mode 100644 index 000000000..05d59c07f --- /dev/null +++ b/kernel/src/net/socket/netlink/common/unbound.rs @@ -0,0 +1,98 @@ +use crate::{ + filesystem::epoll::EPollEventType, + libs::wait_queue::WaitQueue, + net::socket::{ + netlink::{ + addr::{multicast::GroupIdSet, NetlinkSocketAddr}, + common::bound::BoundNetlink, + receiver::{MessageQueue, MessageReceiver}, + table::SupportedNetlinkProtocol, + }, + utils::datagram_common, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::sync::Arc; +use core::marker::PhantomData; +use system_error::SystemError; + +#[derive(Debug)] +pub struct UnboundNetlink { + groups: GroupIdSet, + phantom: PhantomData>, +} + +impl UnboundNetlink

{ + pub(super) fn new() -> Self { + Self { + groups: GroupIdSet::new_empty(), + phantom: PhantomData, + } + } + + pub(super) fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(0, self.groups) + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + self.groups.drop_groups(groups); + } +} + +impl datagram_common::Unbound for UnboundNetlink

{ + type Endpoint = NetlinkSocketAddr; + type Bound = BoundNetlink; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result, SystemError> { + let message_queue = MessageQueue::::new(); + let bound_handle = { + let endpoint = { + let mut endpoint = *endpoint; + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); +

::bind(&endpoint, receiver, netns.clone())? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue, netns)) + } + + fn bind_ephemeral( + &mut self, + _remote_endpoint: &Self::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result, SystemError> { + let message_queue = MessageQueue::::new(); + + let bound_handle = { + let endpoint = { + let mut endpoint = NetlinkSocketAddr::new_unspecified(); + endpoint.add_groups(self.groups); + endpoint + }; + let receiver = MessageReceiver::new(message_queue.clone(), wait_queue); +

::bind(&endpoint, receiver, netns.clone())? + }; + + Ok(BoundNetlink::new(bound_handle, message_queue, netns)) + } + + fn check_io_events(&self) -> EPollEventType { + EPollEventType::EPOLLOUT + } + + fn local_endpoint(&self) -> Option { + Some(self.addr()) + } +} diff --git a/kernel/src/net/socket/netlink/message/attr/mod.rs b/kernel/src/net/socket/netlink/message/attr/mod.rs new file mode 100644 index 000000000..8e21ef19b --- /dev/null +++ b/kernel/src/net/socket/netlink/message/attr/mod.rs @@ -0,0 +1,156 @@ +pub(super) mod noattr; + +use crate::{libs::align::align_up, net::socket::netlink::message::NLMSG_ALIGN}; +use alloc::vec::Vec; +use system_error::SystemError; + +const IS_NESTED_MASK: u16 = 1u16 << 15; +const IS_NET_BYTEORDER_MASK: u16 = 1u16 << 14; +const ATTRIBUTE_TYPE_MASK: u16 = !(IS_NESTED_MASK | IS_NET_BYTEORDER_MASK); + +/// Netlink Attribute Header +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct CAttrHeader { + len: u16, + type_: u16, +} + +impl CAttrHeader { + fn from_payload_len(type_: u16, payload_len: usize) -> Self { + let total_len = payload_len + size_of::(); + // debug_assert!(total_len <= u16::MAX as usize); + + Self { + len: total_len as u16, + type_, + } + } + + pub fn type_(&self) -> u16 { + self.type_ & ATTRIBUTE_TYPE_MASK + } + + pub fn payload_len(&self) -> usize { + self.len as usize - size_of::() + } + + pub fn total_len(&self) -> usize { + self.len as usize + } + + pub fn total_len_with_padding(&self) -> usize { + align_up(self.len as usize, NLMSG_ALIGN) + } + + pub fn padding_len(&self) -> usize { + self.total_len_with_padding() - self.total_len() + } +} + +/// Netlink Attribute +pub trait Attribute: core::fmt::Debug + Send + Sync { + fn type_(&self) -> u16; + + fn payload_as_bytes(&self) -> &[u8]; + + fn total_len_with_padding(&self) -> usize { + const DUMMY_TYPE: u16 = 0; + + CAttrHeader::from_payload_len(DUMMY_TYPE, self.payload_as_bytes().len()) + .total_len_with_padding() + } + + fn read_from_buf(header: &CAttrHeader, payload_buf: &[u8]) -> Result, SystemError> + where + Self: Sized; + + fn write_to_buf(&self, buf: &mut Vec) -> Result { + let type_: u16 = self.type_(); + let payload_bytes = self.payload_as_bytes(); + let header = CAttrHeader::from_payload_len(type_, payload_bytes.len()); + let total_len = header.total_len_with_padding(); + + // let mut current_offset = offset; + + // 写入头部 + let header_bytes = unsafe { + core::slice::from_raw_parts( + &header as *const CAttrHeader as *const u8, + size_of::(), + ) + }; + // buf[current_offset..current_offset + header_bytes.len()].copy_from_slice(header_bytes); + buf.extend_from_slice(header_bytes); + // current_offset += header_bytes.len(); + + // 写入负载 + // buf[current_offset..current_offset + payload_bytes.len()].copy_from_slice(payload_bytes); + buf.extend_from_slice(payload_bytes); + // current_offset += payload_bytes.len(); + + // 添加对齐填充 + let padding_len = header.padding_len(); + if padding_len > 0 { + // buf[current_offset..current_offset + padding_len].fill(0); + buf.extend(vec![0u8; padding_len]); + } + + Ok(total_len) + } + + fn read_all_from_buf(buf: &[u8], mut total_len: usize) -> Result, SystemError> + where + Self: Sized, + { + let mut attrs = Vec::new(); + let mut offset = 0; + + while total_len > 0 { + if total_len < size_of::() { + return Err(SystemError::EINVAL); + } + + // 检查是否有足够的字节读取属性头部 + if buf.len() - offset < size_of::() { + return Err(SystemError::EINVAL); + } + + // 读取属性头部 + let attr_header_bytes = &buf[offset..offset + size_of::()]; + let attr_header = unsafe { *(attr_header_bytes.as_ptr() as *const CAttrHeader) }; + + // 验证属性长度 + if attr_header.total_len() < size_of::() { + return Err(SystemError::EINVAL); + } + + total_len = total_len + .checked_sub(attr_header.total_len()) + .ok_or(SystemError::EINVAL)?; + + if buf.len() - offset < attr_header.total_len() { + return Err(SystemError::EINVAL); + } + + // 读取属性负载 + let payload_start = offset + size_of::(); + let payload_len = attr_header.payload_len(); + let payload_buf = &buf[payload_start..payload_start + payload_len]; + + // 解析属性 + if let Some(attr) = Self::read_from_buf(&attr_header, payload_buf)? { + attrs.push(attr); + } + + // 移动到下一个属性(考虑对齐) + let attr_total_with_padding = attr_header.total_len_with_padding(); + offset += attr_total_with_padding; + + let padding_len = total_len.min(attr_header.padding_len()); + total_len -= padding_len; + } + + Ok(attrs) + } +} diff --git a/kernel/src/net/socket/netlink/message/attr/noattr.rs b/kernel/src/net/socket/netlink/message/attr/noattr.rs new file mode 100644 index 000000000..59f3d27a7 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/attr/noattr.rs @@ -0,0 +1,38 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use alloc::vec::Vec; + +#[derive(Debug)] +pub enum NoAttr {} + +impl Attribute for NoAttr { + fn type_(&self) -> u16 { + match *self {} + } + + fn payload_as_bytes(&self) -> &[u8] { + match *self {} + } + + fn read_from_buf( + header: &super::CAttrHeader, + _payload_buf: &[u8], + ) -> Result, system_error::SystemError> + where + Self: Sized, + { + let _payload_len = header.payload_len(); + //todo reader.skip_some(payload_len); + + Ok(None) + } + + fn read_all_from_buf( + _buf: &[u8], + _offset: usize, + ) -> Result, system_error::SystemError> + where + Self: Sized, + { + Ok(Vec::new()) + } +} diff --git a/kernel/src/net/socket/netlink/message/mod.rs b/kernel/src/net/socket/netlink/message/mod.rs new file mode 100644 index 000000000..1bedbd21c --- /dev/null +++ b/kernel/src/net/socket/netlink/message/mod.rs @@ -0,0 +1,81 @@ +use crate::net::socket::netlink::{ + message::segment::header::CMsgSegHdr, table::StandardNetlinkProtocol, +}; +use alloc::vec::Vec; +use system_error::SystemError; + +pub(super) mod attr; +pub(super) mod segment; + +#[derive(Debug)] +pub struct Message { + segments: Vec, +} + +impl Message { + pub fn new(segments: Vec) -> Self { + Self { segments } + } + + pub fn segments(&self) -> &[T] { + &self.segments + } + + pub fn segments_mut(&mut self) -> &mut [T] { + &mut self.segments + } + + pub fn read_from(reader: &[u8]) -> Result { + let segments = { + let segment = T::read_from(reader)?; + vec![segment] + }; + + Ok(Self { segments }) + } + + pub fn write_to(&self, writer: &mut [u8]) -> Result { + // log::info!("Message write_to"); + let mut total_written: usize = 0; + + for segment in self.segments() { + if total_written >= writer.len() { + log::warn!("Netlink write buffer is full. Some segments may be dropped."); + break; + } + + let remaining_buf = &mut writer[total_written..]; + let written = segment.write_to(remaining_buf)?; + + total_written += written; + } + + // log::info!("Total written bytes: {}", total_written); + Ok(total_written) + } + + pub fn total_len(&self) -> usize { + self.segments + .iter() + .map(|segment| segment.header().len as usize) + .sum() + } + + pub fn protocol(&self) -> StandardNetlinkProtocol { + self.segments + .first() + .map_or(StandardNetlinkProtocol::UNUSED, |segment| { + segment.protocol() + }) + } +} + +pub trait ProtocolSegment: Sized + alloc::fmt::Debug { + fn header(&self) -> &CMsgSegHdr; + fn header_mut(&mut self) -> &mut CMsgSegHdr; + fn read_from(reader: &[u8]) -> Result; + fn write_to(&self, writer: &mut [u8]) -> Result; + fn protocol(&self) -> StandardNetlinkProtocol; +} + +pub(super) const NLMSG_ALIGN: usize = 4; diff --git a/kernel/src/net/socket/netlink/message/segment/ack.rs b/kernel/src/net/socket/netlink/message/segment/ack.rs new file mode 100644 index 000000000..fa2dcaade --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/ack.rs @@ -0,0 +1,84 @@ +use crate::net::socket::netlink::message::{ + attr::noattr::NoAttr, + segment::{ + common::SegmentCommon, + header::{CMsgSegHdr, SegHdrCommonFlags}, + CSegmentType, SegmentBody, + }, +}; +use alloc::vec::Vec; +use system_error::SystemError; + +pub type DoneSegment = SegmentCommon; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct DoneSegmentBody { + error_code: i32, +} + +impl SegmentBody for DoneSegmentBody { + type CType = DoneSegmentBody; +} + +impl DoneSegment { + pub fn new_from_request(request_header: &CMsgSegHdr, error: Option) -> Self { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::DONE as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = if let Some(err) = error { + err.to_posix_errno() + } else { + 0 + }; + DoneSegmentBody { error_code } + }; + + Self::new(header, body, Vec::new()) + } +} + +pub type ErrorSegment = SegmentCommon; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct ErrorSegmentBody { + error_code: i32, + request_header: CMsgSegHdr, +} + +impl SegmentBody for ErrorSegmentBody { + type CType = ErrorSegmentBody; +} + +impl ErrorSegment { + pub fn new_from_request(request_header: &CMsgSegHdr, error: Option) -> Self { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::ERROR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let body = { + let error_code = if let Some(err) = error { + err.to_posix_errno() + } else { + 0 + }; + ErrorSegmentBody { + error_code, + request_header: *request_header, + } + }; + + Self::new(header, body, Vec::new()) + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/common.rs b/kernel/src/net/socket/netlink/message/segment/common.rs new file mode 100644 index 000000000..9f6c92c17 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/common.rs @@ -0,0 +1,111 @@ +use crate::net::socket::netlink::message::{ + attr::Attribute, + segment::{header::CMsgSegHdr, SegmentBody}, +}; +use alloc::vec::Vec; +use system_error::SystemError; + +#[derive(Debug)] +pub struct SegmentCommon { + header: CMsgSegHdr, + body: Body, + attrs: Vec, +} + +impl SegmentCommon { + pub const HEADER_LEN: usize = size_of::(); + + pub fn header(&self) -> &CMsgSegHdr { + &self.header + } + + pub fn header_mut(&mut self) -> &mut CMsgSegHdr { + &mut self.header + } + + pub fn body(&self) -> &Body { + &self.body + } + + pub fn attrs(&self) -> &Vec { + &self.attrs + } +} + +impl SegmentCommon { + pub const BODY_LEN: usize = size_of::(); + + pub fn new(header: CMsgSegHdr, body: Body, attrs: Vec) -> Self { + let mut res = Self { + header, + body, + attrs, + }; + res.header.len = res.total_len() as u32; + res + } + + pub fn read_from_buf(header: CMsgSegHdr, buf: &[u8]) -> Result { + // log::info!("SegmentCommon try to read from buffer"); + let (body, remain_len, padded_len) = Body::read_from_buf(&header, buf)?; + + let attrs_buf = &buf[padded_len..]; + let attrs = Attr::read_all_from_buf(attrs_buf, remain_len)?; + + Ok(Self { + header, + body, + attrs, + }) + } + + pub fn write_to_buf(&self, buf: &mut [u8]) -> Result { + if buf.len() < self.header.len as usize { + return Err(SystemError::EINVAL); + } + + // Write header to the beginning of buf + let header_bytes = unsafe { + core::slice::from_raw_parts( + (&self.header as *const CMsgSegHdr) as *const u8, + Self::HEADER_LEN, + ) + }; + buf[..Self::HEADER_LEN].copy_from_slice(header_bytes); + + // 这里创建一个内核缓冲区,用来写入body和attribute,方便进行写入 + let mut kernel_buf: Vec = vec![]; + + self.body.write_to_buf(&mut kernel_buf)?; + for attr in self.attrs.iter() { + attr.write_to_buf(&mut kernel_buf)?; + } + + let actual_len = kernel_buf.len().min(buf.len()); + let payload_copied = if !kernel_buf.is_empty() { + buf[Self::HEADER_LEN..Self::HEADER_LEN + actual_len] + .copy_from_slice(&kernel_buf[..actual_len]); + // log::info!("buffer: {:?}", &buf[..actual_len]); + actual_len + } else { + // 如果没有数据需要写入,返回0 + // log::info!("No data to write to buffer"); + 0 + }; + + Ok(payload_copied + Self::HEADER_LEN) + } + + pub fn total_len(&self) -> usize { + Self::HEADER_LEN + Self::BODY_LEN + self.attrs_len() + } +} + +impl SegmentCommon { + pub fn attrs_len(&self) -> usize { + self.attrs + .iter() + .map(|attr| attr.total_len_with_padding()) + .sum() + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/header.rs b/kernel/src/net/socket/netlink/message/segment/header.rs new file mode 100644 index 000000000..03b33e631 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/header.rs @@ -0,0 +1,73 @@ +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CMsgSegHdr { + /// Length of the message, including the header + pub len: u32, + /// Type of message content + pub type_: u16, + /// Additional flags + pub flags: u16, + /// Sequence number + pub seq: u32, + /// Sending process port ID + pub pid: u32, +} + +bitflags! { + pub struct SegHdrCommonFlags: u16 { + /// Indicates a request message + const REQUEST = 0x01; + /// Multipart message, terminated by NLMSG_DONE + const MULTI = 0x02; + /// Reply with an acknowledgment, with zero or an error code + const ACK = 0x04; + /// Echo this request + const ECHO = 0x08; + /// Dump was inconsistent due to sequence change + const DUMP_INTR = 0x10; + /// Dump was filtered as requested + const DUMP_FILTERED = 0x20; + } +} + +bitflags! { + pub struct GetRequestFlags: u16 { + /// Specify the tree root + const ROOT = 0x100; + /// Return all matching results + const MATCH = 0x200; + /// Atomic get request + const ATOMIC = 0x400; + /// Combination flag for root and match + const DUMP = Self::ROOT.bits | Self::MATCH.bits; + } +} + +bitflags! { + pub struct NewRequestFlags: u16 { + /// Override existing entries + const REPLACE = 0x100; + /// Do not modify if it exists + const EXCL = 0x200; + /// Create if it does not exist + const CREATE = 0x400; + /// Add to the end of the list + const APPEND = 0x800; + } +} + +bitflags! { + pub struct DeleteRequestFlags: u16 { + /// Do not delete recursively + const NONREC = 0x100; + /// Delete multiple objects + const BULK = 0x200; + } +} + +bitflags! { + pub struct AckFlags: u16 { + const CAPPED = 0x100; + const ACK_TLVS = 0x100; + } +} diff --git a/kernel/src/net/socket/netlink/message/segment/mod.rs b/kernel/src/net/socket/netlink/message/segment/mod.rs new file mode 100644 index 000000000..6d08d20c7 --- /dev/null +++ b/kernel/src/net/socket/netlink/message/segment/mod.rs @@ -0,0 +1,122 @@ +use crate::libs::align::align_up; +use crate::net::socket::netlink::message::{segment::header::CMsgSegHdr, NLMSG_ALIGN}; +use alloc::fmt::Debug; +use alloc::vec::Vec; +use system_error::SystemError; + +pub mod ack; +pub mod common; +pub mod header; + +#[repr(u16)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum CSegmentType { + // Standard netlink message types + NOOP = 1, + ERROR = 2, + DONE = 3, + OVERRUN = 4, + + // protocol-level types + NEWLINK = 16, + DELLINK = 17, + GETLINK = 18, + SETLINK = 19, + + NEWADDR = 20, + DELADDR = 21, + GETADDR = 22, + + NEWROUTE = 24, + DELROUTE = 25, + GETROUTE = 26, + // TODO 补充 +} + +impl TryFrom for CSegmentType { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 1 => Ok(CSegmentType::NOOP), + 2 => Ok(CSegmentType::ERROR), + 3 => Ok(CSegmentType::DONE), + 4 => Ok(CSegmentType::OVERRUN), + 16 => Ok(CSegmentType::NEWLINK), + 17 => Ok(CSegmentType::DELLINK), + 18 => Ok(CSegmentType::GETLINK), + 19 => Ok(CSegmentType::SETLINK), + 20 => Ok(CSegmentType::NEWADDR), + 21 => Ok(CSegmentType::DELADDR), + 22 => Ok(CSegmentType::GETADDR), + 24 => Ok(CSegmentType::NEWROUTE), + 25 => Ok(CSegmentType::DELROUTE), + 26 => Ok(CSegmentType::GETROUTE), + _ => Err(SystemError::EINVAL), + } + } +} + +pub trait SegmentBody: Sized + Clone + Copy { + type CType: Copy + TryInto + From + Debug; + + fn read_from_buf(header: &CMsgSegHdr, buf: &[u8]) -> Result<(Self, usize, usize), SystemError> + where + Self: Sized, + { + // log::info!("header: {:?}", header); + let total_len = (header.len as usize) + .checked_sub(size_of::()) + .ok_or(SystemError::EINVAL)?; + + if buf.len() < total_len { + return Err(SystemError::EINVAL); + } + + let c_type_bytes = &buf[..size_of::()]; + let c_type = unsafe { *(c_type_bytes.as_ptr() as *const Self::CType) }; + // log::info!("c_type: {:?}", c_type); + + let total_len_with_padding = Self::total_len_with_padding(); + + let Ok(body) = c_type.try_into() else { + return Err(SystemError::EINVAL); + }; + + let remaining_len = total_len.saturating_sub(total_len_with_padding); + + Ok((body, remaining_len, total_len_with_padding)) + } + + fn write_to_buf(&self, buf: &mut Vec) -> Result<(), SystemError> { + // log::info!("SegmentBody write_to_buf"); + let c_type = Self::CType::from(*self); + + let body_bytes = unsafe { + core::slice::from_raw_parts( + &c_type as *const Self::CType as *const u8, + size_of::(), + ) + }; + buf.extend_from_slice(body_bytes); + + // let total_len_with_padding = Self::total_len_with_padding(); + let padding_len = Self::padding_len(); + + if padding_len > 0 { + buf.extend(vec![0u8; padding_len]); + } + + Ok(()) + } + + fn total_len_with_padding() -> usize { + let payload_len = size_of::(); + align_up(payload_len, NLMSG_ALIGN) + } + + fn padding_len() -> usize { + let payload_len = size_of::(); + Self::total_len_with_padding() - payload_len + } +} diff --git a/kernel/src/net/socket/netlink/mod.rs b/kernel/src/net/socket/netlink/mod.rs new file mode 100644 index 000000000..59382aec3 --- /dev/null +++ b/kernel/src/net/socket/netlink/mod.rs @@ -0,0 +1,52 @@ +use crate::net::socket::{ + netlink::{ + route::NetlinkRouteSocket, + table::{is_valid_protocol, StandardNetlinkProtocol}, + }, + Socket, PSOCK, +}; +use alloc::sync::Arc; +use system_error::SystemError; + +pub mod addr; +mod common; +mod message; +mod receiver; +mod route; +pub mod table; + +pub fn create_netlink_socket( + socket_type: PSOCK, + protocol: u32, + _is_nonblock: bool, +) -> Result, SystemError> { + match socket_type { + super::PSOCK::Raw | super::PSOCK::Datagram => { + let nl_protocol = StandardNetlinkProtocol::try_from(protocol); + let inode = match nl_protocol { + Ok(StandardNetlinkProtocol::ROUTE) => NetlinkRouteSocket::new(false), + Ok(_) => { + log::warn!( + "standard netlink families {} is not supported yet", + protocol + ); + return Err(SystemError::EAFNOSUPPORT); + } + Err(_) => { + if is_valid_protocol(protocol) { + log::error!("user-provided netlink family is not supported"); + return Err(SystemError::EPROTONOSUPPORT); + } + log::error!("invalid netlink protocol: {}", protocol); + return Err(SystemError::EAFNOSUPPORT); + } + }; + + Ok(inode) + } + _ => { + log::warn!("unsupported socket type for Netlink"); + Err(SystemError::EPROTONOSUPPORT) + } + } +} diff --git a/kernel/src/net/socket/netlink/receiver.rs b/kernel/src/net/socket/netlink/receiver.rs new file mode 100644 index 000000000..a38db9f5b --- /dev/null +++ b/kernel/src/net/socket/netlink/receiver.rs @@ -0,0 +1,51 @@ +use crate::libs::spinlock::SpinLock; +use crate::libs::wait_queue::WaitQueue; +use crate::process::ProcessState; +use alloc::collections::VecDeque; +use alloc::sync::Arc; +use system_error::SystemError; + +/// Netlink Socket 的消息队列 +#[derive(Debug)] +pub struct MessageQueue(pub Arc>>); + +impl Clone for MessageQueue { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl MessageQueue { + pub fn new() -> Self { + Self(Arc::new(SpinLock::new(VecDeque::new()))) + } + + fn enqueue(&self, message: Message) -> Result<(), SystemError> { + // FIXME: 确保消息队列不会超过最大长度 + self.0.lock().push_back(message); + Ok(()) + } +} + +/// Netlink Socket 的消息接收器,记录在当前网络命名空间的 Netlink Socket 表中,负责将消息压入对应的消息队列,并唤醒等待的线程 +#[derive(Debug)] +pub struct MessageReceiver { + message_queue: MessageQueue, + wait_queue: Arc, +} + +impl MessageReceiver { + pub fn new(message_queue: MessageQueue, wait_queue: Arc) -> Self { + Self { + message_queue, + wait_queue, + } + } + + pub fn enqueue_message(&self, message: Message) -> Result<(), SystemError> { + self.message_queue.enqueue(message)?; + // 唤醒等待队列中的线程 + self.wait_queue.wakeup(Some(ProcessState::Blocked(true))); + Ok(()) + } +} diff --git a/kernel/src/net/socket/netlink/route/bound.rs b/kernel/src/net/socket/netlink/route/bound.rs new file mode 100644 index 000000000..953a4a0af --- /dev/null +++ b/kernel/src/net/socket/netlink/route/bound.rs @@ -0,0 +1,121 @@ +use crate::{ + filesystem::epoll::EPollEventType, + net::socket::{ + netlink::{ + addr::NetlinkSocketAddr, + common::bound::BoundNetlink, + message::ProtocolSegment, + route::{kernel::NetlinkRouteKernelSocket, message::RouteNlMessage}, + }, + utils::datagram_common, + PMSG, + }, +}; +use system_error::SystemError; + +impl datagram_common::Bound for BoundNetlink { + type Endpoint = NetlinkSocketAddr; + + fn bind(&mut self, endpoint: &Self::Endpoint) -> Result<(), SystemError> { + self.bind_common(endpoint) + } + + fn local_endpoint(&self) -> Self::Endpoint { + self.handle.addr() + } + + fn remote_endpoint(&self) -> Option { + Some(self.remote_addr) + } + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint) { + self.remote_addr = *endpoint; + } + + fn try_send( + &self, + buf: &[u8], + to: &Self::Endpoint, + _flags: crate::net::socket::PMSG, + ) -> Result { + if *to != NetlinkSocketAddr::new_unspecified() { + return Err(SystemError::ENOTCONN); + } + + let sum_lens = buf.len(); + + let mut nlmsg = match RouteNlMessage::read_from(buf) { + Ok(msg) => msg, + Err(e) if e == SystemError::EFAULT => { + // 说明这个时候 buf 不是一个完整的 netlink 消息 + return Err(e); + } + Err(e) => { + // 传播错误,静默处理 + log::warn!( + "netlink_send: failed to read netlink message from buffer: {:?}", + e + ); + return Ok(sum_lens); + } + }; + + let local_port = self.handle.port(); + + for segment in nlmsg.segments_mut() { + let header = segment.header_mut(); + if header.pid == 0 { + header.pid = local_port; + } + } + + let Some(route_kernel) = self + .netns + .get_netlink_kernel_socket_by_protocol(nlmsg.protocol().into()) + else { + log::warn!("No route kernel socket available in net namespace"); + return Ok(sum_lens); + }; + + let route_kernel_socket = route_kernel + .as_any_ref() + .downcast_ref::() + .unwrap(); + + route_kernel_socket.request(&nlmsg, local_port, self.netns()); + + Ok(sum_lens) + } + + fn try_recv( + &self, + writer: &mut [u8], + flags: crate::net::socket::PMSG, + ) -> Result<(usize, Self::Endpoint), SystemError> { + let mut receive_queue = self.receive_queue.0.lock(); + + let Some(res) = receive_queue.front() else { + return Err(SystemError::EAGAIN_OR_EWOULDBLOCK); + }; + + let len = { + let max = writer.len(); + res.total_len().min(max) + }; + + let _copied = res.write_to(writer)?; + + if !flags.contains(PMSG::PEEK) { + receive_queue.pop_front(); + } + + // todo 目前这个信息只能来自内核 + let remote = NetlinkSocketAddr::new_unspecified(); + + Ok((len, remote)) + } + + fn check_io_events(&self) -> EPollEventType { + self.check_io_events_common() + } +} diff --git a/kernel/src/net/socket/netlink/route/kernel/addr.rs b/kernel/src/net/socket/netlink/route/kernel/addr.rs new file mode 100644 index 000000000..c35a75218 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/addr.rs @@ -0,0 +1,82 @@ +use crate::{ + driver::net::Iface, + net::socket::{ + netlink::{ + message::segment::{ + header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, + CSegmentType, + }, + route::{ + kernel::utils::finish_response, + message::{ + attr::addr::AddrAttr, + segment::{ + addr::{AddrMessageFlags, AddrSegment, AddrSegmentBody, RtScope}, + RouteNlSegment, + }, + }, + }, + }, + AddressFamily, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::ffi::CString; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZeroU32; +use system_error::SystemError; + +pub(super) fn do_get_addr( + request_segment: &AddrSegment, + netns: Arc, +) -> Result, SystemError> { + let dump_all = { + let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); + flags.contains(GetRequestFlags::DUMP) + }; + + if !dump_all { + log::error!("GetAddr request without DUMP flag is not supported yet"); + return Err(SystemError::EOPNOTSUPP_OR_ENOTSUP); + } + + let mut responce: Vec = netns + .device_list() + .iter() + .filter_map(|(_, iface)| iface_to_new_addr(request_segment.header(), iface)) + .map(RouteNlSegment::NewAddr) + .collect(); + + finish_response(request_segment.header(), dump_all, &mut responce); + + Ok(responce) +} + +fn iface_to_new_addr(request_header: &CMsgSegHdr, iface: &Arc) -> Option { + let ipv4_addr = iface.common().ipv4_addr()?; + + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::NEWADDR as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let addr_message = AddrSegmentBody { + family: AddressFamily::INet as _, + prefix_len: iface.common().prefix_len().unwrap(), + flags: AddrMessageFlags::PERMANENT, + scope: RtScope::HOST, + index: NonZeroU32::new(iface.nic_id() as u32), + }; + + let attrs = vec![ + AddrAttr::Address(ipv4_addr.octets()), + AddrAttr::Label(CString::new(iface.iface_name()).unwrap()), + AddrAttr::Local(ipv4_addr.octets()), + ]; + + Some(AddrSegment::new(header, addr_message, attrs)) +} diff --git a/kernel/src/net/socket/netlink/route/kernel/link.rs b/kernel/src/net/socket/netlink/route/kernel/link.rs new file mode 100644 index 000000000..67f7132c3 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/link.rs @@ -0,0 +1,157 @@ +use crate::{ + driver::net::{types::InterfaceType, Iface}, + net::socket::{ + netlink::{ + message::segment::{ + header::{CMsgSegHdr, GetRequestFlags, SegHdrCommonFlags}, + CSegmentType, + }, + route::{ + kernel::utils::finish_response, + message::{ + attr::link::LinkAttr, + segment::{ + link::{LinkMessageFlags, LinkSegment, LinkSegmentBody}, + RouteNlSegment, + }, + }, + }, + }, + AddressFamily, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::ffi::CString; +use alloc::sync::Arc; +use alloc::vec::Vec; +use core::num::NonZero; +use system_error::SystemError; + +pub(super) fn do_get_link( + request_segment: &LinkSegment, + netns: Arc, +) -> Result, SystemError> { + let filter_by = FilterBy::from_requset(request_segment)?; + + let mut responce: Vec = netns + .device_list() + .iter() + .filter(|(_, iface)| match &filter_by { + FilterBy::Index(index) => *index == iface.nic_id() as u32, + FilterBy::Name(name) => *name == iface.name(), + FilterBy::Dump => true, + }) + .map(|(_, iface)| iface_to_new_link(request_segment.header(), iface)) + .map(RouteNlSegment::NewLink) + .collect(); + + let dump_all = matches!(filter_by, FilterBy::Dump); + + if !dump_all && responce.is_empty() { + log::error!("no such device"); + return Err(SystemError::ENODEV); + } + + finish_response(request_segment.header(), dump_all, &mut responce); + + Ok(responce) +} + +enum FilterBy<'a> { + Index(u32), + Name(&'a str), + Dump, +} + +impl<'a> FilterBy<'a> { + fn from_requset(request_segment: &'a LinkSegment) -> Result { + let dump_all = { + let flags = GetRequestFlags::from_bits_truncate(request_segment.header().flags); + flags.contains(GetRequestFlags::DUMP) + }; + if dump_all { + validate_dumplink_request(request_segment.body())?; + return Ok(Self::Dump); + } + + validate_getlink_request(request_segment.body())?; + + if let Some(required_index) = request_segment.body().index { + return Ok(Self::Index(required_index.get())); + } + + let required_name = request_segment.attrs().iter().find_map(|attr| { + if let LinkAttr::Name(name) = attr { + Some(name.to_str().ok()?) + } else { + None + } + }); + + if let Some(name) = required_name { + return Ok(Self::Name(name)); + } + + log::error!("either interface name or index should be specified for non-dump mode"); + Err(SystemError::EINVAL) + } +} + +fn validate_getlink_request(body: &LinkSegmentBody) -> Result<(), SystemError> { + if !body.flags.is_empty() + || body.type_ != InterfaceType::NETROM + || body.pad.is_some() + || !body.change.is_empty() + { + log::error!("the flags or the type is not valid"); + return Err(SystemError::EINVAL); + } + + Ok(()) +} + +fn validate_dumplink_request(body: &LinkSegmentBody) -> Result<(), SystemError> { + // . + if !body.flags.is_empty() + || body.type_ != InterfaceType::NETROM + || body.pad.is_some() + || !body.change.is_empty() + { + log::error!("the flags or the type is not valid"); + return Err(SystemError::EINVAL); + } + + // . + if body.index.is_some() { + log::error!("filtering by interface index is not valid for link dumps"); + return Err(SystemError::EINVAL); + } + + Ok(()) +} + +fn iface_to_new_link(request_header: &CMsgSegHdr, iface: &Arc) -> LinkSegment { + let header = CMsgSegHdr { + len: 0, + type_: CSegmentType::NEWLINK as _, + flags: SegHdrCommonFlags::empty().bits(), + seq: request_header.seq, + pid: request_header.pid, + }; + + let link_message = LinkSegmentBody { + family: AddressFamily::Unspecified, + type_: iface.type_(), + index: NonZero::new(iface.nic_id() as u32), + flags: iface.flags(), + change: LinkMessageFlags::empty(), + pad: None, + }; + + let attrs = vec![ + LinkAttr::Name(CString::new(iface.name()).unwrap()), + LinkAttr::Mtu(iface.mtu() as u32), + ]; + + LinkSegment::new(header, link_message, attrs) +} diff --git a/kernel/src/net/socket/netlink/route/kernel/mod.rs b/kernel/src/net/socket/netlink/route/kernel/mod.rs new file mode 100644 index 000000000..c06dc048a --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/mod.rs @@ -0,0 +1,81 @@ +//! # Netlink route kernel module +//! 内核对于 Netlink 路由的处理模块 + +use crate::{ + net::socket::netlink::{ + message::{ + segment::{ack::ErrorSegment, CSegmentType}, + ProtocolSegment, + }, + route::message::{segment::RouteNlSegment, RouteNlMessage}, + table::{ + NetlinkKernelSocket, NetlinkRouteProtocol, StandardNetlinkProtocol, + SupportedNetlinkProtocol, + }, + }, + process::namespace::net_namespace::NetNamespace, +}; +use alloc::sync::Arc; +use core::marker::PhantomData; + +mod addr; +mod link; +mod utils; + +/// 负责处理 Netlink 路由相关的内核模块 +/// 每个 net namespace 都有一个独立的 NetlinkRouteKernelSocket +#[derive(Debug)] +pub struct NetlinkRouteKernelSocket { + _private: PhantomData<()>, +} + +impl NetlinkRouteKernelSocket { + pub const fn new() -> Self { + NetlinkRouteKernelSocket { + _private: PhantomData, + } + } + + pub(super) fn request( + &self, + request: &RouteNlMessage, + dst_port: u32, + netns: Arc, + ) { + for segment in request.segments() { + let header = segment.header(); + + let seg_type = CSegmentType::try_from(header.type_).unwrap(); + let responce = match segment { + RouteNlSegment::GetAddr(request) => addr::do_get_addr(request, netns.clone()), + RouteNlSegment::GetLink(request) => link::do_get_link(request, netns.clone()), + RouteNlSegment::GetRoute(_new_route) => todo!(), + _ => { + log::warn!("Unsupported route request segment type: {:?}", seg_type); + todo!() + } + }; + + let responce = match responce { + Ok(segments) => RouteNlMessage::new(segments), + Err(error) => { + //todo 处理 `NetlinkMessageCommonFlags::ACK` + let err_segment = ErrorSegment::new_from_request(header, Some(error)); + RouteNlMessage::new(vec![RouteNlSegment::Error(err_segment)]) + } + }; + + NetlinkRouteProtocol::unicast(dst_port, responce, netns.clone()).unwrap(); + } + } +} + +impl NetlinkKernelSocket for NetlinkRouteKernelSocket { + fn protocol(&self) -> StandardNetlinkProtocol { + StandardNetlinkProtocol::ROUTE + } + + fn as_any_ref(&self) -> &dyn core::any::Any { + self + } +} diff --git a/kernel/src/net/socket/netlink/route/kernel/utils.rs b/kernel/src/net/socket/netlink/route/kernel/utils.rs new file mode 100644 index 000000000..837f446a1 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/kernel/utils.rs @@ -0,0 +1,39 @@ +use crate::net::socket::netlink::{ + message::{ + segment::{ + ack::DoneSegment, + header::{CMsgSegHdr, SegHdrCommonFlags}, + }, + ProtocolSegment, + }, + route::message::segment::RouteNlSegment, +}; +use alloc::vec::Vec; + +pub fn finish_response( + request_header: &CMsgSegHdr, + dump_all: bool, + response_segments: &mut Vec, +) { + if !dump_all { + assert_eq!(response_segments.len(), 1); + return; + } + + append_done_segment(request_header, response_segments); + add_multi_flag(response_segments); +} + +fn append_done_segment(header: &CMsgSegHdr, response_segments: &mut Vec) { + let done_segment = DoneSegment::new_from_request(header, None); + response_segments.push(RouteNlSegment::Done(done_segment)); +} + +fn add_multi_flag(responce_segment: &mut [RouteNlSegment]) { + for segment in responce_segment.iter_mut() { + let header = segment.header_mut(); + let mut flags = SegHdrCommonFlags::from_bits_truncate(header.flags); + flags |= SegHdrCommonFlags::MULTI; + header.flags = flags.bits(); + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/addr.rs b/kernel/src/net/socket/netlink/route/message/attr/addr.rs new file mode 100644 index 000000000..18d3c9c44 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/addr.rs @@ -0,0 +1,125 @@ +use crate::net::socket::netlink::{message::attr::Attribute, route::message::attr::IFNAME_SIZE}; +use alloc::ffi::CString; +use system_error::SystemError; + +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +enum AddrAttrClass { + UNSPEC = 0, + ADDRESS = 1, + LOCAL = 2, + LABEL = 3, + BROADCAST = 4, + ANYCAST = 5, + CACHEINFO = 6, + MULTICAST = 7, + FLAGS = 8, + RT_PRIORITY = 9, + TARGET_NETNSID = 10, +} + +impl TryFrom for AddrAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 0 => Ok(AddrAttrClass::UNSPEC), + 1 => Ok(AddrAttrClass::ADDRESS), + 2 => Ok(AddrAttrClass::LOCAL), + 3 => Ok(AddrAttrClass::LABEL), + 4 => Ok(AddrAttrClass::BROADCAST), + 5 => Ok(AddrAttrClass::ANYCAST), + 6 => Ok(AddrAttrClass::CACHEINFO), + 7 => Ok(AddrAttrClass::MULTICAST), + 8 => Ok(AddrAttrClass::FLAGS), + 9 => Ok(AddrAttrClass::RT_PRIORITY), + 10 => Ok(AddrAttrClass::TARGET_NETNSID), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug)] +pub enum AddrAttr { + Address([u8; 4]), + Local([u8; 4]), + Label(CString), +} + +impl AddrAttr { + fn class(&self) -> AddrAttrClass { + match self { + AddrAttr::Address(_) => AddrAttrClass::ADDRESS, + AddrAttr::Local(_) => AddrAttrClass::LOCAL, + AddrAttr::Label(_) => AddrAttrClass::LABEL, + } + } +} + +impl Attribute for AddrAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + AddrAttr::Address(addr) => addr.as_ref(), + AddrAttr::Local(addr) => addr.as_ref(), + AddrAttr::Label(label) => label.to_bytes_with_nul(), + } + } + + fn read_from_buf( + header: &crate::net::socket::netlink::message::attr::CAttrHeader, + payload_buf: &[u8], + ) -> Result, SystemError> + where + Self: Sized, + { + let payload_len = header.payload_len(); + + // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. + let Ok(addr_class) = AddrAttrClass::try_from(header.type_()) else { + //reader.skip_some(payload_len); + return Ok(None); + }; + + // 拷贝payload_buf到本地变量,避免生命周期问题 + let buf = &payload_buf[..payload_len.min(payload_buf.len())]; + + let res = match (addr_class, buf.len()) { + (AddrAttrClass::ADDRESS, 4) => { + let mut arr = [0u8; 4]; + arr.copy_from_slice(&buf[0..4]); + AddrAttr::Address(arr) + } + (AddrAttrClass::LOCAL, 4) => { + let mut arr = [0u8; 4]; + arr.copy_from_slice(&buf[0..4]); + AddrAttr::Local(arr) + } + (AddrAttrClass::LABEL, 1..=IFNAME_SIZE) => { + // 查找第一个0字节作为结尾,否则用全部 + let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + let cstr = CString::new(&buf[..nul_pos]).map_err(|_| SystemError::EINVAL)?; + AddrAttr::Label(cstr) + } + (AddrAttrClass::ADDRESS | AddrAttrClass::LOCAL | AddrAttrClass::LABEL, _) => { + log::warn!( + "address attribute `{:?}` contains invalid payload", + addr_class + ); + return Err(SystemError::EINVAL); + } + (_, _) => { + log::warn!("address attribute `{:?}` is not supported", addr_class); + // reader.skip_some(payload_len); + return Ok(None); + } + }; + + Ok(Some(res)) + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/link.rs b/kernel/src/net/socket/netlink/route/message/attr/link.rs new file mode 100644 index 000000000..49224a4ce --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/link.rs @@ -0,0 +1,219 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use crate::net::socket::netlink::message::attr::CAttrHeader; +use crate::net::socket::netlink::route::message::attr::convert_one_from_raw_buf; +use crate::net::socket::netlink::route::message::attr::IFNAME_SIZE; +use alloc::ffi::CString; +use system_error::SystemError; + +#[derive(Debug, Clone, Copy, FromPrimitive, ToPrimitive)] +#[repr(u16)] +#[expect(non_camel_case_types)] +#[expect(clippy::upper_case_acronyms)] +enum LinkAttrClass { + UNSPEC = 0, + ADDRESS = 1, + BROADCAST = 2, + IFNAME = 3, + MTU = 4, + LINK = 5, + QDISC = 6, + STATS = 7, + COST = 8, + PRIORITY = 9, + MASTER = 10, + /// Wireless Extension event + WIRELESS = 11, + /// Protocol specific information for a link + PROTINFO = 12, + TXQLEN = 13, + MAP = 14, + WEIGHT = 15, + OPERSTATE = 16, + LINKMODE = 17, + LINKINFO = 18, + NET_NS_PID = 19, + IFALIAS = 20, + /// Number of VFs if device is SR-IOV PF + NUM_VF = 21, + VFINFO_LIST = 22, + STATS64 = 23, + VF_PORTS = 24, + PORT_SELF = 25, + AF_SPEC = 26, + /// Group the device belongs to + GROUP = 27, + NET_NS_FD = 28, + /// Extended info mask, VFs, etc. + EXT_MASK = 29, + /// Promiscuity count: > 0 means acts PROMISC + PROMISCUITY = 30, + NUM_TX_QUEUES = 31, + NUM_RX_QUEUES = 32, + CARRIER = 33, + PHYS_PORT_ID = 34, + CARRIER_CHANGES = 35, + PHYS_SWITCH_ID = 36, + LINK_NETNSID = 37, + PHYS_PORT_NAME = 38, + PROTO_DOWN = 39, + GSO_MAX_SEGS = 40, + GSO_MAX_SIZE = 41, + PAD = 42, + XDP = 43, + EVENT = 44, + NEW_NETNSID = 45, + IF_NETNSID = 46, + CARRIER_UP_COUNT = 47, + CARRIER_DOWN_COUNT = 48, + NEW_IFINDEX = 49, + MIN_MTU = 50, + MAX_MTU = 51, + PROP_LIST = 52, + /// Alternative ifname + ALT_IFNAME = 53, + PERM_ADDRESS = 54, + PROTO_DOWN_REASON = 55, + PARENT_DEV_NAME = 56, + PARENT_DEV_BUS_NAME = 57, +} + +impl TryFrom for LinkAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + use num_traits::FromPrimitive; + return ::from_u16(value).ok_or(Self::Error::EINVAL); + } +} + +#[derive(Debug)] +pub enum LinkAttr { + Name(CString), + Mtu(u32), + TxqLen(u32), + LinkMode(u8), + ExtMask(RtExtFilter), +} + +impl LinkAttr { + fn class(&self) -> LinkAttrClass { + match self { + LinkAttr::Name(_) => LinkAttrClass::IFNAME, + LinkAttr::Mtu(_) => LinkAttrClass::MTU, + LinkAttr::TxqLen(_) => LinkAttrClass::TXQLEN, + LinkAttr::LinkMode(_) => LinkAttrClass::LINKMODE, + LinkAttr::ExtMask(_) => LinkAttrClass::EXT_MASK, + } + } +} + +// #[derive(Debug)] +// pub enum LinkInfoAttr{ +// Kind(CString), +// Data(Vec), +// } + +// #[derive(Debug)] +// pub enum LinkInfoDataAttr{ +// VlanId(u16), + +// } + +impl Attribute for LinkAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + + fn payload_as_bytes(&self) -> &[u8] { + match self { + LinkAttr::Name(name) => name.as_bytes_with_nul(), + LinkAttr::Mtu(mtu) => unsafe { + core::slice::from_raw_parts(mtu as *const u32 as *const u8, 4) + }, + LinkAttr::TxqLen(txq_len) => unsafe { + core::slice::from_raw_parts(txq_len as *const u32 as *const u8, 4) + }, + LinkAttr::LinkMode(link_mode) => unsafe { + core::slice::from_raw_parts(link_mode as *const u8, 1) + }, + LinkAttr::ExtMask(ext_filter) => { + let bits = ext_filter.bits(); + unsafe { core::slice::from_raw_parts(&bits as *const u32 as *const u8, 4) } + } + } + } + + fn read_from_buf(header: &CAttrHeader, buf: &[u8]) -> Result, SystemError> + where + Self: Sized, + { + let payload_len = header.payload_len(); + + // TODO: Currently, `IS_NET_BYTEORDER_MASK` and `IS_NESTED_MASK` are ignored. + let Ok(class) = LinkAttrClass::try_from(header.type_()) else { + // reader.skip_some(payload_len); + return Ok(None); + }; + + let res = match (class, payload_len) { + (LinkAttrClass::IFNAME, 1..=IFNAME_SIZE) => { + let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(buf.len()); + let cstr = CString::new(&buf[..nul_pos]).map_err(|_| SystemError::EINVAL)?; + Self::Name(cstr) + } + (LinkAttrClass::MTU, 4) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::Mtu(*data) + } + (LinkAttrClass::TXQLEN, 4) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::TxqLen(*data) + } + (LinkAttrClass::LINKMODE, 1) => { + let data = convert_one_from_raw_buf::(buf)?; + Self::LinkMode(*data) + } + (LinkAttrClass::EXT_MASK, 4) => { + const { assert!(size_of::() == 4) }; + Self::ExtMask(*convert_one_from_raw_buf::(buf)?) + } + + ( + LinkAttrClass::IFNAME + | LinkAttrClass::MTU + | LinkAttrClass::TXQLEN + | LinkAttrClass::LINKMODE + | LinkAttrClass::EXT_MASK, + _, + ) => { + log::warn!("link attribute `{:?}` contains invalid payload", class); + return Err(SystemError::EINVAL); + } + + (_, _) => { + log::warn!("link attribute `{:?}` is not supported", class); + // reader.skip_some(payload_len); + return Ok(None); + } + }; + + Ok(Some(res)) + } +} + +bitflags! { + /// New extended info filters for [`NlLinkAttr::ExtMask`]. + /// + /// Reference: . + #[repr(C)] + pub struct RtExtFilter: u32 { + const VF = 1 << 0; + const BRVLAN = 1 << 1; + const BRVLAN_COMPRESSED = 1 << 2; + const SKIP_STATS = 1 << 3; + const MRP = 1 << 4; + const CFM_CONFIG = 1 << 5; + const CFM_STATUS = 1 << 6; + const MST = 1 << 7; + } +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/mod.rs b/kernel/src/net/socket/netlink/route/message/attr/mod.rs new file mode 100644 index 000000000..9c082971c --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/mod.rs @@ -0,0 +1,21 @@ +use core::slice::from_raw_parts; +use system_error::SystemError; + +pub mod addr; +pub mod link; +pub mod route; + +/// 网卡名字长度 +const IFNAME_SIZE: usize = 16; + +pub(super) fn convert_one_from_raw_buf(src: &[u8]) -> Result<&T, SystemError> { + log::info!("convert_one_from_raw_buf: src.len() = {}", src.len()); + if core::mem::size_of::() > src.len() { + return Err(SystemError::EINVAL); + } + let byte_buffer: &[u8] = &src[..core::mem::size_of::()]; + + let chunks = unsafe { from_raw_parts(byte_buffer.as_ptr() as *const T, 1) }; + let data = &chunks[0]; + return Ok(data); +} diff --git a/kernel/src/net/socket/netlink/route/message/attr/route.rs b/kernel/src/net/socket/netlink/route/message/attr/route.rs new file mode 100644 index 000000000..b393be0a5 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/attr/route.rs @@ -0,0 +1,95 @@ +use crate::net::socket::netlink::message::attr::Attribute; +use system_error::SystemError; + +/// 路由相关属性 +#[derive(Debug, Clone, Copy)] +#[repr(u16)] +enum RouteAttrClass { + UNSPEC = 0, + DST = 1, // 目标地址 + SRC = 2, // 源地址 + IIF = 3, // 输入接口 + OIF = 4, // 输出接口 + GATEWAY = 5, // 网关地址 + PRIORITY = 6, // 路由优先级 + PREFSRC = 7, // 首选源地址 + METRICS = 8, // 路由度量 + MULTIPATH = 9, // 多路径信息 + TABLE = 15, // 路由表ID +} + +impl TryFrom for RouteAttrClass { + type Error = SystemError; + + fn try_from(value: u16) -> Result { + match value { + 0 => Ok(RouteAttrClass::UNSPEC), + 1 => Ok(RouteAttrClass::DST), + 2 => Ok(RouteAttrClass::SRC), + 3 => Ok(RouteAttrClass::IIF), + 4 => Ok(RouteAttrClass::OIF), + 5 => Ok(RouteAttrClass::GATEWAY), + 6 => Ok(RouteAttrClass::PRIORITY), + 7 => Ok(RouteAttrClass::PREFSRC), + 8 => Ok(RouteAttrClass::METRICS), + 9 => Ok(RouteAttrClass::MULTIPATH), + 15 => Ok(RouteAttrClass::TABLE), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug)] +pub enum RouteAttr { + Dst([u8; 4]), // 目标地址 (IPv4) + Src([u8; 4]), // 源地址 (IPv4) + Gateway([u8; 4]), // 网关地址 (IPv4) + Oif(u32), // 输出接口索引 + Iif(u32), // 输入接口索引 + Priority(u32), // 路由优先级 + Prefsrc([u8; 4]), // 首选源地址 (IPv4) + Table(u32), // 路由表ID +} + +impl RouteAttr { + fn class(&self) -> RouteAttrClass { + match self { + RouteAttr::Dst(_) => RouteAttrClass::DST, + RouteAttr::Src(_) => RouteAttrClass::SRC, + RouteAttr::Gateway(_) => RouteAttrClass::GATEWAY, + RouteAttr::Oif(_) => RouteAttrClass::OIF, + RouteAttr::Iif(_) => RouteAttrClass::IIF, + RouteAttr::Priority(_) => RouteAttrClass::PRIORITY, + RouteAttr::Prefsrc(_) => RouteAttrClass::PREFSRC, + RouteAttr::Table(_) => RouteAttrClass::TABLE, + } + } +} + +impl Attribute for RouteAttr { + fn type_(&self) -> u16 { + self.class() as u16 + } + fn payload_as_bytes(&self) -> &[u8] { + // match self { + // RouteAttr::Dst(addr) + // | RouteAttr::Src(addr) + // | RouteAttr::Gateway(addr) + // | RouteAttr::Prefsrc(addr) => addr, + // RouteAttr::Oif(idx) | RouteAttr::Iif(idx) => idx.as_bytes(), + // RouteAttr::Priority(pri) => pri.as_bytes(), + // RouteAttr::Table(table) => table.as_bytes(), + // } + todo!() + } + + fn read_from_buf( + _header: &crate::net::socket::netlink::message::attr::CAttrHeader, + _payload_buf: &[u8], + ) -> Result, SystemError> + where + Self: Sized, + { + todo!() + } +} diff --git a/kernel/src/net/socket/netlink/route/message/mod.rs b/kernel/src/net/socket/netlink/route/message/mod.rs new file mode 100644 index 000000000..dc801f071 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/mod.rs @@ -0,0 +1,6 @@ +pub(super) mod attr; +pub(super) mod segment; + +use crate::net::socket::netlink::{message::Message, route::message::segment::RouteNlSegment}; + +pub(in crate::net::socket::netlink) type RouteNlMessage = Message; diff --git a/kernel/src/net/socket/netlink/route/message/segment/addr.rs b/kernel/src/net/socket/netlink/route/message/segment/addr.rs new file mode 100644 index 000000000..b09adaae5 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/addr.rs @@ -0,0 +1,120 @@ +use core::num::NonZeroU32; + +use system_error::SystemError; + +use crate::net::socket::netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::addr::AddrAttr, +}; + +pub type AddrSegment = SegmentCommon; + +impl SegmentBody for AddrSegmentBody { + type CType = CIfaddrMsg; +} + +/// `ifaddrmsg` in Linux. +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CIfaddrMsg { + pub family: u8, + /// The prefix length + pub prefix_len: u8, + /// Flags + pub flags: u8, + /// Address scope + pub scope: u8, + /// Link index + pub index: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct AddrSegmentBody { + pub family: i32, + pub prefix_len: u8, + pub flags: AddrMessageFlags, + pub scope: RtScope, + pub index: Option, +} + +impl TryFrom for AddrSegmentBody { + type Error = SystemError; + + fn try_from(value: CIfaddrMsg) -> Result { + // TODO: If the attribute IFA_FLAGS exists, the flags in header should be ignored. + let flags = AddrMessageFlags::from_bits_truncate(value.flags as u32); + let scope = RtScope::try_from(value.scope as i32)?; + let index = NonZeroU32::new(value.index); + + Ok(Self { + family: value.family as i32, + prefix_len: value.prefix_len, + flags, + scope, + index, + }) + } +} + +impl From for CIfaddrMsg { + fn from(value: AddrSegmentBody) -> Self { + let index = if let Some(index) = value.index { + index.get() + } else { + 0 + }; + CIfaddrMsg { + family: value.family as u8, + prefix_len: value.prefix_len, + flags: value.flags.bits() as u8, + scope: value.scope as _, + index, + } + } +} + +bitflags! { + /// Flags in [`CIfaddrMsg`]. + pub struct AddrMessageFlags: u32 { + const SECONDARY = 0x01; + const NODAD = 0x02; + const OPTIMISTIC = 0x04; + const DADFAILED = 0x08; + const HOMEADDRESS = 0x10; + const DEPRECATED = 0x20; + const TENTATIVE = 0x40; + const PERMANENT = 0x80; + const MANAGETEMPADDR = 0x100; + const NOPREFIXROUTE = 0x200; + const MCAUTOJOIN = 0x400; + const STABLE_PRIVACY = 0x800; + } +} + +/// `rt_scope_t` in Linux. +#[repr(u8)] +#[derive(Debug, Clone, Copy)] +#[expect(clippy::upper_case_acronyms)] +pub enum RtScope { + UNIVERSE = 0, + // User defined values + SITE = 200, + LINK = 253, + HOST = 254, + NOWHERE = 255, +} + +impl TryFrom for RtScope { + type Error = SystemError; + + fn try_from(value: i32) -> Result { + match value { + 0 => Ok(RtScope::UNIVERSE), + 200 => Ok(RtScope::SITE), + 253 => Ok(RtScope::LINK), + 254 => Ok(RtScope::HOST), + 255 => Ok(RtScope::NOWHERE), + _ => Err(SystemError::EINVAL), + } + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/link.rs b/kernel/src/net/socket/netlink/route/message/segment/link.rs new file mode 100644 index 000000000..c386712d5 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/link.rs @@ -0,0 +1,121 @@ +use crate::{ + driver::net::types::{InterfaceFlags, InterfaceType}, + net::socket::{ + netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::link::LinkAttr, + }, + AddressFamily, + }, +}; +use core::num::NonZeroU32; +use system_error::SystemError; + +pub type LinkSegment = SegmentCommon; + +impl SegmentBody for LinkSegmentBody { + type CType = CIfinfoMsg; +} + +/// `ifinfomsg` +/// . +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CIfinfoMsg { + /// AF_UNSPEC + pub family: u8, + /// Padding byte + pub pad: u8, + /// Device type + pub type_: u16, + /// Interface index + pub index: u32, + /// Device flags + pub flags: u32, + /// Change mask + pub change: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct LinkSegmentBody { + pub family: AddressFamily, + pub type_: InterfaceType, + pub index: Option, + pub flags: InterfaceFlags, + pub change: LinkMessageFlags, + pub pad: Option, // Must be 0 +} + +impl TryFrom for LinkSegmentBody { + type Error = SystemError; + + fn try_from(value: CIfinfoMsg) -> Result { + let family = AddressFamily::try_from(value.family as u16)?; + let type_ = InterfaceType::try_from(value.type_)?; + let index = NonZeroU32::new(value.index); + let flags = InterfaceFlags::from_bits_truncate(value.flags); + let change = LinkMessageFlags::from_bits_truncate(value.change); + let pad = if value.pad > 0 { Some(value.pad) } else { None }; + + Ok(Self { + family, + type_, + index, + flags, + change, + pad, + }) + } +} + +impl From for CIfinfoMsg { + fn from(value: LinkSegmentBody) -> Self { + CIfinfoMsg { + family: value.family as _, + pad: 0u8, + type_: value.type_ as _, + index: value.index.map(NonZeroU32::get).unwrap_or(0), + flags: value.flags.bits(), + change: value.change.bits(), + } + } +} + +bitflags! { + /// Flags in [`CIfinfoMsg`]. + /// . + pub struct LinkMessageFlags: u32 { + // sysfs + const IFF_UP = 1<<0; + // volatile + const IFF_BROADCAST = 1<<1; + // sysfs + const IFF_DEBUG = 1<<2; + // volatile + const IFF_LOOPBACK = 1<<3; + // volatile + const IFF_POINTOPOINT = 1<<4; + // sysfs + const IFF_NOTRAILERS = 1<<5; + // volatile + const IFF_RUNNING = 1<<6; + // sysfs + const IFF_NOARP = 1<<7; + // sysfs + const IFF_PROMISC = 1<<8; + // sysfs + const IFF_ALLMULTI = 1<<9; + // volatile + const IFF_MASTER = 1<<10; + // volatile + const IFF_SLAVE = 1<<11; + // sysfs + const IFF_MULTICAST = 1<<12; + // sysfs + const IFF_PORTSEL = 1<<13; + // sysfs + const IFF_AUTOMEDIA = 1<<14; + // sysfs + const IFF_DYNAMIC = 1<<15; + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/mod.rs b/kernel/src/net/socket/netlink/route/message/segment/mod.rs new file mode 100644 index 000000000..f72f558a6 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/mod.rs @@ -0,0 +1,112 @@ +pub mod addr; +pub mod link; +pub mod route; + +use crate::net::socket::netlink::{ + message::{ + segment::{ + ack::{DoneSegment, ErrorSegment}, + header::CMsgSegHdr, + CSegmentType, + }, + ProtocolSegment, + }, + route::message::segment::{addr::AddrSegment, link::LinkSegment, route::RouteSegment}, +}; +use system_error::SystemError; + +#[derive(Debug)] +pub enum RouteNlSegment { + NewLink(LinkSegment), + GetLink(LinkSegment), + NewAddr(AddrSegment), + GetAddr(AddrSegment), + Done(DoneSegment), + Error(ErrorSegment), + NewRoute(RouteSegment), + DelRoute(RouteSegment), + GetRoute(RouteSegment), +} + +impl ProtocolSegment for RouteNlSegment { + fn header(&self) -> &crate::net::socket::netlink::message::segment::header::CMsgSegHdr { + match self { + RouteNlSegment::NewRoute(route_segment) + | RouteNlSegment::DelRoute(route_segment) + | RouteNlSegment::GetRoute(route_segment) => route_segment.header(), + RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { + addr_segment.header() + } + RouteNlSegment::NewLink(link_segment) | RouteNlSegment::GetLink(link_segment) => { + link_segment.header() + } + RouteNlSegment::Done(done_segment) => done_segment.header(), + RouteNlSegment::Error(error_segment) => error_segment.header(), + } + } + + fn header_mut( + &mut self, + ) -> &mut crate::net::socket::netlink::message::segment::header::CMsgSegHdr { + match self { + RouteNlSegment::NewRoute(route_segment) + | RouteNlSegment::DelRoute(route_segment) + | RouteNlSegment::GetRoute(route_segment) => route_segment.header_mut(), + RouteNlSegment::NewAddr(addr_segment) | RouteNlSegment::GetAddr(addr_segment) => { + addr_segment.header_mut() + } + RouteNlSegment::NewLink(link_segment) | RouteNlSegment::GetLink(link_segment) => { + link_segment.header_mut() + } + RouteNlSegment::Done(done_segment) => done_segment.header_mut(), + RouteNlSegment::Error(error_segment) => error_segment.header_mut(), + } + } + + fn read_from(buf: &[u8]) -> Result { + let header_size = size_of::(); + if buf.len() < header_size { + log::warn!("the buffer is too small to read a netlink segment header"); + return Err(SystemError::EINVAL); + } + + let header = unsafe { *(buf.as_ptr() as *const CMsgSegHdr) }; + let payload_buf = &buf[header_size..]; + + let segment = match CSegmentType::try_from(header.type_)? { + CSegmentType::GETADDR => { + RouteNlSegment::GetAddr(AddrSegment::read_from_buf(header, payload_buf)?) + } + CSegmentType::GETROUTE => { + RouteNlSegment::GetRoute(RouteSegment::read_from_buf(header, payload_buf)?) + } + CSegmentType::GETLINK => { + RouteNlSegment::GetLink(LinkSegment::read_from_buf(header, payload_buf)?) + } + _ => return Err(SystemError::EINVAL), + }; + + Ok(segment) + } + + fn write_to(&self, buf: &mut [u8]) -> Result { + // log::info!("RouteNlSegment write_to"); + let copied = match self { + RouteNlSegment::NewAddr(addr_segment) => addr_segment.write_to_buf(buf)?, + RouteNlSegment::NewRoute(route_segment) => route_segment.write_to_buf(buf)?, + RouteNlSegment::NewLink(link_segment) => link_segment.write_to_buf(buf)?, + RouteNlSegment::Done(done_segment) => done_segment.write_to_buf(buf)?, + RouteNlSegment::Error(error_segment) => error_segment.write_to_buf(buf)?, + _ => { + log::warn!("write_to is not implemented for this segment type"); + return Err(SystemError::ENOSYS); + } + }; + + Ok(copied) + } + + fn protocol(&self) -> crate::net::socket::netlink::table::StandardNetlinkProtocol { + crate::net::socket::netlink::table::StandardNetlinkProtocol::ROUTE + } +} diff --git a/kernel/src/net/socket/netlink/route/message/segment/route.rs b/kernel/src/net/socket/netlink/route/message/segment/route.rs new file mode 100644 index 000000000..995081329 --- /dev/null +++ b/kernel/src/net/socket/netlink/route/message/segment/route.rs @@ -0,0 +1,212 @@ +use system_error::SystemError; + +use crate::net::socket::{ + netlink::{ + message::segment::{common::SegmentCommon, SegmentBody}, + route::message::attr::route::RouteAttr, + }, + AddressFamily, +}; + +pub type RouteSegment = SegmentCommon; + +impl SegmentBody for RouteSegmentBody { + type CType = CRtMsg; +} + +/// `rtmsg` in Linux +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub struct CRtMsg { + /// 地址族 (AF_INET/AF_INET6) + pub family: u8, + /// 目标地址前缀长度 + pub dst_len: u8, + /// 源地址前缀长度 + pub src_len: u8, + /// 服务类型/DSCP + pub tos: u8, + /// 路由表ID + pub table: u8, + /// 路由协议 + pub protocol: u8, + /// 路由作用域 + pub scope: u8, + /// 路由类型 + pub type_: u8, + /// 路由标志 + pub flags: u32, +} + +#[derive(Debug, Clone, Copy)] +pub struct RouteSegmentBody { + pub family: AddressFamily, + pub dst_len: u8, + pub src_len: u8, + pub tos: u8, + pub table: RouteTable, + pub protocol: RouteProtocol, + pub scope: RouteScope, + pub type_: RouteType, + pub flags: RouteFlags, +} + +// 定义路由相关的枚举类型 +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteTable { + Unspec = 0, + Compat = 252, + Default = 253, + Main = 254, + Local = 255, +} + +impl TryFrom for RouteTable { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteTable::Unspec), + 252 => Ok(RouteTable::Compat), + 253 => Ok(RouteTable::Default), + 254 => Ok(RouteTable::Main), + 255 => Ok(RouteTable::Local), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteProtocol { + Unspec = 0, + Redirect = 1, + Kernel = 2, + Boot = 3, + Static = 4, + // 添加更多协议... +} + +impl TryFrom for RouteProtocol { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteProtocol::Unspec), + 1 => Ok(RouteProtocol::Redirect), + 2 => Ok(RouteProtocol::Kernel), + 3 => Ok(RouteProtocol::Boot), + 4 => Ok(RouteProtocol::Static), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteScope { + Universe = 0, + Site = 200, + Link = 253, + Host = 254, + Nowhere = 255, +} + +impl TryFrom for RouteScope { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteScope::Universe), + 200 => Ok(RouteScope::Site), + 253 => Ok(RouteScope::Link), + 254 => Ok(RouteScope::Host), + 255 => Ok(RouteScope::Nowhere), + _ => Err(SystemError::EINVAL), + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(u8)] +pub enum RouteType { + Unspec = 0, + Unicast = 1, + Local = 2, + Broadcast = 3, + Anycast = 4, + Multicast = 5, + Blackhole = 6, + Unreachable = 7, + Prohibit = 8, +} + +impl TryFrom for RouteType { + type Error = SystemError; + + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(RouteType::Unspec), + 1 => Ok(RouteType::Unicast), + 2 => Ok(RouteType::Local), + 3 => Ok(RouteType::Broadcast), + 4 => Ok(RouteType::Anycast), + 5 => Ok(RouteType::Multicast), + 6 => Ok(RouteType::Blackhole), + 7 => Ok(RouteType::Unreachable), + 8 => Ok(RouteType::Prohibit), + _ => Err(SystemError::EINVAL), + } + } +} + +bitflags::bitflags! { + pub struct RouteFlags: u32 { + const NOTIFY = 0x100; + const CLONED = 0x200; + const EQUALIZE = 0x400; + const PREFIX = 0x800; + } +} + +impl TryFrom for RouteSegmentBody { + type Error = SystemError; + + fn try_from(value: CRtMsg) -> Result { + let family = AddressFamily::try_from(value.family as u16)?; + let table = RouteTable::try_from(value.table)?; + let protocol = RouteProtocol::try_from(value.protocol)?; + let scope = RouteScope::try_from(value.scope)?; + let type_ = RouteType::try_from(value.type_)?; + let flags = RouteFlags::from_bits_truncate(value.flags); + + Ok(Self { + family, + dst_len: value.dst_len, + src_len: value.src_len, + tos: value.tos, + table, + protocol, + scope, + type_, + flags, + }) + } +} + +impl From for CRtMsg { + fn from(body: RouteSegmentBody) -> Self { + CRtMsg { + family: body.family as u8, + dst_len: body.dst_len, + src_len: body.src_len, + tos: body.tos, + table: body.table as u8, + protocol: body.protocol as u8, + scope: body.scope as u8, + type_: body.type_ as u8, + flags: body.flags.bits(), + } + } +} diff --git a/kernel/src/net/socket/netlink/route/mod.rs b/kernel/src/net/socket/netlink/route/mod.rs new file mode 100644 index 000000000..0eafb4dff --- /dev/null +++ b/kernel/src/net/socket/netlink/route/mod.rs @@ -0,0 +1,7 @@ +use crate::net::socket::netlink::{common::NetlinkSocket, table::NetlinkRouteProtocol}; + +pub(super) mod bound; +pub(super) mod kernel; +pub(super) mod message; + +pub(super) type NetlinkRouteSocket = NetlinkSocket; diff --git a/kernel/src/net/socket/netlink/table/mod.rs b/kernel/src/net/socket/netlink/table/mod.rs new file mode 100644 index 000000000..adc4aab70 --- /dev/null +++ b/kernel/src/net/socket/netlink/table/mod.rs @@ -0,0 +1,349 @@ +mod multicast; + +use crate::net::socket::netlink::addr::multicast::GroupIdSet; +use crate::net::socket::netlink::route::kernel::NetlinkRouteKernelSocket; +use crate::net::socket::netlink::route::message::RouteNlMessage; +use crate::net::socket::netlink::table::multicast::MulticastMessage; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::process::ProcessManager; +use crate::{libs::rand, net::socket::netlink::addr::NetlinkSocketAddr}; +use crate::{ + libs::rwlock::RwLock, + net::socket::netlink::{receiver::MessageReceiver, table::multicast::MulticastGroup}, +}; +use alloc::boxed::Box; +use alloc::collections::BTreeMap; +use alloc::fmt::Debug; +use alloc::sync::Arc; +use core::any::Any; +use hashbrown::HashMap; +use system_error::SystemError; + +pub const MAX_ALLOWED_PROTOCOL_ID: u32 = 32; +const MAX_GROUPS: u32 = 32; + +#[derive(Debug)] +pub struct NetlinkSocketTable { + route: Arc>>, + // 在这里继续补充其他协议下的 socket table + // 比如 uevent: Arc>>, +} + +impl Default for NetlinkSocketTable { + fn default() -> Self { + Self { + route: Arc::new(RwLock::new(ProtocolSocketTable::new())), + } + } +} + +impl NetlinkSocketTable { + pub fn route(&self) -> Arc>> { + self.route.clone() + } +} + +#[derive(Debug)] +pub struct ProtocolSocketTable { + unicast_sockets: BTreeMap>, + multicast_groups: Box<[MulticastGroup]>, +} + +impl ProtocolSocketTable { + fn new() -> Self { + let multicast_groups = (0u32..MAX_GROUPS).map(|_| MulticastGroup::new()).collect(); + Self { + unicast_sockets: BTreeMap::new(), + multicast_groups, + } + } + + fn bind( + &mut self, + socket_table: Arc>>, + addr: &NetlinkSocketAddr, + receiver: MessageReceiver, + ) -> Result, SystemError> { + let port = if addr.port() != 0 { + addr.port() + } else { + let mut random_port = ProcessManager::current_pid().data() as u32; + while random_port == 0 || self.unicast_sockets.contains_key(&random_port) { + random_port = rand::soft_rand() as u32; + } + random_port + }; + + if self.unicast_sockets.contains_key(&port) { + return Err(SystemError::EADDRINUSE); + } + + self.unicast_sockets.insert(port, receiver); + + for group_id in addr.groups().ids_iter() { + let group = &mut self.multicast_groups[group_id as usize]; + group.add_member(port); + } + + Ok(BoundHandle::new(socket_table, port, addr.groups())) + } + + fn unicast(&self, dst_port: u32, message: Message) -> Result<(), SystemError> { + let Some(receiver) = self.unicast_sockets.get(&dst_port) else { + return Ok(()); + }; + receiver.enqueue_message(message) + } + + fn multicast(&self, dst_groups: GroupIdSet, message: Message) -> Result<(), SystemError> + where + Message: MulticastMessage, + { + for group_id in dst_groups.ids_iter() { + let Some(group) = self.multicast_groups.get(group_id as usize) else { + continue; + }; + for member in group.members() { + if let Some(receiver) = self.unicast_sockets.get(member) { + receiver.enqueue_message(message.clone())?; + } + } + } + Ok(()) + } +} + +#[derive(Debug)] +pub struct BoundHandle { + socket_table: Arc>>, + port: u32, + groups: GroupIdSet, +} + +impl BoundHandle { + fn new( + socket_table: Arc>>, + port: u32, + groups: GroupIdSet, + ) -> Self { + Self { + socket_table, + port, + groups, + } + } + + pub(super) const fn port(&self) -> u32 { + self.port + } + + pub(super) fn addr(&self) -> NetlinkSocketAddr { + NetlinkSocketAddr::new(self.port, self.groups) + } + + pub(super) fn add_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups.add_groups(groups); + } + + pub(super) fn drop_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + self.groups.drop_groups(groups); + } + + pub(super) fn bind_groups(&mut self, groups: GroupIdSet) { + let mut protocol_sockets = self.socket_table.write(); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + + for group_id in groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.add_member(self.port); + } + + self.groups = groups; + } +} + +impl Drop for BoundHandle { + fn drop(&mut self) { + let mut protocol_sockets = self.socket_table.write(); + + protocol_sockets.unicast_sockets.remove(&self.port); + + for group_id in self.groups.ids_iter() { + let group = &mut protocol_sockets.multicast_groups[group_id as usize]; + group.remove_member(self.port); + } + } +} + +pub trait SupportedNetlinkProtocol: Debug { + type Message: 'static + Send + Debug; + + fn socket_table(netns: Arc) -> Arc>>; + + fn bind( + addr: &NetlinkSocketAddr, + receiver: MessageReceiver, + netns: Arc, + ) -> Result, SystemError> { + let socket_table = Self::socket_table(netns); + let mut socket_table_guard = socket_table.write(); + socket_table_guard.bind(socket_table.clone(), addr, receiver) + } + + fn unicast( + dst_port: u32, + message: Self::Message, + netns: Arc, + ) -> Result<(), SystemError> { + Self::socket_table(netns).read().unicast(dst_port, message) + } + + //todo 多播消息用 + #[allow(unused)] + fn multicast( + dst_groups: GroupIdSet, + message: Self::Message, + netns: Arc, + ) -> Result<(), SystemError> + where + Self::Message: MulticastMessage, + { + Self::socket_table(netns) + .read() + .multicast(dst_groups, message) + } +} + +#[derive(Debug)] +pub struct NetlinkRouteProtocol; + +impl SupportedNetlinkProtocol for NetlinkRouteProtocol { + type Message = RouteNlMessage; + + fn socket_table(netns: Arc) -> Arc>> { + netns.netlink_socket_table().route() + } +} + +pub fn is_valid_protocol(protocol: u32) -> bool { + protocol < MAX_ALLOWED_PROTOCOL_ID +} + +pub trait NetlinkKernelSocket: Debug + Send + Sync { + fn protocol(&self) -> StandardNetlinkProtocol; + + /// 用于实现动态转换 + fn as_any_ref(&self) -> &dyn Any; +} + +/// 为一个网络命名空间生成支持的 Netlink 内核套接字 +pub fn generate_supported_netlink_kernel_sockets() -> HashMap> { + let mut sockets: HashMap> = + HashMap::with_capacity(MAX_ALLOWED_PROTOCOL_ID as usize); + let route_socket = Arc::new(NetlinkRouteKernelSocket::new()); + sockets.insert(route_socket.protocol().into(), route_socket); + + // Add other supported netlink kernel sockets here + sockets +} + +#[expect(non_camel_case_types)] +#[repr(u32)] +#[derive(Debug, Clone, Copy)] +pub enum StandardNetlinkProtocol { + /// Routing/device hook + ROUTE = 0, + /// Unused number + UNUSED = 1, + /// Reserved for user mode socket protocols + USERSOCK = 2, + /// Unused number, formerly ip_queue + FIREWALL = 3, + /// Socket monitoring + SOCK_DIAG = 4, + /// Netfilter/iptables ULOG + NFLOG = 5, + /// IPsec + XFRM = 6, + /// SELinux event notifications + SELINUX = 7, + /// Open-iSCSI + ISCSI = 8, + /// Auditing + AUDIT = 9, + FIB_LOOKUP = 10, + CONNECTOR = 11, + /// Netfilter subsystem + NETFILTER = 12, + IP6_FW = 13, + /// DECnet routing messages + DNRTMSG = 14, + /// Kernel messages to userspace + KOBJECT_UEVENT = 15, + GENERIC = 16, + /// Leave room for NETLINK_DM (DM Events) + /// SCSI Transports + SCSITRANSPORT = 18, + ECRYPTFS = 19, + RDMA = 20, + /// Crypto layer + CRYPTO = 21, + /// SMC monitoring + SMC = 22, +} + +impl From for u32 { + fn from(value: StandardNetlinkProtocol) -> Self { + value as u32 + } +} + +impl TryFrom for StandardNetlinkProtocol { + type Error = (); + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(StandardNetlinkProtocol::ROUTE), + 1 => Ok(StandardNetlinkProtocol::UNUSED), + 2 => Ok(StandardNetlinkProtocol::USERSOCK), + 3 => Ok(StandardNetlinkProtocol::FIREWALL), + 4 => Ok(StandardNetlinkProtocol::SOCK_DIAG), + 5 => Ok(StandardNetlinkProtocol::NFLOG), + 6 => Ok(StandardNetlinkProtocol::XFRM), + 7 => Ok(StandardNetlinkProtocol::SELINUX), + 8 => Ok(StandardNetlinkProtocol::ISCSI), + 9 => Ok(StandardNetlinkProtocol::AUDIT), + 10 => Ok(StandardNetlinkProtocol::FIB_LOOKUP), + 11 => Ok(StandardNetlinkProtocol::CONNECTOR), + 12 => Ok(StandardNetlinkProtocol::NETFILTER), + 13 => Ok(StandardNetlinkProtocol::IP6_FW), + 14 => Ok(StandardNetlinkProtocol::DNRTMSG), + 15 => Ok(StandardNetlinkProtocol::KOBJECT_UEVENT), + 16 => Ok(StandardNetlinkProtocol::GENERIC), + 18 => Ok(StandardNetlinkProtocol::SCSITRANSPORT), + 19 => Ok(StandardNetlinkProtocol::ECRYPTFS), + 20 => Ok(StandardNetlinkProtocol::RDMA), + 21 => Ok(StandardNetlinkProtocol::CRYPTO), + 22 => Ok(StandardNetlinkProtocol::SMC), + _ => Err(()), + } + } +} diff --git a/kernel/src/net/socket/netlink/table/multicast.rs b/kernel/src/net/socket/netlink/table/multicast.rs new file mode 100644 index 000000000..c4a7ab333 --- /dev/null +++ b/kernel/src/net/socket/netlink/table/multicast.rs @@ -0,0 +1,30 @@ +use alloc::collections::BTreeSet; + +#[derive(Debug)] +pub struct MulticastGroup { + // portnumber: u32, + members: BTreeSet, +} + +impl MulticastGroup { + pub const fn new() -> Self { + Self { + members: BTreeSet::new(), + } + } + + pub fn add_member(&mut self, port_num: u32) { + self.members.insert(port_num); + } + + pub fn remove_member(&mut self, port_num: u32) { + self.members.remove(&port_num); + } + + pub fn members(&self) -> &BTreeSet { + &self.members + } +} + +/// Uevent that can be sent to multicast groups. +pub trait MulticastMessage: Clone {} diff --git a/kernel/src/net/socket/utils/datagram_common.rs b/kernel/src/net/socket/utils/datagram_common.rs new file mode 100644 index 000000000..ab55dda23 --- /dev/null +++ b/kernel/src/net/socket/utils/datagram_common.rs @@ -0,0 +1,203 @@ +use crate::filesystem::epoll::EPollEventType; +use crate::process::namespace::net_namespace::NetNamespace; +use crate::{ + libs::{rwlock::RwLock, wait_queue::WaitQueue}, + net::socket::PMSG, +}; +use alloc::fmt::Debug; +use alloc::sync::Arc; +use core::panic; +use system_error::SystemError; + +//todo netlink和udp的操作相同,目前只是为netlink实现了下面的trait,后续为 UdpSocket实现下面的trait,提高复用性 + +pub trait Unbound { + type Endpoint: Debug; + type Bound; + + fn bind( + &mut self, + endpoint: &Self::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result; + + fn bind_ephemeral( + &mut self, + endpoint: &Self::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result; + + fn check_io_events(&self) -> EPollEventType; + + fn local_endpoint(&self) -> Option; +} + +pub trait Bound { + type Endpoint: Clone + Debug; + + fn bind(&mut self, _endpoint: &Self::Endpoint) -> Result<(), SystemError> { + Err(SystemError::EINVAL) + } + + fn local_endpoint(&self) -> Self::Endpoint; + + fn remote_endpoint(&self) -> Option; + + fn set_remote_endpoint(&mut self, endpoint: &Self::Endpoint); + + fn try_recv( + &self, + writer: &mut [u8], + flags: PMSG, + ) -> Result<(usize, Self::Endpoint), SystemError>; + + fn try_send(&self, buf: &[u8], to: &Self::Endpoint, flags: PMSG) -> Result; + + fn check_io_events(&self) -> EPollEventType; +} + +#[derive(Debug)] +pub enum Inner { + Unbound(UnboundSocket), + Bound(BoundSocket), +} + +impl Inner +where + UnboundSocket: Unbound, + BoundSocket: Bound, +{ + pub fn bind( + &mut self, + endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result<(), SystemError> { + let unbound = match self { + Inner::Bound(bound) => return bound.bind(endpoint), + Inner::Unbound(unbound) => unbound, + }; + + let bound = unbound.bind(endpoint, wait_queue, netns)?; + *self = Inner::Bound(bound); + + // log::info!("Socket bound to endpoint: {:?}", endpoint); + + Ok(()) + } + + pub fn bind_ephemeral( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result<(), SystemError> { + let unbound_datagram = match self { + Inner::Unbound(unbound) => unbound, + Inner::Bound(_) => return Ok(()), + }; + + let bound = unbound_datagram.bind_ephemeral(remote_endpoint, wait_queue, netns)?; + *self = Inner::Bound(bound); + + Ok(()) + } + + pub fn connect( + &mut self, + remote_endpoint: &UnboundSocket::Endpoint, + wait_queue: Arc, + netns: Arc, + ) -> Result<(), SystemError> { + self.bind_ephemeral(remote_endpoint, wait_queue, netns)?; + + let bound = match self { + Inner::Unbound(_) => { + unreachable!( + "`bind_to_ephemeral_endpoint` succeeds so the socket cannot be unbound" + ); + } + Inner::Bound(bound_datagram) => bound_datagram, + }; + bound.set_remote_endpoint(remote_endpoint); + + Ok(()) + } + + pub fn check_io_events(&self) -> EPollEventType { + match self { + Inner::Unbound(unbound_datagram) => unbound_datagram.check_io_events(), + Inner::Bound(bound_datagram) => bound_datagram.check_io_events(), + } + } + + pub fn addr(&self) -> Option { + match self { + Inner::Unbound(unbound) => unbound.local_endpoint(), + Inner::Bound(bound) => Some(bound.local_endpoint()), + } + } + + pub fn peer_addr(&self) -> Option { + match self { + Inner::Unbound(_) => None, + Inner::Bound(bound) => bound.remote_endpoint(), + } + } + + pub fn try_recv( + &self, + writer: &mut [u8], + flags: PMSG, + ) -> Result<(usize, UnboundSocket::Endpoint), SystemError> { + match self { + Inner::Unbound(_) => Err(SystemError::EAGAIN_OR_EWOULDBLOCK), + Inner::Bound(bound) => bound.try_recv(writer, flags), + } + } + + // try_send 在下面:) +} + +pub fn select_remote_and_bind( + inner_lock: &RwLock>, + remote: Option, + bind_ephemeral: B, + op: F, +) -> Result +where + UnboundSocket: Unbound, + BoundSocket: Bound, + B: FnOnce() -> Result<(), SystemError>, + F: FnOnce(&BoundSocket, UnboundSocket::Endpoint) -> Result, +{ + let mut inner = inner_lock.read(); + + // 这里用 loop 只是为了用 break :) + #[expect(clippy::never_loop)] + let bound = loop { + if let Inner::Bound(bound) = &*inner { + break bound; + } + + drop(inner); + bind_ephemeral()?; + + inner = inner_lock.read(); + + if let Inner::Bound(bound_datagram) = &*inner { + break bound_datagram; + } + + panic!(""); + }; + + let remote_endpoint = match remote { + Some(r) => r.clone(), + None => bound.remote_endpoint().ok_or(SystemError::EDESTADDRREQ)?, + }; + + op(bound, remote_endpoint) +} diff --git a/kernel/src/net/socket/utils.rs b/kernel/src/net/socket/utils/mod.rs similarity index 83% rename from kernel/src/net/socket/utils.rs rename to kernel/src/net/socket/utils/mod.rs index f3adbb3f5..b1825319b 100644 --- a/kernel/src/net/socket/utils.rs +++ b/kernel/src/net/socket/utils/mod.rs @@ -1,5 +1,8 @@ +pub(super) mod datagram_common; + use crate::net::socket::{ - self, inet::syscall::create_inet_socket, unix::create_unix_socket, Socket, + self, inet::syscall::create_inet_socket, netlink::create_netlink_socket, + unix::create_unix_socket, Socket, }; use alloc::sync::Arc; use system_error::SystemError; @@ -27,6 +30,7 @@ pub fn create_socket( is_nonblock, )?, AF::Unix => create_unix_socket(socket_type, is_nonblock)?, + AF::Netlink => create_netlink_socket(socket_type, protocol, is_nonblock)?, _ => { log::warn!("unsupport address family"); return Err(SystemError::EAFNOSUPPORT); diff --git a/kernel/src/net/syscall.rs b/kernel/src/net/syscall.rs index 7e7f0eee9..35dc8054e 100644 --- a/kernel/src/net/syscall.rs +++ b/kernel/src/net/syscall.rs @@ -124,9 +124,13 @@ impl Syscall { optval: &[u8], ) -> Result { let sol = socket::PSOL::try_from(level as u32)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; log::debug!("setsockopt: level = {:?} ", sol); - return socket.set_option(sol, optname, optval).map(|_| 0); + return socket + .as_socket() + .unwrap() + .set_option(sol, optname, optval) + .map(|_| 0); } /// @brief sys_getsockopt系统调用的实际执行函数 @@ -147,7 +151,8 @@ impl Syscall { ) -> Result { // 获取socket let optval = optval as *mut u32; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); use socket::{PSO, PSOL}; @@ -207,8 +212,11 @@ impl Syscall { /// @return 成功返回0,失败返回错误码 pub fn connect(fd: usize, addr: *const SockAddr, addrlen: u32) -> Result { let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; - socket.connect(endpoint)?; + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .connect(endpoint)?; Ok(0) } @@ -228,9 +236,11 @@ impl Syscall { // addrlen // ); let endpoint: Endpoint = SockAddr::to_endpoint(addr, addrlen)?; - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; - // log::debug!("bind: socket={:?}", socket); - socket.bind(endpoint)?; + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .bind(endpoint)?; Ok(0) } @@ -258,7 +268,8 @@ impl Syscall { let flags = socket::PMSG::from_bits_truncate(flags); - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); if let Some(endpoint) = endpoint { return socket.send_to(buf, flags, endpoint); @@ -283,7 +294,8 @@ impl Syscall { addr: *mut SockAddr, addr_len: *mut u32, ) -> Result { - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); let flags = socket::PMSG::from_bits_truncate(flags); @@ -320,7 +332,8 @@ impl Syscall { let iovs = unsafe { IoVecs::from_user(msg.msg_iov, msg.msg_iovlen, true)? }; let (buf, recv_size) = { - let socket = ProcessManager::current_pcb().get_socket(fd as i32)?; + let socket_inode = ProcessManager::current_pcb().get_socket_inode(fd as i32)?; + let socket = socket_inode.as_socket().unwrap(); let flags = socket::PMSG::from_bits_truncate(flags); @@ -344,8 +357,12 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn listen(fd: usize, backlog: usize) -> Result { - let inode = ProcessManager::current_pcb().get_socket(fd as i32)?; - inode.listen(backlog).map(|_| 0) + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .listen(backlog) + .map(|_| 0) } /// @brief sys_shutdown系统调用的实际执行函数 @@ -355,8 +372,12 @@ impl Syscall { /// /// @return 成功返回0,失败返回错误码 pub fn shutdown(fd: usize, how: usize) -> Result { - let inode = ProcessManager::current_pcb().get_socket(fd as i32)?; - inode.shutdown(ShutdownBit::try_from(how)?).map(|()| 0) + ProcessManager::current_pcb() + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() + .shutdown(ShutdownBit::try_from(how)?) + .map(|()| 0) } /// @brief sys_accept系统调用的实际执行函数 @@ -413,7 +434,9 @@ impl Syscall { ) -> Result { let (new_socket, remote_endpoint) = { ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .accept()? }; @@ -459,7 +482,9 @@ impl Syscall { return Err(SystemError::EINVAL); } ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .local_endpoint()? .write_to_user(addr, addrlen)?; return Ok(0); @@ -482,7 +507,9 @@ impl Syscall { } ProcessManager::current_pcb() - .get_socket(fd as i32)? + .get_socket_inode(fd as i32)? + .as_socket() + .unwrap() .remote_endpoint()? .write_to_user(addr, addrlen)?; diff --git a/kernel/src/process/mod.rs b/kernel/src/process/mod.rs index e0fcacb6d..32892449d 100644 --- a/kernel/src/process/mod.rs +++ b/kernel/src/process/mod.rs @@ -14,7 +14,6 @@ use alloc::{ }; use cred::INIT_CRED; use hashbrown::HashMap; -use intertrait::cast::CastArc; use log::{debug, error, info, warn}; use pid::{alloc_pid, Pid, PidLink, PidType}; use process_group::Pgid; @@ -55,7 +54,6 @@ use crate::{ ucontext::AddressSpace, PhysAddr, VirtAddr, }, - net::socket::Socket, sched::{ DequeueFlag, EnqueueFlag, OnRq, SchedMode, WakeupFlags, __schedule, completion::Completion, cpu_rq, fair::FairSchedEntity, prio::MAX_PRIO, @@ -1069,7 +1067,7 @@ impl ProcessControlBlock { self.task_tgid_vnr().unwrap() == RawPid(1) } - /// 根据文件描述符序号,获取socket对象的Arc指针 + /// 根据文件描述符序号,获取socket相应的IndexNode的Arc指针 /// /// this is a helper function /// @@ -1079,8 +1077,13 @@ impl ProcessControlBlock { /// /// ## 返回值 /// - /// Option(&mut Box) socket对象的可变引用. 如果文件描述符不是socket,那么返回None - pub fn get_socket(&self, fd: i32) -> Result, SystemError> { + /// 如果fd对应的文件是一个socket,那么返回这个socket相应的IndexNode的Arc指针,否则返回错误码 + /// + /// # 注意 + /// 因为底层的Socket中可能包含泛型,经过类型擦除转换成Arc的时候内部的泛型信息会丢失; + /// 因此这里返回Arc,可在外部直接通过 `as_socket()` 转换成 `Option<&dyn Socket>`; + /// 因为内部已经经过检查,因此在外部可以直接 `unwarp` 来获取 `&dyn Socket` + pub fn get_socket_inode(&self, fd: i32) -> Result, SystemError> { let f = ProcessManager::current_pcb() .fd_table() .read() @@ -1093,11 +1096,15 @@ impl ProcessControlBlock { if f.file_type() != FileType::Socket { return Err(SystemError::EBADF); } + + let inode = f.inode(); // log::info!("get_socket: fd {} is a socket", fd); - f.inode().cast::().map_err(|_| { - log::error!("get_socket: fd {} is not a socket", fd); - SystemError::EBADF - }) + if let Some(_sock) = inode.as_socket() { + // log::info!("{:?}", sock); + return Ok(inode); + } + + Err(SystemError::EBADF) } /// 当前进程退出时,让初始进程收养所有子进程 diff --git a/kernel/src/process/namespace/mod.rs b/kernel/src/process/namespace/mod.rs index 0a0a78b7c..4427c7f78 100644 --- a/kernel/src/process/namespace/mod.rs +++ b/kernel/src/process/namespace/mod.rs @@ -1,4 +1,5 @@ pub mod mnt; +pub mod net_namespace; pub mod nsproxy; pub mod pid_namespace; pub mod unshare; diff --git a/kernel/src/process/namespace/net_namespace.rs b/kernel/src/process/namespace/net_namespace.rs new file mode 100644 index 000000000..2f21a95f3 --- /dev/null +++ b/kernel/src/process/namespace/net_namespace.rs @@ -0,0 +1,298 @@ +use crate::arch::CurrentIrqArch; +use crate::driver::net::bridge::BridgeDriver; +use crate::driver::net::loopback::{generate_loopback_iface_default, LoopbackInterface}; +use crate::exception::InterruptArch; +use crate::init::initcall::INITCALL_SUBSYS; +use crate::libs::rwlock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use crate::net::routing::Router; +use crate::net::socket::netlink::table::{ + generate_supported_netlink_kernel_sockets, NetlinkKernelSocket, NetlinkSocketTable, +}; +use crate::process::fork::CloneFlags; +use crate::process::kthread::{KernelThreadClosure, KernelThreadMechanism}; +use crate::process::namespace::{NamespaceOps, NamespaceType}; +use crate::process::{ProcessControlBlock, ProcessManager}; +use crate::sched::{schedule, SchedMode}; +use crate::{ + driver::net::Iface, + libs::spinlock::SpinLock, + process::namespace::{nsproxy::NsCommon, user_namespace::UserNamespace}, +}; +use alloc::boxed::Box; +use alloc::collections::BTreeMap; +use alloc::string::{String, ToString}; +use alloc::sync::{Arc, Weak}; +use core::sync::atomic::AtomicUsize; +use hashbrown::HashMap; +use system_error::SystemError; +use unified_init::macros::unified_init; + +lazy_static! { + /// # 所有网络设备,进程,socket的初始网络命名空间 + pub static ref INIT_NET_NAMESPACE: Arc = NetNamespace::new_root(); +} + +/// # 网络命名空间计数器 +/// 用于生成唯一的网络命名空间ID +/// 每次创建新的网络命名空间时,都会增加这个计数器 +pub static mut NETNS_COUNTER: AtomicUsize = AtomicUsize::new(0); + +#[unified_init(INITCALL_SUBSYS)] +pub fn root_net_namespace_init() -> Result<(), SystemError> { + // 创建root网络命名空间的轮询线程 + let pcb = + NetNamespace::create_polling_thread(INIT_NET_NAMESPACE.clone(), "root_netns".to_string()); + INIT_NET_NAMESPACE.set_poll_thread(pcb); + + // 创建 router + let router = Router::new("root_netns_router".to_string()); + INIT_NET_NAMESPACE.inner_mut().router = router.clone(); + let mut guard = router.ns.write(); + *guard = INIT_NET_NAMESPACE.self_ref.clone(); + + Ok(()) +} + +/// # 获取下一个网络命名空间计数器的值 +fn get_next_netns_counter() -> usize { + unsafe { NETNS_COUNTER.fetch_add(1, core::sync::atomic::Ordering::SeqCst) } +} + +#[derive(Debug)] +pub struct NetNamespace { + ns_common: NsCommon, + self_ref: Weak, + _user_ns: Arc, + inner: RwLock, + /// # 负责当前网络命名空间网卡轮询的线程 + net_poll_thread: SpinLock>>, + /// # 当前网络命名空间下所有网络接口的列表 + /// 这个列表在中断上下文会使用到,因此需要irqsave + /// 没有放在InnerNetNamespace里面,独立出来,方便管理 + device_list: RwLock>>, + ///当前网络命名空间下的桥接设备列表 + bridge_list: RwLock>>, + + // -- Netlink -- + /// # 当前网络命名空间下的 Netlink 套接字表 + /// 负责绑定netlink套接字的接收队列,以便发送接收消息 + netlink_socket_table: NetlinkSocketTable, + /// # 当前网络命名空间下的 Netlink 内核套接字 + /// 负责接收并处理 Netlink 消息 + netlink_kernel_socket: RwLock>>, +} + +#[derive(Debug)] +pub struct InnerNetNamespace { + router: Arc, + /// 当前网络命名空间的loopback网卡 + loopback_iface: Option>, + /// 当前网络命名空间的默认网卡 + /// 这个网卡会在没有指定网卡的情况下使用 + default_iface: Option>, +} + +impl InnerNetNamespace { + pub fn router(&self) -> &Arc { + &self.router + } + + pub fn loopback_iface(&self) -> Option> { + self.loopback_iface.clone() + } +} + +impl NetNamespace { + pub fn new_root() -> Arc { + let inner = InnerNetNamespace { + // 这里没有直接创建 router,而是留到 init 函数中创建 + router: Router::new_empty(), + loopback_iface: None, + default_iface: None, + }; + + let netns = Arc::new_cyclic(|self_ref| Self { + ns_common: NsCommon::new(0, NamespaceType::Net), + self_ref: self_ref.clone(), + _user_ns: crate::process::namespace::user_namespace::INIT_USER_NAMESPACE.clone(), + inner: RwLock::new(inner), + net_poll_thread: SpinLock::new(None), + device_list: RwLock::new(BTreeMap::new()), + bridge_list: RwLock::new(BTreeMap::new()), + netlink_socket_table: NetlinkSocketTable::default(), + netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), + }); + + // Self::create_polling_thread(netns.clone(), "netns_root".to_string()); + log::info!("Initialized root net namespace"); + netns + } + + pub fn new_empty(user_ns: Arc) -> Result, SystemError> { + let counter = get_next_netns_counter(); + let loopback = generate_loopback_iface_default(); + + let inner = InnerNetNamespace { + router: Router::new(format!("netns_router_{}", counter)), + loopback_iface: Some(loopback.clone()), + default_iface: None, + }; + + let netns = Arc::new_cyclic(|self_ref| Self { + ns_common: NsCommon::new(0, NamespaceType::Net), + self_ref: self_ref.clone(), + _user_ns: user_ns, + inner: RwLock::new(inner), + net_poll_thread: SpinLock::new(None), + device_list: RwLock::new(BTreeMap::new()), + bridge_list: RwLock::new(BTreeMap::new()), + netlink_socket_table: NetlinkSocketTable::default(), + netlink_kernel_socket: RwLock::new(generate_supported_netlink_kernel_sockets()), + }); + Self::create_polling_thread(netns.clone(), format!("netns_{}", counter)); + netns.add_device(loopback); + + Ok(netns) + } + + pub(super) fn copy_net_ns( + &self, + clone_flags: &CloneFlags, + user_ns: Arc, + ) -> Result, SystemError> { + if !clone_flags.contains(CloneFlags::CLONE_NEWNET) { + return Ok(self.self_ref.upgrade().unwrap()); + } + + Self::new_empty(user_ns) + } + + pub fn device_list_mut(&self) -> RwLockWriteGuard<'_, BTreeMap>> { + self.device_list.write() + } + + pub fn device_list(&self) -> RwLockReadGuard<'_, BTreeMap>> { + self.device_list.read() + } + + pub fn inner(&self) -> RwLockReadGuard<'_, InnerNetNamespace> { + self.inner.read() + } + + pub fn inner_mut(&self) -> RwLockWriteGuard<'_, InnerNetNamespace> { + self.inner.write() + } + + pub fn set_loopback_iface(&self, loopback: Arc) { + self.inner_mut().loopback_iface = Some(loopback); + } + + pub fn loopback_iface(&self) -> Option> { + self.inner().loopback_iface() + } + + pub fn set_default_iface(&self, iface: Arc) { + self.inner_mut().default_iface = Some(iface); + } + + pub fn default_iface(&self) -> Option> { + self.inner().default_iface.clone() + } + + pub fn router(&self) -> Arc { + self.inner().router.clone() + } + + pub fn netlink_socket_table(&self) -> &NetlinkSocketTable { + &self.netlink_socket_table + } + + pub fn get_netlink_kernel_socket_by_protocol( + &self, + protocol: u32, + ) -> Option> { + self.netlink_kernel_socket.read().get(&protocol).cloned() + } + + pub fn add_device(&self, device: Arc) { + device.set_net_namespace(self.self_ref.upgrade().unwrap()); + + self.device_list_mut().insert(device.nic_id(), device); + + log::info!( + "Network device added to namespace count: {:?}", + self.device_list().len() + ); + } + + pub fn remove_device(&self, nic_id: &usize) { + self.device_list_mut().remove(nic_id); + } + + pub fn insert_bridge(&self, bridge: Arc) { + self.bridge_list.write().insert(bridge.name(), bridge); + } + + /// # 拉起网络命名空间的轮询线程 + pub fn wakeup_poll_thread(&self) { + if self.net_poll_thread.lock().is_none() { + return; + } + // log::info!("wakeup net_poll thread for namespace"); + let _ = ProcessManager::wakeup(self.net_poll_thread.lock().as_ref().unwrap()); + } + + /// # 网络命名空间的轮询线程 + /// 该线程会轮询当前命名空间下的所有网络接口 + /// 并调用它们的poll方法 + /// 注意: 此方法仅可在初始化当前net namespace时创建进程使用 + fn polling(&self) { + log::info!("net_poll thread started for namespace"); + loop { + for (_, iface) in self.device_list.read_irqsave().iter() { + iface.poll(); + } + + let irq_guard = unsafe { CurrentIrqArch::save_and_disable_irq() }; + ProcessManager::mark_sleep(true).expect("netns_poll_kthread:mark sleep failed"); + drop(irq_guard); + // log::info!("net_poll thread going to sleep"); + schedule(SchedMode::SM_NONE); + } + } + + fn create_polling_thread(netns: Arc, name: String) -> Arc { + let pcb = { + let closure: Box i32 + Send + Sync> = Box::new(move || { + netns.polling(); + 0 + }); + KernelThreadClosure::EmptyClosure((closure, ())) + }; + + let pcb = KernelThreadMechanism::create_and_run(pcb, name) + .ok_or("") + .expect("create net_poll thread for net namespace failed"); + log::info!("net_poll thread created for namespace"); + pcb + } + + /// # 设置网络命名空间的轮询线程 + /// 这个方法仅可在初始化网络命名空间时调用 + fn set_poll_thread(&self, pcb: Arc) { + let mut lock = self.net_poll_thread.lock(); + *lock = Some(pcb); + } +} + +impl NamespaceOps for NetNamespace { + fn ns_common(&self) -> &NsCommon { + &self.ns_common + } +} + +impl ProcessManager { + pub fn current_netns() -> Arc { + Self::current_pcb().nsproxy.read().net_ns.clone() + } +} diff --git a/kernel/src/process/namespace/nsproxy.rs b/kernel/src/process/namespace/nsproxy.rs index e9b26f5eb..dfc70ed20 100644 --- a/kernel/src/process/namespace/nsproxy.rs +++ b/kernel/src/process/namespace/nsproxy.rs @@ -5,6 +5,7 @@ use crate::process::{ fork::CloneFlags, namespace::{ mnt::{root_mnt_namespace, MntNamespace}, + net_namespace::{NetNamespace, INIT_NET_NAMESPACE}, uts_namespace::{UtsNamespace, INIT_UTS_NAMESPACE}, }, ProcessControlBlock, ProcessManager, @@ -28,8 +29,11 @@ pub struct NsProxy { /// mount namespace(挂载命名空间) pub mnt_ns: Arc, pub uts_ns: Arc, + /// 网络命名空间 + pub net_ns: Arc, + // 注意,user_ns 存储在cred,不存储在nsproxy + // 其他namespace(为未来扩展预留) - // pub net_ns: Option>, // pub ipc_ns: Option>, // pub cgroup_ns: Option>, // pub time_ns: Option>, @@ -46,10 +50,12 @@ impl NsProxy { pub fn new_root() -> Arc { let root_pid_ns = super::pid_namespace::INIT_PID_NAMESPACE.clone(); let root_mnt_ns = root_mnt_namespace(); + let root_net_ns = INIT_NET_NAMESPACE.clone(); let root_uts_ns = INIT_UTS_NAMESPACE.clone(); Arc::new(Self { pid_ns_for_children: root_pid_ns, mnt_ns: root_mnt_ns, + net_ns: root_net_ns, uts_ns: root_uts_ns, }) } @@ -64,10 +70,16 @@ impl NsProxy { &self.mnt_ns } + /// 获取 net namespace + pub fn net_namespace(&self) -> &Arc { + &self.net_ns + } + pub fn clone_inner(&self) -> Self { Self { pid_ns_for_children: self.pid_ns_for_children.clone(), mnt_ns: self.mnt_ns.clone(), + net_ns: self.net_ns.clone(), uts_ns: self.uts_ns.clone(), } } @@ -147,11 +159,13 @@ pub(super) fn create_new_namespaces( .copy_pid_ns(clone_flags, user_ns.clone())?; let mnt_ns = nsproxy.mnt_ns.copy_mnt_ns(clone_flags, user_ns.clone())?; + let net_ns = nsproxy.net_ns.copy_net_ns(clone_flags, user_ns.clone())?; let uts_ns = nsproxy.uts_ns.copy_uts_ns(clone_flags, user_ns.clone())?; let result = NsProxy { pid_ns_for_children, mnt_ns, + net_ns, uts_ns, }; diff --git a/kernel/src/process/namespace/unshare.rs b/kernel/src/process/namespace/unshare.rs index d6167595c..b0ae76847 100644 --- a/kernel/src/process/namespace/unshare.rs +++ b/kernel/src/process/namespace/unshare.rs @@ -21,7 +21,7 @@ pub fn ksys_unshare(flags: CloneFlags) -> Result<(), SystemError> { switch_task_namespaces(¤t_pcb, new_nsproxy)?; } // TODO: 处理其他命名空间的 unshare 操作 - // CLONE_NEWNS, CLONE_FS, CLONE_FILES, CLONE_SIGHAND, CLONE_VM, CLONE_THREAD, CLONE_SYSVSEM, + // CLONE_FS, CLONE_FILES, CLONE_SIGHAND, CLONE_VM, CLONE_THREAD, CLONE_SYSVSEM, // CLONE_NEWUTS, CLONE_NEWIPC, CLONE_NEWUSER, CLONE_NEWNET, CLONE_NEWCGROUP, CLONE_NEWTIME Ok(()) diff --git a/tools/run-qemu.sh b/tools/run-qemu.sh index 81cf1c03b..d13864220 100755 --- a/tools/run-qemu.sh +++ b/tools/run-qemu.sh @@ -99,6 +99,7 @@ QEMU_DRIVE="id=disk,file=${QEMU_DISK_IMAGE},if=none" QEMU_ACCELARATE="" QEMU_ARGUMENT=" -no-reboot " QEMU_DEVICES="" +# QEMU_ARGUMENT+=" -S " if [ -f "${QEMU_EXT4_DISK_IMAGE}" ]; then QEMU_DRIVE+=" -drive id=ext4disk,file=${QEMU_EXT4_DISK_IMAGE},if=none,format=raw" diff --git a/user/apps/c_unitest/test_bind.c b/user/apps/c_unitest/test_bind.c index 420b597e5..6a6f3e1f2 100644 --- a/user/apps/c_unitest/test_bind.c +++ b/user/apps/c_unitest/test_bind.c @@ -139,6 +139,9 @@ void test_all_ports() { } count++; + + close(tcp_fd); + if (count>=100) break; } printf("===TEST 10===\n"); printf("count: %d\n", count); diff --git a/user/apps/c_unitest/test_epoll_socket.c b/user/apps/c_unitest/test_epoll_socket.c new file mode 100644 index 000000000..7b46f8699 --- /dev/null +++ b/user/apps/c_unitest/test_epoll_socket.c @@ -0,0 +1,222 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "111.111.11.1" +#define CLIENT_IP "111.111.11.2" +#define PORT 8888 +#define MAX_EVENTS 10 +#define BUFFER_SIZE 1024 + +// 函数声明 +void server_process(); +void client_process(); + +int main() { + pid_t pid = fork(); + + if (pid < 0) { + perror("fork failed"); + exit(EXIT_FAILURE); + } else if (pid == 0) { + // 子进程作为客户端 + // 等待一秒,确保服务器已启动 + sleep(1); + client_process(); + } else { + // 父进程作为服务器 + server_process(); + } + + return 0; +} + + +// 服务器进程逻辑 +void server_process() { + printf("[Server] Starting server process...\n"); + + int listen_sock, conn_sock, epoll_fd; + struct sockaddr_in server_addr, client_addr; + socklen_t client_len = sizeof(client_addr); + struct epoll_event ev, events[MAX_EVENTS]; + char buffer[BUFFER_SIZE]; + int data_processed = 0; // 添加标志位,标记是否已处理过数据 + + if ((listen_sock = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0)) < 0) { + perror("[Server] socket creation failed"); + exit(EXIT_FAILURE); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + perror("[Server] inet_pton failed"); + exit(EXIT_FAILURE); + } + + if (bind(listen_sock, + (struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + perror("[Server] bind failed"); + exit(EXIT_FAILURE); + } + + if (listen(listen_sock, 1) < 0) { + perror("[Server] listen failed"); + exit(EXIT_FAILURE); + } + printf("[Server] Listening on %s:%d\n", SERVER_IP, PORT); + + if ((epoll_fd = epoll_create1(0)) < 0) { + perror("[Server] epoll_create1 failed"); + exit(EXIT_FAILURE); + } + + ev.events = EPOLLIN; + ev.data.fd = listen_sock; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, listen_sock, &ev) < 0) { + perror("[Server] epoll_ctl: listen_sock failed"); + exit(EXIT_FAILURE); + } + printf("Adding listening socket %d to epoll\n", listen_sock); + + while (!data_processed) { + int nfds = epoll_wait(epoll_fd, events, MAX_EVENTS, -1); + if (nfds < 0) { + perror("[Server] epoll_wait failed"); + exit(EXIT_FAILURE); + } + + for (int n = 0; n < nfds; ++n) { + if (events[n].data.fd == listen_sock) { + printf("trying to accept new connection...\n"); + while (1) { + conn_sock = accept(listen_sock, + (struct sockaddr *)&client_addr, + &client_len); + if (conn_sock < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + printf("All incoming connections have been " + "processed.\n"); + break; + } else { + perror("accept error"); + exit(EXIT_FAILURE); + break; + } + } + ev.events = EPOLLIN | EPOLLET; + ev.data.fd = conn_sock; + if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, conn_sock, &ev) < + 0) { + perror("[Server] epoll_ctl: conn_sock failed"); + exit(EXIT_FAILURE); + } + printf("[Server] Accepted connection from %s:%d\n", + inet_ntoa(client_addr.sin_addr), + ntohs(client_addr.sin_port)); + } + } else { + printf("[Server] handling client data...\n"); + int client_fd = events[n].data.fd; + int nread = read(client_fd, buffer, BUFFER_SIZE); + if (nread == -1) { + if (errno != EAGAIN) { + perror("[Server] read error"); + close(client_fd); + } + } else if (nread == 0) { + printf("[Server] Client disconnected.\n"); + close(client_fd); + } else { + buffer[nread] = '\0'; + printf("[Server] Received from client: %s\n", buffer); + write(client_fd, buffer, nread); + printf("[Server] Echoed data back to client. Server will " + "now exit.\n"); + data_processed = 1; // 设置退出标志 + sleep(3); + + close(client_fd); + break; + } + } + } + } + + printf("[Server] Server process completed.\n"); + close(listen_sock); + close(epoll_fd); +} + +// 客户端进程逻辑 +void client_process() { + printf("[Client] Starting client process...\n"); + + int sock = 0; + struct sockaddr_in client_bind_addr, server_addr; + char buffer[BUFFER_SIZE] = {0}; + + if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0) { + perror("[Client] socket creation failed"); + exit(EXIT_FAILURE); + } + + memset(&client_bind_addr, 0, sizeof(client_bind_addr)); + client_bind_addr.sin_family = AF_INET; + client_bind_addr.sin_port = htons(7777); + if (inet_pton(AF_INET, CLIENT_IP, &client_bind_addr.sin_addr) <= 0) { + perror("[Client] inet_pton for bind failed"); + exit(EXIT_FAILURE); + } + + if (bind(sock, + (struct sockaddr *)&client_bind_addr, + sizeof(client_bind_addr)) < 0) { + perror("[Client] bind failed"); + exit(EXIT_FAILURE); + } + printf("[Client] Bound to IP %s\n", CLIENT_IP); + + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + perror("[Client] inet_pton for connect failed"); + exit(EXIT_FAILURE); + } + + if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < + 0) { + perror("[Client] connect failed"); + exit(EXIT_FAILURE); + } + printf("[Client] Connected to server %s:%d\n", SERVER_IP, PORT); + + + const char *message = "Hello from client"; + write(sock, message, strlen(message)); + printf("[Client] Sent: %s\n", message); + sleep(1); + + int valread = read(sock, buffer, BUFFER_SIZE); + if (valread > 0) { + buffer[valread] = '\0'; + printf("[Client] Received: %s\n", buffer); + } + + printf("[Client] Client process completed.\n"); + close(sock); +} \ No newline at end of file diff --git a/user/apps/c_unitest/test_netlink.c b/user/apps/c_unitest/test_netlink.c new file mode 100644 index 000000000..075b795a3 --- /dev/null +++ b/user/apps/c_unitest/test_netlink.c @@ -0,0 +1,418 @@ +#define _GNU_SOURCE +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// 定义一个足够大的缓冲区来接收Netlink消息 +#define NL_BUFSIZE 8192 + +// 结构体,用于将请求消息封装起来 +struct nl_req_t { + struct nlmsghdr nlh; + struct ifaddrmsg ifa; +}; + +void parse_rtattr(struct rtattr *tb[], int max, struct rtattr *rta, int len) { + memset(tb, 0, sizeof(struct rtattr *) * (max + 1)); + while (RTA_OK(rta, len)) { + if (rta->rta_type <= max) { + tb[rta->rta_type] = rta; + } + rta = RTA_NEXT(rta, len); + } +} + +int run_netlink_test() { + int sock_fd; + struct sockaddr_nl sa_nl; + struct nl_req_t req; + + // struct iovec iov; + // struct msghdr msg; + // struct sockaddr_nl src_addr; // 用于接收发送方的地址 + + char buf[NL_BUFSIZE]; + ssize_t len; + struct nlmsghdr *nlh; + + // 创建Netlink套接字 + sock_fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + if (sock_fd < 0) { + perror("socket creation failed"); + return EXIT_FAILURE; + } + + // 设置Netlink地址 + memset(&sa_nl, 0, sizeof(sa_nl)); + sa_nl.nl_family = AF_NETLINK; + sa_nl.nl_pid = getpid(); // 使用进程ID作为地址 + + // 绑定Netlink套接字 + if (bind(sock_fd, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) < 0) { + perror("socket bind failed"); + close(sock_fd); + return EXIT_FAILURE; + } + + // 构建RTM_GETADDR请求消息 + memset(&req, 0, sizeof(req)); + req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); + // 这是关键:设置DUMP标志以获取所有地址 + req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.nlh.nlmsg_type = RTM_GETADDR; + req.nlh.nlmsg_seq = 1; // 序列号,用于匹配请求和响应 + req.nlh.nlmsg_pid = getpid(); + req.ifa.ifa_family = + AF_INET; // 只请求 IPv4 地址 (也可以用 AF_UNSPEC 获取所有) + + // 内核的目标地址 + struct sockaddr_nl dest_addr; + memset(&dest_addr, 0, sizeof(dest_addr)); + dest_addr.nl_family = AF_NETLINK; + dest_addr.nl_pid = 0; // 0 for the kernel + dest_addr.nl_groups = 0; // Unicast + + // 发送请求到内核 + if (sendto(sock_fd, + &req, + req.nlh.nlmsg_len, + 0, + (struct sockaddr *)&dest_addr, + sizeof(dest_addr)) < 0) { + perror("send failed"); + close(sock_fd); + return EXIT_FAILURE; + } + + printf("Sent RTM_GETADDR request with DUMP flag.\n\n"); + + // // 准备 recvmsg 所需的结构体 + // // iovec 指向我们的主数据缓冲区 + // iov.iov_base = buf; + // iov.iov_len = sizeof(buf); + + // // msghdr 将所有部分组合在一起 + // msg.msg_name = &src_addr; // 填充发送方地址 + // msg.msg_namelen = sizeof(src_addr); + // msg.msg_iov = &iov; // 指向数据缓冲区 + // msg.msg_iovlen = 1; + // msg.msg_control = NULL; // 我们暂时不处理控制消息 + // msg.msg_controllen = 0; + + int received_messages = 0; + + // 循环接收响应 + while ((len = recv(sock_fd, buf, sizeof(buf), 0)) > 0) { + // ssize_t len = recvmsg(sock_fd, &msg, 0); + + // if (len < 0) { + // perror("recvmsg failed"); + // break; + // } + + // if (len == 0) { + // printf("EOF on netlink socket\n"); + // break; + // } + + // if (msg.msg_flags & MSG_TRUNC) { + // fprintf( + // stderr, + // "Warning: Message was truncated. Buffer may be too small.\n"); + // } + + + // 使用 NLMSG_OK 遍历缓冲区中可能存在的多条消息 + for (nlh = (struct nlmsghdr *)buf; NLMSG_OK(nlh, len); + nlh = NLMSG_NEXT(nlh, len)) { + + // 如果是 DUMP 结束的标志,则退出循环 + if (nlh->nlmsg_type == NLMSG_DONE) { + printf("--- End of DUMP ---\n"); + if (!received_messages) { + printf("(Received an empty list as expected)\n"); + } + close(sock_fd); + return EXIT_SUCCESS; + } + + // 如果是错误消息 + if (nlh->nlmsg_type == NLMSG_ERROR) { + struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh); + fprintf(stderr, + "Netlink error received: %s\n", + strerror(-err->error)); + close(sock_fd); + return EXIT_FAILURE; + } + + // 只处理我们期望的 RTM_NEWADDR 消息 + if (nlh->nlmsg_type != RTM_NEWADDR) { + printf("Received unexpected message type: %d\n", + nlh->nlmsg_type); + continue; + } + + // printf("Received message from PID: %u\n", src_addr.nl_pid); + + // 表明我们至少接收到了一条信息 + received_messages = 1; + + struct ifaddrmsg *ifa = (struct ifaddrmsg *)NLMSG_DATA(nlh); + struct rtattr *rta_tb[IFA_MAX + 1]; + + // 解析消息中的路由属性 + int rta_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*ifa)); + parse_rtattr(rta_tb, IFA_MAX, IFA_RTA(ifa), rta_len); + + printf("Interface Index: %d, PrefixLen: %d, Scope: %d\n", + ifa->ifa_index, + ifa->ifa_prefixlen, + ifa->ifa_scope); + + char ip_addr_str[INET6_ADDRSTRLEN]; + + // 打印 IFA_LABEL (对应你的 AddrAttr::Label) + if (rta_tb[IFA_LABEL]) { + printf("\tLabel: %s\n", (char *)RTA_DATA(rta_tb[IFA_LABEL])); + } + + // 打印 IFA_ADDRESS (对应你的 AddrAttr::Address) + if (rta_tb[IFA_ADDRESS]) { + inet_ntop(ifa->ifa_family, + RTA_DATA(rta_tb[IFA_ADDRESS]), + ip_addr_str, + sizeof(ip_addr_str)); + printf("\tAddress: %s\n", ip_addr_str); + } + + // 打印 IFA_LOCAL (对应你的 AddrAttr::Local) + if (rta_tb[IFA_LOCAL]) { + inet_ntop(ifa->ifa_family, + RTA_DATA(rta_tb[IFA_LOCAL]), + ip_addr_str, + sizeof(ip_addr_str)); + printf("\tLocal: %s\n", ip_addr_str); + } + printf("----------------------------------------\n"); + } + } + + if (len < 0) { + perror("recv failed"); + } + + close(sock_fd); + return EXIT_SUCCESS; +} + +int main(int argc, char *argv[]) { + + printf("=========== STAGE 1: Testing in Default Network Namespace " + "===========\n"); + if (run_netlink_test() != 0) { + fprintf(stderr, "Test failed in the default namespace.\n"); + return EXIT_FAILURE; + } + + printf("\n\n=========== STAGE 2: Creating and Testing in a New Network " + "Namespace ===========\n"); + + // ** 关键步骤:创建新的网络命名空间 ** + // 这个调用会将当前进程移入一个新的、隔离的网络栈中 + if (unshare(CLONE_NEWNET) == -1) { + perror("unshare(CLONE_NEWNET) failed"); + fprintf(stderr, + "This test requires root privileges (e.g., 'sudo " + "./your_program').\n"); + return EXIT_FAILURE; + } + printf("Successfully created and entered a new network namespace.\n"); + + // 在新的命名空间中再次运行同样的测试 + if (run_netlink_test() != 0) { + fprintf(stderr, "Test failed in the new namespace.\n"); + return EXIT_FAILURE; + } + + printf("\nAll tests completed successfully.\n"); + return EXIT_SUCCESS; + + // int sock_fd; + // struct sockaddr_nl sa_nl; + // struct nl_req_t req; + + // // struct iovec iov; + // // struct msghdr msg; + // // struct sockaddr_nl src_addr; // 用于接收发送方的地址 + + // char buf[NL_BUFSIZE]; + // ssize_t len; + // struct nlmsghdr *nlh; + + // // 创建Netlink套接字 + // sock_fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_ROUTE); + // if (sock_fd < 0) { + // perror("socket creation failed"); + // return EXIT_FAILURE; + // } + + // // 设置Netlink地址 + // memset(&sa_nl, 0, sizeof(sa_nl)); + // sa_nl.nl_family = AF_NETLINK; + // sa_nl.nl_pid = getpid(); // 使用进程ID作为地址 + + // // 绑定Netlink套接字 + // if (bind(sock_fd, (struct sockaddr *)&sa_nl, sizeof(sa_nl)) < 0) { + // perror("socket bind failed"); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // // 构建RTM_GETADDR请求消息 + // memset(&req, 0, sizeof(req)); + // req.nlh.nlmsg_len = NLMSG_LENGTH(sizeof(struct ifaddrmsg)); + // // 这是关键:设置DUMP标志以获取所有地址 + // req.nlh.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + // req.nlh.nlmsg_type = RTM_GETADDR; + // req.nlh.nlmsg_seq = 1; // 序列号,用于匹配请求和响应 + // req.nlh.nlmsg_pid = getpid(); + // req.ifa.ifa_family = + // AF_INET; // 只请求 IPv4 地址 (也可以用 AF_UNSPEC 获取所有) + + // // 内核的目标地址 + // struct sockaddr_nl dest_addr; + // memset(&dest_addr, 0, sizeof(dest_addr)); + // dest_addr.nl_family = AF_NETLINK; + // dest_addr.nl_pid = 0; // 0 for the kernel + // dest_addr.nl_groups = 0; // Unicast + + // // 发送请求到内核 + // if (sendto(sock_fd, + // &req, + // req.nlh.nlmsg_len, + // 0, + // (struct sockaddr *)&dest_addr, + // sizeof(dest_addr)) < 0) { + // perror("send failed"); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // printf("Sent RTM_GETADDR request with DUMP flag.\n\n"); + + // // // 准备 recvmsg 所需的结构体 + // // // iovec 指向我们的主数据缓冲区 + // // iov.iov_base = buf; + // // iov.iov_len = sizeof(buf); + + // // // msghdr 将所有部分组合在一起 + // // msg.msg_name = &src_addr; // 填充发送方地址 + // // msg.msg_namelen = sizeof(src_addr); + // // msg.msg_iov = &iov; // 指向数据缓冲区 + // // msg.msg_iovlen = 1; + // // msg.msg_control = NULL; // 我们暂时不处理控制消息 + // // msg.msg_controllen = 0; + + // // 循环接收响应 + // while ((len = recv(sock_fd, buf, sizeof(buf), 0)) > 0) { + // // ssize_t len = recvmsg(sock_fd, &msg, 0); + + // // if (len < 0) { + // // perror("recvmsg failed"); + // // break; + // // } + + // // if (len == 0) { + // // printf("EOF on netlink socket\n"); + // // break; + // // } + + // // if (msg.msg_flags & MSG_TRUNC) { + // // fprintf( + // // stderr, + // // "Warning: Message was truncated. Buffer may be too small.\n"); + // // } + + + // // 使用 NLMSG_OK 遍历缓冲区中可能存在的多条消息 + // for (nlh = (struct nlmsghdr *)buf; NLMSG_OK(nlh, len); + // nlh = NLMSG_NEXT(nlh, len)) { + + // // 如果是 DUMP 结束的标志,则退出循环 + // if (nlh->nlmsg_type == NLMSG_DONE) { + // printf("--- End of DUMP ---\n"); + // close(sock_fd); + // return EXIT_SUCCESS; + // } + + // // 如果是错误消息 + // if (nlh->nlmsg_type == NLMSG_ERROR) { + // struct nlmsgerr *err = (struct nlmsgerr *)NLMSG_DATA(nlh); + // fprintf(stderr, + // "Netlink error received: %s\n", + // strerror(-err->error)); + // close(sock_fd); + // return EXIT_FAILURE; + // } + + // // 只处理我们期望的 RTM_NEWADDR 消息 + // if (nlh->nlmsg_type != RTM_NEWADDR) { + // printf("Received unexpected message type: %d\n", + // nlh->nlmsg_type); + // continue; + // } + + // // printf("Received message from PID: %u\n", src_addr.nl_pid); + + // struct ifaddrmsg *ifa = (struct ifaddrmsg *)NLMSG_DATA(nlh); + // struct rtattr *rta_tb[IFA_MAX + 1]; + + // // 解析消息中的路由属性 + // int rta_len = nlh->nlmsg_len - NLMSG_LENGTH(sizeof(*ifa)); + // parse_rtattr(rta_tb, IFA_MAX, IFA_RTA(ifa), rta_len); + + // printf("Interface Index: %d, PrefixLen: %d, Scope: %d\n", + // ifa->ifa_index, + // ifa->ifa_prefixlen, + // ifa->ifa_scope); + + // char ip_addr_str[INET6_ADDRSTRLEN]; + + // // 打印 IFA_LABEL (对应你的 AddrAttr::Label) + // if (rta_tb[IFA_LABEL]) { + // printf("\tLabel: %s\n", (char *)RTA_DATA(rta_tb[IFA_LABEL])); + // } + + // // 打印 IFA_ADDRESS (对应你的 AddrAttr::Address) + // if (rta_tb[IFA_ADDRESS]) { + // inet_ntop(ifa->ifa_family, + // RTA_DATA(rta_tb[IFA_ADDRESS]), + // ip_addr_str, + // sizeof(ip_addr_str)); + // printf("\tAddress: %s\n", ip_addr_str); + // } + + // // 打印 IFA_LOCAL (对应你的 AddrAttr::Local) + // if (rta_tb[IFA_LOCAL]) { + // inet_ntop(ifa->ifa_family, + // RTA_DATA(rta_tb[IFA_LOCAL]), + // ip_addr_str, + // sizeof(ip_addr_str)); + // printf("\tLocal: %s\n", ip_addr_str); + // } + // printf("----------------------------------------\n"); + // } + // } + + + // close(sock_fd); + // return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/user/apps/c_unitest/test_router.c b/user/apps/c_unitest/test_router.c new file mode 100644 index 000000000..d49bcc674 --- /dev/null +++ b/user/apps/c_unitest/test_router.c @@ -0,0 +1,170 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "192.168.2.3" +#define FAKE_SERVER_IP "192.168.2.1" +#define CLIENT_IP "192.168.1.1" +#define PORT 34254 +#define BUFFER_SIZE 1024 + +// 错误处理函数 +void handle_error_message(const char *message) { + perror(message); + exit(EXIT_FAILURE); +} + +// 服务器线程函数 +void *server_func(void *arg) { + int sockfd; + struct sockaddr_in server_addr, client_addr; + char buffer[BUFFER_SIZE]; + socklen_t client_len = sizeof(client_addr); + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[server] Failed to create socket"); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[server] Invalid server IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[server] Failed to bind to " SERVER_IP); + } + printf("[server] Listening on %s:%d\n", SERVER_IP, PORT); + + ssize_t n = recvfrom(sockfd, + buffer, + BUFFER_SIZE, + 0, + (struct sockaddr *)&client_addr, + &client_len); + if (n < 0) { + handle_error_message("[server] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + // //debug + // unsigned char *ip_bytes = (unsigned char *)&client_addr.sin_addr.s_addr; + // printf("[DEBUG] Raw IP bytes received: %d.%d.%d.%d\n", + // ip_bytes[0], + // ip_bytes[1], + // ip_bytes[2], + // ip_bytes[3]); + + char client_ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &client_addr.sin_addr, client_ip_str, INET_ADDRSTRLEN); + printf("[server] Received from %s:%d: %s\n", + client_ip_str, + ntohs(client_addr.sin_port), + buffer); + + if (sendto(sockfd, + buffer, + n, + 0, + (const struct sockaddr *)&client_addr, + client_len) < 0) { + handle_error_message("[server] Failed to send back"); + } + sleep(1); + printf("[server] Echoed back the message\n"); + + close(sockfd); + printf("server goning to exit\n"); + return NULL; +} + +// 客户端线程函数 +void *client_func(void *arg) { + int sockfd; + struct sockaddr_in client_addr, server_addr; + char buffer[BUFFER_SIZE]; + const char *msg = "Hello from veth1!"; + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[client] Failed to create socket"); + } + + memset(&client_addr, 0, sizeof(client_addr)); + client_addr.sin_family = AF_INET; + client_addr.sin_port = htons(0); // 端口为0,由操作系统自动选择 + if (inet_pton(AF_INET, CLIENT_IP, &client_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid client IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&client_addr, + sizeof(client_addr)) < 0) { + handle_error_message("[client] Failed to bind to " CLIENT_IP); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, FAKE_SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid server IP address for connect"); + } + + if (connect(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[client] Failed to connect"); + } + + if (send(sockfd, msg, strlen(msg), 0) < 0) { + handle_error_message("[client] Failed to send"); + } + printf("[client] Sent: %s\n", msg); + + ssize_t n = recv(sockfd, buffer, BUFFER_SIZE, 0); + if (n < 0) { + handle_error_message("[client] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + printf("[client] Received echo: %s\n", buffer); + + assert(strcmp(msg, buffer) == 0 && "[client] Mismatch in echo!"); + + close(sockfd); + printf("client goning to exit\n"); + return NULL; +} + +int main() { + pthread_t server_tid, client_tid; + + if (pthread_create(&server_tid, NULL, server_func, NULL) != 0) { + handle_error_message("Failed to create server thread"); + } + + usleep(200 * 1000); // 200 milliseconds + + if (pthread_create(&client_tid, NULL, client_func, NULL) != 0) { + handle_error_message("Failed to create client thread"); + } + + if (pthread_join(server_tid, NULL) != 0) { + handle_error_message("Failed to join server thread"); + } + if (pthread_join(client_tid, NULL) != 0) { + handle_error_message("Failed to join client thread"); + } + + printf("\nTest completed: veth_a <--> veth_d UDP communication success\n"); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/user/apps/c_unitest/test_veth_bridge.c b/user/apps/c_unitest/test_veth_bridge.c new file mode 100644 index 000000000..e0b20720e --- /dev/null +++ b/user/apps/c_unitest/test_veth_bridge.c @@ -0,0 +1,169 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SERVER_IP "200.0.0.4" +#define CLIENT_IP "200.0.0.1" +#define PORT 34254 +#define BUFFER_SIZE 1024 + +// 错误处理函数 +void handle_error_message(const char *message) { + perror(message); + exit(EXIT_FAILURE); +} + +// 服务器线程函数 +void *server_func(void *arg) { + int sockfd; + struct sockaddr_in server_addr, client_addr; + char buffer[BUFFER_SIZE]; + socklen_t client_len = sizeof(client_addr); + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[server] Failed to create socket"); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[server] Invalid server IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[server] Failed to bind to " SERVER_IP); + } + printf("[server] Listening on %s:%d\n", SERVER_IP, PORT); + + ssize_t n = recvfrom(sockfd, + buffer, + BUFFER_SIZE, + 0, + (struct sockaddr *)&client_addr, + &client_len); + if (n < 0) { + handle_error_message("[server] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + // //debug + // unsigned char *ip_bytes = (unsigned char *)&client_addr.sin_addr.s_addr; + // printf("[DEBUG] Raw IP bytes received: %d.%d.%d.%d\n", + // ip_bytes[0], + // ip_bytes[1], + // ip_bytes[2], + // ip_bytes[3]); + + char client_ip_str[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &client_addr.sin_addr, client_ip_str, INET_ADDRSTRLEN); + printf("[server] Received from %s:%d: %s\n", + client_ip_str, + ntohs(client_addr.sin_port), + buffer); + + if (sendto(sockfd, + buffer, + n, + 0, + (const struct sockaddr *)&client_addr, + client_len) < 0) { + handle_error_message("[server] Failed to send back"); + } + // sleep(5); + printf("[server] Echoed back the message\n"); + + close(sockfd); + printf("server goning to exit\n"); + return NULL; +} + +// 客户端线程函数 +void *client_func(void *arg) { + int sockfd; + struct sockaddr_in client_addr, server_addr; + char buffer[BUFFER_SIZE]; + const char *msg = "Hello from veth1!"; + + if ((sockfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) { + handle_error_message("[client] Failed to create socket"); + } + + memset(&client_addr, 0, sizeof(client_addr)); + client_addr.sin_family = AF_INET; + client_addr.sin_port = htons(0); // 端口为0,由操作系统自动选择 + if (inet_pton(AF_INET, CLIENT_IP, &client_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid client IP address"); + } + + if (bind(sockfd, + (const struct sockaddr *)&client_addr, + sizeof(client_addr)) < 0) { + handle_error_message("[client] Failed to bind to " CLIENT_IP); + } + + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + if (inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr) <= 0) { + handle_error_message("[client] Invalid server IP address for connect"); + } + + if (connect(sockfd, + (const struct sockaddr *)&server_addr, + sizeof(server_addr)) < 0) { + handle_error_message("[client] Failed to connect"); + } + + if (send(sockfd, msg, strlen(msg), 0) < 0) { + handle_error_message("[client] Failed to send"); + } + printf("[client] Sent: %s\n", msg); + + ssize_t n = recv(sockfd, buffer, BUFFER_SIZE, 0); + if (n < 0) { + handle_error_message("[client] Failed to receive"); + } + buffer[n] = '\0'; // 确保字符串正确终止 + + printf("[client] Received echo: %s\n", buffer); + + assert(strcmp(msg, buffer) == 0 && "[client] Mismatch in echo!"); + + close(sockfd); + printf("client goning to exit\n"); + return NULL; +} + +int main() { + pthread_t server_tid, client_tid; + + if (pthread_create(&server_tid, NULL, server_func, NULL) != 0) { + handle_error_message("Failed to create server thread"); + } + + usleep(200 * 1000); // 200 milliseconds + + if (pthread_create(&client_tid, NULL, client_func, NULL) != 0) { + handle_error_message("Failed to create client thread"); + } + + if (pthread_join(server_tid, NULL) != 0) { + handle_error_message("Failed to join server thread"); + } + if (pthread_join(client_tid, NULL) != 0) { + handle_error_message("Failed to join client thread"); + } + + printf("\nTest completed: veth_a <--> veth_d UDP communication success\n"); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/user/apps/test-router/.gitignore b/user/apps/test-router/.gitignore new file mode 100644 index 000000000..1ac354611 --- /dev/null +++ b/user/apps/test-router/.gitignore @@ -0,0 +1,3 @@ +/target +Cargo.lock +/install/ \ No newline at end of file diff --git a/user/apps/test-router/Cargo.toml b/user/apps/test-router/Cargo.toml new file mode 100644 index 000000000..79599a1ca --- /dev/null +++ b/user/apps/test-router/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-router" +version = "0.1.0" +edition = "2021" +description = "测试路由功能" +authors = [ "sparkzky " ] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +smoltcp = { version = "0.9", features = ["medium-ethernet"] } +etherparse = "0.13" \ No newline at end of file diff --git a/user/apps/test-router/Makefile b/user/apps/test-router/Makefile new file mode 100644 index 000000000..7522ea16c --- /dev/null +++ b/user/apps/test-router/Makefile @@ -0,0 +1,56 @@ +TOOLCHAIN= +RUSTFLAGS= + +ifdef DADK_CURRENT_BUILD_DIR +# 如果是在dadk中编译,那么安装到dadk的安装目录中 + INSTALL_DIR = $(DADK_CURRENT_BUILD_DIR) +else +# 如果是在本地编译,那么安装到当前目录下的install目录中 + INSTALL_DIR = ./install +endif + +ifeq ($(ARCH), x86_64) + export RUST_TARGET=x86_64-unknown-linux-musl +else ifeq ($(ARCH), riscv64) + export RUST_TARGET=riscv64gc-unknown-linux-gnu +else +# 默认为x86_86,用于本地编译 + export RUST_TARGET=x86_64-unknown-linux-musl +endif + +run: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) + +build: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) + +clean: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) + +test: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) + +doc: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) doc --target $(RUST_TARGET) + +fmt: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt + +fmt-check: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) fmt --check + +run-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) run --target $(RUST_TARGET) --release + +build-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) build --target $(RUST_TARGET) --release + +clean-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) clean --target $(RUST_TARGET) --release + +test-release: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) test --target $(RUST_TARGET) --release + +.PHONY: install +install: + RUSTFLAGS=$(RUSTFLAGS) cargo $(TOOLCHAIN) install --target $(RUST_TARGET) --path . --no-track --root $(INSTALL_DIR) --force diff --git a/user/apps/test-router/src/main.rs b/user/apps/test-router/src/main.rs new file mode 100644 index 000000000..aec1f5653 --- /dev/null +++ b/user/apps/test-router/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world DragonOS-musl"); +}