Diff of /src/Parser/models.py [000000] .. [f87529]

Switch to side-by-side view

--- a
+++ b/src/Parser/models.py
@@ -0,0 +1,617 @@
+# coding=utf-8
+import os
+import pdb
+import copy
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from torch.nn import CrossEntropyLoss
+from transformers import (
+        BertConfig,
+        BertModel,
+        RobertaModel,
+        BertForTokenClassification,
+        BertTokenizer,
+        RobertaConfig,
+        RobertaForTokenClassification,
+        RobertaTokenizer, 
+        AutoTokenizer,
+)
+
+class BERTMultiNER2(BertForTokenClassification):
+    def __init__(self, config, num_labels=3):
+        super(BERTMultiNER2, self).__init__(config)
+        self.num_labels = num_labels
+        self.bert = BertModel(config)
+        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+        
+        self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease
+        self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical
+        self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein
+        self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species
+        self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line
+        self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna
+        self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna
+        self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type
+        
+        # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure
+        # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure
+        # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration
+        # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date
+        # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure
+        # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom
+        # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value
+        
+
+        self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        
+        # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
+        # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
+        # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+
+        self.init_weights()
+
+    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None):
+        sequence_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0]
+        batch_size,max_len,feat_dim = sequence_output.shape
+        sequence_output = self.dropout(sequence_output)
+
+        if entity_type_ids[0][0].item() == 0:
+            '''
+            Raw text data with trained parameters
+            '''
+            dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value
+            chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value
+            gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value
+            spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value
+            cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value
+            dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value
+            rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value
+            celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value
+            
+            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value
+            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value
+            # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value
+            # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value
+            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value
+            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value
+            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value
+
+
+            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
+            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
+            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
+            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
+            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
+            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
+            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
+            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
+            
+            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
+            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
+            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
+            # date_logits = self.date_classifier(date_sequence_output) # date logit value
+            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
+            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
+            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
+
+            
+
+            # update logit and sequence_output
+            sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
+            # + \
+            #     biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \
+            #     therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output 
+                
+            logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, 
+                      dna_logits, rna_logits, celltype_logits)
+                    #   biological_logits, diagnostic_logits,
+                    #   duration_logits, date_logits, therapeutic_logits,
+                    #   sign_symptom_logits, lab_value_logits)
+        else:
+            ''' 
+            Train, Eval, Test with pre-defined entity type tags
+            '''
+            # make 1*1 conv to adopt entity type
+            dise_idx = copy.deepcopy(entity_type_ids)
+            chem_idx = copy.deepcopy(entity_type_ids)
+            gene_idx = copy.deepcopy(entity_type_ids)
+            spec_idx = copy.deepcopy(entity_type_ids)
+            cellline_idx = copy.deepcopy(entity_type_ids)
+            dna_idx = copy.deepcopy(entity_type_ids)
+            rna_idx = copy.deepcopy(entity_type_ids)
+            celltype_idx = copy.deepcopy(entity_type_ids)
+            
+            # biological_idx = copy.deepcopy(entity_type_ids)
+            # diagnostic_idx = copy.deepcopy(entity_type_ids)
+            # duration_idx = copy.deepcopy(entity_type_ids)
+            # date_idx = copy.deepcopy(entity_type_ids)
+            # therapeutic_idx = copy.deepcopy(entity_type_ids)
+            # sign_symptom_idx = copy.deepcopy(entity_type_ids)
+            # lab_value_idx = copy.deepcopy(entity_type_ids)
+
+            
+
+            dise_idx[dise_idx != 1] = 0
+            chem_idx[chem_idx != 2] = 0
+            gene_idx[gene_idx != 3] = 0
+            spec_idx[spec_idx != 4] = 0
+            cellline_idx[cellline_idx != 5] = 0
+            dna_idx[dna_idx != 6] = 0
+            rna_idx[rna_idx != 7] = 0
+            celltype_idx[celltype_idx != 8] = 0
+            # biological_idx[biological_idx != 9] = 0
+            # diagnostic_idx[diagnostic_idx != 10] = 0
+            # duration_idx[duration_idx != 11] = 0
+            # date_idx[date_idx != 12] = 0
+            # therapeutic_idx[therapeutic_idx != 13] = 0
+            # sign_symptom_idx[sign_symptom_idx != 14] = 0
+            # lab_value_idx[lab_value_idx != 15] = 0
+
+            dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output        
+            chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output
+            gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output
+            spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output
+            cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output
+            dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output
+            rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output
+            celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output
+            # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output
+            # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output
+            # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output
+            # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output
+            # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output
+            # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output
+            # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output
+
+            
+            # F.tanh or F.relu
+            dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value
+            chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value
+            gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value
+            spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value
+            cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value
+            dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value
+            rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value
+            celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value
+
+            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value
+            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value
+            # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value
+            # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value
+            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value
+            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value
+            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value
+
+            
+
+            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
+            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
+            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
+            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
+            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
+            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
+            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
+            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
+            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
+            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
+            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
+            # date_logits = self.date_classifier(date_sequence_output)
+            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
+            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
+            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
+
+            
+
+            # update logit and sequence_output
+            sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
+            # \
+            #     + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \
+            #     + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output
+                
+            logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \
+                + dna_logits + rna_logits + celltype_logits 
+                # + biological_logits \
+                # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \
+                # + sign_symptom_logits + lab_value_logits 
+                
+
+        outputs = (logits, sequence_output)
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            # Only keep active parts of the loss
+            if attention_mask is not None:
+                if entity_type_ids[0][0].item() == 0:
+                    active_loss = attention_mask.view(-1) == 1
+                    dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \
+                    dna_logits, rna_logits, celltype_logits = logits
+                    
+                    #  biological_logits, diagnostic_logits, \
+                    # duration_logits, date_logits, therapeutic_logits, \
+                    # sign_symptom_logits, lab_value_logits
+
+
+                    active_dise_logits = dise_logits.view(-1, self.num_labels)
+                    active_chem_logits = chem_logits.view(-1, self.num_labels)
+                    active_gene_logits = gene_logits.view(-1, self.num_labels)
+                    active_spec_logits = spec_logits.view(-1, self.num_labels)
+                    active_cellline_logits = cellline_logits.view(-1, self.num_labels)
+                    active_dna_logits = dna_logits.view(-1, self.num_labels)
+                    active_rna_logits = rna_logits.view(-1, self.num_labels)
+                    active_celltype_logits = celltype_logits.view(-1, self.num_labels)
+                    # active_biological_logits = biological_logits.view(-1, self.num_labels)
+                    # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels)
+                    # active_duration_logits = duration_logits.view(-1, self.num_labels)
+                    # active_date_logits = date_logits.view(-1, self.num_labels)
+                    # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels)
+                    # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels)
+                    # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels)
+                    
+                    
+                    active_labels = torch.where(
+                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                    )
+                    dise_loss = loss_fct(active_dise_logits, active_labels)
+                    chem_loss = loss_fct(active_chem_logits, active_labels)
+                    gene_loss = loss_fct(active_gene_logits, active_labels)
+                    spec_loss = loss_fct(active_spec_logits, active_labels)
+                    cellline_loss = loss_fct(active_cellline_logits, active_labels)
+                    dna_loss = loss_fct(active_dna_logits, active_labels)
+                    rna_loss = loss_fct(active_rna_logits, active_labels)
+                    celltype_loss = loss_fct(active_celltype_logits, active_labels)
+                    # biological_loss = loss_fct(active_biological_logits, active_labels)
+                    # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels)
+                    # duration_loss = loss_fct(active_duration_logits, active_labels)
+                    # date_loss = loss_fct(active_date_logits, active_labels)
+                    # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels)
+                    # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels)
+                    # lab_value_loss = loss_fct(active_lab_value_logits, active_labels)
+                    
+                    loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss 
+                    # \
+                    #      + biological_loss + diagnostic_loss \
+                    #     + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss 
+                        
+                    return ((loss,) + outputs)
+                else:
+                    active_loss = attention_mask.view(-1) == 1
+                    active_logits = logits.view(-1, self.num_labels)
+                    active_labels = torch.where(
+                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                    )
+                    loss = loss_fct(active_logits, active_labels)
+                    return ((loss,) + outputs)
+            else:
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+                return loss
+        else:
+            return logits
+
+class RoBERTaMultiNER2(RobertaForTokenClassification):
+    def __init__(self, config, num_labels=3):
+        super(RoBERTaMultiNER2, self).__init__(config)
+        self.num_labels = num_labels
+        self.roberta = RobertaModel(config)
+        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+        
+        self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease
+        self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical
+        self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein
+        self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species
+        self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line
+        self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna
+        self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna
+        self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type
+        
+        # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure
+        # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure
+        # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration
+        # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date
+        # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure
+        # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom
+        # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value
+        
+
+        self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        
+        # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
+        # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)    
+        # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+        # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size)
+
+        self.init_weights()
+
+    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None):
+        sequence_output = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0]
+        batch_size,max_len,feat_dim = sequence_output.shape
+        sequence_output = self.dropout(sequence_output)
+
+        if entity_type_ids[0][0].item() == 0:
+            '''
+            Raw text data with trained parameters
+            '''
+            dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value
+            chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value
+            gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value
+            spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value
+            cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value
+            dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value
+            rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value
+            celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value
+            
+            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value
+            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value
+            # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value
+            # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value
+            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value
+            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value
+            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value
+
+
+            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
+            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
+            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
+            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
+            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
+            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
+            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
+            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
+            
+            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
+            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
+            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
+            # date_logits = self.date_classifier(date_sequence_output) # date logit value
+            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
+            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
+            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
+
+            
+
+            # update logit and sequence_output
+            sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
+            # + \
+            #     biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \
+            #     therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output 
+                
+            logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, 
+                      dna_logits, rna_logits, celltype_logits)
+                    #   biological_logits, diagnostic_logits,
+                    #   duration_logits, date_logits, therapeutic_logits,
+                    #   sign_symptom_logits, lab_value_logits)
+        else:
+            ''' 
+            Train, Eval, Test with pre-defined entity type tags
+            '''
+            # make 1*1 conv to adopt entity type
+            dise_idx = copy.deepcopy(entity_type_ids)
+            chem_idx = copy.deepcopy(entity_type_ids)
+            gene_idx = copy.deepcopy(entity_type_ids)
+            spec_idx = copy.deepcopy(entity_type_ids)
+            cellline_idx = copy.deepcopy(entity_type_ids)
+            dna_idx = copy.deepcopy(entity_type_ids)
+            rna_idx = copy.deepcopy(entity_type_ids)
+            celltype_idx = copy.deepcopy(entity_type_ids)
+            
+            # biological_idx = copy.deepcopy(entity_type_ids)
+            # diagnostic_idx = copy.deepcopy(entity_type_ids)
+            # duration_idx = copy.deepcopy(entity_type_ids)
+            # date_idx = copy.deepcopy(entity_type_ids)
+            # therapeutic_idx = copy.deepcopy(entity_type_ids)
+            # sign_symptom_idx = copy.deepcopy(entity_type_ids)
+            # lab_value_idx = copy.deepcopy(entity_type_ids)
+
+            
+
+            dise_idx[dise_idx != 1] = 0
+            chem_idx[chem_idx != 2] = 0
+            gene_idx[gene_idx != 3] = 0
+            spec_idx[spec_idx != 4] = 0
+            cellline_idx[cellline_idx != 5] = 0
+            dna_idx[dna_idx != 6] = 0
+            rna_idx[rna_idx != 7] = 0
+            celltype_idx[celltype_idx != 8] = 0
+            # biological_idx[biological_idx != 9] = 0
+            # diagnostic_idx[diagnostic_idx != 10] = 0
+            # duration_idx[duration_idx != 11] = 0
+            # date_idx[date_idx != 12] = 0
+            # therapeutic_idx[therapeutic_idx != 13] = 0
+            # sign_symptom_idx[sign_symptom_idx != 14] = 0
+            # lab_value_idx[lab_value_idx != 15] = 0
+
+
+            dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output        
+            chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output
+            gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output
+            spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output
+            cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output
+            dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output
+            rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output
+            celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output
+            # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output
+            # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output
+            # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output
+            # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output
+            # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output
+            # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output
+            # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output
+
+            
+            # F.tanh or F.relu
+            dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value
+            chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value
+            gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value
+            spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value
+            cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value
+            dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value
+            rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value
+            celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value
+
+            # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value
+            # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value
+            # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value
+            # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value
+            # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value
+            # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value
+            # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value
+
+            
+
+            dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value
+            chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value
+            gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value
+            spec_logits = self.spec_classifier(spec_sequence_output) # species logit value
+            cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value
+            dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value
+            rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value
+            celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value
+            # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value
+            # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value
+            # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value
+            # date_logits = self.date_classifier(date_sequence_output)
+            # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value
+            # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value
+            # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value
+
+            
+
+            # update logit and sequence_output
+            sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output 
+                # + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \
+                # + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output
+                
+            logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \
+                + dna_logits + rna_logits + celltype_logits 
+                # + biological_logits \
+                # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \
+                # + sign_symptom_logits + lab_value_logits 
+                
+
+        outputs = (logits, sequence_output)
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            # Only keep active parts of the loss
+            if attention_mask is not None:
+                if entity_type_ids[0][0].item() == 0:
+                    active_loss = attention_mask.view(-1) == 1
+                    dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \
+                    dna_logits, rna_logits, celltype_logits = logits
+                    
+                    #  biological_logits, diagnostic_logits, \
+                    # duration_logits, date_logits, therapeutic_logits, \
+                    # sign_symptom_logits, lab_value_logits
+
+
+                    active_dise_logits = dise_logits.view(-1, self.num_labels)
+                    active_chem_logits = chem_logits.view(-1, self.num_labels)
+                    active_gene_logits = gene_logits.view(-1, self.num_labels)
+                    active_spec_logits = spec_logits.view(-1, self.num_labels)
+                    active_cellline_logits = cellline_logits.view(-1, self.num_labels)
+                    active_dna_logits = dna_logits.view(-1, self.num_labels)
+                    active_rna_logits = rna_logits.view(-1, self.num_labels)
+                    active_celltype_logits = celltype_logits.view(-1, self.num_labels)
+                    # active_biological_logits = biological_logits.view(-1, self.num_labels)
+                    # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels)
+                    # active_duration_logits = duration_logits.view(-1, self.num_labels)
+                    # active_date_logits = date_logits.view(-1, self.num_labels)
+                    # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels)
+                    # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels)
+                    # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels)
+                    
+                    
+                    active_labels = torch.where(
+                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                    )
+                    dise_loss = loss_fct(active_dise_logits, active_labels)
+                    chem_loss = loss_fct(active_chem_logits, active_labels)
+                    gene_loss = loss_fct(active_gene_logits, active_labels)
+                    spec_loss = loss_fct(active_spec_logits, active_labels)
+                    cellline_loss = loss_fct(active_cellline_logits, active_labels)
+                    dna_loss = loss_fct(active_dna_logits, active_labels)
+                    rna_loss = loss_fct(active_rna_logits, active_labels)
+                    celltype_loss = loss_fct(active_celltype_logits, active_labels)
+                    # biological_loss = loss_fct(active_biological_logits, active_labels)
+                    # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels)
+                    # duration_loss = loss_fct(active_duration_logits, active_labels)
+                    # date_loss = loss_fct(active_date_logits, active_labels)
+                    # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels)
+                    # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels)
+                    # lab_value_loss = loss_fct(active_lab_value_logits, active_labels)
+                    
+                    loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss 
+                        #  + biological_loss + diagnostic_loss \
+                        # + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss 
+                        
+                    return ((loss,) + outputs)
+                else:
+                    active_loss = attention_mask.view(-1) == 1
+                    active_logits = logits.view(-1, self.num_labels)
+                    active_labels = torch.where(
+                        active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                    )
+                    loss = loss_fct(active_logits, active_labels)
+                    return ((loss,) + outputs)
+            else:
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+                return loss
+        else:
+            return logits
+
+
+class NER(BertForTokenClassification):
+    def __init__(self, config, num_labels=3):
+        super(NER, self).__init__(config)
+        self.num_labels = num_labels
+        self.bert = BertModel(config)
+        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+        self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels)
+
+        self.init_weights()
+
+    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
+        sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
+        batch_size,max_len,feat_dim = sequence_output.shape
+        sequence_output = self.dropout(sequence_output)
+
+        logits = self.classifier(sequence_output)
+
+        outputs = (logits, sequence_output)
+        if labels is not None:
+            loss_fct = CrossEntropyLoss()
+            # Only keep active parts of the loss
+            if attention_mask is not None:
+                active_loss = attention_mask.view(-1) == 1
+                active_logits = logits.view(-1, self.num_labels)
+                active_labels = torch.where(
+                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+                )
+                loss = loss_fct(active_logits, active_labels)
+                return ((loss,) + outputs)
+            else:
+                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+                return loss
+        else:
+            return logits
+
+