--- a
+++ b/test/embeddings_reimplement/ndfrt_analysis.py
@@ -0,0 +1,443 @@
+from transformers import AutoConfig, AutoModel, AutoTokenizer
+import torch
+from tqdm import tqdm
+import numpy as np
+import sys
+sys.path.append("../../pretrain")
+from load_umls import UMLS
+from nltk.tokenize import word_tokenize
+import os
+import ipdb
+
+
+device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
+
+batch_size = 512
+max_seq_length = 32
+
+def get_drug_diseases_to_check(concept_filename):
+    query_to_targets = {}
+    with open(concept_filename, 'r') as infile:
+        data = infile.readlines()
+        for row in data:
+            drug, diseases = row.strip().split(':')
+            diseases = diseases.split(',')[:-1]
+            disease_cui_set = set([])
+            for disease in diseases:
+                disease_cui_set.add(disease)
+            if len(disease_cui_set) > 0:
+                query_to_targets[drug] = disease_cui_set
+
+    cui_list = set()
+    for query, targets in query_to_targets.items():
+        cui_list.update([query])
+        cui_list.update(targets)
+    return query_to_targets, list(cui_list)
+
+def normalize(tensor):
+    norm = torch.norm(tensor, p=2, dim=1, keepdim=True).clamp(min=1e-12)
+    return torch.div(tensor, norm)
+
+def calculate_mrm_ndfrt_origin(term_embedding, cui_list, query_to_targets, k):
+    return calculate_mrm_ndfrt_delta(term_embedding, cui_list, query_to_targets, None, k)
+
+
+def calculate_mrm_ndfrt_q2t(term_embedding, cui_list, query_to_targets, k):
+    delta_list = []
+
+    term_embedding = torch.FloatTensor(term_embedding).to(device)
+    norm_embedding = normalize(term_embedding)
+
+    id2cui = {i:cui_list[i] for i in range(len(cui_list))}
+    cui2id = {cui:index for index, cui in id2cui.items()}
+
+    for query, targets in query_to_targets.items():
+        if query in cui2id:
+            for target in targets:
+                if target in cui2id:
+                    delta = term_embedding[cui2id[query]] - term_embedding[cui2id[target]]
+                    delta_list.append(delta)
+
+    overall_output = []
+    for _, delta in tqdm(enumerate(delta_list)):
+        output = []
+        for query, targets in query_to_targets.items():
+            if query in cui2id:
+                find_embedding = term_embedding[cui2id[query]] - delta
+                similarity = torch.matmul(norm_embedding, find_embedding)
+                _, indices = torch.topk(similarity, k=k + 1)
+                find_cui = [cui_list[index] for index in indices[1:]]
+                score = 0.
+                for cui in find_cui:
+                    if cui in targets:
+                        score = 1.
+                        break
+                output.append(score)
+        if len(output) > 0:
+            score = sum(output) / len(output)
+        else:
+            score = 0.  
+        overall_output.append(score)
+
+    if len(overall_output) > 0:
+        overall_score = sum(overall_output) / len(overall_output)
+        overall_max = max(overall_output)
+    else:
+        overall_score = 0
+        overall_max = 0
+    return overall_score, overall_max
+
+
+def calculate_mrm_ndfrt_delta(term_embedding, cui_list, query_to_targets, delta=None, k=40):
+    term_embedding = torch.FloatTensor(term_embedding).to(device)
+    norm_embedding = normalize(term_embedding)
+
+    id2cui = {i:cui_list[i] for i in range(len(cui_list))}
+    cui2id = {cui:index for index, cui in id2cui.items()}
+
+    output = []
+    check_count = 0
+    for query, targets in query_to_targets.items():
+        if query in cui2id:
+            query_embedding = term_embedding[cui2id[query]]
+            if delta is None:
+                find_embedding = query_embedding
+            else:
+                find_embedding = query_embedding - torch.FloatTensor(delta).to(device)
+            similarity = torch.matmul(norm_embedding, find_embedding)
+            _, indices = torch.topk(similarity, k=k + 1)
+            find_cui = [cui_list[index] for index in indices[1:]]
+            score = 0.
+            for cui in find_cui:
+                if cui in targets:
+                    score = 1.
+                    break
+            output.append(score)
+            check_count += 1
+    del term_embedding
+
+    if len(output) > 0:
+        score = sum(output) / len(output)
+    else:
+        score = 0.
+
+    """
+    print(f"Check count: {check_count}")
+    print(score)
+    """
+
+    return score
+
+
+def mrm_ndfrt_cui(cui_embedding, umls, cui_list, query_to_targets, k, method):
+    w, _ = load_embedding(cui_embedding)
+
+    new_cui_list = [cui for cui in cui_list if cui in w]
+    term_embedding = np.array([w[cui] for cui in new_cui_list])
+
+    print(f"Cui count:{len(new_cui_list)}")
+
+    if method == "origin":
+        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
+        print(f"Origin: {score}")
+    if method == "all":
+        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
+        average_score, max_score = score
+        print(f"Average: {average_score}")
+        print(f"Max: {max_score}")
+    return score
+
+
+def mrm_ndfrt_word(word_embedding, umls, cui_list, query_to_targets, k, method):
+    w, dim = load_embedding(word_embedding)
+
+    print("Tokenize and calculate avg embedding.")
+    cui_str = [[word for word in word_tokenize(
+        list(umls.cui2str[cui])[0]) if word in w] for cui in cui_list if cui in umls.cui2str]
+
+    new_cui_list = []
+    check_count = 0
+    for index, des in enumerate(cui_str):
+        if len(des) > 0:
+            tmp_emb = np.zeros((dim))
+            for word in des:
+                tmp_emb += w[word]
+
+            if check_count == 0:
+                term_embedding = tmp_emb
+            else:
+                term_embedding = np.concatenate(
+                    (term_embedding, tmp_emb), axis=0)
+            check_count += 1
+            new_cui_list.append(cui_list[index])
+    term_embedding = term_embedding.reshape((-1, dim))
+
+    print(f"Cui count:{len(new_cui_list)}")
+
+    if method == "origin":
+        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
+        print(f"Origin: {score}")
+    if method == "all":
+        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
+        average_score, max_score = score
+        print(f"Average: {average_score}")
+        print(f"Max: {max_score}")
+    return score
+
+
+def mrm_ndfrt_bert(bert_embedding, umls, cui_list, query_to_targets, k, method, summary_method):
+    print(summary_method)
+    model, tokenizer = load_bert(bert_embedding)
+    model.eval()
+
+    input_ids = []
+    new_cui_list = []
+    for cui in cui_list:
+        if cui in umls.cui2str:
+            input_ids.append(tokenizer.encode_plus(
+                list(umls.cui2str[cui])[
+                    0], max_length=max_seq_length, add_special_tokens=True,
+                truncation=True, pad_to_max_length=True)['input_ids'])
+            new_cui_list.append(cui)
+
+    count = len(input_ids)
+    now_count = 0
+    # with tqdm(total=count) as pbar:
+    with torch.no_grad():
+        while now_count < count:
+            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
+                now_count + batch_size, count)]).to(device)
+            if summary_method == "CLS":
+                embed = model(input_gpu_0)[1]
+            if summary_method == "MEAN":
+                embed = torch.mean(model(input_gpu_0)[0], dim=1)
+            embed_np = embed.cpu().detach().numpy()
+            if now_count == 0:
+                term_embedding = embed_np
+            else:
+                term_embedding = np.concatenate((term_embedding, embed_np), axis=0)
+            update = min(now_count + batch_size, count) - now_count
+            now_count = now_count + update
+            # pbar.update(update)
+
+    print(f"Cui count:{len(new_cui_list)}")
+
+    if method == "origin":
+        score = calculate_mrm_ndfrt_origin(term_embedding, new_cui_list, query_to_targets, k)
+        print(f"Origin: {score}")
+    if method == "all":
+        score = calculate_mrm_ndfrt_q2t(term_embedding, new_cui_list, query_to_targets, k)
+        average_score, max_score = score
+        print(f"Average: {average_score}")
+        print(f"Max: {max_score}")
+    if method in ["may_treat", "may_prevent"]:
+        beta_path = os.path.join(bert_embedding, "run", "1000000", "rel embedding")
+        with open(os.path.join(beta_path, "metadata.tsv"), "r", encoding="utf-8") as f:
+            metadata = f.readlines()
+        metadata = [line.strip() for line in metadata]
+        with open(os.path.join(beta_path, "tensors.tsv"), "r", encoding="utf-8") as f:
+            tensor = f.readlines()
+
+        tensor = [[float(num) for num in line.split("\t")] for line in tensor]
+        for index, title in enumerate(metadata):
+            if title == method:
+                delta = tensor[index]
+        
+        score = calculate_mrm_ndfrt_delta(term_embedding, new_cui_list, query_to_targets, delta, k)
+        print(f"{method}: {score}")
+    return score
+
+
+def mrm_ndfrt(embedding_list, embedding_type_list, concept_filename, k=40, check_intersection=True):
+    if check_intersection:
+        if not os.path.exists("intersection.txt"):
+            intersection_cui = get_intersection(
+                embedding_list, embedding_type_list)
+            with open("intersection.txt", "w", encoding="utf-8") as f:
+                for cui in intersection_cui:
+                    f.write(cui.strip() + "\n")
+        else:
+            with open("intersection.txt", "r", encoding="utf-8") as f:
+                lines = f.readlines()
+            intersection_cui = [line.strip() for line in lines]
+
+    query_to_targets, cui_list = get_drug_diseases_to_check(concept_filename)
+    umls = UMLS("../../umls", only_load_dict=True) # source_range='SNOMEDCT_US')#, only_load_dict=True)
+
+    if check_intersection:
+        cui_list = [cui for cui in cui_list if cui in intersection_cui]
+
+    #cui_list = [cui for cui in umls.cui2str if umls.cui2sty[cui] in sty_list]
+    #cui_list = [cui for cui in cui_list if cui in umls.sty_list]
+
+    """
+    for cui in cui_list:
+        if not cui in umls.cui2str:
+            print(cui)
+
+    ipdb.set_trace()
+    """
+
+    opt = []
+    """
+    # Origin
+    print("ORIGIN")
+    for index, embedding in enumerate(embedding_list):
+        if embedding_type_list[index].lower() == "cui":
+            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "origin"))
+        if embedding_type_list[index].lower() == "word":
+            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "origin"))
+        if embedding_type_list[index].lower() == "bert":
+            #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+            #                     query_to_targets, k, "origin", summary_method="MEAN"))
+            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                 query_to_targets, k, "origin", summary_method="CLS"))
+
+    # For UMLSBert
+    for index, embedding in enumerate(embedding_list):
+        if embedding_type_list[index].lower() == "bert":
+            print("BETA")
+            beta_path = os.path.join(embedding, "run", "1000000", "rel embedding")
+            if os.path.exists(beta_path):
+                if concept_filename.find('treat') >= 0:
+                    method = "may_treat"
+                else:
+                    method = "may_prevent"
+                #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                #                 query_to_targets, k, method, summary_method="MEAN"))
+                opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                 query_to_targets, k, method, summary_method="CLS"))                
+
+    # For average and max
+
+    print("ALL")
+    for index, embedding in enumerate(embedding_list):
+        if embedding_type_list[index].lower() == "cui":
+            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "all"))
+        if embedding_type_list[index].lower() == "word":
+            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "all"))
+        if embedding_type_list[index].lower() == "bert":
+            #opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+            #                     query_to_targets, k, "all", summary_method="MEAN"))
+            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                 query_to_targets, k, "all", summary_method="CLS"))
+    """
+    for index, embedding in enumerate(embedding_list):
+        if embedding_type_list[index].lower() == "cui":
+            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "origin"))
+            opt.append(mrm_ndfrt_cui(embedding, umls, cui_list, query_to_targets, k, "all"))
+        if embedding_type_list[index].lower() == "word":
+            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "origin"))
+            opt.append(mrm_ndfrt_word(embedding, umls, cui_list, query_to_targets, k, "all"))
+        if embedding_type_list[index].lower() == "bert":
+            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                      query_to_targets, k, "origin", summary_method="CLS"))
+            beta_path = os.path.join(embedding, "run", "1000000", "rel embedding")
+            if os.path.exists(beta_path):
+                if concept_filename.find('treat') >= 0:
+                    method = "may_treat"
+                else:
+                    method = "may_prevent"  
+                opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                            query_to_targets, k, method, summary_method="CLS"))
+            opt.append(mrm_ndfrt_bert(embedding, umls, cui_list,
+                                      query_to_targets, k, "all", summary_method="CLS"))                   
+
+    return opt
+
+def load_embedding(filename):
+    print(filename)
+    if filename.find('bin') >= 0:
+        from gensim import models
+        W = models.KeyedVectors.load_word2vec_format(filename, binary=True)
+        dim = W.vector_size
+        return W, dim
+
+    if filename.find('pkl') >= 0:
+        import pickle
+        with open(filename, 'rb') as f:
+            W = pickle.load(f)
+        for key, value in W.items():
+            W[key] = np.array(list(map(float, value[1:-1].split(","))))
+        dim = len(list(W.values())[0])
+        return W, dim
+
+    W = {}
+    with open(filename, 'r') as f:
+        for i, line in enumerate(f.readlines()):
+            if i == 0:
+                continue
+            toks = line.strip().split()
+            w = toks[0]
+            vec = np.array(list(map(float, toks[1:])))
+            W[w] = vec
+    dim = len(list(W.values())[0])
+    return W, dim
+
+
+def load_bert(model_name_or_path):
+    print(model_name_or_path)
+    try:
+        config = AutoConfig.from_pretrained(model_name_or_path)
+        model = AutoModel.from_pretrained(
+            model_name_or_path, config=config).to(device)
+    except BaseException:
+        model = torch.load(os.path.join(
+            model_name_or_path, 'pytorch_model.bin')).to(device)
+
+    try:
+        model.output_hidden_states = False
+    except BaseException:
+        pass
+
+    try:
+        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
+    except BaseException:
+        tokenizer = AutoTokenizer.from_pretrained(
+            os.path.join(model_name_or_path, "../"))
+    return model, tokenizer
+
+def get_intersection(embedding_list, embedding_type_list):
+    intersection_cui = set()
+    checker = True
+    for index, embed in enumerate(embedding_list):
+        if embedding_type_list[index] == "cui":
+            w, _ = load_embedding(embed)
+            if checker:
+                intersection_cui = set(list(w.keys()))
+                checker = False
+            else:
+                intersection_cui = set(
+                    list(w.keys())).intersection(intersection_cui)
+    print(f"Intersection count: {len(intersection_cui)}")
+    return list(intersection_cui)
+
+if __name__ == "__main__":
+    """
+    embedding_list = ["../../embeddings/claims_codes_hs_300.txt",
+                      "../../embeddings/GoogleNews-vectors-negative300.bin",
+                      "../../models/2020_eng"]
+    embedding_type_list = ["cui", "word", "bert"]
+    mrm_ndfrt(embedding_list, embedding_type_list, "may_prevent_cui.txt", check_intersection=False)
+    """
+
+    embedding_list = ["../../embeddings/wikipedia-pubmed-and-PMC-w2v.bin",
+                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-2.bin",
+                      "../../embeddings/bio_nlp_vec/PubMed-shuffle-win-30.bin"]
+    embedding_type_list = ["word", "word", "word"]
+    embedding_list += ["../../embeddings/DeVine_etal_200.txt",
+                      "/home/yz/pretraining_models/cui2vec.pkl"]
+    embedding_type_list += ["cui", "cui"]
+    embedding_list += ["../../models/2020_all",
+                      "/home/yz/pretraining_models/bert-base-cased",
+                      "/home/yz/pretraining_models/biobert_v1.1",
+                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract",
+                      "/home/yz/pretraining_models/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
+                      "/home/yz/pretraining_models/kexinghuang_clinical",
+                      "emilyalsentzer/Bio_ClinicalBERT"]
+    embedding_type_list += ["bert"] * 7
+
+    #mrm_ndfrt(embedding_list, embedding_type_list, "may_treat_cui.txt", check_intersection=True)
+    mrm_ndfrt(embedding_list, embedding_type_list, "may_treat_cui.txt", check_intersection=False)
+    #mrm_ndfrt(embedding_list[-6:], embedding_type_list[-6:], "may_prevent_cui.txt", check_intersection=True)
+    #mrm_ndfrt(embedding_list, embedding_type_list, "may_prevent_cui.txt", check_intersection=False)
\ No newline at end of file