From ca7b69ce4d24c94071d24973d3427672853c69c4 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Mon, 27 Jan 2025 21:08:11 -0800 Subject: [PATCH] Address comments --- .../hello-pt/src/hello-pt_cifar10_fl.py | 96 +++++----- .../hello-pt/src/hello-pt_cifar10_fl_v2.py | 99 ++++++++++ nvflare/client/__init__.py | 4 +- nvflare/client/api.py | 177 ++++++++---------- nvflare/client/api_context.py | 59 ++++++ nvflare/client/api_spec.py | 13 ++ nvflare/client/ex_process/api.py | 4 + nvflare/client/in_process/api.py | 4 + nvflare/client/task_registry.py | 2 +- nvflare/client/tracking.py | 43 ++++- 10 files changed, 341 insertions(+), 160 deletions(-) create mode 100644 examples/hello-world/hello-pt/src/hello-pt_cifar10_fl_v2.py create mode 100644 nvflare/client/api_context.py diff --git a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py index 905b80a9a0..2d32a3c783 100644 --- a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py +++ b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py @@ -22,7 +22,7 @@ from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor -from nvflare.client import FlareClientContext, FLModel +import nvflare.client as flare from nvflare.client.tracking import SummaryWriter DATASET_PATH = "/tmp/nvflare/data" @@ -30,7 +30,7 @@ def main(): batch_size = 4 - epochs = 5 + epochs = 2 lr = 0.01 model = SimpleNetwork() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -43,55 +43,53 @@ def main(): ] ) - with FlareClientContext() as flare: - sys_info = flare.system_info() - client_name = sys_info["site_name"] + flare.init() + sys_info = flare.system_info() + client_name = sys_info["site_name"] - train_dataset = CIFAR10( - root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + train_dataset = CIFAR10( + root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + summary_writer = SummaryWriter() + while flare.is_running(): + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + model.load_state_dict(input_model.params) + model.to(device) + + steps = epochs * len(train_loader) + for epoch in range(epochs): + running_loss = 0.0 + for i, batch in enumerate(train_loader): + images, labels = batch[0].to(device), batch[1].to(device) + optimizer.zero_grad() + + predictions = model(images) + cost = loss(predictions, labels) + cost.backward() + optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(model.state_dict(), PATH) + + output_model = flare.FLModel( + params=model.cpu().state_dict(), + meta={"NUM_STEPS_CURRENT_ROUND": steps}, ) - train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) - - summary_writer = SummaryWriter() - while flare.is_running(): - input_model = flare.receive() - print(f"current_round={input_model.current_round}") - - model.load_state_dict(input_model.params) - model.to(device) - - steps = epochs * len(train_loader) - for epoch in range(epochs): - running_loss = 0.0 - for i, batch in enumerate(train_loader): - images, labels = batch[0].to(device), batch[1].to(device) - optimizer.zero_grad() - - predictions = model(images) - cost = loss(predictions, labels) - cost.backward() - optimizer.step() - - running_loss += cost.cpu().detach().numpy() / images.size()[0] - if i % 3000 == 0: - print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") - global_step = input_model.current_round * steps + epoch * len(train_loader) + i - summary_writer.add_scalar( - tag="loss_for_each_batch", scalar=running_loss, global_step=global_step - ) - running_loss = 0.0 - - print("Finished Training") - - PATH = "./cifar_net.pth" - torch.save(model.state_dict(), PATH) - - output_model = FLModel( - params=model.cpu().state_dict(), - meta={"NUM_STEPS_CURRENT_ROUND": steps}, - ) - - flare.send(output_model) + + flare.send(output_model) if __name__ == "__main__": diff --git a/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl_v2.py b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl_v2.py new file mode 100644 index 0000000000..11ba7b9737 --- /dev/null +++ b/examples/hello-world/hello-pt/src/hello-pt_cifar10_fl_v2.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from simple_network import SimpleNetwork +from torch import nn +from torch.optim import SGD +from torch.utils.data.dataloader import DataLoader +from torchvision.datasets import CIFAR10 +from torchvision.transforms import Compose, Normalize, ToTensor + +import nvflare.client as flare +from nvflare.client import FLModel +from nvflare.client.tracking import SummaryWriter + +DATASET_PATH = "/tmp/nvflare/data" + + +def main(): + batch_size = 4 + epochs = 2 + lr = 0.01 + model = SimpleNetwork() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + loss = nn.CrossEntropyLoss() + optimizer = SGD(model.parameters(), lr=lr, momentum=0.9) + transforms = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + with flare.init() as ctx: + sys_info = flare.system_info(ctx=ctx) + client_name = sys_info["site_name"] + + train_dataset = CIFAR10( + root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True + ) + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + + summary_writer = SummaryWriter(ctx=ctx) + while flare.is_running(ctx=ctx): + input_model = flare.receive(ctx=ctx) + print(f"current_round={input_model.current_round}") + + model.load_state_dict(input_model.params) + model.to(device) + + steps = epochs * len(train_loader) + for epoch in range(epochs): + running_loss = 0.0 + for i, batch in enumerate(train_loader): + images, labels = batch[0].to(device), batch[1].to(device) + optimizer.zero_grad() + + predictions = model(images) + cost = loss(predictions, labels) + cost.backward() + optimizer.step() + + running_loss += cost.cpu().detach().numpy() / images.size()[0] + if i % 3000 == 0: + print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}") + global_step = input_model.current_round * steps + epoch * len(train_loader) + i + summary_writer.add_scalar( + tag="loss_for_each_batch", scalar=running_loss, global_step=global_step + ) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(model.state_dict(), PATH) + + output_model = FLModel( + params=model.cpu().state_dict(), + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + flare.send(output_model, ctx=ctx) + + +if __name__ == "__main__": + main() diff --git a/nvflare/client/__init__.py b/nvflare/client/__init__.py index fb8ac01ca3..50411ce3c9 100644 --- a/nvflare/client/__init__.py +++ b/nvflare/client/__init__.py @@ -19,7 +19,6 @@ from nvflare.app_common.abstract.fl_model import FLModel as FLModel from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType -from .api import FlareClientContext as FlareClientContext from .api import get_config as get_config from .api import get_job_id as get_job_id from .api import get_site_name as get_site_name @@ -31,7 +30,8 @@ from .api import log as log from .api import receive as receive from .api import send as send +from .api import shutdown as shutdown from .api import system_info as system_info from .decorator import evaluate as evaluate from .decorator import train as train -from .ipc.ipc_agent import IPCAgent as IPCAgent +from .ipc.ipc_agent import IPCAgent diff --git a/nvflare/client/api.py b/nvflare/client/api.py index fc65705fc2..9b7c2a9028 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -13,97 +13,55 @@ # limitations under the License. import logging -import os -from enum import Enum from typing import Any, Dict, Optional from nvflare.apis.analytix import AnalyticsDataType from nvflare.app_common.abstract.fl_model import FLModel -from nvflare.client.constants import CLIENT_API_CONFIG -from nvflare.fuel.data_event.data_bus import DataBus -from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec -from .ex_process.api import ExProcessClientAPI +from .api_context import APIContext +context_dict = {} +default_context = None -class ClientAPIType(Enum): - IN_PROCESS_API = "IN_PROCESS_API" - EX_PROCESS_API = "EX_PROCESS_API" - -DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}" - -client_api: Optional[APISpec] = None -data_bus = DataBus() - - -class FlareClientContext: - def __init__(self, rank: Optional[str] = None, config_file: str = None): - self.rank = rank - self.config_file = config_file if config_file else DEFAULT_CONFIG - self._client_api = None - - def __enter__(self): - """Initialize the client API in the context.""" - api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) - api_type = ClientAPIType(api_type_name) - - if not self._client_api: - global client_api - client_api = self._create_client_api(api_type) - client_api.init(rank=self.rank) - self._client_api = client_api - - return self._client_api - - def __exit__(self, exc_type, exc_val, exc_tb): - """Cleanup the client API when the context ends.""" - if self._client_api: - self._client_api.clear() - self._client_api = None - - def _create_client_api(self, api_type: ClientAPIType) -> APISpec: - """Creates a new client_api based on the provided API type.""" - if api_type == ClientAPIType.IN_PROCESS_API: - return data_bus.get_data(CLIENT_API_KEY) - else: - return ExProcessClientAPI(config_file=self.config_file) - - -def init(rank: Optional[str] = None): +def init(rank: Optional[str] = None, config_file: Optional[str] = None) -> APIContext: """Initializes NVFlare Client API environment. Args: rank (str): local rank of the process. It is only useful when the training script has multiple worker processes. (for example multi GPU) + config_file (str): client api configuration. Returns: - None + APIContext """ - api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) - api_type = ClientAPIType(api_type_name) - global client_api - if client_api is None: - if api_type == ClientAPIType.IN_PROCESS_API: - client_api = data_bus.get_data(CLIENT_API_KEY) - else: - client_api = ExProcessClientAPI(config_file=DEFAULT_CONFIG) - client_api.init(rank=rank) + global context_dict + global default_context + local_ctx = context_dict.get((rank, config_file)) + + if local_ctx is None: + local_ctx = APIContext(rank=rank, config_file=config_file) + context_dict[(rank, config_file)] = local_ctx + default_context = local_ctx else: - logging.warning("Warning: called init() more than once. The subsequence calls are ignored") + logging.warning( + "Warning: called init() more than once with same parameters." "The subsequence calls are ignored" + ) + return local_ctx -def receive(timeout: Optional[float] = None) -> Optional[FLModel]: +def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -> Optional[FLModel]: """Receives model from NVFlare side. Returns: An FLModel received. """ - global client_api - return client_api.receive(timeout) + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.receive(timeout) -def send(model: FLModel, clear_cache: bool = True) -> None: +def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = None) -> None: """Sends the model to NVFlare side. Args: @@ -112,11 +70,12 @@ def send(model: FLModel, clear_cache: bool = True) -> None: """ if not isinstance(model, FLModel): raise TypeError("model needs to be an instance of FLModel") - global client_api - return client_api.send(model, clear_cache) + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.send(model, clear_cache) -def system_info() -> Dict: +def system_info(ctx: Optional[APIContext] = None) -> Dict: """Gets NVFlare system information. System information will be available after a valid FLModel is received. @@ -129,91 +88,100 @@ def system_info() -> Dict: A dict of system information. """ - global client_api - return client_api.system_info() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.system_info() -def get_config() -> Dict: +def get_config(ctx: Optional[APIContext] = None) -> Dict: """Gets the ClientConfig dictionary. Returns: A dict of the configuration used in Client API. """ - global client_api - return client_api.get_config() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.get_config() -def get_job_id() -> str: +def get_job_id(ctx: Optional[APIContext] = None) -> str: """Gets job id. Returns: The current job id. """ - global client_api - return client_api.get_job_id() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.get_job_id() -def get_site_name() -> str: +def get_site_name(ctx: Optional[APIContext] = None) -> str: """Gets site name. Returns: The site name of this client. """ - global client_api - return client_api.get_site_name() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.get_site_name() -def get_task_name() -> str: +def get_task_name(ctx: Optional[APIContext] = None) -> str: """Gets task name. Returns: The task name. """ - global client_api - return client_api.get_task_name() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.get_task_name() -def is_running() -> bool: +def is_running(ctx: Optional[APIContext] = None) -> bool: """Returns whether the NVFlare system is up and running. Returns: True, if the system is up and running. False, otherwise. """ - global client_api - return client_api.is_running() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.is_running() -def is_train() -> bool: +def is_train(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is a training task. Returns: True, if the current task is a training task. False, otherwise. """ - global client_api - return client_api.is_train() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.is_train() -def is_evaluate() -> bool: +def is_evaluate(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is an evaluate task. Returns: True, if the current task is an evaluate task. False, otherwise. """ - global client_api - return client_api.is_evaluate() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.is_evaluate() -def is_submit_model() -> bool: +def is_submit_model(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is a submit_model task. Returns: True, if the current task is a submit_model. False, otherwise. """ - global client_api - return client_api.is_submit_model() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.is_submit_model() -def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs): +def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APIContext] = None, **kwargs): """Logs a key value pair. We suggest users use the high-level APIs in nvflare/client/tracking.py @@ -227,11 +195,20 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs): Returns: whether the key value pair is logged successfully """ - global client_api - return client_api.log(key, value, data_type, **kwargs) + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.log(key, value, data_type, **kwargs) -def clear(): +def clear(ctx: Optional[APIContext] = None): """Clears the cache.""" - global client_api - return client_api.clear() + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.clear() + + +def shutdown(ctx: Optional[APIContext] = None): + """Releases all threads and resources used by the API and stops operation.""" + global default_context + local_ctx = ctx if ctx else default_context + return local_ctx.api.shutdown() diff --git a/nvflare/client/api_context.py b/nvflare/client/api_context.py new file mode 100644 index 0000000000..d576b322d7 --- /dev/null +++ b/nvflare/client/api_context.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from enum import Enum +from typing import Optional + +from nvflare.client.constants import CLIENT_API_CONFIG +from nvflare.fuel.data_event.data_bus import DataBus + +from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec +from .ex_process.api import ExProcessClientAPI + +DEFAULT_CONFIG = f"config/{CLIENT_API_CONFIG}" +data_bus = DataBus() + + +class ClientAPIType(Enum): + IN_PROCESS_API = "IN_PROCESS_API" + EX_PROCESS_API = "EX_PROCESS_API" + + +class APIContext: + def __init__(self, rank: Optional[str] = None, config_file: str = None): + self.rank = rank + self.config_file = config_file if config_file else DEFAULT_CONFIG + + api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) + api_type = ClientAPIType(api_type_name) + self.api = self._create_client_api(api_type) + self.api.init(rank=self.rank) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Cleanup the client API when the context ends.""" + if self.api: + self.api.shutdown() + self.api = None + + def _create_client_api(self, api_type: ClientAPIType) -> APISpec: + """Creates a new client_api based on the provided API type.""" + if api_type == ClientAPIType.IN_PROCESS_API: + return data_bus.get_data(CLIENT_API_KEY) + else: + return ExProcessClientAPI(config_file=self.config_file) diff --git a/nvflare/client/api_spec.py b/nvflare/client/api_spec.py index d2a5802852..458976638e 100644 --- a/nvflare/client/api_spec.py +++ b/nvflare/client/api_spec.py @@ -277,3 +277,16 @@ def clear(self): """ pass + + @abstractmethod + def shutdown(self): + """Releases all threads and resources used by the API and stops operation. + + Example: + + .. code-block:: python + + nvflare.client.shutdown() + + """ + pass diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index f3eeaa8140..e9b1801287 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -205,3 +205,7 @@ def clear(self): model_registry = self.get_model_registry() model_registry.clear() self.receive_called = False + + def shutdown(self): + model_registry = self.get_model_registry() + model_registry.shutdown() diff --git a/nvflare/client/in_process/api.py b/nvflare/client/in_process/api.py index 83d9c4cdb1..0014d2a662 100644 --- a/nvflare/client/in_process/api.py +++ b/nvflare/client/in_process/api.py @@ -234,3 +234,7 @@ def __continue_job(self) -> bool: return False return True + + def shutdown(self): + self.stop = True + self.stop_reason = "API shutdown called." diff --git a/nvflare/client/task_registry.py b/nvflare/client/task_registry.py index 182356a340..bcfeefdaca 100644 --- a/nvflare/client/task_registry.py +++ b/nvflare/client/task_registry.py @@ -114,6 +114,6 @@ def clear(self) -> None: def __str__(self): return f"{self.__class__.__name__}(config: {self.config.get_config()})" - def __del__(self): + def shutdown(self): if self.flare_agent: self.flare_agent.stop() diff --git a/nvflare/client/tracking.py b/nvflare/client/tracking.py index e2ba7e66b7..51e506d0d1 100644 --- a/nvflare/client/tracking.py +++ b/nvflare/client/tracking.py @@ -17,10 +17,19 @@ from nvflare.apis.analytix import AnalyticsDataType from nvflare.app_common.tracking.tracker_types import LogWriterName +# flake8: noqa +from .api import default_context as default_context from .api import log +from .api_context import APIContext -class SummaryWriter: +class _BaseWriter: + def __init__(self, ctx: Optional[APIContext] = None): + global default_context + self.ctx = ctx if ctx else default_context + + +class SummaryWriter(_BaseWriter): """SummaryWriter mimics the usage of Tensorboard's SummaryWriter. Users can replace the import of Tensorboard's SummaryWriter with FLARE's SummaryWriter. @@ -43,6 +52,7 @@ def add_scalar(self, tag: str, scalar: float, global_step: Optional[int] = None, data_type=AnalyticsDataType.SCALAR, global_step=global_step, writer=LogWriterName.TORCH_TB, + ctx=self.ctx, **kwargs, ) @@ -61,6 +71,7 @@ def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None data_type=AnalyticsDataType.SCALARS, global_step=global_step, writer=LogWriterName.TORCH_TB, + ctx=self.ctx, **kwargs, ) @@ -69,7 +80,7 @@ def flush(self): pass -class WandBWriter: +class WandBWriter(_BaseWriter): """WandBWriter mimics the usage of weights and biases. Users can replace the import of wandb with FLARE's WandBWriter. @@ -90,10 +101,11 @@ def log(self, metrics: Dict[str, float], step: Optional[int] = None): data_type=AnalyticsDataType.METRICS, global_step=step, writer=LogWriterName.WANDB, + ctx=self.ctx, ) -class MLflowWriter: +class MLflowWriter(_BaseWriter): """MLflowWriter mimics the usage of MLflow. Users can replace the import of MLflow with FLARE's MLflowWriter. @@ -113,7 +125,7 @@ def log_param(self, key: str, value: any) -> None: All backend stores support values up to length 500, but some may support larger values. """ - log(key=key, value=value, data_type=AnalyticsDataType.PARAMETER, writer=LogWriterName.MLFLOW) + log(key=key, value=value, data_type=AnalyticsDataType.PARAMETER, writer=LogWriterName.MLFLOW, ctx=self.ctx) def log_params(self, values: dict) -> None: """Log a batch of params for the current run. @@ -121,7 +133,13 @@ def log_params(self, values: dict) -> None: Args: values (dict): Dictionary of param_name: String -> value: (String, but will be string-ified if not) """ - log(key="params", value=values, data_type=AnalyticsDataType.PARAMETERS, writer=LogWriterName.MLFLOW) + log( + key="params", + value=values, + data_type=AnalyticsDataType.PARAMETERS, + writer=LogWriterName.MLFLOW, + ctx=self.ctx, + ) def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None: """Log a metric under the current run. @@ -136,7 +154,14 @@ def log_metric(self, key: str, value: float, step: Optional[int] = None) -> None support larger values. step (int, optional): Metric step. Defaults to zero if unspecified. """ - log(key=key, value=value, data_type=AnalyticsDataType.METRIC, global_step=step, writer=LogWriterName.MLFLOW) + log( + key=key, + value=value, + data_type=AnalyticsDataType.METRIC, + global_step=step, + writer=LogWriterName.MLFLOW, + ctx=self.ctx, + ) def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: """Log multiple metrics for the current run. @@ -154,6 +179,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> data_type=AnalyticsDataType.METRICS, global_step=step, writer=LogWriterName.MLFLOW, + ctx=self.ctx, ) def log_text(self, text: str, artifact_file_path: str) -> None: @@ -170,6 +196,7 @@ def log_text(self, text: str, artifact_file_path: str) -> None: data_type=AnalyticsDataType.TEXT, path=artifact_file_path, writer=LogWriterName.MLFLOW, + ctx=self.ctx, ) def set_tag(self, key: str, tag: any) -> None: @@ -181,7 +208,7 @@ def set_tag(self, key: str, tag: any) -> None: All backend stores will support values up to length 5000, but some may support larger values. """ - log(key=key, value=tag, data_type=AnalyticsDataType.TAG, writer=LogWriterName.MLFLOW) + log(key=key, value=tag, data_type=AnalyticsDataType.TAG, writer=LogWriterName.MLFLOW, ctx=self.ctx) def set_tags(self, tags: dict) -> None: """Log a batch of tags for the current run. @@ -190,4 +217,4 @@ def set_tags(self, tags: dict) -> None: tags (dict): Dictionary of tag_name: String -> value: (String, but will be string-ified if not) """ - log(key="tags", value=tags, data_type=AnalyticsDataType.TAGS, writer=LogWriterName.MLFLOW) + log(key="tags", value=tags, data_type=AnalyticsDataType.TAGS, writer=LogWriterName.MLFLOW, ctx=self.ctx)