a b/Retrieval/ATMS_retrieval.py
1
import os
2
3
import torch
4
import torch.optim as optim
5
from torch.nn import CrossEntropyLoss
6
from torch.nn import functional as F
7
from torch.optim import Adam
8
from torch.utils.data import DataLoader
9
10
os.environ["WANDB_API_KEY"] = "KEY"
11
os.environ["WANDB_MODE"] = 'offline'
12
from itertools import combinations
13
14
import clip
15
import matplotlib.pyplot as plt
16
import numpy as np
17
import torch.nn as nn
18
import torchvision.transforms as transforms
19
import tqdm
20
from eegdatasets_leaveone import EEGDataset
21
22
from einops.layers.torch import Rearrange, Reduce
23
24
from sklearn.metrics import confusion_matrix
25
from torch.utils.data import DataLoader, Dataset
26
import random
27
from util import wandb_logger
28
from braindecode.models import EEGNetv4, ATCNet, EEGConformer, EEGITNet, ShallowFBCSPNet
29
import csv
30
from torch import Tensor
31
import itertools
32
import math
33
import re
34
from subject_layers.Transformer_EncDec import Encoder, EncoderLayer
35
from subject_layers.SelfAttention_Family import FullAttention, AttentionLayer
36
from subject_layers.Embed import DataEmbedding
37
import numpy as np
38
from loss import ClipLoss
39
import argparse
40
from torch import nn
41
from torch.optim import AdamW
42
43
class Config:
44
    def __init__(self):
45
        self.task_name = 'classification'  # Example task name
46
        self.seq_len = 250                 # Sequence length
47
        self.pred_len = 250                # Prediction length
48
        self.output_attention = False      # Whether to output attention weights
49
        self.d_model = 250                 # Model dimension
50
        self.embed = 'timeF'               # Time encoding method
51
        self.freq = 'h'                    # Time frequency
52
        self.dropout = 0.25                # Dropout rate
53
        self.factor = 1                    # Attention scaling factor
54
        self.n_heads = 4                   # Number of attention heads
55
        self.e_layers = 1                  # Number of encoder layers
56
        self.d_ff = 256                    # Feedforward network dimension
57
        self.activation = 'gelu'           # Activation function
58
        self.enc_in = 63                   # Encoder input dimension (example value)
59
        
60
class iTransformer(nn.Module):
61
    def __init__(self, configs, joint_train=False,  num_subjects=10):
62
        super(iTransformer, self).__init__()
63
        self.task_name = configs.task_name
64
        self.seq_len = configs.seq_len
65
        self.pred_len = configs.pred_len
66
        self.output_attention = configs.output_attention
67
        # Embedding
68
        self.enc_embedding = DataEmbedding(configs.seq_len, configs.d_model, configs.embed, configs.freq, configs.dropout, joint_train=False, num_subjects=num_subjects)
69
        # Encoder
70
        self.encoder = Encoder(
71
            [
72
                EncoderLayer(
73
                    AttentionLayer(
74
                        FullAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=configs.output_attention),
75
                        configs.d_model, configs.n_heads
76
                    ),
77
                    configs.d_model,
78
                    configs.d_ff,
79
                    dropout=configs.dropout,
80
                    activation=configs.activation
81
                ) for l in range(configs.e_layers)
82
            ],
83
            norm_layer=torch.nn.LayerNorm(configs.d_model)
84
        )
85
86
    def forward(self, x_enc, x_mark_enc, subject_ids=None):
87
        # Embedding
88
        enc_out = self.enc_embedding(x_enc, x_mark_enc, subject_ids)
89
        enc_out, attns = self.encoder(enc_out, attn_mask=None)
90
        enc_out = enc_out[:, :63, :]      
91
        # print("enc_out", enc_out.shape)
92
        return enc_out
93
94
95
96
class PatchEmbedding(nn.Module):
97
    def __init__(self, emb_size=40):
98
        super().__init__()
99
        # Revised from ShallowNet
100
        self.tsconv = nn.Sequential(
101
            nn.Conv2d(1, 40, (1, 25), stride=(1, 1)),
102
            nn.AvgPool2d((1, 51), (1, 5)),
103
            nn.BatchNorm2d(40),
104
            nn.ELU(),
105
            nn.Conv2d(40, 40, (63, 1), stride=(1, 1)),
106
            nn.BatchNorm2d(40),
107
            nn.ELU(),
108
            nn.Dropout(0.5),
109
        )
110
111
        self.projection = nn.Sequential(
112
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  
113
            Rearrange('b e (h) (w) -> b (h w) e'),
114
        )
115
116
    def forward(self, x: Tensor) -> Tensor:
117
        # b, _, _, _ = x.shape
118
        x = x.unsqueeze(1)     
119
        # print("x", x.shape)   
120
        x = self.tsconv(x)
121
        # print("tsconv", x.shape)   
122
        x = self.projection(x)
123
        # print("projection", x.shape)  
124
        return x
125
126
127
class ResidualAdd(nn.Module):
128
    def __init__(self, fn):
129
        super().__init__()
130
        self.fn = fn
131
132
    def forward(self, x, **kwargs):
133
        res = x
134
        x = self.fn(x, **kwargs)
135
        x += res
136
        return x
137
138
139
class FlattenHead(nn.Sequential):
140
    def __init__(self):
141
        super().__init__()
142
143
    def forward(self, x):
144
        x = x.contiguous().view(x.size(0), -1)
145
        return x
146
147
148
class Enc_eeg(nn.Sequential):
149
    def __init__(self, emb_size=40, **kwargs):
150
        super().__init__(
151
            PatchEmbedding(emb_size),
152
            FlattenHead()
153
        )
154
155
        
156
class Proj_eeg(nn.Sequential):
157
    def __init__(self, embedding_dim=1440, proj_dim=1024, drop_proj=0.5):
158
        super().__init__(
159
            nn.Linear(embedding_dim, proj_dim),
160
            ResidualAdd(nn.Sequential(
161
                nn.GELU(),
162
                nn.Linear(proj_dim, proj_dim),
163
                nn.Dropout(drop_proj),
164
            )),
165
            nn.LayerNorm(proj_dim),
166
        )
167
168
169
170
class ATMS(nn.Module):    
171
    def __init__(self, num_channels=63, sequence_length=250, num_subjects=2, num_features=64, num_latents=1024, num_blocks=1):
172
        super(ATMS, self).__init__()
173
        default_config = Config()
174
        self.encoder = iTransformer(default_config)   
175
        self.subject_wise_linear = nn.ModuleList([nn.Linear(default_config.d_model, sequence_length) for _ in range(num_subjects)])
176
        self.enc_eeg = Enc_eeg()
177
        self.proj_eeg = Proj_eeg()        
178
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
179
        self.loss_func = ClipLoss()       
180
         
181
    def forward(self, x, subject_ids):
182
        x = self.encoder(x, None, subject_ids)
183
        # print(f'After attention shape: {x.shape}')
184
        # print("x", x.shape)
185
        # x = self.subject_wise_linear[0](x)
186
        # print(f'After subject-specific linear transformation shape: {x.shape}')
187
        eeg_embedding = self.enc_eeg(x)
188
        
189
        out = self.proj_eeg(eeg_embedding)
190
        return out  
191
    
192
def extract_id_from_string(s):
193
    match = re.search(r'\d+$', s)
194
    if match:
195
        return int(match.group())
196
    return None
197
198
def train_model(sub, eeg_model, dataloader, optimizer, device, text_features_all, img_features_all, config):
199
    eeg_model.train()
200
    text_features_all = text_features_all.to(device).float() # (n_cls, d)
201
    img_features_all = (img_features_all[::10]).to(device).float()
202
    total_loss = 0
203
    correct = 0
204
    total = 0
205
    alpha=0.99
206
    features_list = []  # List to store features
207
    save_features= True
208
    for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
209
        eeg_data = eeg_data.to(device)
210
        text_features = text_features.to(device).float()
211
        img_features = img_features.to(device).float()
212
        labels = labels.to(device)
213
        
214
        optimizer.zero_grad()
215
        
216
        batch_size = eeg_data.size(0)  
217
        subject_id = extract_id_from_string(sub)
218
        # eeg_data = eeg_data.permute(0, 2, 1)
219
        subject_ids = torch.full((batch_size,), subject_id, dtype=torch.long).to(device)
220
        # if not config.insubject:
221
        #     subject_ids = torch.full((batch_size,), -1, dtype=torch.long).to(device)     
222
        eeg_features = eeg_model(eeg_data, subject_ids).float()
223
224
        
225
        features_list.append(eeg_features)
226
        logit_scale = eeg_model.logit_scale
227
        
228
        img_loss = eeg_model.loss_func(eeg_features, img_features, logit_scale)
229
        text_loss = eeg_model.loss_func(eeg_features, text_features, logit_scale)
230
        # loss = img_loss + text_loss
231
        # print("text_loss", text_loss)
232
        # print("img_loss", img_loss)
233
        loss = alpha * img_loss + (1 - alpha) * text_loss
234
        loss.backward()
235
236
        optimizer.step()
237
        total_loss += loss.item()
238
        
239
        # Compute the corresponding logits
240
        logits_img = logit_scale * eeg_features @ img_features_all.T
241
        # logits_text = logit_scale * eeg_features @ text_features_all.T
242
        # logits_single = (logits_text + logits_img) / 2.0        
243
        # logits_text = logit_scale * eeg_features @ text_features_all.T
244
        logits_single = logits_img
245
        predicted = torch.argmax(logits_single, dim=1) # (n_batch, ) in {0, 1, ..., n_cls-1}
246
247
        batch_size = predicted.shape[0]
248
        total += batch_size
249
        correct += (predicted == labels).sum().item()
250
        del eeg_data, eeg_features, img_features
251
    average_loss = total_loss / (batch_idx+1)
252
    accuracy = correct / total
253
    return average_loss, accuracy, torch.cat(features_list, dim=0)
254
255
256
257
def evaluate_model(sub, eeg_model, dataloader, device, text_features_all, img_features_all, k, config):
258
    eeg_model.eval()
259
260
    
261
    text_features_all = text_features_all.to(device).float()
262
    img_features_all = img_features_all.to(device).float()
263
    total_loss = 0
264
    correct = 0
265
    total = 0
266
    alpha = 0.99
267
    top5_correct = 0
268
    top5_correct_count = 0
269
    # Get all unique classes
270
    all_labels = set(range(text_features_all.size(0)))
271
    top5_acc = 0
272
    with torch.no_grad():
273
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
274
            eeg_data = eeg_data.to(device)
275
            text_features = text_features.to(device).float()
276
            labels = labels.to(device)
277
            img_features = img_features.to(device).float()
278
            
279
            batch_size = eeg_data.size(0) 
280
            subject_id = extract_id_from_string(sub)
281
            # eeg_data = eeg_data.permute(0, 2, 1)
282
            subject_ids = torch.full((batch_size,), subject_id, dtype=torch.long).to(device)
283
            # if not config.insubject:
284
            #     subject_ids = torch.full((batch_size,), -1, dtype=torch.long).to(device)          
285
            eeg_features = eeg_model(eeg_data, subject_ids)
286
287
        
288
            logit_scale = eeg_model.logit_scale 
289
            # print(eeg_features.type, text_features.type, img_features.type)
290
            img_loss = eeg_model.loss_func(eeg_features, img_features, logit_scale)
291
            text_loss = eeg_model.loss_func(eeg_features, text_features, logit_scale)
292
            loss = img_loss*alpha + text_loss*(1-alpha)
293
            
294
            total_loss += loss.item()
295
            
296
            for idx, label in enumerate(labels):
297
                # First, select k-1 classes excluding the correct class
298
                possible_classes = list(all_labels - {label.item()})
299
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
300
                selected_img_features = img_features_all[selected_classes]
301
                selected_text_features = text_features_all[selected_classes]
302
                
303
                if k==200:
304
                    # Compute the corresponding logits
305
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
306
                    logits_single = logits_img
307
                    # print("logits_single", logits_single.shape)
308
                    # Get the predicted class
309
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
310
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) in {0, 1, ..., n_cls-1}
311
                    if predicted_label == label.item():
312
                        # print("predicted_label", predicted_label)
313
                        correct += 1
314
                    
315
                    # logits_single is the model's output, shape (n_batch, n_classes)
316
                    # label is the true label, shape (n_batch,)
317
                    # Get the indices of the top-5 predictions
318
                    # print("logits_single", logits_single)
319
                    _, top5_indices = torch.topk(logits_single, 5, largest =True)
320
                                                   
321
                    # Check if the true label is in the top-5 predictions
322
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:                
323
                        top5_correct_count+=1                                
324
                    total += 1
325
                elif k == 50 or k == 100:
326
                    # For k=50 or 100, select k classes for evaluation
327
                    selected_classes = random.sample(possible_classes, k-1) + [label.item()]
328
329
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
330
                    logits_single = logits_img
331
                    
332
                    predicted_label = selected_classes[torch.argmax(logits_single).item()]
333
                    if predicted_label == label.item():
334
                        correct += 1
335
                    _, top5_indices = torch.topk(logits_single, 5, largest =True)
336
                                                   
337
                    # Check if the true label is in the top-5 predictions
338
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:                
339
                        top5_correct_count+=1                                
340
                    total += 1
341
                elif k==2 or k==4 or k==10:
342
                    selected_classes = random.sample(possible_classes, k-1) + [label.item()]
343
                    # Compute the corresponding logits
344
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
345
                    # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
346
                    # logits_single = (logits_text + logits_img) / 2.0
347
                    logits_single = logits_img
348
                    # print("logits_single", logits_single.shape)
349
                    # Get the predicted class
350
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
351
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) in {0, 1, ..., n_cls-1}
352
                    if predicted_label == label.item():
353
                        correct += 1
354
                    total += 1
355
                else:
356
                    print("Error.")
357
            del eeg_data, eeg_features, img_features
358
    average_loss = total_loss / (batch_idx+1)
359
    accuracy = correct / total
360
    top5_acc = top5_correct_count / total
361
    return average_loss, accuracy, top5_acc
362
363
def main_train_loop(sub, current_time, eeg_model, train_dataloader, test_dataloader, optimizer, device, text_features_train_all, text_features_test_all, img_features_train_all, img_features_test_all, config, logger=None):
364
    logger = wandb_logger(config) if logger else None
365
    logger.watch(eeg_model,logger) 
366
    train_losses, train_accuracies = [], []
367
    test_losses, test_accuracies = [], []
368
    v2_accs = []
369
    v4_accs = []
370
    v10_accs = []
371
372
    best_accuracy = 0.0
373
    best_model_weights = None
374
    best_epoch_info = {}
375
    results = []  # List to store results for each epoch
376
    
377
    for epoch in range(config.epochs):
378
        # Train the model
379
        train_loss, train_accuracy, features_tensor = train_model(sub, eeg_model, train_dataloader, optimizer, device, text_features_train_all, img_features_train_all, config=config)
380
        if (epoch +1) % 5 == 0:                    
381
            # Get the current time and format it as a string (e.g., '2024-01-17_15-30-00')                  
382
            if config.insubject==True:       
383
                os.makedirs(f"./models/contrast/{config.encoder_type}/{sub}/{current_time}", exist_ok=True)             
384
                file_path = f"./models/contrast/{config.encoder_type}/{sub}/{current_time}/{epoch+1}.pth"
385
                torch.save(eeg_model.state_dict(), file_path)            
386
            else:                
387
                os.makedirs(f"./models/contrast/across/{config.encoder_type}/{current_time}", exist_ok=True)             
388
                file_path = f"./models/contrast/across/{config.encoder_type}/{current_time}/{epoch+1}.pth"
389
                torch.save(eeg_model.state_dict(), file_path)
390
            print(f"model saved in {file_path}!")
391
        train_losses.append(train_loss)
392
        train_accuracies.append(train_accuracy)
393
394
395
        # Evaluate the model
396
        test_loss, test_accuracy, top5_acc = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all,k=200, config=config)
397
        _, v2_acc, _ = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 2, config=config)
398
        _, v4_acc, _ = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 4, config=config)
399
        _, v10_acc, _ = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 10, config=config)
400
        _, v50_acc, v50_top5_acc = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all,  k=50, config=config)
401
        _, v100_acc, v100_top5_acc = evaluate_model(sub, eeg_model, test_dataloader, device, text_features_test_all, img_features_test_all,  k=100, config=config)
402
        test_losses.append(test_loss)
403
        test_accuracies.append(test_accuracy)
404
        v2_accs.append(v2_acc)
405
        v4_accs.append(v4_acc)
406
        v10_accs.append(v10_acc)
407
        
408
        # Append results for this epoch
409
        epoch_results = {
410
        "epoch": epoch + 1,
411
        # "train_loss": train_loss,
412
        # "train_accuracy": train_accuracy,
413
        "test_loss": test_loss,
414
        "test_accuracy": test_accuracy,
415
        "v2_acc": v2_acc,
416
        "v4_acc": v4_acc,
417
        "v10_acc": v10_acc,
418
        "top5_acc":top5_acc,
419
        "v50_acc": v50_acc,
420
        "v100_acc": v100_acc,
421
        "v50_top5_acc":v50_top5_acc,
422
        "v100_top5_acc": v100_top5_acc
423
        }
424
425
        results.append(epoch_results)
426
        # If the test accuracy of the current epoch is the best, save the model and related information
427
        if test_accuracy > best_accuracy:
428
            best_accuracy = test_accuracy
429
            # best_model_weights = model.state_dict().copy()
430
            
431
            best_epoch_info = {
432
                "epoch": epoch + 1,
433
                "train_loss": train_loss,
434
                "train_accuracy": train_accuracy,
435
                "test_loss": test_loss,
436
                "test_accuracy": test_accuracy,
437
                "v2_acc":v2_acc,
438
                "v4_acc":v4_acc,
439
                "v10_acc":v10_acc
440
            }
441
        logger.log({
442
            "Train Loss": train_loss,
443
            "Train Accuracy": train_accuracy,
444
            "Test Loss": test_loss,
445
            "Test Accuracy": test_accuracy,
446
            "v2 Accuracy": v2_acc,
447
            "v4 Accuracy": v4_acc,
448
            "v10 Accuracy": v10_acc,
449
            "Epoch": epoch
450
        })
451
452
        print(f"Epoch {epoch + 1}/{config.epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Top5 Accuracy: {top5_acc:.4f}")
453
        print(f"Epoch {epoch + 1}/{config.epochs} - v2 Accuracy:{v2_acc} - v4 Accuracy:{v4_acc} - v10 Accuracy:{v10_acc} - v50 Accuracy:{v50_acc} - v100 Accuracy:{v100_acc}")
454
  
455
    # # Load the best model weights
456
    # model.load_state_dict(best_model_weights)
457
458
    # # # Save the best model
459
    # torch.save(model.state_dict(), '{train_pos_img_text}.pth')
460
461
    # Create 5 subplots
462
    fig, axs = plt.subplots(3, 2, figsize=(10, 15))
463
464
    # Loss curve
465
    axs[0, 0].plot(train_losses, label='Train Loss')
466
    axs[0, 0].plot(test_losses, label='Test Loss')
467
    axs[0, 0].legend()
468
    axs[0, 0].set_title("Loss Curve")
469
470
    # Overall accuracy curve
471
    axs[0, 1].plot(train_accuracies, label='Train Accuracy')
472
    axs[0, 1].plot(test_accuracies, label='Test Accuracy')
473
    axs[0, 1].legend()
474
    axs[0, 1].set_title("Accuracy Curve")
475
476
    # The following are the three new plots you added, assuming you've already calculated the corresponding accuracies
477
    # 2-class accuracy plot
478
    axs[1, 0].plot(v2_accs, label='2-class Accuracy')
479
    axs[1, 0].legend()
480
    axs[1, 0].set_title("2-Class Accuracy Curve")
481
482
    # 4-class accuracy plot
483
    axs[1, 1].plot(v4_accs, label='4-class Accuracy')
484
    axs[1, 1].legend()
485
    axs[1, 1].set_title("4-Class Accuracy Curve")
486
487
    # 10-class accuracy plot
488
    axs[2, 0].plot(v10_accs, label='10-class Accuracy')
489
    axs[2, 0].legend()
490
    axs[2, 0].set_title("10-Class Accuracy Curve")
491
492
    # Construct the string information for annotation
493
    info_text = (f"Best Model Info (from Epoch {best_epoch_info['epoch']}):\n"
494
                f"Train Loss: {best_epoch_info['train_loss']:.4f}\n"
495
                f"Train Accuracy: {best_epoch_info['train_accuracy']:.4f}\n"
496
                f"Test Loss: {best_epoch_info['test_loss']:.4f}\n"
497
                f"Test Accuracy: {best_epoch_info['test_accuracy']:.4f}\n"
498
                f"v2_acc:{best_epoch_info['v2_acc']:.4f}\n"
499
                f"v4_acc:{best_epoch_info['v4_acc']:.4f}\n"
500
                f"v10_acc:{best_epoch_info['v10_acc']:.4f}")
501
502
    axs[2, 1].axis('off')  
503
    axs[2, 1].text(0.5, 0.5, info_text, fontsize=10, ha='center', va='center', transform=axs[2, 1].transAxes)
504
505
    plt.tight_layout()
506
507
    # Add main title
508
    plt.suptitle('pos_img_text', fontsize=16, y=1.05)
509
    plt.savefig('pos_img_text')
510
    logger.finish()
511
    return results
512
513
import datetime
514
515
def main():
516
    # Use argparse to parse the command-line arguments
517
    parser = argparse.ArgumentParser(description='EEG Transformer Training Script')
518
    parser.add_argument('--data_path', type=str, default="/root/autodl-tmp/THINGS/Preprocessed_data_250Hz", help='Path to the EEG dataset')
519
    parser.add_argument('--output_dir', type=str, default='./outputs/contrast', help='Directory to save output results')    
520
    parser.add_argument('--project', type=str, default="train_pos_img_text_rep", help='WandB project name')
521
    parser.add_argument('--entity', type=str, default="sustech_rethinkingbci", help='WandB entity name')
522
    parser.add_argument('--name', type=str, default="lr=3e-4_img_pos_pro_eeg", help='Experiment name')
523
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
524
    parser.add_argument('--epochs', type=int, default=40, help='Number of epochs')
525
    parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
526
    parser.add_argument('--logger', type=bool, default=True, help='Enable WandB logging')
527
    parser.add_argument('--gpu', type=str, default='cuda:0', help='GPU device to use')
528
    parser.add_argument('--device', type=str, choices=['cpu', 'gpu'], default='gpu', help='Device to run on (cpu or gpu)')    
529
    parser.add_argument('--insubject', type=bool, default=True, help='In-subject mode or cross-subject mode')
530
    parser.add_argument('--encoder_type', type=str, default='ATMS', help='Encoder type')
531
    parser.add_argument('--subjects', nargs='+', default=['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10'], help='List of subject IDs (default: sub-01 to sub-10)')    
532
    args = parser.parse_args()
533
534
    # Set device based on the argument
535
    if args.device == 'gpu' and torch.cuda.is_available():
536
        device = torch.device(args.gpu)
537
    else:
538
        device = torch.device('cpu')
539
540
    subjects = args.subjects        
541
    current_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
542
543
    for sub in subjects:
544
        eeg_model = globals()[args.encoder_type]()
545
        eeg_model.to(device)
546
547
        optimizer = AdamW(itertools.chain(eeg_model.parameters()), lr=args.lr)
548
549
        if args.insubject:
550
            train_dataset = EEGDataset(args.data_path, subjects=[sub], train=True)
551
            test_dataset = EEGDataset(args.data_path, subjects=[sub], train=False)
552
        else:
553
            train_dataset = EEGDataset(args.data_path, exclude_subject=sub, subjects=subjects, train=True)
554
            test_dataset = EEGDataset(args.data_path, exclude_subject=sub, subjects=subjects, train=False)
555
556
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
557
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0, drop_last=True)
558
559
        text_features_train_all = train_dataset.text_features
560
        text_features_test_all = test_dataset.text_features
561
        img_features_train_all = train_dataset.img_features
562
        img_features_test_all = test_dataset.img_features
563
564
        results = main_train_loop(sub, current_time, eeg_model, train_loader, test_loader, optimizer, device, 
565
                                  text_features_train_all, text_features_test_all, img_features_train_all, img_features_test_all, config=args, logger=args.logger)
566
567
568
        # Save results to a CSV file
569
        results_dir = os.path.join(args.output_dir, args.encoder_type, sub, current_time)
570
        os.makedirs(results_dir, exist_ok=True)
571
572
        if args.insubject:
573
            results_file = f"{results_dir}/{args.encoder_type}_{sub}.csv"
574
        else:
575
            results_file = f"{results_dir}/{args.encoder_type}_cross_exclude_{sub}.csv"
576
577
        with open(results_file, 'w', newline='') as file:
578
            writer = csv.DictWriter(file, fieldnames=results[0].keys())
579
            writer.writeheader()
580
            writer.writerows(results)
581
            print(f'Results saved to {results_file}')
582
583
                
584
if __name__ == '__main__':
585
    main()