From b8dad2fe3e66fd47c167361096992146dcb8e3bf Mon Sep 17 00:00:00 2001
From: Hannes Mehnert <hannes@mehnert.org>
Date: Thu, 29 Aug 2024 12:27:59 +0200
Subject: [PATCH] tls-lwt: read add an optional ?off argument (#510)

* tls-lwt: read has an optional ?off argument
* add a check for off, as proposed by @reynir
---
 lwt/tls_lwt.ml  | 10 ++++++----
 lwt/tls_lwt.mli |  8 ++++----
 2 files changed, 10 insertions(+), 8 deletions(-)

diff --git a/lwt/tls_lwt.ml b/lwt/tls_lwt.ml
index b2754fa9..5a3eea76 100644
--- a/lwt/tls_lwt.ml
+++ b/lwt/tls_lwt.ml
@@ -153,12 +153,14 @@ module Unix = struct
             handle tls (String.sub (Bytes.unsafe_to_string t.recv_buf) 0 n)
           | `Closed -> Lwt.return `Eof
 
-  let rec read t buf =
+  let rec read t ?(off = 0) buf =
+    if off < 0 || off >= Bytes.length buf then
+      invalid_arg "offset must be >= 0 and < Bytes.length buf";
 
     let writeout res =
       let rlen = String.length res in
-      let n    = min (Bytes.length buf) rlen in
-      Bytes.blit_string res 0 buf 0 n ;
+      let n    = min (Bytes.length buf - off) rlen in
+      Bytes.blit_string res 0 buf off n ;
       t.linger <-
         (if n < rlen then Some (String.sub res n (rlen - n)) else None) ;
       Lwt.return n in
@@ -168,7 +170,7 @@ module Unix = struct
     | None     ->
         read_react t >>= function
           | `Eof           -> Lwt.return 0
-          | `Ok None       -> read t buf
+          | `Ok None       -> read t ~off buf
           | `Ok (Some res) -> writeout res
 
   let writev t css =
diff --git a/lwt/tls_lwt.mli b/lwt/tls_lwt.mli
index 59444d6d..8b1a93ad 100644
--- a/lwt/tls_lwt.mli
+++ b/lwt/tls_lwt.mli
@@ -51,12 +51,12 @@ module Unix : sig
 
   (** {2 Common stream operations} *)
 
-  (** [read t buffer] is [length], the number of bytes read into
-      [buffer]. *)
-  val read   : t -> bytes       -> int  Lwt.t
+  (** [read t ~off buffer] is [length], the number of bytes read into
+      [buffer]. It fills [buffer] starting at [off] (default is 0). *)
+  val read   : t -> ?off:int -> bytes -> int  Lwt.t
 
   (** [write t buffer] writes the [buffer] to the session. *)
-  val write  : t -> string      -> unit Lwt.t
+  val write  : t -> string -> unit Lwt.t
 
   (** [writev t buffers] writes the [buffers] to the session. *)
   val writev : t -> string list -> unit Lwt.t