diff --git a/include/cassandra.h b/include/cassandra.h
index 63341bb3c..1b50508cc 100644
--- a/include/cassandra.h
+++ b/include/cassandra.h
@@ -4602,6 +4602,12 @@ cass_ssl_add_trusted_cert_n(CassSsl* ssl,
* common name or one of its subject alternative names. This implies the
* certificate is also present. Hostname resolution must also be enabled.
*
+ * Notes:
+ * - CASS_SSL_VERIFY_PEER_IDENTITY and CASS_SSL_VERIFY_PEER_IDENTITY_DNS are
+ * mutually exclusive options.
+ * - The certificate Common Name is only checked against the IP address or
+ * hostname if there are no Subject Alternative Names in the certificate.
+ *
* Default: CASS_SSL_VERIFY_PEER_CERT
*
* @public @memberof CassSsl
diff --git a/src/address.cpp b/src/address.cpp
index 752bd9ae3..c2bb5bf15 100644
--- a/src/address.cpp
+++ b/src/address.cpp
@@ -75,8 +75,9 @@ Address::Address(const uint8_t* address, uint8_t address_length, int port)
}
}
-Address::Address(const struct sockaddr* addr)
- : family_(UNRESOLVED)
+Address::Address(const struct sockaddr* addr, const String& server_name)
+ : server_name_(server_name)
+ , family_(UNRESOLVED)
, port_(0) {
if (addr->sa_family == AF_INET) {
const struct sockaddr_in* addr_in = reinterpret_cast(addr);
diff --git a/src/address.hpp b/src/address.hpp
index c74c6253a..87c508e0a 100644
--- a/src/address.hpp
+++ b/src/address.hpp
@@ -73,7 +73,7 @@ class Address : public Allocated {
Address(const Address& other, const String& server_name);
Address(const String& hostname_or_address, int port, const String& server_name = String());
Address(const uint8_t* address, uint8_t address_length, int port);
- Address(const struct sockaddr* addr);
+ Address(const struct sockaddr* addr, const String& server_name);
bool equals(const Address& other, bool with_port = true) const;
diff --git a/src/client_insights.cpp b/src/client_insights.cpp
index 351c9c656..d4ad3cdbc 100644
--- a/src/client_insights.cpp
+++ b/src/client_insights.cpp
@@ -635,7 +635,7 @@ class StartupMessageHandler : public RefCounted {
new MultiResolver(bind_callback(&StartupMessageHandler::on_resolve, this)));
}
resolver->resolve(connection_->loop(), contact_point.hostname_or_address(), port,
- config_.resolve_timeout_ms());
+ config_.resolve_timeout_ms(), contact_point.server_name());
}
}
@@ -668,7 +668,8 @@ class StartupMessageHandler : public RefCounted {
Address::SocketStorage name;
int namelen = sizeof(name);
if (uv_tcp_getsockname(tcp, name.addr(), &namelen) == 0) {
- Address address(name.addr());
+ // Pass a blank server name as this is a temporary address.
+ Address address(name.addr(), String());
if (address.is_valid_and_resolved()) {
return address.to_string();
}
diff --git a/src/cluster_config.cpp b/src/cluster_config.cpp
index 9fac5b326..c80a69756 100644
--- a/src/cluster_config.cpp
+++ b/src/cluster_config.cpp
@@ -131,7 +131,8 @@ CassError cass_cluster_set_contact_points_n(CassCluster* cluster, const char* co
explode(String(contact_points, contact_points_length), exploded);
for (Vector::const_iterator it = exploded.begin(), end = exploded.end(); it != end;
++it) {
- cluster->config().contact_points().push_back(Address(*it, -1));
+ // Treat the address string as the server name.
+ cluster->config().contact_points().push_back(Address(*it, -1, *it));
}
}
return CASS_OK;
diff --git a/src/cluster_metadata_resolver.cpp b/src/cluster_metadata_resolver.cpp
index 78ef0c70d..bfe10e382 100644
--- a/src/cluster_metadata_resolver.cpp
+++ b/src/cluster_metadata_resolver.cpp
@@ -39,13 +39,13 @@ class DefaultClusterMetadataResolver : public ClusterMetadataResolver {
int port = it->port() <= 0 ? port_ : it->port();
if (it->is_resolved()) {
- resolved_contact_points_.push_back(Address(it->hostname_or_address(), port));
+ resolved_contact_points_.push_back(Address(it->hostname_or_address(), port, it->server_name()));
} else {
if (!resolver_) {
resolver_.reset(
new MultiResolver(bind_callback(&DefaultClusterMetadataResolver::on_resolve, this)));
}
- resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_);
+ resolver_->resolve(loop, it->hostname_or_address(), port, resolve_timeout_ms_, it->server_name());
}
}
diff --git a/src/resolver.hpp b/src/resolver.hpp
index 2895d81ef..c21663bf2 100644
--- a/src/resolver.hpp
+++ b/src/resolver.hpp
@@ -47,8 +47,9 @@ class Resolver : public RefCounted {
SUCCESS
};
- Resolver(const String& hostname, int port, const Callback& callback)
+ Resolver(const String& hostname, int port, const Callback& callback, const String& server_name)
: hostname_(hostname)
+ , server_name_(server_name)
, port_(port)
, status_(NEW)
, callback_(callback) {
@@ -139,7 +140,7 @@ class Resolver : public RefCounted {
bool init_addresses(struct addrinfo* res) {
bool status = false;
do {
- Address address(res->ai_addr);
+ Address address(res->ai_addr, server_name_);
if (address.is_valid_and_resolved()) {
addresses_.push_back(address);
status = true;
@@ -153,6 +154,7 @@ class Resolver : public RefCounted {
uv_getaddrinfo_t req_;
Timer timer_;
String hostname_;
+ String server_name_;
int port_;
Status status_;
int uv_status_;
@@ -175,10 +177,11 @@ class MultiResolver : public RefCounted {
const Resolver::Vec& resolvers() { return resolvers_; }
void resolve(uv_loop_t* loop, const String& host, int port, uint64_t timeout,
- struct addrinfo* hints = NULL) {
+ const String& server_name, struct addrinfo* hints = NULL) {
inc_ref();
Resolver::Ptr resolver(
- new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this)));
+ new Resolver(host, port, bind_callback(&MultiResolver::on_resolve, this),
+ server_name));
resolver->resolve(loop, timeout, hints);
resolvers_.push_back(resolver);
remaining_++;
diff --git a/src/socket_connector.cpp b/src/socket_connector.cpp
index 3d396a4df..6ce5c1d1b 100644
--- a/src/socket_connector.cpp
+++ b/src/socket_connector.cpp
@@ -120,11 +120,11 @@ void SocketConnector::connect(uv_loop_t* loop) {
hostname_ = address_.hostname_or_address();
resolver_.reset(new Resolver(hostname_, address_.port(),
- bind_callback(&SocketConnector::on_resolve, this)));
+ bind_callback(&SocketConnector::on_resolve, this),
+ address_.server_name()));
resolver_->resolve(loop, settings_.resolve_timeout_ms);
} else {
resolved_address_ = address_;
-
if (settings_.hostname_resolution_enabled) { // Run hostname resolution then connect.
name_resolver_.reset(
new NameResolver(address_, bind_callback(&SocketConnector::on_name_resolve, this)));
diff --git a/src/ssl/ssl_openssl_impl.cpp b/src/ssl/ssl_openssl_impl.cpp
index 3b1124378..3c69b90f9 100644
--- a/src/ssl/ssl_openssl_impl.cpp
+++ b/src/ssl/ssl_openssl_impl.cpp
@@ -228,22 +228,6 @@ static int SSL_CTX_use_certificate_chain_bio(SSL_CTX* ctx, BIO* in) {
return ret;
}
-static X509* load_cert(const char* cert, size_t cert_size) {
- BIO* bio = BIO_new_mem_buf(const_cast(cert), cert_size);
- if (bio == NULL) {
- return NULL;
- }
-
- X509* x509 = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL);
- if (x509 == NULL) {
- ssl_log_errors("Unable to load certificate");
- }
-
- BIO_free_all(bio);
-
- return x509;
-}
-
static EVP_PKEY* load_key(const char* key, size_t key_size, const char* password) {
BIO* bio = BIO_new_mem_buf(const_cast(key), key_size);
if (bio == NULL) {
@@ -489,8 +473,8 @@ void OpenSslSession::verify() {
return;
}
} else if (verify_flags_ &
- CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using hostnames (including wildcards)
- switch (OpenSslVerifyIdentity::match_dns(peer_cert, hostname_)) {
+ CASS_SSL_VERIFY_PEER_IDENTITY_DNS) { // Match using the server name (including wildcards)
+ switch (OpenSslVerifyIdentity::match_dns(peer_cert, sni_server_name_)) {
case OpenSslVerifyIdentity::MATCH:
// Success
break;
@@ -556,13 +540,30 @@ SslSession* OpenSslContext::create_session(const Address& address, const String&
}
CassError OpenSslContext::add_trusted_cert(const char* cert, size_t cert_length) {
- X509* x509 = load_cert(cert, cert_length);
- if (x509 == NULL) {
+ BIO* bio = BIO_new_mem_buf(const_cast(cert), cert_length);
+ if (bio == NULL) {
return CASS_ERROR_SSL_INVALID_CERT;
}
- X509_STORE_add_cert(trusted_store_, x509);
- X509_free(x509);
+ int num_certs = 0;
+
+ // Iterate over the bio, reading out as many certificates as possible.
+ for (X509* cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL);
+ cert != NULL;
+ cert = PEM_read_bio_X509(bio, NULL, pem_password_callback, NULL))
+ {
+ X509_STORE_add_cert(trusted_store_, cert);
+ X509_free(cert);
+ num_certs++;
+ }
+
+ BIO_free_all(bio);
+
+ // If no certificates were read from the bio, that is an error.
+ if (num_certs == 0) {
+ ssl_log_errors("Unable to load certificate(s)");
+ return CASS_ERROR_SSL_INVALID_CERT;
+ }
return CASS_OK;
}
diff --git a/tests/src/unit/tests/test_address.cpp b/tests/src/unit/tests/test_address.cpp
index cae21eb25..ee4655523 100644
--- a/tests/src/unit/tests/test_address.cpp
+++ b/tests/src/unit/tests/test_address.cpp
@@ -17,9 +17,11 @@
#include
#include "address.hpp"
+#include "string.hpp"
using datastax::internal::core::Address;
using datastax::internal::core::AddressSet;
+using datastax::String;
TEST(AddressUnitTest, FromString) {
EXPECT_TRUE(Address("127.0.0.1", 9042).is_resolved());
@@ -64,14 +66,14 @@ TEST(AddressUnitTest, CompareIPv6) {
TEST(AddressUnitTest, ToSockAddrIPv4) {
Address expected("127.0.0.1", 9042);
Address::SocketStorage storage;
- Address actual(expected.to_sockaddr(&storage));
+ Address actual(expected.to_sockaddr(&storage), String());
EXPECT_EQ(expected, actual);
}
TEST(AddressUnitTest, ToSockAddrIPv6) {
Address expected("::1", 9042);
Address::SocketStorage storage;
- Address actual(expected.to_sockaddr(&storage));
+ Address actual(expected.to_sockaddr(&storage), String());
EXPECT_EQ(expected, actual);
}
diff --git a/tests/src/unit/tests/test_resolver.cpp b/tests/src/unit/tests/test_resolver.cpp
index 2efbdd920..f73257d58 100644
--- a/tests/src/unit/tests/test_resolver.cpp
+++ b/tests/src/unit/tests/test_resolver.cpp
@@ -31,8 +31,9 @@ class ResolverUnitTest : public LoopTest {
: status_(Resolver::NEW) {}
Resolver::Ptr create(const String& hostname, int port = 9042) {
+ // Use the hostname as the TLS server name.
return Resolver::Ptr(
- new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this)));
+ new Resolver(hostname, port, bind_callback(&ResolverUnitTest::on_resolve, this), hostname));
}
MultiResolver::Ptr create_multi() {
@@ -108,9 +109,9 @@ TEST_F(ResolverUnitTest, Cancel) {
TEST_F(ResolverUnitTest, Multi) {
MultiResolver::Ptr resolver(create_multi());
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
run_loop();
ASSERT_EQ(3u, resolvers().size());
for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it;
@@ -130,9 +131,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) {
starve_thread_pool(200);
// Use shortest possible timeout for all requests
- resolver->resolve(loop(), "localhost", 9042, 1);
- resolver->resolve(loop(), "localhost", 9042, 1);
- resolver->resolve(loop(), "localhost", 9042, 1);
+ resolver->resolve(loop(), "localhost", 9042, 1, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, 1, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, 1, "localhost");
run_loop();
ASSERT_EQ(3u, resolvers().size());
@@ -145,9 +146,9 @@ TEST_F(ResolverUnitTest, MultiTimeout) {
TEST_F(ResolverUnitTest, MultiInvalid) {
MultiResolver::Ptr resolver(create_multi());
- resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT);
+ resolver->resolve(loop(), "doesnotexist1.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist1.dne");
+ resolver->resolve(loop(), "doesnotexist2.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist2.dne");
+ resolver->resolve(loop(), "doesnotexist3.dne", 9042, RESOLVE_TIMEOUT, "doesnotexist3.dne");
run_loop();
ASSERT_EQ(3u, resolvers().size());
for (Resolver::Vec::const_iterator it = resolvers().begin(), end = resolvers().end(); end != it;
@@ -159,9 +160,9 @@ TEST_F(ResolverUnitTest, MultiInvalid) {
TEST_F(ResolverUnitTest, MultiCancel) {
MultiResolver::Ptr resolver(create_multi());
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
- resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT);
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
+ resolver->resolve(loop(), "localhost", 9042, RESOLVE_TIMEOUT, "localhost");
resolver->cancel();
run_loop();
ASSERT_EQ(3u, resolvers().size());
diff --git a/tests/src/unit/tests/test_socket.cpp b/tests/src/unit/tests/test_socket.cpp
index 37bf730be..3939fae33 100644
--- a/tests/src/unit/tests/test_socket.cpp
+++ b/tests/src/unit/tests/test_socket.cpp
@@ -198,7 +198,8 @@ class SocketUnitTest : public LoopTest {
} else {
bool match = false;
do {
- Address address(res->ai_addr);
+ // Use a blank server name as it's not needed here.
+ Address address(res->ai_addr, String());
if (address.is_valid_and_resolved() && address == Address(DNS_IP_ADDRESS, 8888)) {
match = true;
break;