Skip to content

Commit

Permalink
fix series return
Browse files Browse the repository at this point in the history
  • Loading branch information
konstantinstadler committed Apr 8, 2024
1 parent 84d56ec commit b9959c3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
28 changes: 21 additions & 7 deletions doc/source/notebooks/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
df_some.keys()



# %% [markdown]
#### Extracting from all extensions

Expand All @@ -119,6 +118,7 @@
('emission_type2', 'water')]}



# %% [markdown]
# And can then use the `extension_extract` method to extract the data, either as a pandas DataFrame,
# which returns a dictionary with the extension names as keys
Expand All @@ -137,19 +137,33 @@
ext_extract_all = mrio.extension_extract(to_extract, return_type="extensions")
ext_extract_all.keys()

extracts = ext_extract_all

r = pymrio.concate_extension(*extracts.values(), name="abc")

# %%
str(ext_extract_all['Factor Inputs'])

# %% [markdown]
# Or merge the extracted data into a new pymrio Extension object (when passing a new name as return_type):

# %%
ext_new = mrio.extension_extract(to_extract, return_type="new_merged_extension")
str(ext_new)

# %% [markdown]
# CONT: Extraction to a single extensio does not work.
# Issue: when only one extension row, it becomes a data series, not a dataframe.
# CONT: Continue with explaining, mention the work with find_all etc

# CONT: Make test cases for the things below


mrio.factor_inputs.extract("Value Added", return_type="ext").F

mrio.factor_inputs.extract(("Value Added"), return_type="ext").F

mrio.factor_inputs.extract(["Value Added"], return_type="ext").F


mrio.factor_inputs.extract(mrio.factor_inputs.get_rows(), return_type="ext").F

mrio.emissions.extract(mrio.emissions.get_rows(), return_type="ext").F

mrio.emissions.extract(mrio.emissions.get_rows()[0], return_type="ext").F

mrio.emissions.get_rows()[0]
4 changes: 3 additions & 1 deletion pymrio/core/mriosystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,6 +1659,8 @@ def extract(self, index, dataframes=None, return_type="dataframes"):
"""
if type(index) is dict:
index = index.get(self.name, None)
if type(index) in (str, tuple):
index = [index]

retdict = {}
if dataframes is None:
Expand All @@ -1672,7 +1674,7 @@ def extract(self, index, dataframes=None, return_type="dataframes"):

for dfname in dataframes:
data = getattr(self, dfname)
retdict[dfname] = pd.DataFrame(data.loc[index])
retdict[dfname] = data.loc[index, :]

if return_type.lower() in ["dataframes", "dataframe", "dfs", "df"]:
return retdict
Expand Down

0 comments on commit b9959c3

Please sign in to comment.