Switch to unified view

a b/Retrieval/contrast_retrieval.py
1
import os
2
3
import torch
4
import torch.optim as optim
5
from torch.nn import CrossEntropyLoss
6
from torch.nn import functional as F
7
from torch.optim import Adam
8
from torch.utils.data import DataLoader
9
10
os.environ["WANDB_API_KEY"] = "KEY"
11
os.environ["WANDB_MODE"] = 'offline'
12
from itertools import combinations
13
14
import clip
15
import matplotlib.pyplot as plt
16
import numpy as np
17
import torch.nn as nn
18
import torchvision.transforms as transforms
19
import tqdm
20
from eegdatasets_leaveone import EEGDataset
21
from einops.layers.torch import Rearrange, Reduce
22
from loss import ClipLoss
23
from sklearn.metrics import confusion_matrix
24
from torch.utils.data import DataLoader, Dataset
25
import random
26
from utils import wandb_logger
27
import csv
28
from braindecode.models import EEGNetv4, ATCNet, EEGConformer, EEGITNet, ShallowFBCSPNet
29
import argparse
30
31
32
33
34
35
#--------------------------------NICE-----------------------------------#
36
class PatchEmbedding(nn.Module):
37
    def __init__(self, emb_size=40):
38
        super().__init__()
39
        # revised from shallownet
40
        self.tsconv = nn.Sequential(
41
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
42
            nn.AvgPool2d((1, 51), (1, 5)),
43
            nn.BatchNorm2d(40),
44
            nn.ELU(),
45
            nn.Conv2d(40, 40, (63, 1), (1, 1)),
46
            nn.BatchNorm2d(40),
47
            nn.ELU(),
48
            nn.Dropout(0.5),
49
        )
50
51
        self.projection = nn.Sequential(
52
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  
53
            Rearrange('b e (h) (w) -> b (h w) e'),
54
        )
55
56
    def forward(self, x: Tensor) -> Tensor:
57
        # b, _, _, _ = x.shape
58
        x = x.unsqueeze(1)     
59
        # print("x", x.shape)   
60
        x = self.tsconv(x)
61
        # print("tsconv", x.shape)   
62
        x = self.projection(x)
63
        # print("projection", x.shape)  
64
        return x
65
66
class ResidualAdd(nn.Module):
67
    def __init__(self, fn):
68
        super().__init__()
69
        self.fn = fn
70
71
    def forward(self, x, **kwargs):
72
        res = x
73
        x = self.fn(x, **kwargs)
74
        x += res
75
        return x
76
77
class FlattenHead(nn.Sequential):
78
    def __init__(self):
79
        super().__init__()
80
81
    def forward(self, x):
82
        x = x.contiguous().view(x.size(0), -1)
83
        return x
84
85
class Enc_eeg(nn.Sequential):
86
    def __init__(self, emb_size=40, **kwargs):
87
        super().__init__(
88
            PatchEmbedding(emb_size),
89
            FlattenHead()
90
        )
91
   
92
class Proj_eeg(nn.Sequential):
93
    def __init__(self, embedding_dim=1440, proj_dim=1024, drop_proj=0.5):
94
        super().__init__(
95
            nn.Linear(embedding_dim, proj_dim),
96
            ResidualAdd(nn.Sequential(
97
                nn.GELU(),
98
                nn.Linear(proj_dim, proj_dim),
99
                nn.Dropout(drop_proj),
100
            )),
101
            nn.LayerNorm(proj_dim),
102
        )
103
104
105
class NICE(nn.Module):
106
    def __init__(self):
107
        super().__init__()
108
        self.enc_eeg = Enc_eeg()
109
        self.proj_eeg = Proj_eeg()
110
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
111
        self.loss_func = ClipLoss()        
112
    def forward(self, data):
113
        eeg_embedding = self.enc_eeg(data)
114
        out = self.proj_eeg(eeg_embedding)
115
116
        return out  
117
#########################################################################
118
119
120
#-------------------------------EEGNetv4--------------------------------#
121
class EEGNetv4_Encoder(nn.Module):
122
    def __init__(self):
123
        super().__init__()
124
        self.device = device
125
        self.shape = (63, 250)
126
        self.eegnet = EEGNetv4(
127
            in_chans=self.shape[0],
128
            n_classes=1024,   
129
            input_window_samples=self.shape[1],
130
            final_conv_length='auto',
131
            pool_mode='mean',
132
            F1=8,
133
            D=20,
134
            F2=160,
135
            kernel_length=4,
136
            third_kernel_size=(4, 2),
137
            drop_prob=0.25
138
        )
139
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
140
        self.loss_func = ClipLoss()
141
    def forward(self, data):
142
        data = data.unsqueeze(0)
143
        data = data.reshape(data.shape[1], data.shape[2], data.shape[3], data.shape[0])
144
        # print(data.shape)
145
        prediction = self.eegnet(data)
146
        return prediction
147
#########################################################################
148
149
150
#--------------------------EEGConformer_Encoder-------------------------#
151
class EEGConformer_Encoder(nn.Module):
152
    def __init__(self):
153
        super().__init__()
154
        self.device = device
155
        self.shape = (63, 250)
156
        self.eegConformer = EEGConformer(n_outputs=None, 
157
                                   n_chans=self.shape[0], 
158
                                   n_filters_time=40, 
159
                                   filter_time_length=10, 
160
                                   pool_time_length=25, 
161
                                   pool_time_stride=5, 
162
                                   drop_prob=0.25, 
163
                                   att_depth=2, 
164
                                   att_heads=1, 
165
                                   att_drop_prob=0.5, 
166
                                   final_fc_length=1760, 
167
                                   return_features=False, 
168
                                   n_times=None, 
169
                                   chs_info=None, 
170
                                   input_window_seconds=None,
171
                                   n_classes=1024, 
172
                                   input_window_samples=self.shape[1], 
173
                                   add_log_softmax=True)
174
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
175
        self.loss_func = ClipLoss()
176
    def forward(self, data):
177
        # data = data.unsqueeze(0)
178
        # data = data.reshape(data.shape[1], data.shape[2], data.shape[3], data.shape[0])
179
        # print(data.shape)
180
        prediction = self.eegConformer(data)
181
        return prediction
182
#########################################################################
183
184
185
#-----------------------------EEGITNet_Encoder--------------------------#
186
class EEGITNet_Encoder(nn.Module):
187
    def __init__(self):
188
        super().__init__()
189
        self.device = device
190
        self.shape = (63, 250)
191
        self.eegEEGITNet = EEGITNet(n_outputs=1024, 
192
                                  n_chans=self.shape[0], 
193
                                  n_times=None, 
194
                                  drop_prob=0.4, 
195
                                  chs_info=None, 
196
                                  input_window_seconds=1.0, 
197
                                  sfreq=250, 
198
                                  input_window_samples=self.shape[1],
199
                                  add_log_softmax=True)
200
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
201
        self.loss_func = ClipLoss()
202
    def forward(self, data):
203
        prediction = self.eegEEGITNet(data)
204
        return prediction
205
#########################################################################
206
207
208
#--------------------------------MLP------------------------------------#
209
def make_block(h_c, h_l,dropout_rate=0.25):
210
    block = nn.Sequential(
211
        nn.LayerNorm(h_l),
212
        nn.Linear(h_l, h_l), 
213
        nn.GELU(),
214
        nn.Dropout(dropout_rate),  
215
        Rearrange('B C L->B L C'),
216
        nn.LayerNorm(h_c),
217
        nn.Linear(h_c, h_c), 
218
        nn.GELU(),
219
        nn.Dropout(dropout_rate),  
220
        Rearrange('B L C->B C L'),
221
    )
222
    return block
223
224
class Projector(nn.Module):
225
226
    def __init__(self, in_features, h_dim=(64, 1024), n_hidden_layer=2,dropout_rate=0.25):
227
        # in_features: (c, l)
228
        super().__init__()
229
        c, l = in_features
230
        h_c, h_l = h_dim
231
        c_o, l_o = 1, 1024
232
233
        self.input_layer = nn.Sequential(
234
            nn.LayerNorm(l),
235
            nn.Linear(l, h_l), 
236
            nn.GELU(),
237
            nn.Dropout(dropout_rate),  
238
            Rearrange('B C L->B L C'),
239
            nn.LayerNorm(c),
240
            nn.Linear(c, h_c), 
241
            nn.GELU(),
242
            nn.Dropout(dropout_rate),  
243
            Rearrange('B L C->B C L'),
244
        )
245
        
246
        self.output_layer = nn.Sequential(
247
            nn.LayerNorm(h_l),
248
            nn.Linear(h_l, l_o), 
249
            nn.GELU(),
250
            nn.Dropout(dropout_rate),  
251
            Rearrange('B C L->B L C'),
252
            nn.LayerNorm(h_c),
253
            nn.Linear(h_c, c_o), 
254
            nn.GELU(),
255
            nn.Dropout(dropout_rate),  
256
            Rearrange('B L C->B (C L)'),
257
        )
258
        
259
        self.blocks = nn.Sequential(*[
260
            make_block(h_c, h_l) for _ in range(n_hidden_layer)
261
        ])
262
        
263
        self.projector = nn.Sequential(*[
264
            self.input_layer,
265
            self.blocks,
266
            self.output_layer,
267
        ])
268
269
        # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
270
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.01))
271
                
272
        self.loss_func = ClipLoss()
273
    
274
    def forward(self, eeg_embeds):
275
        
276
        eeg_embeds = self.projector(eeg_embeds)
277
        # print("eeg_embeds")
278
        # print(eeg_embeds.shape)
279
        eeg_features = F.normalize(eeg_embeds, dim=-1)
280
        return eeg_features
281
#########################################################################
282
283
284
#-------------------------ShallowFBCSPNet_Encoder-----------------------#
285
class ShallowFBCSPNet_Encoder(nn.Module):
286
    def __init__(self):
287
        super().__init__()
288
        self.device = device
289
        self.shape = (63, 250)
290
        self.ShallowFBCSPNet = ShallowFBCSPNet(n_chans=self.shape[0],
291
                                         n_outputs=1024,
292
                                         n_times=self.shape[1], 
293
                                         n_filters_time=20, 
294
                                         filter_time_length=20,
295
                                         n_filters_spat=20,
296
                                         pool_time_length=25, 
297
                                         pool_time_stride=5, 
298
                                         final_conv_length='auto', 
299
                                         pool_mode='mean', 
300
                                         split_first_layer=True,
301
                                         batch_norm=True, 
302
                                         batch_norm_alpha=0.1, 
303
                                         drop_prob=0.5,
304
                                         chs_info=None, 
305
                                         input_window_seconds=1.0, 
306
                                         sfreq=250, 
307
                                         add_log_softmax=True)
308
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
309
        self.loss_func = ClipLoss()
310
    def forward(self, data):
311
        prediction = self.ShallowFBCSPNet(data)
312
        return prediction
313
#########################################################################
314
315
316
#---------------------------ATCNet_Encoder------------------------------#
317
class ATCNet_Encoder(nn.Module):
318
    def __init__(self):
319
        super().__init__()
320
        self.device = device
321
        self.shape = (63, 250)
322
        self.eegATCNet = ATCNet(n_chans=self.shape[0], 
323
                                n_outputs=1024,
324
                                input_window_seconds=1.0,
325
                                sfreq=250.,
326
                                conv_block_n_filters=8,
327
                                conv_block_kernel_length_1=32,
328
                                conv_block_kernel_length_2=8,
329
                                conv_block_pool_size_1=4,
330
                                conv_block_pool_size_2=3,
331
                                conv_block_depth_mult=2,
332
                                conv_block_dropout=0.3,
333
                                n_windows=5,
334
                                att_head_dim=4,
335
                                att_num_heads=2,
336
                                att_dropout=0.5,
337
                                tcn_depth=2,
338
                                tcn_kernel_size=4,
339
                                tcn_n_filters=16,
340
                                tcn_dropout=0.3,
341
                                tcn_activation=nn.ELU(),
342
                                concat=False,
343
                                max_norm_const=0.25,
344
                                chs_info=None,
345
                                n_times=None,
346
                                n_channels=None,
347
                                n_classes=None,
348
                                input_size_s=None,
349
                                add_log_softmax=True)
350
        
351
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
352
        self.loss_func = ClipLoss()
353
    def forward(self, data):
354
        # print("data", data.shape)
355
        prediction = self.eegATCNet(data)
356
        return prediction
357
#########################################################################
358
359
360
#-------------------------------Meta------------------------------------#
361
class PositionalEncoding(nn.Module):
362
    def __init__(self, d_model, max_len=5000):
363
        super(PositionalEncoding, self).__init__()
364
        pe = torch.zeros(max_len, d_model)
365
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
366
        
367
        div_term = torch.exp(torch.arange(0, d_model + 1, 2).float() * (-math.log(10000.0) / d_model))
368
        
369
        pe[:, 0::2] = torch.sin(position * div_term[:d_model // 2 + 1])
370
        pe[:, 1::2] = torch.cos(position * div_term[:d_model // 2])
371
372
        self.register_buffer('pe', pe)
373
374
    def forward(self, x):
375
        pe = self.pe[:x.size(0), :].unsqueeze(1).repeat(1, x.size(1), 1)
376
        x = x + pe
377
        return x
378
379
class EEGAttention(nn.Module):
380
    def __init__(self, channel, d_model, nhead):
381
        super(EEGAttention, self).__init__()
382
        self.pos_encoder = PositionalEncoding(d_model)
383
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
384
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
385
        self.channel = channel
386
        self.d_model = d_model
387
388
    def forward(self, src):
389
        src = src.permute(2, 0, 1)  # Change shape to [time_length, batch_size, channel]
390
        src = self.pos_encoder(src)
391
        output = self.transformer_encoder(src)
392
        return output.permute(1, 2, 0)  # Change shape back to [batch_size, channel, time_length]
393
394
class MetaEEG(nn.Module):
395
    def __init__(self, num_channels, sequence_length, num_subjects=1, num_features=64, num_latents=1024, num_blocks=1):
396
        super(MetaEEG, self).__init__()
397
        self.attention_model = EEGAttention(num_channels, num_channels, nhead=1)               
398
        self.subject_wise_linear = nn.ModuleList([nn.Linear(sequence_length, sequence_length) for _ in range(num_subjects)])
399
        self.conv_blocks = nn.Sequential(*[ConvBlock(num_channels, sequence_length) for _ in range(num_blocks)],
400
                                         Rearrange('B C L->B L C'))
401
        self.linear_projection = nn.Sequential(
402
                                            Rearrange('B L C->B C L'),
403
                                            nn.Linear(sequence_length, num_latents),
404
                                            Rearrange('B C L->B L C'))
405
        self.temporal_aggregation = nn.Linear(sequence_length, 1)
406
        self.clip_head = MLPHead(num_latents, num_latents)
407
        self.mse_head = MLPHead(num_latents, num_latents)
408
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/0.01))
409
        self.loss_func = ClipLoss()
410
        
411
    def forward(self, x, subject_id):
412
        # print(f'Input shape: {x.shape}')
413
        # attn_output, _ = self.attention(x, x, x)
414
       
415
        x = self.attention_model(x)
416
        # print(f'After attention shape: {x.shape}')
417
         
418
        x = self.subject_wise_linear[subject_id](x)
419
        # print(f'After subject-specific linear transformation shape: {x.shape}')
420
        
421
        x = self.conv_blocks(x)
422
        # print(f'After convolutional blocks shape: {x.shape}')
423
        
424
        # x = self.conv_blocks(x)
425
        # print(f'After convolutional blocks shape: {x.shape}')
426
        
427
        x = self.linear_projection(x)
428
        # print(f'After linear projection shape: {x.shape}')
429
        
430
        x = self.temporal_aggregation(x)
431
        # print(f'After temporal aggregation shape: {x.shape}')
432
433
        clip_out = self.clip_head(x)
434
        # print(f'Clip head output shape: {clip_out.shape}')
435
    
436
        mse_out = self.mse_head(x)
437
        # print(f'MSE head output shape: {mse_out.shape}')
438
439
        return clip_out, mse_out
440
441
class ConvBlock(nn.Module):
442
    def __init__(self, num_channels, num_features):
443
        super(ConvBlock, self).__init__()
444
        self.conv1 = nn.Conv1d(num_channels, num_features, kernel_size=3, stride=1, padding=1)
445
        self.conv2 = nn.Conv1d(num_features, num_features, kernel_size=3, stride=1, padding=1)
446
        self.conv3 = nn.Conv1d(num_features, num_features, kernel_size=3, stride=1, padding=1)
447
        self.norm1 = nn.LayerNorm(num_features)
448
        self.norm2 = nn.LayerNorm(num_features)
449
        self.norm3 = nn.LayerNorm(num_features)
450
        self.residual_conv = nn.Conv1d(num_channels, num_features, kernel_size=1)
451
452
    def forward(self, x):
453
        # print(f'ConvBlock input shape: {x.shape}')
454
        residual = self.residual_conv(x)
455
        # residual = x
456
        # print(f'residual shape: {residual.shape}')
457
        
458
        x = F.gelu(self.conv1(x))
459
        x = self.norm1(x)
460
        # print(f'After first convolution shape: {x.shape}')
461
                
462
        x = F.gelu(self.conv2(x))
463
        x = self.norm2(x)
464
        # print(f'After second convolution shape: {x.shape}')
465
        
466
        x = F.gelu(self.conv3(x))
467
        x = self.norm3(x)
468
        # print(f'After third convolution shape: {x.shape}')
469
        
470
        x += residual
471
        # print(f'ConvBlock output shape: {x.shape}')
472
        return x
473
474
class MLPHead(nn.Module):
475
    def __init__(self, in_features, num_latents, dropout_rate=0.25):
476
        super(MLPHead, self).__init__()
477
478
        self.layer1 = nn.Sequential(
479
            Rearrange('B C L->B L C'),
480
            nn.LayerNorm(in_features),
481
            nn.Linear(in_features, num_latents),
482
            nn.GELU(),
483
            nn.Dropout(dropout_rate), 
484
            Rearrange('B L C->B (C L)'),
485
        )
486
    def forward(self, x):
487
        # print(f'MLPHead input shape: {x.shape}')
488
        x = self.layer1(x)
489
        # print(f'After first layer of MLPHead shape: {x.shape}')
490
        return x
491
#########################################################################
492
493
494
def train_model(model, dataloader, optimizer, device, text_features_all, img_features_all):
495
    model.train()
496
    text_features_all = text_features_all.to(device).float() # (n_cls, d)
497
    img_features_all = (img_features_all[::10]).to(device).float()
498
    total_loss = 0
499
    correct = 0
500
    total = 0
501
    alpha=0.99
502
    features_list = []  # List to store features
503
    save_features= True
504
    for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
505
        eeg_data = eeg_data.to(device)
506
        text_features = text_features.to(device).float()
507
        img_features = img_features.to(device).float()
508
        labels = labels.to(device)
509
        
510
        optimizer.zero_grad()
511
        eeg_features = model(eeg_data).float()
512
        features_list.append(eeg_features)
513
        logit_scale = model.logit_scale
514
        
515
        img_loss = model.loss_func(eeg_features, img_features, logit_scale)
516
        text_loss = model.loss_func(eeg_features, text_features, logit_scale)
517
        # loss = img_loss + text_loss
518
        # print("text_loss", text_loss)
519
        # print("img_loss", img_loss)
520
        loss = alpha * img_loss + (1 - alpha) * text_loss
521
        loss.backward()
522
523
        optimizer.step()
524
        total_loss += loss.item()
525
        
526
        # logits = logit_scale * eeg_features @ text_features_all.T # (n_batch, n_cls)
527
        
528
        logits_img = logit_scale * eeg_features @ img_features_all.T
529
        # logits_text = logit_scale * eeg_features @ text_features_all.T
530
        # logits_single = (logits_text + logits_img) / 2.0        
531
        # logits_text = logit_scale * eeg_features @ text_features_all.T
532
        logits_single = logits_img
533
        predicted = torch.argmax(logits_single, dim=1) # (n_batch, ) \in {0, 1, ..., n_cls-1}
534
535
        batch_size = predicted.shape[0]
536
        total += batch_size
537
        correct += (predicted == labels).sum().item()
538
539
    average_loss = total_loss / (batch_idx+1)
540
    accuracy = correct / total
541
    return average_loss, accuracy
542
543
def evaluate_model(model, dataloader, device, text_features_all, img_features_all, k):
544
    model.eval()
545
    text_features_all = text_features_all.to(device).float()
546
    img_features_all = img_features_all.to(device).float()
547
    total_loss = 0
548
    correct = 0
549
    total = 0
550
    alpha = 0.99
551
    top5_correct = 0
552
    top5_correct_count = 0
553
    
554
    all_labels = set(range(text_features_all.size(0)))
555
    top5_acc = 0
556
    with torch.no_grad():
557
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
558
            eeg_data = eeg_data.to(device)
559
            text_features = text_features.to(device).float()
560
            labels = labels.to(device)
561
            img_features = img_features.to(device).float()
562
            eeg_features = model(eeg_data).float()
563
            logit_scale = model.logit_scale 
564
            # print(eeg_features.type, text_features.type, img_features.type)
565
            img_loss = model.loss_func(eeg_features, img_features, logit_scale)
566
            text_loss = model.loss_func(eeg_features, text_features, logit_scale)
567
            loss = img_loss*alpha + text_loss*(1-alpha)
568
            
569
            total_loss += loss.item()
570
            
571
            for idx, label in enumerate(labels):
572
                
573
                possible_classes = list(all_labels - {label.item()})
574
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
575
                # selected_text_features = text_features_all[selected_classes]
576
                selected_img_features = img_features_all[selected_classes]
577
                if k==200:
578
                    
579
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
580
                    # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
581
                    # logits_single = (logits_text + logits_img) / 2.0
582
                    logits_single = logits_img
583
                    # print("logits_single", logits_single.shape)
584
                    
585
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
586
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
587
                    if predicted_label == label.item():
588
                        # print("predicted_label", predicted_label)
589
                        correct += 1
590
                    
591
                    
592
                    
593
                    
594
                    # print("logits_single", logits_single)
595
                    _, top5_indices = torch.topk(logits_single, 5, largest =True)
596
                                                           
597
                    
598
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:     
599
                        # print("top5_indices", top5_indices)
600
                        # print("Yes")               
601
                        top5_correct_count+=1     
602
                    # print("*"*50)                               
603
                    total += 1
604
                    
605
                elif k==2 or k==4 or k==10:
606
                    
607
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
608
                    # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
609
                    # logits_single = (logits_text + logits_img) / 2.0
610
                    logits_single = logits_img
611
                    # print("logits_single", logits_single.shape)
612
                    
613
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
614
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
615
                    if predicted_label == label.item():
616
                        correct += 1
617
                    total += 1
618
                else:
619
                    print("Error.")
620
                    
621
    average_loss = total_loss / (batch_idx+1)
622
    accuracy = correct / total
623
    top5_acc = top5_correct_count / total
624
    return average_loss, accuracy, top5_acc
625
626
def main_train_loop(sub, model, train_dataloader, test_dataloader, optimizer, device, 
627
                    text_features_train_all, text_features_test_all, img_features_train_all, img_features_test_all, config, logger=None):
628
    logger = wandb_logger(config) if logger else None
629
    logger.watch(model,logger) 
630
    
631
    train_losses, train_accuracies = [], []
632
    test_losses, test_accuracies = [], []
633
    v2_accs = []
634
    v4_accs = []
635
    v10_accs = []
636
637
    best_accuracy = 0.0
638
    best_model_weights = None
639
    best_epoch_info = {}
640
    results = []
641
    for epoch in range(config['epochs']):
642
        
643
        train_loss, train_accuracy = train_model(model, train_dataloader, optimizer, device, text_features_train_all, img_features_train_all)
644
        
645
        if epoch%5 == 0:                        
646
            if config['insubject']==True:                
647
                torch.save(model.state_dict(), f"./models/{sub}_{epoch}.pth")
648
            else:
649
                torch.save(model.state_dict(), f"./models/across_{epoch}.pth")
650
        train_losses.append(train_loss)
651
        train_accuracies.append(train_accuracy)
652
653
        
654
        test_loss, test_accuracy, top5_acc = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all,k=200)
655
        _, v2_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 2)
656
        _, v4_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 4)
657
        _, v10_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 10)
658
        test_losses.append(test_loss)
659
        test_accuracies.append(test_accuracy)
660
        v2_accs.append(v2_acc)
661
        v4_accs.append(v4_acc)
662
        v10_accs.append(v10_acc)
663
        # Append results for this epoch
664
        epoch_results = {
665
        "epoch": epoch + 1,
666
        "train_loss": train_loss,
667
        "train_accuracy": train_accuracy,
668
        "test_loss": test_loss,
669
        "test_accuracy": test_accuracy,
670
        "v2_acc": v2_acc,
671
        "v4_acc": v4_acc,
672
        "v10_acc": v10_acc,
673
        "top5_acc":top5_acc
674
        }
675
        results.append(epoch_results)
676
        
677
        if test_accuracy > best_accuracy:
678
            best_accuracy = test_accuracy
679
            best_model_weights = model.state_dict().copy()
680
            best_epoch_info = {
681
                "epoch": epoch + 1,
682
                "train_loss": train_loss,
683
                "train_accuracy": train_accuracy,
684
                "test_loss": test_loss,
685
                "test_accuracy": test_accuracy,
686
                "v2_acc":v2_acc,
687
                "v4_acc":v4_acc,
688
                "v10_acc":v10_acc
689
            }
690
        logger.log({
691
            "Train Loss": train_loss,
692
            "Train Accuracy": train_accuracy,
693
            "Test Loss": test_loss,
694
            "Test Accuracy": test_accuracy,
695
            "v2 Accuracy": v2_acc,
696
            "v4 Accuracy": v4_acc,
697
            "v10 Accuracy": v10_acc,
698
            "Epoch": epoch
699
        })
700
701
        print(f"Epoch {epoch + 1}/{config['epochs']} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Top5 Accuracy: {top5_acc:.4f}")
702
        print(f"Epoch {epoch + 1}/{config['epochs']} - v2 Accuracy:{v2_acc} - v4 Accuracy:{v4_acc} - v10 Accuracy:{v10_acc}")
703
  
704
    
705
    # model.load_state_dict(best_model_weights)
706
707
    
708
    # torch.save(model.state_dict(), '{train_pos_img_text}.pth')
709
710
    
711
    fig, axs = plt.subplots(3, 2, figsize=(10, 15))
712
713
    
714
    axs[0, 0].plot(train_losses, label='Train Loss')
715
    axs[0, 0].plot(test_losses, label='Test Loss')
716
    axs[0, 0].legend()
717
    axs[0, 0].set_title("Loss Curve")
718
719
    
720
    axs[0, 1].plot(train_accuracies, label='Train Accuracy')
721
    axs[0, 1].plot(test_accuracies, label='Test Accuracy')
722
    axs[0, 1].legend()
723
    axs[0, 1].set_title("Accuracy Curve")
724
725
    
726
    
727
    axs[1, 0].plot(v2_accs, label='2-class Accuracy')
728
    axs[1, 0].legend()
729
    axs[1, 0].set_title("2-Class Accuracy Curve")
730
731
    
732
    axs[1, 1].plot(v4_accs, label='4-class Accuracy')
733
    axs[1, 1].legend()
734
    axs[1, 1].set_title("4-Class Accuracy Curve")
735
736
    
737
    axs[2, 0].plot(v10_accs, label='10-class Accuracy')
738
    axs[2, 0].legend()
739
    axs[2, 0].set_title("10-Class Accuracy Curve")
740
741
    
742
    info_text = (f"Best Model Info (from Epoch {best_epoch_info['epoch']}):\n"
743
                f"Train Loss: {best_epoch_info['train_loss']:.4f}\n"
744
                f"Train Accuracy: {best_epoch_info['train_accuracy']:.4f}\n"
745
                f"Test Loss: {best_epoch_info['test_loss']:.4f}\n"
746
                f"Test Accuracy: {best_epoch_info['test_accuracy']:.4f}\n"
747
                f"v2_acc:{best_epoch_info['v2_acc']:.4f}\n"
748
                f"v4_acc:{best_epoch_info['v4_acc']:.4f}\n"
749
                f"v10_acc:{best_epoch_info['v10_acc']:.4f}")
750
751
    axs[2, 1].axis('off')  
752
    axs[2, 1].text(0.5, 0.5, info_text, fontsize=10, ha='center', va='center', transform=axs[2, 1].transAxes)
753
754
    plt.tight_layout()
755
756
    
757
    plt.suptitle('pos_img_text', fontsize=16, y=1.05)
758
    plt.savefig('pos_img_text')
759
    logger.finish()
760
    return results
761
762
def main():
763
    parser = argparse.ArgumentParser(description='Train EEG-Image/Text Model')
764
765
    parser.add_argument('--data_path', type=str, default="/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz", help='Path to the preprocessed data')
766
    parser.add_argument('--project', type=str, default="train_pos_img_text_rep", help='Project name')
767
    parser.add_argument('--entity', type=str, default="sustech_rethinkingbci", help='Entity name')
768
    parser.add_argument('--name', type=str, default="lr=3e-4_img_pos_pro_eeg", help='Experiment name')
769
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
770
    parser.add_argument('--epochs', type=int, default=40, help='Number of training epochs')
771
    parser.add_argument('--batch_size', type=int, default=1024, help='Batch size')
772
    parser.add_argument('--logger', default=True, help='Enable logging')
773
    parser.add_argument('--insubject', default=True, help='Train within subject')
774
    parser.add_argument('--encoder_type', type=str, default='Projector', help='EEG encoder model type, you can choose from these options: Projector, EEGConformer_Encoder, MetaEEG, EEGNetv4_Encoder, ShallowFBCSPNet_Encoder, NICE, ATCNet_Encoder, EEGITNet_Encoder')
775
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training (e.g., "cuda:0" or "cpu")')
776
777
    args = parser.parse_args()
778
779
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
780
    data_path = args.data_path
781
    subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
782
783
    for sub in subjects:
784
        # Re-initialize the model for each subject
785
        model = globals()[args.encoder_type]((63, 250))
786
        model.to(device)
787
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
788
789
        print(f'Processing {sub}: number of parameters:', sum(p.numel() for p in model.parameters()))
790
791
        train_dataset = EEGDataset(
792
            data_path,
793
            subjects=[sub] if args.insubject else [],
794
            exclude_subject=sub if not args.insubject else None,
795
            train=True
796
        )
797
        test_dataset = EEGDataset(
798
            data_path,
799
            subjects=[sub] if args.insubject else [],
800
            exclude_subject=sub if not args.insubject else None,
801
            train=False
802
        )
803
804
        train_loader = DataLoader(
805
            train_dataset,
806
            batch_size=args.batch_size,
807
            shuffle=True,
808
            num_workers=0,
809
            drop_last=True
810
        )
811
        test_loader = DataLoader(
812
            test_dataset,
813
            batch_size=1,
814
            shuffle=True,
815
            num_workers=0,
816
            drop_last=True
817
        )
818
819
        text_features_train_all = train_dataset.text_features
820
        text_features_test_all = test_dataset.text_features
821
        img_features_train_all = train_dataset.img_features
822
        img_features_test_all = test_dataset.img_features
823
824
        config = vars(args)
825
826
        results = main_train_loop(
827
            sub,
828
            model,
829
            train_loader,
830
            test_loader,
831
            optimizer,
832
            device,
833
            text_features_train_all,
834
            text_features_test_all,
835
            img_features_train_all,
836
            img_features_test_all,
837
            config,
838
            logger=args.logger
839
        )
840
841
        # Save results to a CSV file
842
        current_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
843
        results_dir = f"./outputs/{args.encoder_type}/{sub}/{current_time}"
844
        os.makedirs(results_dir, exist_ok=True)
845
        results_file = f"{results_dir}/{args.encoder_type}_{'cross_exclude_' if not args.insubject else ''}{sub}.csv"
846
847
        with open(results_file, 'w', newline='') as file:
848
            writer = csv.DictWriter(file, fieldnames=results[0].keys())
849
            writer.writeheader()
850
            writer.writerows(results)
851
        print(f'Results saved to {results_file}')
852
853
if __name__ == '__main__':
854
    main()
855
    
856