Diff of /run.py [000000] .. [c0da92]

Switch to side-by-side view

--- a
+++ b/run.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+import sys
+import random
+import os
+import numpy as np
+from collections import defaultdict
+sys.path.append(os.getcwd()) #add the env path
+from sklearn.model_selection import train_test_split,StratifiedKFold
+from main import train
+
+from config import DRUG_EXAMPLE, RESULT_LOG, PROCESSED_DATA_DIR, LOG_DIR, MODEL_SAVED_DIR, ENTITY2ID_FILE, KG_FILE, \
+    EXAMPLE_FILE,  DRUG_VOCAB_TEMPLATE, ENTITY_VOCAB_TEMPLATE, \
+    RELATION_VOCAB_TEMPLATE, SEPARATOR, THRESHOLD, TRAIN_DATA_TEMPLATE, DEV_DATA_TEMPLATE, \
+    TEST_DATA_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, ModelConfig, NEIGHBOR_SIZE
+from utils import pickle_dump, format_filename,write_log,pickle_load
+
+def read_entity2id_file(file_path: str, drug_vocab: dict, entity_vocab: dict):
+    print(f'Logging Info - Reading entity2id file: {file_path}' )
+    assert len(drug_vocab) == 0 and len(entity_vocab) == 0
+    with open(file_path, encoding='utf8') as reader:
+        count=0
+        for line in reader:
+            if(count==0):
+                count+=1
+                continue
+            drug, entity = line.strip().split('\t')
+            drug_vocab[entity]=len(drug_vocab) 
+            entity_vocab[entity] = len(entity_vocab)
+
+def read_example_file(file_path:str,separator:str,drug_vocab:dict):
+    print(f'Logging Info - Reading example file: {file_path}')
+    assert len(drug_vocab)>0
+    examples=[]
+    with open(file_path,encoding='utf8') as reader:
+        for idx,line in enumerate(reader):
+            d1,d2,flag=line.strip().split(separator)[:3]
+            if d1 not in drug_vocab or d2 not in drug_vocab:
+                continue
+            if d1 in drug_vocab and d2 in drug_vocab:
+                examples.append([drug_vocab[d1],drug_vocab[d2],int(flag)])
+    
+    examples_matrix=np.array(examples)
+    print(f'size of example: {examples_matrix.shape}')
+    X=examples_matrix[:,:2]
+    y=examples_matrix[:,2:3]
+    train_data_X, valid_data_X,train_y,val_y = train_test_split(X,y, test_size=0.2,stratify=y)
+    train_data=np.c_[train_data_X,train_y]
+    valid_data_X, test_data_X,val_y,test_y = train_test_split(valid_data_X,val_y, test_size=0.5)
+    valid_data=np.c_[valid_data_X,val_y]
+    test_data=np.c_[test_data_X,test_y]
+    return examples_matrix
+
+def read_kg(file_path: str, entity_vocab: dict, relation_vocab: dict, neighbor_sample_size: int):
+    print(f'Logging Info - Reading kg file: {file_path}')
+
+    kg = defaultdict(list)
+    with open(file_path, encoding='utf8') as reader:
+        count=0
+        for line in reader:
+            if count==0:
+                count+=1
+                continue
+            head, tail, relation = line.strip().split(' ') 
+
+            if head not in entity_vocab:
+                entity_vocab[head] = len(entity_vocab)
+            if tail not in entity_vocab:
+                entity_vocab[tail] = len(entity_vocab)
+            if relation not in relation_vocab:
+                relation_vocab[relation] = len(relation_vocab)
+
+            # undirected graph
+            kg[entity_vocab[head]].append((entity_vocab[tail], relation_vocab[relation]))
+            kg[entity_vocab[tail]].append((entity_vocab[head], relation_vocab[relation]))
+    print(f'Logging Info - num of entities: {len(entity_vocab)}, '
+          f'num of relations: {len(relation_vocab)}')
+
+    print('Logging Info - Constructing adjacency matrix...')
+    n_entity = len(entity_vocab)
+    adj_entity = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64)
+    adj_relation = np.zeros(shape=(n_entity, neighbor_sample_size), dtype=np.int64)
+
+    for entity_id in range(n_entity):
+        all_neighbors = kg[entity_id]
+        n_neighbor = len(all_neighbors)
+        sample_indices = np.random.choice(
+            n_neighbor,
+            neighbor_sample_size,
+            replace=False if n_neighbor >= neighbor_sample_size else True
+        )
+
+        adj_entity[entity_id] = np.array([all_neighbors[i][0] for i in sample_indices])
+        adj_relation[entity_id] = np.array([all_neighbors[i][1] for i in sample_indices])
+
+    return adj_entity, adj_relation
+
+
+def process_data(dataset: str, neighbor_sample_size: int,K:int):
+    drug_vocab = {}
+    entity_vocab = {}
+    relation_vocab = {}
+
+    read_entity2id_file(ENTITY2ID_FILE[dataset], drug_vocab, entity_vocab)
+
+    pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),drug_vocab)
+    pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),entity_vocab)
+
+    examples_file=format_filename(PROCESSED_DATA_DIR, DRUG_EXAMPLE, dataset=dataset)
+    examples = read_example_file(EXAMPLE_FILE[dataset], SEPARATOR[dataset],drug_vocab)
+    np.save(examples_file,examples)
+          
+    adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)
+    adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)
+    
+    adj_entity, adj_relation = read_kg(KG_FILE[dataset], entity_vocab, relation_vocab,
+                                       neighbor_sample_size)
+
+    pickle_dump(format_filename(PROCESSED_DATA_DIR, DRUG_VOCAB_TEMPLATE, dataset=dataset),
+                drug_vocab)
+    pickle_dump(format_filename(PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, dataset=dataset),
+                entity_vocab)
+    pickle_dump(format_filename(PROCESSED_DATA_DIR, RELATION_VOCAB_TEMPLATE, dataset=dataset),
+                relation_vocab)
+    adj_entity_file = format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, dataset=dataset)
+    np.save(adj_entity_file, adj_entity)
+    print('Logging Info - Saved:', adj_entity_file)
+
+    adj_relation_file = format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, dataset=dataset)
+    np.save(adj_relation_file, adj_relation)
+    print('Logging Info - Saved:', adj_entity_file)
+    cross_validation(K,examples,dataset,neighbor_sample_size)
+
+
+def cross_validation(K_fold,examples,dataset,neighbor_sample_size):
+    subsets=dict()
+    n_subsets=int(len(examples)/K_fold)
+    remain=set(range(0,len(examples)-1))
+    for i in reversed(range(0,K_fold-1)):
+        subsets[i]=random.sample(remain,n_subsets)
+        remain=remain.difference(subsets[i])
+    subsets[K_fold-1]=remain
+    aggregator_types=['sum','concat','neigh']
+    for t in aggregator_types:
+        count=1
+        temp={'dataset':dataset,'aggregator_type':t,'avg_auc':0.0,'avg_acc':0.0,'avg_f1':0.0,'avg_aupr':0.0}
+        for i in reversed(range(0,K_fold)):
+            test_d=examples[list(subsets[i])]
+            val_d,test_data=train_test_split(test_d,test_size=0.5)
+            train_d=[]
+            for j in range(0,K_fold):
+                if i!=j:
+                    train_d.extend(examples[list(subsets[j])])
+            train_data=np.array(train_d)               
+            train_log=train(
+            kfold=count,
+            dataset=dataset,
+            train_d=train_data,
+            dev_d=val_d,
+            test_d=test_data,
+            neighbor_sample_size=neighbor_sample_size,
+            embed_dim=32,
+            n_depth=2,
+            l2_weight=1e-7,
+            lr=2e-2,
+            #lr=5e-3,
+            optimizer_type='adam',
+            batch_size=2048,
+            aggregator_type=t,
+            n_epoch=50,
+            callbacks_to_add=['modelcheckpoint', 'earlystopping']
+            )     
+            count+=1
+            temp['avg_auc']=temp['avg_auc']+train_log['test_auc']
+            temp['avg_acc']=temp['avg_acc']+train_log['test_acc']
+            temp['avg_f1']=temp['avg_f1']+train_log['test_f1']
+            temp['avg_aupr']=temp['avg_aupr']+train_log['test_aupr']
+        for key in temp:
+            if key=='aggregator_type' or key=='dataset':
+                continue
+            temp[key]=temp[key]/K_fold
+        write_log(format_filename(LOG_DIR, RESULT_LOG[dataset]),temp,'a')
+        print(f'Logging Info - {K_fold} fold result: avg_auc: {temp["avg_auc"]}, avg_acc: {temp["avg_acc"]}, avg_f1: {temp["avg_f1"]}, avg_aupr: {temp["avg_aupr"]}')
+   
+if __name__ == '__main__':
+    if not os.path.exists(PROCESSED_DATA_DIR):
+        os.makedirs(PROCESSED_DATA_DIR)
+    if not os.path.exists(LOG_DIR):
+        os.makedirs(LOG_DIR)
+    if not os.path.exists(MODEL_SAVED_DIR):
+        os.makedirs(MODEL_SAVED_DIR)
+    model_config = ModelConfig()
+    process_data('kegg',NEIGHBOR_SIZE['kegg'],5)
+
+
+