Skip to content

Commit 6abaca0

Browse files
committed
[ENH] For chroma cloud efs, extract api key from header if available to authenticate
1 parent cbe2c4f commit 6abaca0

File tree

4 files changed

+235
-4
lines changed

4 files changed

+235
-4
lines changed

chromadb/api/shared_system_client.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, Dict
1+
from typing import ClassVar, Dict, Optional
22
import uuid
33

44
from chromadb.api import ServerAPI
@@ -94,3 +94,50 @@ def _system(self) -> System:
9494
def _submit_client_start_event(self) -> None:
9595
telemetry_client = self._system.instance(ProductTelemetryClient)
9696
telemetry_client.capture(ClientStartEvent())
97+
98+
@staticmethod
99+
def get_chroma_cloud_api_key_from_clients() -> Optional[str]:
100+
"""
101+
try to extract api key from existing client instaces by checking httpx session headers
102+
if available.
103+
104+
Requirements to pull api key:
105+
- must be a FastAPI instance (ignore RustBindingsAPI and SegmentAPI)
106+
- must have a "api.trychroma.com" in the _api_url (ignore local/self-hosted instances)
107+
- must have "x-chroma-token" or "X-Chroma-Token" in the headers
108+
109+
Returns:
110+
The first api key found, or None if no client instances have api keys set.
111+
"""
112+
# check FastAPI instance session headers bc this is where both cloudclient and httpclient paths converge
113+
for system in SharedSystemClient._identifier_to_system.values():
114+
try:
115+
# get the ServerAPI instance (which is FastAPI for HTTP clients)
116+
server_api = system.instance(ServerAPI)
117+
118+
# check if it's a FastAPI instance with a _session attribute
119+
# RustBindingsAPI and SegmentAPI don't have a session attribute
120+
if hasattr(server_api, "_session") and hasattr(
121+
server_api._session, "headers"
122+
):
123+
# only pull api key if the url contains the chroma cloud url
124+
if (
125+
not hasattr(server_api, "_api_url")
126+
or "api.trychroma.com" not in server_api._api_url
127+
):
128+
continue
129+
130+
# pull api key from the chroma token header
131+
headers = server_api._session.headers
132+
api_key = headers.get("X-Chroma-Token") or headers.get(
133+
"x-chroma-token"
134+
)
135+
if api_key:
136+
# header value might be a string or bytes, convert to string
137+
return str(api_key)
138+
except Exception:
139+
# if we can't access the ServerAPI instance or it doesn't have _session,
140+
# continue to the next system instance
141+
continue
142+
143+
return None
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import pytest
2+
from unittest.mock import MagicMock
3+
from chromadb.api.shared_system_client import SharedSystemClient
4+
from chromadb.config import System
5+
from chromadb.api import ServerAPI
6+
from typing import Optional, Dict, Generator
7+
8+
9+
@pytest.fixture(autouse=True)
10+
def clear_cache() -> Generator[None, None, None]:
11+
"""Automatically clear the system cache before and after each test."""
12+
SharedSystemClient.clear_system_cache()
13+
yield
14+
SharedSystemClient.clear_system_cache()
15+
16+
17+
def create_mock_server_api(
18+
api_url: Optional[str] = None,
19+
headers: Optional[Dict[str, str]] = None,
20+
has_session: bool = True,
21+
has_headers_attr: bool = True,
22+
) -> MagicMock:
23+
"""Create a mock ServerAPI instance with the specified configuration."""
24+
mock_server_api = MagicMock(spec=ServerAPI)
25+
26+
if api_url:
27+
mock_server_api._api_url = api_url
28+
29+
if has_session:
30+
mock_session = MagicMock()
31+
if has_headers_attr:
32+
mock_session.headers = headers or {}
33+
else:
34+
# Create a mock without headers attribute
35+
del mock_session.headers
36+
mock_server_api._session = mock_session
37+
else:
38+
if hasattr(mock_server_api, "_session"):
39+
del mock_server_api._session
40+
41+
return mock_server_api
42+
43+
44+
def register_mock_system(system_id: str, mock_server_api: MagicMock) -> MagicMock:
45+
"""Register a mock system with the given ID and server API."""
46+
mock_system = MagicMock(spec=System)
47+
mock_system.instance.return_value = mock_server_api
48+
SharedSystemClient._identifier_to_system[system_id] = mock_system
49+
return mock_system
50+
51+
52+
def test_extracts_api_key_from_chroma_cloud_client() -> None:
53+
mock_server_api = create_mock_server_api(
54+
api_url="https://api.trychroma.com/api/v2",
55+
headers={"X-Chroma-Token": "test-api-key-123"},
56+
)
57+
register_mock_system("test-id", mock_server_api)
58+
59+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
60+
61+
assert api_key == "test-api-key-123"
62+
63+
64+
def test_extracts_api_key_with_lowercase_header() -> None:
65+
mock_server_api = create_mock_server_api(
66+
api_url="https://api.trychroma.com/api/v2",
67+
headers={"x-chroma-token": "test-api-key-456"},
68+
)
69+
register_mock_system("test-id", mock_server_api)
70+
71+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
72+
73+
assert api_key == "test-api-key-456"
74+
75+
76+
def test_skips_non_chroma_cloud_clients() -> None:
77+
mock_server_api = create_mock_server_api(
78+
api_url="https://localhost:8000/api/v2",
79+
headers={"X-Chroma-Token": "local-api-key"},
80+
)
81+
register_mock_system("test-id", mock_server_api)
82+
83+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
84+
85+
assert api_key is None
86+
87+
88+
def test_skips_clients_without_session() -> None:
89+
mock_server_api = create_mock_server_api(
90+
api_url="https://api.trychroma.com/api/v2",
91+
has_session=False,
92+
)
93+
register_mock_system("test-id", mock_server_api)
94+
95+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
96+
97+
assert api_key is None
98+
99+
100+
def test_skips_clients_without_api_url() -> None:
101+
mock_server_api = create_mock_server_api(
102+
api_url=None,
103+
headers={"X-Chroma-Token": "test-api-key"},
104+
)
105+
register_mock_system("test-id", mock_server_api)
106+
107+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
108+
109+
assert api_key is None
110+
111+
112+
def test_returns_none_when_no_api_key_in_headers() -> None:
113+
mock_server_api = create_mock_server_api(
114+
api_url="https://api.trychroma.com/api/v2",
115+
headers={},
116+
)
117+
register_mock_system("test-id", mock_server_api)
118+
119+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
120+
121+
assert api_key is None
122+
123+
124+
def test_returns_first_api_key_found_from_multiple_clients() -> None:
125+
mock_server_api_1 = create_mock_server_api(
126+
api_url="https://api.trychroma.com/api/v2",
127+
headers={"X-Chroma-Token": "first-key"},
128+
)
129+
mock_server_api_2 = create_mock_server_api(
130+
api_url="https://api.trychroma.com/api/v2",
131+
headers={"X-Chroma-Token": "second-key"},
132+
)
133+
register_mock_system("test-id-1", mock_server_api_1)
134+
register_mock_system("test-id-2", mock_server_api_2)
135+
136+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
137+
138+
assert api_key == "first-key"
139+
140+
141+
def test_handles_exception_gracefully() -> None:
142+
mock_system = MagicMock(spec=System)
143+
mock_system.instance.side_effect = Exception("Test exception")
144+
SharedSystemClient._identifier_to_system["test-id"] = mock_system
145+
146+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
147+
148+
assert api_key is None
149+
150+
151+
def test_returns_none_when_no_clients_exist() -> None:
152+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
153+
154+
assert api_key is None
155+
156+
157+
def test_skips_chroma_cloud_client_without_headers_attribute() -> None:
158+
mock_server_api = create_mock_server_api(
159+
api_url="https://api.trychroma.com/api/v2",
160+
has_headers_attr=False,
161+
)
162+
register_mock_system("test-id", mock_server_api)
163+
164+
api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
165+
166+
assert api_key is None

chromadb/utils/embedding_functions/chroma_cloud_qwen_embedding_function.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,19 @@ def __init__(
5656
)
5757

5858
self.api_key_env_var = api_key_env_var
59+
# First, try to get API key from environment variable
5960
self.api_key = os.getenv(api_key_env_var)
61+
# If not found in env var, try to get it from existing client instances
6062
if not self.api_key:
61-
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
63+
from chromadb.api.shared_system_client import SharedSystemClient
64+
65+
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
66+
# Raise error if still no API key found
67+
if not self.api_key:
68+
raise ValueError(
69+
f"API key not found in environment variable {api_key_env_var} "
70+
f"or in any existing client instances"
71+
)
6272

6373
self.model = model
6474
self.task = task

chromadb/utils/embedding_functions/chroma_cloud_splade_embedding_function.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from enum import Enum
88
from chromadb.utils.embedding_functions.schemas import validate_config_schema
99
from chromadb.utils.sparse_embedding_utils import normalize_sparse_vector
10-
from chromadb.base_types import SparseVector
1110
import os
1211
from typing import Union
1312

@@ -36,10 +35,19 @@ def __init__(
3635
"The httpx python package is not installed. Please install it with `pip install httpx`"
3736
)
3837
self.api_key_env_var = api_key_env_var
38+
# First, try to get API key from environment variable
3939
self.api_key = os.getenv(self.api_key_env_var)
40+
# If not found in env var, try to get it from existing client instances
41+
if not self.api_key:
42+
# Import here to avoid circular import
43+
from chromadb.api.shared_system_client import SharedSystemClient
44+
45+
self.api_key = SharedSystemClient.get_chroma_cloud_api_key_from_clients()
46+
# Raise error if still no API key found
4047
if not self.api_key:
4148
raise ValueError(
42-
f"API key not found in environment variable {self.api_key_env_var}"
49+
f"API key not found in environment variable {self.api_key_env_var} "
50+
f"or in any existing client instances"
4351
)
4452
self.model = model
4553
self._api_url = "https://embed.trychroma.com/embed_sparse"

0 commit comments

Comments
 (0)