diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index ae71e409fa..8cee919cca 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -40,7 +40,8 @@ def serialize_cert(cert): def load_crt(path): - return load_crt_bytes(open(path, "rb").read()) + with open(path, "rb") as f: + return load_crt_bytes(f.read()) def load_crt_bytes(data: bytes): @@ -116,17 +117,19 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999): for file in files: if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE: continue - signatures[file] = sign_content( - content=open(os.path.join(root, file), "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(os.path.join(root, file), "rb") as f: + signatures[file] = sign_content( + content=f.read(), + signing_pri_key=signing_pri_key, + ) for folder in folders: signatures[folder] = sign_content( content=folder, signing_pri_key=signing_pri_key, ) - json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "wt") as f: + json.dump(signatures, f) shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) if depth >= max_depth: break @@ -138,7 +141,8 @@ def verify_folder_signature(src_folder, root_ca_path): root_ca_public_key = root_ca_cert.public_key() for root, folders, files in os.walk(src_folder): try: - signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt")) + with open(os.path.join(root, NVFLARE_SIG_FILE), "rt") as f: + signatures = json.load(f) cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) public_key = cert.public_key() except: @@ -150,11 +154,12 @@ def verify_folder_signature(src_folder, root_ca_path): continue signature = signatures.get(file) if signature: - verify_content( - content=open(os.path.join(root, file), "rb").read(), - signature=signature, - public_key=public_key, - ) + with open(os.path.join(root, file), "rb") as f: + verify_content( + content=f.read(), + signature=signature, + public_key=public_key, + ) for folder in folders: signature = signatures.get(folder) if signature: @@ -173,20 +178,52 @@ def sign_all(content_folder, signing_pri_key): for f in os.listdir(content_folder): path = os.path.join(content_folder, f) if os.path.isfile(path): - signatures[f] = sign_content( - content=open(path, "rb").read(), - signing_pri_key=signing_pri_key, - ) + with open(path, "rb") as file: + signatures[f] = sign_content( + content=file.read(), + signing_pri_key=signing_pri_key, + ) return signatures def load_yaml(file): + + root = os.path.split(file)[0] + yaml_data = None if isinstance(file, str): - return yaml.safe_load(open(file, "r")) + with open(file, "r") as f: + yaml_data = yaml.safe_load(f) elif isinstance(file, bytes): - return yaml.safe_load(file) - else: - return None + yaml_data = yaml.safe_load(file) + + yaml_data = load_yaml_include(root, yaml_data) + + return yaml_data + + +def load_yaml_include(root, yaml_data): + new_data = {} + for k, v in yaml_data.items(): + if k == "include": + if isinstance(v, str): + includes = [v] + elif isinstance(v, list): + includes = v + for item in includes: + new_data.update(load_yaml(os.path.join(root, item))) + elif isinstance(v, list): + new_list = [] + for item in v: + if isinstance(item, dict): + item = load_yaml_include(root, item) + new_list.append(item) + new_data[k] = new_list + elif isinstance(v, dict): + new_data[k] = load_yaml_include(root, v) + else: + new_data[k] = v + + return new_data def sh_replace(src, mapping_dict): diff --git a/tests/unit_test/lighter/0.yml b/tests/unit_test/lighter/0.yml new file mode 100644 index 0000000000..bc9ee7e07c --- /dev/null +++ b/tests/unit_test/lighter/0.yml @@ -0,0 +1,15 @@ +api_version: 3 +name: example_project + +include: 1.yml + +participants: + - name: server + port: 123 + include: [1.yml] + extra: + location: "east" + include: 3.yml + - name: client + port: 234 + include: 2.yml diff --git a/tests/unit_test/lighter/1.yml b/tests/unit_test/lighter/1.yml new file mode 100644 index 0000000000..4dece82e9e --- /dev/null +++ b/tests/unit_test/lighter/1.yml @@ -0,0 +1 @@ +server_name: server \ No newline at end of file diff --git a/tests/unit_test/lighter/2.yml b/tests/unit_test/lighter/2.yml new file mode 100644 index 0000000000..18d2519c61 --- /dev/null +++ b/tests/unit_test/lighter/2.yml @@ -0,0 +1 @@ +client_name: client-1 \ No newline at end of file diff --git a/tests/unit_test/lighter/3.yml b/tests/unit_test/lighter/3.yml new file mode 100644 index 0000000000..08117a92cc --- /dev/null +++ b/tests/unit_test/lighter/3.yml @@ -0,0 +1,2 @@ +size: 4 +gpus: large \ No newline at end of file diff --git a/tests/unit_test/lighter/utils_test.py b/tests/unit_test/lighter/utils_test.py index f1bdb7266a..60c3eaa9c8 100644 --- a/tests/unit_test/lighter/utils_test.py +++ b/tests/unit_test/lighter/utils_test.py @@ -25,7 +25,7 @@ from cryptography.x509.oid import NameOID from nvflare.lighter.impl.cert import serialize_cert -from nvflare.lighter.utils import sign_folders, verify_folder_signature +from nvflare.lighter.utils import load_yaml, sign_folders, verify_folder_signature folders = ["folder1", "folder2"] files = ["file1", "file2"] @@ -144,3 +144,21 @@ def test_verify_updated_folder(self): os.unlink("client.crt") os.unlink("root.crt") shutil.rmtree(folder) + + def _get_participant(self, name, participants): + for p in participants: + if p.get("name") == name: + return p + + def test_load_yaml(self): + dir_path = os.path.dirname(os.path.realpath(__file__)) + data = load_yaml(os.path.join(dir_path, "0.yml")) + + assert data.get("server_name") == "server" + + participant = self._get_participant("server", data.get("participants")) + assert participant.get("server_name") == "server" + assert participant.get("extra").get("gpus") == "large" + + participant = self._get_participant("client", data.get("participants")) + assert participant.get("client_name") == "client-1"