Skip to content

Commit ac9ebb0

Browse files
authored
Merge commit from fork
* Fix "Untrusted HTTP Header Handling (REMOTE*/LOCAL*)" * Fix "Untrusted HTTP Header Handling (X-Forwarded-For)" * Fix security problems in docker/main.cc
1 parent 11eed05 commit ac9ebb0

File tree

3 files changed

+355
-43
lines changed

3 files changed

+355
-43
lines changed

docker/main.cc

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,32 +48,10 @@ std::string get_error_time_format() {
4848
return ss.str();
4949
}
5050

51-
std::string get_client_ip(const Request &req) {
52-
// Check for X-Forwarded-For header first (common in reverse proxy setups)
53-
auto forwarded_for = req.get_header_value("X-Forwarded-For");
54-
if (!forwarded_for.empty()) {
55-
// Get the first IP if there are multiple
56-
auto comma_pos = forwarded_for.find(',');
57-
if (comma_pos != std::string::npos) {
58-
return forwarded_for.substr(0, comma_pos);
59-
}
60-
return forwarded_for;
61-
}
62-
63-
// Check for X-Real-IP header
64-
auto real_ip = req.get_header_value("X-Real-IP");
65-
if (!real_ip.empty()) { return real_ip; }
66-
67-
// Fallback to remote address (though cpp-httplib doesn't provide this
68-
// directly) For demonstration, we'll use a placeholder
69-
return "127.0.0.1";
70-
}
71-
7251
// NGINX Combined log format:
7352
// $remote_addr - $remote_user [$time_local] "$request" $status $body_bytes_sent
7453
// "$http_referer" "$http_user_agent"
7554
void nginx_access_logger(const Request &req, const Response &res) {
76-
auto remote_addr = get_client_ip(req);
7755
std::string remote_user =
7856
"-"; // cpp-httplib doesn't have built-in auth user tracking
7957
auto time_local = get_time_format();
@@ -86,7 +64,7 @@ void nginx_access_logger(const Request &req, const Response &res) {
8664
if (http_user_agent.empty()) http_user_agent = "-";
8765

8866
std::cout << std::format("{} - {} [{}] \"{}\" {} {} \"{}\" \"{}\"",
89-
remote_addr, remote_user, time_local, request,
67+
req.remote_addr, remote_user, time_local, request,
9068
status, body_bytes_sent, http_referer,
9169
http_user_agent)
9270
<< std::endl;
@@ -100,16 +78,15 @@ void nginx_error_logger(const Error &err, const Request *req) {
10078
std::string level = "error";
10179

10280
if (req) {
103-
auto client_ip = get_client_ip(*req);
10481
auto request =
10582
std::format("{} {} {}", req->method, req->path, req->version);
10683
auto host = req->get_header_value("Host");
10784
if (host.empty()) host = "-";
10885

10986
std::cerr << std::format("{} [{}] {}, client: {}, request: "
11087
"\"{}\", host: \"{}\"",
111-
time_local, level, to_string(err), client_ip,
112-
request, host)
88+
time_local, level, to_string(err),
89+
req->remote_addr, request, host)
11390
<< std::endl;
11491
} else {
11592
// If no request context, just log the error
@@ -131,6 +108,10 @@ void print_usage(const char *program_name) {
131108
std::cout << " Format: mount_point:document_root"
132109
<< std::endl;
133110
std::cout << " (default: /:./html)" << std::endl;
111+
std::cout << " --trusted-proxy <ip> Add trusted proxy IP address"
112+
<< std::endl;
113+
std::cout << " (can be specified multiple times)"
114+
<< std::endl;
134115
std::cout << " --version Show version information"
135116
<< std::endl;
136117
std::cout << " --help Show this help message" << std::endl;
@@ -140,13 +121,17 @@ void print_usage(const char *program_name) {
140121
<< " --host localhost --port 8080 --mount /:./html" << std::endl;
141122
std::cout << " " << program_name
142123
<< " --host 0.0.0.0 --port 3000 --mount /api:./api" << std::endl;
124+
std::cout << " " << program_name
125+
<< " --trusted-proxy 192.168.1.100 --trusted-proxy 10.0.0.1"
126+
<< std::endl;
143127
}
144128

145129
struct ServerConfig {
146130
std::string hostname = "localhost";
147131
int port = 8080;
148132
std::string mount_point = "/";
149133
std::string document_root = "./html";
134+
std::vector<std::string> trusted_proxies;
150135
};
151136

152137
enum class ParseResult { SUCCESS, HELP_REQUESTED, VERSION_REQUESTED, ERROR };
@@ -205,6 +190,14 @@ ParseResult parse_command_line(int argc, char *argv[], ServerConfig &config) {
205190
} else if (strcmp(argv[i], "--version") == 0) {
206191
std::cout << CPPHTTPLIB_VERSION << std::endl;
207192
return ParseResult::VERSION_REQUESTED;
193+
} else if (strcmp(argv[i], "--trusted-proxy") == 0) {
194+
if (i + 1 >= argc) {
195+
std::cerr << "Error: --trusted-proxy requires an IP address argument"
196+
<< std::endl;
197+
print_usage(argv[0]);
198+
return ParseResult::ERROR;
199+
}
200+
config.trusted_proxies.push_back(argv[++i]);
208201
} else {
209202
std::cerr << "Error: Unknown option '" << argv[i] << "'" << std::endl;
210203
print_usage(argv[0]);
@@ -218,6 +211,11 @@ bool setup_server(Server &svr, const ServerConfig &config) {
218211
svr.set_logger(nginx_access_logger);
219212
svr.set_error_logger(nginx_error_logger);
220213

214+
// Set trusted proxies if specified
215+
if (!config.trusted_proxies.empty()) {
216+
svr.set_trusted_proxies(config.trusted_proxies);
217+
}
218+
221219
auto ret = svr.set_mount_point(config.mount_point, config.document_root);
222220
if (!ret) {
223221
std::cerr
@@ -285,6 +283,16 @@ int main(int argc, char *argv[]) {
285283
<< std::endl;
286284
std::cout << "Mount point: " << config.mount_point << " -> "
287285
<< config.document_root << std::endl;
286+
287+
if (!config.trusted_proxies.empty()) {
288+
std::cout << "Trusted proxies: ";
289+
for (size_t i = 0; i < config.trusted_proxies.size(); ++i) {
290+
if (i > 0) std::cout << ", ";
291+
std::cout << config.trusted_proxies[i];
292+
}
293+
std::cout << std::endl;
294+
}
295+
288296
std::cout << "Press Ctrl+C to shutdown gracefully..." << std::endl;
289297

290298
auto ret = svr.listen(config.hostname, config.port);

httplib.h

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,8 @@ class Server {
11321132
Server &
11331133
set_header_writer(std::function<ssize_t(Stream &, Headers &)> const &writer);
11341134

1135+
Server &set_trusted_proxies(const std::vector<std::string> &proxies);
1136+
11351137
Server &set_keep_alive_max_count(size_t count);
11361138
Server &set_keep_alive_timeout(time_t sec);
11371139

@@ -1170,6 +1172,9 @@ class Server {
11701172
const std::function<void(Request &)> &setup_request);
11711173

11721174
std::atomic<socket_t> svr_sock_{INVALID_SOCKET};
1175+
1176+
std::vector<std::string> trusted_proxies_;
1177+
11731178
size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT;
11741179
time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND;
11751180
time_t read_timeout_sec_ = CPPHTTPLIB_SERVER_READ_TIMEOUT_SECOND;
@@ -4600,13 +4605,35 @@ inline bool zstd_decompressor::decompress(const char *data, size_t data_length,
46004605
}
46014606
#endif
46024607

4608+
inline bool is_prohibited_header_name(const std::string &name) {
4609+
using udl::operator""_t;
4610+
4611+
switch (str2tag(name)) {
4612+
case "REMOTE_ADDR"_t:
4613+
case "REMOTE_PORT"_t:
4614+
case "LOCAL_ADDR"_t:
4615+
case "LOCAL_PORT"_t: return true;
4616+
default: return false;
4617+
}
4618+
}
4619+
46034620
inline bool has_header(const Headers &headers, const std::string &key) {
4621+
if (is_prohibited_header_name(key)) { return false; }
46044622
return headers.find(key) != headers.end();
46054623
}
46064624

46074625
inline const char *get_header_value(const Headers &headers,
46084626
const std::string &key, const char *def,
46094627
size_t id) {
4628+
if (is_prohibited_header_name(key)) {
4629+
#ifndef CPPHTTPLIB_NO_EXCEPTIONS
4630+
std::string msg = "Prohibited header name '" + key + "' is specified.";
4631+
throw std::invalid_argument(msg);
4632+
#else
4633+
return "";
4634+
#endif
4635+
}
4636+
46104637
auto rng = headers.equal_range(key);
46114638
auto it = rng.first;
46124639
std::advance(it, static_cast<ssize_t>(id));
@@ -7501,6 +7528,12 @@ inline Server &Server::set_header_writer(
75017528
return *this;
75027529
}
75037530

7531+
inline Server &
7532+
Server::set_trusted_proxies(const std::vector<std::string> &proxies) {
7533+
trusted_proxies_ = proxies;
7534+
return *this;
7535+
}
7536+
75047537
inline Server &Server::set_keep_alive_max_count(size_t count) {
75057538
keep_alive_max_count_ = count;
75067539
return *this;
@@ -8289,6 +8322,40 @@ inline bool Server::dispatch_request_for_content_reader(
82898322
return false;
82908323
}
82918324

8325+
inline std::string
8326+
get_client_ip(const std::string &x_forwarded_for,
8327+
const std::vector<std::string> &trusted_proxies) {
8328+
// X-Forwarded-For is a comma-separated list per RFC 7239
8329+
std::vector<std::string> ip_list;
8330+
detail::split(x_forwarded_for.data(),
8331+
x_forwarded_for.data() + x_forwarded_for.size(), ',',
8332+
[&](const char *b, const char *e) {
8333+
auto r = detail::trim(b, e, 0, static_cast<size_t>(e - b));
8334+
ip_list.emplace_back(std::string(b + r.first, b + r.second));
8335+
});
8336+
8337+
for (size_t i = 0; i < ip_list.size(); ++i) {
8338+
auto ip = ip_list[i];
8339+
8340+
auto is_trusted_proxy =
8341+
std::any_of(trusted_proxies.begin(), trusted_proxies.end(),
8342+
[&](const std::string &proxy) { return ip == proxy; });
8343+
8344+
if (is_trusted_proxy) {
8345+
if (i == 0) {
8346+
// If the trusted proxy is the first IP, there's no preceding client IP
8347+
return ip;
8348+
} else {
8349+
// Return the IP immediately before the trusted proxy
8350+
return ip_list[i - 1];
8351+
}
8352+
}
8353+
}
8354+
8355+
// If no trusted proxy is found, return the first IP in the list
8356+
return ip_list.front();
8357+
}
8358+
82928359
inline bool
82938360
Server::process_request(Stream &strm, const std::string &remote_addr,
82948361
int remote_port, const std::string &local_addr,
@@ -8352,15 +8419,16 @@ Server::process_request(Stream &strm, const std::string &remote_addr,
83528419
connection_closed = true;
83538420
}
83548421

8355-
req.remote_addr = remote_addr;
8422+
if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) {
8423+
auto x_forwarded_for = req.get_header_value("X-Forwarded-For");
8424+
req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_);
8425+
} else {
8426+
req.remote_addr = remote_addr;
8427+
}
83568428
req.remote_port = remote_port;
8357-
req.set_header("REMOTE_ADDR", req.remote_addr);
8358-
req.set_header("REMOTE_PORT", std::to_string(req.remote_port));
83598429

83608430
req.local_addr = local_addr;
83618431
req.local_port = local_port;
8362-
req.set_header("LOCAL_ADDR", req.local_addr);
8363-
req.set_header("LOCAL_PORT", std::to_string(req.local_port));
83648432

83658433
if (req.has_header("Accept")) {
83668434
const auto &accept_header = req.get_header_value("Accept");

0 commit comments

Comments
 (0)