Diff of /main.py [000000] .. [0f2bcf]

Switch to unified view

a b/main.py
1
import warnings
2
warnings.filterwarnings('ignore')
3
4
import numpy as np
5
import tensorflow as tf
6
7
import random
8
import sys, os
9
10
from sklearn.model_selection import train_test_split
11
12
import import_data as impt
13
from helper import f_get_minibatch_set, evaluate
14
from class_DeepIMV_AISTATS import DeepIMV_AISTATS
15
16
17
import argparse
18
19
def init_arg():
20
    parser = argparse.ArgumentParser()
21
    parser.add_argument('--seed', default=1234, help='random seed', type=int)
22
    
23
    parser.add_argument('--h_dim_p', default=100, help='number of hidden nodes -- predictor', type=int)
24
    parser.add_argument('--num_layers_p', default=2, help='number of layers -- predictor', type=int)
25
26
    parser.add_argument('--h_dim_e', default=100, help='number of hidden nodes -- encoder', type=int)
27
    parser.add_argument('--num_layers_e', default=3, help='number of layers -- encoder', type=int)
28
    
29
    parser.add_argument('--z_dim', default=50, help='dimension of latent representations', type=int)
30
31
    
32
    parser.add_argument("--lr_rate", default=1e-4, help='learning rate', type=float)
33
    parser.add_argument("--l1_reg", default=0., help='l1-regularization', type=float)
34
35
    parser.add_argument("--itrs", default=50000, type=int)
36
    parser.add_argument("--step_size", default=1000, type=int)
37
    parser.add_argument("--max_flag", default=20, type=int)
38
39
    parser.add_argument("--mb_size", default=32, type=int)
40
    parser.add_argument("--keep_prob", help='keep probability for dropout', default=0.7, type=float)
41
    
42
    parser.add_argument('--alpha', default=1.0, help='coefficient -- alpha', type=float)
43
    parser.add_argument('--beta', default=0.01, help='coefficient -- beta', type=float)
44
    
45
    parser.add_argument('--save_path', default='./storage/', help='path to save files', type=str)
46
47
    return parser.parse_args()
48
49
50
if __name__ == '__main__':
51
    
52
    args             = init_arg()    
53
    seed             = args.seed
54
    ### import multi-view dataset with arbitrary view-missing patterns.
55
    X_set, Y_onehot, Mask = impt.import_incomplete_handwritten()
56
    
57
    tr_X_set, te_X_set, va_X_set = {}, {}, {}
58
59
    # 64/16/20 training/validation/testing split
60
    for m in range(len(X_set)):
61
        tr_X_set[m],te_X_set[m] = train_test_split(X_set[m], test_size=0.2, random_state=seed)   
62
        tr_X_set[m],va_X_set[m] = train_test_split(tr_X_set[m], test_size=0.2, random_state=seed)
63
64
    tr_Y_onehot,te_Y_onehot, tr_M,te_M = train_test_split(Y_onehot, Mask, test_size=0.2, random_state=seed)
65
    tr_Y_onehot,va_Y_onehot, tr_M,va_M = train_test_split(tr_Y_onehot, tr_M, test_size=0.2, random_state=seed)
66
67
    x_dim_set    = [tr_X_set[m].shape[1] for m in range(len(tr_X_set))]
68
    y_dim        = np.shape(tr_Y_onehot)[1]
69
70
    if y_dim == 1:
71
        y_type       = 'continuous'
72
    elif y_dim == 2:
73
        y_type       = 'binary'
74
    else:
75
        y_type       = 'categorical'
76
    
77
    
78
    mb_size         = args.mb_size
79
    steps_per_batch = int(np.shape(tr_M)[0]/mb_size) #for moving average
80
    
81
    input_dims = {
82
        'x_dim_set': x_dim_set,
83
        'y_dim': y_dim,
84
        'y_type': y_type,
85
        'z_dim': args.z_dim,
86
87
        'steps_per_batch': steps_per_batch
88
    }
89
90
    network_settings = {
91
        'h_dim_p1': args.h_dim_p,
92
        'num_layers_p1': args.num_layers_p,   #view-specific
93
94
        'h_dim_p2': args.h_dim_p,
95
        'num_layers_p2': args.num_layers_p,  #multi-view
96
97
        'h_dim_e': args.h_dim_e,
98
        'num_layers_e': args.num_layers_e,
99
100
        'fc_activate_fn': tf.nn.relu,
101
        'reg_scale': args.l1_reg,
102
    }
103
    
104
105
    lr_rate         = args.lr_rate
106
    iteration       = args.itrs
107
    stepsize        = args.step_size
108
    max_flag        = args.max_flag
109
110
    k_prob          = args.keep_prob
111
    
112
    alpha           = args.alpha
113
    beta            = args.beta
114
    
115
    save_path       = args.save_path
116
    
117
    if not os.path.exists(save_path):
118
        os.makedirs(save_path)
119
120
121
    tf.reset_default_graph()
122
    gpu_options = tf.GPUOptions()
123
    
124
    sess  = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
125
    model = DeepIMV_AISTATS(sess, "DeepIMV_AISTATS", input_dims, network_settings)
126
    
127
128
    saver = tf.train.Saver()
129
    sess.run(tf.global_variables_initializer())
130
131
    ##### TRAINING
132
    min_loss  = 1e+8   
133
    max_acc   = 0.0
134
135
    tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0
136
    va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0
137
    
138
    stop_flag = 0
139
    for itr in range(iteration):
140
        x_mb_set, y_mb, m_mb          = f_get_minibatch_set(mb_size, tr_X_set, tr_Y_onehot, tr_M)     
141
142
        _, Lt, Lp, Lkl, Lps, Lkls, Lc = model.train(x_mb_set, y_mb, m_mb, alpha, beta, lr_rate, k_prob)
143
144
        tr_avg_Lt   += Lt/stepsize
145
        tr_avg_Lp   += Lp/stepsize
146
        tr_avg_Lkl  += Lkl/stepsize
147
        tr_avg_Lps  += Lps/stepsize
148
        tr_avg_Lkls += Lkls/stepsize
149
        tr_avg_Lc   += Lc/stepsize
150
151
152
        x_mb_set, y_mb, m_mb          = f_get_minibatch_set(min(np.shape(va_M)[0], mb_size), va_X_set, va_Y_onehot, va_M)       
153
        Lt, Lp, Lkl, Lps, Lkls, Lc, _, _    = model.get_loss(x_mb_set, y_mb, m_mb, alpha, beta)
154
155
        va_avg_Lt   += Lt/stepsize
156
        va_avg_Lp   += Lp/stepsize
157
        va_avg_Lkl  += Lkl/stepsize
158
        va_avg_Lps  += Lps/stepsize
159
        va_avg_Lkls += Lkls/stepsize
160
        va_avg_Lc   += Lc/stepsize
161
162
        if (itr+1)%stepsize == 0:
163
            y_pred, y_preds = model.predict_ys(va_X_set, va_M)
164
165
    #         score = 
166
167
            print( "{:05d}: TRAIN| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} | VALID| Lt={:.3f} Lp={:.3f} Lkl={:.3f} Lps={:.3f} Lkls={:.3f} Lc={:.3f} score={}".format(
168
                itr+1, tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc,  
169
                va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc, evaluate(va_Y_onehot, np.mean(y_preds, axis=0), y_type))
170
                 )
171
172
            if min_loss > va_avg_Lt:
173
                min_loss  = va_avg_Lt
174
                stop_flag = 0
175
                saver.save(sess, save_path  + 'best_model')
176
                print('saved...')
177
            else:
178
                stop_flag += 1
179
180
            tr_avg_Lt, tr_avg_Lp, tr_avg_Lkl, tr_avg_Lps, tr_avg_Lkls, tr_avg_Lc = 0, 0, 0, 0, 0, 0
181
            va_avg_Lt, va_avg_Lp, va_avg_Lkl, va_avg_Lps, va_avg_Lkls, va_avg_Lc = 0, 0, 0, 0, 0, 0
182
183
            if stop_flag >= max_flag:
184
                break
185
186
    print('FINISHED...')
187
    
188
    
189
    ##### TESTING
190
    saver.restore(sess, save_path  + 'best_model')
191
    
192
    _, pred_ys = model.predict_ys(te_X_set, te_M)
193
    pred_y = np.mean(pred_ys, axis=0)
194
195
    print('Test Score: {}'.format(evaluate(te_Y_onehot, pred_y, y_type)))