Skip to content

Commit

Permalink
Merge pull request #441 from hannesm/no-rresult
Browse files Browse the repository at this point in the history
remove rresult dependency
  • Loading branch information
hannesm authored Oct 29, 2021
2 parents 6e13d2f + 4b599be commit 94f4c81
Show file tree
Hide file tree
Showing 11 changed files with 871 additions and 762 deletions.
8 changes: 8 additions & 0 deletions lib/core.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ open Sexplib.Conv
open Packet
open Ciphersuite

let (<+>) = Cstruct.append

let ( let* ) = Result.bind

let guard p e = if p then Ok () else Error e

let map_reader_error r = Result.map_error (fun re -> `Fatal (`ReaderError re)) r

type tls13 = [ `TLS_1_3 ] [@@deriving sexp_of]

type tls_before_13 = [
Expand Down
2 changes: 1 addition & 1 deletion lib/dune
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(library
(name tls)
(public_name tls)
(libraries cstruct cstruct-sexp logs hkdf mirage-crypto mirage-crypto-rng mirage-crypto-pk x509 sexplib domain-name fmt mirage-crypto-ec rresult ipaddr ipaddr-sexp)
(libraries cstruct cstruct-sexp logs hkdf mirage-crypto mirage-crypto-rng mirage-crypto-pk x509 sexplib domain-name fmt mirage-crypto-ec ipaddr ipaddr-sexp)
(preprocess (pps ppx_sexp_conv ppx_cstruct)))
148 changes: 70 additions & 78 deletions lib/engine.ml
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
open Core
open State

open Rresult.R.Infix

let guard p e = if p then Ok () else Error e

type state = State.state

type client_hello_errors = State.client_hello_errors
Expand Down Expand Up @@ -92,8 +88,6 @@ type ret =
* [ `Data of Cstruct.t option ],
failure * [ `Response of Cstruct.t ]) result

let (<+>) = Cstruct.append

let new_state config role =
let handshake_state = match role with
| `Client -> Client ClientInitial
Expand Down Expand Up @@ -187,14 +181,14 @@ let encrypt (version : tls_version) (st : crypto_state) ty buf =
(* well-behaved pure decryptor *)
let verify_mac sequence mac mac_k ty ver decrypted =
let macstart = Cstruct.length decrypted - Mirage_crypto.Hash.digest_size mac in
guard (macstart >= 0) (`Fatal `MACUnderflow) >>= fun () ->
let* () = guard (macstart >= 0) (`Fatal `MACUnderflow) in
let (body, mmac) = Cstruct.split decrypted macstart in
let cmac =
let ver = pair_of_tls_version ver in
let hdr = Crypto.pseudo_header sequence ty ver (Cstruct.length body) in
Crypto.mac mac mac_k hdr body in
guard (Cstruct.equal cmac mmac) (`Fatal `MACMismatch) >>| fun () ->
body
let* () = guard (Cstruct.equal cmac mmac) (`Fatal `MACMismatch) in
Ok body


let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf =
Expand All @@ -210,7 +204,8 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf
(* defense against http://lasecwww.epfl.ch/memo/memo_ssl.shtml 1) in
https://www.openssl.org/~bodo/tls-cbc.txt *)
let mask_decrypt_failure seq mac mac_k =
compute_mac seq mac mac_k buf >>= fun _ -> Error (`Fatal `MACMismatch)
let* _ = compute_mac seq mac mac_k buf in
Error (`Fatal `MACMismatch)
in

let dec ctx =
Expand All @@ -222,20 +217,20 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf
| None ->
mask_decrypt_failure seq c.hmac c.hmac_secret
| Some (dec, iv') ->
compute_mac seq c.hmac c.hmac_secret dec >>| fun msg ->
(msg, iv')
let* msg = compute_mac seq c.hmac c.hmac_secret dec in
Ok (msg, iv')
in
( match c.iv_mode with
| Iv iv ->
dec iv buf >>| fun (msg, iv') ->
CBC { c with iv_mode = Iv iv' }, msg
let* msg, iv' = dec iv buf in
Ok (CBC { c with iv_mode = Iv iv' }, msg)
| Random_iv ->
if Cstruct.length buf < Crypto.cbc_block c.cipher then
Error (`Fatal `MACUnderflow)
else
let iv, buf = Cstruct.split buf (Crypto.cbc_block c.cipher) in
dec iv buf >>| fun (msg, _) ->
(CBC c, msg) )
let* msg, _ = dec iv buf in
Ok (CBC c, msg) )

| AEAD c ->
match c.cipher with
Expand Down Expand Up @@ -301,12 +296,12 @@ let decrypt ?(trial = false) (version : tls_version) (st : crypto_state) ty buf
else
Error (`Fatal `MACMismatch)
| Some x ->
unpad x >>| fun (data, ty) ->
(Some { ctx with sequence = Int64.succ ctx.sequence }, data, ty))
let* data, ty = unpad x in
Ok (Some { ctx with sequence = Int64.succ ctx.sequence }, data, ty))
| _ -> Error (`Fatal `InvalidMessage))
| _ -> Error (`Fatal `InvalidMessage))
| Some ctx, _ ->
dec ctx >>= fun (st', msg) ->
let* st', msg = dec ctx in
let ctx' = { cipher_st = st' ; sequence = Int64.succ ctx.sequence } in
Ok (Some ctx', msg, ty)

Expand All @@ -317,8 +312,8 @@ let rec separate_records : Cstruct.t -> ((tls_hdr * Cstruct.t) list * Cstruct.t,
match parse_record buf with
| Ok (`Fragment b) -> Ok ([], b)
| Ok (`Record (packet, fragment)) ->
separate_records fragment >>| fun (tl, frag) ->
(packet :: tl, frag)
let* tl, frag = separate_records fragment in
Ok (packet :: tl, frag)
| Error (Overflow x) ->
Tracing.cs ~tag:"buf-in" buf ;
Error (`Fatal (`RecordOverflow x))
Expand Down Expand Up @@ -359,15 +354,14 @@ module Alert = struct
let close_notify = make ~level:WARNING CLOSE_NOTIFY

let handle buf =
match Reader.parse_alert buf with
| Ok (_, a_type as alert) ->
Tracing.sexpf ~tag:"alert-in" ~f:sexp_of_tls_alert alert ;
let err = match a_type with
| CLOSE_NOTIFY -> `Eof
| _ -> `Alert a_type in
Tracing.sexpf ~tag:"alert-out" ~f:sexp_of_tls_alert (Packet.WARNING, Packet.CLOSE_NOTIFY) ;
Ok (err, [`Record close_notify])
| Error re -> Error (`Fatal (`ReaderError re))
let* alert = map_reader_error (Reader.parse_alert buf) in
let _, a_type = alert in
Tracing.sexpf ~tag:"alert-in" ~f:sexp_of_tls_alert alert ;
let err = match a_type with
| CLOSE_NOTIFY -> `Eof
| _ -> `Alert a_type in
Tracing.sexpf ~tag:"alert-out" ~f:sexp_of_tls_alert (Packet.WARNING, Packet.CLOSE_NOTIFY) ;
Ok (err, [`Record close_notify])
end

let hs_can_handle_appdata s =
Expand Down Expand Up @@ -396,8 +390,8 @@ let rec separate_handshakes buf =
match Reader.parse_handshake_frame buf with
| None, rest -> Ok ([], rest)
| Some hs, rest ->
separate_handshakes rest >>| fun (rt, frag) ->
(hs :: rt, frag)
let* rt, frag = separate_handshakes rest in
Ok (hs :: rt, frag)

let handle_change_cipher_spec = function
| Client cs -> Handshake_client.handle_change_cipher_spec cs
Expand Down Expand Up @@ -431,8 +425,8 @@ let handle_packet hs buf = function
*)

| Packet.ALERT ->
Alert.handle buf >>| fun (err, out) ->
(hs, out, None, err)
let* err, out = Alert.handle buf in
Ok (hs, out, None, err)

| Packet.APPLICATION_DATA ->
if hs_can_handle_appdata hs || (early_data hs && Cstruct.length hs.hs_fragment = 0) then
Expand All @@ -442,20 +436,20 @@ let handle_packet hs buf = function
Error (`Fatal `CannotHandleApplicationDataYet)

| Packet.CHANGE_CIPHER_SPEC ->
handle_change_cipher_spec hs.machina hs buf
>>| fun (hs, items) -> (hs, items, None, `No_err)
let* hs, items = handle_change_cipher_spec hs.machina hs buf in
Ok (hs, items, None, `No_err)

| Packet.HANDSHAKE ->
separate_handshakes (hs.hs_fragment <+> buf)
>>= fun (hss, hs_fragment) ->
let hs = { hs with hs_fragment } in
let* hss, hs_fragment = separate_handshakes (hs.hs_fragment <+> buf) in
let hs = { hs with hs_fragment } in
let* hs, items =
List.fold_left (fun acc raw ->
acc >>= fun (hs, items) ->
handle_handshake hs.machina hs raw
>>| fun (hs', items') -> (hs', items @ items'))
let* hs, items = acc in
let* hs', items' = handle_handshake hs.machina hs raw in
Ok (hs', items @ items'))
(Ok (hs, [])) hss
>>| fun (hs, items) ->
(hs, items, None, `No_err)
in
Ok (hs, items, None, `No_err)

| Packet.HEARTBEAT -> Error (`Fatal `NoHeartbeat)

Expand All @@ -471,8 +465,8 @@ let decrement_early_data hs ty buf =
| _ -> `AES_128_GCM_SHA256
(* TODO assert and ensure that all early_data states have a cipher *)
in
bytes hs.early_data_left cipher >>| fun early_data_left ->
{ hs with early_data_left }
let* early_data_left = bytes hs.early_data_left cipher in
Ok { hs with early_data_left }
else
Ok hs

Expand All @@ -482,38 +476,37 @@ let handle_raw_record state (hdr, buf as record : raw_record) =
Tracing.sexpf ~tag:"record-in" ~f:sexp_of_raw_record record ;
let hs = state.handshake in
let version = hs.protocol_version in
( match hs.machina, version with
let* () =
match hs.machina, version with
| Client (AwaitServerHello _), _ -> Ok ()
| Server AwaitClientHello , _ -> Ok ()
| Server13 AwaitClientHelloHRR13, _ -> Ok ()
| _ , `TLS_1_3 -> guard (hdr.version = `TLS_1_2) (`Fatal (`BadRecordVersion hdr.version))
| _ , v -> guard (version_eq hdr.version v) (`Fatal (`BadRecordVersion hdr.version)) )
>>= fun () ->
| _ , v -> guard (version_eq hdr.version v) (`Fatal (`BadRecordVersion hdr.version))
in
let trial = match hs.machina with
| Server13 (AwaitEndOfEarlyData13 _) | Server13 Established13 -> false
| Server13 _ -> hs.early_data_left > 0l && Cstruct.length hs.hs_fragment = 0
| _ -> false
in
decrypt ~trial version state.decryptor hdr.content_type buf
>>= fun (dec_st, dec, ty) ->
decrement_early_data hs ty buf >>= fun handshake ->
let* dec_st, dec, ty = decrypt ~trial version state.decryptor hdr.content_type buf in
let* handshake = decrement_early_data hs ty buf in
Tracing.sexpf ~tag:"frame-in" ~f:sexp_of_record (ty, dec) ;
handle_packet handshake dec ty
>>| fun (handshake, items, data, err) ->
let (encryptor, decryptor, encs) =
let* handshake, items, data, err = handle_packet handshake dec ty in
let encryptor, decryptor, encs =
List.fold_left (fun (enc, dec, es) -> function
| `Change_enc enc' -> (Some enc', dec, es)
| `Change_dec dec' -> (enc, Some dec', es)
| `Record r ->
Tracing.sexpf ~tag:"frame-out" ~f:sexp_of_record r ;
let (enc', encbuf) = encrypt_records enc handshake.protocol_version [r] in
(enc', dec, es @ encbuf))
let (enc', encbuf) = encrypt_records enc handshake.protocol_version [r] in
(enc', dec, es @ encbuf))
(state.encryptor, dec_st, [])
items
in
List.iter (Tracing.sexpf ~tag:"record-out" ~f:sexp_of_record) encs ;
let state' = { state with handshake ; encryptor ; decryptor } in
(state', encs, data, err)
Ok (state', encs, data, err)

let maybe_app a b = match a, b with
| Some x, Some y -> Some (x <+> y)
Expand All @@ -535,26 +528,25 @@ let handle_tls state buf =
let rec handle_records st = function
| [] -> Ok (st, [], None, `No_err)
| r::rs ->
handle_raw_record st r >>= function
| (st, raw_rs, data, `No_err) ->
handle_records st rs >>| fun (st', raw_rs', data', err') ->
(st', raw_rs @ raw_rs', maybe_app data data', err')
| res -> Ok res
let* r = handle_raw_record st r in
match r with
| st, raw_rs, data, `No_err ->
let* st', raw_rs', data', err' = handle_records st rs in
Ok (st', raw_rs @ raw_rs', maybe_app data data', err')
| res -> Ok res
in
match
separate_records (state.fragment <+> buf)
>>= fun (in_records, fragment) ->
handle_records state in_records
>>| fun (state', out_records, data, err) ->
let version = state'.handshake.protocol_version in
let resp = match out_records with
| [] -> None
| _ ->
let out = assemble_records version out_records in
Tracing.cs ~tag:"wire-out" out ;
Some out
in
({ state' with fragment }, resp, data, err)
let* in_records, fragment = separate_records (state.fragment <+> buf) in
let* state', out_records, data, err = handle_records state in_records in
let version = state'.handshake.protocol_version in
let resp = match out_records with
| [] -> None
| _ ->
let out = assemble_records version out_records in
Tracing.cs ~tag:"wire-out" out ;
Some out
in
Ok ({ state' with fragment }, resp, data, err)
with
| Ok (state, resp, data, err) ->
let res = match err with
Expand Down Expand Up @@ -641,9 +633,9 @@ let reneg ?authenticator ?acceptable_cas ?cert st =
| _ -> None

let key_update ?(request = true) state =
Handshake_common.output_key_update ~request state >>| fun (state', out) ->
let* state', out = Handshake_common.output_key_update ~request state in
let _, outbuf = send_records state [out] in
state', outbuf
Ok (state', outbuf)

let client config =
let config = Config.of_client config in
Expand Down
Loading

0 comments on commit 94f4c81

Please sign in to comment.