diff --git a/client/base/pom.xml b/client/base/pom.xml index 14264c5f..7dcf9eca 100644 --- a/client/base/pom.xml +++ b/client/base/pom.xml @@ -48,6 +48,11 @@ ${project.groupId} a2a-java-sdk-spec + + ${project.groupId} + a2a-java-sdk-spec-grpc + test + org.junit.jupiter junit-jupiter-api @@ -64,6 +69,16 @@ slf4j-jdk14 test + + io.grpc + grpc-testing + test + + + io.grpc + grpc-inprocess + test + \ No newline at end of file diff --git a/client/base/src/test/java/io/a2a/client/AuthenticationAuthorizationTest.java b/client/base/src/test/java/io/a2a/client/AuthenticationAuthorizationTest.java new file mode 100644 index 00000000..29520574 --- /dev/null +++ b/client/base/src/test/java/io/a2a/client/AuthenticationAuthorizationTest.java @@ -0,0 +1,380 @@ +package io.a2a.client; + +import io.a2a.client.config.ClientConfig; +import io.a2a.client.transport.grpc.GrpcTransport; +import io.a2a.client.transport.grpc.GrpcTransportConfigBuilder; +import io.a2a.client.transport.jsonrpc.JSONRPCTransport; +import io.a2a.client.transport.jsonrpc.JSONRPCTransportConfigBuilder; +import io.a2a.client.transport.rest.RestTransport; +import io.a2a.client.transport.rest.RestTransportConfigBuilder; +import io.a2a.grpc.A2AServiceGrpc; +import io.a2a.grpc.SendMessageRequest; +import io.a2a.grpc.SendMessageResponse; +import io.a2a.grpc.StreamResponse; +import io.a2a.spec.A2AClientException; +import io.a2a.spec.AgentCapabilities; +import io.a2a.spec.AgentCard; +import io.a2a.spec.AgentInterface; +import io.a2a.spec.AgentSkill; +import io.a2a.spec.Message; +import io.a2a.spec.TextPart; +import io.a2a.spec.TransportProtocol; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockserver.integration.ClientAndServer; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +/** + * Tests for handling HTTP 401 (Unauthorized) and 403 (Forbidden) responses + * when the client sends streaming and non-streaming messages. + * + * These tests verify that the client properly fails when the server returns + * authentication or authorization errors. + */ +public class AuthenticationAuthorizationTest { + + private static final String AGENT_URL = "http://localhost:4001"; + private static final String AUTHENTICATION_FAILED_MESSAGE = "Authentication failed"; + private static final String AUTHORIZATION_FAILED_MESSAGE = "Authorization failed"; + + private ClientAndServer server; + private Message MESSAGE; + private AgentCard agentCard; + private Server grpcServer; + private ManagedChannel grpcChannel; + private String grpcServerName; + + @BeforeEach + public void setUp() { + server = new ClientAndServer(4001); + MESSAGE = new Message.Builder() + .role(Message.Role.USER) + .parts(Collections.singletonList(new TextPart("test message"))) + .contextId("context-1234") + .messageId("message-1234") + .build(); + + grpcServerName = InProcessServerBuilder.generateName(); + + agentCard = new AgentCard.Builder() + .name("Test Agent") + .description("Test agent for auth tests") + .url(AGENT_URL) + .version("1.0.0") + .capabilities(new AgentCapabilities.Builder() + .streaming(true) // Support streaming for all tests + .build()) + .defaultInputModes(Collections.singletonList("text")) + .defaultOutputModes(Collections.singletonList("text")) + .skills(Collections.singletonList(new AgentSkill.Builder() + .id("test_skill") + .name("Test skill") + .description("Test skill") + .tags(Collections.singletonList("test")) + .build())) + .protocolVersion("0.3.0") + .additionalInterfaces(java.util.Arrays.asList( + new AgentInterface(TransportProtocol.JSONRPC.asString(), AGENT_URL), + new AgentInterface(TransportProtocol.HTTP_JSON.asString(), AGENT_URL), + new AgentInterface(TransportProtocol.GRPC.asString(), grpcServerName))) + .build(); + } + + @AfterEach + public void tearDown() { + server.stop(); + if (grpcChannel != null) { + grpcChannel.shutdownNow(); + } + if (grpcServer != null) { + grpcServer.shutdownNow(); + } + } + + // ========== JSON-RPC Transport Tests ========== + + @Test + public void testJsonRpcNonStreamingUnauthenticated() throws A2AClientException { + // Mock server to return 401 for non-streaming message + server.when( + request() + .withMethod("POST") + .withPath("/") + ).respond( + response() + .withStatusCode(401) + ); + + Client client = getJSONRPCClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHENTICATION_FAILED_MESSAGE)); + } + + @Test + public void testJsonRpcNonStreamingUnauthorized() throws A2AClientException { + // Mock server to return 403 for non-streaming message + server.when( + request() + .withMethod("POST") + .withPath("/") + ).respond( + response() + .withStatusCode(403) + ); + + Client client = getJSONRPCClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHORIZATION_FAILED_MESSAGE)); + } + + @Test + public void testJsonRpcStreamingUnauthenticated() throws Exception { + // Mock server to return 401 for streaming message + server.when( + request() + .withMethod("POST") + .withPath("/") + ).respond( + response() + .withStatusCode(401) + ); + + assertStreamingError( + getJSONRPCClientBuilder(true), + AUTHENTICATION_FAILED_MESSAGE); + } + + @Test + public void testJsonRpcStreamingUnauthorized() throws Exception { + // Mock server to return 403 for streaming message + server.when( + request() + .withMethod("POST") + .withPath("/") + ).respond( + response() + .withStatusCode(403) + ); + + assertStreamingError( + getJSONRPCClientBuilder(true), + AUTHORIZATION_FAILED_MESSAGE); + } + + // ========== REST Transport Tests ========== + + @Test + public void testRestNonStreamingUnauthenticated() throws A2AClientException { + // Mock server to return 401 for non-streaming message + server.when( + request() + .withMethod("POST") + .withPath("/v1/message:send") + ).respond( + response() + .withStatusCode(401) + ); + + Client client = getRestClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHENTICATION_FAILED_MESSAGE)); + } + + @Test + public void testRestNonStreamingUnauthorized() throws A2AClientException { + // Mock server to return 403 for non-streaming message + server.when( + request() + .withMethod("POST") + .withPath("/v1/message:send") + ).respond( + response() + .withStatusCode(403) + ); + + Client client = getRestClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHORIZATION_FAILED_MESSAGE)); + } + + @Test + public void testRestStreamingUnauthenticated() throws Exception { + // Mock server to return 401 for streaming message + server.when( + request() + .withMethod("POST") + .withPath("/v1/message:stream") + ).respond( + response() + .withStatusCode(401) + ); + + assertStreamingError( + getRestClientBuilder(true), + AUTHENTICATION_FAILED_MESSAGE); + } + + @Test + public void testRestStreamingUnauthorized() throws Exception { + // Mock server to return 403 for streaming message + server.when( + request() + .withMethod("POST") + .withPath("/v1/message:stream") + ).respond( + response() + .withStatusCode(403) + ); + + assertStreamingError( + getRestClientBuilder(true), + AUTHORIZATION_FAILED_MESSAGE); + } + + // ========== gRPC Transport Tests ========== + + @Test + public void testGrpcNonStreamingUnauthenticated() throws Exception { + setupGrpcServer(Status.UNAUTHENTICATED); + + Client client = getGrpcClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHENTICATION_FAILED_MESSAGE)); + } + + @Test + public void testGrpcNonStreamingUnauthorized() throws Exception { + setupGrpcServer(Status.PERMISSION_DENIED); + + Client client = getGrpcClientBuilder(false).build(); + + A2AClientException exception = assertThrows(A2AClientException.class, () -> { + client.sendMessage(MESSAGE); + }); + + assertTrue(exception.getMessage().contains(AUTHORIZATION_FAILED_MESSAGE)); + } + + @Test + public void testGrpcStreamingUnauthenticated() throws Exception { + setupGrpcServer(Status.UNAUTHENTICATED); + + assertStreamingError( + getGrpcClientBuilder(true), + AUTHENTICATION_FAILED_MESSAGE); + } + + @Test + public void testGrpcStreamingUnauthorized() throws Exception { + setupGrpcServer(Status.PERMISSION_DENIED); + + assertStreamingError( + getGrpcClientBuilder(true), + AUTHORIZATION_FAILED_MESSAGE); + } + + private ClientBuilder getJSONRPCClientBuilder(boolean streaming) { + return Client.builder(agentCard) + .clientConfig(new ClientConfig.Builder().setStreaming(streaming).build()) + .withTransport(JSONRPCTransport.class, new JSONRPCTransportConfigBuilder()); + } + + private ClientBuilder getRestClientBuilder(boolean streaming) { + return Client.builder(agentCard) + .clientConfig(new ClientConfig.Builder().setStreaming(streaming).build()) + .withTransport(RestTransport.class, new RestTransportConfigBuilder()); + } + + private ClientBuilder getGrpcClientBuilder(boolean streaming) { + return Client.builder(agentCard) + .clientConfig(new ClientConfig.Builder().setStreaming(streaming).build()) + .withTransport(GrpcTransport.class, new GrpcTransportConfigBuilder() + .channelFactory(target -> grpcChannel)); + } + + private void assertStreamingError(ClientBuilder clientBuilder, String expectedErrorMessage) throws Exception { + AtomicReference errorRef = new AtomicReference<>(); + CountDownLatch errorLatch = new CountDownLatch(1); + + Consumer errorHandler = error -> { + errorRef.set(error); + errorLatch.countDown(); + }; + + Client client = clientBuilder.streamingErrorHandler(errorHandler).build(); + + try { + client.sendMessage(MESSAGE); + // If no immediate exception, wait for async error + assertTrue(errorLatch.await(5, TimeUnit.SECONDS), "Expected error handler to be called"); + Throwable error = errorRef.get(); + assertTrue(error.getMessage().contains(expectedErrorMessage), + "Expected error message to contain '" + expectedErrorMessage + "' but got: " + error.getMessage()); + } catch (Exception e) { + // Immediate exception is also acceptable + assertTrue(e.getMessage().contains(expectedErrorMessage), + "Expected error message to contain '" + expectedErrorMessage + "' but got: " + e.getMessage()); + } + } + + private void setupGrpcServer(Status status) throws IOException { + grpcServerName = InProcessServerBuilder.generateName(); + grpcServer = InProcessServerBuilder.forName(grpcServerName) + .directExecutor() + .addService(new A2AServiceGrpc.A2AServiceImplBase() { + @Override + public void sendMessage(SendMessageRequest request, StreamObserver responseObserver) { + responseObserver.onError(status.asRuntimeException()); + } + + @Override + public void sendStreamingMessage(SendMessageRequest request, StreamObserver responseObserver) { + responseObserver.onError(status.asRuntimeException()); + } + }) + .build() + .start(); + + grpcChannel = InProcessChannelBuilder.forName(grpcServerName) + .directExecutor() + .build(); + } +} \ No newline at end of file diff --git a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcErrorMapper.java b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcErrorMapper.java index 7340f7ce..5f0db8f0 100644 --- a/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcErrorMapper.java +++ b/client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcErrorMapper.java @@ -1,5 +1,6 @@ package io.a2a.client.transport.grpc; +import io.a2a.common.A2AErrorMessages; import io.a2a.spec.A2AClientException; import io.a2a.spec.ContentTypeNotSupportedError; import io.a2a.spec.InvalidAgentResponseError; @@ -64,6 +65,10 @@ public static A2AClientException mapGrpcError(StatusRuntimeException e, String e return new A2AClientException(errorPrefix + (description != null ? description : e.getMessage()), new InvalidParamsError()); case INTERNAL: return new A2AClientException(errorPrefix + (description != null ? description : e.getMessage()), new io.a2a.spec.InternalError(null, e.getMessage(), null)); + case UNAUTHENTICATED: + return new A2AClientException(errorPrefix + A2AErrorMessages.AUTHENTICATION_FAILED); + case PERMISSION_DENIED: + return new A2AClientException(errorPrefix + A2AErrorMessages.AUTHORIZATION_FAILED); default: return new A2AClientException(errorPrefix + e.getMessage(), e); } diff --git a/common/src/main/java/io/a2a/common/A2AErrorMessages.java b/common/src/main/java/io/a2a/common/A2AErrorMessages.java new file mode 100644 index 00000000..22b587d8 --- /dev/null +++ b/common/src/main/java/io/a2a/common/A2AErrorMessages.java @@ -0,0 +1,11 @@ +package io.a2a.common; + +public final class A2AErrorMessages { + + private A2AErrorMessages() { + // prevent instantiation + } + + public static final String AUTHENTICATION_FAILED = "Authentication failed: Client credentials are missing or invalid"; + public static final String AUTHORIZATION_FAILED = "Authorization failed: Client does not have permission for the operation"; +} diff --git a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java index abcecc8e..8cd0089d 100644 --- a/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java +++ b/http-client/src/main/java/io/a2a/client/http/JdkA2AHttpClient.java @@ -1,19 +1,31 @@ package io.a2a.client.http; +import static java.net.HttpURLConnection.HTTP_FORBIDDEN; +import static java.net.HttpURLConnection.HTTP_MULT_CHOICE; +import static java.net.HttpURLConnection.HTTP_OK; +import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; + import java.io.IOException; +import java.net.HttpURLConnection; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.BodyHandlers; +import java.net.http.HttpResponse.BodySubscribers; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Flow; import java.util.function.Consumer; +import io.a2a.common.A2AErrorMessages; +import io.a2a.spec.A2AClientException; + public class JdkA2AHttpClient implements A2AHttpClient { private final HttpClient httpClient; @@ -88,6 +100,7 @@ protected CompletableFuture asyncRequest( ) { Flow.Subscriber subscriber = new Flow.Subscriber() { private Flow.Subscription subscription; + private volatile boolean errorRaised = false; @Override public void onSubscribe(Flow.Subscription subscription) { @@ -109,24 +122,72 @@ public void onNext(String item) { @Override public void onError(Throwable throwable) { - errorConsumer.accept(throwable); - subscription.cancel(); + if (!errorRaised) { + errorRaised = true; + errorConsumer.accept(throwable); + } + if (subscription != null) { + subscription.cancel(); + } } @Override public void onComplete() { - completeRunnable.run(); - subscription.cancel(); + if (!errorRaised) { + completeRunnable.run(); + } + if (subscription != null) { + subscription.cancel(); + } } }; - BodyHandler bodyHandler = BodyHandlers.fromLineSubscriber(subscriber); + // Create a custom body handler that checks status before processing body + BodyHandler bodyHandler = responseInfo -> { + // Check for authentication/authorization errors only + if (responseInfo.statusCode() == HTTP_UNAUTHORIZED || responseInfo.statusCode() == HTTP_FORBIDDEN) { + final String errorMessage; + if (responseInfo.statusCode() == HTTP_UNAUTHORIZED) { + errorMessage = A2AErrorMessages.AUTHENTICATION_FAILED; + } else { + errorMessage = A2AErrorMessages.AUTHORIZATION_FAILED; + } + // Return a body subscriber that immediately signals error + return BodySubscribers.fromSubscriber(new Flow.Subscriber>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscriber.onError(new IOException(errorMessage)); + } + + @Override + public void onNext(List item) { + // Should not be called + } + + @Override + public void onError(Throwable throwable) { + // Should not be called + } + + @Override + public void onComplete() { + // Should not be called + } + }); + } else { + // For all other status codes (including other errors), proceed with normal line subscriber + return BodyHandlers.fromLineSubscriber(subscriber).apply(responseInfo); + } + }; // Send the response async, and let the subscriber handle the lines. return httpClient.sendAsync(request, bodyHandler) .thenAccept(response -> { - if (!JdkHttpResponse.success(response)) { - subscriber.onError(new IOException("Request failed " + response.statusCode())); + // Handle non-authentication/non-authorization errors here + if (!isSuccessStatus(response.statusCode()) && + response.statusCode() != HTTP_UNAUTHORIZED && + response.statusCode() != HTTP_FORBIDDEN) { + subscriber.onError(new IOException("Request failed with status " + response.statusCode() + ":" + response.body())); } }); } @@ -200,6 +261,13 @@ public A2AHttpResponse post() throws IOException, InterruptedException { .build(); HttpResponse response = httpClient.send(request, BodyHandlers.ofString(StandardCharsets.UTF_8)); + + if (response.statusCode() == HTTP_UNAUTHORIZED) { + throw new IOException(A2AErrorMessages.AUTHENTICATION_FAILED); + } else if (response.statusCode() == HTTP_FORBIDDEN) { + throw new IOException(A2AErrorMessages.AUTHORIZATION_FAILED); + } + return new JdkHttpResponse(response); } @@ -227,7 +295,7 @@ public boolean success() {// Send the request and get the response } static boolean success(HttpResponse response) { - return response.statusCode() >= 200 && response.statusCode() < 300; + return response.statusCode() >= HTTP_OK && response.statusCode() < HTTP_MULT_CHOICE; } @Override @@ -235,4 +303,8 @@ public String body() { return response.body(); } } + + private static boolean isSuccessStatus(int statusCode) { + return statusCode >= HTTP_OK && statusCode < HTTP_MULT_CHOICE; + } } diff --git a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java index b259e91a..57dc5ae5 100644 --- a/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java +++ b/transport/grpc/src/main/java/io/a2a/transport/grpc/handler/GrpcHandler.java @@ -18,6 +18,7 @@ import java.util.logging.Logger; import com.google.protobuf.Empty; +import io.a2a.common.A2AErrorMessages; import io.a2a.grpc.A2AServiceGrpc; import io.a2a.grpc.StreamResponse; import io.a2a.server.AgentCardValidator; @@ -80,6 +81,8 @@ public void sendMessage(io.a2a.grpc.SendMessageRequest request, responseObserver.onCompleted(); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -100,6 +103,8 @@ public void getTask(io.a2a.grpc.GetTaskRequest request, } } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -120,6 +125,8 @@ public void cancelTask(io.a2a.grpc.CancelTaskRequest request, } } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -141,6 +148,8 @@ public void createTaskPushNotificationConfig(io.a2a.grpc.CreateTaskPushNotificat responseObserver.onCompleted(); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -162,6 +171,8 @@ public void getTaskPushNotificationConfig(io.a2a.grpc.GetTaskPushNotificationCon responseObserver.onCompleted(); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -179,7 +190,7 @@ public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationC ServerCallContext context = createCallContext(responseObserver); ListTaskPushNotificationConfigParams params = FromProto.listTaskPushNotificationConfigParams(request); List configList = getRequestHandler().onListTaskPushNotificationConfig(params, context); - io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder = + io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder = io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder(); for (TaskPushNotificationConfig config : configList) { responseBuilder.addConfigs(ToProto.taskPushNotificationConfig(config)); @@ -188,6 +199,8 @@ public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationC responseObserver.onCompleted(); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -208,6 +221,8 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request, convertToStreamResponse(publisher, responseObserver); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -228,6 +243,8 @@ public void taskSubscription(io.a2a.grpc.TaskSubscriptionRequest request, convertToStreamResponse(publisher, responseObserver); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -308,6 +325,8 @@ public void deleteTaskPushNotificationConfig(io.a2a.grpc.DeleteTaskPushNotificat responseObserver.onCompleted(); } catch (JSONRPCError e) { handleError(responseObserver, e); + } catch (SecurityException e) { + handleSecurityException(responseObserver, e); } catch (Throwable t) { handleInternalError(responseObserver, t); } @@ -413,6 +432,32 @@ private void handleError(StreamObserver responseObserver, JSONRPCError er responseObserver.onError(status.withDescription(description).asRuntimeException()); } + private void handleSecurityException(StreamObserver responseObserver, SecurityException e) { + Status status; + String description; + + String exceptionClassName = e.getClass().getName(); + + // Attempt to detect common authentication and authorization related exceptions + if (exceptionClassName.contains("Unauthorized") || + exceptionClassName.contains("Unauthenticated") || + exceptionClassName.contains("Authentication")) { + status = Status.UNAUTHENTICATED; + description = A2AErrorMessages.AUTHENTICATION_FAILED; + } else if (exceptionClassName.contains("Forbidden") || + exceptionClassName.contains("AccessDenied") || + exceptionClassName.contains("Authorization")) { + status = Status.PERMISSION_DENIED; + description = A2AErrorMessages.AUTHORIZATION_FAILED; + } else { + // If the security exception type cannot be detected, default to PERMISSION_DENIED + status = Status.PERMISSION_DENIED; + description = "Authorization failed: " + (e.getMessage() != null ? e.getMessage() : "Access denied"); + } + + responseObserver.onError(status.withDescription(description).asRuntimeException()); + } + private void handleInternalError(StreamObserver responseObserver, Throwable t) { handleError(responseObserver, new InternalError(t.getMessage())); }