--- a +++ b/train_decoding.py @@ -0,0 +1,373 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.optim import lr_scheduler +from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler +import pickle +import json +import matplotlib.pyplot as plt +from glob import glob +import time +import copy +from tqdm import tqdm +from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, \ + BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, \ + RobertaForSequenceClassification + +from data import ZuCo_dataset +from model_decoding import BrainTranslator, BrainTranslatorNaive +from config import get_config + + +def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, + checkpoint_path_best='./checkpoints/decoding/best/temp_decoding.pt', + checkpoint_path_last='./checkpoints/decoding/last/temp_decoding.pt'): + # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html + os.makedirs(os.path.dirname(checkpoint_path_best), exist_ok=True) + since = time.time() + + best_model_wts = copy.deepcopy(model.state_dict()) + best_loss = 100000000000 + + for epoch in range(num_epochs): + print('Epoch {}/{}'.format(epoch, num_epochs - 1)) + print('-' * 10) + + # Each epoch has a training and validation phase + for phase in ['train', 'dev']: + if phase == 'train': + model.train() # Set model to training mode + else: + model.eval() # Set model to evaluate mode + + running_loss = 0.0 + + # Iterate over data. + for (input_embeddings, seq_len, input_masks, input_mask_invert, + target_ids, target_mask, sentiment_labels, sent_level_EEG) in tqdm(dataloaders[phase]): + # print(input_embeddings, seq_len, input_masks, input_mask_invert, + # target_ids, target_mask, sentiment_labels, sent_level_EEG) + # load in batch + input_embeddings_batch = input_embeddings.to(device).float() + input_masks_batch = input_masks.to(device) + input_mask_invert_batch = input_mask_invert.to(device) + target_ids_batch = target_ids.to(device) + """replace padding ids in target_ids with -100""" + target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, + target_ids_batch) + + """calculate loss""" + # logits = seq2seqLMoutput.logits # 8*48*50265 + # logits = logits.permute(0,2,1) # 8*50265*48 + + # loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts + # NOTE: my criterion not used + loss = seq2seqLMoutput.loss # use the BART language modeling loss + + # """check prediction, instance 0 of each batch""" + # print('target size:', target_ids_batch.size(), ',original logits size:', logits.size(), ',target_mask size', target_mask_batch.size()) + # logits = logits.permute(0,2,1) + # for idx in [0]: + # print(f'-- instance {idx} --') + # # print('permuted logits size:', logits.size()) + # probs = logits[idx].softmax(dim = 1) + # # print('probs size:', probs.size()) + # values, predictions = probs.topk(1) + # # print('predictions before squeeze:',predictions.size()) + # predictions = torch.squeeze(predictions) + # # print('predictions:',predictions) + # # print('target mask:', target_mask_batch[idx]) + # # print('[DEBUG]target tokens:',tokenizer.decode(target_ids_batch_copy[idx])) + # print('[DEBUG]predicted tokens:',tokenizer.decode(predictions)) + + # backward + optimize only if in training phase + if phase == 'train': + # with torch.autograd.detect_anomaly(): + loss.backward() + optimizer.step() + + # statistics + running_loss += loss.item() * input_embeddings_batch.size()[0] # batch loss + # print('[DEBUG]loss:',loss.item()) + # print('#################################') + + if phase == 'train': + scheduler.step() + + epoch_loss = running_loss / dataset_sizes[phase] + + print('{} Loss: {:.4f}'.format(phase, epoch_loss)) + + # deep copy the model + if phase == 'dev' and epoch_loss < best_loss: + best_loss = epoch_loss + best_model_wts = copy.deepcopy(model.state_dict()) + '''save checkpoint''' + torch.save(model.state_dict(), checkpoint_path_best) + print(f'update best on dev checkpoint: {checkpoint_path_best}') + print() + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) + print('Best val loss: {:4f}'.format(best_loss)) + torch.save(model.state_dict(), checkpoint_path_last) + print(f'update last checkpoint: {checkpoint_path_last}') + + # load best model weights + model.load_state_dict(best_model_wts) + return model + + +def show_require_grad_layers(model): + print() + print(' require_grad layers:') + # sanity check + for name, param in model.named_parameters(): + if param.requires_grad: + print(' ', name) + +if __name__ == '__main__': + home_directory = os.path.expanduser("~") + args = get_config('train_decoding') + + ''' config param''' + dataset_setting = 'unique_sent' + + num_epochs_step1 = args['num_epoch_step1'] + num_epochs_step2 = args['num_epoch_step2'] + step1_lr = args['learning_rate_step1'] + step2_lr = args['learning_rate_step2'] + + batch_size = args['batch_size'] + + model_name = args['model_name'] + # model_name = 'BrainTranslatorNaive' # with no additional transformers + # model_name = 'BrainTranslator' + + # task_name = 'task1' + # task_name = 'task1_task2' + # task_name = 'task1_task2_task3' + # task_name = 'task1_task2_taskNRv2' + task_name = args['task_name'] + + save_path = args['save_path'] + if not os.path.exists(save_path): + os.makedirs(save_path) + + skip_step_one = args['skip_step_one'] + load_step1_checkpoint = args['load_step1_checkpoint'] + use_random_init = args['use_random_init'] + + if use_random_init and skip_step_one: + step2_lr = 5*1e-4 + + print(f'[INFO]using model: {model_name}') + + if skip_step_one: + save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}' + else: + save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}' + + if use_random_init: + save_name = 'randinit_' + save_name + + save_path_best = os.path.join(save_path, 'best') + if not os.path.exists(save_path_best): + os.makedirs(save_path_best) + + output_checkpoint_name_best = os.path.join(save_path_best, f'{save_name}.pt') + + save_path_last = os.path.join(save_path, 'last') + if not os.path.exists(save_path_last): + os.makedirs(save_path_last) + + output_checkpoint_name_last = os.path.join(save_path_last, f'{save_name}.pt') + + # subject_choice = 'ALL + subject_choice = args['subjects'] + print(f'![Debug]using {subject_choice}') + # eeg_type_choice = 'GD + eeg_type_choice = args['eeg_type'] + print(f'[INFO]eeg type {eeg_type_choice}') + # bands_choice = ['_t1'] + # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] + bands_choice = args['eeg_bands'] + print(f'[INFO]using bands {bands_choice}') + + + + ''' set random seeds ''' + seed_val = 312 + np.random.seed(seed_val) + torch.manual_seed(seed_val) + torch.cuda.manual_seed_all(seed_val) + + + ''' set up device ''' + # use cuda + if torch.cuda.is_available(): + # dev = "cuda:3" + dev = args['cuda'] + else: + dev = "cpu" + # CUDA_VISIBLE_DEVICES=0,1,2,3 + device = torch.device(dev) + print(f'[INFO]using device {dev}') + print() + + ''' set up dataloader ''' + whole_dataset_dicts = [] + if 'task1' in task_name: + dataset_path_task1 = 'datasets/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' + dataset_path_task1=os.path.join(home_directory,dataset_path_task1) + with open(dataset_path_task1, 'rb') as handle: + whole_dataset_dicts.append(pickle.load(handle)) + if 'task2' in task_name: + dataset_path_task2 = 'datasets/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle' + dataset_path_task2=os.path.join(home_directory,dataset_path_task2) + with open(dataset_path_task2, 'rb') as handle: + whole_dataset_dicts.append(pickle.load(handle)) + if 'task3' in task_name: + dataset_path_task3 = 'datasets/ZuCo/task3-TSR/pickle/task3-TSR-dataset.pickle' + dataset_path_task3=os.path.join(home_directory,dataset_path_task3) + with open(dataset_path_task3, 'rb') as handle: + whole_dataset_dicts.append(pickle.load(handle)) + if 'taskNRv2' in task_name: + dataset_path_taskNRv2 = 'datasets/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle' + dataset_path_taskNRv2=os.path.join(home_directory,dataset_path_taskNRv2) + with open(dataset_path_taskNRv2, 'rb') as handle: + whole_dataset_dicts.append(pickle.load(handle)) + + print() + + """save config""" + cfg_dir = './config/decoding/' + + if not os.path.exists(cfg_dir): + os.makedirs(cfg_dir) + + with open(os.path.join(cfg_dir,f'{save_name}.json'), 'w') as out_config: + json.dump(args, out_config, indent = 4) + + if model_name in ['BrainTranslator','BrainTranslatorNaive']: + tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') + elif model_name == 'BertGeneration': + tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + config = BertConfig.from_pretrained("bert-base-cased") + config.is_decoder = True + + # train dataset + train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) + # dev dataset + dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) + # test dataset + # test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice) + + dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)} + print('[INFO]train_set size: ', len(train_set)) + print('[INFO]dev_set size: ', len(dev_set)) + + # train dataloader + train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4) + # dev dataloader + val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4) + # dataloaders + dataloaders = {'train':train_dataloader, 'dev':val_dataloader} + + ''' set up model ''' + if model_name == 'BrainTranslator': + if use_random_init: + config = BartConfig.from_pretrained('facebook/bart-large') + pretrained = BartForConditionalGeneration(config) + else: + pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large') + + model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) + + elif model_name == 'BertGeneration': + pretrained = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 768, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) + elif model_name == 'BrainTranslatorNaive': + pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large') + model = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) + + model.to(device) + + ''' training loop ''' + + ###################################################### + '''step one trainig: freeze most of BART params''' + ###################################################### + + # closely follow BART paper + if model_name in ['BrainTranslator','BrainTranslatorNaive']: + for name, param in model.named_parameters(): + if param.requires_grad and 'pretrained' in name: + if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name): + continue + else: + param.requires_grad = False + elif model_name == 'BertGeneration': + for name, param in model.named_parameters(): + if param.requires_grad and 'pretrained' in name: + if ('embeddings' in name) or ('encoder.layer.0' in name): + continue + else: + param.requires_grad = False + + + if skip_step_one: + if load_step1_checkpoint: + stepone_checkpoint = 'path_to_step_1_checkpoint.pt' + print(f'skip step one, load checkpoint: {stepone_checkpoint}') + model.load_state_dict(torch.load(stepone_checkpoint)) + else: + print('skip step one, start from scratch at step two') + else: + + ''' set up optimizer and scheduler''' + optimizer_step1 = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9) + + exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.1) + + ''' set up loss function ''' + criterion = nn.CrossEntropyLoss() + + print('=== start Step1 training ... ===') + # print training layers + show_require_grad_layers(model) + # return best loss model from step1 training + model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) + + ###################################################### + '''step two trainig: update whole model for a few iterations''' + ###################################################### + for name, param in model.named_parameters(): + param.requires_grad = True + + ''' set up optimizer and scheduler''' + optimizer_step2 = optim.SGD(model.parameters(), lr=step2_lr, momentum=0.9) + + exp_lr_scheduler_step2 = lr_scheduler.StepLR(optimizer_step2, step_size=30, gamma=0.1) + + ''' set up loss function ''' + criterion = nn.CrossEntropyLoss() + + print() + print('=== start Step2 training ... ===') + # print training layers + show_require_grad_layers(model) + + '''main loop''' + trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last) + + # '''save checkpoint''' + # torch.save(trained_model.state_dict(), os.path.join(save_path,output_checkpoint_name))