|
a |
|
b/main.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
|
|
|
3 |
import os |
|
|
4 |
import gc |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import numpy as np |
|
|
8 |
from collections import defaultdict |
|
|
9 |
from keras import backend as K |
|
|
10 |
from keras import optimizers |
|
|
11 |
|
|
|
12 |
from utils import load_data, pickle_load, format_filename, write_log |
|
|
13 |
from models import KGCN |
|
|
14 |
from config import ModelConfig, PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, \ |
|
|
15 |
RELATION_VOCAB_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, LOG_DIR, PERFORMANCE_LOG, \ |
|
|
16 |
DRUG_VOCAB_TEMPLATE |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
os.environ['CUDA_VISIBLE_DEVICES'] = '1' |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def get_optimizer(op_type, learning_rate): |
|
|
23 |
if op_type == 'sgd': |
|
|
24 |
return optimizers.SGD(learning_rate) |
|
|
25 |
elif op_type == 'rmsprop': |
|
|
26 |
return optimizers.RMSprop(learning_rate) |
|
|
27 |
elif op_type == 'adagrad': |
|
|
28 |
return optimizers.Adagrad(learning_rate) |
|
|
29 |
elif op_type == 'adadelta': |
|
|
30 |
return optimizers.Adadelta(learning_rate) |
|
|
31 |
elif op_type == 'adam': |
|
|
32 |
return optimizers.Adam(learning_rate, clipnorm=5) |
|
|
33 |
else: |
|
|
34 |
raise ValueError('Optimizer Not Understood: {}'.format(op_type)) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def train(train_d,dev_d,test_d,kfold,dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type, |
|
|
38 |
batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=True): |
|
|
39 |
config = ModelConfig() |
|
|
40 |
config.neighbor_sample_size = neighbor_sample_size |
|
|
41 |
config.embed_dim = embed_dim |
|
|
42 |
config.n_depth = n_depth |
|
|
43 |
config.l2_weight = l2_weight |
|
|
44 |
config.dataset=dataset |
|
|
45 |
config.K_Fold=kfold |
|
|
46 |
config.lr = lr |
|
|
47 |
config.optimizer = get_optimizer(optimizer_type, lr) |
|
|
48 |
config.batch_size = batch_size |
|
|
49 |
config.aggregator_type = aggregator_type |
|
|
50 |
config.n_epoch = n_epoch |
|
|
51 |
config.callbacks_to_add = callbacks_to_add |
|
|
52 |
|
|
|
53 |
config.drug_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, |
|
|
54 |
DRUG_VOCAB_TEMPLATE, |
|
|
55 |
dataset=dataset))) |
|
|
56 |
config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, |
|
|
57 |
ENTITY_VOCAB_TEMPLATE, |
|
|
58 |
dataset=dataset))) |
|
|
59 |
config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR, |
|
|
60 |
RELATION_VOCAB_TEMPLATE, |
|
|
61 |
dataset=dataset))) |
|
|
62 |
config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE, |
|
|
63 |
dataset=dataset)) |
|
|
64 |
config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE, |
|
|
65 |
dataset=dataset)) |
|
|
66 |
|
|
|
67 |
config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \ |
|
|
68 |
f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \ |
|
|
69 |
f'batch_size_{batch_size}_epoch_{n_epoch}' |
|
|
70 |
callback_str = '_' + '_'.join(config.callbacks_to_add) |
|
|
71 |
callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')#去掉了这两种方式使用swa得方式平均 |
|
|
72 |
config.exp_name += callback_str |
|
|
73 |
|
|
|
74 |
train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, |
|
|
75 |
'epoch': n_epoch, 'learning_rate': lr} |
|
|
76 |
print('Logging Info - Experiment: %s' % config.exp_name) |
|
|
77 |
model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) |
|
|
78 |
model = KGCN(config) |
|
|
79 |
|
|
|
80 |
train_data=np.array(train_d) |
|
|
81 |
valid_data=np.array(dev_d) |
|
|
82 |
test_data=np.array(test_d) |
|
|
83 |
if not os.path.exists(model_save_path) or overwrite: |
|
|
84 |
start_time = time.time() |
|
|
85 |
model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3], |
|
|
86 |
x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3]) |
|
|
87 |
elapsed_time = time.time() - start_time |
|
|
88 |
print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S", |
|
|
89 |
time.gmtime(elapsed_time))) |
|
|
90 |
train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) |
|
|
91 |
|
|
|
92 |
print('Logging Info - Evaluate over valid data:') |
|
|
93 |
model.load_best_model() |
|
|
94 |
auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) |
|
|
95 |
|
|
|
96 |
print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_aupr: {aupr}' |
|
|
97 |
) |
|
|
98 |
train_log['dev_auc'] = auc |
|
|
99 |
train_log['dev_acc'] = acc |
|
|
100 |
train_log['dev_f1'] = f1 |
|
|
101 |
train_log['dev_aupr']=aupr |
|
|
102 |
train_log['k_fold']=kfold |
|
|
103 |
train_log['dataset']=dataset |
|
|
104 |
train_log['aggregate_type']=config.aggregator_type |
|
|
105 |
if 'swa' in config.callbacks_to_add: |
|
|
106 |
model.load_swa_model() |
|
|
107 |
print('Logging Info - Evaluate over valid data based on swa model:') |
|
|
108 |
auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3]) |
|
|
109 |
|
|
|
110 |
train_log['swa_dev_auc'] = auc |
|
|
111 |
train_log['swa_dev_acc'] = acc |
|
|
112 |
train_log['swa_dev_f1'] = f1 |
|
|
113 |
train_log['swa_dev_aupr']=aupr |
|
|
114 |
print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, swa_dev_aupr: {aupr}') #修改输出指标 |
|
|
115 |
print('Logging Info - Evaluate over test data:') |
|
|
116 |
model.load_best_model() |
|
|
117 |
auc, acc, f1, aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) |
|
|
118 |
|
|
|
119 |
train_log['test_auc'] = auc |
|
|
120 |
train_log['test_acc'] = acc |
|
|
121 |
train_log['test_f1'] = f1 |
|
|
122 |
train_log['test_aupr'] =aupr |
|
|
123 |
print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_aupr: {aupr}') |
|
|
124 |
if 'swa' in config.callbacks_to_add: |
|
|
125 |
model.load_swa_model() |
|
|
126 |
print('Logging Info - Evaluate over test data based on swa model:') |
|
|
127 |
auc, acc, f1,aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3]) |
|
|
128 |
train_log['swa_test_auc'] = auc |
|
|
129 |
train_log['swa_test_acc'] = acc |
|
|
130 |
train_log['swa_test_f1'] = f1 |
|
|
131 |
train_log['swa_test_aupr'] = aupr |
|
|
132 |
print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, swa_test_aupr: {aupr}') |
|
|
133 |
train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) |
|
|
134 |
write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a') |
|
|
135 |
del model |
|
|
136 |
gc.collect() |
|
|
137 |
K.clear_session() |
|
|
138 |
return train_log |
|
|
139 |
|