Diff of /Net.py [000000] .. [6d7935]

Switch to unified view

a b/Net.py
1
"""
2
3
Stefania Fresca, MOX Laboratory, Politecnico di Milano
4
April 2019
5
6
"""
7
8
import tensorflow as tf
9
import numpy as np
10
import scipy.io as sio
11
import time
12
import os
13
14
import utils
15
16
seed = 374
17
np.random.seed(seed)
18
19
class Net:
20
    def __init__(self, config):
21
        self.lr = config['lr']
22
        self.batch_size = config['batch_size']
23
        self.g_step = tf.Variable(0, dtype = tf.int32, trainable = False, name = 'global_step')
24
25
        self.n_data = config['n_data']
26
        self.n_train = int(0.8 * self.n_data)
27
        self.N_h = config['N_h']
28
        self.N_t = config['N_t']
29
30
        self.train_mat = config['train_mat']
31
        self.test_mat = config['test_mat']
32
        self.train_params = config['train_params']
33
        self.test_params = config['test_params']
34
35
        self.omega_h = config['omega_h']
36
        self.omega_n = config['omega_n']
37
38
        self.checkpoints_folder = config['checkpoints_folder']
39
        self.graph_folder = config['graph_folder']
40
        self.large = config['large']
41
        self.zero_padding = config['zero_padding']
42
        self.p = config['p']
43
        self.restart = config['restart']
44
45
    def get_data(self):
46
        with tf.name_scope('data'):
47
            self.X = tf.placeholder(tf.float32, shape = [None, self.N_h])
48
            self.Y = tf.placeholder(tf.float32, shape = [None, self.n_params])
49
50
            dataset = tf.data.Dataset.from_tensor_slices((self.X, self.Y))
51
            dataset = dataset.shuffle(self.n_data)
52
            dataset = dataset.batch(self.batch_size)
53
54
            iterator = dataset.make_initializable_iterator()
55
            self.init = iterator.initializer
56
57
            input, self.params = iterator.get_next()
58
            self.input = tf.reshape(input, shape = [-1, int(np.sqrt(self.N_h)), int(np.sqrt(self.N_h)), 1])
59
60
    def inference(self):
61
        raise NotImplementedError("Must be overridden with proper definition of forward path")
62
63
    def loss(self, u_h, u_n):
64
        with tf.name_scope('loss'):
65
            output = tf.reshape(self.input, shape = [-1, self.N_h])
66
            self.loss_h = self.omega_h * tf.reduce_mean(tf.reduce_sum(tf.pow(output - u_h, 2), axis = 1))
67
            self.loss_n = self.omega_n * tf.reduce_mean(tf.reduce_sum(tf.pow(self.enc - u_n, 2), axis = 1))
68
            self.loss = self.loss_h + self.loss_n
69
70
    def optimize(self):
71
        self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step = self.g_step)
72
73
    def summary(self):
74
        with tf.name_scope('summaries'):
75
            self.summary = tf.summary.scalar('loss', self.loss)
76
77
    def build(self):
78
        self.get_data()
79
        self.inference()
80
        self.loss(self.u_h, self.u_n)
81
        self.optimize()
82
        self.summary()
83
84
    def train_one_epoch(self, sess, init, writer, epoch, step):
85
        start_time = time.time()
86
        sess.run(init, feed_dict = {self.X : self.S_train, self.Y : self.params_train})
87
        total_loss_h = 0
88
        total_loss_n = 0
89
        total_loss = 0
90
        n_batches = 0
91
        print('------------ TRAINING -------------', flush = True)
92
        try:
93
            while True:
94
                _, l_h, l_n, l, summary = sess.run([self.opt, self.loss_h, self.loss_n, self.loss, self.summary])
95
                writer.add_summary(summary, global_step = step)
96
                step += 1
97
                total_loss_h += l_h
98
                total_loss_n += l_n
99
                total_loss += l
100
                n_batches += 1
101
        except tf.errors.OutOfRangeError:
102
            pass
103
        print('Average loss_h at epoch {0} on training set: {1}'.format(epoch, total_loss_h / n_batches))
104
        print('Average loss_n at epoch {0} on training set: {1}'.format(epoch, total_loss_n / n_batches))
105
        print('Average loss at epoch {0} on training set: {1}'.format(epoch, total_loss / n_batches))
106
        print('Took: {0} seconds'.format(time.time() - start_time))
107
        return step
108
109
    def eval_once(self, sess, saver, init, writer, epoch, step):
110
        start_time = time.time()
111
        sess.run(init, feed_dict = {self.X : self.S_val, self.Y : self.params_val})
112
        total_loss_h = 0
113
        total_loss_n = 0
114
        total_loss = 0
115
        n_batches = 0
116
        print('------------ VALIDATION ------------')
117
        try:
118
            while True:
119
                l_h, l_n, l, summary = sess.run([self.loss_h, self.loss_n, self.loss, self.summary])
120
                writer.add_summary(summary, global_step = step)
121
                total_loss_h += l_h
122
                total_loss_n += l_n
123
                total_loss += l
124
                n_batches += 1
125
        except tf.errors.OutOfRangeError:
126
            pass
127
        total_loss_mean = total_loss / n_batches
128
        if total_loss_mean < self.loss_best:
129
            saver.save(sess, self.checkpoints_folder + '/Net', step)
130
        print('Average loss_h at epoch {0} on validation set: {1}'.format(epoch, total_loss_h / n_batches))
131
        print('Average loss_n at epoch {0} on validation set: {1}'.format(epoch, total_loss_n / n_batches))
132
        print('Average loss at epoch {0} on validation set: {1}'.format(epoch, total_loss_mean))
133
        print('Took: {0} seconds'.format(time.time() - start_time))
134
        return total_loss_mean
135
136
    def test_once(self, sess, init):
137
        start_time = time.time()
138
        sess.run(init, feed_dict = {self.X : self.S_test, self.Y : self.params_test})
139
        total_loss_h = 0
140
        total_loss_n = 0
141
        total_loss = 0
142
        n_batches = 0
143
        self.U_h = np.zeros(self.S_test.shape)
144
        print('------------ TESTING ------------')
145
        try:
146
            while True:
147
                l_h, l_n, l, u_h = sess.run([self.loss_h, self.loss_n, self.loss, self.u_h])
148
                self.U_h[self.batch_size * n_batches : self.batch_size * (n_batches + 1)] = u_h
149
                total_loss_h += l_h
150
                total_loss_n += l_n
151
                total_loss += l
152
                n_batches += 1
153
        except tf.errors.OutOfRangeError:
154
            pass
155
        print('Average loss_h on testing set: {0}'.format(total_loss_h / n_batches))
156
        print('Average loss_N on testing set: {0}'.format(total_loss_n / n_batches))
157
        print('Average loss on testing set: {0}'.format(total_loss / n_batches))
158
        print('Took: {0} seconds'.format(time.time() - start_time))
159
160
    #@profile (if memory profiling must be used)
161
    def train_all(self, n_epochs):
162
        if (not self.restart):
163
            utils.safe_mkdir(self.checkpoints_folder)
164
        saver = tf.train.Saver()
165
        train_writer = tf.summary.FileWriter('./' + self.graph_folder + '/train', tf.get_default_graph())
166
        test_writer = tf.summary.FileWriter('./' + self.graph_folder + '/test', tf.get_default_graph())
167
168
        print('Loading snapshot matrix...')
169
        if (self.large):
170
            S = utils.read_large_data(self.train_mat)
171
        else:
172
            S = utils.read_data(self.train_mat)
173
174
        idxs = np.random.permutation(S.shape[0])
175
        S = S[idxs]
176
        S_max, S_min = utils.max_min(S, self.n_train)
177
        utils.scaling(S, S_max, S_min)
178
179
        if (self.zero_padding):
180
            S = utils.zero_pad(S, self.p)
181
182
        self.S_train, self.S_val = S[:self.n_train, :], S[self.n_train:, :]
183
        del S
184
185
        print('Loading parameters...')
186
        params = utils.read_params(self.train_params)
187
188
        params = params[idxs]
189
190
        self.params_train, self.params_val = params[:self.n_train], params[self.n_train:]
191
        del params
192
193
        self.loss_best = 1
194
        count = 0
195
        with tf.Session(config = tf.ConfigProto(gpu_options = tf.GPUOptions(allow_growth = True))) as sess:
196
            sess.run(tf.global_variables_initializer())
197
198
            if (self.restart):
199
                ckpt = tf.train.get_checkpoint_state(os.path.dirname(self.checkpoints_folder + '/checkpoint'))
200
                if ckpt and ckpt.model_checkpoint_path:
201
                    print(ckpt.model_checkpoint_path)
202
                    saver.restore(sess, ckpt.model_checkpoint_path)
203
204
            step = self.g_step.eval()
205
206
            for epoch in range(n_epochs):
207
                step = self.train_one_epoch(sess, self.init, train_writer, epoch, step)
208
                total_loss_mean = self.eval_once(sess, saver, self.init, test_writer, epoch, step)
209
                if total_loss_mean < self.loss_best:
210
                    self.loss_best = total_loss_mean
211
                    count = 0
212
                else:
213
                    count += 1
214
                # early - stopping
215
                if count == 500:
216
                    print('Stopped training due to early-stopping cross-validation')
217
                    break
218
            print('Best loss on validation set: {0}'.format(self.loss_best))
219
220
        train_writer.close()
221
        test_writer.close()
222
223
        with tf.Session() as sess:
224
            sess.run(tf.global_variables_initializer())
225
226
            ckpt = tf.train.get_checkpoint_state(os.path.dirname(self.checkpoints_folder + '/checkpoint'))
227
            if ckpt and ckpt.model_checkpoint_path:
228
                print(ckpt.model_checkpoint_path)
229
                saver.restore(sess, ckpt.model_checkpoint_path)
230
231
            print('Loading testing snapshot matrix...')
232
            if (self.large):
233
                self.S_test = utils.read_large_data(self.test_mat)
234
            else:
235
                self.S_test = utils.read_data(self.test_mat)
236
237
            utils.scaling(self.S_test, S_max, S_min)
238
239
            if (self.zero_padding):
240
                self.S_test = utils.zero_pad(self.S_test, self.n)
241
242
            print('Loading testing parameters...')
243
            self.params_test = utils.read_params(self.test_params)
244
245
            self.test_once(sess, self.init)
246
247
            utils.inverse_scaling(self.U_h, S_max, S_min)
248
            utils.inverse_scaling(self.S_test, S_max, S_min)
249
            n_test = self.S_test.shape[0] // self.N_t
250
            err = np.zeros((n_test, 1))
251
            for i in range(n_test):
252
                num = np.sqrt(np.mean(np.linalg.norm(self.S_test[i * self.N_t : (i + 1) * self.N_t] - self.U_h[i * self.N_t : (i + 1) * self.N_t], 2, axis = 1) ** 2))
253
                den = np.sqrt(np.mean(np.linalg.norm(self.S_test[i * self.N_t : (i + 1) * self.N_t], 2, axis = 1) ** 2))
254
                err[i] = num / den
255
            print('Error indicator epsilon_rel: {0}'.format(np.mean(err)))