Diff of /code/model.py [000000] .. [ebf7be]

Switch to unified view

a b/code/model.py
1
# -*- coding:utf-8 -*-
2
"""
3
 the idea of this script came from LUNA2016 champion paper.
4
 This model composed of three network,namely Archi-1(size of 10x10x6),Archi-2(size of 30x30x10),Archi-3(size of 40x40x26)
5
6
"""
7
import tensorflow as tf
8
from data_prepare import get_train_batch,get_all_filename,get_test_batch
9
import random
10
import time
11
12
class model(object):
13
14
    def __init__(self,learning_rate,keep_prob,batch_size,epoch):
15
        print(" network begin...")
16
        self.learning_rate = learning_rate
17
        self.keep_prob = keep_prob
18
        self.batch_size = batch_size
19
        self.epoch = epoch
20
21
        self.cubic_shape = [[6, 20, 20], [10, 30, 30], [26, 40, 40]]
22
23
    def archi_1(self,input,keep_prob):
24
        with tf.name_scope("Archi-1"):
25
            # input size is batch_sizex20x20x6
26
            # 5x5x3 is the kernel size of conv1,1 is the input depth,64 is the number output channel
27
            w_conv1 = tf.Variable(tf.random_normal([3,5,5,1,32],stddev=0.001),dtype=tf.float32,name='w_conv1')
28
            b_conv1 = tf.Variable(tf.constant(0.01,shape=[32]),dtype=tf.float32,name='b_conv1')
29
            out_conv1 = tf.nn.relu(tf.add(tf.nn.conv3d(input,w_conv1,strides=[1,1,1,1,1],padding='SAME'),b_conv1))
30
            out_conv1 = tf.nn.dropout(out_conv1,keep_prob)
31
            out_conv1_shape = tf.shape(out_conv1)
32
33
            # max pooling ,pooling layer has no effect on the data size
34
            hidden_conv1 = tf.nn.max_pool3d(out_conv1,strides=[1,2,2,2,1],ksize=[1,2,2,2,1],padding='SAME')
35
            hidden_conv1_shape = tf.shape(hidden_conv1)
36
37
            # after conv1 ,the output size is batch_sizex4x16x16x64([batch_size,in_deep,width,height,output_deep])
38
            w_conv2 = tf.Variable(tf.random_normal([3,5, 5, 32,64], stddev=0.001), dtype=tf.float32,name='w_conv2')
39
            b_conv2 = tf.Variable(tf.constant(0.01, shape=[64]), dtype=tf.float32, name='b_conv2')
40
            out_conv2 = tf.nn.relu(tf.add(tf.nn.conv3d(hidden_conv1, w_conv2, strides=[1, 1, 1,1, 1], padding='SAME'), b_conv2))
41
            out_conv2 = tf.nn.dropout(out_conv2, keep_prob)
42
            out_conv2_shape = tf.shape(out_conv2)
43
44
            hidden_conv2 = tf.nn.max_pool3d(out_conv2,strides=[1,2,2,2,1],ksize=[1,2,2,2,1],padding='SAME')
45
            hidden_conv2_shape = tf.shape(hidden_conv2)
46
47
            # after conv2 ,the output size is batch_sizex2x12x12x64([batch_size,in_deep,width,height,output_deep])
48
            w_conv3 = tf.Variable(tf.random_normal([1,5, 5, 64,64], stddev=0.001), dtype=tf.float32, name='w_conv3')
49
            b_conv3 = tf.Variable(tf.constant(0.01, shape=[64]), dtype=tf.float32, name='b_conv3')
50
            out_conv3 = tf.nn.relu(tf.add(tf.nn.conv3d(hidden_conv2, w_conv3, strides=[1, 1, 1, 1,1], padding='SAME'),b_conv3))
51
            out_conv3 = tf.nn.dropout(out_conv3, keep_prob)
52
            out_conv3_shape = tf.shape(out_conv3)
53
            tf.summary.scalar('out_conv3_shape', out_conv3_shape[0])
54
55
            # after conv2 ,the output size is batch_sizex2x8x8x64([batch_size,in_deep,width,height,output_deep])
56
            # all feature map flatten to one dimension vector,this vector will be much long
57
            out_conv3 = tf.reshape(out_conv3,[-1,64*10*10*7])
58
            w_fc1 = tf.Variable(tf.random_normal([64*10*10*7,2048],stddev=0.001),name='w_fc1')
59
            b_fc1 = tf.Variable(tf.constant(0.01, shape=[2048]), dtype=tf.float32, name='b_fc1')
60
            out_fc1 = tf.nn.relu(tf.add(tf.matmul(out_conv3,w_fc1),b_fc1))
61
            out_fc1 = tf.nn.dropout(out_fc1,keep_prob)
62
            out_fc1_shape = tf.shape(out_fc1)
63
            tf.summary.scalar('out_fc1_shape', out_fc1_shape[0])
64
65
            w_fc2 = tf.Variable(tf.random_normal([2048, 256], stddev=0.001), name='w_fc2')
66
            b_fc2 = tf.Variable(tf.constant(0.01, shape=[256]), dtype=tf.float32, name='b_fc2')
67
            out_fc2 = tf.nn.relu(tf.add(tf.matmul(out_fc1, w_fc2), b_fc2))
68
            out_fc2 = tf.nn.dropout(out_fc2, keep_prob)
69
70
            w_fc3 = tf.Variable(tf.random_normal([256, 2], stddev=0.001), name='w_fc3')
71
            b_fc3 = tf.Variable(tf.constant(0.01, shape=[2]), dtype=tf.float32, name='b_fc3')
72
            out_fc3 = tf.nn.relu(tf.add(tf.matmul(out_fc2, w_fc3), b_fc3))
73
            out_fc3 = tf.nn.dropout(out_fc3, keep_prob)
74
75
            w_sm = tf.Variable(tf.random_normal([2, 2], stddev=0.001), name='w_sm')
76
            b_sm = tf.constant(0.001, shape=[2])
77
            out_sm = tf.nn.softmax(tf.add(tf.matmul(out_fc3, w_sm), b_sm))
78
79
            return out_sm
80
81
    def inference(self,npy_path,test_path,model_index,train_flag=True):
82
83
        # some statistic index
84
        highest_acc = 0.0
85
        highest_iterator = 1
86
87
        all_filenames = get_all_filename(npy_path,self.cubic_shape[model_index][1])
88
        
89
        # how many time should one epoch should loop to feed all data
90
        times = int(len(all_filenames) / self.batch_size)
91
        if (len(all_filenames) % self.batch_size) != 0:
92
            times = times + 1
93
94
        # keep_prob used for dropout
95
        keep_prob = tf.placeholder(tf.float32)
96
        # take placeholder as input
97
        x = tf.placeholder(tf.float32, [None, self.cubic_shape[model_index][0], self.cubic_shape[model_index][1], self.cubic_shape[model_index][2]])
98
        x_image = tf.reshape(x, [-1, self.cubic_shape[model_index][0], self.cubic_shape[model_index][1], self.cubic_shape[model_index][2], 1])
99
        net_out = self.archi_1(x_image,keep_prob)
100
101
        saver = tf.train.Saver()  # default to save all variable,save mode or restore from path
102
103
        if train_flag:
104
            # softmax layer
105
            real_label = tf.placeholder(tf.float32, [None, 2])
106
            cross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=net_out, labels=real_label))
107
            net_loss = tf.reduce_mean(cross_entropy)
108
109
            train_step = tf.train.MomentumOptimizer(self.learning_rate, 0.9).minimize(net_loss)
110
111
            correct_prediction = tf.equal(tf.argmax(net_out, 1), tf.argmax(real_label, 1))
112
            accruacy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
113
114
            merged = tf.summary.merge_all()
115
116
            with tf.Session() as sess:
117
                sess.run(tf.global_variables_initializer())
118
                train_writer = tf.summary.FileWriter('./tensorboard/', sess.graph)
119
                # loop epoches
120
                for i in range(self.epoch):
121
                    epoch_start =time.time()
122
                    #  the data will be shuffled by every epoch
123
                    random.shuffle(all_filenames)
124
                    for t in range(times):
125
                        batch_files = all_filenames[t*self.batch_size:(t+1)*self.batch_size]
126
                        batch_data, batch_label = get_train_batch(batch_files)
127
                        feed_dict = {x: batch_data, real_label: batch_label,
128
                                     keep_prob: self.keep_prob}
129
                        _,summary = sess.run([train_step, merged],feed_dict =feed_dict)
130
                        train_writer.add_summary(summary, i)
131
                        saver.save(sess, './ckpt/archi-1', global_step=i + 1)
132
133
                    epoch_end = time.time()
134
                    test_batch,test_label = get_test_batch(test_path)
135
                    test_dict = {x: test_batch, real_label: test_label, keep_prob:self.keep_prob}
136
                    acc_test,loss = sess.run([accruacy,net_loss],feed_dict=test_dict)
137
                    print('accuracy  is %f' % acc_test)
138
                    print("loss is ", loss)
139
                    print(" epoch %d time consumed %f seconds"%(i,(epoch_end-epoch_start)))
140
                    if (acc_test > highest_acc):
141
                        highest_acc = acc_test
142
                        highest_iterator = i
143
144
            print("training finshed..highest accuracy is %f,the iterator is %d " % (highest_acc, highest_iterator))