--- a
+++ b/test/mantra/test.py
@@ -0,0 +1,144 @@
+from gensim import models
+import os
+import sys
+sys.path.append("../../")
+from pretrain.load_umls import UMLS
+import torch
+import numpy as np
+from transformers import AutoTokenizer, AutoModel, AutoConfig
+from data_util import load
+import tqdm
+
+batch_size = 128
+device = "cuda:0"
+
+
+def get_umls():
+    umls_label = []
+    umls_label_set = set()
+    umls_des = []
+    umls = UMLS("../../umls", source_range=["MSH", "SNOMEDCT_US", "MDR"], only_load_dict=True)
+    for cui in tqdm.tqdm(umls.cui2str):
+        if not cui in umls_label_set:
+            tmp_str = list(umls.cui2str[cui])
+            umls_label.extend([cui] * len(tmp_str))
+            umls_des.extend(tmp_str)
+            umls_label_set.update([cui])
+    print(len(umls_des))
+    return umls_label, umls_des
+
+
+def main(filename, summary_method, umls_label, umls_des):
+    try:
+        config = AutoConfig.from_pretrained(filename)
+        model = AutoModel.from_pretrained(
+            filename, config=config).to(device)
+    except BaseException:
+        model = torch.load(os.path.join(
+            filename, 'pytorch_model.bin')).to(device)
+
+    try:
+        model.output_hidden_states = False
+    except BaseException:
+        pass
+
+    try:
+        tokenizer = AutoTokenizer.from_pretrained(filename)
+    except BaseException:
+        tokenizer = AutoTokenizer.from_pretrained(
+            os.path.join(filename, "../"))
+
+    corpus_list = [("Medline", "es"), ("Medline", "fr"), ("Medline", "nl"), ("Medline", "de"),
+                   ("EMEA", "es"), ("EMEA", "fr"), ("EMEA", "nl"), ("EMEA", "de"),
+                   ("Patent", "fr"), ("Patent", "de")]
+    """
+    sty_list = ["Geographic Area",
+                "Drug Delivery Device", "Medical Device", "Research Device",
+                "Anatomical Abnormality", "Anatomical Structure", "Fully Formed Anatomical Structure",
+                "Chemical", "Chemical Viewed Functionally", "Chemical Viewed Structurally", "Inorganic Chemical", "Organic Chemical", "Clinical Drug"]
+    """
+    result_dict = {}
+    umls_embedding = get_bert_embed(umls_des, model, tokenizer, summary_method=summary_method, tqdm_bar=True)
+
+    for corpus in corpus_list:
+        output_text, output_label, label_set = load(dataset=corpus[0], lang=corpus[1])
+        not_umls_label = [label for label in label_set if not label in umls_label]
+        print(f"Count of not appearing in UMLS subset: {len(not_umls_label)}")
+        text_embedding = get_bert_embed(output_text, model, tokenizer, summary_method=summary_method)
+        predict_label = predict(text_embedding, umls_embedding, umls_label)
+        p, r, f1 = metric(output_label, predict_label)
+        result_dict[corpus[0] + "|" + corpus[1]] = (p, r, f1)
+        print(p, r, f1)
+
+    return result_dict
+
+def predict(text_embedding, umls_embedding, umls_label):
+    x_size = text_embedding.size(0)
+    sim = torch.matmul(text_embedding, umls_embedding.t())
+    most_similar = torch.max(sim, dim=1)[1]
+    return [umls_label[idx] for idx in most_similar]
+
+
+def metric(output_label, predict_label):
+    predict_count = 0
+    true_count = 0
+    correct_count = 0
+    for idx in range(len(output_label)):
+        if isinstance(predict_label[idx], str):
+            predict_label[idx] = [predict_label[idx]]
+        if isinstance(output_label[idx], str):
+            output_label[idx] = [output_label[idx]]
+        predict_count += len(predict_label[idx])
+        true_count += len(output_label[idx])
+        for pred in predict_label[idx]:
+            if pred in output_label[idx]:
+                correct_count += 1
+
+    p = correct_count / predict_count
+    r = correct_count / true_count
+    if p == 0. or r == 0.:
+        f1 = 0.
+    else:
+        f1 = 2 * p * r / (p + r)
+    return p, r, f1
+
+
+def get_bert_embed(phrase_list, m, tok, normalize=True, summary_method="CLS", tqdm_bar=False):
+    input_ids = []
+    for phrase in phrase_list:
+        input_ids.append(tok.encode_plus(
+            phrase, max_length=32, add_special_tokens=True,
+            truncation=True, pad_to_max_length=True)['input_ids'])
+    m.eval()
+
+    count = len(input_ids)
+    now_count = 0
+    with torch.no_grad():
+        if tqdm_bar:
+            pbar = tqdm.tqdm(total=count)
+        while now_count < count:
+            input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
+                now_count + batch_size, count)]).to(device)
+            if summary_method == "CLS":
+                embed = m(input_gpu_0)[1]
+            if summary_method == "MEAN":
+                embed = torch.mean(m(input_gpu_0)[0], dim=1)
+            if normalize:
+                embed_norm = torch.norm(
+                    embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
+                embed = embed / embed_norm
+            if now_count == 0:
+                output = embed
+            else:
+                output = torch.cat((output, embed), dim=0)
+            if tqdm_bar:
+                pbar.update(min(now_count + batch_size, count) - now_count)
+            now_count = min(now_count + batch_size, count)
+        if tqdm_bar:
+            pbar.close()
+    return output
+
+
+if __name__ == '__main__':
+    umls_label, umls_des = get_umls()
+    main("bert-base-multilingual-cased", "CLS", umls_label, umls_des)