diff --git a/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponse.kt b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponse.kt index cb7b92896c7..62990d3635c 100644 --- a/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponse.kt +++ b/ktor-http/ktor-http-cio/common/src/io/ktor/http/cio/RequestResponse.kt @@ -6,6 +6,7 @@ package io.ktor.http.cio import io.ktor.http.* import io.ktor.http.cio.internals.* +import io.ktor.utils.io.InternalAPI import io.ktor.utils.io.core.* /** diff --git a/ktor-io/api/ktor-io.api b/ktor-io/api/ktor-io.api index dd4c604a636..eb04c5872e4 100644 --- a/ktor-io/api/ktor-io.api +++ b/ktor-io/api/ktor-io.api @@ -226,6 +226,12 @@ public final class io/ktor/utils/io/ConcurrentIOException : java/lang/IllegalSta public synthetic fun (Ljava/lang/String;Ljava/lang/Throwable;ILkotlin/jvm/internal/DefaultConstructorMarker;)V } +public final class io/ktor/utils/io/ConnectionClosedException : java/io/IOException { + public fun ()V + public fun (Ljava/lang/String;)V + public synthetic fun (Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V +} + public final class io/ktor/utils/io/CountedByteReadChannel : io/ktor/utils/io/ByteReadChannel { public fun (Lio/ktor/utils/io/ByteReadChannel;)V public fun awaitContent (ILkotlin/coroutines/Continuation;)Ljava/lang/Object; diff --git a/ktor-io/api/ktor-io.klib.api b/ktor-io/api/ktor-io.klib.api index d91ee7cc4d7..7162990faf8 100644 --- a/ktor-io/api/ktor-io.klib.api +++ b/ktor-io/api/ktor-io.klib.api @@ -236,6 +236,10 @@ final class io.ktor.utils.io/ConcurrentIOException : kotlin/IllegalStateExceptio constructor (kotlin/String, kotlin/Throwable? = ...) // io.ktor.utils.io/ConcurrentIOException.|(kotlin.String;kotlin.Throwable?){}[0] } +final class io.ktor.utils.io/ConnectionClosedException : kotlinx.io/IOException { // io.ktor.utils.io/ConnectionClosedException|null[0] + constructor (kotlin/String = ...) // io.ktor.utils.io/ConnectionClosedException.|(kotlin.String){}[0] +} + final class io.ktor.utils.io/CountedByteReadChannel : io.ktor.utils.io/ByteReadChannel { // io.ktor.utils.io/CountedByteReadChannel|null[0] constructor (io.ktor.utils.io/ByteReadChannel) // io.ktor.utils.io/CountedByteReadChannel.|(io.ktor.utils.io.ByteReadChannel){}[0] diff --git a/ktor-io/common/src/io/ktor/utils/io/Exceptions.kt b/ktor-io/common/src/io/ktor/utils/io/Exceptions.kt index f4cc7a25fb7..8d58e1a0bee 100644 --- a/ktor-io/common/src/io/ktor/utils/io/Exceptions.kt +++ b/ktor-io/common/src/io/ktor/utils/io/Exceptions.kt @@ -28,3 +28,11 @@ public class ClosedWriteChannelException(cause: Throwable? = null) : ClosedByteC * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.utils.io.ClosedReadChannelException) */ public class ClosedReadChannelException(cause: Throwable? = null) : ClosedByteChannelException(cause) + +/** + * Exception thrown when a network connection is closed or reset by peer. + * This exception is used to signal that the underlying connection was terminated. + * + * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.utils.io.ConnectionClosedException) + */ +public class ConnectionClosedException(message: String = "Connection was closed") : IOException(message) diff --git a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/CIOApplicationEngine.kt b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/CIOApplicationEngine.kt index 18bfc89b103..e79afa1328b 100644 --- a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/CIOApplicationEngine.kt +++ b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/CIOApplicationEngine.kt @@ -6,10 +6,12 @@ package io.ktor.server.cio import io.ktor.events.* import io.ktor.http.* +import io.ktor.http.cio.Request import io.ktor.server.application.* import io.ktor.server.cio.backend.* import io.ktor.server.cio.internal.* import io.ktor.server.engine.* +import io.ktor.server.http.HttpRequestCloseHandlerKey import io.ktor.server.request.* import io.ktor.server.response.* import io.ktor.util.pipeline.* @@ -169,7 +171,15 @@ public class CIOApplicationEngine( return transferEncoding != null || (contentLength != null && contentLength > 0) } - private suspend fun ServerRequestScope.handleRequest(request: io.ktor.http.cio.Request) { + @OptIn(InternalAPI::class) + private fun ServerRequestScope.setCloseHandler(call: CIOApplicationCall) { + onClose = { + val requestCloseHandler = call.attributes.getOrNull(HttpRequestCloseHandlerKey) + requestCloseHandler?.invoke() + } + } + + private suspend fun ServerRequestScope.handleRequest(request: Request) { withContext(userDispatcher) requestContext@{ val call = CIOApplicationCall( applicationProvider(), @@ -186,6 +196,7 @@ public class CIOApplicationEngine( try { addHandlerForExpectedHeader(output, call) + setCloseHandler(call) pipeline.execute(call) } catch (error: Throwable) { handleFailure(call, error) diff --git a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerPipeline.kt b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerPipeline.kt index 4d0d3236033..3be61a73bfb 100644 --- a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerPipeline.kt +++ b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerPipeline.kt @@ -59,11 +59,12 @@ public fun CoroutineScope.startServerConnectionPipeline( val requestContext = RequestHandlerCoroutine + Dispatchers.Unconfined + var handlerScope: ServerRequestScope? = null try { while (true) { // parse requests loop val request = try { parseRequest(connection.input) ?: break - } catch (cause: TooLongLineException) { + } catch (_: TooLongLineException) { respondBadRequest(actorChannel) break // end pipeline loop } catch (io: IOException) { @@ -113,7 +114,7 @@ public fun CoroutineScope.startServerConnectionPipeline( contentType ) expectedHttpUpgrade = !expectedHttpBody && expectHttpUpgrade(request.method, upgrade, connectionOptions) - } catch (cause: Throwable) { + } catch (_: Throwable) { request.release() response.writePacket(BadRequestPacket.copy()) response.close() @@ -129,7 +130,7 @@ public fun CoroutineScope.startServerConnectionPipeline( val upgraded = if (expectedHttpUpgrade) CompletableDeferred() else null launch(requestContext, start = CoroutineStart.UNDISPATCHED) { - val handlerScope = ServerRequestScope( + handlerScope = ServerRequestScope( coroutineContext, requestBody, response, @@ -181,10 +182,11 @@ public fun CoroutineScope.startServerConnectionPipeline( if (isLastHttpRequest(version, connectionOptions)) break } - } catch (cause: IOException) { + } catch (_: IOException) { // already handled coroutineContext.cancel() } finally { + handlerScope?.onClose?.invoke() actorChannel.close() } } diff --git a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerRequestScope.kt b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerRequestScope.kt index b1b6ad7710f..bc35ec73ce6 100644 --- a/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerRequestScope.kt +++ b/ktor-server/ktor-server-cio/common/src/io/ktor/server/cio/backend/ServerRequestScope.kt @@ -42,4 +42,6 @@ public class ServerRequestScope internal constructor( localAddress, upgraded ) + + internal var onClose: (() -> Unit)? = null } diff --git a/ktor-server/ktor-server-cio/jvm/test/io/ktor/tests/server/cio/CIOEngineTestJvm.kt b/ktor-server/ktor-server-cio/jvm/test/io/ktor/tests/server/cio/CIOEngineTestJvm.kt index 5b1829bc5da..4f0219ef7aa 100644 --- a/ktor-server/ktor-server-cio/jvm/test/io/ktor/tests/server/cio/CIOEngineTestJvm.kt +++ b/ktor-server/ktor-server-cio/jvm/test/io/ktor/tests/server/cio/CIOEngineTestJvm.kt @@ -88,3 +88,11 @@ class CIOHooksTest : HooksTestSuite(CIO) { + init { + enableSsl = false + enableHttp2 = false + } +} diff --git a/ktor-server/ktor-server-core/api/ktor-server-core.api b/ktor-server/ktor-server-core/api/ktor-server-core.api index 05ff1a15817..b0a64d09935 100644 --- a/ktor-server/ktor-server-core/api/ktor-server-core.api +++ b/ktor-server/ktor-server-core/api/ktor-server-core.api @@ -782,6 +782,16 @@ public final class io/ktor/server/http/HttpDateJvmKt { public static final fun toHttpDateString (Ljava/time/temporal/Temporal;)Ljava/lang/String; } +public final class io/ktor/server/http/HttpRequestLifecycleConfig { + public final fun getCancelCallOnClose ()Z + public final fun setCancelCallOnClose (Z)V +} + +public final class io/ktor/server/http/HttpRequestLifecycleKt { + public static final fun getHttpRequestCloseHandlerKey ()Lio/ktor/util/AttributeKey; + public static final fun getHttpRequestLifecycle ()Lio/ktor/server/application/RouteScopedPlugin; +} + public final class io/ktor/server/http/LinkHeaderKt { public static final fun link (Lio/ktor/server/response/ApplicationResponse;Lio/ktor/http/LinkHeader;)V public static final fun link (Lio/ktor/server/response/ApplicationResponse;Ljava/lang/String;[Ljava/lang/String;)V diff --git a/ktor-server/ktor-server-core/api/ktor-server-core.klib.api b/ktor-server/ktor-server-core/api/ktor-server-core.klib.api index f574b5634e1..97b976340fd 100644 --- a/ktor-server/ktor-server-core/api/ktor-server-core.klib.api +++ b/ktor-server/ktor-server-core/api/ktor-server-core.klib.api @@ -671,6 +671,12 @@ final class io.ktor.server.http.content/HttpStatusCodeContent : io.ktor.http.con final fun toString(): kotlin/String // io.ktor.server.http.content/HttpStatusCodeContent.toString|toString(){}[0] } +final class io.ktor.server.http/HttpRequestLifecycleConfig { // io.ktor.server.http/HttpRequestLifecycleConfig|null[0] + final var cancelCallOnClose // io.ktor.server.http/HttpRequestLifecycleConfig.cancelCallOnClose|{}cancelCallOnClose[0] + final fun (): kotlin/Boolean // io.ktor.server.http/HttpRequestLifecycleConfig.cancelCallOnClose.|(){}[0] + final fun (kotlin/Boolean) // io.ktor.server.http/HttpRequestLifecycleConfig.cancelCallOnClose.|(kotlin.Boolean){}[0] +} + final class io.ktor.server.plugins/CannotTransformContentToTypeException : io.ktor.server.plugins/ContentTransformationException, kotlinx.coroutines/CopyableThrowable { // io.ktor.server.plugins/CannotTransformContentToTypeException|null[0] constructor (kotlin.reflect/KType) // io.ktor.server.plugins/CannotTransformContentToTypeException.|(kotlin.reflect.KType){}[0] @@ -1709,6 +1715,10 @@ final val io.ktor.server.http.content/isCompressionSuppressed // io.ktor.server. final fun (io.ktor.server.application/ApplicationCall).(): kotlin/Boolean // io.ktor.server.http.content/isCompressionSuppressed.|@io.ktor.server.application.ApplicationCall(){}[0] final val io.ktor.server.http.content/isDecompressionSuppressed // io.ktor.server.http.content/isDecompressionSuppressed|@io.ktor.server.application.ApplicationCall{}isDecompressionSuppressed[0] final fun (io.ktor.server.application/ApplicationCall).(): kotlin/Boolean // io.ktor.server.http.content/isDecompressionSuppressed.|@io.ktor.server.application.ApplicationCall(){}[0] +final val io.ktor.server.http/HttpRequestCloseHandlerKey // io.ktor.server.http/HttpRequestCloseHandlerKey|{}HttpRequestCloseHandlerKey[0] + final fun (): io.ktor.util/AttributeKey> // io.ktor.server.http/HttpRequestCloseHandlerKey.|(){}[0] +final val io.ktor.server.http/HttpRequestLifecycle // io.ktor.server.http/HttpRequestLifecycle|{}HttpRequestLifecycle[0] + final fun (): io.ktor.server.application/RouteScopedPlugin // io.ktor.server.http/HttpRequestLifecycle.|(){}[0] final val io.ktor.server.logging/mdcProvider // io.ktor.server.logging/mdcProvider|@io.ktor.server.application.Application{}mdcProvider[0] final fun (io.ktor.server.application/Application).(): io.ktor.server.logging/MDCProvider // io.ktor.server.logging/mdcProvider.|@io.ktor.server.application.Application(){}[0] final val io.ktor.server.plugins/MutableOriginConnectionPointKey // io.ktor.server.plugins/MutableOriginConnectionPointKey|{}MutableOriginConnectionPointKey[0] diff --git a/ktor-server/ktor-server-core/common/src/io/ktor/server/http/HttpRequestLifecycle.kt b/ktor-server/ktor-server-core/common/src/io/ktor/server/http/HttpRequestLifecycle.kt new file mode 100644 index 00000000000..3c4e808a410 --- /dev/null +++ b/ktor-server/ktor-server-core/common/src/io/ktor/server/http/HttpRequestLifecycle.kt @@ -0,0 +1,106 @@ +/* + * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.http + +import io.ktor.server.application.* +import io.ktor.server.application.hooks.* +import io.ktor.util.* +import io.ktor.utils.io.* +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.cancel + +/** + * Configuration for the [HttpRequestLifecycle] plugin. + */ +public class HttpRequestLifecycleConfig internal constructor() { + /** + * When `true`, cancels the call coroutine context if the other peer resets the client connection. + * When `false` (default), request processing continues even if the connection is closed. + * + * **When to use this property: ** + * - Set to `true` for long-running or resource-intensive requests where you want to stop processing + * immediately when the client disconnects (e.g., streaming, batch processing, heavy computations) + * - Keep as `false` (default) for short requests, or when you need to complete processing regardless + * of client connection status (e.g., important side effects, database transactions) + * + * Example: + * ```kotlin + * install(HttpRequestLifecycle) { + * cancelCallOnClose = true + * } + * ``` + */ + public var cancelCallOnClose: Boolean = false +} + +/** + * Internal attribute key for storing the connection close handler callback. + */ +@InternalAPI +public val HttpRequestCloseHandlerKey: AttributeKey<() -> Unit> = AttributeKey<() -> Unit>("HttpRequestCloseHandler") + +/** + * A plugin that manages the HTTP request lifecycle, particularly handling client disconnections. + * + * The [HttpRequestLifecycle] plugin allows you to detect and respond to client connection closures + * during request processing. When configured with [HttpRequestLifecycleConfig.cancelCallOnClose] set to `true`, + * the plugin will automatically cancel the request handling coroutine if the client disconnects, + * preventing unnecessary processing and freeing up resources. + * + * Remember, when the coroutine context is canceled, the next suspension point will throw [CancellationException], but until + * that moment it doesn't stop any blocking operations, so call `call.coroutineContext.ensureActive` if needed. + * Plugin only works for CIO and Netty engines. Other implementations fail on closed connection only when trying to write some response. + * + * This is particularly useful for: + * - Long-running requests where the client may disconnect before completion + * - Streaming responses where detecting disconnection allows early cleanup + * - Resource-intensive operations that should be canceled when the client is no longer waiting + * + * ## Example + * + * ```kotlin + * install(HttpRequestLifecycle) { + * cancelCallOnClose = true + * } + * + * routing { + * get("/long-process") { + * try { + * // Long-running operation + * repeat(100) { + * // throws an exception if the client disconnects during processing + * call.coroutineContext.ensureActive() + * // Process more data... + * logger.info("Very important work.") + * } + * call.respond("Completed") + * } catch (e: CancellationException) { + * // Handle client disconnected, clean up resources + * } + * } + * } + * ``` + */ +@OptIn(InternalAPI::class) +public val HttpRequestLifecycle: RouteScopedPlugin = createRouteScopedPlugin( + name = "HttpRequestLifecycle", + createConfiguration = ::HttpRequestLifecycleConfig +) { + on(CallSetup) { call -> + if ( + !this@createRouteScopedPlugin.pluginConfig.cancelCallOnClose || + call.attributes.contains(HttpRequestCloseHandlerKey) + ) { + return@on + } + call.attributes.put(HttpRequestCloseHandlerKey) { + val cause = CancellationException( + "Call context was cancelled by `HttpRequestLifecycle` plugin", + ConnectionClosedException() + ) + call.coroutineContext.cancel(cause) + } + } +} diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationCallHandler.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationCallHandler.kt index 1e6eb77f15e..a2b290e70bc 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationCallHandler.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationCallHandler.kt @@ -8,6 +8,7 @@ import io.ktor.http.* import io.ktor.http.HttpHeaders import io.ktor.server.application.* import io.ktor.server.engine.* +import io.ktor.server.http.HttpRequestCloseHandlerKey import io.ktor.server.netty.http1.* import io.ktor.util.pipeline.* import io.ktor.utils.io.* @@ -25,6 +26,7 @@ internal class NettyApplicationCallHandler( private val enginePipeline: EnginePipeline ) : ChannelInboundHandlerAdapter(), CoroutineScope { private var currentJob: Job? = null + private var currentCall: PipelineCall? = null override val coroutineContext: CoroutineContext = userCoroutineContext @@ -35,9 +37,21 @@ internal class NettyApplicationCallHandler( } } + internal fun onConnectionClose(context: ChannelHandlerContext) { + if (context.channel().isActive) { + return + } + currentCall?.let { + currentCall = null + @OptIn(InternalAPI::class) + it.attributes.getOrNull(HttpRequestCloseHandlerKey)?.invoke() + } + } + private fun handleRequest(context: ChannelHandlerContext, call: PipelineCall) { val callContext = CallHandlerCoroutineName + NettyDispatcher.CurrentContext(context) + currentCall = call currentJob = launch(callContext, start = CoroutineStart.UNDISPATCHED) { when { call is NettyHttp1ApplicationCall && !call.request.isValid() -> { @@ -67,6 +81,11 @@ internal class NettyApplicationCallHandler( } } + override fun channelInactive(ctx: ChannelHandlerContext) { + onConnectionClose(ctx) + ctx.fireChannelInactive() + } + private fun respond408RequestTimeout(ctx: ChannelHandlerContext) { val response = DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_TIMEOUT) response.headers().add(HttpHeaders.ContentLength, "0") diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationResponse.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationResponse.kt index 3ca04b25553..85ee9113e44 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationResponse.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/NettyApplicationResponse.kt @@ -33,6 +33,9 @@ public abstract class NettyApplicationResponse( internal var responseChannel: ByteReadChannel = ByteReadChannel.Empty + private val canRespond: Boolean + get() = !responseMessageSent && context.channel().isActive + override suspend fun respondOutgoingContent(content: OutgoingContent) { try { super.respondOutgoingContent(content) @@ -51,7 +54,7 @@ public abstract class NettyApplicationResponse( // because it should've been set by commitHeaders earlier val chunked = headers[HttpHeaders.TransferEncoding] == "chunked" - if (responseMessageSent) return + if (!canRespond) return val message = responseMessage(chunked, bytes) responseChannel = when (message) { @@ -111,7 +114,7 @@ public abstract class NettyApplicationResponse( } internal fun sendResponse(chunked: Boolean = true, content: ByteReadChannel) { - if (responseMessageSent) return + if (!canRespond) return responseChannel = content responseMessage = when { diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt index dcad4e3875d..cfdf3919c45 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http1/NettyHttp1Handler.kt @@ -88,7 +88,8 @@ internal class NettyHttp1Handler( } override fun channelInactive(context: ChannelHandlerContext) { - context.pipeline().remove(NettyApplicationCallHandler::class.java) + val handler = context.pipeline().remove(NettyApplicationCallHandler::class.java) + handler?.onConnectionClose(context) context.fireChannelInactive() } diff --git a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt index 34cad8e9a42..c645828cac2 100644 --- a/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt +++ b/ktor-server/ktor-server-netty/jvm/src/io/ktor/server/netty/http2/NettyHttp2Handler.kt @@ -162,7 +162,7 @@ internal class NettyHttp2Handler( try { Http2FrameCodec::class.javaObjectType.getDeclaredField("streamKey") .also { it.isAccessible = true } - } catch (cause: Throwable) { + } catch (_: Throwable) { null } } @@ -176,7 +176,7 @@ internal class NettyHttp2Handler( try { function.invoke(this, streamKey, childStream) - } catch (cause: Throwable) { + } catch (_: Throwable) { return false } @@ -189,7 +189,7 @@ internal class NettyHttp2Handler( private tailrec fun Class<*>.findIdField(): Field { val idField = try { getDeclaredField("id") - } catch (t: NoSuchFieldException) { + } catch (_: NoSuchFieldException) { null } if (idField != null) { diff --git a/ktor-server/ktor-server-netty/jvm/test/io/ktor/tests/server/netty/NettyEngineTest.kt b/ktor-server/ktor-server-netty/jvm/test/io/ktor/tests/server/netty/NettyEngineTest.kt index 67fbee0f3d3..31bcfde79b9 100644 --- a/ktor-server/ktor-server-netty/jvm/test/io/ktor/tests/server/netty/NettyEngineTest.kt +++ b/ktor-server/ktor-server-netty/jvm/test/io/ktor/tests/server/netty/NettyEngineTest.kt @@ -386,3 +386,11 @@ class NettyH2cEnabledTest : } } } + +class NettyHttpRequestLifecycleTest : + HttpRequestLifecycleTest(Netty) { + init { + enableSsl = true + enableHttp2 = true + } +} diff --git a/ktor-server/ktor-server-test-base/jvm/src/io/ktor/server/test/base/EngineTestBaseJvm.kt b/ktor-server/ktor-server-test-base/jvm/src/io/ktor/server/test/base/EngineTestBaseJvm.kt index 085fa81f926..88d5def036c 100644 --- a/ktor-server/ktor-server-test-base/jvm/src/io/ktor/server/test/base/EngineTestBaseJvm.kt +++ b/ktor-server/ktor-server-test-base/jvm/src/io/ktor/server/test/base/EngineTestBaseJvm.kt @@ -7,7 +7,6 @@ package io.ktor.server.test.base import io.ktor.client.* import io.ktor.client.engine.apache.* import io.ktor.client.engine.cio.* -import io.ktor.client.plugins.* import io.ktor.client.request.* import io.ktor.client.statement.* import io.ktor.http.* @@ -249,10 +248,10 @@ actual abstract class EngineTestBase< builder: suspend HttpRequestBuilder.() -> Unit, block: suspend HttpResponse.(Int) -> Unit ) { - withUrl("http://127.0.0.1:$port$path", port, builder, block) + withHttp1("http://127.0.0.1:$port$path", port, builder, block) if (enableSsl) { - withUrl("https://127.0.0.1:$sslPort$path", sslPort, builder, block) + withHttp1("https://127.0.0.1:$sslPort$path", sslPort, builder, block) } if (enableHttp2 && enableSsl) { @@ -270,7 +269,7 @@ actual abstract class EngineTestBase< } } - private suspend fun withUrl( + protected suspend fun withHttp1( urlString: String, port: Int, builder: suspend HttpRequestBuilder.() -> Unit, @@ -284,22 +283,13 @@ actual abstract class EngineTestBase< } } - private suspend fun withHttp2( + protected suspend fun withHttp2( url: String, port: Int, builder: suspend HttpRequestBuilder.() -> Unit, block: suspend HttpResponse.(Int) -> Unit ) { - HttpClient(Apache) { - followRedirects = false - expectSuccess = false - engine { - pipelining = true - sslContext = SSLContext.getInstance("SSL").apply { - init(null, trustAllCertificates, SecureRandom()) - } - } - }.use { client -> + createApacheClient().use { client -> client.prepareRequest(url) { builder() }.execute { response -> @@ -310,33 +300,48 @@ actual abstract class EngineTestBase< companion object { val keyStoreFile: File = File("build/temp.jks") - lateinit var keyStore: KeyStore - lateinit var sslContext: SSLContext - lateinit var trustManager: X509TrustManager + val keyStore: KeyStore by lazy { generateCertificate(keyStoreFile) } lateinit var client: HttpClient - @BeforeAll - @JvmStatic - fun setupAll() { - keyStore = generateCertificate(keyStoreFile, algorithm = "SHA256withECDSA", keySizeInBits = 256) + fun createTrustManager(): X509TrustManager { + val sslContext = SSLContext.getInstance("TLS") val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) tmf.init(keyStore) - sslContext = SSLContext.getInstance("TLS") sslContext.init(null, tmf.trustManagers, null) - trustManager = tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager + return tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager + } - client = HttpClient(CIO) { + fun createCIOClient(): HttpClient { + return HttpClient(CIO) { engine { - https.trustManager = trustManager + https.trustManager = createTrustManager() https.serverName = "localhost" requestTimeout = 0 } + followRedirects = false + expectSuccess = false + } + } + fun createApacheClient(): HttpClient { + return HttpClient(Apache) { followRedirects = false expectSuccess = false + engine { + pipelining = true + sslContext = SSLContext.getInstance("SSL").apply { + init(null, trustAllCertificates, SecureRandom()) + } + } } } + @BeforeAll + @JvmStatic + fun setupAll() { + client = createCIOClient() + } + @AfterAll @JvmStatic fun cleanup() { @@ -354,13 +359,14 @@ actual abstract class EngineTestBase< } } while (true) } - } - private val trustAllCertificates = arrayOf( - object : X509TrustManager { - override fun getAcceptedIssuers(): Array = emptyArray() - override fun checkClientTrusted(certs: Array, authType: String) {} - override fun checkServerTrusted(certs: Array, authType: String) {} - } - ) + private val trustAllCertificates = arrayOf( + @Suppress("CustomX509TrustManager") + object : X509TrustManager { + override fun getAcceptedIssuers(): Array = emptyArray() + override fun checkClientTrusted(certs: Array, authType: String) {} + override fun checkServerTrusted(certs: Array, authType: String) {} + } + ) + } } diff --git a/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/ConfigTestSuite.kt b/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/ConfigTestSuite.kt index 2befe58ea6b..8a807c016e3 100644 --- a/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/ConfigTestSuite.kt +++ b/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/ConfigTestSuite.kt @@ -10,8 +10,6 @@ import java.util.concurrent.* import kotlin.system.* import kotlin.test.* -var count = 0 - abstract class ConfigTestSuite(val engine: ApplicationEngineFactory<*, *>) { @Test @@ -58,4 +56,8 @@ abstract class ConfigTestSuite(val engine: ApplicationEngineFactory<*, *>) { assertTrue(time < 100, "Stop time is $time") } + + private companion object { + var count = 0 + } } diff --git a/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/HttpRequestLifecycleTest.kt b/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/HttpRequestLifecycleTest.kt new file mode 100644 index 00000000000..3d0de3b9de8 --- /dev/null +++ b/ktor-server/ktor-server-test-suites/jvm/src/io/ktor/server/testing/suites/HttpRequestLifecycleTest.kt @@ -0,0 +1,153 @@ +/* + * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package io.ktor.server.testing.suites + +import io.ktor.client.statement.* +import io.ktor.http.* +import io.ktor.server.engine.* +import io.ktor.server.http.* +import io.ktor.server.response.* +import io.ktor.server.routing.* +import io.ktor.server.test.base.* +import io.ktor.util.* +import io.ktor.utils.io.* +import kotlinx.coroutines.* +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.channels.Channel +import kotlin.concurrent.atomics.AtomicInt +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.concurrent.atomics.incrementAndFetch +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.seconds + +abstract class HttpRequestLifecycleTest( + val engine: ApplicationEngineFactory +) : EngineTestBase(engine) { + + private suspend fun cancellableRoute(handler: RoutingHandler) { + createAndStartServer { + install(plugin = HttpRequestLifecycle) { + cancelCallOnClose = true + } + get(handler) + } + } + + @Test + @OptIn(ExperimentalAtomicApi::class) + fun testClientDisconnectionCancelsRequest() = runTest { + val requestStartedCnt = AtomicInt(0) + val requestCancelledCnt = AtomicInt(0) + + val requestStarted = Channel(Channel.UNLIMITED) + val requestCancelled = Channel(Channel.UNLIMITED) + + cancellableRoute { + requestStarted.send(requestStartedCnt.incrementAndFetch()) + try { + // very long operation + repeat(100) { + call.coroutineContext.ensureActive() + delay(200.milliseconds) + } + } catch (err: CancellationException) { + @OptIn(InternalAPI::class) + assertTrue(err.rootCause is ConnectionClosedException) + requestCancelled.send(requestCancelledCnt.incrementAndFetch()) + } + } + + fun resetRequestOnStart(request: suspend () -> Unit) = launch { + client = createApacheClient() + client.use { + val requestJob = launch { + runCatching { request() } + } + withTimeout(10.seconds) { + requestStarted.receive() // Wait for the request to start processing on the server + } + // Cancel the request and close the client to force TCP to disconnect + requestJob.cancel() + } + } + + buildList { + resetRequestOnStart { + withHttp1("http://127.0.0.1:$port", port, {}, {}) + }.also { add(it) } + if (enableSsl) { + resetRequestOnStart { + withHttp1("https://127.0.0.1:$sslPort", sslPort, {}, {}) + }.also { add(it) } + } + if (enableSsl && enableHttp2) { + resetRequestOnStart { + withHttp2("https://127.0.0.1:$sslPort", sslPort, {}, {}) + }.also { add(it) } + } + }.joinAll() + + withTimeout(10.seconds) { + do { + // Wait for the request to be canceled + val cancelledCount = requestCancelled.receive() + } while (cancelledCount < requestStartedCnt.load()) + } + } + + @Test + fun testHttpRequestLifecycleSuccess() = runTest { + val requestCompleted = CompletableDeferred() + + cancellableRoute { + delay(100.milliseconds) + call.respondText("OK") + requestCompleted.complete(Unit) + } + + client = createApacheClient() + client.use { + withUrl("/") { + assertEquals(HttpStatusCode.OK, status) + assertEquals("OK", bodyAsText()) + } + } + + withTimeout(10.seconds) { + requestCompleted.await() + } + } + + @Test + fun testHttpRequestLifecycleWithStream() = runTest { + val requestCompleted = CompletableDeferred() + + cancellableRoute { + call.respondOutputStream { + repeat(3) { + write("OK;".toByteArray()) + delay(100.milliseconds) + } + requestCompleted.complete(Unit) + } + } + + client = createApacheClient() + client.use { + withUrl("/") { + assertEquals(HttpStatusCode.OK, status) + assertEquals(ContentType.Application.OctetStream, contentType()) + assertEquals("OK;OK;OK;", bodyAsText()) + } + } + + withTimeout(10.seconds) { + requestCompleted.await() + } + } +}