Skip to content

Commit

Permalink
use ByteView for improved performance
Browse files Browse the repository at this point in the history
  • Loading branch information
saviorand committed Feb 2, 2025
1 parent 4b30bcd commit cdf3b06
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 52 deletions.
42 changes: 21 additions & 21 deletions lightbug_http/address.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,19 @@ fn is_ipv6(network: NetworkType) -> Bool:
"""Check if the network type is IPv6."""
return network in (NetworkType.tcp6, NetworkType.udp6, NetworkType.ip6)

fn resolve_localhost(host: String, network: NetworkType) -> String:
fn resolve_localhost(host: ByteView[StaticConstantOrigin], network: NetworkType) -> ByteView[StaticConstantOrigin]:
"""Resolve localhost to the appropriate IP address based on network type."""
if host != AddressConstants.LOCALHOST:
if host != AddressConstants.LOCALHOST.as_bytes():
return host

if network.is_ipv4():
return AddressConstants.IPV4_LOCALHOST
return AddressConstants.IPV4_LOCALHOST.as_bytes()
elif network.is_ipv6():
return AddressConstants.IPV6_LOCALHOST
return AddressConstants.IPV6_LOCALHOST.as_bytes()

return host

fn parse_ipv6_bracketed_address(address: ByteView[ImmutableAnyOrigin]) raises -> (ByteView[ImmutableAnyOrigin], UInt16):
fn parse_ipv6_bracketed_address(address: ByteView[StaticConstantOrigin]) raises -> (ByteView[StaticConstantOrigin], UInt16):
"""Parse an IPv6 address enclosed in brackets.
Returns:
Expand All @@ -420,32 +420,32 @@ fn parse_ipv6_bracketed_address(address: ByteView[ImmutableAnyOrigin]) raises ->
UInt16(end_bracket_index + 1)
)

fn validate_no_brackets(address: String, start_idx: UInt16, end_idx: Optional[UInt16] = None) raises:
fn validate_no_brackets(address: ByteView[StaticConstantOrigin], start_idx: UInt16, end_idx: Optional[UInt16] = None) raises:
"""Validate that the address segment contains no brackets."""
var segment: String
var segment: ByteView[StaticConstantOrigin]

if end_idx is None:
segment = address[int(start_idx):]
else:
segment = address[int(start_idx):int(end_idx.value())]

if segment.find("[") != -1:
if segment.find(Byte(ord("["))) != -1:
raise Error("unexpected '[' in address")
if segment.find("]") != -1:
if segment.find(Byte(ord("]"))) != -1:
raise Error("unexpected ']' in address")

fn parse_port(port_str: String) raises -> UInt16:
fn parse_port(port_str: ByteView[StaticConstantOrigin]) raises -> UInt16:
"""Parse and validate port number."""
if port_str == AddressConstants.EMPTY:
if port_str == AddressConstants.EMPTY.as_bytes():
raise MissingPortError

var port = int(port_str)
var port = int(str(port_str))
if port < MIN_PORT or port > MAX_PORT:
raise Error("Port number out of range (0-65535)")

return UInt16(port)

fn parse_address(network: NetworkType, address: String) raises -> (String, UInt16):
fn parse_address(network: NetworkType, address: ByteView[StaticConstantOrigin]) raises -> (ByteView[StaticConstantOrigin], UInt16):
"""Parse an address string into a host and port.
Args:
Expand All @@ -457,28 +457,28 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1
"""
if network.is_ip_protocol():
var host = resolve_localhost(address, network)
if host == AddressConstants.EMPTY:
if host == AddressConstants.EMPTY.as_bytes():
raise Error("missing host")

# For IPv6 addresses in IP protocol mode, we need to handle the address as-is
if network == NetworkType.ip6 and host.find(":") != -1:
if network == NetworkType.ip6 and host.find(Byte(ord(":"))) != -1:
return host, DEFAULT_IP_PORT

# For other IP protocols, no colons allowed
if host.find(":") != -1:
if host.find(Byte(ord(":"))) != -1:
raise Error("IP protocol addresses should not include ports")

return host, DEFAULT_IP_PORT

var colon_index = address.rfind(":")
var colon_index = address.rfind(Byte(ord(":")))
if colon_index == -1:
raise MissingPortError

var host: String
var host: ByteView[StaticConstantOrigin]
var bracket_offset: UInt16 = 0

# Handle IPv6 addresses
if address[0] == "[":
if address[0] == Byte(ord("[")):
try:
(host, bracket_offset) = parse_ipv6_bracketed_address(address)
except e:
Expand All @@ -488,13 +488,13 @@ fn parse_address(network: NetworkType, address: String) raises -> (String, UInt1
else:
# For IPv4, simply split at the last colon
host = address[:colon_index]
if host.find(":") != -1:
if host.find(Byte(ord(":"))) != -1:
raise TooManyColonsError

var port = parse_port(address[colon_index + 1:])

host = resolve_localhost(host, network)
if host == AddressConstants.EMPTY:
if host == AddressConstants.EMPTY.as_bytes():
raise Error("missing host")

return host, port
Expand Down
20 changes: 19 additions & 1 deletion lightbug_http/io/bytes.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ struct ByteView[origin: Origin]():

fn __iter__(self) -> _SpanIter[Byte, origin]:
return self._inner.__iter__()

fn find(self, target: Byte) -> Int:
"""Finds the index of a byte in a byte span.
Expand All @@ -162,6 +162,24 @@ struct ByteView[origin: Origin]():

return -1

fn rfind(self, target: Byte) -> Int:
"""Finds the index of the last occurrence of a byte in a byte span.
Args:
target: The byte to find.
Returns:
The index of the last occurrence of the byte in the span, or -1 if not found.
"""
# Start from the end and work backwards
var i = len(self) - 1
while i >= 0:
if self[i] == target:
return i
i -= 1

return -1

fn to_bytes(self) -> Bytes:
return Bytes(self._inner)

Expand Down
3 changes: 2 additions & 1 deletion tests/lightbug_http/io/test_bytes.mojo
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import testing
from collections import Dict, List
from lightbug_http.io.bytes import Bytes, bytes
from lightbug_http.io.bytes import Bytes, ByteView, bytes


fn test_string_literal_to_bytes() raises:
Expand Down Expand Up @@ -35,3 +35,4 @@ fn test_string_to_bytes() raises:

for c in cases.items():
testing.assert_equal(Bytes(c[].key.as_bytes()), c[].value)

58 changes: 29 additions & 29 deletions tests/lightbug_http/test_host_port.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,63 @@ from lightbug_http.address import TCPAddr, NetworkType, join_host_port, parse_ad

def test_split_host_port():
# TCP4
var hp = parse_address(NetworkType.tcp4, "127.0.0.1:8080")
testing.assert_equal(hp[0], "127.0.0.1")
var hp = parse_address(NetworkType.tcp4, "127.0.0.1:8080".as_bytes())
testing.assert_equal(hp[0], "127.0.0.1".as_bytes())
testing.assert_equal(hp[1], 8080)

# TCP4 with localhost
hp = parse_address(NetworkType.tcp4, "localhost:8080")
testing.assert_equal(hp[0], "127.0.0.1")
hp = parse_address(NetworkType.tcp4, "localhost:8080".as_bytes())
testing.assert_equal(hp[0], "127.0.0.1".as_bytes())
testing.assert_equal(hp[1], 8080)

# TCP6
hp = parse_address(NetworkType.tcp6, "[::1]:8080")
testing.assert_equal(hp[0], "::1")
hp = parse_address(NetworkType.tcp6, "[::1]:8080".as_bytes())
testing.assert_equal(hp[0], "::1".as_bytes())
testing.assert_equal(hp[1], 8080)

# TCP6 with localhost
hp = parse_address(NetworkType.tcp6, "localhost:8080")
testing.assert_equal(hp[0], "::1")
hp = parse_address(NetworkType.tcp6, "localhost:8080".as_bytes())
testing.assert_equal(hp[0], "::1".as_bytes())
testing.assert_equal(hp[1], 8080)

# UDP4
hp = parse_address(NetworkType.udp4, "192.168.1.1:53")
testing.assert_equal(hp[0], "192.168.1.1")
hp = parse_address(NetworkType.udp4, "192.168.1.1:53".as_bytes())
testing.assert_equal(hp[0], "192.168.1.1".as_bytes())
testing.assert_equal(hp[1], 53)

# UDP4 with localhost
hp = parse_address(NetworkType.udp4, "localhost:53")
testing.assert_equal(hp[0], "127.0.0.1")
hp = parse_address(NetworkType.udp4, "localhost:53".as_bytes())
testing.assert_equal(hp[0], "127.0.0.1".as_bytes())
testing.assert_equal(hp[1], 53)

# UDP6
hp = parse_address(NetworkType.udp6, "[2001:db8::1]:53")
testing.assert_equal(hp[0], "2001:db8::1")
hp = parse_address(NetworkType.udp6, "[2001:db8::1]:53".as_bytes())
testing.assert_equal(hp[0], "2001:db8::1".as_bytes())
testing.assert_equal(hp[1], 53)

# UDP6 with localhost
hp = parse_address(NetworkType.udp6, "localhost:53")
testing.assert_equal(hp[0], "::1")
hp = parse_address(NetworkType.udp6, "localhost:53".as_bytes())
testing.assert_equal(hp[0], "::1".as_bytes())
testing.assert_equal(hp[1], 53)

# IP4 (no port)
hp = parse_address(NetworkType.ip4, "192.168.1.1")
testing.assert_equal(hp[0], "192.168.1.1")
hp = parse_address(NetworkType.ip4, "192.168.1.1".as_bytes())
testing.assert_equal(hp[0], "192.168.1.1".as_bytes())
testing.assert_equal(hp[1], 0)

# IP4 with localhost
hp = parse_address(NetworkType.ip4, "localhost")
testing.assert_equal(hp[0], "127.0.0.1")
hp = parse_address(NetworkType.ip4, "localhost".as_bytes())
testing.assert_equal(hp[0], "127.0.0.1".as_bytes())
testing.assert_equal(hp[1], 0)

# IP6 (no port)
hp = parse_address(NetworkType.ip6, "2001:db8::1")
testing.assert_equal(hp[0], "2001:db8::1")
hp = parse_address(NetworkType.ip6, "2001:db8::1".as_bytes())
testing.assert_equal(hp[0], "2001:db8::1".as_bytes())
testing.assert_equal(hp[1], 0)

# IP6 with localhost
hp = parse_address(NetworkType.ip6, "localhost")
testing.assert_equal(hp[0], "::1")
hp = parse_address(NetworkType.ip6, "localhost".as_bytes())
testing.assert_equal(hp[0], "::1".as_bytes())
testing.assert_equal(hp[1], 0)

# TODO: IPv6 long form - Not supported yet.
Expand All @@ -71,35 +71,35 @@ def test_split_host_port():
# Error cases
# IP protocol with port
try:
_ = parse_address(NetworkType.ip4, "192.168.1.1:80")
_ = parse_address(NetworkType.ip4, "192.168.1.1:80".as_bytes())
testing.assert_false("Should have raised an error for IP protocol with port")
except Error:
testing.assert_true(True)

# Missing port
try:
_ = parse_address(NetworkType.tcp4, "192.168.1.1")
_ = parse_address(NetworkType.tcp4, "192.168.1.1".as_bytes())
testing.assert_false("Should have raised MissingPortError")
except MissingPortError:
testing.assert_true(True)

# Missing port
try:
_ = parse_address(NetworkType.tcp6, "[::1]")
_ = parse_address(NetworkType.tcp6, "[::1]".as_bytes())
testing.assert_false("Should have raised MissingPortError")
except MissingPortError:
testing.assert_true(True)

# Port out of range
try:
_ = parse_address(NetworkType.tcp4, "192.168.1.1:70000")
_ = parse_address(NetworkType.tcp4, "192.168.1.1:70000".as_bytes())
testing.assert_false("Should have raised error for invalid port")
except Error:
testing.assert_true(True)

# Missing closing bracket
try:
_ = parse_address(NetworkType.tcp6, "[::1:8080")
_ = parse_address(NetworkType.tcp6, "[::1:8080".as_bytes())
testing.assert_false("Should have raised error for missing bracket")
except Error:
testing.assert_true(True)
Expand Down

0 comments on commit cdf3b06

Please sign in to comment.