diff --git a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java index 1a54d2727..334de2c9c 100644 --- a/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java +++ b/codegen/aws/core/src/main/java/software/amazon/smithy/python/aws/codegen/AwsAuthIntegration.java @@ -7,18 +7,20 @@ import static software.amazon.smithy.python.aws.codegen.AwsConfiguration.REGION; import java.util.List; +import java.util.Set; import software.amazon.smithy.aws.traits.auth.SigV4Trait; import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.python.codegen.ApplicationProtocol; import software.amazon.smithy.python.codegen.CodegenUtils; import software.amazon.smithy.python.codegen.ConfigProperty; -import software.amazon.smithy.python.codegen.DerivedProperty; import software.amazon.smithy.python.codegen.GenerationContext; import software.amazon.smithy.python.codegen.SmithyPythonDependency; import software.amazon.smithy.python.codegen.integrations.AuthScheme; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; +import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyInternalApi; /** @@ -38,7 +40,7 @@ public List getClientPlugins(GenerationContext context) { .name("aws_credentials_identity_resolver") .documentation("Resolves AWS Credentials. Required for operations that use Sigv4 Auth.") .type(Symbol.builder() - .name("IdentityResolver[AWSCredentialsIdentity, IdentityProperties]") + .name("IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties]") .addReference(Symbol.builder() .addDependency(SmithyPythonDependency.SMITHY_CORE) .name("IdentityResolver") @@ -51,8 +53,8 @@ public List getClientPlugins(GenerationContext context) { .build()) .addReference(Symbol.builder() .addDependency(SmithyPythonDependency.SMITHY_CORE) - .name("IdentityProperties") - .namespace("smithy_core.interfaces.identity", ".") + .name("AWSIdentityProperties") + .namespace("smithy_aws_core.identity", ".") .build()) .build()) // TODO: Initialize with the provider chain? @@ -69,7 +71,6 @@ public void customize(GenerationContext context) { return; } var trait = context.settings().service(context.model()).expectTrait(SigV4Trait.class); - var params = CodegenUtils.getHttpAuthParamsSymbol(context.settings()); var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings()); // Add a function that generates the http auth option for api key auth. @@ -77,22 +78,19 @@ public void customize(GenerationContext context) { // must be accounted for. context.writerDelegator().useFileWriter(resolver.getDefinitionFile(), resolver.getNamespace(), writer -> { writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_http.aio.interfaces.auth", "HTTPAuthOption"); + writer.addImport("smithy_core.interfaces.auth", "AuthOption", "AuthOptionProtocol"); + writer.addImports("smithy_core.auth", Set.of("AuthOption", "AuthParams")); writer.pushState(); writer.write(""" - def $1L(auth_params: $2T) -> HTTPAuthOption | None: - return HTTPAuthOption( - scheme_id=$3S, - identity_properties={}, - signer_properties={ - "service": $4S, - "region": auth_params.region - } + def $1L(auth_params: AuthParams[Any, Any]) -> AuthOptionProtocol | None: + return AuthOption( + scheme_id=$2S, + identity_properties={}, # type: ignore + signer_properties={} # type: ignore ) """, SIGV4_OPTION_GENERATOR_NAME, - params, SigV4Trait.ID.toString(), trait.getName()); writer.popState(); @@ -119,17 +117,6 @@ public ApplicationProtocol getApplicationProtocol() { return ApplicationProtocol.createDefaultHttpApplicationProtocol(); } - @Override - public List getAuthProperties() { - return List.of( - DerivedProperty.builder() - .name("region") - .source(DerivedProperty.Source.CONFIG) - .type(Symbol.builder().name("str").build()) - .sourcePropertyName("region") - .build()); - } - @Override public Symbol getAuthOptionGenerator(GenerationContext context) { var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings()); @@ -148,5 +135,11 @@ public Symbol getAuthSchemeSymbol(GenerationContext context) { .addDependency(AwsPythonDependency.SMITHY_AWS_CORE) .build(); } + + @Override + public void initializeScheme(GenerationContext context, PythonWriter writer, ServiceShape service) { + var trait = service.expectTrait(SigV4Trait.class); + writer.write("$T(service=$S)", getAuthSchemeSymbol(context), trait.getName()); + } } } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java index f0020285d..8f4c9096e 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java @@ -256,9 +256,11 @@ private void generateOperationExecutor(PythonWriter writer) { writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)), writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w))); } + writer.addStdlibImport("typing", "Any"); writer.addStdlibImport("asyncio", "iscoroutine"); writer.addImports("smithy_core.exceptions", Set.of("SmithyError", "CallError", "RetryError")); + writer.addImport("smithy_core.auth", "AuthParams"); writer.pushState(); writer.putContext("request", transportRequest); writer.putContext("response", transportResponse); @@ -438,53 +440,60 @@ await sleep(retry_token.retry_delay) boolean supportsAuth = !ServiceIndex.of(model).getAuthSchemes(service).isEmpty(); writer.pushState(new ResolveIdentitySection()); - if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { - writer.pushState(new InitializeHttpAuthParametersSection()); - writer.write(""" - # Step 7b: Invoke service_auth_scheme_resolver.resolve_auth_scheme - auth_parameters: $1T = $1T( - operation=operation.schema.id.name, - ${2C|} - ) - - """, - CodegenUtils.getHttpAuthParamsSymbol(context.settings()), - writer.consumer(this::initializeHttpAuthParameters)); - writer.popState(); + if (supportsAuth) { + // TODO: delete InitializeHttpAuthParametersSection writer.addDependency(SmithyPythonDependency.SMITHY_CORE); - writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); writer.addImport("smithy_core.interfaces.identity", "Identity"); - writer.addImports("smithy_http.aio.interfaces.auth", Set.of("HTTPSigner", "HTTPAuthOption")); + writer.addImport("smithy_core.interfaces.auth", "AuthOption"); + writer.addImport("smithy_core.aio.interfaces.auth", "Signer"); + writer.addImport("smithy_core.shapes", "ShapeID"); writer.addStdlibImport("typing", "Any"); writer.write(""" - auth_options = config.http_auth_scheme_resolver.resolve_auth_scheme( + auth_parameters = AuthParams( + protocol_id=ShapeID($1S), + operation=operation, + context=context.properties, + ) + auth_options = config.auth_scheme_resolver.resolve_auth_scheme( auth_parameters=auth_parameters ) - auth_option: HTTPAuthOption | None = None + + auth_option: AuthOption | None = None for option in auth_options: - if option.scheme_id in config.http_auth_schemes: + if option.scheme_id in config.auth_schemes: auth_option = option break - signer: HTTPSigner[Any, Any] | None = None + signer: Signer[$2T, Any, Any] | None = None identity: Identity | None = None + auth_scheme: Any = None if auth_option: - auth_scheme = config.http_auth_schemes[auth_option.scheme_id] + auth_scheme = config.auth_schemes[auth_option.scheme_id] + context.properties["auth_scheme"] = auth_scheme # Step 7c: Invoke auth_scheme.identity_resolver - identity_resolver = auth_scheme.identity_resolver(config=config) + identity_resolver = auth_scheme.identity_resolver(context=context.properties) + context.properties["identity_resolver"] = identity_resolver # Step 7d: Invoke auth_scheme.signer - signer = auth_scheme.signer + signer = auth_scheme.signer() + + # TODO: merge from auth_option + identity_properties = auth_scheme.identity_properties( + context=context.properties + ) + context.properties["identity_properties"] = identity_properties # Step 7e: Invoke identity_resolver.get_identity identity = await identity_resolver.get_identity( - identity_properties=auth_option.identity_properties + identity_properties=identity_properties ) - """); + """, + context.protocolGenerator().getProtocol(), + transportRequest); } writer.popState(); @@ -543,7 +552,7 @@ await sleep(retry_token.retry_delay) writer.pushState(new SignRequestSection()); writer.addStdlibImport("typing", "cast"); - if (context.applicationProtocol().isHttpProtocol() && supportsAuth) { + if (supportsAuth) { writer.addStdlibImport("re"); writer.addStdlibImport("typing", "Any"); writer.addImport("smithy_core.interfaces.identity", "Identity"); @@ -551,40 +560,21 @@ await sleep(retry_token.retry_delay) writer.write(""" # Step 7i: sign the request if auth_option and signer: - logger.debug("HTTP request to sign: %s", context.transport_request) - logger.debug( - "Signer properties: %s", - auth_option.signer_properties - ) + signer_properties = auth_scheme.signer_properties(context=context.properties) + context.properties["signer_properties"] = signer_properties + + logger.debug("Request to sign: %s", context.transport_request) + logger.debug("Signer properties: %s", signer_properties) + context = replace( context, - transport_request= await signer.sign( - http_request=context.transport_request, + transport_request = await signer.sign( + request=context.transport_request, identity=identity, - signing_properties=auth_option.signer_properties, + properties=signer_properties, ) ) logger.debug("Signed HTTP request: %s", context.transport_request) - - # TODO - Move this to separate resolution/population function - fields = context.transport_request.fields - auth_value = fields["Authorization"].as_string() # type: ignore - signature = re.split("Signature=", auth_value)[-1] # type: ignore - context.properties["signature"] = signature.encode('utf-8') - - identity_key = cast( - PropertyKey[Identity | None], - PropertyKey( - key="identity", - value_type=Identity | None # type: ignore - ) - ) - sp_key: PropertyKey[dict[str, Any]] = PropertyKey( - key="signer_properties", - value_type=dict[str, Any] # type: ignore - ) - context.properties[identity_key] = identity - context.properties[sp_key] = auth_option.signer_properties """); } writer.popState(); @@ -757,28 +747,6 @@ private boolean hasEventStream() { return false; } - private void initializeHttpAuthParameters(PythonWriter writer) { - var derived = new LinkedHashSet(); - for (PythonIntegration integration : context.integrations()) { - for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) { - if (plugin.matchesService(model, service) - && plugin.getAuthScheme().isPresent() - && plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) { - derived.addAll(plugin.getAuthScheme().get().getAuthProperties()); - } - } - } - - for (DerivedProperty property : derived) { - var source = property.source().scopeLocation(); - if (property.initializationFunction().isPresent()) { - writer.write("$L=$T($L),", property.name(), property.initializationFunction().get(), source); - } else if (property.sourcePropertyName().isPresent()) { - writer.write("$L=$L.$L,", property.name(), source, property.sourcePropertyName().get()); - } - } - } - private void writeDefaultPlugins(PythonWriter writer, Collection plugins) { for (SymbolReference plugin : plugins) { writer.write("$T,", plugin); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/CodegenUtils.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/CodegenUtils.java index 689215ece..994fc9fad 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/CodegenUtils.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/CodegenUtils.java @@ -102,20 +102,6 @@ public static Symbol getServiceError(PythonSettings settings) { .build(); } - /** - * Gets the symbol for the http auth params. - * - * @param settings The client settings, used to account for module configuration. - * @return Returns the http auth params symbol. - */ - public static Symbol getHttpAuthParamsSymbol(PythonSettings settings) { - return Symbol.builder() - .name("HTTPAuthParams") - .namespace(String.format("%s.auth", settings.moduleName()), ".") - .definitionFile(String.format("./src/%s/auth.py", settings.moduleName())) - .build(); - } - /** * Gets the symbol for the http auth scheme resolver. * diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpAuthGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpAuthGenerator.java index 9d684a16c..69780212a 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpAuthGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/HttpAuthGenerator.java @@ -4,9 +4,7 @@ */ package software.amazon.smithy.python.codegen; -import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import software.amazon.smithy.codegen.core.Symbol; @@ -18,7 +16,6 @@ import software.amazon.smithy.python.codegen.integrations.AuthScheme; import software.amazon.smithy.python.codegen.integrations.PythonIntegration; import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin; -import software.amazon.smithy.python.codegen.sections.GenerateHttpAuthParametersSection; import software.amazon.smithy.python.codegen.sections.GenerateHttpAuthSchemeResolverSection; import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyInternalApi; @@ -40,7 +37,6 @@ final class HttpAuthGenerator implements Runnable { @Override public void run() { var supportedAuthSchemes = new HashMap(); - var properties = new ArrayList(); var service = context.settings().service(context.model()); for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) { @@ -49,43 +45,18 @@ public void run() { && plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) { var scheme = plugin.getAuthScheme().get(); supportedAuthSchemes.put(scheme.getAuthTrait(), scheme); - properties.addAll(scheme.getAuthProperties()); } } } - var params = CodegenUtils.getHttpAuthParamsSymbol(settings); - context.writerDelegator().useFileWriter(params.getDefinitionFile(), params.getNamespace(), writer -> { - generateAuthParameters(writer, params, properties); - }); - var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(settings); context.writerDelegator().useFileWriter(resolver.getDefinitionFile(), resolver.getNamespace(), writer -> { - generateAuthSchemeResolver(writer, params, resolver, supportedAuthSchemes); + generateAuthSchemeResolver(writer, resolver, supportedAuthSchemes); }); } - private void generateAuthParameters(PythonWriter writer, Symbol symbol, List properties) { - var propertyMap = new LinkedHashMap(); - for (DerivedProperty property : properties) { - propertyMap.put(property.name(), property.type()); - } - writer.pushState(new GenerateHttpAuthParametersSection(Map.copyOf(propertyMap))); - writer.addStdlibImport("dataclasses", "dataclass"); - writer.write(""" - @dataclass - class $L: - operation: str - ${#properties} - ${key:L}: ${value:T} | None - ${/properties} - """, symbol.getName()); - writer.popState(); - } - private void generateAuthSchemeResolver( PythonWriter writer, - Symbol paramsSymbol, Symbol resolverSymbol, Map supportedAuthSchemes ) { @@ -100,18 +71,19 @@ private void generateAuthSchemeResolver( writer.pushState(new GenerateHttpAuthSchemeResolverSection(resolvedAuthSchemes)); writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_http.aio.interfaces.auth", "HTTPAuthOption"); + writer.addImport("smithy_core.interfaces.auth", "AuthOption", "AuthOptionProtocol"); + writer.addImport("smithy_core.auth", "AuthParams"); + writer.addStdlibImport("typing", "Any"); writer.write(""" class $1L: - def resolve_auth_scheme(self, auth_parameters: $2T) -> list[HTTPAuthOption]: - auth_options: list[HTTPAuthOption] = [] + def resolve_auth_scheme(self, auth_parameters: AuthParams[Any, Any]) -> list[AuthOptionProtocol]: + auth_options: list[AuthOptionProtocol] = [] + ${2C|} ${3C|} - ${4C|} """, resolverSymbol.getName(), - paramsSymbol, writer.consumer(w -> writeOperationAuthOptions(w, supportedAuthSchemes)), writer.consumer(w -> writeAuthOptions(w, resolvedAuthSchemes))); writer.popState(); diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java index 27529e5a9..1764331c0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ConfigGenerator.java @@ -7,7 +7,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Comparator; -import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.TreeSet; @@ -78,6 +77,23 @@ public final class ConfigGenerator implements Runnable { .build()) .build()) .documentation("A static URI to route requests to.") + .build(), + ConfigProperty.builder() + .name("endpoint_resolver") + .type(Symbol.builder() + .name("_EndpointResolver") + .build()) + .documentation(""" + The endpoint resolver used to resolve the final endpoint per-operation based on the \ + configuration.""") + .nullable(false) + .initialize(writer -> { + writer.addImport("smithy_core.aio.interfaces", "EndpointResolver", "_EndpointResolver"); + writer.pushState(new InitDefaultEndpointResolverSection()); + writer.addImport("smithy_core.aio.endpoints", "StaticEndpointResolver"); + writer.write("self.endpoint_resolver = endpoint_resolver or StaticEndpointResolver()"); + writer.popState(); + }) .build()); // This list contains any properties that must be added to any http-based @@ -131,24 +147,6 @@ private static List getHttpProperties(GenerationContext context) } properties.add(clientBuilder.build()); - properties.add(ConfigProperty.builder() - .name("endpoint_resolver") - .type(Symbol.builder() - .name("_EndpointResolver") - .build()) - .documentation(""" - The endpoint resolver used to resolve the final endpoint per-operation based on the \ - configuration.""") - .nullable(false) - .initialize(writer -> { - writer.addImport("smithy_core.aio.interfaces", "EndpointResolver", "_EndpointResolver"); - writer.pushState(new InitDefaultEndpointResolverSection()); - writer.addImport("smithy_core.aio.endpoints", "StaticEndpointResolver"); - writer.write("self.endpoint_resolver = endpoint_resolver or StaticEndpointResolver()"); - writer.popState(); - }) - .build()); - properties.addAll(HTTP_PROPERTIES); return List.copyOf(properties); } @@ -180,16 +178,21 @@ private static boolean usesHttp2(GenerationContext context) { return false; } - private static List getHttpAuthProperties(GenerationContext context) { + private static List getAuthProperties(GenerationContext context) { return List.of( ConfigProperty.builder() - .name("http_auth_schemes") + .name("auth_schemes") .type(Symbol.builder() - .name("dict[str, HTTPAuthScheme[Any, Any, Any, Any]]") + .name("dict[ShapeID, AuthScheme[Any, Any, Any, Any]]") + .addReference(Symbol.builder() + .name("ShapeID") + .namespace("smithy_core.shapes", ".") + .addDependency(SmithyPythonDependency.SMITHY_CORE) + .build()) .addReference(Symbol.builder() - .name("HTTPAuthScheme") - .namespace("smithy_http.aio.interfaces.auth", ".") - .addDependency(SmithyPythonDependency.SMITHY_HTTP) + .name("AuthScheme") + .namespace("smithy_core.aio.interfaces.auth", ".") + .addDependency(SmithyPythonDependency.SMITHY_CORE) .build()) .addReference(Symbol.builder() .name("Any") @@ -197,43 +200,38 @@ private static List getHttpAuthProperties(GenerationContext cont .putProperty(SymbolProperties.STDLIB, true) .build()) .build()) - .documentation("A map of http auth scheme ids to http auth schemes.") + .documentation("A map of auth scheme ids to auth schemes.") .nullable(false) - .initialize(writer -> writeDefaultHttpAuthSchemes(context, writer)) + .initialize(writer -> writeDefaultAuthSchemes(context, writer)) .build(), ConfigProperty.builder() - .name("http_auth_scheme_resolver") + .name("auth_scheme_resolver") .type(CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings())) .documentation( - "An http auth scheme resolver that determines the auth scheme for each operation.") + "An auth scheme resolver that determines the auth scheme for each operation.") .nullable(false) .initialize(writer -> writer.write( - "self.http_auth_scheme_resolver = http_auth_scheme_resolver or HTTPAuthSchemeResolver()")) + "self.auth_scheme_resolver = auth_scheme_resolver or HTTPAuthSchemeResolver()")) .build()); } - private static void writeDefaultHttpAuthSchemes(GenerationContext context, PythonWriter writer) { - var supportedAuthSchemes = new LinkedHashMap(); + private static void writeDefaultAuthSchemes(GenerationContext context, PythonWriter writer) { + writer.pushState(); var service = context.settings().service(context.model()); + + writer.openBlock("self.auth_schemes = auth_schemes or {"); + writer.addImport("smithy_core.shapes", "ShapeID"); for (PythonIntegration integration : context.integrations()) { for (RuntimeClientPlugin plugin : integration.getClientPlugins(context)) { - if (plugin.matchesService(context.model(), service) - && plugin.getAuthScheme().isPresent() - && plugin.getAuthScheme().get().getApplicationProtocol().isHttpProtocol()) { + if (plugin.matchesService(context.model(), service) && plugin.getAuthScheme().isPresent()) { var scheme = plugin.getAuthScheme().get(); - supportedAuthSchemes.put(scheme.getAuthTrait().toString(), scheme.getAuthSchemeSymbol(context)); + writer.write("ShapeID($S): ${C|},", + scheme.getAuthTrait(), + writer.consumer(w -> scheme.initializeScheme(context, writer, service))); } } } - writer.pushState(); - writer.putContext("authSchemes", supportedAuthSchemes); - writer.write(""" - self.http_auth_schemes = http_auth_schemes or { - ${#authSchemes} - ${key:S}: ${value:T}(), - ${/authSchemes} - } - """); + writer.closeBlock("}"); writer.popState(); } @@ -292,16 +290,18 @@ private void generateConfig(GenerationContext context, PythonWriter writer) { var properties = new TreeSet<>(Comparator.comparing(ConfigProperty::name)); properties.addAll(BASE_PROPERTIES); + // Add in auth configuration if the service supports auth. + var serviceIndex = ServiceIndex.of(context.model()); + if (!serviceIndex.getAuthSchemes(settings.service()).isEmpty()) { + properties.addAll(getAuthProperties(context)); + writer.onSection(new AddAuthHelper()); + } + // Smithy is transport agnostic, so we don't add http-related properties by default. // Nevertheless, HTTP is the most common use case so we standardize those settings // and add them in if the protocol is going to need them. - var serviceIndex = ServiceIndex.of(context.model()); if (context.applicationProtocol().isHttpProtocol()) { properties.addAll(getHttpProperties(context)); - if (!serviceIndex.getAuthSchemes(settings.service()).isEmpty()) { - properties.addAll(getHttpAuthProperties(context)); - writer.onSection(new AddAuthHelper()); - } } var model = context.model(); @@ -398,14 +398,14 @@ public void write(PythonWriter writer, String previousText, ConfigSection sectio // Note that this is indented to keep it at the proper indentation level. writer.write(""" - def set_http_auth_scheme(self, scheme: HTTPAuthScheme[Any, Any, Any, Any]) -> None: + def set_auth_scheme(self, scheme: AuthScheme[Any, Any, Any, Any]) -> None: \"""Sets the implementation of an auth scheme. Using this method ensures the correct key is used. :param scheme: The auth scheme to add. \""" - self.http_auth_schemes[scheme.scheme_id] = scheme + self.auth_schemes[scheme.scheme_id] = scheme """); } } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/AuthScheme.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/AuthScheme.java index 6b2371f2b..b68225fca 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/AuthScheme.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/AuthScheme.java @@ -4,13 +4,12 @@ */ package software.amazon.smithy.python.codegen.integrations; -import java.util.Collections; -import java.util.List; import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.python.codegen.ApplicationProtocol; -import software.amazon.smithy.python.codegen.DerivedProperty; import software.amazon.smithy.python.codegen.GenerationContext; +import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyUnstableApi; /** @@ -34,15 +33,6 @@ public interface AuthScheme { */ ApplicationProtocol getApplicationProtocol(); - /** - * Gets a list of properties needed from config or input to authenticate requests. - * - * @return Returns a list of properties to gather for auth. - */ - default List getAuthProperties() { - return Collections.emptyList(); - } - /** * Gets a function that returns a potential auth option for a request. * @@ -59,4 +49,7 @@ default List getAuthProperties() { * @return Returns the symbol for the auth scheme implementation. */ Symbol getAuthSchemeSymbol(GenerationContext context); + + // TODO: replace with from_trait + void initializeScheme(GenerationContext context, PythonWriter writer, ServiceShape service); } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/HttpApiKeyAuth.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/HttpApiKeyAuth.java index 3cd445fe2..819cd2bf0 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/HttpApiKeyAuth.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/HttpApiKeyAuth.java @@ -5,7 +5,10 @@ package software.amazon.smithy.python.codegen.integrations; import java.util.List; +import java.util.Locale; +import java.util.Set; import software.amazon.smithy.codegen.core.Symbol; +import software.amazon.smithy.model.shapes.ServiceShape; import software.amazon.smithy.model.shapes.ShapeId; import software.amazon.smithy.model.traits.HttpApiKeyAuthTrait; import software.amazon.smithy.python.codegen.ApplicationProtocol; @@ -14,6 +17,7 @@ import software.amazon.smithy.python.codegen.GenerationContext; import software.amazon.smithy.python.codegen.PythonSettings; import software.amazon.smithy.python.codegen.SmithyPythonDependency; +import software.amazon.smithy.python.codegen.writer.PythonWriter; import software.amazon.smithy.utils.SmithyInternalApi; /** @@ -29,11 +33,17 @@ public List getClientPlugins(GenerationContext context) { return List.of( RuntimeClientPlugin.builder() .servicePredicate((model, service) -> service.hasTrait(HttpApiKeyAuthTrait.class)) + .addConfigProperty(ConfigProperty.builder() + .name("api_key") + .documentation("The API key to send along with requests.") + .type(Symbol.builder().name("str").build()) + .nullable(true) + .build()) .addConfigProperty(ConfigProperty.builder() .name("api_key_identity_resolver") - .documentation("Resolves the API key. Required for operations that use API key auth.") + .documentation("Resolves the API key.") .type(Symbol.builder() - .name("IdentityResolver[ApiKeyIdentity, IdentityProperties]") + .name("IdentityResolver[APIKeyIdentity, APIKeyIdentityProperties]") .addReference(Symbol.builder() .addDependency(SmithyPythonDependency.SMITHY_CORE) .name("IdentityResolver") @@ -41,16 +51,22 @@ public List getClientPlugins(GenerationContext context) { .build()) .addReference(Symbol.builder() .addDependency(SmithyPythonDependency.SMITHY_HTTP) - .name("ApiKeyIdentity") + .name("APIKeyIdentity") .namespace("smithy_http.aio.identity.apikey", ".") .build()) .addReference(Symbol.builder() - .addDependency(SmithyPythonDependency.SMITHY_CORE) - .name("IdentityProperties") - .namespace("smithy_core.interfaces.identity", ".") + .addDependency(SmithyPythonDependency.SMITHY_HTTP) + .name("APIKeyIdentityProperties") + .namespace("smithy_http.aio.identity.apikey", ".") .build()) .build()) - .nullable(true) + .initialize(writer -> { + writer.addImport("smithy_http.aio.identity.apikey", "APIKeyIdentityResolver"); + writer.write(""" + if api_key_identity_resolver is None: + api_key_identity_resolver = APIKeyIdentityResolver() + """); + }) .build()) .authScheme(new ApiKeyAuthScheme()) .build()); @@ -61,8 +77,6 @@ public void customize(GenerationContext context) { if (!hasApiKeyAuth(context)) { return; } - var trait = context.settings().service(context.model()).expectTrait(HttpApiKeyAuthTrait.class); - var params = CodegenUtils.getHttpAuthParamsSymbol(context.settings()); var resolver = CodegenUtils.getHttpAuthSchemeResolverSymbol(context.settings()); // Add a function that generates the http auth option for api key auth. @@ -71,32 +85,21 @@ public void customize(GenerationContext context) { context.writerDelegator().useFileWriter(resolver.getDefinitionFile(), resolver.getNamespace(), writer -> { writer.addDependency(SmithyPythonDependency.SMITHY_CORE); writer.addDependency(SmithyPythonDependency.SMITHY_HTTP); - writer.addImport("smithy_http.aio.interfaces.auth", "HTTPAuthOption"); - writer.addImport("smithy_http.aio.auth.apikey", "ApiKeyLocation"); + writer.addImport("smithy_core.interfaces.auth", "AuthOption", "AuthOptionProtocol"); + writer.addImports("smithy_core.auth", Set.of("AuthOption", "AuthParams")); + writer.addImport("smithy_core.shapes", "ShapeID"); + writer.addStdlibImport("typing", "Any"); writer.pushState(); - - // Push the scheme into the context to allow for conditionally adding - // it to the properties dict. - writer.putContext("scheme", trait.getScheme().orElse(null)); writer.write(""" - def $1L(auth_params: $2T) -> HTTPAuthOption | None: - return HTTPAuthOption( - scheme_id=$3S, - identity_properties={}, - signer_properties={ - "name": $4S, - "location": ApiKeyLocation($5S), - ${?scheme} - "scheme": ${scheme:S}, - ${/scheme} - } + def $1L(auth_params: AuthParams[Any, Any]) -> AuthOptionProtocol | None: + return AuthOption( + scheme_id=ShapeID($2S), + identity_properties={}, # type: ignore + signer_properties={}, # type: ignore ) """, OPTION_GENERATOR_NAME, - params, - HttpApiKeyAuthTrait.ID.toString(), - trait.getName(), - trait.getIn().toString()); + HttpApiKeyAuthTrait.ID.toString()); writer.popState(); }); } @@ -137,10 +140,31 @@ public Symbol getAuthOptionGenerator(GenerationContext context) { @Override public Symbol getAuthSchemeSymbol(GenerationContext context) { return Symbol.builder() - .name("ApiKeyAuthScheme") + .name("APIKeyAuthScheme") .namespace("smithy_http.aio.auth.apikey", ".") .addDependency(SmithyPythonDependency.SMITHY_HTTP) .build(); } + + @Override + public void initializeScheme(GenerationContext context, PythonWriter writer, ServiceShape service) { + var trait = service.expectTrait(HttpApiKeyAuthTrait.class); + writer.pushState(); + writer.putContext("scheme", trait.getScheme().orElse(null)); + writer.addImport("smithy_core.traits", "APIKeyLocation"); + writer.write(""" + $T( + name=$S, + location=APIKeyLocation.$L, + ${?scheme} + scheme=${scheme:S}, + ${/scheme} + ) + """, + getAuthSchemeSymbol(context), + trait.getName(), + trait.getIn().name().toUpperCase(Locale.ENGLISH)); + writer.popState(); + } } } diff --git a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java index a5c788e29..bd2dc6101 100644 --- a/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java +++ b/codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java @@ -409,27 +409,24 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) { writer.addDependency(SmithyPythonDependency.SMITHY_JSON); writer.addDependency(SmithyPythonDependency.SMITHY_AWS_EVENT_STREAM); writer.addImport("smithy_json", "JSONCodec"); - writer.addImport("smithy_core.aio.types", "AsyncBytesReader"); writer.addImport("smithy_core.types", "TimestampFormat"); - writer.addImport("smithy_aws_event_stream.aio", "AWSEventPublisher"); + writer.addImports("smithy_aws_event_stream.aio", Set.of("AWSEventPublisher", "SigningConfig")); writer.addImport("aws_sdk_signers", "AsyncEventSigner"); writer.write( """ # TODO - Move this out of the RestJSON generator - ctx = request_context - signer_properties = ctx.properties.get("signer_properties") # type: ignore - identity = ctx.properties.get("identity") # type: ignore - signature = ctx.properties.get("signature") # type: ignore - signer = AsyncEventSigner( - signing_properties=signer_properties, # type: ignore - identity=identity, # type: ignore - initial_signature=signature, # type: ignore - ) + ctx = request_context.properties + event_signer = ctx["auth_scheme"].event_signer(request=request_context.transport_request) codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS) publisher = AWSEventPublisher[Any]( payload_codec=codec, async_writer=request_context.transport_request.body, # type: ignore - signer=signer, # type: ignore + signing_config=SigningConfig( + signer=event_signer, + signing_properties=ctx["signing_properties"], + identity_resolver=ctx["identity_resolver"], + identity_properties=ctx["identity_properties"], + ) ) """); } diff --git a/designs/auth.md b/designs/auth.md new file mode 100644 index 000000000..3658a7987 --- /dev/null +++ b/designs/auth.md @@ -0,0 +1,178 @@ +# Identity and Authentication + +Smithy services may define any number of authentication schemes via traits and +configure which schemes are available and prioritized on a per-operation basis. +This document describes how an auth scheme is configured and picked at runtime. + +## Auth Schemes + +Everything to do with an auth scheme is contained within an implementation of +the `AuthScheme` Protocol. These implementations construct the +[identity resolvers](#identity-resolvers) and [signers](#signers) as well as the +extra properties needed for identity resolution and signing. + +Each `AuthScheme` has a `scheme_id`, which is the Smithy shape ID of the auth +scheme. + +```python +class AuthScheme[R: Request, I: Identity, IP: Mapping[str, Any], SP: Mapping[str, Any]]( + Protocol +): + scheme_id: ShapeID + + def identity_properties(self, *, context: _TypedProperties) -> IP: + ... + + def identity_resolver( + self, *, context: _TypedProperties + ) -> IdentityResolver[I, IP]: + ... + + def signer_properties(self, *, context: _TypedProperties) -> SP: + ... + + def signer(self) -> Signer[R, I, SP]: + ... + + def event_signer(self, *, request: R) -> EventSigner[I, SP] | None: + return None +``` + +`AuthScheme` implementations SHOULD cache identity resolvers and signers if +possible. + +### Auth Scheme Resolution + +Services and operation may support any number of auth schemes, each of which may +or may not be availble for a number of reasons, such as not being configured. An +`AuthSchemeResolver` is used to figure out which auth scheme to use for each +request. + +```python +class AuthSchemeResolver(Protocol): + def resolve_auth_scheme( + self, *, auth_parameters: AuthParams[Any, Any] + ) -> Sequence[AuthOption]: + ... + +class AuthOption(Protocol): + scheme_id: ShapeID + identity_properties: TypedProperties + signer_properties: TypedProperties + +@dataclass(kw_only=True, frozen=True) +class AuthParams[I: SerializeableShape, O: DeserializeableShape]: + protocol_id: ShapeID + operation: APIOperation[I, O] + context: TypedProperties +``` + +The resolver is given the ID of the protocol being used by the client, the +schema of the operation being invoked, and the operation invocation context. It +returns a priority-ordered list of auth schemes to pick from, along with +optional overrides for identity and signer properties. + +The client will pick the first auth scheme in the list that has an entry in the +`auth_schemes` [configuration](#configuration) dict and which is able to resolve +an identity. + +The resolver itself is stored in the service's [configuration](#configuration) +object, and may be replaced with a custom implemenatation. Default +implementations are generated based on the modeled auth traits. + +## Identity + +Each auth scheme is associated with an identity type, such as an API key or +username and password. In the AWS context, this is the access key id, secret +access key, and optionally the session token. + +Identities MAY be shared between multiple auth schemes. For example, the AWS +sigv4 and sigv4a auth schemes use the same AWS identity. + +In Python, each identity type MUST implement the following `Protocol`: + +```python +@runtime_checkable +class Identity(Protocol): + + expiration: datetime | None = None + + @property + def is_expired(self) -> bool: + if self.expiration is None: + return False + return datetime.now(tz=UTC) >= self.expiration +``` + +An `Identity` may be derived from any number of sources, such as configuration +properties or environement variables. These different sources are loaded by an +[`IdentityResolver`](#identity-resolvers). + +### Identity Resolvers + +Identity resolvers are responsible for contructiong an `Identity` for a request. + +```python +class IdentityResolver[I: Identity, IP: Mapping[str, Any]](Protocol): + + async def get_identity(self, *, properties: IP) -> I: + ... +``` + +Each identity source SHOULD have its own identity resolver implementation. If an +`Identity` is supported by multiple `IdentityResolver`s, those resolver SHOULD +be prioritized to provide a stable resolution strategy. A +`ChainedIdentityResolver` implementation is provided that implements this +behavior generically. + +The `get_identity` function takes only one (keyword-only) argument - a mapping +of properties that is refined by the `IP` generic parameter. The identity +properties are contructed by the `AuthScheme`'s `identity_properties` method. + +Identity resolvers are constructed by the `AuthScheme`'s `identity_resolver` +method. + +## Signers + +Signers are responsible for signing transport requests so that they can be +authenticated by the server. They are given the transport request to sign, the +resolved identity, and a property mapping that is used for any additional +configuration needed. The signing properties are constructed by the +`AuthScheme`'s `signer_properties` method. + +```python +class Signer[R: Request, I, SP: Mapping[str, Any]](Protocol): + async def sign(self, *, request: R, identity: I, properties: SP) -> R: + ... +``` + +Signers are constructed by the `AuthScheme`'s `signer` method. + +Signers MAY modify the given request and return it, or construct a new signed +request. + +### Event Signers + +Auh schemes MAY also have an associated event signer, which signs events that +are sent to a server. They behave in the same way as normal signers, except that +they sign an event instead of a transport request. The properties passed to this +signing method are identical to those pased to the request signer. + +```python +class EventSigner[I, SP: Mapping[str, Any]](Protocol): + + # TODO: add a protocol type for events + async def sign(self, *, event: Any, identity: I, properties: SP) -> Any: + ... +``` + +## Configuration + +All services with at least one auth trait will have the following properites on +their configuration object. + +```python +class AuthConfig[R: Request](Protocol): + auth_scheme_resolver: AuthSchemeResolver + auth_schemes: dict[ShapeID, AuthScheme[R, Any, Any, Any]] +``` diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/_identity.py b/packages/aws-sdk-signers/src/aws_sdk_signers/_identity.py index 4e32ea52c..2d0a194e2 100644 --- a/packages/aws-sdk-signers/src/aws_sdk_signers/_identity.py +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/_identity.py @@ -2,21 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime -from .interfaces.identity import Identity +from .interfaces.identity import AWSCredentialsIdentity @dataclass(kw_only=True) -class AWSCredentialIdentity(Identity): +class AWSCredentialIdentity(AWSCredentialsIdentity): access_key_id: str secret_access_key: str session_token: str | None = None expiration: datetime | None = None - - @property - def is_expired(self) -> bool: - """Whether the identity is expired.""" - if self.expiration is None: - return False - return self.expiration < datetime.now(UTC) diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/identity.py b/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/identity.py index d5599ef90..5c15e54e3 100644 --- a/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/identity.py +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/interfaces/identity.py @@ -3,39 +3,39 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from typing import Protocol, runtime_checkable +@runtime_checkable class Identity(Protocol): """An entity available to the client representing who the user is.""" - # The expiration time of the identity. If time zone is provided, - # it is updated to UTC. The value must always be in UTC. expiration: datetime | None = None + """The expiration time of the identity. + + If time zone is provided, it is updated to UTC. The value must always be in UTC. + """ @property def is_expired(self) -> bool: """Whether the identity is expired.""" - ... + if self.expiration is None: + return False + return datetime.now(tz=UTC) >= self.expiration @runtime_checkable -class AWSCredentialsIdentity(Protocol): +class AWSCredentialsIdentity(Identity, Protocol): """AWS Credentials Identity.""" - # The access key ID. access_key_id: str + """A unique identifier for an AWS user or role.""" - # The secret access key. secret_access_key: str + """A secret key used in conjunction with the access key ID to authenticate + programmatic access to AWS services.""" - # The session token. - session_token: str | None - - expiration: datetime | None = None - - @property - def is_expired(self) -> bool: - """Whether the identity is expired.""" - ... + session_token: str | None = None + """A temporary token used to specify the current session for the supplied + credentials.""" diff --git a/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py b/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py index 681edf2ef..a43343b14 100644 --- a/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py +++ b/packages/aws-sdk-signers/src/aws_sdk_signers/signers.py @@ -15,7 +15,6 @@ from urllib.parse import parse_qsl, quote from ._http import AWSRequest, Field, URI -from ._identity import AWSCredentialIdentity from ._io import AsyncBytesReader from .exceptions import AWSSDKWarning, MissingExpectedParameterException from .interfaces.identity import AWSCredentialsIdentity as _AWSCredentialsIdentity @@ -55,27 +54,27 @@ class SigV4Signer: def sign( self, *, - signing_properties: SigV4SigningProperties, - http_request: AWSRequest, - identity: AWSCredentialIdentity, + request: AWSRequest, + identity: _AWSCredentialsIdentity, + properties: SigV4SigningProperties, ) -> AWSRequest: """Generate and apply a SigV4 Signature to a copy of the supplied request. - :param signing_properties: SigV4SigningProperties to define signing primitives - such as the target service, region, and date. - :param http_request: An AWSRequest to sign prior to sending to the service. + :param request: An AWSRequest to sign prior to sending to the service. :param identity: A set of credentials representing an AWS Identity or role capacity. + :param properties: SigV4SigningProperties to define signing primitives such as + the target service, region, and date. """ # Copy and prepopulate any missing values in the # supplied request and signing properties. self._validate_identity(identity=identity) new_signing_properties = self._normalize_signing_properties( - signing_properties=signing_properties + signing_properties=properties ) assert "date" in new_signing_properties - new_request = self._generate_new_request(request=http_request) + new_request = self._generate_new_request(request=request) self._apply_required_fields( request=new_request, signing_properties=new_signing_properties, @@ -164,7 +163,7 @@ def _signature( def _hash(self, key: bytes, value: str) -> bytes: return hmac.new(key=key, msg=value.encode(), digestmod=sha256).digest() - def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None: + def _validate_identity(self, *, identity: _AWSCredentialsIdentity) -> None: """Perform runtime and expiration checks before attempting signing.""" if not isinstance(identity, _AWSCredentialsIdentity): # pyright: ignore raise ValueError( @@ -195,7 +194,7 @@ def _apply_required_fields( *, request: AWSRequest, signing_properties: SigV4SigningProperties, - identity: AWSCredentialIdentity, + identity: _AWSCredentialsIdentity, ) -> None: # Apply required X-Amz-Date if neither X-Amz-Date nor Date are present. if "Date" not in request.fields and "X-Amz-Date" not in request.fields: @@ -427,26 +426,26 @@ class AsyncSigV4Signer: async def sign( self, *, - signing_properties: SigV4SigningProperties, - http_request: AWSRequest, - identity: AWSCredentialIdentity, + request: AWSRequest, + identity: _AWSCredentialsIdentity, + properties: SigV4SigningProperties, ) -> AWSRequest: """Generate and apply a SigV4 Signature to a copy of the supplied request. - :param signing_properties: SigV4SigningProperties to define signing primitives - such as the target service, region, and date. - :param http_request: An AWSRequest to sign prior to sending to the service. + :param request: An AWSRequest to sign prior to sending to the service. :param identity: A set of credentials representing an AWS Identity or role capacity. + :param properties: SigV4SigningProperties to define signing primitives such as + the target service, region, and date. """ # Copy and prepopulate any missing values in the # supplied request and signing properties. await self._validate_identity(identity=identity) new_signing_properties = await self._normalize_signing_properties( - signing_properties=signing_properties + signing_properties=properties ) - new_request = await self._generate_new_request(request=http_request) + new_request = await self._generate_new_request(request=request) await self._apply_required_fields( request=new_request, signing_properties=new_signing_properties, @@ -455,7 +454,7 @@ async def sign( # Construct core signing components canonical_request = await self.canonical_request( - signing_properties=signing_properties, + signing_properties=properties, request=new_request, ) string_to_sign = await self.string_to_sign( @@ -535,7 +534,7 @@ async def _signature( async def _hash(self, key: bytes, value: str) -> bytes: return hmac.new(key=key, msg=value.encode(), digestmod=sha256).digest() - async def _validate_identity(self, *, identity: AWSCredentialIdentity) -> None: + async def _validate_identity(self, *, identity: _AWSCredentialsIdentity) -> None: """Perform runtime and expiration checks before attempting signing.""" if not isinstance(identity, _AWSCredentialsIdentity): # pyright: ignore raise ValueError( @@ -566,7 +565,7 @@ async def _apply_required_fields( *, request: AWSRequest, signing_properties: SigV4SigningProperties, - identity: AWSCredentialIdentity, + identity: _AWSCredentialsIdentity, ) -> None: # Apply required X-Amz-Date if neither X-Amz-Date nor Date are present. if "Date" not in request.fields and "X-Amz-Date" not in request.fields: @@ -804,26 +803,25 @@ class AsyncEventSigner: def __init__( self, *, - signing_properties: SigV4SigningProperties, - identity: AWSCredentialIdentity, initial_signature: bytes, + event_encoder_cls: type["EventHeaderEncoder"], ): - self._signing_properties = signing_properties - self._identity = identity self._prior_signature = initial_signature self._signing_lock = asyncio.Lock() + self._event_encoder_cls = event_encoder_cls - async def sign_event( + async def sign( self, *, - event_message: "EventMessage", - event_encoder_cls: type["EventHeaderEncoder"], + event: "EventMessage", + identity: _AWSCredentialsIdentity, + properties: SigV4SigningProperties, ) -> "EventMessage": async with self._signing_lock: # Copy and prepopulate any missing values in the # signing properties. new_signing_properties = SigV4SigningProperties( # type: ignore - **self._signing_properties + **properties ) # TODO: If date is in properties, parse a datetime from it. date_obj = datetime.datetime.now(datetime.UTC) @@ -834,11 +832,11 @@ async def sign_event( timestamp = new_signing_properties["date"] headers: dict[str, str | bytes | datetime.datetime] = {":date": date_obj} - encoder = event_encoder_cls() + encoder = self._event_encoder_cls() encoder.encode_headers(headers) encoded_headers = encoder.get_result() - payload = event_message.encode() + payload = event.encode() string_to_sign = await self._event_string_to_sign( timestamp=timestamp, @@ -848,19 +846,20 @@ async def sign_event( prior_signature=self._prior_signature, ) event_signature = await self._sign_event( + identity=identity, timestamp=timestamp, string_to_sign=string_to_sign, - signing_properties=new_signing_properties, + properties=new_signing_properties, ) headers[":chunk-signature"] = event_signature - event_message.headers = headers - event_message.payload = payload + event.headers = headers + event.payload = payload # set new prior signature before releasing the lock self._prior_signature = hexlify(event_signature) - return event_message + return event async def _event_string_to_sign( self, @@ -885,13 +884,14 @@ async def _sign_event( *, timestamp: str, string_to_sign: str, - signing_properties: SigV4SigningProperties, + identity: _AWSCredentialsIdentity, + properties: SigV4SigningProperties, ) -> bytes: - key = self._identity.secret_access_key.encode("utf-8") + key = identity.secret_access_key.encode("utf-8") today = timestamp[:8].encode("utf-8") k_date = self._hash(b"AWS4" + key, today) - k_region = self._hash(k_date, signing_properties["region"].encode("utf-8")) - k_service = self._hash(k_region, signing_properties["service"].encode("utf-8")) + k_region = self._hash(k_date, properties["region"].encode("utf-8")) + k_service = self._hash(k_region, properties["service"].encode("utf-8")) k_signing = self._hash(k_service, b"aws4_request") return self._hash(k_signing, string_to_sign.encode("utf-8")) diff --git a/packages/aws-sdk-signers/tests/unit/auth/test_sigv4.py b/packages/aws-sdk-signers/tests/unit/auth/test_sigv4.py index 5f596bf0f..73ccfc366 100644 --- a/packages/aws-sdk-signers/tests/unit/auth/test_sigv4.py +++ b/packages/aws-sdk-signers/tests/unit/auth/test_sigv4.py @@ -112,8 +112,8 @@ def _test_signature_version_4_sync(test_case_name: str, signer: SigV4Signer) -> assert test_case.string_to_sign == actual_string_to_sign with pytest.warns(AWSSDKWarning): signed_request = signer.sign( - signing_properties=signing_props, - http_request=request, + properties=signing_props, + request=request, identity=test_case.credentials, ) assert ( @@ -151,8 +151,8 @@ async def _test_signature_version_4_async( assert test_case.string_to_sign == actual_string_to_sign with pytest.warns(AWSSDKWarning): signed_request = await signer.sign( - signing_properties=signing_props, - http_request=request, + properties=signing_props, + request=request, identity=test_case.credentials, ) assert ( diff --git a/packages/aws-sdk-signers/tests/unit/test_signers.py b/packages/aws-sdk-signers/tests/unit/test_signers.py index 1c5e52789..54317493f 100644 --- a/packages/aws-sdk-signers/tests/unit/test_signers.py +++ b/packages/aws-sdk-signers/tests/unit/test_signers.py @@ -67,8 +67,8 @@ def test_sign( signing_properties: SigV4SigningProperties, ) -> None: signed_request = self.SIGV4_SYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=aws_identity, ) assert isinstance(signed_request, AWSRequest) @@ -85,8 +85,8 @@ def test_sign_doesnt_modify_original_request( ) -> None: original_request = copy.deepcopy(aws_request) signed_request = self.SIGV4_SYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=aws_identity, ) assert isinstance(signed_request, AWSRequest) @@ -103,8 +103,8 @@ def test_sign_with_invalid_identity( assert not isinstance(identity, AWSCredentialIdentity) with pytest.raises(ValueError): self.SIGV4_SYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=identity, ) @@ -119,8 +119,8 @@ def test_sign_with_expired_identity( ) with pytest.raises(ValueError): self.SIGV4_SYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=identity, ) @@ -135,8 +135,8 @@ async def test_sign( signing_properties: SigV4SigningProperties, ) -> None: signed_request = await self.SIGV4_ASYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=aws_identity, ) assert isinstance(signed_request, AWSRequest) @@ -153,8 +153,8 @@ async def test_sign_doesnt_modify_original_request( ) -> None: original_request = copy.deepcopy(aws_request) signed_request = await self.SIGV4_ASYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=aws_identity, ) assert isinstance(signed_request, AWSRequest) @@ -171,8 +171,8 @@ async def test_sign_with_invalid_identity( assert not isinstance(identity, AWSCredentialIdentity) with pytest.raises(ValueError): await self.SIGV4_ASYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=identity, ) @@ -187,7 +187,7 @@ async def test_sign_with_expired_identity( ) with pytest.raises(ValueError): await self.SIGV4_ASYNC_SIGNER.sign( - signing_properties=signing_properties, - http_request=aws_request, + properties=signing_properties, + request=aws_request, identity=identity, ) diff --git a/packages/smithy-aws-core/pyproject.toml b/packages/smithy-aws-core/pyproject.toml index dc9ef220e..30dfece40 100644 --- a/packages/smithy-aws-core/pyproject.toml +++ b/packages/smithy-aws-core/pyproject.toml @@ -14,6 +14,11 @@ dependencies = [ requires = ["hatchling"] build-backend = "hatchling.build" +[project.optional-dependencies] +eventstream = [ + "smithy-aws-event-stream" +] + [tool.hatch.build] exclude = [ "tests", diff --git a/packages/smithy-aws-core/src/smithy_aws_core/auth/sigv4.py b/packages/smithy-aws-core/src/smithy_aws_core/auth/sigv4.py index c206974cb..7442db058 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/auth/sigv4.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/auth/sigv4.py @@ -1,39 +1,62 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass -from typing import Protocol +import re +from typing import TYPE_CHECKING, Protocol, Self -from aws_sdk_signers import AsyncSigV4Signer, SigV4SigningProperties -from smithy_core.aio.interfaces.identity import IdentityResolver +from aws_sdk_signers import AsyncEventSigner, AsyncSigV4Signer, SigV4SigningProperties +from smithy_core.aio.interfaces.auth import AuthScheme, EventSigner, Signer from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties -from smithy_http.aio.interfaces.auth import HTTPAuthScheme, HTTPSigner +from smithy_core.interfaces import TypedProperties as _TypedProperties +from smithy_core.types import PropertyKey +from smithy_http.aio.interfaces import HTTPRequest -from ..identity import AWSCredentialsIdentity +from ..identity import ( + AWS_IDENTITY_CONFIG, + AWSCredentialsIdentity, + AWSCredentialsResolver, + AWSIdentityProperties, +) +from ..traits import SigV4Trait + +if TYPE_CHECKING: + from smithy_aws_event_stream.events import EventHeaderEncoder + +try: + from smithy_aws_event_stream.events import EventHeaderEncoder + + HAS_EVENT_STREAM = True +except ImportError: + HAS_EVENT_STREAM = False # type: ignore class SigV4Config(Protocol): - aws_credentials_identity_resolver: ( - IdentityResolver[AWSCredentialsIdentity, IdentityProperties] | None - ) + region: str | None + aws_credentials_identity_resolver: AWSCredentialsResolver | None + + +SIGV4_CONFIG = PropertyKey(key="config", value_type=SigV4Config) + +type SigV4Signer = Signer[HTTPRequest, AWSCredentialsIdentity, SigV4SigningProperties] -@dataclass(init=False) class SigV4AuthScheme( - HTTPAuthScheme[ - AWSCredentialsIdentity, SigV4Config, IdentityProperties, SigV4SigningProperties + AuthScheme[ + HTTPRequest, + AWSCredentialsIdentity, + AWSIdentityProperties, + SigV4SigningProperties, ] ): """SigV4 AuthScheme.""" - scheme_id: str = "aws.auth#sigv4" - signer: HTTPSigner[AWSCredentialsIdentity, SigV4SigningProperties] + scheme_id = SigV4Trait.id + _signer: SigV4Signer def __init__( self, *, - signer: HTTPSigner[AWSCredentialsIdentity, SigV4SigningProperties] - | None = None, + service: str, + signer: SigV4Signer | None = None, ) -> None: """Constructor. @@ -41,14 +64,57 @@ def __init__( :param signer: The signer used to sign the request. """ # TODO: There are type mismatches in the signature of the "sign" method. - self.signer = signer or AsyncSigV4Signer() # type: ignore + # The issues seems to be that it's not using protocols in its signature + self._signer = signer or AsyncSigV4Signer() # type: ignore + self._service = service - def identity_resolver( - self, *, config: SigV4Config - ) -> IdentityResolver[AWSCredentialsIdentity, IdentityProperties]: - if not config.aws_credentials_identity_resolver: + def identity_properties( + self, *, context: _TypedProperties + ) -> AWSIdentityProperties: + config = context[AWS_IDENTITY_CONFIG] + return { + "access_key_id": config.access_key_id, + "secret_access_key": config.secret_access_key, + "session_token": config.session_token, + } + + def identity_resolver(self, *, context: _TypedProperties) -> AWSCredentialsResolver: + config = context.get(SIGV4_CONFIG) + if config is None or config.aws_credentials_identity_resolver is None: raise SmithyIdentityError( "Attempted to use SigV4 auth, but aws_credentials_identity_resolver was not " "set on the config." ) return config.aws_credentials_identity_resolver + + def signer_properties(self, *, context: _TypedProperties) -> SigV4SigningProperties: + config = context.get(SIGV4_CONFIG) + if config is None or config.region is None: + raise SmithyIdentityError( + "Attempted to use SigV4 auth, but region was not set on the config." + ) + return { + "region": config.region, + "service": self._service, + } + + def signer(self) -> SigV4Signer: + return self._signer + + def event_signer( + self, *, request: HTTPRequest + ) -> EventSigner[AWSCredentialsIdentity, SigV4SigningProperties] | None: + if not HAS_EVENT_STREAM: + return None + + auth_value = request.fields["Authorization"].as_string() + signature: str = re.split("Signature=", auth_value)[-1] + + return AsyncEventSigner( + initial_signature=signature.encode("utf-8"), + event_encoder_cls=EventHeaderEncoder, + ) + + @classmethod + def from_trait(cls, trait: SigV4Trait, /) -> Self: + return cls(service=trait.name) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py deleted file mode 100644 index 9d4ace5e6..000000000 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from .environment import EnvironmentCredentialsResolver -from .imds import IMDSCredentialsResolver -from .static import StaticCredentialsResolver - -__all__ = ( - "EnvironmentCredentialsResolver", - "IMDSCredentialsResolver", - "StaticCredentialsResolver", -) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/static.py b/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/static.py deleted file mode 100644 index 3111a3f8c..000000000 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/static.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from smithy_core.aio.interfaces.identity import IdentityResolver -from smithy_core.interfaces.identity import IdentityProperties - -from smithy_aws_core.identity import AWSCredentialsIdentity - - -class StaticCredentialsResolver( - IdentityResolver[AWSCredentialsIdentity, IdentityProperties] -): - """Resolve Static AWS Credentials.""" - - def __init__(self, *, credentials: AWSCredentialsIdentity) -> None: - self._credentials = credentials - - async def get_identity( - self, *, identity_properties: IdentityProperties - ) -> AWSCredentialsIdentity: - return self._credentials diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity.py b/packages/smithy-aws-core/src/smithy_aws_core/identity.py deleted file mode 100644 index 638322a7c..000000000 --- a/packages/smithy-aws-core/src/smithy_aws_core/identity.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. -from datetime import datetime - -from smithy_core.aio.interfaces.identity import IdentityResolver -from smithy_core.identity import Identity -from smithy_core.interfaces.identity import IdentityProperties - - -class AWSCredentialsIdentity(Identity): - """Container for AWS authentication credentials.""" - - def __init__( - self, - *, - access_key_id: str, - secret_access_key: str, - session_token: str | None = None, - expiration: datetime | None = None, - account_id: str | None = None, - ) -> None: - """Initialize the AWSCredentialIdentity. - - :param access_key_id: A unique identifier for an AWS user or role. - :param secret_access_key: A secret key used in conjunction with the access key - ID to authenticate programmatic access to AWS services. - :param session_token: A temporary token used to specify the current session for - the supplied credentials. - :param expiration: The expiration time of the identity. If time zone is - provided, it is updated to UTC. The value must always be in UTC. - :param account_id: The AWS account's ID. - """ - super().__init__(expiration=expiration) - self._access_key_id: str = access_key_id - self._secret_access_key: str = secret_access_key - self._session_token: str | None = session_token - self._account_id: str | None = account_id - - @property - def access_key_id(self) -> str: - return self._access_key_id - - @property - def secret_access_key(self) -> str: - return self._secret_access_key - - @property - def session_token(self) -> str | None: - return self._session_token - - @property - def account_id(self) -> str | None: - return self._account_id - - -type AWSCredentialsResolver = IdentityResolver[ - AWSCredentialsIdentity, IdentityProperties -] diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py new file mode 100644 index 000000000..5d4e542da --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/__init__.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from datetime import datetime +from typing import Protocol, TypedDict + +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.interfaces.identity import Identity +from smithy_core.types import PropertyKey + + +@dataclass(kw_only=True) +class AWSCredentialsIdentity(Identity): + access_key_id: str + """A unique identifier for an AWS user or role.""" + + secret_access_key: str + """A secret key used in conjunction with the access key ID to authenticate + programmatic access to AWS services.""" + + session_token: str | None = None + """A temporary token used to specify the current session for the supplied + credentials.""" + + expiration: datetime | None = None + """The expiration time of the identity. + + If time zone is provided, it is updated to UTC. The value must always be in UTC. + """ + + account_id: str | None = None + """The AWS account's ID.""" + + +class AWSIdentityProperties(TypedDict, total=False): + access_key_id: str | None + secret_access_key: str | None + session_token: str | None + + +type AWSCredentialsResolver = IdentityResolver[ + AWSCredentialsIdentity, AWSIdentityProperties +] + + +class AWSIdentityConfig(Protocol): + access_key_id: str | None + secret_access_key: str | None + session_token: str | None = None + + +AWS_IDENTITY_CONFIG = PropertyKey(key="config", value_type=AWSIdentityConfig) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/chain.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/chain.py new file mode 100644 index 000000000..6241f6ce7 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/chain.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from smithy_core.aio.identity import ChainedIdentityResolver +from smithy_http.aio.interfaces import HTTPClient + +from smithy_aws_core.identity import AWSCredentialsIdentity + +from . import AWSCredentialsResolver, AWSIdentityProperties +from .environment import EnvironmentCredentialsResolver +from .imds import IMDSCredentialsResolver +from .static import StaticCredentialsResolver + + +def create_default_chain(http_client: HTTPClient) -> AWSCredentialsResolver: + """Creates the default AWS credential provider chain.""" + return ChainedIdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties]( + resolvers=( + StaticCredentialsResolver(), + EnvironmentCredentialsResolver(), + IMDSCredentialsResolver(http_client=http_client), + ) + ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/environment.py similarity index 84% rename from packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py rename to packages/smithy-aws-core/src/smithy_aws_core/identity/environment.py index 08aa5fc36..7a9436c73 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/environment.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/environment.py @@ -4,13 +4,12 @@ from smithy_core.aio.interfaces.identity import IdentityResolver from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties -from ..identity import AWSCredentialsIdentity +from . import AWSCredentialsIdentity, AWSIdentityProperties class EnvironmentCredentialsResolver( - IdentityResolver[AWSCredentialsIdentity, IdentityProperties] + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] ): """Resolves AWS Credentials from system environment variables.""" @@ -18,7 +17,7 @@ def __init__(self): self._credentials = None async def get_identity( - self, *, identity_properties: IdentityProperties + self, *, properties: AWSIdentityProperties ) -> AWSCredentialsIdentity: if self._credentials is not None: return self._credentials diff --git a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/imds.py similarity index 97% rename from packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py rename to packages/smithy-aws-core/src/smithy_aws_core/identity/imds.py index a3295642a..6365656ad 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/credentials_resolvers/imds.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/imds.py @@ -10,7 +10,6 @@ from smithy_core import URI from smithy_core.aio.interfaces.identity import IdentityResolver from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties from smithy_core.interfaces.retries import RetryStrategy from smithy_core.retries import SimpleRetryStrategy from smithy_http import Field, Fields @@ -18,7 +17,7 @@ from smithy_http.aio.interfaces import HTTPClient from .. import __version__ -from ..identity import AWSCredentialsIdentity +from ..identity import AWSCredentialsIdentity, AWSIdentityProperties _USER_AGENT_FIELD = Field( name="User-Agent", @@ -181,7 +180,7 @@ async def get(self, *, path: str) -> str: class IMDSCredentialsResolver( - IdentityResolver[AWSCredentialsIdentity, IdentityProperties] + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] ): """Resolves AWS Credentials from an EC2 Instance Metadata Service (IMDS) client.""" @@ -196,7 +195,7 @@ def __init__(self, http_client: HTTPClient, config: Config | None = None): self._profile_name = self._config.ec2_instance_profile_name async def get_identity( - self, *, identity_properties: IdentityProperties + self, *, properties: AWSIdentityProperties ) -> AWSCredentialsIdentity: if ( self._credentials is not None diff --git a/packages/smithy-aws-core/src/smithy_aws_core/identity/static.py b/packages/smithy-aws-core/src/smithy_aws_core/identity/static.py new file mode 100644 index 000000000..988c95be9 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/identity/static.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from smithy_core.aio.interfaces.identity import IdentityResolver +from smithy_core.exceptions import SmithyIdentityError + +from smithy_aws_core.identity import AWSCredentialsIdentity, AWSIdentityProperties + + +class StaticCredentialsResolver( + IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties] +): + """Resolve Static AWS Credentials.""" + + async def get_identity( + self, *, properties: AWSIdentityProperties + ) -> AWSCredentialsIdentity: + access_key_id = properties.get("access_key_id") + secret_access_key = properties.get("secret_access_key") + if access_key_id is not None and secret_access_key is not None: + return AWSCredentialsIdentity( + access_key_id=access_key_id, + secret_access_key=secret_access_key, + session_token=properties.get("session_token"), + ) + raise SmithyIdentityError( + "Attempted to resolve AWS crendentials from config, but credentials weren't configured." + ) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/traits.py b/packages/smithy-aws-core/src/smithy_aws_core/traits.py index 1f5e0bca7..a95b85ba1 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/traits.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/traits.py @@ -43,3 +43,14 @@ def __init__(self, value: DocumentValue | DynamicTrait = None): object.__setattr__( self, "event_stream_http", tuple(event_stream_http_versions) ) + + +@dataclass(init=False, frozen=True) +class SigV4Trait(Trait, id=ShapeID("aws.auth#sigv4")): + def __post_init__(self): + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value["name"], str) + + @property + def name(self) -> str: + return self.document_value["name"] # type: ignore diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_environment_credentials_resolver.py b/packages/smithy-aws-core/tests/unit/identity/test_environment_credentials_resolver.py similarity index 68% rename from packages/smithy-aws-core/tests/unit/credentials_resolvers/test_environment_credentials_resolver.py rename to packages/smithy-aws-core/tests/unit/identity/test_environment_credentials_resolver.py index 982396b6e..b4c6a3f37 100644 --- a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_environment_credentials_resolver.py +++ b/packages/smithy-aws-core/tests/unit/identity/test_environment_credentials_resolver.py @@ -2,52 +2,41 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from smithy_aws_core.credentials_resolvers import EnvironmentCredentialsResolver +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties async def test_no_values_set(): with pytest.raises(SmithyIdentityError): - await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + await EnvironmentCredentialsResolver().get_identity(properties={}) async def test_required_values_missing(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_ACCOUNT_ID", "123456789012") with pytest.raises(SmithyIdentityError): - await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + await EnvironmentCredentialsResolver().get_identity(properties={}) async def test_akid_missing(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret") with pytest.raises(SmithyIdentityError): - await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + await EnvironmentCredentialsResolver().get_identity(properties={}) async def test_secret_missing(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid") with pytest.raises(SmithyIdentityError): - await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + await EnvironmentCredentialsResolver().get_identity(properties={}) async def test_minimum_required(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_ACCESS_KEY_ID", "akid") monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "secret") - credentials = await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + credentials = await EnvironmentCredentialsResolver().get_identity(properties={}) assert credentials.access_key_id == "akid" assert credentials.secret_access_key == "secret" @@ -58,9 +47,7 @@ async def test_all_values(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("AWS_SESSION_TOKEN", "session") monkeypatch.setenv("AWS_ACCOUNT_ID", "123456789012") - credentials = await EnvironmentCredentialsResolver().get_identity( - identity_properties=IdentityProperties() - ) + credentials = await EnvironmentCredentialsResolver().get_identity(properties={}) assert credentials.access_key_id == "akid" assert credentials.secret_access_key == "secret" assert credentials.session_token == "session" diff --git a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py b/packages/smithy-aws-core/tests/unit/identity/test_imds.py similarity index 97% rename from packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py rename to packages/smithy-aws-core/tests/unit/identity/test_imds.py index 864ca6a7d..81a1df95f 100644 --- a/packages/smithy-aws-core/tests/unit/credentials_resolvers/test_imds.py +++ b/packages/smithy-aws-core/tests/unit/identity/test_imds.py @@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from smithy_aws_core.credentials_resolvers.imds import ( +from smithy_aws_core.identity.imds import ( Config, EC2Metadata, IMDSCredentialsResolver, @@ -170,7 +170,7 @@ async def test_imds_credentials_resolver(): ), ] - credentials = await resolver.get_identity(identity_properties=MagicMock()) + credentials = await resolver.get_identity(properties={}) assert credentials.access_key_id == "test-access-key" assert credentials.secret_access_key == "test-secret-key" assert credentials.session_token == "test-session-token" diff --git a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py index 8d7a17195..611e6a114 100644 --- a/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py +++ b/packages/smithy-aws-event-stream/src/smithy_aws_event_stream/aio/__init__.py @@ -3,10 +3,13 @@ import asyncio import logging from collections.abc import Callable -from typing import Protocol +from dataclasses import dataclass +from typing import TYPE_CHECKING from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.aio.interfaces.auth import EventSigner from smithy_core.aio.interfaces.eventstream import EventPublisher, EventReceiver +from smithy_core.aio.interfaces.identity import IdentityResolver from smithy_core.codecs import Codec from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer from smithy_core.exceptions import ExpectationNotMetError @@ -14,21 +17,23 @@ from .._private.deserializers import EventDeserializer as _EventDeserializer from .._private.serializers import EventSerializer as _EventSerializer -from ..events import Event, EventHeaderEncoder, EventMessage +from ..events import Event from ..exceptions import EventError logger = logging.getLogger(__name__) -class EventSigner(Protocol): - """A signer to manage credentials and EventMessages for an Event Stream lifecyle.""" +if TYPE_CHECKING: + from aws_sdk_signers import SigV4SigningProperties + from smithy_aws_core.identity import AWSCredentialsIdentity, AWSIdentityProperties - async def sign_event( - self, - *, - event_message: EventMessage, - event_encoder_cls: type[EventHeaderEncoder], - ) -> EventMessage: ... + +@dataclass +class SigningConfig: + signer: "EventSigner[AWSCredentialsIdentity, SigV4SigningProperties]" + signing_properties: "SigV4SigningProperties" + identity_resolver: "IdentityResolver[AWSCredentialsIdentity, AWSIdentityProperties]" + identity_properties: "AWSIdentityProperties" class AWSEventPublisher[E: SerializeableShape](EventPublisher[E]): @@ -36,11 +41,11 @@ def __init__( self, payload_codec: Codec, async_writer: AsyncWriter, - signer: EventSigner | None = None, + signing_config: SigningConfig | None = None, is_client_mode: bool = True, ): self._writer = async_writer - self._signer = signer + self._signing_config = signing_config self._serializer = _EventSerializer( payload_codec=payload_codec, is_client_mode=is_client_mode ) @@ -56,11 +61,15 @@ async def send(self, event: E) -> None: raise ExpectationNotMetError( "Expected an event message to be serialized, but was None." ) - if self._signer is not None: - encoder = self._serializer.event_header_encoder_cls - result = await self._signer.sign_event( - event_message=result, - event_encoder_cls=encoder, + + if self._signing_config is not None: + identity = await self._signing_config.identity_resolver.get_identity( + properties=self._signing_config.identity_properties + ) + result = await self._signing_config.signer.sign( + event=event, + identity=identity, + properties=self._signing_config.signing_properties, ) encoded_result = result.encode() diff --git a/packages/smithy-core/src/smithy_core/aio/identity.py b/packages/smithy-core/src/smithy_core/aio/identity.py new file mode 100644 index 000000000..62c55abe4 --- /dev/null +++ b/packages/smithy-core/src/smithy_core/aio/identity.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +from collections.abc import Mapping, Sequence +from typing import Any, Final + +from ..exceptions import SmithyIdentityError +from ..interfaces.identity import Identity +from .interfaces.identity import IdentityResolver + +logger: Final = logging.getLogger(__name__) + + +# TODO: turn this into a decorator +class CachingIdentityResolver[I: Identity, IP: Mapping[str, Any]]( + IdentityResolver[I, IP] +): + def __init__(self) -> None: + self._cached: I | None = None + + async def get_identity(self, *, properties: IP) -> I: + if self._cached is None or self._cached.is_expired: + self._cached = await self._get_identity(properties=properties) + return self._cached + + async def _get_identity(self, *, properties: IP) -> I: + raise NotImplementedError + + +class ChainedIdentityResolver[I: Identity, IP: Mapping[str, Any]]( + CachingIdentityResolver[I, IP] +): + """Attempts to resolve an identity by checking a sequence of sub-resolvers. + + If a nested resolver raises a :py:class:`SmithyIdentityError`, the next + resolver in the chain will be attempted. + """ + + def __init__(self, resolvers: Sequence[IdentityResolver[I, IP]]) -> None: + """Construct a ChainedIdentityResolver. + + :param resolvers: The sequence of resolvers to resolve identity from. + """ + super().__init__() + self._resolvers = resolvers + + async def _get_identity(self, *, properties: IP) -> I: + logger.debug("Attempting to resolve identity from resolver chain.") + for resolver in self._resolvers: + try: + logger.debug("Attempting to resolve identity from %s.", type(resolver)) + return await resolver.get_identity(properties=properties) + except SmithyIdentityError as e: + logger.debug( + "Failed to resolve identity from %s: %s", type(resolver), e + ) + + raise SmithyIdentityError("Failed to resolve identity from resolver chain.") diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/auth.py b/packages/smithy-core/src/smithy_core/aio/interfaces/auth.py new file mode 100644 index 000000000..9217b8891 --- /dev/null +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/auth.py @@ -0,0 +1,89 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Mapping +from typing import Any, Protocol + +from ...interfaces import TypedProperties as _TypedProperties +from ...interfaces.identity import Identity +from ...shapes import ShapeID +from . import Request +from .identity import IdentityResolver + + +class Signer[R: Request, I, SP: Mapping[str, Any]](Protocol): + """A class that signs requests before they are sent.""" + + async def sign(self, *, request: R, identity: I, properties: SP) -> R: + """Get a signed version of the request. + + :param request: The request to be signed. + :param identity: The identity to use to sign the request. + :param properties: Additional properties used to sign the request. + """ + ... + + +class EventSigner[I, SP: Mapping[str, Any]](Protocol): + """A class that signs requests before they are sent.""" + + # TODO: add a protocol type for events + async def sign(self, *, event: Any, identity: I, properties: SP) -> Any: + """Get a signed version of the event. + + :param event: The event to be signed. + """ + ... + + +class AuthScheme[R: Request, I: Identity, IP: Mapping[str, Any], SP: Mapping[str, Any]]( + Protocol +): + """A class that coordinates identity and auth.""" + + scheme_id: ShapeID + """The ID of the auth scheme.""" + + def identity_properties(self, *, context: _TypedProperties) -> IP: + """Construct identity properties from the request context. + + The context will always include the client's config under "config". Other + properties may be added by :py:class:`smithy_core.interceptors.Interceptor`s. + + :param context: The context of the request. + """ + ... + + def identity_resolver( + self, *, context: _TypedProperties + ) -> IdentityResolver[I, IP]: + """Get an identity resolver for the request. + + The context will always include the client's config under "config". Other + properties may be added by :py:class:`smithy_core.interceptors.Interceptor`s. + + :param context: The context of the request. + """ + ... + + def signer_properties(self, *, context: _TypedProperties) -> SP: + """Construct signer properties from the request context. + + The context will always include the client's config under "config". Other + properties may be added by :py:class:`smithy_core.interceptors.Interceptor`s. + + :param context: The context of the request. + """ + ... + + def signer(self) -> Signer[R, I, SP]: + """Get a signer for the request.""" + ... + + def event_signer(self, *, request: R) -> EventSigner[I, SP] | None: + """Construct a signer for event stream events. + + :param request: The request that will initiate the event stream. The request + will not have been sent when this method is called. + :returns: An event signer if the scheme supports signing events, otherwise None. + """ + return None diff --git a/packages/smithy-core/src/smithy_core/aio/interfaces/identity.py b/packages/smithy-core/src/smithy_core/aio/interfaces/identity.py index 1b5282066..56070b062 100644 --- a/packages/smithy-core/src/smithy_core/aio/interfaces/identity.py +++ b/packages/smithy-core/src/smithy_core/aio/interfaces/identity.py @@ -1,22 +1,20 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Protocol +from collections.abc import Mapping +from typing import Any, Protocol -from ...interfaces.identity import IdentityPropertiesType_contra, IdentityType_cov +from ...interfaces.identity import Identity -class IdentityResolver(Protocol[IdentityType_cov, IdentityPropertiesType_contra]): +class IdentityResolver[I: Identity, IP: Mapping[str, Any]](Protocol): """Used to load a user's `Identity` from a given source. Each `Identity` may have one or more resolver implementations. """ - async def get_identity( - self, *, identity_properties: IdentityPropertiesType_contra - ) -> IdentityType_cov: + async def get_identity(self, *, properties: IP) -> I: """Load the user's identity from this resolver. - :param identity_properties: Properties used to help determine the identity to - return. + :param properties: Properties used to help determine the identity to return. """ ... diff --git a/packages/smithy-core/src/smithy_core/auth.py b/packages/smithy-core/src/smithy_core/auth.py new file mode 100644 index 000000000..b409543de --- /dev/null +++ b/packages/smithy-core/src/smithy_core/auth.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from dataclasses import dataclass, field +from typing import Any + +from .deserializers import DeserializeableShape +from .interfaces import TypedProperties as _TypedProperties +from .schemas import APIOperation +from .serializers import SerializeableShape +from .shapes import ShapeID +from .types import TypedProperties + + +@dataclass(kw_only=True, frozen=True) +class AuthParams[I: SerializeableShape, O: DeserializeableShape]: + """Parameters passed to an AuthSchemeResolver's ``resolve_auth_scheme`` method.""" + + protocol_id: ShapeID + """The ID of the protocol being used for the operation invocation.""" + + operation: APIOperation[I, O] + """The schema and associated information about the operation being invoked.""" + + context: _TypedProperties + """The context of the operation invocation.""" + + +@dataclass(kw_only=True) +class AuthOption: + """Auth scheme used for signing and identity resolution.""" + + scheme_id: ShapeID + """The ID of the auth scheme to use.""" + + identity_properties: _TypedProperties = field(default_factory=TypedProperties) + """Paramters to pass to the identity resolver method.""" + + signer_properties: _TypedProperties = field(default_factory=TypedProperties) + """Paramters to pass to the signing method.""" + + +class DefaultAuthResolver: + """Determines which authentication scheme to use based on modeled auth schemes.""" + + def resolve_auth_scheme( + self, *, auth_parameters: AuthParams[Any, Any] + ) -> Sequence[AuthOption]: + """Resolve an ordered list of applicable auth schemes. + + :param auth_parameters: The parameters required for determining which + authentication schemes to potentially use. + """ + return [ + AuthOption(scheme_id=id) + for id in auth_parameters.operation.effective_auth_schemes + ] + + +class NoAuthResolver: + """Auth resolver that always returns no auth scheme options.""" + + def resolve_auth_scheme( + self, *, auth_parameters: AuthParams[Any, Any] + ) -> Sequence[AuthOption]: + return [] diff --git a/packages/smithy-core/src/smithy_core/identity.py b/packages/smithy-core/src/smithy_core/identity.py deleted file mode 100644 index ed570bffd..000000000 --- a/packages/smithy-core/src/smithy_core/identity.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from datetime import UTC, datetime - -from .interfaces import identity as identity_interface -from .utils import ensure_utc - - -class Identity(identity_interface.Identity): - """An entity available to the client representing who the user is.""" - - def __init__( - self, - *, - expiration: datetime | None = None, - ) -> None: - """Initialize an identity. - - :param expiration: The expiration time of the identity. If time zone is - provided, it is updated to UTC. The value must always be in UTC. - """ - if expiration is not None: - expiration = ensure_utc(expiration) - self.expiration: datetime | None = expiration - - @property - def is_expired(self) -> bool: - """Whether the identity is expired.""" - if self.expiration is None: - return False - return datetime.now(tz=UTC) >= self.expiration diff --git a/packages/smithy-core/src/smithy_core/interfaces/auth.py b/packages/smithy-core/src/smithy_core/interfaces/auth.py new file mode 100644 index 000000000..df2c95370 --- /dev/null +++ b/packages/smithy-core/src/smithy_core/interfaces/auth.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Protocol + +from ..shapes import ShapeID +from . import TypedProperties + +if TYPE_CHECKING: + from ..auth import AuthParams + + +class AuthOption(Protocol): + """Auth scheme used for signing and identity resolution.""" + + scheme_id: ShapeID + """The ID of the auth scheme to use.""" + + identity_properties: TypedProperties + """Paramters to pass to the identity resolver method.""" + + signer_properties: TypedProperties + """Paramters to pass to the signing method.""" + + +class AuthSchemeResolver(Protocol): + """Determines which authentication scheme to use for a given service.""" + + def resolve_auth_scheme( + self, *, auth_parameters: "AuthParams[Any, Any]" + ) -> Sequence[AuthOption]: + """Resolve an ordered list of applicable auth schemes. + + :param auth_parameters: The parameters required for determining which + authentication schemes to potentially use. + """ + ... diff --git a/packages/smithy-core/src/smithy_core/interfaces/identity.py b/packages/smithy-core/src/smithy_core/interfaces/identity.py index e34a8f976..0e950da8f 100644 --- a/packages/smithy-core/src/smithy_core/interfaces/identity.py +++ b/packages/smithy-core/src/smithy_core/interfaces/identity.py @@ -1,35 +1,22 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from datetime import datetime -from typing import Protocol, TypedDict, TypeVar, runtime_checkable +from datetime import UTC, datetime +from typing import Protocol, runtime_checkable @runtime_checkable class Identity(Protocol): """An entity available to the client representing who the user is.""" - # The expiration time of the identity. If time zone is provided, - # it is updated to UTC. The value must always be in UTC. expiration: datetime | None = None + """The expiration time of the identity. + + If time zone is provided, it is updated to UTC. The value must always be in UTC. + """ @property def is_expired(self) -> bool: """Whether the identity is expired.""" - ... - - -IdentityType = TypeVar("IdentityType", bound=Identity) -IdentityType_contra = TypeVar("IdentityType_contra", bound=Identity, contravariant=True) -IdentityType_cov = TypeVar("IdentityType_cov", bound=Identity, covariant=True) - - -class IdentityProperties(TypedDict): - """Properties used to help determine the identity to return.""" - - -IdentityPropertiesType = TypeVar("IdentityPropertiesType", bound=IdentityProperties) -IdentityPropertiesType_contra = TypeVar( - "IdentityPropertiesType_contra", bound=IdentityProperties, contravariant=True -) - -IdentityConfig_contra = TypeVar("IdentityConfig_contra", contravariant=True) + if self.expiration is None: + return False + return datetime.now(tz=UTC) >= self.expiration diff --git a/packages/smithy-core/src/smithy_core/traits.py b/packages/smithy-core/src/smithy_core/traits.py index d7f6ad26c..4be81aba1 100644 --- a/packages/smithy-core/src/smithy_core/traits.py +++ b/packages/smithy-core/src/smithy_core/traits.py @@ -314,3 +314,30 @@ def host_prefix(self) -> str: class HostLabelTrait(Trait, id=ShapeID("smithy.api#hostLabel")): def __post_init__(self): assert self.document_value is None + + +class APIKeyLocation(Enum): + """The locations that the api key could be placed in the signed request.""" + + HEADER = "header" + QUERY = "query" + + +@dataclass(init=False, frozen=True) +class HTTPAPIKeyAuthTrait(Trait, id=ShapeID("smithy.api#httpApiKeyAuth")): + location: APIKeyLocation = field(repr=False, hash=False, compare=False) + + def __init__(self, value: "DocumentValue | DynamicTrait" = None): + super().__init__(value) + object.__setattr__(self, "location", APIKeyLocation(value)) + assert isinstance(self.document_value, Mapping) + assert isinstance(self.document_value["name"], str) + assert isinstance(self.document_value.get("scheme"), str | None) + + @property + def name(self) -> str: + return self.document_value["name"] # type: ignore + + @property + def scheme(self) -> str | None: + return self.document_value.get("scheme") # type: ignore diff --git a/packages/smithy-core/tests/unit/test_identity.py b/packages/smithy-core/tests/unit/test_identity.py index 7a61f4df2..3550a306f 100644 --- a/packages/smithy-core/tests/unit/test_identity.py +++ b/packages/smithy-core/tests/unit/test_identity.py @@ -1,10 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass from datetime import UTC, datetime, timedelta, timezone import pytest from freezegun import freeze_time -from smithy_core.identity import Identity +from smithy_core.interfaces.identity import Identity + + +@dataclass(kw_only=True) +class EmptyIdentity(Identity): + expiration: datetime | None = None @pytest.mark.parametrize( @@ -21,7 +27,7 @@ ) def test_expiration_timezone(time_zone: timezone) -> None: expiration = datetime.now(tz=time_zone) - identity = Identity(expiration=expiration) + identity = EmptyIdentity(expiration=expiration) assert identity.expiration is not None assert identity.expiration.tzinfo == UTC @@ -30,20 +36,20 @@ def test_expiration_timezone(time_zone: timezone) -> None: "identity, expected_expired", [ ( - Identity( + EmptyIdentity( expiration=datetime(year=2023, month=1, day=1, tzinfo=UTC), ), True, ), - (Identity(), False), + (EmptyIdentity(), False), ( - Identity( + EmptyIdentity( expiration=datetime(year=2023, month=1, day=2, tzinfo=UTC), ), False, ), ( - Identity( + EmptyIdentity( expiration=datetime(year=2022, month=12, day=31, tzinfo=UTC), ), True, diff --git a/packages/smithy-http/src/smithy_http/aio/auth/apikey.py b/packages/smithy-http/src/smithy_http/aio/auth/apikey.py index 2861357d5..a1bf6a4e1 100644 --- a/packages/smithy-http/src/smithy_http/aio/auth/apikey.py +++ b/packages/smithy-http/src/smithy_http/aio/auth/apikey.py @@ -1,117 +1,122 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass -from enum import Enum -from typing import NotRequired, Protocol, TypedDict +from typing import Any, Protocol, Self from smithy_core import URI +from smithy_core.aio.interfaces.auth import AuthScheme, Signer from smithy_core.aio.interfaces.identity import IdentityResolver from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties +from smithy_core.interfaces import TypedProperties as _TypedProperties +from smithy_core.traits import APIKeyLocation, HTTPAPIKeyAuthTrait +from smithy_core.types import PropertyKey from ... import Field -from ..identity.apikey import ApiKeyIdentity +from ..identity.apikey import APIKeyIdentity, APIKeyIdentityProperties from ..interfaces import HTTPRequest -from ..interfaces.auth import HTTPAuthScheme, HTTPSigner -class ApiKeyLocation(Enum): - """The locations that the api key could be placed in the signed request.""" +class APIKeyResolverConfig(Protocol): + """A config bearing API key properties.""" - HEADER = "header" - QUERY = "query" + api_key: str | None + """An explicit API key. + If not set, it MAY be retrieved from elsewhere by the resolver. + """ -class ApiKeySigningProperties(TypedDict): - """The properties needed to sign a request with api key auth. + api_key_identity_resolver: ( + IdentityResolver[APIKeyIdentity, APIKeyIdentityProperties] | None + ) + """An API key identity resolver. - seealso:: The `Smithy API Key auth trait docs `_ - , which have more details on these properties, including examples. + The default implementation only checks the explicitly configured key. """ - name: str - """The name of the HTTP header or query string parameter containing the key.""" - scheme: NotRequired[str] - """The :rfc:`9110#section-11.4` scheme to prefix a header value with.""" +API_KEY_RESOLVER_CONFIG = PropertyKey(key="config", value_type=APIKeyResolverConfig) +"""A context property bearing an API key config.""" - location: ApiKeyLocation - """Where the key is serialized.""" +class APIKeySigner(Signer[HTTPRequest, APIKeyIdentity, Any]): + """A signer that signs http requests with an api key.""" -class ApiKeyConfig(Protocol): - api_key_identity_resolver: ( - IdentityResolver[ApiKeyIdentity, IdentityProperties] | None - ) + def __init__( + self, *, name: str, location: APIKeyLocation, scheme: str | None = None + ) -> None: + self._name = name + self._location = location + self._scheme = scheme + async def sign( + self, + *, + request: HTTPRequest, + identity: APIKeyIdentity, + properties: Any, + ) -> HTTPRequest: + match self._location: + case APIKeyLocation.QUERY: + query = request.destination.query or "" + if query: + query += "&" + query += f"{self._name}={identity.api_key}" + request.destination = URI( + scheme=request.destination.scheme, + username=request.destination.username, + password=request.destination.password, + host=request.destination.host, + port=request.destination.port, + path=request.destination.password, + query=query, + fragment=request.destination.fragment, + ) + case APIKeyLocation.HEADER: + value = identity.api_key + if self._scheme is not None: + value = f"{self._scheme} {value}" + request.fields.set_field(Field(name=self._name, values=[value])) -@dataclass(init=False) -class ApiKeyAuthScheme( - HTTPAuthScheme[ - ApiKeyIdentity, ApiKeyConfig, IdentityProperties, ApiKeySigningProperties - ] + return request + + +class APIKeyAuthScheme( + AuthScheme[HTTPRequest, APIKeyIdentity, APIKeyIdentityProperties, Any] ): """An auth scheme containing necessary data and tools for api key auth.""" - scheme_id: str - signer: HTTPSigner[ApiKeyIdentity, ApiKeySigningProperties] + scheme_id = HTTPAPIKeyAuthTrait.id + _signer: APIKeySigner def __init__( - self, - *, - signer: HTTPSigner[ApiKeyIdentity, ApiKeySigningProperties] | None = None, + self, *, name: str, location: APIKeyLocation, scheme: str | None = None ) -> None: - """Constructor. + self._signer = APIKeySigner(name=name, location=location, scheme=scheme) - :param identity_resolver: The identity resolver to extract the api key identity. - :param signer: The signer used to sign the request. - """ - self.scheme_id = "smithy.api#httpApiKeyAuth" - self.signer = signer or ApiKeySigner() + def identity_properties( + self, *, context: _TypedProperties + ) -> APIKeyIdentityProperties: + config = context.get(API_KEY_RESOLVER_CONFIG) + if config is not None and config.api_key is not None: + return {"api_key": config.api_key} + return {} def identity_resolver( - self, *, config: ApiKeyConfig - ) -> IdentityResolver[ApiKeyIdentity, IdentityProperties]: - if not config.api_key_identity_resolver: + self, *, context: _TypedProperties + ) -> IdentityResolver[APIKeyIdentity, APIKeyIdentityProperties]: + config = context.get(API_KEY_RESOLVER_CONFIG) + if config is None or config.api_key_identity_resolver is None: raise SmithyIdentityError( "Attempted to use API key auth, but api_key_identity_resolver was not " "set on the config." ) return config.api_key_identity_resolver + def signer_properties(self, *, context: _TypedProperties) -> Any: + return {} -class ApiKeySigner(HTTPSigner[ApiKeyIdentity, ApiKeySigningProperties]): - """A signer that signs http requests with an api key.""" - - async def sign( - self, - *, - http_request: HTTPRequest, - identity: ApiKeyIdentity, - signing_properties: ApiKeySigningProperties, - ) -> HTTPRequest: - match signing_properties["location"]: - case ApiKeyLocation.QUERY: - query = http_request.destination.query or "" - if query: - query += "&" - query += f"{signing_properties['name']}={identity.api_key}" - http_request.destination = URI( - scheme=http_request.destination.scheme, - username=http_request.destination.username, - password=http_request.destination.password, - host=http_request.destination.host, - port=http_request.destination.port, - path=http_request.destination.password, - query=query, - fragment=http_request.destination.fragment, - ) - case ApiKeyLocation.HEADER: - value = identity.api_key - if (scheme := signing_properties.get("scheme", None)) is not None: - value = f"{scheme} {value}" - http_request.fields.set_field( - Field(name=signing_properties["name"], values=[value]) - ) + def signer(self) -> Signer[HTTPRequest, APIKeyIdentity, Any]: + return self._signer - return http_request + @classmethod + def from_trait(cls, trait: HTTPAPIKeyAuthTrait, /) -> Self: + return cls(name=trait.name, location=trait.location, scheme=trait.scheme) diff --git a/packages/smithy-http/src/smithy_http/aio/identity/apikey.py b/packages/smithy-http/src/smithy_http/aio/identity/apikey.py index e225725d3..21d3bc375 100644 --- a/packages/smithy-http/src/smithy_http/aio/identity/apikey.py +++ b/packages/smithy-http/src/smithy_http/aio/identity/apikey.py @@ -1,38 +1,38 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from datetime import datetime +from typing import TypedDict + from smithy_core.aio.interfaces.identity import IdentityResolver -from smithy_core.identity import Identity -from smithy_core.interfaces.identity import IdentityProperties +from smithy_core.exceptions import SmithyIdentityError +from smithy_core.interfaces.identity import Identity -class ApiKeyIdentity(Identity): +@dataclass(kw_only=True) +class APIKeyIdentity(Identity): """The identity for auth that uses an api key.""" - def __init__(self, *, api_key: str) -> None: - super().__init__(expiration=None) - self.api_key = api_key + api_key: str + """The API Key to add to requests.""" + + expiration: datetime | None = None + +class APIKeyIdentityProperties(TypedDict, total=False): + api_key: str -class ApiKeyIdentityResolver(IdentityResolver[ApiKeyIdentity, IdentityProperties]): - """Loads the api key identity from the configuration.""" - def __init__(self, *, api_key: str | ApiKeyIdentity) -> None: - """ - :param api_key: The API key to authenticate with. - """ - match api_key: - case str(): - self._identity = ApiKeyIdentity(api_key=api_key) - case ApiKeyIdentity(): - self._identity = api_key +class APIKeyIdentityResolver( + IdentityResolver[APIKeyIdentity, APIKeyIdentityProperties] +): + """Loads the API key identity from the configuration.""" async def get_identity( - self, *, identity_properties: IdentityProperties - ) -> ApiKeyIdentity: - """Load the user's api key identity from this resolver. - - :param identity_properties: Properties used to help determine the identity to - return. - :returns: The api key identity. - """ - return self._identity + self, *, properties: APIKeyIdentityProperties + ) -> APIKeyIdentity: + if (api_key := properties.get("api_key")) is not None: + return APIKeyIdentity(api_key=api_key) + raise SmithyIdentityError( + "Attempted to use API key auth, but api_key was not set on the config." + ) diff --git a/packages/smithy-http/src/smithy_http/aio/interfaces/auth.py b/packages/smithy-http/src/smithy_http/aio/interfaces/auth.py deleted file mode 100644 index 0a40a8dad..000000000 --- a/packages/smithy-http/src/smithy_http/aio/interfaces/auth.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). You -# may not use this file except in compliance with the License. A copy of -# the License is located at -# -# http://aws.amazon.com/apache2.0/ -# -# or in the "license" file accompanying this file. This file is -# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF -# ANY KIND, either express or implied. See the License for the specific -# language governing permissions and limitations under the License. - -from dataclasses import dataclass -from typing import Any, Protocol, TypedDict, TypeVar - -from smithy_core.aio.interfaces.identity import IdentityResolver -from smithy_core.interfaces.identity import Identity, IdentityProperties - -from . import HTTPRequest - - -class SigningProperties(TypedDict): - """Additional properties loaded to modify the signing process.""" - - -SigningPropertiesType = TypeVar("SigningPropertiesType", bound=SigningProperties) -SigningPropertiesType_contra = TypeVar( - "SigningPropertiesType_contra", bound=SigningProperties, contravariant=True -) - - -class HTTPSigner[I: Identity, SP: SigningProperties](Protocol): - """An interface for generating a signed HTTP request.""" - - async def sign( - self, - *, - http_request: HTTPRequest, - identity: I, - signing_properties: SP, - ) -> HTTPRequest: - """Generate a new signed HTTPRequest based on the one provided. - - :param http_request: The HTTP request to sign. - :param identity: The signing identity. - :param signing_properties: Additional properties loaded to modify the signing - process. - """ - ... - - -class HTTPAuthScheme[I: Identity, C, IP: IdentityProperties, SP: SigningProperties]( - Protocol -): - """Represents a way a service will authenticate the user's identity.""" - - # A unique identifier for the authentication scheme. - scheme_id: str - - # An API that can be used to sign HTTP requests. - signer: HTTPSigner[I, SP] - - def identity_resolver(self, *, config: C) -> IdentityResolver[I, IP]: - """An API that can be queried to resolve identity.""" - ... - - -@dataclass(kw_only=True) -class HTTPAuthOption: - """Auth scheme used for signing and identity resolution.""" - - # The ID of the scheme to use. This string matches the one returned by - # HttpAuthScheme.scheme_id - scheme_id: str - - # Parameters to pass to IdentityResolver.get_identity. - identity_properties: dict[str, Any] - - # Parameters to pass to HttpSigner.sign. - signer_properties: dict[str, Any] - - -@dataclass(kw_only=True) -class AuthSchemeParameters: - """The input to the auth scheme resolver. - - A code-generated interface for passing in the data required for determining the - authentication scheme. By default, this only includes the operation name. - """ - - # The service operation being invoked by the client. - operation: str - - -class AuthSchemeResolver(Protocol): - """Determines which authentication scheme to use for a given service.""" - - def resolve_auth_scheme( - self, *, auth_parameters: AuthSchemeParameters - ) -> list[HTTPAuthOption]: - """Resolve an ordered list of applicable auth schemes. - - :param auth_parameters: The parameters required for determining which - authentication schemes to potentially use. - """ - ... diff --git a/packages/smithy-http/tests/unit/aio/auth/test_apikey.py b/packages/smithy-http/tests/unit/aio/auth/test_apikey.py index a489ab329..e80a75365 100644 --- a/packages/smithy-http/tests/unit/aio/auth/test_apikey.py +++ b/packages/smithy-http/tests/unit/aio/auth/test_apikey.py @@ -7,21 +7,16 @@ from smithy_core import URI from smithy_core.aio.interfaces.identity import IdentityResolver from smithy_core.exceptions import SmithyIdentityError -from smithy_core.interfaces.identity import IdentityProperties +from smithy_core.types import TypedProperties from smithy_http import Field, Fields from smithy_http.aio import HTTPRequest from smithy_http.aio.auth.apikey import ( - ApiKeyAuthScheme, - ApiKeyLocation, - ApiKeySigner, - ApiKeySigningProperties, + APIKeyAuthScheme, + APIKeyIdentityProperties, + APIKeyLocation, + APIKeySigner, ) -from smithy_http.aio.identity.apikey import ApiKeyIdentity, ApiKeyIdentityResolver - - -@pytest.fixture -def signer() -> ApiKeySigner: - return ApiKeySigner() +from smithy_http.aio.identity.apikey import APIKeyIdentity, APIKeyIdentityResolver class _FakeBody(AsyncIterable[bytes]): @@ -44,82 +39,69 @@ def request(query: str | None = None, fields: Fields | None = None) -> HTTPReque ) -async def test_sign_empty_query(signer: ApiKeySigner) -> None: +async def test_sign_empty_query() -> None: api_key = "spam" - identity = ApiKeyIdentity(api_key=api_key) - properties: ApiKeySigningProperties = { - "name": "eggs", - "location": ApiKeyLocation.QUERY, - } + identity = APIKeyIdentity(api_key=api_key) + signer = APIKeySigner(name="eggs", location=APIKeyLocation.QUERY) given = request() expected = request(query="eggs=spam") actual = await signer.sign( - http_request=given, + request=given, identity=identity, - signing_properties=properties, + properties={}, ) assert actual == expected -async def test_sign_non_empty_query(signer: ApiKeySigner) -> None: +async def test_sign_non_empty_query() -> None: api_key = "spam" - identity = ApiKeyIdentity(api_key=api_key) - properties: ApiKeySigningProperties = { - "name": "eggs", - "location": ApiKeyLocation.QUERY, - } + identity = APIKeyIdentity(api_key=api_key) + signer = APIKeySigner(name="eggs", location=APIKeyLocation.QUERY) given = request(query="spam=eggs") expected = request(query="spam=eggs&eggs=spam") actual = await signer.sign( - http_request=given, + request=given, identity=identity, - signing_properties=properties, + properties={}, ) assert actual == expected -async def test_sign_header(signer: ApiKeySigner) -> None: +async def test_sign_header() -> None: api_key = "spam" - identity = ApiKeyIdentity(api_key=api_key) - properties: ApiKeySigningProperties = { - "name": "eggs", - "location": ApiKeyLocation.HEADER, - } + identity = APIKeyIdentity(api_key=api_key) + signer = APIKeySigner(name="eggs", location=APIKeyLocation.HEADER) given = request() expected = request(fields=Fields([Field(name="eggs", values=["spam"])])) actual = await signer.sign( - http_request=given, + request=given, identity=identity, - signing_properties=properties, + properties={}, ) assert actual == expected -async def test_sign_header_with_scheme(signer: ApiKeySigner) -> None: +async def test_sign_header_with_scheme() -> None: api_key = "spam" - identity = ApiKeyIdentity(api_key=api_key) - properties: ApiKeySigningProperties = { - "name": "eggs", - "location": ApiKeyLocation.HEADER, - "scheme": "Bearer", - } + identity = APIKeyIdentity(api_key=api_key) + signer = APIKeySigner(name="eggs", location=APIKeyLocation.HEADER, scheme="Bearer") given = request() expected = request(fields=Fields([Field(name="eggs", values=["Bearer spam"])])) actual = await signer.sign( - http_request=given, + request=given, identity=identity, - signing_properties=properties, + properties={}, ) assert actual == expected @@ -128,19 +110,20 @@ async def test_sign_header_with_scheme(signer: ApiKeySigner) -> None: @dataclass class ApiKeyConfig: api_key_identity_resolver: ( - IdentityResolver[ApiKeyIdentity, IdentityProperties] | None + IdentityResolver[APIKeyIdentity, APIKeyIdentityProperties] | None ) = None async def test_auth_scheme_gets_resolver() -> None: - scheme = ApiKeyAuthScheme() - resolver = ApiKeyIdentityResolver(api_key="spam") + scheme = APIKeyAuthScheme(name="eggs", location=APIKeyLocation.QUERY) + resolver = APIKeyIdentityResolver() config = ApiKeyConfig(api_key_identity_resolver=resolver) + properties = TypedProperties({"config": config}) - assert resolver == scheme.identity_resolver(config=config) + assert resolver == scheme.identity_resolver(context=properties) async def test_auth_scheme_missing_resolver() -> None: - scheme = ApiKeyAuthScheme() + scheme = APIKeyAuthScheme(name="eggs", location=APIKeyLocation.QUERY) with pytest.raises(SmithyIdentityError): - scheme.identity_resolver(config=ApiKeyConfig()) + scheme.identity_resolver(context=TypedProperties()) diff --git a/packages/smithy-http/tests/unit/aio/identity/test_apikey.py b/packages/smithy-http/tests/unit/aio/identity/test_apikey.py index 9379caa82..782536b21 100644 --- a/packages/smithy-http/tests/unit/aio/identity/test_apikey.py +++ b/packages/smithy-http/tests/unit/aio/identity/test_apikey.py @@ -1,16 +1,20 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from smithy_http.aio.identity.apikey import ApiKeyIdentity, ApiKeyIdentityResolver +import pytest +from smithy_core.exceptions import SmithyIdentityError +from smithy_http.aio.identity.apikey import APIKeyIdentityResolver async def test_identity_resolver() -> None: api_key = "spam" - resolver = ApiKeyIdentityResolver(api_key=api_key) - identity = await resolver.get_identity(identity_properties={}) + resolver = APIKeyIdentityResolver() + identity = await resolver.get_identity(properties={"api_key": api_key}) assert identity.api_key == api_key - resolver = ApiKeyIdentityResolver(api_key=ApiKeyIdentity(api_key=api_key)) - identity = await resolver.get_identity(identity_properties={}) - assert identity.api_key == api_key +async def test_missing_api_key() -> None: + resolver = APIKeyIdentityResolver() + + with pytest.raises(SmithyIdentityError): + await resolver.get_identity(properties={}) diff --git a/uv.lock b/uv.lock index c55e7e637..e7595ee0a 100644 --- a/uv.lock +++ b/uv.lock @@ -684,12 +684,19 @@ dependencies = [ { name = "smithy-http" }, ] +[package.optional-dependencies] +eventstream = [ + { name = "smithy-aws-event-stream" }, +] + [package.metadata] requires-dist = [ { name = "aws-sdk-signers", editable = "packages/aws-sdk-signers" }, + { name = "smithy-aws-event-stream", marker = "extra == 'eventstream'", editable = "packages/smithy-aws-event-stream" }, { name = "smithy-core", editable = "packages/smithy-core" }, { name = "smithy-http", editable = "packages/smithy-http" }, ] +provides-extras = ["eventstream"] [[package]] name = "smithy-aws-event-stream"