Skip to content

Commit

Permalink
Address few misc. bugs (#2252)
Browse files Browse the repository at this point in the history
* address few misc. bugs
1) Step-by-step example train_with_mlflow.py should use global_step for logging metrics
2) fl_model_utils.update_model() losing metrics after update

* update fl_model_utils.py

* address comments

* import re-order

* remove logger for now
  • Loading branch information
chesterxgchen authored Jan 9, 2024
1 parent 1d2bf4d commit b0ca8eb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def evaluate(input_weights):
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
mlflow.log_metric("loss", running_loss / 2000, i)
global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i
mlflow.log_metric("loss", running_loss / 2000, global_step)
running_loss = 0.0

print(f"({client_id}) Finished Training")
Expand Down
20 changes: 13 additions & 7 deletions nvflare/app_common/utils/fl_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def to_shareable(fl_model: FLModel) -> Shareable:
raise ValueError("FLModel without params and metrics is NOT supported.")
elif fl_model.params is not None:
if fl_model.params_type is None:
raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).")
data_kind = params_type_to_data_kind.get(fl_model.params_type)
fl_model.params_type = ParamsType.FULL

data_kind = params_type_to_data_kind.get(fl_model.params_type.value)
if data_kind is None:
raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).")

Expand Down Expand Up @@ -103,11 +104,15 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) ->
metrics = dxo.data
else:
params_type = data_kind_to_params_type.get(dxo.data_kind)
params = dxo.data
if params_type is None:
raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}")
if params is None:
raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}")
else:
params_type = ParamsType.FULL

params_type = ParamsType(params_type)

params = dxo.data
if MetaKey.INITIAL_METRICS in meta:
metrics = meta[MetaKey.INITIAL_METRICS]
except:
Expand Down Expand Up @@ -197,14 +202,15 @@ def get_configs(model: FLModel) -> Optional[dict]:
@staticmethod
def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel:
if model.params_type != ParamsType.FULL:
raise RuntimeError(
f"params_type {model_update.params_type} of `model` not supported! Expected `ParamsType.FULL`."
)
raise RuntimeError(f"params_type {model.params_type} of `model` not supported! Expected `ParamsType.FULL`.")

if replace_meta:
model.meta = model_update.meta
else:
model.meta.update(model_update.meta)

model.metrics = model_update.metrics

if model_update.params_type == ParamsType.FULL:
model.params = model_update.params
elif model_update.params_type == ParamsType.DIFF:
Expand Down

0 comments on commit b0ca8eb

Please sign in to comment.