Skip to content

Commit f558e2c

Browse files
committed
refactor: centralize aws-chunked encoding in ChunkingInterceptor
Move aws-chunked encoding logic from individual HTTP clients to a centralized ChunkingInterceptor for better separation of concerns. - Add ChunkingInterceptor to handle aws-chunked encoding at request level - Remove custom chunking logic from CRT, Curl, and Windows HTTP clients - Simplify HTTP clients to focus on transport-only responsibilities - Maintain full backwards compatibility with existing APIs unit test for chunking stream added logic to detect custom http client and smart default reversing logic to check for chunked mode changing chunking interceptor to use array instead of vector
1 parent 5e2ccc1 commit f558e2c

File tree

18 files changed

+458
-94
lines changed

18 files changed

+458
-94
lines changed

src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ namespace Aws
7878
WHEN_REQUIRED,
7979
};
8080

81+
/**
82+
* Control HTTP client chunking implementation mode.
83+
* DEFAULT: Use SDK's ChunkingInterceptor for aws-chunked encoding
84+
* CLIENT_IMPLEMENTATION: Rely on HTTP client's native chunking (default for custom clients)
85+
*/
86+
enum class HttpClientChunkedMode {
87+
DEFAULT,
88+
CLIENT_IMPLEMENTATION,
89+
};
90+
8191
struct RequestCompressionConfig {
8292
UseRequestCompression useRequestCompression=UseRequestCompression::ENABLE;
8393
size_t requestMinCompressionSizeBytes = 10240;
@@ -493,6 +503,12 @@ namespace Aws
493503
* https://docs.aws.amazon.com/sdkref/latest/guide/feature-account-endpoints.html
494504
*/
495505
Aws::String accountIdEndpointMode = "preferred";
506+
507+
/**
508+
* Control HTTP client chunking implementation mode.
509+
* Default is set automatically: CLIENT_IMPLEMENTATION for custom clients, DEFAULT for AWS clients.
510+
*/
511+
HttpClientChunkedMode httpClientChunkedMode = HttpClientChunkedMode::CLIENT_IMPLEMENTATION;
496512
/**
497513
* Configuration structure for credential providers in the AWS SDK.
498514
* This structure allows passing configuration options to credential providers

src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ namespace Aws
4848
*/
4949
virtual bool SupportsChunkedTransferEncoding() const { return true; }
5050

51+
/**
52+
* Returns true if this is a default AWS SDK HTTP client implementation.
53+
*/
54+
virtual bool IsDefaultAwsHttpClient() const { return false; }
55+
5156
/**
5257
* Stops all requests in progress and prevents any others from initiating.
5358
*/

src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ namespace Aws
5252
Aws::Utils::RateLimits::RateLimiterInterface* readLimiter,
5353
Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter) const override;
5454

55+
bool IsDefaultAwsHttpClient() const override { return true; }
56+
5557
private:
5658
// Yeah, I know, but someone made MakeRequest() const and didn't think about the fact that
5759
// making an HTTP request most certainly mutates state. It was me. I'm the person that did that, and

src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class AWS_CORE_API CurlHttpClient: public HttpClient
3737
Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr,
3838
Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr) const override;
3939

40+
bool IsDefaultAwsHttpClient() const override { return true; }
41+
4042
static void InitGlobalState();
4143
static void CleanupGlobalState();
4244

src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ namespace Aws
5454
*/
5555
virtual bool SupportsChunkedTransferEncoding() const override { return false; }
5656

57+
bool IsDefaultAwsHttpClient() const override { return true; }
58+
5759
protected:
5860
/**
5961
* Override any configuration on request handle.

src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ namespace Aws
4242
*/
4343
const char* GetLogTag() const override { return "WinHttpSyncHttpClient"; }
4444

45+
bool IsDefaultAwsHttpClient() const override { return true; }
46+
4547
private:
4648
// WinHttp specific implementations
4749
void* OpenRequest(const std::shared_ptr<HttpRequest>& request, void* connection, const Aws::StringStream& ss) const override;

src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ namespace Aws
3939
* Gets log tag for use in logging in the base class.
4040
*/
4141
const char* GetLogTag() const override { return "WinInetSyncHttpClient"; }
42+
43+
bool IsDefaultAwsHttpClient() const override { return true; }
4244
private:
4345

4446
// WinHttp specific implementations

src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <smithy/tracing/TelemetryProvider.h>
1010
#include <smithy/interceptor/Interceptor.h>
1111
#include <smithy/client/features/ChecksumInterceptor.h>
12+
#include <smithy/client/features/ChunkingInterceptor.h>
1213
#include <smithy/client/features/UserAgentInterceptor.h>
1314

1415
#include <aws/crt/Variant.h>
@@ -20,6 +21,7 @@
2021
#include <aws/core/utils/Outcome.h>
2122
#include <aws/core/NoResult.h>
2223
#include <aws/core/http/HttpClientFactory.h>
24+
#include <aws/core/http/HttpClient.h>
2325
#include <aws/core/client/AWSErrorMarshaller.h>
2426
#include <aws/core/AmazonWebServiceResult.h>
2527
#include <utility>
@@ -99,8 +101,13 @@ namespace client
99101
m_serviceUserAgentName(std::move(serviceUserAgentName)),
100102
m_httpClient(std::move(httpClient)),
101103
m_errorMarshaller(std::move(errorMarshaller)),
102-
m_interceptors{Aws::MakeShared<ChecksumInterceptor>("AwsSmithyClientBase", *m_clientConfig)}
104+
m_interceptors({
105+
Aws::MakeShared<ChecksumInterceptor>("AwsSmithyClientBase", *m_clientConfig),
106+
Aws::MakeShared<features::ChunkingInterceptor>("AwsSmithyClientBase",
107+
m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : m_clientConfig->httpClientChunkedMode)
108+
})
103109
{
110+
104111
baseInit();
105112
}
106113

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/**
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
#pragma once
6+
#include <aws/core/http/HttpRequest.h>
7+
#include <aws/core/utils/Array.h>
8+
#include <aws/core/utils/StringUtils.h>
9+
#include <aws/core/utils/HashingUtils.h>
10+
#include <aws/core/utils/logging/LogMacros.h>
11+
#include <aws/core/utils/memory/stl/AWSStringStream.h>
12+
#include <aws/core/utils/memory/stl/AWSVector.h>
13+
#include <smithy/interceptor/Interceptor.h>
14+
#include <aws/core/client/ClientConfiguration.h>
15+
#include <aws/core/utils/Outcome.h>
16+
#include <aws/core/client/AWSError.h>
17+
#include <memory>
18+
19+
namespace smithy {
20+
namespace client {
21+
namespace features {
22+
23+
static const size_t AWS_DATA_BUFFER_SIZE = 65536;
24+
static const char* ALLOCATION_TAG = "ChunkingInterceptor";
25+
static const char* CHECKSUM_HEADER_PREFIX = "x-amz-checksum-";
26+
27+
template <size_t DataBufferSize = AWS_DATA_BUFFER_SIZE>
28+
class AwsChunkedStreamBuf : public std::streambuf {
29+
public:
30+
AwsChunkedStreamBuf(Aws::Http::HttpRequest* request,
31+
const std::shared_ptr<Aws::IOStream>& stream,
32+
size_t bufferSize = DataBufferSize)
33+
: m_request(request),
34+
m_stream(stream),
35+
m_data(bufferSize)
36+
{
37+
assert(m_stream != nullptr);
38+
if (m_stream == nullptr) {
39+
AWS_LOGSTREAM_ERROR("AwsChunkedStream", "stream is null");
40+
}
41+
assert(m_request != nullptr);
42+
if (m_request == nullptr) {
43+
AWS_LOGSTREAM_ERROR("AwsChunkedStream", "request is null");
44+
}
45+
}
46+
47+
protected:
48+
int_type underflow() override {
49+
if (gptr() && gptr() < egptr()) {
50+
return traits_type::to_int_type(*gptr());
51+
}
52+
53+
// only read and write to chunked stream if the underlying stream
54+
// is still in a valid state and we have buffer space
55+
if (m_stream->good() && m_chunkingBufferPos >= m_chunkingBufferSize) {
56+
// Reset buffer for new data only when buffer is consumed
57+
m_chunkingBufferPos = 0;
58+
m_chunkingBufferSize = 0;
59+
60+
// Check if we have enough space for worst-case chunk (data + header + footer)
61+
size_t maxChunkSize = m_data.GetLength() + 20; // data + hex header + CRLF
62+
if (m_chunkingBufferSize + maxChunkSize <= m_chunkingBuffer.GetLength()) {
63+
// Try to read in a 64K chunk, if we cant we know the stream is over
64+
m_stream->read(m_data.GetUnderlyingData(), m_data.GetLength());
65+
size_t bytesRead = static_cast<size_t>(m_stream->gcount());
66+
writeChunk(bytesRead);
67+
68+
// if we've read everything from the stream, we want to add the trailer
69+
// to the underlying stream
70+
if ((m_stream->peek() == EOF || m_stream->eof()) && !m_stream->bad()) {
71+
writeTrailerToUnderlyingStream();
72+
}
73+
}
74+
}
75+
76+
// if the chunking buffer is empty there is nothing to read
77+
if (m_chunkingBufferPos >= m_chunkingBufferSize) {
78+
return traits_type::eof();
79+
}
80+
81+
// Set up buffer pointers to read from chunking buffer
82+
size_t remainingBytes = m_chunkingBufferSize - m_chunkingBufferPos;
83+
size_t bytesToRead = std::min(remainingBytes, DataBufferSize);
84+
85+
setg(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos,
86+
m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos,
87+
m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos + bytesToRead);
88+
89+
m_chunkingBufferPos += bytesToRead;
90+
91+
return traits_type::to_int_type(*gptr());
92+
}
93+
94+
private:
95+
void writeTrailerToUnderlyingStream() {
96+
Aws::String trailer = "0\r\n";
97+
if (m_request->GetRequestHash().second != nullptr) {
98+
trailer += "x-amz-checksum-" + m_request->GetRequestHash().first + ":"
99+
+ Aws::Utils::HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) + "\r\n";
100+
}
101+
trailer += "\r\n";
102+
if (m_chunkingBufferSize + trailer.length() <= m_chunkingBuffer.GetLength()) {
103+
std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, trailer.c_str(), trailer.length());
104+
m_chunkingBufferSize += trailer.length();
105+
}
106+
}
107+
108+
void writeChunk(size_t bytesRead) {
109+
if (m_request->GetRequestHash().second != nullptr) {
110+
m_request->GetRequestHash().second->Update(reinterpret_cast<unsigned char*>(m_data.GetUnderlyingData()), bytesRead);
111+
}
112+
113+
if (bytesRead > 0) {
114+
Aws::String chunkHeader = Aws::Utils::StringUtils::ToHexString(bytesRead) + "\r\n";
115+
size_t totalSize = chunkHeader.length() + bytesRead + 2;
116+
if (m_chunkingBufferSize + totalSize <= m_chunkingBuffer.GetLength()) {
117+
std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, chunkHeader.c_str(), chunkHeader.length());
118+
m_chunkingBufferSize += chunkHeader.length();
119+
std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, m_data.GetUnderlyingData(), bytesRead);
120+
m_chunkingBufferSize += bytesRead;
121+
std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, "\r\n", 2);
122+
m_chunkingBufferSize += 2;
123+
}
124+
}
125+
}
126+
127+
// Buffer for chunked data plus overhead for HTTP chunked encoding headers, trailers, and safety margin
128+
Aws::Utils::Array<char> m_chunkingBuffer{DataBufferSize + 128};
129+
size_t m_chunkingBufferSize{0};
130+
size_t m_chunkingBufferPos{0};
131+
Aws::Http::HttpRequest* m_request{nullptr};
132+
std::shared_ptr<Aws::IOStream> m_stream;
133+
Aws::Utils::Array<char> m_data;
134+
};
135+
136+
class AwsChunkedIOStream : public Aws::IOStream {
137+
public:
138+
AwsChunkedIOStream(Aws::Http::HttpRequest* request,
139+
const std::shared_ptr<Aws::IOStream>& originalBody,
140+
size_t bufferSize = AWS_DATA_BUFFER_SIZE)
141+
: Aws::IOStream(&m_buf),
142+
m_buf(request, originalBody, bufferSize) {}
143+
144+
private:
145+
AwsChunkedStreamBuf<> m_buf;
146+
};
147+
148+
/**
149+
* Interceptor that handles chunked encoding for streaming requests with checksums.
150+
* Wraps request body with chunked stream and sets appropriate headers.
151+
*/
152+
class ChunkingInterceptor : public smithy::interceptor::Interceptor {
153+
public:
154+
explicit ChunkingInterceptor(Aws::Client::HttpClientChunkedMode httpClientChunkedMode)
155+
: m_httpClientChunkedMode(httpClientChunkedMode) {}
156+
~ChunkingInterceptor() override = default;
157+
158+
ModifyRequestOutcome ModifyBeforeSigning(smithy::interceptor::InterceptorContext& context) override {
159+
auto request = context.GetTransmitRequest();
160+
161+
if (!ShouldApplyChunking(request)) {
162+
return request;
163+
}
164+
165+
auto originalBody = request->GetContentBody();
166+
if (!originalBody) {
167+
return request;
168+
}
169+
170+
// Set up chunked encoding headers for checksum calculation
171+
const auto& hashPair = request->GetRequestHash();
172+
if (hashPair.second != nullptr) {
173+
Aws::String checksumHeaderValue = Aws::String(CHECKSUM_HEADER_PREFIX) + hashPair.first;
174+
request->DeleteHeader(checksumHeaderValue.c_str());
175+
request->SetHeaderValue(Aws::Http::AWS_TRAILER_HEADER, checksumHeaderValue);
176+
request->SetTransferEncoding(Aws::Http::CHUNKED_VALUE);
177+
178+
if (!request->HasContentEncoding()) {
179+
request->SetContentEncoding(Aws::Http::AWS_CHUNKED_VALUE);
180+
} else {
181+
Aws::String currentEncoding = request->GetContentEncoding();
182+
if (currentEncoding.find(Aws::Http::AWS_CHUNKED_VALUE) == Aws::String::npos) {
183+
request->SetContentEncoding(Aws::String{Aws::Http::AWS_CHUNKED_VALUE} + "," + currentEncoding);
184+
}
185+
}
186+
187+
if (request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) {
188+
request->SetHeaderValue(Aws::Http::DECODED_CONTENT_LENGTH_HEADER, request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER));
189+
request->DeleteHeader(Aws::Http::CONTENT_LENGTH_HEADER);
190+
}
191+
}
192+
193+
auto chunkedBody = Aws::MakeShared<AwsChunkedIOStream>(
194+
ALLOCATION_TAG, request.get(), originalBody);
195+
196+
request->AddContentBody(chunkedBody);
197+
return request;
198+
}
199+
200+
ModifyResponseOutcome ModifyBeforeDeserialization(smithy::interceptor::InterceptorContext& context) override {
201+
return context.GetTransmitResponse();
202+
}
203+
204+
private:
205+
bool ShouldApplyChunking(const std::shared_ptr<Aws::Http::HttpRequest>& request) const {
206+
// Use configuration setting to determine chunking behavior
207+
if (m_httpClientChunkedMode != Aws::Client::HttpClientChunkedMode::DEFAULT) {
208+
return false;
209+
}
210+
211+
if (!request || !request->GetContentBody()) {
212+
return false;
213+
}
214+
215+
// Check if request has checksum requirements
216+
const auto& hashPair = request->GetRequestHash();
217+
return hashPair.second != nullptr;
218+
}
219+
220+
Aws::Client::HttpClientChunkedMode m_httpClientChunkedMode;
221+
};
222+
223+
} // namespace features
224+
} // namespace client
225+
} // namespace smithy

src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -218,26 +218,10 @@ bool AWSAuthV4Signer::SignRequestWithCreds(Aws::Http::HttpRequest& request, cons
218218
request.SetAwsSessionToken(credentials.GetSessionToken());
219219
}
220220

221-
// If the request checksum, set the signer to use a unsigned
222-
// trailing payload. otherwise use it in the header
223-
if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() && request.GetContentBody() != nullptr) {
224-
AWS_LOGSTREAM_DEBUG(v4LogTag, "Note: Http payloads are not being signed. signPayloads="
225-
<< signBody << " http scheme=" << Http::SchemeMapper::ToString(request.GetUri().GetScheme()));
226-
if (request.GetRequestHash().second != nullptr) {
221+
// If the request has checksum and chunking was applied by interceptor, use streaming payload
222+
if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() &&
223+
request.GetContentBody() != nullptr && request.HasHeader(Http::AWS_TRAILER_HEADER)) {
227224
payloadHash = STREAMING_UNSIGNED_PAYLOAD_TRAILER;
228-
Aws::String checksumHeaderValue = Aws::String("x-amz-checksum-") + request.GetRequestHash().first;
229-
request.DeleteHeader(checksumHeaderValue.c_str());
230-
request.SetHeaderValue(Http::AWS_TRAILER_HEADER, checksumHeaderValue);
231-
request.SetTransferEncoding(CHUNKED_VALUE);
232-
request.HasContentEncoding()
233-
? request.SetContentEncoding(Aws::String{Http::AWS_CHUNKED_VALUE} + "," + request.GetContentEncoding())
234-
: request.SetContentEncoding(Http::AWS_CHUNKED_VALUE);
235-
236-
if (request.HasHeader(Http::CONTENT_LENGTH_HEADER)) {
237-
request.SetHeaderValue(Http::DECODED_CONTENT_LENGTH_HEADER, request.GetHeaderValue(Http::CONTENT_LENGTH_HEADER));
238-
request.DeleteHeader(Http::CONTENT_LENGTH_HEADER);
239-
}
240-
}
241225
} else {
242226
payloadHash = ComputePayloadHash(request);
243227
if (payloadHash.empty()) {

0 commit comments

Comments
 (0)