Skip to content

Commit

Permalink
Fix tb receiver (#2349)
Browse files Browse the repository at this point in the history
Co-authored-by: Chester Chen <[email protected]>
  • Loading branch information
YuanTingHsieh and chesterxgchen authored Feb 2, 2024
1 parent 899f11e commit 2f9d00d
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
6 changes: 1 addition & 5 deletions nvflare/apis/analytix.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,7 @@ def convert_data_type(
return sender_data_type

if sender == LogWriterName.MLFLOW and receiver == LogWriterName.TORCH_TB:
if AnalyticsDataType.PARAMETER == sender_data_type:
return AnalyticsDataType.SCALAR
elif AnalyticsDataType.PARAMETERS == sender_data_type:
return AnalyticsDataType.SCALARS
elif AnalyticsDataType.METRIC == sender_data_type:
if AnalyticsDataType.METRIC == sender_data_type:
return AnalyticsDataType.SCALAR
elif AnalyticsDataType.METRICS == sender_data_type:
return AnalyticsDataType.SCALARS
Expand Down
59 changes: 46 additions & 13 deletions nvflare/app_opt/tracking/tb/tb_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,22 @@
AnalyticsDataType.TEXT: "add_text",
AnalyticsDataType.IMAGE: "add_image",
AnalyticsDataType.SCALARS: "add_scalars",
AnalyticsDataType.PARAMETER: "add_scalar",
AnalyticsDataType.PARAMETERS: "add_scalars",
AnalyticsDataType.METRIC: "add_scalar",
AnalyticsDataType.METRICS: "add_scalars",
}


def _create_new_data(key, value, sender):
if isinstance(value, (int, float)):
data_type = AnalyticsDataType.SCALAR
elif isinstance(value, str):
data_type = AnalyticsDataType.TEXT
else:
return None

return AnalyticsData(key=key, value=value, data_type=data_type, sender=sender)


class TBAnalyticsReceiver(AnalyticsReceiver):
def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None):
"""Receives analytics data to save to TensorBoard.
Expand Down Expand Up @@ -71,6 +80,27 @@ def initialize(self, fl_ctx: FLContext):
os.makedirs(root_log_dir, exist_ok=True)
self.root_log_dir = root_log_dir

def _convert_to_records(self, analytic_data: AnalyticsData, fl_ctx: FLContext) -> List[AnalyticsData]:
# break dict of stuff to smaller items to support
# AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS
records = []

if analytic_data.data_type in (AnalyticsDataType.PARAMETER, AnalyticsDataType.PARAMETERS):
for k, v in (
analytic_data.value.items()
if analytic_data.data_type == AnalyticsDataType.PARAMETERS
else [(analytic_data.tag, analytic_data.value)]
):
new_data = _create_new_data(k, v, analytic_data.sender)
if new_data is None:
self.log_warning(fl_ctx, f"Entry {k} of type {type(v)} is not supported.", fire_event=False)
else:
records.append(new_data)
else:
records.append(analytic_data)

return records

def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
dxo = from_shareable(shareable)
analytic_data = AnalyticsData.from_dxo(dxo)
Expand All @@ -86,19 +116,22 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
# do different things depending on the type in dxo
self.log_debug(
fl_ctx,
f"save data {analytic_data} from {record_origin}",
f"try to save data {analytic_data} from {record_origin}",
fire_event=False,
)
func_name = FUNCTION_MAPPING.get(analytic_data.data_type, None)
if func_name is None:
self.log_error(fl_ctx, f"The data_type {analytic_data.data_type} is not supported.", fire_event=False)
return

func = getattr(writer, func_name)
if analytic_data.step:
func(analytic_data.tag, analytic_data.value, analytic_data.step)
else:
func(analytic_data.tag, analytic_data.value)
data_records = self._convert_to_records(analytic_data, fl_ctx)

for data_record in data_records:
func_name = FUNCTION_MAPPING.get(data_record.data_type, None)
if func_name is None:
self.log_warning(fl_ctx, f"The data_type {data_record.data_type} is not supported.", fire_event=False)
return

func = getattr(writer, func_name)
if data_record.step:
func(data_record.tag, data_record.value, data_record.step)
else:
func(data_record.tag, data_record.value)

def finalize(self, fl_ctx: FLContext):
for writer in self.writers_table.values():
Expand Down

0 comments on commit 2f9d00d

Please sign in to comment.