From 210d431e1a8296dc4e73114f4facb4dbde0712b5 Mon Sep 17 00:00:00 2001 From: qingzhao chu <466363575@qq.com> Date: Thu, 17 Sep 2020 10:49:15 +0800 Subject: [PATCH] Update collect.py fix multi-system --- dpgen/collect/collect.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/dpgen/collect/collect.py b/dpgen/collect/collect.py index a8d5b4060..5e48c7987 100644 --- a/dpgen/collect/collect.py +++ b/dpgen/collect/collect.py @@ -24,8 +24,12 @@ def collect_data(target_folder, param_file, output, init_data = [] init_data_prefix = jdata.get('init_data_prefix', '') init_data_sys = jdata.get('init_data_sys', []) - for ii in init_data_sys: - init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii), fmt='deepmd/npy')) +    for ii in init_data_sys: +        if jdata.get('init_multi_systems', False): +            for single_sys in os.listdir(os.path.join(init_data_prefix, ii)): +                init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii, single_sys), fmt='deepmd/npy')) +        else: +            init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii), fmt='deepmd/npy')) # collect systems from iter dirs coll_data = {} numb_sys = len(sys_configs)