Skip to content

Commit c1c1925

Browse files
committed
Fix detect mocking in CLI tests.
1 parent d61ca7a commit c1c1925

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

tests/cli/test_detect_cli.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pathlib import Path
22
from typing import Callable
3-
from unittest.mock import patch
3+
from unittest.mock import Mock, patch
44

55
import pytest
66
import yaml
@@ -203,7 +203,9 @@ def test_pass_detect_cli(
203203
test_data: dict,
204204
detect_cli: Callable,
205205
) -> None:
206-
with patch("OTVision.detect") as mock_detect:
206+
with patch("detect.OTVisionDetect") as mock_detect:
207+
mock_detect_instance = Mock()
208+
mock_detect.return_value = mock_detect_instance
207209
command = [
208210
*test_data["paths"][PASSED].split(),
209211
*test_data["weights"][PASSED].split(),
@@ -223,6 +225,7 @@ def test_pass_detect_cli(
223225
expected_config = create_expected_config_from_test_data(test_data)
224226

225227
mock_detect.assert_called_once_with(expected_config)
228+
mock_detect_instance.start.assert_called_once()
226229

227230
@pytest.mark.parametrize(argnames="test_fail_data", argvalues=TEST_FAIL_DATA)
228231
def test_fail_wrong_types_passed_to_detect_cli(
@@ -232,7 +235,7 @@ def test_fail_wrong_types_passed_to_detect_cli(
232235
test_fail_data: dict,
233236
) -> None:
234237

235-
with patch("OTVision.detect"):
238+
with patch("detect.OTVisionDetect"):
236239
with pytest.raises(SystemExit) as e:
237240
command = [*test_fail_data[PASSED].split()]
238241
detect_cli(argv=list(filter(None, command)))
@@ -244,13 +247,13 @@ def test_fail_wrong_types_passed_to_detect_cli(
244247
def test_fail_not_existing_path_passed_to_detect_cli(
245248
self, detect_cli: Callable, passed: str
246249
) -> None:
247-
with patch("OTVision.detect"):
250+
with patch("detect.OTVisionDetect"):
248251
with pytest.raises(FileNotFoundError):
249252
command = required_arguments.split() + [*passed.split()]
250253
detect_cli(argv=list(filter(None, command)))
251254

252255
def test_fail_no_paths_passed_to_detect_cli(self, detect_cli: Callable) -> None:
253-
with patch("OTVision.detect"):
256+
with patch("detect.OTVisionDetect"):
254257
error_msg = (
255258
"No paths have been passed as command line args."
256259
+ "No paths have been defined in the user config."

0 commit comments

Comments
 (0)