Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/update diagnose_data_errors #62

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions marl_eval/utils/diagnose_data_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tools for verifying the json file formatting."""
"""Tools for verifying JSON file formatting."""

import copy
from typing import Any, Dict, List
Expand All @@ -25,7 +25,7 @@


class DiagnoseData:
"""Class to diagnose the errors."""
"""Class to diagnose errors in the JSON data."""

def __init__(self, raw_data: Dict[str, Dict[str, Any]]) -> None:
"""Initialise and make all dictionary strings lower case."""
Expand All @@ -34,7 +34,7 @@ def __init__(self, raw_data: Dict[str, Dict[str, Any]]) -> None:

def check_algo(self, list_algo: List) -> tuple:
"""Check that through the scenarios, the data share the same algorithms \
and that algorithm names are of the correct format."""
and that the algorithm names are of the correct format."""
if list_algo == []:
return True, []
identical = True
Expand All @@ -47,8 +47,8 @@ def check_algo(self, list_algo: List) -> tuple:

if not identical:
print(
"The algorithms used across the different tasks are not the same\n\
The overlapping algorithms are :\n",
"The algorithms used across the different tasks are not the same.\n"
+ "The overlapping algorithms are:\n",
sorted(same_algos),
)

Expand All @@ -59,93 +59,97 @@ def check_algo(self, list_algo: List) -> tuple:

if not algo_names_valid:
print(
"Some algorithm names contain commas, which is not permitted."
"Some algorithm names contain commas, which is not permitted. "
+ f"Valid algorithm names are {valid_algo_names}."
)

return identical, algo_names_valid, same_algos, valid_algo_names

def check_metric(self, list_metric: List) -> tuple:
"""Check that through the steps, runs, algo and scenarios, the data share \
the same list of metrics"""
"""Check that through the steps, runs, algoirhtms and scenarios, \
the data share the same list of metrics"""
if list_metric == []:
return True, []
identical = True
same_metrics = sorted(list_metric[0])

if "step_count" in same_metrics:
same_metrics.remove("step_count")
if "elapsed_time" in same_metrics:
same_metrics.remove("elapsed_time")

for i in range(1, len(list_metric)):
if "step_count" in list_metric[i]:
list_metric[i].remove("step_count")
if "elapsed_time" in list_metric[i]:
list_metric[i].remove("elapsed_time")
if sorted(same_metrics) != sorted(list_metric[i]):
identical = False
same_metrics = list(set(same_metrics) & set(list_metric[i]))

if not identical:
print(
"The metrics used across the different steps, runs, algorithms\
and scenarios are not the same\n\
The overlapping metrics are :\n",
sorted(same_metrics),
"The metrics used across the different steps, runs, "
+ "algorithms and scenarios are not the same.\n"
+ f"The overlapping metrics are:\n{sorted(same_metrics)}"
)

return identical, same_metrics

def check_runs(self, num_runs: List) -> tuple:
"""Check that through the algos, the data share the same num of run"""
"""Check that the data share the same number of runs through the algorithms."""
if num_runs == []:
return True, []

if num_runs.count(num_runs[0]) == len(num_runs):
return True, num_runs[0]

print(
"The number of runs is not identical through the different algorithms and "
"scenarios.\nThe minimum number of runs is " + str(min(num_runs)) + " runs."
"The number of runs is not identical through the different algorithms "
+ "and scenarios.\nThe minimum number of runs is "
+ str(min(num_runs))
+ " runs."
)
return False, min(num_runs)

def check_steps(self, num_steps: List) -> tuple:
"""Check that through the different runs, algo and scenarios, \
the data share the same number of steps"""
"""Check that through the different runs, algorithms and scenarios, \
the data share the same number of steps."""
if num_steps == []:
return True, []

if num_steps.count(num_steps[0]) == len(num_steps):
return True, num_steps[0]

print(
"The number of steps is not identical through the different runs, \
algorithms and scenarios.\n The minimum number of steps: "
"The number of steps is not identical through the different runs,"
+ "algorithms and scenarios.\nThe minimum number of steps is "
+ str(min(num_steps))
+ " steps."
)
return False, min(num_steps)

def data_format(self) -> Dict[str, Any]: # noqa: C901
"""Get the necessary details to figure if there is an issue with the json"""

def get_data_format(self) -> Dict[str, Any]: # noqa: C901
"""Get the necessary details from the JSON file to check for errors."""
processed_data = copy.deepcopy(self.raw_data)
data_used: Dict[str, Any] = {}

for env in self.raw_data.keys():
# List of algorithms used in the experiment across the tasks
# List of algorithms used in the experiment across the tasks.
algorithms_used = []
# List of num or runs used across the algos and the tasks
# List of num of runs used across the algos and tasks.
runs_used = []
# List of num of steps used across the runs, the algos and the tasks
# List of num of steps used across the runs, algos and tasks.
steps_used = []
# List of metrics used across the steps, the runs, the algos and the tasks
# List of metrics used across the steps, runs, algos and tasks.
metrics_used = []

for task in self.raw_data[env].keys():
# Append the list of used algorithms across the tasks
# Append the list of used algorithms across the tasks.
algorithms_used.append(sorted(list(processed_data[env][task].keys())))

for algorithm in self.raw_data[env][task].keys():
# Append the number of runs used across the different algos
# Append the number of runs used across the different algos.
runs_used.append(len(processed_data[env][task][algorithm].keys()))

for run in self.raw_data[env][task][algorithm].keys():
Expand Down Expand Up @@ -184,8 +188,8 @@ def data_format(self) -> Dict[str, Any]: # noqa: C901
return data_used

def check_data(self) -> Dict[str, Any]:
"""Check that the format don't issued any issue while using the tools"""
data_used = self.data_format()
"""Check that the data format won't throw errors while using marl-eval tools."""
data_used = self.get_data_format()
check_data_results: Dict[str, Any] = {}
for env in self.raw_data.keys():
valid_algo, valid_algo_names, _, _ = self.check_algo(
Expand All @@ -195,7 +199,7 @@ def check_data(self) -> Dict[str, Any]:
valid_steps, _ = self.check_steps(num_steps=data_used[env]["num_steps"])
valid_metrics, _ = self.check_metric(list_metric=data_used[env]["metrics"])

# Check that we have valid json file
# Check that we have a valid JSON file.
if (
valid_algo
and valid_runs
Expand All @@ -205,7 +209,7 @@ def check_data(self) -> Dict[str, Any]:
):
print("Valid format for the environment " + env + "!")
else:
print("invalid format for the environment " + env + "!")
print("Invalid format for the environment " + env + "!")
check_data_results[env] = {
"valid_algorithms": valid_algo,
"valid_algorithm_names": valid_algo_names,
Expand Down
Loading