a b/data.py
1
import os
2
import numpy as np
3
import torch
4
import pickle
5
from torch.utils.data import Dataset, DataLoader
6
import json
7
import matplotlib.pyplot as plt
8
from glob import glob
9
from transformers import BartTokenizer, BertTokenizer
10
from tqdm import tqdm
11
from fuzzy_match import match
12
from fuzzy_match import algorithims
13
14
# macro
15
#ZUCO_SENTIMENT_LABELS = json.load(open('./dataset/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
16
#SST_SENTIMENT_LABELS = json.load(open('./dataset/stanfordsentiment/ternary_dataset.json'))
17
18
def normalize_1d(input_tensor):
19
    # normalize a 1d tensor
20
    mean = torch.mean(input_tensor)
21
    std = torch.std(input_tensor)
22
    input_tensor = (input_tensor - mean)/std
23
    return input_tensor 
24
25
def get_input_sample(sent_obj, tokenizer, eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], max_len = 56, add_CLS_token = False):
26
    
27
    def get_word_embedding_eeg_tensor(word_obj, eeg_type, bands):
28
        frequency_features = []
29
        for band in bands:
30
            frequency_features.append(word_obj['word_level_EEG'][eeg_type][eeg_type+band])
31
        word_eeg_embedding = np.concatenate(frequency_features)
32
        if len(word_eeg_embedding) != 105*len(bands):
33
            print(f'expect word eeg embedding dim to be {105*len(bands)}, but got {len(word_eeg_embedding)}, return None')
34
            return None
35
        # assert len(word_eeg_embedding) == 105*len(bands)
36
        return_tensor = torch.from_numpy(word_eeg_embedding)
37
        return normalize_1d(return_tensor)
38
39
    def get_sent_eeg(sent_obj, bands):
40
        sent_eeg_features = []
41
        for band in bands:
42
            key = 'mean'+band
43
            sent_eeg_features.append(sent_obj['sentence_level_EEG'][key])
44
        sent_eeg_embedding = np.concatenate(sent_eeg_features)
45
        assert len(sent_eeg_embedding) == 105*len(bands)
46
        return_tensor = torch.from_numpy(sent_eeg_embedding)
47
        return normalize_1d(return_tensor)
48
49
    if sent_obj is None:
50
        # print(f'  - skip bad sentence')   
51
        return None
52
53
    input_sample = {}
54
    # get target label
55
    target_string = sent_obj['content']
56
    target_tokenized = tokenizer(target_string, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
57
    
58
    input_sample['target_ids'] = target_tokenized['input_ids'][0]
59
    
60
    # get sentence level EEG features
61
    sent_level_eeg_tensor = get_sent_eeg(sent_obj, bands)
62
    if torch.isnan(sent_level_eeg_tensor).any():
63
        # print('[NaN sent level eeg]: ', target_string)
64
        return None
65
    input_sample['sent_level_EEG'] = sent_level_eeg_tensor
66
67
    # get sentiment label
68
    # handle some wierd case
69
    if 'emp11111ty' in target_string:
70
        target_string = target_string.replace('emp11111ty','empty')
71
    if 'film.1' in target_string:
72
        target_string = target_string.replace('film.1','film.')
73
    
74
    #if target_string in ZUCO_SENTIMENT_LABELS:
75
    #    input_sample['sentiment_label'] = torch.tensor(ZUCO_SENTIMENT_LABELS[target_string]+1) # 0:Negative, 1:Neutral, 2:Positive
76
    #else:
77
    #    input_sample['sentiment_label'] = torch.tensor(-100) # dummy value
78
    input_sample['sentiment_label'] = torch.tensor(-100) # dummy value
79
80
    # get input embeddings
81
    word_embeddings = []
82
83
    """add CLS token embedding at the front"""
84
    if add_CLS_token:
85
        word_embeddings.append(torch.ones(105*len(bands)))
86
87
    for word in sent_obj['word']:
88
        # add each word's EEG embedding as Tensors
89
        word_level_eeg_tensor = get_word_embedding_eeg_tensor(word, eeg_type, bands = bands)
90
        # check none, for v2 dataset
91
        if word_level_eeg_tensor is None:
92
            return None
93
        # check nan:
94
        if torch.isnan(word_level_eeg_tensor).any():
95
            # print()
96
            # print('[NaN ERROR] problem sent:',sent_obj['content'])
97
            # print('[NaN ERROR] problem word:',word['content'])
98
            # print('[NaN ERROR] problem word feature:',word_level_eeg_tensor)
99
            # print()
100
            return None
101
            
102
103
        word_embeddings.append(word_level_eeg_tensor)
104
    # pad to max_len
105
    while len(word_embeddings) < max_len:
106
        word_embeddings.append(torch.zeros(105*len(bands)))
107
108
    input_sample['input_embeddings'] = torch.stack(word_embeddings) # max_len * (105*num_bands)
109
110
    # mask out padding tokens
111
    input_sample['input_attn_mask'] = torch.zeros(max_len) # 0 is masked out
112
113
    if add_CLS_token:
114
        input_sample['input_attn_mask'][:len(sent_obj['word'])+1] = torch.ones(len(sent_obj['word'])+1) # 1 is not masked
115
    else:
116
        input_sample['input_attn_mask'][:len(sent_obj['word'])] = torch.ones(len(sent_obj['word'])) # 1 is not masked
117
    
118
119
    # mask out padding tokens reverted: handle different use case: this is for pytorch transformers
120
    input_sample['input_attn_mask_invert'] = torch.ones(max_len) # 1 is masked out
121
122
    if add_CLS_token:
123
        input_sample['input_attn_mask_invert'][:len(sent_obj['word'])+1] = torch.zeros(len(sent_obj['word'])+1) # 0 is not masked
124
    else:
125
        input_sample['input_attn_mask_invert'][:len(sent_obj['word'])] = torch.zeros(len(sent_obj['word'])) # 0 is not masked
126
127
    
128
129
    # mask out target padding for computing cross entropy loss
130
    input_sample['target_mask'] = target_tokenized['attention_mask'][0]
131
    input_sample['seq_len'] = len(sent_obj['word'])
132
    
133
    # clean 0 length data
134
    if input_sample['seq_len'] == 0:
135
        print('discard length zero instance: ', target_string)
136
        return None
137
138
    return input_sample
139
140
class ZuCo_dataset(Dataset):
141
    def __init__(self, input_dataset_dicts, phase, tokenizer, subject = 'ALL', eeg_type = 'GD', bands = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'], setting = 'unique_sent', is_add_CLS_token = False):
142
        self.inputs = []
143
        self.tokenizer = tokenizer
144
145
        if not isinstance(input_dataset_dicts,list):
146
            input_dataset_dicts = [input_dataset_dicts]
147
        print(f'[INFO]loading {len(input_dataset_dicts)} task datasets')
148
        for input_dataset_dict in input_dataset_dicts:
149
            if subject == 'ALL':
150
                subjects = list(input_dataset_dict.keys())
151
                print('[INFO]using subjects: ', subjects)
152
            else:
153
                subjects = [subject]
154
            
155
            total_num_sentence = len(input_dataset_dict[subjects[0]])
156
            
157
            train_divider = int(0.8*total_num_sentence)
158
            dev_divider = train_divider + int(0.1*total_num_sentence)
159
            
160
            print(f'train divider = {train_divider}')
161
            print(f'dev divider = {dev_divider}')
162
163
            if setting == 'unique_sent':
164
                # take first 80% as trainset, 10% as dev and 10% as test
165
                if phase == 'train':
166
                    print('[INFO]initializing a train set...')
167
                    for key in subjects:
168
                        for i in range(train_divider):
169
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
170
                            if input_sample is not None:
171
                                self.inputs.append(input_sample)
172
                elif phase == 'dev':
173
                    print('[INFO]initializing a dev set...')
174
                    for key in subjects:
175
                        for i in range(train_divider,dev_divider):
176
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
177
                            if input_sample is not None:
178
                                self.inputs.append(input_sample)
179
                elif phase == 'test':
180
                    print('[INFO]initializing a test set...')
181
                    for key in subjects:
182
                        for i in range(dev_divider,total_num_sentence):
183
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
184
                            if input_sample is not None:
185
                                self.inputs.append(input_sample)
186
            elif setting == 'unique_subj':
187
                print('WARNING!!! only implemented for SR v1 dataset ')
188
                # subject ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH', 'ZKW'] for train
189
                # subject ['ZMG'] for dev
190
                # subject ['ZPH'] for test
191
                if phase == 'train':
192
                    print(f'[INFO]initializing a train set using {setting} setting...')
193
                    for i in range(total_num_sentence):
194
                        for key in ['ZAB', 'ZDM', 'ZGW', 'ZJM', 'ZJN', 'ZJS', 'ZKB', 'ZKH','ZKW']:
195
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
196
                            if input_sample is not None:
197
                                self.inputs.append(input_sample)
198
                if phase == 'dev':
199
                    print(f'[INFO]initializing a dev set using {setting} setting...')
200
                    for i in range(total_num_sentence):
201
                        for key in ['ZMG']:
202
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
203
                            if input_sample is not None:
204
                                self.inputs.append(input_sample)
205
                if phase == 'test':
206
                    print(f'[INFO]initializing a test set using {setting} setting...')
207
                    for i in range(total_num_sentence):
208
                        for key in ['ZPH']:
209
                            input_sample = get_input_sample(input_dataset_dict[key][i],self.tokenizer,eeg_type,bands = bands, add_CLS_token = is_add_CLS_token)
210
                            if input_sample is not None:
211
                                self.inputs.append(input_sample)
212
            print('++ adding task to dataset, now we have:', len(self.inputs))
213
214
        print('[INFO]input tensor size:', self.inputs[0]['input_embeddings'].size())
215
        print()
216
217
    def __len__(self):
218
        return len(self.inputs)
219
220
    def __getitem__(self, idx):
221
        input_sample = self.inputs[idx]
222
        return (
223
            input_sample['input_embeddings'], 
224
            input_sample['seq_len'],
225
            input_sample['input_attn_mask'], 
226
            input_sample['input_attn_mask_invert'],
227
            input_sample['target_ids'], 
228
            input_sample['target_mask'], 
229
            input_sample['sentiment_label'], 
230
            input_sample['sent_level_EEG']
231
        )
232
        # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, 
233
234
235
"""for train classifier on stanford sentiment treebank text-sentiment pairs"""
236
class SST_tenary_dataset(Dataset):
237
    def __init__(self, ternary_labels_dict, tokenizer, max_len = 56, balance_class = True):
238
        self.inputs = []
239
        
240
        pos_samples = []
241
        neg_samples = []
242
        neu_samples = []
243
244
        for key,value in ternary_labels_dict.items():
245
            tokenized_inputs = tokenizer(key, padding='max_length', max_length=max_len, truncation=True, return_tensors='pt', return_attention_mask = True)
246
            input_ids = tokenized_inputs['input_ids'][0]
247
            attn_masks = tokenized_inputs['attention_mask'][0]
248
            label = torch.tensor(value)
249
            # count:
250
            if value == 0:
251
                neg_samples.append((input_ids,attn_masks,label))
252
            elif value == 1:
253
                neu_samples.append((input_ids,attn_masks,label))
254
            elif value == 2:
255
                pos_samples.append((input_ids,attn_masks,label))
256
        print(f'Original distribution:\n\tVery positive: {len(pos_samples)}\n\tNeutral: {len(neu_samples)}\n\tVery negative: {len(neg_samples)}')    
257
        if balance_class:
258
            print(f'balance class to {min([len(pos_samples),len(neg_samples),len(neu_samples)])} each...')
259
            for i in range(min([len(pos_samples),len(neg_samples),len(neu_samples)])):
260
                self.inputs.append(pos_samples[i])
261
                self.inputs.append(neg_samples[i])
262
                self.inputs.append(neu_samples[i])
263
        else:
264
            self.inputs = pos_samples + neg_samples + neu_samples
265
        
266
    def __len__(self):
267
        return len(self.inputs)
268
269
    def __getitem__(self, idx):
270
        input_sample = self.inputs[idx]
271
        return input_sample
272
        # keys: input_embeddings, input_attn_mask, input_attn_mask_invert, target_ids, target_mask, 
273
        
274
275
276
'''sanity test'''
277
if __name__ == '__main__':
278
279
    check_dataset = 'stanford_sentiment'
280
281
    if check_dataset == 'ZuCo':
282
        whole_dataset_dicts = []
283
        
284
        dataset_path_task1 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task1-SR/pickle/task1-SR-dataset-with-tokens_6-25.pickle' 
285
        with open(dataset_path_task1, 'rb') as handle:
286
            whole_dataset_dicts.append(pickle.load(handle))
287
288
        dataset_path_task2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR/pickle/task2-NR-dataset-with-tokens_7-10.pickle' 
289
        with open(dataset_path_task2, 'rb') as handle:
290
            whole_dataset_dicts.append(pickle.load(handle))
291
292
        # dataset_path_task3 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task3-TSR/pickle/task3-TSR-dataset-with-tokens_7-10.pickle' 
293
        # with open(dataset_path_task3, 'rb') as handle:
294
        #     whole_dataset_dicts.append(pickle.load(handle))
295
296
        dataset_path_task2_v2 = '/shared/nas/data/m1/wangz3/SAO_project/SAO/dataset/ZuCo/task2-NR-2.0/pickle/task2-NR-2.0-dataset-with-tokens_7-15.pickle' 
297
        with open(dataset_path_task2_v2, 'rb') as handle:
298
            whole_dataset_dicts.append(pickle.load(handle))
299
300
        print()
301
        for key in whole_dataset_dicts[0]:
302
            print(f'task2_v2, sentence num in {key}:',len(whole_dataset_dicts[0][key]))
303
        print()
304
305
        tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
306
        dataset_setting = 'unique_sent'
307
        subject_choice = 'ALL'
308
        print(f'![Debug]using {subject_choice}')
309
        eeg_type_choice = 'GD'
310
        print(f'[INFO]eeg type {eeg_type_choice}') 
311
        bands_choice = ['_t1','_t2','_a1','_a2','_b1','_b2','_g1','_g2'] 
312
        print(f'[INFO]using bands {bands_choice}')
313
        train_set = ZuCo_dataset(whole_dataset_dicts, 'train', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
314
        dev_set = ZuCo_dataset(whole_dataset_dicts, 'dev', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
315
        test_set = ZuCo_dataset(whole_dataset_dicts, 'test', tokenizer, subject = subject_choice, eeg_type = eeg_type_choice, bands = bands_choice, setting = dataset_setting)
316
317
        print('trainset size:',len(train_set))
318
        print('devset size:',len(dev_set))
319
        print('testset size:',len(test_set))
320
321
    elif check_dataset == 'stanford_sentiment':
322
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
323
        SST_dataset = SST_tenary_dataset(SST_SENTIMENT_LABELS, tokenizer)
324
        print('SST dataset size:',len(SST_dataset))
325
        print(SST_dataset[0])
326
        print(SST_dataset[1])