Diff of /model_decoding.py [000000] .. [66af30]

Switch to unified view

a b/model_decoding.py
1
import torch.nn as nn
2
import torch.nn.functional as F
3
import torch.utils.data
4
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, Text2TextGenerationPipeline
5
import math
6
import numpy as np
7
8
""" main architecture for open vocabulary EEG-To-Text decoding"""
9
10
11
class BrainTranslator(nn.Module):
12
    def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8,
13
                 additional_encoder_dim_feedforward=2048):
14
        super(BrainTranslator, self).__init__()
15
16
        self.pretrained = pretrained_layers
17
        # additional transformer encoder, following BART paper about
18
        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,
19
                                                                   dim_feedforward=additional_encoder_dim_feedforward,
20
                                                                   batch_first=True)
21
        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
22
23
        # print('[INFO]adding positional embedding')
24
        # self.positional_embedding = PositionalEncoding(in_feature)
25
26
        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
27
28
    def addin_forward(self, input_embeddings_batch, input_masks_invert):
29
        """input_embeddings_batch: batch_size*Seq_len*840"""
30
        """input_mask: 1 is not masked, 0 is masked"""
31
        """input_masks_invert: 1 is masked, 0 is not masked"""
32
33
        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
34
        # use src_key_padding_masks
35
        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)
36
37
        # encoded_embedding = self.additional_encoder(input_embeddings_batch)
38
        encoded_embedding = F.relu(self.fc1(encoded_embedding))
39
        return encoded_embedding
40
41
    @torch.no_grad()
42
    def generate(
43
            self,
44
            input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,
45
            generation_config=None,
46
            logits_processor=None,
47
            stopping_criteria=None,
48
            prefix_allowed_tokens_fn=None,
49
            synced_gpus=None,
50
            assistant_model=None,
51
            streamer=None,
52
            negative_prompt_ids=None,
53
            negative_prompt_attention_mask=None,
54
            **kwargs,
55
    ):
56
        encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert)
57
        output = self.pretrained.generate(
58
            inputs_embeds=encoded_embedding,
59
            attention_mask=input_masks_batch[:, :encoded_embedding.shape[1]],
60
            labels=target_ids_batch_converted,
61
            return_dict=True,
62
            generation_config=generation_config,
63
            logits_processor=logits_processor,
64
            stopping_criteria=stopping_criteria,
65
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
66
            synced_gpus=synced_gpus,
67
            assistant_model=assistant_model,
68
            streamer=streamer,
69
            negative_prompt_ids=negative_prompt_ids,
70
            negative_prompt_attention_mask=negative_prompt_attention_mask,
71
            **kwargs, )
72
73
        return output
74
75
    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
76
        encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert)
77
        # print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}')
78
        out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch,
79
                              return_dict=True, labels=target_ids_batch_converted)
80
81
        return out
82
83
84
""" crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder"""
85
86
87
class BrainTranslatorNaive(nn.Module):
88
    def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8,
89
                 additional_encoder_dim_feedforward=2048):
90
        super(BrainTranslatorNaive, self).__init__()
91
        '''no additional transformer encoder version'''
92
        self.pretrained = pretrained_layers
93
        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
94
95
    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
96
        """input_embeddings_batch: batch_size*Seq_len*840"""
97
        """input_mask: 1 is not masked, 0 is masked"""
98
        """input_masks_invert: 1 is masked, 0 is not masked"""
99
        encoded_embedding = F.relu(self.fc1(input_embeddings_batch))
100
        out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch, return_dict=True,
101
                              labels=target_ids_batch_converted)
102
        return out
103
104
105
""" helper modules """
106
107
108
# modified from BertPooler
109
class Pooler(nn.Module):
110
    def __init__(self, hidden_size):
111
        super().__init__()
112
        self.dense = nn.Linear(hidden_size, hidden_size)
113
        self.activation = nn.Tanh()
114
115
    def forward(self, hidden_states):
116
        # We "pool" the model by simply taking the hidden state corresponding
117
        # to the first token.
118
        first_token_tensor = hidden_states[:, 0]
119
        pooled_output = self.dense(first_token_tensor)
120
        pooled_output = self.activation(pooled_output)
121
        return pooled_output
122
123
124
# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
125
class PositionalEncoding(nn.Module):
126
127
    def __init__(self, d_model, dropout=0.1, max_len=5000):
128
        super(PositionalEncoding, self).__init__()
129
        self.dropout = nn.Dropout(p=dropout)
130
131
        pe = torch.zeros(max_len, d_model)
132
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
133
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
134
        pe[:, 0::2] = torch.sin(position * div_term)
135
        pe[:, 1::2] = torch.cos(position * div_term)
136
        pe = pe.unsqueeze(0).transpose(0, 1)
137
        self.register_buffer('pe', pe)
138
139
    def forward(self, x):
140
        # print('[DEBUG] input size:', x.size())
141
        # print('[DEBUG] positional embedding size:', self.pe.size())
142
        x = x + self.pe[:x.size(0), :]
143
        # print('[DEBUG] output x with pe size:', x.size())
144
        return self.dropout(x)
145
146
147
""" Miscellaneous (not working well) """
148
149
150
class BrainTranslatorBert(nn.Module):
151
    def __init__(self, pretrained_layers, in_feature=840, hidden_size=768):
152
        super(BrainTranslatorBert, self).__init__()
153
154
        self.pretrained_Bert = pretrained_layers
155
        self.fc1 = nn.Linear(in_feature, hidden_size)
156
157
    def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch):
158
        embedding = F.relu(self.fc1(input_embeddings_batch))
159
        out = self.pretrained_Bert(inputs_embeds=embedding, attention_mask=input_masks_batch, labels=target_ids_batch,
160
                                   return_dict=True)
161
        return out
162
163
164
class EEG2BertMapping(nn.Module):
165
    def __init__(self, in_feature=840, hidden_size=512, out_feature=768):
166
        super(EEG2BertMapping, self).__init__()
167
        self.fc1 = nn.Linear(in_feature, hidden_size)
168
        self.fc2 = nn.Linear(hidden_size, out_feature)
169
170
    def forward(self, x):
171
        out = F.relu(self.fc1(x))
172
        out = self.fc2(out)
173
        return out
174
175
176
class ContrastiveBrainTextEncoder(nn.Module):
177
    def __init__(self, pretrained_text_encoder, in_feature=840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward=2048,
178
                 embed_dim=768):
179
        super(ContrastiveBrainTextEncoder, self).__init__()
180
        # EEG Encoder
181
        self.positional_embedding = PositionalEncoding(in_feature)
182
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead,
183
                                                        dim_feedforward=eeg_encoder_dim_feedforward, batch_first=True)
184
        self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
185
        self.EEG_pooler = Pooler(in_feature)
186
        self.ln_final = nn.LayerNorm(in_feature)  # to be considered
187
188
        # project to text embedding
189
        self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim))
190
191
        # Text Encoder
192
        self.TextEncoder = pretrained_text_encoder
193
194
        # learned temperature parameter
195
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
196
197
    def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks):
198
        # add positional embedding
199
        input_EEG_features = self.positional_embedding(input_EEG_features)
200
        # get EEG feature embedding
201
        EEG_hiddenstates = self.EEG_Encoder(input_EEG_features, src_key_padding_mask=input_EEG_attn_mask)
202
        EEG_hiddenstates = self.ln_final(EEG_hiddenstates)
203
        EEG_features = self.EEG_pooler(EEG_hiddenstates)  # [N, 840]
204
205
        # project to text embed size
206
        EEG_features = EEG_features @ self.EEG_projection  # [N, 768]
207
208
        # get text feature embedding
209
        Text_features = self.TextEncoder(input_ids=input_ids, attention_mask=input_text_attention_masks,
210
                                         return_dict=True).pooler_output  # [N, 768]
211
212
        # normalized features
213
        EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True)  # [N, 768]
214
        Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True)  # [N, 768]
215
216
        # cosine similarity as logits
217
        logit_scale = self.logit_scale.exp()
218
        logits_per_EEG = logit_scale * EEG_features @ Text_features.t()  # [N, N]
219
        logits_per_text = logit_scale * Text_features @ EEG_features.t()  # [N, N]
220
221
        return logits_per_EEG, logits_per_text