--- a
+++ b/coderpp/train/train.py
@@ -0,0 +1,328 @@
+from dataset import UMLSDataset, my_collate_fn
+from model import UMLSFinetuneModel
+from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
+from tqdm import tqdm, trange
+import torch
+from torch import nn
+import time
+import os
+import numpy as np
+import argparse
+import time
+import pathlib
+from torch.utils.data import DataLoader
+import pickle
+from torch.utils.tensorboard import SummaryWriter
+from test_faiss import find_new_index
+from accelerate import Accelerator
+#import ipdb
+# try:
+#     from torch.utils.tensorboard import SummaryWriter
+# except:
+# from tensorboardX import SummaryWriter
+
+
+def train(args, model, train_dataloader, loss_list):
+    writer = SummaryWriter(comment='umls')
+
+    t_total = args.max_steps
+
+    no_decay = ["bias", "LayerNorm.weight"]
+    optimizer_grouped_parameters = [
+        {
+            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+            "weight_decay": args.weight_decay,
+        },
+        {"params": [p for n, p in model.named_parameters() if any(
+            nd in n for nd in no_decay)], "weight_decay": 0.0},
+    ]
+
+    optimizer = AdamW(optimizer_grouped_parameters,
+                      lr=args.learning_rate, eps=args.adam_epsilon)
+    
+    if args.use_multi_gpu:
+        model, optimizer, train_dataloader = args.accelerator.prepare(model, optimizer, train_dataloader)
+
+    args.warmup_steps = int(args.warmup_steps)
+    if args.schedule == 'linear':
+        scheduler = get_linear_schedule_with_warmup(
+            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
+        )
+    if args.schedule == 'constant':
+        scheduler = get_constant_schedule_with_warmup(
+            optimizer, num_warmup_steps=args.warmup_steps
+        )
+    if args.schedule == 'cosine':
+        scheduler = get_cosine_schedule_with_warmup(
+            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
+        )
+
+    args.accelerator.print("***** Running training *****")
+    args.accelerator.print("  Total Steps =", t_total)
+    args.accelerator.print("  Steps needs to be trained=", t_total - args.shift)
+    args.accelerator.print("  Instantaneous batch size per GPU =", args.train_batch_size)
+    args.accelerator.print(
+        "  Total train batch size (w. parallel, distributed & accumulation) =",
+        args.train_batch_size
+        * args.gradient_accumulation_steps * args.nr_gpus,
+    )
+    args.accelerator.print("  Gradient Accumulation steps =", args.gradient_accumulation_steps)
+
+    model.zero_grad()
+
+    for i in range(args.shift):
+        scheduler.step()
+    global_step = args.shift
+
+    while True:
+        model.train()
+        epoch_iterator = tqdm(train_dataloader, desc="Iteration", ascii=True, disable=not args.accelerator.is_local_main_process)
+        batch_loss = 0.
+        batch_sty_loss = 0.
+        batch_cui_loss = 0.
+        batch_re_loss = 0.
+        for _, batch in enumerate(epoch_iterator):
+            input_ids = batch[0].to(args.device)
+            cui_label = batch[1].to(args.device)
+            attention_mask = batch[2].to(args.device)
+            # for item in batch:
+            #     print(item.shape)
+
+            loss = model(input_ids, cui_label, attention_mask)
+            batch_loss = float(loss.item())
+            loss_list.append(batch_loss)
+            # tensorboardX
+            writer.add_scalar('batch_loss', batch_loss,
+                              global_step=global_step)
+
+            if args.gradient_accumulation_steps > 1:
+                loss = loss / args.gradient_accumulation_steps
+            
+
+            args.accelerator.backward(loss)
+
+            epoch_iterator.set_description("Loss: %0.4f" % batch_loss)
+
+            if (global_step + 1) % args.gradient_accumulation_steps == 0:
+                torch.nn.utils.clip_grad_norm_(
+                    model.parameters(), args.max_grad_norm)
+                optimizer.step()
+                scheduler.step()  # Update learning rate schedule
+                model.zero_grad()
+            del input_ids
+            del cui_label
+            del attention_mask
+            global_step += 1
+            if global_step % args.save_step == 0 and global_step > 0:
+                save_path = os.path.join(
+                    args.output_dir, f'model_{global_step}.pth')
+                unwrapped_model = args.accelerator.unwrap_model(model)
+                torch.save(unwrapped_model.bert, save_path)
+                with open(os.path.join(args.output_dir, f'loss_{global_step}.pkl'), 'wb') as f:
+                    pickle.dump(loss_list, f)
+                torch.cuda.empty_cache()
+            if args.max_steps > 0 and global_step % args.faiss_step == 0:
+                return None
+
+    return None
+
+
+def run(args):
+    torch.manual_seed(args.seed)  # cpu
+    torch.cuda.manual_seed(args.seed)  # gpu
+    np.random.seed(args.seed)  # numpy
+    torch.backends.cudnn.deterministic = True  # cudnn
+
+    if os.path.exists(args.output_dir):
+        indices_list = []
+        for f in os.listdir(args.output_dir):
+            if f[0:7] == "indices" and f[-4:] == ".npy":
+                indices_list.append(int(f[8:-4]))
+        if len(indices_list) > 0:
+            i = max(indices_list)
+            args.indices_path = os.path.join(args.output_dir, f'indices_{i}.npy')
+            i += 1
+        else:
+            i = 0
+    else:
+        i = 0
+    #args.output_dir = args.output_dir + "_" + str(int(time.time()))
+    while i < (int(args.max_steps / args.faiss_step)):
+        print('i', i)
+        args.accelerator = Accelerator()
+        args.device = args.accelerator.device
+        # dataloader
+        if args.lang == "eng":
+            lang = ["ENG"]
+        if args.lang == "all":
+            lang = None
+            assert args.model_name_or_path.find("bio") == -1, "Should use multi-language model"
+        args.accelerator.print(args.indices_path)
+        umls_dataset = UMLSDataset(umls_folder=args.umls_dir, 
+                                model_name_or_path=args.model_name_or_path, 
+                                idx2phrase_path=args.idx2phrase_path, 
+                                phrase2idx_path=args.phrase2idx_path,
+                                indices_path=args.indices_path,
+                                max_length=args.max_seq_length)
+        umls_dataloader = DataLoader(umls_dataset,
+                                    batch_size=args.train_batch_size, 
+                                    shuffle=True,
+                                    num_workers=args.num_workers, 
+                                    pin_memory=True, 
+                                    drop_last=True,
+                                    collate_fn=my_collate_fn)
+        model_load = False
+        if os.path.exists(args.output_dir):
+            save_list = []
+            for f in os.listdir(args.output_dir):
+                if f[0:5] == "model" and f[-4:] == ".pth":
+                    save_list.append(int(f[6:-4]))
+            if len(save_list) > 0:
+                args.shift = max(save_list)
+                with open(os.path.join(args.output_dir, 'loss_'+str(args.shift)+'.pkl'), 'rb') as f:
+                    loss_list = pickle.load(f)
+                model = UMLSFinetuneModel(device=args.device, 
+                                    model_name_or_path=args.model_name_or_path, 
+                                    cui_label_count=len(umls_dataset.cui2idx)).to(args.device)
+                if os.path.exists(os.path.join(args.output_dir, 'last_model.pth')):
+                    model.bert = torch.load(os.path.join(
+                        args.output_dir, 'last_model.pth')).to(args.device)
+                    model_load = True
+                    args.accelerator.print('load last model')
+                else:
+                    model.bert = torch.load(os.path.join(
+                        args.output_dir, f'model_{max(save_list)}.pth')).to(args.device)
+                    model_load = True
+                    args.accelerator.print('load model{}'.format(max(save_list)))
+        if not model_load:
+            if not os.path.exists(args.output_dir):
+                os.makedirs(args.output_dir)
+            model = UMLSFinetuneModel(device=args.device, 
+                                    model_name_or_path=args.model_name_or_path, 
+                                    cui_label_count=len(umls_dataset.cui2idx)).to(args.device)
+            args.shift = 0
+            model_load = True
+            loss_list = []
+        
+        if args.do_train:
+            torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
+            train(args, model, umls_dataloader, loss_list)
+            args.accelerator.wait_for_everyone()
+            unwrapped_model = args.accelerator.unwrap_model(model)
+            args.accelerator.save(unwrapped_model.bert, os.path.join(args.output_dir, 'last_model.pth'))
+            del model
+            torch.cuda.empty_cache()
+            new_CODER_path = os.path.join(args.output_dir, 'last_model.pth')
+            new_indices_path = os.path.join(args.output_dir, f'indices_{i}.npy')
+            new_similarity_path = os.path.join(args.output_dir, f'similarity_{i}.npy')
+            if args.accelerator.is_main_process:
+                find_new_index(new_CODER_path, new_indices_path, new_similarity_path)
+            args.accelerator.wait_for_everyone()
+            args.indices_path = new_indices_path
+        i += 1
+    return None
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--umls_dir",
+        default="../umls",
+        type=str,
+        help="UMLS dir",
+    )
+    parser.add_argument(
+        "--model_name_or_path",
+        default="GanjinZero/UMLSBert_ENG",
+        type=str,
+        help="Path to pre-trained model or shortcut name selected in the list: ",
+    )
+    parser.add_argument(
+        "--idx2phrase_path",
+        default="data/idx2string.pkl",
+        type=str,
+        help="Path to dict index: phrase"
+    )
+    parser.add_argument(
+        "--phrase2idx_path",
+        default="data/string2idx.pkl",
+        type=str,
+        help="Path to dict phrase: index"
+    )
+    parser.add_argument(
+        "--indices_path",
+        default="data/indices.npy",
+        type=str,
+        help="Path to indices obtained by faiss"
+    )
+    parser.add_argument(
+        "--output_dir",
+        default="output_30pos",
+        type=str,
+        help="The output directory where the model predictions and checkpoints will be written.",
+    )
+    parser.add_argument(
+        "--faiss_step",
+        default=100000,
+        type=int,
+        help="Faiss step",
+    )
+    parser.add_argument(
+        "--save_step",
+        default=25000,
+        type=int,
+        help="Save step",
+    )
+
+    # Other parameters
+    parser.add_argument(
+        "--max_seq_length",
+        default=32,
+        type=int,
+        help="The maximum total input sequence length after tokenization. Sequences longer "
+        "than this will be truncated, sequences shorter will be padded.",
+    )
+    parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.")
+    parser.add_argument(
+        "--train_batch_size", default=4, type=int, help="Batch size of groups per GPU/CPU for training. A group contains 32 strings (current + top30 + same cui).",
+    )
+    parser.add_argument(
+        "--gradient_accumulation_steps",
+        type=int,
+        default=8,
+        help="Number of updates steps to accumulate before performing a backward/update pass.",
+    )
+    parser.add_argument("--learning_rate", default=2e-5,
+                        type=float, help="The initial learning rate for Adam.")
+    parser.add_argument("--weight_decay", default=0.01,
+                        type=float, help="Weight decay if we apply some.")
+    parser.add_argument("--adam_epsilon", default=1e-8,
+                        type=float, help="Epsilon for Adam optimizer.")
+    parser.add_argument("--max_grad_norm", default=1.0,
+                        type=float, help="Max gradient norm.")
+    parser.add_argument(
+        "--max_steps",
+        default=400000,
+        type=int,
+        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
+    )
+    parser.add_argument("--warmup_steps", default=40000,
+                        help="Linear warmup over warmup_steps or a float.")
+    parser.add_argument("--device", type=str, default='cuda:1', help="device")
+    parser.add_argument("--seed", type=int, default=72,
+                        help="random seed for initialization")
+    parser.add_argument("--schedule", type=str, default="linear",
+                        choices=["linear", "cosine", "constant"], help="Schedule.")
+    parser.add_argument("--num_workers", default=1, type=int,
+                        help="Num workers for data loader, only 0 can be used for Windows")
+    parser.add_argument("--lang", default='eng', type=str, choices=["eng", "all"],
+                        help="language range, eng or all")
+
+    parser.add_argument("--use_multi_gpu", type=bool, default=False, help="Use multi-gpu or not")
+    parser.add_argument("--nr_gpus", type=int, default=2, help="number of gpus")
+    args = parser.parse_args()
+
+    run(args)
+
+
+if __name__ == "__main__":
+    main()