Skip to content

Commit

Permalink
Merge pull request #1527 from expectedparrot/custodial_keys
Browse files Browse the repository at this point in the history
Updates to KeyLookup
  • Loading branch information
rbyh authored Jan 25, 2025
2 parents 7e64fa3 + fbf0adb commit 25937d8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 26 deletions.
67 changes: 47 additions & 20 deletions edsl/language_models/key_management/KeyLookupBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ class KeyLookupBuilder:
DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE"))
DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE"))

def __init__(self, fetch_order: Optional[tuple[str]] = None):
def __init__(
self,
fetch_order: Optional[tuple[str]] = None,
coop: Optional["Coop"] = None,
):
from edsl.coop import Coop

# Fetch order goes from lowest priority to highest priority
if fetch_order is None:
self.fetch_order = ("config", "env")
else:
Expand All @@ -70,6 +77,11 @@ def __init__(self, fetch_order: Optional[tuple[str]] = None):
if not isinstance(self.fetch_order, tuple):
raise ValueError("fetch_order must be a tuple")

if coop is None:
self.coop = Coop()
else:
self.coop = coop

self.limit_data = {}
self.key_data = {}
self.id_data = {}
Expand Down Expand Up @@ -131,7 +143,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput:
service=service,
rpm=self.DEFAULT_RPM,
tpm=self.DEFAULT_TPM,
source="default",
rpm_source="default",
tpm_source="default",
)

if limit_entry.rpm is None:
Expand All @@ -145,7 +158,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput:
tpm=int(limit_entry.tpm),
api_id=api_id,
token_source=api_key_entry.source,
limit_source=limit_entry.source,
rpm_source=limit_entry.rpm_source,
tpm_source=limit_entry.tpm_source,
id_source=id_source,
)

Expand All @@ -156,10 +170,7 @@ def _os_env_key_value_pairs(self):
return dict(list(os.environ.items()))

def _coop_key_value_pairs(self):
from edsl.coop import Coop

c = Coop()
return dict(list(c.fetch_rate_limit_config_vars().items()))
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))

def _config_key_value_pairs(self):
from edsl.config import CONFIG
Expand All @@ -169,7 +180,7 @@ def _config_key_value_pairs(self):
@staticmethod
def extract_service(key: str) -> str:
"""Extract the service and limit type from the key"""
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_")
limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1)
return service_raw.lower(), limit_type.lower()

def get_key_value_pairs(self) -> dict:
Expand All @@ -187,17 +198,17 @@ def get_key_value_pairs(self) -> dict:
d[k] = (v, source)
return d

def _entry_type(self, key, value) -> str:
def _entry_type(self, key: str) -> str:
"""Determine the type of entry from a key.
>>> builder = KeyLookupBuilder()
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI", "60")
>>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI")
'limit'
>>> builder._entry_type("OPENAI_API_KEY", "sk-1234")
>>> builder._entry_type("OPENAI_API_KEY")
'api_key'
>>> builder._entry_type("AWS_ACCESS_KEY_ID", "AKIA1234")
>>> builder._entry_type("AWS_ACCESS_KEY_ID")
'api_id'
>>> builder._entry_type("UNKNOWN_KEY", "value")
>>> builder._entry_type("UNKNOWN_KEY")
'unknown'
"""
if key.startswith("EDSL_SERVICE_"):
Expand Down Expand Up @@ -243,11 +254,13 @@ def _add_limit(self, key: str, value: str, source: str) -> None:
service, limit_type = self.extract_service(key)
if service in self.limit_data:
setattr(self.limit_data[service], limit_type.lower(), value)
setattr(self.limit_data[service], f"{limit_type}_source", source)
else:
new_limit_entry = LimitEntry(
service=service, rpm=None, tpm=None, source=source
service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None
)
setattr(new_limit_entry, limit_type.lower(), value)
setattr(new_limit_entry, f"{limit_type}_source", source)
self.limit_data[service] = new_limit_entry

def _add_api_key(self, key: str, value: str, source: str) -> None:
Expand All @@ -265,13 +278,27 @@ def _add_api_key(self, key: str, value: str, source: str) -> None:
else:
self.key_data[service].append(new_entry)

def process_key_value_pairs(self) -> None:
"""Process all key-value pairs from the configured sources."""
for key, value_pair in self.get_key_value_pairs().items():
def update_from_dict(self, d: dict) -> None:
"""
Update data from a dictionary of key-value pairs.
Each key is a key name, and each value is a tuple of (value, source).
>>> builder = KeyLookupBuilder()
>>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")})
>>> 'sk-1234' == builder.key_data["openai"][-1].value
True
>>> 'custodial_keys' == builder.key_data["openai"][-1].source
True
"""
for key, value_pair in d.items():
value, source = value_pair
if (entry_type := self._entry_type(key, value)) == "limit":
if self._entry_type(key) == "limit":
self._add_limit(key, value, source)
elif entry_type == "api_key":
elif self._entry_type(key) == "api_key":
self._add_api_key(key, value, source)
elif entry_type == "api_id":
elif self._entry_type(key) == "api_id":
self._add_id(key, value, source)

def process_key_value_pairs(self) -> None:
"""Process all key-value pairs from the configured sources."""
self.update_from_dict(self.get_key_value_pairs())
14 changes: 10 additions & 4 deletions edsl/language_models/key_management/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,23 @@ class LimitEntry:
60
>>> limit.tpm
100000
>>> limit.source
>>> limit.rpm_source
'config'
>>> limit.tpm_source
'env'
"""

service: str
rpm: int
tpm: int
source: Optional[str] = None
rpm_source: Optional[str] = None
tpm_source: Optional[str] = None

@classmethod
def example(cls):
return LimitEntry(service="openai", rpm=60, tpm=100000, source="config")
return LimitEntry(
service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env"
)


@dataclass
Expand Down Expand Up @@ -108,7 +113,8 @@ class LanguageModelInput:
tpm: int
api_id: Optional[str] = None
token_source: Optional[str] = None
limit_source: Optional[str] = None
rpm_source: Optional[str] = None
tpm_source: Optional[str] = None
id_source: Optional[str] = None

def to_dict(self):
Expand Down
31 changes: 29 additions & 2 deletions tests/language_models/test_KeyLookupBuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_invalid_fetch_order():
)
def test_entry_type_detection(builder, key, expected_type):
"""Test correct detection of entry types"""
entry_type = builder._entry_type(key, "dummy-value")
entry_type = builder._entry_type(key)
assert entry_type == expected_type


Expand Down Expand Up @@ -100,7 +100,9 @@ def test_get_language_model_input(builder):
]
}
builder.limit_data = {
"test": LimitEntry(service="test", rpm=10, tpm=2000000, source="env")
"test": LimitEntry(
service="test", rpm=10, tpm=2000000, rpm_source="env", tpm_source="env"
)
}
builder.id_data = {
"test": APIIDEntry(
Expand Down Expand Up @@ -160,6 +162,31 @@ def test_build_method():
assert "test" in result # Default test service should always be present


def test_update_from_dict(mock_env_vars):
"""Test fetching key-value pairs from environment"""
with patch.dict("os.environ", mock_env_vars, clear=True):
builder = KeyLookupBuilder(fetch_order=("env",))

assert builder.key_data["openai"][-1].value == "test-openai-key"
assert builder.key_data["openai"][-1].source == "env"

assert builder.limit_data["openai"].rpm == "20"
assert builder.limit_data["openai"].rpm_source == "env"

builder.update_from_dict(
{
"OPENAI_API_KEY": ("sk-1234", "custodial_keys"),
"EDSL_SERVICE_RPM_OPENAI": ("40", "custodial_keys"),
}
)

assert builder.key_data["openai"][-1].value == "sk-1234"
assert builder.key_data["openai"][-1].source == "custodial_keys"

assert builder.limit_data["openai"].rpm == "40"
assert builder.limit_data["openai"].rpm_source == "custodial_keys"


def test_duplicate_id_handling():
"""Test handling of duplicate API IDs"""
builder = KeyLookupBuilder()
Expand Down

0 comments on commit 25937d8

Please sign in to comment.