a b/data_loader.py
1
""" Code for data loader """
2
import numpy as np
3
import os, sys, copy
4
import random
5
import tensorflow as tf
6
7
from sklearn.model_selection import StratifiedKFold
8
from tensorflow.python.platform import flags
9
10
import tqdm
11
import pickle as pkl
12
13
FLAGS = flags.FLAGS
14
15
PADDING_ID = 1016 # make the padding id as the number of group code
16
                  # maximum of group code index is 1015, start from 0
17
N_WORDS = 1017
18
TIMESTEPS = 21 # choice by statistics
19
20
TASKS = ["AD", "PD", "DM", "AM", "MCI"]
21
22
class DataLoader(object):
23
    '''
24
    Data Loader capable of generating batches of ohsu data.
25
    '''
26
    def __init__(self, source, target, true_target, n_tasks, n_samples_per_task, meta_batch_size):
27
        """
28
        Args:
29
            source:             source tasks
30
            target:             simulated target task(s)
31
            true_target:        true target task (to test)
32
            n_tasks:            number of tasks including both source and simulated target tasks
33
            n_samples_per_task: number samples to generate per task in one batch
34
            meta_batch_size:    size of meta batch size (e.g. number of functions)
35
        """
36
        ### load data: training
37
        self.intmd_path = 'intermediate/'
38
        self.source = source
39
        self.target = target
40
        self.timesteps = TIMESTEPS
41
        self.code_size = 0
42
        # self.code_size = N_WORDS-1 # set the code_size as the number of all the possible codes
43
        #                            # in order to use in pretrain
44
        self.task_code_size = dict() # maintain a dictionary for icd codes, disease : code list
45
        print ("The selected timesteps is: ", self.timesteps)
46
47
        self.data_to_show = dict()
48
        self.label_to_show = dict()
49
        self.ratio_t = 0.8
50
        self.pat_reduce = False
51
        self.code_set = set()
52
        self.data_s, self.data_t, self.label_s, self.label_t = self.load_data()
53
54
        ## load data: validate & test
55
        self.true_target = true_target
56
        if FLAGS.method == "mlp":
57
            data_tt, label_tt = self.load_data_vector(self.true_target[0]) # only 1 true target, index is 0
58
        elif FLAGS.method == "rnn" or FLAGS.method == "cnn":
59
            data_tt, label_tt = self.load_data_matrix(self.true_target[0])
60
            # compute code_size
61
            self.code_size = max([cz for cz in self.task_code_size.values()])
62
            print ("The code_size is: ", self.code_size)
63
            # make data the same size matrices
64
            data_tt, label_tt = self.get_data_prepared(data_tt, label_tt)
65
66
            for i in range(len(self.source)):
67
                self.data_s[i], self.label_s[i] = self.get_data_prepared(self.data_s[i], self.label_s[i])
68
69
            for i in range(len(self.target)):
70
                self.data_t[i], self.label_t[i] = self.get_data_prepared(self.data_t[i], self.label_t[i])
71
72
        # cross validation for true target
73
        self.n_fold = 5
74
        self.get_cross_val(data_tt, label_tt, n_fold=self.n_fold)
75
76
        ### set model params
77
        self.meta_batch_size = meta_batch_size
78
        self.n_samples_per_task = n_samples_per_task # in one meta batch
79
        self.n_tasks = n_tasks
80
        self.n_words = N_WORDS
81
82
        ## generate finetune data
83
        self.tt_sample, self.tt_label = dict(), dict()
84
        self.tt_sample_val, self.tt_label_val = dict(), dict()
85
        for ifold in range(self.n_fold): # generate n-fold cv data for finetuning
86
            self.tt_sample[ifold], self.tt_label[ifold] = self.generate_finetune_data(is_training=True, ifold=ifold)
87
            self.tt_sample_val[ifold], self.tt_label_val[ifold] = self.generate_finetune_data(is_training=False, ifold=ifold)
88
89
        self.episode = self.generate_meta_idx_batches(is_training=True)
90
        self.episode_val = dict()
91
        for ifold in range(self.n_fold): # true target validation
92
            self.episode_val[ifold] = self.generate_meta_idx_batches(is_training=False, ifold=ifold)
93
94
    def get_cross_val(self, X, y, n_fold=5):
95
        '''split the true target into train (might be useful in finetunning) and test (for evaluation)'''
96
        self.data_tt_tr, self.data_tt_val = dict(), dict()
97
        self.label_tt_tr, self.label_tt_val = dict(), dict()
98
        skf = StratifiedKFold(n_splits = n_fold, random_state = 99991)
99
        ifold = 0
100
        print ("split the true target ...")
101
        for train_index, test_index in skf.split(X, y):
102
            self.data_tt_tr[ifold], self.data_tt_val[ifold] = X[train_index], X[test_index]
103
            self.label_tt_tr[ifold], self.label_tt_val[ifold] = y[train_index], y[test_index]
104
            ifold+=1
105
106
    def load_data_matrix(self, task):
107
        '''load data sequential vectors for cnn or rnn. One matrix per sample'''
108
        X_pos, y_pos = [], []
109
        X_neg, y_neg = [], []
110
        with open(self.intmd_path + task + '.pos.pkl', 'rb') as f:
111
            X_pos_mat, y_pos_mat = pkl.load(f)
112
            f.close()
113
114
        with open(self.intmd_path + task + '.neg.pkl', 'rb') as f:
115
            X_neg_mat, y_neg_mat = pkl.load(f)
116
            f.close()
117
118
        print ("The number of positive samles in task %s is: " %task, len(y_pos_mat))
119
        print ("The number of negative samles in task %s is: " %task, len(y_neg_mat))
120
121
        for s, array in X_pos_mat.items():
122
             X_pos.append(array) # X_pos_mat[s] size: seq_len x n_words
123
             y_pos.append(y_pos_mat[s])
124
125
        for s, array in X_neg_mat.items():
126
             X_neg.append(array)
127
             y_neg.append(y_neg_mat[s])
128
        return (X_pos, X_neg), (y_pos, y_neg)
129
130
    def get_fixed_timesteps(self, X_pos, X_neg):
131
        '''delete the first several timesteps according to the selected number'''
132
        # postives:
133
        for i in range(len(X_pos)):
134
            timesteps = X_pos[i].shape[0]
135
            if timesteps > self.timesteps:
136
                X_pos[i] = X_pos[i][timesteps-self.timesteps:, :]
137
        # negatives:
138
        for i in range(len(X_neg)):
139
            timesteps = X_neg[i].shape[0]
140
            if timesteps > self.timesteps:
141
                X_neg[i] = X_neg[i][timesteps-self.timesteps:, :]
142
        return (X_pos, X_neg)
143
144
    def get_fixed_codesize(self, X_pos, X_neg):
145
        '''delete the -1 values according to the code size'''
146
        # postives:
147
        for i in range(len(X_pos)):
148
            code_size = X_pos[i].shape[1]
149
            if code_size > self.code_size:
150
                X_pos[i] = X_pos[i][:, :self.code_size]
151
        # negatives:
152
        for i in range(len(X_neg)):
153
            code_size = X_neg[i].shape[1]
154
            if code_size > self.code_size:
155
                X_neg[i] = X_neg[i][:, :self.code_size]
156
        return (X_pos, X_neg)
157
158
    def get_feed_records(self, X):
159
        '''generate ehrs as a 3d tensor that can be used to feed networks'''
160
        n_samples = len(X)
161
        X_new = np.zeros([n_samples, self.timesteps, self.code_size], dtype="int32") + PADDING_ID
162
        for i in range(n_samples):
163
            timesteps = X[i].shape[0]
164
            X_new[i, self.timesteps-timesteps:, :] = X[i]
165
        return X_new
166
167
    def get_data_prepared(self, data, label):
168
        X_pos, X_neg = data
169
        y_pos, y_neg = label
170
171
        X_pos, X_neg = self.get_fixed_timesteps(X_pos, X_neg)
172
        X_pos, X_neg = self.get_fixed_codesize(X_pos, X_neg)
173
        X_pos = self.get_feed_records(X_pos)
174
        X_neg = self.get_feed_records(X_neg)
175
        # concatenate pos and neg
176
        data, label = np.concatenate((X_pos, X_neg), axis=0), np.concatenate((y_pos, y_neg), axis=0)
177
        return data, label
178
179
    def load_data(self):
180
        '''load data vectors or matrices for samples with labels'''
181
        data_s, label_s = dict(), dict()
182
        data_t, label_t = dict(), dict()
183
184
        self.dim_input = [TIMESTEPS, N_WORDS]
185
        for i in range(len(self.source)):
186
            data_s[i], label_s[i] = self.load_data_matrix(self.source[i])
187
188
        for i in range(len(self.target)):
189
            data_t[i], label_t[i] = self.load_data_matrix(self.target[i])
190
        return data_s, data_t, label_s, label_t
191
192
    def generate_finetune_data(self, is_training=True, ifold=0):
193
        ''' get finetuning samples and labels'''
194
        try:
195
            if is_training:
196
                sample = self.data_tt_tr[ifold]
197
                label = self.label_tt_tr[ifold]
198
            else:
199
                sample = self.data_tt_val[ifold]
200
                label = self.label_tt_val[ifold]
201
        except:
202
            print ("Error: split training and validate first!")
203
        return sample, label
204
205
    def generate_meta_batches(self, is_training=True, ifold=0):
206
        ''' get samples and the corresponding labels with episode for batching'''
207
        if is_training: # training
208
            prefix = "metatrain"
209
            data_s = self.data_s
210
            data_t = self.data_t
211
            label_s = self.label_s
212
            label_t = self.label_t
213
            self.n_total_batches = FLAGS.n_total_batches
214
        else: # test & eval, say, true target task is used here
215
            try:
216
                prefix = "metaval" + str(ifold)
217
                data_s = self.data_s
218
                label_s = self.label_s
219
                data_t = self.data_tt_val[ifold]
220
                label_t = self.label_tt_val[ifold]
221
                self.n_total_batches = int(len(label_t)/self.n_samples_per_task)
222
            except:
223
                print ("Error: split training and validate first!")
224
        # check if the meta batch file dumped
225
        if os.path.isfile(self.intmd_path + "meta.batch." + prefix + ".pkl"):
226
            print ('meta batch file exits')
227
            with open(self.intmd_path + "meta.batch." + prefix + ".pkl", 'rb') as f:
228
                sample, label = pkl.load(f)
229
                f.close()
230
        else:
231
            # generate episode
232
            sample, label = [], []
233
            s_dict, t_dict = dict(), dict()
234
            for i in range(len(self.source)):
235
                s_dict[i] = range(len(self.label_s[i]))
236
            for i in range(len(self.target)):
237
                t_dict[i] = range(len(self.label_t[i]))
238
            batch_count = 0
239
            for _ in tqdm.tqdm(range(self.n_total_batches), 'generating meta batches'): # progress bar
240
                # i.e., sample 16 patients from selected tasks
241
                # len of spl and lbl: 4 * 16
242
                spl, lbl = [], [] # samples and labels in one episode
243
                for i in range(len(self.source)): # fetch from source tasks olderly
244
                    ### do not keep pos/neg ratio
245
                    s_idx = random.sample(s_dict[i], self.n_samples_per_task)
246
                    spl.extend(data_s[i][s_idx])
247
                    lbl.extend(label_s[i][s_idx])
248
                ### do not keep pos/neg ratio
249
                if is_training:
250
                    t_idx = random.sample(t_dict[0], self.n_samples_per_task)
251
                    spl.extend(data_t[0][t_idx])
252
                    lbl.extend(label_t[0][t_idx])
253
                else:
254
                    spl.extend(data_t[batch_count*self.n_samples_per_task:(batch_count+1)*self.n_samples_per_task])
255
                    lbl.extend(label_t[batch_count*self.n_samples_per_task:(batch_count+1)*self.n_samples_per_task])
256
                batch_count += 1
257
                # add meta_batch
258
                sample.append(spl)
259
                label.append(lbl)
260
261
        print ("batch counts: ", batch_count)
262
        sample = np.array(sample, dtype="float32")
263
        label = np.array(label, dtype="float32")
264
        return sample, label
265
266
    def generate_meta_idx_batches(self, is_training=True, ifold=0):
267
        ''' get samples and the corresponding labels with episode for batching'''
268
        if is_training: # training
269
            prefix = "metatrain"
270
            data_s = self.data_s
271
            data_t = self.data_t
272
            label_s = self.label_s
273
            label_t = self.label_t
274
            self.n_total_batches = FLAGS.n_total_batches
275
        else: # test & eval, say, true target task is used here
276
            try:
277
                prefix = "metaval" + str(ifold)
278
                data_s = self.data_s
279
                label_s = self.label_s
280
                data_t = self.data_tt_val[ifold]
281
                label_t = self.label_tt_val[ifold]
282
                self.n_total_batches = int(len(label_t)/self.n_samples_per_task)
283
                print (data_t.shape)
284
                print (label_t.shape)
285
                print (len(label_t))
286
            except:
287
                print ("Error: split training and validate first!")
288
289
        # generate episode
290
        episode = []
291
        s_dict, t_dict = dict(), dict()
292
        for i in range(len(self.source)):
293
            s_dict[i] = range(len(self.label_s[i]))
294
        for i in range(len(self.target)):
295
            t_dict[i] = range(len(self.label_t[i]))
296
        batch_count = 0
297
        for _ in tqdm.tqdm(range(self.n_total_batches), 'generating meta batches'): # progress bar
298
            # i.e., sample 16 patients from selected tasks
299
            # len of spl and lbl: 4 * 16
300
            idx = [] # index in one episode
301
            for i in range(len(self.source)): # fetch from source tasks olderly
302
                ### do not keep pos/neg ratio
303
                s_idx = random.sample(s_dict[i], self.n_samples_per_task)
304
                idx.extend(s_idx)
305
            ### do not keep pos/neg ratio
306
            if is_training:
307
                t_idx = random.sample(t_dict[0], self.n_samples_per_task)
308
                idx.extend(t_idx)
309
            else:
310
                t_idx = range(batch_count*self.n_samples_per_task, (batch_count+1)*self.n_samples_per_task)
311
                idx.extend(t_idx)
312
            batch_count += 1
313
            # add meta_batch
314
            episode.append(idx)
315
316
        print ("batch counts: ", batch_count)
317
        return episode