Diff of /test/diseasedb/train.py [000000] .. [c3444c]

Switch to unified view

a b/test/diseasedb/train.py
1
import sys
2
sys.path.append("../../pretrain/")
3
from linear_model import LinearModel
4
from load_umls import UMLS
5
import numpy as np
6
import os
7
import shutil
8
import torch
9
from torch.utils.data import DataLoader, TensorDataset, Dataset
10
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, AutoConfig, AutoModel
11
from time import time
12
from tqdm import tqdm
13
import ipdb
14
15
16
# parameters
17
embedding = sys.argv[1]
18
embedding_type = sys.argv[2]
19
freeze_embedding = sys.argv[3]
20
device = sys.argv[4]
21
22
if freeze_embedding.lower() in ['t', 'true']:
23
    freeze_embedding = True
24
else:
25
    freeze_embedding = False
26
27
if device == "0":
28
    device = "cuda:0"
29
if device == "1":
30
    device = "cuda:1"
31
32
if embedding_type == 'bert':
33
    epoch_num = 50
34
    if freeze_embedding:
35
        batch_size = 512
36
        learning_rate = 1e-3
37
    else:
38
        batch_size = 96
39
        learning_rate = 2e-5
40
    max_seq_length = 32
41
    try:
42
        tokenizer = AutoTokenizer.from_pretrained(embedding)
43
    except BaseException:
44
        tokenizer = AutoTokenizer.from_pretrained(
45
            os.path.join(embedding, "../"))
46
else:
47
    epoch_num = 50
48
    batch_size = 512
49
    learning_rate = 1e-3
50
    max_seq_length = 16
51
52
def pad(l):
53
    if len(l) > max_seq_length:
54
        return l[0:max_seq_length]
55
    return l + [use_embedding_count - 1] * (max_seq_length - len(l))
56
57
# load train and test
58
cui_train_0 = []
59
cui_train_1 = []
60
rel_train = []
61
with open("./data/x_train.txt") as f:
62
    lines = f.readlines()
63
    for line in lines:
64
        line = line.strip().split("\t")
65
        cui_train_0.append(line[0])
66
        cui_train_1.append(line[1])
67
with open("./data/y_train.txt") as f:
68
    lines = f.readlines()
69
    for line in lines:
70
        rel_train.append(line.strip())
71
72
cui_test_0 = []
73
cui_test_1 = []
74
rel_test = []
75
with open("./data/x_test.txt") as f:
76
    lines = f.readlines()
77
    for line in lines:
78
        line = line.strip().split("\t")
79
        cui_test_0.append(line[0])
80
        cui_test_1.append(line[1])
81
with open("./data/y_test.txt") as f:
82
    lines = f.readlines()
83
    for line in lines:
84
        rel_test.append(line.strip())
85
86
# build rel2id
87
rel_set = set(rel_train + rel_test)
88
rel2id = {rel: index for index, rel in enumerate(list(rel_set))}
89
id2rel = {index: rel for rel, index in rel2id.items()}
90
cui_set = set(cui_train_0 + cui_train_1 + cui_test_0 + cui_test_1)
91
print('Count of differnt cui:', len(cui_set))
92
93
# Deal cui type embedding
94
if embedding_type != 'bert':
95
    if embedding.find('txt') >= 0:
96
        with open(embedding, "r", encoding="utf-8") as f:
97
            line = f.readline()
98
            count, dim = map(int, line.strip().split())
99
            lines = f.readlines()
100
101
if embedding_type == 'cui':
102
    # build cui2id and use_embedding
103
    if embedding.find('txt') >= 0:
104
        cui2id = {}
105
        use_embedding_count = 0
106
        emb_sum = np.zeros(shape=(dim), dtype=float)
107
        for line in lines:
108
            l = line.strip().split()
109
            cui = l[0]
110
            if embedding.find('stanford') >= 0:
111
                cui = 'C' + cui
112
            emb = np.array(list(map(float, l[1:])))
113
            emb_sum += emb
114
            if cui in cui_set:
115
                cui2id[cui] = use_embedding_count
116
                if use_embedding_count == 0:
117
                    use_embedding = emb
118
                else:
119
                    use_embedding = np.concatenate((use_embedding, emb), axis=0)
120
                use_embedding_count += 1
121
        emb_avg = emb_sum / use_embedding_count
122
        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
123
        use_embedding_count += 1
124
        use_embedding = use_embedding.reshape((-1, dim))
125
        print('Embedding shape:', use_embedding.shape)
126
    if embedding.find('pkl') >= 0:
127
        import pickle
128
        with open(embedding, 'rb') as f:
129
            W = pickle.load(f)
130
        cui2id = {}
131
        use_embedding_count = 0
132
        dim = len(list(W.values())[0][1:-1].split(','))
133
        emb_sum = np.zeros(shape=(dim), dtype=float)
134
        for cui in cui_set:
135
            if cui in W and not cui in cui2id:
136
                emb = np.array([float(num) for num in W[cui][1:-1].split(',')])
137
                #ipdb.set_trace()
138
                emb_sum += emb
139
                cui2id[cui] = use_embedding_count
140
                if use_embedding_count == 0:
141
                    use_embedding = emb
142
                else:
143
                    use_embedding = np.concatenate((use_embedding, emb), axis=0)
144
                use_embedding_count += 1
145
        emb_avg = emb_sum / use_embedding_count
146
        if 'empty' in W:
147
            emb_avg = np.array([float(num) for num in W['empty'][1:-1].split(',')])
148
        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
149
        use_embedding_count += 1
150
        use_embedding = use_embedding.reshape((-1, dim))
151
        print('Embedding shape:', use_embedding.shape)
152
153
    # apply cui2id and rel2id
154
    train_input_0 = [cui2id.get(cui, use_embedding_count - 1)
155
                     for cui in cui_train_0]
156
    train_input_1 = [cui2id.get(cui, use_embedding_count - 1)
157
                     for cui in cui_train_1]
158
    train_y = [rel2id[rel] for rel in rel_train]
159
    test_input_0 = [cui2id.get(cui, use_embedding_count - 1)
160
                    for cui in cui_test_0]
161
    test_input_1 = [cui2id.get(cui, use_embedding_count - 1)
162
                    for cui in cui_test_1]
163
    test_y = [rel2id[rel] for rel in rel_test]
164
165
# Find standard term name
166
if not embedding_type == 'cui':
167
    umls = UMLS("../../umls", only_load_dict=True)
168
    cui2str = {}
169
    #ipdb.set_trace()
170
    for cui in cui_set:
171
        standard_term = umls.search(code=cui, max_number=1)
172
        if standard_term is not None:
173
            cui2str[cui] = standard_term[0]
174
        else:
175
            cui2str[cui] = cui
176
177
# Deal word type embedding
178
if embedding_type == 'word':
179
180
    # tokenize
181
    from nltk.tokenize import word_tokenize
182
    cui2tokenize = {}
183
    for cui in cui2str:
184
        cui2tokenize[cui] = word_tokenize(cui2str[cui])
185
186
    # build word2id and use_embedding
187
    word2id = {}
188
    use_embedding_count = 0
189
190
    if embedding.find('txt') >= 0:
191
        emb_sum = np.zeros(shape=(dim), dtype=float)
192
        for line in lines:
193
            l = line.strip().split()
194
            word = l[0]
195
            emb = np.array(list(map(float, l[1:])))
196
            emb_sum += emb
197
            word2id[word] = use_embedding_count
198
            if use_embedding_count == 0:
199
                use_embedding = emb
200
            else:
201
                use_embedding = np.concatenate((use_embedding, emb), axis=0)
202
            use_embedding_count += 1
203
        emb_avg = emb_sum / use_embedding_count
204
        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
205
        use_embedding_count += 1
206
        emb_zero = np.zeros_like(emb_avg)
207
        use_embedding = np.concatenate((use_embedding, emb_zero), axis=0)
208
        use_embedding_count += 1
209
        use_embedding = use_embedding.reshape((-1, dim))
210
        print('Embedding shape:', use_embedding.shape)
211
    if embedding.find('bin') >= 0:
212
        import gensim
213
        model = gensim.models.KeyedVectors.load_word2vec_format(embedding, binary=True)
214
        emb_sum = np.zeros(shape=(model.vector_size), dtype=float)
215
        for cui in cui2tokenize:
216
            for w in cui2tokenize[cui]:
217
                if w in model and not w in word2id:
218
                    emb = model[w]
219
                    emb_sum += emb
220
                    word2id[w] = use_embedding_count
221
                    if use_embedding_count == 0:
222
                        use_embedding = emb
223
                    else:
224
                        use_embedding = np.concatenate((use_embedding, emb), axis=0)
225
                    use_embedding_count += 1  
226
        emb_avg = emb_sum / use_embedding_count
227
        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
228
        use_embedding_count += 1
229
        emb_zero = np.zeros_like(emb_avg)
230
        use_embedding = np.concatenate((use_embedding, emb_zero), axis=0)
231
        use_embedding_count += 1
232
        use_embedding = use_embedding.reshape((-1, model.vector_size))
233
        print('Original embedding count:', len(model.wv.vocab))
234
        print('Embedding shape:', use_embedding.shape)                  
235
236
    # apply word2id and rel2id
237
    train_input_0 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_train_0]
238
    train_input_1 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_train_1]
239
    train_y = [rel2id[rel] for rel in rel_train]
240
    test_input_0 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_test_0]
241
    test_input_1 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_test_1]
242
    test_y = [rel2id[rel] for rel in rel_test]
243
244
    # average and padding
245
    # deal with input length = 0, use average
246
    train_input_0 = [cui if cui else [use_embedding_count - 2] for cui in train_input_0] 
247
    train_input_1 = [cui if cui else [use_embedding_count - 2] for cui in train_input_1]
248
    test_input_0 = [cui if cui else [use_embedding_count - 2] for cui in test_input_0]
249
    test_input_1 = [cui if cui else [use_embedding_count - 2] for cui in test_input_1]
250
    # calculate length
251
    train_length_0 = [len(cui) for cui in train_input_0]
252
    train_length_1 = [len(cui) for cui in train_input_1]
253
    test_length_0 = [len(cui) for cui in test_input_0]
254
    test_length_1 = [len(cui) for cui in test_input_1]
255
    # padding
256
    train_input_0 = list(map(pad, train_input_0))
257
    train_input_1 = list(map(pad, train_input_1))
258
    test_input_0 = list(map(pad, test_input_0))
259
    test_input_1 = list(map(pad, test_input_1))
260
261
# Deal bert type embedding
262
if embedding_type == 'bert':
263
    train_input_0 = []
264
    train_input_1 = []
265
    test_input_0 = []
266
    test_input_1 = []
267
268
    cui2tokenize = {}
269
    for cui in cui2str:
270
        cui2tokenize[cui] = tokenizer.encode_plus(
271
            cui2str[cui], max_length=max_seq_length, add_special_tokens=True,
272
            truncation=True, pad_to_max_length=True)['input_ids']
273
    
274
    train_input_0 = [cui2tokenize[cui] for cui in cui_train_0]
275
    train_input_1 = [cui2tokenize[cui] for cui in cui_train_1]
276
    test_input_0 = [cui2tokenize[cui] for cui in cui_test_0]
277
    test_input_1 = [cui2tokenize[cui] for cui in cui_test_1]
278
    train_y = [rel2id[rel] for rel in rel_train]
279
    test_y = [rel2id[rel] for rel in rel_test]
280
281
# Dataset and Dataloader
282
train_input_0 = torch.LongTensor(train_input_0)
283
train_input_1 = torch.LongTensor(train_input_1)
284
test_input_0 = torch.LongTensor(test_input_0)
285
test_input_1 = torch.LongTensor(test_input_1)
286
train_y = torch.LongTensor(train_y)
287
test_y = torch.LongTensor(test_y)
288
if embedding_type != 'word':
289
    train_dataset = TensorDataset(train_input_0, train_input_1, train_y)
290
    test_dataset = TensorDataset(test_input_0, test_input_1, test_y)
291
else:
292
    train_length_0 = torch.LongTensor(train_length_0)
293
    train_length_1 = torch.LongTensor(train_length_1)
294
    test_length_0 = torch.LongTensor(test_length_0)
295
    test_length_1 = torch.LongTensor(test_length_1)
296
    train_dataset = TensorDataset(train_input_0, train_input_1, train_length_0, train_length_1, train_y)
297
    test_dataset = TensorDataset(test_input_0, test_input_1, test_length_0, test_length_1, test_y)
298
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
299
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
300
301
# Prepare model and optimizier
302
# model
303
if embedding_type != 'bert':
304
    use_embedding = torch.FloatTensor(use_embedding)
305
    model = LinearModel(len(rel2id), embedding_type, use_embedding, freeze_embedding).to(device)
306
else:
307
    try:
308
        config = AutoConfig.from_pretrained(embedding)
309
        bert_model = AutoModel.from_pretrained(embedding, config=config).to(device)
310
    except BaseException:
311
        bert_model = torch.load(os.path.join(embedding, 'pytorch_model.bin')).to(device)
312
    model = LinearModel(len(rel2id), embedding_type, bert_model, freeze_embedding).to(device)
313
314
# optimizier
315
if embedding_type != 'bert':
316
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
317
if embedding_type == "bert":
318
    no_decay = ["bias", "LayerNorm.weight"]
319
    optimizer_grouped_parameters = [
320
        {
321
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
322
            "weight_decay": 0.0,
323
        },
324
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
325
    ]
326
    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
327
328
    scheduler = get_linear_schedule_with_warmup(optimizer,
329
                                                num_warmup_steps=int(epoch_num * len(train_dataloader) * 0.1),
330
                                                num_training_steps=epoch_num * len(train_dataloader))
331
332
# Prepare eval function
333
from sklearn.metrics import accuracy_score, classification_report, f1_score
334
def eval(m, dataloader):
335
    y_pred = []
336
    y_true = []
337
    m.eval()
338
    with torch.no_grad():
339
        for batch in dataloader:
340
            x0 = batch[0].to(device)
341
            x1 = batch[1].to(device)
342
            if m.embedding_type == "word":
343
                l0 = batch[2].to(device)
344
                l1 = batch[3].to(device)
345
                r = batch[4]
346
            else:
347
                l0 = l1 = None
348
                r = batch[2]
349
            pred, loss = m(x0, x1, l0, l1)
350
            y_pred += torch.max(pred, dim=1)[1].detach().cpu().numpy().tolist()
351
            y_true += r.detach().cpu().numpy().tolist()
352
    acc = accuracy_score(y_true, y_pred) * 100
353
    #f1 = f1_score(y_true, y_pred) * 100
354
    report = classification_report(y_true, y_pred)
355
    return acc, report
356
357
# Train and eval
358
if not os.path.exists("./result/"):
359
    os.mkdir("./result/")
360
361
for epoch_index in range(epoch_num):
362
    model.train()
363
    epoch_loss = 0.
364
    time_now = time()
365
    for batch in tqdm(train_dataloader):
366
        optimizer.zero_grad()
367
        x0 = batch[0].to(device)
368
        x1 = batch[1].to(device)
369
        if model.embedding_type == "word":
370
            l0 = batch[2].to(device)
371
            l1 = batch[3].to(device)
372
            r = batch[4].to(device)
373
        else:
374
            l0 = l1 = None
375
            r = batch[2].to(device)
376
        pred, loss = model(x0, x1, l0, l1, r)
377
        loss.backward()
378
        optimizer.step()
379
        if model.embedding_type == "bert":
380
            scheduler.step()
381
        epoch_loss += loss.item()
382
    print(epoch_index + 1, round(time() - time_now, 1), epoch_loss)
383
384
    acc, report = eval(model, test_dataloader)
385
    print("Accuracy:", acc)
386
    #print(report)