diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 6807fd5c..2db3154e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,5 +14,5 @@ jobs: run: | curl -ssL https://magic.modular.com | bash source $HOME/.bash_profile - magic run mojo run_tests.mojo + magic run test diff --git a/.gitignore b/.gitignore index 06727c22..b89f8c78 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,7 @@ install_id .magic # Rattler -output \ No newline at end of file +output + +# misc +.vscode \ No newline at end of file diff --git a/README.md b/README.md index b2cc4456..304bfa39 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ Once you have a Mojo project set up locally, ``` or import individual structs and functions, e.g. ```mojo - from lightbug_http.http import HTTPService, HTTPRequest, HTTPResponse, OK, NotFound + from lightbug_http.service import HTTPService + from lightbug_http.http import HTTPRequest, HTTPResponse, OK, NotFound ``` there are some default handlers you can play with: ```mojo diff --git a/bench.mojo b/bench.mojo index 744a1d2d..1c455651 100644 --- a/bench.mojo +++ b/bench.mojo @@ -4,12 +4,6 @@ from lightbug_http.header import Headers, Header from lightbug_http.utils import ByteReader, ByteWriter from lightbug_http.http import HTTPRequest, HTTPResponse, encode from lightbug_http.uri import URI -from tests.utils import ( - TestStruct, - FakeResponder, - new_fake_listener, - FakeServer, -) alias headers = bytes( """GET /index.html HTTP/1.1\r\nHost: example.com\r\nUser-Agent: Mozilla/5.0\r\nContent-Type: text/html\r\nContent-Length: 1234\r\nConnection: close\r\nTrailer: end-of-message\r\n\r\n""" @@ -148,43 +142,3 @@ fn lightbug_benchmark_header_parse(inout b: Bencher): b.iter[header_parse]() - -fn lightbug_benchmark_server(): - var server_report = benchmark.run[run_fake_server](max_iters=1) - print("Server: ") - server_report.print(benchmark.Unit.ms) - - -fn lightbug_benchmark_misc() -> None: - var direct_set_report = benchmark.run[init_test_and_set_a_direct]( - max_iters=1 - ) - - var recreating_set_report = benchmark.run[init_test_and_set_a_copy]( - max_iters=1 - ) - - print("Direct set: ") - direct_set_report.print(benchmark.Unit.ms) - print("Recreating set: ") - recreating_set_report.print(benchmark.Unit.ms) - - -var GetRequest = HTTPRequest(URI.parse("http://127.0.0.1/path")[URI]) - - -fn run_fake_server(): - var handler = FakeResponder() - var listener = new_fake_listener(2, encode(GetRequest)) - var server = FakeServer(listener, handler) - server.serve() - - -fn init_test_and_set_a_copy() -> None: - var test = TestStruct("a", "b") - _ = test.set_a_copy("c") - - -fn init_test_and_set_a_direct() -> None: - var test = TestStruct("a", "b") - _ = test.set_a_direct("c") diff --git a/lightbug_http/header.mojo b/lightbug_http/header.mojo index 86bb0459..3e9db7a3 100644 --- a/lightbug_http/header.mojo +++ b/lightbug_http/header.mojo @@ -12,6 +12,8 @@ struct HeaderKey: alias CONTENT_LENGTH = "content-length" alias CONTENT_ENCODING = "content-encoding" alias DATE = "date" + alias LOCATION = "location" + alias HOST = "host" @value @@ -70,16 +72,12 @@ struct Headers(Formattable, Stringable): self._inner[key.lower()] = value fn content_length(self) -> Int: - if HeaderKey.CONTENT_LENGTH not in self: - return 0 try: return int(self[HeaderKey.CONTENT_LENGTH]) except: return 0 - fn parse_raw( - inout self, inout r: ByteReader - ) raises -> (String, String, String): + fn parse_raw(inout self, inout r: ByteReader) raises -> (String, String, String): var first_byte = r.peek() if not first_byte: raise Error("Failed to read first byte from response header") diff --git a/lightbug_http/http.mojo b/lightbug_http/http.mojo index 71b8183f..6d8804a3 100644 --- a/lightbug_http/http.mojo +++ b/lightbug_http/http.mojo @@ -35,6 +35,15 @@ fn encode(owned res: HTTPResponse) -> Bytes: return res._encoded() +struct StatusCode: + alias OK = 200 + alias MOVED_PERMANENTLY = 301 + alias FOUND = 302 + alias TEMPORARY_REDIRECT = 307 + alias PERMANENT_REDIRECT = 308 + alias NOT_FOUND = 404 + + @value struct HTTPRequest(Formattable, Stringable): var headers: Headers @@ -48,9 +57,7 @@ struct HTTPRequest(Formattable, Stringable): var timeout: Duration @staticmethod - fn from_bytes( - addr: String, max_body_size: Int, owned b: Bytes - ) raises -> HTTPRequest: + fn from_bytes(addr: String, max_body_size: Int, owned b: Bytes) raises -> HTTPRequest: var reader = ByteReader(b^) var headers = Headers() var method: String @@ -65,16 +72,10 @@ struct HTTPRequest(Formattable, Stringable): var content_length = headers.content_length() - if ( - content_length > 0 - and max_body_size > 0 - and content_length > max_body_size - ): + if content_length > 0 and max_body_size > 0 and content_length > max_body_size: raise Error("Request body too large") - var request = HTTPRequest( - uri, headers=headers, method=method, protocol=protocol - ) + var request = HTTPRequest(uri, headers=headers, method=method, protocol=protocol) try: request.read_body(reader, content_length, max_body_size) @@ -103,6 +104,8 @@ struct HTTPRequest(Formattable, Stringable): self.set_content_length(len(body)) if HeaderKey.CONNECTION not in self.headers: self.set_connection_close() + if HeaderKey.HOST not in self.headers: + self.headers[HeaderKey.HOST] = uri.host fn set_connection_close(inout self): self.headers[HeaderKey.CONNECTION] = "close" @@ -114,20 +117,22 @@ struct HTTPRequest(Formattable, Stringable): return self.headers[HeaderKey.CONNECTION] == "close" @always_inline - fn read_body( - inout self, inout r: ByteReader, content_length: Int, max_body_size: Int - ) raises -> None: + fn read_body(inout self, inout r: ByteReader, content_length: Int, max_body_size: Int) raises -> None: if content_length > max_body_size: raise Error("Request body too large") - r.consume(self.body_raw) + r.consume(self.body_raw, content_length) self.set_content_length(content_length) fn format_to(self, inout writer: Formatter): + writer.write(self.method, whitespace) + path = self.uri.path if len(self.uri.path) > 1 else strSlash + if len(self.uri.query_string) > 0: + path += "?" + self.uri.query_string + + writer.write(path) + writer.write( - self.method, - whitespace, - self.uri.path if len(self.uri.path) > 1 else strSlash, whitespace, self.protocol, lineBreak, @@ -147,6 +152,8 @@ struct HTTPRequest(Formattable, Stringable): writer.write(self.method) writer.write(whitespace) var path = self.uri.path if len(self.uri.path) > 1 else strSlash + if len(self.uri.query_string) > 0: + path += "?" + self.uri.query_string writer.write(path) writer.write(whitespace) writer.write(self.protocol) @@ -215,8 +222,16 @@ struct HTTPResponse(Formattable, Stringable): self.status_text = status_text self.protocol = protocol self.body_raw = body_bytes - self.set_connection_keep_alive() - self.set_content_length(len(body_bytes)) + if HeaderKey.CONNECTION not in self.headers: + self.set_connection_keep_alive() + if HeaderKey.CONTENT_LENGTH not in self.headers: + self.set_content_length(len(body_bytes)) + if HeaderKey.DATE not in self.headers: + try: + var current_time = now(utc=True).__str__() + self.headers[HeaderKey.DATE] = current_time + except: + pass fn get_body_bytes(self) -> Bytes: return self.body_raw @@ -236,9 +251,25 @@ struct HTTPResponse(Formattable, Stringable): fn set_content_length(inout self, l: Int): self.headers[HeaderKey.CONTENT_LENGTH] = str(l) + @always_inline + fn content_length(inout self) -> Int: + try: + return int(self.headers[HeaderKey.CONTENT_LENGTH]) + except: + return 0 + + fn is_redirect(self) -> Bool: + return ( + self.status_code == StatusCode.MOVED_PERMANENTLY + or self.status_code == StatusCode.FOUND + or self.status_code == StatusCode.TEMPORARY_REDIRECT + or self.status_code == StatusCode.PERMANENT_REDIRECT + ) + @always_inline fn read_body(inout self, inout r: ByteReader) raises -> None: - r.consume(self.body_raw) + r.consume(self.body_raw, self.content_length()) + self.set_content_length(len(self.body_raw)) fn format_to(self, inout writer: Formatter): writer.write( @@ -252,13 +283,6 @@ struct HTTPResponse(Formattable, Stringable): lineBreak, ) - if HeaderKey.DATE not in self.headers: - try: - var current_time = now(utc=True).__str__() - write_header(writer, HeaderKey.DATE, current_time) - except: - pass - self.headers.format_to(writer) writer.write(lineBreak) @@ -326,9 +350,7 @@ fn OK(body: Bytes, content_type: String) -> HTTPResponse: ) -fn OK( - body: Bytes, content_type: String, content_encoding: String -) -> HTTPResponse: +fn OK(body: Bytes, content_type: String, content_encoding: String) -> HTTPResponse: return HTTPResponse( headers=Headers( Header(HeaderKey.CONTENT_TYPE, content_type), diff --git a/lightbug_http/libc.mojo b/lightbug_http/libc.mojo index 506eff23..81d87b32 100644 --- a/lightbug_http/libc.mojo +++ b/lightbug_http/libc.mojo @@ -460,9 +460,7 @@ fn inet_ntop( ](af, src, dst, size) -fn inet_pton( - af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void] -) -> c_int: +fn inet_pton(af: c_int, src: UnsafePointer[c_char], dst: UnsafePointer[c_void]) -> c_int: """Libc POSIX `inet_pton` function Reference: https://man7.org/linux/man-pages/man3/inet_ntop.3p.html Fn signature: int inet_pton(int af, const char *restrict src, void *restrict dst). @@ -513,9 +511,7 @@ fn socket(domain: c_int, type: c_int, protocol: c_int) -> c_int: protocol: The protocol to use. Returns: A File Descriptor or -1 in case of failure. """ - return external_call[ - "socket", c_int, c_int, c_int, c_int # FnName, RetType # Args - ](domain, type, protocol) + return external_call["socket", c_int, c_int, c_int, c_int](domain, type, protocol) # FnName, RetType # Args fn setsockopt( @@ -593,16 +589,12 @@ fn getpeername( ](sockfd, addr, address_len) -fn bind( - socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t -) -> c_int: +fn bind(socket: c_int, address: UnsafePointer[sockaddr], address_len: socklen_t) -> c_int: """Libc POSIX `bind` function Reference: https://man7.org/linux/man-pages/man3/bind.3p.html Fn signature: int bind(int socket, const struct sockaddr *address, socklen_t address_len). """ - return external_call[ - "bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t - ](socket, address, address_len) + return external_call["bind", c_int, c_int, UnsafePointer[sockaddr], socklen_t](socket, address, address_len) fn listen(socket: c_int, backlog: c_int) -> c_int: @@ -640,9 +632,7 @@ fn accept( ](socket, address, address_len) -fn connect( - socket: c_int, address: Reference[sockaddr], address_len: socklen_t -) -> c_int: +fn connect(socket: c_int, address: Reference[sockaddr], address_len: socklen_t) -> c_int: """Libc POSIX `connect` function Reference: https://man7.org/linux/man-pages/man3/connect.3p.html Fn signature: int connect(int socket, const struct sockaddr *address, socklen_t address_len). @@ -675,9 +665,7 @@ fn recv( ](socket, buffer, length, flags) -fn send( - socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int -) -> c_ssize_t: +fn send(socket: c_int, buffer: UnsafePointer[c_void], length: c_size_t, flags: c_int) -> c_ssize_t: """Libc POSIX `send` function Reference: https://man7.org/linux/man-pages/man3/send.3p.html Fn signature: ssize_t send(int socket, const void *buffer, size_t length, int flags). @@ -700,11 +688,7 @@ fn shutdown(socket: c_int, how: c_int) -> c_int: how: How to shutdown the socket. Returns: 0 on success, -1 on error. """ - return external_call[ - "shutdown", c_int, c_int, c_int - ]( # FnName, RetType # Args - socket, how - ) + return external_call["shutdown", c_int, c_int, c_int](socket, how) # FnName, RetType # Args fn getaddrinfo( @@ -735,9 +719,7 @@ fn gai_strerror(ecode: c_int) -> UnsafePointer[c_char]: Args: ecode: The error code. Returns: A UnsafePointer to a string describing the error. """ - return external_call[ - "gai_strerror", UnsafePointer[c_char], c_int # FnName, RetType # Args - ](ecode) + return external_call["gai_strerror", UnsafePointer[c_char], c_int](ecode) # FnName, RetType # Args fn inet_pton(address_family: Int, address: String) -> Int: @@ -746,9 +728,7 @@ fn inet_pton(address_family: Int, address: String) -> Int: ip_buf_size = 16 var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) - var conv_status = inet_pton( - rebind[c_int](address_family), to_char_ptr(address), ip_buf - ) + var conv_status = inet_pton(rebind[c_int](address_family), to_char_ptr(address), ip_buf) return int(ip_buf.bitcast[c_uint]()) @@ -773,9 +753,7 @@ fn close(fildes: c_int) -> c_int: return external_call["close", c_int, c_int](fildes) -fn open[ - *T: AnyType -](path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int: +fn open[*T: AnyType](path: UnsafePointer[c_char], oflag: c_int, *args: *T) -> c_int: """Libc POSIX `open` function Reference: https://man7.org/linux/man-pages/man3/open.3p.html Fn signature: int open(const char *path, int oflag, ...). @@ -815,9 +793,7 @@ fn read(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: nbyte: The number of bytes to read. Returns: The number of bytes read or -1 in case of failure. """ - return external_call[ - "read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t - ](fildes, buf, nbyte) + return external_call["read", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte) fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: @@ -830,9 +806,7 @@ fn write(fildes: c_int, buf: UnsafePointer[c_void], nbyte: c_size_t) -> c_int: nbyte: The number of bytes to write. Returns: The number of bytes written or -1 in case of failure. """ - return external_call[ - "write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t - ](fildes, buf, nbyte) + return external_call["write", c_ssize_t, c_int, UnsafePointer[c_void], c_size_t](fildes, buf, nbyte) fn __test_getaddrinfo__(): @@ -854,8 +828,8 @@ fn __test_getaddrinfo__(): UnsafePointer.address_of(servinfo), ) var msg_ptr = gai_strerror(c_int(status)) - _ = external_call[ - "printf", c_int, UnsafePointer[c_char], UnsafePointer[c_char] - ](to_char_ptr("gai_strerror: %s"), msg_ptr) + _ = external_call["printf", c_int, UnsafePointer[c_char], UnsafePointer[c_char]]( + to_char_ptr("gai_strerror: %s"), msg_ptr + ) var msg = c_charptr_to_string(msg_ptr) print("getaddrinfo satus: " + msg) diff --git a/lightbug_http/net.mojo b/lightbug_http/net.mojo index 48beea6f..bc18b716 100644 --- a/lightbug_http/net.mojo +++ b/lightbug_http/net.mojo @@ -122,9 +122,7 @@ struct TCPAddr(Addr): fn string(self) -> String: if self.zone != "": - return join_host_port( - self.ip + "%" + self.zone, self.port.__str__() - ) + return join_host_port(self.ip + "%" + self.zone, self.port.__str__()) return join_host_port(self.ip, self.port.__str__()) @@ -145,11 +143,7 @@ fn resolve_internet_addr(network: String, address: String) raises -> TCPAddr: host = host_port.host port = host_port.port portnum = atol(port.__str__()) - elif ( - network == NetworkType.ip.value - or network == NetworkType.ip4.value - or network == NetworkType.ip6.value - ): + elif network == NetworkType.ip.value or network == NetworkType.ip4.value or network == NetworkType.ip6.value: if address != "": host = address elif network == NetworkType.unix.value: @@ -223,9 +217,7 @@ fn convert_binary_port_to_int(port: UInt16) -> Int: return int(ntohs(port)) -fn convert_binary_ip_to_string( - owned ip_address: UInt32, address_family: Int32, address_length: UInt32 -) -> String: +fn convert_binary_ip_to_string(owned ip_address: UInt32, address_family: Int32, address_length: UInt32) -> String: """Convert a binary IP address to a string by calling inet_ntop. Args: diff --git a/lightbug_http/server.mojo b/lightbug_http/server.mojo index 8f50c097..2d17ca57 100644 --- a/lightbug_http/server.mojo +++ b/lightbug_http/server.mojo @@ -17,9 +17,7 @@ trait ServerTrait: fn get_concurrency(self) -> Int: ... - fn listen_and_serve( - self, address: String, handler: HTTPService - ) raises -> None: + fn listen_and_serve(self, address: String, handler: HTTPService) raises -> None: ... fn serve(self, ln: Listener, handler: HTTPService) raises -> None: diff --git a/lightbug_http/service.mojo b/lightbug_http/service.mojo index e7c5eec4..5df31ccc 100644 --- a/lightbug_http/service.mojo +++ b/lightbug_http/service.mojo @@ -18,9 +18,7 @@ struct Printer(HTTPService): var header = req.headers print("Request protocol: ", req.protocol) print("Request method: ", req.method) - print( - "Request Content-Type: ", to_string(header[HeaderKey.CONTENT_TYPE]) - ) + print("Request Content-Type: ", to_string(header[HeaderKey.CONTENT_TYPE])) var body = req.body_raw print("Request Body: ", to_string(body)) diff --git a/lightbug_http/sys/client.mojo b/lightbug_http/sys/client.mojo index 9d30a931..f45dd0eb 100644 --- a/lightbug_http/sys/client.mojo +++ b/lightbug_http/sys/client.mojo @@ -12,26 +12,23 @@ from lightbug_http.strings import to_string from lightbug_http.client import Client from lightbug_http.net import default_buffer_size from lightbug_http.http import HTTPRequest, HTTPResponse, encode -from lightbug_http.header import Headers +from lightbug_http.header import Headers, HeaderKey from lightbug_http.sys.net import create_connection from lightbug_http.io.bytes import Bytes from lightbug_http.utils import ByteReader struct MojoClient(Client): - var fd: c_int var host: StringLiteral var port: Int var name: String fn __init__(inout self) raises: - self.fd = socket(AF_INET, SOCK_STREAM, 0) self.host = "127.0.0.1" self.port = 8888 self.name = "lightbug_http_client" fn __init__(inout self, host: StringLiteral, port: Int) raises: - self.fd = socket(AF_INET, SOCK_STREAM, 0) self.host = host self.port = port self.name = "lightbug_http_client" @@ -87,8 +84,10 @@ struct MojoClient(Client): port = 443 else: port = 80 - var conn = create_connection(self.fd, host_str, port) - var bytes_sent = conn.write(encode(req^)) + + # TODO: Actually handle persistent connections + var conn = create_connection(socket(AF_INET, SOCK_STREAM, 0), host_str, port) + var bytes_sent = conn.write(encode(req)) if bytes_sent == -1: raise Error("Failed to send message") @@ -97,11 +96,28 @@ struct MojoClient(Client): if bytes_recv == 0: conn.close() - try: - return HTTPResponse.from_bytes(new_buf^) + var res = HTTPResponse.from_bytes(new_buf^) + if res.is_redirect(): + conn.close() + return self._handle_redirect(req^, res^) + return res except e: conn.close() raise e return HTTPResponse(Bytes()) + + fn _handle_redirect( + self, owned original_req: HTTPRequest, owned original_response: HTTPResponse + ) raises -> HTTPResponse: + var new_uri: URI + var new_location = original_response.headers[HeaderKey.LOCATION] + if new_location.startswith("http"): + new_uri = URI.parse_raises(new_location) + original_req.headers[HeaderKey.HOST] = new_uri.host + else: + new_uri = original_req.uri + new_uri.path = new_location + original_req.uri = new_uri + return self.do(original_req^) diff --git a/lightbug_http/sys/net.mojo b/lightbug_http/sys/net.mojo index b3fc03da..5413b577 100644 --- a/lightbug_http/sys/net.mojo +++ b/lightbug_http/sys/net.mojo @@ -111,23 +111,16 @@ struct SysListener: var their_addr_ptr = Reference[sockaddr](their_addr) var sin_size = socklen_t(sizeof[socklen_t]()) var sin_size_ptr = Reference[socklen_t](sin_size) - var new_sockfd = external_call["accept", c_int]( - self.fd, their_addr_ptr, sin_size_ptr - ) + var new_sockfd = external_call["accept", c_int](self.fd, their_addr_ptr, sin_size_ptr) # var new_sockfd = accept( # self.fd, their_addr_ptr, UnsafePointer[socklen_t].address_of(sin_size) # ) if new_sockfd == -1: - print( - "Failed to accept connection, system accept() returned an" - " error." - ) + print("Failed to accept connection, system accept() returned an error.") var peer = get_peer_name(new_sockfd) - return SysConnection( - self.__addr, TCPAddr(peer.host, atol(peer.port)), new_sockfd - ) + return SysConnection(self.__addr, TCPAddr(peer.host, atol(peer.port)), new_sockfd) fn close(self) raises: _ = shutdown(self.fd, SHUT_RDWR) @@ -148,9 +141,7 @@ struct SysListenConfig(ListenConfig): fn __init__(inout self, keep_alive: Duration) raises: self.__keep_alive = keep_alive - fn listen( - inout self, network: String, address: String - ) raises -> SysListener: + fn listen(inout self, network: String, address: String) raises -> SysListener: var addr = resolve_internet_addr(network, address) var address_family = AF_INET var ip_buf_size = 4 @@ -174,30 +165,21 @@ struct SysListenConfig(ListenConfig): var bind_fail_logged = False var ip_buf = UnsafePointer[c_void].alloc(ip_buf_size) - var conv_status = inet_pton( - address_family, to_char_ptr(addr.ip), ip_buf - ) + var conv_status = inet_pton(address_family, to_char_ptr(addr.ip), ip_buf) var raw_ip = ip_buf.bitcast[c_uint]()[] var bin_port = htons(UInt16(addr.port)) - var ai = sockaddr_in( - address_family, bin_port, raw_ip, StaticTuple[c_char, 8]() - ) + var ai = sockaddr_in(address_family, bin_port, raw_ip, StaticTuple[c_char, 8]()) var ai_ptr = Reference[sockaddr_in](ai) while not bind_success: # var bind = bind(sockfd, ai_ptr, sizeof[sockaddr_in]()) - var bind = external_call["bind", c_int]( - sockfd, ai_ptr, sizeof[sockaddr_in]() - ) + var bind = external_call["bind", c_int](sockfd, ai_ptr, sizeof[sockaddr_in]()) if bind == 0: bind_success = True else: if not bind_fail_logged: - print( - "Bind attempt failed. The address might be in use or" - " the socket might not be available." - ) + print("Bind attempt failed. The address might be in use or the socket might not be available.") print("Retrying. Might take 10-15 seconds.") bind_fail_logged = True print(".", end="", flush=True) @@ -209,13 +191,7 @@ struct SysListenConfig(ListenConfig): var listener = SysListener(addr, sockfd) - print( - "\nšŸ”„šŸ Lightbug is listening on " - + "http://" - + addr.ip - + ":" - + addr.port.__str__() - ) + print("\nšŸ”„šŸ Lightbug is listening on " + "http://" + addr.ip + ":" + addr.port.__str__()) print("Ready to accept connections...") return listener @@ -359,10 +335,7 @@ struct addrinfo_macos(AnAddrInfo): var ai_addr = addrinfo.ai_addr if not ai_addr: print("ai_addr is null") - raise Error( - "Failed to get IP address. getaddrinfo was called successfully," - " but ai_addr is null." - ) + raise Error("Failed to get IP address. getaddrinfo was called successfully, but ai_addr is null.") var addr_in = ai_addr.bitcast[sockaddr_in]()[] @@ -430,19 +403,14 @@ struct addrinfo_unix(AnAddrInfo): var ai_addr = addrinfo.ai_addr if not ai_addr: print("ai_addr is null") - raise Error( - "Failed to get IP address. getaddrinfo was called successfully," - " but ai_addr is null." - ) + raise Error("Failed to get IP address. getaddrinfo was called successfully, but ai_addr is null.") var addr_in = ai_addr.bitcast[sockaddr_in]()[] return addr_in.sin_addr -fn create_connection( - sock: c_int, host: String, port: UInt16 -) raises -> SysConnection: +fn create_connection(sock: c_int, host: String, port: UInt16) raises -> SysConnection: """ Connect to a server using a socket. @@ -461,15 +429,10 @@ fn create_connection( ip = addrinfo_unix().get_ip_address(host) # Convert ip address to network byte order. - var addr: sockaddr_in = sockaddr_in( - AF_INET, htons(port), ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0) - ) + var addr: sockaddr_in = sockaddr_in(AF_INET, htons(port), ip, StaticTuple[c_char, 8](0, 0, 0, 0, 0, 0, 0, 0)) var addr_ptr = Reference[sockaddr_in](addr) - if ( - external_call["connect", c_int](sock, addr_ptr, sizeof[sockaddr_in]()) - == -1 - ): + if external_call["connect", c_int](sock, addr_ptr, sizeof[sockaddr_in]()) == -1: _ = shutdown(sock, SHUT_RDWR) raise Error("Failed to connect to server") diff --git a/lightbug_http/sys/server.mojo b/lightbug_http/sys/server.mojo index 08d78ea8..d550fd5f 100644 --- a/lightbug_http/sys/server.mojo +++ b/lightbug_http/sys/server.mojo @@ -82,9 +82,7 @@ struct SysServer: self.tcp_keep_alive = False self.ln = SysListener() - fn __init__( - inout self, max_request_body_size: Int, tcp_keep_alive: Bool - ) raises: + fn __init__(inout self, max_request_body_size: Int, tcp_keep_alive: Bool) raises: self.error_handler = ErrorHandler() self.name = "lightbug_http" self.__address = "127.0.0.1" @@ -121,9 +119,7 @@ struct SysServer: concurrency = DefaultConcurrency return concurrency - fn listen_and_serve[ - T: HTTPService - ](inout self, address: String, handler: T) raises -> None: + fn listen_and_serve[T: HTTPService](inout self, address: String, handler: T) raises -> None: """ Listen for incoming connections and serve HTTP requests. @@ -136,9 +132,7 @@ struct SysServer: _ = self.set_address(address) self.serve(listener, handler) - fn serve[ - T: HTTPService - ](inout self, ln: SysListener, handler: T) raises -> None: + fn serve[T: HTTPService](inout self, ln: SysListener, handler: T) raises -> None: """ Serve HTTP requests. @@ -155,9 +149,7 @@ struct SysServer: var conn = self.ln.accept() self.serve_connection(conn, handler) - fn serve_connection[ - T: HTTPService - ](inout self, conn: SysConnection, handler: T) raises -> None: + fn serve_connection[T: HTTPService](inout self, conn: SysConnection, handler: T) raises -> None: """ Serve a single connection. @@ -189,9 +181,7 @@ struct SysServer: conn.close() break - var request = HTTPRequest.from_bytes( - self.address(), max_request_body_size, b^ - ) + var request = HTTPRequest.from_bytes(self.address(), max_request_body_size, b^) var res = handler.func(request) diff --git a/lightbug_http/uri.mojo b/lightbug_http/uri.mojo index 51b09c22..d188e2f2 100644 --- a/lightbug_http/uri.mojo +++ b/lightbug_http/uri.mojo @@ -12,7 +12,7 @@ from lightbug_http.strings import ( @value -struct URI: +struct URI(Formattable, Stringable): var __path_original: String var scheme: String var path: String @@ -35,7 +35,7 @@ struct URI: return "Failed to parse URI: " + str(e) return u - + @staticmethod fn parse_raises(uri: String) raises -> URI: var u = URI(uri) @@ -47,7 +47,7 @@ struct URI: uri: String = "", ) -> None: self.__path_original = "/" - self.scheme = + self.scheme = "" self.path = "/" self.query_string = "" self.__hash = "" @@ -57,6 +57,15 @@ struct URI: self.username = "" self.password = "" + fn __str__(self) -> String: + var s = self.scheme + "://" + self.host + self.path + if len(self.query_string) > 0: + s += "?" + self.query_string + return s + + fn format_to(self, inout writer: Formatter): + writer.write(str(self)) + fn is_https(self) -> Bool: return self.scheme == https @@ -74,12 +83,12 @@ struct URI: proto_str = raw_uri[:proto_end] if proto_str == https: is_https = True - remainder_uri = raw_uri[proto_end + 3:] + remainder_uri = raw_uri[proto_end + 3 :] else: remainder_uri = raw_uri - + self.scheme = proto_str^ - + var path_start = remainder_uri.find("/") var host_and_port: String var request_uri: String @@ -96,7 +105,7 @@ struct URI: self.scheme = https else: self.scheme = http - + var n = request_uri.find("?") if n >= 0: self.__path_original = request_uri[:n] diff --git a/lightbug_http/utils.mojo b/lightbug_http/utils.mojo index 8ca06387..7055edc3 100644 --- a/lightbug_http/utils.mojo +++ b/lightbug_http/utils.mojo @@ -90,9 +90,15 @@ struct ByteReader: self.read_pos += v @always_inline - fn consume(inout self, inout buffer: Bytes): + fn consume(inout self, inout buffer: Bytes, bytes_len: Int = -1): var pos = self.read_pos - self.read_pos = -1 - var read_len = len(self._inner) - pos + var read_len: Int + if bytes_len == -1: + self.read_pos = -1 + read_len = len(self._inner) - pos + else: + self.read_pos += bytes_len + read_len = bytes_len + buffer.resize(read_len, 0) memcpy(buffer.data, self._inner.data + pos, read_len) diff --git a/mojoproject.toml b/mojoproject.toml index c635214c..6d9b887d 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -14,10 +14,11 @@ repository = "https://github.com/saviorand/lightbug_http" build = { cmd = "bash scripts/build.sh nightly", env = {MODULAR_MOJO_IMPORT_PATH = "$CONDA_PREFIX/lib/mojo"} } publish = { cmd = "bash scripts/publish.sh mojo-community-nightly", env = { PREFIX_API_KEY = "$PREFIX_API_KEY" } } bp = { depends_on=["build", "publish"] } -test = { cmd = "magic run mojo run_tests.mojo" } template = { cmd = "magic run python scripts/templater.py" } +test = { cmd = "magic run mojo test -I . tests" } bench = { cmd = "magic run mojo bench.mojo" } bench_server = { cmd = "magic run mojo build bench_server.mojo && ./bench_server ; rm bench_server" } +format = { cmd = "magic run mojo format -l 120 lightbug_http" } [dependencies] max = ">=24.5.0,<25" diff --git a/run_tests.mojo b/run_tests.mojo deleted file mode 100644 index ccfd1244..00000000 --- a/run_tests.mojo +++ /dev/null @@ -1,13 +0,0 @@ -from tests.test_io import test_io -from tests.test_http import test_http -from tests.test_header import test_header -from tests.test_uri import test_uri -from tests.test_client import test_client - - -fn main() raises: - test_io() - test_http() - test_header() - test_uri() - test_client() diff --git a/tests/__init__.mojo b/tests/__init__.mojo deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_io.mojo b/tests/lightbug_http/io/test_bytes.mojo similarity index 96% rename from tests/test_io.mojo rename to tests/lightbug_http/io/test_bytes.mojo index 9520659a..79748733 100644 --- a/tests/test_io.mojo +++ b/tests/lightbug_http/io/test_bytes.mojo @@ -3,10 +3,6 @@ from collections import Dict, List from lightbug_http.io.bytes import Bytes, bytes_equal, bytes -def test_io(): - test_string_literal_to_bytes() - - fn test_string_literal_to_bytes() raises: var cases = Dict[StringLiteral, Bytes]() cases[""] = Bytes() diff --git a/tests/lightbug_http/sys/test_client.mojo b/tests/lightbug_http/sys/test_client.mojo new file mode 100644 index 00000000..2631af61 --- /dev/null +++ b/tests/lightbug_http/sys/test_client.mojo @@ -0,0 +1,93 @@ +import testing +from lightbug_http.sys.client import MojoClient +from lightbug_http.http import HTTPRequest, encode +from lightbug_http.uri import URI +from lightbug_http.header import Header, Headers +from lightbug_http.io.bytes import bytes + + +fn test_mojo_client_redirect_external_req_google() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://google.com"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + except e: + print(e) + +fn test_mojo_client_redirect_external_req_302() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://httpbin.org/status/302"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + except e: + print(e) + +fn test_mojo_client_redirect_external_req_308() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://httpbin.org/status/308"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + except e: + print(e) + +fn test_mojo_client_redirect_external_req_307() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://httpbin.org/status/307"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + except e: + print(e) + +fn test_mojo_client_redirect_external_req_301() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://httpbin.org/status/301"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + testing.assert_equal(res.headers.content_length(), 228) + except e: + print(e) + +fn test_mojo_client_lightbug_external_req_200() raises: + var client = MojoClient() + var req = HTTPRequest( + uri=URI.parse_raises("http://httpbin.org/status/200"), + headers=Headers( + Header("Connection", "close")), + method="GET", + ) + + try: + var res = client.do(req) + testing.assert_equal(res.status_code, 200) + except e: + print(e) diff --git a/tests/test_header.mojo b/tests/lightbug_http/test_header.mojo similarity index 91% rename from tests/test_header.mojo rename to tests/lightbug_http/test_header.mojo index eed820f8..438b2c28 100644 --- a/tests/test_header.mojo +++ b/tests/lightbug_http/test_header.mojo @@ -2,15 +2,8 @@ from testing import assert_equal, assert_true from lightbug_http.utils import ByteReader from lightbug_http.header import Headers, Header from lightbug_http.io.bytes import Bytes, bytes -from lightbug_http.strings import empty_string -from lightbug_http.net import default_buffer_size -def test_header(): - test_parse_request_header() - test_parse_response_header() - test_header_case_insensitive() - def test_header_case_insensitive(): var headers = Headers(Header("Host", "SomeHost")) diff --git a/tests/test_http.mojo b/tests/lightbug_http/test_http.mojo similarity index 78% rename from tests/test_http.mojo rename to tests/lightbug_http/test_http.mojo index c7f40d5d..2a9a1cb3 100644 --- a/tests/test_http.mojo +++ b/tests/lightbug_http/test_http.mojo @@ -5,16 +5,13 @@ from lightbug_http.http import HTTPRequest, HTTPResponse, encode from lightbug_http.header import Header, Headers, HeaderKey from lightbug_http.uri import URI from lightbug_http.strings import to_string -from tests.utils import default_server_conn_string +alias default_server_conn_string = "http://localhost:8080" -def test_http(): - test_encode_http_request() - test_encode_http_response() def test_encode_http_request(): - var uri = URI(default_server_conn_string + "/foobar?baz") + var uri = URI.parse_raises(default_server_conn_string + "/foobar?baz") var req = HTTPRequest( uri, body=String("Hello world!").as_bytes(), @@ -23,12 +20,16 @@ def test_encode_http_request(): var as_str = str(req) var req_encoded = to_string(encode(req^)) + + + var expected = String( + "GET /foobar?baz HTTP/1.1\r\nconnection: keep-alive\r\ncontent-length:" + " 12\r\nhost: localhost:8080\r\n\r\nHello world!" + ) + testing.assert_equal( req_encoded, - ( - "GET / HTTP/1.1\r\nconnection: keep-alive\r\ncontent-length:" - " 12\r\n\r\nHello world!" - ), + expected ) testing.assert_equal(req_encoded, as_str) diff --git a/tests/test_net.mojo b/tests/lightbug_http/test_net.mojo similarity index 100% rename from tests/test_net.mojo rename to tests/lightbug_http/test_net.mojo diff --git a/tests/test_uri.mojo b/tests/lightbug_http/test_uri.mojo similarity index 90% rename from tests/test_uri.mojo rename to tests/lightbug_http/test_uri.mojo index 1c4a8a14..885234b8 100644 --- a/tests/test_uri.mojo +++ b/tests/lightbug_http/test_uri.mojo @@ -5,17 +5,6 @@ from lightbug_http.strings import empty_string, to_string from lightbug_http.io.bytes import Bytes -def test_uri(): - test_uri_no_parse_defaults() - test_uri_parse_http_with_port() - test_uri_parse_https_with_port() - test_uri_parse_http_with_path() - test_uri_parse_https_with_path() - test_uri_parse_http_basic() - test_uri_parse_http_basic_www() - test_uri_parse_http_with_query_string() - test_uri_parse_http_with_hash() - def test_uri_no_parse_defaults(): var uri = URI.parse("http://example.com")[URI] diff --git a/tests/test_client.mojo b/tests/test_client.mojo deleted file mode 100644 index 5ca9717f..00000000 --- a/tests/test_client.mojo +++ /dev/null @@ -1,31 +0,0 @@ -import testing -from tests.utils import ( - default_server_conn_string, -) -from lightbug_http.sys.client import MojoClient -from lightbug_http.http import HTTPRequest, encode -from lightbug_http.uri import URI -from lightbug_http.header import Header, Headers -from lightbug_http.io.bytes import bytes - - -def test_client(): - var mojo_client = MojoClient() - test_mojo_client_lightbug_external_req(mojo_client) - - -fn test_mojo_client_lightbug_external_req(client: MojoClient) raises: - var req = HTTPRequest( - uri=URI.parse("http://httpbin.org/status/200")[URI], - headers=Headers( - Header("Connection", "keep-alive"), - Header("Host", "httpbin.org")), - method="GET", - ) - - try: - var res = client.do(req) - testing.assert_equal(res.status_code, 200) - - except e: - print(e) diff --git a/testutils/__init__.mojo b/testutils/__init__.mojo new file mode 100644 index 00000000..90f60fdd --- /dev/null +++ b/testutils/__init__.mojo @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/tests/utils.mojo b/testutils/utils.mojo similarity index 98% rename from tests/utils.mojo rename to testutils/utils.mojo index 8cc4bd2e..c2131e4b 100644 --- a/tests/utils.mojo +++ b/testutils/utils.mojo @@ -71,7 +71,7 @@ struct FakeClient(Client): self.req_host = "" self.req_is_tls = False - fn do(self, req: HTTPRequest) raises -> HTTPResponse: + fn do(self, owned req: HTTPRequest) raises -> HTTPResponse: return OK(String(defaultExpectedGetResponse)) fn extract(inout self, req: HTTPRequest) raises -> ReqInfo: