|
a |
|
b/make_splits.py |
|
|
1 |
### data_loaders.py |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
import pickle |
|
|
5 |
|
|
|
6 |
import numpy as np |
|
|
7 |
import pandas as pd |
|
|
8 |
from PIL import Image |
|
|
9 |
from sklearn import preprocessing |
|
|
10 |
|
|
|
11 |
# Env |
|
|
12 |
from networks import define_net |
|
|
13 |
from utils import getCleanAllDataset |
|
|
14 |
import torch |
|
|
15 |
from torchvision import transforms |
|
|
16 |
from options import parse_gpuids |
|
|
17 |
|
|
|
18 |
### Initializes parser and data |
|
|
19 |
""" |
|
|
20 |
all_st |
|
|
21 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st # for training Surv Path, Surv Graph, and testing Surv Graph |
|
|
22 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st # for training Grad Path, Grad Graph, and testing Surv_graph |
|
|
23 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st # for training Surv Omic, Surv Graphomic |
|
|
24 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st # for training Grad Omic, Grad Graphomic |
|
|
25 |
|
|
|
26 |
all_st_patches_512 (no VGG) |
|
|
27 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st_patches_512 # for testing Surv Path |
|
|
28 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st_patches_512 # for testing Grad Path |
|
|
29 |
|
|
|
30 |
all_st_patches_512 (use VGG) |
|
|
31 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --gpu_ids 0 # for Surv Pathgraph |
|
|
32 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 1 # for Grad Pathgraph |
|
|
33 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --gpu_ids 2 # for Surv Pathomic, Pathgraphomic |
|
|
34 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 3 # for Grad Pathomic, Pathgraphomic |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --make_all_train 1 |
|
|
38 |
|
|
|
39 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1 |
|
|
40 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1 |
|
|
41 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15 --use_rnaseq 1 --gpu_ids 2 |
|
|
42 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --use_rnaseq 1 --act_type LSM --label_dim 3 --gpu_ids 3 |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --gpu_ids 0 |
|
|
46 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --use_rnaseq 1 --gpu_ids 0 |
|
|
47 |
|
|
|
48 |
python make_splits.py --ignore_missing_moltype 0 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 1 |
|
|
49 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --use_rnaseq 1 --act_type LSM --label_dim 3 --gpu_ids 1 |
|
|
50 |
|
|
|
51 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1 |
|
|
52 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 0 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name surv_15_rnaseq --gpu_ids 2 |
|
|
53 |
|
|
|
54 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 0 --roi_dir all_st --use_rnaseq 1 |
|
|
55 |
python make_splits.py --ignore_missing_moltype 1 --ignore_missing_histype 1 --use_vgg_features 1 --roi_dir all_st_patches_512 --exp_name grad_15 --act_type LSM --label_dim 3 --gpu_ids 3 |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
|
|
|
59 |
|
|
|
60 |
""" |
|
|
61 |
def parse_args(): |
|
|
62 |
parser = argparse.ArgumentParser() |
|
|
63 |
parser.add_argument('--dataroot', type=str, default='./data/TCGA_GBMLGG/', help="datasets") |
|
|
64 |
parser.add_argument('--roi_dir', type=str, default='all_st') |
|
|
65 |
parser.add_argument('--graph_feat_type', type=str, default='cpc', help="graph features to use") |
|
|
66 |
parser.add_argument('--ignore_missing_moltype', type=int, default=0, help="Ignore data points with missing molecular subtype") |
|
|
67 |
parser.add_argument('--ignore_missing_histype', type=int, default=0, help="Ignore data points with missign histology subtype") |
|
|
68 |
parser.add_argument('--make_all_train', type=int, default=0) |
|
|
69 |
parser.add_argument('--use_vgg_features', type=int, default=0) |
|
|
70 |
parser.add_argument('--use_rnaseq', type=int, default=0) |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/TCGA_GBMLGG/', help='models are saved here') |
|
|
74 |
parser.add_argument('--exp_name', type=str, default='surv_15_rnaseq', help='name of the project. It decides where to store samples and models') |
|
|
75 |
parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') |
|
|
76 |
parser.add_argument('--mode', type=str, default='path', help='mode') |
|
|
77 |
parser.add_argument('--model_name', type=str, default='path', help='mode') |
|
|
78 |
parser.add_argument('--task', type=str, default='surv', help='surv | grad') |
|
|
79 |
parser.add_argument('--act_type', type=str, default='Sigmoid', help='activation function') |
|
|
80 |
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') |
|
|
81 |
parser.add_argument('--label_dim', type=int, default=1, help='size of output') |
|
|
82 |
parser.add_argument('--batch_size', type=int, default=32, help="Number of batches to train/test for. Default: 256") |
|
|
83 |
parser.add_argument('--path_dim', type=int, default=32) |
|
|
84 |
parser.add_argument('--init_type', type=str, default='none', help='network initialization [normal | xavier | kaiming | orthogonal | max]. Max seems to work well') |
|
|
85 |
parser.add_argument('--dropout_rate', default=0.25, type=float, help='0 - 0.25. Increasing dropout_rate helps overfitting. Some people have gone as high as 0.5. You can try adding more regularization') |
|
|
86 |
|
|
|
87 |
opt = parser.parse_known_args()[0] |
|
|
88 |
opt = parse_gpuids(opt) |
|
|
89 |
return opt |
|
|
90 |
|
|
|
91 |
opt = parse_args() |
|
|
92 |
device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') |
|
|
93 |
metadata, all_dataset = getCleanAllDataset(opt.dataroot, opt.ignore_missing_moltype, opt.ignore_missing_histype, opt.use_rnaseq) |
|
|
94 |
|
|
|
95 |
### Creates a mapping from TCGA ID -> Image ROI |
|
|
96 |
img_fnames = os.listdir(os.path.join(opt.dataroot, opt.roi_dir)) |
|
|
97 |
pat2img = {} |
|
|
98 |
for pat, img_fname in zip([img_fname[:12] for img_fname in img_fnames], img_fnames): |
|
|
99 |
if pat not in pat2img.keys(): pat2img[pat] = [] |
|
|
100 |
pat2img[pat].append(img_fname) |
|
|
101 |
|
|
|
102 |
### Dictionary file containing split information |
|
|
103 |
data_dict = {} |
|
|
104 |
data_dict['data_pd'] = all_dataset |
|
|
105 |
#data_dict['pat2img'] = pat2img |
|
|
106 |
#data_dict['img_fnames'] = img_fnames |
|
|
107 |
cv_splits = {} |
|
|
108 |
|
|
|
109 |
### Extracting K-Fold Splits |
|
|
110 |
pnas_splits = pd.read_csv(opt.dataroot+'pnas_splits.csv') |
|
|
111 |
pnas_splits.columns = ['TCGA ID']+[str(k) for k in range(1, 16)] |
|
|
112 |
pnas_splits.index = pnas_splits['TCGA ID'] |
|
|
113 |
pnas_splits = pnas_splits.drop(['TCGA ID'], axis=1) |
|
|
114 |
|
|
|
115 |
### get path_feats |
|
|
116 |
def get_vgg_features(model, device, img_path): |
|
|
117 |
if model is None: |
|
|
118 |
return img_path |
|
|
119 |
else: |
|
|
120 |
x_path = Image.open(img_path).convert('RGB') |
|
|
121 |
normalize = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
|
122 |
x_path = torch.unsqueeze(normalize(x_path), dim=0) |
|
|
123 |
features, hazard = model(x_path=x_path.to(device)) |
|
|
124 |
return features.cpu().detach().numpy() |
|
|
125 |
|
|
|
126 |
### method for constructing aligned |
|
|
127 |
def getAlignedMultimodalData(opt, model, device, all_dataset, pat_split, pat2img): |
|
|
128 |
x_patname, x_path, x_grph, x_omic, e, t, g = [], [], [], [], [], [], [] |
|
|
129 |
|
|
|
130 |
for pat_name in pat_split: |
|
|
131 |
if pat_name not in all_dataset.index: continue |
|
|
132 |
|
|
|
133 |
for img_fname in pat2img[pat_name]: |
|
|
134 |
grph_fname = img_fname.rstrip('.png')+'.pt' |
|
|
135 |
assert grph_fname in os.listdir(os.path.join(opt.dataroot, '%s_%s' % (opt.roi_dir, opt.graph_feat_type))) |
|
|
136 |
assert all_dataset[all_dataset['TCGA ID'] == pat_name].shape[0] == 1 |
|
|
137 |
|
|
|
138 |
x_patname.append(pat_name) |
|
|
139 |
x_path.append(get_vgg_features(model, device, os.path.join(opt.dataroot, opt.roi_dir, img_fname))) |
|
|
140 |
x_grph.append(os.path.join(opt.dataroot, '%s_%s' % (opt.roi_dir, opt.graph_feat_type), grph_fname)) |
|
|
141 |
x_omic.append(np.array(all_dataset[all_dataset['TCGA ID'] == pat_name].drop(metadata, axis=1))) |
|
|
142 |
e.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['censored'])) |
|
|
143 |
t.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['Survival months'])) |
|
|
144 |
g.append(int(all_dataset[all_dataset['TCGA ID']==pat_name]['Grade'])) |
|
|
145 |
|
|
|
146 |
return x_patname, x_path, x_grph, x_omic, e, t, g |
|
|
147 |
|
|
|
148 |
print(all_dataset.shape) |
|
|
149 |
|
|
|
150 |
for k in pnas_splits.columns: |
|
|
151 |
print('Creating Split %s' % k) |
|
|
152 |
pat_train = pnas_splits.index[pnas_splits[k] == 'Train'] if opt.make_all_train == 0 else pnas_splits.index |
|
|
153 |
pat_test = pnas_splits.index[pnas_splits[k] == 'Test'] |
|
|
154 |
cv_splits[int(k)] = {} |
|
|
155 |
|
|
|
156 |
model = None |
|
|
157 |
if opt.use_vgg_features: |
|
|
158 |
load_path = os.path.join(opt.checkpoints_dir, opt.exp_name, opt.model_name, '%s_%s.pt' % (opt.model_name, k)) |
|
|
159 |
model_ckpt = torch.load(load_path, map_location=device) |
|
|
160 |
model_state_dict = model_ckpt['model_state_dict'] |
|
|
161 |
if hasattr(model_state_dict, '_metadata'): del model_state_dict._metadata |
|
|
162 |
model = define_net(opt, None) |
|
|
163 |
if isinstance(model, torch.nn.DataParallel): model = model.module |
|
|
164 |
print('Loading the model from %s' % load_path) |
|
|
165 |
model.load_state_dict(model_state_dict) |
|
|
166 |
model.eval() |
|
|
167 |
|
|
|
168 |
train_x_patname, train_x_path, train_x_grph, train_x_omic, train_e, train_t, train_g = getAlignedMultimodalData(opt, model, device, all_dataset, pat_train, pat2img) |
|
|
169 |
test_x_patname, test_x_path, test_x_grph, test_x_omic, test_e, test_t, test_g = getAlignedMultimodalData(opt, model, device, all_dataset, pat_test, pat2img) |
|
|
170 |
|
|
|
171 |
train_x_omic, train_e, train_t = np.array(train_x_omic).squeeze(axis=1), np.array(train_e, dtype=np.float64), np.array(train_t, dtype=np.float64) |
|
|
172 |
test_x_omic, test_e, test_t = np.array(test_x_omic).squeeze(axis=1), np.array(test_e, dtype=np.float64), np.array(test_t, dtype=np.float64) |
|
|
173 |
|
|
|
174 |
scaler = preprocessing.StandardScaler().fit(train_x_omic) |
|
|
175 |
train_x_omic = scaler.transform(train_x_omic) |
|
|
176 |
test_x_omic = scaler.transform(test_x_omic) |
|
|
177 |
|
|
|
178 |
train_data = {'x_patname': train_x_patname, |
|
|
179 |
'x_path':np.array(train_x_path), |
|
|
180 |
'x_grph':train_x_grph, |
|
|
181 |
'x_omic':train_x_omic, |
|
|
182 |
'e':np.array(train_e, dtype=np.float64), |
|
|
183 |
't':np.array(train_t, dtype=np.float64), |
|
|
184 |
'g':np.array(train_g, dtype=np.float64)} |
|
|
185 |
|
|
|
186 |
test_data = {'x_patname': test_x_patname, |
|
|
187 |
'x_path':np.array(test_x_path), |
|
|
188 |
'x_grph':test_x_grph, |
|
|
189 |
'x_omic':test_x_omic, |
|
|
190 |
'e':np.array(test_e, dtype=np.float64), |
|
|
191 |
't':np.array(test_t, dtype=np.float64), |
|
|
192 |
'g':np.array(test_g, dtype=np.float64)} |
|
|
193 |
|
|
|
194 |
dataset = {'train':train_data, 'test':test_data} |
|
|
195 |
cv_splits[int(k)] = dataset |
|
|
196 |
|
|
|
197 |
if opt.make_all_train: break |
|
|
198 |
|
|
|
199 |
data_dict['cv_splits'] = cv_splits |
|
|
200 |
|
|
|
201 |
pickle.dump(data_dict, open('%s/splits/gbmlgg15cv_%s_%d_%d_%d%s.pkl' % (opt.dataroot, opt.roi_dir, opt.ignore_missing_moltype, opt.ignore_missing_histype, opt.use_vgg_features, '_rnaseq' if opt.use_rnaseq else ''), 'wb')) |