[66af30]: / util / get_SST_ternary_dataset.py

Download this file

143 lines (125 with data), 6.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import numpy as np
import torch
import pickle
from torch.utils.data import Dataset, DataLoader
import json
import matplotlib.pyplot as plt
from glob import glob
from transformers import BartTokenizer
from tqdm import tqdm
from fuzzy_match import match
from fuzzy_match import algorithims
def get_SST_dataset(SST_dir_path, ZuCo_used_sentences, ZUCO_SENTIMENT_LABELS):
def get_sentiment_label_dict(SST_dictionary_file_path):
'''
return {phrase_id:sentiment_score(0-1)}
'''
ret_dict = {}
with open(SST_dictionary_file_path) as f:
for line in f:
if line.startswith('phrase'):
continue
else:
phrase_id = int(line.split('|')[0])
label = float(line.split('|')[1].strip())
assert phrase_id not in ret_dict
ret_dict[phrase_id] = label
return ret_dict
def get_phrasestr_phrase_dict(SST_dictionary_file_path):
'''
return {phrase_str: phrase_id}
'''
ret_dict = {}
with open(SST_dictionary_file_path) as f:
for line in f:
phrase_str = line.split('|')[0]
phrase_id = int(line.split('|')[1].strip())
assert phrase_str not in ret_dict
ret_dict[phrase_str] = phrase_id
return ret_dict
def get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path):
'''
return {sentence_str:label(0-1)}
'''
phraseID_2_label = get_sentiment_label_dict(SST_labels_file_path)
phraseStr_2_phraseID = get_phrasestr_phrase_dict(SST_dictionary_file_path)
sentence_2_label_all = {}
sentence_2_label_ternary = {}
with open(SST_sentences_file_path) as f:
for line in f:
if line.startswith('sentence_index'):
continue
else:
parsed_line = line.split('\t')
assert len(parsed_line) == 2
sentence = parsed_line[1].strip()
# convert -LRB- to (, -RRB- to ):
sentence = sentence.replace('-LRB-','(').replace('-RRB-',')').replace('é','é')
if sentence not in phraseStr_2_phraseID:
# print(f'[ERROR]sentence-phrase match not found in dictionary, skipped: {sentence}')
# print()
continue
sent_phrase_id = phraseStr_2_phraseID[sentence]
label = phraseID_2_label[sent_phrase_id]
# add to all dict
if sentence not in sentence_2_label_all:
sentence_2_label_all[sentence] = label
# add to ternary dict
if sentence not in sentence_2_label_ternary:
if label<=0.2:
label = 0
sentence_2_label_ternary[sentence] = label
elif (label > 0.4) and (label<=0.6):
label = 1
sentence_2_label_ternary[sentence] = label
elif label>0.8:
label = 2
sentence_2_label_ternary[sentence] = label
return sentence_2_label_all, sentence_2_label_ternary
SST_sentences_file_path = os.path.join(SST_dir_path,'datasetSentences.txt')
if not os.path.isfile(SST_sentences_file_path):
print(f'NOT FOUND file: {SST_sentences_file_path}')
SST_labels_file_path = os.path.join(SST_dir_path,'sentiment_labels.txt')
if not os.path.isfile(SST_labels_file_path):
print(f'NOT FOUND file: {SST_labels_file_path}')
SST_dictionary_file_path = os.path.join(SST_dir_path,'dictionary.txt')
if not os.path.isfile(SST_dictionary_file_path):
print(f'NOT FOUND file: {SST_dictionary_file_path}')
sentence_2_label_all, sentence_2_label_ternary = get_sentence_label_dict(SST_sentences_file_path, SST_labels_file_path, SST_dictionary_file_path)
print('original ternary dataset size:', len(sentence_2_label_ternary))
ZuCo_used_sentences = list(ZUCO_SENTIMENT_LABELS)
filtered_ternary_dataset = {}
filtered_pairs = []
for key,value in sentence_2_label_ternary.items():
add_instance = True
for used_sent in ZuCo_used_sentences:
if algorithims.trigram(used_sent, key) > 0.7:
# print(f'Filter match: \n\t{used_sent}\n\t{key}')
# print('###########################')
filtered_pairs.append((used_sent, key))
ZuCo_used_sentences.remove(used_sent)
add_instance = False
break
if add_instance:
filtered_ternary_dataset[key] = value
print('filtered instance number:', len(filtered_pairs))
print('filtered ternary dataset size:', len(filtered_ternary_dataset))
print('unmatched remaining sentences:', ZuCo_used_sentences)
print('unmatched remaining sentences length:', len(ZuCo_used_sentences))
with open('temp.txt','w') as temp:
for matched_pair in filtered_pairs:
temp.write('#######\n')
temp.write('\t'+matched_pair[0]+'\n')
temp.write('\t'+matched_pair[1]+'\n')
temp.write('\n')
with open('./dataset/stanfordsentiment/ternary_dataset.json', 'w') as out:
json.dump(filtered_ternary_dataset,out, indent = 4)
print('write json to /dataset/stanfordsentiment/ternary_dataset.json')
if __name__ == '__main__':
print('##############################')
print('start generating stanfordSentimentTreebank ternary sentiment dataset...')
SST_dir_path = '~/datasets/stanfordsentiment/stanfordSentimentTreebank'
ZuCo_task1_csv_path = '~/datasets/ZuCo/task_materials/sentiment_labels_task1.csv'
ZUCO_SENTIMENT_LABELS = json.load(open('~/datasets/ZuCo/task1-SR/sentiment_labels/sentiment_labels.json'))
get_SST_dataset(SST_dir_path, ZuCo_task1_csv_path, ZUCO_SENTIMENT_LABELS)