|
a |
|
b/config.py |
|
|
1 |
''' |
|
|
2 |
@Author: your name |
|
|
3 |
@Date: 2019-12-20 19:02:25 |
|
|
4 |
@LastEditTime: 2020-05-26 20:58:12 |
|
|
5 |
@LastEditors: Please set LastEditors |
|
|
6 |
@Description: In User Settings Edit |
|
|
7 |
@FilePath: /matengfei/KGCN_Keras-master/config.py |
|
|
8 |
''' |
|
|
9 |
# -*- coding: utf-8 -*- |
|
|
10 |
|
|
|
11 |
import os |
|
|
12 |
|
|
|
13 |
RAW_DATA_DIR = os.getcwd()+'/raw_data' |
|
|
14 |
PROCESSED_DATA_DIR = os.getcwd()+'/data' |
|
|
15 |
LOG_DIR = os.getcwd()+'/log' |
|
|
16 |
MODEL_SAVED_DIR = os.getcwd()+'/ckpt' |
|
|
17 |
|
|
|
18 |
KG_FILE = { |
|
|
19 |
'drugbank':os.path.join(RAW_DATA_DIR,'drugbank','train2id.txt'), |
|
|
20 |
'kegg':os.path.join(RAW_DATA_DIR,'kegg','train2id.txt')} |
|
|
21 |
ENTITY2ID_FILE = { |
|
|
22 |
'drugbank':os.path.join(RAW_DATA_DIR,'drugbank','entity2id.txt'), |
|
|
23 |
'kegg':os.path.join(RAW_DATA_DIR,'kegg','entity2id.txt')} |
|
|
24 |
EXAMPLE_FILE = { |
|
|
25 |
'drugbank':os.path.join(RAW_DATA_DIR,'drugbank','approved_example.txt'), |
|
|
26 |
'kegg':os.path.join(RAW_DATA_DIR,'kegg','approved_example.txt')} |
|
|
27 |
SEPARATOR = {'drug':'\t','kegg':'\t'} |
|
|
28 |
THRESHOLD = {'drug':4,'kegg':4} #添加drug修改 |
|
|
29 |
NEIGHBOR_SIZE = {'drug':4,'kegg':4} |
|
|
30 |
|
|
|
31 |
# |
|
|
32 |
DRUG_VOCAB_TEMPLATE = '{dataset}_drug_vocab.pkl' |
|
|
33 |
ENTITY_VOCAB_TEMPLATE = '{dataset}_entity_vocab.pkl' |
|
|
34 |
RELATION_VOCAB_TEMPLATE = '{dataset}_relation_vocab.pkl' |
|
|
35 |
ADJ_ENTITY_TEMPLATE = '{dataset}_adj_entity.npy' |
|
|
36 |
ADJ_RELATION_TEMPLATE = '{dataset}_adj_relation.npy' |
|
|
37 |
TRAIN_DATA_TEMPLATE = '{dataset}_train.npy' |
|
|
38 |
DEV_DATA_TEMPLATE = '{dataset}_dev.npy' |
|
|
39 |
TEST_DATA_TEMPLATE = '{dataset}_test.npy' |
|
|
40 |
#RESULT_LOG='result.txt' |
|
|
41 |
RESULT_LOG={'drugbank':'drugbank_result.txt','kegg':'kegg_result.txt'} |
|
|
42 |
PERFORMANCE_LOG = 'kgcn_performance.log' |
|
|
43 |
DRUG_EXAMPLE='{dataset}_examples.npy' |
|
|
44 |
|
|
|
45 |
class ModelConfig(object): |
|
|
46 |
def __init__(self): |
|
|
47 |
self.neighbor_sample_size = 4 # neighbor sampling size |
|
|
48 |
self.embed_dim = 32 # dimension of embedding |
|
|
49 |
self.n_depth = 2 # depth of receptive field |
|
|
50 |
self.l2_weight = 1e-7 # l2 regularizer weight |
|
|
51 |
self.lr = 2e-2 # learning rate |
|
|
52 |
self.batch_size = 65536 |
|
|
53 |
self.aggregator_type = 'sum' |
|
|
54 |
self.n_epoch = 50 |
|
|
55 |
self.optimizer = 'adam' |
|
|
56 |
|
|
|
57 |
self.drug_vocab_size = None |
|
|
58 |
self.entity_vocab_size = None |
|
|
59 |
self.relation_vocab_size = None |
|
|
60 |
self.adj_entity = None |
|
|
61 |
self.adj_relation = None |
|
|
62 |
|
|
|
63 |
self.exp_name = None |
|
|
64 |
self.model_name = None |
|
|
65 |
|
|
|
66 |
# checkpoint configuration 设置检查点 |
|
|
67 |
self.checkpoint_dir = MODEL_SAVED_DIR |
|
|
68 |
self.checkpoint_monitor = 'val_auc' |
|
|
69 |
self.checkpoint_save_best_only = True |
|
|
70 |
self.checkpoint_save_weights_only = True |
|
|
71 |
self.checkpoint_save_weights_mode = 'max' |
|
|
72 |
self.checkpoint_verbose = 1 |
|
|
73 |
|
|
|
74 |
# early_stoping configuration |
|
|
75 |
self.early_stopping_monitor = 'val_auc' |
|
|
76 |
self.early_stopping_mode = 'max' |
|
|
77 |
self.early_stopping_patience = 5 |
|
|
78 |
self.early_stopping_verbose = 1 |
|
|
79 |
self.dataset='drug' |
|
|
80 |
self.K_Fold=1 |
|
|
81 |
self.callbacks_to_add = None |
|
|
82 |
|
|
|
83 |
# config for learning rating scheduler and ensembler |
|
|
84 |
self.swa_start = 3 |