Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to KeyLookup #1527

Merged
merged 7 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 47 additions & 20 deletions edsl/language_models/key_management/KeyLookupBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ class KeyLookupBuilder:
DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))

def __init__(self, fetch_order: Optional[tuple[str]] = None):
def __init__(
self,
fetch_order: Optional[tuple[str]] = None,
coop: Optional["Coop"] = None,
):
from edsl.coop import Coop

# Fetch order goes from lowest priority to highest priority
if fetch_order is None:
self.fetch_order = ("config", "env")
else:
Expand All @@ -70,6 +77,11 @@ def __init__(self, fetch_order: Optional[tuple[str]] = None):
if not isinstance(self.fetch_order, tuple):
raise ValueError("fetch_order must be a tuple")

if coop is None:
self.coop = Coop()
else:
self.coop = coop

self.limit_data = {}
self.key_data = {}
self.id_data = {}
Expand Down Expand Up @@ -131,7 +143,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput:
service=service,
rpm=self.DEFAULT_RPM,
tpm=self.DEFAULT_TPM,
source="default",
rpm_source="default",
tpm_source="default",
)

if limit_entry.rpm is None:
Expand All @@ -145,7 +158,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput:
tpm=int(limit_entry.tpm),
api_id=api_id,
token_source=api_key_entry.source,
limit_source=limit_entry.source,
rpm_source=limit_entry.rpm_source,
tpm_source=limit_entry.tpm_source,
id_source=id_source,
)

Expand All @@ -156,10 +170,7 @@ def _os_env_key_value_pairs(self):
return dict(list(os.environ.items()))

def _coop_key_value_pairs(self):
from edsl.coop import Coop

c = Coop()
return dict(list(c.fetch_rate_limit_config_vars().items()))
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))

def _config_key_value_pairs(self):
from edsl.config import CONFIG
Expand All @@ -169,7 +180,7 @@ def _config_key_value_pairs(self):
@staticmethod
def extract_service(key: str) -> str:
"""Extract the service and limit type from the key"""
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
return service_raw.lower(), limit_type.lower()

def get_key_value_pairs(self) -> dict:
Expand All @@ -187,17 +198,17 @@ def get_key_value_pairs(self) -> dict:
d[k] = (v, source)
return d

def _entry_type(self, key, value) -> str:
def _entry_type(self, key: str) -> str:
"""Determine the type of entry from a key.

>>> builder = KeyLookupBuilder()
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI", "60")
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
'limit'
>>> builder._entry_type("OPENAI_API_KEY", "sk-1234")
>>> builder._entry_type("OPENAI_API_KEY")
'api_key'
>>> builder._entry_type("AWS_ACCESS_KEY_ID", "AKIA1234")
>>> builder._entry_type("AWS_ACCESS_KEY_ID")
'api_id'
>>> builder._entry_type("UNKNOWN_KEY", "value")
>>> builder._entry_type("UNKNOWN_KEY")
'unknown'
"""
if key.startswith("EDSL_SERVICE_"):
Expand Down Expand Up @@ -243,11 +254,13 @@ def _add_limit(self, key: str, value: str, source: str) -> None:
service, limit_type = self.extract_service(key)
if service in self.limit_data:
setattr(self.limit_data[service], limit_type.lower(), value)
setattr(self.limit_data[service], f"{limit_type}_source", source)
else:
new_limit_entry = LimitEntry(
service=service, rpm=None, tpm=None, source=source
service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
)
setattr(new_limit_entry, limit_type.lower(), value)
setattr(new_limit_entry, f"{limit_type}_source", source)
self.limit_data[service] = new_limit_entry

def _add_api_key(self, key: str, value: str, source: str) -> None:
Expand All @@ -265,13 +278,27 @@ def _add_api_key(self, key: str, value: str, source: str) -> None:
else:
self.key_data[service].append(new_entry)

def process_key_value_pairs(self) -> None:
"""Process all key-value pairs from the configured sources."""
for key, value_pair in self.get_key_value_pairs().items():
def update_from_dict(self, d: dict) -> None:
"""
Update data from a dictionary of key-value pairs.
Each key is a key name, and each value is a tuple of (value, source).

>>> builder = KeyLookupBuilder()
>>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
>>> 'sk-1234' == builder.key_data["openai"][-1].value
True
>>> 'custodial_keys' == builder.key_data["openai"][-1].source
True
"""
for key, value_pair in d.items():
value, source = value_pair
if (entry_type := self._entry_type(key, value)) == "limit":
if self._entry_type(key) == "limit":
self._add_limit(key, value, source)
elif entry_type == "api_key":
elif self._entry_type(key) == "api_key":
self._add_api_key(key, value, source)
elif entry_type == "api_id":
elif self._entry_type(key) == "api_id":
self._add_id(key, value, source)

def process_key_value_pairs(self) -> None:
"""Process all key-value pairs from the configured sources."""
self.update_from_dict(self.get_key_value_pairs())
14 changes: 10 additions & 4 deletions edsl/language_models/key_management/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,23 @@ class LimitEntry:
60
>>> limit.tpm
100000
>>> limit.source
>>> limit.rpm_source
'config'
>>> limit.tpm_source
'env'
"""

service: str
rpm: int
tpm: int
source: Optional[str] = None
rpm_source: Optional[str] = None
tpm_source: Optional[str] = None

@classmethod
def example(cls):
return LimitEntry(service="openai", rpm=60, tpm=100000, source="config")
return LimitEntry(
service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
)


@dataclass
Expand Down Expand Up @@ -108,7 +113,8 @@ class LanguageModelInput:
tpm: int
api_id: Optional[str] = None
token_source: Optional[str] = None
limit_source: Optional[str] = None
rpm_source: Optional[str] = None
tpm_source: Optional[str] = None
id_source: Optional[str] = None

def to_dict(self):
Expand Down
31 changes: 29 additions & 2 deletions tests/language_models/test_KeyLookupBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_invalid_fetch_order():
)
def test_entry_type_detection(builder, key, expected_type):
"""Test correct detection of entry types"""
entry_type = builder._entry_type(key, "dummy-value")
entry_type = builder._entry_type(key)
assert entry_type == expected_type


Expand Down Expand Up @@ -100,7 +100,9 @@ def test_get_language_model_input(builder):
]
}
builder.limit_data = {
"test": LimitEntry(service="test", rpm=10, tpm=2000000, source="env")
"test": LimitEntry(
service="test", rpm=10, tpm=2000000, rpm_source="env", tpm_source="env"
)
}
builder.id_data = {
"test": APIIDEntry(
Expand Down Expand Up @@ -160,6 +162,31 @@ def test_build_method():
assert "test" in result # Default test service should always be present


def test_update_from_dict(mock_env_vars):
"""Test fetching key-value pairs from environment"""
with patch.dict("os.environ", mock_env_vars, clear=True):
builder = KeyLookupBuilder(fetch_order=("env",))

assert builder.key_data["openai"][-1].value == "test-openai-key"
assert builder.key_data["openai"][-1].source == "env"

assert builder.limit_data["openai"].rpm == "20"
assert builder.limit_data["openai"].rpm_source == "env"

builder.update_from_dict(
{
"OPENAI_API_KEY": ("sk-1234", "custodial_keys"),
"EDSL_SERVICE_RPM_OPENAI": ("40", "custodial_keys"),
}
)

assert builder.key_data["openai"][-1].value == "sk-1234"
assert builder.key_data["openai"][-1].source == "custodial_keys"

assert builder.limit_data["openai"].rpm == "40"
assert builder.limit_data["openai"].rpm_source == "custodial_keys"


def test_duplicate_id_handling():
"""Test handling of duplicate API IDs"""
builder = KeyLookupBuilder()
Expand Down
Loading