Skip to content

Commit

Permalink
change Fed Stats output format (#2199)
Browse files Browse the repository at this point in the history
* 1. change fed Stats output format : no need to change the format to fit the visualization needs
2. change the visualization to reformat the disctionary to fit the visualization needs

* formatting
  • Loading branch information
chesterxgchen authored Dec 8, 2023
1 parent dd7e312 commit 9de8f31
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 15 deletions.
20 changes: 18 additions & 2 deletions examples/hello-world/step-by-step/higgs/stats/tabular_stats.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@
"cell_type": "code",
"execution_count": null,
"id": "4af4d563-ec0e-4d03-a6ae-008d9ba62171",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!nvflare job create -w stats_df -force -j /tmp/nvflare/jobs/stats_df \\-sd code \\\n",
Expand Down Expand Up @@ -485,7 +487,9 @@
"cell_type": "code",
"execution_count": null,
"id": "3fb4ffde-627c-4f72-b114-fb9e92c1dc6f",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!nvflare simulator /tmp/nvflare/jobs/stats_df -w /tmp/nvflare/tabular/stats_df -n 3 -t 3"
Expand Down Expand Up @@ -516,6 +520,18 @@
"!ls -al /tmp/nvflare/tabular/stats_df/simulate_job/statistics/\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48c955f1-d002-4d5a-a408-6c9c78ad9854",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!cat /tmp/nvflare/tabular/stats_df/simulate_job/statistics/stats.json"
]
},
{
"cell_type": "markdown",
"id": "653f83f8-f96f-4943-af27-c5e6551d3449",
Expand Down
22 changes: 16 additions & 6 deletions nvflare/app_common/workflows/statistics_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,34 +384,44 @@ def _combine_all_statistics(self):
for statistic in filtered_client_statistics:
for client in self.client_statistics[statistic]:
for ds in self.client_statistics[statistic][client]:
client_dataset = f"{client}-{ds}"
for feature_name in self.client_statistics[statistic][client][ds]:
if feature_name not in result:
result[feature_name] = {}
if statistic not in result[feature_name]:
result[feature_name][statistic] = {}

if client not in result[feature_name][statistic]:
result[feature_name][statistic][client] = {}

if ds not in result[feature_name][statistic][client]:
result[feature_name][statistic][client][ds] = {}

if statistic == StC.STATS_HISTOGRAM:
hist: Histogram = self.client_statistics[statistic][client][ds][feature_name]
buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision)
result[feature_name][statistic][client_dataset] = buckets
result[feature_name][statistic][client][ds] = buckets
else:
result[feature_name][statistic][client_dataset] = round(
result[feature_name][statistic][client][ds] = round(
self.client_statistics[statistic][client][ds][feature_name], self.precision
)

precision = self.precision
for statistic in filtered_global_statistics:
for ds in self.global_statistics[statistic]:
global_dataset = f"{StC.GLOBAL}-{ds}"
for feature_name in self.global_statistics[statistic][ds]:
if StC.GLOBAL not in result[feature_name][statistic]:
result[feature_name][statistic][StC.GLOBAL] = {}

if ds not in result[feature_name][statistic][StC.GLOBAL]:
result[feature_name][statistic][StC.GLOBAL][ds] = {}

if statistic == StC.STATS_HISTOGRAM:
hist: Histogram = self.global_statistics[statistic][ds][feature_name]
buckets = StatisticsController._apply_histogram_precision(hist.bins, self.precision)
result[feature_name][statistic][global_dataset] = buckets
result[feature_name][statistic][StC.GLOBAL][ds] = buckets
else:
result[feature_name][statistic].update(
{global_dataset: round(self.global_statistics[statistic][ds][feature_name], precision)}
{StC.GLOBAL: {ds: round(self.global_statistics[statistic][ds][feature_name], precision)}}
)

return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
from nvflare.fuel.utils.import_utils import optional_import


def convert_data(feature_metrics) -> dict:
converted = {}
for statistic in feature_metrics:
converted[statistic] = {}
for site in feature_metrics[statistic]:
for ds in feature_metrics[statistic][site]:
site_dataset = f"{site}-{ds}"
converted[statistic][site_dataset] = feature_metrics[statistic][site][ds]
return converted


class Visualization:
def import_modules(self):
display, import_flag = optional_import(module="IPython.display", name="display")
Expand All @@ -27,17 +38,23 @@ def import_modules(self):
print(pd.failure)
return display, pd

def show_stats(self, data, white_list_features=[]):
def show_stats(self, data, white_list_features=None):
if white_list_features is None:
white_list_features = []

display, pd = self.import_modules()
all_features = [k for k in data]
target_features = self._get_target_features(all_features, white_list_features)
for feature in target_features:
print(f"\n{feature}\n")
feature_metrics = data[feature]
df = pd.DataFrame.from_dict(feature_metrics)
converted = convert_data(feature_metrics)
df = pd.DataFrame.from_dict(converted)
display(df)

def show_histograms(self, data, display_format="sample_count", white_list_features=[], plot_type="both"):
def show_histograms(self, data, display_format="sample_count", white_list_features=None, plot_type="both"):
if white_list_features is None:
white_list_features = []
feature_dfs = self.get_histogram_dataframes(data, display_format, white_list_features)
self.show_dataframe_plots(feature_dfs, plot_type)

Expand All @@ -54,7 +71,9 @@ def show_dataframe_plots(self, feature_dfs, plot_type="both"):
else:
print(f"not supported plot type: '{plot_type}'")

def get_histogram_dataframes(self, data, display_format="sample_count", white_list_features=[]) -> Dict:
def get_histogram_dataframes(self, data, display_format="sample_count", white_list_features=None) -> Dict:
if white_list_features is None:
white_list_features = []
display, pd = self.import_modules()
(hists, edges) = self._prepare_histogram_data(data, display_format, white_list_features)
all_features = [k for k in edges]
Expand All @@ -69,15 +88,18 @@ def get_histogram_dataframes(self, data, display_format="sample_count", white_li

return feature_dfs

def _prepare_histogram_data(self, data, display_format="sample_count", white_list_features=[]):
def _prepare_histogram_data(self, data, display_format="sample_count", white_list_features=None):
if white_list_features is None:
white_list_features = []
all_features = [k for k in data]
target_features = self._get_target_features(all_features, white_list_features)

feature_hists = {}
feature_edges = {}

for feature in target_features:
xs = data[feature]["histogram"]
converted = convert_data(data[feature])
xs = converted["histogram"]
hists = {}
feature_edges[feature] = []
for i, ds in enumerate(xs):
Expand All @@ -103,7 +125,10 @@ def sum_counts_in_histogram(self, hist):
sum_value += bucket[2]
return sum_value

def _get_target_features(self, all_features, white_list_features=[]):
def _get_target_features(self, all_features, white_list_features=None):
if white_list_features is None:
white_list_features = []

target_features = white_list_features
if not white_list_features:
target_features = all_features
Expand Down

0 comments on commit 9de8f31

Please sign in to comment.