diff --git a/pymrio/core/mriosystem.py b/pymrio/core/mriosystem.py index 93eef2be..4897d2c0 100644 --- a/pymrio/core/mriosystem.py +++ b/pymrio/core/mriosystem.py @@ -733,7 +733,9 @@ def find(self, term): """ res_dict = dict() try: - index_find = ioutil.index_contains(self.get_index(as_dict=False), find_all=term) + index_find = ioutil.index_contains( + self.get_index(as_dict=False), find_all=term + ) if len(index_find) > 0: res_dict["index"] = index_find except: # noqa: E722 @@ -759,8 +761,8 @@ def find(self, term): try: for ext in self.get_extensions(data=False): ext_index_find = ioutil.index_contains( - getattr(self, ext).get_index(as_dict=False), - find_all=term) + getattr(self, ext).get_index(as_dict=False), find_all=term + ) if len(ext_index_find) > 0: res_dict[ext + "_index"] = ext_index_find except: # noqa: E722 @@ -802,7 +804,6 @@ def contains(self, find_all=None, **kwargs): """ return ioutil.index_contains(self.get_index(as_dict=False), find_all, **kwargs) - def match(self, find_all=None, **kwargs): """Check if index of the system match the regex pattern @@ -872,7 +873,6 @@ def fullmatch(self, find_all=None, **kwargs): return ioutil.index_fullmatch(self.get_index(as_dict=False), find_all, **kwargs) - # API classes class Extension(BaseSystem): """Class which gathers all information for one extension of the IOSystem @@ -2164,6 +2164,67 @@ def get_extensions(self, data=False): else: yield key + def extension_contains(self, extensions=None, find_all=None, **kwargs): + """Get a dict of extension index dicts + + Parameters + ---------- + extensions: str, list of str, list of extensions, None + Which extensions to consider, default (None): all extensions + find_all : None or str + If str (regex pattern) search in all index levels. + All matching rows are returned. The remaining kwargs are ignored. + kwargs : dict + The regex which should be contained. The keys are the index names, + the values are the regex. + If the entry is not in index name, it is ignored silently. + """ + method = "contains" + return self._apply_extension_method( + extensions, method, find_all=find_all, **kwargs + ) + + # TODO: CONT: added methods for match, fullmatch + # TODO: CONT: add method for extract extension + + def _apply_extension_method(self, extensions, method, *args, **kwargs): + """Apply a method to a list of extensions + + Parameters + ---------- + extensions: str, list of str, list of extensions, None + Which extensions to consider, None: all extensions + method: str + The method to apply + args: list + The arguments to pass to the method + kwargs: dict + The keyword arguments to pass to the method + + Returns + ------- + dict + A dict with the extension names as keys and the return values of the + method as values + + """ + if extensions is None: + extensions = self.get_extensions(data=False) + elif ( + str(type(extensions)) == "" + ) or type(extensions) == str: + extensions = [extensions] + result = dict() + for ext in extensions: + if type(ext) is str: + extname = ext + ext = getattr(self, ext) + else: + extname = ext.name + method_fun = getattr(ext, method) + result[extname] = method_fun(*args, **kwargs) + return result + def reset_full(self, force=False): """Remove all accounts which can be recalculated based on Z, Y, F, F_Y diff --git a/pymrio/tools/ioutil.py b/pymrio/tools/ioutil.py index 4f2b5a7a..833a18c0 100644 --- a/pymrio/tools/ioutil.py +++ b/pymrio/tools/ioutil.py @@ -969,7 +969,7 @@ def _index_regex_matcher(_dfs_idx, _method, _find_all=None, **kwargs): if not at_least_one_valid: if type(_dfs_idx) in [pd.DataFrame, pd.Series]: - _dfs_idx = pd.DataFrame(index=[], columns=_dfs_idx.columns) + _dfs_idx = pd.DataFrame(index=[], columns=_dfs_idx.columns) elif type(_dfs_idx) in [pd.Index, pd.MultiIndex]: _dfs_idx = pd.Index([]) diff --git a/tests/test_core.py b/tests/test_core.py index 36275400..28e5d720 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -495,6 +495,7 @@ def test_reset_to_coefficients(fix_testmrio): assert tt.Z is None assert tt.emissions.F is None + def test_find(fix_testmrio): tt = fix_testmrio.testmrio @@ -512,16 +513,17 @@ def test_find(fix_testmrio): assert "regions" not in ext_find.keys() assert "Y_categories" not in ext_find.keys() + def test_contain_match_matchall(fix_testmrio): tt = fix_testmrio.testmrio - + cont_bare = tt.contains("th") - cont_find_all = tt.contains(find_all = "th") + cont_find_all = tt.contains(find_all="th") assert all(cont_bare == cont_find_all) - assert 'other' in cont_bare.get_level_values('sector') - assert 'reg1' in cont_bare.get_level_values('region') - assert 'reg2' in cont_bare.get_level_values('region') - assert 'food' not in cont_bare.get_level_values('sector') + assert "other" in cont_bare.get_level_values("sector") + assert "reg1" in cont_bare.get_level_values("region") + assert "reg2" in cont_bare.get_level_values("region") + assert "food" not in cont_bare.get_level_values("sector") match_test_empty = tt.match("th") fullmatch_test_empty = tt.fullmatch("oth") @@ -543,11 +545,15 @@ def test_contain_match_matchall(fix_testmrio): assert all(fullmatch_test4 == cont_bare) # check with keywors and extensions - ext_air = tt.emissions.match(compartment = "air") - ext_air_none = tt.emissions.match(stressor = "air") + ext_air = tt.emissions.match(compartment="air") + ext_air_none = tt.emissions.match(stressor="air") assert len(ext_air_none) == 0 assert len(ext_air) > 0 + ext_all_comp = tt.emissions.match(compartment="air|water") + assert all(ext_all_comp == tt.emissions.F.index) + + def test_direct_account_calc(fix_testmrio): orig = fix_testmrio.testmrio orig.calc_all() diff --git a/tests/test_util.py b/tests/test_util.py index 71d7f8dc..66c05d2f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -342,7 +342,9 @@ def test_util_regex(): # 6. test with all kwargs not present - df_some_match = index_match(test_df, region="a.*", sector=".*b.*", not_present_column="abc") + df_some_match = index_match( + test_df, region="a.*", sector=".*b.*", not_present_column="abc" + ) assert df_some_match.index.get_level_values("region").unique() == ["a1"] assert df_some_match.index.get_level_values("sector").unique() == ["bb"]