-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathPipelineCTransPathWithStride.py
More file actions
1036 lines (848 loc) · 46.1 KB
/
PipelineCTransPathWithStride.py
File metadata and controls
1036 lines (848 loc) · 46.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import warnings
import os
import logging
import random
import pickle
import datetime
import joblib
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
import torch_geometric
import argparse
warnings.filterwarnings("ignore")
#mpl.rcParams["figure.dpi"] = 300 # for high resolution figure in notebook, default 100
# local utils
from utils.visualisation import mkdir, recur_find_ext, rm_n_mkdir, load_json
from utils.spxl_graph import construct_superpixel_graph
from utils.data import SlideGraphEpiDataset, load_patch_labels, make_label_df_with_slide_labels, \
dual_upsample, find_base_data, get_mask_dir, get_epi_mask_dir, get_wsi_dir, select_cohort, filter_wsis
from utils.model import select_checkpoints, SlideGraphArch
from utils.helper import reset_logging
from utils.metrics import create_resp_metric_dict, find_optimal_cutoff, threshold_predictions, metric_str_thresh_all
from utils.plot import plot_confusion_matrix, density_plot
from utils.utils import str2bool
from superpixels import superpixel_feats_for_one_slide
from graphs import construct_slidegraph
from train import run_once
########## Arguments ##########
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=3, choices=[0, 1, 2, 3], help='GPU number to use')
parser.add_argument('--seed', type=int, default=4, #choices=[1, 2, 3, 4],
help='set the seed for training and data split. '
'should match with baseline model to separate train/val datasets')
parser.add_argument('--clinical-file', type=str,
default="Metadata/PatchLabelsInclNAsTm20TGFb.csv",
help='CSV file where patch labels and other metadata is defined'
)
parser.add_argument('--mag', type=str, default='20X', choices=['5X', '10X', '20X'],
help='Magnification of patches')
parser.add_argument('--resp', nargs='+', default=['response_cr_nocr', 'CMS4', 'epithelium'],
help='List of response variables')
# use like python script.py --resp response_cr_nocr CMS4 epithelium -- etc.
parser.add_argument('--cohorts', nargs='+', default=['GRAMPIAN', 'ARISTOTLE'], # SALZBURG
help='List of cohorts to train and validate on')
# Added new args
parser.add_argument('--filter-epi', type=str2bool, default=True,
help='Filter WSIs by those which have saved epithelial graphs already created')
parser.add_argument('--upsample', type=str2bool, default=True,
help='Whether to upsample WSIs from minority classes. Generally always true.')
parser.add_argument('--shuffle-splits', type=str2bool, default=True,
help='Whether to shuffle WSIs in training/validation splits. Generally tru.')
parser.add_argument('--resolution', default=5.0, type=float,
help='Resolution/magnification for graph generation')
parser.add_argument('--compactness', default=20.0, type=float,
help='Compactness parameter for SLIC algorithm')
parser.add_argument('--generate-graphs', default=False, type=str2bool,
help='Whether to generate graphs or use saved graphs. May depend on other parameters.')
parser.add_argument('--generate-superpixels', default=False, type=str2bool,
help='Whether to generate superpixels or use saved features. May depend on other parameters.')
parser.add_argument('--set-max-clusters', default=False, type=str2bool,
help='Whether to set max number of clusters in WSI graph')
parser.add_argument('--num-clusters', default=None, help='Number of clusters if setting maximum for graph')
parser.add_argument('--epochs', default=50, type=int, help='Number of epochs for GNN training')
parser.add_argument('--batch-size', default=64, type=int, help='Batch size for GNN training')
parser.add_argument('--lr', default=1.0e-3, type=float, help='Learning rate for GNN')
parser.add_argument('--weight-decay', default=1.0e-4, type=float, help='Weight decay for learning for GNN')
###############
parser.add_argument('--base-name', default='CTransPath', type=str, choices=['CTransPath', 'DINO'],
help='Baseline model for patch/node features')
parser.add_argument('--base-version', default='4.04', type=str,
help='Baseline model version for patch/node features')
parser.add_argument('--scaler', default=False, type=str2bool,
help='True for trainable logistic regression (upside down results), False for nonparametric sigmoid')
parser.add_argument('--temper', type=float, default=1.5, help='Tempering output; 1.5 used for MICCAI; alt 0.1')
parser.add_argument('--connectivity-scale', default=20, help='Graph connectivity', type=int)
parser.add_argument('--gembed', type=str2bool, default=False, help='Whether to gembed the GNN')
parser.add_argument('--superpixel', type=str2bool, default=True, help='True for MICCAI')
parser.add_argument('--scale-slic', type=int, default=2, help='Scale for SLIC algorithm, 8 for Salzburg, 2 otherwise')
parser.add_argument('--spxl-by-patch', type=str2bool, default=False,
help='Number of superpixels ~ patches. False for MICCAI. Implemented after v5.x')
parser.add_argument('--with-stride', type=str2bool, default=False,
help='To determine number of patches for calculating number of superpixels')
parser.add_argument('--remove-background', type=str2bool, default=False,
help='Removing white background superpixels. Implemented after v5.x')
parser.add_argument('--mlp', type=str2bool, default=True, help='MLP layer for output')
parser.add_argument('--mlp-version', type=int, default=1,
help='MLP layer version for output. 1 is MICCAI version. 2 is ops.MLP applied earlier.')
parser.add_argument('--loss', type=str, default='bce', choices=['bce', 'slidegraph'], help='Loss function')
parser.add_argument('--loss-weights', nargs='+', type=float, default=[1., 1., 1.],
help='Weights on respective response variables')
parser.add_argument('--remove-unclassified-cms4', type=str2bool, default=False,
help='Remove unclassified CMS4 WSIs from analysis (usually treated as not CMS4)')
parser.add_argument('--remove-unmatched-cms4', type=str2bool, default=False,
help='Remove unmatched CMS4 WSIs from analysis (usually treated as not CMS4)')
parser.add_argument('--preproc', type=str2bool, default=True,
help='Whether to preprocess and normalize the node features prior to GNN training')
parser.add_argument('--log', default=False, type=str2bool, help='Whether to log training in Tensorboard')
parser.add_argument('--dev-mode', default=False, type=str2bool, help='Whether to run reduced analysis in dev mode')
parser.add_argument('--epi-graph-dir-root', type=str,
default='checkpoint/',
help='Root directory where graphs for epithelium are saved. Base model details will be added.')
parser.add_argument('--root-dir', type=str,
default='checkpoint/',
help='Root directory where everything is saved. Base model details will be added.')
parser.add_argument('--layer-dims', default=[64, 32, 16], nargs='+', type=int, help='Layer dimensions in GNN')
parser.add_argument('--graph-agg', default='min', type=str, choices=['mean', 'max', 'min', 'sum', 'mul'],
help='Aggregation method for GNN')
parser.add_argument('--graph-pool', default='mean', type=str, choices=['mean', 'max', 'add'],
help='Pooling method for GNN')
parser.add_argument('--dropout', default=0.5, type=float, help='Dropout probability for GNN')
parser.add_argument('--mlp-dropout', default=0.1, type=float, help='Dropout probability for MLP heads')
parser.add_argument('--graph-cache-name', default='default', type=str)
parser.add_argument('--overwrite', default=False, type=str2bool, help='Whether to write over existing model checkpoints')
# 'slidegraph'
#'superpixel_5X_compactness_20_scaleslic_2' # for MICCAI
#'superpixel_upsample_connectivity_range_.125_gembed_true_temper_1.5_ginconv_scaleslic_2' # None
args = parser.parse_args()
########## Add defined arguments ##########
#setattr(args, 'base_version', f'4.0{args.seed}')
setattr(args, 'root_output_dir', os.path.join(args.root_dir, f"{args.base_name}Base{args.base_version}"))
#setattr(args, 'epi_graph_dir', os.path.join(args.root_dir,
# f'{args.base_name}Base{args.base_version}/graph/epithelium'))
if args.graph_cache_name == 'None':
setattr(args, 'graph_cache_name', None)
elif args.graph_cache_name == 'default':
setattr(args, 'graph_cache_name', f'superpixel_5X_compactness_20_scaleslic_{args.scale_slic}')
# Set graph dir
#GRAPH_NAME = f'superpixel_{int(args.resolution)}X_compactness_{int(args.compactness)}_scaleslic_{args.scale_slic}'
GRAPH_DIR = f"{args.root_output_dir}/graph/epithelium/{args.graph_cache_name}"
if float(args.base_version) >= 5.0:
GRAPH_DIR = os.path.join(GRAPH_DIR, f'seed_{args.seed}')
# need seed for train/val split as features have diff augs
print('Graph dir:', GRAPH_DIR)
if args.set_max_clusters:
print('Setting max number of clusters')
GRAPH_DIR = os.path.join(f"{args.root_output_dir}/graph", f'{args.num_clusters}_clusters')
CLUSTER_DIR = f"{args.root_output_dir}/clusters/{args.graph_cache_name}"
setattr(args, 'epi_graph_dir', GRAPH_DIR)
loss_weights_str = 'weight_' + '_'.join(str(num) for num in args.loss_weights)
mlp_str = f"_mlp_{args.mlp_version}_dropout_{str(args.mlp_dropout).lstrip('0')}" if args.mlp else "" #_dropout_{str(args.mlp_dropout).lstrip('0')}"
print('args.layer_dims:', args.layer_dims)
setattr(args, 'layer_dims', list(args.layer_dims))
layer_str = 'layers_' + '_'.join(str(num) for num in args.layer_dims)
if args.layer_dims==[64, 32, 16]:
layer_str += "_xlarge"
elif args.layer_dims==[128, 64, 32, 16]:
layer_str += "_xxlarge"
elif args.layer_dims==[32, 16]:
layer_str += "_large"
#else:
# layer_str = ""
setattr(args, 'model_name', os.path.join("_".join(args.resp),
"_".join(args.cohorts) +
f'_{"superpixel" if args.superpixel else "slidegraph"}_' +
f'{"_patch_scaled" if args.spxl_by_patch else ""}' +
f'{"_filtered" if args.remove_background else ""}' +
f'{"rm_unmatched_" if args.remove_unmatched_cms4 else ""}' +
f'{"rm_unclassified_" if args.remove_unclassified_cms4 else ""}' +
f'{"upsample_" if args.upsample else ""}' +
f'{"preproc_false" if args.preproc == False else "normalize_train"}' +
#f'_connectivity_range_{str(1/args.connectivity_scale).lstrip("0")[:4]}' +
f'_connectivity_scale_{str(args.connectivity_scale)}' +
f'_gembed_{str(args.gembed).lower()}_' +
f'temper_{args.temper}_ginconv{mlp_str}' +
f'_{args.loss}' +
f'{layer_str}_' +
f'{args.graph_agg}_aggr_{args.graph_pool}_pool' +
f'{loss_weights_str if (not all(it == 1 for it in args.loss_weights)) else ""}'))
########## Assert parameters as expected ##########
if 'SALZBURG' in args.cohorts:
assert args.scale_slic == 8, f"Scale SLIC parameter ({args.scale_slic}) should be 8 for SALZBURG"
else:
assert args.scale_slic == 2, f"Scale SLIC parameter ({args.scale_slic}) should be 2 for GRAMPIAN/ARISTOTLE"
if args.with_stride and 'nostride' in args.clinical_file.lower():
raise Exception("with-stride set to True but no stride metadata used")
########## Check if model already exists ##########
MODEL_DIR = os.path.join(f"{args.root_output_dir}/model/", args.model_name)
if args.set_max_clusters:
MODEL_DIR = os.path.join(f"{args.root_output_dir}/model/{args.num_clusters}_clusters", args.model_name)
print('Model dir:', MODEL_DIR)
if os.path.exists(MODEL_DIR) and args.overwrite:
print('WARNING: model directory already exists, set to overwrite results')
if not args.overwrite:
while os.path.exists(MODEL_DIR):
#if not args.overwrite:
print('WARNING: overwrite set to False. Set --overwrite True to overwrite previous model.')
model_version = args.model_name.split('_')[-1]
if model_version.startswith('v'):
version_number = model_version[1:]
new_version_number = int(version_number) + 1
new_model_name = args.model_name.replace(f'_v{version_number}', f'_v{new_version_number}')
else:
new_model_name = args.model_name + '_v1'
# check extists again
#setattr(args, 'model_name', new_model_name)
#print(f'Model name updated to {new_model_name}')
MODEL_DIR = os.path.join(f"{args.root_output_dir}/model/", new_model_name)
setattr(args, 'model_name', new_model_name)
#else:
# pass
print(f'Model name updated to {new_model_name}')
setattr(args, 'save_img_path', os.path.join(args.root_output_dir, 'visualisations', str(args.model_name)))
########## Set GPU ##########
torch.cuda.set_device(args.gpu)
########## Logging ##########
if args.log:
sub_dir = 'tensorboard'
if args.dev_mode:
sub_dir = 'tensorboard_dev'
tensorboard_dir = os.path.join(f'logs/{sub_dir}',
f'{args.base_name}{args.base_version}', args.model_name)
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
current_time = str(datetime.datetime.now().strftime("%d%m%Y-%H:%M:%S"))
train_log_dir = tensorboard_dir + '/train/' + current_time
val_log_dir = tensorboard_dir + '/val/' + current_time
train_summary_writer = SummaryWriter(log_dir=train_log_dir)
val_summary_writer = SummaryWriter(log_dir=val_log_dir)
else:
train_summary_writer, val_summary_writer = None, None
########## Set seed ##########
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
########## Make directories ##########
if not args.dev_mode:
mkdir(args.save_img_path)
mkdir(args.root_output_dir)
mkdir(f'{args.root_output_dir}/{"_".join(args.resp)}')
########## Load data ##########
def load_test_patch_labels(clinical_file, mag, cohort):
patch_labels = pd.read_csv(clinical_file, index_col=0)
patch_labels = patch_labels[patch_labels.magnification == mag]
patch_labels = select_cohort(patch_labels, cohort)
return patch_labels
wsi_dirs, msk_dirs, epi_msk_dirs = [], [], []
for cohort in args.cohorts:
cohort_wsi_dir = get_wsi_dir(cohort)
wsi_dirs.append(cohort_wsi_dir)
setattr(args, f'WSI_DIR_{cohort.upper()}', cohort_wsi_dir)
cohort_mask_dir = get_mask_dir(cohort)
msk_dirs.append(cohort_mask_dir)
setattr(args, f'MSK_DIR_{cohort.upper()}', cohort_mask_dir)
cohort_epi_mask_dir = get_epi_mask_dir(cohort)
epi_msk_dirs.append(cohort_epi_mask_dir)
setattr(args, f'EPI_MSK_DIR_{cohort.upper()}', cohort_epi_mask_dir)
def filter_wsis_by_epi_graphs(wsi_names, filter_epi=args.filter_epi, epi_graph_dir=args.epi_graph_dir,
graph_cache_name=args.graph_cache_name):
if filter_epi:
#epi_graph_paths = recur_find_ext(os.path.join(epi_graph_dir, graph_cache_name), [".json"])
# From 5.0 onwards, graph_dir includes graph_name
epi_graph_paths = recur_find_ext(epi_graph_dir, [".json"])
#epi_wsi_names = list(map(lambda path: slide_name_from_path(path, split_by=graph_cache_name),
# epi_graph_paths))
epi_wsi_names = list(map(lambda path: path.split('/')[-1].split('.json')[0], epi_graph_paths))
# those which we have epi masks for
wsi_names = list(filter(lambda wsi: wsi in epi_wsi_names, wsi_names))
else:
wsi_names = list(wsi_names)
return wsi_names
if args.cohorts == ['SALZBURG']: # test set
wsi_names, wsi_paths, msk_paths, base_feature_dir = find_base_data(wsi_dirs, msk_dirs, base_name=args.base_name,
base_version=args.base_version, seed=args.seed,
epi_msk_dirs=None, test=True)
epi_msk_paths = recur_find_ext(epi_msk_dirs[0], [".png", ".jpg", ".jpeg"])
wsi_names = filter_wsis_by_epi_graphs(wsi_names, args.filter_epi, args.epi_graph_dir, args.graph_cache_name)
# load_dual_patch_labels(clinical_file, mag, responses, cohorts)
patch_labels = load_test_patch_labels(args.clinical_file, args.mag, args.cohorts)
# set slide column to type string
patch_labels.slide = patch_labels.slide.astype('str')
clinical_df = patch_labels.groupby('slide').first().drop(['patch'], axis=1).reset_index()
# Select response columns, without dealing with possible epithelium response here
slide_df = patch_labels.groupby('slide').first().drop(['patch'], axis=1).reset_index()
label_df, slide_responses = make_label_df_with_slide_labels(slide_df, responses=args.resp)
our_sel = np.where([str(wsi) in wsi_names for wsi in label_df['WSI-CODE']])[0]
label_df = label_df.loc[our_sel].reset_index(drop=True)
print('Labels:', len(label_df))
else:
wsi_names, wsi_paths, msk_paths, epi_msk_paths, base_feature_dir = find_base_data(wsi_dirs, msk_dirs,
base_name=args.base_name,
base_version=args.base_version,
seed=args.seed,
epi_msk_dirs=epi_msk_dirs,
test=False)
# if already generated epi graphs
wsi_names = filter_wsis_by_epi_graphs(wsi_names, args.filter_epi, args.epi_graph_dir, args.graph_cache_name)
# if DEV_MODE:
# wsi_names = wsi_names[:100]
# wsi_paths = wsi_paths[:100]
# msk_paths = msk_paths[:100]
# * Generate WSI labels
patch_labels = load_patch_labels(args.clinical_file, args.mag, args.resp, args.cohorts)
slide_df = patch_labels.groupby('slide').first().drop(['patch'], axis=1).reset_index()
if args.remove_unclassified_cms4:
print(f'Removing {slide_df.CMS_matching.value_counts()["Unclassified"]} unclassified CMS from dataset')
slide_df = slide_df[slide_df.CMS_matching != 'Unclassified'].reset_index(drop=True)
if args.remove_unmatched_cms4:
print(f'Removing {slide_df.CMS_matching.value_counts()["Unmatched"]} unmatched CMS from dataset')
slide_df = slide_df[slide_df.CMS_matching != 'Unmatched'].reset_index(drop=True)
# Select response columns, without dealing with possible epithelium response here
label_df, slide_responses = make_label_df_with_slide_labels(slide_df, responses=args.resp)
# label_df.rename(columns={'slide':'WSI-CODE', RESP[0]: 'LABEL_0', RESP[1]: 'LABEL_1'},
# inplace=True)
# Epithelium labels are stored in graph data, hence only deal with RESP[:2].
# Filter label_df based on WSIs we have features for
our_sel = np.where([wsi in wsi_names for wsi in label_df['WSI-CODE']])[0]
label_df = label_df.loc[our_sel].reset_index(drop=True)
print('Labels:', len(label_df))
# Redo wsi_names for normalizer and superpixels
wsi_names = label_df['WSI-CODE'].values
assert len(label_df) > 0, "Problem loading WSI labels, none found"
########## Create training data splits ##########
# splits is list of length 1 (num_folds). In list is dictionary with keys ['train', 'valid', 'test'].
# Each dict value is list of tuples, tuples of length two, with slide name and response value.
#split_cache_path = f"{args.root_output_dir}/shuffle_splits.dat"
# Define
mkdir(f"{args.root_output_dir}/{args.model_name}")
SPLIT_PATH = os.path.join(f"{args.root_output_dir}/{args.model_name}",
f"{'shuffle_' if args.shuffle_splits else ''}splits.dat")
NUM_FOLDS = 1
def split_train_val(label_df, train_val_split=0.7, seed=args.seed):
cases = label_df['WSI-CODE'].values
num_train_cases = int(np.ceil(len(cases) * train_val_split))
random.seed(seed)
random.shuffle(cases)
train_cases = cases[:num_train_cases]
val_cases = cases[num_train_cases:]
print('Number of train cases:', len(train_cases))
print('Number of validation cases:', len(val_cases))
# train_label_df = label_df[label_df['WSI-CODE'].isin(train_cases)].reset_index(drop=True)
# val_patch_labels = patch_labels[patch_labels.case.isin(val_cases)].reset_index(drop=True)
return sorted(train_cases), sorted(val_cases)
if float(args.base_version) < 5.0:
if args.cohorts == ['SALZBURG']:
train_wsis, val_wsis = split_train_val(label_df, train_val_split=0.7, seed=args.seed)
else:
train_wsis = sorted(os.listdir(os.path.join(base_feature_dir, 'Train')))
val_wsis = sorted(os.listdir(os.path.join(base_feature_dir, 'Validation')))
else:
train_wsis, val_wsis = split_train_val(label_df, train_val_split=0.7, seed=args.seed)
train_wsis = filter_wsis(train_wsis, label_df)
val_wsis = filter_wsis(val_wsis, label_df)
if args.shuffle_splits:
random.seed(args.seed) # changed from 0 after DINO1.11 first two models
random.shuffle(train_wsis)
random.shuffle(val_wsis)
print('Shuffled wsis')
train_labels = [
label_df[label_df['WSI-CODE'] == slide][[f'LABEL_{i}' for i in range(len(slide_responses))]].values.tolist()[0] for
slide in train_wsis]
val_labels = [
label_df[label_df['WSI-CODE'] == slide][[f'LABEL_{i}' for i in range(len(slide_responses))]].values.tolist()[0] for
slide in val_wsis]
print('Train and validation set response distribution (across both labels):')
print(f' {np.unique(train_labels, return_counts=True)}')
print(f' {np.unique(val_labels, return_counts=True)}')
for i in range(len(slide_responses)):
print(
f'There are {np.unique(np.array(val_labels)[:, i], return_counts=True)[1][1]} positive {slide_responses[i]} slides in' +
' the validation set')
if args.upsample:
train_wsis, train_labels = dual_upsample(train_wsis, train_labels, slide_responses)
print(f'{len(train_wsis)} slides in training set after upsampling')
print('Number of train slides:', len(train_wsis))
print('Number of validation slides:', len(val_wsis))
splits = []
splits.append(
{
"train": list(zip(train_wsis, train_labels)),
"valid": list(zip(val_wsis, val_labels)),
# "test": list(zip(val_wsis, val_labels)),
}
)
# Save splits
if not args.dev_mode:
joblib.dump(splits, SPLIT_PATH)
########## Generating superpixels ##########
if args.base_name == 'CTransPath':
NUM_NODE_FEATURES = 768
elif args.base_name == 'DINO':
NUM_NODE_FEATURES = 384
if args.superpixel:
graph_name = f'superpixel_{int(args.resolution)}X_compactness_{int(args.compactness)}_scaleslic_{args.scale_slic}'
if args.spxl_by_patch:
graph_name += '_patch_scaled'
if args.remove_background:
graph_name += '_filtered'
WSI_FEATURE_DIR = os.path.join(args.root_output_dir, 'features', graph_name) #f'superpixel_{int(args.resolution)}X_compactness_{int(args.compactness)}_scaleslic_{args.scale_slic}')
if float(args.base_version) >= 5.0:
WSI_FEATURE_DIR = os.path.join(WSI_FEATURE_DIR, f'seed_{args.seed}')
# need seed for train/val split as features have diff augs
print(WSI_FEATURE_DIR)
if args.generate_superpixels:
patch_labels.slide = patch_labels.slide.astype('str')
patch_counts_per_slide = patch_labels.groupby('slide')['patch'].count()
num_patches = None
failed_spxl = []
for wsi in sorted(train_wsis, reverse=True):
print(f'Generating superpixels for {wsi}')
if args.spxl_by_patch:
num_patches = patch_counts_per_slide[wsi] / 4 if args.with_stride else patch_counts_per_slide[wsi]
try:
_, _ = superpixel_feats_for_one_slide(wsi, wsi_paths=wsi_paths, mask_paths=msk_paths,
epi_msk_paths=epi_msk_paths,
wsi_feature_dir=WSI_FEATURE_DIR,
scale_slic=args.scale_slic,
base_name=args.base_name,
base_version=args.base_version,
seed=args.seed,
num_node_features=NUM_NODE_FEATURES,
train_or_val='Train',
num_patches=num_patches,
remove_background=args.remove_background,
resolution=args.resolution, mag=args.mag, compactness=args.compactness,
save_feats=True, jit=False)
except Exception as e:
print(f'Couldn\'t generate superpixels for slide {wsi}. \nError: {e}')
failed_spxl.append(wsi)
for wsi in sorted(val_wsis, reverse=True):
print(f'Generating superpixels for {wsi}')
if args.spxl_by_patch:
num_patches = patch_counts_per_slide[wsi] / 4 if args.with_stride else patch_counts_per_slide[wsi]
try:
_, _ = superpixel_feats_for_one_slide(wsi, wsi_paths=wsi_paths, mask_paths=msk_paths,
epi_msk_paths=epi_msk_paths,
wsi_feature_dir=WSI_FEATURE_DIR,
scale_slic=args.scale_slic,
base_name=args.base_name,
base_version=args.base_version,
seed=args.seed,
num_node_features=NUM_NODE_FEATURES,
train_or_val='Validation',
num_patches=num_patches,
remove_background=args.remove_background,
resolution=args.resolution, mag=args.mag, compactness=args.compactness,
save_feats=True, jit=False)
except Exception as e:
print(f'Couldn\'t generate superpixels for slide {wsi}. \nError: {e}')
failed_spxl.append(wsi)
# Failed spxl WSIs might still be in wsi_names when reloading features
for wsi in failed_spxl:
wsi_names = wsi_names.tolist()
wsi_names.remove(wsi)
else:
WSI_FEATURE_DIR = None
########## Generating graphs ##########
#GRAPH_NAME = f'superpixel_{int(args.resolution)}X_compactness_{int(args.compactness)}_scaleslic_{args.scale_slic}'
#GRAPH_DIR = f"{args.root_output_dir}/graph/epithelium/{GRAPH_NAME}"
#if float(args.base_version) >= 5.0:
# GRAPH_DIR = os.path.join(GRAPH_DIR, f'seed_{args.seed}')
# # need seed for train/val split as features have diff augs
#print('Graph dir:', GRAPH_DIR)
#
#if args.set_max_clusters:
# print('Setting max number of clusters')
# GRAPH_DIR = os.path.join(f"{args.root_output_dir}/graph", f'{args.num_clusters}_clusters')
# CLUSTER_DIR = f"{args.root_output_dir}/clusters/{GRAPH_NAME}"
if args.generate_graphs:
wsi_names = np.unique([file.split('.')[0] for file in os.listdir(WSI_FEATURE_DIR)])
print('Redoing splits after filtering for graphs')
# Redo splits after filtering for graphs
splits = []
train_wsi_idx = np.nonzero(np.in1d(train_wsis, wsi_names))[0]
train_labels = [train_labels[i] for i in train_wsi_idx]
train_wsis = [train_wsis[i] for i in train_wsi_idx]
print(f'{len(train_wsis)} train slides')
val_wsi_idx = np.nonzero(np.in1d(val_wsis, wsi_names))[0]
val_labels = [val_labels[i] for i in val_wsi_idx]
val_wsis = [val_wsis[i] for i in val_wsi_idx]
print(f'{len(val_wsis)} val slides')
splits.append(
{
"train": list(zip(train_wsis, train_labels)),
"valid": list(zip(val_wsis, val_labels)),
# "test": list(zip(val_wsis, val_labels)),
}
)
# Save splits
if not args.dev_mode:
joblib.dump(splits, SPLIT_PATH)
print(f'Generating graphs for {len(wsi_names)} slides')
mkdir(args.epi_graph_dir)
if args.superpixel:
for wsi in wsi_names:
construct_superpixel_graph(wsi, save_path=f"{args.epi_graph_dir}/{wsi}.json",
connectivity_scale=args.connectivity_scale,
wsi_feature_dir=WSI_FEATURE_DIR,
add_epi=True) # assumes binary_epi_labels exist in feature dir
else:
coords_clusters_dict, coords_max_clusters_dict = {}, {}
for wsi in wsi_names:
coords_clusters, coords_max_clusters = construct_slidegraph(wsi, f"{args.epi_graph_dir}/{wsi}.json")
if coords_clusters is None:
continue
coords_clusters_dict[wsi] = coords_clusters
coords_max_clusters_dict[wsi] = coords_max_clusters
mkdir(CLUSTER_DIR)
pickle.dump(coords_clusters_dict, open(f'{CLUSTER_DIR}/graph_clusters.p', 'wb'))
if args.set_max_clusters:
pickle.dump(coords_max_clusters_dict, open(f'{CLUSTER_DIR}/graph_{args.num_clusters}_clusters.p', 'wb'))
else:
graph_paths = recur_find_ext(f"{args.epi_graph_dir}/", [".json"])
########## Scale graph features ##########
SCALER_PATH = f"{args.root_output_dir}/{args.model_name}_{'clusters_' if args.set_max_clusters else ''}node_scaler.dat"
if args.preproc:
if os.path.exists(SCALER_PATH):
node_scaler = joblib.load(SCALER_PATH)
else:
#if NODE_PREDICTION:
dataset = SlideGraphEpiDataset(wsi_names, graph_dir=args.epi_graph_dir, mode="infer") # no labels
#else:
#dataset = SlideGraphDataset(wsi_names, graph_dir=args.epi_graph_dir, mode="infer") # train_wsis for train set only
loader = torch_geometric.loader.DataLoader(
dataset, num_workers=8, batch_size=1, shuffle=False, drop_last=False
)
node_features = [v[0]["graph"].x.numpy() for idx, v in enumerate(tqdm(loader))]
node_features = np.concatenate(node_features, axis=0)
node_scaler = StandardScaler(copy=False) # Standardize features by removing the mean and scaling to unit variance.
node_scaler.fit(node_features)
if not args.dev_mode:
joblib.dump(node_scaler, SCALER_PATH)
# we must define the function after training/loading
def nodes_preproc_func(node_feats):
return node_scaler.transform(node_feats)
########## Train model ##########
NUM_EPOCHS = 2 if args.dev_mode else args.epochs
torch.autograd.set_detect_anomaly(True)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
if not args.dev_mode:
splits = joblib.load(SPLIT_PATH)
if args.preproc:
if not args.dev_mode:
node_scaler = joblib.load(SCALER_PATH)
def nodes_preproc_func(node_feats):
return node_scaler.transform(node_feats)
else:
nodes_preproc_func = None
loader_kwargs = dict(
num_workers=8,
batch_size=args.batch_size, # RW: can't have batch_size bigger than dataset. changed from 16 to 4.
)
arch_kwargs = dict(
dim_features=NUM_NODE_FEATURES,
dim_target=1, # RW: changed from 1 to 4
layers=args.layer_dims, # changed from [16, 16, 8], xlarge is [64, 32, 16], xxlarge is [128, 64, 32, 16]
dropout=args.dropout, # changed from 0.5 to 0.3
pooling=args.graph_pool, # changed from mean to max
conv="GINConv",
aggr=args.graph_agg, # changed from max to min
gembed=args.gembed,
scaler=args.scaler,
temper=args.temper,
use_mlp=args.mlp,
mlp_version=args.mlp_version,
mlp_dropout=args.mlp_dropout
)
optim_kwargs = dict(
lr=args.lr, # RW: changed from 1e-3 to 1e-1
weight_decay=args.weight_decay,
)
logging.basicConfig(
level=logging.INFO,
)
for split_idx, split in enumerate(splits):
new_split = {"train": split["train"]}
if args.scaler:
new_split.update({"infer-train": split["train"]})
new_split.update({"infer-valid-A": split["valid"]})
# "infer-valid-B": split["test"], # Same as validation for now
split_save_dir = f"{MODEL_DIR}/{split_idx:02d}/"
rm_n_mkdir(split_save_dir)
reset_logging(split_save_dir)
out, wsis = run_once(
resp=args.resp, loss_name=args.loss, loss_weights=args.loss_weights, scale=args.scaler,
preproc=args.preproc, temper=args.temper,
dataset_dict=new_split,
num_epochs=NUM_EPOCHS,
graph_dir=args.epi_graph_dir,
save_dir=split_save_dir,
nodes_preproc_func=nodes_preproc_func,
dev_mode=args.dev_mode,
train_summary_writer=train_summary_writer,
val_summary_writer=val_summary_writer,
pretrained=None,
arch_kwargs=arch_kwargs,
loader_kwargs=loader_kwargs,
optim_kwargs=optim_kwargs,
)
if args.log:
train_summary_writer.close()
########## Save losses ##########
if not args.dev_mode:
for split_idx, split in enumerate(splits):
stats_dict = load_json(recur_find_ext(f"{MODEL_DIR}/{split_idx:02d}/", [".json"])[0])
# keys are strings of epochs. Each value contains dict of loss and metrics.
train_losses = [d['train-EMA-loss'] for d in stats_dict.values()]
val_losses = [d['infer-valid-A-loss'] for d in stats_dict.values()]
np.save(f"{MODEL_DIR}/{split_idx:02d}/train_losses.npy", train_losses)
np.save(f"{MODEL_DIR}/{split_idx:02d}/val_losses.npy", val_losses)
mpl.rcParams["figure.dpi"] = 100
plt.figure(figsize=(5, 3))
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.legend()
plt.title('Loss')
plt.savefig(os.path.join(MODEL_DIR, f'{split_idx:02d}', 'loss_plot.png'))
#plt.show()
########## Inference ##########
if args.dev_mode:
exit()
TOP_K = 1
#metric_name = f"{RESP[0]}-infer-valid-A-auroc" # choose best model based on first response only
metric_name = 'infer-valid-A-auroc' # choose based on all responses
PRETRAINED_DIR = MODEL_DIR
splits = joblib.load(SPLIT_PATH)
if args.preproc:
node_scaler = joblib.load(SCALER_PATH)
# need loader_kwargs and arch_kwargs defined, usually from training in same go
cum_stats, cum_preds = [], []
for split_idx, split in enumerate(splits):
new_split = {'valid': split["valid"]} # want valid to return epi label
stat_files = recur_find_ext(f"{PRETRAINED_DIR}/{split_idx:02d}/", [".json"])
print(stat_files)
stat_files = [v for v in stat_files if ".old.json" not in v]
print(stat_files)
assert len(stat_files) == 1
chkpts, chkpt_stats_list, best_epoch = select_checkpoints(
stat_files[0], top_k=TOP_K, metrics=[metric_name]
)
# Perform ensembling by averaging probabilities across checkpoint predictions
cum_results = []
for chkpt_info in chkpts:
chkpt_results, wsis = run_once(
resp=args.resp, loss_name=args.loss, loss_weights=args.loss_weights, scale=args.scaler,
preproc=args.preproc, temper=args.temper,
dataset_dict=new_split,
num_epochs=1,
graph_dir=args.epi_graph_dir,
save_dir=None,
nodes_preproc_func=nodes_preproc_func,
dev_mode=args.dev_mode,
val_summary_writer=val_summary_writer,
pretrained=chkpt_info,
arch_kwargs=arch_kwargs,
loader_kwargs=loader_kwargs
)
# * re-calibrate logit to probabilities
chkpt_results = np.array(chkpt_results)
chkpt_results = np.squeeze(chkpt_results)
if args.scaler:
model = SlideGraphArch(responses=args.resp, **arch_kwargs)
model.load(*chkpt_info)
scaler = model.aux_model["scaler"]
chkpt_results = scaler.predict_proba(np.array(chkpt_results, ndmin=2).T)[:, 0]
cum_results.append(chkpt_results)
cum_results = np.array(cum_results)
if len(args.resp) > 1:
cum_results = np.squeeze(cum_results)
# Generalize for different number of responses with node predictions always last (but check)
####################
metric_dict = {}
pred_dict = {
"fold": split_idx, "best_epoch": best_epoch[split_idx],
}
all_mets = []
for i in range(len(args.resp)):
node_level = False
if 'epithelium' in args.resp[i]:
node_level = True
output_logit, output_true = [], []
for out in cum_results:
if node_level:
output_logit.extend([out_[0] for out_ in out[i:]])
output_true.extend([out_[1] for out_ in out[i:]])
else:
output_logit.append(out[i][0])
output_true.append(out[i][1])
############### # TODO: these are scalers ###############
print(f'out[{i}] length:', len(out[i]))
print('out length:', len(out))
output_logit = np.array(output_logit, dtype=np.float16)
output_true = np.array(output_true)
metric_dict.update(create_resp_metric_dict(args.resp[i], output_true, output_logit, best_epoch[split_idx]))
# Add thresholded metrics
print(' Using thresholding from all cohorts')
threshold = find_optimal_cutoff(output_true, output_logit)
resp_mets = create_resp_metric_dict(args.resp[i], output_true, output_logit, best_epoch[split_idx],
cutoff=threshold)
resp_mets = {'threshold-' + k: v for k, v in resp_mets.items() if not k == 'best_epoch'}
resp_mets[f'{args.resp[i]}-threshold'] = threshold[0]
metric_dict.update(resp_mets)
pred_dict.update({f"{args.resp[i]}_preds": output_logit, f"{args.resp[i]}_true": output_true})
# Print metrics in table format
all_mets.append(resp_mets)
cum_stats.append(metric_dict)
if args.log:
hparams = vars(args).copy()
hparams['layer_dims'] = '_'.join(str(num) for num in hparams['layer_dims'])
hparams['cohorts'] = '_'.join(str(cohort) for cohort in hparams['cohorts'])
hparams['loss_weights'] = '_'.join(str(num) for num in hparams['loss_weights'])
hparams['resp'] = '_'.join(str(response) for response in hparams['resp'])
#print('hparams')
#print(hparams)
#print('\nmetric_dict')
#print(metric_dict)
# Either add metrics at end after cohort evaluation or here
if args.cohorts != ['GRAMPIAN', 'ARISTOTLE']:
val_summary_writer.add_hparams(hparam_dict=hparams, metric_dict=metric_dict)
cum_preds.append(pred_dict)
# Save metrics
print(args.base_name, args.base_version, args.model_name)
stat_df = pd.DataFrame(cum_stats)
for metric in stat_df.columns:
vals = stat_df[metric]
mu = np.mean(vals)
va = np.std(vals)
print(f"- {metric}: {mu:0.4f}±{va:0.4f}")
results_save_path = os.path.join(args.root_output_dir, 'results', args.model_name)
if not os.path.exists(results_save_path):
mkdir(results_save_path)
stat_df.to_csv(os.path.join(results_save_path, 'mean_best_metrics_over_folds'), index=False)
preds_df = pd.DataFrame(cum_preds)
preds_df.to_csv(os.path.join(results_save_path, 'fold_predictions'), index=False)
# Save confusion matrices and prediction density plots
viz_fold = 0
viz_epoch = best_epoch[int(viz_fold)]
met_args = [resp for resp in args.resp if resp!='cohort_cls'] # exclude cohort_cls
#met_args = args.resp[:2] + [args.resp[-1]] # exclude any third value in case is cohort_cls
for response in list(met_args):
if response == 'cohort_cls':
continue
print(response)
resp_true = preds_df[f'{response}_true'][0]
resp_preds = preds_df[f'{response}_preds'][0]
confusion_fig = plot_confusion_matrix(resp_true, threshold_predictions(resp_true, resp_preds), response,
viz_fold, viz_epoch, save=True, save_img_path=args.save_img_path, thresh=True)
density_fig = density_plot(resp_true, resp_preds, response, viz_fold, viz_epoch, save=True,
save_img_path=args.save_img_path)
if args.log:
val_summary_writer.add_figure(f'Validation Confusion Matrix with Threshold - {response}', confusion_fig)
val_summary_writer.add_figure(f'Validation Density Plot - {response}', density_fig)
# Print metrics in table format
#all_mets = [resp_0_mets, resp_1_mets, resp_2_mets]
print('Thresholded metrics printed below - can be used in Notebook table')
print()
print(f'| {args.base_name} {args.base_version} | {args.model_name.split("/")[1]} |' +\
metric_str_thresh_all(all_mets, met_args, 'auroc', threshold=True) +
metric_str_thresh_all(all_mets, met_args, 'balanced_acc', threshold=True) +
metric_str_thresh_all(all_mets, met_args, 'weighted_f1', threshold=True))
print()
# Check validation metrics on different cohorts
def validation_metrics(split, chkpt_info=chkpts[0], epoch=best_epoch[0], arch_kwargs=arch_kwargs,
loader_kwargs=loader_kwargs):
chkpt_results, wsis = run_once(
resp=args.resp, loss_name=args.loss, loss_weights=args.loss_weights, scale=args.scaler,
preproc=args.preproc, temper=args.temper,
dataset_dict=split,
num_epochs=1,
graph_dir=args.epi_graph_dir,
save_dir=None,
nodes_preproc_func=nodes_preproc_func,
dev_mode=args.dev_mode,
val_summary_writer=val_summary_writer,
pretrained=chkpt_info,
arch_kwargs=arch_kwargs,
loader_kwargs=loader_kwargs
)
# * re-calibrate logit to probabilities
chkpt_results = np.array(chkpt_results)
chkpt_results = np.squeeze(chkpt_results)
if args.scaler:
model = SlideGraphArch(responses=args.resp, **arch_kwargs)
model.load(*chkpt_info)
scaler = model.aux_model["scaler"]
chkpt_results = scaler.predict_proba(np.array(chkpt_results, ndmin=2).T)[:, 0]
cum_results = chkpt_results
cum_results = np.array(cum_results)
cum_results = np.squeeze(cum_results)
output_1_logit, output_1_true = [], []
output_2_logit, output_2_true = [], []
node_output_logit, node_output_true = [], []
if 'cohort_cls' in args.resp:
epi_idx = 3
else:
epi_idx = 2
for out in cum_results:
output_1_logit.append(out[0][0])
output_1_true.append(out[0][1])
output_2_logit.append(out[1][0])
output_2_true.append(out[1][1])
node_output_logit.extend([out_[0] for out_ in out[epi_idx:]])
node_output_true.extend([out_[1] for out_ in out[epi_idx:]])
output_1_logit = np.array(output_1_logit)
output_1_true = np.array(output_1_true)
output_2_logit = np.array(output_2_logit)
output_2_true = np.array(output_2_true)
node_output_logit = np.array(node_output_logit)
node_output_true = np.array(node_output_true)
print('Without thresholding')
metric_dict = {}
print(args.resp[0])
metric_dict.update(create_resp_metric_dict(args.resp[0], output_1_true, output_1_logit, epoch))
print(args.resp[1])
metric_dict.update(create_resp_metric_dict(args.resp[1], output_2_true, output_2_logit, epoch))
print(args.resp[-1])
metric_dict.update(create_resp_metric_dict(args.resp[-1], node_output_true, node_output_logit, epoch))
print('Using thresholding from joint cohorts')
print(args.resp[0])
threshold_0 = find_optimal_cutoff(output_1_true, output_1_logit)
print('threshold_0:', threshold_0)
resp_0_mets = create_resp_metric_dict(args.resp[0], output_1_true, output_1_logit, epoch, cutoff=threshold_0)