diff --git a/beamsim.py b/beamsim.py index 6025263..703444f 100644 --- a/beamsim.py +++ b/beamsim.py @@ -66,13 +66,13 @@ def __init__(self, items): run_exe_time = None -def run(b="ns3", t="direct", g=10, gv=10, shuffle=False, mpi=False): +def run(b="ns3-direct", t="direct", p="tcp", g=10, gv=10, shuffle=False, mpi=False): global run_exe_time exe_time = os.stat(exe).st_mtime if run_exe_time != exe_time: run_exe_time = exe_time run_cache.clear() - key = (b, t, g, gv, shuffle, mpi) + key = (b, t, p, g, gv, shuffle, mpi) output = run_cache.get(key, None) if output is None: cmd = [ @@ -86,6 +86,8 @@ def run(b="ns3", t="direct", g=10, gv=10, shuffle=False, mpi=False): b, "-t", t, + "-p", + p, "-g", str(g), "-gv", diff --git a/cli.hpp b/cli.hpp index f34da6b..94ad430 100644 --- a/cli.hpp +++ b/cli.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -349,6 +350,10 @@ struct SimulationConfig { {Topology::GOSSIP, "gossip"}, {Topology::GRID, "grid"}, }}; + const Args::Enum enum_protocol_{{ + {beamsim::ns3_::Protocol::TCP, "tcp"}, + {beamsim::ns3_::Protocol::UDP, "udp"}, + }}; beamsim::example::RolesConfig roles_config; @@ -372,6 +377,13 @@ struct SimulationConfig { "Communication topology", enum_topology_, }; + beamsim::ns3_::Protocol protocol = beamsim::ns3_::Protocol::TCP; + Args::FlagEnum flag_protocol{ + {"-p", "--protocol"}, + protocol, + "Network protocol (TCP/UDP)", + enum_protocol_, + }; Args::FlagInt flag_group_count{{ {"-g", "--groups"}, roles_config.group_count, @@ -394,6 +406,7 @@ struct SimulationConfig { return f(flag_config_path, flag_backend, flag_topology, + flag_protocol, flag_group_count, flag_validators_per_group, flag_shuffle, @@ -424,6 +437,7 @@ struct SimulationConfig { Yaml yaml{YAML::LoadFile(config_path)}; yaml.at({"backend"}).get(backend, enum_backend_); yaml.at({"topology"}).get(topology, enum_topology_); + yaml.at({"protocol"}).get(protocol, enum_protocol_); yaml.at({"shuffle"}).get(shuffle); yaml.at({"roles", "group_count"}).get(roles_config.group_count); @@ -466,6 +480,7 @@ struct SimulationConfig { std::println("Configuration:"); std::println(" Backend: {}", enum_backend_.str(backend)); std::println(" Topology: {}", enum_topology_.str(topology)); + std::println(" Protocol: {}", enum_protocol_.str(protocol)); std::println(" Groups: {}", roles_config.group_count); std::println(" Validators per group: {}", roles_config.group_validator_count); diff --git a/example.yaml b/example.yaml index 0f0af7c..12f0014 100644 --- a/example.yaml +++ b/example.yaml @@ -1,5 +1,6 @@ backend: ns3 topology: direct +protocol: tcp # Use tcp or udp roles: group_count: 4 diff --git a/main.cpp b/main.cpp index 5e96162..d7d0716 100644 --- a/main.cpp +++ b/main.cpp @@ -597,6 +597,7 @@ void run_simulation(const SimulationConfig &config) { case SimulationConfig::Backend::NS3_DIRECT: { #ifdef ns3_FOUND beamsim::ns3_::Simulator simulator{metrics_ptr}; + simulator.setProtocol(config.protocol); if (config.backend == SimulationConfig::Backend::NS3) { simulator.routing_.initRouters(routers); } else { diff --git a/src/beamsim/ns3/protocol.hpp b/src/beamsim/ns3/protocol.hpp new file mode 100644 index 0000000..9dd3093 --- /dev/null +++ b/src/beamsim/ns3/protocol.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace beamsim::ns3_ { + enum class Protocol { + TCP, + UDP, + }; +} // namespace beamsim::ns3_ diff --git a/src/beamsim/ns3/simulator.hpp b/src/beamsim/ns3/simulator.hpp index c3cd523..17daef8 100644 --- a/src/beamsim/ns3/simulator.hpp +++ b/src/beamsim/ns3/simulator.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -178,42 +179,27 @@ namespace beamsim::ns3_ { peer_->onStart(); } - SocketPtr makeSocket() { - auto socket = ns3::Socket::CreateSocket( + ns3::InetSocketAddress peerAddress(PeerIndex peer_index) const; + PeerIndex peerIndex(const ns3::Address &address) const; + SocketPtr makeTcpSocket() { + return ns3::Socket::CreateSocket( GetNode(), ns3::TypeId::LookupByName("ns3::TcpSocketFactory")); - return socket; - } - void listen() { - tcp_listener_ = makeSocket(); - tcp_listener_->Bind(ns3::InetSocketAddress{ - ns3::Ipv4Address::GetAny(), - kPort, - }); - tcp_listener_->Listen(); - tcp_listener_->SetAcceptCallback( - ns3::MakeNullCallback(), - ns3::MakeCallback(&Application::onAccept, this)); } + void listen(); void onAccept(SocketPtr socket, const ns3::Address &address); void add(PeerIndex peer_index, SocketPtr socket) { - socket->SetRecvCallback(MakeCallback(&Application::pollRead, this)); - socket->SetSendCallback(MakeCallback(&Application::pollWrite, this)); + socket->SetRecvCallback(MakeCallback(&Application::pollReadTcp, this)); + socket->SetSendCallback(MakeCallback(&Application::pollWriteTcp, this)); tcp_socket_state_.emplace(socket, SocketState{peer_index}); } - void pollRead(SocketPtr socket); - void pollWrite(SocketPtr socket, uint32_t = 0) { - auto &state = tcp_socket_state_.at(socket); - while (not state.writing.empty()) { - size_t available = socket->GetTxAvailable(); - if (available == 0) { - break; - } - auto packet = state.writing.readPacket(available); - assert2(socket->Send(packet) == static_cast(packet->GetSize())); - } - } + void pollReadTcp(SocketPtr socket); + void pollReadUdp(SocketPtr); + void onPacket(ns3::Ptr packet, SocketState &state); + bool pollWrite(SocketPtr socket, SocketState &state); + void pollWriteTcp(SocketPtr socket, uint32_t = 0); + void pollWriteUdp(SocketPtr, uint32_t = 0); void onConnect(SocketPtr socket) { - pollWrite(socket); + pollWriteTcp(socket); } void onConnectError(SocketPtr) { abort(); @@ -221,26 +207,27 @@ namespace beamsim::ns3_ { void connect(PeerIndex peer_index); void send(PeerIndex peer_index, std::optional message_id, - const IMessage &message) { - assert2(peer_index != peer_->peer_index_); - auto &sockets = tcp_sockets_[peer_index]; - auto connected = sockets.write() != nullptr; - if (not connected) { - connect(peer_index); - } - auto &socket = sockets.write(); - auto &state = tcp_socket_state_.at(socket); - state.writing.write(message_id, message); - if (connected) { - pollWrite(socket); + const IMessage &message); + SocketState &udpSocketState(PeerIndex peer_index) { + auto it = udp_socket_state_.find(peer_index); + if (it == udp_socket_state_.end()) { + it = udp_socket_state_.emplace(peer_index, SocketState{peer_index}) + .first; } + return it->second; } + public: Simulator &simulator_; std::unique_ptr peer_; + + private: SocketPtr tcp_listener_; std::unordered_map tcp_sockets_; std::unordered_map tcp_socket_state_; + SocketPtr udp_socket_; + std::unordered_map udp_socket_state_; + std::unordered_set udp_writing_; Bytes reading_; }; @@ -252,6 +239,10 @@ namespace beamsim::ns3_ { } } + void setProtocol(Protocol protocol) { + protocol_ = protocol; + } + // ISimulator ~Simulator() override { ns3::Simulator::Destroy(); @@ -367,6 +358,7 @@ namespace beamsim::ns3_ { } IMetrics *metrics_; + Protocol protocol_ = Protocol::TCP; bool cache_messages_ = true; std::vector> applications_; Routing routing_; @@ -375,9 +367,49 @@ namespace beamsim::ns3_ { MessageDecodeFn message_decode_; }; - void Application::onAccept(SocketPtr socket, const ns3::Address &address) { - auto index = simulator_.routing_.ip_peer_index_.at( + // Application method implementations + ns3::InetSocketAddress Application::peerAddress(PeerIndex peer_index) const { + return ns3::InetSocketAddress{ + simulator_.routing_.peer_ips_.at(peer_index), + kPort, + }; + } + + PeerIndex Application::peerIndex(const ns3::Address &address) const { + return simulator_.routing_.ip_peer_index_.at( ns3::InetSocketAddress::ConvertFrom(address).GetIpv4()); + } + + void Application::listen() { + ns3::InetSocketAddress bind{ + ns3::Ipv4Address::GetAny(), + kPort, + }; + switch (simulator_.protocol_) { + case Protocol::TCP: { + tcp_listener_ = makeTcpSocket(); + assert2(tcp_listener_->Bind(bind) != -1); + assert2(tcp_listener_->Listen() != -1); + tcp_listener_->SetAcceptCallback( + ns3::MakeNullCallback(), + ns3::MakeCallback(&Application::onAccept, this)); + break; + } + case Protocol::UDP: { + udp_socket_ = ns3::Socket::CreateSocket( + GetNode(), ns3::TypeId::LookupByName("ns3::UdpSocketFactory")); + assert2(udp_socket_->Bind(bind) != -1); + udp_socket_->SetRecvCallback( + ns3::MakeCallback(&Application::pollReadUdp, this)); + udp_socket_->SetSendCallback( + ns3::MakeCallback(&Application::pollWriteUdp, this)); + break; + } + } + } + + void Application::onAccept(SocketPtr socket, const ns3::Address &address) { + auto index = peerIndex(address); auto &sockets = tcp_sockets_[index]; assert2(not sockets.in); sockets.in = socket; @@ -385,78 +417,133 @@ namespace beamsim::ns3_ { add(index, socket); } - void Application::pollRead(SocketPtr socket) { + void Application::pollReadTcp(SocketPtr socket) { auto &state = tcp_socket_state_.at(socket); while (auto packet = socket->Recv()) { - reading_.resize(packet->GetSize()); - packet->CopyData(reading_.data(), reading_.size()); - state.reading.buffer_.write(reading_); - while (true) { - auto &buffer = state.reading.buffer_; - if (state.reading.frame_.has_value()) { - auto &frame = state.reading.frame_.value(); - auto want_data = frame.header.data_size.value() - frame.data.size(); - if (want_data > 0) { - auto n = std::min(want_data, buffer.size()); - if (n == 0) { - break; - } - frame.data.resize(frame.data.size() + n); - buffer.peek(std::span{frame.data}.subspan(frame.data.size() - n)); - buffer.read(n); - continue; - } - if (frame.padding > 0) { - auto n = std::min(frame.padding, buffer.size()); - if (n == 0) { - break; - } - buffer.read(n); - frame.padding -= n; - continue; - } - auto item = std::exchange(state.reading.frame_, {}).value(); - MessagePtr message; - if (simulator_.cache_messages_ - and simulator_.isLocalPeer(state.peer_index)) { - auto node = - simulator_.messages_.extract(item.header.message_id.value()); - assert2(node); - message = std::move(node.mapped()); - } else { - MessageDecodeFrom data{item.data}; - message = simulator_.message_decode_(data); + onPacket(packet, state); + } + } + + void Application::pollReadUdp(SocketPtr) { + ns3::Address from; + while (auto packet = udp_socket_->RecvFrom(from)) { + onPacket(packet, udpSocketState(peerIndex(from))); + } + } + + void Application::onPacket(ns3::Ptr packet, SocketState &state) { + reading_.resize(packet->GetSize()); + packet->CopyData(reading_.data(), reading_.size()); + state.reading.buffer_.write(reading_); + while (true) { + auto &buffer = state.reading.buffer_; + if (state.reading.frame_.has_value()) { + auto &frame = state.reading.frame_.value(); + auto want_data = frame.header.data_size.value() - frame.data.size(); + if (want_data > 0) { + auto n = std::min(want_data, buffer.size()); + if (n == 0) { + break; } - if (simulator_.metrics_ != nullptr) { - simulator_.metrics_->onPeerReceivedMessage(state.peer_index); + frame.data.resize(frame.data.size() + n); + buffer.peek(std::span{frame.data}.subspan(frame.data.size() - n)); + buffer.read(n); + continue; + } + if (frame.padding > 0) { + auto n = std::min(frame.padding, buffer.size()); + if (n == 0) { + break; } - peer_->onMessage(state.peer_index, std::move(message)); + buffer.read(n); + frame.padding -= n; continue; } - Header header; - if (buffer.size() < header.size()) { + auto item = std::exchange(state.reading.frame_, {}).value(); + MessagePtr message; + if (simulator_.cache_messages_ + and simulator_.isLocalPeer(state.peer_index)) { + auto node = + simulator_.messages_.extract(item.header.message_id.value()); + assert2(node); + message = std::move(node.mapped()); + } else { + MessageDecodeFrom data{item.data}; + message = simulator_.message_decode_(data); + } + if (simulator_.metrics_ != nullptr) { + simulator_.metrics_->onPeerReceivedMessage(state.peer_index); + } + peer_->onMessage(state.peer_index, std::move(message)); + continue; + } + Header header; + if (buffer.size() < header.size()) { + break; + } + buffer.read(header.message_size.bytes); + buffer.read(header.message_id.bytes); + buffer.read(header.data_size.bytes); + state.reading.frame_.emplace( + header, + Bytes{}, + header.message_size.value() - header.data_size.value()); + } + } + + bool Application::pollWrite(SocketPtr socket, SocketState &state) { + auto any = false; + while (not state.writing.empty()) { + size_t available = socket->GetTxAvailable(); + if (available == 0) { + break; + } + auto packet = state.writing.readPacket(available); + int r = 0; + switch (simulator_.protocol_) { + case Protocol::TCP: { + r = socket->Send(packet) == static_cast(packet->GetSize()); + break; + } + case Protocol::UDP: { + r = socket->SendTo(packet, 0, peerAddress(state.peer_index)); break; } - buffer.read(header.message_size.bytes); - buffer.read(header.message_id.bytes); - buffer.read(header.data_size.bytes); - state.reading.frame_.emplace( - header, - Bytes{}, - header.message_size.value() - header.data_size.value()); + } + assert2(r != -1); + any = true; + } + return any; + } + + void Application::pollWriteTcp(SocketPtr socket, uint32_t) { + pollWrite(socket, tcp_socket_state_.at(socket)); + } + + void Application::pollWriteUdp(SocketPtr, uint32_t) { + auto available = true; + while (not udp_writing_.empty() and available) { + auto peer_index = *udp_writing_.begin(); + auto &state = udp_socket_state_.at(peer_index); + available = pollWrite(udp_socket_, state); + if (state.writing.empty()) { + udp_writing_.erase(peer_index); } } } void Application::connect(PeerIndex peer_index) { + if (simulator_.protocol_ == Protocol::UDP) { + return; // UDP is connectionless, no need to connect + } + auto &sockets = tcp_sockets_[peer_index]; assert2(not sockets.out); if (sockets.write()) { return; } - auto socket = makeSocket(); - socket->Connect(ns3::InetSocketAddress{ - simulator_.routing_.peer_ips_.at(peer_index), kPort}); + auto socket = makeTcpSocket(); + socket->Connect(peerAddress(peer_index)); sockets.out = socket; sockets.write_out = true; socket->SetConnectCallback( @@ -464,4 +551,33 @@ namespace beamsim::ns3_ { MakeCallback(&Application::onConnectError, this)); add(peer_index, socket); } + + void Application::send(PeerIndex peer_index, + std::optional message_id, + const IMessage &message) { + assert2(peer_index != peer_->peer_index_); + switch (simulator_.protocol_) { + case Protocol::TCP: { + auto &sockets = tcp_sockets_[peer_index]; + auto connected = sockets.write() != nullptr; + if (not connected) { + connect(peer_index); + } + auto &socket = sockets.write(); + auto &state = tcp_socket_state_.at(socket); + state.writing.write(message_id, message); + if (connected) { + pollWrite(socket, state); + } + break; + } + case Protocol::UDP: { + auto &state = udpSocketState(peer_index); + state.writing.write(message_id, message); + udp_writing_.emplace(peer_index); + pollWrite(udp_socket_, state); + break; + } + } + } } // namespace beamsim::ns3_ diff --git a/test_tcp.yaml b/test_tcp.yaml new file mode 100644 index 0000000..9cb0937 --- /dev/null +++ b/test_tcp.yaml @@ -0,0 +1,8 @@ +backend: ns3-direct +topology: gossip +protocol: tcp +roles: + group_count: 10 + group_validator_count: 256 + global_aggregator_count: 1 + group_local_aggregator_count: 1 diff --git a/test_udp.yaml b/test_udp.yaml new file mode 100644 index 0000000..2ab638f --- /dev/null +++ b/test_udp.yaml @@ -0,0 +1,8 @@ +backend: ns3-direct +topology: gossip +protocol: udp +roles: + group_count: 10 + group_validator_count: 128 + global_aggregator_count: 1 + group_local_aggregator_count: 1