Diff of /src/Parser/models.py [000000] .. [f87529]

Switch to unified view

a b/src/Parser/models.py
1
# coding=utf-8
2
import os
3
import pdb
4
import copy
5
import torch
6
import torch.nn.functional as F
7
from torch import nn
8
9
from torch.nn import CrossEntropyLoss
10
from transformers import (
11
        BertConfig,
12
        BertModel,
13
        RobertaModel,
14
        BertForTokenClassification,
15
        BertTokenizer,
16
        RobertaConfig,
17
        RobertaForTokenClassification,
18
        RobertaTokenizer, 
19
        AutoTokenizer,
20
)
21
22
class BERTMultiNER2(BertForTokenClassification):
23
    def __init__(self, config, num_labels=3):
24
        super(BERTMultiNER2, self).__init__(config)
25
        self.num_labels = num_labels
26
        self.bert = BertModel(config)
27
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
28
        
29
        self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease
30
        self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical
31
        self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein
32
        self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species
33
        self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line
34
        self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna
35
        self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna
36
        self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type
37
        
38
        # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure
39
        # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure
40
        # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration
41
        # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date
42
        # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure
43
        # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom
44
        # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value
45
        
46
47
        self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
48
        self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
49
        self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
50
        self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
51
        self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
52
        self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
53
        self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
54
        self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
55
        
56
        # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
57
        # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
58
        # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
59
        # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
60
        # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
61
        # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
62
        # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
63
64
        self.init_weights()
65
66
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None):
67
        sequence_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0]
68
        batch_size,max_len,feat_dim = sequence_output.shape
69
        sequence_output = self.dropout(sequence_output)
70
71
        if entity_type_ids[0][0].item() == 0:
72
            '''
73
            Raw text data with trained parameters
74
            '''
75
            dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value
76
            chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value
77
            gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value
78
            spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value
79
            cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value
80
            dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value
81
            rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value
82
            celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value
83
            
84
            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value
85
            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value
86
            # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value
87
            # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value
88
            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value
89
            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value
90
            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value
91
92
93
            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
94
            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
95
            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
96
            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
97
            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
98
            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
99
            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
100
            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
101
            
102
            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
103
            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
104
            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
105
            # date_logits = self.date_classifier(date_sequence_output) # date logit value
106
            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
107
            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
108
            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
109
110
            
111
112
            # update logit and sequence_output
113
            sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
114
            # + \
115
            #     biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \
116
            #     therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output 
117
                
118
            logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, 
119
                      dna_logits, rna_logits, celltype_logits)
120
                    #   biological_logits, diagnostic_logits,
121
                    #   duration_logits, date_logits, therapeutic_logits,
122
                    #   sign_symptom_logits, lab_value_logits)
123
        else:
124
            ''' 
125
            Train, Eval, Test with pre-defined entity type tags
126
            '''
127
            # make 1*1 conv to adopt entity type
128
            dise_idx = copy.deepcopy(entity_type_ids)
129
            chem_idx = copy.deepcopy(entity_type_ids)
130
            gene_idx = copy.deepcopy(entity_type_ids)
131
            spec_idx = copy.deepcopy(entity_type_ids)
132
            cellline_idx = copy.deepcopy(entity_type_ids)
133
            dna_idx = copy.deepcopy(entity_type_ids)
134
            rna_idx = copy.deepcopy(entity_type_ids)
135
            celltype_idx = copy.deepcopy(entity_type_ids)
136
            
137
            # biological_idx = copy.deepcopy(entity_type_ids)
138
            # diagnostic_idx = copy.deepcopy(entity_type_ids)
139
            # duration_idx = copy.deepcopy(entity_type_ids)
140
            # date_idx = copy.deepcopy(entity_type_ids)
141
            # therapeutic_idx = copy.deepcopy(entity_type_ids)
142
            # sign_symptom_idx = copy.deepcopy(entity_type_ids)
143
            # lab_value_idx = copy.deepcopy(entity_type_ids)
144
145
            
146
147
            dise_idx[dise_idx != 1] = 0
148
            chem_idx[chem_idx != 2] = 0
149
            gene_idx[gene_idx != 3] = 0
150
            spec_idx[spec_idx != 4] = 0
151
            cellline_idx[cellline_idx != 5] = 0
152
            dna_idx[dna_idx != 6] = 0
153
            rna_idx[rna_idx != 7] = 0
154
            celltype_idx[celltype_idx != 8] = 0
155
            # biological_idx[biological_idx != 9] = 0
156
            # diagnostic_idx[diagnostic_idx != 10] = 0
157
            # duration_idx[duration_idx != 11] = 0
158
            # date_idx[date_idx != 12] = 0
159
            # therapeutic_idx[therapeutic_idx != 13] = 0
160
            # sign_symptom_idx[sign_symptom_idx != 14] = 0
161
            # lab_value_idx[lab_value_idx != 15] = 0
162
163
            dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output        
164
            chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output
165
            gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output
166
            spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output
167
            cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output
168
            dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output
169
            rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output
170
            celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output
171
            # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output
172
            # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output
173
            # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output
174
            # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output
175
            # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output
176
            # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output
177
            # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output
178
179
            
180
            # F.tanh or F.relu
181
            dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value
182
            chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value
183
            gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value
184
            spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value
185
            cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value
186
            dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value
187
            rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value
188
            celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value
189
190
            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value
191
            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value
192
            # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value
193
            # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value
194
            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value
195
            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value
196
            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value
197
198
            
199
200
            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
201
            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
202
            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
203
            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
204
            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
205
            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
206
            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
207
            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
208
            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
209
            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
210
            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
211
            # date_logits = self.date_classifier(date_sequence_output)
212
            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
213
            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
214
            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
215
216
            
217
218
            # update logit and sequence_output
219
            sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
220
            # \
221
            #     + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \
222
            #     + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output
223
                
224
            logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \
225
                + dna_logits + rna_logits + celltype_logits 
226
                # + biological_logits \
227
                # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \
228
                # + sign_symptom_logits + lab_value_logits 
229
                
230
231
        outputs = (logits, sequence_output)
232
        if labels is not None:
233
            loss_fct = CrossEntropyLoss()
234
            # Only keep active parts of the loss
235
            if attention_mask is not None:
236
                if entity_type_ids[0][0].item() == 0:
237
                    active_loss = attention_mask.view(-1) == 1
238
                    dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \
239
                    dna_logits, rna_logits, celltype_logits = logits
240
                    
241
                    #  biological_logits, diagnostic_logits, \
242
                    # duration_logits, date_logits, therapeutic_logits, \
243
                    # sign_symptom_logits, lab_value_logits
244
245
246
                    active_dise_logits = dise_logits.view(-1, self.num_labels)
247
                    active_chem_logits = chem_logits.view(-1, self.num_labels)
248
                    active_gene_logits = gene_logits.view(-1, self.num_labels)
249
                    active_spec_logits = spec_logits.view(-1, self.num_labels)
250
                    active_cellline_logits = cellline_logits.view(-1, self.num_labels)
251
                    active_dna_logits = dna_logits.view(-1, self.num_labels)
252
                    active_rna_logits = rna_logits.view(-1, self.num_labels)
253
                    active_celltype_logits = celltype_logits.view(-1, self.num_labels)
254
                    # active_biological_logits = biological_logits.view(-1, self.num_labels)
255
                    # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels)
256
                    # active_duration_logits = duration_logits.view(-1, self.num_labels)
257
                    # active_date_logits = date_logits.view(-1, self.num_labels)
258
                    # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels)
259
                    # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels)
260
                    # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels)
261
                    
262
                    
263
                    active_labels = torch.where(
264
                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
265
                    )
266
                    dise_loss = loss_fct(active_dise_logits, active_labels)
267
                    chem_loss = loss_fct(active_chem_logits, active_labels)
268
                    gene_loss = loss_fct(active_gene_logits, active_labels)
269
                    spec_loss = loss_fct(active_spec_logits, active_labels)
270
                    cellline_loss = loss_fct(active_cellline_logits, active_labels)
271
                    dna_loss = loss_fct(active_dna_logits, active_labels)
272
                    rna_loss = loss_fct(active_rna_logits, active_labels)
273
                    celltype_loss = loss_fct(active_celltype_logits, active_labels)
274
                    # biological_loss = loss_fct(active_biological_logits, active_labels)
275
                    # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels)
276
                    # duration_loss = loss_fct(active_duration_logits, active_labels)
277
                    # date_loss = loss_fct(active_date_logits, active_labels)
278
                    # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels)
279
                    # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels)
280
                    # lab_value_loss = loss_fct(active_lab_value_logits, active_labels)
281
                    
282
                    loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss 
283
                    # \
284
                    #      + biological_loss + diagnostic_loss \
285
                    #     + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss 
286
                        
287
                    return ((loss,) + outputs)
288
                else:
289
                    active_loss = attention_mask.view(-1) == 1
290
                    active_logits = logits.view(-1, self.num_labels)
291
                    active_labels = torch.where(
292
                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
293
                    )
294
                    loss = loss_fct(active_logits, active_labels)
295
                    return ((loss,) + outputs)
296
            else:
297
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
298
                return loss
299
        else:
300
            return logits
301
302
class RoBERTaMultiNER2(RobertaForTokenClassification):
303
    def __init__(self, config, num_labels=3):
304
        super(RoBERTaMultiNER2, self).__init__(config)
305
        self.num_labels = num_labels
306
        self.roberta = RobertaModel(config)
307
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
308
        
309
        self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease
310
        self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical
311
        self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein
312
        self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species
313
        self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line
314
        self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna
315
        self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna
316
        self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type
317
        
318
        # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure
319
        # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure
320
        # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration
321
        # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date
322
        # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure
323
        # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom
324
        # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value
325
        
326
327
        self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
328
        self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
329
        self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
330
        self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
331
        self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
332
        self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
333
        self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
334
        self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
335
        
336
        # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
337
        # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
338
        # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
339
        # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
340
        # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
341
        # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
342
        # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
343
344
        self.init_weights()
345
346
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None):
347
        sequence_output = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0]
348
        batch_size,max_len,feat_dim = sequence_output.shape
349
        sequence_output = self.dropout(sequence_output)
350
351
        if entity_type_ids[0][0].item() == 0:
352
            '''
353
            Raw text data with trained parameters
354
            '''
355
            dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value
356
            chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value
357
            gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value
358
            spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value
359
            cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value
360
            dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value
361
            rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value
362
            celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value
363
            
364
            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value
365
            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value
366
            # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value
367
            # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value
368
            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value
369
            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value
370
            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value
371
372
373
            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
374
            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
375
            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
376
            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
377
            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
378
            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
379
            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
380
            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
381
            
382
            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
383
            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
384
            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
385
            # date_logits = self.date_classifier(date_sequence_output) # date logit value
386
            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
387
            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
388
            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
389
390
            
391
392
            # update logit and sequence_output
393
            sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
394
            # + \
395
            #     biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \
396
            #     therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output 
397
                
398
            logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, 
399
                      dna_logits, rna_logits, celltype_logits)
400
                    #   biological_logits, diagnostic_logits,
401
                    #   duration_logits, date_logits, therapeutic_logits,
402
                    #   sign_symptom_logits, lab_value_logits)
403
        else:
404
            ''' 
405
            Train, Eval, Test with pre-defined entity type tags
406
            '''
407
            # make 1*1 conv to adopt entity type
408
            dise_idx = copy.deepcopy(entity_type_ids)
409
            chem_idx = copy.deepcopy(entity_type_ids)
410
            gene_idx = copy.deepcopy(entity_type_ids)
411
            spec_idx = copy.deepcopy(entity_type_ids)
412
            cellline_idx = copy.deepcopy(entity_type_ids)
413
            dna_idx = copy.deepcopy(entity_type_ids)
414
            rna_idx = copy.deepcopy(entity_type_ids)
415
            celltype_idx = copy.deepcopy(entity_type_ids)
416
            
417
            # biological_idx = copy.deepcopy(entity_type_ids)
418
            # diagnostic_idx = copy.deepcopy(entity_type_ids)
419
            # duration_idx = copy.deepcopy(entity_type_ids)
420
            # date_idx = copy.deepcopy(entity_type_ids)
421
            # therapeutic_idx = copy.deepcopy(entity_type_ids)
422
            # sign_symptom_idx = copy.deepcopy(entity_type_ids)
423
            # lab_value_idx = copy.deepcopy(entity_type_ids)
424
425
            
426
427
            dise_idx[dise_idx != 1] = 0
428
            chem_idx[chem_idx != 2] = 0
429
            gene_idx[gene_idx != 3] = 0
430
            spec_idx[spec_idx != 4] = 0
431
            cellline_idx[cellline_idx != 5] = 0
432
            dna_idx[dna_idx != 6] = 0
433
            rna_idx[rna_idx != 7] = 0
434
            celltype_idx[celltype_idx != 8] = 0
435
            # biological_idx[biological_idx != 9] = 0
436
            # diagnostic_idx[diagnostic_idx != 10] = 0
437
            # duration_idx[duration_idx != 11] = 0
438
            # date_idx[date_idx != 12] = 0
439
            # therapeutic_idx[therapeutic_idx != 13] = 0
440
            # sign_symptom_idx[sign_symptom_idx != 14] = 0
441
            # lab_value_idx[lab_value_idx != 15] = 0
442
443
444
            dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output        
445
            chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output
446
            gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output
447
            spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output
448
            cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output
449
            dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output
450
            rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output
451
            celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output
452
            # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output
453
            # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output
454
            # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output
455
            # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output
456
            # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output
457
            # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output
458
            # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output
459
460
            
461
            # F.tanh or F.relu
462
            dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value
463
            chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value
464
            gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value
465
            spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value
466
            cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value
467
            dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value
468
            rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value
469
            celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value
470
471
            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value
472
            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value
473
            # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value
474
            # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value
475
            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value
476
            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value
477
            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value
478
479
            
480
481
            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
482
            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
483
            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
484
            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
485
            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
486
            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
487
            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
488
            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
489
            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
490
            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
491
            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
492
            # date_logits = self.date_classifier(date_sequence_output)
493
            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
494
            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
495
            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
496
497
            
498
499
            # update logit and sequence_output
500
            sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
501
                # + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \
502
                # + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output
503
                
504
            logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \
505
                + dna_logits + rna_logits + celltype_logits 
506
                # + biological_logits \
507
                # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \
508
                # + sign_symptom_logits + lab_value_logits 
509
                
510
511
        outputs = (logits, sequence_output)
512
        if labels is not None:
513
            loss_fct = CrossEntropyLoss()
514
            # Only keep active parts of the loss
515
            if attention_mask is not None:
516
                if entity_type_ids[0][0].item() == 0:
517
                    active_loss = attention_mask.view(-1) == 1
518
                    dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \
519
                    dna_logits, rna_logits, celltype_logits = logits
520
                    
521
                    #  biological_logits, diagnostic_logits, \
522
                    # duration_logits, date_logits, therapeutic_logits, \
523
                    # sign_symptom_logits, lab_value_logits
524
525
526
                    active_dise_logits = dise_logits.view(-1, self.num_labels)
527
                    active_chem_logits = chem_logits.view(-1, self.num_labels)
528
                    active_gene_logits = gene_logits.view(-1, self.num_labels)
529
                    active_spec_logits = spec_logits.view(-1, self.num_labels)
530
                    active_cellline_logits = cellline_logits.view(-1, self.num_labels)
531
                    active_dna_logits = dna_logits.view(-1, self.num_labels)
532
                    active_rna_logits = rna_logits.view(-1, self.num_labels)
533
                    active_celltype_logits = celltype_logits.view(-1, self.num_labels)
534
                    # active_biological_logits = biological_logits.view(-1, self.num_labels)
535
                    # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels)
536
                    # active_duration_logits = duration_logits.view(-1, self.num_labels)
537
                    # active_date_logits = date_logits.view(-1, self.num_labels)
538
                    # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels)
539
                    # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels)
540
                    # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels)
541
                    
542
                    
543
                    active_labels = torch.where(
544
                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
545
                    )
546
                    dise_loss = loss_fct(active_dise_logits, active_labels)
547
                    chem_loss = loss_fct(active_chem_logits, active_labels)
548
                    gene_loss = loss_fct(active_gene_logits, active_labels)
549
                    spec_loss = loss_fct(active_spec_logits, active_labels)
550
                    cellline_loss = loss_fct(active_cellline_logits, active_labels)
551
                    dna_loss = loss_fct(active_dna_logits, active_labels)
552
                    rna_loss = loss_fct(active_rna_logits, active_labels)
553
                    celltype_loss = loss_fct(active_celltype_logits, active_labels)
554
                    # biological_loss = loss_fct(active_biological_logits, active_labels)
555
                    # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels)
556
                    # duration_loss = loss_fct(active_duration_logits, active_labels)
557
                    # date_loss = loss_fct(active_date_logits, active_labels)
558
                    # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels)
559
                    # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels)
560
                    # lab_value_loss = loss_fct(active_lab_value_logits, active_labels)
561
                    
562
                    loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss 
563
                        #  + biological_loss + diagnostic_loss \
564
                        # + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss 
565
                        
566
                    return ((loss,) + outputs)
567
                else:
568
                    active_loss = attention_mask.view(-1) == 1
569
                    active_logits = logits.view(-1, self.num_labels)
570
                    active_labels = torch.where(
571
                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
572
                    )
573
                    loss = loss_fct(active_logits, active_labels)
574
                    return ((loss,) + outputs)
575
            else:
576
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
577
                return loss
578
        else:
579
            return logits
580
581
582
class NER(BertForTokenClassification):
583
    def __init__(self, config, num_labels=3):
584
        super(NER, self).__init__(config)
585
        self.num_labels = num_labels
586
        self.bert = BertModel(config)
587
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
588
        self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
589
590
        self.init_weights()
591
592
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
593
        sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
594
        batch_size,max_len,feat_dim = sequence_output.shape
595
        sequence_output = self.dropout(sequence_output)
596
597
        logits = self.classifier(sequence_output)
598
599
        outputs = (logits, sequence_output)
600
        if labels is not None:
601
            loss_fct = CrossEntropyLoss()
602
            # Only keep active parts of the loss
603
            if attention_mask is not None:
604
                active_loss = attention_mask.view(-1) == 1
605
                active_logits = logits.view(-1, self.num_labels)
606
                active_labels = torch.where(
607
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
608
                )
609
                loss = loss_fct(active_logits, active_labels)
610
                return ((loss,) + outputs)
611
            else:
612
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
613
                return loss
614
        else:
615
            return logits
616
617