Diff of /train_decoding.py [000000] .. [66af30]

Switch to unified view

a b/train_decoding.py
1
import os
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
import torch.optim as optim
6
from torch.optim import lr_scheduler
7
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
8
import pickle
9
import json
10
import matplotlib.pyplot as plt
11
from glob import glob
12
import time
13
import copy
14
from tqdm import tqdm
15
from transformers import BertLMHeadModel, BartTokenizer, BartForConditionalGeneration, BartConfig, \
16
    BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, \
17
    RobertaForSequenceClassification
18
19
from data import ZuCo_dataset
20
from model_decoding import BrainTranslator, BrainTranslatorNaive
21
from config import get_config
22
23
24
def train_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25,
25
                checkpoint_path_best='./checkpoints/decoding/best/temp_decoding.pt',
26
                checkpoint_path_last='./checkpoints/decoding/last/temp_decoding.pt'):
27
    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
28
    os.makedirs(os.path.dirname(checkpoint_path_best), exist_ok=True)
29
    since = time.time()
30
31
    best_model_wts = copy.deepcopy(model.state_dict())
32
    best_loss = 100000000000
33
34
    for epoch in range(num_epochs):
35
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
36
        print('-' * 10)
37
38
        # Each epoch has a training and validation phase
39
        for phase in ['train', 'dev']:
40
            if phase == 'train':
41
                model.train()  # Set model to training mode
42
            else:
43
                model.eval()  # Set model to evaluate mode
44
45
            running_loss = 0.0
46
47
            # Iterate over data.
48
            for (input_embeddings, seq_len, input_masks, input_mask_invert,
49
                 target_ids, target_mask, sentiment_labels, sent_level_EEG) in tqdm(dataloaders[phase]):
50
                # print(input_embeddings, seq_len, input_masks, input_mask_invert,
51
                #  target_ids, target_mask, sentiment_labels, sent_level_EEG)
52
                # load in batch
53
                input_embeddings_batch = input_embeddings.to(device).float()
54
                input_masks_batch = input_masks.to(device)
55
                input_mask_invert_batch = input_mask_invert.to(device)
56
                target_ids_batch = target_ids.to(device)
57
                """replace padding ids in target_ids with -100"""
58
                target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100
59
60
                # zero the parameter gradients
61
                optimizer.zero_grad()
62
63
                # forward
64
                seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch,
65
                                        target_ids_batch)
66
67
                """calculate loss"""
68
                # logits = seq2seqLMoutput.logits # 8*48*50265
69
                # logits = logits.permute(0,2,1) # 8*50265*48
70
71
                # loss = criterion(logits, target_ids_batch_label) # calculate cross entropy loss only on encoded target parts
72
                # NOTE: my criterion not used
73
                loss = seq2seqLMoutput.loss  # use the BART language modeling loss
74
75
                # """check prediction, instance 0 of each batch"""
76
                # print('target size:', target_ids_batch.size(), ',original logits size:', logits.size(), ',target_mask size', target_mask_batch.size())
77
                # logits = logits.permute(0,2,1)
78
                # for idx in [0]:
79
                #     print(f'-- instance {idx} --')
80
                #     # print('permuted logits size:', logits.size())
81
                #     probs = logits[idx].softmax(dim = 1)
82
                #     # print('probs size:', probs.size())
83
                #     values, predictions = probs.topk(1)
84
                #     # print('predictions before squeeze:',predictions.size())
85
                #     predictions = torch.squeeze(predictions)
86
                #     # print('predictions:',predictions)
87
                #     # print('target mask:', target_mask_batch[idx])
88
                #     # print('[DEBUG]target tokens:',tokenizer.decode(target_ids_batch_copy[idx]))
89
                #     print('[DEBUG]predicted tokens:',tokenizer.decode(predictions))
90
91
                # backward + optimize only if in training phase
92
                if phase == 'train':
93
                    # with torch.autograd.detect_anomaly():
94
                    loss.backward()
95
                    optimizer.step()
96
97
                # statistics
98
                running_loss += loss.item() * input_embeddings_batch.size()[0]  # batch loss
99
                # print('[DEBUG]loss:',loss.item())
100
                # print('#################################')
101
102
            if phase == 'train':
103
                scheduler.step()
104
105
            epoch_loss = running_loss / dataset_sizes[phase]
106
107
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
108
109
            # deep copy the model
110
            if phase == 'dev' and epoch_loss < best_loss:
111
                best_loss = epoch_loss
112
                best_model_wts = copy.deepcopy(model.state_dict())
113
                '''save checkpoint'''
114
                torch.save(model.state_dict(), checkpoint_path_best)
115
                print(f'update best on dev checkpoint: {checkpoint_path_best}')
116
        print()
117
118
    time_elapsed = time.time() - since
119
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
120
    print('Best val loss: {:4f}'.format(best_loss))
121
    torch.save(model.state_dict(), checkpoint_path_last)
122
    print(f'update last checkpoint: {checkpoint_path_last}')
123
124
    # load best model weights
125
    model.load_state_dict(best_model_wts)
126
    return model
127
128
129
def show_require_grad_layers(model):
130
    print()
131
    print(' require_grad layers:')
132
    # sanity check
133
    for name, param in model.named_parameters():
134
        if param.requires_grad:
135
            print(' ', name)
136
137
if __name__ == '__main__':
138
    home_directory = os.path.expanduser("~")
139
    args = get_config('train_decoding')
140
141
    ''' config param'''
142
    dataset_setting = 'unique_sent'
143
    
144
    num_epochs_step1 = args['num_epoch_step1']
145
    num_epochs_step2 = args['num_epoch_step2']
146
    step1_lr = args['learning_rate_step1']
147
    step2_lr = args['learning_rate_step2']
148
    
149
    batch_size = args['batch_size']
150
    
151
    model_name = args['model_name']
152
    # model_name = 'BrainTranslatorNaive' # with no additional transformers
153
    # model_name = 'BrainTranslator' 
154
    
155
    # task_name = 'task1'
156
    # task_name = 'task1_task2'
157
    # task_name = 'task1_task2_task3'
158
    # task_name = 'task1_task2_taskNRv2'
159
    task_name = args['task_name']
160
161
    save_path = args['save_path']
162
    if not os.path.exists(save_path):
163
        os.makedirs(save_path)
164
165
    skip_step_one = args['skip_step_one']
166
    load_step1_checkpoint = args['load_step1_checkpoint']
167
    use_random_init = args['use_random_init']
168
169
    if use_random_init and skip_step_one:
170
        step2_lr = 5*1e-4
171
        
172
    print(f'[INFO]using model: {model_name}')
173
    
174
    if skip_step_one:
175
        save_name = f'{task_name}_finetune_{model_name}_skipstep1_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}'
176
    else:
177
        save_name = f'{task_name}_finetune_{model_name}_2steptraining_b{batch_size}_{num_epochs_step1}_{num_epochs_step2}_{step1_lr}_{step2_lr}_{dataset_setting}'
178
    
179
    if use_random_init:
180
        save_name = 'randinit_' + save_name
181
182
    save_path_best = os.path.join(save_path, 'best')
183
    if not os.path.exists(save_path_best):
184
        os.makedirs(save_path_best)
185
186
    output_checkpoint_name_best = os.path.join(save_path_best, f'{save_name}.pt')
187
188
    save_path_last = os.path.join(save_path, 'last')
189
    if not os.path.exists(save_path_last):
190
        os.makedirs(save_path_last)
191
192
    output_checkpoint_name_last = os.path.join(save_path_last, f'{save_name}.pt')
193
194
    # subject_choice = 'ALL
195
    subject_choice = args['subjects']
196
    print(f'![Debug]using {subject_choice}')
197
    # eeg_type_choice = 'GD
198
    eeg_type_choice = args['eeg_type']
199
    print(f'[INFO]eeg type {eeg_type_choice}')
200
    # bands_choice = ['_t1'] 
201
    # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
202
    bands_choice = args['eeg_bands']
203
    print(f'[INFO]using bands {bands_choice}')
204
205
206
    
207
    ''' set random seeds '''
208
    seed_val = 312
209
    np.random.seed(seed_val)
210
    torch.manual_seed(seed_val)
211
    torch.cuda.manual_seed_all(seed_val)
212
213
214
    ''' set up device '''
215
    # use cuda
216
    if torch.cuda.is_available():  
217
        # dev = "cuda:3" 
218
        dev = args['cuda'] 
219
    else:  
220
        dev = "cpu"
221
    # CUDA_VISIBLE_DEVICES=0,1,2,3  
222
    device = torch.device(dev)
223
    print(f'[INFO]using device {dev}')
224
    print()
225
226
    ''' set up dataloader '''
227
    whole_dataset_dicts = []
228
    if 'task1' in task_name:
229
        dataset_path_task1 = 'datasets/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
230
        dataset_path_task1=os.path.join(home_directory,dataset_path_task1)
231
        with open(dataset_path_task1, 'rb') as handle:
232
            whole_dataset_dicts.append(pickle.load(handle))
233
    if 'task2' in task_name:
234
        dataset_path_task2 = 'datasets/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle'
235
        dataset_path_task2=os.path.join(home_directory,dataset_path_task2)
236
        with open(dataset_path_task2, 'rb') as handle:
237
            whole_dataset_dicts.append(pickle.load(handle))
238
    if 'task3' in task_name:
239
        dataset_path_task3 = 'datasets/ZuCo/task3-TSR/pickle/task3-TSR-dataset.pickle'
240
        dataset_path_task3=os.path.join(home_directory,dataset_path_task3)
241
        with open(dataset_path_task3, 'rb') as handle:
242
            whole_dataset_dicts.append(pickle.load(handle))
243
    if 'taskNRv2' in task_name:
244
        dataset_path_taskNRv2 = 'datasets/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle'
245
        dataset_path_taskNRv2=os.path.join(home_directory,dataset_path_taskNRv2)
246
        with open(dataset_path_taskNRv2, 'rb') as handle:
247
            whole_dataset_dicts.append(pickle.load(handle))
248
249
    print()
250
251
    """save config"""
252
    cfg_dir = './config/decoding/'
253
254
    if not os.path.exists(cfg_dir):
255
        os.makedirs(cfg_dir)
256
257
    with open(os.path.join(cfg_dir,f'{save_name}.json'), 'w') as out_config:
258
        json.dump(args, out_config, indent = 4)
259
260
    if model_name in ['BrainTranslator','BrainTranslatorNaive']:
261
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
262
    elif model_name == 'BertGeneration':
263
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
264
        config = BertConfig.from_pretrained("bert-base-cased")
265
        config.is_decoder = True
266
267
    # train dataset
268
    train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
269
    # dev dataset
270
    dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
271
    # test dataset
272
    # test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice)
273
274
    dataset_sizes = {'train': len(train_set), 'dev': len(dev_set)}
275
    print('[INFO]train_set size: ', len(train_set))
276
    print('[INFO]dev_set size: ', len(dev_set))
277
    
278
    # train dataloader
279
    train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=4)
280
    # dev dataloader
281
    val_dataloader = DataLoader(dev_set, batch_size = 1, shuffle=False, num_workers=4)
282
    # dataloaders
283
    dataloaders = {'train':train_dataloader, 'dev':val_dataloader}
284
285
    ''' set up model '''
286
    if model_name == 'BrainTranslator':
287
        if use_random_init:
288
            config = BartConfig.from_pretrained('facebook/bart-large')
289
            pretrained = BartForConditionalGeneration(config)
290
        else:
291
            pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
292
    
293
        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
294
    
295
    elif model_name == 'BertGeneration':
296
        pretrained = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
297
        model = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 768, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
298
    elif model_name == 'BrainTranslatorNaive':
299
        pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
300
        model = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
301
302
    model.to(device)
303
    
304
    ''' training loop '''
305
306
    ######################################################
307
    '''step one trainig: freeze most of BART params'''
308
    ######################################################
309
310
    # closely follow BART paper
311
    if model_name in ['BrainTranslator','BrainTranslatorNaive']:
312
        for name, param in model.named_parameters():
313
            if param.requires_grad and 'pretrained' in name:
314
                if ('shared' in name) or ('embed_positions' in name) or ('encoder.layers.0' in name):
315
                    continue
316
                else:
317
                    param.requires_grad = False
318
    elif model_name == 'BertGeneration':
319
        for name, param in model.named_parameters():
320
            if param.requires_grad and 'pretrained' in name:
321
                if ('embeddings' in name) or ('encoder.layer.0' in name):
322
                    continue
323
                else:
324
                    param.requires_grad = False
325
 
326
327
    if skip_step_one:
328
        if load_step1_checkpoint:
329
            stepone_checkpoint = 'path_to_step_1_checkpoint.pt'
330
            print(f'skip step one, load checkpoint: {stepone_checkpoint}')
331
            model.load_state_dict(torch.load(stepone_checkpoint))
332
        else:
333
            print('skip step one, start from scratch at step two')
334
    else:
335
336
        ''' set up optimizer and scheduler'''
337
        optimizer_step1 = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=step1_lr, momentum=0.9)
338
339
        exp_lr_scheduler_step1 = lr_scheduler.StepLR(optimizer_step1, step_size=20, gamma=0.1)
340
341
        ''' set up loss function '''
342
        criterion = nn.CrossEntropyLoss()
343
344
        print('=== start Step1 training ... ===')
345
        # print training layers
346
        show_require_grad_layers(model)
347
        # return best loss model from step1 training
348
        model = train_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs_step1, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
349
350
    ######################################################
351
    '''step two trainig: update whole model for a few iterations'''
352
    ######################################################
353
    for name, param in model.named_parameters():
354
        param.requires_grad = True
355
356
    ''' set up optimizer and scheduler'''
357
    optimizer_step2 = optim.SGD(model.parameters(), lr=step2_lr, momentum=0.9)
358
359
    exp_lr_scheduler_step2 = lr_scheduler.StepLR(optimizer_step2, step_size=30, gamma=0.1)
360
361
    ''' set up loss function '''
362
    criterion = nn.CrossEntropyLoss()
363
    
364
    print()
365
    print('=== start Step2 training ... ===')
366
    # print training layers
367
    show_require_grad_layers(model)
368
    
369
    '''main loop'''
370
    trained_model = train_model(dataloaders, device, model, criterion, optimizer_step2, exp_lr_scheduler_step2, num_epochs=num_epochs_step2, checkpoint_path_best = output_checkpoint_name_best, checkpoint_path_last = output_checkpoint_name_last)
371
372
    # '''save checkpoint'''
373
    # torch.save(trained_model.state_dict(), os.path.join(save_path,output_checkpoint_name))