Skip to content

Commit

Permalink
added Extension extract method
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Jan 5, 2024
1 parent 4dd72ba commit 7aeb626
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ Changelog
v0.5.3dev
***************************

Depracted
=========

* extension.get_row_data()
The method get_row_data() is deprecated and will be removed in v0.6.0.
Use extension.extract() instead.


***************************
v0.5.2 - 20230815
Expand Down
86 changes: 66 additions & 20 deletions pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,11 +1568,13 @@ def get_rows(self):
return None

def get_row_data(self, row, name=None):
"""Returns a dict with all available data for a row in the extension
"""Returns a dict with all available data for a row in the extension.
If you need a new extension, see the extract method.
Parameters
----------
row : tuple, list, string
row : index, tuple, list, string
A valid index for the extension DataFrames
name : string, optional
If given, adds a key 'name' with the given value to the dict. In
Expand All @@ -1581,15 +1583,56 @@ def get_row_data(self, row, name=None):
Returns
-------
dict object with the data (pandas DataFrame)for the specific rows
dict object with the data (pandas DataFrame) for the specific rows
"""
# depraction warning

warnings.warn(
"This method will be removed in future versions. "
"Use extract method instead",
DeprecationWarning,
)

retdict = {}
for rowname, data in zip(self.get_DataFrame(), self.get_DataFrame(data=True)):
retdict[rowname] = pd.DataFrame(data.loc[row])
if name:
retdict["name"] = name
return retdict

def extract(self, index, dataframes=None):
"""Returns a dict with all available data for a row in the extension.
Note
-----
To build a new extension from the extracted data, use the
Extension constructor.
new_extension = Extension(name='new_extension', **Extension.extract(index))
Parameters
----------
index : valid row index
A valid index for the extension DataFrames
dataframes : list, optional
The dataframes which should be extracted. If None (default),
all available dataframes are extracted.
Returns
-------
dict object with the data (pandas DataFrame) for the specific rows
"""
retdict = {}
if dataframes is None:
dataframes = self.get_DataFrame()

for dfname in dataframes:
data = getattr(self, dfname)
retdict[dfname] = pd.DataFrame(data.loc[index])

return retdict


def diag_stressor(self, stressor, name=None, _meta=None):
"""Diagonalize one row of the stressor matrix for a flow analysis.
Expand Down Expand Up @@ -2164,7 +2207,7 @@ def get_extensions(self, data=False):
else:
yield key

def extension_fullmatch(self, extensions=None, find_all=None, **kwargs):
def extension_fullmatch(self, find_all=None, extensions=None, **kwargs):
"""Get a dict of extension index dicts with full match of a search pattern.
This calls the extension.fullmatch for all extensions.
Expand All @@ -2181,11 +2224,11 @@ def extension_fullmatch(self, extensions=None, find_all=None, **kwargs):
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
find_all : None or str
If str (regex pattern) search in all index levels.
All matching rows are returned. The remaining kwargs are ignored.
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
kwargs : dict
The regex which should be contained. The keys are the index names,
the values are the regex.
Expand All @@ -2201,7 +2244,7 @@ def extension_fullmatch(self, extensions=None, find_all=None, **kwargs):
extensions, method="match", find_all=find_all, **kwargs
)

def extension_match(self, extensions=None, find_all=None, **kwargs):
def extension_match(self, find_all=None, extensions=None, **kwargs):
"""Get a dict of extension index dicts which match a search pattern
This calls the extension.match for all extensions.
Expand All @@ -2218,11 +2261,11 @@ def extension_match(self, extensions=None, find_all=None, **kwargs):
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
find_all : None or str
If str (regex pattern) search in all index levels.
All matching rows are returned. The remaining kwargs are ignored.
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
kwargs : dict
The regex which should be contained. The keys are the index names,
the values are the regex.
Expand All @@ -2238,7 +2281,7 @@ def extension_match(self, extensions=None, find_all=None, **kwargs):
extensions, method="match", find_all=find_all, **kwargs
)

def extension_contains(self, extensions=None, find_all=None, **kwargs):
def extension_contains(self, find_all=None, extensions=None, **kwargs):
"""Get a dict of extension index dicts which contains a search pattern
This calls the extension.contains for all extensions.
Expand All @@ -2256,11 +2299,11 @@ def extension_contains(self, extensions=None, find_all=None, **kwargs):
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
find_all : None or str
If str (regex pattern) search in all index levels.
All matching rows are returned. The remaining kwargs are ignored.
extensions: str, list of str, list of extensions, None
Which extensions to consider, default (None): all extensions
kwargs : dict
The regex which should be contained. The keys are the index names,
the values are the regex.
Expand All @@ -2276,21 +2319,24 @@ def extension_contains(self, extensions=None, find_all=None, **kwargs):
extensions, method="contains", find_all=find_all, **kwargs
)

# TODO: CONT: add method for extract extension



# TODO: CONT: add extract method for all extensions

def _apply_extension_method(self, extensions, method, *args, **kwargs):
"""Apply a method to a list of extensions
Parameters
----------
extensions: str, list of str, list of extensions, None
Which extensions to consider, None: all extensions
extensions: str, list of str, list of extensions, or None
Specifies which extensions to consider. Use None to consider all extensions.
method: str
The method to apply
Specifies the method to apply.
args: list
The arguments to pass to the method
Specifies the arguments to pass to the method.
kwargs: dict
The keyword arguments to pass to the method
Specifies the keyword arguments to pass to the method.
Returns
-------
Expand All @@ -2303,11 +2349,11 @@ def _apply_extension_method(self, extensions, method, *args, **kwargs):
extensions = self.get_extensions(data=False)
elif (
str(type(extensions)) == "<class 'pymrio.core.mriosystem.Extension'>"
) or type(extensions) == str:
) or isinstance(extensions, str):
extensions = [extensions]
result = dict()
for ext in extensions:
if type(ext) is str:
if isinstance(ext, str):
extname = ext
ext = getattr(self, ext)
else:
Expand Down
23 changes: 16 additions & 7 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,15 +237,24 @@ def test_copy_and_extensions(fix_testmrio):
assert len(list(fix_testmrio.testmrio.get_extensions())) == 2


def test_get_row_data(fix_testmrio):
stressor = ("emission_type1", "air")
def test_extract(fix_testmrio):
tt = fix_testmrio.testmrio.copy().calc_all()
td = tt.emissions.get_row_data(stressor)["D_exp_reg"]
md = pd.DataFrame(tt.emissions.D_exp_reg.loc[stressor])
pdt.assert_frame_equal(td, md)

for df_name in tt.emissions.get_DataFrame():
assert df_name in tt.emissions.get_row_data(stressor)
all_index = tt.emissions.get_index()
new_all = pymrio.Extension(name="new_all", **tt.emissions.extract(all_index))
assert new_all.name == "new_all"
for df in tt.emissions.get_DataFrame():
assert df in new_all.get_DataFrame()

id_air = tt.emissions.match(compartment="air")
new_air = pymrio.Extension(name="new_air", **tt.emissions.extract(index=id_air, dataframes=["S", "S_Y"]))

assert "F" not in new_air.get_DataFrame()
assert "S" in new_air.get_DataFrame()
assert "S_Y" in new_air.get_DataFrame()

with pytest.raises(AttributeError):
tt.emissions.extract(index=id_air, dataframes=["S", "FOO"])


def test_diag_stressor(fix_testmrio):
Expand Down

0 comments on commit 7aeb626

Please sign in to comment.