--- a +++ b/run.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +import sys +import random +import os +import numpy as np +from collections import defaultdict +sys.path.append(os.getcwd()) #add the env path +from sklearn.model_selection import train_test_split,StratifiedKFold +from main import train + +from config import DRUG_EXAMPLE, RESULT_LOG, PROCESSED_DATA_DIR, LOG_DIR, MODEL_SAVED_DIR, ENTITY2ID_FILE, KG_FILE, \ + EXAMPLE_FILE, DRUG_VOCAB_TEMPLATE, ENTITY_VOCAB_TEMPLATE, \ + RELATION_VOCAB_TEMPLATE, SEPARATOR, THRESHOLD, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, \ + TEST_DATA_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, ModelConfig, NEIGHBOR_SIZE +from utils import pickle_dump, format_filename,write_log,pickle_load + +def read_entity2id_file(file_path: str, drug_vocab: dict, entity_vocab: dict): + print(f'Logging Info - Reading entity2id file: {file_path}' ) + assert len(drug_vocab) == 0 and len(entity_vocab) == 0 + with open(file_path, encoding='utf8') as reader: + count=0 + for line in reader: + if(count==0): + count+=1 + continue + drug, entity = line.strip().split('\t') + drug_vocab[entity]=len(drug_vocab) + entity_vocab[entity] = len(entity_vocab) + +def read_example_file(file_path:str,separator:str,drug_vocab:dict): + print(f'Logging Info - Reading example file: {file_path}') + assert len(drug_vocab)>0 + examples=[] + with open(file_path,encoding='utf8') as reader: + for idx,line in enumerate(reader): + d1,d2,flag=line.strip().split(separator)[:3] + if d1 not in drug_vocab or d2 not in drug_vocab: + continue + if d1 in drug_vocab and d2 in drug_vocab: + examples.append([drug_vocab[d1],drug_vocab[d2],int(flag)]) + + examples_matrix=np.array(examples) + print(f'size of example: {examples_matrix.shape}') + X=examples_matrix[:,:2] + y=examples_matrix[:,2:3] + train_data_X, valid_data_X,train_y,val_y = train_test_split(X,y, test_size=0.2,stratify=y) + train_data=np.c_[train_data_X,train_y] + valid_data_X, test_data_X,val_y,test_y = train_test_split(valid_data_X,val_y, test_size=0.5) + valid_data=np.c_[valid_data_X,val_y] + test_data=np.c_[test_data_X,test_y] + return examples_matrix + +def read_kg(file_path: str, entity_vocab: dict, relation_vocab: dict, neighbor_sample_size: int): + print(f'Logging Info - Reading kg file: {file_path}') + + kg = defaultdict(list) + with open(file_path, encoding='utf8') as reader: + count=0 + for line in reader: + if count==0: + count+=1 + continue + head, tail, relation = line.strip().split(' ') + + if head not in entity_vocab: + entity_vocab[head] = len(entity_vocab) + if tail not in entity_vocab: + entity_vocab[tail] = len(entity_vocab) + if relation not in relation_vocab: + relation_vocab[relation] = len(relation_vocab) + + # undirected graph + kg[entity_vocab[head]].append((entity_vocab[tail], relation_vocab[relation])) + kg[entity_vocab[tail]].append((entity_vocab[head], relation_vocab[relation])) + print(f'Logging Info - num of entities: {len(entity_vocab)}, ' + f'num of relations: {len(relation_vocab)}') + + print('Logging Info - Constructing adjacency matrix...') + n_entity = len(entity_vocab) + adj_entity = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64) + adj_relation = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64) + + for entity_id in range(n_entity): + all_neighbors = kg[entity_id] + n_neighbor = len(all_neighbors) + sample_indices = np.random.choice( + n_neighbor, + neighbor_sample_size, + replace=False if n_neighbor >= neighbor_sample_size else True + ) + + adj_entity[entity_id] = np.array([all_neighbors[i][0] for i in sample_indices]) + adj_relation[entity_id] = np.array([all_neighbors[i][1] for i in sample_indices]) + + return adj_entity, adj_relation + + +def process_data(dataset: str, neighbor_sample_size: int,K:int): + drug_vocab = {} + entity_vocab = {} + relation_vocab = {} + + read_entity2id_file(ENTITY2ID_FILE[dataset], drug_vocab, entity_vocab) + + pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),drug_vocab) + pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),entity_vocab) + + examples_file=format_filename(PROCESSED_DATA_DIR, DRUG_EXAMPLE, dataset=dataset) + examples = read_example_file(EXAMPLE_FILE[dataset], SEPARATOR[dataset],drug_vocab) + np.save(examples_file,examples) + + adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset) + adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset) + + adj_entity, adj_relation = read_kg(KG_FILE[dataset], entity_vocab, relation_vocab, + neighbor_sample_size) + + pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset), + drug_vocab) + pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset), + entity_vocab) + pickle_dump(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset), + relation_vocab) + adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset) + np.save(adj_entity_file, adj_entity) + print('Logging Info - Saved:', adj_entity_file) + + adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset) + np.save(adj_relation_file, adj_relation) + print('Logging Info - Saved:', adj_entity_file) + cross_validation(K,examples,dataset,neighbor_sample_size) + + +def cross_validation(K_fold,examples,dataset,neighbor_sample_size): + subsets=dict() + n_subsets=int(len(examples)/K_fold) + remain=set(range(0,len(examples)-1)) + for i in reversed(range(0,K_fold-1)): + subsets[i]=random.sample(remain,n_subsets) + remain=remain.difference(subsets[i]) + subsets[K_fold-1]=remain + aggregator_types=['sum','concat','neigh'] + for t in aggregator_types: + count=1 + temp={'dataset':dataset,'aggregator_type':t,'avg_auc':0.0,'avg_acc':0.0,'avg_f1':0.0,'avg_aupr':0.0} + for i in reversed(range(0,K_fold)): + test_d=examples[list(subsets[i])] + val_d,test_data=train_test_split(test_d,test_size=0.5) + train_d=[] + for j in range(0,K_fold): + if i!=j: + train_d.extend(examples[list(subsets[j])]) + train_data=np.array(train_d) + train_log=train( + kfold=count, + dataset=dataset, + train_d=train_data, + dev_d=val_d, + test_d=test_data, + neighbor_sample_size=neighbor_sample_size, + embed_dim=32, + n_depth=2, + l2_weight=1e-7, + lr=2e-2, + #lr=5e-3, + optimizer_type='adam', + batch_size=2048, + aggregator_type=t, + n_epoch=50, + callbacks_to_add=['modelcheckpoint', 'earlystopping'] + ) + count+=1 + temp['avg_auc']=temp['avg_auc']+train_log['test_auc'] + temp['avg_acc']=temp['avg_acc']+train_log['test_acc'] + temp['avg_f1']=temp['avg_f1']+train_log['test_f1'] + temp['avg_aupr']=temp['avg_aupr']+train_log['test_aupr'] + for key in temp: + if key=='aggregator_type' or key=='dataset': + continue + temp[key]=temp[key]/K_fold + write_log(format_filename(LOG_DIR, RESULT_LOG[dataset]),temp,'a') + print(f'Logging Info - {K_fold} fold result: avg_auc: {temp["avg_auc"]}, avg_acc: {temp["avg_acc"]}, avg_f1: {temp["avg_f1"]}, avg_aupr: {temp["avg_aupr"]}') + +if __name__ == '__main__': + if not os.path.exists(PROCESSED_DATA_DIR): + os.makedirs(PROCESSED_DATA_DIR) + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + if not os.path.exists(MODEL_SAVED_DIR): + os.makedirs(MODEL_SAVED_DIR) + model_config = ModelConfig() + process_data('kegg',NEIGHBOR_SIZE['kegg'],5) + + +