@@ -1439,6 +1439,7 @@ def organism(self):
14391439 return self .__organism
14401440
14411441 def format_tcr_data (self , tcrs , epitopes , pairwise , ** kwargs ):
1442+ self .download_models (** kwargs )
14421443 rename_columns = {
14431444 "VJ_cdr3" : "cdr3_TRA" ,
14441445 "VDJ_cdr3" : "cdr3_TRB" ,
@@ -1452,34 +1453,64 @@ def format_tcr_data(self, tcrs, epitopes, pairwise, **kwargs):
14521453 df_tcrs = self .combine_tcrs_epitopes_pairwise (df_tcrs , epitopes )
14531454 else :
14541455 df_tcrs = self .combine_tcrs_epitopes_list (df_tcrs , epitopes )
1455- df_tcrs = df_tcrs [list (rename_columns .values ())]
1456+ df_tcrs = df_tcrs [list (rename_columns .values ()) + [ "Epitope" , "MHC" ] ]
14561457 df_tcrs = self .filter_by_length (df_tcrs , "cdr3_TRA" , "cdr3_TRB" , None )
14571458 for el in rename_columns .values ():
14581459 df_tcrs = df_tcrs [(~ df_tcrs [el ].isna ()) & (df_tcrs [el ] != "" )]
1460+ df_tcrs = self .filter_epitopes (df_tcrs , ** kwargs )
14591461 df_tcrs = df_tcrs .drop_duplicates ()
14601462 return df_tcrs
14611463
14621464 def download_models (self , ** kwargs ):
14631465 path_models = os .path .join (kwargs ["repository" ], "pretrained_models" )
14641466 n_files = len ([f for f in os .listdir (path_models )])
1465- if n_files < 148 :
1466- cmd = f"python { kwargs ['repository' ]} /MixTCRpred.py --download_all"
1467- path_file = os .path .join (path_models , "model_A0101_ATDALMTGF.ckpt" )
1468- self .run_exec_cmd (cmd , path_file , ** kwargs )
1467+ if n_files >= 148 :
1468+ return
1469+ cmd = f"MixTCRpred.py --download_all"
1470+ path_file = os .path .join (path_models , "model_A0101_ATDALMTGF.ckpt" )
1471+ try :
1472+ super ().run_exec_cmd (cmd , path_file , ** kwargs )
1473+ except Exception as e :
1474+ if "urllib.error.HTTPError: HTTP Error 404: NOT FOUND" not in str (e ):
1475+ raise
1476+
1477+ def filter_epitopes (self , df , repository = None , ** kwargs ):
1478+ def epitope_contained (row ):
1479+ epitope = row ["Epitope" ]
1480+ mhc = row ["MHC" ].name
1481+ mhc = mhc .replace ("*" , "" )
1482+ mhc = mhc .replace (":" , "" )
1483+ name = f"{ mhc } _{ epitope } "
1484+ return name in models
1485+
1486+ models = os .listdir (os .path .join (repository , "pretrained_models" ))
1487+ models = [el .split ("model_" )[1 ].split (".ckpt" )[0 ] for el in models if el .endswith (".ckpt" )]
1488+ mask_epitope = df .apply (epitope_contained , axis = 1 ).values
1489+ delta = len (df ) - sum (mask_epitope )
1490+ if delta > 0 :
1491+ warnings .warn (f"Filtering { delta } rows as Epitope not available for categorical model" )
1492+ df = df [mask_epitope ].copy ()
1493+ return df
1494+
1495+
14691496
14701497 def save_tmp_files (self , data , ** kwargs ):
14711498 tmp_folder = self .get_tmp_folder_path ()
1499+ tmp_foldername = "/lustre/groups/imm01/workspace/felix.drost/2023_benchmark/epytope/external/tmp"
1500+
14721501 paths = []
1473- for _ , row in data [["Epitope" , "MHC" ]].iterrows ():
1474- epitope = row ["Epitope" ] # TODO transform mhc / epitope
1502+ for _ , row in data [["Epitope" , "MHC" ]].drop_duplicates (). iterrows ():
1503+ epitope = row ["Epitope" ]
14751504 mhc = row ["MHC" ]
14761505
14771506 data_epitope = data [(data ["Epitope" ] == epitope ) & (data ["MHC" ] == mhc )]
14781507 cols = ["cdr3_TRA" , "cdr3_TRB" , "TRAV" , "TRAJ" , "TRBV" , "TRBJ" ]
14791508 data_epitope = data_epitope [cols ].copy ()
14801509
1481- path_in = os .path .join (tmp_folder , f"input_{ mhc } _{ epitope } .csv" )
1482- path_out = os .path .join (tmp_folder , f"output_{ mhc } _{ epitope } .csv" )
1510+ mhc = mhc .name .replace ("*" , "" ).replace (":" , "" ).replace ("H-2" , "H2" )
1511+
1512+ path_in = os .path .join (tmp_foldername , f"input_{ mhc } _{ epitope } .csv" )
1513+ path_out = os .path .join (tmp_foldername , f"output_{ mhc } _{ epitope } .csv" )
14831514
14841515 paths .append (path_in )
14851516 paths .append (path_out )
@@ -1502,20 +1533,27 @@ def run_exec_cmd(self, cmd, filenames, interpreter=None, conda=None, cmd_prefix=
15021533 conda , cmd_prefix , m_cmd = False , ** kwargs )
15031534
15041535 def format_results (self , filenames , tmp_folder , tcrs , epitopes , pairwise , ** kwargs ):
1536+ import re
15051537 results_joined = []
15061538 for i in range (0 , len (filenames ), 2 ):
15071539 path_out = filenames [i + 1 ]
1508- results_predictor = pd .read_csv (path_out )
1540+ results_predictor = pd .read_csv (path_out , index_col = 0 , comment = "#" )
1541+ results_predictor .index .name = None
15091542 results_predictor = results_predictor .fillna ("" )
15101543 results_predictor = results_predictor .rename (columns = {"cdr3_TRB" : "VDJ_cdr3" ,
15111544 "TRBV" : "VDJ_v_gene" ,
15121545 "TRBJ" : "VDJ_j_gene" ,
15131546 "cdr3_TRA" : "VJ_cdr3" ,
15141547 "TRAV" : "VJ_v_gene" ,
15151548 "TRAJ" : "VJ_j_gene" ,
1516- "prediction" : "Score" })
1517- epitope = ""
1518- mhc = ""
1549+ "score" : "Score" })
1550+ epitope = path_out .split ("_" )[- 1 ].split (".csv" )[0 ]
1551+ mhc = path_out .split ("_" )[- 2 ]
1552+ mhc = mhc .replace ("H2" , "H-2" )
1553+ if mhc [0 ] != "H" :
1554+ match = re .match (r"([A-Za-z]+)(\d+)" , mhc )
1555+ letters , numbers = match .groups ()
1556+ mhc = f"HLA-{ letters } *{ numbers [:- 2 ]} :{ numbers [- 2 :]} "
15191557 results_predictor ["Epitope" ] = epitope
15201558 results_predictor ["MHC" ] = mhc
15211559 results_joined .append (results_predictor )
0 commit comments