[d69072]: / src / utils / save_predictions.py

Download this file

133 lines (109 with data), 5.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from config import DEFAULT_OUTPUT_DIR
ast_to_concept = {
"test" : "test",
"treatment" : "treatment",
"present" : "problem",
"absent" : "problem",
"possible" : "problem",
"conditional" : "problem",
"hypothetical" : "problem",
"associated_with_someone_else" : "problem"
}
token_id_to_label = {0: 'o',
1: 'test',
2: 'test',
3: 'treatment',
4: 'treatment',
5: 'present',
6: 'present',
7: 'absent',
8: 'absent',
9: 'possible',
10: 'possible',
11: 'conditional',
12: 'conditional',
13: 'hypothetical',
14: 'hypothetical',
15: 'associated_with_someone_else',
16: 'associated_with_someone_else'}
def save_predictions(test_data, predictions, ast_to_concept=ast_to_concept, token_id_to_label=token_id_to_label, output_dir=DEFAULT_OUTPUT_DIR):
# see if ouput dir exists, if not create it
if not os.path.exists(os.path.join(output_dir, "con")):
os.makedirs(os.path.join(output_dir, "con"))
if not os.path.exists(os.path.join(output_dir, "ast")):
os.makedirs(os.path.join(output_dir, "ast"))
known_files = set({}) # set of files that have seen, because in the first time we clear the file
for i,a in enumerate(test_data):
#creat files
con_output_dir = os.path.join(output_dir, "con")
ast_output_dir = os.path.join(output_dir, "ast")
open(os.path.join(con_output_dir, a['filename'] + ".con"), 'a').close()
open(os.path.join(ast_output_dir, a['filename'] + ".ast"), 'a').close()
if sum(predictions[i])==0:
pass
else:
pred = predictions[i][0:len(a['input_ids'])-1] #remove padding zeros
# -------------------------------- Real thing -------------------------------- #
old_token = None
splits = []
for j,token in enumerate(pred):
if old_token!=None and token_id_to_label[old_token] == token_id_to_label[token]: #continue same sequence
splits[-1][1] = j
else: #start a new sequence
if len(splits) :
splits[-1][1] = j
old_token = token
splits.append([j, None, token])
if splits[-1][1]==None:
splits[-1][1] = splits[-1][0]+1
for split in splits :
token = split[-1]
if token == 0:
continue
mapping_list = a['offset_mapping'][split[0]:split[1]]
mapping = [mapping_list[0][0], mapping_list[-1][1]] #by character
# ------------------------------- word_mapping ------------------------------- #
word = a["text"][mapping[0]:mapping[1]]
word_mapping_0 = a["text"][:mapping[0]].count(" ")
word_mapping_1 = a["text"][:mapping[1]].count(" ")
word_mapping = [word_mapping_0, word_mapping_1]
# word = " ".join(a['text'].split(" ")[word_mapping[0]:word_mapping[1]+1]).strip()
# -------------------------------- Build lines ------------------------------- #
con_line = 'c="'+ word
con_line += '" '
con_line += str(a["row"]) + ":" + str(word_mapping[0])
con_line += ' '
con_line += str(a["row"]) + ":" + str(word_mapping[1])
con_line += '||t="'
con_line += ast_to_concept[token_id_to_label[token]] + '"'
if ast_to_concept[token_id_to_label[token]] == "problem":
ast_line = con_line
ast_line += '||a="'
ast_line += token_id_to_label[token]
ast_line += '"'
# if a['filename'] == "0033":
# print("kj")
# ---------------------------------------------------------------------------- #
# Concept #
# ---------------------------------------------------------------------------- #
if a['filename'] + ".con" in known_files:
pass
else :
known_files.add(a['filename'] + ".con")
open(os.path.join(con_output_dir, a['filename'] + ".con"), 'w').close()
with open(os.path.join(con_output_dir, a['filename'] + ".con"), 'a') as f:
f.write(con_line + "\n")
# ---------------------------------------------------------------------------- #
# Assertion #
# ---------------------------------------------------------------------------- #
if ast_to_concept[token_id_to_label[token]] == "problem":
if a['filename'] + ".ast" in known_files:
pass
else :
known_files.add(a['filename'] + ".ast")
open(os.path.join(ast_output_dir, a['filename'] + ".ast"), 'w').close()
with open(os.path.join(ast_output_dir, a['filename'] + ".ast"), 'a') as f:
f.write(ast_line + "\n")