Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust tf/fedopt_ctl to include updates for the model's non-trainable #3058

Merged
merged 10 commits into from
Nov 25, 2024
120 changes: 67 additions & 53 deletions nvflare/app_opt/tf/fedopt_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import tensorflow as tf

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -106,55 +106,69 @@ def _to_tf_params_list(self, params: Dict, negate: bool = False):
tf_params_list.append(tf.Variable(v))
return tf_params_list

def update_model(self, global_model: FLModel, aggr_result: FLModel):
"""
Override the default version of update_model
to perform update with Keras Optimizer on the
global model stored in memory in persistor, instead of
creating new temporary model on-the-fly.

Creating a new model would not work for Keras
Optimizers, since an optimizer is bind to
specific set of Variables.

"""
# Get the Keras model stored in memory in persistor.
holgerroth marked this conversation as resolved.
Show resolved Hide resolved
global_model_tf = self.persistor.model
global_params = global_model_tf.trainable_weights

# Compute model diff: need to use model diffs as
# gradients to be applied by the optimizer.
model_diff_params = {k: aggr_result.params[k] - global_model.params[k] for k in global_model.params}
model_diff = self._to_tf_params_list(model_diff_params, negate=True)

# Apply model diffs as gradients, using the optimizer.
start = time.time()
self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start

# Convert updated global model weights to
# numpy format for FLModel.
start = time.time()
weights = global_model_tf.get_weights()
w_idx = 0
new_weights = {}
for key in global_model.params:
w = weights[w_idx]
while global_model.params[key].shape != w.shape:
w_idx += 1
w = weights[w_idx]
new_weights[key] = w
secs_detach = time.time() - start

self.info(
f"FedOpt ({type(self.optimizer)}) server model update "
f"round {self.current_round}, "
f"{type(self.lr_scheduler)} "
f"lr: {self.optimizer.learning_rate}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)

global_model.params = new_weights
global_model.meta = aggr_result.meta

return global_model
def update_model(self, global_model: FLModel, aggr_result: FLModel):
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
"""
Override the default version of update_model
to perform update with Keras Optimizer on the
global model stored in memory in persistor, instead of
creating new temporary model on-the-fly.

Creating a new model would not work for Keras
Optimizers, since an optimizer is bind to
specific set of Variables.

"""
global_model_tf = self.persistor.model
global_params = global_model_tf.trainable_weights
num_trainable_weights = len(global_params)

model_diff_params = {}

w_idx = 0

for key, param in global_model.params.items():
if w_idx >= num_trainable_weights:
break

if param.shape == global_params[w_idx].shape:
model_diff_params[key] = (
aggr_result.params[key] - param
if aggr_result.params_type == ParamsType.FULL
else aggr_result.params[key]
)
w_idx += 1

model_diff = self._to_tf_params_list(model_diff_params, negate=True)
start = time.time()

self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start

start = time.time()
weights = global_model_tf.get_weights()

new_weights = {}
for w_idx, key in enumerate(global_model.params):
if key in model_diff_params:
new_weights[key] = weights[w_idx]

else:

new_weights[key] = (
aggr_result.params[key]
if aggr_result.params_type == ParamsType.FULL
else global_model.params[key] + aggr_result.params[key]
)
secs_detach = time.time() - start
self.info(
f"FedOpt ({type(self.optimizer)}) server model update "
f"round {self.current_round}, "
f"{type(self.lr_scheduler)} "
f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)

global_model.params = new_weights
global_model.meta = aggr_result.meta

return global_model