a b/model_sentiment.py
1
import torch.nn as nn
2
import torch.nn.functional as F
3
import torch.utils.data
4
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertForSequenceClassification
5
import math
6
import numpy as np
7
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8
9
"""MLP baseline using sentence level eeg"""
10
# using sent level EEG, MLP baseline for sentiment
11
class BaselineMLPSentence(nn.Module):
12
    def __init__(self, input_dim = 840, hidden_dim = 128, output_dim = 3):
13
        super(BaselineMLPSentence, self).__init__()
14
        self.fc1 = nn.Linear(input_dim, hidden_dim) 
15
        self.relu1 = nn.ReLU()
16
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
17
        self.relu2 = nn.ReLU()
18
        self.fc3 = nn.Linear(hidden_dim, output_dim) # positive, negative, neutral  
19
        self.dropout = nn.Dropout(0.25)
20
21
    def forward(self, x):
22
        out = self.fc1(x)
23
        out = self.relu1(out)
24
        out = self.fc2(out)
25
        out = self.relu2(out)
26
        out = self.dropout(out)
27
        out = self.fc3(out)
28
        return out
29
30
31
"""bidirectional LSTM baseline using word level eeg"""
32
class BaselineLSTM(nn.Module):
33
    def __init__(self, input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1):
34
        super(BaselineLSTM, self).__init__()
35
        
36
        self.hidden_dim = hidden_dim
37
38
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers = 1, batch_first = True, bidirectional = True)
39
40
        self.hidden2sentiment = nn.Linear(hidden_dim*2, output_dim)
41
42
    def forward(self, x_packed):
43
        # input: (N,seq_len,input_dim)
44
        # print(x_packed.data.size())
45
        lstm_out, _ = self.lstm(x_packed)
46
        last_hidden_state = pad_packed_sequence(lstm_out, batch_first = True)[0][:,-1,:]
47
        # print(last_hidden_state.size())
48
        out = self.hidden2sentiment(last_hidden_state)
49
        return out
50
51
""" Bert Baseline: Finetuning from a pretrained language model Bert"""
52
class NaiveFineTunePretrainedBert(nn.Module):
53
    def __init__(self, input_dim = 840, hidden_dim = 768, output_dim = 3, pretrained_checkpoint = None):
54
        super(NaiveFineTunePretrainedBert, self).__init__()
55
        # mapping hidden states dimensioin
56
        self.fc1 = nn.Linear(input_dim, hidden_dim)
57
        self.pretrained_Bert = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
58
        
59
        if pretrained_checkpoint is not None:
60
            self.pretrained_Bert.load_state_dict(torch.load(pretrained_checkpoint))
61
62
    def forward(self, input_embeddings_batch, input_masks_batch, labels):
63
        embedding = F.relu(self.fc1(input_embeddings_batch))
64
        out = self.pretrained_Bert(inputs_embeds = embedding, attention_mask = input_masks_batch, labels = labels, return_dict = True)
65
        return out
66
67
""" Finetuning from a pretrained language model BART, two step training"""
68
class FineTunePretrainedTwoStep(nn.Module):
69
    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048):
70
        super(FineTunePretrainedTwoStep, self).__init__()
71
        
72
        self.pretrained_layers = pretrained_layers
73
        # additional transformer encoder, following BART paper about 
74
        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
75
        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
76
        
77
        # NOTE: add positional embedding?
78
        # print('[INFO]adding positional embedding')
79
        # self.positional_embedding = PositionalEncoding(in_feature)
80
81
        self.fc1 = nn.Linear(in_feature, d_model)
82
83
    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, labels):
84
        """input_embeddings_batch: batch_size*Seq_len*840"""
85
        """input_mask: 1 is not masked, 0 is masked"""
86
        """input_masks_invert: 1 is masked, 0 is not masked"""
87
        """labels: sentitment labels 0,1,2"""
88
        
89
        # NOTE: add positional embedding?
90
        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) 
91
92
        # use src_key_padding_masks
93
        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) 
94
        # encoded_embedding = self.additional_encoder(input_embeddings_batch) 
95
        
96
        encoded_embedding = F.relu(self.fc1(encoded_embedding))
97
        out = self.pretrained_layers(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = labels)                    
98
        
99
        return out
100
101
""" Zero-shot sentiment discovery using a finetuned generation model and a sentiment model pretrained on text """
102
class ZeroShotSentimentDiscovery(nn.Module):
103
    def __init__(self, brain2text_translator, sentiment_classifier, translation_tokenizer, sentiment_tokenizer, device = 'cpu'):
104
        # only for inference
105
        super(ZeroShotSentimentDiscovery, self).__init__()
106
        
107
        self.brain2text_translator = brain2text_translator
108
        self.sentiment_classifier = sentiment_classifier
109
        self.translation_tokenizer = translation_tokenizer
110
        self.sentiment_tokenizer = sentiment_tokenizer
111
        self.device = device
112
    
113
114
    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
115
        """input_embeddings_batch: batch_size*Seq_len*840"""
116
        """input_mask: 1 is not masked, 0 is masked"""
117
        """input_masks_invert: 1 is masked, 0 is not masked"""
118
        """labels: sentitment labels 0,1,2"""
119
        
120
        def logits2PredString(logits):
121
            probs = logits[0].softmax(dim = 1)
122
            # print('probs size:', probs.size())
123
            values, predictions = probs.topk(1)
124
            # print('predictions before squeeze:',predictions.size())
125
            predictions = torch.squeeze(predictions)
126
            predict_string = self.translation_tokenizer.decode(predictions)
127
            return predict_string
128
129
        # only works on batch is one
130
        assert input_embeddings_batch.size()[0] == 1
131
132
        seq2seqLMoutput = self.brain2text_translator(input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted)
133
        predict_string = logits2PredString(seq2seqLMoutput.logits)
134
        predict_string = predict_string.split('</s></s>')[0]
135
        predict_string = predict_string.replace('<s>','')
136
        print('predict string:', predict_string)
137
        re_tokenized = self.sentiment_tokenizer(predict_string, return_tensors='pt', return_attention_mask = True)
138
        input_ids = re_tokenized['input_ids'].to(self.device) # batch = 1
139
        attn_mask = re_tokenized['attention_mask'].to(self.device) # batch = 1
140
141
        out = self.sentiment_classifier(input_ids = input_ids, attention_mask = attn_mask, return_dict = True, labels = sentiment_labels)
142
143
        return out
144
145
146
""" Miscellaneous: jointly learn generation and classification (not working well) """
147
class BartClassificationHead(nn.Module):
148
    # from transformers: https://huggingface.co/transformers/_modules/transformers/models/bart/modeling_bart.html
149
    """Head for sentence-level classification tasks."""
150
    def __init__(
151
        self,
152
        input_dim: int,
153
        inner_dim: int,
154
        num_classes: int,
155
        pooler_dropout: float,
156
    ):
157
        super().__init__()
158
        self.dense = nn.Linear(input_dim, inner_dim)
159
        self.dropout = nn.Dropout(p=pooler_dropout)
160
        self.out_proj = nn.Linear(inner_dim, num_classes)
161
162
    def forward(self, hidden_states: torch.Tensor):
163
        hidden_states = self.dropout(hidden_states)
164
        hidden_states = self.dense(hidden_states)
165
        hidden_states = torch.tanh(hidden_states)
166
        hidden_states = self.dropout(hidden_states)
167
        hidden_states = self.out_proj(hidden_states)
168
        return hidden_states
169
170
class JointBrainTranslatorSentimentClassifier(nn.Module):
171
    def __init__(self, pretrained_layers, in_feature = 840, d_model = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048, num_labels = 3):
172
        super(JointBrainTranslatorSentimentClassifier, self).__init__()
173
        
174
        self.pretrained_generator = pretrained_layers
175
        # additional transformer encoder, following BART paper about 
176
        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,  dim_feedforward = additional_encoder_dim_feedforward, batch_first=True)
177
        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
178
        self.fc1 = nn.Linear(in_feature, d_model)
179
        self.num_labels = num_labels
180
181
        self.pooler = Pooler(d_model)
182
        self.classifier = BartClassificationHead(input_dim = d_model, inner_dim = d_model, num_classes = num_labels, pooler_dropout = pretrained_layers.config.classifier_dropout)
183
184
    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, sentiment_labels):
185
        """input_embeddings_batch: batch_size*Seq_len*840"""
186
        """input_mask: 1 is not masked, 0 is masked"""
187
        """input_masks_invert: 1 is masked, 0 is not masked"""
188
        
189
        # NOTE: add positional embedding?
190
        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) 
191
192
        # use src_key_padding_masks
193
        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask = input_masks_invert) 
194
        
195
        # encoded_embedding = self.additional_encoder(input_embeddings_batch) 
196
        encoded_embedding = F.relu(self.fc1(encoded_embedding))
197
        LMoutput = self.pretrained_generator(inputs_embeds = encoded_embedding, attention_mask = input_masks_batch, return_dict = True, labels = target_ids_batch_converted, output_hidden_states = True)                    
198
        hidden_states = LMoutput.decoder_hidden_states # N, seq_len, hidden_dim
199
        # print('hidden states len:', len(hidden_states))
200
        last_hidden_states = hidden_states[-1]
201
        # print('last hidden states size:', last_hidden_states.size())
202
        sentence_representation = self.pooler(last_hidden_states)
203
 
204
        classification_logits = self.classifier(sentence_representation) 
205
        loss_fct = nn.CrossEntropyLoss()
206
        classification_loss = loss_fct(classification_logits.view(-1, self.num_labels), sentiment_labels.view(-1))
207
        classification_output = {'loss':classification_loss,'logits':classification_logits}
208
        # print('successful one forward!!!!')
209
        return LMoutput, classification_output
210
211
212
""" helper modules """
213
# modified from BertPooler
214
class Pooler(nn.Module):
215
    def __init__(self, hidden_size):
216
        super().__init__()
217
        self.dense = nn.Linear(hidden_size, hidden_size)
218
        self.activation = nn.Tanh()
219
220
    def forward(self, hidden_states):
221
        # We "pool" the model by simply taking the hidden state corresponding
222
        # to the first token.
223
        first_token_tensor = hidden_states[:, 0]
224
        pooled_output = self.dense(first_token_tensor)
225
        pooled_output = self.activation(pooled_output)
226
        return pooled_output
227
228
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
229
class PositionalEncoding(nn.Module):
230
231
    def __init__(self, d_model, dropout=0.1, max_len=5000):
232
        super(PositionalEncoding, self).__init__()
233
        self.dropout = nn.Dropout(p=dropout)
234
235
        pe = torch.zeros(max_len, d_model)
236
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
237
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
238
        pe[:, 0::2] = torch.sin(position * div_term)
239
        pe[:, 1::2] = torch.cos(position * div_term)
240
        pe = pe.unsqueeze(0).transpose(0, 1)
241
        self.register_buffer('pe', pe)
242
243
    def forward(self, x):
244
        # print('[DEBUG] input size:', x.size())
245
        # print('[DEBUG] positional embedding size:', self.pe.size())
246
        x = x + self.pe[:x.size(0), :]
247
        # print('[DEBUG] output x with pe size:', x.size())
248
        return self.dropout(x)
249