diff --git a/src/muse/agents/agent.py b/src/muse/agents/agent.py index d5f085494..894f21add 100644 --- a/src/muse/agents/agent.py +++ b/src/muse/agents/agent.py @@ -362,7 +362,7 @@ def next( ) # Calculate investments - investments = self.invest( + result = self.invest( search=search[["search_space", "decision"]], technologies=technologies, constraints=constraints, @@ -373,7 +373,7 @@ def next( # Add investments self.add_investments( technologies=technologies, - investments=investments, + investments=result["capacity"].rename("investment"), investment_year=investment_year, ) diff --git a/src/muse/investments.py b/src/muse/investments.py index f5467d919..32b1ce31c 100644 --- a/src/muse/investments.py +++ b/src/muse/investments.py @@ -103,16 +103,22 @@ def decorated( **kwargs, ) - if isinstance(result, xr.Dataset): - investment = result["capacity"].rename("investment") - if "production" in result: - cache_quantity(production=result["production"]) - else: - investment = result.rename("investment") - - cache_quantity(capacity=investment) - - return investment + # Check the output + assert set(result.data_vars) == {"capacity", "production"} + assert set(result.capacity.dims) == {"asset", "replacement"} + assert set(result.production.dims) == { + "asset", + "replacement", + "commodity", + "timeslice", + } + + # Add to cache + cache_quantity(production=result["production"]) + cache_quantity(capacity=result["capacity"]) + + # Return the result + return result return decorated @@ -158,7 +164,7 @@ def compute_investment( constraints=constraints, **params, **kwargs, - ).rename("investment") + ) return compute_investment