--- a +++ b/model_decoding.py @@ -0,0 +1,221 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, Text2TextGenerationPipeline +import math +import numpy as np + +""" main architecture for open vocabulary EEG-To-Text decoding""" + + +class BrainTranslator(nn.Module): + def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8, + additional_encoder_dim_feedforward=2048): + super(BrainTranslator, self).__init__() + + self.pretrained = pretrained_layers + # additional transformer encoder, following BART paper about + self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead, + dim_feedforward=additional_encoder_dim_feedforward, + batch_first=True) + self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6) + + # print('[INFO]adding positional embedding') + # self.positional_embedding = PositionalEncoding(in_feature) + + self.fc1 = nn.Linear(in_feature, decoder_embedding_size) + + def addin_forward(self, input_embeddings_batch, input_masks_invert): + """input_embeddings_batch: batch_size*Seq_len*840""" + """input_mask: 1 is not masked, 0 is masked""" + """input_masks_invert: 1 is masked, 0 is not masked""" + + # input_embeddings_batch = self.positional_embedding(input_embeddings_batch) + # use src_key_padding_masks + encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert) + + # encoded_embedding = self.additional_encoder(input_embeddings_batch) + encoded_embedding = F.relu(self.fc1(encoded_embedding)) + return encoded_embedding + + @torch.no_grad() + def generate( + self, + input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert) + output = self.pretrained.generate( + inputs_embeds=encoded_embedding, + attention_mask=input_masks_batch[:, :encoded_embedding.shape[1]], + labels=target_ids_batch_converted, + return_dict=True, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **kwargs, ) + + return output + + def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted): + encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert) + # print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}') + out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch, + return_dict=True, labels=target_ids_batch_converted) + + return out + + +""" crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder""" + + +class BrainTranslatorNaive(nn.Module): + def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8, + additional_encoder_dim_feedforward=2048): + super(BrainTranslatorNaive, self).__init__() + '''no additional transformer encoder version''' + self.pretrained = pretrained_layers + self.fc1 = nn.Linear(in_feature, decoder_embedding_size) + + def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted): + """input_embeddings_batch: batch_size*Seq_len*840""" + """input_mask: 1 is not masked, 0 is masked""" + """input_masks_invert: 1 is masked, 0 is not masked""" + encoded_embedding = F.relu(self.fc1(input_embeddings_batch)) + out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch, return_dict=True, + labels=target_ids_batch_converted) + return out + + +""" helper modules """ + + +# modified from BertPooler +class Pooler(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + # print('[DEBUG] input size:', x.size()) + # print('[DEBUG] positional embedding size:', self.pe.size()) + x = x + self.pe[:x.size(0), :] + # print('[DEBUG] output x with pe size:', x.size()) + return self.dropout(x) + + +""" Miscellaneous (not working well) """ + + +class BrainTranslatorBert(nn.Module): + def __init__(self, pretrained_layers, in_feature=840, hidden_size=768): + super(BrainTranslatorBert, self).__init__() + + self.pretrained_Bert = pretrained_layers + self.fc1 = nn.Linear(in_feature, hidden_size) + + def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch): + embedding = F.relu(self.fc1(input_embeddings_batch)) + out = self.pretrained_Bert(inputs_embeds=embedding, attention_mask=input_masks_batch, labels=target_ids_batch, + return_dict=True) + return out + + +class EEG2BertMapping(nn.Module): + def __init__(self, in_feature=840, hidden_size=512, out_feature=768): + super(EEG2BertMapping, self).__init__() + self.fc1 = nn.Linear(in_feature, hidden_size) + self.fc2 = nn.Linear(hidden_size, out_feature) + + def forward(self, x): + out = F.relu(self.fc1(x)) + out = self.fc2(out) + return out + + +class ContrastiveBrainTextEncoder(nn.Module): + def __init__(self, pretrained_text_encoder, in_feature=840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward=2048, + embed_dim=768): + super(ContrastiveBrainTextEncoder, self).__init__() + # EEG Encoder + self.positional_embedding = PositionalEncoding(in_feature) + self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead, + dim_feedforward=eeg_encoder_dim_feedforward, batch_first=True) + self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6) + self.EEG_pooler = Pooler(in_feature) + self.ln_final = nn.LayerNorm(in_feature) # to be considered + + # project to text embedding + self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim)) + + # Text Encoder + self.TextEncoder = pretrained_text_encoder + + # learned temperature parameter + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks): + # add positional embedding + input_EEG_features = self.positional_embedding(input_EEG_features) + # get EEG feature embedding + EEG_hiddenstates = self.EEG_Encoder(input_EEG_features, src_key_padding_mask=input_EEG_attn_mask) + EEG_hiddenstates = self.ln_final(EEG_hiddenstates) + EEG_features = self.EEG_pooler(EEG_hiddenstates) # [N, 840] + + # project to text embed size + EEG_features = EEG_features @ self.EEG_projection # [N, 768] + + # get text feature embedding + Text_features = self.TextEncoder(input_ids=input_ids, attention_mask=input_text_attention_masks, + return_dict=True).pooler_output # [N, 768] + + # normalized features + EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True) # [N, 768] + Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True) # [N, 768] + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_EEG = logit_scale * EEG_features @ Text_features.t() # [N, N] + logits_per_text = logit_scale * Text_features @ EEG_features.t() # [N, N] + + return logits_per_EEG, logits_per_text