diff --git a/build.gradle b/build.gradle index f52304c..17337c8 100644 --- a/build.gradle +++ b/build.gradle @@ -28,7 +28,7 @@ repositories { group 'org.wiremock' allprojects { - version = "0.11.0" + version = "0.11.2" sourceCompatibility = 11 targetCompatibility = 11 diff --git a/src/main/java/org/wiremock/grpc/internal/GrpcClient.java b/src/main/java/org/wiremock/grpc/internal/GrpcClient.java index 8573b05..3298473 100644 --- a/src/main/java/org/wiremock/grpc/internal/GrpcClient.java +++ b/src/main/java/org/wiremock/grpc/internal/GrpcClient.java @@ -17,6 +17,7 @@ import static com.github.tomakehurst.wiremock.http.Response.response; +import com.github.tomakehurst.wiremock.common.Pair; import com.github.tomakehurst.wiremock.core.Options; import com.github.tomakehurst.wiremock.http.HttpHeader; import com.github.tomakehurst.wiremock.http.HttpHeaders; @@ -27,9 +28,13 @@ import com.google.protobuf.DynamicMessage; import io.grpc.*; import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; + import java.io.IOException; +import java.net.URL; import java.util.ArrayList; import java.util.List; +import java.util.Objects; public class GrpcClient implements HttpClient { private final HttpClient delegateClient; @@ -54,10 +59,19 @@ public Response execute(Request request) throws IOException { GrpcContext context = BaseCallHandler.CONTEXT.get(); BaseCallHandler.CONTEXT.remove(); - Channel channel = - ManagedChannelBuilder.forAddress(request.getHost(), request.getPort()) - .usePlaintext() - .build(); + ManagedChannelBuilder managedChannelBuilder = ManagedChannelBuilder.forAddress(request.getHost(), request.getPort()); + if (request.getScheme().equals("https")) { + managedChannelBuilder.useTransportSecurity(); + } else { + managedChannelBuilder.usePlaintext(); + } + Metadata metadata = new Metadata(); + request.getHeaders().all().forEach(header -> + metadata.put(Metadata.Key.of(header.key(), Metadata.ASCII_STRING_MARSHALLER), header.firstValue()) + ); + ClientInterceptor clientInterceptor = MetadataUtils.newAttachHeadersInterceptor(metadata); + ManagedChannel channel = managedChannelBuilder.intercept(clientInterceptor).build(); + List headers = new ArrayList<>(); headers.add(new HttpHeader("Content-Type", "application/json")); Response.Builder grpcRespBuilder = response(); @@ -96,6 +110,12 @@ public Response execute(Request request) throws IOException { } headers.add(new HttpHeader(GrpcUtils.GRPC_STATUS_NAME, statusName)); headers.add(new HttpHeader(GrpcUtils.GRPC_STATUS_REASON, statusReason)); + } finally { + try { + channel.shutdown().awaitTermination(5, java.util.concurrent.TimeUnit.SECONDS); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } } return grpcRespBuilder.headers(new HttpHeaders(headers.toArray(HttpHeader[]::new))).build(); diff --git a/src/test/java/org/wiremock/grpc/GrpcProxy2Test.java b/src/test/java/org/wiremock/grpc/GrpcProxy2Test.java new file mode 100644 index 0000000..e2e4cef --- /dev/null +++ b/src/test/java/org/wiremock/grpc/GrpcProxy2Test.java @@ -0,0 +1,70 @@ +package org.wiremock.grpc; + +import com.example.grpc.GreetingServiceGrpc; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.github.tomakehurst.wiremock.junit5.WireMockExtension; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.wiremock.grpc.client.GreetingsClient; +import org.wiremock.grpc.dsl.WireMockGrpcService; +import org.wiremock.grpc.server.GreetingServer; + +import static com.github.tomakehurst.wiremock.client.WireMock.*; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class GrpcProxy2Test { + + WireMockGrpcService mockGreetingService; + ManagedChannel channel; + GreetingsClient greetingsClient; + WireMock wireMock; + GreetingServer greetingServer; + + @RegisterExtension + public static WireMockExtension wm = + WireMockExtension.newInstance() + .options( + wireMockConfig() + .dynamicPort() + .withRootDirectory("src/test/resources/wiremock") + .extensions(new GrpcExtensionFactory())) + .build(); + + @BeforeEach + void init() { + wireMock = wm.getRuntimeInfo().getWireMock(); + mockGreetingService = new WireMockGrpcService(wireMock, GreetingServiceGrpc.SERVICE_NAME); + + channel = ManagedChannelBuilder.forAddress("localhost", wm.getPort()).usePlaintext().build(); + greetingsClient = new GreetingsClient(channel); + greetingServer = new GreetingServer(5088); + greetingServer.start(); + } + + @AfterEach + void tearDown() { + channel.shutdown(); + greetingServer.stop(); + } + + @Test + void withProxy() { + + wm.stubFor( + post(urlPathEqualTo("/com.example.grpc.GreetingService/greeting")) + .willReturn( + aResponse() + .proxiedFrom("http://localhost:5088"))); + + String greeting = greetingsClient.greet("Tommy"); + + assertThat(greeting, is("Hello from GRPC proxy, Tommy")); + } + +} diff --git a/src/testFixtures/java/org/wiremock/grpc/server/GreetingServer.java b/src/testFixtures/java/org/wiremock/grpc/server/GreetingServer.java new file mode 100644 index 0000000..c729191 --- /dev/null +++ b/src/testFixtures/java/org/wiremock/grpc/server/GreetingServer.java @@ -0,0 +1,53 @@ +package org.wiremock.grpc.server; + +import com.example.grpc.GreetingServiceGrpc; +import com.example.grpc.request.HelloRequest; +import com.example.grpc.response.HelloResponse; +import io.grpc.BindableService; +import io.grpc.Server; +import io.grpc.ServerBuilder; + +import java.io.IOException; + +public class GreetingServer { + + private final int port; + private Server server; + + public GreetingServer(int port) { + this.port = port; + } + + public void start() { + try { + server = ServerBuilder.forPort(port) + .addService(new GreetingServiceImpl()) + .build() + .start(); + System.out.println("gRPC server started on port " + port); + } catch (Exception e) { + e.printStackTrace(); + } + + } + + public void stop() { + if (server != null) { + server.shutdown(); + System.out.println("gRPC stopped"); + } + } + + static class GreetingServiceImpl extends GreetingServiceGrpc.GreetingServiceImplBase { + + @Override + public void greeting(HelloRequest request, io.grpc.stub.StreamObserver responseObserver) { + String name = request.getName(); + HelloResponse response = HelloResponse.newBuilder() + .setGreeting("Hello from GRPC proxy, " + name) + .build(); + responseObserver.onNext(response); + responseObserver.onCompleted(); + } + } +}