|
a |
|
b/pretrain/train.py |
|
|
1 |
from data_util import UMLSDataset, fixed_length_dataloader |
|
|
2 |
from model import UMLSPretrainedModel |
|
|
3 |
from transformers import AdamW, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup |
|
|
4 |
from tqdm import tqdm, trange |
|
|
5 |
import torch |
|
|
6 |
from torch import nn |
|
|
7 |
import time |
|
|
8 |
import os |
|
|
9 |
import numpy as np |
|
|
10 |
import argparse |
|
|
11 |
import time |
|
|
12 |
import pathlib |
|
|
13 |
#import ipdb |
|
|
14 |
# try: |
|
|
15 |
# from torch.utils.tensorboard import SummaryWriter |
|
|
16 |
# except: |
|
|
17 |
from tensorboardX import SummaryWriter |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def train(args, model, train_dataloader, umls_dataset): |
|
|
21 |
writer = SummaryWriter(comment='umls') |
|
|
22 |
|
|
|
23 |
t_total = args.max_steps |
|
|
24 |
|
|
|
25 |
no_decay = ["bias", "LayerNorm.weight"] |
|
|
26 |
optimizer_grouped_parameters = [ |
|
|
27 |
{ |
|
|
28 |
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
|
|
29 |
"weight_decay": args.weight_decay, |
|
|
30 |
}, |
|
|
31 |
{"params": [p for n, p in model.named_parameters() if any( |
|
|
32 |
nd in n for nd in no_decay)], "weight_decay": 0.0}, |
|
|
33 |
] |
|
|
34 |
|
|
|
35 |
optimizer = AdamW(optimizer_grouped_parameters, |
|
|
36 |
lr=args.learning_rate, eps=args.adam_epsilon) |
|
|
37 |
args.warmup_steps = int(args.warmup_steps) |
|
|
38 |
if args.schedule == 'linear': |
|
|
39 |
scheduler = get_linear_schedule_with_warmup( |
|
|
40 |
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total |
|
|
41 |
) |
|
|
42 |
if args.schedule == 'constant': |
|
|
43 |
scheduler = get_constant_schedule_with_warmup( |
|
|
44 |
optimizer, num_warmup_steps=args.warmup_steps |
|
|
45 |
) |
|
|
46 |
if args.schedule == 'cosine': |
|
|
47 |
scheduler = get_cosine_schedule_with_warmup( |
|
|
48 |
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total |
|
|
49 |
) |
|
|
50 |
|
|
|
51 |
print("***** Running training *****") |
|
|
52 |
print(" Total Steps =", t_total) |
|
|
53 |
print(" Steps needs to be trained=", t_total - args.shift) |
|
|
54 |
print(" Instantaneous batch size per GPU =", args.train_batch_size) |
|
|
55 |
print( |
|
|
56 |
" Total train batch size (w. parallel, distributed & accumulation) =", |
|
|
57 |
args.train_batch_size |
|
|
58 |
* args.gradient_accumulation_steps, |
|
|
59 |
) |
|
|
60 |
print(" Gradient Accumulation steps =", args.gradient_accumulation_steps) |
|
|
61 |
|
|
|
62 |
model.zero_grad() |
|
|
63 |
|
|
|
64 |
for i in range(args.shift): |
|
|
65 |
scheduler.step() |
|
|
66 |
global_step = args.shift |
|
|
67 |
|
|
|
68 |
while True: |
|
|
69 |
model.train() |
|
|
70 |
epoch_iterator = tqdm(train_dataloader, desc="Iteration", ascii=True) |
|
|
71 |
batch_loss = 0. |
|
|
72 |
batch_sty_loss = 0. |
|
|
73 |
batch_cui_loss = 0. |
|
|
74 |
batch_re_loss = 0. |
|
|
75 |
for _, batch in enumerate(epoch_iterator): |
|
|
76 |
input_ids_0 = batch[0].to(args.device) |
|
|
77 |
input_ids_1 = batch[1].to(args.device) |
|
|
78 |
input_ids_2 = batch[2].to(args.device) |
|
|
79 |
cui_label_0 = batch[3].to(args.device) |
|
|
80 |
cui_label_1 = batch[4].to(args.device) |
|
|
81 |
cui_label_2 = batch[5].to(args.device) |
|
|
82 |
sty_label_0 = batch[6].to(args.device) |
|
|
83 |
sty_label_1 = batch[7].to(args.device) |
|
|
84 |
sty_label_2 = batch[8].to(args.device) |
|
|
85 |
# use batch[9] for re, use batch[10] for rel |
|
|
86 |
if args.use_re: |
|
|
87 |
re_label = batch[9].to(args.device) |
|
|
88 |
else: |
|
|
89 |
re_label = batch[10].to(args.device) |
|
|
90 |
# for item in batch: |
|
|
91 |
# print(item.shape) |
|
|
92 |
|
|
|
93 |
loss, (sty_loss, cui_loss, re_loss) = \ |
|
|
94 |
model(input_ids_0, input_ids_1, input_ids_2, |
|
|
95 |
cui_label_0, cui_label_1, cui_label_2, |
|
|
96 |
sty_label_0, sty_label_1, sty_label_2, |
|
|
97 |
re_label) |
|
|
98 |
batch_loss = float(loss.item()) |
|
|
99 |
batch_sty_loss = float(sty_loss.item()) |
|
|
100 |
batch_cui_loss = float(cui_loss.item()) |
|
|
101 |
batch_re_loss = float(re_loss.item()) |
|
|
102 |
|
|
|
103 |
# tensorboardX |
|
|
104 |
writer.add_scalar( |
|
|
105 |
'rel_count', train_dataloader.batch_sampler.rel_sampler_count, global_step=global_step) |
|
|
106 |
writer.add_scalar('batch_loss', batch_loss, |
|
|
107 |
global_step=global_step) |
|
|
108 |
writer.add_scalar('batch_sty_loss', batch_sty_loss, |
|
|
109 |
global_step=global_step) |
|
|
110 |
writer.add_scalar('batch_cui_loss', batch_cui_loss, |
|
|
111 |
global_step=global_step) |
|
|
112 |
writer.add_scalar('batch_re_loss', batch_re_loss, |
|
|
113 |
global_step=global_step) |
|
|
114 |
|
|
|
115 |
if args.gradient_accumulation_steps > 1: |
|
|
116 |
loss = loss / args.gradient_accumulation_steps |
|
|
117 |
loss.backward() |
|
|
118 |
|
|
|
119 |
epoch_iterator.set_description("Rel_count: %s, Loss: %0.4f, Sty: %0.4f, Cui: %0.4f, Re: %0.4f" % |
|
|
120 |
(train_dataloader.batch_sampler.rel_sampler_count, batch_loss, batch_sty_loss, batch_cui_loss, batch_re_loss)) |
|
|
121 |
|
|
|
122 |
if (global_step + 1) % args.gradient_accumulation_steps == 0: |
|
|
123 |
torch.nn.utils.clip_grad_norm_( |
|
|
124 |
model.parameters(), args.max_grad_norm) |
|
|
125 |
optimizer.step() |
|
|
126 |
scheduler.step() # Update learning rate schedule |
|
|
127 |
model.zero_grad() |
|
|
128 |
|
|
|
129 |
global_step += 1 |
|
|
130 |
if global_step % args.save_step == 0 and global_step > 0: |
|
|
131 |
save_path = os.path.join( |
|
|
132 |
args.output_dir, f'model_{global_step}.pth') |
|
|
133 |
torch.save(model, save_path) |
|
|
134 |
|
|
|
135 |
# re_embedding |
|
|
136 |
if args.use_re: |
|
|
137 |
writer.add_embedding(model.re_embedding.weight, metadata=umls_dataset.re2id.keys( |
|
|
138 |
), global_step=global_step, tag="re embedding") |
|
|
139 |
else: |
|
|
140 |
# print(len(umls_dataset.rel2id)) |
|
|
141 |
# print(model.re_embedding.weight.shape) |
|
|
142 |
writer.add_embedding(model.re_embedding.weight, metadata=umls_dataset.rel2id.keys( |
|
|
143 |
), global_step=global_step, tag="rel embedding") |
|
|
144 |
|
|
|
145 |
# sty_parameter |
|
|
146 |
writer.add_embedding(model.linear_sty.weight, metadata=umls_dataset.sty2id.keys( |
|
|
147 |
), global_step=global_step, tag="sty weight") |
|
|
148 |
|
|
|
149 |
if args.max_steps > 0 and global_step > args.max_steps: |
|
|
150 |
return None |
|
|
151 |
|
|
|
152 |
return None |
|
|
153 |
|
|
|
154 |
|
|
|
155 |
def run(args): |
|
|
156 |
torch.manual_seed(args.seed) # cpu |
|
|
157 |
torch.cuda.manual_seed(args.seed) # gpu |
|
|
158 |
np.random.seed(args.seed) # numpy |
|
|
159 |
torch.backends.cudnn.deterministic = True # cudnn |
|
|
160 |
|
|
|
161 |
#args.output_dir = args.output_dir + "_" + str(int(time.time())) |
|
|
162 |
|
|
|
163 |
# dataloader |
|
|
164 |
if args.lang == "eng": |
|
|
165 |
lang = ["ENG"] |
|
|
166 |
if args.lang == "all": |
|
|
167 |
lang = None |
|
|
168 |
assert args.model_name_or_path.find("bio") == -1, "Should use multi-language model" |
|
|
169 |
umls_dataset = UMLSDataset( |
|
|
170 |
umls_folder=args.umls_dir, model_name_or_path=args.model_name_or_path, lang=lang, json_save_path=args.output_dir) |
|
|
171 |
umls_dataloader = fixed_length_dataloader( |
|
|
172 |
umls_dataset, fixed_length=args.train_batch_size, num_workers=args.num_workers) |
|
|
173 |
|
|
|
174 |
if args.use_re: |
|
|
175 |
rel_label_count = len(umls_dataset.re2id) |
|
|
176 |
else: |
|
|
177 |
rel_label_count = len(umls_dataset.rel2id) |
|
|
178 |
|
|
|
179 |
model_load = False |
|
|
180 |
if os.path.exists(args.output_dir): |
|
|
181 |
save_list = [] |
|
|
182 |
for f in os.listdir(args.output_dir): |
|
|
183 |
if f[0:5] == "model" and f[-4:] == ".pth": |
|
|
184 |
save_list.append(int(f[6:-4])) |
|
|
185 |
if len(save_list) > 0: |
|
|
186 |
args.shift = max(save_list) |
|
|
187 |
if os.path.exists(os.path.join(args.output_dir, 'last_model.pth')): |
|
|
188 |
model = torch.load(os.path.join( |
|
|
189 |
args.output_dir, 'last_model.pth')).to(args.device) |
|
|
190 |
model_load = True |
|
|
191 |
else: |
|
|
192 |
model = torch.load(os.path.join( |
|
|
193 |
args.output_dir, f'model_{max(save_list)}.pth')).to(args.device) |
|
|
194 |
model_load = True |
|
|
195 |
if not model_load: |
|
|
196 |
if not os.path.exists(args.output_dir): |
|
|
197 |
os.makedirs(args.output_dir) |
|
|
198 |
model = UMLSPretrainedModel(device=args.device, model_name_or_path=args.model_name_or_path, |
|
|
199 |
cui_label_count=len(umls_dataset.cui2id), |
|
|
200 |
rel_label_count=rel_label_count, |
|
|
201 |
sty_label_count=len(umls_dataset.sty2id), |
|
|
202 |
re_weight=args.re_weight, |
|
|
203 |
sty_weight=args.sty_weight).to(args.device) |
|
|
204 |
args.shift = 0 |
|
|
205 |
model_load = True |
|
|
206 |
|
|
|
207 |
if args.do_train: |
|
|
208 |
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) |
|
|
209 |
train(args, model, umls_dataloader, umls_dataset) |
|
|
210 |
torch.save(model, os.path.join(args.output_dir, 'last_model.pth')) |
|
|
211 |
|
|
|
212 |
return None |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
def main(): |
|
|
216 |
parser = argparse.ArgumentParser() |
|
|
217 |
parser.add_argument( |
|
|
218 |
"--umls_dir", |
|
|
219 |
default="../umls", |
|
|
220 |
type=str, |
|
|
221 |
help="UMLS dir", |
|
|
222 |
) |
|
|
223 |
parser.add_argument( |
|
|
224 |
"--model_name_or_path", |
|
|
225 |
default="../biobert_v1.1", |
|
|
226 |
type=str, |
|
|
227 |
help="Path to pre-trained model or shortcut name selected in the list: ", |
|
|
228 |
) |
|
|
229 |
parser.add_argument( |
|
|
230 |
"--output_dir", |
|
|
231 |
default="output", |
|
|
232 |
type=str, |
|
|
233 |
help="The output directory where the model predictions and checkpoints will be written.", |
|
|
234 |
) |
|
|
235 |
parser.add_argument( |
|
|
236 |
"--save_step", |
|
|
237 |
default=25000, |
|
|
238 |
type=int, |
|
|
239 |
help="Save step", |
|
|
240 |
) |
|
|
241 |
|
|
|
242 |
# Other parameters |
|
|
243 |
parser.add_argument( |
|
|
244 |
"--max_seq_length", |
|
|
245 |
default=32, |
|
|
246 |
type=int, |
|
|
247 |
help="The maximum total input sequence length after tokenization. Sequences longer " |
|
|
248 |
"than this will be truncated, sequences shorter will be padded.", |
|
|
249 |
) |
|
|
250 |
parser.add_argument("--do_train", default=True, type=bool, help="Whether to run training.") |
|
|
251 |
parser.add_argument( |
|
|
252 |
"--train_batch_size", default=256, type=int, help="Batch size per GPU/CPU for training.", |
|
|
253 |
) |
|
|
254 |
parser.add_argument( |
|
|
255 |
"--gradient_accumulation_steps", |
|
|
256 |
type=int, |
|
|
257 |
default=8, |
|
|
258 |
help="Number of updates steps to accumulate before performing a backward/update pass.", |
|
|
259 |
) |
|
|
260 |
parser.add_argument("--learning_rate", default=2e-5, |
|
|
261 |
type=float, help="The initial learning rate for Adam.") |
|
|
262 |
parser.add_argument("--weight_decay", default=0.01, |
|
|
263 |
type=float, help="Weight decay if we apply some.") |
|
|
264 |
parser.add_argument("--adam_epsilon", default=1e-8, |
|
|
265 |
type=float, help="Epsilon for Adam optimizer.") |
|
|
266 |
parser.add_argument("--max_grad_norm", default=1.0, |
|
|
267 |
type=float, help="Max gradient norm.") |
|
|
268 |
parser.add_argument( |
|
|
269 |
"--max_steps", |
|
|
270 |
default=1000000, |
|
|
271 |
type=int, |
|
|
272 |
help="If > 0: set total number of training steps to perform. Override num_train_epochs.", |
|
|
273 |
) |
|
|
274 |
parser.add_argument("--warmup_steps", default=10000, |
|
|
275 |
help="Linear warmup over warmup_steps or a float.") |
|
|
276 |
parser.add_argument("--device", type=str, default='cuda:1', help="device") |
|
|
277 |
parser.add_argument("--seed", type=int, default=72, |
|
|
278 |
help="random seed for initialization") |
|
|
279 |
parser.add_argument("--schedule", type=str, default="linear", |
|
|
280 |
choices=["linear", "cosine", "constant"], help="Schedule.") |
|
|
281 |
parser.add_argument("--trans_margin", type=float, default=1.0, |
|
|
282 |
help="Margin of TransE.") |
|
|
283 |
parser.add_argument("--use_re", default=False, type=bool, |
|
|
284 |
help="Whether to use re or rel.") |
|
|
285 |
parser.add_argument("--num_workers", default=1, type=int, |
|
|
286 |
help="Num workers for data loader, only 0 can be used for Windows") |
|
|
287 |
parser.add_argument("--lang", default='eng', type=str, choices=["eng", "all"], |
|
|
288 |
help="language range, eng or all") |
|
|
289 |
parser.add_argument("--sty_weight", type=float, default=0.0, |
|
|
290 |
help="Weight of sty.") |
|
|
291 |
parser.add_argument("--re_weight", type=float, default=1.0, |
|
|
292 |
help="Weight of re.") |
|
|
293 |
|
|
|
294 |
args = parser.parse_args() |
|
|
295 |
|
|
|
296 |
run(args) |
|
|
297 |
|
|
|
298 |
|
|
|
299 |
if __name__ == "__main__": |
|
|
300 |
main() |