Skip to content

Commit 4110107

Browse files
authored
Merge pull request #5 from FullFact/dont-error-on-bad-labels
fix: drop bad billing labels rather than error
2 parents 2c21661 + 9aa486b commit 4110107

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

src/genai_utils/gemini.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -213,50 +213,57 @@ def add_citations(response: types.GenerateContentResponse) -> str:
213213
return text
214214

215215

216-
def validate_labels(labels: dict[str, str]) -> None:
216+
def validate_labels(labels: dict[str, str]) -> dict[str, str]:
217217
"""
218-
Validates labels for GCP requirements.
218+
Validates labels for GCP requirements, removing any labels that would cause GCP to
219+
return an error.
219220
220221
GCP label requirements:
221222
- Keys must start with a lowercase letter
222223
- Keys and values can only contain lowercase letters, numbers, hyphens, and underscores
223224
- Keys and values must be max 63 characters
224225
- Keys cannot be empty
225-
226-
Raises:
227-
GeminiError: If labels don't meet GCP requirements
228226
"""
229227
label_pattern = re.compile(r"^[a-z0-9_-]{1,63}$")
230228
key_start_pattern = re.compile(r"^[a-z]")
231229

230+
valid_labels: dict[str, str] = {}
232231
for key, value in labels.items():
233232
if not key:
234-
raise GeminiError("Label keys cannot be empty")
233+
_logger.warning("Label keys cannot be empty")
234+
continue
235235

236236
if len(key) > 63:
237-
raise GeminiError(
237+
_logger.warning(
238238
f"Label key '{key}' exceeds 63 characters (length: {len(key)})"
239239
)
240+
continue
240241

241242
if len(value) > 63:
242-
raise GeminiError(
243+
_logger.warning(
243244
f"Label value for key '{key}' exceeds 63 characters (length: {len(value)})"
244245
)
246+
continue
245247

246248
if not key_start_pattern.match(key):
247-
raise GeminiError(f"Label key '{key}' must start with a lowercase letter")
249+
_logger.warning(f"Label key '{key}' must start with a lowercase letter")
250+
continue
248251

249252
if not label_pattern.match(key):
250-
raise GeminiError(
253+
_logger.warning(
251254
f"Label key '{key}' contains invalid characters. "
252255
"Only lowercase letters, numbers, hyphens, and underscores are allowed"
253256
)
257+
continue
254258

255259
if not label_pattern.match(value):
256-
raise GeminiError(
260+
_logger.warning(
257261
f"Label value '{value}' for key '{key}' contains invalid characters. "
258262
"Only lowercase letters, numbers, hyphens, and underscores are allowed"
259263
)
264+
continue
265+
valid_labels[key] = value
266+
return valid_labels
260267

261268

262269
def check_grounding_ran(response: types.GenerateContentResponse) -> bool:
@@ -543,8 +550,7 @@ class Movie(BaseModel):
543550

544551
if inline_citations and not use_grounding:
545552
raise GeminiError("Inline citations only work if `use_grounding = True`")
546-
merged_labels = DEFAULT_LABELS | labels
547-
validate_labels(merged_labels)
553+
merged_labels = validate_labels(DEFAULT_LABELS | labels)
548554

549555
response = await client.aio.models.generate_content(
550556
model=model_config.model_name,

tests/genai_utils/test_labels.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from google.genai.models import Models
88

99
from genai_utils.gemini import (
10-
GeminiError,
1110
ModelConfig,
1211
run_prompt_async,
1312
validate_labels,
@@ -24,74 +23,74 @@ async def get_dummy():
2423

2524

2625
def test_validate_labels_valid():
27-
"""Test that valid labels pass validation"""
26+
"""Test that valid labels pass validation and are returned"""
2827
valid_labels = {
2928
"team": "ai",
3029
"project": "genai-utils",
3130
"environment": "production",
3231
"version": "1-2-3",
3332
"my_label": "my_value",
3433
}
35-
# Should not raise any exception
36-
validate_labels(valid_labels)
34+
result = validate_labels(valid_labels)
35+
assert result == valid_labels
3736

3837

3938
def test_validate_labels_empty_key():
40-
"""Test that empty keys are rejected"""
41-
with pytest.raises(GeminiError, match="cannot be empty"):
42-
validate_labels({"": "value"})
39+
"""Test that empty keys are filtered out"""
40+
result = validate_labels({"": "value", "valid": "label"})
41+
assert result == {"valid": "label"}
4342

4443

4544
def test_validate_labels_key_too_long():
46-
"""Test that keys exceeding 63 characters are rejected"""
45+
"""Test that keys exceeding 63 characters are filtered out"""
4746
long_key = "a" * 64
48-
with pytest.raises(GeminiError, match="exceeds 63 characters"):
49-
validate_labels({long_key: "value"})
47+
result = validate_labels({long_key: "value", "valid": "label"})
48+
assert result == {"valid": "label"}
5049

5150

5251
def test_validate_labels_value_too_long():
53-
"""Test that values exceeding 63 characters are rejected"""
52+
"""Test that values exceeding 63 characters are filtered out"""
5453
long_value = "a" * 64
55-
with pytest.raises(GeminiError, match="exceeds 63 characters"):
56-
validate_labels({"key": long_value})
54+
result = validate_labels({"key": long_value, "valid": "label"})
55+
assert result == {"valid": "label"}
5756

5857

5958
def test_validate_labels_key_starts_with_number():
60-
"""Test that keys starting with numbers are rejected"""
61-
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
62-
validate_labels({"1key": "value"})
59+
"""Test that keys starting with numbers are filtered out"""
60+
result = validate_labels({"1key": "value", "valid": "label"})
61+
assert result == {"valid": "label"}
6362

6463

6564
def test_validate_labels_key_starts_with_uppercase():
66-
"""Test that keys starting with uppercase are rejected"""
67-
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
68-
validate_labels({"Key": "value"})
65+
"""Test that keys starting with uppercase are filtered out"""
66+
result = validate_labels({"Key": "value", "valid": "label"})
67+
assert result == {"valid": "label"}
6968

7069

7170
@pytest.mark.parametrize(
7271
"invalid_key", ["key@value", "key.name", "key$", "key with space", "key/name"]
7372
)
7473
def test_validate_labels_key_invalid_characters(invalid_key):
75-
"""Test that keys with invalid characters are rejected"""
76-
with pytest.raises(GeminiError, match="contains invalid characters"):
77-
validate_labels({invalid_key: "value"})
74+
"""Test that keys with invalid characters are filtered out"""
75+
result = validate_labels({invalid_key: "value", "valid": "label"})
76+
assert result == {"valid": "label"}
7877

7978

8079
@pytest.mark.parametrize(
8180
"invalid_value", ["value@", "value.txt", "value$", "value with space", "value/"]
8281
)
8382
def test_validate_labels_value_invalid_characters(invalid_value):
84-
"""Test that values with invalid characters are rejected"""
85-
with pytest.raises(GeminiError, match="contains invalid characters"):
86-
validate_labels({"key": invalid_value})
83+
"""Test that values with invalid characters are filtered out"""
84+
result = validate_labels({"key": invalid_value, "valid": "label"})
85+
assert result == {"valid": "label"}
8786

8887

8988
def test_validate_labels_max_length_valid():
9089
"""Test that keys and values at exactly 63 characters are valid"""
9190
max_key = "a" * 63
9291
max_value = "b" * 63
93-
# Should not raise any exception
94-
validate_labels({max_key: max_value})
92+
result = validate_labels({max_key: max_value})
93+
assert result == {max_key: max_value}
9594

9695

9796
def test_validate_labels_valid_special_chars():
@@ -102,8 +101,8 @@ def test_validate_labels_valid_special_chars():
102101
"my-key_name": "my-value_123",
103102
"key123": "value456",
104103
}
105-
# Should not raise any exception
106-
validate_labels(valid_labels)
104+
result = validate_labels(valid_labels)
105+
assert result == valid_labels
107106

108107

109108
@patch("genai_utils.gemini.genai.Client")
@@ -137,7 +136,7 @@ async def test_run_prompt_with_valid_labels(mock_client):
137136

138137
@patch("genai_utils.gemini.genai.Client")
139138
async def test_run_prompt_with_invalid_labels(mock_client):
140-
"""Test that run_prompt rejects invalid labels"""
139+
"""Test that run_prompt filters out invalid labels"""
141140
client = Mock(Client)
142141
models = Mock(Models)
143142
async_client = Mock(AsyncClient)
@@ -147,16 +146,22 @@ async def test_run_prompt_with_invalid_labels(mock_client):
147146
async_client.models = models
148147
mock_client.return_value = client
149148

150-
invalid_labels = {"Invalid": "value"} # uppercase key
149+
invalid_labels = {"Invalid": "value", "valid": "label"} # uppercase key is invalid
151150

152-
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
153-
await run_prompt_async(
154-
"test prompt",
155-
labels=invalid_labels,
156-
model_config=ModelConfig(
157-
project="project", location="location", model_name="model"
158-
),
159-
)
151+
await run_prompt_async(
152+
"test prompt",
153+
labels=invalid_labels,
154+
model_config=ModelConfig(
155+
project="project", location="location", model_name="model"
156+
),
157+
)
158+
159+
# Verify the call was made with only valid labels
160+
assert models.generate_content.called
161+
call_kwargs = models.generate_content.call_args[1]
162+
assert "config" in call_kwargs
163+
# The invalid "Invalid" key should be filtered out
164+
assert call_kwargs["config"].labels == {"valid": "label"}
160165

161166

162167
@patch("genai_utils.gemini.genai.Client")

0 commit comments

Comments
 (0)