--- a +++ b/src/Parser/models.py @@ -0,0 +1,617 @@ +# coding=utf-8 +import os +import pdb +import copy +import torch +import torch.nn.functional as F +from torch import nn + +from torch.nn import CrossEntropyLoss +from transformers import ( + BertConfig, + BertModel, + RobertaModel, + BertForTokenClassification, + BertTokenizer, + RobertaConfig, + RobertaForTokenClassification, + RobertaTokenizer, + AutoTokenizer, +) + +class BERTMultiNER2(BertForTokenClassification): + def __init__(self, config, num_labels=3): + super(BERTMultiNER2, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + + self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease + self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical + self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein + self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species + self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line + self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna + self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna + self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type + + # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure + # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure + # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration + # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date + # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure + # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom + # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value + + + self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + + # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + + self.init_weights() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None): + sequence_output = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0] + batch_size,max_len,feat_dim = sequence_output.shape + sequence_output = self.dropout(sequence_output) + + if entity_type_ids[0][0].item() == 0: + ''' + Raw text data with trained parameters + ''' + dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value + chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value + gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value + spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value + cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value + dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value + rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value + celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value + + # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value + # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value + # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value + # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value + # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value + # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value + # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value + + + dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value + chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value + gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value + spec_logits = self.spec_classifier(spec_sequence_output) # species logit value + cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value + dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value + rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value + celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value + + # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value + # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value + # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value + # date_logits = self.date_classifier(date_sequence_output) # date logit value + # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value + # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value + # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value + + + + # update logit and sequence_output + sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output + # + \ + # biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \ + # therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output + + logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, + dna_logits, rna_logits, celltype_logits) + # biological_logits, diagnostic_logits, + # duration_logits, date_logits, therapeutic_logits, + # sign_symptom_logits, lab_value_logits) + else: + ''' + Train, Eval, Test with pre-defined entity type tags + ''' + # make 1*1 conv to adopt entity type + dise_idx = copy.deepcopy(entity_type_ids) + chem_idx = copy.deepcopy(entity_type_ids) + gene_idx = copy.deepcopy(entity_type_ids) + spec_idx = copy.deepcopy(entity_type_ids) + cellline_idx = copy.deepcopy(entity_type_ids) + dna_idx = copy.deepcopy(entity_type_ids) + rna_idx = copy.deepcopy(entity_type_ids) + celltype_idx = copy.deepcopy(entity_type_ids) + + # biological_idx = copy.deepcopy(entity_type_ids) + # diagnostic_idx = copy.deepcopy(entity_type_ids) + # duration_idx = copy.deepcopy(entity_type_ids) + # date_idx = copy.deepcopy(entity_type_ids) + # therapeutic_idx = copy.deepcopy(entity_type_ids) + # sign_symptom_idx = copy.deepcopy(entity_type_ids) + # lab_value_idx = copy.deepcopy(entity_type_ids) + + + + dise_idx[dise_idx != 1] = 0 + chem_idx[chem_idx != 2] = 0 + gene_idx[gene_idx != 3] = 0 + spec_idx[spec_idx != 4] = 0 + cellline_idx[cellline_idx != 5] = 0 + dna_idx[dna_idx != 6] = 0 + rna_idx[rna_idx != 7] = 0 + celltype_idx[celltype_idx != 8] = 0 + # biological_idx[biological_idx != 9] = 0 + # diagnostic_idx[diagnostic_idx != 10] = 0 + # duration_idx[duration_idx != 11] = 0 + # date_idx[date_idx != 12] = 0 + # therapeutic_idx[therapeutic_idx != 13] = 0 + # sign_symptom_idx[sign_symptom_idx != 14] = 0 + # lab_value_idx[lab_value_idx != 15] = 0 + + dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output + chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output + gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output + spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output + cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output + dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output + rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output + celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output + # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output + # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output + # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output + # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output + # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output + # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output + # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output + + + # F.tanh or F.relu + dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value + chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value + gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value + spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value + cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value + dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value + rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value + celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value + + # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value + # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value + # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value + # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value + # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value + # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value + # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value + + + + dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value + chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value + gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value + spec_logits = self.spec_classifier(spec_sequence_output) # species logit value + cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value + dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value + rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value + celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value + # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value + # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value + # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value + # date_logits = self.date_classifier(date_sequence_output) + # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value + # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value + # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value + + + + # update logit and sequence_output + sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output + # \ + # + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \ + # + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output + + logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \ + + dna_logits + rna_logits + celltype_logits + # + biological_logits \ + # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \ + # + sign_symptom_logits + lab_value_logits + + + outputs = (logits, sequence_output) + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + if entity_type_ids[0][0].item() == 0: + active_loss = attention_mask.view(-1) == 1 + dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \ + dna_logits, rna_logits, celltype_logits = logits + + # biological_logits, diagnostic_logits, \ + # duration_logits, date_logits, therapeutic_logits, \ + # sign_symptom_logits, lab_value_logits + + + active_dise_logits = dise_logits.view(-1, self.num_labels) + active_chem_logits = chem_logits.view(-1, self.num_labels) + active_gene_logits = gene_logits.view(-1, self.num_labels) + active_spec_logits = spec_logits.view(-1, self.num_labels) + active_cellline_logits = cellline_logits.view(-1, self.num_labels) + active_dna_logits = dna_logits.view(-1, self.num_labels) + active_rna_logits = rna_logits.view(-1, self.num_labels) + active_celltype_logits = celltype_logits.view(-1, self.num_labels) + # active_biological_logits = biological_logits.view(-1, self.num_labels) + # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels) + # active_duration_logits = duration_logits.view(-1, self.num_labels) + # active_date_logits = date_logits.view(-1, self.num_labels) + # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels) + # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels) + # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels) + + + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + dise_loss = loss_fct(active_dise_logits, active_labels) + chem_loss = loss_fct(active_chem_logits, active_labels) + gene_loss = loss_fct(active_gene_logits, active_labels) + spec_loss = loss_fct(active_spec_logits, active_labels) + cellline_loss = loss_fct(active_cellline_logits, active_labels) + dna_loss = loss_fct(active_dna_logits, active_labels) + rna_loss = loss_fct(active_rna_logits, active_labels) + celltype_loss = loss_fct(active_celltype_logits, active_labels) + # biological_loss = loss_fct(active_biological_logits, active_labels) + # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels) + # duration_loss = loss_fct(active_duration_logits, active_labels) + # date_loss = loss_fct(active_date_logits, active_labels) + # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels) + # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels) + # lab_value_loss = loss_fct(active_lab_value_logits, active_labels) + + loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss + # \ + # + biological_loss + diagnostic_loss \ + # + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss + + return ((loss,) + outputs) + else: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + return ((loss,) + outputs) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + +class RoBERTaMultiNER2(RobertaForTokenClassification): + def __init__(self, config, num_labels=3): + super(RoBERTaMultiNER2, self).__init__(config) + self.num_labels = num_labels + self.roberta = RobertaModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + + self.dise_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # disease + self.chem_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # chemical + self.gene_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # gene/protein + self.spec_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # species + self.cellline_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell line + self.dna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # dna + self.rna_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # rna + self.celltype_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # cell type + + # self.biological_structure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # biological structure + # self.diagnostic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # diagnostic procedure + # self.duration_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # duration + # self.date_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # date + # self.therapeutic_procedure_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # therapeutic procedure + # self.sign_symptom_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # sign/symptom + # self.lab_value_classifier = torch.nn.Linear(config.hidden_size, self.num_labels) # lab value + + + self.dise_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.chem_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.gene_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.spec_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.cellline_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.dna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.rna_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + self.celltype_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + + # self.biological_structure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.diagnostic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.duration_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.date_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.therapeutic_procedure_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.sign_symptom_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + # self.lab_value_classifier_2 = torch.nn.Linear(config.hidden_size, config.hidden_size) + + self.init_weights() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, entity_type_ids=None): + sequence_output = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=None)[0] + batch_size,max_len,feat_dim = sequence_output.shape + sequence_output = self.dropout(sequence_output) + + if entity_type_ids[0][0].item() == 0: + ''' + Raw text data with trained parameters + ''' + dise_sequence_output = F.relu(self.dise_classifier_2(sequence_output)) # disease logit value + chem_sequence_output = F.relu(self.chem_classifier_2(sequence_output)) # chemical logit value + gene_sequence_output = F.relu(self.gene_classifier_2(sequence_output)) # gene/protein logit value + spec_sequence_output = F.relu(self.spec_classifier_2(sequence_output)) # species logit value + cellline_sequence_output = F.relu(self.cellline_classifier_2(sequence_output)) # cell line logit value + dna_sequence_output = F.relu(self.dna_classifier_2(sequence_output)) # dna logit value + rna_sequence_output = F.relu(self.rna_classifier_2(sequence_output)) # rna logit value + celltype_sequence_output = F.relu(self.celltype_classifier_2(sequence_output)) # cell type logit value + + # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(sequence_output)) # biological structure logit value + # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(sequence_output)) # diagnostic procedure logit value + # duration_sequence_output = F.relu(self.duration_classifier_2(sequence_output)) # duration logit value + # date_sequence_output = F.relu(self.date_classifier_2(sequence_output)) # date logit value + # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(sequence_output)) # therapeutic procedure logit value + # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sequence_output)) # sign/symptom logit value + # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(sequence_output)) # lab value logit value + + + dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value + chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value + gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value + spec_logits = self.spec_classifier(spec_sequence_output) # species logit value + cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value + dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value + rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value + celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value + + # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value + # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value + # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value + # date_logits = self.date_classifier(date_sequence_output) # date logit value + # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value + # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value + # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value + + + + # update logit and sequence_output + sequence_output = dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output + # + \ + # biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output + \ + # therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output + + logits = (dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, + dna_logits, rna_logits, celltype_logits) + # biological_logits, diagnostic_logits, + # duration_logits, date_logits, therapeutic_logits, + # sign_symptom_logits, lab_value_logits) + else: + ''' + Train, Eval, Test with pre-defined entity type tags + ''' + # make 1*1 conv to adopt entity type + dise_idx = copy.deepcopy(entity_type_ids) + chem_idx = copy.deepcopy(entity_type_ids) + gene_idx = copy.deepcopy(entity_type_ids) + spec_idx = copy.deepcopy(entity_type_ids) + cellline_idx = copy.deepcopy(entity_type_ids) + dna_idx = copy.deepcopy(entity_type_ids) + rna_idx = copy.deepcopy(entity_type_ids) + celltype_idx = copy.deepcopy(entity_type_ids) + + # biological_idx = copy.deepcopy(entity_type_ids) + # diagnostic_idx = copy.deepcopy(entity_type_ids) + # duration_idx = copy.deepcopy(entity_type_ids) + # date_idx = copy.deepcopy(entity_type_ids) + # therapeutic_idx = copy.deepcopy(entity_type_ids) + # sign_symptom_idx = copy.deepcopy(entity_type_ids) + # lab_value_idx = copy.deepcopy(entity_type_ids) + + + + dise_idx[dise_idx != 1] = 0 + chem_idx[chem_idx != 2] = 0 + gene_idx[gene_idx != 3] = 0 + spec_idx[spec_idx != 4] = 0 + cellline_idx[cellline_idx != 5] = 0 + dna_idx[dna_idx != 6] = 0 + rna_idx[rna_idx != 7] = 0 + celltype_idx[celltype_idx != 8] = 0 + # biological_idx[biological_idx != 9] = 0 + # diagnostic_idx[diagnostic_idx != 10] = 0 + # duration_idx[duration_idx != 11] = 0 + # date_idx[date_idx != 12] = 0 + # therapeutic_idx[therapeutic_idx != 13] = 0 + # sign_symptom_idx[sign_symptom_idx != 14] = 0 + # lab_value_idx[lab_value_idx != 15] = 0 + + + dise_sequence_output = dise_idx.unsqueeze(-1) * sequence_output + chem_sequence_output = chem_idx.unsqueeze(-1) * sequence_output + gene_sequence_output = gene_idx.unsqueeze(-1) * sequence_output + spec_sequence_output = spec_idx.unsqueeze(-1) * sequence_output + cellline_sequence_output = cellline_idx.unsqueeze(-1) * sequence_output + dna_sequence_output = dna_idx.unsqueeze(-1) * sequence_output + rna_sequence_output = rna_idx.unsqueeze(-1) * sequence_output + celltype_sequence_output = celltype_idx.unsqueeze(-1) * sequence_output + # biological_structure_sequence_output = biological_idx.unsqueeze(-1) * sequence_output + # diagnostic_procedure_sequence_output = diagnostic_idx.unsqueeze(-1) * sequence_output + # duration_sequence_output = duration_idx.unsqueeze(-1) * sequence_output + # date_sequence_output = date_idx.unsqueeze(-1) * sequence_output + # therapeutic_procedure_sequence_output = therapeutic_idx.unsqueeze(-1) * sequence_output + # sign_symptom_sequence_output = sign_symptom_idx.unsqueeze(-1) * sequence_output + # lab_value_sequence_output = lab_value_idx.unsqueeze(-1) * sequence_output + + + # F.tanh or F.relu + dise_sequence_output = F.relu(self.dise_classifier_2(dise_sequence_output)) # disease logit value + chem_sequence_output = F.relu(self.chem_classifier_2(chem_sequence_output)) # chemical logit value + gene_sequence_output = F.relu(self.gene_classifier_2(gene_sequence_output)) # gene/protein logit value + spec_sequence_output = F.relu(self.spec_classifier_2(spec_sequence_output)) # species logit value + cellline_sequence_output = F.relu(self.cellline_classifier_2(cellline_sequence_output)) # cell line logit value + dna_sequence_output = F.relu(self.dna_classifier_2(dna_sequence_output)) # dna logit value + rna_sequence_output = F.relu(self.rna_classifier_2(rna_sequence_output)) # rna logit value + celltype_sequence_output = F.relu(self.celltype_classifier_2(celltype_sequence_output)) # cell type logit value + + # biological_structure_sequence_output = F.relu(self.biological_structure_classifier_2(biological_structure_sequence_output)) # biological structure logit value + # diagnostic_procedure_sequence_output = F.relu(self.diagnostic_procedure_classifier_2(diagnostic_procedure_sequence_output)) # diagnostic procedure logit value + # duration_sequence_output = F.relu(self.duration_classifier_2(duration_sequence_output)) # duration logit value + # date_sequence_output = F.relu(self.date_classifier_2(date_sequence_output)) # date logit value + # therapeutic_procedure_sequence_output = F.relu(self.therapeutic_procedure_classifier_2(therapeutic_procedure_sequence_output)) # therapeutic procedure logit value + # sign_symptom_sequence_output = F.relu(self.sign_symptom_classifier_2(sign_symptom_sequence_output)) # sign/symptom logit value + # lab_value_sequence_output = F.relu(self.lab_value_classifier_2(lab_value_sequence_output)) # lab value logit value + + + + dise_logits = self.dise_classifier(dise_sequence_output) # disease logit value + chem_logits = self.chem_classifier(chem_sequence_output) # chemical logit value + gene_logits = self.gene_classifier(gene_sequence_output) # gene/protein logit value + spec_logits = self.spec_classifier(spec_sequence_output) # species logit value + cellline_logits = self.cellline_classifier(cellline_sequence_output) # cell line logit value + dna_logits = self.dna_classifier(dna_sequence_output) # dna logit value + rna_logits = self.rna_classifier(rna_sequence_output) # rna logit value + celltype_logits = self.celltype_classifier(celltype_sequence_output) # cell type logit value + # biological_logits = self.biological_structure_classifier(biological_structure_sequence_output) # biological structure logit value + # diagnostic_logits = self.diagnostic_procedure_classifier(diagnostic_procedure_sequence_output) # diagnostic procedure logit value + # duration_logits = self.duration_classifier(duration_sequence_output) # duration logit value + # date_logits = self.date_classifier(date_sequence_output) + # therapeutic_logits = self.therapeutic_procedure_classifier(therapeutic_procedure_sequence_output) # therapeutic procedure logit value + # sign_symptom_logits = self.sign_symptom_classifier(sign_symptom_sequence_output) # sign/symptom logit value + # lab_value_logits = self.lab_value_classifier(lab_value_sequence_output) # lab value logit value + + + + # update logit and sequence_output + sequence_output =dise_sequence_output + chem_sequence_output + gene_sequence_output + spec_sequence_output + cellline_sequence_output + dna_sequence_output + rna_sequence_output + celltype_sequence_output + # + biological_structure_sequence_output + diagnostic_procedure_sequence_output + duration_sequence_output + date_sequence_output \ + # + therapeutic_procedure_sequence_output + sign_symptom_sequence_output + lab_value_sequence_output + + logits = dise_logits + chem_logits + gene_logits + spec_logits + cellline_logits \ + + dna_logits + rna_logits + celltype_logits + # + biological_logits \ + # + diagnostic_logits + duration_logits + date_logits + therapeutic_logits \ + # + sign_symptom_logits + lab_value_logits + + + outputs = (logits, sequence_output) + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + if entity_type_ids[0][0].item() == 0: + active_loss = attention_mask.view(-1) == 1 + dise_logits, chem_logits, gene_logits, spec_logits, cellline_logits, \ + dna_logits, rna_logits, celltype_logits = logits + + # biological_logits, diagnostic_logits, \ + # duration_logits, date_logits, therapeutic_logits, \ + # sign_symptom_logits, lab_value_logits + + + active_dise_logits = dise_logits.view(-1, self.num_labels) + active_chem_logits = chem_logits.view(-1, self.num_labels) + active_gene_logits = gene_logits.view(-1, self.num_labels) + active_spec_logits = spec_logits.view(-1, self.num_labels) + active_cellline_logits = cellline_logits.view(-1, self.num_labels) + active_dna_logits = dna_logits.view(-1, self.num_labels) + active_rna_logits = rna_logits.view(-1, self.num_labels) + active_celltype_logits = celltype_logits.view(-1, self.num_labels) + # active_biological_logits = biological_logits.view(-1, self.num_labels) + # active_diagnostic_logits = diagnostic_logits.view(-1, self.num_labels) + # active_duration_logits = duration_logits.view(-1, self.num_labels) + # active_date_logits = date_logits.view(-1, self.num_labels) + # active_therapeutic_logits = therapeutic_logits.view(-1, self.num_labels) + # active_sign_symptom_logits = sign_symptom_logits.view(-1, self.num_labels) + # active_lab_value_logits = lab_value_logits.view(-1, self.num_labels) + + + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + dise_loss = loss_fct(active_dise_logits, active_labels) + chem_loss = loss_fct(active_chem_logits, active_labels) + gene_loss = loss_fct(active_gene_logits, active_labels) + spec_loss = loss_fct(active_spec_logits, active_labels) + cellline_loss = loss_fct(active_cellline_logits, active_labels) + dna_loss = loss_fct(active_dna_logits, active_labels) + rna_loss = loss_fct(active_rna_logits, active_labels) + celltype_loss = loss_fct(active_celltype_logits, active_labels) + # biological_loss = loss_fct(active_biological_logits, active_labels) + # diagnostic_loss = loss_fct(active_diagnostic_logits, active_labels) + # duration_loss = loss_fct(active_duration_logits, active_labels) + # date_loss = loss_fct(active_date_logits, active_labels) + # therapeutic_loss = loss_fct(active_therapeutic_logits, active_labels) + # sign_symptom_loss = loss_fct(active_sign_symptom_logits, active_labels) + # lab_value_loss = loss_fct(active_lab_value_logits, active_labels) + + loss = dise_loss + chem_loss + gene_loss + spec_loss + cellline_loss + dna_loss + rna_loss + celltype_loss + # + biological_loss + diagnostic_loss \ + # + duration_loss + date_loss + therapeutic_loss + sign_symptom_loss + lab_value_loss + + return ((loss,) + outputs) + else: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + return ((loss,) + outputs) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class NER(BertForTokenClassification): + def __init__(self, config, num_labels=3): + super(NER, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, self.num_labels) + + self.init_weights() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0] + batch_size,max_len,feat_dim = sequence_output.shape + sequence_output = self.dropout(sequence_output) + + logits = self.classifier(sequence_output) + + outputs = (logits, sequence_output) + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + return ((loss,) + outputs) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + +