a b/pretrain/train.py
1
from data_util import UMLSDataset, fixed_length_dataloader
2
from model import UMLSPretrainedModel
3
from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
4
from tqdm import tqdm, trange
5
import torch
6
from torch import nn
7
import time
8
import os
9
import numpy as np
10
import argparse
11
import time
12
import pathlib
13
#import ipdb
14
# try:
15
#     from torch.utils.tensorboard import SummaryWriter
16
# except:
17
from tensorboardX import SummaryWriter
18
19
20
def train(args, model, train_dataloader, umls_dataset):
21
    writer = SummaryWriter(comment='umls')
22
23
    t_total = args.max_steps
24
25
    no_decay = ["bias", "LayerNorm.weight"]
26
    optimizer_grouped_parameters = [
27
        {
28
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
29
            "weight_decay": args.weight_decay,
30
        },
31
        {"params": [p for n, p in model.named_parameters() if any(
32
            nd in n for nd in no_decay)], "weight_decay": 0.0},
33
    ]
34
35
    optimizer = AdamW(optimizer_grouped_parameters,
36
                      lr=args.learning_rate, eps=args.adam_epsilon)
37
    args.warmup_steps = int(args.warmup_steps)
38
    if args.schedule == 'linear':
39
        scheduler = get_linear_schedule_with_warmup(
40
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
41
        )
42
    if args.schedule == 'constant':
43
        scheduler = get_constant_schedule_with_warmup(
44
            optimizer, num_warmup_steps=args.warmup_steps
45
        )
46
    if args.schedule == 'cosine':
47
        scheduler = get_cosine_schedule_with_warmup(
48
            optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
49
        )
50
51
    print("***** Running training *****")
52
    print("  Total Steps =", t_total)
53
    print("  Steps needs to be trained=", t_total - args.shift)
54
    print("  Instantaneous batch size per GPU =", args.train_batch_size)
55
    print(
56
        "  Total train batch size (w. parallel, distributed & accumulation) =",
57
        args.train_batch_size
58
        * args.gradient_accumulation_steps,
59
    )
60
    print("  Gradient Accumulation steps =", args.gradient_accumulation_steps)
61
62
    model.zero_grad()
63
64
    for i in range(args.shift):
65
        scheduler.step()
66
    global_step = args.shift
67
68
    while True:
69
        model.train()
70
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", ascii=True)
71
        batch_loss = 0.
72
        batch_sty_loss = 0.
73
        batch_cui_loss = 0.
74
        batch_re_loss = 0.
75
        for _, batch in enumerate(epoch_iterator):
76
            input_ids_0 = batch[0].to(args.device)
77
            input_ids_1 = batch[1].to(args.device)
78
            input_ids_2 = batch[2].to(args.device)
79
            cui_label_0 = batch[3].to(args.device)
80
            cui_label_1 = batch[4].to(args.device)
81
            cui_label_2 = batch[5].to(args.device)
82
            sty_label_0 = batch[6].to(args.device)
83
            sty_label_1 = batch[7].to(args.device)
84
            sty_label_2 = batch[8].to(args.device)
85
            # use batch[9] for re, use batch[10] for rel
86
            if args.use_re:
87
                re_label = batch[9].to(args.device)
88
            else:
89
                re_label = batch[10].to(args.device)
90
            # for item in batch:
91
            #     print(item.shape)
92
93
            loss, (sty_loss, cui_loss, re_loss) = \
94
                model(input_ids_0, input_ids_1, input_ids_2,
95
                      cui_label_0, cui_label_1, cui_label_2,
96
                      sty_label_0, sty_label_1, sty_label_2,
97
                      re_label)
98
            batch_loss = float(loss.item())
99
            batch_sty_loss = float(sty_loss.item())
100
            batch_cui_loss = float(cui_loss.item())
101
            batch_re_loss = float(re_loss.item())
102
103
            # tensorboardX
104
            writer.add_scalar(
105
                'rel_count', train_dataloader.batch_sampler.rel_sampler_count, global_step=global_step)
106
            writer.add_scalar('batch_loss', batch_loss,
107
                              global_step=global_step)
108
            writer.add_scalar('batch_sty_loss', batch_sty_loss,
109
                              global_step=global_step)
110
            writer.add_scalar('batch_cui_loss', batch_cui_loss,
111
                              global_step=global_step)
112
            writer.add_scalar('batch_re_loss', batch_re_loss,
113
                              global_step=global_step)
114
115
            if args.gradient_accumulation_steps > 1:
116
                loss = loss / args.gradient_accumulation_steps
117
            loss.backward()
118
119
            epoch_iterator.set_description("Rel_count: %s, Loss: %0.4f, Sty: %0.4f, Cui: %0.4f, Re: %0.4f" %
120
                                           (train_dataloader.batch_sampler.rel_sampler_count, batch_loss, batch_sty_loss, batch_cui_loss, batch_re_loss))
121
122
            if (global_step + 1) % args.gradient_accumulation_steps == 0:
123
                torch.nn.utils.clip_grad_norm_(
124
                    model.parameters(), args.max_grad_norm)
125
                optimizer.step()
126
                scheduler.step()  # Update learning rate schedule
127
                model.zero_grad()
128
129
            global_step += 1
130
            if global_step % args.save_step == 0 and global_step > 0:
131
                save_path = os.path.join(
132
                    args.output_dir, f'model_{global_step}.pth')
133
                torch.save(model, save_path)
134
135
                # re_embedding
136
                if args.use_re:
137
                    writer.add_embedding(model.re_embedding.weight, metadata=umls_dataset.re2id.keys(
138
                    ), global_step=global_step, tag="re embedding")
139
                else:
140
                    # print(len(umls_dataset.rel2id))
141
                    # print(model.re_embedding.weight.shape)
142
                    writer.add_embedding(model.re_embedding.weight, metadata=umls_dataset.rel2id.keys(
143
                    ), global_step=global_step, tag="rel embedding")
144
145
                # sty_parameter
146
                writer.add_embedding(model.linear_sty.weight, metadata=umls_dataset.sty2id.keys(
147
                ), global_step=global_step, tag="sty weight")
148
149
            if args.max_steps > 0 and global_step > args.max_steps:
150
                return None
151
152
    return None
153
154
155
def run(args):
156
    torch.manual_seed(args.seed)  # cpu
157
    torch.cuda.manual_seed(args.seed)  # gpu
158
    np.random.seed(args.seed)  # numpy
159
    torch.backends.cudnn.deterministic = True  # cudnn
160
161
    #args.output_dir = args.output_dir + "_" + str(int(time.time()))
162
163
    # dataloader
164
    if args.lang == "eng":
165
        lang = ["ENG"]
166
    if args.lang == "all":
167
        lang = None
168
        assert args.model_name_or_path.find("bio") == -1, "Should use multi-language model"
169
    umls_dataset = UMLSDataset(
170
        umls_folder=args.umls_dir, model_name_or_path=args.model_name_or_path, lang=lang, json_save_path=args.output_dir)
171
    umls_dataloader = fixed_length_dataloader(
172
        umls_dataset, fixed_length=args.train_batch_size, num_workers=args.num_workers)
173
174
    if args.use_re:
175
        rel_label_count = len(umls_dataset.re2id)
176
    else:
177
        rel_label_count = len(umls_dataset.rel2id)
178
179
    model_load = False
180
    if os.path.exists(args.output_dir):
181
        save_list = []
182
        for f in os.listdir(args.output_dir):
183
            if f[0:5] == "model" and f[-4:] == ".pth":
184
                save_list.append(int(f[6:-4]))
185
        if len(save_list) > 0:
186
            args.shift = max(save_list)
187
            if os.path.exists(os.path.join(args.output_dir, 'last_model.pth')):
188
                model = torch.load(os.path.join(
189
                    args.output_dir, 'last_model.pth')).to(args.device)
190
                model_load = True
191
            else:
192
                model = torch.load(os.path.join(
193
                    args.output_dir, f'model_{max(save_list)}.pth')).to(args.device)
194
                model_load = True
195
    if not model_load:
196
        if not os.path.exists(args.output_dir):
197
            os.makedirs(args.output_dir)
198
        model = UMLSPretrainedModel(device=args.device, model_name_or_path=args.model_name_or_path,
199
                                    cui_label_count=len(umls_dataset.cui2id),
200
                                    rel_label_count=rel_label_count,
201
                                    sty_label_count=len(umls_dataset.sty2id),
202
                                    re_weight=args.re_weight,
203
                                    sty_weight=args.sty_weight).to(args.device)
204
        args.shift = 0
205
        model_load = True
206
207
    if args.do_train:
208
        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
209
        train(args, model, umls_dataloader, umls_dataset)
210
        torch.save(model, os.path.join(args.output_dir, 'last_model.pth'))
211
212
    return None
213
214
215
def main():
216
    parser = argparse.ArgumentParser()
217
    parser.add_argument(
218
        "--umls_dir",
219
        default="../umls",
220
        type=str,
221
        help="UMLS dir",
222
    )
223
    parser.add_argument(
224
        "--model_name_or_path",
225
        default="../biobert_v1.1",
226
        type=str,
227
        help="Path to pre-trained model or shortcut name selected in the list: ",
228
    )
229
    parser.add_argument(
230
        "--output_dir",
231
        default="output",
232
        type=str,
233
        help="The output directory where the model predictions and checkpoints will be written.",
234
    )
235
    parser.add_argument(
236
        "--save_step",
237
        default=25000,
238
        type=int,
239
        help="Save step",
240
    )
241
242
    # Other parameters
243
    parser.add_argument(
244
        "--max_seq_length",
245
        default=32,
246
        type=int,
247
        help="The maximum total input sequence length after tokenization. Sequences longer "
248
        "than this will be truncated, sequences shorter will be padded.",
249
    )
250
    parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.")
251
    parser.add_argument(
252
        "--train_batch_size", default=256, type=int, help="Batch size per GPU/CPU for training.",
253
    )
254
    parser.add_argument(
255
        "--gradient_accumulation_steps",
256
        type=int,
257
        default=8,
258
        help="Number of updates steps to accumulate before performing a backward/update pass.",
259
    )
260
    parser.add_argument("--learning_rate", default=2e-5,
261
                        type=float, help="The initial learning rate for Adam.")
262
    parser.add_argument("--weight_decay", default=0.01,
263
                        type=float, help="Weight decay if we apply some.")
264
    parser.add_argument("--adam_epsilon", default=1e-8,
265
                        type=float, help="Epsilon for Adam optimizer.")
266
    parser.add_argument("--max_grad_norm", default=1.0,
267
                        type=float, help="Max gradient norm.")
268
    parser.add_argument(
269
        "--max_steps",
270
        default=1000000,
271
        type=int,
272
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
273
    )
274
    parser.add_argument("--warmup_steps", default=10000,
275
                        help="Linear warmup over warmup_steps or a float.")
276
    parser.add_argument("--device", type=str, default='cuda:1', help="device")
277
    parser.add_argument("--seed", type=int, default=72,
278
                        help="random seed for initialization")
279
    parser.add_argument("--schedule", type=str, default="linear",
280
                        choices=["linear", "cosine", "constant"], help="Schedule.")
281
    parser.add_argument("--trans_margin", type=float, default=1.0,
282
                        help="Margin of TransE.")
283
    parser.add_argument("--use_re", default=False, type=bool,
284
                        help="Whether to use re or rel.")
285
    parser.add_argument("--num_workers", default=1, type=int,
286
                        help="Num workers for data loader, only 0 can be used for Windows")
287
    parser.add_argument("--lang", default='eng', type=str, choices=["eng", "all"],
288
                        help="language range, eng or all")
289
    parser.add_argument("--sty_weight", type=float, default=0.0,
290
                        help="Weight of sty.")
291
    parser.add_argument("--re_weight", type=float, default=1.0,
292
                        help="Weight of re.")
293
294
    args = parser.parse_args()
295
296
    run(args)
297
298
299
if __name__ == "__main__":
300
    main()