--- a +++ b/test/diseasedb/generate_data.py @@ -0,0 +1,44 @@ +from sklearn.model_selection import train_test_split + + +def load_all_relation(): + with open("./data/relation_all.txt", "r", encoding="utf-8") as f: + lines = f.readlines()[1:] + cui1_list = [] + rel_list = [] + cui2_list = [] + for line in lines: + cui1, rel, cui2, source = line.strip().split("|") + if source == "diseasedb": + cui1_list.append(cui1) + cui2_list.append(cui2) + rel_list.append(rel) + + print("Tri group count:", len(cui1_list)) + print("Relation count:", len(set(rel_list))) + print(set(rel_list)) + + return cui1_list, cui2_list, rel_list + + +def split_and_save(): + cui1_list, cui2_list, rel_list = load_all_relation() + x = [[cui1_list[i], cui2_list[i]] for i in range(len(cui1_list))] + x_train, x_test, y_train, y_test = train_test_split( + x, rel_list, test_size=0.2, random_state=72, stratify=rel_list) + with open("./data/x_train.txt", "w", encoding="utf-8") as f: + for x in x_train: + f.write(x[0] + "\t" + x[1] + "\n") + with open("./data/x_test.txt", "w", encoding="utf-8") as f: + for x in x_test: + f.write(x[0] + "\t" + x[1] + "\n") + with open("./data/y_train.txt", "w", encoding="utf-8") as f: + for y in y_train: + f.write(y + "\n") + with open("./data/y_test.txt", "w", encoding="utf-8") as f: + for y in y_test: + f.write(y + "\n") + +if __name__ == "__main__": + #split_and_save() + load_all_relation()