[c3444c]: / test / diseasedb / generate_data.py

Download this file

45 lines (37 with data), 1.5 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from sklearn.model_selection import train_test_split
def load_all_relation():
with open("./data/relation_all.txt", "r", encoding="utf-8") as f:
lines = f.readlines()[1:]
cui1_list = []
rel_list = []
cui2_list = []
for line in lines:
cui1, rel, cui2, source = line.strip().split("|")
if source == "diseasedb":
cui1_list.append(cui1)
cui2_list.append(cui2)
rel_list.append(rel)
print("Tri group count:", len(cui1_list))
print("Relation count:", len(set(rel_list)))
print(set(rel_list))
return cui1_list, cui2_list, rel_list
def split_and_save():
cui1_list, cui2_list, rel_list = load_all_relation()
x = [[cui1_list[i], cui2_list[i]] for i in range(len(cui1_list))]
x_train, x_test, y_train, y_test = train_test_split(
x, rel_list, test_size=0.2, random_state=72, stratify=rel_list)
with open("./data/x_train.txt", "w", encoding="utf-8") as f:
for x in x_train:
f.write(x[0] + "\t" + x[1] + "\n")
with open("./data/x_test.txt", "w", encoding="utf-8") as f:
for x in x_test:
f.write(x[0] + "\t" + x[1] + "\n")
with open("./data/y_train.txt", "w", encoding="utf-8") as f:
for y in y_train:
f.write(y + "\n")
with open("./data/y_test.txt", "w", encoding="utf-8") as f:
for y in y_test:
f.write(y + "\n")
if __name__ == "__main__":
#split_and_save()
load_all_relation()