From d11db787e375b8b6f5ed7d929665c948ac5a08bb Mon Sep 17 00:00:00 2001 From: amussara Date: Mon, 16 Feb 2026 15:54:58 -0500 Subject: [PATCH] Add comprehensive tests for core modules and enhance RangeSet/APIClient - Add test_range_set.py: 30+ tests covering init, normalization, contains, subtract_ids, random_sample, prioritized_sample, intersection, union - Add test_sampling_list.py: Tests for SamplingListManager init/rotate logic including fill mode, shrink mode, safety checks, prioritize_new - Add test_models.py: Tests for Miner, SampleSubmission, Result models including serialization, sign_data, verify, repr - Add test_api_client_extended.py: Tests for DELETE, PUT 204, POST error parsing - Enhance RangeSet with from_ids() classmethod, __contains__ (binary search), __len__, __eq__, intersection(), and union() methods - Add APIClient.delete() method for DELETE HTTP requests --- affine/core/range_set.py | 109 +++++++++++ affine/utils/api_client.py | 42 +++++ tests/test_api_client_extended.py | 123 ++++++++++++ tests/test_models.py | 177 +++++++++++++++++ tests/test_range_set.py | 303 ++++++++++++++++++++++++++++++ tests/test_sampling_list.py | 187 ++++++++++++++++++ 6 files changed, 941 insertions(+) create mode 100644 tests/test_api_client_extended.py create mode 100644 tests/test_models.py create mode 100644 tests/test_range_set.py create mode 100644 tests/test_sampling_list.py diff --git a/affine/core/range_set.py b/affine/core/range_set.py index 56c5b263..2bdd1eda 100644 --- a/affine/core/range_set.py +++ b/affine/core/range_set.py @@ -24,6 +24,38 @@ def __init__(self, ranges: List[List[int]]): """ self.ranges = self._normalize_ranges(ranges) + @classmethod + def from_ids(cls, ids: List[int]) -> 'RangeSet': + """Create a RangeSet from a list of individual IDs. + + Efficiently converts a list of IDs into merged intervals. + For example, [1, 2, 3, 7, 8, 10] becomes [[1, 4), [7, 9), [10, 11)]. + + Args: + ids: List of integer IDs (need not be sorted or unique) + + Returns: + RangeSet covering exactly the given IDs + """ + if not ids: + return cls([]) + + sorted_ids = sorted(set(ids)) + ranges = [] + start = sorted_ids[0] + prev = start + + for id_val in sorted_ids[1:]: + if id_val == prev + 1: + prev = id_val + else: + ranges.append([start, prev + 1]) + start = id_val + prev = id_val + + ranges.append([start, prev + 1]) + return cls(ranges) + def _normalize_ranges(self, ranges: List[List[int]]) -> List[Tuple[int, int]]: """Normalize ranges: merge overlapping intervals and sort. @@ -242,6 +274,83 @@ def to_list(self) -> List[List[int]]: """ return [[start, end] for start, end in self.ranges] + def __contains__(self, item: int) -> bool: + """Check if an ID is contained in any range using binary search. + + Args: + item: Integer ID to check + + Returns: + True if item is in any range + """ + import bisect + + if not self.ranges: + return False + + # Binary search for the rightmost range whose start <= item + starts = [r[0] for r in self.ranges] + idx = bisect.bisect_right(starts, item) - 1 + + if idx < 0: + return False + + start, end = self.ranges[idx] + return start <= item < end + + def __len__(self) -> int: + """Return the total number of IDs in all ranges.""" + return self.size() + + def __eq__(self, other: object) -> bool: + """Check equality with another RangeSet.""" + if not isinstance(other, RangeSet): + return NotImplemented + return self.ranges == other.ranges + + def intersection(self, other: 'RangeSet') -> 'RangeSet': + """Compute the intersection of two RangeSets. + + Args: + other: Another RangeSet + + Returns: + New RangeSet containing only IDs present in both + """ + result = [] + i, j = 0, 0 + + while i < len(self.ranges) and j < len(other.ranges): + a_start, a_end = self.ranges[i] + b_start, b_end = other.ranges[j] + + # Find overlap + overlap_start = max(a_start, b_start) + overlap_end = min(a_end, b_end) + + if overlap_start < overlap_end: + result.append([overlap_start, overlap_end]) + + # Advance the range that ends first + if a_end < b_end: + i += 1 + else: + j += 1 + + return RangeSet(result) + + def union(self, other: 'RangeSet') -> 'RangeSet': + """Compute the union of two RangeSets. + + Args: + other: Another RangeSet + + Returns: + New RangeSet containing IDs from either set + """ + combined = self.to_list() + other.to_list() + return RangeSet(combined) + def __repr__(self) -> str: """String representation for debugging.""" return f"RangeSet({self.to_list()}, size={self.size()})" \ No newline at end of file diff --git a/affine/utils/api_client.py b/affine/utils/api_client.py index a50dadc3..1c20bcfc 100644 --- a/affine/utils/api_client.py +++ b/affine/utils/api_client.py @@ -290,6 +290,48 @@ async def put( raise NetworkError(f"Network error during PUT {url}: {e}", url, e) + async def delete( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + ) -> Any: + """Make DELETE request to API endpoint. + + Args: + endpoint: API endpoint path + params: Optional query parameters + headers: Optional request headers + + Returns: + Response data dict on success, empty dict for 204 No Content + + Raises: + NetworkError: On network/connection errors + ApiResponseError: On non-2xx response or invalid JSON + """ + + url = f"{self.base_url}{endpoint}" + logger.debug(f"DELETE {url}") + + try: + async with self._session.delete(url, params=params, headers=headers) as response: + if response.status >= 400: + body = await response.text() + raise ApiResponseError(f"HTTP {response.status}: {body[:200]}", response.status, url, body) + + if response.status == 204: + return {} + + try: + return await response.json() + except Exception: + raw = await response.text() + raise ApiResponseError(f"Invalid JSON response: {raw[:200]}", response.status, url, raw) + + except aiohttp.ClientError as e: + raise NetworkError(f"Network error during DELETE {url}: {e}", url, e) + async def get_chute_info(self, chute_id: str) -> Optional[Dict]: """Get chute info from Chutes API. diff --git a/tests/test_api_client_extended.py b/tests/test_api_client_extended.py new file mode 100644 index 00000000..e4fc85ba --- /dev/null +++ b/tests/test_api_client_extended.py @@ -0,0 +1,123 @@ +"""Extended tests for APIClient - DELETE method and edge cases.""" + +import pytest +import aiohttp +from unittest.mock import MagicMock, AsyncMock +from affine.utils.api_client import APIClient +from affine.utils.errors import NetworkError, ApiResponseError + + +class TestAPIClientDelete: + """Test the DELETE method on APIClient.""" + + @pytest.mark.asyncio + async def test_delete_success_json(self): + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json.return_value = {"deleted": True} + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.delete.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + result = await client.delete("/items/123") + assert result == {"deleted": True} + + @pytest.mark.asyncio + async def test_delete_204_no_content(self): + mock_response = AsyncMock() + mock_response.status = 204 + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.delete.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + result = await client.delete("/items/123") + assert result == {} + + @pytest.mark.asyncio + async def test_delete_404(self): + mock_response = AsyncMock() + mock_response.status = 404 + mock_response.text.return_value = "Not Found" + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.delete.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + with pytest.raises(ApiResponseError) as exc: + await client.delete("/items/999") + assert exc.value.status_code == 404 + + @pytest.mark.asyncio + async def test_delete_network_error(self): + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.side_effect = aiohttp.ClientConnectionError("Connection refused") + mock_session.delete.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + with pytest.raises(NetworkError): + await client.delete("/items/1") + + +class TestAPIClientPut: + """Test PUT method edge cases.""" + + @pytest.mark.asyncio + async def test_put_204_no_content(self): + mock_response = AsyncMock() + mock_response.status = 204 + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.put.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + result = await client.put("/items/1", json={"name": "updated"}) + assert result == {} + + @pytest.mark.asyncio + async def test_put_bad_json_response(self): + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.json.side_effect = ValueError("Bad JSON") + mock_response.text.return_value = "not json" + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.put.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + with pytest.raises(ApiResponseError) as exc: + await client.put("/items/1") + assert "Invalid JSON" in str(exc.value) + + +class TestAPIClientPost: + """Test POST method edge cases.""" + + @pytest.mark.asyncio + async def test_post_json_error_parsing(self): + """POST with output_json=True and a JSON-formatted error body.""" + mock_response = AsyncMock() + mock_response.status = 422 + mock_response.text.return_value = '{"detail": "Validation failed"}' + + mock_session = MagicMock() + mock_ctx = MagicMock() + mock_ctx.__aenter__.return_value = mock_response + mock_session.post.return_value = mock_ctx + + client = APIClient("http://test.com", mock_session) + with pytest.raises(ApiResponseError) as exc: + await client.post("/submit", output_json=True) + assert exc.value.status_code == 422 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..feab2aa6 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,177 @@ +"""Tests for core data models - Miner, SampleSubmission, Result.""" + +import json +import time +import pytest +from unittest.mock import MagicMock, patch +from affine.core.models import Miner, SampleSubmission, Result, _truncate + + +class TestTruncate: + """Test the _truncate helper.""" + + def test_none_returns_empty(self): + assert _truncate(None) == "" + + def test_short_text_unchanged(self): + assert _truncate("hello", 80) == "hello" + + def test_long_text_truncated(self): + result = _truncate("a" * 200, 10) + assert len(result) <= 10 + assert "…" in result + + def test_empty_string(self): + assert _truncate("") == "" + + +class TestMiner: + """Test Miner model.""" + + def test_basic_creation(self): + miner = Miner(uid=1, hotkey="5abc123") + assert miner.uid == 1 + assert miner.hotkey == "5abc123" + assert miner.model is None + + def test_full_creation(self): + miner = Miner( + uid=42, + hotkey="hotkey123", + model="my_model", + revision="v1.0", + block=100, + slug="my-slug", + ) + assert miner.uid == 42 + assert miner.model == "my_model" + assert miner.revision == "v1.0" + assert miner.block == 100 + + def test_model_dump_property(self): + miner = Miner(uid=1, hotkey="key") + # model_dump is an alias for dict + data = miner.model_dump() + assert data["uid"] == 1 + assert data["hotkey"] == "key" + + def test_optional_fields_default_none(self): + miner = Miner(uid=0, hotkey="k") + assert miner.model is None + assert miner.revision is None + assert miner.block is None + assert miner.chute is None + assert miner.slug is None + assert miner.weights_shas is None + + +class TestSampleSubmission: + """Test SampleSubmission model.""" + + def test_basic_creation(self): + sub = SampleSubmission(task_uuid="uuid-1", score=0.95, latency_ms=100) + assert sub.task_uuid == "uuid-1" + assert sub.score == 0.95 + assert sub.latency_ms == 100 + assert sub.extra == {} + assert sub.signature == "" + + def test_negative_score_allowed(self): + sub = SampleSubmission(task_uuid="t", score=-5.0, latency_ms=0) + assert sub.score == -5.0 + + def test_negative_latency_rejected(self): + with pytest.raises(Exception): # pydantic validation + SampleSubmission(task_uuid="t", score=1.0, latency_ms=-1) + + def test_get_sign_data_deterministic(self): + sub = SampleSubmission( + task_uuid="test-uuid", + score=0.123456, + latency_ms=500, + extra={"b": 2, "a": 1} + ) + data1 = sub.get_sign_data() + data2 = sub.get_sign_data() + assert data1 == data2 + # Keys should be sorted in extra + assert '"a": 1' in data1 + assert data1.index('"a"') < data1.index('"b"') + + def test_get_sign_data_format(self): + sub = SampleSubmission(task_uuid="u1", score=1.5, latency_ms=100, extra={}) + data = sub.get_sign_data() + assert data == "u1:1.500000:100:{}" + + def test_verify_without_signature_fails(self): + sub = SampleSubmission(task_uuid="t", score=1.0, latency_ms=0, signature="") + # Empty signature should fail verification + result = sub.verify("some_hotkey") + assert result is False + + def test_verify_invalid_hex_fails(self): + sub = SampleSubmission(task_uuid="t", score=1.0, latency_ms=0, signature="not_hex") + result = sub.verify("some_hotkey") + assert result is False + + +class TestResult: + """Test Result model.""" + + def test_basic_creation(self): + r = Result(env="coding", score=0.9, latency_seconds=1.5, success=True) + assert r.env == "coding" + assert r.score == 0.9 + assert r.success is True + assert r.error is None + + def test_failed_result(self): + r = Result( + env="math", + score=0.0, + latency_seconds=0.1, + success=False, + error="Timeout" + ) + assert r.success is False + assert r.error == "Timeout" + + def test_timestamp_auto_set(self): + before = time.time() + r = Result(env="test", score=0.5, latency_seconds=0.1, success=True) + after = time.time() + assert before <= r.timestamp <= after + + def test_extra_defaults_empty(self): + r = Result(env="e", score=0.0, latency_seconds=0.0, success=True) + assert r.extra == {} + + def test_dict_serialization(self): + r = Result(env="env1", score=0.5, latency_seconds=1.0, success=True) + d = r.dict() + assert d["env"] == "env1" + assert d["score"] == 0.5 + assert d["success"] is True + + def test_json_serialization(self): + r = Result(env="env1", score=0.5, latency_seconds=1.0, success=True) + j = r.json() + parsed = json.loads(j) + assert parsed["env"] == "env1" + + def test_repr(self): + r = Result( + env="coding", + score=0.9876, + latency_seconds=1.0, + success=True, + miner_hotkey="abcdef123456789" + ) + s = repr(r) + assert "Result" in s + assert "coding" in s + assert "0.9876" in s + + def test_str_same_as_repr(self): + r = Result(env="e", score=0.0, latency_seconds=0.0, success=True) + assert str(r) == repr(r) diff --git a/tests/test_range_set.py b/tests/test_range_set.py new file mode 100644 index 00000000..3771c64e --- /dev/null +++ b/tests/test_range_set.py @@ -0,0 +1,303 @@ +"""Tests for RangeSet - interval-based set operations.""" + +import pytest +import random +from affine.core.range_set import RangeSet + + +class TestRangeSetInit: + """Test RangeSet initialization and normalization.""" + + def test_empty_ranges(self): + rs = RangeSet([]) + assert rs.size() == 0 + assert rs.ranges == [] + + def test_single_range(self): + rs = RangeSet([[0, 10]]) + assert rs.size() == 10 + assert rs.ranges == [(0, 10)] + + def test_zero_width_ranges_filtered(self): + rs = RangeSet([[5, 5], [3, 3], [1, 2]]) + assert rs.size() == 1 + assert rs.ranges == [(1, 2)] + + def test_overlapping_ranges_merged(self): + rs = RangeSet([[0, 5], [3, 8]]) + assert rs.size() == 8 + assert rs.ranges == [(0, 8)] + + def test_adjacent_ranges_merged(self): + rs = RangeSet([[0, 5], [5, 10]]) + assert rs.size() == 10 + assert rs.ranges == [(0, 10)] + + def test_disjoint_ranges_preserved(self): + rs = RangeSet([[0, 3], [10, 15]]) + assert rs.size() == 8 + assert rs.ranges == [(0, 3), (10, 15)] + + def test_unsorted_ranges_sorted(self): + rs = RangeSet([[10, 20], [0, 5]]) + assert rs.ranges == [(0, 5), (10, 20)] + + def test_multiple_overlapping_ranges(self): + rs = RangeSet([[0, 5], [3, 8], [7, 12], [20, 25]]) + assert rs.ranges == [(0, 12), (20, 25)] + assert rs.size() == 17 + + def test_negative_width_range_filtered(self): + rs = RangeSet([[10, 5]]) + assert rs.size() == 0 + + +class TestRangeSetFromIds: + """Test the from_ids classmethod.""" + + def test_empty_ids(self): + rs = RangeSet.from_ids([]) + assert rs.size() == 0 + + def test_single_id(self): + rs = RangeSet.from_ids([42]) + assert rs.size() == 1 + assert rs.ranges == [(42, 43)] + + def test_consecutive_ids(self): + rs = RangeSet.from_ids([1, 2, 3, 4, 5]) + assert rs.ranges == [(1, 6)] + + def test_gaps_in_ids(self): + rs = RangeSet.from_ids([1, 2, 3, 7, 8, 10]) + assert rs.ranges == [(1, 4), (7, 9), (10, 11)] + assert rs.size() == 6 + + def test_duplicate_ids(self): + rs = RangeSet.from_ids([1, 1, 2, 2, 3]) + assert rs.ranges == [(1, 4)] + assert rs.size() == 3 + + def test_unsorted_ids(self): + rs = RangeSet.from_ids([5, 1, 3, 2, 4]) + assert rs.ranges == [(1, 6)] + + +class TestRangeSetContains: + """Test __contains__ (in operator).""" + + def test_contains_in_range(self): + rs = RangeSet([[10, 20]]) + assert 10 in rs + assert 15 in rs + assert 19 in rs + + def test_not_contains_at_end(self): + rs = RangeSet([[10, 20]]) + assert 20 not in rs # end is exclusive + + def test_not_contains_before_start(self): + rs = RangeSet([[10, 20]]) + assert 9 not in rs + + def test_contains_multiple_ranges(self): + rs = RangeSet([[0, 5], [10, 15]]) + assert 0 in rs + assert 4 in rs + assert 5 not in rs + assert 7 not in rs + assert 10 in rs + assert 14 in rs + + def test_contains_empty(self): + rs = RangeSet([]) + assert 0 not in rs + + +class TestRangeSetLen: + """Test __len__.""" + + def test_len(self): + rs = RangeSet([[0, 100], [200, 300]]) + assert len(rs) == 200 + + +class TestRangeSetEquality: + """Test __eq__.""" + + def test_equal(self): + assert RangeSet([[0, 10]]) == RangeSet([[0, 10]]) + + def test_not_equal(self): + assert RangeSet([[0, 10]]) != RangeSet([[0, 11]]) + + def test_overlapping_normalized_equal(self): + assert RangeSet([[0, 5], [3, 10]]) == RangeSet([[0, 10]]) + + def test_not_equal_other_type(self): + assert RangeSet([[0, 10]]) != "not a rangeset" + + +class TestRangeSetSubtractIds: + """Test subtract_ids.""" + + def test_subtract_empty(self): + rs = RangeSet([[0, 10]]) + result = rs.subtract_ids(set()) + assert result.size() == 10 + + def test_subtract_single_id_middle(self): + rs = RangeSet([[0, 10]]) + result = rs.subtract_ids({5}) + assert result.size() == 9 + assert 5 not in result + assert 4 in result + assert 6 in result + + def test_subtract_first_id(self): + rs = RangeSet([[0, 5]]) + result = rs.subtract_ids({0}) + assert result.size() == 4 + assert 0 not in result + + def test_subtract_last_id(self): + rs = RangeSet([[0, 5]]) + result = rs.subtract_ids({4}) + assert result.size() == 4 + assert 4 not in result + + def test_subtract_all_ids(self): + rs = RangeSet([[0, 3]]) + result = rs.subtract_ids({0, 1, 2}) + assert result.size() == 0 + + def test_subtract_ids_outside_range(self): + rs = RangeSet([[0, 5]]) + result = rs.subtract_ids({10, 20, 30}) + assert result.size() == 5 + + def test_subtract_from_multiple_ranges(self): + rs = RangeSet([[0, 5], [10, 15]]) + result = rs.subtract_ids({2, 12}) + assert result.size() == 8 + assert 2 not in result + assert 12 not in result + + +class TestRangeSetRandomSample: + """Test random_sample.""" + + def test_sample_zero(self): + rs = RangeSet([[0, 100]]) + assert rs.random_sample(0) == [] + + def test_sample_exceeds_size(self): + rs = RangeSet([[0, 5]]) + with pytest.raises(ValueError, match="Cannot sample"): + rs.random_sample(10) + + def test_sample_all(self): + rs = RangeSet([[0, 5]]) + samples = rs.random_sample(5) + assert sorted(samples) == [0, 1, 2, 3, 4] + + def test_sample_within_range(self): + rs = RangeSet([[10, 20]]) + samples = rs.random_sample(5) + assert len(samples) == 5 + assert all(10 <= s < 20 for s in samples) + assert len(set(samples)) == 5 # unique + + def test_sample_multiple_ranges(self): + rs = RangeSet([[0, 5], [100, 105]]) + samples = rs.random_sample(8) + assert len(samples) == 8 + for s in samples: + assert (0 <= s < 5) or (100 <= s < 105) + + +class TestRangeSetPrioritizedSample: + """Test prioritized_sample.""" + + def test_prioritized_sample_zero(self): + rs = RangeSet([[0, 100]]) + assert rs.prioritized_sample(0) == [] + + def test_prioritized_sample_exceeds_size(self): + rs = RangeSet([[0, 5]]) + with pytest.raises(ValueError, match="Cannot sample"): + rs.prioritized_sample(10) + + def test_prioritized_favors_later_ranges(self): + """When sampling fewer than a later range, all should come from later range.""" + rs = RangeSet([[0, 100], [1000, 1005]]) + # Sample 5 — should all come from the later range [1000, 1005) + samples = rs.prioritized_sample(5) + assert len(samples) == 5 + assert all(1000 <= s < 1005 for s in samples) + + def test_prioritized_sample_all(self): + rs = RangeSet([[0, 3], [10, 13]]) + samples = rs.prioritized_sample(6) + assert sorted(samples) == [0, 1, 2, 10, 11, 12] + + +class TestRangeSetIntersection: + """Test intersection.""" + + def test_no_overlap(self): + a = RangeSet([[0, 5]]) + b = RangeSet([[10, 15]]) + assert a.intersection(b).size() == 0 + + def test_full_overlap(self): + a = RangeSet([[0, 10]]) + b = RangeSet([[0, 10]]) + assert a.intersection(b) == RangeSet([[0, 10]]) + + def test_partial_overlap(self): + a = RangeSet([[0, 10]]) + b = RangeSet([[5, 15]]) + result = a.intersection(b) + assert result == RangeSet([[5, 10]]) + + def test_subset(self): + a = RangeSet([[0, 20]]) + b = RangeSet([[5, 10]]) + assert a.intersection(b) == RangeSet([[5, 10]]) + + def test_multiple_intersections(self): + a = RangeSet([[0, 5], [10, 15]]) + b = RangeSet([[3, 12]]) + result = a.intersection(b) + assert result == RangeSet([[3, 5], [10, 12]]) + + +class TestRangeSetUnion: + """Test union.""" + + def test_disjoint_union(self): + a = RangeSet([[0, 5]]) + b = RangeSet([[10, 15]]) + result = a.union(b) + assert result.size() == 10 + + def test_overlapping_union(self): + a = RangeSet([[0, 10]]) + b = RangeSet([[5, 15]]) + result = a.union(b) + assert result == RangeSet([[0, 15]]) + + +class TestRangeSetToList: + """Test to_list serialization.""" + + def test_round_trip(self): + original = [[0, 5], [10, 15], [20, 25]] + rs = RangeSet(original) + assert rs.to_list() == original + + def test_repr(self): + rs = RangeSet([[0, 5]]) + assert "RangeSet" in repr(rs) + assert "size=5" in repr(rs) diff --git a/tests/test_sampling_list.py b/tests/test_sampling_list.py new file mode 100644 index 00000000..2c9e621b --- /dev/null +++ b/tests/test_sampling_list.py @@ -0,0 +1,187 @@ +"""Tests for SamplingListManager and get_task_id_set_from_config.""" + +import pytest +from affine.core.sampling_list import SamplingListManager, get_task_id_set_from_config + + +class TestGetTaskIdSetFromConfig: + """Test get_task_id_set_from_config helper.""" + + def test_with_sampling_list(self): + config = {'sampling_config': {'sampling_list': [1, 2, 3, 4, 5]}} + result = get_task_id_set_from_config(config) + assert result == {1, 2, 3, 4, 5} + + def test_with_empty_sampling_list(self): + config = {'sampling_config': {'sampling_list': []}} + result = get_task_id_set_from_config(config) + assert result == set() + + def test_no_sampling_config(self): + result = get_task_id_set_from_config({}) + assert result == set() + + def test_no_sampling_list_in_config(self): + config = {'sampling_config': {'other_key': 'value'}} + result = get_task_id_set_from_config(config) + assert result == set() + + def test_deduplicates(self): + config = {'sampling_config': {'sampling_list': [1, 1, 2, 2, 3]}} + result = get_task_id_set_from_config(config) + assert result == {1, 2, 3} + + +class TestSamplingListManagerInit: + """Test SamplingListManager.initialize_sampling_list.""" + + @pytest.mark.asyncio + async def test_basic_init(self): + manager = SamplingListManager() + result = await manager.initialize_sampling_list("test_env", [[0, 100]], 10) + assert len(result) == 10 + assert all(0 <= x < 100 for x in result) + assert result == sorted(result) # should be sorted + + @pytest.mark.asyncio + async def test_init_larger_than_available(self): + manager = SamplingListManager() + result = await manager.initialize_sampling_list("test_env", [[0, 5]], 100) + assert len(result) == 5 + assert sorted(result) == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_init_empty_range(self): + manager = SamplingListManager() + result = await manager.initialize_sampling_list("test_env", [], 10) + assert len(result) == 0 + + @pytest.mark.asyncio + async def test_init_multiple_ranges(self): + manager = SamplingListManager() + result = await manager.initialize_sampling_list("test_env", [[0, 5], [100, 105]], 8) + assert len(result) == 8 + for x in result: + assert (0 <= x < 5) or (100 <= x < 105) + + +class TestSamplingListManagerRotate: + """Test SamplingListManager.rotate_sampling_list.""" + + @pytest.mark.asyncio + async def test_basic_rotation(self): + manager = SamplingListManager() + current = list(range(10)) + dataset_range = [[0, 100]] + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, dataset_range, + sampling_count=10, rotation_count=3 + ) + assert len(removed) == 3 + assert len(added) == 3 + assert len(new_list) == 10 + # Removed from front + assert removed == [0, 1, 2] + + @pytest.mark.asyncio + async def test_rotation_zero(self): + """rotation_count=0 means no rotation, only size adjustment.""" + manager = SamplingListManager() + current = list(range(10)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 100]], + sampling_count=10, rotation_count=0 + ) + assert removed == [] + assert added == [] + assert new_list == current + + @pytest.mark.asyncio + async def test_negative_rotation_count(self): + manager = SamplingListManager() + current = list(range(10)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 100]], + sampling_count=10, rotation_count=-1 + ) + assert new_list == current + assert removed == [] + assert added == [] + + @pytest.mark.asyncio + async def test_fill_mode(self): + """When current < target, should only add.""" + manager = SamplingListManager() + current = [0, 1, 2] + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 100]], + sampling_count=10, rotation_count=3 + ) + assert removed == [] + assert len(added) == 7 # fill to 10 + assert len(new_list) == 10 + + @pytest.mark.asyncio + async def test_shrink_mode(self): + """When current > target, should remove surplus + rotation_count.""" + manager = SamplingListManager() + current = list(range(15)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 100]], + sampling_count=10, rotation_count=2 + ) + # surplus=5, remove 5+2=7, add 2 + assert len(removed) == 7 + assert len(added) == 2 + assert len(new_list) == 10 + + @pytest.mark.asyncio + async def test_safety_check_large_rotation(self): + """Should skip if sampling_count + rotation_count > 80% of dataset.""" + manager = SamplingListManager() + current = list(range(10)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 12]], # dataset size 12 + sampling_count=10, rotation_count=3 # 10+3=13 > 12*0.8=9.6 + ) + assert new_list == current + assert removed == [] + assert added == [] + + @pytest.mark.asyncio + async def test_insufficient_available_ids(self): + """Should skip if not enough available IDs for addition.""" + manager = SamplingListManager() + current = list(range(95)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 200]], + sampling_count=95, rotation_count=110 # need 110, only 105 available + ) + assert new_list == current + + @pytest.mark.asyncio + async def test_added_ids_not_in_remaining(self): + """Added IDs should not duplicate remaining IDs.""" + manager = SamplingListManager() + current = list(range(50)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 200]], + sampling_count=50, rotation_count=10 + ) + remaining = current[10:] # after removing 10 from front + remaining_set = set(remaining) + for a in added: + assert a not in remaining_set + + @pytest.mark.asyncio + async def test_prioritize_new(self): + """With prioritize_new=True, additions should come from later segments.""" + manager = SamplingListManager() + current = list(range(5)) + new_list, removed, added = await manager.rotate_sampling_list( + "test_env", current, [[0, 10], [1000, 1010]], + sampling_count=5, rotation_count=3, prioritize_new=True + ) + # Added should come from [1000, 1010] segment first + assert len(added) == 3 + assert all(1000 <= a < 1010 for a in added)