Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate transfer format on negotiate #93

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions src/signalrclient/case_insensitive_comparison_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,62 @@ namespace signalr

return true;
}

bool operator()(const char* s1, const char* s2) const
{
auto length = std::char_traits<char>::length(s1);
if (length != std::char_traits<char>::length(s2))
{
return false;
}

for (unsigned i = 0; i < length; ++i)
{
if (std::toupper(s1[i]) != std::toupper(s2[i]))
{
return false;
}
}

return true;
}

bool operator()(const std::string& s1, const char* s2) const
{
if (s1.length() != std::char_traits<char>::length(s2))
{
return false;
}

for (unsigned i = 0; i < s1.size(); ++i)
{
if (std::toupper(s1[i]) != std::toupper(s2[i]))
{
return false;
}
}

return true;
}

bool operator()(const char* s1, const std::string& s2) const
{
auto length = std::char_traits<char>::length(s1);
if (length != s2.length())
{
return false;
}

for (unsigned i = 0; i < length; ++i)
{
if (std::toupper(s1[i]) != std::toupper(s2[i]))
{
return false;
}
}

return true;
}
};

struct case_insensitive_hash
Expand Down
34 changes: 28 additions & 6 deletions src/signalrclient/connection_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ namespace signalr
std::function<void()> mFunc;
};

void connection_impl::start(std::function<void(std::exception_ptr)> callback) noexcept
void connection_impl::start(transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept
{
{
std::lock_guard<std::mutex> lock(m_stop_lock);
Expand All @@ -147,6 +147,8 @@ namespace signalr
return;
}

m_transfer_format = transfer_format;

// there should not be any active transport at this point
assert(!m_transport);

Expand Down Expand Up @@ -306,6 +308,19 @@ namespace signalr
start_negotiate_internal(url, 0, transport_started);
}

const char* get_transfer_format(transfer_format transfer_format)
{
switch (transfer_format)
{
case transfer_format::text:
return "text";
case transfer_format::binary:
return "binary";
}
assert(false);
return "";
}

void connection_impl::start_negotiate_internal(const std::string& url, int redirect_count, std::function<void(std::shared_ptr<transport> transport, std::exception_ptr)> transport_started)
{
if (m_disconnect_cts->is_canceled())
Expand Down Expand Up @@ -381,6 +396,15 @@ namespace signalr
if (comparer(availableTransport.transport, "WebSockets"))
{
foundWebsockets = true;

auto transfer_format = get_transfer_format(connection->m_transfer_format);
if (std::find_if(availableTransport.transfer_formats.begin(), availableTransport.transfer_formats.end(),
[comparer, transfer_format](const std::string& s) { return comparer(s, transfer_format); }) == availableTransport.transfer_formats.end())
{
transport_started(nullptr, std::make_exception_ptr(signalr_exception(std::string("The server does not support WebSockets with the requested transfer format '").append(transfer_format).append("'."))));
return;
}

break;
}
}
Expand All @@ -391,8 +415,6 @@ namespace signalr
return;
}

// TODO: use transfer format

if (token->is_canceled())
{
transport_started(nullptr, std::make_exception_ptr(canceled_exception()));
Expand Down Expand Up @@ -499,7 +521,7 @@ namespace signalr
auto query_string = "id=" + m_connection_token;
auto connect_url = url_builder::build_connect(url, transport->get_transport_type(), query_string);

transport->start(connect_url, [callback, logger](std::exception_ptr exception)
transport->start(connect_url, m_transfer_format, [callback, logger](std::exception_ptr exception)
mutable {
try
{
Expand Down Expand Up @@ -558,7 +580,7 @@ namespace signalr
}
}

void connection_impl::send(const std::string& data, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept
void connection_impl::send(const std::string& data, std::function<void(std::exception_ptr)> callback) noexcept
{
// To prevent an (unlikely) condition where the transport is nulled out after we checked the connection_state
// and before sending data we store the pointer in the local variable. In this case `send()` will throw but
Expand All @@ -581,7 +603,7 @@ namespace signalr
logger.log(trace_level::info, std::string("sending data: ").append(data));
}

transport->send(data, transfer_format, [logger, callback](std::exception_ptr exception)
transport->send(data, [logger, callback](std::exception_ptr exception)
mutable {
try
{
Expand Down
5 changes: 3 additions & 2 deletions src/signalrclient/connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ namespace signalr

~connection_impl();

void start(std::function<void(std::exception_ptr)> callback) noexcept;
void send(const std::string &data, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept;
void start(transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept;
void send(const std::string &data, std::function<void(std::exception_ptr)> callback) noexcept;
void stop(std::function<void(std::exception_ptr)> callback, std::exception_ptr exception) noexcept;

connection_state get_connection_state() const noexcept;
Expand All @@ -60,6 +60,7 @@ namespace signalr
std::function<void(std::string&&)> m_message_received;
std::function<void(std::exception_ptr)> m_disconnected;
signalr_client_config m_signalr_client_config;
transfer_format m_transfer_format;

std::shared_ptr<cancellation_token_source> m_disconnect_cts;
std::mutex m_stop_lock;
Expand Down
9 changes: 5 additions & 4 deletions src/signalrclient/hub_connection_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ namespace signalr
m_disconnect_cts = std::make_shared<cancellation_token_source>();
m_handshakeReceived = false;
std::weak_ptr<hub_connection_impl> weak_connection = shared_from_this();
m_connection->start([weak_connection, callback](std::exception_ptr start_exception)
auto transfer_format = m_protocol->transfer_format();
m_connection->start(transfer_format, [weak_connection, callback](std::exception_ptr start_exception)
{
auto connection = weak_connection.lock();
if (!connection)
Expand Down Expand Up @@ -242,7 +243,7 @@ namespace signalr
return true;
});

connection->m_connection->send(handshake_request, connection->m_protocol->transfer_format(),
connection->m_connection->send(handshake_request,
[handle_handshake, handshake_request_done, handshake_request_lock](std::exception_ptr exception)
{
{
Expand Down Expand Up @@ -481,7 +482,7 @@ namespace signalr
// weak_ptr prevents a circular dependency leading to memory leak and other problems
auto weak_hub_connection = std::weak_ptr<hub_connection_impl>(shared_from_this());

m_connection->send(message, m_protocol->transfer_format(), [set_completion, set_exception, weak_hub_connection, callback_id](std::exception_ptr exception)
m_connection->send(message, [set_completion, set_exception, weak_hub_connection, callback_id](std::exception_ptr exception)
{
if (exception)
{
Expand Down Expand Up @@ -572,7 +573,7 @@ namespace signalr
std::weak_ptr<hub_connection_impl> weak_connection = connection;
connection->m_connection->send(
connection->m_cached_ping,
connection->m_protocol->transfer_format(), [weak_connection](std::exception_ptr exception)
[weak_connection](std::exception_ptr exception)
{
auto connection = weak_connection.lock();
if (connection)
Expand Down
1 change: 0 additions & 1 deletion src/signalrclient/json_hub_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ namespace signalr
// TODO: other message types
default:
// Future protocol changes can add message types, old clients can ignore them
// TODO: null
break;
}
#pragma warning (pop)
Expand Down
4 changes: 2 additions & 2 deletions src/signalrclient/transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ namespace signalr

virtual ~transport();

virtual void start(const std::string& url, std::function<void(std::exception_ptr)> callback) noexcept = 0;
virtual void start(const std::string& url, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept = 0;
virtual void stop(std::function<void(std::exception_ptr)> callback) noexcept = 0;
virtual void on_close(std::function<void(std::exception_ptr)> callback) = 0;

virtual void send(const std::string& payload, signalr::transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept = 0;
virtual void send(const std::string& payload, std::function<void(std::exception_ptr)> callback) noexcept = 0;

virtual void on_receive(std::function<void(std::string&&, std::exception_ptr)> callback) = 0;

Expand Down
7 changes: 4 additions & 3 deletions src/signalrclient/websocket_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ namespace signalr
}
}

void websocket_transport::start(const std::string& url, std::function<void(std::exception_ptr)> callback) noexcept
void websocket_transport::start(const std::string& url, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept
{
signalr::uri uri(url);
assert(uri.scheme() == "ws" || uri.scheme() == "wss");
Expand All @@ -213,6 +213,7 @@ namespace signalr
std::string("[websocket transport] connecting to: ")
.append(url));

m_transfer_format = transfer_format;
auto websocket_client = m_websocket_client_factory(m_signalr_client_config);

{
Expand Down Expand Up @@ -327,9 +328,9 @@ namespace signalr
m_process_response_callback = callback;
}

void websocket_transport::send(const std::string& payload, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept
void websocket_transport::send(const std::string& payload, std::function<void(std::exception_ptr)> callback) noexcept
{
safe_get_websocket_client()->send(payload, transfer_format, [callback](std::exception_ptr exception)
safe_get_websocket_client()->send(payload, m_transfer_format, [callback](std::exception_ptr exception)
{
if (exception != nullptr)
{
Expand Down
5 changes: 3 additions & 2 deletions src/signalrclient/websocket_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ namespace signalr

transport_type get_transport_type() const noexcept override;

void start(const std::string& url, std::function<void(std::exception_ptr)> callback) noexcept override;
void start(const std::string& url, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept override;
void stop(std::function<void(std::exception_ptr)> callback) noexcept override;
void on_close(std::function<void(std::exception_ptr)> callback) override;

void send(const std::string& payload, transfer_format transfer_format, std::function<void(std::exception_ptr)> callback) noexcept override;
void send(const std::string& payload, std::function<void(std::exception_ptr)> callback) noexcept override;

void on_receive(std::function<void(std::string&&, std::exception_ptr)>) override;

Expand All @@ -47,6 +47,7 @@ namespace signalr

bool m_disconnected;
std::shared_ptr<cancellation_token_source> m_receive_loop_task;
transfer_format m_transfer_format;

void receive_loop();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,37 @@ TEST(case_insensitive_equals_functor, basic_comparison_tests)
{
case_insensitive_equals case_insensitive_compare;

// (const char *, const char *)
ASSERT_TRUE(case_insensitive_compare("", ""));
ASSERT_TRUE(case_insensitive_compare("abc", "ABC"));
ASSERT_TRUE(case_insensitive_compare("abc123!@", "ABC123!@"));

ASSERT_FALSE(case_insensitive_compare("abc", "ABCD"));
ASSERT_FALSE(case_insensitive_compare("abce", "ABCD"));

// (std::string, std::string)
ASSERT_TRUE(case_insensitive_compare(std::string(""), std::string("")));
ASSERT_TRUE(case_insensitive_compare(std::string("abc"), std::string("ABC")));
ASSERT_TRUE(case_insensitive_compare(std::string("abc123!@"), std::string("ABC123!@")));

ASSERT_FALSE(case_insensitive_compare(std::string("abc"), std::string("ABCD")));
ASSERT_FALSE(case_insensitive_compare(std::string("abce"), std::string("ABCD")));

// (std::string, const char *)
ASSERT_TRUE(case_insensitive_compare(std::string(""), ""));
ASSERT_TRUE(case_insensitive_compare(std::string("abc"), "ABC"));
ASSERT_TRUE(case_insensitive_compare(std::string("abc123!@"), "ABC123!@"));

ASSERT_FALSE(case_insensitive_compare(std::string("abc"), "ABCD"));
ASSERT_FALSE(case_insensitive_compare(std::string("abce"), "ABCD"));

// (const char *, std::string)
ASSERT_TRUE(case_insensitive_compare("", std::string("")));
ASSERT_TRUE(case_insensitive_compare("abc", std::string("ABC")));
ASSERT_TRUE(case_insensitive_compare("abc123!@", std::string("ABC123!@")));

ASSERT_FALSE(case_insensitive_compare("abc", std::string("ABCD")));
ASSERT_FALSE(case_insensitive_compare("abce", std::string("ABCD")));
}

TEST(case_insensitive_hash_functor, basic_hash_tests)
Expand Down
Loading