--- a +++ b/model_sentiment.py @@ -0,0 +1,249 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification +import math +import numpy as np +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +"""MLP baseline using sentence level eeg""" +# using sent level EEG, MLP baseline for sentiment +class BaselineMLPSentence(nn.Module): + def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3): + super(BaselineMLPSentence, self).__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(hidden_dim, hidden_dim) + self.relu2 = nn.ReLU() + self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral + self.dropout = nn.Dropout(0.25) + + def forward(self, x): + out = self.fc1(x) + out = self.relu1(out) + out = self.fc2(out) + out = self.relu2(out) + out = self.dropout(out) + out = self.fc3(out) + return out + + +"""bidirectional LSTM baseline using word level eeg""" +class BaselineLSTM(nn.Module): + def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1): + super(BaselineLSTM, self).__init__() + + self.hidden_dim = hidden_dim + + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True) + + self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim) + + def forward(self, x_packed): + # input: (N,seq_len,input_dim) + # print(x_packed.data.size()) + lstm_out, _ = self.lstm(x_packed) + last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:] + # print(last_hidden_state.size()) + out = self.hidden2sentiment(last_hidden_state) + return out + +""" Bert Baseline: Finetuning from a pretrained language model Bert""" +class NaiveFineTunePretrainedBert(nn.Module): + def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None): + super(NaiveFineTunePretrainedBert, self).__init__() + # mapping hidden states dimensioin + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3) + + if pretrained_checkpoint is not None: + self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint)) + + def forward(self, input_embeddings_batch, input_masks_batch, labels): + embedding = F.relu(self.fc1(input_embeddings_batch)) + out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True) + return out + +""" Finetuning from a pretrained language model BART, two step training""" +class FineTunePretrainedTwoStep(nn.Module): + def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048): + super(FineTunePretrainedTwoStep, self).__init__() + + self.pretrained_layers = pretrained_layers + # additional transformer encoder, following BART paper about + self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) + self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) + + # NOTE: add positional embedding? + # print('[INFO]adding positional embedding') + # self.positional_embedding = PositionalEncoding(in_feature) + + self.fc1 = nn.Linear(in_feature, d_model) + + def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels): + """input_embeddings_batch: batch_size*Seq_len*840""" + """input_mask: 1 is not masked, 0 is masked""" + """input_masks_invert: 1 is masked, 0 is not masked""" + """labels: sentitment labels 0,1,2""" + + # NOTE: add positional embedding? + # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) + + # use src_key_padding_masks + encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) + # encoded_embedding = self.additional_encoder(input_embeddings_batch) + + encoded_embedding = F.relu(self.fc1(encoded_embedding)) + out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels) + + return out + +""" Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text """ +class ZeroShotSentimentDiscovery(nn.Module): + def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'): + # only for inference + super(ZeroShotSentimentDiscovery, self).__init__() + + self.brain2text_translator = brain2text_translator + self.sentiment_classifier = sentiment_classifier + self.translation_tokenizer = translation_tokenizer + self.sentiment_tokenizer = sentiment_tokenizer + self.device = device + + + def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels): + """input_embeddings_batch: batch_size*Seq_len*840""" + """input_mask: 1 is not masked, 0 is masked""" + """input_masks_invert: 1 is masked, 0 is not masked""" + """labels: sentitment labels 0,1,2""" + + def logits2PredString(logits): + probs = logits[0].softmax(dim = 1) + # print('probs size:', probs.size()) + values, predictions = probs.topk(1) + # print('predictions before squeeze:',predictions.size()) + predictions = torch.squeeze(predictions) + predict_string = self.translation_tokenizer.decode(predictions) + return predict_string + + # only works on batch is one + assert input_embeddings_batch.size()[0] == 1 + + seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted) + predict_string = logits2PredString(seq2seqLMoutput.logits) + predict_string = predict_string.split('</s></s>')[0] + predict_string = predict_string.replace('<s>','') + print('predict string:', predict_string) + re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True) + input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1 + attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1 + + out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels) + + return out + + +""" Miscellaneous: jointly learn generation and classification (not working well) """ +class BartClassificationHead(nn.Module): + # from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html + """Head for sentence-level classification tasks.""" + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor): + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + +class JointBrainTranslatorSentimentClassifier(nn.Module): + def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3): + super(JointBrainTranslatorSentimentClassifier, self).__init__() + + self.pretrained_generator = pretrained_layers + # additional transformer encoder, following BART paper about + self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, dim_feedforward = additional_encoder_dim_feedforward, batch_first=True) + self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) + self.fc1 = nn.Linear(in_feature, d_model) + self.num_labels = num_labels + + self.pooler = Pooler(d_model) + self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout) + + def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels): + """input_embeddings_batch: batch_size*Seq_len*840""" + """input_mask: 1 is not masked, 0 is masked""" + """input_masks_invert: 1 is masked, 0 is not masked""" + + # NOTE: add positional embedding? + # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) + + # use src_key_padding_masks + encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) + + # encoded_embedding = self.additional_encoder(input_embeddings_batch) + encoded_embedding = F.relu(self.fc1(encoded_embedding)) + LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True) + hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim + # print('hidden states len:', len(hidden_states)) + last_hidden_states = hidden_states[-1] + # print('last hidden states size:', last_hidden_states.size()) + sentence_representation = self.pooler(last_hidden_states) + + classification_logits = self.classifier(sentence_representation) + loss_fct = nn.CrossEntropyLoss() + classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1)) + classification_output = {'loss':classification_loss,'logits':classification_logits} + # print('successful one forward!!!!') + return LMoutput, classification_output + + +""" helper modules """ +# modified from BertPooler +class Pooler(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + # print('[DEBUG] input size:', x.size()) + # print('[DEBUG] positional embedding size:', self.pe.size()) + x = x + self.pe[:x.size(0), :] + # print('[DEBUG] output x with pe size:', x.size()) + return self.dropout(x) +