Switch to unified view

a b/Retrieval/ATME_retrieval.py
1
import os
2
3
import torch
4
import torch.optim as optim
5
from torch.nn import CrossEntropyLoss
6
from torch.nn import functional as F
7
from torch.optim import Adam
8
from torch.utils.data import DataLoader
9
10
os.environ["WANDB_API_KEY"] = "KEY"
11
os.environ["WANDB_MODE"] = 'offline'
12
from itertools import combinations
13
14
import clip
15
import matplotlib.pyplot as plt
16
import numpy as np
17
import torch.nn as nn
18
import torchvision.transforms as transforms
19
import tqdm
20
from BrainAligning_retrieval.eegdatasets_leaveone import EEGDataset
21
from einops.layers.torch import Rearrange, Reduce
22
from lavis.models.clip_models.loss import ClipLoss
23
from sklearn.metrics import confusion_matrix
24
from torch.utils.data import DataLoader, Dataset
25
import random
26
from utils import wandb_logger
27
import csv
28
29
30
class PositionalEncoding(nn.Module):
31
    def __init__(self, d_model, max_len=5000):
32
        super(PositionalEncoding, self).__init__()
33
        pe = torch.zeros(max_len, d_model)
34
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
35
        
36
        div_term = torch.exp(torch.arange(0, d_model + 1, 2).float() * (-math.log(10000.0) / d_model))
37
        
38
        pe[:, 0::2] = torch.sin(position * div_term[:d_model // 2 + 1])
39
        pe[:, 1::2] = torch.cos(position * div_term[:d_model // 2])
40
41
        self.register_buffer('pe', pe)
42
43
    def forward(self, x):
44
        pe = self.pe[:x.size(0), :].unsqueeze(1).repeat(1, x.size(1), 1)
45
        x = x + pe
46
        return x
47
48
class EEGAttention(nn.Module):
49
    def __init__(self, channel, d_model, nhead):
50
        super(EEGAttention, self).__init__()
51
        self.pos_encoder = PositionalEncoding(d_model)
52
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
53
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
54
        self.channel = channel
55
        self.d_model = d_model
56
57
    def forward(self, src):
58
        src = src.permute(2, 0, 1)  # Change shape to [time_length, batch_size, channel]
59
        src = self.pos_encoder(src)
60
        output = self.transformer_encoder(src)
61
        return output.permute(1, 2, 0)  # Change shape back to [batch_size, channel, time_length]
62
63
class PatchEmbedding(nn.Module):
64
    def __init__(self, emb_size=40):
65
        super().__init__()
66
        # revised from shallownet
67
        self.shape = (63, 250)
68
        self.tsconv = EEGNetv4(
69
            in_chans=self.shape[0],
70
            n_classes=1440,   
71
            input_window_samples=self.shape[1],
72
            final_conv_length='auto',
73
            pool_mode='mean',
74
            F1=8,
75
            D=20,
76
            F2=160,
77
            kernel_length=4,
78
            third_kernel_size=(4, 2),
79
            drop_prob=0.25
80
        )
81
82
83
    def forward(self, x: Tensor) -> Tensor:
84
        x = x.unsqueeze(3)     
85
        # print("x", x.shape)   
86
        x = self.tsconv(x)
87
        
88
        return x
89
90
class ResidualAdd(nn.Module):
91
    def __init__(self, fn):
92
        super().__init__()
93
        self.fn = fn
94
95
    def forward(self, x, **kwargs):
96
        res = x
97
        x = self.fn(x, **kwargs)
98
        x += res
99
        return x
100
101
class FlattenHead(nn.Sequential):
102
    def __init__(self):
103
        super().__init__()
104
105
    def forward(self, x):
106
        x = x.contiguous().view(x.size(0), -1)
107
        return x
108
109
class Enc_eeg(nn.Sequential):
110
    def __init__(self, emb_size=40, **kwargs):
111
        super().__init__(
112
            PatchEmbedding(emb_size),
113
            FlattenHead()
114
        )
115
116
class Proj_eeg(nn.Sequential):
117
    def __init__(self, embedding_dim=1440, proj_dim=1024, drop_proj=0.5):
118
        super().__init__(
119
            nn.Linear(embedding_dim, proj_dim),
120
            ResidualAdd(nn.Sequential(
121
                nn.GELU(),
122
                nn.Linear(proj_dim, proj_dim),
123
                nn.Dropout(drop_proj),
124
            )),
125
            nn.LayerNorm(proj_dim),
126
        )
127
128
class ATM_E(nn.Module):    
129
    def __init__(self, num_channels=63, sequence_length=250, num_subjects=1, num_features=64, num_latents=1024, num_blocks=1):
130
        super(ATM_E, self).__init__()
131
        self.attention_model = EEGAttention(num_channels, num_channels, nhead=1)   
132
        self.subject_wise_linear = nn.ModuleList([nn.Linear(sequence_length, sequence_length) for _ in range(num_subjects)])
133
        self.enc_eeg = Enc_eeg()
134
        self.proj_eeg = Proj_eeg()        
135
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
136
        self.loss_func = ClipLoss()       
137
         
138
    def forward(self, x):
139
        x = self.attention_model(x)
140
        # print(f'After attention shape: {x.shape}')
141
         
142
        x = self.subject_wise_linear[0](x)
143
        # print(f'After subject-specific linear transformation shape: {x.shape}')
144
        
145
        eeg_embedding = self.enc_eeg(x)
146
        # print(f'After enc_eeg shape: {eeg_embedding.shape}')
147
        
148
        out = self.proj_eeg(eeg_embedding)
149
        return out  
150
151
152
def train_model(model, dataloader, optimizer, device, text_features_all, img_features_all):
153
    model.train()
154
    text_features_all = text_features_all.to(device).float() # (n_cls, d)
155
    img_features_all = (img_features_all[::10]).to(device).float()
156
    total_loss = 0
157
    correct = 0
158
    total = 0
159
    alpha=0.99
160
    features_list = []  # List to store features
161
    save_features= True
162
    for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
163
        eeg_data = eeg_data.to(device)
164
        text_features = text_features.to(device).float()
165
        img_features = img_features.to(device).float()
166
        labels = labels.to(device)
167
        
168
        optimizer.zero_grad()
169
        eeg_features = model(eeg_data).float()
170
        features_list.append(eeg_features)
171
        logit_scale = model.logit_scale
172
        
173
        img_loss = model.loss_func(eeg_features, img_features, logit_scale)
174
        text_loss = model.loss_func(eeg_features, text_features, logit_scale)
175
        # loss = img_loss + text_loss
176
        # print("text_loss", text_loss)
177
        # print("img_loss", img_loss)
178
        loss = alpha * img_loss + (1 - alpha) * text_loss
179
        loss.backward()
180
181
        optimizer.step()
182
        total_loss += loss.item()
183
        
184
        # logits = logit_scale * eeg_features @ text_features_all.T # (n_batch, n_cls)
185
        
186
        logits_img = logit_scale * eeg_features @ img_features_all.T
187
        # logits_text = logit_scale * eeg_features @ text_features_all.T
188
        # logits_single = (logits_text + logits_img) / 2.0        
189
        # logits_text = logit_scale * eeg_features @ text_features_all.T
190
        logits_single = logits_img
191
        predicted = torch.argmax(logits_single, dim=1) # (n_batch, ) \in {0, 1, ..., n_cls-1}
192
193
        batch_size = predicted.shape[0]
194
        total += batch_size
195
        correct += (predicted == labels).sum().item()
196
197
    average_loss = total_loss / (batch_idx+1)
198
    accuracy = correct / total
199
    return average_loss, accuracy
200
201
def evaluate_model(model, dataloader, device, text_features_all, img_features_all, k):
202
    model.eval()
203
    text_features_all = text_features_all.to(device).float()
204
    img_features_all = img_features_all.to(device).float()
205
    total_loss = 0
206
    correct = 0
207
    total = 0
208
    alpha = 0.99
209
    top5_correct = 0
210
    top5_correct_count = 0
211
    
212
    all_labels = set(range(text_features_all.size(0)))
213
    top5_acc = 0
214
    with torch.no_grad():
215
        for batch_idx, (eeg_data, labels, text, text_features, img, img_features) in enumerate(dataloader):
216
            eeg_data = eeg_data.to(device)
217
            text_features = text_features.to(device).float()
218
            labels = labels.to(device)
219
            img_features = img_features.to(device).float()
220
            eeg_features = model(eeg_data).float()
221
            logit_scale = model.logit_scale 
222
            # print(eeg_features.type, text_features.type, img_features.type)
223
            img_loss = model.loss_func(eeg_features, img_features, logit_scale)
224
            text_loss = model.loss_func(eeg_features, text_features, logit_scale)
225
            loss = img_loss*alpha + text_loss*(1-alpha)
226
            
227
            total_loss += loss.item()
228
            
229
            for idx, label in enumerate(labels):
230
                
231
                possible_classes = list(all_labels - {label.item()})
232
                selected_classes = random.sample(possible_classes, k-1) + [label.item()]
233
                # selected_text_features = text_features_all[selected_classes]
234
                selected_img_features = img_features_all[selected_classes]
235
                if k==200:
236
                    
237
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
238
                    # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
239
                    # logits_single = (logits_text + logits_img) / 2.0
240
                    logits_single = logits_img
241
                    # print("logits_single", logits_single.shape)
242
                    
243
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
244
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
245
                    if predicted_label == label.item():
246
                        # print("predicted_label", predicted_label)
247
                        correct += 1
248
                    
249
                    
250
                    
251
                    
252
                    # print("logits_single", logits_single)
253
                    _, top5_indices = torch.topk(logits_single, 5, largest =True)
254
                                                           
255
                    
256
                    if label.item() in [selected_classes[i] for i in top5_indices.tolist()]:     
257
                        # print("top5_indices", top5_indices)
258
                        # print("Yes")               
259
                        top5_correct_count+=1     
260
                    # print("*"*50)                               
261
                    total += 1
262
                    
263
                elif k==2 or k==4 or k==10:
264
                    
265
                    logits_img = logit_scale * eeg_features[idx] @ selected_img_features.T
266
                    # logits_text = logit_scale * eeg_features[idx] @ selected_text_features.T
267
                    # logits_single = (logits_text + logits_img) / 2.0
268
                    logits_single = logits_img
269
                    # print("logits_single", logits_single.shape)
270
                    
271
                    # predicted_label = selected_classes[torch.argmax(logits_single).item()]
272
                    predicted_label = selected_classes[torch.argmax(logits_single).item()] # (n_batch, ) \in {0, 1, ..., n_cls-1}
273
                    if predicted_label == label.item():
274
                        correct += 1
275
                    total += 1
276
                else:
277
                    print("Error.")
278
                    
279
    average_loss = total_loss / (batch_idx+1)
280
    accuracy = correct / total
281
    top5_acc = top5_correct_count / total
282
    return average_loss, accuracy, top5_acc
283
284
def main_train_loop(sub, model, train_dataloader, test_dataloader, optimizer, device, 
285
                    text_features_train_all, text_features_test_all, img_features_train_all, img_features_test_all, config, logger=None):
286
    logger = wandb_logger(config) if logger else None
287
    logger.watch(model,logger) 
288
    
289
    train_losses, train_accuracies = [], []
290
    test_losses, test_accuracies = [], []
291
    v2_accs = []
292
    v4_accs = []
293
    v10_accs = []
294
295
    best_accuracy = 0.0
296
    best_model_weights = None
297
    best_epoch_info = {}
298
    results = []
299
    for epoch in range(config['epochs']):
300
        
301
        train_loss, train_accuracy = train_model(model, train_dataloader, optimizer, device, text_features_train_all, img_features_train_all)
302
        
303
        if epoch%5 == 0:                        
304
            if config['insubject']==True:                
305
                torch.save(model.state_dict(), f"./models/{sub}_{epoch}.pth")
306
            else:
307
                torch.save(model.state_dict(), f"./models/across_{epoch}.pth")
308
        train_losses.append(train_loss)
309
        train_accuracies.append(train_accuracy)
310
311
        
312
        test_loss, test_accuracy, top5_acc = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all,k=200)
313
        _, v2_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 2)
314
        _, v4_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 4)
315
        _, v10_acc, _ = evaluate_model(model, test_dataloader, device, text_features_test_all, img_features_test_all, k = 10)
316
        test_losses.append(test_loss)
317
        test_accuracies.append(test_accuracy)
318
        v2_accs.append(v2_acc)
319
        v4_accs.append(v4_acc)
320
        v10_accs.append(v10_acc)
321
        # Append results for this epoch
322
        epoch_results = {
323
        "epoch": epoch + 1,
324
        "train_loss": train_loss,
325
        "train_accuracy": train_accuracy,
326
        "test_loss": test_loss,
327
        "test_accuracy": test_accuracy,
328
        "v2_acc": v2_acc,
329
        "v4_acc": v4_acc,
330
        "v10_acc": v10_acc,
331
        "top5_acc":top5_acc
332
        }
333
        results.append(epoch_results)
334
        
335
        if test_accuracy > best_accuracy:
336
            best_accuracy = test_accuracy
337
            best_model_weights = model.state_dict().copy()
338
            best_epoch_info = {
339
                "epoch": epoch + 1,
340
                "train_loss": train_loss,
341
                "train_accuracy": train_accuracy,
342
                "test_loss": test_loss,
343
                "test_accuracy": test_accuracy,
344
                "v2_acc":v2_acc,
345
                "v4_acc":v4_acc,
346
                "v10_acc":v10_acc
347
            }
348
        logger.log({
349
            "Train Loss": train_loss,
350
            "Train Accuracy": train_accuracy,
351
            "Test Loss": test_loss,
352
            "Test Accuracy": test_accuracy,
353
            "v2 Accuracy": v2_acc,
354
            "v4 Accuracy": v4_acc,
355
            "v10 Accuracy": v10_acc,
356
            "Epoch": epoch
357
        })
358
359
        print(f"Epoch {epoch + 1}/{config['epochs']} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, Top5 Accuracy: {top5_acc:.4f}")
360
        print(f"Epoch {epoch + 1}/{config['epochs']} - v2 Accuracy:{v2_acc} - v4 Accuracy:{v4_acc} - v10 Accuracy:{v10_acc}")
361
  
362
    
363
    # model.load_state_dict(best_model_weights)
364
365
    
366
    # torch.save(model.state_dict(), '{train_pos_img_text}.pth')
367
368
    
369
    fig, axs = plt.subplots(3, 2, figsize=(10, 15))
370
371
    
372
    axs[0, 0].plot(train_losses, label='Train Loss')
373
    axs[0, 0].plot(test_losses, label='Test Loss')
374
    axs[0, 0].legend()
375
    axs[0, 0].set_title("Loss Curve")
376
377
    
378
    axs[0, 1].plot(train_accuracies, label='Train Accuracy')
379
    axs[0, 1].plot(test_accuracies, label='Test Accuracy')
380
    axs[0, 1].legend()
381
    axs[0, 1].set_title("Accuracy Curve")
382
383
    
384
    
385
    axs[1, 0].plot(v2_accs, label='2-class Accuracy')
386
    axs[1, 0].legend()
387
    axs[1, 0].set_title("2-Class Accuracy Curve")
388
389
    
390
    axs[1, 1].plot(v4_accs, label='4-class Accuracy')
391
    axs[1, 1].legend()
392
    axs[1, 1].set_title("4-Class Accuracy Curve")
393
394
    
395
    axs[2, 0].plot(v10_accs, label='10-class Accuracy')
396
    axs[2, 0].legend()
397
    axs[2, 0].set_title("10-Class Accuracy Curve")
398
399
    
400
    info_text = (f"Best Model Info (from Epoch {best_epoch_info['epoch']}):\n"
401
                f"Train Loss: {best_epoch_info['train_loss']:.4f}\n"
402
                f"Train Accuracy: {best_epoch_info['train_accuracy']:.4f}\n"
403
                f"Test Loss: {best_epoch_info['test_loss']:.4f}\n"
404
                f"Test Accuracy: {best_epoch_info['test_accuracy']:.4f}\n"
405
                f"v2_acc:{best_epoch_info['v2_acc']:.4f}\n"
406
                f"v4_acc:{best_epoch_info['v4_acc']:.4f}\n"
407
                f"v10_acc:{best_epoch_info['v10_acc']:.4f}")
408
409
    axs[2, 1].axis('off')  
410
    axs[2, 1].text(0.5, 0.5, info_text, fontsize=10, ha='center', va='center', transform=axs[2, 1].transAxes)
411
412
    plt.tight_layout()
413
414
    
415
    plt.suptitle('pos_img_text', fontsize=16, y=1.05)
416
    plt.savefig('pos_img_text')
417
    logger.finish()
418
    return results
419
420
def main():
421
    parser = argparse.ArgumentParser(description='Train EEG-Image/Text Model')
422
423
    parser.add_argument('--data_path', type=str, default="/home/ldy/Workspace/THINGS/Preprocessed_data_250Hz", help='Path to the preprocessed data')
424
    parser.add_argument('--project', type=str, default="train_pos_img_text_rep", help='Project name')
425
    parser.add_argument('--entity', type=str, default="sustech_rethinkingbci", help='Entity name')
426
    parser.add_argument('--name', type=str, default="lr=3e-4_img_pos_pro_eeg", help='Experiment name')
427
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
428
    parser.add_argument('--epochs', type=int, default=40, help='Number of training epochs')
429
    parser.add_argument('--batch_size', type=int, default=1024, help='Batch size')
430
    parser.add_argument('--logger', action='store_true', help='Enable logging')
431
    parser.add_argument('--insubject', action='store_true', help='Train within subject')
432
    parser.add_argument('--encoder_type', type=str, default='ATM_E', help='EEG encoder model type')
433
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training (e.g., "cuda:0" or "cpu")')
434
435
    args = parser.parse_args()
436
437
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
438
    data_path = args.data_path
439
    subjects = ['sub-01', 'sub-02', 'sub-03', 'sub-04', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10']
440
441
    for sub in subjects:
442
        # Re-initialize the model for each subject
443
        model = globals()[args.encoder_type]((63, 250))
444
        model.to(device)
445
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
446
447
        print(f'Processing {sub}: number of parameters:', sum(p.numel() for p in model.parameters()))
448
449
        train_dataset = EEGDataset(
450
            data_path,
451
            subjects=[sub] if args.insubject else [],
452
            exclude_subject=sub if not args.insubject else None,
453
            train=True
454
        )
455
        test_dataset = EEGDataset(
456
            data_path,
457
            subjects=[sub] if args.insubject else [],
458
            exclude_subject=sub if not args.insubject else None,
459
            train=False
460
        )
461
462
        train_loader = DataLoader(
463
            train_dataset,
464
            batch_size=args.batch_size,
465
            shuffle=True,
466
            num_workers=0,
467
            drop_last=True
468
        )
469
        test_loader = DataLoader(
470
            test_dataset,
471
            batch_size=1,
472
            shuffle=True,
473
            num_workers=0,
474
            drop_last=True
475
        )
476
477
        text_features_train_all = train_dataset.text_features
478
        text_features_test_all = test_dataset.text_features
479
        img_features_train_all = train_dataset.img_features
480
        img_features_test_all = test_dataset.img_features
481
482
        config = vars(args)
483
484
        results = main_train_loop(
485
            sub,
486
            model,
487
            train_loader,
488
            test_loader,
489
            optimizer,
490
            device,
491
            text_features_train_all,
492
            text_features_test_all,
493
            img_features_train_all,
494
            img_features_test_all,
495
            config,
496
            logger=args.logger
497
        )
498
499
        # Save results to a CSV file
500
        current_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
501
        results_dir = f"./outputs/{args.encoder_type}/{sub}/{current_time}"
502
        os.makedirs(results_dir, exist_ok=True)
503
        results_file = f"{results_dir}/{args.encoder_type}_{'cross_exclude_' if not args.insubject else ''}{sub}.csv"
504
505
        with open(results_file, 'w', newline='') as file:
506
            writer = csv.DictWriter(file, fieldnames=results[0].keys())
507
            writer.writeheader()
508
            writer.writerows(results)
509
        print(f'Results saved to {results_file}')
510
511
if __name__ == '__main__':
512
    main()
513
    
514