diff --git a/src/interface/comm.rs b/src/interface/comm.rs index 23cf655a..67c57360 100644 --- a/src/interface/comm.rs +++ b/src/interface/comm.rs @@ -639,3 +639,10 @@ pub fn kernel_select( return result; } + +pub fn get_loopback_path(port: u16) -> String { + let mut path = String::from("tmp/loopback"); + path.push_str(&port.to_string()); + + path +} \ No newline at end of file diff --git a/src/safeposix/syscalls/net_calls.rs b/src/safeposix/syscalls/net_calls.rs index 90f738d5..3eea7aab 100644 --- a/src/safeposix/syscalls/net_calls.rs +++ b/src/safeposix/syscalls/net_calls.rs @@ -56,6 +56,8 @@ use super::net_constants::*; use super::sys_constants::*; use crate::interface; use crate::interface::errnos::{syscall_error, Errno}; +use crate::interface::GenIpaddr; +use crate::interface::GenSockaddr; use crate::safeposix::cage::{FileDescriptor::*, *}; use crate::safeposix::filesystem::*; use crate::safeposix::net::*; @@ -445,7 +447,32 @@ impl Cage { let res = match sockhandle.domain { AF_UNIX => self.bind_inner_socket_unix(sockhandle, &mut newsockaddr), AF_INET | AF_INET6 => { - self.bind_inner_socket_inet(sockhandle, &mut newsockaddr, prereserved) + let loopback_addr = u32::from_ne_bytes([127, 0, 0, 7]); + match newsockaddr.addr() { + interface::GenIpaddr::V4(addr) => { + if addr.s_addr == loopback_addr { + // this is the loopback address, and we need to fake it into a domain socket + + let path = interface::get_loopback_path(newsockaddr.port()); + newsockaddr = GenSockaddr::Unix(interface::new_sockaddr_unix(AF_UNIX as u16, path.as_bytes())); + + sockhandle.domain = AF_UNIX; + + // self.bind_inner_socket_inet(sockhandle, &mut newsockaddr, prereserved) + let ret = self.bind_inner_socket_unix(sockhandle, &mut newsockaddr); + + ret + + } else { + // not loopback adddress + self.bind_inner_socket_inet(sockhandle, &mut newsockaddr, prereserved) + } + }, + _ => { + // we do not want to handle ipv6 loopback for now + self.bind_inner_socket_inet(sockhandle, &mut newsockaddr, prereserved) + } + } } _ => { return syscall_error(Errno::EINVAL, "bind", "Unsupported domain provided"); @@ -890,6 +917,7 @@ impl Cage { //by other threads accessing other fields let sock_tmp = sockfdobj.handle.clone(); let mut sockhandle = sock_tmp.write(); + //Possible address families are Unix, V4, V6 //Error occurs if remoteaddr's address family does not match //the domain of the socket pointed to by fd @@ -901,12 +929,27 @@ impl Cage { ); } + let mut remoteaddr_final = remoteaddr.clone(); + + if remoteaddr.get_family() == AF_INET as u16 { + let loopback_addr = u32::from_ne_bytes([127, 0, 0, 7]); + let addr = remoteaddr.addr(); + if let interface::GenIpaddr::V4(addr) = addr { + if addr.s_addr == loopback_addr { + sockhandle.domain = AF_UNIX; + let path = interface::get_loopback_path(remoteaddr.port()); + + remoteaddr_final = GenSockaddr::Unix(interface::new_sockaddr_unix(AF_UNIX as u16, path.as_bytes())); + } + } + } + match sockhandle.protocol { IPPROTO_UDP => { - return self.connect_udp(&mut *sockhandle, sockfdobj, remoteaddr) + return self.connect_udp(&mut *sockhandle, sockfdobj, &remoteaddr_final) } IPPROTO_TCP => { - return self.connect_tcp(&mut *sockhandle, sockfdobj, remoteaddr) + return self.connect_tcp(&mut *sockhandle, sockfdobj, &remoteaddr_final) } _ => { return syscall_error( diff --git a/src/tests/ipc_tests.rs b/src/tests/ipc_tests.rs index 514cac18..804dbf17 100644 --- a/src/tests/ipc_tests.rs +++ b/src/tests/ipc_tests.rs @@ -352,6 +352,94 @@ pub mod ipc_tests { lindrustfinalize(); } + #[test] + pub fn ut_lind_ipc_loopback_socket() { + // acquiring a lock on TESTMUTEX prevents other tests from running concurrently, + // and also performs clean env setup + let _thelock = setup::lock_and_init(); + + let cage = interface::cagetable_getref(1); + + let serversockfd = cage.socket_syscall(AF_INET, SOCK_STREAM, 0); + + // create a INET address + let port: u16 = generate_random_port(); + + let sockaddr = interface::SockaddrV4 { + sin_family: AF_INET as u16, + sin_port: port.to_le(), + sin_addr: interface::V4Addr { + s_addr: u32::from_ne_bytes([127, 0, 0, 7]), + }, + padding: 0, + }; + let socket = interface::GenSockaddr::V4(sockaddr); //127.0.0.7 from bytes above + + assert_eq!(cage.bind_syscall(serversockfd, &socket), 0); + assert_eq!(cage.listen_syscall(serversockfd, 4), 0); + + let barrier = interface::RustRfc::new(std::sync::Barrier::new(2)); + let barrier_clone = barrier.clone(); + + assert_eq!(cage.fork_syscall(2), 0); // used for pipe thread + + // client 1 connects to the server to send and recv data + let threadclient = interface::helper_thread(move || { + let cage2 = interface::cagetable_getref(2); + // assert_eq!(cage2.close_syscall(serversockfd), 0); + + let clientsockfd = cage2.socket_syscall(AF_INET, SOCK_STREAM, 0); + + // connect to server + assert_eq!(cage2.connect_syscall(clientsockfd, &socket), 0); + + // send message to server + assert_eq!(cage2.send_syscall(clientsockfd, str2cbuf("test"), 4, 0), 4); + + interface::sleep(interface::RustDuration::from_millis(1)); + + // receive message from server + let mut buf = sizecbuf(4); + assert_eq!(cage2.recv_syscall(clientsockfd, buf.as_mut_ptr(), 4, 0), 4); + assert_eq!(cbuf2str(&buf), "test"); + + assert_eq!(cage2.close_syscall(clientsockfd), 0); + + barrier_clone.wait(); + + cage2.exit_syscall(EXIT_SUCCESS); + }); + + let mut sockgarbage = + interface::GenSockaddr::V4(interface::SockaddrV4::default()); + + let sockfd = cage.accept_syscall(serversockfd as i32, &mut sockgarbage); + assert!(sockfd > 0); + + let mut buf = sizecbuf(4); + let mut recvresult: i32; + loop { + // receive message from peer + recvresult = cage.recv_syscall(sockfd as i32, buf.as_mut_ptr(), 4, 0); + if recvresult != -libc::EINTR { + break; // if the error was EINTR, retry the + // syscall + } + } + + assert!(cbuf2str(&buf) == "test"); + + // send message to server + assert_eq!(cage.send_syscall(sockfd, str2cbuf("test"), 4, 0), 4); + + barrier.wait(); + + threadclient.join().unwrap(); + + assert_eq!(cage.exit_syscall(EXIT_SUCCESS), EXIT_SUCCESS); + lindrustfinalize(); + } + // support for retrying writes in case the system doesn't write all bytes at // once #[test]