|
a |
|
b/task/MLM.ipynb |
|
|
1 |
{ |
|
|
2 |
"cells": [ |
|
|
3 |
{ |
|
|
4 |
"cell_type": "code", |
|
|
5 |
"execution_count": null, |
|
|
6 |
"metadata": {}, |
|
|
7 |
"outputs": [], |
|
|
8 |
"source": [ |
|
|
9 |
"import sys \n", |
|
|
10 |
"sys.path.insert(0, '../')" |
|
|
11 |
] |
|
|
12 |
}, |
|
|
13 |
{ |
|
|
14 |
"cell_type": "code", |
|
|
15 |
"execution_count": null, |
|
|
16 |
"metadata": {}, |
|
|
17 |
"outputs": [], |
|
|
18 |
"source": [ |
|
|
19 |
"from common.common import create_folder\n", |
|
|
20 |
"from common.pytorch import load_model\n", |
|
|
21 |
"import pytorch_pretrained_bert as Bert\n", |
|
|
22 |
"from model.utils import age_vocab\n", |
|
|
23 |
"from common.common import load_obj\n", |
|
|
24 |
"from dataLoader.MLM import MLMLoader\n", |
|
|
25 |
"from torch.utils.data import DataLoader\n", |
|
|
26 |
"import pandas as pd\n", |
|
|
27 |
"from model.MLM import BertForMaskedLM\n", |
|
|
28 |
"from model.optimiser import adam\n", |
|
|
29 |
"import sklearn.metrics as skm\n", |
|
|
30 |
"import numpy as np\n", |
|
|
31 |
"import torch\n", |
|
|
32 |
"import time\n", |
|
|
33 |
"import torch.nn as nn\n", |
|
|
34 |
"import os" |
|
|
35 |
] |
|
|
36 |
}, |
|
|
37 |
{ |
|
|
38 |
"cell_type": "code", |
|
|
39 |
"execution_count": null, |
|
|
40 |
"metadata": {}, |
|
|
41 |
"outputs": [], |
|
|
42 |
"source": [ |
|
|
43 |
"class BertConfig(Bert.modeling.BertConfig):\n", |
|
|
44 |
" def __init__(self, config):\n", |
|
|
45 |
" super(BertConfig, self).__init__(\n", |
|
|
46 |
" vocab_size_or_config_json_file=config.get('vocab_size'),\n", |
|
|
47 |
" hidden_size=config['hidden_size'],\n", |
|
|
48 |
" num_hidden_layers=config.get('num_hidden_layers'),\n", |
|
|
49 |
" num_attention_heads=config.get('num_attention_heads'),\n", |
|
|
50 |
" intermediate_size=config.get('intermediate_size'),\n", |
|
|
51 |
" hidden_act=config.get('hidden_act'),\n", |
|
|
52 |
" hidden_dropout_prob=config.get('hidden_dropout_prob'),\n", |
|
|
53 |
" attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n", |
|
|
54 |
" max_position_embeddings = config.get('max_position_embedding'),\n", |
|
|
55 |
" initializer_range=config.get('initializer_range'),\n", |
|
|
56 |
" )\n", |
|
|
57 |
" self.seg_vocab_size = config.get('seg_vocab_size')\n", |
|
|
58 |
" self.age_vocab_size = config.get('age_vocab_size')\n", |
|
|
59 |
" \n", |
|
|
60 |
"class TrainConfig(object):\n", |
|
|
61 |
" def __init__(self, config):\n", |
|
|
62 |
" self.batch_size = config.get('batch_size')\n", |
|
|
63 |
" self.use_cuda = config.get('use_cuda')\n", |
|
|
64 |
" self.max_len_seq = config.get('max_len_seq')\n", |
|
|
65 |
" self.train_loader_workers = config.get('train_loader_workers')\n", |
|
|
66 |
" self.test_loader_workers = config.get('test_loader_workers')\n", |
|
|
67 |
" self.device = config.get('device')\n", |
|
|
68 |
" self.output_dir = config.get('output_dir')\n", |
|
|
69 |
" self.output_name = config.get('output_name')\n", |
|
|
70 |
" self.best_name = config.get('best_name')" |
|
|
71 |
] |
|
|
72 |
}, |
|
|
73 |
{ |
|
|
74 |
"cell_type": "code", |
|
|
75 |
"execution_count": null, |
|
|
76 |
"metadata": {}, |
|
|
77 |
"outputs": [], |
|
|
78 |
"source": [ |
|
|
79 |
"file_config = {\n", |
|
|
80 |
" 'vocab':'', # vocabulary idx2token, token2idx\n", |
|
|
81 |
" 'data': '', # formated data \n", |
|
|
82 |
" 'model_path': '', # where to save model\n", |
|
|
83 |
" 'model_name': '', # model name\n", |
|
|
84 |
" 'file_name': '', # log path\n", |
|
|
85 |
"}\n", |
|
|
86 |
"create_folder(file_config['model_path'])" |
|
|
87 |
] |
|
|
88 |
}, |
|
|
89 |
{ |
|
|
90 |
"cell_type": "code", |
|
|
91 |
"execution_count": null, |
|
|
92 |
"metadata": {}, |
|
|
93 |
"outputs": [], |
|
|
94 |
"source": [ |
|
|
95 |
"global_params = {\n", |
|
|
96 |
" 'max_seq_len': 64,\n", |
|
|
97 |
" 'max_age': 110,\n", |
|
|
98 |
" 'month': 1,\n", |
|
|
99 |
" 'age_symbol': None,\n", |
|
|
100 |
" 'min_visit': 5,\n", |
|
|
101 |
" 'gradient_accumulation_steps': 1\n", |
|
|
102 |
"}\n", |
|
|
103 |
"\n", |
|
|
104 |
"optim_param = {\n", |
|
|
105 |
" 'lr': 3e-5,\n", |
|
|
106 |
" 'warmup_proportion': 0.1,\n", |
|
|
107 |
" 'weight_decay': 0.01\n", |
|
|
108 |
"}\n", |
|
|
109 |
"\n", |
|
|
110 |
"train_params = {\n", |
|
|
111 |
" 'batch_size': 256,\n", |
|
|
112 |
" 'use_cuda': True,\n", |
|
|
113 |
" 'max_len_seq': global_params['max_seq_len'],\n", |
|
|
114 |
" 'device': 'cuda:0'\n", |
|
|
115 |
"}" |
|
|
116 |
] |
|
|
117 |
}, |
|
|
118 |
{ |
|
|
119 |
"cell_type": "code", |
|
|
120 |
"execution_count": null, |
|
|
121 |
"metadata": {}, |
|
|
122 |
"outputs": [], |
|
|
123 |
"source": [ |
|
|
124 |
"BertVocab = load_obj(file_config['vocab'])\n", |
|
|
125 |
"ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])" |
|
|
126 |
] |
|
|
127 |
}, |
|
|
128 |
{ |
|
|
129 |
"cell_type": "code", |
|
|
130 |
"execution_count": null, |
|
|
131 |
"metadata": {}, |
|
|
132 |
"outputs": [], |
|
|
133 |
"source": [ |
|
|
134 |
"data = pd.read_parquet(file_config['data'])\n", |
|
|
135 |
"# remove patients with visits less than min visit\n", |
|
|
136 |
"data['length'] = data['caliber_id'].apply(lambda x: len([i for i in range(len(x)) if x[i] == 'SEP']))\n", |
|
|
137 |
"data = data[data['length'] >= global_params['min_visit']]\n", |
|
|
138 |
"data = data.reset_index(drop=True)" |
|
|
139 |
] |
|
|
140 |
}, |
|
|
141 |
{ |
|
|
142 |
"cell_type": "code", |
|
|
143 |
"execution_count": null, |
|
|
144 |
"metadata": {}, |
|
|
145 |
"outputs": [], |
|
|
146 |
"source": [ |
|
|
147 |
"Dset = MLMLoader(data, BertVocab['token2idx'], ageVocab, max_len=train_params['max_len_seq'], code='caliber_id')\n", |
|
|
148 |
"trainload = DataLoader(dataset=Dset, batch_size=train_params['batch_size'], shuffle=True, num_workers=3)" |
|
|
149 |
] |
|
|
150 |
}, |
|
|
151 |
{ |
|
|
152 |
"cell_type": "code", |
|
|
153 |
"execution_count": null, |
|
|
154 |
"metadata": {}, |
|
|
155 |
"outputs": [], |
|
|
156 |
"source": [ |
|
|
157 |
"model_config = {\n", |
|
|
158 |
" 'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n", |
|
|
159 |
" 'hidden_size': 288, # word embedding and seg embedding hidden size\n", |
|
|
160 |
" 'seg_vocab_size': 2, # number of vocab for seg embedding\n", |
|
|
161 |
" 'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n", |
|
|
162 |
" 'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens\n", |
|
|
163 |
" 'hidden_dropout_prob': 0.1, # dropout rate\n", |
|
|
164 |
" 'num_hidden_layers': 6, # number of multi-head attention layers required\n", |
|
|
165 |
" 'num_attention_heads': 12, # number of attention heads\n", |
|
|
166 |
" 'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate\n", |
|
|
167 |
" 'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n", |
|
|
168 |
" 'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n", |
|
|
169 |
" 'initializer_range': 0.02, # parameter weight initializer range\n", |
|
|
170 |
"}" |
|
|
171 |
] |
|
|
172 |
}, |
|
|
173 |
{ |
|
|
174 |
"cell_type": "code", |
|
|
175 |
"execution_count": null, |
|
|
176 |
"metadata": {}, |
|
|
177 |
"outputs": [], |
|
|
178 |
"source": [ |
|
|
179 |
"conf = BertConfig(model_config)\n", |
|
|
180 |
"model = BertForMaskedLM(conf)" |
|
|
181 |
] |
|
|
182 |
}, |
|
|
183 |
{ |
|
|
184 |
"cell_type": "code", |
|
|
185 |
"execution_count": null, |
|
|
186 |
"metadata": {}, |
|
|
187 |
"outputs": [], |
|
|
188 |
"source": [ |
|
|
189 |
"model = model.to(train_params['device'])\n", |
|
|
190 |
"optim = adam(params=list(model.named_parameters()), config=optim_param)" |
|
|
191 |
] |
|
|
192 |
}, |
|
|
193 |
{ |
|
|
194 |
"cell_type": "code", |
|
|
195 |
"execution_count": null, |
|
|
196 |
"metadata": {}, |
|
|
197 |
"outputs": [], |
|
|
198 |
"source": [ |
|
|
199 |
"def cal_acc(label, pred):\n", |
|
|
200 |
" logs = nn.LogSoftmax()\n", |
|
|
201 |
" label=label.cpu().numpy()\n", |
|
|
202 |
" ind = np.where(label!=-1)[0]\n", |
|
|
203 |
" truepred = pred.detach().cpu().numpy()\n", |
|
|
204 |
" truepred = truepred[ind]\n", |
|
|
205 |
" truelabel = label[ind]\n", |
|
|
206 |
" truepred = logs(torch.tensor(truepred))\n", |
|
|
207 |
" outs = [np.argmax(pred_x) for pred_x in truepred.numpy()]\n", |
|
|
208 |
" precision = skm.precision_score(truelabel, outs, average='micro')\n", |
|
|
209 |
" return precision" |
|
|
210 |
] |
|
|
211 |
}, |
|
|
212 |
{ |
|
|
213 |
"cell_type": "code", |
|
|
214 |
"execution_count": null, |
|
|
215 |
"metadata": {}, |
|
|
216 |
"outputs": [], |
|
|
217 |
"source": [ |
|
|
218 |
"def train(e, loader):\n", |
|
|
219 |
" tr_loss = 0\n", |
|
|
220 |
" temp_loss = 0\n", |
|
|
221 |
" nb_tr_examples, nb_tr_steps = 0, 0\n", |
|
|
222 |
" cnt= 0\n", |
|
|
223 |
" start = time.time()\n", |
|
|
224 |
"\n", |
|
|
225 |
" for step, batch in enumerate(loader):\n", |
|
|
226 |
" cnt +=1\n", |
|
|
227 |
" batch = tuple(t.to(train_params['device']) for t in batch)\n", |
|
|
228 |
" age_ids, input_ids, posi_ids, segment_ids, attMask, masked_label = batch\n", |
|
|
229 |
" loss, pred, label = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, masked_lm_labels=masked_label)\n", |
|
|
230 |
" if global_params['gradient_accumulation_steps'] >1:\n", |
|
|
231 |
" loss = loss/global_params['gradient_accumulation_steps']\n", |
|
|
232 |
" loss.backward()\n", |
|
|
233 |
" \n", |
|
|
234 |
" temp_loss += loss.item()\n", |
|
|
235 |
" tr_loss += loss.item()\n", |
|
|
236 |
" \n", |
|
|
237 |
" nb_tr_examples += input_ids.size(0)\n", |
|
|
238 |
" nb_tr_steps += 1\n", |
|
|
239 |
" \n", |
|
|
240 |
" if step % 200==0:\n", |
|
|
241 |
" print(\"epoch: {}\\t| cnt: {}\\t|Loss: {}\\t| precision: {:.4f}\\t| time: {:.2f}\".format(e, cnt, temp_loss/2000, cal_acc(label, pred), time.time()-start))\n", |
|
|
242 |
" temp_loss = 0\n", |
|
|
243 |
" start = time.time()\n", |
|
|
244 |
" \n", |
|
|
245 |
" if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n", |
|
|
246 |
" optim.step()\n", |
|
|
247 |
" optim.zero_grad()\n", |
|
|
248 |
"\n", |
|
|
249 |
" print(\"** ** * Saving fine - tuned model ** ** * \")\n", |
|
|
250 |
" model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n", |
|
|
251 |
" create_folder(file_config['model_path'])\n", |
|
|
252 |
" output_model_file = os.path.join(file_config['model_path'], file_config['model_name'])\n", |
|
|
253 |
"\n", |
|
|
254 |
" torch.save(model_to_save.state_dict(), output_model_file)\n", |
|
|
255 |
" \n", |
|
|
256 |
" cost = time.time() - start\n", |
|
|
257 |
" return tr_loss, cost" |
|
|
258 |
] |
|
|
259 |
}, |
|
|
260 |
{ |
|
|
261 |
"cell_type": "code", |
|
|
262 |
"execution_count": null, |
|
|
263 |
"metadata": {}, |
|
|
264 |
"outputs": [], |
|
|
265 |
"source": [ |
|
|
266 |
"f = open(os.path.join(file_config['model_path'], file_config['file_name']), \"w\")\n", |
|
|
267 |
"f.write('{}\\t{}\\t{}\\n'.format('epoch', 'loss', 'time'))\n", |
|
|
268 |
"for e in range(50):\n", |
|
|
269 |
" loss, time_cost = train(e, trainload)\n", |
|
|
270 |
" loss = loss/data_len\n", |
|
|
271 |
" f.write('{}\\t{}\\t{}\\n'.format(e, loss, time_cost))\n", |
|
|
272 |
"f.close() " |
|
|
273 |
] |
|
|
274 |
} |
|
|
275 |
], |
|
|
276 |
"metadata": { |
|
|
277 |
"kernelspec": { |
|
|
278 |
"display_name": "Python 3", |
|
|
279 |
"language": "python", |
|
|
280 |
"name": "python3" |
|
|
281 |
}, |
|
|
282 |
"language_info": { |
|
|
283 |
"codemirror_mode": { |
|
|
284 |
"name": "ipython", |
|
|
285 |
"version": 3 |
|
|
286 |
}, |
|
|
287 |
"file_extension": ".py", |
|
|
288 |
"mimetype": "text/x-python", |
|
|
289 |
"name": "python", |
|
|
290 |
"nbconvert_exporter": "python", |
|
|
291 |
"pygments_lexer": "ipython3", |
|
|
292 |
"version": "3.7.4" |
|
|
293 |
} |
|
|
294 |
}, |
|
|
295 |
"nbformat": 4, |
|
|
296 |
"nbformat_minor": 2 |
|
|
297 |
} |