Skip to content

Commit 722a264

Browse files
Parse errors in http binding protocols
1 parent a9088c2 commit 722a264

File tree

8 files changed

+205
-17
lines changed

8 files changed

+205
-17
lines changed

packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,44 @@
1-
from typing import Final
1+
from typing import Any, Final
22

33
from smithy_core.codecs import Codec
4+
from smithy_core.schemas import APIOperation
45
from smithy_core.shapes import ShapeID
6+
from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse
57
from smithy_http.aio.protocols import HttpBindingClientProtocol
68
from smithy_json import JSONCodec
79

810
from ..traits import RestJson1Trait
911

1012

13+
class AWSErrorIdentifier(HTTPErrorIdentifier):
14+
_HEADER_KEY: Final = "x-amzn-errortype"
15+
16+
def identify(
17+
self,
18+
*,
19+
operation: APIOperation[Any, Any],
20+
response: HTTPResponse,
21+
) -> ShapeID | None:
22+
if self._HEADER_KEY not in response.fields:
23+
return None
24+
25+
code = response.fields[self._HEADER_KEY].values[0]
26+
if not code:
27+
return None
28+
29+
code = code.split(":")[0]
30+
if "#" in code:
31+
return ShapeID(code)
32+
return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace)
33+
34+
1135
class RestJsonClientProtocol(HttpBindingClientProtocol):
1236
"""An implementation of the aws.protocols#restJson1 protocol."""
1337

14-
_id: ShapeID = RestJson1Trait.id
15-
_codec: JSONCodec = JSONCodec()
38+
_id: Final = RestJson1Trait.id
39+
_codec: Final = JSONCodec()
1640
_contentType: Final = "application/json"
41+
_error_identifier: Final = AWSErrorIdentifier()
1742

1843
@property
1944
def id(self) -> ShapeID:
@@ -26,3 +51,7 @@ def payload_codec(self) -> Codec:
2651
@property
2752
def content_type(self) -> str:
2853
return self._contentType
54+
55+
@property
56+
def error_identifier(self) -> HTTPErrorIdentifier:
57+
return self._error_identifier

packages/smithy-aws-core/src/smithy_aws_core/traits.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from collections.abc import Mapping, Sequence
1111
from dataclasses import dataclass, field
1212

13+
from smithy_core.documents import DocumentValue
1314
from smithy_core.shapes import ShapeID
14-
from smithy_core.traits import DocumentValue, DynamicTrait, Trait
15+
from smithy_core.traits import DynamicTrait, Trait
1516

1617

1718
@dataclass(init=False, frozen=True)

packages/smithy-aws-core/tests/unit/aio/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from unittest.mock import Mock
5+
6+
import pytest
7+
from smithy_aws_core.aio.protocols import AWSErrorIdentifier
8+
from smithy_core.schemas import APIOperation, Schema
9+
from smithy_core.shapes import ShapeID, ShapeType
10+
from smithy_http import Fields, tuples_to_fields
11+
from smithy_http.aio import HTTPResponse
12+
13+
14+
@pytest.mark.parametrize(
15+
"header, expected",
16+
[
17+
("FooError", "com.test#FooError"),
18+
(
19+
"FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/",
20+
"com.test#FooError",
21+
),
22+
(
23+
"com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate",
24+
"com.test#FooError",
25+
),
26+
("", None),
27+
(None, None),
28+
],
29+
)
30+
def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> None:
31+
fields = Fields()
32+
if header is not None:
33+
fields = tuples_to_fields([("x-amzn-errortype", header)])
34+
http_response = HTTPResponse(status=500, fields=fields)
35+
36+
operation = Mock(spec=APIOperation)
37+
operation.schema = Schema(
38+
id=ShapeID("com.test#TestOperation"), shape_type=ShapeType.OPERATION
39+
)
40+
41+
error_identifier = AWSErrorIdentifier()
42+
actual = error_identifier.identify(operation=operation, response=http_response)
43+
44+
assert actual == expected

packages/smithy-core/src/smithy_core/interfaces/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def is_bytes_reader(obj: Any) -> TypeGuard[BytesReader]:
7979
)
8080

8181

82+
@runtime_checkable
83+
class SeekableBytesReader(BytesReader, Protocol):
84+
"""A synchronous bytes reader with seek and tell methods."""
85+
86+
def tell(self) -> int: ...
87+
def seek(self, offset: int, whence: int = 0, /) -> int: ...
88+
89+
8290
# A union of all acceptable streaming blob types. Deserialized payloads will
8391
# always return a ByteStream, or AsyncByteStream if async is enabled.
8492
type StreamingBlob = BytesReader | bytes | bytearray

packages/smithy-http/src/smithy_http/aio/interfaces/__init__.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
from typing import Protocol
3+
from typing import Any, Protocol
44

55
from smithy_core.aio.interfaces import ClientTransport, Request, Response
66
from smithy_core.aio.utils import read_streaming_blob, read_streaming_blob_async
7+
from smithy_core.schemas import APIOperation
8+
from smithy_core.shapes import ShapeID
79

810
from ...interfaces import (
911
Fields,
@@ -83,3 +85,19 @@ async def send(
8385
:param request_config: Configuration specific to this request.
8486
"""
8587
...
88+
89+
90+
class HTTPErrorIdentifier:
91+
"""A class that uses HTTP response metadata to identify errors.
92+
93+
The body of the response SHOULD NOT be touched by this. The payload codec will be
94+
used instead to check for an ID in the body.
95+
"""
96+
97+
def identify(
98+
self,
99+
*,
100+
operation: APIOperation[Any, Any],
101+
response: HTTPResponse,
102+
) -> ShapeID | None:
103+
"""Idenitfy the ShapeID of an error from an HTTP response."""

packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
11
import os
22
from inspect import iscoroutinefunction
33
from io import BytesIO
4+
from typing import Any
45

56
from smithy_core.aio.interfaces import ClientProtocol
67
from smithy_core.codecs import Codec
78
from smithy_core.deserializers import DeserializeableShape
89
from smithy_core.documents import TypeRegistry
9-
from smithy_core.exceptions import ExpectationNotMetError
10-
from smithy_core.interfaces import Endpoint, TypedProperties, URI
10+
from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError
11+
from smithy_core.interfaces import (
12+
Endpoint,
13+
SeekableBytesReader,
14+
TypedProperties,
15+
URI,
16+
is_streaming_blob,
17+
)
18+
from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob
19+
from smithy_core.prelude import DOCUMENT
1120
from smithy_core.schemas import APIOperation
1221
from smithy_core.serializers import SerializeableShape
1322
from smithy_core.traits import EndpointTrait, HTTPTrait
1423

15-
from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
16-
from smithy_http.deserializers import HTTPResponseDeserializer
17-
from smithy_http.serializers import HTTPRequestSerializer
24+
from ..deserializers import HTTPResponseDeserializer
25+
from ..serializers import HTTPRequestSerializer
26+
from .interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse
1827

1928

2029
class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
@@ -54,6 +63,12 @@ def content_type(self) -> str:
5463
"""The media type of the http payload."""
5564
raise NotImplementedError()
5665

66+
@property
67+
def error_identifier(self) -> HTTPErrorIdentifier:
68+
"""The class used to identify the shape IDs of errors based on fields or other
69+
response information."""
70+
raise NotImplementedError()
71+
5772
def serialize_request[
5873
OperationInput: "SerializeableShape",
5974
OperationOutput: "DeserializeableShape",
@@ -94,19 +109,25 @@ async def deserialize_response[
94109
error_registry: TypeRegistry,
95110
context: TypedProperties,
96111
) -> OperationOutput:
97-
if not (200 <= response.status <= 299):
98-
# TODO: implement error serde from type registry
99-
raise NotImplementedError
100-
101112
body = response.body
102113

103114
# if body is not streaming and is async, we have to buffer it
104-
if not operation.output_stream_member:
115+
if not operation.output_stream_member and not is_streaming_blob(body):
105116
if (
106117
read := getattr(body, "read", None)
107118
) is not None and iscoroutinefunction(read):
108119
body = BytesIO(await read())
109120

121+
if not self._is_success(operation, context, response):
122+
raise await self._create_error(
123+
operation=operation,
124+
request=request,
125+
response=response,
126+
response_body=body, # type: ignore
127+
error_registry=error_registry,
128+
context=context,
129+
)
130+
110131
# TODO(optimization): response binding cache like done in SJ
111132
deserializer = HTTPResponseDeserializer(
112133
payload_codec=self.payload_codec,
@@ -116,3 +137,69 @@ async def deserialize_response[
116137
)
117138

118139
return operation.output.deserialize(deserializer)
140+
141+
def _is_success(
142+
self,
143+
operation: APIOperation[Any, Any],
144+
context: TypedProperties,
145+
response: HTTPResponse,
146+
) -> bool:
147+
return 200 <= response.status <= 299
148+
149+
async def _create_error(
150+
self,
151+
operation: APIOperation[Any, Any],
152+
request: HTTPRequest,
153+
response: HTTPResponse,
154+
response_body: SyncStreamingBlob,
155+
error_registry: TypeRegistry,
156+
context: TypedProperties,
157+
) -> CallError:
158+
error_id = self.error_identifier.identify(
159+
operation=operation, response=response
160+
)
161+
162+
if error_id is None:
163+
if isinstance(response_body, bytearray):
164+
response_body = bytes(response_body)
165+
deserializer = self.payload_codec.create_deserializer(source=response_body)
166+
document = deserializer.read_document(schema=DOCUMENT)
167+
168+
if document.discriminator in error_registry:
169+
error_id = document.discriminator
170+
if isinstance(response_body, SeekableBytesReader):
171+
response_body.seek(0)
172+
173+
if error_id is not None and error_id in error_registry:
174+
error_shape = error_registry.get(error_id)
175+
176+
# make sure the error shape is derived from modeled exception
177+
if not issubclass(error_shape, ModeledError):
178+
raise ExpectationNotMetError(
179+
f"Modeled errors must be derived from 'ModeledError', "
180+
f"but got {error_shape}"
181+
)
182+
183+
deserializer = HTTPResponseDeserializer(
184+
payload_codec=self.payload_codec,
185+
http_trait=operation.schema.expect_trait(HTTPTrait),
186+
response=response,
187+
body=response_body,
188+
)
189+
return error_shape.deserialize(deserializer)
190+
191+
is_throttle = response.status == 429
192+
message = (
193+
f"Unknown error for operation {operation.schema.id} "
194+
f"- status: {response.status}"
195+
)
196+
if error_id is not None:
197+
message += f" - id: {error_id}"
198+
if response.reason is not None:
199+
message += f" - reason: {response.status}"
200+
return CallError(
201+
message=message,
202+
fault="client" if response.status < 500 else "server",
203+
is_throttling_error=is_throttle,
204+
is_retry_safe=is_throttle or None,
205+
)

packages/smithy-http/src/smithy_http/deserializers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@ class HTTPResponseDeserializer(SpecificShapeDeserializer):
3939
# Note: caller will have to read the body if it's async and not streaming
4040
def __init__(
4141
self,
42+
*,
4243
payload_codec: Codec,
43-
http_trait: HTTPTrait,
4444
response: HTTPResponse,
45+
http_trait: HTTPTrait | None = None,
4546
body: "SyncStreamingBlob | None" = None,
4647
) -> None:
4748
"""Initialize an HTTPResponseDeserializer.
4849
4950
:param payload_codec: The Codec to use to deserialize the payload, if present.
50-
:param http_trait: The HTTP trait of the operation being handled.
5151
:param response: The HTTP response to read from.
52+
:param http_trait: The HTTP trait of the operation being handled.
5253
:param body: The HTTP response body in a synchronously readable form. This is
5354
necessary for async response bodies when there is no streaming member.
5455
"""

0 commit comments

Comments
 (0)