diff --git a/sniproxy.conf b/sniproxy.conf index b6bbcc9b..7bddb6c1 100644 --- a/sniproxy.conf +++ b/sniproxy.conf @@ -28,7 +28,17 @@ listen 80 { listen 443 { proto tls + +# The prefer keyword enables sniproxy to select which handler +# will take precedence if both ALPN and SNI tables are set, and +# the client sends both extensions. + + prefer alpn +# prefer sni + table https_hosts + ALPNtable service_hosts + fallback localhost:8080 } listen 192.0.2.10:80 { @@ -85,3 +95,9 @@ table { example.com 192.0.2.10 example.net 192.0.2.20 } + +ALPNtable service_hosts { + http/1.1 192.0.2.10 + http/2.0 192.0.2.20 + spdy/3 192.0.2.25 +} diff --git a/src/Makefile.am b/src/Makefile.am index fc3eaa1d..0792aeba 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -17,6 +17,7 @@ sniproxy_SOURCES = address.c \ connection.h \ http.c \ http.h \ + types.h \ listener.c \ listener.h \ logger.c \ diff --git a/src/backend.c b/src/backend.c index b279169b..b273c745 100644 --- a/src/backend.c +++ b/src/backend.c @@ -107,15 +107,17 @@ init_backend(struct Backend *backend) { } struct Backend * -lookup_backend(const struct Backend_head *head, const char *hostname) { +lookup_backend(const struct Backend_head *head, const char *name, size_t name_size) { struct Backend *iter; - if (hostname == NULL) - hostname = ""; + if (name == NULL) { + name = ""; + name_size = 0; + } STAILQ_FOREACH(iter, head, entries) if (pcre_exec(iter->hostname_re, NULL, - hostname, strlen(hostname), 0, 0, NULL, 0) >= 0) + name, name_size, 0, 0, NULL, 0) >= 0) return iter; return NULL; diff --git a/src/backend.h b/src/backend.h index 35e5dc7e..9c152ee8 100644 --- a/src/backend.h +++ b/src/backend.h @@ -34,7 +34,7 @@ STAILQ_HEAD(Backend_head, Backend); struct Backend { - char *hostname; + char *hostname; /* name actually */ struct Address *address; /* Runtime fields */ @@ -44,7 +44,7 @@ struct Backend { void add_backend(struct Backend_head *, struct Backend *); int init_backend(struct Backend *); -struct Backend *lookup_backend(const struct Backend_head *, const char *); +struct Backend *lookup_backend(const struct Backend_head *, const char *, size_t); int open_backend_socket(struct Backend *, const char *); void print_backend_config(FILE *, const struct Backend *); void remove_backend(struct Backend_head *, struct Backend *); diff --git a/src/config.c b/src/config.c index 17f7e988..54009bf9 100644 --- a/src/config.c +++ b/src/config.c @@ -41,6 +41,7 @@ static int accept_username(struct Config *, char *); static int accept_pidfile(struct Config *, char *); static int end_listener_stanza(struct Config *, struct Listener *); static int end_table_stanza(struct Config *, struct Table *); +static int end_alpn_table_stanza(struct Config *, struct Table *); static int end_backend(struct Table *, struct Backend *); static struct LoggerBuilder *new_logger_builder(); static int accept_logger_filename(struct LoggerBuilder *, char *); @@ -55,11 +56,21 @@ struct Keyword listener_stanza_grammar[] = { (int(*)(void *, char *))accept_listener_protocol, NULL, NULL}, + { "prefer", + NULL, + (int(*)(void *, char *))prefer_in_listener, + NULL, + NULL}, { "table", NULL, (int(*)(void *, char *))accept_listener_table_name, NULL, NULL}, + { "ALPNtable", + NULL, + (int(*)(void *, char *))accept_listener_alpn_table_name, + NULL, + NULL}, { "fallback", NULL, (int(*)(void *, char *))accept_listener_fallback_address, @@ -76,6 +87,14 @@ static struct Keyword table_stanza_grammar[] = { (int(*)(void *, void *))end_backend}, }; +static struct Keyword alpn_table_stanza_grammar[] = { + { NULL, + (void *(*)())new_backend, + (int(*)(void *, char *))accept_backend_arg, + NULL, + (int(*)(void *, void *))end_backend}, +}; + static struct Keyword logger_stanza_grammar[] = { { "filename", NULL, @@ -121,6 +140,11 @@ static struct Keyword global_grammar[] = { (int(*)(void *, char *))accept_table_arg, table_stanza_grammar, (int(*)(void *, void *))end_table_stanza}, + { "ALPNtable", + (void *(*)())new_table, + (int(*)(void *, char *))accept_table_arg, + alpn_table_stanza_grammar, + (int(*)(void *, void *))end_alpn_table_stanza}, { NULL, NULL, NULL, NULL, NULL } }; @@ -142,6 +166,7 @@ init_config(const char *filename) { config->pidfile = NULL; SLIST_INIT(&config->listeners); SLIST_INIT(&config->tables); + SLIST_INIT(&config->alpn_tables); config->filename = strdup(filename); if (config->filename == NULL) { @@ -186,6 +211,7 @@ free_config(struct Config *config) { free_listeners(&config->listeners); free_tables(&config->tables); + free_tables(&config->alpn_tables); free(config); } @@ -224,6 +250,10 @@ print_config(FILE *file, struct Config *config) { SLIST_FOREACH(table, &config->tables, entries) { print_table_config(file, table); } + + SLIST_FOREACH(table, &config->alpn_tables, entries) { + print_table_config(file, table); + } } static int @@ -271,6 +301,15 @@ end_table_stanza(struct Config *config, struct Table *table) { return 1; } +static int +end_alpn_table_stanza(struct Config *config, struct Table *table) { + /* TODO check table */ + + add_table(&config->alpn_tables, table); + + return 1; +} + static int end_backend(struct Table *table, struct Backend *backend) { /* TODO check backend */ diff --git a/src/config.h b/src/config.h index 32911b66..d133e17c 100644 --- a/src/config.h +++ b/src/config.h @@ -36,6 +36,7 @@ struct Config { char *pidfile; struct Listener_head listeners; struct Table_head tables; + struct Table_head alpn_tables; }; struct Config *init_config(const char *); diff --git a/src/connection.c b/src/connection.c index c8161bc0..33420327 100644 --- a/src/connection.c +++ b/src/connection.c @@ -299,14 +299,15 @@ static void handle_connection_client_hello(struct Connection *con, struct ev_loop *loop) { char buffer[1460]; /* TCP MSS over standard Ethernet and IPv4 */ ssize_t len; - char *hostname = NULL; int parse_result; char peer_ip[INET6_ADDRSTRLEN + 8]; int sockfd = -1; + struct ProtocolRes pres; len = buffer_peek(con->client.buffer, buffer, sizeof(buffer)); - parse_result = con->listener->protocol->parse_packet(buffer, len, &hostname); + parse_result = con->listener->protocol->parse_packet(con->listener, buffer, len, + &pres); if (parse_result == -1) { return; /* incomplete request: try again */ } else if (parse_result < -1) { @@ -329,12 +330,19 @@ handle_connection_client_hello(struct Connection *con, struct ev_loop *loop) { return; } } - con->hostname = hostname; + con->hostname = pres.name; + + if (pres.name == NULL) { + warn("name returned from parse_packet is null"); + return; + } + /* TODO break the remainder out into other states */ - /* lookup server for hostname and connect */ + /* lookup server for name and connect */ struct Address *server_address = - listener_lookup_server_address(con->listener, hostname); + listener_lookup_server_address(con->listener, pres.name, pres.name_size, pres.name_type); + if (server_address == NULL) { close_client_socket(con, loop); return; @@ -399,7 +407,7 @@ handle_connection_client_hello(struct Connection *con, struct ev_loop *loop) { } if (sockfd < 0) { - warn("Server connection failed to %s", hostname); + warn("Server connection failed to %.*s", pres.name_size, pres.name); close_client_socket(con, loop); return; } diff --git a/src/http.c b/src/http.c index 3c8ae92a..018d0e04 100644 --- a/src/http.c +++ b/src/http.c @@ -42,7 +42,7 @@ static const char http_503[] = "Connection: close\r\n\r\n" "Backend not available"; -static int parse_http_header(const char *, size_t, char **); +static int parse_http_header(const struct Listener *, const char*, size_t, struct ProtocolRes*); static int get_header(const char *, const char *, int, char **); static int next_header(const char **, int *); @@ -69,13 +69,15 @@ const struct Protocol *http_protocol = &http_protocol_st; * */ static int -parse_http_header(const char* data, size_t data_len, char **hostname) { +parse_http_header(const struct Listener * t, const char* data, size_t data_len, + struct ProtocolRes* pres) { int result, i; + char* hostname; - if (hostname == NULL) + if (pres == NULL) return -3; - result = get_header("Host:", data, data_len, hostname); + result = get_header("Host:", data, data_len, &hostname); if (result < 0) return result; @@ -85,12 +87,17 @@ parse_http_header(const char* data, size_t data_len, char **hostname) { * so we trim off port portion */ for (i = result - 1; i >= 0; i--) - if ((*hostname)[i] == ':') { - (*hostname)[i] = '\0'; + if (hostname[i] == ':') { + hostname[i] = '\0'; result = i; + break; } + pres->name = hostname; + pres->name_size = result; + pres->name_type = NTYPE_HOST; + return result; } diff --git a/src/listener.c b/src/listener.c index be1ace96..48d1bd66 100644 --- a/src/listener.c +++ b/src/listener.c @@ -55,11 +55,11 @@ static void accept_cb(struct ev_loop *, struct ev_io *, int); */ void init_listeners(struct Listener_head *listeners, - const struct Table_head *tables) { + const struct Table_head *tables, const struct Table_head *alpn_tables) { struct Listener *iter; SLIST_FOREACH(iter, listeners, entries) { - if (init_listener(iter, tables) < 0) { + if (init_listener(iter, tables, alpn_tables) < 0) { fprintf(stderr, "Failed to initialize listener\n"); print_listener_config(stderr, iter); exit(1); @@ -113,12 +113,33 @@ accept_listener_arg(struct Listener *listener, char *arg) { return 1; } +int +prefer_in_listener(struct Listener *listener, char *arg) { + if (strcasecmp(arg, "alpn") == 0) { + listener->prefer_alpn = 1; + } else { + listener->prefer_alpn = 0; + } + + return 1; +} + int accept_listener_table_name(struct Listener *listener, char *table_name) { if (listener->table_name == NULL) listener->table_name = strdup(table_name); else - fprintf(stderr, "Duplicate table_name: %s\n", table_name); + fprintf(stderr, "Duplicate table name: %s\n", table_name); + + return 1; +} + +int +accept_listener_alpn_table_name(struct Listener *listener, char *table_name) { + if (listener->alpn_table_name == NULL) + listener->alpn_table_name = strdup(table_name); + else + fprintf(stderr, "Duplicate ALPNtable name: %s\n", table_name); return 1; } @@ -210,16 +231,35 @@ valid_listener(const struct Listener *listener) { } int -init_listener(struct Listener *listener, const struct Table_head *tables) { +init_listener(struct Listener *listener, const struct Table_head *tables, + const struct Table_head *alpn_tables) { int sockfd; int on = 1; - listener->table = table_lookup(tables, listener->table_name); - if (listener->table == NULL) { - fprintf(stderr, "Table \"%s\" not defined\n", listener->table_name); - return -1; + listener->table = NULL; + listener->alpn_table = NULL; + + if (listener->table_name != NULL) { + listener->table = table_lookup(tables, listener->table_name); + if (listener->table == NULL) { + fprintf(stderr, "Table \"%s\" not defined\n", listener->table_name); + return -1; + } + init_table(listener->table); } - init_table(listener->table); + + if (listener->alpn_table_name != NULL) { + listener->alpn_table = table_lookup(alpn_tables, listener->alpn_table_name); + if (listener->alpn_table == NULL) { + fprintf(stderr, "ALPNTable \"%s\" not defined\n", listener->alpn_table_name); + return -1; + } + init_table(listener->alpn_table); + } + + /* Here listener->table and listener->alpn_table may both be null. + * In that case the default SNI table will be used. + */ /* If no port was specified on the fallback address, inherit the address * from the listening address */ @@ -261,10 +301,15 @@ init_listener(struct Listener *listener, const struct Table_head *tables) { struct Address * listener_lookup_server_address(const struct Listener *listener, - const char *hostname) { + const char *name, unsigned name_size, unsigned ntype) { struct Address *new_addr = NULL; - const struct Address *addr = - table_lookup_server_address(listener->table, hostname); + const struct Address *addr = NULL; + + if (ntype == NTYPE_ALPN && listener->alpn_table) { + addr = table_lookup_server_address(listener->alpn_table, name, name_size); + } else if (ntype == NTYPE_HOST && listener->table) { + addr = table_lookup_server_address(listener->table, name, name_size); + } if (addr == NULL) addr = listener->fallback_address; @@ -275,9 +320,9 @@ listener_lookup_server_address(const struct Listener *listener, int port = address_port(addr); if (address_is_wildcard(addr)) { - new_addr = new_address(hostname); + new_addr = new_address(name); if (new_addr == NULL) { - warn("Invalid hostname %s", hostname); + warn("Invalid name %s", name); return NULL; } diff --git a/src/listener.h b/src/listener.h index 82cdea66..32d22ffa 100644 --- a/src/listener.h +++ b/src/listener.h @@ -38,10 +38,13 @@ struct Listener { struct Address *address, *fallback_address; const struct Protocol *protocol; char *table_name; + char *alpn_table_name; + unsigned prefer_alpn; /* Runtime fields */ struct ev_io watcher; struct Table *table; + struct Table *alpn_table; SLIST_ENTRY(Listener) entries; }; @@ -49,18 +52,20 @@ struct Listener { struct Listener *new_listener(); int accept_listener_arg(struct Listener *, char *); int accept_listener_table_name(struct Listener *, char *); +int accept_listener_alpn_table_name(struct Listener *, char *); +int prefer_in_listener(struct Listener *, char *); int accept_listener_fallback_address(struct Listener *, char *); int accept_listener_protocol(struct Listener *, char *); void add_listener(struct Listener_head *, struct Listener *); -void init_listeners(struct Listener_head *, const struct Table_head *); +void init_listeners(struct Listener_head *, const struct Table_head *, const struct Table_head *); void remove_listener(struct Listener_head *, struct Listener *); void free_listeners(struct Listener_head *); int valid_listener(const struct Listener *); -int init_listener(struct Listener *, const struct Table_head *); +int init_listener(struct Listener *, const struct Table_head *, const struct Table_head *); struct Address *listener_lookup_server_address(const struct Listener *, - const char *); + const char *, unsigned, unsigned ntype); void print_listener_config(FILE *, const struct Listener *); void free_listener(struct Listener *); diff --git a/src/protocol.h b/src/protocol.h index a8b0df40..e007a362 100644 --- a/src/protocol.h +++ b/src/protocol.h @@ -26,10 +26,24 @@ #ifndef PROTOCOL_H #define PROTOCOL_H +#include "listener.h" + +enum NameType { + NTYPE_NONE = 0, + NTYPE_HOST = 1, + NTYPE_ALPN = 2, +}; + +struct ProtocolRes { + char* name; + unsigned name_size; + unsigned name_type; +}; + struct Protocol { const char *name; int default_port; - int (*parse_packet)(const char*, size_t, char **); + int (*parse_packet)(const struct Listener *, const char*, size_t, struct ProtocolRes*); const char *abort_message; size_t abort_message_len; }; diff --git a/src/server.c b/src/server.c index 1214e3d6..3126af83 100644 --- a/src/server.c +++ b/src/server.c @@ -56,7 +56,7 @@ init_server(struct Config *c) { /* ignore SIGPIPE, or it will kill us */ signal(SIGPIPE, SIG_IGN); - init_listeners(&config->listeners, &config->tables); + init_listeners(&config->listeners, &config->tables, &config->alpn_tables); } void diff --git a/src/table.c b/src/table.c index 420428cb..1fd36893 100644 --- a/src/table.c +++ b/src/table.c @@ -30,19 +30,13 @@ #include "backend.h" #include "address.h" #include "logger.h" - - -static inline struct Backend * -table_lookup_backend(const struct Table *table, const char *hostname) { - return lookup_backend(&table->backends, hostname); -} +#include "protocol.h" static inline void remove_table_backend(struct Table *table, struct Backend *backend) { remove_backend(&table->backends, backend); } - struct Table * new_table() { struct Table *table; @@ -122,12 +116,14 @@ remove_table(struct Table_head *tables, struct Table *table) { } const struct Address * -table_lookup_server_address(const struct Table *table, const char *hostname) { +table_lookup_server_address(const struct Table *table, + const char *name, + unsigned name_size) { struct Backend *b; - b = table_lookup_backend(table, hostname); + b = table_lookup_backend(table, name, name_size); if (b == NULL) { - info("No match found for %s", hostname); + info("No match found for %.*s", name_size, name); return NULL; } diff --git a/src/table.h b/src/table.h index 8fc55efb..0aefc6fd 100644 --- a/src/table.h +++ b/src/table.h @@ -39,6 +39,8 @@ struct Table { char *name; /* Runtime fields */ + unsigned prefer_alpn; + struct Backend_head backends; SLIST_ENTRY(Table) entries; }; @@ -48,12 +50,18 @@ int accept_table_arg(struct Table *, char *); void add_table(struct Table_head *, struct Table *); struct Table *table_lookup(const struct Table_head *, const char *); const struct Address *table_lookup_server_address(const struct Table *, - const char *); + const char *, unsigned); void print_table_config(FILE *, struct Table *); int valid_table(struct Table *); void free_table(struct Table *); void init_table(struct Table *); +static inline struct Backend * +table_lookup_backend(const struct Table *table, const char *name, unsigned name_size) +{ + return lookup_backend(&table->backends, name, name_size); +} + void free_tables(struct Table_head *); #endif diff --git a/src/tls.c b/src/tls.c index bdc31206..ba739a98 100644 --- a/src/tls.c +++ b/src/tls.c @@ -53,9 +53,12 @@ static const char tls_alert[] = { 0x02, 0x28, /* Fatal, handshake failure */ }; -static int parse_tls_header(const char *, size_t, char **); -static int parse_extensions(const char *, size_t, char **); -static int parse_server_name_extension(const char *, size_t, char **); +static int parse_tls_header(const struct Listener *, const char*, size_t, struct ProtocolRes*); +static int parse_extensions(const struct Listener *, const char *, size_t, struct ProtocolRes*); +static int parse_server_name_extension(const struct Listener *, const char *, size_t, struct ProtocolRes*); +static int +parse_alpn_extension(const struct Listener * t, const char *data, size_t data_len, + struct ProtocolRes* pres); static const struct Protocol tls_protocol_st = { .name = "tls", @@ -81,14 +84,16 @@ const struct Protocol *tls_protocol = &tls_protocol_st; * < -4 - Invalid TLS client hello */ static int -parse_tls_header(const char *data, size_t data_len, char **hostname) { +parse_tls_header(const struct Listener * t, const char* data, size_t data_len, + struct ProtocolRes* pres) +{ char tls_content_type; char tls_version_major; char tls_version_minor; size_t pos = TLS_HEADER_LEN; size_t len; - if (hostname == NULL) + if (pres == NULL) return -3; /* Check that our TCP payload is at least large enough for a TLS header */ @@ -166,13 +171,19 @@ parse_tls_header(const char *data, size_t data_len, char **hostname) { if (pos + len > data_len) return -5; - return parse_extensions(data + pos, len, hostname); + return parse_extensions(t, data + pos, len, pres); } static int -parse_extensions(const char *data, size_t data_len, char **hostname) { +parse_extensions(const struct Listener * l, const char *data, size_t data_len, + struct ProtocolRes* pres) { size_t pos = 0; - size_t len; + size_t sn_pos, alpn_pos; + size_t len, alpn_len, sn_len; + int ret; + + sn_pos = alpn_pos = 0; + alpn_len = sn_len = 0; /* Parse each 4 bytes for the extension header */ while (pos + 4 < data_len) { @@ -181,27 +192,50 @@ parse_extensions(const char *data, size_t data_len, char **hostname) { (unsigned char)data[pos + 3]; /* Check if it's a server name extension */ - if (data[pos] == 0x00 && data[pos + 1] == 0x00) { - /* There can be only one extension of each type, so we break - our state and move p to beinnging of the extension here */ - if (pos + 4 + len > data_len) - return -5; - return parse_server_name_extension(data + pos + 4, len, hostname); + if (data[pos] == 0x00) { + if (data[pos + 1] == 0x00) { /* server name */ + /* There can be only one extension of each type, so we break + our state and move p to beinnging of the extension here */ + if (pos + 4 + len > data_len) + return -5; + + sn_pos = pos + 4; + sn_len = len; + } else if (data[pos + 1] == 0x10) { /* ALPN */ + if (pos + 4 + len > data_len) + return -5; + + alpn_pos = pos + 4; + alpn_len = len; + } } pos += 4 + len; /* Advance to the next extension header */ } + /* Check we ended where we expected to */ if (pos != data_len) return -5; + if ((l->prefer_alpn && alpn_pos != 0) || (sn_pos == 0 && alpn_pos != 0)) { + ret = parse_alpn_extension(l, data + alpn_pos, alpn_len, pres); + /* if we fail allow fall back to SNI */ + if (ret >= 0) + return ret; + } + + if (sn_pos != 0) { + return parse_server_name_extension(l, data + sn_pos, sn_len, pres); + } + return -2; } static int -parse_server_name_extension(const char *data, size_t data_len, - char **hostname) { +parse_server_name_extension(const struct Listener * l, const char *data, size_t data_len, + struct ProtocolRes* pres) { size_t pos = 2; /* skip server name list length */ size_t len; + char* hostname; while (pos + 3 < data_len) { len = ((unsigned char)data[pos + 1] << 8) + @@ -212,15 +246,19 @@ parse_server_name_extension(const char *data, size_t data_len, switch (data[pos]) { /* name type */ case 0x00: /* host_name */ - *hostname = malloc(len + 1); - if (*hostname == NULL) { + hostname = malloc(len + 1); + if (hostname == NULL) { err("malloc() failure"); return -4; } - strncpy(*hostname, data + pos + 3, len); + strncpy(hostname, data + pos + 3, len); - (*hostname)[len] = '\0'; + hostname[len] = '\0'; + + pres->name = hostname; + pres->name_size = len; + pres->name_type = NTYPE_HOST; return len; default: @@ -235,3 +273,54 @@ parse_server_name_extension(const char *data, size_t data_len, return -2; } + +static int is_alpn_proto_known(const struct Listener * l, const char* name, unsigned name_size) +{ +#ifdef SELF_CONTAINED + return 1; +#else + if (table_lookup_backend(l->alpn_table, name, name_size) != NULL) + return 1; + return 0; +#endif +} + +static int +parse_alpn_extension(const struct Listener * l, const char *data, size_t data_len, + struct ProtocolRes* pres) { + size_t pos = 2; + size_t len; + char* hostname; + + while (pos + 1 < data_len) { + len = (unsigned char)data[pos]; + + if (pos + 1 + len > data_len) + return -5; + + if (len > 0 && is_alpn_proto_known(l, data + pos + 1, len)) { + hostname = malloc(len + 1); + if (hostname == NULL) { + err("malloc() failure"); + return -4; + } + + memcpy(hostname, data + pos + 1, len); + hostname[len] = '\0'; + + pres->name = hostname; + pres->name_size = len; + pres->name_type = NTYPE_ALPN; + + return len; + } else if (len > 0) { + debug("Unknown ALPN name: %.*s", (int)len, data + pos + 2); + } + pos += 1 + len; + } + /* Check we ended where we expected to */ + if (pos != data_len) + return -5; + + return -2; +} diff --git a/tests/Makefile.am b/tests/Makefile.am index 48397335..06db8c11 100644 --- a/tests/Makefile.am +++ b/tests/Makefile.am @@ -1,4 +1,4 @@ -AM_CPPFLAGS = -I$(top_srcdir)/src -g +AM_CPPFLAGS = -I$(top_srcdir)/src -g -DSELF_CONTAINED TESTS = http_test \ tls_test \ diff --git a/tests/http_test.c b/tests/http_test.c index d1cfdcde..e6056189 100644 --- a/tests/http_test.c +++ b/tests/http_test.c @@ -44,11 +44,14 @@ int main() { unsigned int i; int result; char *hostname; + struct Listener l; + + memset(&l, 0, sizeof(l)); for (i = 0; i < sizeof(good) / sizeof(const char *); i++) { hostname = NULL; - result = http_protocol->parse_packet(good[i], strlen(good[i]), &hostname); + result = http_protocol->parse_packet(&l, good[i], strlen(good[i]), &hostname); assert(result == 9); @@ -62,7 +65,7 @@ int main() { for (i = 0; i < sizeof(bad) / sizeof(const char *); i++) { hostname = NULL; - result = http_protocol->parse_packet(bad[i], strlen(bad[i]), &hostname); + result = http_protocol->parse_packet(&l, bad[i], strlen(bad[i]), &hostname); assert(result < 0); diff --git a/tests/tls_test.c b/tests/tls_test.c index 4edecffe..f0c8e879 100644 --- a/tests/tls_test.c +++ b/tests/tls_test.c @@ -7,6 +7,41 @@ struct test_packet { const char *packet; int len; + unsigned expected_type; + const char* expected_name; + int expected_res; +}; + +const unsigned char alpn_good_data_1[] = { + 0x16, 0x03, 0x00, 0x01, 0x16, 0x01, 0x00, 0x01, 0x12, 0x03, + 0x03, 0x53, 0x03, 0x59, 0x03, 0xa8, 0xb2, 0xa9, 0x36, 0x4c, + 0x2d, 0x04, 0x72, 0x4f, 0xea, 0x98, 0xd5, 0xb5, 0xbb, 0xea, + 0x07, 0x4f, 0x00, 0x83, 0x1c, 0xfa, 0xa0, 0x01, 0xcc, 0x7d, + 0x2f, 0x4f, 0x6f, 0x00, 0x00, 0x84, 0xc0, 0x2b, 0xc0, 0x2c, + 0xc0, 0x86, 0xc0, 0x87, 0xc0, 0x09, 0xc0, 0x23, 0xc0, 0x0a, + 0xc0, 0x24, 0xc0, 0x72, 0xc0, 0x73, 0xc0, 0x08, 0xc0, 0x07, + 0xc0, 0x2f, 0xc0, 0x30, 0xc0, 0x8a, 0xc0, 0x8b, 0xc0, 0x13, + 0xc0, 0x27, 0xc0, 0x14, 0xc0, 0x28, 0xc0, 0x76, 0xc0, 0x77, + 0xc0, 0x12, 0xc0, 0x11, 0x00, 0x9c, 0x00, 0x9d, 0xc0, 0x7a, + 0xc0, 0x7b, 0x00, 0x2f, 0x00, 0x3c, 0x00, 0x35, 0x00, 0x3d, + 0x00, 0x41, 0x00, 0xba, 0x00, 0x84, 0x00, 0xc0, 0x00, 0x0a, + 0x00, 0x05, 0x00, 0x04, 0x00, 0x9e, 0x00, 0x9f, 0xc0, 0x7c, + 0xc0, 0x7d, 0x00, 0x33, 0x00, 0x67, 0x00, 0x39, 0x00, 0x6b, + 0x00, 0x45, 0x00, 0xbe, 0x00, 0x88, 0x00, 0xc4, 0x00, 0x16, + 0x00, 0xa2, 0x00, 0xa3, 0xc0, 0x80, 0xc0, 0x81, 0x00, 0x32, + 0x00, 0x40, 0x00, 0x38, 0x00, 0x6a, 0x00, 0x44, 0x00, 0xbd, + 0x00, 0x87, 0x00, 0xc3, 0x00, 0x13, 0x00, 0x66, 0x01, 0x00, + 0x00, 0x65, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x0c, 0x00, 0x00, 0x09, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x68, 0x6f, 0x73, 0x74, 0xff, + 0x01, 0x00, 0x01, 0x00, 0x00, 0x23, 0x00, 0x00, 0x00, 0x0a, + 0x00, 0x08, 0x00, 0x06, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, + 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x0d, 0x00, 0x1c, + 0x00, 0x1a, 0x04, 0x01, 0x04, 0x02, 0x04, 0x03, 0x05, 0x01, + 0x05, 0x03, 0x06, 0x01, 0x06, 0x03, 0x03, 0x01, 0x03, 0x02, + 0x03, 0x03, 0x02, 0x01, 0x02, 0x02, 0x02, 0x03, 0x00, 0x10, + 0x00, 0x0b, 0x00, 0x09, 0x08, 0x68, 0x74, 0x74, 0x70, 0x2f, + 0x32, 0x2e, 0x30 }; const unsigned char good_data_1[] = { @@ -59,44 +94,51 @@ const unsigned char bad_data_2[] = { }; static struct test_packet good[] = { - { (char *)good_data_1, sizeof(good_data_1) } + { (char *)good_data_1, sizeof(good_data_1), NTYPE_HOST, "localhost", 9 }, + { (char *)alpn_good_data_1, sizeof(alpn_good_data_1), NTYPE_ALPN, "http/2.0", 8 } }; static struct test_packet bad[] = { - { (char *)bad_data_1, sizeof(bad_data_1) }, - { (char *)bad_data_2, sizeof(bad_data_2) } + { (char *)bad_data_1, sizeof(bad_data_1), 0, "localhost"}, + { (char *)bad_data_2, sizeof(bad_data_2), 0, "localhost"} }; int main() { unsigned int i; int result; - char *hostname; + struct ProtocolRes res; + struct Listener l; + + memset(&l, 0, sizeof(l)); + l.prefer_alpn = 1; for (i = 0; i < sizeof(good) / sizeof(struct test_packet); i++) { - hostname = NULL; + memset(&res, 0, sizeof(res)); + + result = tls_protocol->parse_packet(&l, good[i].packet, good[i].len, &res); - result = tls_protocol->parse_packet(good[i].packet, good[i].len, &hostname); + assert(result == good[i].expected_res); - assert(result == 9); + assert(NULL != res.name); - assert(NULL != hostname); + assert(good[i].expected_type == res.name_type); - assert(0 == strcmp("localhost", hostname)); + assert(0 == strcmp(good[i].expected_name, res.name)); - free(hostname); + free(res.name); } for (i = 0; i < sizeof(bad) / sizeof(struct test_packet); i++) { - hostname = NULL; + memset(&res, 0, sizeof(res)); - result = tls_protocol->parse_packet(bad[i].packet, bad[i].len, &hostname); + result = tls_protocol->parse_packet(&l, bad[i].packet, bad[i].len, &res); // parse failure or not "localhost" assert(result < 0 || - hostname == NULL || - strcmp("localhost", hostname) != 0); + res.name == NULL || + strcmp(bad[i].expected_name, res.name) != 0); - free(hostname); + free(res.name); } return 0;