diff --git a/src/main/java/org/wiremock/grpc/dsl/GrpcStubMappingBuilder.java b/src/main/java/org/wiremock/grpc/dsl/GrpcStubMappingBuilder.java index d6ed4da..d5b19d7 100644 --- a/src/main/java/org/wiremock/grpc/dsl/GrpcStubMappingBuilder.java +++ b/src/main/java/org/wiremock/grpc/dsl/GrpcStubMappingBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023 Thomas Akehurst + * Copyright (C) 2023-2025 Thomas Akehurst * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,13 @@ import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.http.LogNormal; import com.github.tomakehurst.wiremock.http.UniformDistribution; +import com.github.tomakehurst.wiremock.matching.MultiValuePattern; import com.github.tomakehurst.wiremock.matching.StringValuePattern; import com.github.tomakehurst.wiremock.stubbing.StubMapping; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.wiremock.annotations.Beta; @Beta(justification = "Incubating extension: https://github.com/wiremock/wiremock/issues/2383") @@ -36,6 +39,8 @@ public class GrpcStubMappingBuilder { private GrpcResponseDefinitionBuilder responseBuilder; private List requestMessageJsonPatterns = new ArrayList<>(); + private Map stringValuePatternMap = new HashMap<>(); + private Map multiValuePatternMap = new HashMap<>(); public GrpcStubMappingBuilder(String method) { this.method = method; @@ -46,6 +51,16 @@ public GrpcStubMappingBuilder withRequestMessage(StringValuePattern requestMessa return this; } + public GrpcStubMappingBuilder withHeader(String header, StringValuePattern pattern) { + this.stringValuePatternMap.put(header, pattern); + return this; + } + + public GrpcStubMappingBuilder withHeader(String header, MultiValuePattern pattern) { + this.multiValuePatternMap.put(header, pattern); + return this; + } + public GrpcStubMappingBuilder willReturn(GrpcResponseDefinitionBuilder responseBuilder) { this.responseBuilder = responseBuilder; return this; @@ -83,6 +98,8 @@ public GrpcStubMappingBuilder withUniformRandomDelay( public StubMapping build(String serviceName) { final MappingBuilder mappingBuilder = WireMock.post(grpcUrlPath(serviceName, method)); requestMessageJsonPatterns.forEach(mappingBuilder::withRequestBody); + stringValuePatternMap.forEach((mappingBuilder::withHeader)); + multiValuePatternMap.forEach((mappingBuilder::withHeader)); return mappingBuilder.willReturn(responseBuilder.build()).build(); } } diff --git a/src/test/java/org/wiremock/grpc/GrpcAcceptanceTest.java b/src/test/java/org/wiremock/grpc/GrpcAcceptanceTest.java index f30b1c0..4679e64 100644 --- a/src/test/java/org/wiremock/grpc/GrpcAcceptanceTest.java +++ b/src/test/java/org/wiremock/grpc/GrpcAcceptanceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2023-2024 Thomas Akehurst + * Copyright (C) 2023-2025 Thomas Akehurst * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.wiremock.grpc; import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.including; import static com.github.tomakehurst.wiremock.client.WireMock.moreThanOrExactly; import static com.github.tomakehurst.wiremock.client.WireMock.okJson; import static com.github.tomakehurst.wiremock.client.WireMock.post; @@ -34,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.wiremock.grpc.dsl.WireMockGrpc.Status; +import static org.wiremock.grpc.dsl.WireMockGrpc.Status.UNIMPLEMENTED; import static org.wiremock.grpc.dsl.WireMockGrpc.equalToMessage; import static org.wiremock.grpc.dsl.WireMockGrpc.json; import static org.wiremock.grpc.dsl.WireMockGrpc.jsonTemplate; @@ -50,6 +53,8 @@ import com.github.tomakehurst.wiremock.junit5.WireMockExtension; import com.google.common.base.Stopwatch; import com.google.protobuf.Empty; +import io.grpc.Channel; +import io.grpc.ClientInterceptors; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.StatusRuntimeException; @@ -71,6 +76,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.wiremock.grpc.RequestHeadersAcceptanceTest.HeaderAdditionInterceptor; import org.wiremock.grpc.client.AnotherGreetingsClient; import org.wiremock.grpc.client.GreetingsClient; import org.wiremock.grpc.dsl.WireMockGrpcService; @@ -529,4 +535,71 @@ public void onCompleted() { serverReflectionResponses.get(0).getListServicesResponse().getServiceList(); assertThat(serviceList.size(), is(4)); } + + @Test + void supportsGrpcHeaderMatchingWithStringValueMatchingHeader() { + String headerName = "My-Fancy-Header"; + String headerValue = "My-Fancy-Value"; + + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, new HeaderAdditionInterceptor(headerName, headerValue)); + + GreetingsClient interceptedGreetingsClient = new GreetingsClient(interceptedChannel); + + // Create a new stub + mockGreetingService.stubFor( + method("greeting") + .withHeader(headerName, equalTo(headerValue)) + .willReturn(message(HelloResponse.newBuilder().setGreeting("matched header")))); + + // Ensure the request was matched + assertThat(interceptedGreetingsClient.greet("Whatever"), is("matched header")); + } + + @Test + void supportsGrpcHeaderMatchingWithMultiValueMatchingHeader() { + String headerName = "My-Fancy-Header"; + String headerValue = "My-Fancy-Value"; + + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, new HeaderAdditionInterceptor(headerName, headerValue)); + + GreetingsClient interceptedGreetingsClient = new GreetingsClient(interceptedChannel); + + // Create a new stub + mockGreetingService.stubFor( + method("greeting") + .withHeader(headerName, including(headerValue)) + .willReturn(message(HelloResponse.newBuilder().setGreeting("matched header")))); + + // Ensure the request was matched + assertThat(interceptedGreetingsClient.greet("Whatever"), is("matched header")); + } + + @Test + void supportsGrpcHeaderMatchingWithMismatchingHeader() { + String headerName = "My-Fancy-Header"; + String headerValue = "My-Fancy-Value"; + + Channel interceptedChannel = + ClientInterceptors.intercept( + channel, new HeaderAdditionInterceptor(headerName, headerValue)); + + GreetingsClient interceptedGreetingsClient = new GreetingsClient(interceptedChannel); + + // Create a new stub + mockGreetingService.stubFor( + method("greeting") + .withHeader(headerName, equalTo("it-will-never-match-this-header-value")) + .willReturn( + message(HelloResponse.newBuilder().setGreeting("this should not have matched")))); + + // It should throw exception as the request should not match the setup: + Exception exception = + assertThrows( + StatusRuntimeException.class, () -> interceptedGreetingsClient.greet("Whatever")); + assertThat(exception.getMessage(), startsWith(UNIMPLEMENTED.toString())); + } } diff --git a/src/test/java/org/wiremock/grpc/RequestHeadersAcceptanceTest.java b/src/test/java/org/wiremock/grpc/RequestHeadersAcceptanceTest.java index 2c07282..43a478b 100644 --- a/src/test/java/org/wiremock/grpc/RequestHeadersAcceptanceTest.java +++ b/src/test/java/org/wiremock/grpc/RequestHeadersAcceptanceTest.java @@ -69,11 +69,14 @@ void tearDown() { @Test void arbitraryRequestHeaderCanBeUsedWhenMatchingAndTemplating() { - channel = ClientInterceptors.intercept(managedChannel, new HeaderAdditionInterceptor()); + String headerValue = "match me"; + channel = + ClientInterceptors.intercept( + managedChannel, new HeaderAdditionInterceptor(X_MY_HEADER, headerValue)); greetingsClient = new GreetingsClient(channel); wm.stubFor( post(urlPathEqualTo("/com.example.grpc.GreetingService/greeting")) - .withHeader(X_MY_HEADER, equalTo("match me")) + .withHeader(X_MY_HEADER, equalTo(headerValue)) .willReturn( okJson( "{\n" @@ -83,7 +86,7 @@ void arbitraryRequestHeaderCanBeUsedWhenMatchingAndTemplating() { String greeting = greetingsClient.greet("Whatever"); - assertThat(greeting, is("The header value was: match me")); + assertThat(greeting, is("The header value was: " + headerValue)); } @Test @@ -104,8 +107,15 @@ void binaryRequestHeaderCanBeUsed() { public static class HeaderAdditionInterceptor implements ClientInterceptor { - static final Metadata.Key CUSTOM_HEADER_KEY = - Metadata.Key.of(X_MY_HEADER, Metadata.ASCII_STRING_MARSHALLER); + final String headerName; + final String headerValue; + final Metadata.Key grpcHeaderKey; + + public HeaderAdditionInterceptor(String headerName, String headerValue) { + this.headerName = headerName; + this.headerValue = headerValue; + this.grpcHeaderKey = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + } @Override public ClientCall interceptCall( @@ -115,7 +125,7 @@ public ClientCall interceptCall( @Override public void start(Listener responseListener, Metadata headers) { - headers.put(CUSTOM_HEADER_KEY, "match me"); + headers.put(grpcHeaderKey, headerValue); super.start(responseListener, headers); } };