Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions cli.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ struct SimulationConfig {
GOSSIP,
GRID,
};
enum class Protocol {
TCP,
UDP,
};

const Args::Enum<Backend> enum_backend_{{
{Backend::DELAY, "delay"},
Expand All @@ -349,6 +353,10 @@ struct SimulationConfig {
{Topology::GOSSIP, "gossip"},
{Topology::GRID, "grid"},
}};
const Args::Enum<Protocol> enum_protocol_{{
{Protocol::TCP, "tcp"},
{Protocol::UDP, "udp"},
}};

beamsim::example::RolesConfig roles_config;

Expand All @@ -372,6 +380,13 @@ struct SimulationConfig {
"Communication topology",
enum_topology_,
};
Protocol protocol = Protocol::TCP;
Args::FlagEnum<decltype(protocol)> flag_protocol{
{"-p", "--protocol"},
protocol,
"Network protocol (TCP/UDP)",
enum_protocol_,
};
Args::FlagInt<beamsim::example::GroupIndex> flag_group_count{{
{"-g", "--groups"},
roles_config.group_count,
Expand All @@ -394,6 +409,7 @@ struct SimulationConfig {
return f(flag_config_path,
flag_backend,
flag_topology,
flag_protocol,
flag_group_count,
flag_validators_per_group,
flag_shuffle,
Expand Down Expand Up @@ -424,6 +440,7 @@ struct SimulationConfig {
Yaml yaml{YAML::LoadFile(config_path)};
yaml.at({"backend"}).get(backend, enum_backend_);
yaml.at({"topology"}).get(topology, enum_topology_);
yaml.at({"protocol"}).get(protocol, enum_protocol_);
yaml.at({"shuffle"}).get(shuffle);

yaml.at({"roles", "group_count"}).get(roles_config.group_count);
Expand Down Expand Up @@ -466,6 +483,7 @@ struct SimulationConfig {
std::println("Configuration:");
std::println(" Backend: {}", enum_backend_.str(backend));
std::println(" Topology: {}", enum_topology_.str(topology));
std::println(" Protocol: {}", enum_protocol_.str(protocol));
std::println(" Groups: {}", roles_config.group_count);
std::println(" Validators per group: {}",
roles_config.group_validator_count);
Expand Down
1 change: 1 addition & 0 deletions example.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
backend: ns3
topology: direct
protocol: tcp # Use tcp or udp

roles:
group_count: 4
Expand Down
3 changes: 3 additions & 0 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ void run_simulation(const SimulationConfig &config) {
case SimulationConfig::Backend::NS3_DIRECT: {
#ifdef ns3_FOUND
beamsim::ns3_::Simulator simulator{metrics_ptr};
simulator.setProtocol(config.protocol == SimulationConfig::Protocol::UDP ?
beamsim::ns3_::Protocol::UDP :
beamsim::ns3_::Protocol::TCP);
if (config.backend == SimulationConfig::Backend::NS3) {
simulator.routing_.initRouters(routers);
} else {
Expand Down
202 changes: 155 additions & 47 deletions src/beamsim/ns3/simulator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ns3/applications-module.h>
#include <ns3/core-module.h>
#include <ns3/internet-module.h>
#include <ns3/mpi-interface.h>
#include <ns3/network-module.h>

Expand All @@ -16,6 +17,11 @@ namespace beamsim::ns3_ {

constexpr uint16_t kPort = 10000;

enum class Protocol {
TCP,
UDP,
};

inline ns3::Time timeToNs3(Time time) {
static_assert(std::is_same_v<Time, std::chrono::microseconds>);
return ns3::MicroSeconds(time.count());
Expand Down Expand Up @@ -178,22 +184,9 @@ namespace beamsim::ns3_ {
peer_->onStart();
}

SocketPtr makeSocket() {
auto socket = ns3::Socket::CreateSocket(
GetNode(), ns3::TypeId::LookupByName("ns3::TcpSocketFactory"));
return socket;
}
void listen() {
tcp_listener_ = makeSocket();
tcp_listener_->Bind(ns3::InetSocketAddress{
ns3::Ipv4Address::GetAny(),
kPort,
});
tcp_listener_->Listen();
tcp_listener_->SetAcceptCallback(
ns3::MakeNullCallback<bool, SocketPtr, const ns3::Address &>(),
ns3::MakeCallback(&Application::onAccept, this));
}
SocketPtr makeSocket();
void listen();
void onUdpReceive(SocketPtr socket);
void onAccept(SocketPtr socket, const ns3::Address &address);
void add(PeerIndex peer_index, SocketPtr socket) {
socket->SetRecvCallback(MakeCallback(&Application::pollRead, this));
Expand Down Expand Up @@ -221,23 +214,21 @@ namespace beamsim::ns3_ {
void connect(PeerIndex peer_index);
void send(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message) {
assert2(peer_index != peer_->peer_index_);
auto &sockets = tcp_sockets_[peer_index];
auto connected = sockets.write() != nullptr;
if (not connected) {
connect(peer_index);
}
auto &socket = sockets.write();
auto &state = tcp_socket_state_.at(socket);
state.writing.write(message_id, message);
if (connected) {
pollWrite(socket);
}
}
const IMessage &message);

private:
void sendTcp(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message);

void sendUdp(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message);

public:
Simulator &simulator_;
std::unique_ptr<IPeer> peer_;
private:
SocketPtr tcp_listener_;
std::unordered_map<PeerIndex, SocketInOut> tcp_sockets_;
std::unordered_map<SocketPtr, SocketState> tcp_socket_state_;
Expand All @@ -252,6 +243,10 @@ namespace beamsim::ns3_ {
}
}

void setProtocol(Protocol protocol) {
protocol_ = protocol;
}

// ISimulator
~Simulator() override {
ns3::Simulator::Destroy();
Expand Down Expand Up @@ -367,6 +362,7 @@ namespace beamsim::ns3_ {
}

IMetrics *metrics_;
Protocol protocol_ = Protocol::TCP;
bool cache_messages_ = true;
std::vector<ns3::Ptr<Application>> applications_;
Routing routing_;
Expand All @@ -375,6 +371,105 @@ namespace beamsim::ns3_ {
MessageDecodeFn message_decode_;
};

// Application method implementations
inline SocketPtr Application::makeSocket() {
const char* socketFactory = simulator_.protocol_ == Protocol::UDP ?
"ns3::UdpSocketFactory" : "ns3::TcpSocketFactory";
auto socket = ns3::Socket::CreateSocket(
GetNode(), ns3::TypeId::LookupByName(socketFactory));
return socket;
}

inline void Application::listen() {
if (simulator_.protocol_ == Protocol::UDP) {
tcp_listener_ = makeSocket();
tcp_listener_->Bind(ns3::InetSocketAddress{
ns3::Ipv4Address::GetAny(),
kPort,
});
tcp_listener_->SetRecvCallback(ns3::MakeCallback(&Application::onUdpReceive, this));
} else {
tcp_listener_ = makeSocket();
tcp_listener_->Bind(ns3::InetSocketAddress{
ns3::Ipv4Address::GetAny(),
kPort,
});
tcp_listener_->Listen();
tcp_listener_->SetAcceptCallback(
ns3::MakeNullCallback<bool, SocketPtr, const ns3::Address &>(),
ns3::MakeCallback(&Application::onAccept, this));
}
}

inline void Application::send(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message) {
assert2(peer_index != peer_->peer_index_);

if (simulator_.protocol_ == Protocol::UDP) {
sendUdp(peer_index, message_id, message);
} else {
sendTcp(peer_index, message_id, message);
}
}

inline void Application::sendTcp(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message) {
auto &sockets = tcp_sockets_[peer_index];
auto connected = sockets.write() != nullptr;
if (not connected) {
connect(peer_index);
}
auto &socket = sockets.write();
auto &state = tcp_socket_state_.at(socket);
state.writing.write(message_id, message);
if (connected) {
pollWrite(socket);
}
}

inline void Application::sendUdp(PeerIndex peer_index,
std::optional<MessageId> message_id,
const IMessage &message) {
// For UDP, create a simple packet with the serialized message
Bytes data;
MessageEncodeTo encode_data{[&data](BytesIn part) {
data.insert(data.end(), part.begin(), part.end());
}};
message.encode(encode_data);

auto packet = ns3::Create<ns3::Packet>(data.data(), data.size());

// Send directly to the peer
auto socket = makeSocket();
socket->Connect(ns3::InetSocketAddress{
simulator_.routing_.peer_ips_.at(peer_index), kPort});
socket->Send(packet);
socket->Close();
}

inline void Application::connect(PeerIndex peer_index) {
if (simulator_.protocol_ == Protocol::UDP) {
return; // UDP is connectionless, no need to connect
}

auto &sockets = tcp_sockets_[peer_index];
assert2(not sockets.out);
if (sockets.write()) {
return;
}
auto socket = makeSocket();
socket->Connect(ns3::InetSocketAddress{
simulator_.routing_.peer_ips_.at(peer_index), kPort});
sockets.out = socket;
sockets.write_out = true;
socket->SetConnectCallback(
MakeCallback(&Application::onConnect, this),
MakeCallback(&Application::onConnectError, this));
add(peer_index, socket);
}

void Application::onAccept(SocketPtr socket, const ns3::Address &address) {
auto index = simulator_.routing_.ip_peer_index_.at(
ns3::InetSocketAddress::ConvertFrom(address).GetIpv4());
Expand All @@ -385,6 +480,36 @@ namespace beamsim::ns3_ {
add(index, socket);
}

void Application::onUdpReceive(SocketPtr socket) {
ns3::Address from_address;
auto packet = socket->RecvFrom(from_address);
if (!packet) {
return;
}

auto from_ip = ns3::InetSocketAddress::ConvertFrom(from_address).GetIpv4();
auto it = simulator_.routing_.ip_peer_index_.find(from_ip);
if (it == simulator_.routing_.ip_peer_index_.end()) {
return; // Unknown sender
}
auto from_peer = it->second;

// Extract message data
auto size = packet->GetSize();
Bytes data(size);
packet->CopyData(data.data(), size);

// Decode message
MessageDecodeFrom decode_data{data};
auto message = simulator_.message_decode_(decode_data);

if (simulator_.metrics_ != nullptr) {
simulator_.metrics_->onPeerReceivedMessage(peer_->peer_index_);
}

peer_->onMessage(from_peer, std::move(message));
}

void Application::pollRead(SocketPtr socket) {
auto &state = tcp_socket_state_.at(socket);
while (auto packet = socket->Recv()) {
Expand Down Expand Up @@ -447,21 +572,4 @@ namespace beamsim::ns3_ {
}
}
}

void Application::connect(PeerIndex peer_index) {
auto &sockets = tcp_sockets_[peer_index];
assert2(not sockets.out);
if (sockets.write()) {
return;
}
auto socket = makeSocket();
socket->Connect(ns3::InetSocketAddress{
simulator_.routing_.peer_ips_.at(peer_index), kPort});
sockets.out = socket;
sockets.write_out = true;
socket->SetConnectCallback(
MakeCallback(&Application::onConnect, this),
MakeCallback(&Application::onConnectError, this));
add(peer_index, socket);
}
} // namespace beamsim::ns3_
8 changes: 8 additions & 0 deletions test_tcp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
backend: ns3-direct
topology: gossip
protocol: tcp
roles:
group_count: 10
group_validator_count: 256
global_aggregator_count: 1
group_local_aggregator_count: 1
8 changes: 8 additions & 0 deletions test_udp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
backend: ns3-direct
topology: gossip
protocol: udp
roles:
group_count: 10
group_validator_count: 128
global_aggregator_count: 1
group_local_aggregator_count: 1