Skip to content

Commit d88db2c

Browse files
authored
Requires class name for count detectors (#291)
1 parent d82ea4e commit d88db2c

File tree

10 files changed

+55
-15
lines changed

10 files changed

+55
-15
lines changed

generated/docs/CountModeConfiguration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
## Properties
55
Name | Type | Description | Notes
66
------------ | ------------- | ------------- | -------------
7+
**class_name** | **str** | |
78
**max_count** | **int** | | [optional]
89
**any string name** | **bool, date, datetime, dict, float, int, list, str, none_type** | any string name can be used but the value must be the correct type | [optional]
910

generated/docs/ImageQueriesApi.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ with groundlight_openapi_client.ApiClient(configuration) as api_client:
282282
# Create an instance of the API class
283283
api_instance = image_queries_api.ImageQueriesApi(api_client)
284284
detector_id = "detector_id_example" # str | Choose a detector by its ID.
285+
confidence_threshold = 0 # float | The confidence threshold for the image query. (optional)
285286
human_review = "human_review_example" # str | If set to `DEFAULT`, use the regular escalation logic (i.e., send the image query for human review if the ML model is not confident). If set to `ALWAYS`, always send the image query for human review even if the ML model is confident. If set to `NEVER`, never send the image query for human review even if the ML model is not confident. (optional)
286287
image_query_id = "image_query_id_example" # str | The ID to assign to the created image query. (optional)
287288
inspection_id = "inspection_id_example" # str | Associate the image query with an inspection. (optional)
@@ -300,7 +301,7 @@ with groundlight_openapi_client.ApiClient(configuration) as api_client:
300301
# example passing only required values which don't have defaults set
301302
# and optional values
302303
try:
303-
api_response = api_instance.submit_image_query(detector_id, human_review=human_review, image_query_id=image_query_id, inspection_id=inspection_id, metadata=metadata, patience_time=patience_time, want_async=want_async, body=body)
304+
api_response = api_instance.submit_image_query(detector_id, confidence_threshold=confidence_threshold, human_review=human_review, image_query_id=image_query_id, inspection_id=inspection_id, metadata=metadata, patience_time=patience_time, want_async=want_async, body=body)
304305
pprint(api_response)
305306
except groundlight_openapi_client.ApiException as e:
306307
print("Exception when calling ImageQueriesApi->submit_image_query: %s\n" % e)
@@ -312,6 +313,7 @@ with groundlight_openapi_client.ApiClient(configuration) as api_client:
312313
Name | Type | Description | Notes
313314
------------- | ------------- | ------------- | -------------
314315
**detector_id** | **str**| Choose a detector by its ID. |
316+
**confidence_threshold** | **float**| The confidence threshold for the image query. | [optional]
315317
**human_review** | **str**| If set to `DEFAULT`, use the regular escalation logic (i.e., send the image query for human review if the ML model is not confident). If set to `ALWAYS`, always send the image query for human review even if the ML model is confident. If set to `NEVER`, never send the image query for human review even if the ML model is not confident. | [optional]
316318
**image_query_id** | **str**| The ID to assign to the created image query. | [optional]
317319
**inspection_id** | **str**| Associate the image query with an inspection. | [optional]

generated/groundlight_openapi_client/api/image_queries_api.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self, api_client=None):
170170
params_map={
171171
"all": [
172172
"detector_id",
173+
"confidence_threshold",
173174
"human_review",
174175
"image_query_id",
175176
"inspection_id",
@@ -183,13 +184,21 @@ def __init__(self, api_client=None):
183184
],
184185
"nullable": [],
185186
"enum": [],
186-
"validation": [],
187+
"validation": [
188+
"confidence_threshold",
189+
],
187190
},
188191
root_map={
189-
"validations": {},
192+
"validations": {
193+
("confidence_threshold",): {
194+
"inclusive_maximum": 1,
195+
"inclusive_minimum": 0,
196+
},
197+
},
190198
"allowed_values": {},
191199
"openapi_types": {
192200
"detector_id": (str,),
201+
"confidence_threshold": (float,),
193202
"human_review": (str,),
194203
"image_query_id": (str,),
195204
"inspection_id": (str,),
@@ -200,6 +209,7 @@ def __init__(self, api_client=None):
200209
},
201210
"attribute_map": {
202211
"detector_id": "detector_id",
212+
"confidence_threshold": "confidence_threshold",
203213
"human_review": "human_review",
204214
"image_query_id": "image_query_id",
205215
"inspection_id": "inspection_id",
@@ -209,6 +219,7 @@ def __init__(self, api_client=None):
209219
},
210220
"location_map": {
211221
"detector_id": "query",
222+
"confidence_threshold": "query",
212223
"human_review": "query",
213224
"image_query_id": "query",
214225
"inspection_id": "query",
@@ -421,6 +432,7 @@ def submit_image_query(self, detector_id, **kwargs):
421432
detector_id (str): Choose a detector by its ID.
422433
423434
Keyword Args:
435+
confidence_threshold (float): The confidence threshold for the image query.. [optional]
424436
human_review (str): If set to `DEFAULT`, use the regular escalation logic (i.e., send the image query for human review if the ML model is not confident). If set to `ALWAYS`, always send the image query for human review even if the ML model is confident. If set to `NEVER`, never send the image query for human review even if the ML model is not confident.. [optional]
425437
image_query_id (str): The ID to assign to the created image query.. [optional]
426438
inspection_id (str): Associate the image query with an inspection.. [optional]

generated/groundlight_openapi_client/model/count_mode_configuration.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def openapi_types():
9393
and the value is attribute type.
9494
"""
9595
return {
96+
"class_name": (str,), # noqa: E501
9697
"max_count": (int,), # noqa: E501
9798
}
9899

@@ -101,6 +102,7 @@ def discriminator():
101102
return None
102103

103104
attribute_map = {
105+
"class_name": "class_name", # noqa: E501
104106
"max_count": "max_count", # noqa: E501
105107
}
106108

@@ -110,9 +112,12 @@ def discriminator():
110112

111113
@classmethod
112114
@convert_js_args_to_python_args
113-
def _from_openapi_data(cls, *args, **kwargs): # noqa: E501
115+
def _from_openapi_data(cls, class_name, *args, **kwargs): # noqa: E501
114116
"""CountModeConfiguration - a model defined in OpenAPI
115117
118+
Args:
119+
class_name (str):
120+
116121
Keyword Args:
117122
_check_type (bool): if True, values for parameters in openapi_types
118123
will be type checked and a TypeError will be
@@ -173,6 +178,7 @@ def _from_openapi_data(cls, *args, **kwargs): # noqa: E501
173178
self._configuration = _configuration
174179
self._visited_composed_classes = _visited_composed_classes + (self.__class__,)
175180

181+
self.class_name = class_name
176182
for var_name, var_value in kwargs.items():
177183
if (
178184
var_name not in self.attribute_map
@@ -195,9 +201,12 @@ def _from_openapi_data(cls, *args, **kwargs): # noqa: E501
195201
])
196202

197203
@convert_js_args_to_python_args
198-
def __init__(self, *args, **kwargs): # noqa: E501
204+
def __init__(self, class_name, *args, **kwargs): # noqa: E501
199205
"""CountModeConfiguration - a model defined in OpenAPI
200206
207+
Args:
208+
class_name (str):
209+
201210
Keyword Args:
202211
_check_type (bool): if True, values for parameters in openapi_types
203212
will be type checked and a TypeError will be
@@ -256,6 +265,7 @@ def __init__(self, *args, **kwargs): # noqa: E501
256265
self._configuration = _configuration
257266
self._visited_composed_classes = _visited_composed_classes + (self.__class__,)
258267

268+
self.class_name = class_name
259269
for var_name, var_value in kwargs.items():
260270
if (
261271
var_name not in self.attribute_map

generated/groundlight_openapi_client/model/counting_result.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class CountingResult(ModelNormal):
6464
}
6565

6666
validations = {
67+
("count",): {
68+
"inclusive_minimum": 0,
69+
},
6770
("confidence",): {
6871
"inclusive_maximum": 1.0,
6972
"inclusive_minimum": 0.0,

generated/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# generated by datamodel-codegen:
22
# filename: public-api.yaml
3-
# timestamp: 2024-12-05T21:27:07+00:00
3+
# timestamp: 2024-12-07T00:51:02+00:00
44

55
from __future__ import annotations
66

@@ -200,7 +200,7 @@ class BinaryClassificationResult(BaseModel):
200200
class CountingResult(BaseModel):
201201
confidence: Optional[confloat(ge=0.0, le=1.0)] = None
202202
source: Optional[Source] = None
203-
count: int
203+
count: conint(ge=0)
204204
greater_than_max: Optional[bool] = None
205205

206206

@@ -212,6 +212,7 @@ class MultiClassificationResult(BaseModel):
212212

213213
class CountModeConfiguration(BaseModel):
214214
max_count: Optional[conint(ge=1, le=50)] = None
215+
class_name: str
215216

216217

217218
class MultiClassModeConfiguration(BaseModel):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ packages = [
99
{include = "**/*.py", from = "src"},
1010
]
1111
readme = "README.md"
12-
version = "0.20.0"
12+
version = "0.21.0"
1313

1414
[tool.poetry.dependencies]
1515
# For certifi, use ">=" instead of "^" since it upgrades its "major version" every year, not really following semver

spec/public-api.yaml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,14 @@ paths:
347347
--data-binary @path/to/filename.jpeg
348348
```
349349
parameters:
350+
- in: query
351+
name: confidence_threshold
352+
schema:
353+
type: number
354+
format: float
355+
minimum: 0
356+
maximum: 1
357+
description: The confidence threshold for the image query.
350358
- in: query
351359
name: detector_id
352360
schema:
@@ -1367,8 +1375,7 @@ components:
13671375
- ALGORITHM
13681376
count:
13691377
type: integer
1370-
minimum: null
1371-
maximum: null
1378+
minimum: 0
13721379
greater_than_max:
13731380
type: boolean
13741381
required:
@@ -1400,7 +1407,10 @@ components:
14001407
type: integer
14011408
minimum: 1
14021409
maximum: 50
1403-
required: []
1410+
class_name:
1411+
type: string
1412+
required:
1413+
- class_name
14041414
MultiClassModeConfiguration:
14051415
type: object
14061416
properties:
@@ -1410,8 +1420,6 @@ components:
14101420
type: string
14111421
num_classes:
14121422
type: integer
1413-
minimum: null
1414-
maximum: null
14151423
required:
14161424
- class_names
14171425
ChannelEnum:

src/groundlight/experimental_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
687687
self,
688688
name: str,
689689
query: str,
690+
class_name: str,
690691
*,
691692
max_count: Optional[int] = None,
692693
group_name: Optional[str] = None,
@@ -706,6 +707,7 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
706707
detector = gl.create_counting_detector(
707708
name="people_counter",
708709
query="How many people are in the image?",
710+
class_name="person",
709711
max_count=5,
710712
confidence_threshold=0.9,
711713
patience_time=30.0
@@ -718,6 +720,7 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
718720
719721
:param name: A short, descriptive name for the detector.
720722
:param query: A question about the count of an object in the image.
723+
:param class_name: The class name of the object to count.
721724
:param max_count: Maximum number of objects to count (default: 10)
722725
:param group_name: Optional name of a group to organize related detectors together.
723726
:param confidence_threshold: A value that sets the minimum confidence level required for the ML model's
@@ -747,7 +750,7 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
747750
# TODO: pull the BE defined default
748751
if max_count is None:
749752
max_count = 10
750-
mode_config = CountModeConfiguration(max_count=max_count)
753+
mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name)
751754
detector_creation_input.mode_configuration = mode_config
752755
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
753756
return Detector.parse_obj(obj.to_dict())

test/unit/test_experimental.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_counting_detector(gl_experimental: ExperimentalApi):
100100
verify that we can create and submit to a counting detector
101101
"""
102102
name = f"Test {datetime.utcnow()}"
103-
created_detector = gl_experimental.create_counting_detector(name, "How many dogs")
103+
created_detector = gl_experimental.create_counting_detector(name, "How many dogs", "dog")
104104
assert created_detector is not None
105105
count_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
106106
assert count_iq.result.count is not None

0 commit comments

Comments
 (0)