diff --git a/areal/api/reward_api.py b/areal/api/reward_api.py index a1b4564618..6b63ff3925 100644 --- a/areal/api/reward_api.py +++ b/areal/api/reward_api.py @@ -56,7 +56,7 @@ def reward_fn( :param completion_ids: The token IDs of the trajectory generated by the model. :param kwargs: Other attributes of the data in the dataset, such as solutions, input_outputs, etc. Any other attributes in the dataset will be passed as keyword arguments to this function. - :rtype: float + :rtype: float | dict[str, float] """ @@ -135,7 +135,7 @@ def _recreate_executor(cls, executor_key, max_workers): return cls._executors[executor_key] return None - async def __call__(self, *args, **kwargs) -> float: + async def __call__(self, *args, **kwargs) -> float | dict[str, float]: last_exception = None for attempt in range(self.max_retries + 1): diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index 63946842e0..3c340f9263 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -102,6 +102,12 @@ async def arun_episode( "trajectories returned None, using remaining results" ) + aggregate_group_results = getattr( + self.workflow, "aggregate_group_results", None + ) + if callable(aggregate_group_results): + return aggregate_group_results(valid_results) + # Check if results are InteractionWithTokenLogpReward dicts first = valid_results[0] if (