Skip to content

Commit

Permalink
continue reg matching on al extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Dec 15, 2023
1 parent 9c3bc88 commit 0711f8b
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 15 deletions.
71 changes: 66 additions & 5 deletions pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,9 @@ def find(self, term):
"""
res_dict = dict()
try:
index_find = ioutil.index_contains(self.get_index(as_dict=False), find_all=term)
index_find = ioutil.index_contains(
self.get_index(as_dict=False), find_all=term
)
if len(index_find) > 0:
res_dict["index"] = index_find
except: # noqa: E722
Expand All @@ -759,8 +761,8 @@ def find(self, term):
try:
for ext in self.get_extensions(data=False):
ext_index_find = ioutil.index_contains(
getattr(self, ext).get_index(as_dict=False),
find_all=term)
getattr(self, ext).get_index(as_dict=False), find_all=term
)
if len(ext_index_find) > 0:
res_dict[ext + "_index"] = ext_index_find
except: # noqa: E722
Expand Down Expand Up @@ -802,7 +804,6 @@ def contains(self, find_all=None, **kwargs):
"""
return ioutil.index_contains(self.get_index(as_dict=False), find_all, **kwargs)


def match(self, find_all=None, **kwargs):
"""Check if index of the system match the regex pattern
Expand Down Expand Up @@ -872,7 +873,6 @@ def fullmatch(self, find_all=None, **kwargs):
return ioutil.index_fullmatch(self.get_index(as_dict=False), find_all, **kwargs)



# API classes
class Extension(BaseSystem):
"""Class which gathers all information for one extension of the IOSystem
Expand Down Expand Up @@ -2164,6 +2164,67 @@ def get_extensions(self, data=False):
else:
yield key

def extension_contains(self, extensions=None, find_all=None, **kwargs):
"""Get a dict of extension index dicts
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
find_all : None or str
If str (regex pattern) search in all index levels.
All matching rows are returned. The remaining kwargs are ignored.
kwargs : dict
The regex which should be contained. The keys are the index names,
the values are the regex.
If the entry is not in index name, it is ignored silently.
"""
method = "contains"
return self._apply_extension_method(
extensions, method, find_all=find_all, **kwargs
)

# TODO: CONT: added methods for match, fullmatch
# TODO: CONT: add method for extract extension

def _apply_extension_method(self, extensions, method, *args, **kwargs):
"""Apply a method to a list of extensions
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, None: all extensions
method: str
The method to apply
args: list
The arguments to pass to the method
kwargs: dict
The keyword arguments to pass to the method
Returns
-------
dict
A dict with the extension names as keys and the return values of the
method as values
"""
if extensions is None:
extensions = self.get_extensions(data=False)
elif (
str(type(extensions)) == "<class 'pymrio.core.mriosystem.Extension'>"
) or type(extensions) == str:
extensions = [extensions]
result = dict()
for ext in extensions:
if type(ext) is str:
extname = ext
ext = getattr(self, ext)
else:
extname = ext.name
method_fun = getattr(ext, method)
result[extname] = method_fun(*args, **kwargs)
return result

def reset_full(self, force=False):
"""Remove all accounts which can be recalculated based on Z, Y, F, F_Y
Expand Down
2 changes: 1 addition & 1 deletion pymrio/tools/ioutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def _index_regex_matcher(_dfs_idx, _method, _find_all=None, **kwargs):

if not at_least_one_valid:
if type(_dfs_idx) in [pd.DataFrame, pd.Series]:
_dfs_idx = pd.DataFrame(index=[], columns=_dfs_idx.columns)
_dfs_idx = pd.DataFrame(index=[], columns=_dfs_idx.columns)
elif type(_dfs_idx) in [pd.Index, pd.MultiIndex]:
_dfs_idx = pd.Index([])

Expand Down
22 changes: 14 additions & 8 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def test_reset_to_coefficients(fix_testmrio):
assert tt.Z is None
assert tt.emissions.F is None


def test_find(fix_testmrio):
tt = fix_testmrio.testmrio

Expand All @@ -512,16 +513,17 @@ def test_find(fix_testmrio):
assert "regions" not in ext_find.keys()
assert "Y_categories" not in ext_find.keys()


def test_contain_match_matchall(fix_testmrio):
tt = fix_testmrio.testmrio

cont_bare = tt.contains("th")
cont_find_all = tt.contains(find_all = "th")
cont_find_all = tt.contains(find_all="th")
assert all(cont_bare == cont_find_all)
assert 'other' in cont_bare.get_level_values('sector')
assert 'reg1' in cont_bare.get_level_values('region')
assert 'reg2' in cont_bare.get_level_values('region')
assert 'food' not in cont_bare.get_level_values('sector')
assert "other" in cont_bare.get_level_values("sector")
assert "reg1" in cont_bare.get_level_values("region")
assert "reg2" in cont_bare.get_level_values("region")
assert "food" not in cont_bare.get_level_values("sector")

match_test_empty = tt.match("th")
fullmatch_test_empty = tt.fullmatch("oth")
Expand All @@ -543,11 +545,15 @@ def test_contain_match_matchall(fix_testmrio):
assert all(fullmatch_test4 == cont_bare)

# check with keywors and extensions
ext_air = tt.emissions.match(compartment = "air")
ext_air_none = tt.emissions.match(stressor = "air")
ext_air = tt.emissions.match(compartment="air")
ext_air_none = tt.emissions.match(stressor="air")
assert len(ext_air_none) == 0
assert len(ext_air) > 0

ext_all_comp = tt.emissions.match(compartment="air|water")
assert all(ext_all_comp == tt.emissions.F.index)


def test_direct_account_calc(fix_testmrio):
orig = fix_testmrio.testmrio
orig.calc_all()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ def test_util_regex():

# 6. test with all kwargs not present

df_some_match = index_match(test_df, region="a.*", sector=".*b.*", not_present_column="abc")
df_some_match = index_match(
test_df, region="a.*", sector=".*b.*", not_present_column="abc"
)
assert df_some_match.index.get_level_values("region").unique() == ["a1"]
assert df_some_match.index.get_level_values("sector").unique() == ["bb"]

Expand Down

0 comments on commit 0711f8b

Please sign in to comment.