Skip to content

Commit 8f17022

Browse files
committed
Added mixtcrpred
1 parent 2347825 commit 8f17022

1 file changed

Lines changed: 51 additions & 13 deletions

File tree

epytope/TCRSpecificityPrediction/External.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,7 @@ def organism(self):
14391439
return self.__organism
14401440

14411441
def format_tcr_data(self, tcrs, epitopes, pairwise, **kwargs):
1442+
self.download_models(**kwargs)
14421443
rename_columns = {
14431444
"VJ_cdr3": "cdr3_TRA",
14441445
"VDJ_cdr3": "cdr3_TRB",
@@ -1452,34 +1453,64 @@ def format_tcr_data(self, tcrs, epitopes, pairwise, **kwargs):
14521453
df_tcrs = self.combine_tcrs_epitopes_pairwise(df_tcrs, epitopes)
14531454
else:
14541455
df_tcrs = self.combine_tcrs_epitopes_list(df_tcrs, epitopes)
1455-
df_tcrs = df_tcrs[list(rename_columns.values())]
1456+
df_tcrs = df_tcrs[list(rename_columns.values()) + ["Epitope", "MHC"]]
14561457
df_tcrs = self.filter_by_length(df_tcrs, "cdr3_TRA", "cdr3_TRB", None)
14571458
for el in rename_columns.values():
14581459
df_tcrs = df_tcrs[(~df_tcrs[el].isna()) & (df_tcrs[el] != "")]
1460+
df_tcrs = self.filter_epitopes(df_tcrs, **kwargs)
14591461
df_tcrs = df_tcrs.drop_duplicates()
14601462
return df_tcrs
14611463

14621464
def download_models(self, **kwargs):
14631465
path_models = os.path.join(kwargs["repository"], "pretrained_models")
14641466
n_files = len([f for f in os.listdir(path_models)])
1465-
if n_files < 148:
1466-
cmd = f"python {kwargs['repository']}/MixTCRpred.py --download_all"
1467-
path_file = os.path.join(path_models, "model_A0101_ATDALMTGF.ckpt")
1468-
self.run_exec_cmd(cmd, path_file, **kwargs)
1467+
if n_files >= 148:
1468+
return
1469+
cmd = f"MixTCRpred.py --download_all"
1470+
path_file = os.path.join(path_models, "model_A0101_ATDALMTGF.ckpt")
1471+
try:
1472+
super().run_exec_cmd(cmd, path_file, **kwargs)
1473+
except Exception as e:
1474+
if "urllib.error.HTTPError: HTTP Error 404: NOT FOUND" not in str(e):
1475+
raise
1476+
1477+
def filter_epitopes(self, df, repository=None, **kwargs):
1478+
def epitope_contained(row):
1479+
epitope = row["Epitope"]
1480+
mhc = row["MHC"].name
1481+
mhc = mhc.replace("*", "")
1482+
mhc = mhc.replace(":", "")
1483+
name = f"{mhc}_{epitope}"
1484+
return name in models
1485+
1486+
models = os.listdir(os.path.join(repository, "pretrained_models"))
1487+
models = [el.split("model_")[1].split(".ckpt")[0] for el in models if el.endswith(".ckpt")]
1488+
mask_epitope = df.apply(epitope_contained, axis=1).values
1489+
delta = len(df) - sum(mask_epitope)
1490+
if delta > 0:
1491+
warnings.warn(f"Filtering {delta} rows as Epitope not available for categorical model")
1492+
df = df[mask_epitope].copy()
1493+
return df
1494+
1495+
14691496

14701497
def save_tmp_files(self, data, **kwargs):
14711498
tmp_folder = self.get_tmp_folder_path()
1499+
tmp_foldername = "/lustre/groups/imm01/workspace/felix.drost/2023_benchmark/epytope/external/tmp"
1500+
14721501
paths = []
1473-
for _, row in data[["Epitope", "MHC"]].iterrows():
1474-
epitope = row["Epitope"] # TODO transform mhc / epitope
1502+
for _, row in data[["Epitope", "MHC"]].drop_duplicates().iterrows():
1503+
epitope = row["Epitope"]
14751504
mhc = row["MHC"]
14761505

14771506
data_epitope = data[(data["Epitope"] == epitope) & (data["MHC"] == mhc)]
14781507
cols = ["cdr3_TRA", "cdr3_TRB", "TRAV", "TRAJ", "TRBV", "TRBJ"]
14791508
data_epitope = data_epitope[cols].copy()
14801509

1481-
path_in = os.path.join(tmp_folder, f"input_{mhc}_{epitope}.csv")
1482-
path_out = os.path.join(tmp_folder, f"output_{mhc}_{epitope}.csv")
1510+
mhc = mhc.name.replace("*", "").replace(":", "").replace("H-2", "H2")
1511+
1512+
path_in = os.path.join(tmp_foldername, f"input_{mhc}_{epitope}.csv")
1513+
path_out = os.path.join(tmp_foldername, f"output_{mhc}_{epitope}.csv")
14831514

14841515
paths.append(path_in)
14851516
paths.append(path_out)
@@ -1502,20 +1533,27 @@ def run_exec_cmd(self, cmd, filenames, interpreter=None, conda=None, cmd_prefix=
15021533
conda, cmd_prefix, m_cmd=False, **kwargs)
15031534

15041535
def format_results(self, filenames, tmp_folder, tcrs, epitopes, pairwise, **kwargs):
1536+
import re
15051537
results_joined = []
15061538
for i in range(0, len(filenames), 2):
15071539
path_out = filenames[i+1]
1508-
results_predictor = pd.read_csv(path_out)
1540+
results_predictor = pd.read_csv(path_out, index_col=0, comment="#")
1541+
results_predictor.index.name = None
15091542
results_predictor = results_predictor.fillna("")
15101543
results_predictor = results_predictor.rename(columns={"cdr3_TRB": "VDJ_cdr3",
15111544
"TRBV": "VDJ_v_gene",
15121545
"TRBJ": "VDJ_j_gene",
15131546
"cdr3_TRA": "VJ_cdr3",
15141547
"TRAV": "VJ_v_gene",
15151548
"TRAJ": "VJ_j_gene",
1516-
"prediction": "Score"})
1517-
epitope = ""
1518-
mhc = ""
1549+
"score": "Score"})
1550+
epitope = path_out.split("_")[-1].split(".csv")[0]
1551+
mhc = path_out.split("_")[-2]
1552+
mhc = mhc.replace("H2", "H-2")
1553+
if mhc[0] != "H":
1554+
match = re.match(r"([A-Za-z]+)(\d+)", mhc)
1555+
letters, numbers = match.groups()
1556+
mhc = f"HLA-{letters}*{numbers[:-2]}:{numbers[-2:]}"
15191557
results_predictor["Epitope"] = epitope
15201558
results_predictor["MHC"] = mhc
15211559
results_joined.append(results_predictor)

0 commit comments

Comments
 (0)