Diff of /eval_sentiment.py [000000] .. [66af30]

Switch to unified view

a b/eval_sentiment.py
1
import os
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
import torch.optim as optim
6
from torch.optim import lr_scheduler
7
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
8
from torch.nn.utils.rnn import pack_padded_sequence 
9
import pickle
10
import json
11
import matplotlib.pyplot as plt
12
from glob import glob
13
import time
14
import copy
15
from tqdm import tqdm
16
17
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BartForSequenceClassification, BertTokenizer, BertConfig, BertForSequenceClassification, RobertaTokenizer, RobertaForSequenceClassification
18
from data import ZuCo_dataset
19
from model_sentiment import BaselineMLPSentence, BaselineLSTM, FineTunePretrainedTwoStep, ZeroShotSentimentDiscovery, JointBrainTranslatorSentimentClassifier
20
from model_decoding import BrainTranslator, BrainTranslatorNaive
21
from sklearn.metrics import precision_recall_fscore_support
22
from sklearn.metrics import accuracy_score
23
from config import get_config
24
25
# Function to calculate the accuracy of our predictions vs labels
26
def flat_accuracy(preds, labels):
27
    # preds: numpy array: N * 3 
28
    # labels: numpy array: N 
29
    pred_flat = np.argmax(preds, axis=1).flatten()  
30
    
31
    labels_flat = labels.flatten()
32
    
33
    return np.sum(pred_flat == labels_flat) / len(labels_flat)
34
35
def flat_accuracy_top_k(preds, labels,k):
36
    topk_preds = []
37
    for pred in preds:
38
        topk = pred.argsort()[-k:][::-1]
39
        topk_preds.append(list(topk))
40
    # print(topk_preds)
41
    topk_preds = list(topk_preds)
42
    right_count = 0
43
    # print(len(labels))
44
    for i in range(len(labels)):
45
        l = labels[i][0]
46
        if l in topk_preds[i]:
47
            right_count+=1
48
    return right_count/len(labels)
49
50
def eval_model(dataloaders, device, model, criterion, optimizer, scheduler, num_epochs=25, tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')):
51
52
    def logits2PredString(logits, tokenizer):
53
        probs = logits[0].softmax(dim = 1)
54
        # print('probs size:', probs.size())
55
        values, predictions = probs.topk(1)
56
        # print('predictions before squeeze:',predictions.size())
57
        predictions = torch.squeeze(predictions)
58
        predict_string = tokenizer.decode(predictions)
59
        return predict_string
60
    
61
    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
62
    since = time.time()
63
      
64
    best_model_wts = copy.deepcopy(model.state_dict())
65
    best_loss = 100000000000
66
    best_acc = 0.0
67
    
68
    total_pred_labels = np.array([])
69
    total_true_labels = np.array([])
70
71
    for epoch in range(1):
72
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
73
        print('-' * 10)
74
75
        # Each epoch has a training and validation phase
76
        for phase in ['test']:
77
            total_accuracy = 0.0
78
            if phase == 'train':
79
                model.train()  # Set model to training mode
80
            else:
81
                model.eval()   # Set model to evaluate mode
82
83
            running_loss = 0.0
84
85
            # Iterate over data.
86
            for input_word_eeg_features, seq_lens, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in dataloaders[phase]:
87
                
88
                input_word_eeg_features = input_word_eeg_features.to(device).float()
89
                input_masks = input_masks.to(device)
90
                input_mask_invert = input_mask_invert.to(device)
91
 
92
                sent_level_EEG = sent_level_EEG.to(device)
93
                sentiment_labels = sentiment_labels.to(device)
94
95
                target_ids = target_ids.to(device)
96
                target_mask = target_mask.to(device)
97
98
                ## forward ###################
99
                if isinstance(model, BaselineMLPSentence):
100
                    logits = model(sent_level_EEG) # before softmax
101
                    # calculate loss
102
                    loss = criterion(logits, sentiment_labels)
103
104
                elif isinstance(model, BaselineLSTM):
105
                    x_packed = pack_padded_sequence(input_word_eeg_features, seq_lens, batch_first=True, enforce_sorted=False)
106
                    logits = model(x_packed)
107
                    # calculate loss
108
                    loss = criterion(logits, sentiment_labels)
109
110
                elif isinstance(model, BertForSequenceClassification) or isinstance(model, RobertaForSequenceClassification) or isinstance(model, BartForSequenceClassification):
111
                    output = model(input_ids = target_ids, attention_mask = target_mask, return_dict = True, labels = sentiment_labels)
112
                    logits = output.logits
113
                    loss = output.loss
114
                
115
                elif isinstance(model, FineTunePretrainedTwoStep):
116
                    output = model(input_word_eeg_features, input_masks, input_mask_invert, sentiment_labels)
117
                    logits = output.logits
118
                    loss = output.loss
119
120
                elif isinstance(model, ZeroShotSentimentDiscovery):    
121
                    print()
122
                    print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0]) 
123
124
                    """replace padding ids in target_ids with -100"""
125
                    target_ids[target_ids == tokenizer.pad_token_id] = -100 
126
127
                    output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)
128
                    logits = output.logits
129
                    loss = output.loss
130
                
131
                elif isinstance(model, JointBrainTranslatorSentimentClassifier):
132
133
                    print()
134
                    print('target string:',tokenizer.decode(target_ids[0]).replace('<pad>','').split('</s>')[0]) 
135
136
                    """replace padding ids in target_ids with -100"""
137
                    target_ids[target_ids == tokenizer.pad_token_id] = -100 
138
139
                    LM_output, classification_output = model(input_word_eeg_features, input_masks, input_mask_invert, target_ids, sentiment_labels)
140
                    LM_logits = LM_output.logits
141
                    print('pred string:', logits2PredString(LM_logits, tokenizer).split('</s></s>')[0].replace('<s>',''))
142
                    classification_loss = classification_output['loss']
143
                    logits = classification_output['logits']
144
                    loss = classification_loss 
145
                ###############################
146
147
                # backward + optimize only if in training phase
148
                if phase == 'train':
149
                    # with torch.autograd.detect_anomaly():
150
                    loss.backward()
151
                    optimizer.step()
152
153
                # calculate accuracy
154
                preds_cpu = logits.detach().cpu().numpy()
155
                label_cpu = sentiment_labels.cpu().numpy()
156
157
                total_accuracy += flat_accuracy(preds_cpu, label_cpu)
158
                
159
                # add to total pred and label array, for cal F1, precision, recall
160
                pred_flat = np.argmax(preds_cpu, axis=1).flatten()
161
                labels_flat = label_cpu.flatten()
162
163
                total_pred_labels = np.concatenate((total_pred_labels,pred_flat))
164
                total_true_labels = np.concatenate((total_true_labels,labels_flat))
165
                
166
167
                # statistics
168
                running_loss += loss.item() * sent_level_EEG.size()[0] # batch loss
169
                # print('[DEBUG]loss:',loss.item())
170
                # print('#################################')
171
                
172
173
            if phase == 'train':
174
                scheduler.step()
175
176
            epoch_loss = running_loss / dataset_sizes[phase]
177
            epoch_acc = total_accuracy / len(dataloaders[phase])
178
            print('{} Loss: {:.4f}'.format(phase, epoch_loss))
179
            print('{} Acc: {:.4f}'.format(phase, epoch_acc))
180
181
            # deep copy the model
182
            if phase == 'test' and epoch_loss < best_loss:
183
                best_loss = epoch_loss
184
                best_acc = epoch_acc
185
        print()
186
187
    time_elapsed = time.time() - since
188
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
189
    print('Best test loss: {:4f}'.format(best_loss))
190
    print('Best test acc: {:4f}'.format(best_acc))
191
    print()
192
    print('test sample num:', len(total_pred_labels))
193
    print('total preds:',total_pred_labels)
194
    print('total truth:',total_true_labels)
195
    print('sklearn macro: precision, recall, F1:')
196
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='macro'))
197
    print()
198
    print('sklearn micro: precision, recall, F1:')
199
    print(precision_recall_fscore_support(total_true_labels, total_pred_labels, average='micro'))
200
    print()
201
    print('sklearn accuracy:')
202
    print(accuracy_score(total_true_labels,total_pred_labels))
203
    print()
204
205
206
207
if __name__ == '__main__':
208
    args = get_config('eval_sentiment')
209
210
    ''' config param'''
211
    num_epochs = 1
212
213
    dataset_setting = 'unique_sent'
214
    
215
    '''model name'''
216
    # model_name = 'BaselineMLP'
217
    # model_name = 'BaselineLSTM'
218
    # model_name = 'NaiveFinetuneBert'
219
    # model_name = 'FinetunedBertOnText'
220
    # model_name = 'FinetunedRoBertaOnText'
221
    # model_name = 'FinetunedBartOnText'
222
    # model_name = 'ZeroShotSentimentDiscovery'
223
    model_name = args['model_name']
224
225
    print(f'[INFO] eval {model_name}')
226
    if model_name == 'ZeroShotSentimentDiscovery':
227
        '''load decoder and classifier config'''
228
        config_decoder = json.load(open(args['decoder_config_path']))
229
        config_classifier = json.load(open(args['classifier_config_path']))
230
        '''choose generator'''
231
        # decoder_name = 'BrainTranslator'
232
        # decoder_name = 'BrainTranslatorNaive'
233
        decoder_name = config_decoder['model_name']
234
        decoder_checkpoint = args['decoder_checkpoint_path']
235
        print(f'[INFO] using decoder: {decoder_name}')
236
237
        '''choose classifier'''
238
        # pretrain_Bert, pretrain_RoBerta, pretrain_Bart
239
        classifier_name = config_classifier['model_name']
240
        classifier_checkpoint = args['classifier_checkpoint_path']
241
        print(f'[INFO] using classifier: {classifier_name}')
242
    else:
243
        checkpoint_path = args['checkpoint_path']
244
        print('[INFO] loading baseline:', checkpoint_path)
245
246
    batch_size = 1
247
248
249
    # subject_choice = 'ALL
250
    subject_choice = args['subjects']
251
    print(f'![Debug]using {subject_choice}')
252
    # eeg_type_choice = 'GD
253
    eeg_type_choice = args['eeg_type']
254
    print(f'[INFO]eeg type {eeg_type_choice}')
255
    # bands_choice = ['_t1'] 
256
    # bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
257
    bands_choice = args['eeg_bands']
258
    print(f'[INFO]using bands {bands_choice}')
259
260
261
    
262
    ''' set random seeds '''
263
    seed_val = 312
264
    np.random.seed(seed_val)
265
    torch.manual_seed(seed_val)
266
    torch.cuda.manual_seed_all(seed_val)
267
268
269
    ''' set up device '''
270
    # use cuda
271
    if torch.cuda.is_available():  
272
        dev = args['cuda']
273
    else:  
274
        dev = "cpu"
275
    # CUDA_VISIBLE_DEVICES=0,1,2,3  
276
    device = torch.device(dev)
277
    print(f'[INFO]using device {dev}')
278
279
280
    ''' load pickle'''
281
    whole_dataset_dict = []
282
    dataset_path_task1 = './dataset/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' 
283
    with open(dataset_path_task1, 'rb') as handle:
284
        whole_dataset_dict.append(pickle.load(handle))
285
    
286
    '''set up tokenizer'''
287
    if model_name in ['BaselineMLP','BaselineLSTM', 'NaiveFinetuneBert', 'FinetunedBertOnText']:
288
        print('[INFO]using Bert tokenizer')
289
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
290
    elif model_name == 'FinetunedBartOnText':
291
        print('[INFO]using Bart tokenizer')
292
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
293
    elif model_name == 'FinetunedRoBertaOnText':
294
        print('[INFO]using RoBerta tokenizer')
295
        tokenizer =  RobertaTokenizer.from_pretrained('roberta-base')
296
    elif model_name == 'ZeroShotSentimentDiscovery':
297
        decoder_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') # Bart
298
        tokenizer = decoder_tokenizer
299
        if classifier_name == 'pretrain_Bert':
300
            sentiment_tokenizer = BertTokenizer.from_pretrained('bert-base-cased') # Bert
301
        elif classifier_name == 'pretrain_Bart':
302
            sentiment_tokenizer = decoder_tokenizer
303
        elif classifier_name == 'pretrain_RoBerta':
304
            sentiment_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
305
306
    ''' set up model '''
307
    if model_name == 'BaselineMLP':
308
        print('[INFO]Model: BaselineMLP')
309
        model = BaselineMLPSentence(input_dim = 840, hidden_dim = 128, output_dim = 3)
310
    elif model_name == 'BaselineLSTM':
311
        print('[INFO]Model: BaselineLSTM')
312
        # model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 1)
313
        model = BaselineLSTM(input_dim = 840, hidden_dim = 256, output_dim = 3, num_layers = 4)
314
    elif model_name == 'FinetunedBertOnText':
315
        print('[INFO]Model: FinetunedBertOnText')
316
        model = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
317
    elif model_name == 'FinetunedRoBertaOnText':
318
        print('[INFO]Model: FinetunedRoBertaOnText')
319
        model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
320
    elif model_name == 'FinetunedBartOnText':
321
        print('[INFO]Model: FinetunedBartOnText')
322
        model = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)
323
    elif model_name == 'ZeroShotSentimentDiscovery':
324
        print(f'[INFO]Model: ZeroShotSentimentDiscovery, using classifer:{classifier_name}, using generator: {decoder_name}')
325
        pretrained = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
326
        if decoder_name == 'BrainTranslator':
327
            decoder = BrainTranslator(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
328
        elif decoder_name == 'BrainTranslatorNaive':
329
            decoder = BrainTranslatorNaive(pretrained, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
330
        decoder.load_state_dict(torch.load(decoder_checkpoint))
331
        
332
        if classifier_name == 'pretrain_Bert':
333
            classifier = BertForSequenceClassification.from_pretrained('bert-base-cased',num_labels=3)
334
        elif classifier_name == 'pretrain_Bart':
335
            classifier = BartForSequenceClassification.from_pretrained('facebook/bart-large', num_labels=3)
336
        elif classifier_name == 'pretrain_RoBerta':
337
            classifier = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
338
339
        classifier.load_state_dict(torch.load(classifier_checkpoint))
340
341
        model = ZeroShotSentimentDiscovery(decoder, classifier, decoder_tokenizer, sentiment_tokenizer, device = device)
342
        model.to(device)
343
344
    if model_name != 'ZeroShotSentimentDiscovery':
345
        # load model and send to device
346
        model.load_state_dict(torch.load(checkpoint_path))
347
        model.to(device)
348
349
    ''' set up dataloader '''
350
    # test dataset
351
    test_set = ZuCo_dataset(whole_dataset_dict, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = 'unique_sent')
352
353
    dataset_sizes = {'test': len(test_set)}
354
    # print('[INFO]train_set size: ', len(train_set))
355
    print('[INFO]test_set size: ', len(test_set))
356
    
357
    test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4)
358
    # dataloaders
359
    dataloaders = {'test':test_dataloader}
360
    
361
    ''' set up optimizer and scheduler'''
362
    optimizer_step1 = None
363
    exp_lr_scheduler_step1 = None
364
365
    ''' set up loss function '''
366
    criterion = nn.CrossEntropyLoss()
367
368
    print('=== start training ... ===')
369
    # return best loss model from step1 training
370
    model = eval_model(dataloaders, device, model, criterion, optimizer_step1, exp_lr_scheduler_step1, num_epochs=num_epochs, tokenizer = tokenizer)