diff --git a/marl_eval/utils/diagnose_data_errors.py b/marl_eval/utils/diagnose_data_errors.py index 7f03fd6f..05a6a5b3 100644 --- a/marl_eval/utils/diagnose_data_errors.py +++ b/marl_eval/utils/diagnose_data_errors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tools for verifying the json file formatting.""" +"""Tools for verifying JSON file formatting.""" import copy from typing import Any, Dict, List @@ -25,7 +25,7 @@ class DiagnoseData: - """Class to diagnose the errors.""" + """Class to diagnose errors in the JSON data.""" def __init__(self, raw_data: Dict[str, Dict[str, Any]]) -> None: """Initialise and make all dictionary strings lower case.""" @@ -34,7 +34,7 @@ def __init__(self, raw_data: Dict[str, Dict[str, Any]]) -> None: def check_algo(self, list_algo: List) -> tuple: """Check that through the scenarios, the data share the same algorithms \ - and that algorithm names are of the correct format.""" + and that the algorithm names are of the correct format.""" if list_algo == []: return True, [] identical = True @@ -47,8 +47,8 @@ def check_algo(self, list_algo: List) -> tuple: if not identical: print( - "The algorithms used across the different tasks are not the same\n\ - The overlapping algorithms are :\n", + "The algorithms used across the different tasks are not the same.\n" + + "The overlapping algorithms are:\n", sorted(same_algos), ) @@ -59,15 +59,15 @@ def check_algo(self, list_algo: List) -> tuple: if not algo_names_valid: print( - "Some algorithm names contain commas, which is not permitted." + "Some algorithm names contain commas, which is not permitted. " + f"Valid algorithm names are {valid_algo_names}." ) return identical, algo_names_valid, same_algos, valid_algo_names def check_metric(self, list_metric: List) -> tuple: - """Check that through the steps, runs, algo and scenarios, the data share \ - the same list of metrics""" + """Check that through the steps, runs, algoirhtms and scenarios, \ + the data share the same list of metrics""" if list_metric == []: return True, [] identical = True @@ -75,26 +75,29 @@ def check_metric(self, list_metric: List) -> tuple: if "step_count" in same_metrics: same_metrics.remove("step_count") + if "elapsed_time" in same_metrics: + same_metrics.remove("elapsed_time") for i in range(1, len(list_metric)): if "step_count" in list_metric[i]: list_metric[i].remove("step_count") + if "elapsed_time" in list_metric[i]: + list_metric[i].remove("elapsed_time") if sorted(same_metrics) != sorted(list_metric[i]): identical = False same_metrics = list(set(same_metrics) & set(list_metric[i])) if not identical: print( - "The metrics used across the different steps, runs, algorithms\ - and scenarios are not the same\n\ - The overlapping metrics are :\n", - sorted(same_metrics), + "The metrics used across the different steps, runs, " + + "algorithms and scenarios are not the same.\n" + + f"The overlapping metrics are:\n{sorted(same_metrics)}" ) return identical, same_metrics def check_runs(self, num_runs: List) -> tuple: - """Check that through the algos, the data share the same num of run""" + """Check that the data share the same number of runs through the algorithms.""" if num_runs == []: return True, [] @@ -102,14 +105,16 @@ def check_runs(self, num_runs: List) -> tuple: return True, num_runs[0] print( - "The number of runs is not identical through the different algorithms and " - "scenarios.\nThe minimum number of runs is " + str(min(num_runs)) + " runs." + "The number of runs is not identical through the different algorithms " + + "and scenarios.\nThe minimum number of runs is " + + str(min(num_runs)) + + " runs." ) return False, min(num_runs) def check_steps(self, num_steps: List) -> tuple: - """Check that through the different runs, algo and scenarios, \ - the data share the same number of steps""" + """Check that through the different runs, algorithms and scenarios, \ + the data share the same number of steps.""" if num_steps == []: return True, [] @@ -117,35 +122,34 @@ def check_steps(self, num_steps: List) -> tuple: return True, num_steps[0] print( - "The number of steps is not identical through the different runs, \ - algorithms and scenarios.\n The minimum number of steps: " + "The number of steps is not identical through the different runs," + + "algorithms and scenarios.\nThe minimum number of steps is " + str(min(num_steps)) + " steps." ) return False, min(num_steps) - def data_format(self) -> Dict[str, Any]: # noqa: C901 - """Get the necessary details to figure if there is an issue with the json""" - + def get_data_format(self) -> Dict[str, Any]: # noqa: C901 + """Get the necessary details from the JSON file to check for errors.""" processed_data = copy.deepcopy(self.raw_data) data_used: Dict[str, Any] = {} for env in self.raw_data.keys(): - # List of algorithms used in the experiment across the tasks + # List of algorithms used in the experiment across the tasks. algorithms_used = [] - # List of num or runs used across the algos and the tasks + # List of num of runs used across the algos and tasks. runs_used = [] - # List of num of steps used across the runs, the algos and the tasks + # List of num of steps used across the runs, algos and tasks. steps_used = [] - # List of metrics used across the steps, the runs, the algos and the tasks + # List of metrics used across the steps, runs, algos and tasks. metrics_used = [] for task in self.raw_data[env].keys(): - # Append the list of used algorithms across the tasks + # Append the list of used algorithms across the tasks. algorithms_used.append(sorted(list(processed_data[env][task].keys()))) for algorithm in self.raw_data[env][task].keys(): - # Append the number of runs used across the different algos + # Append the number of runs used across the different algos. runs_used.append(len(processed_data[env][task][algorithm].keys())) for run in self.raw_data[env][task][algorithm].keys(): @@ -184,8 +188,8 @@ def data_format(self) -> Dict[str, Any]: # noqa: C901 return data_used def check_data(self) -> Dict[str, Any]: - """Check that the format don't issued any issue while using the tools""" - data_used = self.data_format() + """Check that the data format won't throw errors while using marl-eval tools.""" + data_used = self.get_data_format() check_data_results: Dict[str, Any] = {} for env in self.raw_data.keys(): valid_algo, valid_algo_names, _, _ = self.check_algo( @@ -195,7 +199,7 @@ def check_data(self) -> Dict[str, Any]: valid_steps, _ = self.check_steps(num_steps=data_used[env]["num_steps"]) valid_metrics, _ = self.check_metric(list_metric=data_used[env]["metrics"]) - # Check that we have valid json file + # Check that we have a valid JSON file. if ( valid_algo and valid_runs @@ -205,7 +209,7 @@ def check_data(self) -> Dict[str, Any]: ): print("Valid format for the environment " + env + "!") else: - print("invalid format for the environment " + env + "!") + print("Invalid format for the environment " + env + "!") check_data_results[env] = { "valid_algorithms": valid_algo, "valid_algorithm_names": valid_algo_names,