Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 125 additions & 12 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package zio.http.netty

import java.io.IOException
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable

Expand All @@ -28,15 +30,18 @@ import zio.http.netty.NettyBody.UnsafeAsync
import io.netty.buffer.ByteBufUtil
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}
import io.netty.util.concurrent.ScheduledFuture

private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](true) {
private[netty] abstract class AsyncBodyReader(timeoutMillis: Option[Long])
extends SimpleChannelInboundHandler[HttpContent](true) {
import zio.http.netty.AsyncBodyReader._

private var state: State = State.Buffering
private val buffer = new mutable.ArrayBuilder.ofByte()
private var previousAutoRead: Boolean = false
private var readingDone: Boolean = false
private var ctx: ChannelHandlerContext = _
private var state: State = State.Buffering
private val buffer = new mutable.ArrayBuilder.ofByte()
private var previousAutoRead: Boolean = false
private var readingDone: Boolean = false
private var ctx: ChannelHandlerContext = _
private val timeoutTask: AtomicReference[ScheduledFuture[_]] = new AtomicReference(null)

private def result(buffer: mutable.ArrayBuilder.ofByte): Chunk[Byte] = {
val arr = buffer.result()
Expand All @@ -57,12 +62,50 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
case UnsafeAsync.Aggregating(bufSize) => buffer.sizeHint(bufSize)
case cb => cb(result(buffer0), isLast = false)
}

// Schedule timeout task if configured
timeoutMillis.foreach { timeoutMillis =>
val task = ctx
.channel()
.eventLoop()
.schedule(
new Runnable {
override def run(): Unit = {
AsyncBodyReader.this.synchronized {
state match {
case State.Direct(cb) if !readingDone =>
cb.fail(
new IOException(
s"Body read timeout: server stopped sending data after ${timeoutMillis}ms",
),
)
// Mark as done to prevent further processing
readingDone = true
if (ctx.channel().isOpen) {
ctx.channel().close(): Unit
}
case _ => // Already completed or not connected
}
}
}
},
timeoutMillis,
TimeUnit.MILLISECONDS,
)
timeoutTask.set(task)
}

ctx.read(): Unit
} else {
throw new IllegalStateException("Attempting to read from a closed channel, which will never finish")
// Channel is already closed - fail immediately with appropriate error
callback.fail(
new IOException(
"Server closed connection before sending complete response body",
),
)
}
case _ =>
throw new IllegalStateException("Cannot connect twice")
callback.fail(new IllegalStateException("Cannot connect twice"))
}
}
}
Expand All @@ -74,7 +117,13 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
}

override def handlerRemoved(ctx: ChannelHandlerContext): Unit = {
val _ = ctx.channel().config().setAutoRead(previousAutoRead)
val _ = ctx.channel().config().setAutoRead(previousAutoRead)
// Cancel any pending timeout task
val currentTask = timeoutTask.get()
if (currentTask != null) {
currentTask.cancel(false)
timeoutTask.set(null)
}
Comment on lines +122 to +126
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wrap this in synchronized you can simply use a var for timeoutTask instead of an AtomicReference.

}

protected def onLastMessage(): Unit = ()
Expand All @@ -89,6 +138,13 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
val isLast = msg.isInstanceOf[LastHttpContent]
val content = ByteBufUtil.getBytes(msg.content())

// Cancel timeout task since we received data
val currentTask = timeoutTask.get()
if (currentTask != null) {
currentTask.cancel(false)
timeoutTask.set(null)
}

if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
Expand Down Expand Up @@ -121,12 +177,52 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
!isLast
}

// Reschedule timeout for next chunk if not the last message
if (readMore && !isLast) {
timeoutMillis.foreach { timeoutMillis =>
val task = ctx
.channel()
.eventLoop()
.schedule(
new Runnable {
override def run(): Unit = {
AsyncBodyReader.this.synchronized {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing (I've actually never seen this syntax before). Maybe just use this.synchronized?

state match {
case State.Direct(cb) if !readingDone =>
cb.fail(
new IOException(
s"Body read timeout: server stopped sending data after ${timeoutMillis}ms",
),
)
readingDone = true
if (ctx.channel().isOpen) {
ctx.channel().close(): Unit
}
case _ => // Already completed or not connected
}
}
}
},
timeoutMillis,
TimeUnit.MILLISECONDS,
)
timeoutTask.set(task)
}
}
Comment on lines +182 to +211
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the same code as above? Maybe extract it all into a method if that's the case?


if (readMore) ctx.read(): Unit
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
this.synchronized {
// Cancel timeout task
val currentTask = timeoutTask.get()
if (currentTask != null) {
currentTask.cancel(false)
timeoutTask.set(null)
}

state match {
case State.Buffering =>
case State.Direct(callback) =>
Expand All @@ -138,10 +234,27 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle

override def channelInactive(ctx: ChannelHandlerContext): Unit = {
this.synchronized {
// Cancel timeout task
val currentTask = timeoutTask.get()
if (currentTask != null) {
currentTask.cancel(false)
timeoutTask.set(null)
}
Comment on lines +238 to +242
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being used everywhere and doesn't have any dependency on the local vals. Maybe extract it into a method to keep things DRY


state match {
case State.Buffering =>
case State.Direct(callback) =>
callback.fail(new IOException("Channel closed unexpectedly"))
case State.Buffering =>
case State.Direct(callback) if !readingDone =>
// Step 4: Premature channel closure detection
// This is the core issue from #2383 - server sent headers but closed before body completed
// Provide a clear, actionable error message
callback.fail(
new IOException(
"Server closed connection before sending complete response body. " +
"This may indicate a broken server, network issue, or server-side timeout.",
),
)
case _ =>
// Reading already done - this is a normal close after completion
}
}
ctx.fireChannelInactive(): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ private[netty] object NettyResponse {
jRes: HttpResponse,
onComplete: Promise[Throwable, ChannelState],
keepAlive: Boolean,
bodyReadTimeoutMillis: Option[Long] = None,
)(implicit
unsafe: Unsafe,
trace: Trace,
Expand All @@ -55,7 +56,7 @@ private[netty] object NettyResponse {
Response(status, headers, Body.empty)
} else {
val contentType = headers.get(Header.ContentType)
val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status)
val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status, bodyReadTimeoutMillis)
ctx
.pipeline()
.addAfter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ private[netty] final class ClientInboundHandler(
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
enableKeepAlive: Boolean,
bodyReadTimeoutMillis: Option[Long] = None,
)(implicit trace: Trace)
extends SimpleChannelInboundHandler[HttpObject](false) {
implicit private val unsafeClass: Unsafe = Unsafe.unsafe
Expand All @@ -62,7 +63,7 @@ private[netty] final class ClientInboundHandler(
msg match {
case response: HttpResponse =>
val keepAlive = enableKeepAlive && HttpUtil.isKeepAlive(response)
val resp = NettyResponse.make(ctx, response, onComplete, keepAlive)
val resp = NettyResponse.make(ctx, response, onComplete, keepAlive, bodyReadTimeoutMillis)
onResponse.unsafe.done(Exit.succeed(resp))
case content: HttpContent =>
ctx.fireChannelRead(content): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,40 @@ import zio.http.netty.AsyncBodyReader
import io.netty.channel._
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}

/**
* Handles streaming HTTP response bodies and manages connection lifecycle.
*
* This handler extends AsyncBodyReader to provide body reading with timeout
* support, while also ensuring proper connection pool management through the
* onComplete promise.
*
* Connection Lifecycle Management:
* - onLastMessage(): Called when body completes successfully
* - keepAlive=true: Mark connection as reusable (ChannelState.forStatus)
* - keepAlive=false: Mark connection as invalid
* - channelInactive(): Called when channel closes prematurely
* - Always marks connection as Invalid to remove from pool
* - Allows parent AsyncBodyReader to fail the body callback
* - Ensures connection cleanup even on timeout/error
* - exceptionCaught(): Called on any exception during body reading
* - Marks connection as Invalid via Exit.fail
* - Ensures connection removed from pool on errors
*
* This ensures proper coordination with ZClient's connection pool:
* 1. Body reads successfully → connection returned to pool (if keep-alive)
* 2. Body read times out → connection invalidated and removed
* 3. Channel closes early → connection invalidated and removed
* 4. Exception occurs → connection invalidated and removed
*
* The onComplete promise is always fulfilled, preventing connection leaks.
*/
private[netty] final class ClientResponseStreamHandler(
onComplete: Promise[Throwable, ChannelState],
keepAlive: Boolean,
status: Status,
timeoutMillis: Option[Long],
)(implicit trace: Trace)
extends AsyncBodyReader { self =>
extends AsyncBodyReader(timeoutMillis) { self =>

private implicit val unsafe: Unsafe = Unsafe.unsafe

Expand All @@ -47,6 +75,15 @@ private[netty] final class ClientResponseStreamHandler(
if (isLast && !keepAlive) ctx.close(): Unit
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit =
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
onComplete.unsafe.done(Exit.fail(cause))
}

override def channelInactive(ctx: ChannelHandlerContext): Unit = {
// Channel closed before body reading completed
// Mark connection as invalid to ensure it's removed from pool
// The parent AsyncBodyReader will handle failing the body callback
onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid))
super.channelInactive(ctx)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package zio.http.netty.client

import scala.annotation.unroll

import zio._
import zio.stacktracer.TracingImplicits.disableAutoTrace

Expand Down Expand Up @@ -49,18 +51,21 @@ final case class NettyClientDriver private[netty] (
enableKeepAlive: Boolean,
createSocketApp: () => WebSocketApp[Any],
webSocketConfig: WebSocketConfig,
@unroll
bodyReadTimeoutMillis: Option[Long] = None,
)(implicit trace: Trace): ZIO[Scope, Throwable, ChannelInterface] =
if (location.scheme.isWebSocket)
requestWebsocket(channel, req, onResponse, onComplete, createSocketApp, webSocketConfig)
else
requestHttp(channel, req, onResponse, onComplete, enableKeepAlive)
requestHttp(channel, req, onResponse, onComplete, enableKeepAlive, bodyReadTimeoutMillis)

private def requestHttp(
channel: Channel,
req: Request,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
enableKeepAlive: Boolean,
bodyReadTimeoutMillis: Option[Long],
)(implicit trace: Trace): RIO[Scope, ChannelInterface] =
ZIO
.succeed(NettyRequestEncoder.encode(req))
Expand All @@ -85,7 +90,15 @@ final case class NettyClientDriver private[netty] (

pipeline.addLast(
Names.ClientInboundHandler,
new ClientInboundHandler(nettyRuntime, req, jReq, onResponse, onComplete, enableKeepAlive),
new ClientInboundHandler(
nettyRuntime,
req,
jReq,
onResponse,
onComplete,
enableKeepAlive,
bodyReadTimeoutMillis,
),
)

pipeline.addLast(
Expand Down
Loading
Loading