Skip to content

Commit

Permalink
updated search functions for new _apply method
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Jan 9, 2024
1 parent 33bb725 commit 7314ee5
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 54 deletions.
13 changes: 11 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ Breaking changes

* get_extensions has a new signature.
Two new paramters, names and istance_names.
Names allows to filter the extensions by name (set names of the attribute .name).
instance_names and be set to False to get the set names of the extensions.

- 'names' allows to filter the extensions by name (set names of the attribute .name
or the instance names, also allows to pass the extension itself). Can be used
to harmonize the names of an extension list.
- 'instance_names' can be set to False to get the "set names" of the extensions.

When using keyword arguments before, the new signature should just work.

* remove_extension:
Previous all extensions got removed when no name was given.
This has changed. Now all extensions are kept when no name is given (TypeError is raised).
To remove all extensions, use mrio.remove_extension(mrio.get_extensions())



Depracted
Expand Down
85 changes: 49 additions & 36 deletions pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2194,7 +2194,8 @@ def get_extensions(self, names=None, data=False, instance_names=True):
----------
names = str or list like, optional
Extension names to yield. If None (default), all extensions are
yielded. This can be used to convert from set names to instance names.
yielded. This can be used to convert from set names to instance names
and vice versa or to harmonize a list of extensions.
data : boolean, optional
If True, returns a generator which yields the extensions.
Expand All @@ -2214,21 +2215,35 @@ def get_extensions(self, names=None, data=False, instance_names=True):
"""

ext_list = [
key for key in self.__dict__ if type(self.__dict__[key]) is Extension
all_ext_list = [
key for key in self.__dict__ if isinstance(self.__dict__[key], Extension)
]
ext_names = names if names else [getattr(self, ext).name for ext in ext_list]
all_name_list = [getattr(self, key).name for key in all_ext_list]

if isinstance(names, str):
names = [names]
_pre_ext = names if names else all_ext_list
ext_name_or_inst = [
nn.name if isinstance(nn, Extension) else nn for nn in _pre_ext
]

for name in ext_name_or_inst:
if name in all_ext_list:
inst_name = name
ext_name = all_name_list[all_ext_list.index(name)]
elif name in all_name_list:
inst_name = all_ext_list[all_name_list.index(name)]
ext_name = name
else:
raise ValueError(f"Extension {name} not present in the system.")

for key in ext_list:
if getattr(self, key).name not in ext_names:
continue
if data:
yield getattr(self, key)
yield getattr(self, inst_name)
else:
if instance_names:
yield key
yield inst_name
else:
yield getattr(self, key).name
yield ext_name

def extension_fullmatch(self, find_all=None, extensions=None, **kwargs):
"""Get a dict of extension index dicts with full match of a search pattern.
Expand Down Expand Up @@ -2387,24 +2402,25 @@ def _apply_extension_method(self, extensions, method, *args, **kwargs):
-------
dict
A dict with the extension names as keys and the return values of the
method as values
method as values. The keys are the same as in 'extensions', thus
convert these to the set names or instance names before (using
mrio.get_extensions)
"""
if extensions is None:
extensions = self.get_extensions(data=False)
elif (
str(type(extensions)) == "<class 'pymrio.core.mriosystem.Extension'>"
) or isinstance(extensions, str):
extensions = list(self.get_extensions(data=False, instance_names=False))
if isinstance(extensions, (Extension, str)):
extensions = [extensions]
result = dict()
for ext in extensions:
if isinstance(ext, str):
extname = ext
ext = getattr(self, ext)
else:
extname = ext.name

instance_names = self.get_extensions(
names=extensions, data=False, instance_names=True
)
ext_data = self.get_extensions(names=extensions, data=True)

result = {}
for ext_name, inst_name, ext in zip(extensions, instance_names, ext_data):
method_fun = getattr(ext, method)
result[extname] = method_fun(*args, **kwargs)
result[ext_name] = method_fun(*args, **kwargs)
return result

def reset_full(self, force=False):
Expand Down Expand Up @@ -2531,7 +2547,9 @@ def save_all(self, path, table_format="txt", sep="\t", float_format="%.12g"):
float_format=float_format,
)

for ext, ext_name in zip(self.get_extensions(data=True), self.get_extensions()):
for ext, ext_name in zip(
self.get_extensions(data=True), self.get_extensions(instance_names=True)
):
ext_path = path / ext_name

ext.save(
Expand Down Expand Up @@ -2922,7 +2940,7 @@ def aggregate(
self.calc_extensions()
return self

def remove_extension(self, ext=None):
def remove_extension(self, ext):
"""Remove extension from IOSystem
For single Extensions the same can be achieved with del
Expand All @@ -2932,26 +2950,21 @@ def remove_extension(self, ext=None):
----------
ext : string or list, optional
The extension to remove, this can be given as the name of the
instance or of Extension.name (the latter will be checked if no
instance or of Extension.name.
instance was found)
If ext is None (default) all Extensions will be removed
"""
if ext is None:
ext = list(self.get_extensions())
if type(ext) is str:
ext = [ext]

for ee in ext:
try:
del self.__dict__[ee]
except KeyError:
for exinstancename, exdata in zip(
self.get_extensions(data=False), self.get_extensions(data=True)
):
if exdata.name == ee:
del self.__dict__[exinstancename]
finally:
self.meta._add_modify("Removed extension {}".format(ee))
except KeyError:
ext_instance = self.get_extensions(ee, instance_names=True)
for x in ext_instance:
del self.__dict__[x]
self.meta._add_modify("Removed extension {}".format(x))

return self

Expand Down
72 changes: 56 additions & 16 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def test_get_gross_trade(fix_testmrio):
(None, False, False, ["Emissions", "Factor Inputs"]),
(["Emissions", "Factor Inputs"], False, False, ["Emissions", "Factor Inputs"]),
(["Emissions", "Factor Inputs"], False, True, ["emissions", "factor_inputs"]),
(["Emissions", "factor_inputs"], False, True, ["emissions", "factor_inputs"]),
(["emissions", "factor_inputs"], False, False, ["Emissions", "Factor Inputs"]),
(
["emissions", "factor_inputs", "emissions"],
False,
False,
["Emissions", "Factor Inputs", "Emissions"],
),
],
)
def test_get_extensions(fix_testmrio, names, data, instance_names, result):
Expand All @@ -114,6 +122,19 @@ def test_get_extensions(fix_testmrio, names, data, instance_names, result):
assert sorted(exts) == sorted(result)


def test_get_extension_raise(fix_testmrio):
tt = fix_testmrio.testmrio
with pytest.raises(ValueError):
list(
tt.get_extensions(
names=["emissions", "foo"], data=False, instance_names=True
)
)


(["emissions", "facor_inputs"], False, False, pytest.raises(ValueError)),


def test_get_index(fix_testmrio):
"""Testing the different options for get_index in core.mriosystem"""
tt = fix_testmrio.testmrio
Expand Down Expand Up @@ -253,9 +274,11 @@ def test_copy_and_extensions(fix_testmrio):
tcp = fix_testmrio.testmrio.copy()
tcp.remove_extension("Emissions")
assert len(list(tcp.get_extensions())) == 1
tcp.remove_extension()
assert len(list(tcp.get_extensions())) == 0
assert len(list(fix_testmrio.testmrio.get_extensions())) == 2
with pytest.raises(TypeError):
tcp.remove_extension()
tcnew = fix_testmrio.testmrio.copy()
tcnew.remove_extension(tcnew.get_extensions())
assert len(list(tcnew.get_extensions())) == 0


def test_extract(fix_testmrio):
Expand Down Expand Up @@ -576,7 +599,7 @@ def test_contain_match_matchall(fix_testmrio):
assert all(fullmatch_test3 == cont_bare)
assert all(fullmatch_test4 == cont_bare)

# check with keywors and extensions
# check with keywords and extensions
ext_air = tt.emissions.match(compartment="air")
ext_air_none = tt.emissions.match(stressor="air")
assert len(ext_air_none) == 0
Expand All @@ -589,28 +612,45 @@ def test_contain_match_matchall(fix_testmrio):
def test_extension_match_contain(fix_testmrio):
tt = fix_testmrio.testmrio
match_air = tt.extension_match(find_all="air")
assert len(match_air["factor_inputs"]) == 0
assert len(match_air["emissions"]) == 1
assert len(match_air["Factor Inputs"]) == 0
assert len(match_air["Emissions"]) == 1

contain_value_added = tt.extension_contains(inputtype="dded")
assert len(contain_value_added["factor_inputs"]) == 1
assert len(contain_value_added["emissions"]) == 0
assert len(contain_value_added["Factor Inputs"]) == 1
assert len(contain_value_added["Emissions"]) == 0

fullmatch_0 = tt.extension_fullmatch(emissions="dded")
assert len(fullmatch_0["factor_inputs"]) == 0
assert len(fullmatch_0["emissions"]) == 0
assert len(fullmatch_0["Factor Inputs"]) == 0
assert len(fullmatch_0["Emissions"]) == 0
fullmatch_1 = tt.extension_fullmatch(stressor="emission_type.*")
assert len(fullmatch_1["factor_inputs"]) == 0
assert len(fullmatch_1["emissions"]) == 2
assert len(fullmatch_1["Factor Inputs"]) == 0
assert len(fullmatch_1["Emissions"]) == 2

# dual match
dual_match1 = tt.extension_match(stressor="emission_type.*", compartment="air")
assert len(dual_match1["factor_inputs"]) == 0
assert len(dual_match1["emissions"]) == 1
assert len(dual_match1["Factor Inputs"]) == 0
assert len(dual_match1["Emissions"]) == 1

dual_match2 = tt.extension_contains(stressor="1", inputtype="alue")
assert len(dual_match2["factor_inputs"]) == 1
assert len(dual_match2["emissions"]) == 1
assert len(dual_match2["Factor Inputs"]) == 1
assert len(dual_match2["Emissions"]) == 1

# Test for extension instance and set names
inst_match = tt.extension_match(
extensions=["emissions", "factor_inputs"], stressor="emission_type.*"
)
assert len(inst_match["emissions"]) == 2
assert len(inst_match["factor_inputs"]) == 0

inst_match2 = tt.extension_match(
extensions=["emissions"], stressor="emission_type.*"
)
assert len(inst_match2["emissions"]) == 2
assert "factor_inputs" not in inst_match2.keys()

name_match = tt.extension_contains(extensions=["Factor Inputs"], inputtype="Value")
assert "factor_inputs" not in name_match.keys()
assert len(name_match["Factor Inputs"]) == 1


def test_direct_account_calc(fix_testmrio):
Expand Down

0 comments on commit 7314ee5

Please sign in to comment.