diff --git a/edsl/language_models/key_management/KeyLookupBuilder.py b/edsl/language_models/key_management/KeyLookupBuilder.py index a0aa52fe..b5a61e37 100644 --- a/edsl/language_models/key_management/KeyLookupBuilder.py +++ b/edsl/language_models/key_management/KeyLookupBuilder.py @@ -61,7 +61,14 @@ class KeyLookupBuilder: DEFAULT_RPM = int(CONFIG.get("EDSL_SERVICE_RPM_BASELINE")) DEFAULT_TPM = int(CONFIG.get("EDSL_SERVICE_TPM_BASELINE")) - def __init__(self, fetch_order: Optional[tuple[str]] = None): + def __init__( + self, + fetch_order: Optional[tuple[str]] = None, + coop: Optional["Coop"] = None, + ): + from edsl.coop import Coop + + # Fetch order goes from lowest priority to highest priority if fetch_order is None: self.fetch_order = ("config", "env") else: @@ -70,6 +77,11 @@ def __init__(self, fetch_order: Optional[tuple[str]] = None): if not isinstance(self.fetch_order, tuple): raise ValueError("fetch_order must be a tuple") + if coop is None: + self.coop = Coop() + else: + self.coop = coop + self.limit_data = {} self.key_data = {} self.id_data = {} @@ -131,7 +143,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput: service=service, rpm=self.DEFAULT_RPM, tpm=self.DEFAULT_TPM, - source="default", + rpm_source="default", + tpm_source="default", ) if limit_entry.rpm is None: @@ -145,7 +158,8 @@ def get_language_model_input(self, service: str) -> LanguageModelInput: tpm=int(limit_entry.tpm), api_id=api_id, token_source=api_key_entry.source, - limit_source=limit_entry.source, + rpm_source=limit_entry.rpm_source, + tpm_source=limit_entry.tpm_source, id_source=id_source, ) @@ -156,10 +170,7 @@ def _os_env_key_value_pairs(self): return dict(list(os.environ.items())) def _coop_key_value_pairs(self): - from edsl.coop import Coop - - c = Coop() - return dict(list(c.fetch_rate_limit_config_vars().items())) + return dict(list(self.coop.fetch_rate_limit_config_vars().items())) def _config_key_value_pairs(self): from edsl.config import CONFIG @@ -169,7 +180,7 @@ def _config_key_value_pairs(self): @staticmethod def extract_service(key: str) -> str: """Extract the service and limit type from the key""" - limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_") + limit_type, service_raw = key.replace("EDSL_SERVICE_", "").split("_", 1) return service_raw.lower(), limit_type.lower() def get_key_value_pairs(self) -> dict: @@ -187,17 +198,17 @@ def get_key_value_pairs(self) -> dict: d[k] = (v, source) return d - def _entry_type(self, key, value) -> str: + def _entry_type(self, key: str) -> str: """Determine the type of entry from a key. >>> builder = KeyLookupBuilder() - >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI", "60") + >>> builder._entry_type("EDSL_SERVICE_RPM_OPENAI") 'limit' - >>> builder._entry_type("OPENAI_API_KEY", "sk-1234") + >>> builder._entry_type("OPENAI_API_KEY") 'api_key' - >>> builder._entry_type("AWS_ACCESS_KEY_ID", "AKIA1234") + >>> builder._entry_type("AWS_ACCESS_KEY_ID") 'api_id' - >>> builder._entry_type("UNKNOWN_KEY", "value") + >>> builder._entry_type("UNKNOWN_KEY") 'unknown' """ if key.startswith("EDSL_SERVICE_"): @@ -243,11 +254,13 @@ def _add_limit(self, key: str, value: str, source: str) -> None: service, limit_type = self.extract_service(key) if service in self.limit_data: setattr(self.limit_data[service], limit_type.lower(), value) + setattr(self.limit_data[service], f"{limit_type}_source", source) else: new_limit_entry = LimitEntry( - service=service, rpm=None, tpm=None, source=source + service=service, rpm=None, tpm=None, rpm_source=None, tpm_source=None ) setattr(new_limit_entry, limit_type.lower(), value) + setattr(new_limit_entry, f"{limit_type}_source", source) self.limit_data[service] = new_limit_entry def _add_api_key(self, key: str, value: str, source: str) -> None: @@ -265,13 +278,27 @@ def _add_api_key(self, key: str, value: str, source: str) -> None: else: self.key_data[service].append(new_entry) - def process_key_value_pairs(self) -> None: - """Process all key-value pairs from the configured sources.""" - for key, value_pair in self.get_key_value_pairs().items(): + def update_from_dict(self, d: dict) -> None: + """ + Update data from a dictionary of key-value pairs. + Each key is a key name, and each value is a tuple of (value, source). + + >>> builder = KeyLookupBuilder() + >>> builder.update_from_dict({"OPENAI_API_KEY": ("sk-1234", "custodial_keys")}) + >>> 'sk-1234' == builder.key_data["openai"][-1].value + True + >>> 'custodial_keys' == builder.key_data["openai"][-1].source + True + """ + for key, value_pair in d.items(): value, source = value_pair - if (entry_type := self._entry_type(key, value)) == "limit": + if self._entry_type(key) == "limit": self._add_limit(key, value, source) - elif entry_type == "api_key": + elif self._entry_type(key) == "api_key": self._add_api_key(key, value, source) - elif entry_type == "api_id": + elif self._entry_type(key) == "api_id": self._add_id(key, value, source) + + def process_key_value_pairs(self) -> None: + """Process all key-value pairs from the configured sources.""" + self.update_from_dict(self.get_key_value_pairs()) diff --git a/edsl/language_models/key_management/models.py b/edsl/language_models/key_management/models.py index bf96abd8..933bec57 100644 --- a/edsl/language_models/key_management/models.py +++ b/edsl/language_models/key_management/models.py @@ -40,18 +40,23 @@ class LimitEntry: 60 >>> limit.tpm 100000 - >>> limit.source + >>> limit.rpm_source 'config' + >>> limit.tpm_source + 'env' """ service: str rpm: int tpm: int - source: Optional[str] = None + rpm_source: Optional[str] = None + tpm_source: Optional[str] = None @classmethod def example(cls): - return LimitEntry(service="openai", rpm=60, tpm=100000, source="config") + return LimitEntry( + service="openai", rpm=60, tpm=100000, rpm_source="config", tpm_source="env" + ) @dataclass @@ -108,7 +113,8 @@ class LanguageModelInput: tpm: int api_id: Optional[str] = None token_source: Optional[str] = None - limit_source: Optional[str] = None + rpm_source: Optional[str] = None + tpm_source: Optional[str] = None id_source: Optional[str] = None def to_dict(self): diff --git a/tests/language_models/test_KeyLookupBuilder.py b/tests/language_models/test_KeyLookupBuilder.py index 7dd67d3b..3cb6ede2 100644 --- a/tests/language_models/test_KeyLookupBuilder.py +++ b/tests/language_models/test_KeyLookupBuilder.py @@ -55,7 +55,7 @@ def test_invalid_fetch_order(): ) def test_entry_type_detection(builder, key, expected_type): """Test correct detection of entry types""" - entry_type = builder._entry_type(key, "dummy-value") + entry_type = builder._entry_type(key) assert entry_type == expected_type @@ -100,7 +100,9 @@ def test_get_language_model_input(builder): ] } builder.limit_data = { - "test": LimitEntry(service="test", rpm=10, tpm=2000000, source="env") + "test": LimitEntry( + service="test", rpm=10, tpm=2000000, rpm_source="env", tpm_source="env" + ) } builder.id_data = { "test": APIIDEntry( @@ -160,6 +162,31 @@ def test_build_method(): assert "test" in result # Default test service should always be present +def test_update_from_dict(mock_env_vars): + """Test fetching key-value pairs from environment""" + with patch.dict("os.environ", mock_env_vars, clear=True): + builder = KeyLookupBuilder(fetch_order=("env",)) + + assert builder.key_data["openai"][-1].value == "test-openai-key" + assert builder.key_data["openai"][-1].source == "env" + + assert builder.limit_data["openai"].rpm == "20" + assert builder.limit_data["openai"].rpm_source == "env" + + builder.update_from_dict( + { + "OPENAI_API_KEY": ("sk-1234", "custodial_keys"), + "EDSL_SERVICE_RPM_OPENAI": ("40", "custodial_keys"), + } + ) + + assert builder.key_data["openai"][-1].value == "sk-1234" + assert builder.key_data["openai"][-1].source == "custodial_keys" + + assert builder.limit_data["openai"].rpm == "40" + assert builder.limit_data["openai"].rpm_source == "custodial_keys" + + def test_duplicate_id_handling(): """Test handling of duplicate API IDs""" builder = KeyLookupBuilder()