Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust tf/fedopt_ctl to include updates for the model's non-trainable #3058

Merged
merged 10 commits into from
Nov 25, 2024
43 changes: 32 additions & 11 deletions nvflare/app_opt/tf/fedopt_ctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import tensorflow as tf

from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.app_common.abstract.fl_model import FLModel, ParamsType
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -121,36 +121,57 @@ def update_model(self, global_model: FLModel, aggr_result: FLModel):
# Get the Keras model stored in memory in persistor.
holgerroth marked this conversation as resolved.
Show resolved Hide resolved
global_model_tf = self.persistor.model
global_params = global_model_tf.trainable_weights
num_trainable_weights = len(global_params)

# Compute model diff: need to use model diffs as
# gradients to be applied by the optimizer.
model_diff_params = {k: aggr_result.params[k] - global_model.params[k] for k in global_model.params}
model_diff_params = {}

w_idx = 0

for key, param in global_model.params.items():
if w_idx >= num_trainable_weights:
break

if param.shape == global_params[w_idx].shape:
model_diff_params[key] = (
aggr_result.params[key] - param
if aggr_result.params_type == ParamsType.FULL
else aggr_result.params[key]
)
w_idx += 1

model_diff = self._to_tf_params_list(model_diff_params, negate=True)

# Apply model diffs as gradients, using the optimizer.
start = time.time()

self.optimizer.apply_gradients(zip(model_diff, global_params))
secs = time.time() - start

# Convert updated global model weights to
# numpy format for FLModel.
start = time.time()
weights = global_model_tf.get_weights()
w_idx = 0

new_weights = {}
for key in global_model.params:
w = weights[w_idx]
while global_model.params[key].shape != w.shape:
w_idx += 1
w = weights[w_idx]
new_weights[key] = w
secs_detach = time.time() - start
for w_idx, key in enumerate(global_model.params):
if key in model_diff_params:
new_weights[key] = weights[w_idx]

else:

new_weights[key] = (
aggr_result.params[key]
if aggr_result.params_type == ParamsType.FULL
else global_model.params[key] + aggr_result.params[key]
)
YuanTingHsieh marked this conversation as resolved.
Show resolved Hide resolved
secs_detach = time.time() - start
self.info(
f"FedOpt ({type(self.optimizer)}) server model update "
f"round {self.current_round}, "
f"{type(self.lr_scheduler)} "
f"lr: {self.optimizer.learning_rate}, "
f"lr: {self.optimizer.learning_rate(self.optimizer.iterations).numpy()}, "
f"update: {secs} secs., detach: {secs_detach} secs.",
)

Expand Down
Loading