From 2d0be5fcceae2b0036a1ba5ee667c7c3962fdbfe Mon Sep 17 00:00:00 2001 From: Brad Herrmann Date: Mon, 5 Apr 2021 14:53:40 -0500 Subject: [PATCH] implement grpc request scope --- .../grpc/demo/DemoAppConfiguration.java | 10 + .../springboot/grpc/demo/RandomUUID.java | 11 + .../grpc/demo/RequestScopedService.java | 52 ++ .../src/main/proto/request_scoped.proto | 11 + .../io/grpc/examples/RequestScopedGrpc.java | 259 ++++++++ .../examples/RequestScopedOuterClass.java | 583 ++++++++++++++++++ .../springboot/grpc/GrpcRequestScopeTest.java | 139 +++++ .../autoconfigure/GRpcAutoConfiguration.java | 11 +- .../autoconfigure/scope/GRpcRequestScope.java | 264 ++++++++ 9 files changed, 1339 insertions(+), 1 deletion(-) create mode 100644 grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RandomUUID.java create mode 100644 grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RequestScopedService.java create mode 100644 grpc-spring-boot-starter-demo/src/main/proto/request_scoped.proto create mode 100644 grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedGrpc.java create mode 100644 grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedOuterClass.java create mode 100644 grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/GrpcRequestScopeTest.java create mode 100644 grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/scope/GRpcRequestScope.java diff --git a/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/DemoAppConfiguration.java b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/DemoAppConfiguration.java index 42da990b..3c99e555 100644 --- a/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/DemoAppConfiguration.java +++ b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/DemoAppConfiguration.java @@ -5,7 +5,11 @@ import io.grpc.examples.SecuredCalculatorGrpc; import io.grpc.stub.StreamObserver; import org.lognet.springboot.grpc.GRpcService; +import org.lognet.springboot.grpc.autoconfigure.scope.GRpcRequestScope; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Scope; +import org.springframework.context.annotation.ScopedProxyMode; import org.springframework.security.access.annotation.Secured; @Configuration @@ -58,4 +62,10 @@ public void calculate(CalculatorOuterClass.CalculatorRequest request, StreamObse } + + @Bean + @Scope(scopeName = GRpcRequestScope.GRPC_REQUEST_SCOPE_NAME, proxyMode = ScopedProxyMode.TARGET_CLASS) + public RandomUUID testRequestScopeBean() { + return new RandomUUID(); + } } diff --git a/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RandomUUID.java b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RandomUUID.java new file mode 100644 index 00000000..5c7970f6 --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RandomUUID.java @@ -0,0 +1,11 @@ +package org.lognet.springboot.grpc.demo; + +import java.util.UUID; + +public class RandomUUID { + private final String id = UUID.randomUUID().toString(); + + public String getId() { + return this.id; + } +} diff --git a/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RequestScopedService.java b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RequestScopedService.java new file mode 100644 index 00000000..bc5751c2 --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/main/java/org/lognet/springboot/grpc/demo/RequestScopedService.java @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2016-2021 Michael Zhang + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.lognet.springboot.grpc.demo; + +import io.grpc.examples.RequestScopedGrpc; +import io.grpc.examples.RequestScopedOuterClass; +import io.grpc.stub.StreamObserver; +import org.lognet.springboot.grpc.GRpcService; +import org.springframework.beans.factory.annotation.Autowired; + +import java.util.UUID; + +@GRpcService +public class RequestScopedService extends RequestScopedGrpc.RequestScopedImplBase { + @Autowired + private RandomUUID testRequestScopeBean; + + @Override + public StreamObserver requestScoped(StreamObserver observer) { + return new StreamObserver() { + @Override + public void onNext(RequestScopedOuterClass.RequestScopedMessage value) { + observer.onNext(value.toBuilder().setStr(RequestScopedService.this.testRequestScopeBean.getId()).build()); + } + + @Override + public void onError(Throwable t) { + observer.onError(t); + } + + @Override + public void onCompleted() { + observer.onCompleted(); + } + }; + } +} diff --git a/grpc-spring-boot-starter-demo/src/main/proto/request_scoped.proto b/grpc-spring-boot-starter-demo/src/main/proto/request_scoped.proto new file mode 100644 index 00000000..f439f22f --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/main/proto/request_scoped.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +option java_package = "io.grpc.examples"; + +service RequestScoped { + rpc RequestScoped(stream RequestScopedMessage) returns (stream RequestScopedMessage) {} +} + +message RequestScopedMessage { + string str = 1; +} \ No newline at end of file diff --git a/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedGrpc.java b/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedGrpc.java new file mode 100644 index 00000000..fcf7e1bf --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedGrpc.java @@ -0,0 +1,259 @@ +package io.grpc.examples; + +import static io.grpc.MethodDescriptor.generateFullMethodName; + +/** + */ +@javax.annotation.Generated( + value = "by gRPC proto compiler (version 1.36.0)", + comments = "Source: request_scoped.proto") +public final class RequestScopedGrpc { + + private RequestScopedGrpc() {} + + public static final String SERVICE_NAME = "RequestScoped"; + + // Static method descriptors that strictly reflect the proto. + private static volatile io.grpc.MethodDescriptor getRequestScopedMethod; + + @io.grpc.stub.annotations.RpcMethod( + fullMethodName = SERVICE_NAME + '/' + "RequestScoped", + requestType = io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.class, + responseType = io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.class, + methodType = io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + public static io.grpc.MethodDescriptor getRequestScopedMethod() { + io.grpc.MethodDescriptor getRequestScopedMethod; + if ((getRequestScopedMethod = RequestScopedGrpc.getRequestScopedMethod) == null) { + synchronized (RequestScopedGrpc.class) { + if ((getRequestScopedMethod = RequestScopedGrpc.getRequestScopedMethod) == null) { + RequestScopedGrpc.getRequestScopedMethod = getRequestScopedMethod = + io.grpc.MethodDescriptor.newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.BIDI_STREAMING) + .setFullMethodName(generateFullMethodName(SERVICE_NAME, "RequestScoped")) + .setSampledToLocalTracing(true) + .setRequestMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance())) + .setResponseMarshaller(io.grpc.protobuf.ProtoUtils.marshaller( + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance())) + .setSchemaDescriptor(new RequestScopedMethodDescriptorSupplier("RequestScoped")) + .build(); + } + } + } + return getRequestScopedMethod; + } + + /** + * Creates a new async stub that supports all call types for the service + */ + public static RequestScopedStub newStub(io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RequestScopedStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedStub(channel, callOptions); + } + }; + return RequestScopedStub.newStub(factory, channel); + } + + /** + * Creates a new blocking-style stub that supports unary and streaming output calls on the service + */ + public static RequestScopedBlockingStub newBlockingStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RequestScopedBlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedBlockingStub(channel, callOptions); + } + }; + return RequestScopedBlockingStub.newStub(factory, channel); + } + + /** + * Creates a new ListenableFuture-style stub that supports unary calls on the service + */ + public static RequestScopedFutureStub newFutureStub( + io.grpc.Channel channel) { + io.grpc.stub.AbstractStub.StubFactory factory = + new io.grpc.stub.AbstractStub.StubFactory() { + @java.lang.Override + public RequestScopedFutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedFutureStub(channel, callOptions); + } + }; + return RequestScopedFutureStub.newStub(factory, channel); + } + + /** + */ + public static abstract class RequestScopedImplBase implements io.grpc.BindableService { + + /** + */ + public io.grpc.stub.StreamObserver requestScoped( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall(getRequestScopedMethod(), responseObserver); + } + + @java.lang.Override public final io.grpc.ServerServiceDefinition bindService() { + return io.grpc.ServerServiceDefinition.builder(getServiceDescriptor()) + .addMethod( + getRequestScopedMethod(), + io.grpc.stub.ServerCalls.asyncBidiStreamingCall( + new MethodHandlers< + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage, + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage>( + this, METHODID_REQUEST_SCOPED))) + .build(); + } + } + + /** + */ + public static final class RequestScopedStub extends io.grpc.stub.AbstractAsyncStub { + private RequestScopedStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RequestScopedStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedStub(channel, callOptions); + } + + /** + */ + public io.grpc.stub.StreamObserver requestScoped( + io.grpc.stub.StreamObserver responseObserver) { + return io.grpc.stub.ClientCalls.asyncBidiStreamingCall( + getChannel().newCall(getRequestScopedMethod(), getCallOptions()), responseObserver); + } + } + + /** + */ + public static final class RequestScopedBlockingStub extends io.grpc.stub.AbstractBlockingStub { + private RequestScopedBlockingStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RequestScopedBlockingStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedBlockingStub(channel, callOptions); + } + } + + /** + */ + public static final class RequestScopedFutureStub extends io.grpc.stub.AbstractFutureStub { + private RequestScopedFutureStub( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + super(channel, callOptions); + } + + @java.lang.Override + protected RequestScopedFutureStub build( + io.grpc.Channel channel, io.grpc.CallOptions callOptions) { + return new RequestScopedFutureStub(channel, callOptions); + } + } + + private static final int METHODID_REQUEST_SCOPED = 0; + + private static final class MethodHandlers implements + io.grpc.stub.ServerCalls.UnaryMethod, + io.grpc.stub.ServerCalls.ServerStreamingMethod, + io.grpc.stub.ServerCalls.ClientStreamingMethod, + io.grpc.stub.ServerCalls.BidiStreamingMethod { + private final RequestScopedImplBase serviceImpl; + private final int methodId; + + MethodHandlers(RequestScopedImplBase serviceImpl, int methodId) { + this.serviceImpl = serviceImpl; + this.methodId = methodId; + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + default: + throw new AssertionError(); + } + } + + @java.lang.Override + @java.lang.SuppressWarnings("unchecked") + public io.grpc.stub.StreamObserver invoke( + io.grpc.stub.StreamObserver responseObserver) { + switch (methodId) { + case METHODID_REQUEST_SCOPED: + return (io.grpc.stub.StreamObserver) serviceImpl.requestScoped( + (io.grpc.stub.StreamObserver) responseObserver); + default: + throw new AssertionError(); + } + } + } + + private static abstract class RequestScopedBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoFileDescriptorSupplier, io.grpc.protobuf.ProtoServiceDescriptorSupplier { + RequestScopedBaseDescriptorSupplier() {} + + @java.lang.Override + public com.google.protobuf.Descriptors.FileDescriptor getFileDescriptor() { + return io.grpc.examples.RequestScopedOuterClass.getDescriptor(); + } + + @java.lang.Override + public com.google.protobuf.Descriptors.ServiceDescriptor getServiceDescriptor() { + return getFileDescriptor().findServiceByName("RequestScoped"); + } + } + + private static final class RequestScopedFileDescriptorSupplier + extends RequestScopedBaseDescriptorSupplier { + RequestScopedFileDescriptorSupplier() {} + } + + private static final class RequestScopedMethodDescriptorSupplier + extends RequestScopedBaseDescriptorSupplier + implements io.grpc.protobuf.ProtoMethodDescriptorSupplier { + private final String methodName; + + RequestScopedMethodDescriptorSupplier(String methodName) { + this.methodName = methodName; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.MethodDescriptor getMethodDescriptor() { + return getServiceDescriptor().findMethodByName(methodName); + } + } + + private static volatile io.grpc.ServiceDescriptor serviceDescriptor; + + public static io.grpc.ServiceDescriptor getServiceDescriptor() { + io.grpc.ServiceDescriptor result = serviceDescriptor; + if (result == null) { + synchronized (RequestScopedGrpc.class) { + result = serviceDescriptor; + if (result == null) { + serviceDescriptor = result = io.grpc.ServiceDescriptor.newBuilder(SERVICE_NAME) + .setSchemaDescriptor(new RequestScopedFileDescriptorSupplier()) + .addMethod(getRequestScopedMethod()) + .build(); + } + } + } + return result; + } +} diff --git a/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedOuterClass.java b/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedOuterClass.java new file mode 100644 index 00000000..61717f56 --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/main/protoGen/io/grpc/examples/RequestScopedOuterClass.java @@ -0,0 +1,583 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: request_scoped.proto + +package io.grpc.examples; + +public final class RequestScopedOuterClass { + private RequestScopedOuterClass() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + public interface RequestScopedMessageOrBuilder extends + // @@protoc_insertion_point(interface_extends:RequestScopedMessage) + com.google.protobuf.MessageOrBuilder { + + /** + * string str = 1; + */ + java.lang.String getStr(); + /** + * string str = 1; + */ + com.google.protobuf.ByteString + getStrBytes(); + } + /** + * Protobuf type {@code RequestScopedMessage} + */ + public static final class RequestScopedMessage extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:RequestScopedMessage) + RequestScopedMessageOrBuilder { + private static final long serialVersionUID = 0L; + // Use RequestScopedMessage.newBuilder() to construct. + private RequestScopedMessage(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private RequestScopedMessage() { + str_ = ""; + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + private RequestScopedMessage( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + this(); + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + int mutable_bitField0_ = 0; + com.google.protobuf.UnknownFieldSet.Builder unknownFields = + com.google.protobuf.UnknownFieldSet.newBuilder(); + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!parseUnknownFieldProto3( + input, unknownFields, extensionRegistry, tag)) { + done = true; + } + break; + } + case 10: { + java.lang.String s = input.readStringRequireUtf8(); + + str_ = s; + break; + } + } + } + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(this); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException( + e).setUnfinishedMessage(this); + } finally { + this.unknownFields = unknownFields.build(); + makeExtensionsImmutable(); + } + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return io.grpc.examples.RequestScopedOuterClass.internal_static_RequestScopedMessage_descriptor; + } + + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return io.grpc.examples.RequestScopedOuterClass.internal_static_RequestScopedMessage_fieldAccessorTable + .ensureFieldAccessorsInitialized( + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.class, io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.Builder.class); + } + + public static final int STR_FIELD_NUMBER = 1; + private volatile java.lang.Object str_; + /** + * string str = 1; + */ + public java.lang.String getStr() { + java.lang.Object ref = str_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + str_ = s; + return s; + } + } + /** + * string str = 1; + */ + public com.google.protobuf.ByteString + getStrBytes() { + java.lang.Object ref = str_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + str_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + private byte memoizedIsInitialized = -1; + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!getStrBytes().isEmpty()) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, str_); + } + unknownFields.writeTo(output); + } + + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!getStrBytes().isEmpty()) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, str_); + } + size += unknownFields.getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage)) { + return super.equals(obj); + } + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage other = (io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage) obj; + + boolean result = true; + result = result && getStr() + .equals(other.getStr()); + result = result && unknownFields.equals(other.unknownFields); + return result; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + STR_FIELD_NUMBER; + hash = (53 * hash) + getStr().hashCode(); + hash = (29 * hash) + unknownFields.hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code RequestScopedMessage} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:RequestScopedMessage) + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessageOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return io.grpc.examples.RequestScopedOuterClass.internal_static_RequestScopedMessage_descriptor; + } + + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return io.grpc.examples.RequestScopedOuterClass.internal_static_RequestScopedMessage_fieldAccessorTable + .ensureFieldAccessorsInitialized( + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.class, io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.Builder.class); + } + + // Construct using io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + } + } + public Builder clear() { + super.clear(); + str_ = ""; + + return this; + } + + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return io.grpc.examples.RequestScopedOuterClass.internal_static_RequestScopedMessage_descriptor; + } + + public io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage getDefaultInstanceForType() { + return io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance(); + } + + public io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage build() { + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + public io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage buildPartial() { + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage result = new io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage(this); + result.str_ = str_; + onBuilt(); + return result; + } + + public Builder clone() { + return (Builder) super.clone(); + } + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.setField(field, value); + } + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return (Builder) super.clearField(field); + } + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return (Builder) super.clearOneof(oneof); + } + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return (Builder) super.setRepeatedField(field, index, value); + } + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return (Builder) super.addRepeatedField(field, value); + } + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage) { + return mergeFrom((io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage other) { + if (other == io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()) return this; + if (!other.getStr().isEmpty()) { + str_ = other.str_; + onChanged(); + } + this.mergeUnknownFields(other.unknownFields); + onChanged(); + return this; + } + + public final boolean isInitialized() { + return true; + } + + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage parsedMessage = null; + try { + parsedMessage = PARSER.parsePartialFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + parsedMessage = (io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage) e.getUnfinishedMessage(); + throw e.unwrapIOException(); + } finally { + if (parsedMessage != null) { + mergeFrom(parsedMessage); + } + } + return this; + } + + private java.lang.Object str_ = ""; + /** + * string str = 1; + */ + public java.lang.String getStr() { + java.lang.Object ref = str_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + str_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string str = 1; + */ + public com.google.protobuf.ByteString + getStrBytes() { + java.lang.Object ref = str_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + str_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string str = 1; + */ + public Builder setStr( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + str_ = value; + onChanged(); + return this; + } + /** + * string str = 1; + */ + public Builder clearStr() { + + str_ = getDefaultInstance().getStr(); + onChanged(); + return this; + } + /** + * string str = 1; + */ + public Builder setStrBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + str_ = value; + onChanged(); + return this; + } + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFieldsProto3(unknownFields); + } + + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:RequestScopedMessage) + } + + // @@protoc_insertion_point(class_scope:RequestScopedMessage) + private static final io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage(); + } + + public static io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + public RequestScopedMessage parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return new RequestScopedMessage(input, extensionRegistry); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + public io.grpc.examples.RequestScopedOuterClass.RequestScopedMessage getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_RequestScopedMessage_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_RequestScopedMessage_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\024request_scoped.proto\"#\n\024RequestScopedM" + + "essage\022\013\n\003str\030\001 \001(\t2T\n\rRequestScoped\022C\n\r" + + "RequestScoped\022\025.RequestScopedMessage\032\025.R" + + "equestScopedMessage\"\000(\0010\001B\022\n\020io.grpc.exa" + + "mplesb\006proto3" + }; + com.google.protobuf.Descriptors.FileDescriptor.InternalDescriptorAssigner assigner = + new com.google.protobuf.Descriptors.FileDescriptor. InternalDescriptorAssigner() { + public com.google.protobuf.ExtensionRegistry assignDescriptors( + com.google.protobuf.Descriptors.FileDescriptor root) { + descriptor = root; + return null; + } + }; + com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + }, assigner); + internal_static_RequestScopedMessage_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_RequestScopedMessage_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_RequestScopedMessage_descriptor, + new java.lang.String[] { "Str", }); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/GrpcRequestScopeTest.java b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/GrpcRequestScopeTest.java new file mode 100644 index 00000000..d48d12bd --- /dev/null +++ b/grpc-spring-boot-starter-demo/src/test/java/org/lognet/springboot/grpc/GrpcRequestScopeTest.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2016-2021 Michael Zhang + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.lognet.springboot.grpc; + +import io.grpc.examples.RequestScopedGrpc; +import io.grpc.examples.RequestScopedOuterClass; +import io.grpc.stub.StreamObserver; +import junit.framework.AssertionFailedError; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.lognet.springboot.grpc.demo.DemoApp; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.junit4.SpringRunner; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.NONE; + +@RunWith(SpringRunner.class) +@SpringBootTest(classes = {DemoApp.class}, webEnvironment = NONE) +@DirtiesContext +public class GrpcRequestScopeTest extends GrpcServerTestBase { + + @Test + @DirtiesContext + public void testScope() throws InterruptedException { + // Prepare + RequestScopedGrpc.RequestScopedStub requestScopedServiceStub = RequestScopedGrpc.newStub(channel); + ScopedStreamObserverChecker scope1 = new ScopedStreamObserverChecker(); + StreamObserver request1 = requestScopedServiceStub.requestScoped(scope1); + ScopedStreamObserverChecker scope2 = new ScopedStreamObserverChecker(); + StreamObserver request2 = requestScopedServiceStub.requestScoped(scope2); + + // Run + request1.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + request1.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + Thread.sleep(150); + + request2.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + request2.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + Thread.sleep(150); + + request1.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + request2.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + Thread.sleep(150); + + request2.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + request1.onNext(RequestScopedOuterClass.RequestScopedMessage.getDefaultInstance()); + Thread.sleep(150); + + request1.onCompleted(); + request2.onCompleted(); + Thread.sleep(150); + + // Assert + assertTrue(scope1.isCompleted()); + assertTrue(scope2.isCompleted()); + assertNull(scope1.getError()); + assertNull(scope2.getError()); + assertNotNull(scope1.getText()); + assertNotNull(scope2.getText()); + assertNotEquals(scope1.getText(), scope2.getText()); + } + + + /** + * Helper class used to check that the scoped responses are different per request, but the same for different + * messages in the same request. + */ + private static class ScopedStreamObserverChecker implements StreamObserver { + + private String text; + private boolean completed = false; + private Throwable error; + + @Override + public void onNext(RequestScopedOuterClass.RequestScopedMessage value) { + if (this.text == null) { + this.text = value.getStr(); + } + try { + assertEquals(this.text, value.getStr()); + } catch (AssertionFailedError e) { + if (this.error == null) { + this.error = e; + } else { + this.error.addSuppressed(e); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.error == null) { + this.error = t; + } else { + this.error.addSuppressed(t); + } + this.completed = true; + } + + @Override + public void onCompleted() { + this.completed = true; + } + + public String getText() { + return this.text; + } + + public boolean isCompleted() { + return this.completed; + } + + public Throwable getError() { + return this.error; + } + + } +} diff --git a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/GRpcAutoConfiguration.java b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/GRpcAutoConfiguration.java index c34e1d43..fc34e0d4 100644 --- a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/GRpcAutoConfiguration.java +++ b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/GRpcAutoConfiguration.java @@ -8,6 +8,7 @@ import org.lognet.springboot.grpc.GRpcServerBuilderConfigurer; import org.lognet.springboot.grpc.GRpcServerRunner; import org.lognet.springboot.grpc.GRpcService; +import org.lognet.springboot.grpc.autoconfigure.scope.GRpcRequestScope; import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -174,7 +175,15 @@ public InetSocketAddress convert(String source) { }; } - + /** + * A scope that is valid for the duration of a grpc request. + * + * @return The grpc request scope bean. + */ + @Bean + public static GRpcRequestScope grpcRequestScope() { + return new GRpcRequestScope(); + } } diff --git a/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/scope/GRpcRequestScope.java b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/scope/GRpcRequestScope.java new file mode 100644 index 00000000..a002e2ba --- /dev/null +++ b/grpc-spring-boot-starter/src/main/java/org/lognet/springboot/grpc/autoconfigure/scope/GRpcRequestScope.java @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2016-2021 Michael Zhang + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.lognet.springboot.grpc.autoconfigure.scope; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.lognet.springboot.grpc.GRpcGlobalInterceptor; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.ObjectFactory; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.config.Scope; +import org.springframework.context.annotation.Bean; +import org.springframework.core.annotation.Order; +import org.springframework.core.Ordered; + +import com.google.common.util.concurrent.MoreExecutors; + +import io.grpc.Context; +import io.grpc.Context.CancellationListener; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; + +/** + * The scope for beans that have their lifecycle bound to the grpc {@link Context}. + * + *

+ * Note: If you write the {@link Bean @Bean} definition of this class, you must use the {@code static} keyword. + *

+ * + * @author Daniel Theuke (daniel.theuke@heuboe.de) + */ +@GRpcGlobalInterceptor +@Order(Ordered.HIGHEST_PRECEDENCE) +public class GRpcRequestScope implements Scope, BeanFactoryPostProcessor, ServerInterceptor, CancellationListener { + public static final String GRPC_REQUEST_SCOPE_NAME = "grpcRequest"; + private static final String GRPC_REQUEST_SCOPE_ID = "grpc-request"; + private static final Context.Key GRPC_REQUEST_KEY = Context.key("grpcRequestScope"); + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory factory) throws BeansException { + factory.registerScope(GRPC_REQUEST_SCOPE_NAME, this); + } + + @Override + public Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + ScopedBeansContainer container = new ScopedBeansContainer(); + Context context = Context.current().withValue(GRPC_REQUEST_KEY, container); + context.addListener(this, MoreExecutors.directExecutor()); + return Contexts.interceptCall(context, call, headers, next); + } + + @Override + public Object get(String name, ObjectFactory objectFactory) { + return getCurrentScopeContainer().getOrCreate(name, objectFactory); + } + + @Override + public Object remove(String name) { + return getCurrentScopeContainer().remove(name); + } + + @Override + public void registerDestructionCallback(String name, Runnable callback) { + getCurrentScopeContainer().registerDestructionCallback(name, callback); + } + + @Override + public Object resolveContextualObject(String key) { + return null; + } + + @Override + public String getConversationId() { + return GRPC_REQUEST_SCOPE_ID; + } + + @Override + public void cancelled(Context context) { + final ScopedBeansContainer container = GRPC_REQUEST_KEY.get(context); + if (container != null) { + container.destroy(); + } + } + + /** + * Gets the current container for the grpc request scope. + * + * @return The currently active scope container. + * @throws IllegalStateException If the grpc request scope is currently not active. + */ + private ScopedBeansContainer getCurrentScopeContainer() { + ScopedBeansContainer scopedBeansContainer = GRPC_REQUEST_KEY.get(); + if (scopedBeansContainer == null) { + throw new IllegalStateException( + "Trying to access grpcRequest-Scope, but it was not started for this thread."); + } + return scopedBeansContainer; + } + + /** + * Container for all beans used in the active scope. + */ + private static class ScopedBeansContainer { + + private final Map references = new ConcurrentHashMap<>(); + + /** + * Gets or creates the bean with the given name using the given object factory. + * + * @param name The name of the bean. + * @param objectFactory The object factory used to create new instances. + * @return The bean associated with the given name. + */ + public Object getOrCreate(final String name, final ObjectFactory objectFactory) { + return this.references.computeIfAbsent(name, key -> new ScopedBeanReference(objectFactory)) + .getBean(); + } + + /** + * Removes the bean with the given name from this scope. + * + * @param name The name of the bean to remove. + * @return The bean instances that was removed from the scope or null, if it wasn't present. + */ + public Object remove(final String name) { + final ScopedBeanReference ref = this.references.remove(name); + if (ref == null) { + return null; + } else { + return ref.getBeanIfExists(); + } + } + + /** + * Attaches a destruction callback to the bean with the given name. + * + * @param name The name of the bean to attach the destruction callback to. + * @param callback The callback to register for the bean. + */ + public void registerDestructionCallback(final String name, final Runnable callback) { + final ScopedBeanReference ref = this.references.get(name); + if (ref != null) { + ref.setDestructionCallback(callback); + } + } + + /** + * Destroys all beans in the scope and executes their destruction callbacks. + */ + public void destroy() { + final List errors = new ArrayList<>(); + final Iterator it = this.references.values().iterator(); + while (it.hasNext()) { + ScopedBeanReference val = it.next(); + it.remove(); + try { + val.destroy(); + } catch (RuntimeException e) { + errors.add(e); + } + } + if (!errors.isEmpty()) { + RuntimeException rex = errors.remove(0); + for (RuntimeException error : errors) { + rex.addSuppressed(error); + } + throw rex; + } + } + + } + + /** + * Container for a single scoped bean. This class manages the bean creation + */ + private static class ScopedBeanReference { + + private final ObjectFactory objectFactory; + private Object bean; + private Runnable destructionCallback; + + /** + * Creates a new scoped bean reference using the given object factory. + * + * @param objectFactory The object factory used to create instances of that bean. + */ + public ScopedBeanReference(ObjectFactory objectFactory) { + this.objectFactory = objectFactory; + } + + /** + * Gets or creates the bean managed by this instance. + * + * @return The existing or newly created bean instance. + */ + public synchronized Object getBean() { + if (this.bean == null) { + this.bean = this.objectFactory.getObject(); + } + return this.bean; + } + + /** + * Gets the bean managed by this instance, if it exists. + * + * @return The existing bean or null. + */ + public Object getBeanIfExists() { + return this.bean; + } + + /** + * Sets the given callback used to destroy the managed bean. + * + * @param destructionCallback The destruction callback to use. + */ + public void setDestructionCallback(final Runnable destructionCallback) { + this.destructionCallback = destructionCallback; + } + + /** + * Executes the destruction callback if set and clears the internal bean references. + */ + public synchronized void destroy() { + Runnable callback = this.destructionCallback; + if (callback != null) { + callback.run(); + } + this.bean = null; + this.destructionCallback = null; + } + + @Override + public String toString() { + return "ScopedBeanReference [objectFactory=" + this.objectFactory + ", bean=" + this.bean + "]"; + } + + } +}