--- a
+++ b/model_decoding.py
@@ -0,0 +1,221 @@
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.data
+from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, Text2TextGenerationPipeline
+import math
+import numpy as np
+
+""" main architecture for open vocabulary EEG-To-Text decoding"""
+
+
+class BrainTranslator(nn.Module):
+    def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8,
+                 additional_encoder_dim_feedforward=2048):
+        super(BrainTranslator, self).__init__()
+
+        self.pretrained = pretrained_layers
+        # additional transformer encoder, following BART paper about
+        self.additional_encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=additional_encoder_nhead,
+                                                                   dim_feedforward=additional_encoder_dim_feedforward,
+                                                                   batch_first=True)
+        self.additional_encoder = nn.TransformerEncoder(self.additional_encoder_layer, num_layers=6)
+
+        # print('[INFO]adding positional embedding')
+        # self.positional_embedding = PositionalEncoding(in_feature)
+
+        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
+
+    def addin_forward(self, input_embeddings_batch, input_masks_invert):
+        """input_embeddings_batch: batch_size*Seq_len*840"""
+        """input_mask: 1 is not masked, 0 is masked"""
+        """input_masks_invert: 1 is masked, 0 is not masked"""
+
+        # input_embeddings_batch = self.positional_embedding(input_embeddings_batch)
+        # use src_key_padding_masks
+        encoded_embedding = self.additional_encoder(input_embeddings_batch, src_key_padding_mask=input_masks_invert)
+
+        # encoded_embedding = self.additional_encoder(input_embeddings_batch)
+        encoded_embedding = F.relu(self.fc1(encoded_embedding))
+        return encoded_embedding
+
+    @torch.no_grad()
+    def generate(
+            self,
+            input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted,
+            generation_config=None,
+            logits_processor=None,
+            stopping_criteria=None,
+            prefix_allowed_tokens_fn=None,
+            synced_gpus=None,
+            assistant_model=None,
+            streamer=None,
+            negative_prompt_ids=None,
+            negative_prompt_attention_mask=None,
+            **kwargs,
+    ):
+        encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert)
+        output = self.pretrained.generate(
+            inputs_embeds=encoded_embedding,
+            attention_mask=input_masks_batch[:, :encoded_embedding.shape[1]],
+            labels=target_ids_batch_converted,
+            return_dict=True,
+            generation_config=generation_config,
+            logits_processor=logits_processor,
+            stopping_criteria=stopping_criteria,
+            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+            synced_gpus=synced_gpus,
+            assistant_model=assistant_model,
+            streamer=streamer,
+            negative_prompt_ids=negative_prompt_ids,
+            negative_prompt_attention_mask=negative_prompt_attention_mask,
+            **kwargs, )
+
+        return output
+
+    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
+        encoded_embedding = self.addin_forward(input_embeddings_batch, input_masks_invert)
+        # print(f'forward:{input_embeddings_batch.shape,input_masks_batch.shape,input_masks_invert.shape,target_ids_batch_converted.shape,encoded_embedding.shape}')
+        out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch,
+                              return_dict=True, labels=target_ids_batch_converted)
+
+        return out
+
+
+""" crippled open vocabulary EEG-To-Text decoding model w/o additional MTE encoder"""
+
+
+class BrainTranslatorNaive(nn.Module):
+    def __init__(self, pretrained_layers, in_feature=840, decoder_embedding_size=1024, additional_encoder_nhead=8,
+                 additional_encoder_dim_feedforward=2048):
+        super(BrainTranslatorNaive, self).__init__()
+        '''no additional transformer encoder version'''
+        self.pretrained = pretrained_layers
+        self.fc1 = nn.Linear(in_feature, decoder_embedding_size)
+
+    def forward(self, input_embeddings_batch, input_masks_batch, input_masks_invert, target_ids_batch_converted):
+        """input_embeddings_batch: batch_size*Seq_len*840"""
+        """input_mask: 1 is not masked, 0 is masked"""
+        """input_masks_invert: 1 is masked, 0 is not masked"""
+        encoded_embedding = F.relu(self.fc1(input_embeddings_batch))
+        out = self.pretrained(inputs_embeds=encoded_embedding, attention_mask=input_masks_batch, return_dict=True,
+                              labels=target_ids_batch_converted)
+        return out
+
+
+""" helper modules """
+
+
+# modified from BertPooler
+class Pooler(nn.Module):
+    def __init__(self, hidden_size):
+        super().__init__()
+        self.dense = nn.Linear(hidden_size, hidden_size)
+        self.activation = nn.Tanh()
+
+    def forward(self, hidden_states):
+        # We "pool" the model by simply taking the hidden state corresponding
+        # to the first token.
+        first_token_tensor = hidden_states[:, 0]
+        pooled_output = self.dense(first_token_tensor)
+        pooled_output = self.activation(pooled_output)
+        return pooled_output
+
+
+# from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
+class PositionalEncoding(nn.Module):
+
+    def __init__(self, d_model, dropout=0.1, max_len=5000):
+        super(PositionalEncoding, self).__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        pe = torch.zeros(max_len, d_model)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(0).transpose(0, 1)
+        self.register_buffer('pe', pe)
+
+    def forward(self, x):
+        # print('[DEBUG] input size:', x.size())
+        # print('[DEBUG] positional embedding size:', self.pe.size())
+        x = x + self.pe[:x.size(0), :]
+        # print('[DEBUG] output x with pe size:', x.size())
+        return self.dropout(x)
+
+
+""" Miscellaneous (not working well) """
+
+
+class BrainTranslatorBert(nn.Module):
+    def __init__(self, pretrained_layers, in_feature=840, hidden_size=768):
+        super(BrainTranslatorBert, self).__init__()
+
+        self.pretrained_Bert = pretrained_layers
+        self.fc1 = nn.Linear(in_feature, hidden_size)
+
+    def forward(self, input_embeddings_batch, input_masks_batch, target_ids_batch):
+        embedding = F.relu(self.fc1(input_embeddings_batch))
+        out = self.pretrained_Bert(inputs_embeds=embedding, attention_mask=input_masks_batch, labels=target_ids_batch,
+                                   return_dict=True)
+        return out
+
+
+class EEG2BertMapping(nn.Module):
+    def __init__(self, in_feature=840, hidden_size=512, out_feature=768):
+        super(EEG2BertMapping, self).__init__()
+        self.fc1 = nn.Linear(in_feature, hidden_size)
+        self.fc2 = nn.Linear(hidden_size, out_feature)
+
+    def forward(self, x):
+        out = F.relu(self.fc1(x))
+        out = self.fc2(out)
+        return out
+
+
+class ContrastiveBrainTextEncoder(nn.Module):
+    def __init__(self, pretrained_text_encoder, in_feature=840, eeg_encoder_nhead=8, eeg_encoder_dim_feedforward=2048,
+                 embed_dim=768):
+        super(ContrastiveBrainTextEncoder, self).__init__()
+        # EEG Encoder
+        self.positional_embedding = PositionalEncoding(in_feature)
+        self.encoder_layer = nn.TransformerEncoderLayer(d_model=in_feature, nhead=eeg_encoder_nhead,
+                                                        dim_feedforward=eeg_encoder_dim_feedforward, batch_first=True)
+        self.EEG_Encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
+        self.EEG_pooler = Pooler(in_feature)
+        self.ln_final = nn.LayerNorm(in_feature)  # to be considered
+
+        # project to text embedding
+        self.EEG_projection = nn.Parameter(torch.empty(in_feature, embed_dim))
+
+        # Text Encoder
+        self.TextEncoder = pretrained_text_encoder
+
+        # learned temperature parameter
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+    def forward(self, input_EEG_features, input_EEG_attn_mask, input_ids, input_text_attention_masks):
+        # add positional embedding
+        input_EEG_features = self.positional_embedding(input_EEG_features)
+        # get EEG feature embedding
+        EEG_hiddenstates = self.EEG_Encoder(input_EEG_features, src_key_padding_mask=input_EEG_attn_mask)
+        EEG_hiddenstates = self.ln_final(EEG_hiddenstates)
+        EEG_features = self.EEG_pooler(EEG_hiddenstates)  # [N, 840]
+
+        # project to text embed size
+        EEG_features = EEG_features @ self.EEG_projection  # [N, 768]
+
+        # get text feature embedding
+        Text_features = self.TextEncoder(input_ids=input_ids, attention_mask=input_text_attention_masks,
+                                         return_dict=True).pooler_output  # [N, 768]
+
+        # normalized features
+        EEG_features = EEG_features / EEG_features.norm(dim=-1, keepdim=True)  # [N, 768]
+        Text_features = Text_features / Text_features.norm(dim=-1, keepdim=True)  # [N, 768]
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_EEG = logit_scale * EEG_features @ Text_features.t()  # [N, N]
+        logits_per_text = logit_scale * Text_features @ EEG_features.t()  # [N, N]
+
+        return logits_per_EEG, logits_per_text