|
a |
|
b/eval_decoding.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
import torch.nn as nn |
|
|
5 |
import torch.optim as optim |
|
|
6 |
from torch.optim import lr_scheduler |
|
|
7 |
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler |
|
|
8 |
import pickle |
|
|
9 |
import json |
|
|
10 |
import matplotlib.pyplot as plt |
|
|
11 |
from glob import glob |
|
|
12 |
import time |
|
|
13 |
import copy |
|
|
14 |
from tqdm import tqdm |
|
|
15 |
import re |
|
|
16 |
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig, BertTokenizer |
|
|
17 |
from data import ZuCo_dataset |
|
|
18 |
from model_decoding import BrainTranslator, BrainTranslatorNaive |
|
|
19 |
from metrics import compute_metrics |
|
|
20 |
from config import get_config |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path='./results/temp.txt'): |
|
|
24 |
# modified from: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html |
|
|
25 |
|
|
|
26 |
model.eval() # Set model to evaluate mode |
|
|
27 |
running_loss = 0.0 |
|
|
28 |
|
|
|
29 |
# Iterate over data. |
|
|
30 |
# sample_count = 0 |
|
|
31 |
|
|
|
32 |
target_tokens_list = [] |
|
|
33 |
target_string_list = [] |
|
|
34 |
pred_tokens_list = [] |
|
|
35 |
pred_string_list = [] |
|
|
36 |
with open(output_all_results_path, 'w') as f: |
|
|
37 |
# count=0 |
|
|
38 |
for input_embeddings, seq_len, input_masks, input_mask_invert, target_ids, target_mask, sentiment_labels, sent_level_EEG in \ |
|
|
39 |
dataloaders['test']: |
|
|
40 |
# count+=1 |
|
|
41 |
# if count>5: |
|
|
42 |
# break |
|
|
43 |
# load in batch |
|
|
44 |
input_embeddings_batch = input_embeddings.to(device).float() |
|
|
45 |
input_masks_batch = input_masks.to(device) |
|
|
46 |
target_ids_batch = target_ids.to(device) |
|
|
47 |
input_mask_invert_batch = input_mask_invert.to(device) |
|
|
48 |
|
|
|
49 |
if intput_noise: |
|
|
50 |
input_embeddings_batch=torch.rand_like(input_embeddings_batch) |
|
|
51 |
# target_tokens = tokenizer.convert_ids_to_tokens(target_ids_batch, skip_special_tokens = True) |
|
|
52 |
target_string = tokenizer.batch_decode(target_ids_batch, skip_special_tokens=True) |
|
|
53 |
# print('target ids tensor:',target_ids_batch[0]) |
|
|
54 |
# print('target ids:',target_ids_batch[0].tolist()) |
|
|
55 |
# print('+' * 100) |
|
|
56 |
# print('target tokens:',target_tokens) |
|
|
57 |
# print('target string:', target_string) |
|
|
58 |
|
|
|
59 |
# add to list for later calculate bleu metric |
|
|
60 |
# target_tokens_list.append([target_tokens]) |
|
|
61 |
target_string_list.extend(target_string) |
|
|
62 |
|
|
|
63 |
"""replace padding ids in target_ids with -100""" |
|
|
64 |
target_ids_batch[target_ids_batch == tokenizer.pad_token_id] = -100 |
|
|
65 |
if not teacher_forcing: |
|
|
66 |
predictions = model.generate(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, |
|
|
67 |
target_ids_batch, |
|
|
68 |
max_length=100, |
|
|
69 |
num_beams=5, do_sample=False, repetition_penalty=5.0, |
|
|
70 |
|
|
|
71 |
# num_beams=5,encoder_no_repeat_ngram_size =1, |
|
|
72 |
# do_sample=True, top_k=15,temperature=0.5,num_return_sequences=5, |
|
|
73 |
# early_stopping=True |
|
|
74 |
|
|
|
75 |
) |
|
|
76 |
# predicted_string=predicted_string.squeeze() |
|
|
77 |
# print(f'predictions:{predictions}') |
|
|
78 |
# print(f'predicted_string:{predicted_string}') |
|
|
79 |
# |
|
|
80 |
# print(f'predicted_string:{predicted_string}') |
|
|
81 |
else: |
|
|
82 |
seq2seqLMoutput = model(input_embeddings_batch, input_masks_batch, input_mask_invert_batch, |
|
|
83 |
target_ids_batch) |
|
|
84 |
logits = seq2seqLMoutput.logits # bs*seq_len*voc_sz |
|
|
85 |
probs = logits.softmax(dim=-1) |
|
|
86 |
values, predictions = probs.topk(1) |
|
|
87 |
predictions = torch.squeeze(predictions, dim=-1) |
|
|
88 |
# print(f'predictions:{predictions} predictions shape:{predictions.shape}') |
|
|
89 |
predicted_string = tokenizer.batch_decode(predictions, skip_special_tokens=True, ) |
|
|
90 |
# print(f'predicted_string:{predicted_string}') |
|
|
91 |
|
|
|
92 |
# start = predicted_string.find("[CLS]") + len("[CLS]") |
|
|
93 |
# end = predicted_string.find("[SEP]") |
|
|
94 |
# predicted_string = predicted_string[start:end] |
|
|
95 |
# predicted_string=merge_consecutive_duplicates(predicted_string,'。') |
|
|
96 |
# predictions=tokenizer.encode(predicted_string) |
|
|
97 |
for str_id in range(len(target_string)): |
|
|
98 |
f.write(f'start################################################\n') |
|
|
99 |
f.write(f'Predicted: {predicted_string[str_id]}\n') |
|
|
100 |
f.write(f'True: {target_string[str_id]}\n') |
|
|
101 |
f.write(f'end################################################\n\n\n') |
|
|
102 |
# convert to int list |
|
|
103 |
# predictions = predictions.tolist() |
|
|
104 |
# truncated_prediction = [] |
|
|
105 |
# for t in predictions: |
|
|
106 |
# if t != tokenizer.eos_token_id: |
|
|
107 |
# truncated_prediction.append(t) |
|
|
108 |
# else: |
|
|
109 |
# break |
|
|
110 |
# pred_tokens = tokenizer.convert_ids_to_tokens(truncated_prediction, skip_special_tokens = True) |
|
|
111 |
# pred_tokens_list.append(pred_tokens) |
|
|
112 |
pred_string_list.extend(predicted_string) |
|
|
113 |
# sample_count += 1 |
|
|
114 |
# print('predicted tokens:',pred_tokens) |
|
|
115 |
# print('predicted string:',predicted_string) |
|
|
116 |
# print('-' * 100) |
|
|
117 |
# print(f'pred_string_list:{pred_string_list}') |
|
|
118 |
# print(f'target_string_list:{target_string_list}') |
|
|
119 |
metrics_results=compute_metrics(pred_string_list,target_string_list) |
|
|
120 |
print(f'teacher_forcing{teacher_forcing} intput_noise{intput_noise}') |
|
|
121 |
print(metrics_results) |
|
|
122 |
print(output_all_results_path) |
|
|
123 |
print(output_all_metrics_results_path) |
|
|
124 |
with open(output_all_metrics_results_path, "w") as json_file: |
|
|
125 |
json.dump(metrics_results, json_file, indent=4, ensure_ascii=False) |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
if __name__ == '__main__': |
|
|
129 |
home_directory = os.path.expanduser("~") |
|
|
130 |
''' get args''' |
|
|
131 |
args = get_config('eval_decoding') |
|
|
132 |
|
|
|
133 |
''' load training config''' |
|
|
134 |
training_config = json.load(open(args['config_path'])) |
|
|
135 |
|
|
|
136 |
batch_size = 1 |
|
|
137 |
|
|
|
138 |
subject_choice = training_config['subjects'] |
|
|
139 |
print(f'[INFO]subjects: {subject_choice}') |
|
|
140 |
eeg_type_choice = training_config['eeg_type'] |
|
|
141 |
print(f'[INFO]eeg type: {eeg_type_choice}') |
|
|
142 |
bands_choice = training_config['eeg_bands'] |
|
|
143 |
print(f'[INFO]using bands: {bands_choice}') |
|
|
144 |
|
|
|
145 |
dataset_setting = 'unique_sent' |
|
|
146 |
|
|
|
147 |
task_name = training_config['task_name'] |
|
|
148 |
|
|
|
149 |
model_name = training_config['model_name'] |
|
|
150 |
# model_name = 'BrainTranslator' |
|
|
151 |
# model_name = 'BrainTranslatorNaive' |
|
|
152 |
# teacher_forcing = True |
|
|
153 |
# {'wer': 0.7980769276618958, 'rouge1_fmeasure': 23.912235260009766, 'rouge1_precision': 24.66936492919922, 'rouge1_recall': 23.318071365356445, 'rouge2_fmeasure': 6.851282119750977, 'rouge2_precision': 6.962162017822266, 'rouge2_recall': 6.751219272613525, 'rougeL_fmeasure': 22.912235260009766, 'rougeL_precision': 23.61673355102539, 'rougeL_recall': 22.36568832397461, 'rougeLsum_fmeasure': 22.912235260009766, 'rougeLsum_precision': 23.61673355102539, 'rougeLsum_recall': 22.36568832397461, 'bleu-1': 0.23883000016212463, 'bleu-2': 0.13888777792453766, 'bleu-3': 0.0, 'bleu-4': 0.0} |
|
|
154 |
|
|
|
155 |
teacher_forcing = eval(args['tf']) |
|
|
156 |
intput_noise = eval(args['noise']) |
|
|
157 |
print(f'teacher_forcing{teacher_forcing} intput_noise{intput_noise}') |
|
|
158 |
output_all_results_path = (f'./results/{task_name}-{model_name}{"-teacher_forcing" if teacher_forcing else ""}{"-intput_noise" if intput_noise else ""}-all_decoding_results.txt') |
|
|
159 |
output_all_metrics_results_path = output_all_results_path.replace('txt', 'json') |
|
|
160 |
''' set random seeds ''' |
|
|
161 |
seed_val = 312 |
|
|
162 |
np.random.seed(seed_val) |
|
|
163 |
torch.manual_seed(seed_val) |
|
|
164 |
torch.cuda.manual_seed_all(seed_val) |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
''' set up device ''' |
|
|
168 |
# use cuda |
|
|
169 |
if torch.cuda.is_available(): |
|
|
170 |
dev = args['cuda'] |
|
|
171 |
else: |
|
|
172 |
dev = "cpu" |
|
|
173 |
# CUDA_VISIBLE_DEVICES=0,1,2,3 |
|
|
174 |
device = torch.device(dev) |
|
|
175 |
print(f'[INFO]using device {dev}') |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
''' set up dataloader ''' |
|
|
179 |
whole_dataset_dicts = [] |
|
|
180 |
if 'task1' in task_name: |
|
|
181 |
dataset_path_task1 = 'datasets/ZuCo/task1-SR/pickle/task1-SR-dataset.pickle' |
|
|
182 |
dataset_path_task1=os.path.join(home_directory,dataset_path_task1) |
|
|
183 |
with open(dataset_path_task1, 'rb') as handle: |
|
|
184 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
185 |
if 'task2' in task_name: |
|
|
186 |
dataset_path_task2 = 'datasets/ZuCo/task2-NR/pickle/task2-NR-dataset.pickle' |
|
|
187 |
dataset_path_task2=os.path.join(home_directory,dataset_path_task2) |
|
|
188 |
with open(dataset_path_task2, 'rb') as handle: |
|
|
189 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
190 |
if 'task3' in task_name: |
|
|
191 |
dataset_path_task3 = 'datasets/ZuCo/task3-TSR/pickle/task3-TSR-dataset.pickle' |
|
|
192 |
dataset_path_task3=os.path.join(home_directory,dataset_path_task3) |
|
|
193 |
with open(dataset_path_task3, 'rb') as handle: |
|
|
194 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
195 |
if 'taskNRv2' in task_name: |
|
|
196 |
dataset_path_taskNRv2 = 'datasets/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset.pickle' |
|
|
197 |
dataset_path_taskNRv2=os.path.join(home_directory,dataset_path_taskNRv2) |
|
|
198 |
with open(dataset_path_taskNRv2, 'rb') as handle: |
|
|
199 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
200 |
print() |
|
|
201 |
|
|
|
202 |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
|
|
203 |
|
|
|
204 |
# test dataset |
|
|
205 |
test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) |
|
|
206 |
|
|
|
207 |
dataset_sizes = {"test_set":len(test_set)} |
|
|
208 |
print('[INFO]test_set size: ', len(test_set)) |
|
|
209 |
|
|
|
210 |
# dataloaders |
|
|
211 |
test_dataloader = DataLoader(test_set, batch_size = 1, shuffle=False, num_workers=4) |
|
|
212 |
|
|
|
213 |
dataloaders = {'test':test_dataloader} |
|
|
214 |
|
|
|
215 |
''' set up model ''' |
|
|
216 |
checkpoint_path = args['checkpoint_path'] |
|
|
217 |
pretrained_bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large') |
|
|
218 |
|
|
|
219 |
if model_name == 'BrainTranslator': |
|
|
220 |
model = BrainTranslator(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) |
|
|
221 |
elif model_name == 'BrainTranslatorNaive': |
|
|
222 |
model = BrainTranslatorNaive(pretrained_bart, in_feature = 105*len(bands_choice), decoder_embedding_size = 1024, additional_encoder_nhead=8, additional_encoder_dim_feedforward = 2048) |
|
|
223 |
|
|
|
224 |
model.load_state_dict(torch.load(checkpoint_path)) |
|
|
225 |
model.to(device) |
|
|
226 |
|
|
|
227 |
criterion = nn.CrossEntropyLoss() |
|
|
228 |
|
|
|
229 |
''' eval ''' |
|
|
230 |
eval_model(dataloaders, device, tokenizer, criterion, model, output_all_results_path = output_all_results_path) |