Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/boltzgen/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,13 +1325,60 @@ def _from_feat(
]
token_to_res_old = token_to_res_old[not_padding_selector]

# Create mapping from original token indices to filtered token indices
# This is needed because atom_to_token contains original indices but we're using filtered tokens
# Build remap tensor: old_token_idx -> new_token_idx
# Find the max old token index that might be referenced by atoms
max_old_token = atom_to_token.max().item() if len(atom_to_token) > 0 else -1
if max_old_token >= 0:
# Initialize with -1 (invalid) for all possible old token indices
token_remap_tensor = torch.full(
(max_old_token + 1,), -1, dtype=atom_to_token.dtype, device=atom_to_token.device
)
# Map old token indices to new token indices
for new_idx, old_idx in enumerate(not_padding_selector):
old_idx_item = old_idx.item()
if old_idx_item <= max_old_token:
token_remap_tensor[old_idx_item] = new_idx
else:
token_remap_tensor = None

# remove atom padding if there is any
ref_element = ref_element[atom_pad_mask.bool()]
ref_charge = ref_charge[atom_pad_mask.bool()]
ref_atom_name_chars = ref_atom_name_chars[atom_pad_mask.bool()]
coords = coords[atom_pad_mask.bool()]
atom_to_token = atom_to_token[atom_pad_mask.bool()]

# Remap atom_to_token to use filtered token indices
# atom_to_token currently contains original token indices, but we need filtered indices
if token_remap_tensor is not None:
atom_to_token = token_remap_tensor[atom_to_token]
# Check for atoms that mapped to -1 (invalid token references)
invalid_mask = atom_to_token == -1
if invalid_mask.any():
# This indicates atom_pad_mask is incorrect - atoms reference non-existent tokens
invalid_count = invalid_mask.sum().item()
raise ValueError(
f"Found {invalid_count} atoms with token indices that don't exist in filtered token array. "
"This indicates data corruption in atom_pad_mask or atom_to_token. "
f"Invalid atom indices: {torch.where(invalid_mask)[0].tolist()[:10]}..."
)

# Validate that all atom_to_token values are within bounds of filtered token array
if len(not_padding_selector) > 0:
max_valid_token = len(not_padding_selector) - 1
if (atom_to_token > max_valid_token).any():
invalid_count = (atom_to_token > max_valid_token).sum().item()
invalid_indices = torch.where(atom_to_token > max_valid_token)[0]
raise ValueError(
f"Found {invalid_count} atoms with token indices > {max_valid_token} "
f"(filtered token array size = {len(not_padding_selector)}). "
"This indicates data corruption in atom_pad_mask. "
f"Invalid atom indices: {invalid_indices.tolist()[:10]}... "
f"Invalid token values: {atom_to_token[invalid_indices].tolist()[:10]}..."
)

# create residue identifiers
res_identifiers = []
for asym, res_idx in zip(asym_id, residue_index):
Expand Down
13 changes: 8 additions & 5 deletions src/boltzgen/data/parse/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,12 +1476,14 @@ def recursive_check(data):
if c2 in total_renaming.keys():
c2 = total_renaming[c2]

if c1 not in all_parsed_chains.keys():
# Check if chain exists in either parsed_chains (protein/ligand) or data.chains (file entities)
chain_names_in_data = {chain["name"].item() for chain in data.chains}
if c1 not in all_parsed_chains.keys() and c1 not in chain_names_in_data:
msg = f"Chain {c1} in the specified connection does not exist: {constraint}"
ValueError(msg)
if c2 not in all_parsed_chains.keys():
raise ValueError(msg)
if c2 not in all_parsed_chains.keys() and c2 not in chain_names_in_data:
msg = f"Chain {c2} in the specified connection does not exist: {constraint}"
ValueError(msg)
raise ValueError(msg)

# Map index
if (
Expand Down Expand Up @@ -1861,7 +1863,8 @@ def parse_file(self, item, mols, mol_dir, ligand_id, base_file_path=Path(".")):

# Handle the "all" case where all chains are set to be specified
if chain_id == "all":
new_groups = np.ones(num_res)
visibility = group["visibility"]
new_groups = np.full(num_res, visibility, dtype=np.int32)
continue

if chain_id not in structure.chains["name"]:
Expand Down
8 changes: 6 additions & 2 deletions src/boltzgen/data/write/mmcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def add_design_cols(structure, block, colors):
"metric_value",
]
plddt_loop = block.init_loop("_ma_qa_metric_local.", plddt_cols)
ordinal_id = 0
global_res_idx = -1
for chain in structure.chains:
chain_name_str = re.sub(r"\d+", "", chain["name"].item())
Expand All @@ -322,11 +323,12 @@ def add_design_cols(structure, block, colors):
global_res_idx += 1
if not res["is_present"]:
continue
ordinal_id += 1
mon_id = res["name"].item()
design_label = colors[global_res_idx] # [plddt_loop.length()]
plddt_loop.add_row(
[
str(plddt_loop.length() + 1), # ordinal_id
str(ordinal_id), # ordinal_id
"1", # model_id
chain_id, # label_asym_id
str(res["res_idx"].item() + 1),
Expand Down Expand Up @@ -362,6 +364,7 @@ def add_plddt_cols(structure, block):
"metric_value",
]
plddt_loop = block.init_loop("_ma_qa_metric_local.", plddt_cols)
ordinal_id = 0
for chain in structure.chains:
if chain["mol_type"].item() == const.chain_type_ids["NONPOLYMER"]:
continue
Expand All @@ -376,12 +379,13 @@ def add_plddt_cols(structure, block):
for _, res in enumerate(residues, 1):
if not res["is_present"]:
continue
ordinal_id += 1
mon_id = res["name"].item()
ref_atom_idx = res["atom_idx"]
plddt_score = structure.atoms[ref_atom_idx]["bfactor"].item()
plddt_loop.add_row(
[
str(plddt_loop.length() + 1), # ordinal_id
str(ordinal_id), # ordinal_id
"1", # model_id
chain_id, # label_asym_id
str(res["res_idx"].item() + 1),
Expand Down