-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* bionemo_examples * update readme and restructure folders * update model paths * add result figures * add result files * add license headers * formatting * add license * add license * update notebook * update nvflare version
- Loading branch information
1 parent
81caf29
commit 04e3511
Showing
193 changed files
with
14,255 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# BioNeMo code | ||
bionemo | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# BioNeMo | ||
|
||
[BioNeMo](https://www.nvidia.com/en-us/clara/bionemo/) is NVIDIA's generative AI platform for drug discovery. | ||
|
||
This directory contains examples of running BioNeMo in a federated learning environment using [NVFlare](https://github.com/NVIDIA/NVFlare). | ||
|
||
1. The [task_fitting](./task_fitting/README.md) example includes a notebook that shows how to obtain protein learned representations in the form of embeddings using the ESM-1nv pre-trained model. | ||
The model is trained with NVIDIA's BioNeMo framework for Large Language Model training and inference. | ||
2. The [downstream](./downstream/README.md) example shows three different downstream tasks for fine-tuning a BioNeMo ESM-style model. | ||
|
||
## Requirements | ||
|
||
Download and run the latest [BioNeMo docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/clara/containers/bionemo-framework). | ||
|
||
We recommend following the [Quickstart Guide](https://docs.nvidia.com/bionemo-framework/latest/quickstart-fw.html#docker-container-access) | ||
on how to get the BioNeMo container. | ||
|
||
First, copy the NeMo code to a local directory and configure the launch script so that downloaded models can be reused | ||
```commandline | ||
CONTAINER="nvcr.io/nvidia/clara/bionemo-framework:latest" | ||
DEST_PATH="." | ||
CONTAINER_NAME=bionemo | ||
docker run --name $CONTAINER_NAME -itd --rm $CONTAINER bash | ||
docker cp $CONTAINER_NAME:/opt/nvidia/bionemo $DEST_PATH | ||
docker kill $CONTAINER_NAME | ||
``` | ||
|
||
Next, download the pre-trained models using | ||
```commandline | ||
cd ./bionemo | ||
./launch.sh download | ||
cd .. | ||
``` | ||
|
||
Then, start the container and Jupyter Lab to run the NVFlare experiments with NVFlare using | ||
```commandline | ||
./start_bionemo.sh | ||
``` | ||
|
||
**Note:** The examples here were tested with `nvcr.io/nvidia/clara/bionemo-framework:1.0` | ||
|
||
For information about how to get started with BioNeMo refer to the [documentation](https://docs.nvidia.com/bionemo-framework/latest). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Federated BioNeMo with NVFlare | ||
|
||
## 1. Install requirements | ||
|
||
Follow the instructions provide [here](../README.md#requirements) on how to start the BioNeMo container. | ||
|
||
Inside the container, install nvflare: `pip install nvflare~=2.4.0 PyTDC` | ||
|
||
## 2. Run examples | ||
|
||
The example datasets used here are made available by [Therapeutics Data Commons](https://tdcommons.ai/) through PyTDC. | ||
|
||
### 2.1. Cross-endpoint multi-task fitting | ||
|
||
#### Data: “Five computational developability guidelines for therapeutic antibody profiling” | ||
See https://tdcommons.ai/single_pred_tasks/develop/#tap | ||
- 241 Antibodies (both chains) | ||
|
||
#### Task Description: *Regression*. | ||
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain. | ||
|
||
Includes five metrics measuring developability of an antibody: | ||
- Complementarity-determining regions (CDR) length - Trivial (excluded) | ||
- patches of surface hydrophobicity (PSH) | ||
- patches of positive charge (PPC) | ||
- patches of negative charge (PNC) | ||
- structural Fv charge symmetry parameter (SFvCSP) | ||
|
||
#### Download and prepare the data | ||
```commandline | ||
python prepare_tap_data.py | ||
``` | ||
In the data preparation script, one can choose between uniform sampling of the data among clients and | ||
heterogeneous data splits using a Dirichlet sampling strategy. | ||
Here, different values of alpha control the level of heterogeneity. Below, we show a Dirichlet sampling of `alpha=1`. | ||
|
||
| Uniform sampling | Dirichlet sampling | | ||
|:-------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:| | ||
| <img src="./tap/figs/tap_uniform.svg" alt="Uniform data sampling" width="150"/> | <img src="./tap/figs/tap_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> | | ||
|
||
#### Run training (central, local, & FL) | ||
```commandline | ||
python run_sim_tap.py | ||
``` | ||
|
||
#### Results with uniform data sampling | ||
<img src="./tap/figs/tap_uniform_results.svg" alt="Results on TAP with uniform sampling" width="300"/> | ||
|
||
#### Results with heterogeneous data sampling | ||
<img src="./tap/figs/tap_alpha1.0_results.svg" alt="Results on TAP with heterogeneous sampling" width="300"/> | ||
|
||
### 2.2. Cross-compound task fitting | ||
|
||
#### Data: “Predicting Antibody Developability from Sequence using Machine Learning” | ||
See https://tdcommons.ai/single_pred_tasks/develop/#sabdab-chen-et-al | ||
- 2,409 Antibodies (both chains) | ||
|
||
#### Task Description: *Binary classification*. | ||
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain. | ||
|
||
#### Download and prepare the data | ||
```commandline | ||
python prepare_sabdab_data.py | ||
``` | ||
Again, we are using the Dirichlet sampling strategy to generate heterogeneous data distributions among clients. | ||
Lower values of `alpha` generate higher levels of heterogeneity. | ||
|
||
| Alpha 10.0 | Alpha 1.0 | | ||
|:-------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:| | ||
| <img src="./sabdab/figs/sabdab_alpha10.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="150"/> | <img src="./sabdab/figs/sabdab_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> | | ||
|
||
|
||
#### Run training (central, local, & FL) | ||
```commandline | ||
python run_sim_sabdab.py | ||
``` | ||
#### Results with heterogeneous data sampling (alpha=10.0) | ||
| Setting | Accuracy | | ||
|:-------:|:---------:| | ||
| Local | 0.821 | | ||
| FL | **0.833** | | ||
|
||
#### Results with heterogeneous data sampling (alpha=1.0) | ||
| Setting | Accuracy | | ||
|:-------:|:---------:| | ||
| Local | 0.813 | | ||
| FL | **0.835** | | ||
|
||
### 2.3 Subcellular location prediction with ESM2nv 650M | ||
Follow the data download and preparation in [task_fitting.ipynb](../task_fitting/task_fitting.ipynb). | ||
|
||
Here, we use a heterogeneous sampling with `alpha=1.0`. | ||
|
||
<img src="./scl/figs/scl_alpha1.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="300"/> | ||
|
||
#### Run training (local FL) | ||
```commandline | ||
python run_sim_scl.py | ||
``` | ||
|
||
#### Results with heterogeneous data sampling (alpha=10.0) | ||
| Setting | Accuracy | | ||
|:-------:|:---------:| | ||
| Local | 0.773 | | ||
| FL | **0.776** | | ||
|
||
<img src="./scl/figs/scl_results.svg" alt="Dirichlet sampling (alpha=1.0)" width="300"/> |
1 change: 1 addition & 0 deletions
1
examples/advanced/bionemo/downstream/sabdab/figs/sabdab_alpha1.0.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions
1
examples/advanced/bionemo/downstream/sabdab/figs/sabdab_alpha10.0.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 94 additions & 0 deletions
94
...ed/bionemo/downstream/sabdab/jobs/central_sabdab_esm1nv/app/config/config_fed_client.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# This is the application script which will be invoked. Client can replace this script with user's own training script. | ||
app_script = "downstream_flip.py" | ||
|
||
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. | ||
app_config = "" | ||
|
||
# Additional arguments needed by DDP. | ||
#ddp_config = "--nnodes=1 --nproc_per_node=1 --master_port=7777" | ||
|
||
# Client Computing Executors. | ||
executors = [ | ||
{ | ||
# tasks the executors are defined to handle | ||
tasks = ["train"] | ||
|
||
# This particular executor | ||
executor { | ||
|
||
# This is an executor for pytorch + Client API. The underline data exchange is using Pipe. | ||
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" | ||
|
||
args { | ||
# launcher_id is used to locate the Launcher object in "components" | ||
launcher_id = "launcher" | ||
|
||
# pipe_id is used to locate the Pipe object in "components" | ||
pipe_id = "pipe" | ||
|
||
# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. | ||
# Please refer to the class docstring for all available arguments | ||
heartbeat_timeout = 60 | ||
|
||
# format of the exchange parameters | ||
params_exchange_format = "pytorch" | ||
|
||
# if the transfer_type is FULL, then it will be sent directly | ||
# if the transfer_type is DIFF, then we will calculate the | ||
# difference VS received parameters and send the difference | ||
params_transfer_type = "FULL" | ||
|
||
# if train_with_evaluation is true, the executor will expect | ||
# the custom code need to send back both the trained parameters and the evaluation metric | ||
# otherwise only trained parameters are expected | ||
train_with_evaluation = false | ||
} | ||
} | ||
} | ||
], | ||
|
||
# this defined an array of task data filters. If provided, it will control the data from server controller to client executor | ||
task_data_filters = [] | ||
|
||
# this defined an array of task result filters. If provided, it will control the result from client executor to server controller | ||
task_result_filters = [] | ||
|
||
components = [ | ||
{ | ||
# component id is "launcher" | ||
id = "launcher" | ||
|
||
# the class path of this component | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
|
||
args { | ||
# the launcher will invoke the script | ||
#script = "python3 -m torch.distributed.run {ddp_config} custom/{app_script} {app_config} " | ||
script = "python3 custom/{app_script} {app_config} " | ||
# if launch_once is true, the SubprocessLauncher will launch once for the whole job | ||
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server | ||
launch_once = true | ||
} | ||
} | ||
{ | ||
id = "pipe" | ||
|
||
path = "nvflare.fuel.utils.pipe.file_pipe.FilePipe" | ||
|
||
args { | ||
# Mode of the endpoint. A pipe has two endpoints. | ||
# An endpoint can be either the one that initiates communication or the one listening. | ||
# PASSIVE is the one listening. | ||
mode = "PASSIVE" | ||
|
||
# root_path: is the directory location of the parameters exchange. | ||
# You can also set it to an absolute path in your system. | ||
root_path = "{WORKSPACE}/{JOB_ID}/{SITE_NAME}" | ||
} | ||
} | ||
] | ||
} |
110 changes: 110 additions & 0 deletions
110
...ed/bionemo/downstream/sabdab/jobs/central_sabdab_esm1nv/app/config/config_fed_server.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# task data filter: if filters are provided, the filter will filter the data flow out of server to client. | ||
task_data_filters =[] | ||
|
||
# task result filter: if filters are provided, the filter will filter the result flow out of client to server. | ||
task_result_filters = [] | ||
|
||
# This assumes that there will be a "net.py" file with class name "Net". | ||
# If your model code is not in "net.py" and class name is not "Net", please modify here | ||
#model_class_path = "nemo_nvflare.peft_model.PEFTmodel" | ||
|
||
# Location of pre-trained NeMo model file. | ||
#restore_from_path = "/models/megatron_gpt_345m.nemo" | ||
|
||
# Location of pre-trained peft model file. | ||
#peft_restore_from_path = null | ||
|
||
# workflows: Array of workflows the control the Federated Learning workflow lifecycle. | ||
# One can specify multiple workflows. The NVFLARE will run them in the order specified. | ||
workflows = [ | ||
{ | ||
# 1st workflow" | ||
id = "scatter_and_gather" | ||
|
||
# name = ScatterAndGather, path is the class path of the ScatterAndGather controller. | ||
path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" | ||
args { | ||
# argument of the ScatterAndGather class. | ||
# min number of clients required for ScatterAndGather controller to move to the next round | ||
# during the workflow cycle. The controller will wait until the min_clients returned from clients | ||
# before move to the next step. | ||
min_clients = 1 | ||
|
||
# number of global round of the training. | ||
num_rounds = 1 | ||
|
||
# starting round is 0-based | ||
start_round = 0 | ||
|
||
# after received min number of clients' result, | ||
# how much time should we wait further before move to the next step | ||
wait_time_after_min_received = 0 | ||
|
||
# For ScatterAndGather, the server will aggregate the weights based on the client's result. | ||
# the aggregator component id is named here. One can use the this ID to find the corresponding | ||
# aggregator component listed below | ||
# | ||
aggregator_id = "aggregator" | ||
|
||
# The Scatter and Gather controller use an persistor to load the model and save the model. | ||
# The persistent component can be identified by component ID specified here. | ||
#persistor_id = "persistor" | ||
|
||
# Shareable to a communication message, i.e. shared between clients and server. | ||
# Shareable generator is a component that responsible to take the model convert to/from this communication message: sharable. | ||
# The component can be identified via "shareable_generator_id" | ||
shareable_generator_id = "shareable_generator" | ||
|
||
# train task name: Client will start training once received such task. | ||
train_task_name = "train" | ||
|
||
# train timeout in second. If zero, meaning no timeout. | ||
train_timeout = 0 | ||
} | ||
} | ||
] | ||
|
||
# List of components used in the server side workflow. | ||
components = [ | ||
#{ | ||
# This is the persistence component used in above workflow. | ||
# PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. | ||
|
||
# id = "persistor" | ||
# path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" | ||
|
||
# the persistor class take model class as argument | ||
# This imply that the model is initialized from the server-side. | ||
# The initialized model will be broadcast to all the clients to start the training. | ||
# args.model.path = "{model_class_path}" | ||
# args.model.args.restore_from_path = "{restore_from_path}" | ||
# args.model.args.peft_restore_from_path = "{peft_restore_from_path}" | ||
#}, | ||
{ | ||
# This is the generator that convert the model to shareable communication message structure used in workflow | ||
id = "shareable_generator" | ||
path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" | ||
args = {} | ||
}, | ||
{ | ||
# This is the aggregator that perform the weighted average aggregation. | ||
# the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data. | ||
id = "aggregator" | ||
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" | ||
args.expected_data_kind = "WEIGHTS" | ||
}, | ||
{ | ||
# This component is not directly used in Workflow. | ||
# it select the best model based on the incoming global validation metrics. | ||
id = "model_selector" | ||
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" | ||
# need to make sure this "key_metric" match what server side received | ||
args.key_metric = "validation_exact_string_match" | ||
} | ||
] | ||
|
||
} |
Oops, something went wrong.