Skip to content

Commit

Permalink
move bias_register to core.register
Browse files Browse the repository at this point in the history
  • Loading branch information
hsulab committed Apr 19, 2024
1 parent 4403446 commit 1cd58f6
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
20 changes: 8 additions & 12 deletions src/gdpx/bias/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
# -*- coding: utf-8 -*-



from typing import List

from ..core.register import Register, registers
bias_register = Register("bias")
from ..core.register import registers

"""Add bias on potential energy surface.
Expand All @@ -16,27 +12,27 @@
"""

from .afir import AFIRCalculator
bias_register.register("afir")(AFIRCalculator)
registers.bias.register("afir")(AFIRCalculator)

from .bondboost import BondBoostCalculator
bias_register.register("bondboost")(BondBoostCalculator)
registers.bias.register("bondboost")(BondBoostCalculator)

# from .harmonic import HarmonicBias
# bias_register.register("harmonic")(HarmonicBias)

from .harmonic import DistanceHarmonicCalculator, PlaneHarmonicCalculator
bias_register.register("distance_harmonic")(DistanceHarmonicCalculator)
bias_register.register("plane_harmonic")(PlaneHarmonicCalculator)
registers.bias.register("distance_harmonic")(DistanceHarmonicCalculator)
registers.bias.register("plane_harmonic")(PlaneHarmonicCalculator)

# - gaussian

# from .gaussian import GaussianCalculator
# bias_register.register("gaussian")(GaussianCalculator)

from .gaussian import BondGaussianCalculator, CenterOfMassGaussianCalculator, RMSDGaussian
bias_register.register("bond_gaussian")(BondGaussianCalculator)
bias_register.register("center_of_mass_gaussian")(CenterOfMassGaussianCalculator)
bias_register.register("rmsd_gaussian")(RMSDGaussian)
registers.bias.register("bond_gaussian")(BondGaussianCalculator)
registers.bias.register("center_of_mass_gaussian")(CenterOfMassGaussianCalculator)
registers.bias.register("rmsd_gaussian")(RMSDGaussian)


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions src/gdpx/core/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class registers:
#: Builders.
builder: Register = Register("builder")

#: Bias.
bias: Register = Register("bias")

#: Colvars.
colvar: Register = Register("colvar")

Expand Down Expand Up @@ -161,6 +164,8 @@ def create(
("gdpx.potential", ["managers"]),
# -- dataloaders (datasets)
("gdpx.data", ["dataset"]),
# -- bias
("gdpx", ["bias"]),
# -- builders
("gdpx", ["builder"]),
# -- genetic-algorithm-related
Expand Down
1 change: 0 additions & 1 deletion src/gdpx/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from .. import config
from ..core.register import registers
from ..bias import bias_register


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions src/gdpx/potential/managers/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import registers
from . import AbstractPotentialManager, DummyCalculator
from .. import bias_register


"""This manager registers ALL bias calculators."""

Expand Down Expand Up @@ -51,7 +51,8 @@ def register_calculator(self, calc_params, *agrs, **kwargs) -> None:
# - instantiate calculator
calc = DummyCalculator()
if self.calc_backend == "ase":
calc = bias_register[bias_type](**calc_params)
bias_cls = registers.bias[bias_type]
calc = bias_cls(**calc_params)
else:
...

Expand Down

0 comments on commit 1cd58f6

Please sign in to comment.