Switch to side-by-side view

--- a
+++ b/util/get_SST_ternary_dataset.py
@@ -0,0 +1,143 @@
+import os
+import numpy as np
+import torch
+import pickle
+from torch.utils.data import Dataset, DataLoader
+import json
+import matplotlib.pyplot as plt
+from glob import glob
+from transformers import BartTokenizer
+from tqdm import tqdm
+from fuzzy_match import match
+from fuzzy_match import algorithims
+
+
+def get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LABELS):
+    
+    def get_sentiment_label_dict(SST_dictionary_file_path):
+        '''
+            return {phrase_id:sentiment_score(0-1)}
+        '''
+        ret_dict = {}
+        with open(SST_dictionary_file_path) as f:
+            for line in f:
+                if line.startswith('phrase'):
+                    continue
+                else:
+                    phrase_id = int(line.split('|')[0])
+                    label = float(line.split('|')[1].strip())
+                    assert phrase_id not in ret_dict
+                    ret_dict[phrase_id] = label
+        return ret_dict
+
+    def get_phrasestr_phrase_dict(SST_dictionary_file_path):
+        '''
+            return {phrase_str: phrase_id}
+        '''
+        ret_dict = {}
+        with open(SST_dictionary_file_path) as f:
+            for line in f:
+                phrase_str = line.split('|')[0]
+                phrase_id = int(line.split('|')[1].strip())
+                assert phrase_str not in ret_dict
+                ret_dict[phrase_str] = phrase_id
+        return ret_dict
+
+    def get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path):
+        '''
+            return {sentence_str:label(0-1)}
+        '''
+        phraseID_2_label = get_sentiment_label_dict(SST_labels_file_path)
+        phraseStr_2_phraseID = get_phrasestr_phrase_dict(SST_dictionary_file_path)
+
+        sentence_2_label_all = {}
+        sentence_2_label_ternary = {}
+        with open(SST_sentences_file_path) as f:
+            for line in f:
+                if line.startswith('sentence_index'):
+                    continue
+                else:
+                    parsed_line = line.split('\t')
+                    assert len(parsed_line) == 2
+                    sentence = parsed_line[1].strip()
+                    # convert -LRB- to (, -RRB- to ):
+                    sentence = sentence.replace('-LRB-','(').replace('-RRB-',')').replace('é','é')
+                    if sentence not in phraseStr_2_phraseID:
+                        # print(f'[ERROR]sentence-phrase match not found in dictionary, skipped: {sentence}')
+                        # print()
+                        continue
+                    sent_phrase_id = phraseStr_2_phraseID[sentence]
+                    label = phraseID_2_label[sent_phrase_id]
+                    
+                    # add to all dict
+                    if sentence not in sentence_2_label_all:
+                        sentence_2_label_all[sentence] = label
+
+                    # add to ternary dict
+                    if sentence not in sentence_2_label_ternary:
+                        if label<=0.2:
+                            label = 0
+                            sentence_2_label_ternary[sentence] = label
+                        elif (label > 0.4) and (label<=0.6): 
+                            label = 1
+                            sentence_2_label_ternary[sentence] = label
+                        elif label>0.8:
+                            label = 2
+                            sentence_2_label_ternary[sentence] = label
+
+        return sentence_2_label_all, sentence_2_label_ternary
+
+
+    SST_sentences_file_path = os.path.join(SST_dir_path,'datasetSentences.txt')
+    if not os.path.isfile(SST_sentences_file_path):
+        print(f'NOT FOUND file: {SST_sentences_file_path}')
+    SST_labels_file_path = os.path.join(SST_dir_path,'sentiment_labels.txt')
+    if not os.path.isfile(SST_labels_file_path):
+        print(f'NOT FOUND file: {SST_labels_file_path}')
+    SST_dictionary_file_path = os.path.join(SST_dir_path,'dictionary.txt')
+    if not os.path.isfile(SST_dictionary_file_path):
+        print(f'NOT FOUND file: {SST_dictionary_file_path}')
+
+    sentence_2_label_all, sentence_2_label_ternary = get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path)
+    print('original ternary dataset size:', len(sentence_2_label_ternary))
+
+    ZuCo_used_sentences = list(ZUCO_SENTIMENT_LABELS)
+
+    filtered_ternary_dataset = {}
+    filtered_pairs = []
+    for key,value in sentence_2_label_ternary.items():
+        add_instance = True
+        for used_sent in ZuCo_used_sentences:
+            if algorithims.trigram(used_sent, key) > 0.7:
+                # print(f'Filter match: \n\t{used_sent}\n\t{key}')
+                # print('###########################')
+                filtered_pairs.append((used_sent, key))
+                ZuCo_used_sentences.remove(used_sent)
+                add_instance = False
+                break
+        if add_instance:
+            filtered_ternary_dataset[key] = value
+    
+    print('filtered instance number:', len(filtered_pairs))
+    print('filtered ternary dataset size:', len(filtered_ternary_dataset))
+    print('unmatched remaining sentences:', ZuCo_used_sentences)
+    print('unmatched remaining sentences length:', len(ZuCo_used_sentences))
+    with open('temp.txt','w') as temp:
+        for matched_pair in filtered_pairs:
+            temp.write('#######\n')
+            temp.write('\t'+matched_pair[0]+'\n')
+            temp.write('\t'+matched_pair[1]+'\n')
+            temp.write('\n')
+
+    with open('./dataset/stanfordsentiment/ternary_dataset.json', 'w') as out:
+        json.dump(filtered_ternary_dataset,out, indent = 4)
+    print('write json to /dataset/stanfordsentiment/ternary_dataset.json')
+
+if __name__ == '__main__':
+    print('##############################')
+    print('start generating stanfordSentimentTreebank ternary sentiment dataset...')
+    SST_dir_path = '~/datasets/stanfordsentiment/stanfordSentimentTreebank'
+    ZuCo_task1_csv_path = '~/datasets/ZuCo/task_materials/sentiment_labels_task1.csv'
+    ZUCO_SENTIMENT_LABELS = json.load(open('~/datasets/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
+
+    get_SST_dataset(SST_dir_path, ZuCo_task1_csv_path, ZUCO_SENTIMENT_LABELS)
\ No newline at end of file