|
a |
|
b/run.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
import sys |
|
|
3 |
import random |
|
|
4 |
import os |
|
|
5 |
import numpy as np |
|
|
6 |
from collections import defaultdict |
|
|
7 |
sys.path.append(os.getcwd()) #add the env path |
|
|
8 |
from sklearn.model_selection import train_test_split,StratifiedKFold |
|
|
9 |
from main import train |
|
|
10 |
|
|
|
11 |
from config import DRUG_EXAMPLE, RESULT_LOG, PROCESSED_DATA_DIR, LOG_DIR, MODEL_SAVED_DIR, ENTITY2ID_FILE, KG_FILE, \ |
|
|
12 |
EXAMPLE_FILE, DRUG_VOCAB_TEMPLATE, ENTITY_VOCAB_TEMPLATE, \ |
|
|
13 |
RELATION_VOCAB_TEMPLATE, SEPARATOR, THRESHOLD, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, \ |
|
|
14 |
TEST_DATA_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, ModelConfig, NEIGHBOR_SIZE |
|
|
15 |
from utils import pickle_dump, format_filename,write_log,pickle_load |
|
|
16 |
|
|
|
17 |
def read_entity2id_file(file_path: str, drug_vocab: dict, entity_vocab: dict): |
|
|
18 |
print(f'Logging Info - Reading entity2id file: {file_path}' ) |
|
|
19 |
assert len(drug_vocab) == 0 and len(entity_vocab) == 0 |
|
|
20 |
with open(file_path, encoding='utf8') as reader: |
|
|
21 |
count=0 |
|
|
22 |
for line in reader: |
|
|
23 |
if(count==0): |
|
|
24 |
count+=1 |
|
|
25 |
continue |
|
|
26 |
drug, entity = line.strip().split('\t') |
|
|
27 |
drug_vocab[entity]=len(drug_vocab) |
|
|
28 |
entity_vocab[entity] = len(entity_vocab) |
|
|
29 |
|
|
|
30 |
def read_example_file(file_path:str,separator:str,drug_vocab:dict): |
|
|
31 |
print(f'Logging Info - Reading example file: {file_path}') |
|
|
32 |
assert len(drug_vocab)>0 |
|
|
33 |
examples=[] |
|
|
34 |
with open(file_path,encoding='utf8') as reader: |
|
|
35 |
for idx,line in enumerate(reader): |
|
|
36 |
d1,d2,flag=line.strip().split(separator)[:3] |
|
|
37 |
if d1 not in drug_vocab or d2 not in drug_vocab: |
|
|
38 |
continue |
|
|
39 |
if d1 in drug_vocab and d2 in drug_vocab: |
|
|
40 |
examples.append([drug_vocab[d1],drug_vocab[d2],int(flag)]) |
|
|
41 |
|
|
|
42 |
examples_matrix=np.array(examples) |
|
|
43 |
print(f'size of example: {examples_matrix.shape}') |
|
|
44 |
X=examples_matrix[:,:2] |
|
|
45 |
y=examples_matrix[:,2:3] |
|
|
46 |
train_data_X, valid_data_X,train_y,val_y = train_test_split(X,y, test_size=0.2,stratify=y) |
|
|
47 |
train_data=np.c_[train_data_X,train_y] |
|
|
48 |
valid_data_X, test_data_X,val_y,test_y = train_test_split(valid_data_X,val_y, test_size=0.5) |
|
|
49 |
valid_data=np.c_[valid_data_X,val_y] |
|
|
50 |
test_data=np.c_[test_data_X,test_y] |
|
|
51 |
return examples_matrix |
|
|
52 |
|
|
|
53 |
def read_kg(file_path: str, entity_vocab: dict, relation_vocab: dict, neighbor_sample_size: int): |
|
|
54 |
print(f'Logging Info - Reading kg file: {file_path}') |
|
|
55 |
|
|
|
56 |
kg = defaultdict(list) |
|
|
57 |
with open(file_path, encoding='utf8') as reader: |
|
|
58 |
count=0 |
|
|
59 |
for line in reader: |
|
|
60 |
if count==0: |
|
|
61 |
count+=1 |
|
|
62 |
continue |
|
|
63 |
head, tail, relation = line.strip().split(' ') |
|
|
64 |
|
|
|
65 |
if head not in entity_vocab: |
|
|
66 |
entity_vocab[head] = len(entity_vocab) |
|
|
67 |
if tail not in entity_vocab: |
|
|
68 |
entity_vocab[tail] = len(entity_vocab) |
|
|
69 |
if relation not in relation_vocab: |
|
|
70 |
relation_vocab[relation] = len(relation_vocab) |
|
|
71 |
|
|
|
72 |
# undirected graph |
|
|
73 |
kg[entity_vocab[head]].append((entity_vocab[tail], relation_vocab[relation])) |
|
|
74 |
kg[entity_vocab[tail]].append((entity_vocab[head], relation_vocab[relation])) |
|
|
75 |
print(f'Logging Info - num of entities: {len(entity_vocab)}, ' |
|
|
76 |
f'num of relations: {len(relation_vocab)}') |
|
|
77 |
|
|
|
78 |
print('Logging Info - Constructing adjacency matrix...') |
|
|
79 |
n_entity = len(entity_vocab) |
|
|
80 |
adj_entity = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64) |
|
|
81 |
adj_relation = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64) |
|
|
82 |
|
|
|
83 |
for entity_id in range(n_entity): |
|
|
84 |
all_neighbors = kg[entity_id] |
|
|
85 |
n_neighbor = len(all_neighbors) |
|
|
86 |
sample_indices = np.random.choice( |
|
|
87 |
n_neighbor, |
|
|
88 |
neighbor_sample_size, |
|
|
89 |
replace=False if n_neighbor >= neighbor_sample_size else True |
|
|
90 |
) |
|
|
91 |
|
|
|
92 |
adj_entity[entity_id] = np.array([all_neighbors[i][0] for i in sample_indices]) |
|
|
93 |
adj_relation[entity_id] = np.array([all_neighbors[i][1] for i in sample_indices]) |
|
|
94 |
|
|
|
95 |
return adj_entity, adj_relation |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def process_data(dataset: str, neighbor_sample_size: int,K:int): |
|
|
99 |
drug_vocab = {} |
|
|
100 |
entity_vocab = {} |
|
|
101 |
relation_vocab = {} |
|
|
102 |
|
|
|
103 |
read_entity2id_file(ENTITY2ID_FILE[dataset], drug_vocab, entity_vocab) |
|
|
104 |
|
|
|
105 |
pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),drug_vocab) |
|
|
106 |
pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),entity_vocab) |
|
|
107 |
|
|
|
108 |
examples_file=format_filename(PROCESSED_DATA_DIR, DRUG_EXAMPLE, dataset=dataset) |
|
|
109 |
examples = read_example_file(EXAMPLE_FILE[dataset], SEPARATOR[dataset],drug_vocab) |
|
|
110 |
np.save(examples_file,examples) |
|
|
111 |
|
|
|
112 |
adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset) |
|
|
113 |
adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset) |
|
|
114 |
|
|
|
115 |
adj_entity, adj_relation = read_kg(KG_FILE[dataset], entity_vocab, relation_vocab, |
|
|
116 |
neighbor_sample_size) |
|
|
117 |
|
|
|
118 |
pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset), |
|
|
119 |
drug_vocab) |
|
|
120 |
pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset), |
|
|
121 |
entity_vocab) |
|
|
122 |
pickle_dump(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset), |
|
|
123 |
relation_vocab) |
|
|
124 |
adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset) |
|
|
125 |
np.save(adj_entity_file, adj_entity) |
|
|
126 |
print('Logging Info - Saved:', adj_entity_file) |
|
|
127 |
|
|
|
128 |
adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset) |
|
|
129 |
np.save(adj_relation_file, adj_relation) |
|
|
130 |
print('Logging Info - Saved:', adj_entity_file) |
|
|
131 |
cross_validation(K,examples,dataset,neighbor_sample_size) |
|
|
132 |
|
|
|
133 |
|
|
|
134 |
def cross_validation(K_fold,examples,dataset,neighbor_sample_size): |
|
|
135 |
subsets=dict() |
|
|
136 |
n_subsets=int(len(examples)/K_fold) |
|
|
137 |
remain=set(range(0,len(examples)-1)) |
|
|
138 |
for i in reversed(range(0,K_fold-1)): |
|
|
139 |
subsets[i]=random.sample(remain,n_subsets) |
|
|
140 |
remain=remain.difference(subsets[i]) |
|
|
141 |
subsets[K_fold-1]=remain |
|
|
142 |
aggregator_types=['sum','concat','neigh'] |
|
|
143 |
for t in aggregator_types: |
|
|
144 |
count=1 |
|
|
145 |
temp={'dataset':dataset,'aggregator_type':t,'avg_auc':0.0,'avg_acc':0.0,'avg_f1':0.0,'avg_aupr':0.0} |
|
|
146 |
for i in reversed(range(0,K_fold)): |
|
|
147 |
test_d=examples[list(subsets[i])] |
|
|
148 |
val_d,test_data=train_test_split(test_d,test_size=0.5) |
|
|
149 |
train_d=[] |
|
|
150 |
for j in range(0,K_fold): |
|
|
151 |
if i!=j: |
|
|
152 |
train_d.extend(examples[list(subsets[j])]) |
|
|
153 |
train_data=np.array(train_d) |
|
|
154 |
train_log=train( |
|
|
155 |
kfold=count, |
|
|
156 |
dataset=dataset, |
|
|
157 |
train_d=train_data, |
|
|
158 |
dev_d=val_d, |
|
|
159 |
test_d=test_data, |
|
|
160 |
neighbor_sample_size=neighbor_sample_size, |
|
|
161 |
embed_dim=32, |
|
|
162 |
n_depth=2, |
|
|
163 |
l2_weight=1e-7, |
|
|
164 |
lr=2e-2, |
|
|
165 |
#lr=5e-3, |
|
|
166 |
optimizer_type='adam', |
|
|
167 |
batch_size=2048, |
|
|
168 |
aggregator_type=t, |
|
|
169 |
n_epoch=50, |
|
|
170 |
callbacks_to_add=['modelcheckpoint', 'earlystopping'] |
|
|
171 |
) |
|
|
172 |
count+=1 |
|
|
173 |
temp['avg_auc']=temp['avg_auc']+train_log['test_auc'] |
|
|
174 |
temp['avg_acc']=temp['avg_acc']+train_log['test_acc'] |
|
|
175 |
temp['avg_f1']=temp['avg_f1']+train_log['test_f1'] |
|
|
176 |
temp['avg_aupr']=temp['avg_aupr']+train_log['test_aupr'] |
|
|
177 |
for key in temp: |
|
|
178 |
if key=='aggregator_type' or key=='dataset': |
|
|
179 |
continue |
|
|
180 |
temp[key]=temp[key]/K_fold |
|
|
181 |
write_log(format_filename(LOG_DIR, RESULT_LOG[dataset]),temp,'a') |
|
|
182 |
print(f'Logging Info - {K_fold} fold result: avg_auc: {temp["avg_auc"]}, avg_acc: {temp["avg_acc"]}, avg_f1: {temp["avg_f1"]}, avg_aupr: {temp["avg_aupr"]}') |
|
|
183 |
|
|
|
184 |
if __name__ == '__main__': |
|
|
185 |
if not os.path.exists(PROCESSED_DATA_DIR): |
|
|
186 |
os.makedirs(PROCESSED_DATA_DIR) |
|
|
187 |
if not os.path.exists(LOG_DIR): |
|
|
188 |
os.makedirs(LOG_DIR) |
|
|
189 |
if not os.path.exists(MODEL_SAVED_DIR): |
|
|
190 |
os.makedirs(MODEL_SAVED_DIR) |
|
|
191 |
model_config = ModelConfig() |
|
|
192 |
process_data('kegg',NEIGHBOR_SIZE['kegg'],5) |
|
|
193 |
|
|
|
194 |
|
|
|
195 |
|