From e4c2572a3c3ef87f959732dd40a9818975ef4f45 Mon Sep 17 00:00:00 2001 From: Roy Haolin Du Date: Fri, 7 Nov 2025 17:17:59 +0000 Subject: [PATCH 1/3] Make gmx configuration and ensure successful run --- a3fe/__init__.py | 26 +- a3fe/configuration/__init__.py | 9 +- a3fe/configuration/engine_config.py | 739 ++++++++++++++++++++++- a3fe/configuration/enums.py | 30 +- a3fe/configuration/slurm_config.py | 21 +- a3fe/configuration/system_prep_config.py | 47 +- a3fe/run/leg.py | 115 ++-- a3fe/run/simulation.py | 119 ++-- a3fe/run/stage.py | 20 +- a3fe/run/system_prep.py | 25 +- devtools/conda-envs/dev_env.yaml | 1 + 11 files changed, 1009 insertions(+), 143 deletions(-) diff --git a/a3fe/__init__.py b/a3fe/__init__.py index 461336f3..2559d58e 100644 --- a/a3fe/__init__.py +++ b/a3fe/__init__.py @@ -13,26 +13,26 @@ import warnings as _warnings from ._version import __version__ -from .run import ( - CalcSet, - Calculation, - LamWindow, - Leg, - Simulation, - Stage, -) - from .configuration import ( - SlurmConfig, - _EngineConfig, - SomdConfig, EngineType, + GromacsSystemPreparationConfig, JobStatus, LegType, PreparationStage, + SlurmConfig, + SomdConfig, + SomdSystemPreparationConfig, StageType, + _EngineConfig, enums, - SomdSystemPreparationConfig, +) +from .run import ( + CalcSet, + Calculation, + LamWindow, + Leg, + Simulation, + Stage, ) # A3FE can open many files due to the use of multiprocessing and diff --git a/a3fe/configuration/__init__.py b/a3fe/configuration/__init__.py index b2eefad4..19da9634 100644 --- a/a3fe/configuration/__init__.py +++ b/a3fe/configuration/__init__.py @@ -1,9 +1,10 @@ """Pydantic configuration classes for the a3fe package.""" +from .engine_config import GromacsConfig, SomdConfig, _EngineConfig +from .enums import EngineType, JobStatus, LegType, PreparationStage, StageType +from .slurm_config import SlurmConfig from .system_prep_config import ( - _BaseSystemPreparationConfig, + GromacsSystemPreparationConfig, SomdSystemPreparationConfig, + _BaseSystemPreparationConfig, ) -from .slurm_config import SlurmConfig -from .engine_config import _EngineConfig, SomdConfig -from .enums import EngineType, JobStatus, LegType, PreparationStage, StageType diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index e114e656..340993a1 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -2,26 +2,40 @@ __all__ = [ "SomdConfig", + "GromacsConfig", ] import os as _os +from abc import ABC as _ABC +from abc import abstractmethod as _abstractmethod from decimal import Decimal as _Decimal from typing import ( Dict as _Dict, - Literal as _Literal, +) +from typing import ( List as _List, - Union as _Union, +) +from typing import ( + Literal as _Literal, +) +from typing import ( Optional as _Optional, ) +from typing import ( + Union as _Union, +) + +import yaml as _yaml from pydantic import ( BaseModel as _BaseModel, +) +from pydantic import ( Field as _Field, +) +from pydantic import ( model_validator as _model_validator, ) -import yaml as _yaml -from abc import ABC as _ABC, abstractmethod as _abstractmethod - class _EngineConfig(_BaseModel, _ABC): """Base class for engine runner configurations.""" @@ -96,6 +110,12 @@ def get_run_cmd(self) -> str: """ pass + def setup_lambda_arrays(self, stage_type) -> None: + """ + Make sure GROMACS has the same stages as SOMD. + """ + pass + class SomdConfig(_EngineConfig): """ @@ -431,3 +451,712 @@ def get_run_cmd(self, lam: float) -> str: Get the command to run the simulation. """ return f"somd-freenrg -C {self.get_file_name()} -l {lam} -p CUDA" + + +class GromacsConfig(_EngineConfig): + """ + Pydantic model for holding GROMACS engine configuration. + Based on fragment-opt-abfe-benchmark mdp format with 4fs timestep. + """ + + ### Simulation Type ### + mdp_type: _Literal["em", "nvt", "npt", "npt-norest", "prod"] = _Field( + "prod", description="Type of simulation" + ) + + ### Run Control ### + define: _Optional[str] = _Field( + None, description="Preprocessor defines (e.g., -DPOSRES, -DFLEXIBLE)" + ) + integrator: _Literal["sd", "steep", "md"] = _Field( + "sd", description="Integrator type" + ) + nsteps: int = _Field(5000000, description="Number of MD steps") + dt: float = _Field(0.002, description="Timestep in ps (2 fs)") + comm_mode: _Literal["Linear", "Angular", "None"] = _Field( + "Linear", description="Center of mass motion removal" + ) + nstcomm: int = _Field(50, description="Frequency for COM removal") + emtol: _Optional[float] = _Field( + None, description="Energy minimization tolerance (for EM)" + ) + emstep: _Optional[float] = _Field(None, description="Initial step size for EM") + pbc: _Literal["xyz", "xy", "no"] = _Field("xyz", description="PBC type") + + ### Output Control ### + nstxout: int = _Field(0, description="Write coords to .trr") + nstvout: int = _Field(0, description="Write velocities to .trr") + nstfout: int = _Field(0, description="Write forces to .trr") + nstxout_compressed: int = _Field(500, description="Write compressed trajectory") + compressed_x_precision: int = _Field(1000, description="Compressed precision") + nstlog: int = _Field(500, description="Update log file") + nstenergy: int = _Field(500, description="Save energies") + nstcalcenergy: int = _Field(50, description="Calculate energies") + + ### Bonds ### + constraint_algorithm: _Literal["lincs", "shake"] = _Field( + "lincs", description="Constraint algorithm" + ) + constraints: _Literal["none", "h-bonds", "all-bonds"] = _Field( + "h-bonds", description="Constraint type (none/h-bonds/all-bonds)" + ) + lincs_iter: int = _Field(1, description="LINCS iterations") + lincs_order: int = _Field(6, description="LINCS order") + lincs_warnangle: int = _Field(30, description="LINCS warning angle") + continuation: _Literal["yes", "no"] = _Field("yes", description="Continuation") + + ### Neighbor Searching ### + cutoff_scheme: _Literal["Verlet", "group"] = _Field( + "Verlet", description="Cutoff scheme" + ) + ns_type: _Literal["grid", "simple"] = _Field("grid", description="Neighbor search") + nstlist: int = _Field(20, description="Update neighbor list") + rlist: float = _Field(1.2, description="Neighbor list cutoff (nm)") + + ### Electrostatics ### + coulombtype: _Literal["PME", "Cut-off"] = _Field("PME", description="Coulomb type") + rcoulomb: float = _Field(1.0, description="Coulomb cutoff (nm)") + ewald_geometry: _Literal["3d", "3dc"] = _Field("3d", description="Ewald geometry") + pme_order: int = _Field(4, description="PME order") + fourierspacing: float = _Field(0.10, description="PME grid spacing (nm)") + ewald_rtol: float = _Field(1e-6, description="Ewald tolerance") + + ### VDW ### + vdwtype: _Literal["Cut-off", "PME"] = _Field("Cut-off", description="VdW type") + vdw_modifier: _Literal["Potential-shift-Verlet", "None"] = _Field( + "Potential-shift-Verlet", description="VdW modifier" + ) + verlet_buffer_tolerance: float = _Field( + 0.005, description="Verlet buffer tolerance" + ) + rvdw: float = _Field(1.0, description="VdW cutoff (nm)") + DispCorr: _Literal["EnerPres", "Ener", "no"] = _Field( + "EnerPres", description="Long range dispersion corrections" + ) + + ### Temperature Coupling ### + tcoupl: _Literal["no", "yes"] = _Field("no", description="Temperature coupling") + tc_grps: str = _Field("System", description="Temperature coupling groups") + tau_t: float = _Field(2.0, description="Time constant for T-coupling (ps)") + ref_t: float = _Field(298.15, description="Reference temperature (K)") + + ### Pressure Coupling ### + pcoupl: _Literal["no", "Berendsen", "C-rescale", "Parrinello-Rahman"] = _Field( + "Parrinello-Rahman", description="Pressure coupling" + ) + pcoupltype: _Literal["isotropic", "semiisotropic"] = _Field( + "isotropic", description="Pressure coupling type" + ) + tau_p: float = _Field(2.0, description="Time constant for P-coupling (ps)") + ref_p: float = _Field(1.01325, description="Reference pressure (bar)") + compressibility: float = _Field(4.5e-5, description="Compressibility (bar^-1)") + refcoord_scaling: _Optional[_Literal["all", "com", "no"]] = _Field( + None, description="Reference coordinate scaling" + ) + + ### Velocity Generation ### + gen_vel: _Literal["yes", "no"] = _Field("no", description="Generate velocities") + gen_seed: int = _Field(-1, description="Random seed") + gen_temp: float = _Field(298.15, description="Generation temperature (K)") + + ### Restraints (for interface compatibility, not used in MDP generation) ### + use_boresch_restraints: bool = _Field( + False, + description="Use Boresch restraints mode (interface compatibility, handled in topology)", + ) + turn_on_receptor_ligand_restraints: bool = _Field( + False, + description="Turn on receptor-ligand restraints mode (interface compatibility, handled in topology)", + ) + boresch_restraints_dictionary: _Optional[str] = _Field( + None, + description="Boresch restraints dictionary content (interface compatibility, handled in topology)", + ) + + ### Free Energy ### + perturbed_residue_number: int = _Field( + 1, + alias="perturbed residue number", + ge=1, + description="Residue number to perturb. Must be >= 1", + ) + + ligand_charge: int = _Field( + 0, + description="Net charge of the ligand. If non-zero, must use PME for electrostatics.", + ) + + free_energy: _Literal["yes", "no"] = _Field("yes", description="Enable FEP") + couple_moltype: str = _Field( + "LIG", description="Molecule type to couple (e.g., 'LIG')" + ) + + couple_lambda0: str = _Field( + "vdw-q", + description="Interactions at lambda=0 (e.g., 'vdw-q', 'vdw', 'q', 'none')", + ) + + couple_lambda1: str = _Field( + "none", + description="Interactions at lambda=1 (e.g., 'vdw-q', 'vdw', 'q', 'none')", + ) + sc_alpha: float = _Field(0.5, description="Soft-core alpha") + sc_power: int = _Field(1, description="Soft-core power") + sc_sigma: float = _Field(0.3, description="Soft-core sigma (nm)") + init_lambda_state: _Optional[int] = _Field(None, description="Initial lambda state") + + lambda_values: _Optional[_List[float]] = _Field( + None, description="Lambda values for this stage (set by system_prep_config)" + ) + + bonded_lambdas: _Optional[_List[float]] = _Field( + None, description="Bonded lambda values" + ) + coul_lambdas: _Optional[_List[float]] = _Field( + None, description="Coulomb lambda values" + ) + vdw_lambdas: _Optional[_List[float]] = _Field(None, description="VdW lambda values") + + nstdhdl: int = _Field(100, description="Frequency to write dH/dlambda") + dhdl_print_energy: _Literal["total", "potential"] = _Field( + "total", description="dH/dlambda energy output" + ) + calc_lambda_neighbors: int = _Field( + -1, + description="Lambda neighbors to calculate (-1 = all for MBAR, 1 = adjacent for BAR)", + ) + separate_dhdl_file: _Literal["yes", "no"] = _Field( + "yes", description="Separate dH/dlambda file" + ) + couple_intramol: _Literal["yes", "no"] = _Field( + "yes", description="Couple intramolecular" + ) + + ### Extra options ### + extra_options: _Dict[str, str] = _Field( + default_factory=dict, description="Extra options" + ) + + @staticmethod + def get_file_name() -> str: + return "gromacs.mdp" + + def setup_lambda_arrays(self, stage_type) -> None: + """ + Set up GROMACS-specific bonded/coul/vdw lambda arrays. + + Parameters + ---------- + stage_type : StageType + The type of stage (RESTRAIN, DISCHARGE, or VANISH) + """ + if self.lambda_values is None: + raise ValueError( + "lambda_values must be set before calling _get_lambda_arrays_for_stage(). " + "This should be set from GromacsSystemPreparationConfig." + ) + + stage = stage_type.name.lower() + + if stage == "restrain": + self.bonded_lambdas = self.lambda_values + self.coul_lambdas = None + self.vdw_lambdas = None + elif stage == "discharge": + self.bonded_lambdas = [1.0] * len(self.lambda_values) + self.coul_lambdas = self.lambda_values + self.vdw_lambdas = None + elif stage == "vanish": + self.bonded_lambdas = [1.0] * len(self.lambda_values) + self.coul_lambdas = None + self.vdw_lambdas = self.lambda_values + else: + raise ValueError(f"Unknown stage type: {stage}") + + def _configure_for_mdp_type(self) -> None: + """ + Configure parameters based on mdp_type. + Resets all parameters to GROMACS defaults first, then applies stage-specific settings. + """ + # Reset all parameters to MD defaults (for nvt/npt/prod) + self.integrator = "sd" + self.constraints = "h-bonds" + self.tcoupl = "yes" + self.pcoupl = "Parrinello-Rahman" + self.continuation = "yes" + self.gen_vel = "no" + self.tau_p = 2.0 + self.nstcomm = 50 + self.nstlist = 20 + + # Apply stage-specific settings + if self.mdp_type == "em": + # EM is completely different from MD + self.integrator = "steep" + self.constraints = "none" + self.tcoupl = "no" + self.pcoupl = "no" + self.gen_vel = "no" + self.nsteps = 50000 + self.emtol = 10 + self.emstep = 0.01 + self.nstcomm = 100 + self.nstxout = 250 + self.nstlist = 1 + + elif self.mdp_type == "nvt": + self.nsteps = 5000 # 10 ps + self.continuation = "no" + self.gen_vel = "yes" + self.pcoupl = "no" + self.nstxout = 25000 + + elif self.mdp_type == "npt": + self.nsteps = 50000 # 100 ps + self.pcoupl = "C-rescale" # GROMACS 2025 + self.tau_p = 1.0 + self.refcoord_scaling = "all" + self.nstxout = 25000 + + elif self.mdp_type == "npt-norest": + self.nsteps = 250000 # 500 ps + self.nstxout = 25000 + + else: # prod + self.nsteps = 10000000 # 20 ns (will be overridden by runtime) + self.nstxout = 0 + + def write_config( + self, + run_dir: str, + lambda_val: float, + runtime: float, + top_file: str = "", + coord_file: str = "", + morph_file: str = "", + ) -> None: + """ + Generate GROMACS mdp file matching fragment-opt-abfe-benchmark format. + """ + # Configure based on type + self._configure_for_mdp_type() + + # Override nsteps for prod based on runtime + if self.mdp_type == "prod": + runtime_ps = runtime * 1000 + self.nsteps = int(runtime_ps / self.dt) + + # Find lambda state index from the active lambda array + if self.bonded_lambdas and not self.coul_lambdas and not self.vdw_lambdas: + lambda_array = self.bonded_lambdas + elif self.coul_lambdas: + lambda_array = self.coul_lambdas + elif self.vdw_lambdas: + lambda_array = self.vdw_lambdas + else: + raise ValueError("No lambda arrays set.") + + try: + self.init_lambda_state = lambda_array.index(lambda_val) + except ValueError: + raise ValueError(f"Lambda {lambda_val} not found in {lambda_array}") + + # Build mdp content + mdp_lines = [ + ";====================================================", + f"; {self.mdp_type.upper().replace('-', ' ')} {'simulation' if self.mdp_type != 'em' else ''}".strip(), + ";====================================================", + "", + ] + + # Run Control + if self.mdp_type == "em": + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; RUN CONTROL & MINIMIZATION", + ";----------------------------------------------------", + ] + ) + if self.define: + mdp_lines.append(f"define = {self.define}") + mdp_lines.extend( + [ + f"integrator = {self.integrator}", + f"nsteps = {self.nsteps}", + f"emtol = {self.emtol}", + f"emstep = {self.emstep}", + f"nstcomm = {self.nstcomm}", + f"pbc = {self.pbc}", + ] + ) + else: + mdp_lines.extend( + [ + "; RUN CONTROL", + ";----------------------------------------------------", + ] + ) + if self.define: + mdp_lines.append(f"define = {self.define}") + mdp_lines.extend( + [ + f"integrator = {self.integrator:<13} ; langevin integrator", + f"nsteps = {self.nsteps:<13} ; {self.dt} * {self.nsteps} fs = {self.nsteps * self.dt * 0.001:.0f} ps", + f"dt = {self.dt:<13} ; {self.dt * 1000:.0f} fs", + f"comm-mode = {self.comm_mode:<13} ; remove center of mass translation", + f"nstcomm = {self.nstcomm:<13} ; frequency for center of mass motion removal", + ] + ) + + # Output Control + mdp_lines.extend( + [ + "", + "; OUTPUT CONTROL" + if self.mdp_type != "em" + else ";----------------------------------------------------", + "" + if self.mdp_type == "em" + else ";----------------------------------------------------", + f"nstxout = {self.nstxout:<10} ; " + + ( + "save coordinates to .trr every " + str(self.nstxout) + " steps" + if self.nstxout > 0 + else "don't save coordinates to .trr" + ), + f"nstvout = {self.nstvout:<10} ; don't save velocities to .trr", + f"nstfout = {self.nstfout:<10} ; don't save forces to .trr", + "", + f"nstxout-compressed = {self.nstxout_compressed:<10} ; xtc compressed trajectory output every {self.nstxout_compressed} steps", + f"compressed-x-precision = {self.compressed_x_precision}", + f"nstlog = {self.nstlog:<10} ; update log file every {self.nstlog} steps", + f"nstenergy = {self.nstenergy:<10} ; save energies every {self.nstenergy} steps", + f"nstcalcenergy = {self.nstcalcenergy}", + "", + ] + ) + + # Neighbor Searching + if self.mdp_type == "em": + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; NEIGHBOR SEARCHING", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type}", + f"nstlist = {self.nstlist}", + f"rlist = {self.rlist}", + "", + ] + ) + else: + mdp_lines.extend( + [ + "; NEIGHBOR SEARCHING" + if self.mdp_type == "nvt" + else ";----------------------------------------------------", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type:<6} ; search neighboring grid cells", + f"nstlist = {self.nstlist:<6} ; {self.nstlist * self.dt * 1000:.0f} fs", + f"rlist = {self.rlist:<6} ; short-range neighborlist cutoff (in nm)", + f"pbc = {self.pbc:<6} ; 3D PBC", + "", + ] + ) + + # Bonds (skip for EM) + if self.mdp_type != "em": + bonds_header = ( + "; BONDS" + if self.mdp_type == "nvt" + else ";----------------------------------------------------" + ) + constraints_comment = ( + " ; all bonds are constrained (HMR)" + if self.constraints == "all-bonds" + else "" + ) + + bonds_section = [ + bonds_header, + ";----------------------------------------------------", + f"constraint_algorithm = {self.constraint_algorithm:<9} ; holonomic constraints", + f"constraints = {self.constraints:<9}{constraints_comment}", + ] + + if self.constraints != "none": + bonds_section.extend( + [ + f"lincs_iter = {self.lincs_iter:<9} ; accuracy of LINCS (1 is default)", + f"lincs_order = {self.lincs_order:<9} ; also related to accuracy (4 is default)", + f"lincs-warnangle = {self.lincs_warnangle:<9} ; maximum angle that a bond can rotate before LINCS will complain (30 is default)", + ] + ) + + bonds_section.extend([f"continuation = {self.continuation}", ""]) + mdp_lines.extend(bonds_section) + + # Electrostatics + if self.mdp_type == "em": + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; ELECTROSTATICS", + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype}", + f"rcoulomb = {self.rcoulomb}", + f"pme-order = {self.pme_order}", + f"fourierspacing = {self.fourierspacing}", + f"ewald-rtol = {self.ewald_rtol}", + "", + ] + ) + else: + elec_header = ( + "; ELECTROSTATICS" + if self.mdp_type == "nvt" + else "; ELECTROSTATICS & EWALD" + ) + mdp_lines.extend( + [ + elec_header, + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype:<6} ; Particle Mesh Ewald for long-range electrostatics", + f"rcoulomb = {self.rcoulomb:<6} ; short-range electrostatic cutoff (in nm)", + f"ewald_geometry = {self.ewald_geometry:<6} ; Ewald sum is performed in all three dimensions", + f"pme-order = {self.pme_order:<6} ; interpolation order for PME (default is 4)", + f"fourierspacing = {self.fourierspacing:<6} ; grid spacing for FFT", + f"ewald-rtol = {self.ewald_rtol:<6} ; relative strength of the Ewald-shifted direct potential at rcoulomb", + "", + ] + ) + + # VDW + if self.mdp_type == "em": + vdw_header = [ + ";----------------------------------------------------", + "; VDW", + ";----------------------------------------------------", + ] + else: + vdw_header = ( + [ + "; VAN DER WAALS", + ";----------------------------------------------------", + ] + if self.mdp_type == "nvt" + else [ + ";----------------------------------------------------", + "; VDW", + ";----------------------------------------------------", + ] + ) + + mdp_lines.extend( + [ + *vdw_header, + f"vdwtype = {self.vdwtype}", + f"vdw-modifier = {self.vdw_modifier}", + f"verlet-buffer-tolerance = {self.verlet_buffer_tolerance}", + f"rvdw = {self.rvdw} ; short-range van der Waals cutoff (in nm)", + f"DispCorr = {self.DispCorr} ; apply long range dispersion corrections for Energy and Pressure", + "", + ] + ) + + # Bonds (for EM only) + if self.mdp_type == "em": + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; BONDS", + ";----------------------------------------------------", + f"constraints = {self.constraints}", + "", + ] + ) + + # Temperature & Pressure Coupling + if self.mdp_type == "em": + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tcoupl = {self.tcoupl}", + f"pcoupl = {self.pcoupl}", + f"gen-vel = {self.gen_vel}", + "", + ] + ) + elif self.mdp_type == "nvt": + mdp_lines.extend( + [ + "; TEMPERATURE COUPLING", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + "", + "; PRESSURE COUPLING", + ";----------------------------------------------------", + f"pcoupl = {self.pcoupl}", + "", + ] + ) + else: # npt, npt-norest, prod + mdp_lines.extend( + [ + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + f"pcoupl = {self.pcoupl}", + f"pcoupltype = {self.pcoupltype} ; uniform scaling of box vectors", + f"tau-p = {self.tau_p} ; time constant (ps)", + f"ref-p = {self.ref_p} ; reference pressure (bar)", + f"compressibility = {self.compressibility} ; isothermal compressibility of water (bar^-1)", + ] + ) + if self.refcoord_scaling: + mdp_lines.append(f"refcoord-scaling = {self.refcoord_scaling}") + mdp_lines.append("") + + # Velocity Generation (only for non-EM stages) + if self.mdp_type != "em": + mdp_lines.extend( + [ + "; VELOCITY GENERATION", + ";----------------------------------------------------", + f"gen_vel = {self.gen_vel} ; Velocity generation is {'on' if self.gen_vel == 'yes' else 'off'}", + f"gen-seed = {self.gen_seed} ; Use random seed", + f"gen-temp = {self.gen_temp}", + "", + ] + ) + + # Free Energy + if self.free_energy == "yes": + mdp_lines.extend( + [ + "; FREE ENERGY", + ";----------------------------------------------------", + f"free-energy = {self.free_energy}", + f"couple-moltype = {self.couple_moltype}", + f"couple-lambda0 = {self.couple_lambda0}", + f"couple-lambda1 = {self.couple_lambda1}", + f"sc-alpha = {self.sc_alpha}", + f"sc-power = {self.sc_power}", + f"sc-sigma = {self.sc_sigma}", + f"init-lambda-state = {'' if self.init_lambda_state is None else self.init_lambda_state}", + ] + ) + + if self.bonded_lambdas: + mdp_lines.append( + f"bonded-lambdas = {' '.join(str(x) for x in self.bonded_lambdas)}" + ) + + if self.coul_lambdas: + mdp_lines.append( + f"coul-lambdas = {' '.join(str(x) for x in self.coul_lambdas)}" + ) + + if self.vdw_lambdas: + mdp_lines.append( + f"vdw-lambdas = {' '.join(str(x) for x in self.vdw_lambdas)}" + ) + + mdp_lines.extend( + [ + f"nstdhdl = {self.nstdhdl}", + f"dhdl-print-energy = {self.dhdl_print_energy}", + f"calc-lambda-neighbors = {self.calc_lambda_neighbors}", + f"separate-dhdl-file = {self.separate_dhdl_file}", + f"couple-intramol = {self.couple_intramol}", + ] + ) + + # Extra options + if self.extra_options: + mdp_lines.append("") + for key, value in self.extra_options.items(): + mdp_lines.append(f"{key} = {value}") + + # Write file + config_path = _os.path.join(run_dir, self.get_file_name()) + with open(config_path, "w") as f: + f.write("\n".join(mdp_lines) + "\n") + + def write_all_stage_configs( + self, + run_dir: str, + lambda_val: float, + runtime: float, + ) -> None: + """ + Generate all GROMACS stage MDP files (em, nvt, npt, npt-norest, prod). + Creates subdirectories for each stage and writes stage-specific MDP files. + + Parameters + ---------- + run_dir : str + Base directory (e.g., output/lambda_X.XXX/run_YY/) + lambda_val : float + Lambda value for this simulation + runtime : float + Runtime for production stage (ns) + """ + stages = ["em", "nvt", "npt", "npt-norest", "prod"] + + for stage in stages: + stage_dir = _os.path.join(run_dir, stage) + _os.makedirs(stage_dir, exist_ok=True) + + original_mdp_type = self.mdp_type + original_define = self.define + + self.mdp_type = stage + + # Set define based on stage + if stage == "em": + self.define = "-DFLEXIBLE" + else: + self.define = None + + self.write_config( + run_dir=stage_dir, + lambda_val=lambda_val, + runtime=runtime if stage == "prod" else 0.1, + ) + + self.mdp_type = original_mdp_type + self.define = original_define + + def get_run_cmd(self, lam: float) -> str: + stages = ["em", "nvt", "npt", "npt-norest", "prod"] + commands = [] + + for i, stage in enumerate(stages): + # prepare input coordinates + if i == 0: + input_gro = "../gromacs.gro" + else: + prev_stage = stages[i - 1] + input_gro = f"../{prev_stage}/{prev_stage}.gro" + + # grompp + mdrun + cmd = ( + f"cd {stage} && " + f"gmx grompp -f gromacs.mdp -c {input_gro} -p ../gromacs.top -o {stage}.tpr && " + f"gmx mdrun -s {stage}.tpr -deffnm {stage} -v && " + f"cd .." + ) + commands.append(cmd) + + return " && ".join(commands) diff --git a/a3fe/configuration/enums.py b/a3fe/configuration/enums.py index 65e36561..a9f773ff 100644 --- a/a3fe/configuration/enums.py +++ b/a3fe/configuration/enums.py @@ -1,12 +1,20 @@ """Enums required for Classes in the run package.""" from enum import Enum as _Enum +from typing import Any as _Any from typing import List as _List -import yaml as _yaml -from .engine_config import _EngineConfig, SomdConfig as _SomdConfig +import yaml as _yaml -from typing import Any as _Any +from .engine_config import ( + GromacsConfig as _GromacsConfig, +) +from .engine_config import ( + SomdConfig as _SomdConfig, +) +from .engine_config import ( + _EngineConfig, +) __all__ = [ "JobStatus", @@ -95,17 +103,22 @@ class LegType(_YamlSerialisableEnum): class EngineType(_YamlSerialisableEnum): SOMD = 1 + GROMACS = 2 @property def engine_config(self) -> _EngineConfig: """Return the configuration class for the engine.""" engine_configs = { EngineType.SOMD: _SomdConfig, + EngineType.GROMACS: _GromacsConfig, } return engine_configs[self] @property def system_prep_config(self): + from .system_prep_config import ( + GromacsSystemPreparationConfig as _GromacsSystemPreparationConfig, + ) from .system_prep_config import ( SomdSystemPreparationConfig as _SomdSystemPreparationConfig, ) @@ -113,6 +126,7 @@ def system_prep_config(self): """Return the system preparation configuration class.""" system_prep_configs = { EngineType.SOMD: _SomdSystemPreparationConfig, + EngineType.GROMACS: _GromacsSystemPreparationConfig, } return system_prep_configs[self] @@ -141,11 +155,17 @@ def file_suffix(self) -> str: @property def prep_fn(self): """The function to use to prepare the input files for this stage.""" + from ..run.system_prep import ( + heat_and_preequil_input as _heat_and_preequil_input, + ) + from ..run.system_prep import ( + minimise_input as _minimise_input, + ) from ..run.system_prep import ( parameterise_input as _parameterise_input, + ) + from ..run.system_prep import ( solvate_input as _solvate_input, - minimise_input as _minimise_input, - heat_and_preequil_input as _heat_and_preequil_input, ) prep_fns = { diff --git a/a3fe/configuration/slurm_config.py b/a3fe/configuration/slurm_config.py index e7abb3d7..4709edd1 100644 --- a/a3fe/configuration/slurm_config.py +++ b/a3fe/configuration/slurm_config.py @@ -4,17 +4,16 @@ "SlurmConfig", ] -import yaml as _yaml -import subprocess as _subprocess +import os as _os import re as _re +import subprocess as _subprocess +from typing import Dict as _Dict +from typing import List as _List +import yaml as _yaml from pydantic import BaseModel as _BaseModel -from pydantic import Field as _Field from pydantic import ConfigDict as _ConfigDict - -import os as _os - -from typing import List as _List, Dict as _Dict +from pydantic import Field as _Field class SlurmConfig(_BaseModel): @@ -33,8 +32,12 @@ class SlurmConfig(_BaseModel): extra_options: _Dict[str, str] = _Field( {}, description="Extra options to pass to SLURM. For example, {'account': 'qt'}" ) - queue_check_interval: int = _Field(30, ge=1, description="Interval in seconds between SLURM queue status checks.") - job_submission_wait: int = _Field(300, ge=1, description="Wait time in seconds for job submission to SLURM queue.") + queue_check_interval: int = _Field( + 30, ge=1, description="Interval in seconds between SLURM queue status checks." + ) + job_submission_wait: int = _Field( + 300, ge=1, description="Wait time in seconds for job submission to SLURM queue." + ) model_config = _ConfigDict(validate_assignment=True) diff --git a/a3fe/configuration/system_prep_config.py b/a3fe/configuration/system_prep_config.py index 3e9fea5b..1cba3fa0 100644 --- a/a3fe/configuration/system_prep_config.py +++ b/a3fe/configuration/system_prep_config.py @@ -4,20 +4,20 @@ __all__ = [ "SomdSystemPreparationConfig", + "GromacsSystemPreparationConfig", ] -import yaml as _yaml - from abc import ABC as _ABC +from typing import Dict as _Dict +from typing import List as _List +import yaml as _yaml from pydantic import BaseModel as _BaseModel -from pydantic import Field as _Field from pydantic import ConfigDict as _ConfigDict +from pydantic import Field as _Field -from .enums import StageType as _StageType from .enums import LegType as _LegType - -from typing import List as _List, Dict as _Dict +from .enums import StageType as _StageType class _BaseSystemPreparationConfig(_ABC, _BaseModel): @@ -243,3 +243,38 @@ class SomdSystemPreparationConfig(_BaseSystemPreparationConfig): Currently this doesn't modify the base settings, but it may do in the future. """ + + +class GromacsSystemPreparationConfig(_BaseSystemPreparationConfig): + """ + Pydantic model for holding system preparation configuration + for running simulations with GROMACS. + + Uses lambda values optimized for GROMACS soft-core parameters. + """ + + lambda_values: _Dict[_LegType, _Dict[_StageType, _List[float]]] = _Field( + default={ + _LegType.BOUND: { + _StageType.RESTRAIN: [0.0, 0.05, 0.15, 0.5, 0.75, 1.0], + _StageType.DISCHARGE: [0.0, 0.3, 0.6, 0.9, 1.0], + _StageType.VANISH: [ + 0.0, + 0.05, + 0.2, + 0.25, + 0.4, + 0.5, + 0.65, + 0.8, + 0.9, + 1.0, + ], + }, + _LegType.FREE: { + _StageType.DISCHARGE: [0.0, 0.2, 0.4, 0.5, 0.7, 0.9, 1.0], + _StageType.VANISH: [0.0, 0.2, 0.35, 0.4, 0.55, 0.7, 0.85, 1.0], + }, + }, + description="Lambda values optimized for GROMACS.", + ) diff --git a/a3fe/run/leg.py b/a3fe/run/leg.py index af3b61ad..68b7472d 100644 --- a/a3fe/run/leg.py +++ b/a3fe/run/leg.py @@ -22,20 +22,19 @@ from ..analyse.plot import plot_convergence as _plot_convergence from ..analyse.plot import plot_rmsds as _plot_rmsds from ..analyse.plot import plot_sq_sem_convergence as _plot_sq_sem_convergence +from ..configuration import EngineType as _EngineType +from ..configuration import LegType as _LegType +from ..configuration import PreparationStage as _PreparationStage +from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import StageType as _StageType +from ..configuration import _BaseSystemPreparationConfig, _EngineConfig from . import system_prep as _system_prep from ._restraint import A3feRestraint as _A3feRestraint from ._simulation_runner import SimulationRunner as _SimulationRunner from ._utils import get_single_mol as _get_single_mol from ._virtual_queue import Job as _Job from ._virtual_queue import VirtualQueue as _VirtualQueue -from ..configuration import LegType as _LegType -from ..configuration import PreparationStage as _PreparationStage -from ..configuration import StageType as _StageType -from ..configuration import EngineType as _EngineType -from ..configuration import _EngineConfig -from ..configuration import SlurmConfig as _SlurmConfig from .stage import Stage as _Stage -from ..configuration import _BaseSystemPreparationConfig class Leg(_SimulationRunner): @@ -533,7 +532,11 @@ def run_ensemble_equilibration( _PreparationStage.PREEQUILIBRATED.get_simulation_input_files( self.leg_type ) - + ["somd.rst7"] + + ( + ["somd.rst7"] + if self.engine_type == _EngineType.SOMD + else ["gromacs.gro"] + ) ): if not _os.path.isfile(f"{outdir}/{file}"): raise RuntimeError( @@ -556,10 +559,20 @@ def run_ensemble_equilibration( # Give the output files unique names equil_numbers = [int(outdir.split("_")[-1]) for outdir in outdirs_to_run] for equil_number, outdir in zip(equil_numbers, outdirs_to_run): - _subprocess.run( - ["mv", f"{outdir}/somd.rst7", f"{outdir}/somd_{equil_number}.rst7"], - check=True, - ) + if self.engine_type == _EngineType.GROMACS: + _subprocess.run( + [ + "mv", + f"{outdir}/gromacs.gro", + f"{outdir}/gromacs_{equil_number}.gro", + ], + check=True, + ) + else: + _subprocess.run( + ["mv", f"{outdir}/somd.rst7", f"{outdir}/somd_{equil_number}.rst7"], + check=True, + ) # Load the system and mark the ligand to be decoupled self._logger.info("Loading pre-equilibrated system...") @@ -603,7 +616,10 @@ def run_ensemble_equilibration( # Save the restraints to a text file and store within the Leg object with open(f"{outdir}/restraint_{i + 1}.txt", "w") as f: - f.write(restraint.toString(engine="SOMD")) # type: ignore + run_engine = ( + "GROMACS" if self.engine_type == _EngineType.GROMACS else "SOMD" + ) + f.write(restraint.toString(engine=run_engine)) # type: ignore self.restraints.append(restraint) return pre_equilibrated_system @@ -656,28 +672,38 @@ def setup_stages( f"Setting up {self.leg_type.name} leg {stage_type.name} stage" ) restraint = self.restraints[0] if self.leg_type == _LegType.BOUND else None - protocol = _BSS.Protocol.FreeEnergy( - runtime=dummy_runtime * _BSS.Units.Time.nanosecond, # type: ignore - lam_vals=dummy_lam_vals, - perturbation_type=stage_type.bss_perturbation_type, - ) + if self.engine_type == _EngineType.GROMACS: + protocol = _BSS.Protocol.FreeEnergy( + runtime=dummy_runtime * _BSS.Units.Time.nanosecond, # type: ignore + lam_vals=dummy_lam_vals, + perturbation_type="full", + ) + else: + protocol = _BSS.Protocol.FreeEnergy( + runtime=dummy_runtime * _BSS.Units.Time.nanosecond, # type: ignore + lam_vals=dummy_lam_vals, + perturbation_type=stage_type.bss_perturbation_type, + ) self._logger.info(f"Perturbation type: {stage_type.bss_perturbation_type}") # Ensure we remove the velocites to avoid RST7 file writing issues, as before + run_engine = ( + "gromacs" if self.engine_type == _EngineType.GROMACS else "somd" + ) _BSS.FreeEnergy.AlchemicalFreeEnergy( pre_equilibrated_system, protocol, - engine="SOMD", + engine=run_engine, restraint=restraint, work_dir=stage_input_dir, setup_only=True, property_map={"velocity": "foo"}, ) # We will run outside of BSS - # Copy input written by BSS to the stage input directory, excluding only somd.cfg + # Copy input written by BSS to the stage input directory, excluding cfg and mdp files = [ file - for file in _glob.glob(f"{stage_input_dir}/lambda_0.0000/*") - if not file.endswith(".cfg") + for file in _glob.glob(f"{stage_input_dir}/lambda_*/*") + if not file.endswith((".cfg", ".mdp", ".tpr")) ] for file in files: _shutil.copy(file, stage_input_dir) @@ -691,25 +717,34 @@ def setup_stages( # and, if this is the bound stage, read in the restraints for i in range(self.ensemble_size): ens_equil_output_dir = f"{self.base_dir}/ensemble_equilibration_{i + 1}" - coordinates_file = f"{ens_equil_output_dir}/somd_{i + 1}.rst7" - _shutil.copy(coordinates_file, f"{stage_input_dir}/somd_{i + 1}.rst7") + coordinates_file = ( + f"{ens_equil_output_dir}/somd_{i + 1}.rst7" + if self.engine_type == _EngineType.SOMD + else f"{ens_equil_output_dir}/gromacs_{i + 1}.gro" + ) + _shutil.copy( + coordinates_file, + f"{stage_input_dir}/somd_{i + 1}.rst7" + if self.engine_type == _EngineType.SOMD + else f"{stage_input_dir}/gromacs_{i + 1}.gro", + ) if self.leg_type == _LegType.BOUND: - # Read in the first restraint file - with open( - f"{ens_equil_output_dir}/restraint_{i + 1}.txt", "r" - ) as f: - lines = f.readlines() - restraint_type, restraint_dict = [ - item.strip() for item in lines[0].split("=") - ] - - if restraint_type != "boresch restraints dictionary": - raise ValueError( - f"Only Boresch restraints are supported. Found {restraint_type} restraints." - ) - - stage_config.boresch_restraints_dictionary = restraint_dict + if self.engine_type == _EngineType.SOMD: + with open( + f"{ens_equil_output_dir}/restraint_{i + 1}.txt", "r" + ) as f: + lines = f.readlines() + restraint_type, restraint_dict = [ + item.strip() for item in lines[0].split("=") + ] + + if restraint_type != "boresch restraints dictionary": + raise ValueError( + f"Only Boresch restraints are supported. Found {restraint_type} restraints." + ) + + stage_config.boresch_restraints_dictionary = restraint_dict # Set configuration options stage_config.perturbed_residue_number = perturbed_resnum @@ -723,7 +758,7 @@ def setup_stages( stage_config.lambda_values = sys_prep_config.lambda_values[self.leg_type][ stage_type ] - + stage_config.setup_lambda_arrays(stage_type) stage_configs[stage_type] = stage_config # We no longer need to store the large BSS restraint classes. diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index 45013677..4fccbfeb 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -1,4 +1,4 @@ -"""Functionality to run a single SOMD simulation.""" +"""Functionality to run a single simulation.""" __all__ = ["Simulation"] @@ -14,23 +14,29 @@ import numpy as _np from sire.units import k_boltz as _k_boltz -from ._simulation_runner import SimulationRunner as _SimulationRunner -from ._virtual_queue import Job as _Job -from ._virtual_queue import VirtualQueue as _VirtualQueue +from ..configuration import EngineType as _EngineType from ..configuration import JobStatus as _JobStatus from ..configuration import SlurmConfig as _SlurmConfig from ..configuration import _EngineConfig -from ..configuration import EngineType as _EngineType +from ._simulation_runner import SimulationRunner as _SimulationRunner +from ._virtual_queue import Job as _Job +from ._virtual_queue import VirtualQueue as _VirtualQueue class Simulation(_SimulationRunner): """Class to store information about a single SOMD simulation.""" - required_input_files = [ - "somd.prm7", - "somd.rst7", - "somd.pert", - ] + required_input_files = { + _EngineType.SOMD: [ + "somd.prm7", + "somd.rst7", + "somd.pert", + ], + _EngineType.GROMACS: [ + "gromacs.top", + "gromacs.gro", + ], + } # Files to be cleaned by self.clean() run_files = _SimulationRunner.run_files + [ @@ -212,37 +218,59 @@ def _validate_input(self) -> None: raise FileNotFoundError("Input directory does not exist.") # Check that the required input files are present - for file in Simulation.required_input_files: + for file in Simulation.required_input_files[self.engine_type]: if not _os.path.isfile(_os.path.join(self.input_dir, file)): raise FileNotFoundError("Required input file " + file + " not found.") def _select_input_files(self) -> None: """Select the correct rst7 and, if supplied, restraints, according to the run number.""" - - # Check if we have multiple rst7 files, or only one - rst7_files = _glob.glob(_os.path.join(self.input_dir, "*.rst7")) - if len(rst7_files) == 0: - raise FileNotFoundError("No rst7 files found in input directory") - elif len(rst7_files) > 1: - # Rename the rst7 file for this run to somd.rst7 and delete any other - # rst7 files - self._logger.debug("Multiple rst7 files found - renaming") - _subprocess.run( - [ - "mv", - _os.path.join(self.input_dir, f"somd_{self.run_no}.rst7"), - _os.path.join(self.input_dir, "somd.rst7"), - ] - ) - unwanted_rst7_files = _glob.glob( - _os.path.join(self.input_dir, "somd_?.rst7") - ) - for file in unwanted_rst7_files: - _subprocess.run(["rm", file]) + if self.engine_type == _EngineType.SOMD: + # Check if we have multiple rst7 files, or only one + rst7_files = _glob.glob(_os.path.join(self.input_dir, "*.rst7")) + if len(rst7_files) == 0: + raise FileNotFoundError("No rst7 files found in input directory") + elif len(rst7_files) > 1: + # Rename the rst7 file for this run to somd.rst7 and delete any other + # rst7 files + self._logger.debug("Multiple rst7 files found - renaming") + _subprocess.run( + [ + "mv", + _os.path.join(self.input_dir, f"somd_{self.run_no}.rst7"), + _os.path.join(self.input_dir, "somd.rst7"), + ] + ) + unwanted_rst7_files = _glob.glob( + _os.path.join(self.input_dir, "somd_?.rst7") + ) + for file in unwanted_rst7_files: + _subprocess.run(["rm", file]) + else: + self._logger.info("Only one rst7 file found - not renaming") + + elif self.engine_type == _EngineType.GROMACS: + gro_files = _glob.glob(_os.path.join(self.input_dir, "*.gro")) + if len(gro_files) == 0: + raise FileNotFoundError("No gro files found in input directory") + elif len(gro_files) > 1: + self._logger.debug("Multiple gro files found - renaming") + _subprocess.run( + [ + "mv", + _os.path.join(self.input_dir, f"gromacs_{self.run_no}.gro"), + _os.path.join(self.input_dir, "gromacs.gro"), + ] + ) + unwanted_gro_files = _glob.glob( + _os.path.join(self.input_dir, "gromacs_?.gro") + ) + for file in unwanted_gro_files: + _subprocess.run(["rm", file]) + else: + self._logger.info("Only one gro file found - not renaming") else: - self._logger.info("Only one rst7 file found - not renaming") - + raise ValueError(f"Engine type {self.engine_type} not supported") # Deal with restraints. Get the name of the restraint file for this run old_restr_file = _os.path.join(self.input_dir, f"restraint_{self.run_no}.txt") @@ -285,14 +313,21 @@ def run(self, runtime: float = 2.5) -> None: None """ # Write updated config to file - self.engine_config.write_config( - run_dir=self.output_dir, - lambda_val=self.lam, - runtime=runtime, - top_file="somd.prm7", # TODO - make generic - coord_file="somd.rst7", # TODO - make generic - morph_file="somd.pert", # TODO - make generic - ) + if self.engine_type == _EngineType.SOMD: + self.engine_config.write_config( + run_dir=self.output_dir, + lambda_val=self.lam, + runtime=runtime, + top_file="somd.prm7", + coord_file="somd.rst7", + morph_file="somd.pert", + ) + else: # GROMACS + self.engine_config.write_all_stage_configs( + run_dir=self.output_dir, + lambda_val=self.lam, + runtime=runtime, + ) # Get the commands to run the simulation cmd = self.engine_config.get_run_cmd(self.lam) diff --git a/a3fe/run/stage.py b/a3fe/run/stage.py index 4b714010..2438470f 100644 --- a/a3fe/run/stage.py +++ b/a3fe/run/stage.py @@ -9,7 +9,6 @@ import threading as _threading from copy import deepcopy as _deepcopy from math import ceil as _ceil -import matplotlib.pyplot as _plt from multiprocessing import get_context as _get_context from time import sleep as _sleep from typing import Any as _Any @@ -20,6 +19,7 @@ from typing import Tuple as _Tuple from typing import Union as _Union +import matplotlib.pyplot as _plt import numpy as _np import pandas as _pd import scipy.stats as _stats @@ -49,13 +49,13 @@ from ..analyse.plot import plot_rmsds as _plot_rmsds from ..analyse.plot import plot_sq_sem_convergence as _plot_sq_sem_convergence from ..analyse.process_grads import GradientData as _GradientData +from ..configuration import EngineType as _EngineType +from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import StageType as _StageType +from ..configuration import _EngineConfig from ._simulation_runner import SimulationRunner as _SimulationRunner from ._virtual_queue import VirtualQueue as _VirtualQueue -from ..configuration import StageType as _StageType from .lambda_window import LamWindow as _LamWindow -from ..configuration import SlurmConfig as _SlurmConfig -from ..configuration import _EngineConfig -from ..configuration import EngineType as _EngineType class Stage(_SimulationRunner): @@ -790,8 +790,10 @@ def analyse( win._write_equilibrated_simfiles() # Run MBAR and compute mean and 95 % C.I. of free energy - if not slurm: + # GROMACS doesn't need SLURM (Python function can use multiprocessing) + if not slurm or self.engine_type == _EngineType.GROMACS: free_energies, errors, mbar_outfiles, _ = _run_mbar( + engine_type=self.engine_type, run_nos=run_nos, output_dir=self.output_dir, percentage_end=fraction * 100, @@ -1046,7 +1048,8 @@ def analyse_convergence( for win in self.lam_windows: win._write_equilibrated_simfiles() - if not slurm: + # GROMACS doesn't need SLURM (Python function can use multiprocessing) + if not slurm or self.engine_type == _EngineType.GROMACS: # Now run mbar with multiprocessing to speed things up with _get_context("spawn").Pool() as pool: results = pool.starmap( @@ -1060,13 +1063,14 @@ def analyse_convergence( False, # Subsample True, # Delete output files equilibrated, # Equilibrated + self.engine_type, ) for start_percent, end_percent in zip( start_percents, end_percents ) ], ) - else: # Use slurm + else: # Use SLURM (SOMD only) frac_jobs = [] results = [] for start_percent, end_percent in zip(start_percents, end_percents): diff --git a/a3fe/run/system_prep.py b/a3fe/run/system_prep.py index f1d92758..f75ecfb5 100644 --- a/a3fe/run/system_prep.py +++ b/a3fe/run/system_prep.py @@ -14,11 +14,11 @@ import BioSimSpace.Sandpit.Exscientia as _BSS -from ..read._process_bss_systems import rename_lig as _rename_lig -from ._utils import check_has_wat_and_box as _check_has_wat_and_box +from ..configuration import EngineType as _EngineType from ..configuration import LegType as _LegType from ..configuration import PreparationStage as _PreparationStage -from ..configuration import EngineType as _EngineType +from ..read._process_bss_systems import rename_lig as _rename_lig +from ._utils import check_has_wat_and_box as _check_has_wat_and_box def parameterise_input( @@ -452,7 +452,7 @@ def run_ensemble_equilibration( print( f"Running ensemble equilibration simulation with GROMACS for {cfg.ensemble_equilibration_time} ps" ) - if leg_type == _LegType.BOUND: + if leg_type == _LegType.BOUND or engine_type == _EngineType.GROMACS: work_dir = output_dir else: work_dir = None @@ -461,13 +461,16 @@ def run_ensemble_equilibration( # Save the coordinates only, renaming the velocity property to foo so avoid saving velocities. Saving the # velocities sometimes causes issues with the size of the floats overflowing the RST7 # format. - print(f"Saving somd.rst7 to {output_dir}") - _BSS.IO.saveMolecules( - f"{output_dir}/somd", - final_system, - fileformat=["rst7"], - property_map={"velocity": "foo"}, - ) + if engine_type == _EngineType.GROMACS: + print(f"GROMACS files already saved to {output_dir}") + else: + print(f"Saving somd.rst7 to {output_dir}") + _BSS.IO.saveMolecules( + f"{output_dir}/somd", + final_system, + fileformat=["rst7"], + property_map={"velocity": "foo"}, + ) def run_process( diff --git a/devtools/conda-envs/dev_env.yaml b/devtools/conda-envs/dev_env.yaml index 3a4b2f59..8b5b4ea9 100644 --- a/devtools/conda-envs/dev_env.yaml +++ b/devtools/conda-envs/dev_env.yaml @@ -17,6 +17,7 @@ dependencies: - scipy - ipython - pymbar<4 + - alchemlyb=1.0.1 # for pymbar<4 - ambertools - biosimspace>2023.4 - numpydoc From 44602c92fb0dd1f04ca4939c7c7b9cd7519b76a3 Mon Sep 17 00:00:00 2001 From: Roy Haolin Du Date: Mon, 8 Dec 2025 09:58:26 +0000 Subject: [PATCH 2/3] Add GROMACS engine support and configuration system --- a3fe/configuration/engine_config.py | 296 ++++++++++++----------- a3fe/configuration/system_prep_config.py | 21 +- a3fe/read/_read_exp_dgs.py | 3 +- a3fe/run/_simulation_runner.py | 7 +- a3fe/run/_virtual_queue.py | 2 +- a3fe/run/calc_set.py | 6 +- a3fe/run/calculation.py | 11 +- a3fe/run/lambda_window.py | 53 ++-- a3fe/run/simulation.py | 207 ++++++++++------ a3fe/run/stage.py | 55 ++--- 10 files changed, 371 insertions(+), 290 deletions(-) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index 340993a1..c210f673 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -492,7 +492,7 @@ class GromacsConfig(_EngineConfig): nstlog: int = _Field(500, description="Update log file") nstenergy: int = _Field(500, description="Save energies") nstcalcenergy: int = _Field(50, description="Calculate energies") - + ### Bonds ### constraint_algorithm: _Literal["lincs", "shake"] = _Field( "lincs", description="Constraint algorithm" @@ -504,7 +504,7 @@ class GromacsConfig(_EngineConfig): lincs_order: int = _Field(6, description="LINCS order") lincs_warnangle: int = _Field(30, description="LINCS warning angle") continuation: _Literal["yes", "no"] = _Field("yes", description="Continuation") - + ### Neighbor Searching ### cutoff_scheme: _Literal["Verlet", "group"] = _Field( "Verlet", description="Cutoff scheme" @@ -512,7 +512,7 @@ class GromacsConfig(_EngineConfig): ns_type: _Literal["grid", "simple"] = _Field("grid", description="Neighbor search") nstlist: int = _Field(20, description="Update neighbor list") rlist: float = _Field(1.2, description="Neighbor list cutoff (nm)") - + ### Electrostatics ### coulombtype: _Literal["PME", "Cut-off"] = _Field("PME", description="Coulomb type") rcoulomb: float = _Field(1.0, description="Coulomb cutoff (nm)") @@ -520,7 +520,7 @@ class GromacsConfig(_EngineConfig): pme_order: int = _Field(4, description="PME order") fourierspacing: float = _Field(0.10, description="PME grid spacing (nm)") ewald_rtol: float = _Field(1e-6, description="Ewald tolerance") - + ### VDW ### vdwtype: _Literal["Cut-off", "PME"] = _Field("Cut-off", description="VdW type") vdw_modifier: _Literal["Potential-shift-Verlet", "None"] = _Field( @@ -533,13 +533,13 @@ class GromacsConfig(_EngineConfig): DispCorr: _Literal["EnerPres", "Ener", "no"] = _Field( "EnerPres", description="Long range dispersion corrections" ) - + ### Temperature Coupling ### tcoupl: _Literal["no", "yes"] = _Field("no", description="Temperature coupling") tc_grps: str = _Field("System", description="Temperature coupling groups") tau_t: float = _Field(2.0, description="Time constant for T-coupling (ps)") ref_t: float = _Field(298.15, description="Reference temperature (K)") - + ### Pressure Coupling ### pcoupl: _Literal["no", "Berendsen", "C-rescale", "Parrinello-Rahman"] = _Field( "Parrinello-Rahman", description="Pressure coupling" @@ -553,7 +553,7 @@ class GromacsConfig(_EngineConfig): refcoord_scaling: _Optional[_Literal["all", "com", "no"]] = _Field( None, description="Reference coordinate scaling" ) - + ### Velocity Generation ### gen_vel: _Literal["yes", "no"] = _Field("no", description="Generate velocities") gen_seed: int = _Field(-1, description="Random seed") @@ -572,7 +572,7 @@ class GromacsConfig(_EngineConfig): None, description="Boresch restraints dictionary content (interface compatibility, handled in topology)", ) - + ### Free Energy ### perturbed_residue_number: int = _Field( 1, @@ -631,7 +631,7 @@ class GromacsConfig(_EngineConfig): couple_intramol: _Literal["yes", "no"] = _Field( "yes", description="Couple intramolecular" ) - + ### Extra options ### extra_options: _Dict[str, str] = _Field( default_factory=dict, description="Extra options" @@ -641,10 +641,15 @@ class GromacsConfig(_EngineConfig): def get_file_name() -> str: return "gromacs.mdp" + @property + def timestep(self) -> float: + """Return timestep in femtoseconds for compatibility with SomdConfig.""" + return self.dt * 1000.0 # ps to fs + def setup_lambda_arrays(self, stage_type) -> None: """ Set up GROMACS-specific bonded/coul/vdw lambda arrays. - + Parameters ---------- stage_type : StageType @@ -660,15 +665,15 @@ def setup_lambda_arrays(self, stage_type) -> None: if stage == "restrain": self.bonded_lambdas = self.lambda_values - self.coul_lambdas = None - self.vdw_lambdas = None + self.coul_lambdas = [0.0] * len(self.lambda_values) + self.vdw_lambdas = [0.0] * len(self.lambda_values) elif stage == "discharge": self.bonded_lambdas = [1.0] * len(self.lambda_values) self.coul_lambdas = self.lambda_values - self.vdw_lambdas = None + self.vdw_lambdas = [0.0] * len(self.lambda_values) elif stage == "vanish": self.bonded_lambdas = [1.0] * len(self.lambda_values) - self.coul_lambdas = None + self.coul_lambdas = [1.0] * len(self.lambda_values) self.vdw_lambdas = self.lambda_values else: raise ValueError(f"Unknown stage type: {stage}") @@ -697,33 +702,33 @@ def _configure_for_mdp_type(self) -> None: self.tcoupl = "no" self.pcoupl = "no" self.gen_vel = "no" - self.nsteps = 50000 + self.nsteps = 10000 self.emtol = 10 self.emstep = 0.01 self.nstcomm = 100 self.nstxout = 250 self.nstlist = 1 - + elif self.mdp_type == "nvt": self.nsteps = 5000 # 10 ps self.continuation = "no" self.gen_vel = "yes" self.pcoupl = "no" self.nstxout = 25000 - + elif self.mdp_type == "npt": self.nsteps = 50000 # 100 ps self.pcoupl = "C-rescale" # GROMACS 2025 self.tau_p = 1.0 self.refcoord_scaling = "all" self.nstxout = 25000 - + elif self.mdp_type == "npt-norest": self.nsteps = 250000 # 500 ps self.nstxout = 25000 - + else: # prod - self.nsteps = 10000000 # 20 ns (will be overridden by runtime) + self.nsteps = 2500000 # 5 ns (will be overridden by runtime) self.nstxout = 0 def write_config( @@ -740,22 +745,41 @@ def write_config( """ # Configure based on type self._configure_for_mdp_type() - + # Override nsteps for prod based on runtime if self.mdp_type == "prod": runtime_ps = runtime * 1000 self.nsteps = int(runtime_ps / self.dt) # Find lambda state index from the active lambda array - if self.bonded_lambdas and not self.coul_lambdas and not self.vdw_lambdas: - lambda_array = self.bonded_lambdas - elif self.coul_lambdas: - lambda_array = self.coul_lambdas - elif self.vdw_lambdas: + # Priority: find array containing lambda_val, otherwise use varying array + lambda_array = None + + # First, try to find array containing lambda_val + if self.vdw_lambdas and lambda_val in self.vdw_lambdas: lambda_array = self.vdw_lambdas - else: - raise ValueError("No lambda arrays set.") - + elif self.coul_lambdas and lambda_val in self.coul_lambdas: + lambda_array = self.coul_lambdas + elif self.bonded_lambdas and lambda_val in self.bonded_lambdas: + lambda_array = self.bonded_lambdas + + # If not found, use varying array (not all values are the same) + if lambda_array is None: + if self.vdw_lambdas and len(set(self.vdw_lambdas)) > 1: + lambda_array = self.vdw_lambdas + elif self.coul_lambdas and len(set(self.coul_lambdas)) > 1: + lambda_array = self.coul_lambdas + elif self.bonded_lambdas and len(set(self.bonded_lambdas)) > 1: + lambda_array = self.bonded_lambdas + + if lambda_array is None: + raise ValueError( + f"Lambda {lambda_val} not found in any lambda array. " + f"bonded: {self.bonded_lambdas}, " + f"coul: {self.coul_lambdas}, " + f"vdw: {self.vdw_lambdas}" + ) + try: self.init_lambda_state = lambda_array.index(lambda_val) except ValueError: @@ -768,47 +792,47 @@ def write_config( ";====================================================", "", ] - + # Run Control if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; RUN CONTROL & MINIMIZATION", - ";----------------------------------------------------", + ";----------------------------------------------------", + "; RUN CONTROL & MINIMIZATION", + ";----------------------------------------------------", ] ) if self.define: mdp_lines.append(f"define = {self.define}") mdp_lines.extend( [ - f"integrator = {self.integrator}", - f"nsteps = {self.nsteps}", - f"emtol = {self.emtol}", - f"emstep = {self.emstep}", - f"nstcomm = {self.nstcomm}", - f"pbc = {self.pbc}", + f"integrator = {self.integrator}", + f"nsteps = {self.nsteps}", + f"emtol = {self.emtol}", + f"emstep = {self.emstep}", + f"nstcomm = {self.nstcomm}", + f"pbc = {self.pbc}", ] ) else: mdp_lines.extend( [ - "; RUN CONTROL", - ";----------------------------------------------------", + "; RUN CONTROL", + ";----------------------------------------------------", ] ) if self.define: mdp_lines.append(f"define = {self.define}") mdp_lines.extend( [ - f"integrator = {self.integrator:<13} ; langevin integrator", - f"nsteps = {self.nsteps:<13} ; {self.dt} * {self.nsteps} fs = {self.nsteps * self.dt * 0.001:.0f} ps", - f"dt = {self.dt:<13} ; {self.dt * 1000:.0f} fs", - f"comm-mode = {self.comm_mode:<13} ; remove center of mass translation", - f"nstcomm = {self.nstcomm:<13} ; frequency for center of mass motion removal", + f"integrator = {self.integrator:<13} ; langevin integrator", + f"nsteps = {self.nsteps:<13} ; {self.dt} * {self.nsteps} fs = {self.nsteps * self.dt * 0.001:.0f} ps", + f"dt = {self.dt:<13} ; {self.dt * 1000:.0f} fs", + f"comm-mode = {self.comm_mode:<13} ; remove center of mass translation", + f"nstcomm = {self.nstcomm:<13} ; frequency for center of mass motion removal", ] ) - + # Output Control mdp_lines.extend( [ @@ -825,30 +849,30 @@ def write_config( if self.nstxout > 0 else "don't save coordinates to .trr" ), - f"nstvout = {self.nstvout:<10} ; don't save velocities to .trr", - f"nstfout = {self.nstfout:<10} ; don't save forces to .trr", - "", - f"nstxout-compressed = {self.nstxout_compressed:<10} ; xtc compressed trajectory output every {self.nstxout_compressed} steps", - f"compressed-x-precision = {self.compressed_x_precision}", - f"nstlog = {self.nstlog:<10} ; update log file every {self.nstlog} steps", - f"nstenergy = {self.nstenergy:<10} ; save energies every {self.nstenergy} steps", - f"nstcalcenergy = {self.nstcalcenergy}", - "", + f"nstvout = {self.nstvout:<10} ; don't save velocities to .trr", + f"nstfout = {self.nstfout:<10} ; don't save forces to .trr", + "", + f"nstxout-compressed = {self.nstxout_compressed:<10} ; xtc compressed trajectory output every {self.nstxout_compressed} steps", + f"compressed-x-precision = {self.compressed_x_precision}", + f"nstlog = {self.nstlog:<10} ; update log file every {self.nstlog} steps", + f"nstenergy = {self.nstenergy:<10} ; save energies every {self.nstenergy} steps", + f"nstcalcenergy = {self.nstcalcenergy}", + "", ] ) - + # Neighbor Searching if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; NEIGHBOR SEARCHING", - ";----------------------------------------------------", - f"cutoff-scheme = {self.cutoff_scheme}", - f"ns-type = {self.ns_type}", - f"nstlist = {self.nstlist}", - f"rlist = {self.rlist}", - "", + ";----------------------------------------------------", + "; NEIGHBOR SEARCHING", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type}", + f"nstlist = {self.nstlist}", + f"rlist = {self.rlist}", + "", ] ) else: @@ -857,16 +881,16 @@ def write_config( "; NEIGHBOR SEARCHING" if self.mdp_type == "nvt" else ";----------------------------------------------------", - ";----------------------------------------------------", - f"cutoff-scheme = {self.cutoff_scheme}", - f"ns-type = {self.ns_type:<6} ; search neighboring grid cells", - f"nstlist = {self.nstlist:<6} ; {self.nstlist * self.dt * 1000:.0f} fs", - f"rlist = {self.rlist:<6} ; short-range neighborlist cutoff (in nm)", - f"pbc = {self.pbc:<6} ; 3D PBC", - "", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type:<6} ; search neighboring grid cells", + f"nstlist = {self.nstlist:<6} ; {self.nstlist * self.dt * 1000:.0f} fs", + f"rlist = {self.rlist:<6} ; short-range neighborlist cutoff (in nm)", + f"pbc = {self.pbc:<6} ; 3D PBC", + "", ] ) - + # Bonds (skip for EM) if self.mdp_type != "em": bonds_header = ( @@ -898,20 +922,20 @@ def write_config( bonds_section.extend([f"continuation = {self.continuation}", ""]) mdp_lines.extend(bonds_section) - + # Electrostatics if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; ELECTROSTATICS", - ";----------------------------------------------------", - f"coulombtype = {self.coulombtype}", - f"rcoulomb = {self.rcoulomb}", - f"pme-order = {self.pme_order}", - f"fourierspacing = {self.fourierspacing}", - f"ewald-rtol = {self.ewald_rtol}", - "", + ";----------------------------------------------------", + "; ELECTROSTATICS", + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype}", + f"rcoulomb = {self.rcoulomb}", + f"pme-order = {self.pme_order}", + f"fourierspacing = {self.fourierspacing}", + f"ewald-rtol = {self.ewald_rtol}", + "", ] ) else: @@ -923,17 +947,17 @@ def write_config( mdp_lines.extend( [ elec_header, - ";----------------------------------------------------", - f"coulombtype = {self.coulombtype:<6} ; Particle Mesh Ewald for long-range electrostatics", - f"rcoulomb = {self.rcoulomb:<6} ; short-range electrostatic cutoff (in nm)", - f"ewald_geometry = {self.ewald_geometry:<6} ; Ewald sum is performed in all three dimensions", - f"pme-order = {self.pme_order:<6} ; interpolation order for PME (default is 4)", - f"fourierspacing = {self.fourierspacing:<6} ; grid spacing for FFT", - f"ewald-rtol = {self.ewald_rtol:<6} ; relative strength of the Ewald-shifted direct potential at rcoulomb", - "", + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype:<6} ; Particle Mesh Ewald for long-range electrostatics", + f"rcoulomb = {self.rcoulomb:<6} ; short-range electrostatic cutoff (in nm)", + f"ewald_geometry = {self.ewald_geometry:<6} ; Ewald sum is performed in all three dimensions", + f"pme-order = {self.pme_order:<6} ; interpolation order for PME (default is 4)", + f"fourierspacing = {self.fourierspacing:<6} ; grid spacing for FFT", + f"ewald-rtol = {self.ewald_rtol:<6} ; relative strength of the Ewald-shifted direct potential at rcoulomb", + "", ] ) - + # VDW if self.mdp_type == "em": vdw_header = [ @@ -978,68 +1002,68 @@ def write_config( "", ] ) - + # Temperature & Pressure Coupling if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; TEMPERATURE & PRESSURE COUPL", - ";----------------------------------------------------", - f"tcoupl = {self.tcoupl}", - f"pcoupl = {self.pcoupl}", - f"gen-vel = {self.gen_vel}", - "", + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tcoupl = {self.tcoupl}", + f"pcoupl = {self.pcoupl}", + f"gen-vel = {self.gen_vel}", + "", ] ) elif self.mdp_type == "nvt": mdp_lines.extend( [ - "; TEMPERATURE COUPLING", - ";----------------------------------------------------", - f"tc-grps = {self.tc_grps}", - f"tau-t = {self.tau_t}", - f"ref-t = {self.ref_t}", - "", - "; PRESSURE COUPLING", - ";----------------------------------------------------", - f"pcoupl = {self.pcoupl}", - "", + "; TEMPERATURE COUPLING", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + "", + "; PRESSURE COUPLING", + ";----------------------------------------------------", + f"pcoupl = {self.pcoupl}", + "", ] ) else: # npt, npt-norest, prod mdp_lines.extend( [ - ";----------------------------------------------------", - "; TEMPERATURE & PRESSURE COUPL", - ";----------------------------------------------------", - f"tc-grps = {self.tc_grps}", - f"tau-t = {self.tau_t}", - f"ref-t = {self.ref_t}", - f"pcoupl = {self.pcoupl}", - f"pcoupltype = {self.pcoupltype} ; uniform scaling of box vectors", - f"tau-p = {self.tau_p} ; time constant (ps)", - f"ref-p = {self.ref_p} ; reference pressure (bar)", - f"compressibility = {self.compressibility} ; isothermal compressibility of water (bar^-1)", + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + f"pcoupl = {self.pcoupl}", + f"pcoupltype = {self.pcoupltype} ; uniform scaling of box vectors", + f"tau-p = {self.tau_p} ; time constant (ps)", + f"ref-p = {self.ref_p} ; reference pressure (bar)", + f"compressibility = {self.compressibility} ; isothermal compressibility of water (bar^-1)", ] ) if self.refcoord_scaling: mdp_lines.append(f"refcoord-scaling = {self.refcoord_scaling}") mdp_lines.append("") - + # Velocity Generation (only for non-EM stages) if self.mdp_type != "em": mdp_lines.extend( [ "; VELOCITY GENERATION", - ";----------------------------------------------------", - f"gen_vel = {self.gen_vel} ; Velocity generation is {'on' if self.gen_vel == 'yes' else 'off'}", + ";----------------------------------------------------", + f"gen_vel = {self.gen_vel} ; Velocity generation is {'on' if self.gen_vel == 'yes' else 'off'}", f"gen-seed = {self.gen_seed} ; Use random seed", f"gen-temp = {self.gen_temp}", "", ] ) - + # Free Energy if self.free_energy == "yes": mdp_lines.extend( @@ -1056,38 +1080,38 @@ def write_config( f"init-lambda-state = {'' if self.init_lambda_state is None else self.init_lambda_state}", ] ) - + if self.bonded_lambdas: mdp_lines.append( f"bonded-lambdas = {' '.join(str(x) for x in self.bonded_lambdas)}" ) - + if self.coul_lambdas: mdp_lines.append( f"coul-lambdas = {' '.join(str(x) for x in self.coul_lambdas)}" ) - + if self.vdw_lambdas: mdp_lines.append( f"vdw-lambdas = {' '.join(str(x) for x in self.vdw_lambdas)}" ) - + mdp_lines.extend( [ - f"nstdhdl = {self.nstdhdl}", - f"dhdl-print-energy = {self.dhdl_print_energy}", - f"calc-lambda-neighbors = {self.calc_lambda_neighbors}", - f"separate-dhdl-file = {self.separate_dhdl_file}", - f"couple-intramol = {self.couple_intramol}", + f"nstdhdl = {self.nstdhdl}", + f"dhdl-print-energy = {self.dhdl_print_energy}", + f"calc-lambda-neighbors = {self.calc_lambda_neighbors}", + f"separate-dhdl-file = {self.separate_dhdl_file}", + f"couple-intramol = {self.couple_intramol}", ] ) - + # Extra options if self.extra_options: mdp_lines.append("") for key, value in self.extra_options.items(): mdp_lines.append(f"{key} = {value}") - + # Write file config_path = _os.path.join(run_dir, self.get_file_name()) with open(config_path, "w") as f: diff --git a/a3fe/configuration/system_prep_config.py b/a3fe/configuration/system_prep_config.py index 1cba3fa0..20d6dfc3 100644 --- a/a3fe/configuration/system_prep_config.py +++ b/a3fe/configuration/system_prep_config.py @@ -256,24 +256,13 @@ class GromacsSystemPreparationConfig(_BaseSystemPreparationConfig): lambda_values: _Dict[_LegType, _Dict[_StageType, _List[float]]] = _Field( default={ _LegType.BOUND: { - _StageType.RESTRAIN: [0.0, 0.05, 0.15, 0.5, 0.75, 1.0], - _StageType.DISCHARGE: [0.0, 0.3, 0.6, 0.9, 1.0], - _StageType.VANISH: [ - 0.0, - 0.05, - 0.2, - 0.25, - 0.4, - 0.5, - 0.65, - 0.8, - 0.9, - 1.0, - ], + _StageType.RESTRAIN: [0.0, 1.0], + _StageType.DISCHARGE: [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + _StageType.VANISH: [ 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], }, _LegType.FREE: { - _StageType.DISCHARGE: [0.0, 0.2, 0.4, 0.5, 0.7, 0.9, 1.0], - _StageType.VANISH: [0.0, 0.2, 0.35, 0.4, 0.55, 0.7, 0.85, 1.0], + _StageType.DISCHARGE: [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + _StageType.VANISH: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], }, }, description="Lambda values optimized for GROMACS.", diff --git a/a3fe/read/_read_exp_dgs.py b/a3fe/read/_read_exp_dgs.py index 6677f11a..5934cb55 100644 --- a/a3fe/read/_read_exp_dgs.py +++ b/a3fe/read/_read_exp_dgs.py @@ -2,11 +2,10 @@ This must have the columns: calc_base_dir, name, exp_dg, exp_er""" import os as _os - from typing import Optional as _Optional -import pandas as _pd import numpy as _np +import pandas as _pd def read_exp_dgs( diff --git a/a3fe/run/_simulation_runner.py b/a3fe/run/_simulation_runner.py index c56714eb..bd7ad06e 100644 --- a/a3fe/run/_simulation_runner.py +++ b/a3fe/run/_simulation_runner.py @@ -22,15 +22,14 @@ import pandas as _pd import scipy.stats as _stats +from .._version import __version__ as _version from ..analyse.exceptions import AnalysisError as _AnalysisError from ..analyse.plot import plot_convergence as _plot_convergence from ..analyse.plot import plot_sq_sem_convergence as _plot_sq_sem_convergence -from ._logging_formatters import _A3feFileFormatter, _A3feStreamFormatter - -from ..configuration import SlurmConfig as _SlurmConfig from ..configuration import EngineType as _EngineType +from ..configuration import SlurmConfig as _SlurmConfig from ..configuration import _EngineConfig -from .._version import __version__ as _version +from ._logging_formatters import _A3feFileFormatter, _A3feStreamFormatter class SimulationRunner(ABC): diff --git a/a3fe/run/_virtual_queue.py b/a3fe/run/_virtual_queue.py index be3a4ffa..f9d429c5 100644 --- a/a3fe/run/_virtual_queue.py +++ b/a3fe/run/_virtual_queue.py @@ -10,9 +10,9 @@ from typing import List as _List from typing import Optional as _Optional +from ..configuration.enums import JobStatus as _JobStatus from ._logging_formatters import _A3feFileFormatter, _A3feStreamFormatter from ._utils import retry as _retry -from ..configuration.enums import JobStatus as _JobStatus @_dataclass diff --git a/a3fe/run/calc_set.py b/a3fe/run/calc_set.py index ededb7da..29fb966b 100644 --- a/a3fe/run/calc_set.py +++ b/a3fe/run/calc_set.py @@ -12,12 +12,10 @@ import numpy as _np from scipy import stats as _stats -from ..configuration import SlurmConfig as _SlurmConfig -from ..configuration import _EngineConfig -from ..configuration import _BaseSystemPreparationConfig - from ..analyse.analyse_set import compute_stats as _compute_stats from ..analyse.plot import plot_against_exp as _plt_against_exp +from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import _BaseSystemPreparationConfig, _EngineConfig from ..read._read_exp_dgs import read_exp_dgs as _read_exp_dgs from ._simulation_runner import SimulationRunner as _SimulationRunner from ._utils import SimulationRunnerIterator as _SimulationRunnerIterator diff --git a/a3fe/run/calculation.py b/a3fe/run/calculation.py index e8e101a7..d6acb7e7 100644 --- a/a3fe/run/calculation.py +++ b/a3fe/run/calculation.py @@ -6,17 +6,16 @@ import logging as _logging import os as _os import time as _time +from pathlib import Path as _Path from typing import List as _List from typing import Optional as _Optional -from pathlib import Path as _Path -from ._simulation_runner import SimulationRunner as _SimulationRunner +from ..configuration import EngineType as _EngineType +from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import _BaseSystemPreparationConfig, _EngineConfig from ..configuration.enums import PreparationStage as _PreparationStage +from ._simulation_runner import SimulationRunner as _SimulationRunner from .leg import Leg as _Leg -from ..configuration import _BaseSystemPreparationConfig -from ..configuration import SlurmConfig as _SlurmConfig -from ..configuration import _EngineConfig -from ..configuration import EngineType as _EngineType class Calculation(_SimulationRunner): diff --git a/a3fe/run/lambda_window.py b/a3fe/run/lambda_window.py index 87150768..4fa6712a 100644 --- a/a3fe/run/lambda_window.py +++ b/a3fe/run/lambda_window.py @@ -17,12 +17,12 @@ from ..analyse.detect_equil import ( dummy_check_equil_multiwindow as _dummy_check_equil_multiwindow, ) +from ..configuration import EngineType as _EngineType +from ..configuration import SlurmConfig as _SlurmConfig +from ..configuration import _EngineConfig from ._simulation_runner import SimulationRunner as _SimulationRunner from ._virtual_queue import VirtualQueue as _VirtualQueue from .simulation import Simulation as _Simulation -from ..configuration import SlurmConfig as _SlurmConfig -from ..configuration import _EngineConfig -from ..configuration import EngineType as _EngineType class LamWindow(_SimulationRunner): @@ -350,36 +350,53 @@ def _write_equilibrated_simfiles(self) -> None: ) # Get the index of the first equilibrated data point - # Minus 1 because first energy is only written after the first nrg_freq steps - equil_index = ( - int( - self._equil_time - / ( - self.sims[0].engine_config.timestep - * self.sims[0].engine_config.energy_frequency - ) - ) - - 1 # type: ignore + config = self.sims[0].engine_config + + if self.sims[0].engine_type == _EngineType.GROMACS: + # GROMACS: dt in ps, nstdhdl is steps (dH/dlambda output frequency), convert to ns + # First energy is written at time 0, so no offset needed + time_per_energy = config.dt * config.nstdhdl / 1000 + equil_index = int(self._equil_time / time_per_energy) + else: + # SOMD: timestep in fs, energy_frequency is steps, convert to ns + # First energy is only written after the first nrg_freq steps, so subtract 1 + time_per_energy = config.timestep * config.energy_frequency / 1_000_000 + equil_index = int(self._equil_time / time_per_energy) - 1 + + if equil_index < 0: + raise ValueError( + f"Equilibration time ({self._equil_time:.3f} ns) is too short. " + f"Must be at least {time_per_energy:.6f} ns (one energy output interval)." ) # Write the equilibrated data for each simulation for sim in self.sims: - in_simfile = sim.output_dir + "/simfile.dat" - out_simfile = sim.output_dir + "/simfile_equilibrated.dat" + # Set file paths based on engine type + if sim.engine_type == _EngineType.GROMACS: + in_file = sim.output_dir + "/prod/prod.xvg" + out_file = sim.output_dir + "/prod/prod_equilibrated.xvg" + header_chars = ("#", "@") + else: + in_file = sim.output_dir + "/simfile.dat" + out_file = sim.output_dir + "/simfile_equilibrated.dat" + header_chars = ("#",) + + if not _os.path.exists(in_file): + continue - with open(in_simfile, "r") as ifile: + with open(in_file, "r") as ifile: lines = ifile.readlines() # Figure out how many lines come before the data non_data_lines = 0 for line in lines: - if line.startswith("#"): + if line.startswith(header_chars): non_data_lines += 1 else: break # Overwrite the original file with one containing only the equilibrated data - with open(out_simfile, "w") as ofile: + with open(out_file, "w") as ofile: # First, write the header for line in lines[:non_data_lines]: ofile.write(line) diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index 4fccbfeb..64834a6c 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -183,23 +183,20 @@ def running(self) -> bool: elif self.job.status == "FAILED": old_job = self.job self._logger.info(f"{old_job} failed - resubmitting") - # Move log files and s3 files so that the job does not restart - _subprocess.run(["mkdir", f"{self.output_dir}/failure"]) - for s3_file in _glob.glob(f"{self.output_dir}/*.s3"): - _subprocess.run( - [ - "mv", - f"{self.output_dir}/{s3_file}", - f"{self.output_dir}/failure", - ] - ) - _subprocess.run( - [ - "mv", - old_job.slurm_outfile, - f"{self.output_dir}/failure", - ] - ) + # Move log files and checkpoint files so that the job does not restart + _subprocess.run(["mkdir", "-p", f"{self.output_dir}/failure"]) + + if self.engine_type == _EngineType.GROMACS: + # Move GROMACS checkpoint files + for cpt_file in _glob.glob(f"{self.output_dir}/**/*.cpt", recursive=True): + _subprocess.run(["mv", cpt_file, f"{self.output_dir}/failure"]) + else: # SOMD + # Move SOMD s3 files + for s3_file in _glob.glob(f"{self.output_dir}/*.s3"): + _subprocess.run(["mv", s3_file, f"{self.output_dir}/failure"]) + + _subprocess.run(["mv", old_job.slurm_outfile, f"{self.output_dir}/failure"]) + # Now resubmit cmd_list = old_job.command_list self.job = self.virtual_queue.submit( @@ -349,28 +346,35 @@ def get_tot_simtime(self) -> float: tot_simtime : float Total simulation time in ns. """ - data_simfile = f"{self.output_dir}/simfile.dat" - if not _pathlib.Path(data_simfile).is_file(): - # Simuation has not been run, hence total simulation time is 0 - return 0 - elif _os.stat(data_simfile).st_size == 0: - # Simfile is empty, hence total simulation time is 0 - return 0 - else: - # Read last line of simfile with subprocess to make as fast as possible - step = int( - _subprocess.check_output( - [ - "tail", - "-1", - f"{self.output_dir}/simfile.dat", - ] + if self.engine_type == _EngineType.GROMACS: + data_file = f"{self.output_dir}/prod/prod.xvg" + if not _pathlib.Path(data_file).is_file(): + return 0 + elif _os.stat(data_file).st_size == 0: + return 0 + else: + # Read last non-comment line of xvg file + last_line = _subprocess.check_output( + ["grep", "-v", "^[#@]", data_file] + ).decode("utf-8").strip().split("\n")[-1] + time_ps = float(last_line.split()[0]) + return time_ps / 1000.0 # ps to ns + else: # SOMD + data_simfile = f"{self.output_dir}/simfile.dat" + if not _pathlib.Path(data_simfile).is_file(): + return 0 + elif _os.stat(data_simfile).st_size == 0: + return 0 + else: + step = int( + _subprocess.check_output( + ["tail", "-1", f"{self.output_dir}/simfile.dat"] + ) + .decode("utf-8") + .strip() + .split()[0] ) - .decode("utf-8") - .strip() - .split()[0] - ) - return step * (self.engine_config.timestep / 1_000_000) # ns + return step * (self.engine_config.timestep / 1_000_000) # ns def get_tot_gpu_time(self) -> float: """ @@ -390,11 +394,28 @@ def get_tot_gpu_time(self) -> float: # Otherwise, add up the simulation time in seconds tot_gpu_time = 0 - for file in slurm_output_files: - with open(file, "rt") as file: - for line in file.readlines(): - if line.startswith("Simulation took"): - tot_gpu_time += float(line.split(" ")[2]) + + if self.engine_type == _EngineType.GROMACS: + # GROMACS: look for "Time:" in performance summary + # Format: " Time: 2496.669 156.055 1599.9" + # We want the Wall time (second number, in seconds) + for file in slurm_output_files: + with open(file, "rt") as f: + for line in f.readlines(): + if line.strip().startswith("Time:"): + try: + parts = line.split() + wall_time = float(parts[2]) # Wall time in seconds + tot_gpu_time += wall_time + except (IndexError, ValueError): + continue + else: # SOMD + # SOMD: look for "Simulation took" + for file in slurm_output_files: + with open(file, "rt") as f: + for line in f.readlines(): + if line.startswith("Simulation took"): + tot_gpu_time += float(line.split(" ")[2]) # And convert to GPU hours return tot_gpu_time / 3600 @@ -421,21 +442,35 @@ def failed(self) -> bool: # "Simulation took" line if self.slurm_output_files: for file in self.slurm_output_files: - with open(file, "rt") as file: - failed = True - for line in file.readlines(): - if line.startswith("Simulation took"): - # File shows success, so continue to next file - failed = False - break - # We haven't found "Simulation took" in this file, indicating failure - if failed: - return True - - # Either We aren't running and have output files, all with the "Simulation took" line, - # or we aren't running and have no output files - either way, we haven't failed + if not self._check_simulation_success(file): + return True + return False + def _check_simulation_success(self, slurm_file: str) -> bool: + """ + Check if simulation completed successfully based on engine type. + + Parameters + ---------- + slurm_file : str + Path to SLURM output file + + Returns + ------- + success : bool + True if simulation succeeded, False otherwise + """ + with open(slurm_file, "rt") as f: + content = f.read() + + if self.engine_type == _EngineType.SOMD: + # SOMD success: "Simulation took" line in SLURM output + return "Simulation took" in content + else: # GROMACS + # GROMACS success: Performance line appears only when stage completes + return "Performance:" in content + @property def slurm_output_files(self) -> _List[str]: """Get a list of all slurm output files for this simulation.""" @@ -452,25 +487,20 @@ def kill(self) -> None: def lighten(self) -> None: """Lighten the simulation by deleting all restart and trajectory files.""" - delete_files = [ - "*.dcd", - "*.s3", - "*.s3.previous", - "gradients.s3", - "simfile_equilibrated.dat", - "latest.pdb", - ] - - for del_file in delete_files: - # Delete files in base directory - for file in _pathlib.Path(self.base_dir).glob(del_file): - self._logger.info(f"Deleting {file}") - _subprocess.run(["rm", file]) - - # Delete files in output directory - for file in _pathlib.Path(self.output_dir).glob(del_file): - self._logger.info(f"Deleting {file}") - _subprocess.run(["rm", file]) + if self.engine_type == _EngineType.GROMACS: + patterns = ["*.xtc", "*.trr", "*.cpt", "*_equilibrated.xvg"] + use_recursive = True + else: # SOMD + patterns = ["*.dcd", "*.s3", "*.s3.previous", "gradients.s3", + "simfile_equilibrated.dat", "latest.pdb"] + use_recursive = False + + for directory in [self.base_dir, self.output_dir]: + for pattern in patterns: + glob_func = _pathlib.Path(directory).rglob if use_recursive else _pathlib.Path(directory).glob + for file in glob_func(pattern): + self._logger.info(f"Deleting {file}") + _subprocess.run(["rm", str(file)]) def read_gradients( self, equilibrated_only: bool = False, endstate: bool = False @@ -496,7 +526,32 @@ def read_gradients( grads : np.ndarray Array of gradients, in kcal/mol. """ - # Read the output file + # GROMACS: read .xvg file + if self.engine_type == _EngineType.GROMACS: + filename = "prod/prod_equilibrated.xvg" if equilibrated_only else "prod/prod.xvg" + times = [] + grads = [] + + with open(_os.path.join(self.output_dir, filename), "r") as f: + for line in f: + if line.startswith(('#', '@')) or not line.strip(): + continue + vals = line.split() + time_ps = float(vals[0]) + + if endstate: + energy_start = float(vals[3]) + energy_end = float(vals[-2]) + grad_kj = energy_end - energy_start + else: + grad_kj = float(vals[1]) + + times.append(time_ps / 1000.0) + grads.append(grad_kj / 4.184) + + return _np.array(times), _np.array(grads) + + # SOMD: read simfile.dat if equilibrated_only: with open( _os.path.join(self.output_dir, "simfile_equilibrated.dat"), "r" diff --git a/a3fe/run/stage.py b/a3fe/run/stage.py index 2438470f..f735bfe2 100644 --- a/a3fe/run/stage.py +++ b/a3fe/run/stage.py @@ -822,36 +822,37 @@ def analyse( tmp_files=tmp_files, ) - mean_free_energy = _np.mean(free_energies) - # Gaussian 95 % C.I. - conf_int = ( - _stats.t.interval( - 0.95, - len(free_energies) - 1, - mean_free_energy, - scale=_stats.sem(free_energies), - )[1] - - mean_free_energy - ) # 95 % C.I. - - # Write overall MBAR stats to file - with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile: - if get_frnrg: - ofile.write( - "###################################### Free Energies ########################################\n" - ) - ofile.write( - f"Mean free energy: {mean_free_energy: .3f} + /- {conf_int:.3f} kcal/mol\n" - ) - for i in range(len(free_energies)): - ofile.write( - f"Free energy from run {i + 1}: {free_energies[i]: .3f} +/- {errors[i]:.3f} kcal/mol\n" - ) + mean_free_energy = _np.mean(free_energies) + # Gaussian 95 % C.I. + conf_int = ( + _stats.t.interval( + 0.95, + len(free_energies) - 1, + mean_free_energy, + scale=_stats.sem(free_energies), + )[1] + - mean_free_energy + ) # 95 % C.I. + + # Write overall MBAR stats to file + with open(f"{self.output_dir}/overall_stats.dat", "a") as ofile: + if get_frnrg: + ofile.write( + "###################################### Free Energies ########################################\n" + ) + ofile.write( + f"Mean free energy: {mean_free_energy: .3f} + /- {conf_int:.3f} kcal/mol\n" + ) + for i in range(len(free_energies)): ofile.write( - "Errors are 95 % C.I.s based on the assumption of a Gaussian distribution of free energies\n" + f"Free energy from run {i + 1}: {free_energies[i]: .3f} +/- {errors[i]:.3f} kcal/mol\n" ) - ofile.write(f"Runs analysed: {run_nos}\n") + ofile.write( + "Errors are 95 % C.I.s based on the assumption of a Gaussian distribution of free energies\n" + ) + ofile.write(f"Runs analysed: {run_nos}\n") + if get_frnrg: # Plot overlap matrices and PMFs _plot_overlap_mats( output_dir=self.output_dir, From 2c00347a96e0e28256d67db063183b12ae915d69 Mon Sep 17 00:00:00 2001 From: Roy Haolin Du Date: Mon, 8 Dec 2025 10:15:11 +0000 Subject: [PATCH 3/3] Format code with ruff --- a3fe/configuration/engine_config.py | 252 +++++++++--------- a3fe/configuration/system_prep_config.py | 38 ++- a3fe/run/lambda_window.py | 2 +- a3fe/run/simulation.py | 65 +++-- a3fe/tests/test_calc_set.py | 3 +- a3fe/tests/test_engine_configuration.py | 3 +- a3fe/tests/test_run.py | 1 - a3fe/tests/test_run_integration.py | 10 +- a3fe/tests/test_slurm_configuration.py | 5 +- .../tests/test_update_engine_config_option.py | 5 +- 10 files changed, 217 insertions(+), 167 deletions(-) diff --git a/a3fe/configuration/engine_config.py b/a3fe/configuration/engine_config.py index c210f673..dead8577 100644 --- a/a3fe/configuration/engine_config.py +++ b/a3fe/configuration/engine_config.py @@ -492,7 +492,7 @@ class GromacsConfig(_EngineConfig): nstlog: int = _Field(500, description="Update log file") nstenergy: int = _Field(500, description="Save energies") nstcalcenergy: int = _Field(50, description="Calculate energies") - + ### Bonds ### constraint_algorithm: _Literal["lincs", "shake"] = _Field( "lincs", description="Constraint algorithm" @@ -504,7 +504,7 @@ class GromacsConfig(_EngineConfig): lincs_order: int = _Field(6, description="LINCS order") lincs_warnangle: int = _Field(30, description="LINCS warning angle") continuation: _Literal["yes", "no"] = _Field("yes", description="Continuation") - + ### Neighbor Searching ### cutoff_scheme: _Literal["Verlet", "group"] = _Field( "Verlet", description="Cutoff scheme" @@ -512,7 +512,7 @@ class GromacsConfig(_EngineConfig): ns_type: _Literal["grid", "simple"] = _Field("grid", description="Neighbor search") nstlist: int = _Field(20, description="Update neighbor list") rlist: float = _Field(1.2, description="Neighbor list cutoff (nm)") - + ### Electrostatics ### coulombtype: _Literal["PME", "Cut-off"] = _Field("PME", description="Coulomb type") rcoulomb: float = _Field(1.0, description="Coulomb cutoff (nm)") @@ -520,7 +520,7 @@ class GromacsConfig(_EngineConfig): pme_order: int = _Field(4, description="PME order") fourierspacing: float = _Field(0.10, description="PME grid spacing (nm)") ewald_rtol: float = _Field(1e-6, description="Ewald tolerance") - + ### VDW ### vdwtype: _Literal["Cut-off", "PME"] = _Field("Cut-off", description="VdW type") vdw_modifier: _Literal["Potential-shift-Verlet", "None"] = _Field( @@ -533,13 +533,13 @@ class GromacsConfig(_EngineConfig): DispCorr: _Literal["EnerPres", "Ener", "no"] = _Field( "EnerPres", description="Long range dispersion corrections" ) - + ### Temperature Coupling ### tcoupl: _Literal["no", "yes"] = _Field("no", description="Temperature coupling") tc_grps: str = _Field("System", description="Temperature coupling groups") tau_t: float = _Field(2.0, description="Time constant for T-coupling (ps)") ref_t: float = _Field(298.15, description="Reference temperature (K)") - + ### Pressure Coupling ### pcoupl: _Literal["no", "Berendsen", "C-rescale", "Parrinello-Rahman"] = _Field( "Parrinello-Rahman", description="Pressure coupling" @@ -553,7 +553,7 @@ class GromacsConfig(_EngineConfig): refcoord_scaling: _Optional[_Literal["all", "com", "no"]] = _Field( None, description="Reference coordinate scaling" ) - + ### Velocity Generation ### gen_vel: _Literal["yes", "no"] = _Field("no", description="Generate velocities") gen_seed: int = _Field(-1, description="Random seed") @@ -572,7 +572,7 @@ class GromacsConfig(_EngineConfig): None, description="Boresch restraints dictionary content (interface compatibility, handled in topology)", ) - + ### Free Energy ### perturbed_residue_number: int = _Field( 1, @@ -631,7 +631,7 @@ class GromacsConfig(_EngineConfig): couple_intramol: _Literal["yes", "no"] = _Field( "yes", description="Couple intramolecular" ) - + ### Extra options ### extra_options: _Dict[str, str] = _Field( default_factory=dict, description="Extra options" @@ -649,7 +649,7 @@ def timestep(self) -> float: def setup_lambda_arrays(self, stage_type) -> None: """ Set up GROMACS-specific bonded/coul/vdw lambda arrays. - + Parameters ---------- stage_type : StageType @@ -708,25 +708,25 @@ def _configure_for_mdp_type(self) -> None: self.nstcomm = 100 self.nstxout = 250 self.nstlist = 1 - + elif self.mdp_type == "nvt": self.nsteps = 5000 # 10 ps self.continuation = "no" self.gen_vel = "yes" self.pcoupl = "no" self.nstxout = 25000 - + elif self.mdp_type == "npt": self.nsteps = 50000 # 100 ps self.pcoupl = "C-rescale" # GROMACS 2025 self.tau_p = 1.0 self.refcoord_scaling = "all" self.nstxout = 25000 - + elif self.mdp_type == "npt-norest": self.nsteps = 250000 # 500 ps self.nstxout = 25000 - + else: # prod self.nsteps = 2500000 # 5 ns (will be overridden by runtime) self.nstxout = 0 @@ -745,7 +745,7 @@ def write_config( """ # Configure based on type self._configure_for_mdp_type() - + # Override nsteps for prod based on runtime if self.mdp_type == "prod": runtime_ps = runtime * 1000 @@ -754,7 +754,7 @@ def write_config( # Find lambda state index from the active lambda array # Priority: find array containing lambda_val, otherwise use varying array lambda_array = None - + # First, try to find array containing lambda_val if self.vdw_lambdas and lambda_val in self.vdw_lambdas: lambda_array = self.vdw_lambdas @@ -762,7 +762,7 @@ def write_config( lambda_array = self.coul_lambdas elif self.bonded_lambdas and lambda_val in self.bonded_lambdas: lambda_array = self.bonded_lambdas - + # If not found, use varying array (not all values are the same) if lambda_array is None: if self.vdw_lambdas and len(set(self.vdw_lambdas)) > 1: @@ -771,7 +771,7 @@ def write_config( lambda_array = self.coul_lambdas elif self.bonded_lambdas and len(set(self.bonded_lambdas)) > 1: lambda_array = self.bonded_lambdas - + if lambda_array is None: raise ValueError( f"Lambda {lambda_val} not found in any lambda array. " @@ -779,7 +779,7 @@ def write_config( f"coul: {self.coul_lambdas}, " f"vdw: {self.vdw_lambdas}" ) - + try: self.init_lambda_state = lambda_array.index(lambda_val) except ValueError: @@ -792,47 +792,47 @@ def write_config( ";====================================================", "", ] - + # Run Control if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; RUN CONTROL & MINIMIZATION", - ";----------------------------------------------------", + ";----------------------------------------------------", + "; RUN CONTROL & MINIMIZATION", + ";----------------------------------------------------", ] ) if self.define: mdp_lines.append(f"define = {self.define}") mdp_lines.extend( [ - f"integrator = {self.integrator}", - f"nsteps = {self.nsteps}", - f"emtol = {self.emtol}", - f"emstep = {self.emstep}", - f"nstcomm = {self.nstcomm}", - f"pbc = {self.pbc}", + f"integrator = {self.integrator}", + f"nsteps = {self.nsteps}", + f"emtol = {self.emtol}", + f"emstep = {self.emstep}", + f"nstcomm = {self.nstcomm}", + f"pbc = {self.pbc}", ] ) else: mdp_lines.extend( [ - "; RUN CONTROL", - ";----------------------------------------------------", + "; RUN CONTROL", + ";----------------------------------------------------", ] ) if self.define: mdp_lines.append(f"define = {self.define}") mdp_lines.extend( [ - f"integrator = {self.integrator:<13} ; langevin integrator", - f"nsteps = {self.nsteps:<13} ; {self.dt} * {self.nsteps} fs = {self.nsteps * self.dt * 0.001:.0f} ps", - f"dt = {self.dt:<13} ; {self.dt * 1000:.0f} fs", - f"comm-mode = {self.comm_mode:<13} ; remove center of mass translation", - f"nstcomm = {self.nstcomm:<13} ; frequency for center of mass motion removal", + f"integrator = {self.integrator:<13} ; langevin integrator", + f"nsteps = {self.nsteps:<13} ; {self.dt} * {self.nsteps} fs = {self.nsteps * self.dt * 0.001:.0f} ps", + f"dt = {self.dt:<13} ; {self.dt * 1000:.0f} fs", + f"comm-mode = {self.comm_mode:<13} ; remove center of mass translation", + f"nstcomm = {self.nstcomm:<13} ; frequency for center of mass motion removal", ] ) - + # Output Control mdp_lines.extend( [ @@ -849,30 +849,30 @@ def write_config( if self.nstxout > 0 else "don't save coordinates to .trr" ), - f"nstvout = {self.nstvout:<10} ; don't save velocities to .trr", - f"nstfout = {self.nstfout:<10} ; don't save forces to .trr", - "", - f"nstxout-compressed = {self.nstxout_compressed:<10} ; xtc compressed trajectory output every {self.nstxout_compressed} steps", - f"compressed-x-precision = {self.compressed_x_precision}", - f"nstlog = {self.nstlog:<10} ; update log file every {self.nstlog} steps", - f"nstenergy = {self.nstenergy:<10} ; save energies every {self.nstenergy} steps", - f"nstcalcenergy = {self.nstcalcenergy}", - "", + f"nstvout = {self.nstvout:<10} ; don't save velocities to .trr", + f"nstfout = {self.nstfout:<10} ; don't save forces to .trr", + "", + f"nstxout-compressed = {self.nstxout_compressed:<10} ; xtc compressed trajectory output every {self.nstxout_compressed} steps", + f"compressed-x-precision = {self.compressed_x_precision}", + f"nstlog = {self.nstlog:<10} ; update log file every {self.nstlog} steps", + f"nstenergy = {self.nstenergy:<10} ; save energies every {self.nstenergy} steps", + f"nstcalcenergy = {self.nstcalcenergy}", + "", ] ) - + # Neighbor Searching if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; NEIGHBOR SEARCHING", - ";----------------------------------------------------", - f"cutoff-scheme = {self.cutoff_scheme}", - f"ns-type = {self.ns_type}", - f"nstlist = {self.nstlist}", - f"rlist = {self.rlist}", - "", + ";----------------------------------------------------", + "; NEIGHBOR SEARCHING", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type}", + f"nstlist = {self.nstlist}", + f"rlist = {self.rlist}", + "", ] ) else: @@ -881,16 +881,16 @@ def write_config( "; NEIGHBOR SEARCHING" if self.mdp_type == "nvt" else ";----------------------------------------------------", - ";----------------------------------------------------", - f"cutoff-scheme = {self.cutoff_scheme}", - f"ns-type = {self.ns_type:<6} ; search neighboring grid cells", - f"nstlist = {self.nstlist:<6} ; {self.nstlist * self.dt * 1000:.0f} fs", - f"rlist = {self.rlist:<6} ; short-range neighborlist cutoff (in nm)", - f"pbc = {self.pbc:<6} ; 3D PBC", - "", + ";----------------------------------------------------", + f"cutoff-scheme = {self.cutoff_scheme}", + f"ns-type = {self.ns_type:<6} ; search neighboring grid cells", + f"nstlist = {self.nstlist:<6} ; {self.nstlist * self.dt * 1000:.0f} fs", + f"rlist = {self.rlist:<6} ; short-range neighborlist cutoff (in nm)", + f"pbc = {self.pbc:<6} ; 3D PBC", + "", ] ) - + # Bonds (skip for EM) if self.mdp_type != "em": bonds_header = ( @@ -922,20 +922,20 @@ def write_config( bonds_section.extend([f"continuation = {self.continuation}", ""]) mdp_lines.extend(bonds_section) - + # Electrostatics if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; ELECTROSTATICS", - ";----------------------------------------------------", - f"coulombtype = {self.coulombtype}", - f"rcoulomb = {self.rcoulomb}", - f"pme-order = {self.pme_order}", - f"fourierspacing = {self.fourierspacing}", - f"ewald-rtol = {self.ewald_rtol}", - "", + ";----------------------------------------------------", + "; ELECTROSTATICS", + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype}", + f"rcoulomb = {self.rcoulomb}", + f"pme-order = {self.pme_order}", + f"fourierspacing = {self.fourierspacing}", + f"ewald-rtol = {self.ewald_rtol}", + "", ] ) else: @@ -947,17 +947,17 @@ def write_config( mdp_lines.extend( [ elec_header, - ";----------------------------------------------------", - f"coulombtype = {self.coulombtype:<6} ; Particle Mesh Ewald for long-range electrostatics", - f"rcoulomb = {self.rcoulomb:<6} ; short-range electrostatic cutoff (in nm)", - f"ewald_geometry = {self.ewald_geometry:<6} ; Ewald sum is performed in all three dimensions", - f"pme-order = {self.pme_order:<6} ; interpolation order for PME (default is 4)", - f"fourierspacing = {self.fourierspacing:<6} ; grid spacing for FFT", - f"ewald-rtol = {self.ewald_rtol:<6} ; relative strength of the Ewald-shifted direct potential at rcoulomb", - "", + ";----------------------------------------------------", + f"coulombtype = {self.coulombtype:<6} ; Particle Mesh Ewald for long-range electrostatics", + f"rcoulomb = {self.rcoulomb:<6} ; short-range electrostatic cutoff (in nm)", + f"ewald_geometry = {self.ewald_geometry:<6} ; Ewald sum is performed in all three dimensions", + f"pme-order = {self.pme_order:<6} ; interpolation order for PME (default is 4)", + f"fourierspacing = {self.fourierspacing:<6} ; grid spacing for FFT", + f"ewald-rtol = {self.ewald_rtol:<6} ; relative strength of the Ewald-shifted direct potential at rcoulomb", + "", ] ) - + # VDW if self.mdp_type == "em": vdw_header = [ @@ -1002,68 +1002,68 @@ def write_config( "", ] ) - + # Temperature & Pressure Coupling if self.mdp_type == "em": mdp_lines.extend( [ - ";----------------------------------------------------", - "; TEMPERATURE & PRESSURE COUPL", - ";----------------------------------------------------", - f"tcoupl = {self.tcoupl}", - f"pcoupl = {self.pcoupl}", - f"gen-vel = {self.gen_vel}", - "", + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tcoupl = {self.tcoupl}", + f"pcoupl = {self.pcoupl}", + f"gen-vel = {self.gen_vel}", + "", ] ) elif self.mdp_type == "nvt": mdp_lines.extend( [ - "; TEMPERATURE COUPLING", - ";----------------------------------------------------", - f"tc-grps = {self.tc_grps}", - f"tau-t = {self.tau_t}", - f"ref-t = {self.ref_t}", - "", - "; PRESSURE COUPLING", - ";----------------------------------------------------", - f"pcoupl = {self.pcoupl}", - "", + "; TEMPERATURE COUPLING", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + "", + "; PRESSURE COUPLING", + ";----------------------------------------------------", + f"pcoupl = {self.pcoupl}", + "", ] ) else: # npt, npt-norest, prod mdp_lines.extend( [ - ";----------------------------------------------------", - "; TEMPERATURE & PRESSURE COUPL", - ";----------------------------------------------------", - f"tc-grps = {self.tc_grps}", - f"tau-t = {self.tau_t}", - f"ref-t = {self.ref_t}", - f"pcoupl = {self.pcoupl}", - f"pcoupltype = {self.pcoupltype} ; uniform scaling of box vectors", - f"tau-p = {self.tau_p} ; time constant (ps)", - f"ref-p = {self.ref_p} ; reference pressure (bar)", - f"compressibility = {self.compressibility} ; isothermal compressibility of water (bar^-1)", + ";----------------------------------------------------", + "; TEMPERATURE & PRESSURE COUPL", + ";----------------------------------------------------", + f"tc-grps = {self.tc_grps}", + f"tau-t = {self.tau_t}", + f"ref-t = {self.ref_t}", + f"pcoupl = {self.pcoupl}", + f"pcoupltype = {self.pcoupltype} ; uniform scaling of box vectors", + f"tau-p = {self.tau_p} ; time constant (ps)", + f"ref-p = {self.ref_p} ; reference pressure (bar)", + f"compressibility = {self.compressibility} ; isothermal compressibility of water (bar^-1)", ] ) if self.refcoord_scaling: mdp_lines.append(f"refcoord-scaling = {self.refcoord_scaling}") mdp_lines.append("") - + # Velocity Generation (only for non-EM stages) if self.mdp_type != "em": mdp_lines.extend( [ "; VELOCITY GENERATION", - ";----------------------------------------------------", - f"gen_vel = {self.gen_vel} ; Velocity generation is {'on' if self.gen_vel == 'yes' else 'off'}", + ";----------------------------------------------------", + f"gen_vel = {self.gen_vel} ; Velocity generation is {'on' if self.gen_vel == 'yes' else 'off'}", f"gen-seed = {self.gen_seed} ; Use random seed", f"gen-temp = {self.gen_temp}", "", ] ) - + # Free Energy if self.free_energy == "yes": mdp_lines.extend( @@ -1080,38 +1080,38 @@ def write_config( f"init-lambda-state = {'' if self.init_lambda_state is None else self.init_lambda_state}", ] ) - + if self.bonded_lambdas: mdp_lines.append( f"bonded-lambdas = {' '.join(str(x) for x in self.bonded_lambdas)}" ) - + if self.coul_lambdas: mdp_lines.append( f"coul-lambdas = {' '.join(str(x) for x in self.coul_lambdas)}" ) - + if self.vdw_lambdas: mdp_lines.append( f"vdw-lambdas = {' '.join(str(x) for x in self.vdw_lambdas)}" ) - + mdp_lines.extend( [ - f"nstdhdl = {self.nstdhdl}", - f"dhdl-print-energy = {self.dhdl_print_energy}", - f"calc-lambda-neighbors = {self.calc_lambda_neighbors}", - f"separate-dhdl-file = {self.separate_dhdl_file}", - f"couple-intramol = {self.couple_intramol}", + f"nstdhdl = {self.nstdhdl}", + f"dhdl-print-energy = {self.dhdl_print_energy}", + f"calc-lambda-neighbors = {self.calc_lambda_neighbors}", + f"separate-dhdl-file = {self.separate_dhdl_file}", + f"couple-intramol = {self.couple_intramol}", ] ) - + # Extra options if self.extra_options: mdp_lines.append("") for key, value in self.extra_options.items(): mdp_lines.append(f"{key} = {value}") - + # Write file config_path = _os.path.join(run_dir, self.get_file_name()) with open(config_path, "w") as f: diff --git a/a3fe/configuration/system_prep_config.py b/a3fe/configuration/system_prep_config.py index 20d6dfc3..f4ea757c 100644 --- a/a3fe/configuration/system_prep_config.py +++ b/a3fe/configuration/system_prep_config.py @@ -258,11 +258,45 @@ class GromacsSystemPreparationConfig(_BaseSystemPreparationConfig): _LegType.BOUND: { _StageType.RESTRAIN: [0.0, 1.0], _StageType.DISCHARGE: [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], - _StageType.VANISH: [ 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], + _StageType.VANISH: [ + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + 1.0, + ], }, _LegType.FREE: { _StageType.DISCHARGE: [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], - _StageType.VANISH: [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], + _StageType.VANISH: [ + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.55, + 0.6, + 0.65, + 0.7, + 0.75, + 0.8, + 0.85, + 0.9, + 0.95, + 1.0, + ], }, }, description="Lambda values optimized for GROMACS.", diff --git a/a3fe/run/lambda_window.py b/a3fe/run/lambda_window.py index 4fa6712a..ce1bd0fe 100644 --- a/a3fe/run/lambda_window.py +++ b/a3fe/run/lambda_window.py @@ -367,7 +367,7 @@ def _write_equilibrated_simfiles(self) -> None: raise ValueError( f"Equilibration time ({self._equil_time:.3f} ns) is too short. " f"Must be at least {time_per_energy:.6f} ns (one energy output interval)." - ) + ) # Write the equilibrated data for each simulation for sim in self.sims: diff --git a/a3fe/run/simulation.py b/a3fe/run/simulation.py index 64834a6c..03037568 100644 --- a/a3fe/run/simulation.py +++ b/a3fe/run/simulation.py @@ -185,18 +185,22 @@ def running(self) -> bool: self._logger.info(f"{old_job} failed - resubmitting") # Move log files and checkpoint files so that the job does not restart _subprocess.run(["mkdir", "-p", f"{self.output_dir}/failure"]) - + if self.engine_type == _EngineType.GROMACS: # Move GROMACS checkpoint files - for cpt_file in _glob.glob(f"{self.output_dir}/**/*.cpt", recursive=True): + for cpt_file in _glob.glob( + f"{self.output_dir}/**/*.cpt", recursive=True + ): _subprocess.run(["mv", cpt_file, f"{self.output_dir}/failure"]) else: # SOMD # Move SOMD s3 files for s3_file in _glob.glob(f"{self.output_dir}/*.s3"): _subprocess.run(["mv", s3_file, f"{self.output_dir}/failure"]) - - _subprocess.run(["mv", old_job.slurm_outfile, f"{self.output_dir}/failure"]) - + + _subprocess.run( + ["mv", old_job.slurm_outfile, f"{self.output_dir}/failure"] + ) + # Now resubmit cmd_list = old_job.command_list self.job = self.virtual_queue.submit( @@ -354,9 +358,12 @@ def get_tot_simtime(self) -> float: return 0 else: # Read last non-comment line of xvg file - last_line = _subprocess.check_output( - ["grep", "-v", "^[#@]", data_file] - ).decode("utf-8").strip().split("\n")[-1] + last_line = ( + _subprocess.check_output(["grep", "-v", "^[#@]", data_file]) + .decode("utf-8") + .strip() + .split("\n")[-1] + ) time_ps = float(last_line.split()[0]) return time_ps / 1000.0 # ps to ns else: # SOMD @@ -394,7 +401,7 @@ def get_tot_gpu_time(self) -> float: # Otherwise, add up the simulation time in seconds tot_gpu_time = 0 - + if self.engine_type == _EngineType.GROMACS: # GROMACS: look for "Time:" in performance summary # Format: " Time: 2496.669 156.055 1599.9" @@ -450,12 +457,12 @@ def failed(self) -> bool: def _check_simulation_success(self, slurm_file: str) -> bool: """ Check if simulation completed successfully based on engine type. - + Parameters ---------- slurm_file : str Path to SLURM output file - + Returns ------- success : bool @@ -463,7 +470,7 @@ def _check_simulation_success(self, slurm_file: str) -> bool: """ with open(slurm_file, "rt") as f: content = f.read() - + if self.engine_type == _EngineType.SOMD: # SOMD success: "Simulation took" line in SLURM output return "Simulation took" in content @@ -491,13 +498,23 @@ def lighten(self) -> None: patterns = ["*.xtc", "*.trr", "*.cpt", "*_equilibrated.xvg"] use_recursive = True else: # SOMD - patterns = ["*.dcd", "*.s3", "*.s3.previous", "gradients.s3", - "simfile_equilibrated.dat", "latest.pdb"] + patterns = [ + "*.dcd", + "*.s3", + "*.s3.previous", + "gradients.s3", + "simfile_equilibrated.dat", + "latest.pdb", + ] use_recursive = False - + for directory in [self.base_dir, self.output_dir]: for pattern in patterns: - glob_func = _pathlib.Path(directory).rglob if use_recursive else _pathlib.Path(directory).glob + glob_func = ( + _pathlib.Path(directory).rglob + if use_recursive + else _pathlib.Path(directory).glob + ) for file in glob_func(pattern): self._logger.info(f"Deleting {file}") _subprocess.run(["rm", str(file)]) @@ -528,29 +545,31 @@ def read_gradients( """ # GROMACS: read .xvg file if self.engine_type == _EngineType.GROMACS: - filename = "prod/prod_equilibrated.xvg" if equilibrated_only else "prod/prod.xvg" + filename = ( + "prod/prod_equilibrated.xvg" if equilibrated_only else "prod/prod.xvg" + ) times = [] grads = [] - + with open(_os.path.join(self.output_dir, filename), "r") as f: for line in f: - if line.startswith(('#', '@')) or not line.strip(): + if line.startswith(("#", "@")) or not line.strip(): continue vals = line.split() time_ps = float(vals[0]) - + if endstate: energy_start = float(vals[3]) energy_end = float(vals[-2]) grad_kj = energy_end - energy_start else: grad_kj = float(vals[1]) - + times.append(time_ps / 1000.0) grads.append(grad_kj / 4.184) - + return _np.array(times), _np.array(grads) - + # SOMD: read simfile.dat if equilibrated_only: with open( diff --git a/a3fe/tests/test_calc_set.py b/a3fe/tests/test_calc_set.py index ba818fd0..8a78abdc 100644 --- a/a3fe/tests/test_calc_set.py +++ b/a3fe/tests/test_calc_set.py @@ -2,9 +2,8 @@ import os -import pytest - import pandas as pd +import pytest def test_calc_set_analysis(calc_set): diff --git a/a3fe/tests/test_engine_configuration.py b/a3fe/tests/test_engine_configuration.py index 2c0e770b..705634d8 100644 --- a/a3fe/tests/test_engine_configuration.py +++ b/a3fe/tests/test_engine_configuration.py @@ -1,7 +1,8 @@ """Unit and regression tests for the engine configuration class.""" -from tempfile import TemporaryDirectory import os +from tempfile import TemporaryDirectory + import pytest from pydantic import ValidationError diff --git a/a3fe/tests/test_run.py b/a3fe/tests/test_run.py index 2d637c7d..0ad3c9b8 100644 --- a/a3fe/tests/test_run.py +++ b/a3fe/tests/test_run.py @@ -16,7 +16,6 @@ import a3fe as a3 from a3fe.analyse.detect_equil import dummy_check_equil_multiwindow - LEGS_WITH_STAGES = {"bound": ["discharge", "vanish"], "free": ["discharge", "vanish"]} diff --git a/a3fe/tests/test_run_integration.py b/a3fe/tests/test_run_integration.py index 715ef64e..0f89579b 100644 --- a/a3fe/tests/test_run_integration.py +++ b/a3fe/tests/test_run_integration.py @@ -18,16 +18,16 @@ See README.md in this directory for more information on running these tests. """ +import glob +import logging import os -import pytest import subprocess -import glob from tempfile import TemporaryDirectory -import logging -import a3fe as a3 -from a3fe.tests import SLURM_PRESENT, RUN_SLURM_TESTS +import pytest +import a3fe as a3 +from a3fe.tests import RUN_SLURM_TESTS, SLURM_PRESENT # Define the legs and stages for testing LEGS_WITH_STAGES = { diff --git a/a3fe/tests/test_slurm_configuration.py b/a3fe/tests/test_slurm_configuration.py index 9a718a45..9ffe7c2f 100644 --- a/a3fe/tests/test_slurm_configuration.py +++ b/a3fe/tests/test_slurm_configuration.py @@ -1,14 +1,11 @@ """Unit and regression tests for the SlurmConfig class.""" +import os from tempfile import TemporaryDirectory - - from unittest.mock import patch from a3fe import SlurmConfig -import os - def test_create_default_config(): """Test that the default config is created correctly.""" diff --git a/a3fe/tests/test_update_engine_config_option.py b/a3fe/tests/test_update_engine_config_option.py index c754f29e..e65ee8e6 100644 --- a/a3fe/tests/test_update_engine_config_option.py +++ b/a3fe/tests/test_update_engine_config_option.py @@ -1,10 +1,11 @@ """Test the functionality of updating engine_config (SomdConfig) options.""" -from a3fe import SomdConfig -from a3fe.run._simulation_runner import SimulationRunner import pytest from pydantic import ValidationError +from a3fe import SomdConfig +from a3fe.run._simulation_runner import SimulationRunner + class MockSimulationRunner(SimulationRunner): """Simple mock for testing update_engine_config_option."""