Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 186 additions & 69 deletions ipsae.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,100 @@ def csv_header_line() -> str:
return "i,AlignChn,ScoredChain,AlignResNum,AlignResType,AlignRespLDDT,n0chn,n0dom,n0res,d0chn,d0dom,d0res,ipTM_pae,ipSAE_d0chn,ipSAE_d0dom,ipSAE\n"


@dataclass
class ChainPairScoreResults:
"""Container for chain-pair summary score results.

Attributes:
Chn1: first chain identifier
Chn2: second chain identifier
PAE: PAE cutoff value
Dist: Distance cutoff for CA-CA contacts
Type: "asym" or "max"; asym means asymmetric ipTM/ipSAE values; max is maximum of asym values
ipSAE: ipSAE value for given PAE cutoff and d0 determined by number of residues in 2nd chain with PAE<cutoff
ipSAE_d0chn: ipSAE calculated with PAE cutoff and d0 = sum of chain lengths
ipSAE_d0dom: ipSAE calculated with PAE cutoff and d0 = total number of residues in both chains with any interchain PAE<cutoff
ipTM_af: AlphaFold ipTM values. For AF2, this is for whole complex from json file. For AF3, this is symmetric pairwise value from summary json file.
ipTM_d0chn: ipTM (no PAE cutoff) calculated from PAE matrix and d0 = sum of chain lengths
pDockQ: score from pLDDTs from Bryant, Pozotti, and Eloffson
pDockQ2: score based on PAE, calculated pairwise from Zhu, Shenoy, Kundrotas, Elofsson
LIS: Local Interaction Score based on transform of PAEs from Kim, Hu, Comjean, Rodiger, Mohr, Perrimon
n0res: number of residues for d0 in ipSAE calculation
n0chn: number of residues in d0 in ipSAE_d0chn calculation
n0dom: number of residues in d0 in ipSAE_d0dom calculation
d0res: d0 for ipSAE
d0chn: d0 for ipSAE_d0chn
d0dom: d0 for ipSAE_d0dom
nres1: number of residues in chain1 with PAE<cutoff with residues in chain2
nres2: number of residues in chain2 with PAE<cutoff with residues in chain1
dist1: number of residues in chain 1 with PAE<cutoff and dist<cutoff from chain2
dist2: number of residues in chain 2 with PAE<cutoff and dist<cutoff from chain1
Model: AlphaFold filename
"""

Chn1: str
Chn2: str
PAE: int
Dist: int
Type: str
ipSAE: float
ipSAE_d0chn: float
ipSAE_d0dom: float
ipTM_af: float
ipTM_d0chn: float
pDockQ: float
pDockQ2: float
LIS: float
n0res: int
n0chn: int
n0dom: int
d0res: float
d0chn: float
d0dom: float
nres1: int
nres2: int
dist1: int
dist2: int
Model: str

def to_formatted_line(self) -> str:
"""Format the summary result as a fixed-width string."""
pae_str = str(self.PAE).zfill(2)
dist_str = str(self.Dist).zfill(2)
return (
f"{self.Chn1} {self.Chn2} {pae_str:3} {dist_str:3} {self.Type:5} "
f"{self.ipSAE:8.6f} "
f"{self.ipSAE_d0chn:8.6f} "
f"{self.ipSAE_d0dom:8.6f} "
f"{self.ipTM_af:5.3f} "
f"{self.ipTM_d0chn:8.6f} "
f"{self.pDockQ:8.4f} "
f"{self.pDockQ2:8.4f} "
f"{self.LIS:8.4f} "
f"{self.n0res:5d} "
f"{self.n0chn:5d} "
f"{self.n0dom:5d} "
f"{self.d0res:6.2f} "
f"{self.d0chn:6.2f} "
f"{self.d0dom:6.2f} "
f"{self.nres1:5d} "
f"{self.nres2:5d} "
f"{self.dist1:5d} "
f"{self.dist2:5d} "
f"{self.Model}\n"
)

@staticmethod
def header_line() -> str:
"""Return the header line for the summary output."""
return "Chn1 Chn2 PAE Dist Type ipSAE ipSAE_d0chn ipSAE_d0dom ipTM_af ipTM_d0chn pDockQ pDockQ2 LIS n0res n0chn n0dom d0res d0chn d0dom nres1 nres2 dist1 dist2 Model\n"

@staticmethod
def csv_header_line() -> str:
"""Return the CSV header line for the summary output."""
return "Chn1,Chn2,PAE,Dist,Type,ipSAE,ipSAE_d0chn,ipSAE_d0dom,ipTM_af,ipTM_d0chn,pDockQ,pDockQ2,LIS,n0res,n0chn,n0dom,d0res,d0chn,d0dom,nres1,nres2,dist1,dist2,Model\n"


@dataclass
class ScoreResults:
"""Container for calculated scores and output data.
Expand All @@ -281,8 +375,8 @@ class ScoreResults:
pdockq2_scores: Dictionary of pDockQ2 scores (by chain pair).
lis_scores: Dictionary of LIS scores (by chain pair).
metrics: Dictionary of pDockQ, pDockQ2, and LIS scores for each chain pair.
by_res_data: Lists of per-residue scores.
summary_lines: List of summarized chain-pair scores.
by_res_scores: Lists of per-residue scores.
chain_pair_scores: List of chain-pair summary score results.
pymol_script: List of formatted strings for PyMOL script.
"""

Expand All @@ -293,8 +387,10 @@ class ScoreResults:
lis_scores: dict[str, dict[str, float]] # {c1: {c2: score}}
metrics: dict[str, dict[str, float]] # {`<c1>_<c2>`: {metric_name: value}}

by_res_data: list[PerResScoreResults]
summary_lines: list[str] # Storing the formatted lines for the summary output file
by_res_scores: list[PerResScoreResults]
chain_pair_scores: list[
ChainPairScoreResults
] # List of chain-pair summary score results
pymol_script: list[str]


Expand Down Expand Up @@ -938,15 +1034,19 @@ def aggregate_byres_scores(
d0chn: dict[str, dict[str, float]],
d0dom: dict[str, dict[str, float]],
pdb_stem: str,
) -> tuple[list, list, dict[str, dict[str, float]]]:
"""Aggregate per-residue scores into chain-pair-specific scores."""
) -> tuple[list[ChainPairScoreResults], list[str], dict[str, dict[str, float]]]:
"""Aggregate per-residue scores into chain-pair-specific scores.

Returns:
A tuple containing:
- List of ChainPairScoreResults objects with chain-pair scores.
- List of PyMOL script lines.
- Dictionary of metrics for each chain pair.
"""
# Store results in a structured way
results_metrics: dict[str, dict[str, float]] = {}

summary_lines = []
summary_lines.append(
"\nChn1 Chn2 PAE Dist Type ipSAE ipSAE_d0chn ipSAE_d0dom ipTM_af ipTM_d0chn pDockQ pDockQ2 LIS n0res n0chn n0dom d0res d0chn d0dom nres1 nres2 dist1 dist2 Model\n"
)
chain_pair_scores: list[ChainPairScoreResults] = []

pymol_lines = []
pymol_lines.append(
Expand All @@ -962,8 +1062,8 @@ def get_max_info(values_array, c1, c2):
idx = np.argmax(vals)
return vals[idx], residues[idx].residue_str, idx

pae_str = str(int(pae_cutoff)).zfill(2)
dist_str = str(int(dist_cutoff)).zfill(2)
pae_int = int(pae_cutoff)
dist_int = int(dist_cutoff)
chainpairs = set()
for c1 in unique_chains:
for c2 in unique_chains:
Expand Down Expand Up @@ -1000,30 +1100,34 @@ def get_max_info(values_array, c1, c2):
if iptm_af == 0.0 and pae_data.iptm != -1.0:
iptm_af = pae_data.iptm # Fallback to global if per-pair not found

outstring = (
f"{c1} {c2} {pae_str:3} {dist_str:3} {'asym':5} "
f"{ipsae_res_val:8.6f} "
f"{ipsae_chn_val:8.6f} "
f"{ipsae_dom_val:8.6f} "
f"{iptm_af:5.3f} "
f"{iptm_chn_val:8.6f} "
f"{pDockQ[c1][c2]:8.4f} "
f"{pDockQ2[c1][c2]:8.4f} "
f"{LIS[c1][c2]:8.4f} "
f"{int(n0res_val):5d} "
f"{int(n0chn[c1][c2]):5d} "
f"{int(n0dom[c1][c2]):5d} "
f"{d0res_val:6.2f} "
f"{d0chn[c1][c2]:6.2f} "
f"{d0dom[c1][c2]:6.2f} "
f"{res1_cnt:5d} "
f"{res2_cnt:5d} "
f"{dist1_cnt:5d} "
f"{dist2_cnt:5d} "
f"{pdb_stem}\n"
summary_result = ChainPairScoreResults(
Chn1=c1,
Chn2=c2,
PAE=pae_int,
Dist=dist_int,
Type="asym",
ipSAE=float(ipsae_res_val),
ipSAE_d0chn=float(ipsae_chn_val),
ipSAE_d0dom=float(ipsae_dom_val),
ipTM_af=float(iptm_af),
ipTM_d0chn=float(iptm_chn_val),
pDockQ=float(pDockQ[c1][c2]),
pDockQ2=float(pDockQ2[c1][c2]),
LIS=float(LIS[c1][c2]),
n0res=int(n0res_val),
n0chn=int(n0chn[c1][c2]),
n0dom=int(n0dom[c1][c2]),
d0res=float(d0res_val),
d0chn=float(d0chn[c1][c2]),
d0dom=float(d0dom[c1][c2]),
nres1=res1_cnt,
nres2=res2_cnt,
dist1=dist1_cnt,
dist2=dist2_cnt,
Model=pdb_stem,
)
summary_lines.append(outstring)
pymol_lines.append("# " + outstring)
chain_pair_scores.append(summary_result)
pymol_lines.append("# " + summary_result.to_formatted_line())

# Store in results dict
results_metrics[f"{c1}_{c2}"] = {
Expand Down Expand Up @@ -1117,33 +1221,36 @@ def get_max_of_pair(arr, k1, k2):
len(dist_unique_residues_chain2[c2][c1]),
)

outstring = (
f"{c2} {c1} {pae_str:3} {dist_str:3} {'max':5} "
f"{ipsae_res_max:8.6f} "
f"{ipsae_chn_max:8.6f} "
f"{ipsae_dom_max:8.6f} "
f"{iptm_af_max:5.3f} "
f"{iptm_chn_max:8.6f} "
f"{pDockQ[c1][c2]:8.4f} "
f"{pdockq2_max:8.4f} "
f"{lis_avg:8.4f} "
f"{int(n0res_max):5d} "
f"{int(n0chn[c1][c2]):5d} "
f"{int(n0dom_max):5d} "
f"{d0res_max:6.2f} "
f"{d0chn[c1][c2]:6.2f} "
f"{d0dom_max:6.2f} "
f"{res1_max:5d} "
f"{res2_max:5d} "
f"{dist1_max:5d} "
f"{dist2_max:5d} "
f"{pdb_stem}\n"
summary_result = ChainPairScoreResults(
Chn1=c2,
Chn2=c1,
PAE=pae_int,
Dist=dist_int,
Type="max",
ipSAE=float(ipsae_res_max),
ipSAE_d0chn=float(ipsae_chn_max),
ipSAE_d0dom=float(ipsae_dom_max),
ipTM_af=float(iptm_af_max),
ipTM_d0chn=float(iptm_chn_max),
pDockQ=float(pDockQ[c1][c2]),
pDockQ2=float(pdockq2_max),
LIS=float(lis_avg),
n0res=int(n0res_max),
n0chn=int(n0chn[c1][c2]),
n0dom=int(n0dom_max),
d0res=float(d0res_max),
d0chn=float(d0chn[c1][c2]),
d0dom=float(d0dom_max),
nres1=res1_max,
nres2=res2_max,
dist1=dist1_max,
dist2=dist2_max,
Model=pdb_stem,
)
summary_lines.append(outstring)
summary_lines.append("\n")
pymol_lines.append("# " + outstring)
chain_pair_scores.append(summary_result)
pymol_lines.append("# " + summary_result.to_formatted_line())

return summary_lines, pymol_lines, results_metrics
return chain_pair_scores, pymol_lines, results_metrics


def calculate_scores(
Expand Down Expand Up @@ -1354,7 +1461,7 @@ def calculate_scores(
# We need to store these to generate the summary table

# Store results in a structured way
summary_lines, pymol_lines, results_metrics = aggregate_byres_scores(
chain_pair_scores, pymol_lines, results_metrics = aggregate_byres_scores(
residues,
pae_cutoff,
dist_cutoff,
Expand Down Expand Up @@ -1387,8 +1494,8 @@ def calculate_scores(
pdockq2_scores=pDockQ2,
lis_scores=LIS,
metrics=results_metrics,
by_res_data=by_res_lines,
summary_lines=summary_lines,
by_res_scores=by_res_lines,
chain_pair_scores=chain_pair_scores,
pymol_script=pymol_lines,
)

Expand All @@ -1406,11 +1513,17 @@ def write_outputs(results: ScoreResults, output_prefix: str | Path) -> None:
output_prefix: The prefix for the output filenames (including path).
"""
with Path(f"{output_prefix}.txt").open("w") as f:
f.writelines(results.summary_lines)
f.write("\n") # Leading newline
f.write(ChainPairScoreResults.header_line())
for summary in results.chain_pair_scores:
f.write(summary.to_formatted_line())
# Add newline after "max" line (end of each chain pair group)
if summary.Type == "max":
f.write("\n")

with Path(f"{output_prefix}_byres.txt").open("w") as f:
f.write(results.by_res_data[0].header_line())
for res_line in results.by_res_data:
f.write(results.by_res_scores[0].header_line())
for res_line in results.by_res_scores:
f.write(res_line.to_formatted_line())

with Path(f"{output_prefix}.pml").open("w") as f:
Expand Down Expand Up @@ -1536,10 +1649,14 @@ def main(
else:
# Print summary to stdout
print("#" * 90 + "\n# Summary\n" + "#" * 90)
print("".join(results.summary_lines))
print("\n" + ChainPairScoreResults.header_line(), end="")
for summary in results.chain_pair_scores:
print(summary.to_formatted_line(), end="")
if summary.Type == "max":
print()
print("#" * 90 + "\n# Per-residue scores\n" + "#" * 90)
print(results.by_res_data[0].header_line())
print("".join(x.to_formatted_line() for x in results.by_res_data))
print(results.by_res_scores[0].header_line())
print("".join(x.to_formatted_line() for x in results.by_res_scores))
print("#" * 90 + "\n# PyMOL script\n" + "#" * 90)
print("".join(results.pymol_script))

Expand Down