a b/eval_decoding.py
1
import os
2
import numpy as np
3
import torch
4
import torch.nn as nn
5
import torch.optim as optim
6
from torch.optim import lr_scheduler
7
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
8
import pickle
9
import json
10
import matplotlib.pyplot as plt
11
from glob import glob
12
import time
13
import copy
14
from tqdm import tqdm
15
import re
16
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertTokenizer
17
from data import ZuCo_dataset
18
from model_decoding import BrainTranslator, BrainTranslatorNaive
19
from metrics import compute_metrics
20
from config import get_config
21
22
23
def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path='./results/temp.txt'):
24
    # modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
25
26
    model.eval()  # Set model to evaluate mode
27
    running_loss = 0.0
28
29
    # Iterate over data.
30
    # sample_count = 0
31
32
    target_tokens_list = []
33
    target_string_list = []
34
    pred_tokens_list = []
35
    pred_string_list = []
36
    with open(output_all_results_path, 'w') as f:
37
        # count=0
38
        for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in \
39
        dataloaders['test']:
40
            # count+=1
41
            # if count>5:
42
            #     break
43
            # load in batch
44
            input_embeddings_batch = input_embeddings.to(device).float()
45
            input_masks_batch = input_masks.to(device)
46
            target_ids_batch = target_ids.to(device)
47
            input_mask_invert_batch = input_mask_invert.to(device)
48
49
            if intput_noise:
50
                input_embeddings_batch=torch.rand_like(input_embeddings_batch)
51
            # target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch, skip_special_tokens = True)
52
            target_string = tokenizer.batch_decode(target_ids_batch, skip_special_tokens=True)
53
            # print('target ids tensor:',target_ids_batch[0])
54
            # print('target ids:',target_ids_batch[0].tolist())
55
            # print('+' * 100)
56
            # print('target tokens:',target_tokens)
57
            # print('target string:', target_string)
58
59
            # add to list for later calculate bleu metric
60
            # target_tokens_list.append([target_tokens])
61
            target_string_list.extend(target_string)
62
63
            """replace padding ids in target_ids with -100"""
64
            target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100
65
            if not teacher_forcing:
66
                predictions = model.generate(input_embeddings_batch, input_masks_batch, input_mask_invert_batch,
67
                                             target_ids_batch,
68
                                             max_length=100,
69
                                             num_beams=5, do_sample=False, repetition_penalty=5.0,
70
71
                                             # num_beams=5,encoder_no_repeat_ngram_size =1,
72
                                             # do_sample=True, top_k=15,temperature=0.5,num_return_sequences=5,
73
                                             # early_stopping=True
74
75
                                             )
76
            # predicted_string=predicted_string.squeeze()
77
            # print(f'predictions:{predictions}')
78
            # print(f'predicted_string:{predicted_string}')
79
            #
80
            # print(f'predicted_string:{predicted_string}')
81
            else:
82
                seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch,
83
                                        target_ids_batch)
84
                logits = seq2seqLMoutput.logits  # bs*seq_len*voc_sz
85
                probs = logits.softmax(dim=-1)
86
                values, predictions = probs.topk(1)
87
                predictions = torch.squeeze(predictions, dim=-1)
88
                # print(f'predictions:{predictions} predictions shape:{predictions.shape}')
89
            predicted_string = tokenizer.batch_decode(predictions, skip_special_tokens=True, )
90
            # print(f'predicted_string:{predicted_string}')
91
92
            # start = predicted_string.find("[CLS]") + len("[CLS]")
93
            # end = predicted_string.find("[SEP]")
94
            # predicted_string = predicted_string[start:end]
95
            # predicted_string=merge_consecutive_duplicates(predicted_string,'。')
96
            # predictions=tokenizer.encode(predicted_string)
97
            for str_id in range(len(target_string)):
98
                f.write(f'start################################################\n')
99
                f.write(f'Predicted: {predicted_string[str_id]}\n')
100
                f.write(f'True: {target_string[str_id]}\n')
101
                f.write(f'end################################################\n\n\n')
102
            # convert to int list
103
            # predictions = predictions.tolist()
104
            # truncated_prediction = []
105
            # for t in predictions:
106
            #     if t != tokenizer.eos_token_id:
107
            #         truncated_prediction.append(t)
108
            #     else:
109
            #         break
110
            # pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True)
111
            # pred_tokens_list.append(pred_tokens)
112
            pred_string_list.extend(predicted_string)
113
            # sample_count += 1
114
            # print('predicted tokens:',pred_tokens)
115
            # print('predicted string:',predicted_string)
116
            # print('-' * 100)
117
    # print(f'pred_string_list:{pred_string_list}')
118
    # print(f'target_string_list:{target_string_list}')
119
    metrics_results=compute_metrics(pred_string_list,target_string_list)
120
    print(f'teacher_forcing{teacher_forcing} intput_noise{intput_noise}')
121
    print(metrics_results)
122
    print(output_all_results_path)
123
    print(output_all_metrics_results_path)
124
    with open(output_all_metrics_results_path, "w") as json_file:
125
        json.dump(metrics_results, json_file, indent=4, ensure_ascii=False)
126
127
128
if __name__ == '__main__':
129
    home_directory = os.path.expanduser("~")
130
    ''' get args'''
131
    args = get_config('eval_decoding')
132
133
    ''' load training config'''
134
    training_config = json.load(open(args['config_path']))
135
136
    batch_size = 1
137
138
    subject_choice = training_config['subjects']
139
    print(f'[INFO]subjects: {subject_choice}')
140
    eeg_type_choice = training_config['eeg_type']
141
    print(f'[INFO]eeg type: {eeg_type_choice}')
142
    bands_choice = training_config['eeg_bands']
143
    print(f'[INFO]using bands: {bands_choice}')
144
145
    dataset_setting = 'unique_sent'
146
147
    task_name = training_config['task_name']
148
149
    model_name = training_config['model_name']
150
    # model_name = 'BrainTranslator'
151
    # model_name = 'BrainTranslatorNaive'
152
    # teacher_forcing = True
153
    # {'wer': 0.7980769276618958, 'rouge1_fmeasure': 23.912235260009766, 'rouge1_precision': 24.66936492919922, 'rouge1_recall': 23.318071365356445, 'rouge2_fmeasure': 6.851282119750977, 'rouge2_precision': 6.962162017822266, 'rouge2_recall': 6.751219272613525, 'rougeL_fmeasure': 22.912235260009766, 'rougeL_precision': 23.61673355102539, 'rougeL_recall': 22.36568832397461, 'rougeLsum_fmeasure': 22.912235260009766, 'rougeLsum_precision': 23.61673355102539, 'rougeLsum_recall': 22.36568832397461, 'bleu-1': 0.23883000016212463, 'bleu-2': 0.13888777792453766, 'bleu-3': 0.0, 'bleu-4': 0.0}
154
155
    teacher_forcing = eval(args['tf'])
156
    intput_noise = eval(args['noise'])
157
    print(f'teacher_forcing{teacher_forcing} intput_noise{intput_noise}')
158
    output_all_results_path = (f'./results/{task_name}-{model_name}{"-teacher_forcing" if teacher_forcing else ""}{"-intput_noise" if intput_noise else ""}-all_decoding_results.txt')
159
    output_all_metrics_results_path = output_all_results_path.replace('txt', 'json')
160
    ''' set random seeds '''
161
    seed_val = 312
162
    np.random.seed(seed_val)
163
    torch.manual_seed(seed_val)
164
    torch.cuda.manual_seed_all(seed_val)
165
166
167
    ''' set up device '''
168
    # use cuda
169
    if torch.cuda.is_available():
170
        dev = args['cuda']
171
    else:
172
        dev = "cpu"
173
    # CUDA_VISIBLE_DEVICES=0,1,2,3
174
    device = torch.device(dev)
175
    print(f'[INFO]using device {dev}')
176
177
178
    ''' set up dataloader '''
179
    whole_dataset_dicts = []
180
    if 'task1' in task_name:
181
        dataset_path_task1 = 'datasets/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle'
182
        dataset_path_task1=os.path.join(home_directory,dataset_path_task1)
183
        with open(dataset_path_task1, 'rb') as handle:
184
            whole_dataset_dicts.append(pickle.load(handle))
185
    if 'task2' in task_name:
186
        dataset_path_task2 = 'datasets/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle'
187
        dataset_path_task2=os.path.join(home_directory,dataset_path_task2)
188
        with open(dataset_path_task2, 'rb') as handle:
189
            whole_dataset_dicts.append(pickle.load(handle))
190
    if 'task3' in task_name:
191
        dataset_path_task3 = 'datasets/ZuCo/task3-TSR/pickle/task3-TSR-dataset.pickle'
192
        dataset_path_task3=os.path.join(home_directory,dataset_path_task3)
193
        with open(dataset_path_task3, 'rb') as handle:
194
            whole_dataset_dicts.append(pickle.load(handle))
195
    if 'taskNRv2' in task_name:
196
        dataset_path_taskNRv2 = 'datasets/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle'
197
        dataset_path_taskNRv2=os.path.join(home_directory,dataset_path_taskNRv2)
198
        with open(dataset_path_taskNRv2, 'rb') as handle:
199
            whole_dataset_dicts.append(pickle.load(handle))
200
    print()
201
202
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
203
204
    # test dataset
205
    test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
206
207
    dataset_sizes = {"test_set":len(test_set)}
208
    print('[INFO]test_set size: ', len(test_set))
209
210
    # dataloaders
211
    test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4)
212
213
    dataloaders = {'test':test_dataloader}
214
215
    ''' set up model '''
216
    checkpoint_path = args['checkpoint_path']
217
    pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large')
218
219
    if model_name == 'BrainTranslator':
220
        model = BrainTranslator(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
221
    elif model_name == 'BrainTranslatorNaive':
222
        model = BrainTranslatorNaive(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048)
223
224
    model.load_state_dict(torch.load(checkpoint_path))
225
    model.to(device)
226
227
    criterion = nn.CrossEntropyLoss()
228
229
    ''' eval '''
230
    eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path)