Diff of /make_splits.py [000000] .. [2095ed]

Switch to unified view

a b/make_splits.py
1
### data_loaders.py
2
import argparse
3
import os
4
import pickle
5
6
import numpy as np
7
import pandas as pd
8
from PIL import Image
9
from sklearn import preprocessing
10
11
# Env
12
from networks import define_net
13
from utils import getCleanAllDataset
14
import torch
15
from torchvision import transforms
16
from options import parse_gpuids
17
18
### Initializes parser and data
19
"""
20
all_st
21
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st # for training Surv Path, Surv Graph, and testing Surv Graph
22
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st # for training Grad Path, Grad Graph, and testing Surv_graph
23
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st # for training Surv Omic, Surv Graphomic
24
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st # for training Grad Omic, Grad Graphomic
25
26
all_st_patches_512 (no VGG)
27
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st_patches_512 # for testing Surv Path
28
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st_patches_512 # for testing Grad Path
29
30
all_st_patches_512 (use VGG)
31
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --gpu_ids 0 # for Surv Pathgraph
32
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 1 # for Grad Pathgraph
33
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --gpu_ids 2 # for Surv Pathomic, Pathgraphomic
34
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 3 # for Grad Pathomic, Pathgraphomic
35
36
37
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --make_all_train 1
38
39
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1
40
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1
41
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --use_rnaseq 1 --gpu_ids 2
42
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --use_rnaseq 1 --act_type LSM --label_dim 3 --gpu_ids 3
43
44
45
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --gpu_ids 0
46
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --use_rnaseq 1 --gpu_ids 0
47
48
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 1
49
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --use_rnaseq 1 --act_type LSM --label_dim 3 --gpu_ids 1
50
51
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1
52
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --gpu_ids 2
53
54
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1
55
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 3
56
57
58
59
60
"""
61
def parse_args():
62
    parser = argparse.ArgumentParser()
63
    parser.add_argument('--dataroot', type=str, default='./data/TCGA_GBMLGG/', help="datasets")
64
    parser.add_argument('--roi_dir', type=str, default='all_st')
65
    parser.add_argument('--graph_feat_type', type=str, default='cpc', help="graph features to use")
66
    parser.add_argument('--ignore_missing_moltype', type=int, default=0, help="Ignore data points with missing molecular subtype")
67
    parser.add_argument('--ignore_missing_histype', type=int, default=0, help="Ignore data points with missign histology subtype")
68
    parser.add_argument('--make_all_train', type=int, default=0)
69
    parser.add_argument('--use_vgg_features', type=int, default=0)
70
    parser.add_argument('--use_rnaseq', type=int, default=0)
71
72
73
    parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/TCGA_GBMLGG/', help='models are saved here')
74
    parser.add_argument('--exp_name', type=str, default='surv_15_rnaseq', help='name of the project. It decides where to store samples and models')
75
    parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
76
    parser.add_argument('--mode', type=str, default='path', help='mode')
77
    parser.add_argument('--model_name', type=str, default='path', help='mode')
78
    parser.add_argument('--task', type=str, default='surv', help='surv | grad')
79
    parser.add_argument('--act_type', type=str, default='Sigmoid', help='activation function')
80
    parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
81
    parser.add_argument('--label_dim', type=int, default=1, help='size of output')
82
    parser.add_argument('--batch_size', type=int, default=32, help="Number of batches to train/test for. Default: 256")
83
    parser.add_argument('--path_dim', type=int, default=32)
84
    parser.add_argument('--init_type', type=str, default='none', help='network initialization [normal | xavier | kaiming | orthogonal | max]. Max seems to work well')
85
    parser.add_argument('--dropout_rate', default=0.25, type=float, help='0 - 0.25. Increasing dropout_rate helps overfitting. Some people have gone as high as 0.5. You can try adding more regularization')
86
87
    opt = parser.parse_known_args()[0]
88
    opt = parse_gpuids(opt)
89
    return opt
90
91
opt = parse_args()
92
device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
93
metadata, all_dataset = getCleanAllDataset(opt.dataroot, opt.ignore_missing_moltype, opt.ignore_missing_histype, opt.use_rnaseq)
94
95
### Creates a mapping from TCGA ID -> Image ROI
96
img_fnames = os.listdir(os.path.join(opt.dataroot, opt.roi_dir))
97
pat2img = {}
98
for pat, img_fname in zip([img_fname[:12] for img_fname in img_fnames], img_fnames):
99
    if pat not in pat2img.keys(): pat2img[pat] = []
100
    pat2img[pat].append(img_fname)
101
102
### Dictionary file containing split information
103
data_dict = {}
104
data_dict['data_pd'] = all_dataset
105
#data_dict['pat2img'] = pat2img
106
#data_dict['img_fnames'] = img_fnames
107
cv_splits = {}
108
109
### Extracting K-Fold Splits
110
pnas_splits = pd.read_csv(opt.dataroot+'pnas_splits.csv')
111
pnas_splits.columns = ['TCGA ID']+[str(k) for k in range(1, 16)]
112
pnas_splits.index = pnas_splits['TCGA ID']
113
pnas_splits = pnas_splits.drop(['TCGA ID'], axis=1)
114
115
### get path_feats
116
def get_vgg_features(model, device, img_path):
117
    if model is None:
118
        return img_path
119
    else:
120
        x_path = Image.open(img_path).convert('RGB')
121
        normalize = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
122
        x_path = torch.unsqueeze(normalize(x_path), dim=0)
123
        features, hazard = model(x_path=x_path.to(device))
124
        return features.cpu().detach().numpy()
125
126
### method for constructing aligned
127
def getAlignedMultimodalData(opt, model, device, all_dataset, pat_split, pat2img):
128
    x_patname, x_path, x_grph, x_omic, e, t, g = [], [], [], [], [], [], []
129
130
    for pat_name in pat_split:
131
        if pat_name not in all_dataset.index: continue
132
133
        for img_fname in pat2img[pat_name]:
134
            grph_fname = img_fname.rstrip('.png')+'.pt'
135
            assert grph_fname in os.listdir(os.path.join(opt.dataroot, '%s_%s' % (opt.roi_dir, opt.graph_feat_type)))
136
            assert all_dataset[all_dataset['TCGA ID'] == pat_name].shape[0] == 1
137
138
            x_patname.append(pat_name)
139
            x_path.append(get_vgg_features(model, device, os.path.join(opt.dataroot, opt.roi_dir, img_fname)))
140
            x_grph.append(os.path.join(opt.dataroot, '%s_%s' % (opt.roi_dir, opt.graph_feat_type), grph_fname))
141
            x_omic.append(np.array(all_dataset[all_dataset['TCGA ID'] == pat_name].drop(metadata, axis=1)))
142
            e.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['censored']))
143
            t.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['Survival months']))
144
            g.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['Grade']))
145
146
    return x_patname, x_path, x_grph, x_omic, e, t, g
147
148
print(all_dataset.shape)
149
150
for k in pnas_splits.columns:
151
    print('Creating Split %s' % k)
152
    pat_train = pnas_splits.index[pnas_splits[k] == 'Train'] if opt.make_all_train == 0 else pnas_splits.index
153
    pat_test = pnas_splits.index[pnas_splits[k] == 'Test']
154
    cv_splits[int(k)] = {}
155
156
    model = None
157
    if opt.use_vgg_features:
158
        load_path = os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%s.pt' % (opt.model_name, k))
159
        model_ckpt = torch.load(load_path, map_location=device)
160
        model_state_dict = model_ckpt['model_state_dict']
161
        if hasattr(model_state_dict, '_metadata'): del model_state_dict._metadata
162
        model = define_net(opt, None)
163
        if isinstance(model, torch.nn.DataParallel): model = model.module
164
        print('Loading the model from %s' % load_path)
165
        model.load_state_dict(model_state_dict)
166
        model.eval()
167
168
    train_x_patname, train_x_path, train_x_grph, train_x_omic, train_e, train_t, train_g = getAlignedMultimodalData(opt, model, device, all_dataset, pat_train, pat2img)
169
    test_x_patname, test_x_path, test_x_grph, test_x_omic, test_e, test_t, test_g = getAlignedMultimodalData(opt, model, device, all_dataset, pat_test, pat2img)
170
171
    train_x_omic, train_e, train_t = np.array(train_x_omic).squeeze(axis=1), np.array(train_e, dtype=np.float64), np.array(train_t, dtype=np.float64)
172
    test_x_omic, test_e, test_t = np.array(test_x_omic).squeeze(axis=1), np.array(test_e, dtype=np.float64), np.array(test_t, dtype=np.float64)
173
        
174
    scaler = preprocessing.StandardScaler().fit(train_x_omic)
175
    train_x_omic = scaler.transform(train_x_omic)
176
    test_x_omic = scaler.transform(test_x_omic)
177
178
    train_data = {'x_patname': train_x_patname,
179
                  'x_path':np.array(train_x_path),
180
                  'x_grph':train_x_grph,
181
                  'x_omic':train_x_omic,
182
                  'e':np.array(train_e, dtype=np.float64), 
183
                  't':np.array(train_t, dtype=np.float64),
184
                  'g':np.array(train_g, dtype=np.float64)}
185
186
    test_data = {'x_patname': test_x_patname,
187
                 'x_path':np.array(test_x_path),
188
                 'x_grph':test_x_grph,
189
                 'x_omic':test_x_omic,
190
                 'e':np.array(test_e, dtype=np.float64),
191
                 't':np.array(test_t, dtype=np.float64),
192
                 'g':np.array(test_g, dtype=np.float64)}
193
194
    dataset = {'train':train_data, 'test':test_data}
195
    cv_splits[int(k)] = dataset
196
197
    if opt.make_all_train: break
198
    
199
data_dict['cv_splits'] = cv_splits
200
201
pickle.dump(data_dict, open('%s/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (opt.dataroot, opt.roi_dir, opt.ignore_missing_moltype, opt.ignore_missing_histype, opt.use_vgg_features, '_rnaseq' if opt.use_rnaseq else ''), 'wb'))