diff --git a/ktor-server/ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/CORS.kt b/ktor-server/ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/CORS.kt index 355e1b17128..1c900063b74 100644 --- a/ktor-server/ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/CORS.kt +++ b/ktor-server/ktor-server-plugins/ktor-server-cors/common/src/io/ktor/server/plugins/cors/CORS.kt @@ -54,7 +54,7 @@ internal fun PluginBuilder.buildPlugin() { val allowNonSimpleContentTypes: Boolean = pluginConfig.allowNonSimpleContentTypes val headersList = pluginConfig.headers.filterNot { it in CORSConfig.CorsSimpleRequestHeaders } .let { if (allowNonSimpleContentTypes) it + HttpHeaders.ContentType else it } - val methodsListHeaderValue = methods.filterNot { it in CORSConfig.CorsDefaultMethods } + val methodsListHeaderValue = methods.distinct() .map { it.value } .sorted() .joinToString(", ") diff --git a/ktor-server/ktor-server-tests/common/test/io/ktor/tests/server/plugins/CORSTest.kt b/ktor-server/ktor-server-tests/common/test/io/ktor/tests/server/plugins/CORSTest.kt index 98b15d52a33..f841034d766 100644 --- a/ktor-server/ktor-server-tests/common/test/io/ktor/tests/server/plugins/CORSTest.kt +++ b/ktor-server/ktor-server-tests/common/test/io/ktor/tests/server/plugins/CORSTest.kt @@ -1592,4 +1592,33 @@ class CORSTest { assertEquals(response.status, HttpStatusCode.Forbidden) } } + + @Test + fun testPreflightIncludesDefaultMethods() = testApplication { + install(CORS) { + anyHost() + allowMethod(HttpMethod.Put) + } + + routing { + get("/") { call.respond("OK") } + post("/") { call.respond("OK") } + put("/") { call.respond("OK") } + } + + val response = client.options("/") { + header(HttpHeaders.Origin, "http://my-host") + header(HttpHeaders.AccessControlRequestMethod, "PUT") + } + + assertEquals(HttpStatusCode.OK, response.status) + val allowMethods = response.headers[HttpHeaders.AccessControlAllowMethods]?.split(", ")?.toSet() + assertNotNull(allowMethods) + + assertTrue(HttpMethod.Get.value in allowMethods!!) + assertTrue(HttpMethod.Post.value in allowMethods) + assertTrue(HttpMethod.Head.value in allowMethods) + assertTrue(HttpMethod.Put.value in allowMethods) + } + }