Diff of /model_sentiment.py [000000] .. [66af30]

Switch to side-by-side view

--- a
+++ b/model_sentiment.py
@@ -0,0 +1,249 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data
+from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification
+import math
+import numpy as np
+from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
+
+"""MLP baseline using sentence level eeg"""
+# using sent level EEG, MLP baseline for sentiment
+class BaselineMLPSentence(nn.Module):
+    def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3):
+        super(BaselineMLPSentence, self).__init__()
+        self.fc1 = nn.Linear(input_dim, hidden_dim) 
+        self.relu1 = nn.ReLU()
+        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
+        self.relu2 = nn.ReLU()
+        self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral  
+        self.dropout = nn.Dropout(0.25)
+
+    def forward(self, x):
+        out = self.fc1(x)
+        out = self.relu1(out)
+        out = self.fc2(out)
+        out = self.relu2(out)
+        out = self.dropout(out)
+        out = self.fc3(out)
+        return out
+
+
+"""bidirectional LSTM baseline using word level eeg"""
+class BaselineLSTM(nn.Module):
+    def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1):
+        super(BaselineLSTM, self).__init__()
+        
+        self.hidden_dim = hidden_dim
+
+        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True)
+
+        self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim)
+
+    def forward(self, x_packed):
+        # input: (N,seq_len,input_dim)
+        # print(x_packed.data.size())
+        lstm_out, _ = self.lstm(x_packed)
+        last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:]
+        # print(last_hidden_state.size())
+        out = self.hidden2sentiment(last_hidden_state)
+        return out
+
+""" Bert Baseline: Finetuning from a pretrained language model Bert"""
+class NaiveFineTunePretrainedBert(nn.Module):
+    def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None):
+        super(NaiveFineTunePretrainedBert, self).__init__()
+        # mapping hidden states dimensioin
+        self.fc1 = nn.Linear(input_dim, hidden_dim)
+        self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
+        
+        if pretrained_checkpoint is not None:
+            self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint))
+
+    def forward(self, input_embeddings_batch, input_masks_batch, labels):
+        embedding = F.relu(self.fc1(input_embeddings_batch))
+        out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True)
+        return out
+
+""" Finetuning from a pretrained language model BART, two step training"""
+class FineTunePretrainedTwoStep(nn.Module):
+    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
+        super(FineTunePretrainedTwoStep, self).__init__()
+        
+        self.pretrained_layers = pretrained_layers
+        # additional transformer encoder, following BART paper about 
+        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
+        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
+        
+        # NOTE: add positional embedding?
+        # print('[INFO]adding positional embedding')
+        # self.positional_embedding = PositionalEncoding(in_feature)
+
+        self.fc1 = nn.Linear(in_feature, d_model)
+
+    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels):
+        """input_embeddings_batch: batch_size*Seq_len*840"""
+        """input_mask: 1 is not masked, 0 is masked"""
+        """input_masks_invert: 1 is masked, 0 is not masked"""
+        """labels: sentitment labels 0,1,2"""
+        
+        # NOTE: add positional embedding?
+        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) 
+
+        # use src_key_padding_masks
+        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) 
+        # encoded_embedding = self.additional_encoder(input_embeddings_batch) 
+        
+        encoded_embedding = F.relu(self.fc1(encoded_embedding))
+        out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels)                    
+        
+        return out
+
+""" Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text """
+class ZeroShotSentimentDiscovery(nn.Module):
+    def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'):
+        # only for inference
+        super(ZeroShotSentimentDiscovery, self).__init__()
+        
+        self.brain2text_translator = brain2text_translator
+        self.sentiment_classifier = sentiment_classifier
+        self.translation_tokenizer = translation_tokenizer
+        self.sentiment_tokenizer = sentiment_tokenizer
+        self.device = device
+    
+
+    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
+        """input_embeddings_batch: batch_size*Seq_len*840"""
+        """input_mask: 1 is not masked, 0 is masked"""
+        """input_masks_invert: 1 is masked, 0 is not masked"""
+        """labels: sentitment labels 0,1,2"""
+        
+        def logits2PredString(logits):
+            probs = logits[0].softmax(dim = 1)
+            # print('probs size:', probs.size())
+            values, predictions = probs.topk(1)
+            # print('predictions before squeeze:',predictions.size())
+            predictions = torch.squeeze(predictions)
+            predict_string = self.translation_tokenizer.decode(predictions)
+            return predict_string
+
+        # only works on batch is one
+        assert input_embeddings_batch.size()[0] == 1
+
+        seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted)
+        predict_string = logits2PredString(seq2seqLMoutput.logits)
+        predict_string = predict_string.split('</s></s>')[0]
+        predict_string = predict_string.replace('<s>','')
+        print('predict string:', predict_string)
+        re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True)
+        input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1
+        attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1
+
+        out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels)
+
+        return out
+
+
+""" Miscellaneous: jointly learn generation and classification (not working well) """
+class BartClassificationHead(nn.Module):
+    # from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html
+    """Head for sentence-level classification tasks."""
+    def __init__(
+        self,
+        input_dim: int,
+        inner_dim: int,
+        num_classes: int,
+        pooler_dropout: float,
+    ):
+        super().__init__()
+        self.dense = nn.Linear(input_dim, inner_dim)
+        self.dropout = nn.Dropout(p=pooler_dropout)
+        self.out_proj = nn.Linear(inner_dim, num_classes)
+
+    def forward(self, hidden_states: torch.Tensor):
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.dense(hidden_states)
+        hidden_states = torch.tanh(hidden_states)
+        hidden_states = self.dropout(hidden_states)
+        hidden_states = self.out_proj(hidden_states)
+        return hidden_states
+
+class JointBrainTranslatorSentimentClassifier(nn.Module):
+    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3):
+        super(JointBrainTranslatorSentimentClassifier, self).__init__()
+        
+        self.pretrained_generator = pretrained_layers
+        # additional transformer encoder, following BART paper about 
+        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
+        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
+        self.fc1 = nn.Linear(in_feature, d_model)
+        self.num_labels = num_labels
+
+        self.pooler = Pooler(d_model)
+        self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout)
+
+    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
+        """input_embeddings_batch: batch_size*Seq_len*840"""
+        """input_mask: 1 is not masked, 0 is masked"""
+        """input_masks_invert: 1 is masked, 0 is not masked"""
+        
+        # NOTE: add positional embedding?
+        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) 
+
+        # use src_key_padding_masks
+        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) 
+        
+        # encoded_embedding = self.additional_encoder(input_embeddings_batch) 
+        encoded_embedding = F.relu(self.fc1(encoded_embedding))
+        LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True)                    
+        hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim
+        # print('hidden states len:', len(hidden_states))
+        last_hidden_states = hidden_states[-1]
+        # print('last hidden states size:', last_hidden_states.size())
+        sentence_representation = self.pooler(last_hidden_states)
+ 
+        classification_logits = self.classifier(sentence_representation) 
+        loss_fct = nn.CrossEntropyLoss()
+        classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1))
+        classification_output = {'loss':classification_loss,'logits':classification_logits}
+        # print('successful one forward!!!!')
+        return LMoutput, classification_output
+
+
+""" helper modules """
+# modified from BertPooler
+class Pooler(nn.Module):
+    def __init__(self, hidden_size):
+        super().__init__()
+        self.dense = nn.Linear(hidden_size, hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model, dropout=0.1, max_len=5000):
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x):
+        # print('[DEBUG] input size:', x.size())
+        # print('[DEBUG] positional embedding size:', self.pe.size())
+        x = x + self.pe[:x.size(0), :]
+        # print('[DEBUG] output x with pe size:', x.size())
+        return self.dropout(x)
+