Skip to content

Commit

Permalink
Add proper dof masking to torsion-space minimization
Browse files Browse the repository at this point in the history
  • Loading branch information
fdimaio committed Apr 12, 2024
1 parent 41be4b7 commit 9a79b48
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 61 deletions.
71 changes: 48 additions & 23 deletions tmol/kinematics/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@ class NodeType(enum.IntEnum):
bond = enum.auto()


class BondDOFTypes(enum.IntEnum):
"""Indices of bond dof types within KinDOF.raw."""

phi_p = 0
theta = enum.auto()
d = enum.auto()
phi_c = enum.auto()


class JumpDOFTypes(enum.IntEnum):
"""Indices of jump dof types within KinDOF.raw."""

RBx = 0
RBy = enum.auto()
RBz = enum.auto()
RBdel_alpha = enum.auto()
RBdel_beta = enum.auto()
RBdel_gamma = enum.auto()
RBalpha = enum.auto()
RBbeta = enum.auto()
RBgamma = enum.auto()


@attr.s(auto_attribs=True, frozen=True)
class KinForest(TensorGroup, ConvertAttrs):
"""A collection of atom-level kinematic trees, each of which can be processed
Expand Down Expand Up @@ -120,6 +143,31 @@ def root_node(cls):
id=-1, doftype=NodeType.root, parent=0, frame_x=0, frame_y=0, frame_z=0
)

def default_mask(self, nonideal=False):
"""Get a default DOF mask corresponding to ideal or nonideal refinement.
This can be updated in the future to accept something equivalent to a
MoveMap as an input."""
mask = torch.zeros(
(self.id.shape[0], 9), dtype=torch.bool, device=self.doftype.device
)
mask[self.doftype == NodeType.jump, JumpDOFTypes.RBx : JumpDOFTypes.RBz] = (
True # jump translations
)
mask[
self.doftype == NodeType.jump,
JumpDOFTypes.RBdel_alpha : JumpDOFTypes.RBdel_gamma,
] = True # jump rotations
mask[self.doftype == NodeType.bond, BondDOFTypes.phi_p] = True # bond rotation

if nonideal:
mask[self.doftype == NodeType.bond, BondDOFTypes.theta] = True # bond angle
mask[self.doftype == NodeType.bond, BondDOFTypes.d] = True # bond length
mask[self.doftype == NodeType.bond, BondDOFTypes.phi_c] = (
True # child bond rotation
)

return mask


@attr.s(auto_attribs=True, slots=True, frozen=True)
class KinDOF(TensorGroup, ConvertAttrs):
Expand All @@ -146,29 +194,6 @@ def clone(self):
return KinDOF(raw=self.raw.clone())


class BondDOFTypes(enum.IntEnum):
"""Indices of bond dof types within KinDOF.raw."""

phi_p = 0
theta = enum.auto()
d = enum.auto()
phi_c = enum.auto()


class JumpDOFTypes(enum.IntEnum):
"""Indices of jump dof types within KinDOF.raw."""

RBx = 0
RBy = enum.auto()
RBz = enum.auto()
RBdel_alpha = enum.auto()
RBdel_beta = enum.auto()
RBdel_gamma = enum.auto()
RBalpha = enum.auto()
RBbeta = enum.auto()
RBgamma = enum.auto()


@attr.s(auto_attribs=True, slots=True, frozen=True)
class BondDOF(TensorGroup, ConvertAttrs):
"""A bond dof view of KinDOF."""
Expand Down
25 changes: 9 additions & 16 deletions tmol/optimization/sfxn_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,6 @@ def forward(self):
return self.whole_pose_scoring_module(self.full_coords)


def default_mask_ideal(kinforest):
mask = torch.zeros()
pass


def default_mask_nonideal(kinforest):
pass


class KinematicSfxnNetwork(torch.nn.Module):
def __init__(self, score_function, pose_stack, nonideal=False):
super(KinematicSfxnNetwork, self).__init__()
Expand All @@ -58,19 +49,21 @@ def __init__(self, score_function, pose_stack, nonideal=False):
# fd There are two issues to address:
# fd 1) this call only works on CPU (since we use numba)
# fd 2) 'pose_stack.n_res_per_pose' does not work on multichain poses
fold_forest = FoldForest.polymeric_forest(pose_stack.n_res_per_pose.cpu())

self.kinforest = construct_pose_stack_kinforest(pose_stack, fold_forest)
dofs = inverseKin(self.kinforest, kincoords.to(torch.double))
if nonideal:
self.dof_mask = default_mask_nonideal(self.kinforest.doftype)
else:
self.dof_mask = default_mask_ideal(self.kinforest)
# a) construct the kinforest and dof mask
fold_forest = FoldForest.polymeric_forest(pose_stack.n_res_per_pose.cpu())
self.kinforest = construct_pose_stack_kinforest(pose_stack, fold_forest).to(
pose_stack.coords.device
)
self.dof_mask = self.kinforest.default_mask(nonideal=nonideal)

# b) inverse fold to get torsions from pose
self.pose_stack_coords = pose_stack.coords
coords_flat = pose_stack.coords.reshape(-1, 3)
kincoords = coords_flat[self.kinforest.id.to(torch.long)]
dofs = inverseKin(self.kinforest, kincoords.to(torch.double))

# c) apply mask, allocate params
self.full_dofs = dofs.raw
self.masked_dofs = torch.nn.Parameter(self.full_dofs[self.dof_mask])

Expand Down
57 changes: 35 additions & 22 deletions tmol/tests/optimization/test_scorefunction_minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,45 +85,58 @@ def closure():
run


def test_cart_minimizer(ubq_pdb, torch_device):
@pytest.mark.benchmark(group="pose_50step_minimization")
def test_cart_minimizer(benchmark, ubq_pdb, torch_device):
pose_stack = pose_stack_from_pdb(ubq_pdb, torch_device)
sfxn = beta2016_score_function(torch_device)

wpsm = sfxn.render_whole_pose_scoring_module(pose_stack)
wpsm(pose_stack.coords)

network = CartesianSfxnNetwork(sfxn, pose_stack)
optimizer = LBFGS_Armijo(network.parameters(), lr=0.1, max_iter=20)
@benchmark
def do_minimize():
network = CartesianSfxnNetwork(sfxn, pose_stack)
optimizer = LBFGS_Armijo(network.parameters(), lr=0.001, max_iter=50)

def closure():
optimizer.zero_grad()
E = network().sum()
E.backward()
return E
def closure():
optimizer.zero_grad()
E = network().sum()
E.backward()
return E

Estart = network().sum()
optimizer.step(closure)
Estop = network().sum()
Estart = network().sum()
optimizer.step(closure)
Estop = network().sum()
return Estart, Estop

Estart, Estop = do_minimize
assert Estop < Estart


def test_dof_minimizer(ubq_pdb, torch_device):
@pytest.mark.parametrize("nonideal", [False, True])
@pytest.mark.benchmark(group="pose_50step_minimization")
def test_dof_minimizer(benchmark, ubq_pdb, torch_device, nonideal):
pose_stack = pose_stack_from_pdb(ubq_pdb, torch_device)
sfxn = beta2016_score_function(torch_device)

wpsm = sfxn.render_whole_pose_scoring_module(pose_stack)
wpsm(pose_stack.coords)

network = KinematicSfxnNetwork(sfxn, pose_stack)
optimizer = LBFGS_Armijo(network.parameters(), lr=0.001, max_iter=20)
@benchmark
def do_minimize():
network = KinematicSfxnNetwork(sfxn, pose_stack, nonideal=nonideal)
optimizer = LBFGS_Armijo(network.parameters(), lr=0.001, max_iter=50)

def closure():
optimizer.zero_grad()
E = network().sum()
E.backward()
return E
def closure():
optimizer.zero_grad()
E = network().sum()
E.backward()
return E

Estart = network().sum()
optimizer.step(closure)
Estop = network().sum()
Estart = network().sum()
optimizer.step(closure)
Estop = network().sum()
return Estart, Estop

Estart, Estop = do_minimize
assert Estop < Estart

0 comments on commit 9a79b48

Please sign in to comment.