Skip to content

Commit

Permalink
add context check during state machine initialization in client
Browse files Browse the repository at this point in the history
Summary: Add a check during state machine initialization to ensure that the FizzClientContext is compatible with the Factory. Only checking keyshares and ciphers for now. Sigschemes are a bit more complicated.

Reviewed By: mingtaoy

Differential Revision: D65295680

fbshipit-source-id: 1ed244a44b23ef964bbf3991beded1cc3cd6832f
  • Loading branch information
Zale Young authored and facebook-github-bot committed Jan 14, 2025
1 parent effa3fe commit e1f180a
Show file tree
Hide file tree
Showing 6 changed files with 226 additions and 14 deletions.
8 changes: 8 additions & 0 deletions fizz/client/ClientProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,12 @@ static ClientHello constructEncryptedClientHello(
return chloOuter;
}

static void checkContext(std::shared_ptr<const FizzClientContext>& context) {
#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS
context->validate();
#endif
}

Actions
EventHandler<ClientTypes, StateEnum::Uninitialized, Event::Connect>::handle(
const State& /*state*/,
Expand All @@ -809,6 +815,8 @@ EventHandler<ClientTypes, StateEnum::Uninitialized, Event::Connect>::handle(

auto context = std::move(connect.context);

checkContext(context);

// Set up SNI (including possible replacement ECH SNI)
folly::Optional<std::string> echSni;
auto sni = std::move(connect.sni);
Expand Down
23 changes: 23 additions & 0 deletions fizz/client/FizzClientContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,28 @@ FizzClientContext::FizzClientContext()
: factory_(std::make_shared<DefaultFactory>()),
clock_(std::make_shared<SystemClock>()) {}

void FizzClientContext::validate() const {
// TODO: check supported sig schemes
for (auto& c : supportedCiphers_) {
if (!FIZZ_CONTEXT_VALIDATION_SHOULD_CHECK_CIPHER(c)) {
continue;
}
// will throw if factory doesn't support this cipher
factory_->makeAead(c);
}

for (auto& g : supportedGroups_) {
// will throw if factory doesn't support this named group
factory_->makeKeyExchange(g, KeyExchangeRole::Client);
}

for (auto& share : defaultShares_) {
if (std::find(supportedGroups_.begin(), supportedGroups_.end(), share) ==
supportedGroups_.end()) {
throw std::runtime_error("unsupported named group in default shares");
}
}
}

} // namespace client
} // namespace fizz
5 changes: 5 additions & 0 deletions fizz/client/FizzClientContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class FizzClientContext {
return factory_;
}

/* Ensure that the TLS parameters set in this context are valid (eg.
* compatible with the factory, etc.). Will throw if invalid.
*/
virtual void validate() const;

/**
* Sets the certificate decompression manager for server certs.
*/
Expand Down
71 changes: 57 additions & 14 deletions fizz/client/test/ClientProtocolTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ namespace test {

class ClientProtocolTest : public ProtocolTest<ClientTypes, Actions> {
public:
class ContextWithMockValidate : public FizzClientContext {
public:
MOCK_METHOD(void, validate, (), (const override));
};

void SetUp() override {
context_ = std::make_shared<FizzClientContext>();
context_ = std::make_shared<ContextWithMockValidate>();
context_->setSupportedVersions({ProtocolVersion::tls_1_3});
context_->setSupportedCiphers(
{CipherSuite::TLS_AES_128_GCM_SHA256,
Expand Down Expand Up @@ -279,7 +284,13 @@ class ClientProtocolTest : public ProtocolTest<ClientTypes, Actions> {

void doFinishedFlow(ClientAuthType authType);

std::shared_ptr<FizzClientContext> context_;
void maybeExpectValidate() {
#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS
EXPECT_CALL(*context_, validate()).Times(1);
#endif
}

std::shared_ptr<ContextWithMockValidate> context_;
MockPlaintextReadRecordLayer* mockRead_;
MockPlaintextWriteRecordLayer* mockWrite_;
MockEncryptedWriteRecordLayer* mockEarlyWrite_;
Expand Down Expand Up @@ -355,6 +366,9 @@ TEST_F(ClientProtocolTest, TestConnectFlow) {
}));
return ret;
}));

maybeExpectValidate();

MockKeyExchange* mockKex;
EXPECT_CALL(
*factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client))
Expand Down Expand Up @@ -398,6 +412,22 @@ TEST_F(ClientProtocolTest, TestConnectFlow) {
EXPECT_FALSE(state_.earlyDataParams().has_value());
}

TEST_F(ClientProtocolTest, TestConnectInvalidContext) {
#if FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS
EXPECT_CALL(*context_, validate()).Times(1).WillRepeatedly(Invoke([]() {
throw std::runtime_error("unsupported parameter");
}));

Connect connect;
connect.context = context_;
fizz::Param param = std::move(connect);

auto actions = detail::processEvent(state_, param);

expectError<std::runtime_error>(actions, {}, "unsupported parameter");
#endif
}

TEST_F(ClientProtocolTest, TestConnectPskFlow) {
auto psk = getCachedPsk();
EXPECT_CALL(*factory_, makePlaintextReadRecordLayer())
Expand All @@ -420,6 +450,9 @@ TEST_F(ClientProtocolTest, TestConnectPskFlow) {
}));
return ret;
}));

maybeExpectValidate();

MockKeyExchange* mockKex;
EXPECT_CALL(
*factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client))
Expand Down Expand Up @@ -516,6 +549,9 @@ TEST_F(ClientProtocolTest, TestConnectPskEarlyFlow) {
}));
return ret;
}));

maybeExpectValidate();

MockKeyExchange* mockKex;
EXPECT_CALL(
*factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client))
Expand Down Expand Up @@ -832,6 +868,9 @@ TEST_F(ClientProtocolTest, TestConnectSniExtFirst) {
TEST_F(ClientProtocolTest, TestConnectMultipleShares) {
MockKeyExchange* mockKex1;
MockKeyExchange* mockKex2;

maybeExpectValidate();

EXPECT_CALL(
*factory_, makeKeyExchange(NamedGroup::x25519, KeyExchangeRole::Client))
.WillOnce(InvokeWithoutArgs([&mockKex1]() {
Expand All @@ -843,6 +882,7 @@ TEST_F(ClientProtocolTest, TestConnectMultipleShares) {
mockKex1 = ret.get();
return ret;
}));

EXPECT_CALL(
*factory_,
makeKeyExchange(NamedGroup::secp256r1, KeyExchangeRole::Client))
Expand Down Expand Up @@ -872,6 +912,9 @@ TEST_F(ClientProtocolTest, TestConnectMultipleShares) {

TEST_F(ClientProtocolTest, TestConnectCachedGroup) {
context_->setDefaultShares({NamedGroup::x25519});

maybeExpectValidate();

MockKeyExchange* mockKex;
EXPECT_CALL(
*factory_,
Expand Down Expand Up @@ -1051,8 +1094,8 @@ TEST_F(ClientProtocolTest, TestConnectECH) {
connect.sni = "www.hostname.com";
const auto& actualChlo = getDefaultClientHello();

// Two randoms should be generated, 1 for the client hello inner and 1 for the
// client hello outer.
// Two randoms should be generated, 1 for the client hello inner and 1 for
// the client hello outer.
EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2);

fizz::Param param = std::move(connect);
Expand Down Expand Up @@ -1118,8 +1161,8 @@ TEST_F(ClientProtocolTest, TestConnectECHWithHybridSupportedGroup) {
connect.sni = "www.hostname.com";
const auto& actualChlo = getDefaultClientHello();

// Two randoms should be generated, 1 for the client hello inner and 1 for the
// client hello outer.
// Two randoms should be generated, 1 for the client hello inner and 1 for
// the client hello outer.
EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2);

fizz::Param param = std::move(connect);
Expand Down Expand Up @@ -1186,8 +1229,8 @@ TEST_F(ClientProtocolTest, TestConnectECHWithAEGIS) {
connect.sni = "www.hostname.com";
const auto& actualChlo = getDefaultClientHello();

// Two randoms should be generated, 1 for the client hello inner and 1 for the
// client hello outer.
// Two randoms should be generated, 1 for the client hello inner and 1 for
// the client hello outer.
EXPECT_CALL(*factory_, makeRandomBytes(_, 32)).Times(2);

fizz::Param param = std::move(connect);
Expand Down Expand Up @@ -3063,8 +3106,8 @@ TEST_F(ClientProtocolTest, TestHelloRetryRequestECHFlow) {
// Add the extension to the inner one
chlo.extensions.push_back(encodeExtension(ech::InnerECHClientHello()));

// Save this one (the real one), then blank the legacy session id and emplace
// OuterExtensions for AAD construction
// Save this one (the real one), then blank the legacy session id and
// emplace OuterExtensions for AAD construction
auto encodedClientHelloInner = encodeHandshake(chlo.clone());

chlo.legacy_session_id = folly::IOBuf::copyBuffer("");
Expand Down Expand Up @@ -3334,8 +3377,8 @@ TEST_F(ClientProtocolTest, TestHelloRetryRequestECHRejectedFlow) {
// Add the extension to the inner one
chlo.extensions.push_back(encodeExtension(ech::InnerECHClientHello()));

// Save this one (the real one), then blank the legacy session id and emplace
// OuterExtensions for AAD construction
// Save this one (the real one), then blank the legacy session id and
// emplace OuterExtensions for AAD construction
auto encodedClientHelloInner = encodeHandshake(chlo.clone());

chlo.legacy_session_id = folly::IOBuf::copyBuffer("");
Expand Down Expand Up @@ -5897,8 +5940,8 @@ TEST_F(ClientProtocolTest, TestPskWithoutCerts) {
// Because CachedPsks can be serialized, and because certificates may fail
// to serialize for whatever reason, there may be an instance where a client
// uses a deserialized cached psk that does not contain either a client or
// a server certificate, but the PSK itself is valid (and the server accepted
// the offered PSK).
// a server certificate, but the PSK itself is valid (and the server
// accepted the offered PSK).
setupExpectingServerHello();

CachedPsk psk = getCachedPsk();
Expand Down
122 changes: 122 additions & 0 deletions fizz/client/test/FizzClientContextTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2018-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <folly/portability/GMock.h>
#include <folly/portability/GTest.h>

#include <fizz/client/ClientProtocol.h>
#include <fizz/client/FizzClientContext.h>
#include <fizz/protocol/test/Mocks.h>
#include <fizz/protocol/test/ProtocolTest.h>

using namespace fizz::test;

namespace fizz {
namespace client {
namespace test {

class FizzClientContextTest : public ::testing::Test {
public:
void SetUp() override {
auto mockFactory = std::make_shared<MockFactory>();
mockFactory->setDefaults();
factory_ = mockFactory.get();

context_ = std::make_shared<FizzClientContext>(mockFactory);
}

void expectValidateThrows(std::string msg) {
try {
context_->validate();
} catch (const std::exception& error) {
EXPECT_THAT(error.what(), HasSubstr(msg));
return;
}
// shouldn't reach here
ASSERT_TRUE(false);
}

std::shared_ptr<FizzClientContext> context_;
MockFactory* factory_;
};

TEST_F(FizzClientContextTest, TestValidateUnsupportedCipher) {
const auto unsupportedCipher = static_cast<fizz::CipherSuite>(0xFFFF);
EXPECT_CALL(*factory_, makeAead(_)).WillRepeatedly([](CipherSuite cipher) {
if (cipher == unsupportedCipher) {
throw std::runtime_error("unsupported cipher");
} else {
return std::make_unique<MockAead>();
}
});

context_->setSupportedCiphers({unsupportedCipher});

expectValidateThrows("unsupported cipher");
}

TEST_F(FizzClientContextTest, TestValidateUnsupportedGroup) {
const auto unsupportedGroup = static_cast<fizz::NamedGroup>(0xFFFF);
EXPECT_CALL(*factory_, makeKeyExchange(_, _))
.WillRepeatedly([](NamedGroup group, KeyExchangeRole /*unused*/) {
if (group == unsupportedGroup) {
throw std::runtime_error("unsupported group");
} else {
return std::make_unique<MockKeyExchange>();
}
});

context_->setSupportedGroups({unsupportedGroup});

expectValidateThrows("unsupported group");
}

TEST_F(FizzClientContextTest, TestValidateUnsupportedDefaultShare) {
context_->setSupportedGroups(
{static_cast<fizz::NamedGroup>(0x01),
static_cast<fizz::NamedGroup>(0x02)});

context_->setDefaultShares(
{static_cast<fizz::NamedGroup>(0x02),
static_cast<fizz::NamedGroup>(0x03)});

expectValidateThrows("unsupported named group in default shares");
}

TEST_F(FizzClientContextTest, TestValidateSuccess) {
EXPECT_CALL(*factory_, makeAead(_)).WillRepeatedly([](CipherSuite cipher) {
if (cipher == static_cast<fizz::CipherSuite>(0xFFFF)) {
throw std::runtime_error("unsupported cipher");
} else {
return std::make_unique<MockAead>();
}
});
EXPECT_CALL(*factory_, makeKeyExchange(_, _))
.WillRepeatedly([](NamedGroup group, KeyExchangeRole /*unused*/) {
if (group == static_cast<fizz::NamedGroup>(0xFFFF)) {
throw std::runtime_error("unsupported group");
} else {
return std::make_unique<MockKeyExchange>();
}
});

context_->setSupportedCiphers(
{static_cast<fizz::CipherSuite>(0x01),
static_cast<fizz::CipherSuite>(0x02)});

context_->setSupportedGroups(
{static_cast<fizz::NamedGroup>(0x03),
static_cast<fizz::NamedGroup>(0x04)});

context_->setDefaultShares({static_cast<fizz::NamedGroup>(0x03)});

EXPECT_NO_THROW(context_->validate());
}
} // namespace test
} // namespace client
} // namespace fizz
11 changes: 11 additions & 0 deletions fizz/fizz-config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,16 @@
#cmakedefine01 FIZZ_CERTIFICATE_USE_OPENSSL_CERT
#cmakedefine01 FIZZ_HAVE_OQS

#if !defined(FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS)
#if defined(NDEBUG)
#define FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS 0
#else
#define FIZZ_ENABLE_CONTEXT_COMPATIBILITY_CHECKS 1
#endif
#endif

#define FIZZ_CONTEXT_VALIDATION_SHOULD_CHECK_CIPHER(x) (true)

#define FIZZ_DEFAULT_FACTORY_HEADER <fizz/protocol/MultiBackendFactory.h>
#define FIZZ_DEFAULT_FACTORY ::fizz::MultiBackendFactory

0 comments on commit e1f180a

Please sign in to comment.