diff --git a/src/genai_utils/gemini.py b/src/genai_utils/gemini.py index 9c17114..63d41e1 100644 --- a/src/genai_utils/gemini.py +++ b/src/genai_utils/gemini.py @@ -213,50 +213,57 @@ def add_citations(response: types.GenerateContentResponse) -> str: return text -def validate_labels(labels: dict[str, str]) -> None: +def validate_labels(labels: dict[str, str]) -> dict[str, str]: """ - Validates labels for GCP requirements. + Validates labels for GCP requirements, removing any labels that would cause GCP to + return an error. GCP label requirements: - Keys must start with a lowercase letter - Keys and values can only contain lowercase letters, numbers, hyphens, and underscores - Keys and values must be max 63 characters - Keys cannot be empty - - Raises: - GeminiError: If labels don't meet GCP requirements """ label_pattern = re.compile(r"^[a-z0-9_-]{1,63}$") key_start_pattern = re.compile(r"^[a-z]") + valid_labels: dict[str, str] = {} for key, value in labels.items(): if not key: - raise GeminiError("Label keys cannot be empty") + _logger.warning("Label keys cannot be empty") + continue if len(key) > 63: - raise GeminiError( + _logger.warning( f"Label key '{key}' exceeds 63 characters (length: {len(key)})" ) + continue if len(value) > 63: - raise GeminiError( + _logger.warning( f"Label value for key '{key}' exceeds 63 characters (length: {len(value)})" ) + continue if not key_start_pattern.match(key): - raise GeminiError(f"Label key '{key}' must start with a lowercase letter") + _logger.warning(f"Label key '{key}' must start with a lowercase letter") + continue if not label_pattern.match(key): - raise GeminiError( + _logger.warning( f"Label key '{key}' contains invalid characters. " "Only lowercase letters, numbers, hyphens, and underscores are allowed" ) + continue if not label_pattern.match(value): - raise GeminiError( + _logger.warning( f"Label value '{value}' for key '{key}' contains invalid characters. " "Only lowercase letters, numbers, hyphens, and underscores are allowed" ) + continue + valid_labels[key] = value + return valid_labels def check_grounding_ran(response: types.GenerateContentResponse) -> bool: @@ -543,8 +550,7 @@ class Movie(BaseModel): if inline_citations and not use_grounding: raise GeminiError("Inline citations only work if `use_grounding = True`") - merged_labels = DEFAULT_LABELS | labels - validate_labels(merged_labels) + merged_labels = validate_labels(DEFAULT_LABELS | labels) response = await client.aio.models.generate_content( model=model_config.model_name, diff --git a/tests/genai_utils/test_labels.py b/tests/genai_utils/test_labels.py index 5866dfc..7443551 100644 --- a/tests/genai_utils/test_labels.py +++ b/tests/genai_utils/test_labels.py @@ -7,7 +7,6 @@ from google.genai.models import Models from genai_utils.gemini import ( - GeminiError, ModelConfig, run_prompt_async, validate_labels, @@ -24,7 +23,7 @@ async def get_dummy(): def test_validate_labels_valid(): - """Test that valid labels pass validation""" + """Test that valid labels pass validation and are returned""" valid_labels = { "team": "ai", "project": "genai-utils", @@ -32,66 +31,66 @@ def test_validate_labels_valid(): "version": "1-2-3", "my_label": "my_value", } - # Should not raise any exception - validate_labels(valid_labels) + result = validate_labels(valid_labels) + assert result == valid_labels def test_validate_labels_empty_key(): - """Test that empty keys are rejected""" - with pytest.raises(GeminiError, match="cannot be empty"): - validate_labels({"": "value"}) + """Test that empty keys are filtered out""" + result = validate_labels({"": "value", "valid": "label"}) + assert result == {"valid": "label"} def test_validate_labels_key_too_long(): - """Test that keys exceeding 63 characters are rejected""" + """Test that keys exceeding 63 characters are filtered out""" long_key = "a" * 64 - with pytest.raises(GeminiError, match="exceeds 63 characters"): - validate_labels({long_key: "value"}) + result = validate_labels({long_key: "value", "valid": "label"}) + assert result == {"valid": "label"} def test_validate_labels_value_too_long(): - """Test that values exceeding 63 characters are rejected""" + """Test that values exceeding 63 characters are filtered out""" long_value = "a" * 64 - with pytest.raises(GeminiError, match="exceeds 63 characters"): - validate_labels({"key": long_value}) + result = validate_labels({"key": long_value, "valid": "label"}) + assert result == {"valid": "label"} def test_validate_labels_key_starts_with_number(): - """Test that keys starting with numbers are rejected""" - with pytest.raises(GeminiError, match="must start with a lowercase letter"): - validate_labels({"1key": "value"}) + """Test that keys starting with numbers are filtered out""" + result = validate_labels({"1key": "value", "valid": "label"}) + assert result == {"valid": "label"} def test_validate_labels_key_starts_with_uppercase(): - """Test that keys starting with uppercase are rejected""" - with pytest.raises(GeminiError, match="must start with a lowercase letter"): - validate_labels({"Key": "value"}) + """Test that keys starting with uppercase are filtered out""" + result = validate_labels({"Key": "value", "valid": "label"}) + assert result == {"valid": "label"} @pytest.mark.parametrize( "invalid_key", ["key@value", "key.name", "key$", "key with space", "key/name"] ) def test_validate_labels_key_invalid_characters(invalid_key): - """Test that keys with invalid characters are rejected""" - with pytest.raises(GeminiError, match="contains invalid characters"): - validate_labels({invalid_key: "value"}) + """Test that keys with invalid characters are filtered out""" + result = validate_labels({invalid_key: "value", "valid": "label"}) + assert result == {"valid": "label"} @pytest.mark.parametrize( "invalid_value", ["value@", "value.txt", "value$", "value with space", "value/"] ) def test_validate_labels_value_invalid_characters(invalid_value): - """Test that values with invalid characters are rejected""" - with pytest.raises(GeminiError, match="contains invalid characters"): - validate_labels({"key": invalid_value}) + """Test that values with invalid characters are filtered out""" + result = validate_labels({"key": invalid_value, "valid": "label"}) + assert result == {"valid": "label"} def test_validate_labels_max_length_valid(): """Test that keys and values at exactly 63 characters are valid""" max_key = "a" * 63 max_value = "b" * 63 - # Should not raise any exception - validate_labels({max_key: max_value}) + result = validate_labels({max_key: max_value}) + assert result == {max_key: max_value} def test_validate_labels_valid_special_chars(): @@ -102,8 +101,8 @@ def test_validate_labels_valid_special_chars(): "my-key_name": "my-value_123", "key123": "value456", } - # Should not raise any exception - validate_labels(valid_labels) + result = validate_labels(valid_labels) + assert result == valid_labels @patch("genai_utils.gemini.genai.Client") @@ -137,7 +136,7 @@ async def test_run_prompt_with_valid_labels(mock_client): @patch("genai_utils.gemini.genai.Client") async def test_run_prompt_with_invalid_labels(mock_client): - """Test that run_prompt rejects invalid labels""" + """Test that run_prompt filters out invalid labels""" client = Mock(Client) models = Mock(Models) async_client = Mock(AsyncClient) @@ -147,16 +146,22 @@ async def test_run_prompt_with_invalid_labels(mock_client): async_client.models = models mock_client.return_value = client - invalid_labels = {"Invalid": "value"} # uppercase key + invalid_labels = {"Invalid": "value", "valid": "label"} # uppercase key is invalid - with pytest.raises(GeminiError, match="must start with a lowercase letter"): - await run_prompt_async( - "test prompt", - labels=invalid_labels, - model_config=ModelConfig( - project="project", location="location", model_name="model" - ), - ) + await run_prompt_async( + "test prompt", + labels=invalid_labels, + model_config=ModelConfig( + project="project", location="location", model_name="model" + ), + ) + + # Verify the call was made with only valid labels + assert models.generate_content.called + call_kwargs = models.generate_content.call_args[1] + assert "config" in call_kwargs + # The invalid "Invalid" key should be filtered out + assert call_kwargs["config"].labels == {"valid": "label"} @patch("genai_utils.gemini.genai.Client")