Switch to unified view

a b/task/NextVIsit-12month.ipynb
1
{
2
 "cells": [
3
  {
4
   "cell_type": "code",
5
   "execution_count": null,
6
   "metadata": {},
7
   "outputs": [],
8
   "source": [
9
    "import sys \n",
10
    "sys.path.insert(0, '../')\n",
11
    "\n",
12
    "from common.common import create_folder,load_obj\n",
13
    "# from data import bert,dataframe,utils\n",
14
    "from dataLoader.utils import seq_padding,code2index, position_idx, index_seg\n",
15
    "from torch.utils.data import DataLoader\n",
16
    "import pandas as pd\n",
17
    "import numpy as np\n",
18
    "from torch.utils.data.dataset import Dataset\n",
19
    "import os\n",
20
    "import torch\n",
21
    "import torch.nn as nn\n",
22
    "import pytorch_pretrained_bert as Bert\n",
23
    "from model.utils import age_vocab\n",
24
    "from model import optimiser\n",
25
    "import sklearn.metrics as skm\n",
26
    "import math\n",
27
    "from torch.utils.data.dataset import Dataset\n",
28
    "import random\n",
29
    "import numpy as np\n",
30
    "import torch\n",
31
    "import time\n",
32
    "\n",
33
    "# from data.utils import seq_padding, index_seg, position_idx, age_vocab, random_mask, code2index\n",
34
    "# from sklearn.metrics import roc_auc_score"
35
   ]
36
  },
37
  {
38
   "cell_type": "markdown",
39
   "metadata": {},
40
   "source": [
41
    "# File Parameters"
42
   ]
43
  },
44
  {
45
   "cell_type": "code",
46
   "execution_count": null,
47
   "metadata": {},
48
   "outputs": [],
49
   "source": [
50
    "file_config = {\n",
51
    "    'vocab':'',  # token2idx idx2token\n",
52
    "    'train': '',\n",
53
    "    'test': '',\n",
54
    "}\n",
55
    "\n",
56
    "optim_config = {\n",
57
    "    'lr': 3e-5,\n",
58
    "    'warmup_proportion': 0.1,\n",
59
    "    'weight_decay': 0.01\n",
60
    "}\n",
61
    "\n",
62
    "global_params = {\n",
63
    "    'batch_size': 128,\n",
64
    "    'gradient_accumulation_steps': 1,\n",
65
    "    'device': 'cuda:1',\n",
66
    "    'output_dir': '',  # output folder\n",
67
    "    'best_name': '', # output model name\n",
68
    "    'save_model': True,\n",
69
    "    'max_len_seq': 100,\n",
70
    "    'max_age': 110,\n",
71
    "    'month': 1,\n",
72
    "    'age_symbol': None,\n",
73
    "    'min_visit': 5\n",
74
    "}\n",
75
    "\n",
76
    "pretrainModel = '' # pretrained MLM path"
77
   ]
78
  },
79
  {
80
   "cell_type": "code",
81
   "execution_count": null,
82
   "metadata": {},
83
   "outputs": [],
84
   "source": [
85
    "create_folder(global_params['output_dir'])"
86
   ]
87
  },
88
  {
89
   "cell_type": "code",
90
   "execution_count": null,
91
   "metadata": {},
92
   "outputs": [],
93
   "source": [
94
    "BertVocab = load_obj(file_config['vocab'])\n",
95
    "ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])"
96
   ]
97
  },
98
  {
99
   "cell_type": "code",
100
   "execution_count": null,
101
   "metadata": {},
102
   "outputs": [],
103
   "source": [
104
    "def format_label_vocab(token2idx):\n",
105
    "    token2idx = token2idx.copy()\n",
106
    "    del token2idx['PAD']\n",
107
    "    del token2idx['SEP']\n",
108
    "    del token2idx['CLS']\n",
109
    "    del token2idx['MASK']\n",
110
    "    token = list(token2idx.keys())\n",
111
    "    labelVocab = {}\n",
112
    "    for i,x in enumerate(token):\n",
113
    "        labelVocab[x] = i\n",
114
    "    return labelVocab\n",
115
    "\n",
116
    "Vocab_diag = format_label_vocab(BertVocab['token2idx'])"
117
   ]
118
  },
119
  {
120
   "cell_type": "code",
121
   "execution_count": null,
122
   "metadata": {},
123
   "outputs": [],
124
   "source": [
125
    "model_config = {\n",
126
    "    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding\n",
127
    "    'hidden_size': 288, # word embedding and seg embedding hidden size\n",
128
    "    'seg_vocab_size': 2, # number of vocab for seg embedding\n",
129
    "    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding\n",
130
    "    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens\n",
131
    "    'hidden_dropout_prob': 0.2, # dropout rate\n",
132
    "    'num_hidden_layers': 6, # number of multi-head attention layers required\n",
133
    "    'num_attention_heads': 12, # number of attention heads\n",
134
    "    'attention_probs_dropout_prob': 0.22, # multi-head attention dropout rate\n",
135
    "    'intermediate_size': 512, # the size of the \"intermediate\" layer in the transformer encoder\n",
136
    "    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler \"gelu\", 'relu', 'swish' are supported\n",
137
    "    'initializer_range': 0.02, # parameter weight initializer range\n",
138
    "}"
139
   ]
140
  },
141
  {
142
   "cell_type": "markdown",
143
   "metadata": {},
144
   "source": [
145
    "# Set Up Model"
146
   ]
147
  },
148
  {
149
   "cell_type": "code",
150
   "execution_count": null,
151
   "metadata": {},
152
   "outputs": [],
153
   "source": [
154
    "class NextVisit(Dataset):\n",
155
    "    def __init__(self, token2idx, diag2idx, age2idx,dataframe, max_len, max_age=110, min_visit=5):\n",
156
    "        # dataframe preproecssing\n",
157
    "        # filter out the patient with number of visits less than min_visit\n",
158
    "        self.vocab = token2idx\n",
159
    "        self.label_vocab = diag2idx\n",
160
    "        self.max_len = max_len\n",
161
    "        self.code = dataframe.code\n",
162
    "        self.age = dataframe.age\n",
163
    "        self.label = dataframe.label\n",
164
    "        self.patid = dataframe.patid\n",
165
    "\n",
166
    "        self.age2idx = age2idx\n",
167
    "\n",
168
    "    def __getitem__(self, index):\n",
169
    "        \"\"\"\n",
170
    "        return: age, code, position, segmentation, mask, label\n",
171
    "        \"\"\"\n",
172
    "        # cut data\n",
173
    "        age = self.age[index]\n",
174
    "        code = self.code[index]\n",
175
    "        label = self.label[index]\n",
176
    "        patid = self.patid[index]\n",
177
    "\n",
178
    "        # extract data\n",
179
    "        age = age[(-self.max_len+1):]\n",
180
    "        code = code[(-self.max_len+1):]\n",
181
    "\n",
182
    "        # avoid data cut with first element to be 'SEP'\n",
183
    "        if code[0] != 'SEP':\n",
184
    "            code = np.append(np.array(['CLS']), code)\n",
185
    "            age = np.append(np.array(age[0]), age)\n",
186
    "        else:\n",
187
    "            code[0] = 'CLS'\n",
188
    "\n",
189
    "        # mask 0:len(code) to 1, padding to be 0\n",
190
    "        mask = np.ones(self.max_len)\n",
191
    "        mask[len(code):] = 0\n",
192
    "\n",
193
    "        # pad age sequence and code sequence\n",
194
    "        age = seq_padding(age, self.max_len, token2idx=self.age2idx)\n",
195
    "\n",
196
    "        tokens, code = code2index(code, self.vocab)\n",
197
    "        _, label = code2index(label, self.label_vocab)\n",
198
    "\n",
199
    "        # get position code and segment code\n",
200
    "        tokens = seq_padding(tokens, self.max_len)\n",
201
    "        position = position_idx(tokens)\n",
202
    "        segment = index_seg(tokens)\n",
203
    "\n",
204
    "        # pad code and label\n",
205
    "        code = seq_padding(code, self.max_len, symbol=self.vocab['PAD'])\n",
206
    "        label = seq_padding(label, self.max_len, symbol=-1)\n",
207
    "\n",
208
    "        return torch.LongTensor(age), torch.LongTensor(code), torch.LongTensor(position), torch.LongTensor(segment), \\\n",
209
    "               torch.LongTensor(mask), torch.LongTensor(label), torch.LongTensor([int(patid)])\n",
210
    "\n",
211
    "    def __len__(self):\n",
212
    "        return len(self.code)"
213
   ]
214
  },
215
  {
216
   "cell_type": "code",
217
   "execution_count": null,
218
   "metadata": {},
219
   "outputs": [],
220
   "source": [
221
    "class BertConfig(Bert.modeling.BertConfig):\n",
222
    "    def __init__(self, config):\n",
223
    "        super(BertConfig, self).__init__(\n",
224
    "            vocab_size_or_config_json_file=config.get('vocab_size'),\n",
225
    "            hidden_size=config['hidden_size'],\n",
226
    "            num_hidden_layers=config.get('num_hidden_layers'),\n",
227
    "            num_attention_heads=config.get('num_attention_heads'),\n",
228
    "            intermediate_size=config.get('intermediate_size'),\n",
229
    "            hidden_act=config.get('hidden_act'),\n",
230
    "            hidden_dropout_prob=config.get('hidden_dropout_prob'),\n",
231
    "            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),\n",
232
    "            max_position_embeddings = config.get('max_position_embedding'),\n",
233
    "            initializer_range=config.get('initializer_range'),\n",
234
    "        )\n",
235
    "        self.seg_vocab_size = config.get('seg_vocab_size')\n",
236
    "        self.age_vocab_size = config.get('age_vocab_size')\n",
237
    "\n",
238
    "class BertEmbeddings(nn.Module):\n",
239
    "    \"\"\"Construct the embeddings from word, segment, age\n",
240
    "    \"\"\"\n",
241
    "\n",
242
    "    def __init__(self, config):\n",
243
    "        super(BertEmbeddings, self).__init__()\n",
244
    "        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n",
245
    "        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)\n",
246
    "        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)\n",
247
    "        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).\\\n",
248
    "            from_pretrained(embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size))\n",
249
    "\n",
250
    "        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)\n",
251
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
252
    "\n",
253
    "    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None, age=True):\n",
254
    "        if seg_ids is None:\n",
255
    "            seg_ids = torch.zeros_like(word_ids)\n",
256
    "        if age_ids is None:\n",
257
    "            age_ids = torch.zeros_like(word_ids)\n",
258
    "        if posi_ids is None:\n",
259
    "            posi_ids = torch.zeros_like(word_ids)\n",
260
    "\n",
261
    "        word_embed = self.word_embeddings(word_ids)\n",
262
    "        segment_embed = self.segment_embeddings(seg_ids)\n",
263
    "        age_embed = self.age_embeddings(age_ids)\n",
264
    "        posi_embeddings = self.posi_embeddings(posi_ids)\n",
265
    "        \n",
266
    "        if age:\n",
267
    "            embeddings = word_embed + segment_embed + age_embed + posi_embeddings\n",
268
    "        else:\n",
269
    "            embeddings = word_embed + segment_embed + posi_embeddings\n",
270
    "        embeddings = self.LayerNorm(embeddings)\n",
271
    "        embeddings = self.dropout(embeddings)\n",
272
    "        return embeddings\n",
273
    "\n",
274
    "    def _init_posi_embedding(self, max_position_embedding, hidden_size):\n",
275
    "        def even_code(pos, idx):\n",
276
    "            return np.sin(pos/(10000**(2*idx/hidden_size)))\n",
277
    "\n",
278
    "        def odd_code(pos, idx):\n",
279
    "            return np.cos(pos/(10000**(2*idx/hidden_size)))\n",
280
    "\n",
281
    "        # initialize position embedding table\n",
282
    "        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)\n",
283
    "\n",
284
    "        # reset table parameters with hard encoding\n",
285
    "        # set even dimension\n",
286
    "        for pos in range(max_position_embedding):\n",
287
    "            for idx in np.arange(0, hidden_size, step=2):\n",
288
    "                lookup_table[pos, idx] = even_code(pos, idx)\n",
289
    "        # set odd dimension\n",
290
    "        for pos in range(max_position_embedding):\n",
291
    "            for idx in np.arange(1, hidden_size, step=2):\n",
292
    "                lookup_table[pos, idx] = odd_code(pos, idx)\n",
293
    "\n",
294
    "        return torch.tensor(lookup_table)\n",
295
    "\n",
296
    "class BertModel(Bert.modeling.BertPreTrainedModel):\n",
297
    "    def __init__(self, config):\n",
298
    "        super(BertModel, self).__init__(config)\n",
299
    "        self.embeddings = BertEmbeddings(config=config)\n",
300
    "        self.encoder = Bert.modeling.BertEncoder(config=config)\n",
301
    "        self.pooler = Bert.modeling.BertPooler(config)\n",
302
    "        self.apply(self.init_bert_weights)\n",
303
    "\n",
304
    "    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, output_all_encoded_layers=True):\n",
305
    "        if attention_mask is None:\n",
306
    "            attention_mask = torch.ones_like(input_ids)\n",
307
    "        if age_ids is None:\n",
308
    "            age_ids = torch.zeros_like(input_ids)\n",
309
    "        if seg_ids is None:\n",
310
    "            seg_ids = torch.zeros_like(input_ids)\n",
311
    "        if posi_ids is None:\n",
312
    "            posi_ids = torch.zeros_like(input_ids)\n",
313
    "\n",
314
    "        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)\n",
315
    "        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility\n",
316
    "        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
317
    "\n",
318
    "        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)\n",
319
    "        encoded_layers = self.encoder(embedding_output,\n",
320
    "                                      extended_attention_mask,\n",
321
    "                                      output_all_encoded_layers=output_all_encoded_layers)\n",
322
    "        sequence_output = encoded_layers[-1]\n",
323
    "        pooled_output = self.pooler(sequence_output)\n",
324
    "        if not output_all_encoded_layers:\n",
325
    "            encoded_layers = encoded_layers[-1]\n",
326
    "        return encoded_layers, pooled_output\n",
327
    "\n",
328
    "class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):\n",
329
    "    def __init__(self, config, num_labels):\n",
330
    "        super(BertForMultiLabelPrediction, self).__init__(config)\n",
331
    "        self.num_labels = num_labels\n",
332
    "        self.bert = BertModel(config)\n",
333
    "        self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
334
    "        self.classifier = nn.Linear(config.hidden_size, num_labels)\n",
335
    "        self.apply(self.init_bert_weights)\n",
336
    "\n",
337
    "    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):\n",
338
    "        _, pooled_output = self.bert(input_ids, age_ids ,seg_ids, posi_ids, attention_mask,\n",
339
    "                                     output_all_encoded_layers=False)\n",
340
    "        pooled_output = self.dropout(pooled_output)\n",
341
    "        logits = self.classifier(pooled_output)\n",
342
    "\n",
343
    "        if labels is not None:\n",
344
    "            loss_fct = nn.MultiLabelSoftMarginLoss()\n",
345
    "            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))\n",
346
    "            return loss, logits\n",
347
    "        else:\n",
348
    "            return logits"
349
   ]
350
  },
351
  {
352
   "cell_type": "markdown",
353
   "metadata": {},
354
   "source": [
355
    "# Load Data"
356
   ]
357
  },
358
  {
359
   "cell_type": "code",
360
   "execution_count": null,
361
   "metadata": {},
362
   "outputs": [],
363
   "source": [
364
    "data = pd.read_parquet(file_config['train']).reset_index(drop=True)\n",
365
    "data['label'] = data.label.apply(lambda x: list(set(x)))\n",
366
    "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
367
    "trainload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=True, num_workers=3)"
368
   ]
369
  },
370
  {
371
   "cell_type": "code",
372
   "execution_count": null,
373
   "metadata": {},
374
   "outputs": [],
375
   "source": [
376
    "data = pd.read_parquet(file_config['test']).reset_index(drop=True)\n",
377
    "data['label'] = data.label.apply(lambda x: list(set(x)))\n",
378
    "Dset = NextVisit(token2idx=BertVocab['token2idx'], diag2idx=Vocab_diag, age2idx=ageVocab,dataframe=data, max_len=global_params['max_len_seq'])\n",
379
    "testload = DataLoader(dataset=Dset, batch_size=global_params['batch_size'], shuffle=False, num_workers=3)"
380
   ]
381
  },
382
  {
383
   "cell_type": "markdown",
384
   "metadata": {},
385
   "source": [
386
    "# Set Up Model"
387
   ]
388
  },
389
  {
390
   "cell_type": "code",
391
   "execution_count": null,
392
   "metadata": {},
393
   "outputs": [],
394
   "source": [
395
    "# del model\n",
396
    "conf = BertConfig(model_config)\n",
397
    "model = BertForMultiLabelPrediction(conf, num_labels=len(Vocab_diag.keys()))"
398
   ]
399
  },
400
  {
401
   "cell_type": "code",
402
   "execution_count": null,
403
   "metadata": {
404
    "scrolled": true
405
   },
406
   "outputs": [],
407
   "source": [
408
    "# load pretrained model and update weights\n",
409
    "pretrained_dict = torch.load(pretrainModel)\n",
410
    "model_dict = model.state_dict()\n",
411
    "# 1. filter out unnecessary keys\n",
412
    "pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}\n",
413
    "# 2. overwrite entries in the existing state dict\n",
414
    "model_dict.update(pretrained_dict) \n",
415
    "# 3. load the new state dict\n",
416
    "model.load_state_dict(model_dict)"
417
   ]
418
  },
419
  {
420
   "cell_type": "code",
421
   "execution_count": null,
422
   "metadata": {},
423
   "outputs": [],
424
   "source": [
425
    "model = model.to(global_params['device'])\n",
426
    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)"
427
   ]
428
  },
429
  {
430
   "cell_type": "markdown",
431
   "metadata": {},
432
   "source": [
433
    "# Evaluation Matrix"
434
   ]
435
  },
436
  {
437
   "cell_type": "code",
438
   "execution_count": null,
439
   "metadata": {},
440
   "outputs": [],
441
   "source": [
442
    "import sklearn\n",
443
    "def precision(logits, label):\n",
444
    "    sig = nn.Sigmoid()\n",
445
    "    output=sig(logits)\n",
446
    "    label, output=label.cpu(), output.detach().cpu()\n",
447
    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
448
    "    return tempprc, output, label\n",
449
    "\n",
450
    "def precision_test(logits, label):\n",
451
    "    sig = nn.Sigmoid()\n",
452
    "    output=sig(logits)\n",
453
    "    tempprc= sklearn.metrics.average_precision_score(label.numpy(),output.numpy(), average='samples')\n",
454
    "#     roc = sklearn.metrics.roc_auc_score()\n",
455
    "    return tempprc, output, label\n",
456
    "\n",
457
    "def auroc_test(logits, label):\n",
458
    "    sig = nn.Sigmoid()\n",
459
    "    output=sig(logits)\n",
460
    "    tempprc= sklearn.metrics.roc_auc_score(label.numpy(),output.numpy(), average='samples')\n",
461
    "#     roc = sklearn.metrics.roc_auc_score()\n",
462
    "    return tempprc"
463
   ]
464
  },
465
  {
466
   "cell_type": "markdown",
467
   "metadata": {},
468
   "source": [
469
    "# Multi-hot Label Encoder"
470
   ]
471
  },
472
  {
473
   "cell_type": "code",
474
   "execution_count": null,
475
   "metadata": {},
476
   "outputs": [],
477
   "source": [
478
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
479
    "mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))\n",
480
    "mlb.fit([[each] for each in list(Vocab_diag.values())])"
481
   ]
482
  },
483
  {
484
   "cell_type": "markdown",
485
   "metadata": {},
486
   "source": [
487
    "# Train and Test"
488
   ]
489
  },
490
  {
491
   "cell_type": "code",
492
   "execution_count": null,
493
   "metadata": {},
494
   "outputs": [],
495
   "source": [
496
    "def train(e):\n",
497
    "    model.train()\n",
498
    "    tr_loss = 0\n",
499
    "    temp_loss = 0\n",
500
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
501
    "    cnt = 0\n",
502
    "    for step, batch in enumerate(trainload):\n",
503
    "        cnt +=1\n",
504
    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
505
    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
506
    "        \n",
507
    "        age_ids = age_ids.to(global_params['device'])\n",
508
    "        input_ids = input_ids.to(global_params['device'])\n",
509
    "        posi_ids = posi_ids.to(global_params['device'])\n",
510
    "        segment_ids = segment_ids.to(global_params['device'])\n",
511
    "        attMask = attMask.to(global_params['device'])\n",
512
    "        targets = targets.to(global_params['device'])\n",
513
    "        \n",
514
    "        loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
515
    "        \n",
516
    "        if global_params['gradient_accumulation_steps'] >1:\n",
517
    "            loss = loss/global_params['gradient_accumulation_steps']\n",
518
    "        loss.backward()\n",
519
    "        \n",
520
    "        temp_loss += loss.item()\n",
521
    "        tr_loss += loss.item()\n",
522
    "        nb_tr_examples += input_ids.size(0)\n",
523
    "        nb_tr_steps += 1\n",
524
    "        \n",
525
    "        if step % 2000==0:\n",
526
    "            prec, a, b = precision(logits, targets)\n",
527
    "            print(\"epoch: {}\\t| Cnt: {}\\t| Loss: {}\\t| precision: {}\".format(e, cnt,temp_loss/2000, prec))\n",
528
    "            temp_loss = 0\n",
529
    "        \n",
530
    "        if (step + 1) % global_params['gradient_accumulation_steps'] == 0:\n",
531
    "            optim.step()\n",
532
    "            optim.zero_grad()\n",
533
    "        \n",
534
    "\n",
535
    "def evaluation():\n",
536
    "    model.eval()\n",
537
    "    y = []\n",
538
    "    y_label = []\n",
539
    "    for step, batch in enumerate(testload):\n",
540
    "        model.eval()\n",
541
    "        age_ids, input_ids, posi_ids, segment_ids, attMask, targets, _ = batch\n",
542
    "        targets = torch.tensor(mlb.transform(targets.numpy()), dtype=torch.float32)\n",
543
    "        \n",
544
    "        age_ids = age_ids.to(global_params['device'])\n",
545
    "        input_ids = input_ids.to(global_params['device'])\n",
546
    "        posi_ids = posi_ids.to(global_params['device'])\n",
547
    "        segment_ids = segment_ids.to(global_params['device'])\n",
548
    "        attMask = attMask.to(global_params['device'])\n",
549
    "        targets = targets.to(global_params['device'])\n",
550
    "        \n",
551
    "        with torch.no_grad():\n",
552
    "            loss, logits = model(input_ids, age_ids, segment_ids, posi_ids,attention_mask=attMask, labels=targets)\n",
553
    "        logits = logits.cpu()\n",
554
    "        targets = targets.cpu()\n",
555
    "\n",
556
    "        y_label.append(targets)\n",
557
    "        y.append(logits)\n",
558
    "\n",
559
    "    y_label = torch.cat(y_label, dim=0)\n",
560
    "    y = torch.cat(y, dim=0)\n",
561
    "\n",
562
    "    tempprc, output, label = precision_test(y, y_label)\n",
563
    "    auroc = auroc_test(y, y_label)\n",
564
    "    return tempprc, auroc"
565
   ]
566
  },
567
  {
568
   "cell_type": "code",
569
   "execution_count": null,
570
   "metadata": {
571
    "scrolled": true
572
   },
573
   "outputs": [],
574
   "source": [
575
    "import warnings\n",
576
    "warnings.filterwarnings(action='ignore')\n",
577
    "optim_config = {\n",
578
    "    'lr': 9e-6,\n",
579
    "    'warmup_proportion': 0.1\n",
580
    "}\n",
581
    "optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)\n",
582
    "\n",
583
    "best_pre = 0\n",
584
    "for e in range(50):\n",
585
    "    train(e)\n",
586
    "    auc, roc= evaluation()\n",
587
    "    if auc >best_pre:\n",
588
    "        # Save a trained model\n",
589
    "        print(\"** ** * Saving fine - tuned model ** ** * \")\n",
590
    "        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
591
    "        output_model_file = os.path.join(global_params['output_dir'],global_params['best_name'])\n",
592
    "        create_folder(global_params['output_dir'])\n",
593
    "        if global_params['save_model']:\n",
594
    "            torch.save(model_to_save.state_dict(), output_model_file)\n",
595
    "        best_pre = auc\n",
596
    "    print('precision : {}, auroc: {},'.format(auc, roc))"
597
   ]
598
  }
599
 ],
600
 "metadata": {
601
  "kernelspec": {
602
   "display_name": "Python 3",
603
   "language": "python",
604
   "name": "python3"
605
  },
606
  "language_info": {
607
   "codemirror_mode": {
608
    "name": "ipython",
609
    "version": 3
610
   },
611
   "file_extension": ".py",
612
   "mimetype": "text/x-python",
613
   "name": "python",
614
   "nbconvert_exporter": "python",
615
   "pygments_lexer": "ipython3",
616
   "version": "3.7.4"
617
  }
618
 },
619
 "nbformat": 4,
620
 "nbformat_minor": 2
621
}