Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/autoplex/auto/rss/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def make(self, **kwargs):
- 'test_error': float, The test error from the last completed training step.
- 'pre_database_dir': str, Path to the directory containing the pre-existing database for resuming.
- 'mlip_path': str, Path to the file of a previous MLIP model.
- 'mlip_path': str | Path, Path to the file of a previous MLIP model.
- 'isolated_atom_energies': dict, A dictionary with isolated atom energy values mapped to atomic numbers.
generated_struct_numbers: list[int]
Expand Down Expand Up @@ -286,7 +286,7 @@ def make(self, **kwargs):
- 'test_error': float, The test error of the fitted MLIP.
- 'pre_database_dir': str, The directory of the latest RSS database.
- 'mlip_path': List of path to the latest fitted MLIP.
- 'mlip_path': str | Path, Path to the latest fitted MLIP.
- 'isolated_atom_energies': dict, The isolated energy values.
- 'current_iter': int, The current iteration index.
- 'kb_temp': float, The temperature (in eV) for Boltzmann sampling.
Expand Down
29 changes: 17 additions & 12 deletions src/autoplex/auto/rss/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def initial_rss(

- 'test_error': float, The test error of the fitted MLIP.
- 'pre_database_dir': str, The directory of the preprocessed database.
- 'mlip_path': List of path to the fitted MLIP.
- 'mlip_path': Path to the fitted MLIP.
- 'isolated_atom_energies': dict, The isolated energy values.
- 'current_iter': int, The current iteration index, set to 0.
"""
Expand Down Expand Up @@ -257,17 +257,18 @@ def initial_rss(
do_dft_static = DFTStaticLabelling(
e0_spin=e0_spin,
isolatedatom_box=isolatedatom_box,
isolated_atom=include_isolated_atom,
dimer=include_dimer,
include_isolated_atom=include_isolated_atom,
include_dimer=include_dimer,
dimer_box=dimer_box,
dimer_range=dimer_range,
dimer_num=dimer_num,
custom_incar=custom_incar,
custom_potcar=custom_potcar,
static_energy_maker=static_energy_maker,
static_energy_maker_isolated_atoms=static_energy_maker_isolated_atoms,
config_type=config_type,
).make(
structures=do_randomized_structure_generation.output, config_type=config_type
structures=do_randomized_structure_generation.output,
)
do_data_collection = collect_dft_data(
dft_ref_file=dft_ref_file, rss_group=rss_group, dft_dirs=do_dft_static.output
Expand Down Expand Up @@ -315,7 +316,7 @@ def initial_rss(
output={
"test_error": do_mlip_fit.output["test_error"],
"pre_database_dir": do_data_preprocessing.output,
"mlip_path": do_mlip_fit.output["mlip_path"],
"mlip_path": do_mlip_fit.output["mlip_path"][0],
"isolated_atom_energies": do_data_collection.output[
"isolated_atom_energies"
],
Expand Down Expand Up @@ -410,8 +411,8 @@ def do_rss_iterations(
The test error of the fitted MLIP.
pre_database_dir: str
The directory of the preprocessed database.
mlip_path: list[str]
List of path to the fitted MLIP.
mlip_path: str | path
Path to the fitted MLIP.
isolated_atom_energies: dict
The isolated energy values.
current_iter: int
Expand Down Expand Up @@ -573,7 +574,7 @@ def do_rss_iterations(

- 'test_error': float, The test error of the fitted MLIP.
- 'pre_database_dir': str, The directory of the preprocessed database.
- 'mlip_path': List of path to the fitted MLIP.
- 'mlip_path': Path to the fitted MLIP.
- 'isolated_atom_energies': dict, The isolated energy values.
- 'current_iter': int, The current iteration index.
- 'kt': float, The temperature (in eV) for Boltzmann sampling.
Expand Down Expand Up @@ -650,23 +651,27 @@ def do_rss_iterations(
num_of_selection=num_of_rss_selected_structs,
bcur_params=bcur_params,
traj_path=do_rss.output,
traj_type="rss",
random_seed=random_seed,
isolated_atom_energies=input["isolated_atom_energies"],
remove_traj_files=remove_traj_files,
)
do_dft_static = DFTStaticLabelling(
e0_spin=e0_spin,
isolatedatom_box=isolatedatom_box,
isolated_atom=include_isolated_atom,
dimer=include_dimer,
include_isolated_atom=include_isolated_atom,
include_dimer=include_dimer,
dimer_box=dimer_box,
dimer_range=dimer_range,
dimer_num=dimer_num,
custom_incar=custom_incar,
custom_potcar=custom_potcar,
static_energy_maker=static_energy_maker,
static_energy_maker_isolated_atoms=static_energy_maker_isolated_atoms,
).make(structures=do_data_sampling.output, config_type=config_type)
config_type=config_type,
).make(
structures=do_data_sampling.output,
)
do_data_collection = collect_dft_data(
dft_ref_file=dft_ref_file,
rss_group=rss_group,
Expand Down Expand Up @@ -713,7 +718,7 @@ def do_rss_iterations(
input={
"test_error": do_mlip_fit.output["test_error"],
"pre_database_dir": do_data_preprocessing.output,
"mlip_path": do_mlip_fit.output["mlip_path"],
"mlip_path": do_mlip_fit.output["mlip_path"][0],
"isolated_atom_energies": input["isolated_atom_energies"],
"current_iter": current_iter,
"kt": kt,
Expand Down
Loading
Loading