Switch to unified view

a b/src/Parser/biomedner_init.py
1
import logging
2
import os
3
import time
4
import json
5
import torch
6
import argparse
7
import numpy as np
8
9
from dataclasses import dataclass, field
10
from typing import Any, Callable, Dict, List, Optional, NewType, NamedTuple, Union, Tuple
11
from tqdm import tqdm
12
from torch import nn
13
from torch.utils.data.dataset import Dataset
14
from torch.utils.data.dataloader import DataLoader
15
from torch.utils.data.sampler import SequentialSampler
16
17
from transformers import (
18
    AutoConfig,
19
    AutoTokenizer,
20
    set_seed,
21
    PreTrainedTokenizer,
22
    BertTokenizerFast
23
)
24
25
from ops import (
26
    json_to_sent, 
27
    input_form, 
28
    get_prob,
29
    detokenize, 
30
    preprocess, 
31
    Profile, 
32
)
33
from models import RoBERTaMultiNER2, BERTMultiNER2
34
35
logger = logging.getLogger(__name__)
36
37
InputDataClass = NewType("InputDataClass", Any)
38
DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
39
40
@dataclass
41
class InputExample:
42
    """
43
    A single training/test example for token classification.
44
45
    Args:
46
        guid: Unique id for the example.
47
        words: list. The words of the sequence.
48
        labels: (Optional) list. The labels for each word of the sequence. This should be
49
        specified for train and dev examples, but not for test examples.
50
    """
51
52
    guid: str
53
    words: List[str]
54
    labels: Optional[List[str]]
55
    entity_labels: Optional[List[int]]
56
57
@dataclass
58
class InputFeatures:
59
    """
60
    A single set of features of data.
61
    Property names are the same names as the corresponding inputs to a model.
62
    """
63
64
    input_ids: List[int]
65
    attention_mask: List[int]
66
    token_type_ids: Optional[List[int]] = None
67
    label_ids: Optional[List[int]] = None
68
    entity_type_ids: Optional[List[int]] = None
69
70
class DataProcessor(object):
71
    """Base class for data converters for sequence classification data sets."""
72
73
    def get_train_examples(self, data_dir):
74
        """Gets a collection of `InputExample`s for the train set."""
75
        raise NotImplementedError()
76
77
    def get_dev_examples(self, data_dir):
78
        """Gets a collection of `InputExample`s for the dev set."""
79
        raise NotImplementedError()
80
81
    def get_labels(self):
82
        """Gets the list of labels for this data set."""
83
        raise NotImplementedError()
84
85
    @classmethod
86
    def _read_data(cls, data, pmids):
87
        """Reads a BIO data."""
88
        lines = []
89
        words = []
90
        labels = []
91
        entity_labels = []
92
        for pmid in pmids:
93
            for sent in data[pmid]['words']:
94
                words = sent[:]
95
                labels = ['O'] * len(words)
96
                entity_labels = [str(0)] * len(words)
97
                
98
                if len(words) >= 30:
99
                    while len(words) >= 30:
100
                        tmplabel = labels[:30]
101
                        l = ' '.join([label for label
102
                                      in labels[:len(tmplabel)]
103
                                      if len(label) > 0])
104
                        w = ' '.join([word for word
105
                                      in words[:len(tmplabel)]
106
                                      if len(word) > 0])
107
                        e = ' '.join([el for el
108
                                      in entity_labels[:len(tmplabel)]
109
                                      if len(el) > 0])              
110
                        lines.append([l, w, e])
111
                        words = words[len(tmplabel):]
112
                        labels = labels[len(tmplabel):]
113
                        entity_labels = entity_labels[len(tmplabel):]
114
                if len(words) == 0:
115
                    continue
116
117
                l = ' '.join([label for label in labels if len(label) > 0])
118
                w = ' '.join([word for word in words if len(word) > 0])
119
                e = ' '.join([el for el in entity_labels if len(entity_labels) > 0])
120
                lines.append([l, w, e])
121
                words = []
122
                labels = []
123
                entity_labels = []
124
                continue
125
126
        return lines
127
128
class NerDataset(Dataset):
129
    """
130
        This will be superseded by a framework-agnostic approach soon.
131
    """
132
    features: List[InputFeatures]
133
    pad_token_label_id: int = nn.CrossEntropyLoss().ignore_index
134
    def __init__(
135
        self,
136
        predict_examples,
137
        labels: List[str],
138
        tokenizer: PreTrainedTokenizer,
139
        config,
140
        params,
141
        base_name
142
    ):
143
        logger.info(f"Creating features from dataset file")
144
        self.labels = labels
145
        self.predict_examples = predict_examples
146
        self.tokenizer = tokenizer
147
        self.config = config
148
        self.params = params
149
150
        self.features = convert_examples_to_features(
151
            self.predict_examples,
152
            self.labels,
153
            self.params.max_seq_length,
154
            self.tokenizer,
155
            cls_token_at_end=bool(self.config.model_type in ["xlnet"]),
156
            cls_token=self.tokenizer.cls_token,
157
            cls_token_segment_id=2 if self.config.model_type in ["xlnet"] else 0,
158
            sep_token=self.tokenizer.sep_token,
159
            sep_token_extra=False,
160
            pad_on_left=bool(self.tokenizer.padding_side=="left"),
161
            pad_token=self.tokenizer.pad_token_id,
162
            pad_token_segment_id=self.tokenizer.pad_token_type_id,
163
            pad_token_label_id=self.pad_token_label_id,
164
            base_name=base_name,
165
        )
166
167
    def __len__(self):
168
        return len(self.features)
169
170
    def __getitem__(self, i) -> InputFeatures:
171
        return self.features[i]
172
173
class PredictionOutput(NamedTuple):
174
    predictions: np.ndarray
175
    label_ids: Optional[np.ndarray]
176
177
def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
178
    """
179
    Very simple data collator that:
180
    - simply collates batches of dict-like objects
181
    - Performs special handling for potential keys named:
182
        - `label`: handles a single value (int or float) per object
183
        - `label_ids`: handles a list of values per object
184
    - does not do any additional preprocessing
185
186
    i.e., Property names of the input object will be used as corresponding inputs to the model.
187
    See glue and ner for example of how it's useful.
188
    """
189
190
    # In this function we'll make the assumption that all `features` in the batch
191
    # have the same attributes.
192
    # So we will look at the first element as a proxy for what attributes exist
193
    # on the whole batch.
194
    if not isinstance(features[0], dict):
195
        features = [vars(f) for f in features]
196
197
    first = features[0]
198
    batch = {}
199
200
    # Special handling for labels.
201
    # Ensure that tensor is created with the correct type
202
    # (it should be automatically the case, but let's make sure of it.)
203
    if "label" in first and first["label"] is not None:
204
        dtype = torch.long if type(first["label"]) is int else torch.float
205
        batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
206
    elif "label_ids" in first and first["label_ids"] is not None:
207
        if isinstance(first["label_ids"], torch.Tensor):
208
            batch["labels"] = torch.stack([f["label_ids"] for f in features])
209
        else:
210
            dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
211
            batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
212
213
    # Handling of all other possible keys.
214
    # Again, we will use the first element to figure out which key/values are not None for this model.
215
    for k, v in first.items():
216
        if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
217
            if isinstance(v, torch.Tensor):
218
                batch[k] = torch.stack([f[k] for f in features])
219
            else:
220
                batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
221
222
    return batch
223
224
225
def convert_examples_to_features(
226
    examples: List[InputExample],
227
    label_list: List[str],
228
    max_seq_length: int,
229
    tokenizer: PreTrainedTokenizer,
230
    cls_token_at_end=False,
231
    cls_token="[CLS]",
232
    cls_token_segment_id=1,
233
    sep_token="[SEP]",
234
    sep_token_extra=False,
235
    pad_on_left=False,
236
    pad_token=0,
237
    pad_token_segment_id=0,
238
    pad_token_label_id=-100,
239
    sequence_a_segment_id=0,
240
    mask_padding_with_zero=True,
241
    base_name="",
242
) -> List[InputFeatures]:
243
    """ Loads a data file into a list of `InputFeatures`
244
        `cls_token_at_end` define the location of the CLS token:
245
            - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
246
            - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
247
        `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
248
    """
249
    # TODO clean up all this to leverage built-in features of tokenizers
250
251
    label_map = {label: i for i, label in enumerate(label_list)}
252
    features = []
253
    
254
    for (ex_index, example) in tqdm(enumerate(examples)):
255
        if ex_index % 10_000 == 0:
256
            logger.info("Writing example %d of %d", ex_index, len(examples))
257
258
        tokens, label_ids, = [], []
259
        det_tokens = []
260
261
        for word_idx, (word, label) in enumerate(zip(example.words.split(), example.labels.split())):
262
            word_tokens = tokenizer.tokenize(word)
263
            
264
            # bert-base-multilingual-cased sometimes output "nothing ([]) when calling tokenize with just a space.
265
            if len(word_tokens) > 0:
266
                tokens.extend(word_tokens)
267
                # Use the real label id for the first token of the word, and padding ids for the remaining tokens
268
                label_ids.extend([label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1))
269
270
                if len(word_tokens) == 1:
271
                    det_tokens.extend(word_tokens)
272
                elif len(word_tokens) > 1:
273
                    for det_idx, det_word in enumerate(word_tokens):
274
                        if det_idx > 0:
275
                            det_word = '##' + det_word
276
                            det_tokens.append(det_word)
277
                        else:
278
                            det_tokens.append(det_word)
279
280
        # calculate temperature with length : temp = 1 - 0.02 * length
281
        # temperature = [1 - sharpening * i if i > 1 else i for _, i in enumerate(entity_length)]
282
283
        # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
284
        special_tokens_count = tokenizer.num_special_tokens_to_add()
285
        ## truncating tokens with max_seq_length
286
        # if len(tokens) > max_seq_length - special_tokens_count:
287
        #     tokens = tokens[: (max_seq_length - special_tokens_count)]
288
        #     label_ids = label_ids[: (max_seq_length - special_tokens_count)]
289
        #     det_tokens = det_tokens[: (max_seq_length - special_tokens_count)]
290
291
        # for sliding window tokens - update 23.11.13
292
        for i in range(0, (len(tokens) // max_seq_length) + 1):
293
            if i == 0:
294
                window_tokens = tokens[i*max_seq_length:(i+1)*max_seq_length-special_tokens_count]
295
                window_label_ids = label_ids[i*max_seq_length:(i+1)*max_seq_length-special_tokens_count]
296
                window_det_tokens = det_tokens[i*max_seq_length:(i+1)*max_seq_length-special_tokens_count]
297
            elif i >= 1:
298
                window_tokens = tokens[i*max_seq_length-special_tokens_count:(i+1)*max_seq_length-special_tokens_count]
299
                window_label_ids = label_ids[i*max_seq_length-special_tokens_count:(i+1)*max_seq_length-special_tokens_count]
300
                window_det_tokens = det_tokens[i*max_seq_length-special_tokens_count:(i+1)*max_seq_length-special_tokens_count]
301
302
            # The convention in BERT is:
303
            # (a) For sequence pairs:
304
            #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
305
            #  type_ids:   0   0  0    0    0     0       0   0   1  1  1  1   1   1
306
            # (b) For single sequences:
307
            #  tokens:   [CLS] the dog is hairy . [SEP]
308
            #  type_ids:   0   0   0   0  0     0   0
309
            #
310
            # Where "type_ids" are used to indicate whether this is the first
311
            # sequence or the second sequence. The embedding vectors for `type=0` and
312
            # `type=1` were learned during pre-training and are added to the wordpiece
313
            # embedding vector (and position vector). This is not *strictly* necessary
314
            # since the [SEP] token unambiguously separates the sequences, but it makes
315
            # it easier for the model to learn the concept of sequences.
316
            #
317
            # For classification tasks, the first vector (corresponding to [CLS]) is
318
            # used as as the "sentence vector". Note that this only makes sense because
319
            # the entire model is fine-tuned.
320
            window_tokens += [sep_token]
321
            window_label_ids += [pad_token_label_id]
322
            window_det_tokens += [sep_token]
323
324
            if sep_token_extra:
325
                # roberta uses an extra separator b/w pairs of sentences
326
                window_tokens += [sep_token]
327
                window_label_ids += [pad_token_label_id]
328
                window_det_tokens += [sep_token]
329
330
            # make entity type label index for multiner
331
            entity_type_ids = [int(example.entity_labels[0])] * len(window_tokens)
332
            segment_ids = [sequence_a_segment_id] * len(window_tokens)
333
            if cls_token_at_end:
334
                window_tokens += [cls_token]
335
                window_label_ids += [pad_token_label_id]
336
                segment_ids += [cls_token_segment_id]
337
                entity_type_ids += [int(example.entity_labels[0])]
338
                window_det_tokens += [cls_token]
339
            else:
340
                window_tokens = [cls_token] + window_tokens
341
                window_label_ids = [pad_token_label_id] + window_label_ids
342
                segment_ids = [cls_token_segment_id] + segment_ids
343
                entity_type_ids = [int(example.entity_labels[0])] + entity_type_ids
344
                window_det_tokens = [cls_token] + window_det_tokens
345
346
            input_ids = tokenizer.convert_tokens_to_ids(window_tokens)
347
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
348
            # tokens are attended to.
349
            input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
350
351
            # Zero-pad up to the sequence length.
352
            padding_length = max_seq_length - len(input_ids)
353
            
354
            if pad_on_left:
355
                input_ids = ([pad_token] * padding_length) + input_ids
356
                input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
357
                segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
358
                window_label_ids = ([pad_token_label_id] * padding_length) + window_label_ids
359
                entity_type_ids = ([int(example.entity_labels[0])] * padding_length) + entity_type_ids
360
                window_tokens = (["**NULL**"] * padding_length) + window_tokens
361
                window_det_tokens = (["**NULL**"] * padding_length) + window_det_tokens
362
            else:
363
                input_ids += [pad_token] * padding_length
364
                input_mask += [0 if mask_padding_with_zero else 1] * padding_length
365
                segment_ids += [pad_token_segment_id] * padding_length
366
                window_label_ids += [pad_token_label_id] * padding_length
367
                entity_type_ids += [int(example.entity_labels[0])] * padding_length
368
                window_tokens += ["**NULL**"] * padding_length
369
                window_det_tokens += ["**NULL**"] * padding_length
370
371
            assert len(input_ids) == max_seq_length
372
            assert len(input_mask) == max_seq_length
373
            assert len(segment_ids) == max_seq_length
374
            assert len(window_label_ids) == max_seq_length
375
            assert len(entity_type_ids) == max_seq_length
376
            assert len(window_tokens) == max_seq_length
377
378
            if ex_index < 1:
379
                logger.info("*** Example ***")
380
                logger.info("guid: %s", example.guid)
381
                logger.info("tokens: %s", " ".join([str(x) for x in window_tokens]))
382
                logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
383
                logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
384
                logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
385
                logger.info("label_ids: %s", " ".join([str(x) for x in window_label_ids]))
386
                logger.info("entity_type_ids: %s", " ".join([str(x) for x in entity_type_ids]))
387
388
            if "token_type_ids" not in tokenizer.model_input_names:
389
                segment_ids = None
390
            
391
            features.append(
392
                InputFeatures(
393
                    input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids, \
394
                    label_ids=window_label_ids, entity_type_ids=entity_type_ids, \
395
                )
396
            )
397
            write_tokens(window_tokens, window_det_tokens, 'test', base_name)
398
399
    return features
400
401
def write_tokens(tokens, det_tokens, mode, base_name):
402
    if mode == "test":
403
        tmp_path = os.path.join('multi_ner', 'tmp')
404
        if not os.path.exists(tmp_path):
405
            os.makedirs(tmp_path)
406
407
        path = os.path.join("multi_ner", "tmp",
408
                            "token_{}_{}.txt".format(mode, base_name))
409
        with open(path, 'a') as wf:
410
            for token in tokens:
411
                if token != "**NULL**":
412
                    wf.write(token + '\n')
413
414
        det_path = os.path.join("multi_ner", "tmp",
415
                            "det_token_{}_{}.txt".format(mode, base_name))
416
        with open(det_path, 'a') as wf:
417
            for token in det_tokens:
418
                if token != "**NULL**":
419
                    wf.write(token + '\n')
420
421
class NerProcessor(DataProcessor):
422
    def get_test_examples(self, data_dir):
423
        data = list()
424
        pmids = list()
425
        with open(data_dir, 'r') as in_:
426
            for line in in_:
427
                line = line.strip()
428
                tmp = json.loads(line)
429
                tmp['title'] = preprocess(tmp['title'])
430
                tmp['abstract'] = preprocess(tmp['abstract'])
431
                data.append(tmp)
432
                pmids.append(tmp["pmid"])
433
434
        json_file = input_form(json_to_sent(data))
435
436
        return \
437
            self._create_example(self._read_data(json_file, pmids), "test"), \
438
            json_file, data
439
440
    def get_test_dict_list(self, dict_list):
441
        pmids = list()
442
        for d in dict_list:
443
            pmids.append(d["pmid"])
444
            
445
        json_file = input_form(json_to_sent(dict_list))
446
447
        return \
448
            self._create_example(self._read_data(json_file, pmids), "test"), \
449
            json_file
450
451
    def get_labels(self):
452
        return ["B", "I", "O"]
453
454
    def _create_example(self, lines, set_type):
455
        examples = []
456
        for (i,line) in enumerate(lines):
457
            guid = "%s-%s" % (set_type, i)
458
            text = line[1]
459
            label = line[0]
460
            entity_labels = line[2]
461
            examples.append(InputExample(guid=guid, words=text, labels=label, entity_labels=entity_labels))
462
463
        return examples
464
465
466
class BioMedNER:
467
    def __init__(self, params):
468
        # See all possible arguments in src/transformers/training_args.py
469
        # or by passing the --help flag to this script.
470
        # We now keep distinct sets of args, for a cleaner separation of concerns.
471
472
        init_start_t = time.time()
473
474
        # Set ner processor
475
        self.processor = NerProcessor()
476
        
477
        # Setup parsing
478
        self.params = params
479
        self.prediction_loss_only = False
480
481
        # Set seed
482
        set_seed(self.params.seed)
483
        
484
        # Prepare Labels
485
        self.labels = self.processor.get_labels()
486
        self.id2label: Dict[int, str] = {i: label for i, label in enumerate(self.labels)}
487
        self.label2id = {label:i for i, label in enumerate(self.labels)}
488
        self.num_labels = len(self.labels)
489
490
        self.config = AutoConfig.from_pretrained(
491
            self.params.model_name_or_path,
492
            num_labels=self.num_labels,
493
            id2label=self.id2label,
494
            label2id=self.label2id,
495
        )
496
        self.tokenizer = BertTokenizerFast.from_pretrained(
497
            self.params.model_name_or_path,
498
        )
499
        self.model = BERTMultiNER2.from_pretrained(
500
            self.params.model_name_or_path,
501
            num_labels=self.num_labels,
502
            config=self.config,
503
        )
504
        if not self.params.no_cuda:
505
            self.model = self.model.cuda()
506
        self.entity_types = ['disease', 'drug', 'gene', 'species', 'cell_line', 'DNA', 'RNA', 'cell_type']
507
                            #  'biological_structure', 'diagnostic_procedure', 'duration', 'date', 'therapeutic_procedure',
508
                            #  'sign_symptom', 'lab_value']
509
        self.estimator_dict = {}
510
        for etype in self.entity_types:
511
            self.estimator_dict[etype] = {}
512
            self.estimator_dict[etype]['prediction'] = []
513
            self.estimator_dict[etype]['log_probs'] = []
514
515
        self.counter = 0
516
        self.pad_token_label_id:int = nn.CrossEntropyLoss().ignore_index
517
        init_end_t = time.time()
518
        print('BioMedNER init_t {:.3f} sec.'.format(init_end_t - init_start_t))
519
520
    @Profile(__name__)
521
    def recognize(self, input_dl, base_name, indent=None):
522
        if type(input_dl) is str:
523
            predict_examples, self.json_dict, self.data_list = \
524
                self.processor.get_test_examples(input_dl)
525
        elif type(input_dl) is list:
526
            predict_examples, self.json_dict = \
527
                self.processor.get_test_dict_list(input_dl)
528
            self.data_list = input_dl
529
        else:
530
            raise ValueError('Wrong type')
531
532
        token_path = os.path.join("multi_ner", "tmp",        
533
                                  "token_test_{}.txt".format(base_name))
534
        det_token_path = os.path.join("multi_ner", "tmp",
535
                                  "det_token_test_{}.txt".format(base_name))
536
537
        if os.path.exists(token_path):
538
            os.remove(token_path)
539
        if os.path.exists(det_token_path):
540
            os.remove(det_token_path)
541
542
        predict_example_list = (NerDataset(predict_examples, self.labels,\
543
                                self.tokenizer, self.config, self.params, base_name))
544
        
545
        tokens, tot_tokens = list(), list()
546
547
        """
548
        Aggregate label results with detokenized tokens
549
550
        words: <s> Auto phagy main tain s tumour growth ... </s>
551
        label:  O   O     O     O    O  O    B      O   ...   O
552
553
        detok_words: <s> Authophagy maintains tumour growth ... </s>
554
        detok_label:  O       O         O        B      O   ... </s>
555
        """
556
        
557
        with open(det_token_path, 'r') as reader:
558
            for line_idx, line in enumerate(reader):
559
                tok = line.strip()
560
                tot_tokens.append(tok)
561
                
562
                if tok == '[CLS]' or tok == '<s>':
563
                    tmp_toks = [tok]
564
                elif tok == '[SEP]' or tok == '</s>':
565
                    tmp_toks.append(tok)
566
                    tokens.append(tmp_toks)
567
                else:
568
                    tmp_toks.append(tok)
569
570
        self.predict_dict, self.prob_dict = dict(), dict()
571
        threads, self.out_tag_dict = list(), dict()
572
573
        all_type = self._predict(predict_example_list)
574
        # disease, drug, gene, spec, cell_line, dna, rna, cell_type
575
        for etype_idx, etype in enumerate(self.entity_types):
576
            
577
            predictions, label_ids = all_type[etype_idx] # batch, seq, labels
578
            preds_array = self.align_predictions(predictions) # batch, seq
579
580
            self.out_tag_dict[etype] = (False, None)
581
            self.recognize_etype(etype, tokens, tot_tokens, predictions, preds_array)
582
583
        for etype in self.entity_types:
584
            if self.out_tag_dict[etype][0]:
585
                if type(input_dl) is str:
586
                    print(os.path.split(input_dl)[1],
587
                          'Found an error:', self.out_tag_dict[etype][1])
588
                else:
589
                    print('Found an error:', self.out_tag_dict[etype][1])
590
                if os.path.exists(token_path):
591
                    os.remove(token_path)
592
                return None
593
594
        # get probability of all mentions
595
        data_list = get_prob(self.data_list, self.json_dict, self.predict_dict,
596
                                  self.prob_dict, entity_types=self.entity_types)
597
598
        if type(input_dl) is str:
599
            output_path = os.path.join('result/', os.path.splitext(
600
                os.path.basename(input_dl))[0] + '_NER_{}.json'.format(base_name))
601
            print('pred', output_path)
602
603
            with open(output_path, 'w') as resultf:
604
                for paper in data_list:
605
                    paper['ner_model'] = "MULTI-TASK NER v.20210707"
606
                    resultf.write(
607
                        json.dumps(paper, sort_keys=True, indent=indent) + '\n'
608
                    )
609
        # delete temp files
610
        if os.path.exists(token_path):
611
            os.remove(token_path)
612
        if os.path.exists(det_token_path):
613
            os.remove(det_token_path)
614
615
        return data_list
616
617
    @Profile(__name__)
618
    def recognize_etype(self, etype, tokens, tot_tokens, predictions, preds_array):
619
        result = []
620
        
621
        for one_batch in range(predictions.shape[0]):
622
            result.append({'prediction':preds_array[one_batch],
623
                           'log_probs':predictions[one_batch]})
624
625
        predicts = list()
626
        logits = list()
627
        
628
        for pidx, prediction in enumerate(result):
629
            slen = len(tokens[pidx])
630
            for p in prediction['prediction'][:slen]:
631
                predicts.append(self.id2label[p])
632
            for l in prediction['log_probs'][:slen]:
633
                logits.append(l)
634
635
        de_toks, de_labels, de_logits = detokenize(tot_tokens, predicts, logits)
636
637
        self.predict_dict[etype] = dict()
638
        self.prob_dict[etype] = dict()
639
        piv = 0
640
        for data in self.data_list:
641
            pmid = data['pmid']
642
            self.predict_dict[etype][pmid] = list()
643
            self.prob_dict[etype][pmid] = list()
644
645
            sent_lens = list()
646
            for sent in self.json_dict[pmid]['words']:
647
                sent_lens.append(len(sent))
648
            sent_idx = 0
649
            de_i = 0
650
            overlen = False
651
            while True:
652
                if overlen:
653
                    
654
                    try:
655
                        self.predict_dict[etype][pmid][-1].extend(
656
                            de_labels[piv + de_i])
657
                    except Exception as e:
658
                        self.out_tag_dict[etype] = (True, e)
659
                        break
660
                    self.prob_dict[etype][pmid][-1].extend(de_logits[piv + de_i])
661
                    de_i += 1
662
                    if len(self.predict_dict[etype][pmid][-1]) == len(
663
                            self.json_dict[pmid]['words'][
664
                                len(self.predict_dict[etype][pmid]) - 1]):
665
                        sent_idx += 1
666
                        overlen = False
667
668
                else:
669
                    self.predict_dict[etype][pmid].append(de_labels[piv + de_i])
670
                    self.prob_dict[etype][pmid].append(de_logits[piv + de_i])
671
                    de_i += 1
672
                    if len(self.predict_dict[etype][pmid][-1]) == len(
673
                            self.json_dict[pmid]['words'][
674
                                len(self.predict_dict[etype][pmid]) - 1]):
675
                        sent_idx += 1
676
                        overlen = False
677
                    else:
678
                        overlen = True
679
680
                if sent_idx == len(self.json_dict[pmid]['words']):
681
                    piv += de_i
682
                    break
683
684
            if self.out_tag_dict[etype][0]:
685
                break
686
687
    def _predict(self, test_dataset:Dataset):
688
        sampler = SequentialSampler(test_dataset)
689
        data_loader = DataLoader(
690
            test_dataset,
691
            sampler=sampler,
692
            batch_size=32, # you can adjust evaluation batch size, we prefer using 32
693
            collate_fn=default_data_collator,
694
            drop_last=False,
695
        )
696
        return self._prediction_loop(data_loader, description="Prediction")
697
698
    def _prediction_loop(
699
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
700
    ) -> PredictionOutput:
701
        """
702
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.
703
704
        Works both with or without labels.
705
        """
706
        
707
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only
708
709
        model = self.model
710
        
711
        eval_losses: List[float] = []
712
        dise_preds: torch.Tensor = None
713
        chem_preds: torch.Tensor = None
714
        gene_preds: torch.Tensor = None
715
        spec_preds: torch.Tensor = None
716
        cl_preds: torch.Tensor = None
717
        dna_preds: torch.Tensor = None
718
        rna_preds: torch.Tensor = None
719
        ct_preds: torch.Tensor = None
720
        # biological_preds: torch.Tensor = None
721
        # diagnostic_preds: torch.Tensor = None
722
        # duration_preds: torch.Tensor = None
723
        # date_preds: torch.Tensor = None
724
        # therapeutic_preds: torch.Tensor = None
725
        # sign_symptom_preds: torch.Tensor = None
726
        # lab_value_preds: torch.Tensor = None
727
        label_ids: torch.Tensor = None
728
        model.eval()
729
730
        for inputs in tqdm(dataloader, desc=description):
731
            has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])
732
733
            for k, v in inputs.items():
734
                if isinstance(v, torch.Tensor):
735
                    inputs[k] = v.to(self.model.device)
736
737
            with torch.no_grad():
738
                outputs = model(**inputs)
739
                if has_labels:
740
                    step_eval_loss, logits = outputs[:2]
741
                    eval_losses += [step_eval_loss.mean().item()]
742
                else:
743
                    logits = outputs[0]
744
745
            if not prediction_loss_only:
746
                (dise_logits, chem_logits, gene_logits, spec_logits, cl_logits, dna_logits, rna_logits, ct_logits) = logits
747
                #  biological_logits, diagnostic_logits, duration_logits, date_logits, therapeutic_logits, 
748
                #  sign_symptom_logits, lab_value_logits) = logits
749
                
750
                if dise_preds is None \
751
                and chem_preds is None \
752
                and gene_preds is None \
753
                and spec_preds is None \
754
                and cl_preds is None \
755
                and dna_preds is None \
756
                and rna_preds is None \
757
                and ct_preds is None :
758
                # and biological_preds is None \
759
                # and diagnostic_preds is None \
760
                # and duration_preds is None \
761
                # and date_preds is None \
762
                # and therapeutic_preds is None \
763
                # and sign_symptom_preds is None \
764
                # and lab_value_preds is None:
765
                        
766
                    dise_preds = dise_logits.detach()
767
                    chem_preds = chem_logits.detach()
768
                    gene_preds = gene_logits.detach()
769
                    spec_preds = spec_logits.detach()
770
                    cl_preds = cl_logits.detach()
771
                    dna_preds = dna_logits.detach()
772
                    rna_preds = rna_logits.detach()
773
                    ct_preds = ct_logits.detach()
774
                    # biological_preds = biological_logits.detach()
775
                    # diagnostic_preds = diagnostic_logits.detach()
776
                    # duration_preds = duration_logits.detach()
777
                    # date_preds = date_logits.detach()
778
                    # therapeutic_preds = therapeutic_logits.detach()
779
                    # sign_symptom_preds = sign_symptom_logits.detach()
780
                    # lab_value_preds = lab_value_logits.detach()
781
                else:
782
                    dise_preds = torch.cat((dise_preds, dise_logits.detach()), dim=0)
783
                    chem_preds = torch.cat((chem_preds, chem_logits.detach()), dim=0)
784
                    gene_preds = torch.cat((gene_preds, gene_logits.detach()), dim=0)
785
                    spec_preds = torch.cat((spec_preds, spec_logits.detach()), dim=0)
786
                    cl_preds = torch.cat((cl_preds, cl_logits.detach()), dim=0)
787
                    dna_preds = torch.cat((dna_preds, dna_logits.detach()), dim=0)
788
                    rna_preds = torch.cat((rna_preds, rna_logits.detach()), dim=0)
789
                    ct_preds = torch.cat((ct_preds, ct_logits.detach()), dim=0)
790
                    # biological_preds = torch.cat((biological_preds, biological_logits.detach()), dim=0)
791
                    # diagnostic_preds = torch.cat((diagnostic_preds, diagnostic_logits.detach()), dim=0)
792
                    # duration_preds = torch.cat((duration_preds, duration_logits.detach()), dim=0)
793
                    # date_preds = torch.cat((date_preds, date_logits.detach()), dim=0)
794
                    # therapeutic_preds = torch.cat((therapeutic_preds, therapeutic_logits.detach()), dim=0)
795
                    # sign_symptom_preds = torch.cat((sign_symptom_preds, sign_symptom_logits.detach()), dim=0)
796
                    # lab_value_preds = torch.cat((lab_value_preds, lab_value_logits.detach()), dim=0)
797
                if inputs.get("labels") is not None:
798
                    if label_ids is None:
799
                        label_ids = inputs["labels"].detach()
800
                    else:
801
                        label_ids = torch.cat((label_ids, inputs["labels"].detach()), dim=0)
802
803
        # Finally, turn the aggregated tensors into numpy arrays.
804
        if dise_preds is not None \
805
        and chem_preds is not None \
806
        and gene_preds is not None \
807
        and spec_preds is not None \
808
        and cl_preds is not None \
809
        and dna_preds is not None \
810
        and rna_preds is not None \
811
        and ct_preds is not None :
812
        # and biological_preds is not None \
813
        # and diagnostic_preds is not None \
814
        # and duration_preds is not None \
815
        # and date_preds is not None \
816
        # and therapeutic_preds is not None \
817
        # and sign_symptom_preds is not None \
818
        # and lab_value_preds is not None:
819
        
820
            dise_preds = dise_preds.cpu().numpy()
821
            chem_preds = chem_preds.cpu().numpy()
822
            gene_preds = gene_preds.cpu().numpy()
823
            spec_preds = spec_preds.cpu().numpy()
824
            cl_preds = cl_preds.cpu().numpy()
825
            dna_preds = dna_preds.cpu().numpy()
826
            rna_preds = rna_preds.cpu().numpy()
827
            ct_preds = ct_preds.cpu().numpy()
828
            # biological_preds = biological_preds.cpu().numpy()
829
            # diagnostic_preds = diagnostic_preds.cpu().numpy()
830
            # duration_preds = duration_preds.cpu().numpy()
831
            # date_preds = date_preds.cpu().numpy()
832
            # therapeutic_preds = therapeutic_preds.cpu().numpy()
833
            # sign_symptom_preds = sign_symptom_preds.cpu().numpy()
834
            # lab_value_preds = lab_value_preds.cpu().numpy()
835
            
836
        if label_ids is not None:
837
            label_ids = label_ids.cpu().numpy()
838
839
        return_output = (PredictionOutput(predictions=dise_preds, label_ids=label_ids), \
840
                        PredictionOutput(predictions=chem_preds, label_ids=label_ids), \
841
                        PredictionOutput(predictions=gene_preds, label_ids=label_ids), \
842
                        PredictionOutput(predictions=spec_preds, label_ids=label_ids), \
843
                        PredictionOutput(predictions=cl_preds, label_ids=label_ids), \
844
                        PredictionOutput(predictions=dna_preds, label_ids=label_ids), \
845
                        PredictionOutput(predictions=rna_preds, label_ids=label_ids), \
846
                        PredictionOutput(predictions=ct_preds, label_ids=label_ids))
847
                        # PredictionOutput(predictions=biological_preds, label_ids=label_ids),
848
                        # PredictionOutput(predictions=diagnostic_preds, label_ids=label_ids),
849
                        # PredictionOutput(predictions=duration_preds, label_ids=label_ids),
850
                        # PredictionOutput(predictions=date_preds, label_ids=label_ids),
851
                        # PredictionOutput(predictions=therapeutic_preds, label_ids=label_ids),
852
                        # PredictionOutput(predictions=sign_symptom_preds, label_ids=label_ids),
853
                        # PredictionOutput(predictions=lab_value_preds, label_ids=label_ids))
854
                        
855
        return return_output
856
857
    def align_predictions(self, predictions: np.ndarray) -> List[int]:
858
        preds = np.argmax(predictions, axis=2)
859
        batch_size, seq_len = preds.shape
860
861
        preds_list = [[] for _ in range(batch_size)]
862
        
863
        for i in range(batch_size):
864
            for j in range(seq_len):
865
                preds_list[i].append(preds[i][j])
866
867
        return np.array(preds_list)
868
869
def main():
870
    os.environ["CUDA_VISIBLE_DEVICES"]="6"
871
872
    argparser = argparse.ArgumentParser()
873
    argparser.add_argument('--model_name_or_path', default='dmis-lab/bern2-ner')
874
    argparser.add_argument('--max_seq_length', type=int, help='The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded.',
875
                            default=128)
876
    argparser.add_argument('--seed', type=int, help='random seed for initialization',
877
                            default=1)
878
    args = argparser.parse_args()
879
880
    biomedner = BioMedNER(args)
881
882
if __name__ == "__main__":
883
    main()