Diff of /run.py [000000] .. [c0da92]

Switch to unified view

a b/run.py
1
# -*- coding: utf-8 -*-
2
import sys
3
import random
4
import os
5
import numpy as np
6
from collections import defaultdict
7
sys.path.append(os.getcwd()) #add the env path
8
from sklearn.model_selection import train_test_split,StratifiedKFold
9
from main import train
10
11
from config import DRUG_EXAMPLE, RESULT_LOG, PROCESSED_DATA_DIR, LOG_DIR, MODEL_SAVED_DIR, ENTITY2ID_FILE, KG_FILE, \
12
    EXAMPLE_FILE,  DRUG_VOCAB_TEMPLATE, ENTITY_VOCAB_TEMPLATE, \
13
    RELATION_VOCAB_TEMPLATE, SEPARATOR, THRESHOLD, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, \
14
    TEST_DATA_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, ModelConfig, NEIGHBOR_SIZE
15
from utils import pickle_dump, format_filename,write_log,pickle_load
16
17
def read_entity2id_file(file_path: str, drug_vocab: dict, entity_vocab: dict):
18
    print(f'Logging Info - Reading entity2id file: {file_path}' )
19
    assert len(drug_vocab) == 0 and len(entity_vocab) == 0
20
    with open(file_path, encoding='utf8') as reader:
21
        count=0
22
        for line in reader:
23
            if(count==0):
24
                count+=1
25
                continue
26
            drug, entity = line.strip().split('\t')
27
            drug_vocab[entity]=len(drug_vocab) 
28
            entity_vocab[entity] = len(entity_vocab)
29
30
def read_example_file(file_path:str,separator:str,drug_vocab:dict):
31
    print(f'Logging Info - Reading example file: {file_path}')
32
    assert len(drug_vocab)>0
33
    examples=[]
34
    with open(file_path,encoding='utf8') as reader:
35
        for idx,line in enumerate(reader):
36
            d1,d2,flag=line.strip().split(separator)[:3]
37
            if d1 not in drug_vocab or d2 not in drug_vocab:
38
                continue
39
            if d1 in drug_vocab and d2 in drug_vocab:
40
                examples.append([drug_vocab[d1],drug_vocab[d2],int(flag)])
41
    
42
    examples_matrix=np.array(examples)
43
    print(f'size of example: {examples_matrix.shape}')
44
    X=examples_matrix[:,:2]
45
    y=examples_matrix[:,2:3]
46
    train_data_X, valid_data_X,train_y,val_y = train_test_split(X,y, test_size=0.2,stratify=y)
47
    train_data=np.c_[train_data_X,train_y]
48
    valid_data_X, test_data_X,val_y,test_y = train_test_split(valid_data_X,val_y, test_size=0.5)
49
    valid_data=np.c_[valid_data_X,val_y]
50
    test_data=np.c_[test_data_X,test_y]
51
    return examples_matrix
52
53
def read_kg(file_path: str, entity_vocab: dict, relation_vocab: dict, neighbor_sample_size: int):
54
    print(f'Logging Info - Reading kg file: {file_path}')
55
56
    kg = defaultdict(list)
57
    with open(file_path, encoding='utf8') as reader:
58
        count=0
59
        for line in reader:
60
            if count==0:
61
                count+=1
62
                continue
63
            head, tail, relation = line.strip().split(' ') 
64
65
            if head not in entity_vocab:
66
                entity_vocab[head] = len(entity_vocab)
67
            if tail not in entity_vocab:
68
                entity_vocab[tail] = len(entity_vocab)
69
            if relation not in relation_vocab:
70
                relation_vocab[relation] = len(relation_vocab)
71
72
            # undirected graph
73
            kg[entity_vocab[head]].append((entity_vocab[tail], relation_vocab[relation]))
74
            kg[entity_vocab[tail]].append((entity_vocab[head], relation_vocab[relation]))
75
    print(f'Logging Info - num of entities: {len(entity_vocab)}, '
76
          f'num of relations: {len(relation_vocab)}')
77
78
    print('Logging Info - Constructing adjacency matrix...')
79
    n_entity = len(entity_vocab)
80
    adj_entity = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64)
81
    adj_relation = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64)
82
83
    for entity_id in range(n_entity):
84
        all_neighbors = kg[entity_id]
85
        n_neighbor = len(all_neighbors)
86
        sample_indices = np.random.choice(
87
            n_neighbor,
88
            neighbor_sample_size,
89
            replace=False if n_neighbor >= neighbor_sample_size else True
90
        )
91
92
        adj_entity[entity_id] = np.array([all_neighbors[i][0] for i in sample_indices])
93
        adj_relation[entity_id] = np.array([all_neighbors[i][1] for i in sample_indices])
94
95
    return adj_entity, adj_relation
96
97
98
def process_data(dataset: str, neighbor_sample_size: int,K:int):
99
    drug_vocab = {}
100
    entity_vocab = {}
101
    relation_vocab = {}
102
103
    read_entity2id_file(ENTITY2ID_FILE[dataset], drug_vocab, entity_vocab)
104
105
    pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),drug_vocab)
106
    pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),entity_vocab)
107
108
    examples_file=format_filename(PROCESSED_DATA_DIR, DRUG_EXAMPLE, dataset=dataset)
109
    examples = read_example_file(EXAMPLE_FILE[dataset], SEPARATOR[dataset],drug_vocab)
110
    np.save(examples_file,examples)
111
          
112
    adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)
113
    adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)
114
    
115
    adj_entity, adj_relation = read_kg(KG_FILE[dataset], entity_vocab, relation_vocab,
116
                                       neighbor_sample_size)
117
118
    pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),
119
                drug_vocab)
120
    pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),
121
                entity_vocab)
122
    pickle_dump(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset),
123
                relation_vocab)
124
    adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)
125
    np.save(adj_entity_file, adj_entity)
126
    print('Logging Info - Saved:', adj_entity_file)
127
128
    adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)
129
    np.save(adj_relation_file, adj_relation)
130
    print('Logging Info - Saved:', adj_entity_file)
131
    cross_validation(K,examples,dataset,neighbor_sample_size)
132
133
134
def cross_validation(K_fold,examples,dataset,neighbor_sample_size):
135
    subsets=dict()
136
    n_subsets=int(len(examples)/K_fold)
137
    remain=set(range(0,len(examples)-1))
138
    for i in reversed(range(0,K_fold-1)):
139
        subsets[i]=random.sample(remain,n_subsets)
140
        remain=remain.difference(subsets[i])
141
    subsets[K_fold-1]=remain
142
    aggregator_types=['sum','concat','neigh']
143
    for t in aggregator_types:
144
        count=1
145
        temp={'dataset':dataset,'aggregator_type':t,'avg_auc':0.0,'avg_acc':0.0,'avg_f1':0.0,'avg_aupr':0.0}
146
        for i in reversed(range(0,K_fold)):
147
            test_d=examples[list(subsets[i])]
148
            val_d,test_data=train_test_split(test_d,test_size=0.5)
149
            train_d=[]
150
            for j in range(0,K_fold):
151
                if i!=j:
152
                    train_d.extend(examples[list(subsets[j])])
153
            train_data=np.array(train_d)               
154
            train_log=train(
155
            kfold=count,
156
            dataset=dataset,
157
            train_d=train_data,
158
            dev_d=val_d,
159
            test_d=test_data,
160
            neighbor_sample_size=neighbor_sample_size,
161
            embed_dim=32,
162
            n_depth=2,
163
            l2_weight=1e-7,
164
            lr=2e-2,
165
            #lr=5e-3,
166
            optimizer_type='adam',
167
            batch_size=2048,
168
            aggregator_type=t,
169
            n_epoch=50,
170
            callbacks_to_add=['modelcheckpoint', 'earlystopping']
171
            )     
172
            count+=1
173
            temp['avg_auc']=temp['avg_auc']+train_log['test_auc']
174
            temp['avg_acc']=temp['avg_acc']+train_log['test_acc']
175
            temp['avg_f1']=temp['avg_f1']+train_log['test_f1']
176
            temp['avg_aupr']=temp['avg_aupr']+train_log['test_aupr']
177
        for key in temp:
178
            if key=='aggregator_type' or key=='dataset':
179
                continue
180
            temp[key]=temp[key]/K_fold
181
        write_log(format_filename(LOG_DIR, RESULT_LOG[dataset]),temp,'a')
182
        print(f'Logging Info - {K_fold} fold result: avg_auc: {temp["avg_auc"]}, avg_acc: {temp["avg_acc"]}, avg_f1: {temp["avg_f1"]}, avg_aupr: {temp["avg_aupr"]}')
183
   
184
if __name__ == '__main__':
185
    if not os.path.exists(PROCESSED_DATA_DIR):
186
        os.makedirs(PROCESSED_DATA_DIR)
187
    if not os.path.exists(LOG_DIR):
188
        os.makedirs(LOG_DIR)
189
    if not os.path.exists(MODEL_SAVED_DIR):
190
        os.makedirs(MODEL_SAVED_DIR)
191
    model_config = ModelConfig()
192
    process_data('kegg',NEIGHBOR_SIZE['kegg'],5)
193
194
195