Diff of /main.py [000000] .. [c0da92]

Switch to unified view

a b/main.py
1
# -*- coding: utf-8 -*-
2
3
import os
4
import gc
5
import time
6
7
import numpy as np
8
from collections import defaultdict
9
from keras import backend as K
10
from keras import optimizers
11
12
from utils import load_data, pickle_load, format_filename, write_log
13
from models import KGCN
14
from config import ModelConfig, PROCESSED_DATA_DIR,  ENTITY_VOCAB_TEMPLATE, \
15
    RELATION_VOCAB_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, LOG_DIR, PERFORMANCE_LOG, \
16
    DRUG_VOCAB_TEMPLATE
17
18
19
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
20
21
22
def get_optimizer(op_type, learning_rate):
23
    if op_type == 'sgd':
24
        return optimizers.SGD(learning_rate)
25
    elif op_type == 'rmsprop':
26
        return optimizers.RMSprop(learning_rate)
27
    elif op_type == 'adagrad':
28
        return optimizers.Adagrad(learning_rate)
29
    elif op_type == 'adadelta':
30
        return optimizers.Adadelta(learning_rate)
31
    elif op_type == 'adam':
32
        return optimizers.Adam(learning_rate, clipnorm=5)
33
    else:
34
        raise ValueError('Optimizer Not Understood: {}'.format(op_type))
35
36
37
def train(train_d,dev_d,test_d,kfold,dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type,
38
          batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=True):
39
    config = ModelConfig()
40
    config.neighbor_sample_size = neighbor_sample_size
41
    config.embed_dim = embed_dim
42
    config.n_depth = n_depth
43
    config.l2_weight = l2_weight
44
    config.dataset=dataset
45
    config.K_Fold=kfold
46
    config.lr = lr
47
    config.optimizer = get_optimizer(optimizer_type, lr)
48
    config.batch_size = batch_size
49
    config.aggregator_type = aggregator_type
50
    config.n_epoch = n_epoch
51
    config.callbacks_to_add = callbacks_to_add
52
53
    config.drug_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
54
                                                             DRUG_VOCAB_TEMPLATE,
55
                                                             dataset=dataset)))
56
    config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
57
                                                               ENTITY_VOCAB_TEMPLATE,
58
                                                               dataset=dataset)))
59
    config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
60
                                                                 RELATION_VOCAB_TEMPLATE,
61
                                                                 dataset=dataset)))
62
    config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE,
63
                                                dataset=dataset))
64
    config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE,
65
                                                  dataset=dataset))
66
67
    config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \
68
                      f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \
69
                      f'batch_size_{batch_size}_epoch_{n_epoch}'
70
    callback_str = '_' + '_'.join(config.callbacks_to_add)
71
    callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')#去掉了这两种方式使用swa得方式平均
72
    config.exp_name += callback_str
73
74
    train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type,
75
                 'epoch': n_epoch, 'learning_rate': lr}
76
    print('Logging Info - Experiment: %s' % config.exp_name)
77
    model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
78
    model = KGCN(config)
79
80
    train_data=np.array(train_d)
81
    valid_data=np.array(dev_d)
82
    test_data=np.array(test_d)
83
    if not os.path.exists(model_save_path) or overwrite:
84
        start_time = time.time()
85
        model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3],
86
                  x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3])
87
        elapsed_time = time.time() - start_time
88
        print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S",
89
                                                                 time.gmtime(elapsed_time)))
90
        train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
91
92
    print('Logging Info - Evaluate over valid data:')
93
    model.load_best_model()
94
    auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])
95
96
    print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_aupr: {aupr}'
97
          )
98
    train_log['dev_auc'] = auc
99
    train_log['dev_acc'] = acc
100
    train_log['dev_f1'] = f1
101
    train_log['dev_aupr']=aupr
102
    train_log['k_fold']=kfold
103
    train_log['dataset']=dataset
104
    train_log['aggregate_type']=config.aggregator_type
105
    if 'swa' in config.callbacks_to_add:
106
        model.load_swa_model()
107
        print('Logging Info - Evaluate over valid data based on swa model:')
108
        auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])
109
110
        train_log['swa_dev_auc'] = auc
111
        train_log['swa_dev_acc'] = acc
112
        train_log['swa_dev_f1'] = f1
113
        train_log['swa_dev_aupr']=aupr
114
        print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, swa_dev_aupr: {aupr}') #修改输出指标
115
    print('Logging Info - Evaluate over test data:')
116
    model.load_best_model()
117
    auc, acc, f1, aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
118
119
    train_log['test_auc'] = auc
120
    train_log['test_acc'] = acc
121
    train_log['test_f1'] = f1
122
    train_log['test_aupr'] =aupr
123
    print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_aupr: {aupr}')
124
    if 'swa' in config.callbacks_to_add:
125
        model.load_swa_model()
126
        print('Logging Info - Evaluate over test data based on swa model:')
127
        auc, acc, f1,aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
128
        train_log['swa_test_auc'] = auc
129
        train_log['swa_test_acc'] = acc
130
        train_log['swa_test_f1'] = f1
131
        train_log['swa_test_aupr'] = aupr
132
        print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, swa_test_aupr: {aupr}')
133
    train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
134
    write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a')
135
    del model
136
    gc.collect()
137
    K.clear_session()
138
    return train_log
139