Skip to content

Commit a600a87

Browse files
Tyler RomeroAuto-format Bot
andauthored
Support creating text-mode detectors (#363)
Co-authored-by: Auto-format Bot <[email protected]>
1 parent 04ac2ea commit a600a87

File tree

5 files changed

+130
-55
lines changed

5 files changed

+130
-55
lines changed

src/groundlight/experimental_api.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from groundlight_openapi_client.model.payload_template_request import PayloadTemplateRequest
3131
from groundlight_openapi_client.model.rule_request import RuleRequest
3232
from groundlight_openapi_client.model.status_enum import StatusEnum
33+
from groundlight_openapi_client.model.text_mode_configuration import TextModeConfiguration
3334
from groundlight_openapi_client.model.webhook_action_request import WebhookActionRequest
3435
from model import (
3536
ROI,
@@ -1053,6 +1054,60 @@ def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-ar
10531054
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
10541055
return Detector.parse_obj(obj.to_dict())
10551056

1057+
def create_text_recognition_detector( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals
1058+
self,
1059+
name: str,
1060+
query: str,
1061+
*,
1062+
group_name: Optional[str] = None,
1063+
confidence_threshold: Optional[float] = None,
1064+
patience_time: Optional[float] = None,
1065+
pipeline_config: Optional[str] = None,
1066+
metadata: Union[dict, str, None] = None,
1067+
) -> Detector:
1068+
"""
1069+
Creates a text recognition detector that can read specified spans of text from images.
1070+
1071+
**Example usage**::
1072+
1073+
gl = ExperimentalApi()
1074+
1075+
# Create a text recognition detector
1076+
detector = gl.create_text_recognition_detector(
1077+
name="date_and_time_detector",
1078+
query="Read the date and time from the bottom left corner of the image.",
1079+
)
1080+
1081+
:param name: A short, descriptive name for the detector.
1082+
:param query: A question about the object to detect in the image.
1083+
:param group_name: Optional name of a group to organize related detectors together.
1084+
:param confidence_threshold: A value that sets the minimum confidence level required for the ML model's
1085+
predictions. If confidence is below this threshold, the query may be sent for human review.
1086+
:param patience_time: The maximum time in seconds that Groundlight will attempt to generate a
1087+
confident prediction before falling back to human review. Defaults to 30 seconds.
1088+
:param pipeline_config: Advanced usage only. Configuration string needed to instantiate a specific
1089+
prediction pipeline for this detector.
1090+
:param metadata: A dictionary or JSON string containing custom key/value pairs to associate with
1091+
1092+
:return: The created Detector object
1093+
"""
1094+
1095+
detector_creation_input = self._prep_create_detector(
1096+
name=name,
1097+
query=query,
1098+
group_name=group_name,
1099+
confidence_threshold=confidence_threshold,
1100+
patience_time=patience_time,
1101+
pipeline_config=pipeline_config,
1102+
metadata=metadata,
1103+
)
1104+
detector_creation_input.mode = ModeEnum.TEXT
1105+
mode_config = TextModeConfiguration()
1106+
1107+
detector_creation_input.mode_configuration = mode_config
1108+
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
1109+
return Detector.parse_obj(obj.to_dict())
1110+
10561111
def _download_mlbinary_url(self, detector: Union[str, Detector]) -> EdgeModelInfo:
10571112
"""
10581113
Gets a temporary presigned URL to download the model binaries for the given detector, along

test/unit/conftest.py renamed to test/conftest.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from model import Detector, ImageQuery, ImageQueryTypeEnum, ResultTypeEnum
66

77

8-
def pytest_configure(config):
8+
def pytest_configure(config): # pylint: disable=unused-argument
99
# Run environment check before tests
1010
gl = Groundlight()
11-
if gl._user_is_privileged():
12-
raise Exception(
11+
if gl._user_is_privileged(): # pylint: disable=protected-access
12+
raise RuntimeError(
1313
"ERROR: You are running tests with a privileged user. Please run tests with a non-privileged user."
1414
)
1515

@@ -31,6 +31,17 @@ def fixture_detector(gl: Groundlight) -> Detector:
3131
return gl.create_detector(name=name, query=query, pipeline_config=pipeline_config)
3232

3333

34+
@pytest.fixture(name="count_detector")
35+
def fixture_count_detector(gl_experimental: ExperimentalApi) -> Detector:
36+
"""Creates a new Test detector."""
37+
name = f"Test {datetime.utcnow()}" # Need a unique name
38+
query = "How many dogs?"
39+
pipeline_config = "never-review-multi" # always predicts 0
40+
return gl_experimental.create_counting_detector(
41+
name=name, query=query, class_name="dog", pipeline_config=pipeline_config
42+
)
43+
44+
3445
@pytest.fixture(name="image_query_yes")
3546
def fixture_image_query_yes(gl: Groundlight, detector: Detector) -> ImageQuery:
3647
iq = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", human_review="NEVER")
@@ -43,9 +54,27 @@ def fixture_image_query_no(gl: Groundlight, detector: Detector) -> ImageQuery:
4354
return iq
4455

4556

57+
@pytest.fixture(name="image_query_one")
58+
def fixture_image_query_one(gl_experimental: Groundlight, count_detector: Detector) -> ImageQuery:
59+
iq = gl_experimental.submit_image_query(
60+
detector=count_detector.id, image="test/assets/dog.jpeg", human_review="NEVER"
61+
)
62+
return iq
63+
64+
65+
@pytest.fixture(name="image_query_zero")
66+
def fixture_image_query_zero(gl_experimental: Groundlight, count_detector: Detector) -> ImageQuery:
67+
iq = gl_experimental.submit_image_query(
68+
detector=count_detector.id, image="test/assets/no_dogs.jpeg", human_review="NEVER"
69+
)
70+
return iq
71+
72+
4673
@pytest.fixture(name="gl_experimental")
47-
def _gl() -> ExperimentalApi:
48-
return ExperimentalApi()
74+
def fixture_gl_experimental() -> ExperimentalApi:
75+
_gl = ExperimentalApi()
76+
_gl.DEFAULT_WAIT = 10
77+
return _gl
4978

5079

5180
@pytest.fixture(name="initial_iq")

test/integration/test_groundlight.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -72,35 +72,6 @@ def is_valid_display_label(label: str) -> bool:
7272
return label in VALID_DISPLAY_LABELS
7373

7474

75-
@pytest.fixture(name="gl")
76-
def fixture_gl() -> Groundlight:
77-
"""Creates a Groundlight client object for testing."""
78-
_gl = Groundlight()
79-
_gl.DEFAULT_WAIT = 10
80-
return _gl
81-
82-
83-
@pytest.fixture(name="detector")
84-
def fixture_detector(gl: Groundlight) -> Detector:
85-
"""Creates a new Test detector."""
86-
name = f"Test {datetime.utcnow()}" # Need a unique name
87-
query = "Is there a dog?"
88-
pipeline_config = "never-review"
89-
return gl.create_detector(name=name, query=query, pipeline_config=pipeline_config)
90-
91-
92-
@pytest.fixture(name="image_query_yes")
93-
def fixture_image_query_yes(gl: Groundlight, detector: Detector) -> ImageQuery:
94-
iq = gl.submit_image_query(detector=detector.id, image="test/assets/dog.jpeg", human_review="NEVER")
95-
return iq
96-
97-
98-
@pytest.fixture(name="image_query_no")
99-
def fixture_image_query_no(gl: Groundlight, detector: Detector) -> ImageQuery:
100-
iq = gl.submit_image_query(detector=detector.id, image="test/assets/cat.jpeg", human_review="NEVER")
101-
return iq
102-
103-
10475
@pytest.fixture(name="image")
10576
def fixture_image() -> str:
10677
return "test/assets/dog.jpeg"

test/unit/test_experimental.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,19 +66,13 @@ def test_update_detector_escalation_type(gl_experimental: ExperimentalApi):
6666
updated_detector.escalation_type == "STANDARD"
6767

6868

69-
@pytest.mark.skip(
70-
reason=(
71-
"Users currently don't have permission to turn object detection on their own. If you have questions, reach out"
72-
" to Groundlight support."
73-
)
74-
)
75-
def test_submit_roi(gl_experimental: ExperimentalApi, image_query_yes: ImageQuery):
69+
def test_submit_roi(gl_experimental: ExperimentalApi, image_query_one: ImageQuery):
7670
"""
7771
verify that we can submit an ROI
7872
"""
7973
label_name = "dog"
8074
roi = gl_experimental.create_roi(label_name, (0, 0), (0.5, 0.5))
81-
gl_experimental.add_label(image_query_yes.id, "YES", [roi])
75+
gl_experimental.add_label(image_query_one.id, 1, [roi])
8276

8377

8478
@pytest.mark.skip(
@@ -87,21 +81,21 @@ def test_submit_roi(gl_experimental: ExperimentalApi, image_query_yes: ImageQuer
8781
" to Groundlight support."
8882
)
8983
)
90-
def test_submit_multiple_rois(gl_experimental: ExperimentalApi, image_query_no: ImageQuery):
84+
def test_submit_multiple_rois(gl_experimental: ExperimentalApi, image_query_one: ImageQuery):
9185
"""
9286
verify that we can submit multiple ROIs
9387
"""
9488
label_name = "dog"
9589
roi = gl_experimental.create_roi(label_name, (0, 0), (0.5, 0.5))
96-
gl_experimental.add_label(image_query_no, "YES", [roi] * 3)
90+
gl_experimental.add_label(image_query_one, 3, [roi] * 3)
9791

9892

9993
def test_counting_detector(gl_experimental: ExperimentalApi):
10094
"""
10195
verify that we can create and submit to a counting detector
10296
"""
10397
name = f"Test {datetime.utcnow()}"
104-
created_detector = gl_experimental.create_counting_detector(name, "How many dogs", "dog")
98+
created_detector = gl_experimental.create_counting_detector(name, "How many dogs", "dog", confidence_threshold=0.0)
10599
assert created_detector is not None
106100
count_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
107101
assert count_iq.result.count is not None
@@ -112,7 +106,7 @@ def test_counting_detector_async(gl_experimental: ExperimentalApi):
112106
verify that we can create and submit to a counting detector
113107
"""
114108
name = f"Test {datetime.utcnow()}"
115-
created_detector = gl_experimental.create_counting_detector(name, "How many dogs", "dog")
109+
created_detector = gl_experimental.create_counting_detector(name, "How many dogs", "dog", confidence_threshold=0.0)
116110
assert created_detector is not None
117111
async_iq = gl_experimental.ask_async(created_detector, "test/assets/dog.jpeg")
118112
# attempting to access fields within the result should raise an exception
@@ -126,27 +120,34 @@ def test_counting_detector_async(gl_experimental: ExperimentalApi):
126120
assert _image_query.result is not None
127121

128122

129-
@pytest.mark.skip(
130-
reason=(
131-
"General users currently currently can't use multiclass detectors. If you have questions, reach out"
132-
" to Groundlight support, or upgrade your plan."
133-
)
134-
)
135123
def test_multiclass_detector(gl_experimental: ExperimentalApi):
136124
"""
137125
verify that we can create and submit to a multi-class detector
138126
"""
139127
name = f"Test {datetime.utcnow()}"
140128
class_names = ["Golden Retriever", "Labrador Retriever", "Poodle"]
141129
created_detector = gl_experimental.create_multiclass_detector(
142-
name, "What kind of dog is this?", class_names=class_names
130+
name, "What kind of dog is this?", class_names=class_names, confidence_threshold=0.0
143131
)
144132
assert created_detector is not None
145133
mc_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
146134
assert mc_iq.result.label is not None
147135
assert mc_iq.result.label in class_names
148136

149137

138+
def test_text_recognition_detector(gl_experimental: ExperimentalApi):
139+
"""
140+
verify that we can create and submit to a text recognition detector
141+
"""
142+
name = f"Test {datetime.utcnow()}"
143+
created_detector = gl_experimental.create_text_recognition_detector(
144+
name, "What is the date and time?", confidence_threshold=0.0
145+
)
146+
assert created_detector is not None
147+
mc_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
148+
assert mc_iq.result.text is not None
149+
150+
150151
@pytest.mark.skip(
151152
reason=(
152153
"General users currently currently can't use bounding box detectors. If you have questions, reach out"
@@ -159,7 +160,7 @@ def test_bounding_box_detector(gl_experimental: ExperimentalApi):
159160
"""
160161
name = f"Test {datetime.now(timezone.utc)}"
161162
created_detector = gl_experimental.create_bounding_box_detector(
162-
name, "Draw a bounding box around each dog in the image", "dog"
163+
name, "Draw a bounding box around each dog in the image", "dog", confidence_threshold=0.0
163164
)
164165
assert created_detector is not None
165166
bbox_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
@@ -179,7 +180,7 @@ def test_bounding_box_detector_async(gl_experimental: ExperimentalApi):
179180
"""
180181
name = f"Test {datetime.now(timezone.utc)}"
181182
created_detector = gl_experimental.create_bounding_box_detector(
182-
name, "Draw a bounding box around each dog in the image", "dog"
183+
name, "Draw a bounding box around each dog in the image", "dog", confidence_threshold=0.0
183184
)
184185
assert created_detector is not None
185186
async_iq = gl_experimental.ask_async(created_detector, "test/assets/dog.jpeg")

test/unit/test_labels.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,22 @@ def test_multiclass_labels(gl_experimental: ExperimentalApi):
6464
assert iq1.result.label == "cherry"
6565
with pytest.raises(ApiException) as _:
6666
gl_experimental.add_label(iq1, "MAYBE")
67+
68+
69+
def test_text_recognition_labels(gl_experimental: ExperimentalApi):
70+
name = f"Test text recognition labels{datetime.utcnow()}"
71+
det = gl_experimental.create_text_recognition_detector(name, "test_query")
72+
iq1 = gl_experimental.submit_image_query(det, "test/assets/cat.jpeg")
73+
gl_experimental.add_label(iq1, "apple text")
74+
iq1 = gl_experimental.get_image_query(iq1.id)
75+
assert iq1.result.text == "apple text"
76+
gl_experimental.add_label(iq1, "banana text")
77+
iq1 = gl_experimental.get_image_query(iq1.id)
78+
assert iq1.result.text == "banana text"
79+
gl_experimental.add_label(iq1, "")
80+
iq1 = gl_experimental.get_image_query(iq1.id)
81+
assert iq1.result.text == ""
82+
83+
gl_experimental.add_label(iq1, "UNCLEAR")
84+
iq1 = gl_experimental.get_image_query(iq1.id)
85+
assert iq1.result.text is None

0 commit comments

Comments
 (0)