Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propose Middleware Solution #33

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
37 changes: 34 additions & 3 deletions lightbug.🔥
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
from lightbug_http import *
from lightbug_http.middleware.helpers import Success
from sys import is_defined
from lightbug_http import *

@value
struct HelloWorld(HTTPHandler):
fn handle(self, context: Context) -> HTTPResponse:
var name = context.params.find("username")
if name:
return Success("Hello!")
else:
return Success("Hello, World!")

fn main() raises:
var server = SysServer()
var handler = Welcome()
server.listen_and_serve("0.0.0.0:8080", handler)
if not is_defined["TEST"]():
var router = RouterMiddleware()
router.add("GET", "/hello", HelloWorld())

var middleware = MiddlewareChain()
middleware.add(CompressionMiddleware())
middleware.add(ErrorMiddleware())
middleware.add(LoggerMiddleware())
middleware.add(CorsMiddleware(allows_origin = "*"))
middleware.add(BasicAuthMiddleware("admin", "password"))
middleware.add(StaticMiddleware("static"))
middleware.add(router)
middleware.add(NotFoundMiddleware())

var server = SysServer()
server.listen_and_serve("0.0.0.0:8080", middleware)
else:
try:
run_tests()
print("Test suite passed")
except e:
print("Test suite failed: " + e.__str__())
1 change: 1 addition & 0 deletions lightbug_http/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from lightbug_http.http import HTTPRequest, HTTPResponse, OK
from lightbug_http.service import HTTPService, Welcome
from lightbug_http.sys.server import SysServer
from lightbug_http.tests.run import run_tests
from lightbug_http.middleware import *

trait DefaultConstructible:
fn __init__(inout self) raises:
Expand Down
172 changes: 172 additions & 0 deletions lightbug_http/middleware.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from lightbug_http.http import HTTPRequest, HTTPResponse

struct Context:
var request: Request
var params: Dict[String, AnyType]

fn __init__(self, request: Request):
self.request = request
self.params = Dict[String, AnyType]()

trait Middleware:
var next: Middleware

fn call(self, context: Context) -> Response:
...

struct ErrorMiddleware(Middleware):
fn call(self, context: Context) -> Response:
try:
return next.call(context: context)
catch e: Exception:
return InternalServerError()

struct LoggerMiddleware(Middleware):
fn call(self, context: Context) -> Response:
print("Request: \(context.request)")
return next.call(context: context)

struct StaticMiddleware(Middleware):
var path: String

fnt __init__(self, path: String):
self.path = path

fn call(self, context: Context) -> Response:
if context.request.path == "/":
var file = File(path: path + "index.html")
else:
var file = File(path: path + context.request.path)

if file.exists:
var html: String
with open(file, "r") as f:
html = f.read()
return OK(html.as_bytes(), "text/html")
else:
return next.call(context: context)

struct CorsMiddleware(Middleware):
var allow_origin: String

fn __init__(self, allow_origin: String):
self.allow_origin = allow_origin

fn call(self, context: Context) -> Response:
if context.request.method == "OPTIONS":
var response = next.call(context: context)
response.headers["Access-Control-Allow-Origin"] = allow_origin
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response

if context.request.origin == allow_origin:
return next.call(context: context)
else:
return Unauthorized()

struct CompressionMiddleware(Middleware):
fn call(self, context: Context) -> Response:
var response = next.call(context: context)
response.body = compress(response.body)
return response

fn compress(self, body: Bytes) -> Bytes:
#TODO: implement compression
return body


struct RouterMiddleware(Middleware):
var routes: Dict[String, Middleware]

fn __init__(self):
self.routes = Dict[String, Middleware]()

fn add(self, method: String, route: String, middleware: Middleware):
routes[method + ":" + route] = middleware

fn call(self, context: Context) -> Response:
# TODO: create a more advanced router
var method = context.request.method
var route = context.request.path
if middleware = routes[method + ":" + route]:
return middleware.call(context: context)
else:
return next.call(context: context)

struct BasicAuthMiddleware(Middleware):
var username: String
var password: String

fn __init__(self, username: String, password: String):
self.username = username
self.password = password

fn call(self, context: Context) -> Response:
var request = context.request
var auth = request.headers["Authorization"]
if auth == "Basic \(username):\(password)":
context.params["username"] = username
return next.call(context: context)
else:
return Unauthorized()

# always add at the end of the middleware chain
struct NotFoundMiddleware(Middleware):
fn call(self, context: Context) -> Response:
return NotFound()

struct MiddlewareChain(HttpService):
var middlewares: Array[Middleware]

fn __init__(self):
self.middlewares = Array[Middleware]()

fn add(self, middleware: Middleware):
if middlewares.count == 0:
middlewares.append(middleware)
else:
var last = middlewares[middlewares.count - 1]
last.next = middleware
middlewares.append(middleware)

fn func(self, request: Request) -> Response:
self.add(NotFoundMiddleware())
var context = Context(request: request, response: response)
return middlewares[0].call(context: context)

fn OK(body: Bytes) -> HTTPResponse:
return OK(body, String("text/plain"))

fn OK(body: Bytes, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 200, String("OK").as_bytes(), content_type.as_bytes()),
body,
)

fn NotFound(body: Bytes) -> HTTPResponse:
return NotFoundResponse(body, String("text/plain"))

fn NotFound(body: Bytes, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 404, String("Not Found").as_bytes(), content_type.as_bytes()),
body,
)

fn InternalServerError(body: Bytes) -> HTTPResponse:
return InternalServerErrorResponse(body, String("text/plain"))

fn InternalServerError(body: Bytes, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 500, String("Internal Server Error").as_bytes(), content_type.as_bytes()),
body,
)

fn Unauthorized(body: Bytes) -> HTTPResponse:
return UnauthorizedResponse(body, String("text/plain"))

fn Unauthorized(body: Bytes, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 401, String("Unauthorized").as_bytes(), content_type.as_bytes()),
body,
)
18 changes: 18 additions & 0 deletions lightbug_http/middleware/__init__.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from lightbug_http.middleware.helpers import Success
from lightbug_http.middleware.middleware import Context, Middleware, MiddlewareChain

from lightbug_http.middleware.basicauth import BasicAuthMiddleware
from lightbug_http.middleware.compression import CompressionMiddleware
from lightbug_http.middleware.cors import CorsMiddleware
from lightbug_http.middleware.error import ErrorMiddleware
from lightbug_http.middleware.logger import LoggerMiddleware
from lightbug_http.middleware.notfound import NotFoundMiddleware
from lightbug_http.middleware.router import RouterMiddleware, HTTPHandler
from lightbug_http.middleware.static import StaticMiddleware

# from lightbug_http.middleware.csrf import CsrfMiddleware
# from lightbug_http.middleware.session import SessionMiddleware
# from lightbug_http.middleware.websocket import WebSocketMiddleware
# from lightbug_http.middleware.cache import CacheMiddleware
# from lightbug_http.middleware.cookies import CookiesMiddleware
# from lightbug_http.middleware.session import SessionMiddleware
23 changes: 23 additions & 0 deletions lightbug_http/middleware/basicauth.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from lightbug_http.middleware.helpers import Unauthorized

## BasicAuth middleware requires basic authentication to access the route.
@value
struct BasicAuthMiddleware(Middleware):
var next: Middleware
var username: String
var password: String

fn __init__(inout self, username: String, password: String):
self.username = username
self.password = password

fn call(self, context: Context) -> HTTPResponse:
var request = context.request
#TODO: request object should have a way to get headers
# var auth = request.headers["Authorization"]
var auth = "Basic " + self.username + ":" + self.password
if auth == "Basic " + self.username + ":" + self.password:
context.params["username"] = username
return next.call(context)
else:
return Unauthorized("Requires Basic Authentication")
18 changes: 18 additions & 0 deletions lightbug_http/middleware/compression.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from lightbug_http.io.bytes import bytes

alias Bytes = List[Int8]

@value
struct CompressionMiddleware(Middleware):
var next: Middleware

fn call(self, context: Context) -> HTTPResponse:
var response = self.next.call(context)
response.body_raw = self.compress(response.body_raw)
return response

# TODO: implement compression
fn compress(self, body: String) -> Bytes:
#TODO: implement compression
return bytes(body)

28 changes: 28 additions & 0 deletions lightbug_http/middleware/cors.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from lightbug_http.middleware.helpers import Unauthorized
from lightbug_http.io.bytes import bytes, bytes_equal

## CORS middleware adds the necessary headers to allow cross-origin requests.
@value
struct CorsMiddleware(Middleware):
var next: Middleware
var allow_origin: String

fn __init__(inout self, allow_origin: String):
self.allow_origin = allow_origin

fn call(self, context: Context) -> HTTPResponse:
if bytes_equal(context.request.header.method(), bytes("OPTIONS")):
var response = self.next.call(context)
# TODO: implement headers
# response.headers["Access-Control-Allow-Origin"] = self.allow_origin
# response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
# response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response

# TODO: implement headers
# if context.request.headers["origin"] == self.allow_origin:
# return self.next.call(context)
# else:
# return Unauthorized("CORS not allowed")

return self.next.call(context)
15 changes: 15 additions & 0 deletions lightbug_http/middleware/error.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from lightbug_http.middleware.helpers import InternalServerError

## Error handler will catch any exceptions thrown by the other
## middleware and return a 500 response.
## It should be the first middleware in the chain.
@value
struct ErrorMiddleware(Middleware):
var next: Middleware

fn call(inout self, context: Context) -> HTTPResponse:
try:
return self.next.call(context)
except e:
return InternalServerError(e)

42 changes: 42 additions & 0 deletions lightbug_http/middleware/helpers.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from lightbug_http.http import HTTPRequest, HTTPResponse, ResponseHeader

### Helper functions to create HTTP responses
fn Success(body: String) -> HTTPResponse:
return Success(body, String("text/plain"))

fn Success(body: String, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 200, String("Success").as_bytes(), content_type.as_bytes()),
body.as_bytes(),
)

fn NotFound(body: String) -> HTTPResponse:
return NotFound(body, String("text/plain"))

fn NotFound(body: String, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 404, String("Not Found").as_bytes(), content_type.as_bytes()),
body.as_bytes(),
)

fn InternalServerError(body: String) -> HTTPResponse:
return InternalServerError(body, String("text/plain"))

fn InternalServerError(body: String, content_type: String) -> HTTPResponse:
return HTTPResponse(
ResponseHeader(True, 500, String("Internal Server Error").as_bytes(), content_type.as_bytes()),
body.as_bytes(),
)

fn Unauthorized(body: String) -> HTTPResponse:
return Unauthorized(body, String("text/plain"))

fn Unauthorized(body: String, content_type: String) -> HTTPResponse:
var header = ResponseHeader(True, 401, String("Unauthorized").as_bytes(), content_type.as_bytes())
# TODO: currently no way to set headers or cookies
# header.headers["WWW-Authenticate"] = "Basic realm=\"Login Required\""

return HTTPResponse(
header,
body.as_bytes(),
)
12 changes: 12 additions & 0 deletions lightbug_http/middleware/logger.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Logger middleware logs the request to the console.
@value
struct LoggerMiddleware(Middleware):
var next: Middleware

fn call(self, context: Context) -> HTTPResponse:
var request = context.request
#TODO: request is not printable
# print("Request: ", request)
var response = self.next.call(context)
print("Response:", response)
return response
Loading