Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions test/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,11 @@ def test_caching():
class SampleConfig(Config):
prop1 = key('SampleConfig', 'PROP1')

mock_source: ConfigSource = MagicMock(spec=ConfigSource)
mock_source.get_config_value = MagicMock()
class MockSource(ConfigSource):
get_config_value = MagicMock()

mock_source = MockSource()

s = SampleConfig()
s.add_source(mock_source)

Expand Down
30 changes: 20 additions & 10 deletions typedconfig/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,20 @@ def getter(self: Config) -> T:
resolved_key_name = self._resolve_key_name(getter)
_mutable_state['key_name'] = resolved_key_name

section_hierarchy = [*self._parent_section_hierarchy, resolved_section_name]

# If value is cached, just use the cached value
cached_value = self._provider.get_from_cache(resolved_section_name, resolved_key_name)
cached_value = self._provider.get_from_cache(section_hierarchy, resolved_key_name)
if cached_value is not None:
return cached_value

value = self._provider.get_key(resolved_section_name, resolved_key_name)
value = self._provider.get_key(section_hierarchy, resolved_key_name)

# If we still haven't found a config value and this parameter is required,
# raise an exception, otherwise use the default
if value is None:
if required:
raise KeyError("Config parameter {0}.{1} not found".format(resolved_section_name, resolved_key_name))
raise KeyError(f"Config parameter {''.join(section_hierarchy)}/{resolved_key_name} not found")
else:
value = default

Expand All @@ -86,7 +88,7 @@ def getter(self: Config) -> T:

# Cache this for next time if still not none
if value is not None:
self._provider.add_to_cache(resolved_section_name, resolved_key_name, value)
self._provider.add_to_cache(section_hierarchy, resolved_key_name, value)

return value

Expand All @@ -111,10 +113,16 @@ def group_key(cls: Type[TConfig], group_section_name: str = None, hierarchical:
"""

@property
def wrapped_f(self):
def wrapped_f(self: Config):
if hierarchical:
resolved_section_name = self._resolve_section_name(group_section_name)
parent_section_hierarchy = [*self._parent_section_hierarchy, resolved_section_name]
else:
parent_section_hierarchy = None

attr_name = '_' + self._get_property_name_from_object(wrapped_f)
if not hasattr(self, attr_name):
setattr(self, attr_name, cls(provider=self._provider))
setattr(self, attr_name, cls(provider=self._provider, parent_section_hierarchy=parent_section_hierarchy))
return typing.cast(TConfig, getattr(self, attr_name))

setattr(wrapped_f.fget, Config._composed_config_registration_string, True)
Expand Down Expand Up @@ -143,14 +151,16 @@ class Config:
_config_key_key_name_string = '__config_key_key_name__'
_config_key_section_name_string = '__config_key_section_name__'

def __init__(self, section_name=None, sources: List[ConfigSource] = None,
provider: Optional[ConfigProvider] = None):
def __init__(self, section_name: str = None, sources: List[ConfigSource] = None,
provider: Optional[ConfigProvider] = None,
parent_section_hierarchy: List[str] = None):
if provider is None:
provider = ConfigProvider(sources=sources)
elif not isinstance(provider, ConfigProvider):
raise TypeError("provider must be a ConfigProvider object")
self._section_name = section_name
self._provider: ConfigProvider = provider
self._parent_section_hierarchy: List[str] = parent_section_hierarchy if parent_section_hierarchy is not None else []

def __repr__(self):
key_names = self.get_registered_properties()
Expand Down Expand Up @@ -274,7 +284,7 @@ def _post_read(self, updated_values: dict):

section_name = self._resolve_section_name(getattr(property_object.fget, self._config_key_section_name_string))
key_name = self._resolve_key_name(property_object)
self._provider.add_to_cache(section_name, key_name, v)
self._provider.add_to_cache([*self._parent_section_hierarchy, section_name], key_name, v)

def clear_cache(self):
"""
Expand Down Expand Up @@ -343,4 +353,4 @@ def get_key(self, section_name: str, key_name: str) -> Optional[str]:
-------
value: the loaded config value as a string
"""
return self._provider.get_key(section_name, key_name)
return self._provider.get_key([*self._parent_section_hierarchy, section_name], key_name)
28 changes: 15 additions & 13 deletions typedconfig/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Tuple
from typedconfig.source import ConfigSource

logger = logging.getLogger(__name__)
Expand All @@ -11,21 +11,23 @@ class ConfigProvider:
across the configuration objects
"""
def __init__(self, sources: List[ConfigSource] = None):
self._cache: Dict[str, Dict[str, str]] = {}
self._cache: Dict[Tuple[str], Dict[str, str]] = {}
self._config_sources: List[ConfigSource] = []
if sources is not None:
for source in sources:
self.add_source(source)

def add_to_cache(self, section_name: str, key_name: str, value) -> None:
if section_name not in self._cache:
self._cache[section_name] = {}
self._cache[section_name][key_name] = value
def add_to_cache(self, section_hierarchy: List[str], key_name: str, value) -> None:
hashable_section_hierarchy = tuple(section_hierarchy)
if hashable_section_hierarchy not in self._cache:
self._cache[hashable_section_hierarchy] = {}
self._cache[hashable_section_hierarchy][key_name] = value

def get_from_cache(self, section_name: str, key_name: str):
if section_name not in self._cache:
def get_from_cache(self, section_hierarchy: List[str], key_name: str):
hashable_section_hierarchy = tuple(section_hierarchy)
if hashable_section_hierarchy not in self._cache:
return None
return self._cache[section_name].get(key_name, None)
return self._cache[hashable_section_hierarchy].get(key_name, None)

def clear_cache(self) -> None:
self._cache.clear()
Expand All @@ -34,15 +36,15 @@ def clear_cache(self) -> None:
def config_sources(self) -> List[ConfigSource]:
return self._config_sources

def get_key(self, section_name: str, key_name: str) -> Optional[str]:
def get_key(self, section_hierarchy: List[str], key_name: str) -> Optional[str]:
value = None

# Go through the config sources until we find one which supplies the requested value
for source in self._config_sources:
logger.debug(f'Looking for config value {section_name}/{key_name} in {source}')
value = source.get_config_value(section_name, key_name)
logger.debug(f'Looking for config value {".".join(section_hierarchy)}/{key_name} in {source}')
value = source.get_hierarchical_config_value(section_hierarchy, key_name)
if value is not None:
logger.debug(f'Found config value {section_name}/{key_name} in {source}')
logger.debug(f'Found config value {".".join(section_hierarchy)}/{key_name} in {source}')
break

return value
Expand Down
7 changes: 6 additions & 1 deletion typedconfig/source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from abc import ABC, abstractmethod
from typing import Optional, Dict
from typing import Optional, Dict, List
from configparser import ConfigParser


Expand All @@ -9,6 +9,11 @@ class ConfigSource(ABC):
def get_config_value(self, section_name: str, key_name: str) -> Optional[str]:
raise NotImplementedError()

def get_hierarchical_config_value(self, section_hierarchy: List[str], key_name: str) -> Optional[str]:
# Default implementation only works with hierarchy of the correct depth
assert len(section_hierarchy) == 1
return self.get_config_value(section_hierarchy[0], key_name)

def __repr__(self):
return f'<{self.__class__.__qualname__}>'

Expand Down