From f5580b4e829812b2df766ac577e589ff783bbf40 Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Sat, 24 Aug 2024 17:29:05 -0700 Subject: [PATCH] http upgrade error handler (#71) --- async/httpun_ws_async.ml | 8 +++++--- async/httpun_ws_async.mli | 9 +++++---- eio/httpun_ws_eio.ml | 9 ++++++--- eio/httpun_ws_eio.mli | 5 +++-- examples/async/echo_server.ml | 4 ++-- examples/eio/echo_server.ml | 4 ++-- examples/lwt/echo_server.ml | 4 ++-- lib/handshake.ml | 10 +++++----- lib/httpun_ws.mli | 6 ++++-- lib/server_connection.ml | 35 ++++++++++++++++++++--------------- lwt/httpun_ws_lwt.ml | 8 +++++--- lwt/httpun_ws_lwt_intf.ml | 5 +++-- mirage/httpun_ws_mirage.ml | 20 ++++++++++++-------- mirage/httpun_ws_mirage.mli | 8 ++++---- 14 files changed, 78 insertions(+), 57 deletions(-) diff --git a/async/httpun_ws_async.ml b/async/httpun_ws_async.ml index 53f2d7ae..8cd21fc6 100644 --- a/async/httpun_ws_async.ml +++ b/async/httpun_ws_async.ml @@ -9,12 +9,14 @@ let sha1 s = module Server = struct let create_connection_handler ?(config = Httpun.Config.default) - ~websocket_handler - ~error_handler = fun client_addr socket -> + ?error_handler + ?websocket_error_handler + websocket_handler = fun client_addr socket -> let connection = Httpun_ws.Server_connection.create ~sha1 - ~error_handler:(error_handler client_addr) + ?error_handler:(Option.map ~f:(fun f -> f client_addr) error_handler) + ?websocket_error_handler:(Option.map ~f:(fun f -> f client_addr) websocket_error_handler) (websocket_handler client_addr) in Gluten_async.Server.create_connection_handler diff --git a/async/httpun_ws_async.mli b/async/httpun_ws_async.mli index 031fc633..3419ac88 100644 --- a/async/httpun_ws_async.mli +++ b/async/httpun_ws_async.mli @@ -3,10 +3,11 @@ open Async module Server : sig val create_connection_handler : ?config : Httpun.Config.t - -> websocket_handler : ( 'a - -> Httpun_ws.Wsd.t - -> Httpun_ws.Websocket_connection.input_handlers) - -> error_handler : ('a -> Httpun_ws.Server_connection.error_handler) + -> ?error_handler: ('a -> Httpun.Server_connection.error_handler) + -> ?websocket_error_handler: ('a -> Httpun_ws.Server_connection.error_handler) + -> ('a + -> Httpun_ws.Wsd.t + -> Httpun_ws.Websocket_connection.input_handlers) -> 'a -> ([`Active], [< Socket.Address.t] as 'a) Socket.t -> unit Deferred.t diff --git a/eio/httpun_ws_eio.ml b/eio/httpun_ws_eio.ml index 9bdded71..5e48adde 100644 --- a/eio/httpun_ws_eio.ml +++ b/eio/httpun_ws_eio.ml @@ -8,13 +8,16 @@ module Server = struct * error handler?*) let create_connection_handler ?(config = Httpun.Config.default) - ~websocket_handler - ~error_handler ~sw = + ?error_handler + ?websocket_error_handler + ~sw + websocket_handler = fun client_addr socket -> let connection = Httpun_ws.Server_connection.create ~sha1 - ~error_handler:(error_handler client_addr) + ?error_handler:(Option.map (fun f -> f client_addr) error_handler) + ?websocket_error_handler:(Option.map (fun f -> f client_addr) websocket_error_handler) (websocket_handler client_addr) in Gluten_eio.Server.create_connection_handler diff --git a/eio/httpun_ws_eio.mli b/eio/httpun_ws_eio.mli index 375a57c6..6965d0de 100644 --- a/eio/httpun_ws_eio.mli +++ b/eio/httpun_ws_eio.mli @@ -4,9 +4,10 @@ module Server : sig val create_connection_handler : ?config : Httpun.Config.t - -> websocket_handler : (Eio.Net.Sockaddr.stream -> Wsd.t -> Websocket_connection.input_handlers) - -> error_handler : (Eio.Net.Sockaddr.stream -> Server_connection.error_handler) + -> ?error_handler : (Eio.Net.Sockaddr.stream -> Httpun.Server_connection.error_handler) + -> ?websocket_error_handler : (Eio.Net.Sockaddr.stream -> Server_connection.error_handler) -> sw:Eio.Switch.t + -> (Eio.Net.Sockaddr.stream -> Wsd.t -> Websocket_connection.input_handlers) -> (Eio.Net.Sockaddr.stream -> _ Eio.Net.stream_socket -> unit) end diff --git a/examples/async/echo_server.ml b/examples/async/echo_server.ml index 78e6e7d8..2682b5fa 100644 --- a/examples/async/echo_server.ml +++ b/examples/async/echo_server.ml @@ -45,8 +45,8 @@ let connection_handler : ([< Socket.Address.t] as 'a) -> ([`Active], 'a) Socket. Httpun_ws_async.Server.create_connection_handler ?config:None - ~websocket_handler - ~error_handler + ~websocket_error_handler:error_handler + websocket_handler let main port max_accepts_per_batch () = let where_to_listen = Tcp.Where_to_listen.of_port port in diff --git a/examples/eio/echo_server.ml b/examples/eio/echo_server.ml index 56ec01d5..79921e23 100644 --- a/examples/eio/echo_server.ml +++ b/examples/eio/echo_server.ml @@ -47,9 +47,9 @@ let connection_handler ~sw : Eio.Net.Sockaddr.stream -> _ Eio.Net.stream_socket Httpun_ws_eio.Server.create_connection_handler ?config:None - ~websocket_handler - ~error_handler ~sw + ~websocket_error_handler:error_handler + websocket_handler let () = diff --git a/examples/lwt/echo_server.ml b/examples/lwt/echo_server.ml index 16839774..852e5aaf 100644 --- a/examples/lwt/echo_server.ml +++ b/examples/lwt/echo_server.ml @@ -40,8 +40,8 @@ let connection_handler : Unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t = Httpun_ws_lwt_unix.Server.create_connection_handler ?config:None - ~websocket_handler - ~error_handler + ~websocket_error_handler:error_handler + websocket_handler diff --git a/lib/handshake.ml b/lib/handshake.ml index 9b4ec827..6a6f65dc 100644 --- a/lib/handshake.ml +++ b/lib/handshake.ml @@ -114,15 +114,15 @@ let passes_scrutiny ~request_method headers = let upgrade_headers ~sha1 ~request_method headers = if passes_scrutiny ~request_method headers then begin - let sec_websocket_key = Headers.get_exn headers "sec-websocket-key" in - let accept = sec_websocket_key_proof ~sha1 sec_websocket_key in - let upgrade_headers = + let accept = + let sec_websocket_key = Headers.get_exn headers "sec-websocket-key" in + sec_websocket_key_proof ~sha1 sec_websocket_key + in + Ok [ "Upgrade", "websocket" ; "Connection", "upgrade" ; "Sec-Websocket-Accept", accept ] - in - Ok upgrade_headers end else Error "Didn't pass scrutiny" diff --git a/lib/httpun_ws.mli b/lib/httpun_ws.mli index 54f7a220..945ce98b 100644 --- a/lib/httpun_ws.mli +++ b/lib/httpun_ws.mli @@ -202,8 +202,10 @@ module Server_connection : sig (* TODO: should take handshake error handler. *) val create - : sha1 : (string -> string) - -> ?error_handler : error_handler + : ?config : Httpun.Config.t + -> ?error_handler : Httpun.Server_connection.error_handler + -> ?websocket_error_handler : error_handler + -> sha1 : (string -> string) -> (Wsd.t -> Websocket_connection.input_handlers) -> t diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 2a09e2ff..22fe1712 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -20,24 +20,32 @@ let is_closed t = | Websocket websocket -> Websocket_connection.is_closed websocket -let create ~sha1 ?error_handler websocket_handler = - let rec upgrade_handler upgrade () = - let t = Lazy.force t in +let create ?config ?error_handler ?websocket_error_handler ~sha1 websocket_handler = + let upgrade_handler t upgrade () = let ws_connection = - Websocket_connection.create ~mode:`Server ?error_handler websocket_handler + Websocket_connection.create + ~mode:`Server + ?error_handler:websocket_error_handler + websocket_handler in t.state <- Websocket ws_connection; upgrade (Gluten.make (module Websocket_connection) ws_connection); - and request_handler { Gluten.reqd; upgrade } = + in + let rec request_handler { Gluten.reqd; upgrade } = let error msg = - let response = Httpun.(Response.create - ~headers:(Headers.of_list ["Connection", "close"]) - `Bad_request) + let response = Httpun.Response.create + ~headers:(Httpun.Headers.of_list ["Connection", "close"]) + `Bad_request in Httpun.Reqd.respond_with_string reqd response msg in let ret = Httpun.Reqd.try_with reqd (fun () -> - match Handshake.respond_with_upgrade ~sha1 reqd (upgrade_handler upgrade) with + match + Handshake.respond_with_upgrade + ~sha1 + reqd + (upgrade_handler (Lazy.force t) upgrade) + with | Ok () -> () | Error msg -> error msg) in @@ -51,7 +59,7 @@ let create ~sha1 ?error_handler websocket_handler = (Server_handshake.create_upgradable ~protocol:(module Httpun.Server_connection) ~create: - (Httpun.Server_connection.create ?config:None ?error_handler:None) + (Httpun.Server_connection.create ?config ?error_handler) request_handler) ; websocket_handler } @@ -76,11 +84,8 @@ let shutdown t = let report_exn t exn = match t.state with - | Handshake _ -> - (* TODO: we need to handle this properly. There was an error in the upgrade *) - assert false - | Websocket websocket -> - Websocket_connection.report_exn websocket exn + | Handshake hs -> Server_handshake.report_exn hs exn; + | Websocket websocket -> Websocket_connection.report_exn websocket exn let next_read_operation t = match t.state with diff --git a/lwt/httpun_ws_lwt.ml b/lwt/httpun_ws_lwt.ml index def64776..0ea0698d 100644 --- a/lwt/httpun_ws_lwt.ml +++ b/lwt/httpun_ws_lwt.ml @@ -12,12 +12,14 @@ module Server (Server_runtime: Gluten_lwt.Server) = struct * error handler?*) let create_connection_handler ?(config = Httpun.Config.default) - ~websocket_handler - ~error_handler = fun client_addr socket -> + ?error_handler + ?websocket_error_handler + websocket_handler = fun client_addr socket -> let connection = Httpun_ws.Server_connection.create ~sha1 - ~error_handler:(error_handler client_addr) + ?error_handler:(Option.map (fun f -> f client_addr) error_handler) + ?websocket_error_handler:(Option.map (fun f -> f client_addr) websocket_error_handler) (websocket_handler client_addr) in Server_runtime.create_connection_handler diff --git a/lwt/httpun_ws_lwt_intf.ml b/lwt/httpun_ws_lwt_intf.ml index 8a6aa476..1515a032 100644 --- a/lwt/httpun_ws_lwt_intf.ml +++ b/lwt/httpun_ws_lwt_intf.ml @@ -41,8 +41,9 @@ module type Server = sig val create_connection_handler : ?config : Httpun.Config.t - -> websocket_handler : (addr -> Wsd.t -> Websocket_connection.input_handlers) - -> error_handler : (addr -> Httpun_ws.Server_connection.error_handler) + -> ?error_handler : (addr -> Httpun.Server_connection.error_handler) + -> ?websocket_error_handler : (addr -> Server_connection.error_handler) + -> (addr -> Wsd.t -> Websocket_connection.input_handlers) -> (addr -> socket -> unit Lwt.t) end diff --git a/mirage/httpun_ws_mirage.ml b/mirage/httpun_ws_mirage.ml index a5975067..a30a7713 100644 --- a/mirage/httpun_ws_mirage.ml +++ b/mirage/httpun_ws_mirage.ml @@ -37,14 +37,18 @@ module Server (Flow : Mirage_flow.S) = struct module Server_runtime = Httpun_ws_lwt.Server (Gluten_mirage.Server (Flow)) - let create_connection_handler ?config ~websocket_handler ~error_handler = + let create_connection_handler + ?config + ?error_handler + ?websocket_error_handler + websocket_handler = fun flow -> let websocket_handler = fun () -> websocket_handler in - let error_handler = fun () -> error_handler in Server_runtime.create_connection_handler ?config - ~websocket_handler - ~error_handler + ?error_handler:(Option.map (fun f -> fun () -> f) error_handler) + ?websocket_error_handler:(Option.map (fun f -> fun () -> f) websocket_error_handler) + websocket_handler () (Gluten_mirage.Buffered_flow.create flow) end @@ -57,10 +61,10 @@ module type Server = sig val create_connection_handler : ?config : Httpun.Config.t - -> websocket_handler : (Wsd.t -> Websocket_connection.input_handlers) - -> error_handler : Server_connection.error_handler - -> socket - -> unit Lwt.t + -> ?error_handler : Httpun.Server_connection.error_handler + -> ?websocket_error_handler : Server_connection.error_handler + -> (Wsd.t -> Websocket_connection.input_handlers) + -> (socket -> unit Lwt.t) end module type Client = Httpun_ws_lwt.Client diff --git a/mirage/httpun_ws_mirage.mli b/mirage/httpun_ws_mirage.mli index 33cbcf3b..df946cdd 100644 --- a/mirage/httpun_ws_mirage.mli +++ b/mirage/httpun_ws_mirage.mli @@ -39,10 +39,10 @@ module type Server = sig val create_connection_handler : ?config : Httpun.Config.t - -> websocket_handler : (Wsd.t -> Websocket_connection.input_handlers) - -> error_handler : Server_connection.error_handler - -> socket - -> unit Lwt.t + -> ?error_handler : Httpun.Server_connection.error_handler + -> ?websocket_error_handler : Server_connection.error_handler + -> (Wsd.t -> Websocket_connection.input_handlers) + -> (socket -> unit Lwt.t) end module Server (Flow : Mirage_flow.S) : Server with type socket = Flow.flow