Switch to side-by-side view

--- a
+++ b/train_sentiment_textbased.py
@@ -0,0 +1,368 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.optim import lr_scheduler
+from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler, random_split
+import pickle
+import json
+import matplotlib.pyplot as plt
+from glob import glob
+import time
+import copy
+from tqdm import tqdm
+
+from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
+from data import ZuCo_dataset, SST_tenary_dataset
+from model_sentiment import FineTunePretrainedTwoStep
+from config import get_config
+# Function to calculate the accuracy of our predictions vs labels
+def flat_accuracy(preds, labels):
+    # preds: numpy array: N * 3 
+    # labels: numpy array: N 
+    pred_flat = np.argmax(preds, axis=1).flatten()  
+    
+    labels_flat = labels.flatten()
+    
+    return np.sum(pred_flat == labels_flat) / len(labels_flat)
+
+def flat_accuracy_top_k(preds, labels,k):
+    topk_preds = []
+    for pred in preds:
+        topk = pred.argsort()[-k:][::-1]
+        topk_preds.append(list(topk))
+    # print(topk_preds)
+    topk_preds = list(topk_preds)
+    right_count = 0
+    # print(len(labels))
+    for i in range(len(labels)):
+        l = labels[i][0]
+        if l in topk_preds[i]:
+            right_count+=1
+    return right_count/len(labels)
+
+def train_model_ZuCo(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):
+    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
+    since = time.time()
+      
+    best_model_wts = copy.deepcopy(model.state_dict())
+    best_loss = 100000000000
+    best_acc = 0.0
+    
+
+    for epoch in range(num_epochs):
+        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
+        print('-' * 10)
+
+        # Each epoch has a training and validation phase
+        for phase in ['train', 'dev']:
+            total_accuracy = 0.0
+            if phase == 'train':
+                model.train()  # Set model to training mode
+            else:
+                model.eval()   # Set model to evaluate mode
+
+            running_loss = 0.0
+
+            # Iterate over data.
+            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in tqdm(dataloaders[phase]):
+                
+                # input_word_eeg_features = input_word_eeg_features.to(device).float()
+                # input_masks = input_masks.to(device)
+                # input_mask_invert = input_mask_invert.to(device)
+                target_ids = target_ids.to(device)
+                target_mask = target_mask.to(device)
+                sentiment_labels = sentiment_labels.to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward
+                output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)
+                logits = output.logits
+                loss = output.loss
+
+                # backward + optimize only if in training phase
+                if phase == 'train':
+                    # with torch.autograd.detect_anomaly():
+                    loss.backward()
+                    optimizer.step()
+
+                # calculate accuracy
+                preds_cpu = logits.detach().cpu().numpy()
+                label_cpu = sentiment_labels.cpu().numpy()
+
+                total_accuracy += flat_accuracy(preds_cpu, label_cpu)
+
+                # statistics
+                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss
+                # print('[DEBUG]loss:',loss.item())
+                # print('#################################')
+                
+
+            if phase == 'train':
+                scheduler.step()
+
+            epoch_loss = running_loss / dataset_sizes[phase]
+            epoch_acc = total_accuracy / len(dataloaders[phase])
+            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
+            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
+
+            # deep copy the model
+            if phase == 'dev' and (epoch_acc > best_acc):
+                best_loss = epoch_loss
+                best_acc = epoch_acc
+                best_model_wts = copy.deepcopy(model.state_dict())
+                '''save checkpoint'''
+                torch.save(model.state_dict(), checkpoint_path_best)
+                print(f'update best on dev checkpoint: {checkpoint_path_best}')
+        print()
+
+    time_elapsed = time.time() - since
+    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+    print('Best val loss: {:4f}'.format(best_loss))
+    print('Best val acc: {:4f}'.format(best_acc))
+    torch.save(model.state_dict(), checkpoint_path_last)
+    print(f'update last checkpoint: {checkpoint_path_last}')
+    
+    # write to log
+    with open(output_log_file_name, 'w') as outlog:
+        outlog.write(f'best val loss: {best_loss}\n')
+        outlog.write('Best val acc: {:4f}'.format(best_acc))
+    # load best model weights
+    model.load_state_dict(best_model_wts)
+    return model
+
+def train_model_SST(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, checkpoint_path_best = './checkpoints/text_sentiment_classifier/best/test.pt', checkpoint_path_last = './checkpoints/text_sentiment_classifier/last/test.pt'):
+    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
+    since = time.time()
+      
+    best_model_wts = copy.deepcopy(model.state_dict())
+    best_loss = 100000000000
+    best_acc = 0.0
+    
+
+    for epoch in range(num_epochs):
+        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
+        print('-' * 10)
+
+        # Each epoch has a training and validation phase
+        for phase in ['train', 'dev']:
+            total_accuracy = 0.0
+            if phase == 'train':
+                model.train()  # Set model to training mode
+            else:
+                model.eval()   # Set model to evaluate mode
+
+            running_loss = 0.0
+
+            # Iterate over data.
+            for input_ids,input_masks,sentiment_labels in tqdm(dataloaders[phase]):
+                
+                input_ids = input_ids.to(device)
+                input_masks = input_masks.to(device)
+                sentiment_labels = sentiment_labels.to(device)
+
+                # zero the parameter gradients
+                optimizer.zero_grad()
+
+                # forward
+                output = model(input_ids = input_ids, attention_mask = input_masks, return_dict = True, labels = sentiment_labels)
+                logits = output.logits
+                loss = output.loss
+
+                # backward + optimize only if in training phase
+                if phase == 'train':
+                    # with torch.autograd.detect_anomaly():
+                    loss.backward()
+                    optimizer.step()
+
+                # calculate accuracy
+                preds_cpu = logits.detach().cpu().numpy()
+                label_cpu = sentiment_labels.cpu().numpy()
+
+                total_accuracy += flat_accuracy(preds_cpu, label_cpu)
+
+                # statistics
+                running_loss += loss.item() * input_ids.size()[0] # batch loss
+                # print('[DEBUG]loss:',loss.item())
+                # print('#################################')
+                
+
+            if phase == 'train':
+                scheduler.step()
+
+            epoch_loss = running_loss / dataset_sizes[phase]
+            epoch_acc = total_accuracy / len(dataloaders[phase])
+            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
+            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
+
+            # deep copy the model
+            if phase == 'dev' and (epoch_acc > best_acc):
+                best_loss = epoch_loss
+                best_acc = epoch_acc
+                best_model_wts = copy.deepcopy(model.state_dict())
+                '''save checkpoint'''
+                torch.save(model.state_dict(), checkpoint_path_best)
+                print(f'update best on dev checkpoint: {checkpoint_path_best}')
+        print()
+
+    time_elapsed = time.time() - since
+    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
+    print('Best val loss: {:4f}'.format(best_loss))
+    print('Best val acc: {:4f}'.format(best_acc))
+    torch.save(model.state_dict(), checkpoint_path_last)
+    print(f'update last checkpoint: {checkpoint_path_last}')
+    
+    # load best model weights
+    model.load_state_dict(best_model_wts)
+    return model
+
+
+if __name__ == '__main__':
+    args = get_config('train_sentiment_textbased')
+
+    ''' config param'''
+
+    num_epoch = args['num_epoch']
+    # lr = 1e-3 # Bert, RoBerta
+    # lr = 1e-4 # Bart
+    lr = args['learning_rate']
+
+    dataset_name = args['dataset_name'] # zero-shot setting: using external dataset from stanford sentiment treebank, pass in 'SST'; or pass in 'ZuCo' to train on ZuCo's text-sentiment pairs
+
+    dataset_setting = 'unique_sent'
+
+    batch_size = args['batch_size']
+    
+    # model_name = 'pretrain_Bert'
+    # model_name = 'pretrain_RoBerta'
+    # model_name = 'pretrain_Bart'
+    model_name = args['model_name']
+    print(f'[INFO]model name: {model_name}')
+
+    save_path = args['save_path'] 
+
+    if dataset_name == 'ZuCo':
+        # subject_choice = 'ALL
+        subject_choice = args['subjects']
+        print(f'![Debug]using {subject_choice}')
+        # eeg_type_choice = 'GD
+        eeg_type_choice = args['eeg_type']
+        print(f'[INFO]eeg type {eeg_type_choice}')
+        # bands_choice = ['_t1'] 
+        # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
+        bands_choice = args['eeg_bands']
+        print(f'[INFO]using bands {bands_choice}')
+        save_name = f'Textbased_ZuCo_{model_name}_b{batch_size}_{num_epoch}_{lr}_{dataset_setting}_{eeg_type_choice}'
+    elif dataset_name == 'SST':
+        save_name = f'Textbased_StanfordSentitmentTreeband_{model_name}_b{batch_size}_{num_epoch}_{lr}'
+
+    output_checkpoint_name_best = save_path + f'/best/{save_name}.pt' 
+    output_checkpoint_name_last = save_path + f'/last/{save_name}.pt' 
+
+
+    ''' set random seeds '''
+    seed_val = 312
+    np.random.seed(seed_val)
+    torch.manual_seed(seed_val)
+    torch.cuda.manual_seed_all(seed_val)
+
+    ''' set up device '''
+    # use cuda
+    if torch.cuda.is_available():  
+        dev = args['cuda']
+    else:  
+        dev = "cpu"
+    # CUDA_VISIBLE_DEVICES=0,1,2,3  
+    device = torch.device(dev)
+    print(f'[INFO]using device {dev}')
+
+
+    ''' load pickle '''
+    if dataset_name == 'ZuCo':
+        whole_dataset_dict = []
+        dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' 
+        with open(dataset_path_task1, 'rb') as handle:
+            whole_dataset_dict.append(pickle.load(handle))
+    
+    '''tokenizer'''
+    if model_name == 'pretrain_Bert':
+        print('[INFO]pretrained checkpoint: bert-base-cased')
+        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
+    elif model_name == 'pretrain_RoBerta':
+        print('[INFO]pretrained checkpoint: roberta-base')
+        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
+    elif model_name == 'pretrain_Bart':
+        print('[INFO]pretrained checkpoint: bart-large')
+        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
+
+    ''' set up dataloader '''
+    if dataset_name == 'ZuCo':
+        # train dataset
+        train_set = ZuCo_dataset(whole_dataset_dict, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
+        # dev dataset
+        dev_set = ZuCo_dataset(whole_dataset_dict, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
+    
+    elif dataset_name == 'SST':
+        SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
+
+        SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)  
+        
+        train_size = int(0.9 * len(SST_dataset))
+        val_size = len(SST_dataset) - train_size
+
+        train_set, dev_set = random_split(SST_dataset, [train_size, val_size])
+        print('{:>5,} training samples'.format(len(train_set)))
+        print('{:>5,} validation samples'.format(len(dev_set)))
+
+
+    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
+    print('[INFO]train_set size: ', len(train_set))
+    print('[INFO]dev_set size: ', len(dev_set))
+    
+    # train dataloader
+    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
+    # dev dataloader
+    val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
+    # dataloaders
+    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}
+
+    ''' set up model '''
+    if model_name == 'pretrain_Bert':
+        model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
+    elif model_name == 'pretrain_RoBerta':
+        model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
+    elif model_name == 'pretrain_Bart':
+        model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels = 3)
+    
+    model.to(device)
+    
+
+    """save config"""
+    with open(f'./config/text_sentiment_classifier/{save_name}.json', 'w') as out_config:
+        json.dump(args, out_config, indent = 4)
+
+
+    ''' training loop '''
+    ######################################################
+    '''step one trainig: freeze most of BART params'''
+    ######################################################
+
+    ''' set up optimizer and scheduler'''
+    optimizer_step1 = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
+
+    exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=10, gamma=0.1)
+
+    # TODO: rethink about the loss function
+    ''' set up loss function '''
+    criterion = nn.CrossEntropyLoss()
+
+    # return best loss model from step1 training
+    print(f'=== start training {dataset_name} ... ===')
+    if dataset_name == 'ZuCo':
+        model = train_model_ZuCo(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
+    elif dataset_name == 'SST':
+        model = train_model_SST(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epoch, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
+        
\ No newline at end of file