Skip to content

Commit

Permalink
Fix/socket unit test (#33)
Browse files Browse the repository at this point in the history
* Rename callback functions and fix some bugs

* Add new constructor and setters to the socket implementation, additionally fix error when passing the responsibility of raw pointers to smart ptrs.

* Rename thread and callback handles in SocketDefines.h

* Adjust socket unit test and add new check when receiving data by the receive callback
  • Loading branch information
FlorianFrank authored Jan 15, 2023
1 parent 656e714 commit 258a26b
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 73 deletions.
4 changes: 2 additions & 2 deletions Additional/ctlib/SocketDefines.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ struct PIL_SOCKET
PIL_BOOL m_IsConnected;
PIL_ErrorHandle m_ErrorHandle;

ThreadHandle *m_callbackThreadHandle;
ThreadHandle *m_ReceiveCallbackThreadHandle;
ReceiveThreadCallbackArgC *m_callbackThreadArg;
PIL_BOOL m_callbackActive;
PIL_BOOL m_ReceiveCallback;

} typedef PIL_SOCKET;

Expand Down
6 changes: 1 addition & 5 deletions Communication/include/ctlib/Socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ PIL_ERROR_CODE PIL_SOCKET_RegisterReceiveCallbackFunction(PIL_SOCKET *socketRet,
void (*callback)(PIL_SOCKET* socket, uint8_t *buffer, uint32_t len,
void *), void *additional);

PIL_ERROR_CODE PIL_SOCKET_UnregisterCallbackFunction(PIL_SOCKET *socketRet);
//#endif // PIL_THREADING



PIL_ERROR_CODE PIL_SOCKET_UnregisterReceiveCallbackFunction(PIL_SOCKET *socketRet);

/**
* @}
Expand Down
61 changes: 34 additions & 27 deletions Communication/include/ctlib/Socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,29 @@ namespace PIL
class Socket
{
public:

struct ReceiveCallbackArg {
explicit ReceiveCallbackArg(std::function<void(std::shared_ptr<PIL::Socket>&, std::string&)>& c): m_ReceiveCallback(c){
}
std::function<void(std::shared_ptr<PIL::Socket>&, std::string&)> &m_ReceiveCallback;
std::shared_ptr<PIL::Socket> m_Socket = {};
};

/**
* @brief Workaround to pass std::functions to C-acceptCallback function.
*/
struct ThreadAcceptArg
{
/** Old C-threading function. */
AcceptThreadArgC argC = {};
/** Function pointer to C++ function returning PIL::Socket object. */
std::function<void(std::unique_ptr<PIL::Socket> &)> acceptCallback = {};
};

Socket();
Socket(TransportProtocol transportProtocol, InternetProtocol internetProtocol, const std::string &address,
int port, uint16_t timeoutInMS);
Socket(std::unique_ptr<PIL_SOCKET> socket, std::string &ip, uint16_t port);
Socket(std::shared_ptr<PIL_SOCKET> &socket, std::string &ip, uint16_t port);
~Socket();

PIL_ERROR_CODE Bind(PIL_BOOL reuse);
Expand All @@ -44,35 +64,20 @@ namespace PIL
PIL_ERROR_CODE Send(std::string &message);
PIL_ERROR_CODE SendTo(std::string &destAddr, int port, uint8_t *buffer, int *bufferLen);

std::string GetSenderIP();
PIL_ERROR_CODE GetInterfaceInfos(InterfaceInfoList *interfaceInfos);
PIL_BOOL IsOpen();
TransportProtocol GetTransportProtocol() const { return m_TransportProtocol; }
InternetProtocol GetInternetProtocol() const { return m_InternetProtocol; }

PIL_ERROR_CODE CreateServerSocket(std::function<void(std::unique_ptr<PIL::Socket>&)> &receiveCallback);
PIL_ERROR_CODE ConnectToServer(std::string &ipAddr, int destPort, std::function<void(std::unique_ptr<Socket>& , std::string &)> &receiveCallback);


struct ReceiveCallbackArg {
explicit ReceiveCallbackArg(std::function<void(std::unique_ptr<PIL::Socket>&, std::string&)>& c): m_ReceiveCallback(c){
}
std::function<void(std::unique_ptr<PIL::Socket>&, std::string&)> &m_ReceiveCallback;
std::unique_ptr<PIL::Socket> m_Socket = {};
};
PIL_ERROR_CODE ConnectToServer(std::string &ipAddr, int destPort, std::function<void(std::shared_ptr<Socket>& , std::string &)> &receiveCallback);

PIL_ERROR_CODE RegisterReceiveCallbackFunction(ReceiveCallbackArg& additionalArg);
PIL_ERROR_CODE UnregisterCallbackFunction();
PIL_ERROR_CODE UnregisterAllCallbackFunctions();

/**
* @brief Workaround to pass std::functions to C-acceptCallback function.
*/
struct ThreadAcceptArg {
/** Old C-threading function. */
AcceptThreadArgC argC = {};
/** Function pointer to C++ function returning PIL::Socket object. */
std::function<void(std::unique_ptr<PIL::Socket>&)> acceptCallback = {};
};
std::string GetSenderIPAddress();
PIL_ERROR_CODE GetInterfaceInfos(InterfaceInfoList *interfaceInfos);
PIL_BOOL IsOpen();
TransportProtocol GetTransportProtocol() const { return m_TransportProtocol; }
InternetProtocol GetInternetProtocol() const { return m_InternetProtocol; }
void setPort(uint16_t mPort);
void setIPAddress(const std::string &mIpAddress);
void setSocketHandle(const std::shared_ptr<PIL_SOCKET> &mCSocketHandle);

private:
uint16_t m_Port;
Expand All @@ -81,10 +86,12 @@ namespace PIL
InternetProtocol m_InternetProtocol;
uint16_t m_TimeoutInMS;

std::unique_ptr<PIL_SOCKET> m_CSocketHandle;

std::shared_ptr<PIL_SOCKET> m_CSocketHandle;
std::vector<std::unique_ptr<PIL_SOCKET>> m_SocketList;
// std::unique_ptr<ThreadAcceptArg> m_ThreadArg;
std::unique_ptr<PIL::Threading<ThreadAcceptArg>> m_AcceptThread;
std::unique_ptr<ReceiveCallbackArg> m_ReceiveCallback;
PIL_ERROR_CODE RegisterAcceptCallback(std::function<void(std::unique_ptr<PIL::Socket>&)> &f);
};

Expand Down
43 changes: 22 additions & 21 deletions Communication/src/Socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@ PIL_ERROR_CODE PIL_SOCKET_Close(PIL_SOCKET *socketRet)
if (!socketRet)
return PIL_INVALID_ARGUMENTS;

if(socketRet->m_callbackActive == TRUE)
PIL_SOCKET_UnregisterCallbackFunction(socketRet);
if(!socketRet->m_IsOpen)
return PIL_NO_ERROR;

if(socketRet->m_ReceiveCallback == TRUE)
PIL_SOCKET_UnregisterReceiveCallbackFunction(socketRet);

#ifndef embedded
if (socketRet->m_IsOpen)
Expand Down Expand Up @@ -435,13 +438,13 @@ PIL_ERROR_CODE PIL_SOCKET_Receive(PIL_SOCKET *socketRet, uint8_t *buffer, uint32
*/
void* PIL_ReceiveThreadFunction(void *handle)
{
assert(handle);
assert(handle);
ReceiveThreadCallbackArgC *arg = (ReceiveThreadCallbackArgC*) handle;
assert(arg->socket && arg->receiveCallback);

uint8_t buffer[DEFAULT_SOCK_BUFF_SIZE];
uint32_t len = DEFAULT_SOCK_BUFF_SIZE;
while(arg->socket->m_callbackActive)
while(arg->socket->m_ReceiveCallback)
{
memset(buffer, 0, DEFAULT_SOCK_BUFF_SIZE);
PIL_ERROR_CODE ret = PIL_SOCKET_WaitTillDataAvail(arg->socket, DEFAULT_TIMEOUT_MS);
Expand All @@ -455,9 +458,7 @@ void* PIL_ReceiveThreadFunction(void *handle)
{
ret = PIL_SOCKET_Receive(arg->socket, buffer, &len, 0);
if(ret == PIL_NO_ERROR)
{
arg->receiveCallback(arg->socket, buffer, len, arg->additionalArg);
}
}
}
return arg;
Expand All @@ -481,20 +482,20 @@ PIL_ERROR_CODE PIL_SOCKET_RegisterReceiveCallbackFunction(PIL_SOCKET *socketRet,
socketRet->m_callbackThreadArg ->receiveCallback = callback;
socketRet->m_callbackThreadArg ->additionalArg = additional;

socketRet->m_callbackThreadHandle = malloc(sizeof(ThreadHandle));
socketRet->m_callbackThreadHandle->m_ThreadArgument.m_ThreadArgument = (void*)&socketRet->m_callbackThreadArg ;
socketRet->m_callbackThreadHandle->m_ThreadArgument.m_ThreadFunction = PIL_ReceiveThreadFunction;
PIL_ERROR_CODE createSockRet = PIL_THREADING_CreateThread(socketRet->m_callbackThreadHandle, PIL_ReceiveThreadFunction, (void*)socketRet->m_callbackThreadArg);
socketRet->m_ReceiveCallbackThreadHandle = malloc(sizeof(ThreadHandle));
socketRet->m_ReceiveCallbackThreadHandle->m_ThreadArgument.m_ThreadArgument = (void*)&socketRet->m_callbackThreadArg ;
socketRet->m_ReceiveCallbackThreadHandle->m_ThreadArgument.m_ThreadFunction = PIL_ReceiveThreadFunction;
PIL_ERROR_CODE createSockRet = PIL_THREADING_CreateThread(socketRet->m_ReceiveCallbackThreadHandle, PIL_ReceiveThreadFunction, (void*)socketRet->m_callbackThreadArg);
if (createSockRet == PIL_NO_ERROR)
{
socketRet->m_callbackActive = TRUE;
createSockRet = PIL_THREADING_RunThread(socketRet->m_callbackThreadHandle, FALSE);
socketRet->m_ReceiveCallback = TRUE;
createSockRet = PIL_THREADING_RunThread(socketRet->m_ReceiveCallbackThreadHandle, FALSE);
}

if (createSockRet != PIL_NO_ERROR)
{
free(socketRet->m_callbackThreadHandle);
socketRet->m_callbackThreadHandle = NULL;
free(socketRet->m_ReceiveCallbackThreadHandle);
socketRet->m_ReceiveCallbackThreadHandle = NULL;
free(socketRet->m_callbackThreadArg);
socketRet->m_callbackThreadArg = NULL;
return createSockRet;
Expand All @@ -508,19 +509,19 @@ PIL_ERROR_CODE PIL_SOCKET_RegisterReceiveCallbackFunction(PIL_SOCKET *socketRet,
* @param socketRet socket for which the callback was registered.
* @return PIL_NO_ERROR on success.
*/
PIL_ERROR_CODE PIL_SOCKET_UnregisterCallbackFunction(PIL_SOCKET *socketRet)
PIL_ERROR_CODE PIL_SOCKET_UnregisterReceiveCallbackFunction(PIL_SOCKET *socketRet)
{
if (!socketRet)
return PIL_INVALID_ARGUMENTS;

if (socketRet->m_callbackActive)
if (socketRet->m_ReceiveCallback)
{
socketRet->m_callbackActive = FALSE;
if (!socketRet->m_callbackThreadHandle)
socketRet->m_ReceiveCallback = FALSE;
if (!socketRet->m_ReceiveCallbackThreadHandle)
return PIL_INVALID_ARGUMENTS;
PIL_ERROR_CODE ret = PIL_THREADING_JoinThread(socketRet->m_callbackThreadHandle, NULL);
free(socketRet->m_callbackThreadHandle);
socketRet->m_callbackThreadHandle = NULL;
PIL_ERROR_CODE ret = PIL_THREADING_JoinThread(socketRet->m_ReceiveCallbackThreadHandle, NULL);
free(socketRet->m_ReceiveCallbackThreadHandle);
socketRet->m_ReceiveCallbackThreadHandle = NULL;
free(socketRet->m_callbackThreadArg);
socketRet->m_callbackThreadArg = NULL;
if (ret != PIL_NO_ERROR)
Expand Down
74 changes: 58 additions & 16 deletions Communication/src/Socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ extern "C" {

namespace PIL
{
Socket::Socket(std::unique_ptr<PIL_SOCKET> socket, std::string &ip, uint16_t port) :
m_CSocketHandle(std::move(socket)), m_IPAddress(ip), m_Port(port),
Socket::Socket(): m_IPAddress(""), m_Port(0),
m_TransportProtocol(TCP), m_InternetProtocol(IPv4), m_TimeoutInMS(0){

}

Socket::Socket(std::shared_ptr<PIL_SOCKET> &socket, std::string &ip, uint16_t port) :
m_CSocketHandle(socket), m_IPAddress(ip), m_Port(port),
m_TransportProtocol(TCP), m_InternetProtocol(IPv4), m_TimeoutInMS(0){
}

Expand All @@ -40,7 +45,7 @@ namespace PIL
}

PIL_ERROR_CODE Socket::Disconnect(){
auto retCode = UnregisterCallbackFunction();
auto retCode = UnregisterAllCallbackFunctions();
if(retCode != PIL_NO_ERROR){
#ifdef PIL_EXCEPTION_HANDLING
throw PIL::Exception(retCode, __FILENAME__, __LINE__);
Expand Down Expand Up @@ -164,7 +169,7 @@ namespace PIL
return ret;
}

std::string Socket::GetSenderIP(){
std::string Socket::GetSenderIPAddress(){
const char *senderIP = PIL_SOCKET_GetSenderIP(m_CSocketHandle.get());
#ifdef PIL_EXCEPTION_HANDLING
if(!senderIP)
Expand All @@ -182,11 +187,15 @@ namespace PIL

char ipAddr[MAX_IP_LEN];

std::unique_ptr<PIL::Socket> socketPtr = std::make_unique<PIL::Socket>();
do {
if(!arg->argC.socket->m_IsOpen){
return arg.get();
}
std::unique_ptr<PIL_SOCKET> retHandle = std::make_unique<PIL_SOCKET>();
auto retHandle = std::make_shared<PIL_SOCKET>();
retHandle->m_IsOpen = TRUE;
retHandle->m_ReceiveCallback = FALSE;

int ret = PIL_SOCKET_Accept(arg->argC.socket, ipAddr, retHandle.get());
if(ret == PIL_INTERFACE_CLOSED)
return arg.get();
Expand All @@ -199,7 +208,11 @@ namespace PIL
retHandle->m_IsOpen = TRUE;
retHandle->m_IsConnected = TRUE;
std::string ipStr = ipAddr;
std::unique_ptr<PIL::Socket> socketPtr = std::make_unique<PIL::Socket>(std::move(retHandle), ipStr, retHandle->m_port);

socketPtr->setSocketHandle(retHandle);
socketPtr->setIPAddress(ipStr);
socketPtr->setPort(retHandle->m_port);

arg->acceptCallback(socketPtr);
}while(arg->argC.socket->m_IsOpen);
return arg.get();
Expand Down Expand Up @@ -235,7 +248,7 @@ namespace PIL
}

PIL_ERROR_CODE
Socket::ConnectToServer(std::string &ipAddr, int destPort, std::function<void(std::unique_ptr<Socket>& , std::string &)> &receiveCallback){
Socket::ConnectToServer(std::string &ipAddr, int destPort, std::function<void(std::shared_ptr<Socket>& , std::string &)> &receiveCallback){
auto functionPtr = [](PIL_SOCKET* socket, uint8_t *buffer, uint32_t bufferLen, void* additionalArg){
if(!additionalArg){
#ifdef PIL_EXCEPTION_HANDLING
Expand All @@ -244,16 +257,30 @@ namespace PIL
return;
}
std::string ip = socket->m_IPAddress;
std::unique_ptr<PIL_SOCKET> s = std::unique_ptr<PIL_SOCKET>(socket);
std::unique_ptr<PIL::Socket> socketCXX = std::make_unique<PIL::Socket>(std::move(s), ip, socket->m_port);

// Create new socket object now managed by shared_ptr
auto sock = new PIL_SOCKET;
*sock = *socket;

std::shared_ptr<PIL_SOCKET> s = std::shared_ptr<PIL_SOCKET>(sock);

// Set old socket to closed!
socket->m_IsOpen = FALSE;
socket->m_IsConnected = FALSE;

s->m_IsOpen = TRUE;
s->m_ReceiveCallback = FALSE;

auto socketCXX = std::make_shared<PIL::Socket>(s, ip, socket->m_port);
std::string value = std::string((char *)buffer, bufferLen);
auto *arg = reinterpret_cast<ReceiveCallbackArg*>(additionalArg);
arg->m_ReceiveCallback(socketCXX, value);
auto *callbackFunction = reinterpret_cast<std::unique_ptr<ReceiveCallbackArg>*>(additionalArg);
(*callbackFunction)->m_ReceiveCallback(socketCXX, value);
};

m_ReceiveCallback = std::move(std::make_unique<ReceiveCallbackArg>(receiveCallback));
auto ret = PIL_SOCKET_ConnectToServer(m_CSocketHandle.get(), ipAddr.c_str(),
m_Port, destPort, m_TimeoutInMS, functionPtr,
&receiveCallback);
&m_ReceiveCallback);
#ifdef PIL_EXCEPTION_HANDLING
if(ret != PIL_NO_ERROR)
throw PIL::Exception(ret, __FILENAME__, __LINE__);
Expand All @@ -268,8 +295,8 @@ namespace PIL
std::string ip = socket->m_IPAddress;
std::string value = std::string((char *)buffer, bufferLen);
auto *arg = reinterpret_cast<ReceiveCallbackArg*>(additionalArg);
auto s = std::unique_ptr<PIL_SOCKET>(socket);
arg->m_Socket = std::make_unique<PIL::Socket>(std::move(s), ip, socket->m_port);
auto s = std::shared_ptr<PIL_SOCKET>(socket);
arg->m_Socket = std::make_shared<PIL::Socket>(s, ip, socket->m_port);
arg->m_ReceiveCallback(arg->m_Socket, value);
};

Expand All @@ -281,10 +308,10 @@ namespace PIL
return ret;
}

PIL_ERROR_CODE Socket::UnregisterCallbackFunction(){
PIL_ERROR_CODE Socket::UnregisterAllCallbackFunctions(){
if(m_CSocketHandle == nullptr)
return PIL_NO_ERROR;
auto ret = PIL_SOCKET_UnregisterCallbackFunction(m_CSocketHandle.get());
auto ret = PIL_SOCKET_UnregisterReceiveCallbackFunction(m_CSocketHandle.get());
#ifdef PIL_EXCEPTION_HANDLING
if(ret != PIL_NO_ERROR)
throw PIL::Exception(ret, __FILENAME__, __LINE__);
Expand All @@ -302,5 +329,20 @@ namespace PIL
return ret;
}

void Socket::setPort(uint16_t mPort)
{
m_Port = mPort;
}

void Socket::setIPAddress(const std::string &mIpAddress)
{
m_IPAddress = mIpAddress;
}

void Socket::setSocketHandle(const std::shared_ptr<PIL_SOCKET> &mCSocketHandle)
{
m_CSocketHandle = mCSocketHandle;
}

}
#endif // PIL_CXX
6 changes: 4 additions & 2 deletions UnitTesting/SocketUnitTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ void ReceiveHandlerC(PIL_SOCKET *socket, uint8_t* buffer, uint32_t len, void*)
}


void ReceiveHandlerCPP(std::unique_ptr<PIL::Socket> &socket, std::string& buffer)
void ReceiveHandlerCPP(std::shared_ptr<PIL::Socket> &socket, std::string& buffer)
{
strncpy(recvBuff, (char*)buffer.c_str(), buffer.length());
EXPECT_TRUE(strcmp(recvBuff, loram_ipsum.c_str()) == 0);
}


Expand Down Expand Up @@ -88,8 +89,9 @@ TEST(SocketTest_CPP, SimpleSocketTest)
EXPECT_EQ(ret, PIL_NO_ERROR);
PIL::Socket clientSock(TCP, IPv4, "localhost", 14003, 1000);
std::string destIP = "127.0.0.1";
std::function<void(std::unique_ptr<PIL::Socket>& , std::string &)> callbackFunc = ReceiveHandlerCPP;
std::function<void(std::shared_ptr<PIL::Socket>& , std::string &)> callbackFunc = ReceiveHandlerCPP;
ret = clientSock.ConnectToServer(destIP, 14002, callbackFunc);
std::this_thread::sleep_for(std::chrono::microseconds(10000));
EXPECT_EQ(ret, PIL_NO_ERROR);
}

Expand Down

0 comments on commit 258a26b

Please sign in to comment.