Diff of /test/diseasedb/train.py [000000] .. [c3444c]

Switch to side-by-side view

--- a
+++ b/test/diseasedb/train.py
@@ -0,0 +1,386 @@
+import sys
+sys.path.append("../../pretrain/")
+from linear_model import LinearModel
+from load_umls import UMLS
+import numpy as np
+import os
+import shutil
+import torch
+from torch.utils.data import DataLoader, TensorDataset, Dataset
+from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, AutoConfig, AutoModel
+from time import time
+from tqdm import tqdm
+import ipdb
+
+
+# parameters
+embedding = sys.argv[1]
+embedding_type = sys.argv[2]
+freeze_embedding = sys.argv[3]
+device = sys.argv[4]
+
+if freeze_embedding.lower() in ['t', 'true']:
+    freeze_embedding = True
+else:
+    freeze_embedding = False
+
+if device == "0":
+    device = "cuda:0"
+if device == "1":
+    device = "cuda:1"
+
+if embedding_type == 'bert':
+    epoch_num = 50
+    if freeze_embedding:
+        batch_size = 512
+        learning_rate = 1e-3
+    else:
+        batch_size = 96
+        learning_rate = 2e-5
+    max_seq_length = 32
+    try:
+        tokenizer = AutoTokenizer.from_pretrained(embedding)
+    except BaseException:
+        tokenizer = AutoTokenizer.from_pretrained(
+            os.path.join(embedding, "../"))
+else:
+    epoch_num = 50
+    batch_size = 512
+    learning_rate = 1e-3
+    max_seq_length = 16
+
+def pad(l):
+    if len(l) > max_seq_length:
+        return l[0:max_seq_length]
+    return l + [use_embedding_count - 1] * (max_seq_length - len(l))
+
+# load train and test
+cui_train_0 = []
+cui_train_1 = []
+rel_train = []
+with open("./data/x_train.txt") as f:
+    lines = f.readlines()
+    for line in lines:
+        line = line.strip().split("\t")
+        cui_train_0.append(line[0])
+        cui_train_1.append(line[1])
+with open("./data/y_train.txt") as f:
+    lines = f.readlines()
+    for line in lines:
+        rel_train.append(line.strip())
+
+cui_test_0 = []
+cui_test_1 = []
+rel_test = []
+with open("./data/x_test.txt") as f:
+    lines = f.readlines()
+    for line in lines:
+        line = line.strip().split("\t")
+        cui_test_0.append(line[0])
+        cui_test_1.append(line[1])
+with open("./data/y_test.txt") as f:
+    lines = f.readlines()
+    for line in lines:
+        rel_test.append(line.strip())
+
+# build rel2id
+rel_set = set(rel_train + rel_test)
+rel2id = {rel: index for index, rel in enumerate(list(rel_set))}
+id2rel = {index: rel for rel, index in rel2id.items()}
+cui_set = set(cui_train_0 + cui_train_1 + cui_test_0 + cui_test_1)
+print('Count of differnt cui:', len(cui_set))
+
+# Deal cui type embedding
+if embedding_type != 'bert':
+    if embedding.find('txt') >= 0:
+        with open(embedding, "r", encoding="utf-8") as f:
+            line = f.readline()
+            count, dim = map(int, line.strip().split())
+            lines = f.readlines()
+
+if embedding_type == 'cui':
+    # build cui2id and use_embedding
+    if embedding.find('txt') >= 0:
+        cui2id = {}
+        use_embedding_count = 0
+        emb_sum = np.zeros(shape=(dim), dtype=float)
+        for line in lines:
+            l = line.strip().split()
+            cui = l[0]
+            if embedding.find('stanford') >= 0:
+                cui = 'C' + cui
+            emb = np.array(list(map(float, l[1:])))
+            emb_sum += emb
+            if cui in cui_set:
+                cui2id[cui] = use_embedding_count
+                if use_embedding_count == 0:
+                    use_embedding = emb
+                else:
+                    use_embedding = np.concatenate((use_embedding, emb), axis=0)
+                use_embedding_count += 1
+        emb_avg = emb_sum / use_embedding_count
+        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
+        use_embedding_count += 1
+        use_embedding = use_embedding.reshape((-1, dim))
+        print('Embedding shape:', use_embedding.shape)
+    if embedding.find('pkl') >= 0:
+        import pickle
+        with open(embedding, 'rb') as f:
+            W = pickle.load(f)
+        cui2id = {}
+        use_embedding_count = 0
+        dim = len(list(W.values())[0][1:-1].split(','))
+        emb_sum = np.zeros(shape=(dim), dtype=float)
+        for cui in cui_set:
+            if cui in W and not cui in cui2id:
+                emb = np.array([float(num) for num in W[cui][1:-1].split(',')])
+                #ipdb.set_trace()
+                emb_sum += emb
+                cui2id[cui] = use_embedding_count
+                if use_embedding_count == 0:
+                    use_embedding = emb
+                else:
+                    use_embedding = np.concatenate((use_embedding, emb), axis=0)
+                use_embedding_count += 1
+        emb_avg = emb_sum / use_embedding_count
+        if 'empty' in W:
+            emb_avg = np.array([float(num) for num in W['empty'][1:-1].split(',')])
+        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
+        use_embedding_count += 1
+        use_embedding = use_embedding.reshape((-1, dim))
+        print('Embedding shape:', use_embedding.shape)
+
+    # apply cui2id and rel2id
+    train_input_0 = [cui2id.get(cui, use_embedding_count - 1)
+                     for cui in cui_train_0]
+    train_input_1 = [cui2id.get(cui, use_embedding_count - 1)
+                     for cui in cui_train_1]
+    train_y = [rel2id[rel] for rel in rel_train]
+    test_input_0 = [cui2id.get(cui, use_embedding_count - 1)
+                    for cui in cui_test_0]
+    test_input_1 = [cui2id.get(cui, use_embedding_count - 1)
+                    for cui in cui_test_1]
+    test_y = [rel2id[rel] for rel in rel_test]
+
+# Find standard term name
+if not embedding_type == 'cui':
+    umls = UMLS("../../umls", only_load_dict=True)
+    cui2str = {}
+    #ipdb.set_trace()
+    for cui in cui_set:
+        standard_term = umls.search(code=cui, max_number=1)
+        if standard_term is not None:
+            cui2str[cui] = standard_term[0]
+        else:
+            cui2str[cui] = cui
+
+# Deal word type embedding
+if embedding_type == 'word':
+
+    # tokenize
+    from nltk.tokenize import word_tokenize
+    cui2tokenize = {}
+    for cui in cui2str:
+        cui2tokenize[cui] = word_tokenize(cui2str[cui])
+
+    # build word2id and use_embedding
+    word2id = {}
+    use_embedding_count = 0
+
+    if embedding.find('txt') >= 0:
+        emb_sum = np.zeros(shape=(dim), dtype=float)
+        for line in lines:
+            l = line.strip().split()
+            word = l[0]
+            emb = np.array(list(map(float, l[1:])))
+            emb_sum += emb
+            word2id[word] = use_embedding_count
+            if use_embedding_count == 0:
+                use_embedding = emb
+            else:
+                use_embedding = np.concatenate((use_embedding, emb), axis=0)
+            use_embedding_count += 1
+        emb_avg = emb_sum / use_embedding_count
+        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
+        use_embedding_count += 1
+        emb_zero = np.zeros_like(emb_avg)
+        use_embedding = np.concatenate((use_embedding, emb_zero), axis=0)
+        use_embedding_count += 1
+        use_embedding = use_embedding.reshape((-1, dim))
+        print('Embedding shape:', use_embedding.shape)
+    if embedding.find('bin') >= 0:
+        import gensim
+        model = gensim.models.KeyedVectors.load_word2vec_format(embedding, binary=True)
+        emb_sum = np.zeros(shape=(model.vector_size), dtype=float)
+        for cui in cui2tokenize:
+            for w in cui2tokenize[cui]:
+                if w in model and not w in word2id:
+                    emb = model[w]
+                    emb_sum += emb
+                    word2id[w] = use_embedding_count
+                    if use_embedding_count == 0:
+                        use_embedding = emb
+                    else:
+                        use_embedding = np.concatenate((use_embedding, emb), axis=0)
+                    use_embedding_count += 1  
+        emb_avg = emb_sum / use_embedding_count
+        use_embedding = np.concatenate((use_embedding, emb_avg), axis=0)
+        use_embedding_count += 1
+        emb_zero = np.zeros_like(emb_avg)
+        use_embedding = np.concatenate((use_embedding, emb_zero), axis=0)
+        use_embedding_count += 1
+        use_embedding = use_embedding.reshape((-1, model.vector_size))
+        print('Original embedding count:', len(model.wv.vocab))
+        print('Embedding shape:', use_embedding.shape)                  
+
+    # apply word2id and rel2id
+    train_input_0 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_train_0]
+    train_input_1 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_train_1]
+    train_y = [rel2id[rel] for rel in rel_train]
+    test_input_0 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_test_0]
+    test_input_1 = [[word2id[w] for w in cui2tokenize[cui] if w in word2id] for cui in cui_test_1]
+    test_y = [rel2id[rel] for rel in rel_test]
+
+    # average and padding
+    # deal with input length = 0, use average
+    train_input_0 = [cui if cui else [use_embedding_count - 2] for cui in train_input_0] 
+    train_input_1 = [cui if cui else [use_embedding_count - 2] for cui in train_input_1]
+    test_input_0 = [cui if cui else [use_embedding_count - 2] for cui in test_input_0]
+    test_input_1 = [cui if cui else [use_embedding_count - 2] for cui in test_input_1]
+    # calculate length
+    train_length_0 = [len(cui) for cui in train_input_0]
+    train_length_1 = [len(cui) for cui in train_input_1]
+    test_length_0 = [len(cui) for cui in test_input_0]
+    test_length_1 = [len(cui) for cui in test_input_1]
+    # padding
+    train_input_0 = list(map(pad, train_input_0))
+    train_input_1 = list(map(pad, train_input_1))
+    test_input_0 = list(map(pad, test_input_0))
+    test_input_1 = list(map(pad, test_input_1))
+
+# Deal bert type embedding
+if embedding_type == 'bert':
+    train_input_0 = []
+    train_input_1 = []
+    test_input_0 = []
+    test_input_1 = []
+
+    cui2tokenize = {}
+    for cui in cui2str:
+        cui2tokenize[cui] = tokenizer.encode_plus(
+            cui2str[cui], max_length=max_seq_length, add_special_tokens=True,
+            truncation=True, pad_to_max_length=True)['input_ids']
+    
+    train_input_0 = [cui2tokenize[cui] for cui in cui_train_0]
+    train_input_1 = [cui2tokenize[cui] for cui in cui_train_1]
+    test_input_0 = [cui2tokenize[cui] for cui in cui_test_0]
+    test_input_1 = [cui2tokenize[cui] for cui in cui_test_1]
+    train_y = [rel2id[rel] for rel in rel_train]
+    test_y = [rel2id[rel] for rel in rel_test]
+
+# Dataset and Dataloader
+train_input_0 = torch.LongTensor(train_input_0)
+train_input_1 = torch.LongTensor(train_input_1)
+test_input_0 = torch.LongTensor(test_input_0)
+test_input_1 = torch.LongTensor(test_input_1)
+train_y = torch.LongTensor(train_y)
+test_y = torch.LongTensor(test_y)
+if embedding_type != 'word':
+    train_dataset = TensorDataset(train_input_0, train_input_1, train_y)
+    test_dataset = TensorDataset(test_input_0, test_input_1, test_y)
+else:
+    train_length_0 = torch.LongTensor(train_length_0)
+    train_length_1 = torch.LongTensor(train_length_1)
+    test_length_0 = torch.LongTensor(test_length_0)
+    test_length_1 = torch.LongTensor(test_length_1)
+    train_dataset = TensorDataset(train_input_0, train_input_1, train_length_0, train_length_1, train_y)
+    test_dataset = TensorDataset(test_input_0, test_input_1, test_length_0, test_length_1, test_y)
+train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
+test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
+
+# Prepare model and optimizier
+# model
+if embedding_type != 'bert':
+    use_embedding = torch.FloatTensor(use_embedding)
+    model = LinearModel(len(rel2id), embedding_type, use_embedding, freeze_embedding).to(device)
+else:
+    try:
+        config = AutoConfig.from_pretrained(embedding)
+        bert_model = AutoModel.from_pretrained(embedding, config=config).to(device)
+    except BaseException:
+        bert_model = torch.load(os.path.join(embedding, 'pytorch_model.bin')).to(device)
+    model = LinearModel(len(rel2id), embedding_type, bert_model, freeze_embedding).to(device)
+
+# optimizier
+if embedding_type != 'bert':
+    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
+if embedding_type == "bert":
+    no_decay = ["bias", "LayerNorm.weight"]
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": 0.0,
+        },
+        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
+    ]
+    optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=1e-8)
+
+    scheduler = get_linear_schedule_with_warmup(optimizer,
+                                                num_warmup_steps=int(epoch_num * len(train_dataloader) * 0.1),
+                                                num_training_steps=epoch_num * len(train_dataloader))
+
+# Prepare eval function
+from sklearn.metrics import accuracy_score, classification_report, f1_score
+def eval(m, dataloader):
+    y_pred = []
+    y_true = []
+    m.eval()
+    with torch.no_grad():
+        for batch in dataloader:
+            x0 = batch[0].to(device)
+            x1 = batch[1].to(device)
+            if m.embedding_type == "word":
+                l0 = batch[2].to(device)
+                l1 = batch[3].to(device)
+                r = batch[4]
+            else:
+                l0 = l1 = None
+                r = batch[2]
+            pred, loss = m(x0, x1, l0, l1)
+            y_pred += torch.max(pred, dim=1)[1].detach().cpu().numpy().tolist()
+            y_true += r.detach().cpu().numpy().tolist()
+    acc = accuracy_score(y_true, y_pred) * 100
+    #f1 = f1_score(y_true, y_pred) * 100
+    report = classification_report(y_true, y_pred)
+    return acc, report
+
+# Train and eval
+if not os.path.exists("./result/"):
+    os.mkdir("./result/")
+
+for epoch_index in range(epoch_num):
+    model.train()
+    epoch_loss = 0.
+    time_now = time()
+    for batch in tqdm(train_dataloader):
+        optimizer.zero_grad()
+        x0 = batch[0].to(device)
+        x1 = batch[1].to(device)
+        if model.embedding_type == "word":
+            l0 = batch[2].to(device)
+            l1 = batch[3].to(device)
+            r = batch[4].to(device)
+        else:
+            l0 = l1 = None
+            r = batch[2].to(device)
+        pred, loss = model(x0, x1, l0, l1, r)
+        loss.backward()
+        optimizer.step()
+        if model.embedding_type == "bert":
+            scheduler.step()
+        epoch_loss += loss.item()
+    print(epoch_index + 1, round(time() - time_now, 1), epoch_loss)
+
+    acc, report = eval(model, test_dataloader)
+    print("Accuracy:", acc)
+    #print(report)