Skip to content

Commit bd0f361

Browse files
blaise-muhirwaBlaise Munyampirwa
andauthored
Add Retries to the Python SDK with Exponential Backoff (#70)
* add decorator for retries * add unit test * add request_mock in pyproject.toml * remove unnecessary property from internal api client * add more unit tests * fix poetry lock file * fix linting * inherit from ApiException in InternalApiException * remove lock file * disable linting for useless-super-delegation * disable useless-super-delegation linting * allow caching image byte stream for subsequent access when file is closed * add a wrapper class to bytes * fix linting * forgotten return statement * add unit test for ByteStreamWrapper * fix linting * fix linting --------- Co-authored-by: Blaise Munyampirwa <[email protected]>
1 parent b3b4cdc commit bd0f361

File tree

6 files changed

+302
-28
lines changed

6 files changed

+302
-28
lines changed

src/groundlight/client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212

1313
from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display
1414
from groundlight.config import API_TOKEN_VARIABLE_NAME, API_TOKEN_WEB_URL
15-
from groundlight.images import parse_supported_image_types
16-
from groundlight.internalapi import GroundlightApiClient, NotFoundError, iq_is_confident, sanitize_endpoint_url
15+
from groundlight.images import ByteStreamWrapper, parse_supported_image_types
16+
from groundlight.internalapi import (
17+
GroundlightApiClient,
18+
NotFoundError,
19+
iq_is_confident,
20+
sanitize_endpoint_url,
21+
)
1722
from groundlight.optional_imports import Image, np
1823

1924
logger = logging.getLogger("groundlight.sdk")
@@ -181,7 +186,8 @@ def submit_image_query(
181186
if wait is None:
182187
wait = self.DEFAULT_WAIT
183188
detector_id = detector.id if isinstance(detector, Detector) else detector
184-
image_bytesio: Union[BytesIO, BufferedReader] = parse_supported_image_types(image)
189+
190+
image_bytesio: ByteStreamWrapper = parse_supported_image_types(image)
185191

186192
raw_image_query = self.image_queries_api.submit_image_query(detector_id=detector_id, body=image_bytesio)
187193
image_query = ImageQuery.parse_obj(raw_image_query.to_dict())

src/groundlight/images.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,33 @@
11
import imghdr
2-
from io import BufferedReader, BytesIO
2+
from io import BufferedReader, BytesIO, IOBase
33
from typing import Union
44

55
from groundlight.optional_imports import Image, np
66

77

8+
class ByteStreamWrapper(IOBase):
9+
"""This class acts as a thin wrapper around bytes in order to
10+
maintain files in an open state. This is useful, in particular,
11+
when we want to retry accessing the file without having to re-open it.
12+
"""
13+
14+
def __init__(self, data: Union[BufferedReader, BytesIO, bytes]) -> None:
15+
super().__init__()
16+
if isinstance(data, (BufferedReader, BytesIO)):
17+
self._data = data.read()
18+
else:
19+
self._data = data
20+
21+
def read(self) -> bytes:
22+
return self._data
23+
24+
def getvalue(self) -> bytes:
25+
return self._data
26+
27+
def close(self) -> None:
28+
pass
29+
30+
831
def buffer_from_jpeg_file(image_filename: str) -> BufferedReader:
932
"""Get a buffer from an jpeg image file.
1033
@@ -29,31 +52,32 @@ def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes:
2952
def parse_supported_image_types(
3053
image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray],
3154
jpeg_quality: int = 95,
32-
) -> Union[BytesIO, BufferedReader]:
55+
) -> ByteStreamWrapper:
3356
"""Parse the many supported image types into a bytes-stream objects.
3457
In some cases we have to JPEG compress.
3558
"""
3659
if isinstance(image, str):
3760
# Assume it is a filename
38-
return buffer_from_jpeg_file(image)
61+
buffer = buffer_from_jpeg_file(image)
62+
return ByteStreamWrapper(data=buffer)
3963
if isinstance(image, bytes):
4064
# Create a BytesIO object
41-
return BytesIO(image)
65+
return ByteStreamWrapper(data=image)
4266
if isinstance(image, Image.Image):
4367
# Save PIL image as jpeg in BytesIO
4468
bytesio = BytesIO()
4569
image.save(bytesio, "jpeg", quality=jpeg_quality)
4670
bytesio.seek(0)
47-
return bytesio
71+
return ByteStreamWrapper(data=bytesio)
4872
if isinstance(image, (BytesIO, BufferedReader)):
4973
# Already in the right format
50-
return image
74+
return ByteStreamWrapper(data=image)
5175
if isinstance(image, np.ndarray):
5276
# Assume it is in BGR format from opencv
53-
return BytesIO(jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality))
77+
return ByteStreamWrapper(data=jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality))
5478
raise TypeError(
5579
(
56-
"Unsupported type for image. Must be PIL, numpy (H,W,3) RGB, or a JPEG as a filename (str), bytes,"
80+
"Unsupported type for image. Must be PIL, numpy (H,W,3) BGR, or a JPEG as a filename (str), bytes,"
5781
" BytesIO, or BufferedReader."
5882
),
5983
)

src/groundlight/internalapi.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import logging
22
import os
3+
import random
34
import time
45
import uuid
5-
from typing import Optional
6+
from functools import wraps
7+
from typing import Callable, Optional
68
from urllib.parse import urlsplit, urlunsplit
79

810
import requests
911
from model import Detector, ImageQuery
10-
from openapi_client.api_client import ApiClient
12+
from openapi_client.api_client import ApiClient, ApiException
1113

1214
from groundlight.status_codes import is_ok
1315

@@ -67,9 +69,76 @@ def iq_is_confident(iq: ImageQuery, confidence_threshold: float) -> bool:
6769
return iq.result.confidence >= confidence_threshold
6870

6971

70-
class InternalApiError(RuntimeError):
71-
# TODO: We need a better exception hierarchy
72-
pass
72+
class InternalApiError(ApiException, RuntimeError):
73+
# TODO: We should really avoid this double inheritance since
74+
# both `ApiException` and `RuntimeError` are subclasses of
75+
# `Exception`. Error handling might become more complex since
76+
# the two super classes cross paths.
77+
# pylint: disable=useless-super-delegation
78+
def __init__(self, status=None, reason=None, http_resp=None):
79+
super().__init__(status, reason, http_resp)
80+
81+
82+
class RequestsRetryDecorator:
83+
"""
84+
Decorate a function to retry sending HTTP requests.
85+
86+
Tries to re-execute the decorated function in case the execution
87+
fails due to a server error (HTTP Error code 500 - 599).
88+
Retry attempts are executed while exponentially backing off by a factor
89+
of 2 with full jitter (picking a random delay time between 0 and the
90+
maximum delay time).
91+
92+
"""
93+
94+
def __init__(
95+
self,
96+
initial_delay: float = 0.2,
97+
exponential_backoff: int = 2,
98+
status_code_range: tuple = (500, 600),
99+
max_retries: int = 3,
100+
):
101+
self.initial_delay = initial_delay
102+
self.exponential_backoff = exponential_backoff
103+
self.status_code_range = range(*status_code_range)
104+
self.max_retries = max_retries
105+
106+
def __call__(self, function: Callable) -> Callable:
107+
""":param callable: The function to invoke."""
108+
109+
@wraps(function)
110+
def decorated(*args, **kwargs):
111+
delay = self.initial_delay
112+
retry_count = 0
113+
114+
while retry_count <= self.max_retries:
115+
try:
116+
return function(*args, **kwargs)
117+
except ApiException as e:
118+
is_retryable = (e.status is not None) and (e.status in self.status_code_range)
119+
if not is_retryable:
120+
raise e
121+
if retry_count == self.max_retries:
122+
raise InternalApiError(reason="Maximum retries reached") from e
123+
124+
if is_retryable:
125+
status_code = e.status
126+
if status_code in self.status_code_range:
127+
logger.warning(
128+
(
129+
f"Current HTTP response status: {status_code}. "
130+
f"Remaining retries: {self.max_retries - retry_count}"
131+
),
132+
exc_info=True,
133+
)
134+
# This is implementing a full jitter strategy
135+
random_delay = random.uniform(0, delay)
136+
time.sleep(random_delay)
137+
138+
retry_count += 1
139+
delay *= self.exponential_backoff
140+
141+
return decorated
73142

74143

75144
class GroundlightApiClient(ApiClient):
@@ -80,6 +149,7 @@ class GroundlightApiClient(ApiClient):
80149

81150
REQUEST_ID_HEADER = "X-Request-Id"
82151

152+
@RequestsRetryDecorator()
83153
def call_api(self, *args, **kwargs):
84154
"""Adds a request-id header to each API call."""
85155
# Note we don't look for header_param in kwargs here, because this method is only called in one place
@@ -97,7 +167,6 @@ def call_api(self, *args, **kwargs):
97167
# The methods below will eventually go away when we move to properly model
98168
# these methods with OpenAPI
99169
#
100-
101170
def _headers(self) -> dict:
102171
request_id = _generate_request_id()
103172
return {
@@ -106,6 +175,7 @@ def _headers(self) -> dict:
106175
"X-Request-Id": request_id,
107176
}
108177

178+
@RequestsRetryDecorator()
109179
def _add_label(self, image_query_id: str, label: str) -> dict:
110180
"""Temporary internal call to add a label to an image query. Not supported."""
111181
# TODO: Properly model this with OpenApi spec.
@@ -126,11 +196,14 @@ def _add_label(self, image_query_id: str, label: str) -> dict:
126196

127197
if not is_ok(response.status_code):
128198
raise InternalApiError(
129-
f"Error adding label to image query {image_query_id} status={response.status_code} {response.text}",
199+
status=response.status_code,
200+
reason=f"Error adding label to image query {image_query_id}",
201+
http_resp=response,
130202
)
131203

132204
return response.json()
133205

206+
@RequestsRetryDecorator()
134207
def _get_detector_by_name(self, name: str) -> Detector:
135208
"""Get a detector by name. For now, we use the list detectors API directly.
136209
@@ -141,9 +214,7 @@ def _get_detector_by_name(self, name: str) -> Detector:
141214
response = requests.request("GET", url, headers=headers)
142215

143216
if not is_ok(response.status_code):
144-
raise InternalApiError(
145-
f"Error getting detector by name '{name}' (status={response.status_code}): {response.text}",
146-
)
217+
raise InternalApiError(status=response.status_code, http_resp=response)
147218

148219
parsed = response.json()
149220

src/groundlight/status_codes.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# Helper functions for checking HTTP status codes.
22

3-
OK_MIN = 200
4-
OK_MAX = 299
5-
USER_ERROR_MIN = 400
6-
USER_ERROR_MAX = 499
3+
4+
# We can use range because of Python's lazy evaluation. Thus, the values
5+
# in the range are actually not generated, so we still get O(1) time complexity
6+
OK_RANGE = range(200, 300)
7+
USER_ERROR_RANGE = range(400, 500)
78

89

910
def is_ok(status_code: int) -> bool:
10-
return OK_MIN <= status_code <= OK_MAX
11+
return status_code in OK_RANGE
1112

1213

1314
def is_user_error(status_code: int) -> bool:
14-
return USER_ERROR_MIN <= status_code <= USER_ERROR_MAX
15+
return status_code in USER_ERROR_RANGE

0 commit comments

Comments
 (0)