10
10
from openapi_client .api .image_queries_api import ImageQueriesApi
11
11
from openapi_client .model .detector_creation_input import DetectorCreationInput
12
12
13
- from groundlight .binary_labels import convert_display_label_to_internal
13
+ from groundlight .binary_labels import Label , convert_display_label_to_internal , convert_internal_label_to_display
14
14
from groundlight .config import API_TOKEN_VARIABLE_NAME , API_TOKEN_WEB_URL
15
15
from groundlight .images import parse_supported_image_types
16
16
from groundlight .internalapi import GroundlightApiClient , NotFoundError , sanitize_endpoint_url
@@ -71,6 +71,15 @@ def __init__(self, endpoint: Optional[str] = None, api_token: Optional[str] = No
71
71
self .detectors_api = DetectorsApi (self .api_client )
72
72
self .image_queries_api = ImageQueriesApi (self .api_client )
73
73
74
+ @classmethod
75
+ def _post_process_image_query (cls , iq : ImageQuery ) -> ImageQuery :
76
+ """Post-process the image query so we don't use confusing internal labels.
77
+
78
+ TODO: Get rid of this once we clean up the mapping logic server-side.
79
+ """
80
+ iq .result .label = convert_internal_label_to_display (iq , iq .result .label )
81
+ return iq
82
+
74
83
def get_detector (self , id : Union [str , Detector ]) -> Detector : # pylint: disable=redefined-builtin
75
84
if isinstance (id , Detector ):
76
85
# Short-circuit
@@ -102,7 +111,12 @@ def create_detector(
102
111
return Detector .parse_obj (obj .to_dict ())
103
112
104
113
def get_or_create_detector (
105
- self , name : str , query : str , * , confidence_threshold : Optional [float ] = None , config_name : Optional [str ] = None
114
+ self ,
115
+ name : str ,
116
+ query : str ,
117
+ * ,
118
+ confidence_threshold : Optional [float ] = None ,
119
+ config_name : Optional [str ] = None ,
106
120
) -> Detector :
107
121
"""Tries to look up the detector by name. If a detector with that name, query, and
108
122
confidence exists, return it. Otherwise, create a detector with the specified query and
@@ -113,30 +127,41 @@ def get_or_create_detector(
113
127
except NotFoundError :
114
128
logger .debug (f"We could not find a detector with name='{ name } '. So we will create a new detector ..." )
115
129
return self .create_detector (
116
- name = name , query = query , confidence_threshold = confidence_threshold , config_name = config_name
130
+ name = name ,
131
+ query = query ,
132
+ confidence_threshold = confidence_threshold ,
133
+ config_name = config_name ,
117
134
)
118
135
119
136
# TODO: We may soon allow users to update the retrieved detector's fields.
120
137
if existing_detector .query != query :
121
138
raise ValueError (
122
- f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the queries don't match."
123
- f" The existing query is '{ existing_detector .query } '."
139
+ (
140
+ f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the queries don't match."
141
+ f" The existing query is '{ existing_detector .query } '."
142
+ ),
124
143
)
125
144
if confidence_threshold is not None and existing_detector .confidence_threshold != confidence_threshold :
126
145
raise ValueError (
127
- f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the confidence"
128
- " thresholds don't match. The existing confidence threshold is"
129
- f" { existing_detector .confidence_threshold } ."
146
+ (
147
+ f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the confidence"
148
+ " thresholds don't match. The existing confidence threshold is"
149
+ f" { existing_detector .confidence_threshold } ."
150
+ ),
130
151
)
131
152
return existing_detector
132
153
133
154
def get_image_query (self , id : str ) -> ImageQuery : # pylint: disable=redefined-builtin
134
155
obj = self .image_queries_api .get_image_query (id = id )
135
- return ImageQuery .parse_obj (obj .to_dict ())
156
+ iq = ImageQuery .parse_obj (obj .to_dict ())
157
+ return self ._post_process_image_query (iq )
136
158
137
159
def list_image_queries (self , page : int = 1 , page_size : int = 10 ) -> PaginatedImageQueryList :
138
160
obj = self .image_queries_api .list_image_queries (page = page , page_size = page_size )
139
- return PaginatedImageQueryList .parse_obj (obj .to_dict ())
161
+ image_queries = PaginatedImageQueryList .parse_obj (obj .to_dict ())
162
+ if image_queries .results is not None :
163
+ image_queries .results = [self ._post_process_image_query (iq ) for iq in image_queries .results ]
164
+ return image_queries
140
165
141
166
def submit_image_query (
142
167
self ,
@@ -166,7 +191,7 @@ def submit_image_query(
166
191
if wait :
167
192
threshold = self .get_detector (detector ).confidence_threshold
168
193
image_query = self .wait_for_confident_result (image_query , confidence_threshold = threshold , timeout_sec = wait )
169
- return image_query
194
+ return self . _post_process_image_query ( image_query )
170
195
171
196
def wait_for_confident_result (
172
197
self ,
@@ -203,11 +228,11 @@ def wait_for_confident_result(
203
228
image_query = self .get_image_query (image_query .id )
204
229
return image_query
205
230
206
- def add_label (self , image_query : Union [ImageQuery , str ], label : str ):
231
+ def add_label (self , image_query : Union [ImageQuery , str ], label : Union [ Label , str ] ):
207
232
"""A new label to an image query. This answers the detector's question.
208
233
:param image_query: Either an ImageQuery object (returned from `submit_image_query`) or
209
234
an image_query id as a string.
210
- :param label: The string "Yes " or the string "No " in answer to the query.
235
+ :param label: The string "YES " or the string "NO " in answer to the query.
211
236
"""
212
237
if isinstance (image_query , ImageQuery ):
213
238
image_query_id = image_query .id
0 commit comments