|
a |
|
b/data.py |
|
|
1 |
import os |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
import pickle |
|
|
5 |
from torch.utils.data import Dataset, DataLoader |
|
|
6 |
import json |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
from glob import glob |
|
|
9 |
from transformers import BartTokenizer, BertTokenizer |
|
|
10 |
from tqdm import tqdm |
|
|
11 |
from fuzzy_match import match |
|
|
12 |
from fuzzy_match import algorithims |
|
|
13 |
|
|
|
14 |
# macro |
|
|
15 |
#ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json')) |
|
|
16 |
#SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json')) |
|
|
17 |
|
|
|
18 |
def normalize_1d(input_tensor): |
|
|
19 |
# normalize a 1d tensor |
|
|
20 |
mean = torch.mean(input_tensor) |
|
|
21 |
std = torch.std(input_tensor) |
|
|
22 |
input_tensor = (input_tensor - mean)/std |
|
|
23 |
return input_tensor |
|
|
24 |
|
|
|
25 |
def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False): |
|
|
26 |
|
|
|
27 |
def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands): |
|
|
28 |
frequency_features = [] |
|
|
29 |
for band in bands: |
|
|
30 |
frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band]) |
|
|
31 |
word_eeg_embedding = np.concatenate(frequency_features) |
|
|
32 |
if len(word_eeg_embedding) != 105*len(bands): |
|
|
33 |
print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None') |
|
|
34 |
return None |
|
|
35 |
# assert len(word_eeg_embedding) == 105*len(bands) |
|
|
36 |
return_tensor = torch.from_numpy(word_eeg_embedding) |
|
|
37 |
return normalize_1d(return_tensor) |
|
|
38 |
|
|
|
39 |
def get_sent_eeg(sent_obj, bands): |
|
|
40 |
sent_eeg_features = [] |
|
|
41 |
for band in bands: |
|
|
42 |
key = 'mean'+band |
|
|
43 |
sent_eeg_features.append(sent_obj['sentence_level_EEG'][key]) |
|
|
44 |
sent_eeg_embedding = np.concatenate(sent_eeg_features) |
|
|
45 |
assert len(sent_eeg_embedding) == 105*len(bands) |
|
|
46 |
return_tensor = torch.from_numpy(sent_eeg_embedding) |
|
|
47 |
return normalize_1d(return_tensor) |
|
|
48 |
|
|
|
49 |
if sent_obj is None: |
|
|
50 |
# print(f' - skip bad sentence') |
|
|
51 |
return None |
|
|
52 |
|
|
|
53 |
input_sample = {} |
|
|
54 |
# get target label |
|
|
55 |
target_string = sent_obj['content'] |
|
|
56 |
target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) |
|
|
57 |
|
|
|
58 |
input_sample['target_ids'] = target_tokenized['input_ids'][0] |
|
|
59 |
|
|
|
60 |
# get sentence level EEG features |
|
|
61 |
sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands) |
|
|
62 |
if torch.isnan(sent_level_eeg_tensor).any(): |
|
|
63 |
# print('[NaN sent level eeg]: ', target_string) |
|
|
64 |
return None |
|
|
65 |
input_sample['sent_level_EEG'] = sent_level_eeg_tensor |
|
|
66 |
|
|
|
67 |
# get sentiment label |
|
|
68 |
# handle some wierd case |
|
|
69 |
if 'emp11111ty' in target_string: |
|
|
70 |
target_string = target_string.replace('emp11111ty','empty') |
|
|
71 |
if 'film.1' in target_string: |
|
|
72 |
target_string = target_string.replace('film.1','film.') |
|
|
73 |
|
|
|
74 |
#if target_string in ZUCO_SENTIMENT_LABELS: |
|
|
75 |
# input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive |
|
|
76 |
#else: |
|
|
77 |
# input_sample['sentiment_label'] = torch.tensor(-100) # dummy value |
|
|
78 |
input_sample['sentiment_label'] = torch.tensor(-100) # dummy value |
|
|
79 |
|
|
|
80 |
# get input embeddings |
|
|
81 |
word_embeddings = [] |
|
|
82 |
|
|
|
83 |
"""add CLS token embedding at the front""" |
|
|
84 |
if add_CLS_token: |
|
|
85 |
word_embeddings.append(torch.ones(105*len(bands))) |
|
|
86 |
|
|
|
87 |
for word in sent_obj['word']: |
|
|
88 |
# add each word's EEG embedding as Tensors |
|
|
89 |
word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands) |
|
|
90 |
# check none, for v2 dataset |
|
|
91 |
if word_level_eeg_tensor is None: |
|
|
92 |
return None |
|
|
93 |
# check nan: |
|
|
94 |
if torch.isnan(word_level_eeg_tensor).any(): |
|
|
95 |
# print() |
|
|
96 |
# print('[NaN ERROR] problem sent:',sent_obj['content']) |
|
|
97 |
# print('[NaN ERROR] problem word:',word['content']) |
|
|
98 |
# print('[NaN ERROR] problem word feature:',word_level_eeg_tensor) |
|
|
99 |
# print() |
|
|
100 |
return None |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
word_embeddings.append(word_level_eeg_tensor) |
|
|
104 |
# pad to max_len |
|
|
105 |
while len(word_embeddings) < max_len: |
|
|
106 |
word_embeddings.append(torch.zeros(105*len(bands))) |
|
|
107 |
|
|
|
108 |
input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands) |
|
|
109 |
|
|
|
110 |
# mask out padding tokens |
|
|
111 |
input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out |
|
|
112 |
|
|
|
113 |
if add_CLS_token: |
|
|
114 |
input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked |
|
|
115 |
else: |
|
|
116 |
input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
# mask out padding tokens reverted: handle different use case: this is for pytorch transformers |
|
|
120 |
input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out |
|
|
121 |
|
|
|
122 |
if add_CLS_token: |
|
|
123 |
input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked |
|
|
124 |
else: |
|
|
125 |
input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
|
|
|
129 |
# mask out target padding for computing cross entropy loss |
|
|
130 |
input_sample['target_mask'] = target_tokenized['attention_mask'][0] |
|
|
131 |
input_sample['seq_len'] = len(sent_obj['word']) |
|
|
132 |
|
|
|
133 |
# clean 0 length data |
|
|
134 |
if input_sample['seq_len'] == 0: |
|
|
135 |
print('discard length zero instance: ', target_string) |
|
|
136 |
return None |
|
|
137 |
|
|
|
138 |
return input_sample |
|
|
139 |
|
|
|
140 |
class ZuCo_dataset(Dataset): |
|
|
141 |
def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False): |
|
|
142 |
self.inputs = [] |
|
|
143 |
self.tokenizer = tokenizer |
|
|
144 |
|
|
|
145 |
if not isinstance(input_dataset_dicts,list): |
|
|
146 |
input_dataset_dicts = [input_dataset_dicts] |
|
|
147 |
print(f'[INFO]loading {len(input_dataset_dicts)} task datasets') |
|
|
148 |
for input_dataset_dict in input_dataset_dicts: |
|
|
149 |
if subject == 'ALL': |
|
|
150 |
subjects = list(input_dataset_dict.keys()) |
|
|
151 |
print('[INFO]using subjects: ', subjects) |
|
|
152 |
else: |
|
|
153 |
subjects = [subject] |
|
|
154 |
|
|
|
155 |
total_num_sentence = len(input_dataset_dict[subjects[0]]) |
|
|
156 |
|
|
|
157 |
train_divider = int(0.8*total_num_sentence) |
|
|
158 |
dev_divider = train_divider + int(0.1*total_num_sentence) |
|
|
159 |
|
|
|
160 |
print(f'train divider = {train_divider}') |
|
|
161 |
print(f'dev divider = {dev_divider}') |
|
|
162 |
|
|
|
163 |
if setting == 'unique_sent': |
|
|
164 |
# take first 80% as trainset, 10% as dev and 10% as test |
|
|
165 |
if phase == 'train': |
|
|
166 |
print('[INFO]initializing a train set...') |
|
|
167 |
for key in subjects: |
|
|
168 |
for i in range(train_divider): |
|
|
169 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
170 |
if input_sample is not None: |
|
|
171 |
self.inputs.append(input_sample) |
|
|
172 |
elif phase == 'dev': |
|
|
173 |
print('[INFO]initializing a dev set...') |
|
|
174 |
for key in subjects: |
|
|
175 |
for i in range(train_divider,dev_divider): |
|
|
176 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
177 |
if input_sample is not None: |
|
|
178 |
self.inputs.append(input_sample) |
|
|
179 |
elif phase == 'test': |
|
|
180 |
print('[INFO]initializing a test set...') |
|
|
181 |
for key in subjects: |
|
|
182 |
for i in range(dev_divider,total_num_sentence): |
|
|
183 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
184 |
if input_sample is not None: |
|
|
185 |
self.inputs.append(input_sample) |
|
|
186 |
elif setting == 'unique_subj': |
|
|
187 |
print('WARNING!!! only implemented for SR v1 dataset ') |
|
|
188 |
# subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train |
|
|
189 |
# subject ['ZMG'] for dev |
|
|
190 |
# subject ['ZPH'] for test |
|
|
191 |
if phase == 'train': |
|
|
192 |
print(f'[INFO]initializing a train set using {setting} setting...') |
|
|
193 |
for i in range(total_num_sentence): |
|
|
194 |
for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']: |
|
|
195 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
196 |
if input_sample is not None: |
|
|
197 |
self.inputs.append(input_sample) |
|
|
198 |
if phase == 'dev': |
|
|
199 |
print(f'[INFO]initializing a dev set using {setting} setting...') |
|
|
200 |
for i in range(total_num_sentence): |
|
|
201 |
for key in ['ZMG']: |
|
|
202 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
203 |
if input_sample is not None: |
|
|
204 |
self.inputs.append(input_sample) |
|
|
205 |
if phase == 'test': |
|
|
206 |
print(f'[INFO]initializing a test set using {setting} setting...') |
|
|
207 |
for i in range(total_num_sentence): |
|
|
208 |
for key in ['ZPH']: |
|
|
209 |
input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token) |
|
|
210 |
if input_sample is not None: |
|
|
211 |
self.inputs.append(input_sample) |
|
|
212 |
print('++ adding task to dataset, now we have:', len(self.inputs)) |
|
|
213 |
|
|
|
214 |
print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size()) |
|
|
215 |
print() |
|
|
216 |
|
|
|
217 |
def __len__(self): |
|
|
218 |
return len(self.inputs) |
|
|
219 |
|
|
|
220 |
def __getitem__(self, idx): |
|
|
221 |
input_sample = self.inputs[idx] |
|
|
222 |
return ( |
|
|
223 |
input_sample['input_embeddings'], |
|
|
224 |
input_sample['seq_len'], |
|
|
225 |
input_sample['input_attn_mask'], |
|
|
226 |
input_sample['input_attn_mask_invert'], |
|
|
227 |
input_sample['target_ids'], |
|
|
228 |
input_sample['target_mask'], |
|
|
229 |
input_sample['sentiment_label'], |
|
|
230 |
input_sample['sent_level_EEG'] |
|
|
231 |
) |
|
|
232 |
# keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, |
|
|
233 |
|
|
|
234 |
|
|
|
235 |
"""for train classifier on stanford sentiment treebank text-sentiment pairs""" |
|
|
236 |
class SST_tenary_dataset(Dataset): |
|
|
237 |
def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balance_class = True): |
|
|
238 |
self.inputs = [] |
|
|
239 |
|
|
|
240 |
pos_samples = [] |
|
|
241 |
neg_samples = [] |
|
|
242 |
neu_samples = [] |
|
|
243 |
|
|
|
244 |
for key,value in ternary_labels_dict.items(): |
|
|
245 |
tokenized_inputs = tokenizer(key, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True) |
|
|
246 |
input_ids = tokenized_inputs['input_ids'][0] |
|
|
247 |
attn_masks = tokenized_inputs['attention_mask'][0] |
|
|
248 |
label = torch.tensor(value) |
|
|
249 |
# count: |
|
|
250 |
if value == 0: |
|
|
251 |
neg_samples.append((input_ids,attn_masks,label)) |
|
|
252 |
elif value == 1: |
|
|
253 |
neu_samples.append((input_ids,attn_masks,label)) |
|
|
254 |
elif value == 2: |
|
|
255 |
pos_samples.append((input_ids,attn_masks,label)) |
|
|
256 |
print(f'Original distribution:\n\tVery positive: {len(pos_samples)}\n\tNeutral: {len(neu_samples)}\n\tVery negative: {len(neg_samples)}') |
|
|
257 |
if balance_class: |
|
|
258 |
print(f'balance class to {min([len(pos_samples),len(neg_samples),len(neu_samples)])} each...') |
|
|
259 |
for i in range(min([len(pos_samples),len(neg_samples),len(neu_samples)])): |
|
|
260 |
self.inputs.append(pos_samples[i]) |
|
|
261 |
self.inputs.append(neg_samples[i]) |
|
|
262 |
self.inputs.append(neu_samples[i]) |
|
|
263 |
else: |
|
|
264 |
self.inputs = pos_samples + neg_samples + neu_samples |
|
|
265 |
|
|
|
266 |
def __len__(self): |
|
|
267 |
return len(self.inputs) |
|
|
268 |
|
|
|
269 |
def __getitem__(self, idx): |
|
|
270 |
input_sample = self.inputs[idx] |
|
|
271 |
return input_sample |
|
|
272 |
# keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, |
|
|
273 |
|
|
|
274 |
|
|
|
275 |
|
|
|
276 |
'''sanity test''' |
|
|
277 |
if __name__ == '__main__': |
|
|
278 |
|
|
|
279 |
check_dataset = 'stanford_sentiment' |
|
|
280 |
|
|
|
281 |
if check_dataset == 'ZuCo': |
|
|
282 |
whole_dataset_dicts = [] |
|
|
283 |
|
|
|
284 |
dataset_path_task1 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task1-SR/pickle/task1-SR-dataset-with-tokens_6-25.pickle' |
|
|
285 |
with open(dataset_path_task1, 'rb') as handle: |
|
|
286 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
287 |
|
|
|
288 |
dataset_path_task2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR/pickle/task2-NR-dataset-with-tokens_7-10.pickle' |
|
|
289 |
with open(dataset_path_task2, 'rb') as handle: |
|
|
290 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
291 |
|
|
|
292 |
# dataset_path_task3 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset-with-tokens_7-10.pickle' |
|
|
293 |
# with open(dataset_path_task3, 'rb') as handle: |
|
|
294 |
# whole_dataset_dicts.append(pickle.load(handle)) |
|
|
295 |
|
|
|
296 |
dataset_path_task2_v2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset-with-tokens_7-15.pickle' |
|
|
297 |
with open(dataset_path_task2_v2, 'rb') as handle: |
|
|
298 |
whole_dataset_dicts.append(pickle.load(handle)) |
|
|
299 |
|
|
|
300 |
print() |
|
|
301 |
for key in whole_dataset_dicts[0]: |
|
|
302 |
print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key])) |
|
|
303 |
print() |
|
|
304 |
|
|
|
305 |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') |
|
|
306 |
dataset_setting = 'unique_sent' |
|
|
307 |
subject_choice = 'ALL' |
|
|
308 |
print(f'![Debug]using {subject_choice}') |
|
|
309 |
eeg_type_choice = 'GD' |
|
|
310 |
print(f'[INFO]eeg type {eeg_type_choice}') |
|
|
311 |
bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] |
|
|
312 |
print(f'[INFO]using bands {bands_choice}') |
|
|
313 |
train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) |
|
|
314 |
dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) |
|
|
315 |
test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting) |
|
|
316 |
|
|
|
317 |
print('trainset size:',len(train_set)) |
|
|
318 |
print('devset size:',len(dev_set)) |
|
|
319 |
print('testset size:',len(test_set)) |
|
|
320 |
|
|
|
321 |
elif check_dataset == 'stanford_sentiment': |
|
|
322 |
tokenizer = BertTokenizer.from_pretrained('bert-base-cased') |
|
|
323 |
SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer) |
|
|
324 |
print('SST dataset size:',len(SST_dataset)) |
|
|
325 |
print(SST_dataset[0]) |
|
|
326 |
print(SST_dataset[1]) |