|
a |
|
b/data_loader.py |
|
|
1 |
""" Code for data loader """ |
|
|
2 |
import numpy as np |
|
|
3 |
import os, sys, copy |
|
|
4 |
import random |
|
|
5 |
import tensorflow as tf |
|
|
6 |
|
|
|
7 |
from sklearn.model_selection import StratifiedKFold |
|
|
8 |
from tensorflow.python.platform import flags |
|
|
9 |
|
|
|
10 |
import tqdm |
|
|
11 |
import pickle as pkl |
|
|
12 |
|
|
|
13 |
FLAGS = flags.FLAGS |
|
|
14 |
|
|
|
15 |
PADDING_ID = 1016 # make the padding id as the number of group code |
|
|
16 |
# maximum of group code index is 1015, start from 0 |
|
|
17 |
N_WORDS = 1017 |
|
|
18 |
TIMESTEPS = 21 # choice by statistics |
|
|
19 |
|
|
|
20 |
TASKS = ["AD", "PD", "DM", "AM", "MCI"] |
|
|
21 |
|
|
|
22 |
class DataLoader(object): |
|
|
23 |
''' |
|
|
24 |
Data Loader capable of generating batches of ohsu data. |
|
|
25 |
''' |
|
|
26 |
def __init__(self, source, target, true_target, n_tasks, n_samples_per_task, meta_batch_size): |
|
|
27 |
""" |
|
|
28 |
Args: |
|
|
29 |
source: source tasks |
|
|
30 |
target: simulated target task(s) |
|
|
31 |
true_target: true target task (to test) |
|
|
32 |
n_tasks: number of tasks including both source and simulated target tasks |
|
|
33 |
n_samples_per_task: number samples to generate per task in one batch |
|
|
34 |
meta_batch_size: size of meta batch size (e.g. number of functions) |
|
|
35 |
""" |
|
|
36 |
### load data: training |
|
|
37 |
self.intmd_path = 'intermediate/' |
|
|
38 |
self.source = source |
|
|
39 |
self.target = target |
|
|
40 |
self.timesteps = TIMESTEPS |
|
|
41 |
self.code_size = 0 |
|
|
42 |
# self.code_size = N_WORDS-1 # set the code_size as the number of all the possible codes |
|
|
43 |
# # in order to use in pretrain |
|
|
44 |
self.task_code_size = dict() # maintain a dictionary for icd codes, disease : code list |
|
|
45 |
print ("The selected timesteps is: ", self.timesteps) |
|
|
46 |
|
|
|
47 |
self.data_to_show = dict() |
|
|
48 |
self.label_to_show = dict() |
|
|
49 |
self.ratio_t = 0.8 |
|
|
50 |
self.pat_reduce = False |
|
|
51 |
self.code_set = set() |
|
|
52 |
self.data_s, self.data_t, self.label_s, self.label_t = self.load_data() |
|
|
53 |
|
|
|
54 |
## load data: validate & test |
|
|
55 |
self.true_target = true_target |
|
|
56 |
if FLAGS.method == "mlp": |
|
|
57 |
data_tt, label_tt = self.load_data_vector(self.true_target[0]) # only 1 true target, index is 0 |
|
|
58 |
elif FLAGS.method == "rnn" or FLAGS.method == "cnn": |
|
|
59 |
data_tt, label_tt = self.load_data_matrix(self.true_target[0]) |
|
|
60 |
# compute code_size |
|
|
61 |
self.code_size = max([cz for cz in self.task_code_size.values()]) |
|
|
62 |
print ("The code_size is: ", self.code_size) |
|
|
63 |
# make data the same size matrices |
|
|
64 |
data_tt, label_tt = self.get_data_prepared(data_tt, label_tt) |
|
|
65 |
|
|
|
66 |
for i in range(len(self.source)): |
|
|
67 |
self.data_s[i], self.label_s[i] = self.get_data_prepared(self.data_s[i], self.label_s[i]) |
|
|
68 |
|
|
|
69 |
for i in range(len(self.target)): |
|
|
70 |
self.data_t[i], self.label_t[i] = self.get_data_prepared(self.data_t[i], self.label_t[i]) |
|
|
71 |
|
|
|
72 |
# cross validation for true target |
|
|
73 |
self.n_fold = 5 |
|
|
74 |
self.get_cross_val(data_tt, label_tt, n_fold=self.n_fold) |
|
|
75 |
|
|
|
76 |
### set model params |
|
|
77 |
self.meta_batch_size = meta_batch_size |
|
|
78 |
self.n_samples_per_task = n_samples_per_task # in one meta batch |
|
|
79 |
self.n_tasks = n_tasks |
|
|
80 |
self.n_words = N_WORDS |
|
|
81 |
|
|
|
82 |
## generate finetune data |
|
|
83 |
self.tt_sample, self.tt_label = dict(), dict() |
|
|
84 |
self.tt_sample_val, self.tt_label_val = dict(), dict() |
|
|
85 |
for ifold in range(self.n_fold): # generate n-fold cv data for finetuning |
|
|
86 |
self.tt_sample[ifold], self.tt_label[ifold] = self.generate_finetune_data(is_training=True, ifold=ifold) |
|
|
87 |
self.tt_sample_val[ifold], self.tt_label_val[ifold] = self.generate_finetune_data(is_training=False, ifold=ifold) |
|
|
88 |
|
|
|
89 |
self.episode = self.generate_meta_idx_batches(is_training=True) |
|
|
90 |
self.episode_val = dict() |
|
|
91 |
for ifold in range(self.n_fold): # true target validation |
|
|
92 |
self.episode_val[ifold] = self.generate_meta_idx_batches(is_training=False, ifold=ifold) |
|
|
93 |
|
|
|
94 |
def get_cross_val(self, X, y, n_fold=5): |
|
|
95 |
'''split the true target into train (might be useful in finetunning) and test (for evaluation)''' |
|
|
96 |
self.data_tt_tr, self.data_tt_val = dict(), dict() |
|
|
97 |
self.label_tt_tr, self.label_tt_val = dict(), dict() |
|
|
98 |
skf = StratifiedKFold(n_splits = n_fold, random_state = 99991) |
|
|
99 |
ifold = 0 |
|
|
100 |
print ("split the true target ...") |
|
|
101 |
for train_index, test_index in skf.split(X, y): |
|
|
102 |
self.data_tt_tr[ifold], self.data_tt_val[ifold] = X[train_index], X[test_index] |
|
|
103 |
self.label_tt_tr[ifold], self.label_tt_val[ifold] = y[train_index], y[test_index] |
|
|
104 |
ifold+=1 |
|
|
105 |
|
|
|
106 |
def load_data_matrix(self, task): |
|
|
107 |
'''load data sequential vectors for cnn or rnn. One matrix per sample''' |
|
|
108 |
X_pos, y_pos = [], [] |
|
|
109 |
X_neg, y_neg = [], [] |
|
|
110 |
with open(self.intmd_path + task + '.pos.pkl', 'rb') as f: |
|
|
111 |
X_pos_mat, y_pos_mat = pkl.load(f) |
|
|
112 |
f.close() |
|
|
113 |
|
|
|
114 |
with open(self.intmd_path + task + '.neg.pkl', 'rb') as f: |
|
|
115 |
X_neg_mat, y_neg_mat = pkl.load(f) |
|
|
116 |
f.close() |
|
|
117 |
|
|
|
118 |
print ("The number of positive samles in task %s is: " %task, len(y_pos_mat)) |
|
|
119 |
print ("The number of negative samles in task %s is: " %task, len(y_neg_mat)) |
|
|
120 |
|
|
|
121 |
for s, array in X_pos_mat.items(): |
|
|
122 |
X_pos.append(array) # X_pos_mat[s] size: seq_len x n_words |
|
|
123 |
y_pos.append(y_pos_mat[s]) |
|
|
124 |
|
|
|
125 |
for s, array in X_neg_mat.items(): |
|
|
126 |
X_neg.append(array) |
|
|
127 |
y_neg.append(y_neg_mat[s]) |
|
|
128 |
return (X_pos, X_neg), (y_pos, y_neg) |
|
|
129 |
|
|
|
130 |
def get_fixed_timesteps(self, X_pos, X_neg): |
|
|
131 |
'''delete the first several timesteps according to the selected number''' |
|
|
132 |
# postives: |
|
|
133 |
for i in range(len(X_pos)): |
|
|
134 |
timesteps = X_pos[i].shape[0] |
|
|
135 |
if timesteps > self.timesteps: |
|
|
136 |
X_pos[i] = X_pos[i][timesteps-self.timesteps:, :] |
|
|
137 |
# negatives: |
|
|
138 |
for i in range(len(X_neg)): |
|
|
139 |
timesteps = X_neg[i].shape[0] |
|
|
140 |
if timesteps > self.timesteps: |
|
|
141 |
X_neg[i] = X_neg[i][timesteps-self.timesteps:, :] |
|
|
142 |
return (X_pos, X_neg) |
|
|
143 |
|
|
|
144 |
def get_fixed_codesize(self, X_pos, X_neg): |
|
|
145 |
'''delete the -1 values according to the code size''' |
|
|
146 |
# postives: |
|
|
147 |
for i in range(len(X_pos)): |
|
|
148 |
code_size = X_pos[i].shape[1] |
|
|
149 |
if code_size > self.code_size: |
|
|
150 |
X_pos[i] = X_pos[i][:, :self.code_size] |
|
|
151 |
# negatives: |
|
|
152 |
for i in range(len(X_neg)): |
|
|
153 |
code_size = X_neg[i].shape[1] |
|
|
154 |
if code_size > self.code_size: |
|
|
155 |
X_neg[i] = X_neg[i][:, :self.code_size] |
|
|
156 |
return (X_pos, X_neg) |
|
|
157 |
|
|
|
158 |
def get_feed_records(self, X): |
|
|
159 |
'''generate ehrs as a 3d tensor that can be used to feed networks''' |
|
|
160 |
n_samples = len(X) |
|
|
161 |
X_new = np.zeros([n_samples, self.timesteps, self.code_size], dtype="int32") + PADDING_ID |
|
|
162 |
for i in range(n_samples): |
|
|
163 |
timesteps = X[i].shape[0] |
|
|
164 |
X_new[i, self.timesteps-timesteps:, :] = X[i] |
|
|
165 |
return X_new |
|
|
166 |
|
|
|
167 |
def get_data_prepared(self, data, label): |
|
|
168 |
X_pos, X_neg = data |
|
|
169 |
y_pos, y_neg = label |
|
|
170 |
|
|
|
171 |
X_pos, X_neg = self.get_fixed_timesteps(X_pos, X_neg) |
|
|
172 |
X_pos, X_neg = self.get_fixed_codesize(X_pos, X_neg) |
|
|
173 |
X_pos = self.get_feed_records(X_pos) |
|
|
174 |
X_neg = self.get_feed_records(X_neg) |
|
|
175 |
# concatenate pos and neg |
|
|
176 |
data, label = np.concatenate((X_pos, X_neg), axis=0), np.concatenate((y_pos, y_neg), axis=0) |
|
|
177 |
return data, label |
|
|
178 |
|
|
|
179 |
def load_data(self): |
|
|
180 |
'''load data vectors or matrices for samples with labels''' |
|
|
181 |
data_s, label_s = dict(), dict() |
|
|
182 |
data_t, label_t = dict(), dict() |
|
|
183 |
|
|
|
184 |
self.dim_input = [TIMESTEPS, N_WORDS] |
|
|
185 |
for i in range(len(self.source)): |
|
|
186 |
data_s[i], label_s[i] = self.load_data_matrix(self.source[i]) |
|
|
187 |
|
|
|
188 |
for i in range(len(self.target)): |
|
|
189 |
data_t[i], label_t[i] = self.load_data_matrix(self.target[i]) |
|
|
190 |
return data_s, data_t, label_s, label_t |
|
|
191 |
|
|
|
192 |
def generate_finetune_data(self, is_training=True, ifold=0): |
|
|
193 |
''' get finetuning samples and labels''' |
|
|
194 |
try: |
|
|
195 |
if is_training: |
|
|
196 |
sample = self.data_tt_tr[ifold] |
|
|
197 |
label = self.label_tt_tr[ifold] |
|
|
198 |
else: |
|
|
199 |
sample = self.data_tt_val[ifold] |
|
|
200 |
label = self.label_tt_val[ifold] |
|
|
201 |
except: |
|
|
202 |
print ("Error: split training and validate first!") |
|
|
203 |
return sample, label |
|
|
204 |
|
|
|
205 |
def generate_meta_batches(self, is_training=True, ifold=0): |
|
|
206 |
''' get samples and the corresponding labels with episode for batching''' |
|
|
207 |
if is_training: # training |
|
|
208 |
prefix = "metatrain" |
|
|
209 |
data_s = self.data_s |
|
|
210 |
data_t = self.data_t |
|
|
211 |
label_s = self.label_s |
|
|
212 |
label_t = self.label_t |
|
|
213 |
self.n_total_batches = FLAGS.n_total_batches |
|
|
214 |
else: # test & eval, say, true target task is used here |
|
|
215 |
try: |
|
|
216 |
prefix = "metaval" + str(ifold) |
|
|
217 |
data_s = self.data_s |
|
|
218 |
label_s = self.label_s |
|
|
219 |
data_t = self.data_tt_val[ifold] |
|
|
220 |
label_t = self.label_tt_val[ifold] |
|
|
221 |
self.n_total_batches = int(len(label_t)/self.n_samples_per_task) |
|
|
222 |
except: |
|
|
223 |
print ("Error: split training and validate first!") |
|
|
224 |
# check if the meta batch file dumped |
|
|
225 |
if os.path.isfile(self.intmd_path + "meta.batch." + prefix + ".pkl"): |
|
|
226 |
print ('meta batch file exits') |
|
|
227 |
with open(self.intmd_path + "meta.batch." + prefix + ".pkl", 'rb') as f: |
|
|
228 |
sample, label = pkl.load(f) |
|
|
229 |
f.close() |
|
|
230 |
else: |
|
|
231 |
# generate episode |
|
|
232 |
sample, label = [], [] |
|
|
233 |
s_dict, t_dict = dict(), dict() |
|
|
234 |
for i in range(len(self.source)): |
|
|
235 |
s_dict[i] = range(len(self.label_s[i])) |
|
|
236 |
for i in range(len(self.target)): |
|
|
237 |
t_dict[i] = range(len(self.label_t[i])) |
|
|
238 |
batch_count = 0 |
|
|
239 |
for _ in tqdm.tqdm(range(self.n_total_batches), 'generating meta batches'): # progress bar |
|
|
240 |
# i.e., sample 16 patients from selected tasks |
|
|
241 |
# len of spl and lbl: 4 * 16 |
|
|
242 |
spl, lbl = [], [] # samples and labels in one episode |
|
|
243 |
for i in range(len(self.source)): # fetch from source tasks olderly |
|
|
244 |
### do not keep pos/neg ratio |
|
|
245 |
s_idx = random.sample(s_dict[i], self.n_samples_per_task) |
|
|
246 |
spl.extend(data_s[i][s_idx]) |
|
|
247 |
lbl.extend(label_s[i][s_idx]) |
|
|
248 |
### do not keep pos/neg ratio |
|
|
249 |
if is_training: |
|
|
250 |
t_idx = random.sample(t_dict[0], self.n_samples_per_task) |
|
|
251 |
spl.extend(data_t[0][t_idx]) |
|
|
252 |
lbl.extend(label_t[0][t_idx]) |
|
|
253 |
else: |
|
|
254 |
spl.extend(data_t[batch_count*self.n_samples_per_task:(batch_count+1)*self.n_samples_per_task]) |
|
|
255 |
lbl.extend(label_t[batch_count*self.n_samples_per_task:(batch_count+1)*self.n_samples_per_task]) |
|
|
256 |
batch_count += 1 |
|
|
257 |
# add meta_batch |
|
|
258 |
sample.append(spl) |
|
|
259 |
label.append(lbl) |
|
|
260 |
|
|
|
261 |
print ("batch counts: ", batch_count) |
|
|
262 |
sample = np.array(sample, dtype="float32") |
|
|
263 |
label = np.array(label, dtype="float32") |
|
|
264 |
return sample, label |
|
|
265 |
|
|
|
266 |
def generate_meta_idx_batches(self, is_training=True, ifold=0): |
|
|
267 |
''' get samples and the corresponding labels with episode for batching''' |
|
|
268 |
if is_training: # training |
|
|
269 |
prefix = "metatrain" |
|
|
270 |
data_s = self.data_s |
|
|
271 |
data_t = self.data_t |
|
|
272 |
label_s = self.label_s |
|
|
273 |
label_t = self.label_t |
|
|
274 |
self.n_total_batches = FLAGS.n_total_batches |
|
|
275 |
else: # test & eval, say, true target task is used here |
|
|
276 |
try: |
|
|
277 |
prefix = "metaval" + str(ifold) |
|
|
278 |
data_s = self.data_s |
|
|
279 |
label_s = self.label_s |
|
|
280 |
data_t = self.data_tt_val[ifold] |
|
|
281 |
label_t = self.label_tt_val[ifold] |
|
|
282 |
self.n_total_batches = int(len(label_t)/self.n_samples_per_task) |
|
|
283 |
print (data_t.shape) |
|
|
284 |
print (label_t.shape) |
|
|
285 |
print (len(label_t)) |
|
|
286 |
except: |
|
|
287 |
print ("Error: split training and validate first!") |
|
|
288 |
|
|
|
289 |
# generate episode |
|
|
290 |
episode = [] |
|
|
291 |
s_dict, t_dict = dict(), dict() |
|
|
292 |
for i in range(len(self.source)): |
|
|
293 |
s_dict[i] = range(len(self.label_s[i])) |
|
|
294 |
for i in range(len(self.target)): |
|
|
295 |
t_dict[i] = range(len(self.label_t[i])) |
|
|
296 |
batch_count = 0 |
|
|
297 |
for _ in tqdm.tqdm(range(self.n_total_batches), 'generating meta batches'): # progress bar |
|
|
298 |
# i.e., sample 16 patients from selected tasks |
|
|
299 |
# len of spl and lbl: 4 * 16 |
|
|
300 |
idx = [] # index in one episode |
|
|
301 |
for i in range(len(self.source)): # fetch from source tasks olderly |
|
|
302 |
### do not keep pos/neg ratio |
|
|
303 |
s_idx = random.sample(s_dict[i], self.n_samples_per_task) |
|
|
304 |
idx.extend(s_idx) |
|
|
305 |
### do not keep pos/neg ratio |
|
|
306 |
if is_training: |
|
|
307 |
t_idx = random.sample(t_dict[0], self.n_samples_per_task) |
|
|
308 |
idx.extend(t_idx) |
|
|
309 |
else: |
|
|
310 |
t_idx = range(batch_count*self.n_samples_per_task, (batch_count+1)*self.n_samples_per_task) |
|
|
311 |
idx.extend(t_idx) |
|
|
312 |
batch_count += 1 |
|
|
313 |
# add meta_batch |
|
|
314 |
episode.append(idx) |
|
|
315 |
|
|
|
316 |
print ("batch counts: ", batch_count) |
|
|
317 |
return episode |