From 8ca9405630fdfd7fdf40458a2aca543353a4960c Mon Sep 17 00:00:00 2001 From: Matthew Pearson Date: Sun, 26 Oct 2025 08:44:09 +0000 Subject: [PATCH] Copy response http headers into grpc headers --- .../wiremock/grpc/internal/GrpcRequest.java | 2 +- .../HeaderCopyingServerInterceptor.java | 55 +++++++-- .../grpc/internal/UnaryServerCallHandler.java | 9 +- .../grpc/ResponseHeadersAcceptanceTest.java | 104 ++++++++++++++++++ 4 files changed, 157 insertions(+), 13 deletions(-) create mode 100644 src/test/java/org/wiremock/grpc/ResponseHeadersAcceptanceTest.java diff --git a/src/main/java/org/wiremock/grpc/internal/GrpcRequest.java b/src/main/java/org/wiremock/grpc/internal/GrpcRequest.java index a99e381..65596f4 100644 --- a/src/main/java/org/wiremock/grpc/internal/GrpcRequest.java +++ b/src/main/java/org/wiremock/grpc/internal/GrpcRequest.java @@ -97,7 +97,7 @@ public ContentTypeHeader contentTypeHeader() { @Override public HttpHeaders getHeaders() { - return HeaderCopyingServerInterceptor.HTTP_HEADERS_CONTEXT_KEY.get(); + return HeaderCopyingServerInterceptor.HTTP_REQUEST_HEADERS_CONTEXT_KEY.get(); } @Override diff --git a/src/main/java/org/wiremock/grpc/internal/HeaderCopyingServerInterceptor.java b/src/main/java/org/wiremock/grpc/internal/HeaderCopyingServerInterceptor.java index 6237faa..fbb22cd 100644 --- a/src/main/java/org/wiremock/grpc/internal/HeaderCopyingServerInterceptor.java +++ b/src/main/java/org/wiremock/grpc/internal/HeaderCopyingServerInterceptor.java @@ -17,28 +17,40 @@ import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; import static io.grpc.Metadata.BINARY_BYTE_MARSHALLER; +import static java.util.stream.Collectors.toList; import com.github.tomakehurst.wiremock.http.HttpHeader; import com.github.tomakehurst.wiremock.http.HttpHeaders; -import io.grpc.*; +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; + import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; +import java.util.concurrent.atomic.AtomicReference; public class HeaderCopyingServerInterceptor implements ServerInterceptor { - public static final Context.Key HTTP_HEADERS_CONTEXT_KEY = - Context.key("HTTP_HEADERS_CONTEXT_KEY"); + public static final Context.Key HTTP_REQUEST_HEADERS_CONTEXT_KEY = Context.key("HTTP_REQUEST_HEADERS_CONTEXT_KEY"); + + public static final Context.Key> HTTP_RESPONSE_HEADERS_CONTEXT_KEY = + Context.key("HTTP_RESPONSE_HEADERS_CONTEXT_KEY"); @Override public ServerCall.Listener interceptCall( ServerCall call, Metadata headers, ServerCallHandler next) { - final HttpHeaders httpHeaders = buildHttpHeaders(headers); - Context newContext = Context.current().withValue(HTTP_HEADERS_CONTEXT_KEY, httpHeaders); - return Contexts.interceptCall(newContext, call, headers, next); + final HttpHeaders httpHeaders = toHttpHeaders(headers); + Context newContext = Context.current().withValue(HTTP_REQUEST_HEADERS_CONTEXT_KEY, httpHeaders) + .withValue(HTTP_RESPONSE_HEADERS_CONTEXT_KEY, new AtomicReference<>()); + ServerCall responseHeadersHttpToGrpc = new HttpResponseHeadersToGrpcHeadersForwardingServerCall<>(call); + return Contexts.interceptCall(newContext, responseHeadersHttpToGrpc, headers, next); } - private static HttpHeaders buildHttpHeaders(Metadata metadata) { + private static HttpHeaders toHttpHeaders(Metadata metadata) { final List httpHeaderList = metadata.keys().stream() .map( @@ -54,8 +66,31 @@ private static HttpHeaders buildHttpHeaders(Metadata metadata) { return new HttpHeader( key, metadata.get(Metadata.Key.of(key, ASCII_STRING_MARSHALLER))); } - }) - .collect(Collectors.toList()); + }).collect(toList()); return new HttpHeaders(httpHeaderList); } + + private static Metadata fromHttpHeaders(HttpHeaders httpHeaders) { + Metadata metadata = new Metadata(); + httpHeaders.all().forEach(responseHttpHeader -> responseHttpHeader.values() + .forEach(v -> metadata.put(Metadata.Key.of(responseHttpHeader.key(), ASCII_STRING_MARSHALLER), v))); + return metadata; + } + + private static class HttpResponseHeadersToGrpcHeadersForwardingServerCall + extends SimpleForwardingServerCall { + + public HttpResponseHeadersToGrpcHeadersForwardingServerCall(ServerCall call) { + super(call); + } + + @Override + public void sendHeaders(Metadata headers) { + HttpHeaders responseHttpHeaders = HTTP_RESPONSE_HEADERS_CONTEXT_KEY.get().get(); + if (responseHttpHeaders != null) { + headers.merge(fromHttpHeaders(responseHttpHeaders)); + } + super.sendHeaders(headers); + } + } } diff --git a/src/main/java/org/wiremock/grpc/internal/UnaryServerCallHandler.java b/src/main/java/org/wiremock/grpc/internal/UnaryServerCallHandler.java index bcf0175..ac67de2 100644 --- a/src/main/java/org/wiremock/grpc/internal/UnaryServerCallHandler.java +++ b/src/main/java/org/wiremock/grpc/internal/UnaryServerCallHandler.java @@ -18,9 +18,11 @@ import static org.wiremock.grpc.dsl.GrpcResponseDefinitionBuilder.GRPC_STATUS_NAME; import static org.wiremock.grpc.dsl.GrpcResponseDefinitionBuilder.GRPC_STATUS_REASON; import static org.wiremock.grpc.internal.Delays.delayIfRequired; +import static org.wiremock.grpc.internal.HeaderCopyingServerInterceptor.HTTP_RESPONSE_HEADERS_CONTEXT_KEY; import com.github.tomakehurst.wiremock.common.Pair; import com.github.tomakehurst.wiremock.http.HttpHeader; +import com.github.tomakehurst.wiremock.http.HttpHeaders; import com.github.tomakehurst.wiremock.http.StubRequestHandler; import com.github.tomakehurst.wiremock.stubbing.ServeEvent; import com.google.protobuf.Descriptors; @@ -60,7 +62,10 @@ public void invoke(DynamicMessage request, StreamObserver respon stubRequestHandler.handle( wireMockRequest, (req, resp, attributes) -> { - final HttpHeader statusHeader = resp.getHeaders().getHeader(GRPC_STATUS_NAME); + HttpHeaders respHeaders = resp.getHeaders(); + HTTP_RESPONSE_HEADERS_CONTEXT_KEY.get().set(respHeaders); + + final HttpHeader statusHeader = respHeaders.getHeader(GRPC_STATUS_NAME); delayIfRequired(resp); @@ -75,7 +80,7 @@ public void invoke(DynamicMessage request, StreamObserver respon if (statusHeader.isPresent() && !statusHeader.firstValue().equals(Status.Code.OK.name())) { - final HttpHeader statusReasonHeader = resp.getHeaders().getHeader(GRPC_STATUS_REASON); + final HttpHeader statusReasonHeader = respHeaders.getHeader(GRPC_STATUS_REASON); final String reason = statusReasonHeader.isPresent() ? statusReasonHeader.firstValue() : ""; diff --git a/src/test/java/org/wiremock/grpc/ResponseHeadersAcceptanceTest.java b/src/test/java/org/wiremock/grpc/ResponseHeadersAcceptanceTest.java new file mode 100644 index 0000000..da12e45 --- /dev/null +++ b/src/test/java/org/wiremock/grpc/ResponseHeadersAcceptanceTest.java @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2023-2025 Thomas Akehurst + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wiremock.grpc; + +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +import com.example.grpc.GreetingServiceGrpc; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import com.google.common.collect.Lists; +import io.grpc.Channel; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; +import io.grpc.stub.MetadataUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.wiremock.grpc.client.GreetingsClient; +import org.wiremock.grpc.dsl.WireMockGrpcService; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +public class ResponseHeadersAcceptanceTest { + + public static final String X_MY_HEADER = "x-my-Header"; + WireMockGrpcService mockGreetingService; + ManagedChannel managedChannel; + Channel channel; + GreetingsClient greetingsClient; + WireMock wireMock; + + @RegisterExtension + public static WireMockExtension wm = + WireMockExtension.newInstance() + .options( + wireMockConfig() + // .dynamicPort() + .port(8282) + .withRootDirectory("src/test/resources/wiremock") + .extensions(new GrpcExtensionFactory())) + .build(); + + @BeforeEach + void init() { + wireMock = wm.getRuntimeInfo().getWireMock(); + mockGreetingService = new WireMockGrpcService(wireMock, GreetingServiceGrpc.SERVICE_NAME); + + managedChannel = + ManagedChannelBuilder.forAddress("localhost", wm.getPort()).usePlaintext().build(); + } + + @AfterEach + void tearDown() { + managedChannel.shutdown(); + } + + @Test + void httpResponseHeadersAreAddedToTheGrpcTrailers() { + AtomicReference headersCapture = new AtomicReference<>(); + AtomicReference trailersCapture = new AtomicReference<>(); + ClientInterceptor metadataInterceptor = MetadataUtils.newCaptureMetadataInterceptor(headersCapture, trailersCapture); + channel = ClientInterceptors.intercept(managedChannel, metadataInterceptor); + + greetingsClient = new GreetingsClient(channel); + wm.stubFor( + post(urlPathEqualTo("/com.example.grpc.GreetingService/greeting")) + .willReturn( + okJson("{\n" + " \"greeting\": \"Howdy!\"\n" + "}") + .withHeader(X_MY_HEADER, "first", "second", "third"))); + + String greeting = greetingsClient.greet("Whatever"); + + assertThat(greeting, is("Howdy!")); + + Metadata.Key key = Metadata.Key.of("x-my-header", Metadata.ASCII_STRING_MARSHALLER); + ArrayList values = Lists.newArrayList(headersCapture.get().getAll(key)); + assertThat(values, is(List.of("first", "second", "third"))); + } +}