Skip to content

Commit 2e2376a

Browse files
merge 0.8 branch into main (#56)
* Bump 0.8.0rc1 * (0.8.0rc1) Return YES/NO instead of PASS/FAIL (#50) * Convert PASS/FAIL to YES/NO * Change to classmethod * Remove logger * Bump 0.8.0 * (0.8.0rc1) Only allow "YES"/"NO" in `add_label()` (#51) * Convert PASS/FAIL to YES/NO * Change to classmethod * Remove logger * Bump 0.8.0 * Only allow YES/NO labels through SDK; Add tests * Add comments * Remove duplicates * Feedback * Add type ignore * Remove select all * Enum * Update pyproject.toml --------- Co-authored-by: Michael Vogelsong <[email protected]>
1 parent 6f55ad0 commit 2e2376a

File tree

10 files changed

+223
-60
lines changed

10 files changed

+223
-60
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ generate: install-generator ## Generate the SDK from our public openapi spec
2020
poetry run datamodel-codegen --input spec/public-api.yaml --output generated/model.py
2121
poetry run black .
2222

23-
PYTEST=poetry run pytest -v --cov=src
23+
PYTEST=poetry run pytest -v
2424

2525
# You can pass extra arguments to pytest by setting the TEST_ARGS environment variable.
2626
# For example:

code-quality/lint

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ echo "Linting with mypy (type checking) ..."
3636
poetry run mypy $TARGET_PATHS || ((errors++))
3737

3838
if [[ $errors -gt 0 ]]; then
39-
echo "🚨 $errors linters found errors!"
40-
exit $errors
39+
echo "🚨 $errors linters found errors!"
40+
exit $errors
4141
fi
4242

4343
echo "✅ Success!"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packages = [
99
{include = "**/*.py", from = "src"},
1010
]
1111
readme = "README.md"
12-
version = "0.7.8"
12+
version = "0.8.0"
1313

1414
[tool.poetry.dependencies]
1515
certifi = "^2021.10.8"

samples/blocking_submit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""Example of how to wait for a confident result
2-
"""
1+
"""Example of how to wait for a confident result."""
32
import logging
43

54
from groundlight import Groundlight

src/groundlight/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# Imports from our code
66
from .client import Groundlight
7+
from .binary_labels import Label
78

89
try:
910
import importlib.metadata

src/groundlight/binary_labels.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,76 @@
1-
"""Defines the possible values for binary class labels like Yes/No or PASS/FAIL.
1+
"""Defines the possible values for binary class labels like YES/NO.
22
Provides methods to convert between them.
33
44
This part of the API is kinda ugly right now. So we'll encapsulate the ugliness in one place.
55
"""
66
import logging
7-
from typing import List, Union
7+
from enum import Enum
8+
from typing import Union
89

910
from model import Detector, ImageQuery
1011

11-
logger = logging.getLogger("groundlight.sdk")
12+
logger = logging.getLogger(__name__)
1213

1314

14-
def internal_labels_for_detector(
15-
context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument
16-
) -> List[str]:
17-
"""Returns an ordered list of class labels as strings.
18-
These are the versions of labels that the API demands.
19-
:param context: Can be an ImageQuery, a Detector, or a string-id for one of them.
20-
"""
21-
# NOTE: At some point this will need to be an API call, because these will be defined per-detector
22-
return ["PASS", "FAIL"]
15+
class Label(str, Enum):
16+
YES = "YES"
17+
NO = "NO"
18+
UNSURE = "UNSURE"
19+
20+
21+
VALID_DISPLAY_LABELS = {Label.YES, Label.NO, Label.UNSURE}
22+
23+
24+
class DeprecatedLabel(str, Enum):
25+
PASS = "PASS"
26+
FAIL = "FAIL"
27+
NEEDS_REVIEW = "NEEDS_REVIEW"
28+
29+
30+
DEPRECATED_LABEL_NAMES = {DeprecatedLabel.PASS, DeprecatedLabel.FAIL, DeprecatedLabel.NEEDS_REVIEW}
2331

2432

2533
def convert_internal_label_to_display(
2634
context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument
2735
label: str,
28-
) -> str:
36+
) -> Union[Label, str]:
37+
"""Convert a label that comes from our API into the label string enum that we show to the user.
38+
39+
NOTE: We return UPPERCASE label strings to the user, unless there is a custom label (which
40+
shouldn't be happening at this time).
41+
"""
2942
# NOTE: Someday we will do nothing here, when the server provides properly named classes.
43+
if not isinstance(label, str):
44+
raise ValueError(f"Expected a string label, but got {label} of type {type(label)}")
3045
upper = label.upper()
31-
if upper == "PASS":
32-
return "YES"
33-
if upper == "FAIL":
34-
return "NO"
35-
if upper in ["YES", "NO"]:
36-
return label
37-
logger.warning(f"Unrecognized internal label {label} - leaving alone.")
46+
if upper in {Label.YES, DeprecatedLabel.PASS}:
47+
return Label.YES
48+
if upper in {Label.NO, DeprecatedLabel.FAIL}:
49+
return Label.NO
50+
if upper in {Label.UNSURE, DeprecatedLabel.NEEDS_REVIEW}:
51+
return Label.UNSURE
52+
53+
logger.warning(f"Unrecognized internal label {label} - leaving it alone as a string.")
3854
return label
3955

4056

4157
def convert_display_label_to_internal(
4258
context: Union[ImageQuery, Detector, str], # pylint: disable=unused-argument
43-
label: str,
59+
label: Union[Label, str],
4460
) -> str:
61+
"""Convert a label that comes from the user into the label string that we send to the server. We
62+
are strict here, and only allow YES/NO.
63+
64+
NOTE: We accept case-insensitive label strings from the user, but we send UPPERCASE labels to
65+
the server. E.g., user inputs "yes" -> the label is returned as "YES".
66+
"""
4567
# NOTE: In the future we should validate against actually supported labels for the detector
68+
if not isinstance(label, str):
69+
raise ValueError(f"Expected a string label, but got {label} of type {type(label)}")
4670
upper = label.upper()
47-
if upper in {"PASS", "YES"}:
48-
return "PASS"
49-
if upper in {"FAIL", "NO"}:
50-
return "FAIL"
51-
raise ValueError(f'Invalid label string "{label}". Must be one of YES,NO,PASS,FAIL')
71+
if upper == Label.YES:
72+
return DeprecatedLabel.PASS.value
73+
if upper == Label.NO:
74+
return DeprecatedLabel.FAIL.value
75+
76+
raise ValueError(f"Invalid label string '{label}'. Must be one of '{Label.YES.value}','{Label.NO.value}'.")

src/groundlight/client.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from openapi_client.api.image_queries_api import ImageQueriesApi
1111
from openapi_client.model.detector_creation_input import DetectorCreationInput
1212

13-
from groundlight.binary_labels import convert_display_label_to_internal
13+
from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display
1414
from groundlight.config import API_TOKEN_VARIABLE_NAME, API_TOKEN_WEB_URL
1515
from groundlight.images import parse_supported_image_types
1616
from groundlight.internalapi import GroundlightApiClient, NotFoundError, sanitize_endpoint_url
@@ -71,6 +71,15 @@ def __init__(self, endpoint: Optional[str] = None, api_token: Optional[str] = No
7171
self.detectors_api = DetectorsApi(self.api_client)
7272
self.image_queries_api = ImageQueriesApi(self.api_client)
7373

74+
@classmethod
75+
def _post_process_image_query(cls, iq: ImageQuery) -> ImageQuery:
76+
"""Post-process the image query so we don't use confusing internal labels.
77+
78+
TODO: Get rid of this once we clean up the mapping logic server-side.
79+
"""
80+
iq.result.label = convert_internal_label_to_display(iq, iq.result.label)
81+
return iq
82+
7483
def get_detector(self, id: Union[str, Detector]) -> Detector: # pylint: disable=redefined-builtin
7584
if isinstance(id, Detector):
7685
# Short-circuit
@@ -102,7 +111,12 @@ def create_detector(
102111
return Detector.parse_obj(obj.to_dict())
103112

104113
def get_or_create_detector(
105-
self, name: str, query: str, *, confidence_threshold: Optional[float] = None, config_name: Optional[str] = None
114+
self,
115+
name: str,
116+
query: str,
117+
*,
118+
confidence_threshold: Optional[float] = None,
119+
config_name: Optional[str] = None,
106120
) -> Detector:
107121
"""Tries to look up the detector by name. If a detector with that name, query, and
108122
confidence exists, return it. Otherwise, create a detector with the specified query and
@@ -113,30 +127,41 @@ def get_or_create_detector(
113127
except NotFoundError:
114128
logger.debug(f"We could not find a detector with name='{name}'. So we will create a new detector ...")
115129
return self.create_detector(
116-
name=name, query=query, confidence_threshold=confidence_threshold, config_name=config_name
130+
name=name,
131+
query=query,
132+
confidence_threshold=confidence_threshold,
133+
config_name=config_name,
117134
)
118135

119136
# TODO: We may soon allow users to update the retrieved detector's fields.
120137
if existing_detector.query != query:
121138
raise ValueError(
122-
f"Found existing detector with name={name} (id={existing_detector.id}) but the queries don't match."
123-
f" The existing query is '{existing_detector.query}'."
139+
(
140+
f"Found existing detector with name={name} (id={existing_detector.id}) but the queries don't match."
141+
f" The existing query is '{existing_detector.query}'."
142+
),
124143
)
125144
if confidence_threshold is not None and existing_detector.confidence_threshold != confidence_threshold:
126145
raise ValueError(
127-
f"Found existing detector with name={name} (id={existing_detector.id}) but the confidence"
128-
" thresholds don't match. The existing confidence threshold is"
129-
f" {existing_detector.confidence_threshold}."
146+
(
147+
f"Found existing detector with name={name} (id={existing_detector.id}) but the confidence"
148+
" thresholds don't match. The existing confidence threshold is"
149+
f" {existing_detector.confidence_threshold}."
150+
),
130151
)
131152
return existing_detector
132153

133154
def get_image_query(self, id: str) -> ImageQuery: # pylint: disable=redefined-builtin
134155
obj = self.image_queries_api.get_image_query(id=id)
135-
return ImageQuery.parse_obj(obj.to_dict())
156+
iq = ImageQuery.parse_obj(obj.to_dict())
157+
return self._post_process_image_query(iq)
136158

137159
def list_image_queries(self, page: int = 1, page_size: int = 10) -> PaginatedImageQueryList:
138160
obj = self.image_queries_api.list_image_queries(page=page, page_size=page_size)
139-
return PaginatedImageQueryList.parse_obj(obj.to_dict())
161+
image_queries = PaginatedImageQueryList.parse_obj(obj.to_dict())
162+
if image_queries.results is not None:
163+
image_queries.results = [self._post_process_image_query(iq) for iq in image_queries.results]
164+
return image_queries
140165

141166
def submit_image_query(
142167
self,
@@ -166,7 +191,7 @@ def submit_image_query(
166191
if wait:
167192
threshold = self.get_detector(detector).confidence_threshold
168193
image_query = self.wait_for_confident_result(image_query, confidence_threshold=threshold, timeout_sec=wait)
169-
return image_query
194+
return self._post_process_image_query(image_query)
170195

171196
def wait_for_confident_result(
172197
self,
@@ -203,11 +228,11 @@ def wait_for_confident_result(
203228
image_query = self.get_image_query(image_query.id)
204229
return image_query
205230

206-
def add_label(self, image_query: Union[ImageQuery, str], label: str):
231+
def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str]):
207232
"""A new label to an image query. This answers the detector's question.
208233
:param image_query: Either an ImageQuery object (returned from `submit_image_query`) or
209234
an image_query id as a string.
210-
:param label: The string "Yes" or the string "No" in answer to the query.
235+
:param label: The string "YES" or the string "NO" in answer to the query.
211236
"""
212237
if isinstance(image_query, ImageQuery):
213238
image_query_id = image_query.id

src/groundlight/internalapi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _get_detector_by_name(self, name: str) -> Detector:
133133

134134
if not is_ok(response.status_code):
135135
raise InternalApiError(
136-
f"Error getting detector by name '{name}' (status={response.status_code}): {response.text}"
136+
f"Error getting detector by name '{name}' (status={response.status_code}): {response.text}",
137137
)
138138

139139
parsed = response.json()
@@ -142,6 +142,6 @@ def _get_detector_by_name(self, name: str) -> Detector:
142142
raise NotFoundError(f"Detector with name={name} not found.")
143143
if parsed["count"] > 1:
144144
raise RuntimeError(
145-
f"We found multiple ({parsed['count']}) detectors with the same name. This shouldn't happen."
145+
f"We found multiple ({parsed['count']}) detectors with the same name. This shouldn't happen.",
146146
)
147147
return Detector.parse_obj(parsed["results"][0])

0 commit comments

Comments
 (0)