|
a |
|
b/Net.py |
|
|
1 |
""" |
|
|
2 |
|
|
|
3 |
Stefania Fresca, MOX Laboratory, Politecnico di Milano |
|
|
4 |
April 2019 |
|
|
5 |
|
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import tensorflow as tf |
|
|
9 |
import numpy as np |
|
|
10 |
import scipy.io as sio |
|
|
11 |
import time |
|
|
12 |
import os |
|
|
13 |
|
|
|
14 |
import utils |
|
|
15 |
|
|
|
16 |
seed = 374 |
|
|
17 |
np.random.seed(seed) |
|
|
18 |
|
|
|
19 |
class Net: |
|
|
20 |
def __init__(self, config): |
|
|
21 |
self.lr = config['lr'] |
|
|
22 |
self.batch_size = config['batch_size'] |
|
|
23 |
self.g_step = tf.Variable(0, dtype = tf.int32, trainable = False, name = 'global_step') |
|
|
24 |
|
|
|
25 |
self.n_data = config['n_data'] |
|
|
26 |
self.n_train = int(0.8 * self.n_data) |
|
|
27 |
self.N_h = config['N_h'] |
|
|
28 |
self.N_t = config['N_t'] |
|
|
29 |
|
|
|
30 |
self.train_mat = config['train_mat'] |
|
|
31 |
self.test_mat = config['test_mat'] |
|
|
32 |
self.train_params = config['train_params'] |
|
|
33 |
self.test_params = config['test_params'] |
|
|
34 |
|
|
|
35 |
self.omega_h = config['omega_h'] |
|
|
36 |
self.omega_n = config['omega_n'] |
|
|
37 |
|
|
|
38 |
self.checkpoints_folder = config['checkpoints_folder'] |
|
|
39 |
self.graph_folder = config['graph_folder'] |
|
|
40 |
self.large = config['large'] |
|
|
41 |
self.zero_padding = config['zero_padding'] |
|
|
42 |
self.p = config['p'] |
|
|
43 |
self.restart = config['restart'] |
|
|
44 |
|
|
|
45 |
def get_data(self): |
|
|
46 |
with tf.name_scope('data'): |
|
|
47 |
self.X = tf.placeholder(tf.float32, shape = [None, self.N_h]) |
|
|
48 |
self.Y = tf.placeholder(tf.float32, shape = [None, self.n_params]) |
|
|
49 |
|
|
|
50 |
dataset = tf.data.Dataset.from_tensor_slices((self.X, self.Y)) |
|
|
51 |
dataset = dataset.shuffle(self.n_data) |
|
|
52 |
dataset = dataset.batch(self.batch_size) |
|
|
53 |
|
|
|
54 |
iterator = dataset.make_initializable_iterator() |
|
|
55 |
self.init = iterator.initializer |
|
|
56 |
|
|
|
57 |
input, self.params = iterator.get_next() |
|
|
58 |
self.input = tf.reshape(input, shape = [-1, int(np.sqrt(self.N_h)), int(np.sqrt(self.N_h)), 1]) |
|
|
59 |
|
|
|
60 |
def inference(self): |
|
|
61 |
raise NotImplementedError("Must be overridden with proper definition of forward path") |
|
|
62 |
|
|
|
63 |
def loss(self, u_h, u_n): |
|
|
64 |
with tf.name_scope('loss'): |
|
|
65 |
output = tf.reshape(self.input, shape = [-1, self.N_h]) |
|
|
66 |
self.loss_h = self.omega_h * tf.reduce_mean(tf.reduce_sum(tf.pow(output - u_h, 2), axis = 1)) |
|
|
67 |
self.loss_n = self.omega_n * tf.reduce_mean(tf.reduce_sum(tf.pow(self.enc - u_n, 2), axis = 1)) |
|
|
68 |
self.loss = self.loss_h + self.loss_n |
|
|
69 |
|
|
|
70 |
def optimize(self): |
|
|
71 |
self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step = self.g_step) |
|
|
72 |
|
|
|
73 |
def summary(self): |
|
|
74 |
with tf.name_scope('summaries'): |
|
|
75 |
self.summary = tf.summary.scalar('loss', self.loss) |
|
|
76 |
|
|
|
77 |
def build(self): |
|
|
78 |
self.get_data() |
|
|
79 |
self.inference() |
|
|
80 |
self.loss(self.u_h, self.u_n) |
|
|
81 |
self.optimize() |
|
|
82 |
self.summary() |
|
|
83 |
|
|
|
84 |
def train_one_epoch(self, sess, init, writer, epoch, step): |
|
|
85 |
start_time = time.time() |
|
|
86 |
sess.run(init, feed_dict = {self.X : self.S_train, self.Y : self.params_train}) |
|
|
87 |
total_loss_h = 0 |
|
|
88 |
total_loss_n = 0 |
|
|
89 |
total_loss = 0 |
|
|
90 |
n_batches = 0 |
|
|
91 |
print('------------ TRAINING -------------', flush = True) |
|
|
92 |
try: |
|
|
93 |
while True: |
|
|
94 |
_, l_h, l_n, l, summary = sess.run([self.opt, self.loss_h, self.loss_n, self.loss, self.summary]) |
|
|
95 |
writer.add_summary(summary, global_step = step) |
|
|
96 |
step += 1 |
|
|
97 |
total_loss_h += l_h |
|
|
98 |
total_loss_n += l_n |
|
|
99 |
total_loss += l |
|
|
100 |
n_batches += 1 |
|
|
101 |
except tf.errors.OutOfRangeError: |
|
|
102 |
pass |
|
|
103 |
print('Average loss_h at epoch {0} on training set: {1}'.format(epoch, total_loss_h / n_batches)) |
|
|
104 |
print('Average loss_n at epoch {0} on training set: {1}'.format(epoch, total_loss_n / n_batches)) |
|
|
105 |
print('Average loss at epoch {0} on training set: {1}'.format(epoch, total_loss / n_batches)) |
|
|
106 |
print('Took: {0} seconds'.format(time.time() - start_time)) |
|
|
107 |
return step |
|
|
108 |
|
|
|
109 |
def eval_once(self, sess, saver, init, writer, epoch, step): |
|
|
110 |
start_time = time.time() |
|
|
111 |
sess.run(init, feed_dict = {self.X : self.S_val, self.Y : self.params_val}) |
|
|
112 |
total_loss_h = 0 |
|
|
113 |
total_loss_n = 0 |
|
|
114 |
total_loss = 0 |
|
|
115 |
n_batches = 0 |
|
|
116 |
print('------------ VALIDATION ------------') |
|
|
117 |
try: |
|
|
118 |
while True: |
|
|
119 |
l_h, l_n, l, summary = sess.run([self.loss_h, self.loss_n, self.loss, self.summary]) |
|
|
120 |
writer.add_summary(summary, global_step = step) |
|
|
121 |
total_loss_h += l_h |
|
|
122 |
total_loss_n += l_n |
|
|
123 |
total_loss += l |
|
|
124 |
n_batches += 1 |
|
|
125 |
except tf.errors.OutOfRangeError: |
|
|
126 |
pass |
|
|
127 |
total_loss_mean = total_loss / n_batches |
|
|
128 |
if total_loss_mean < self.loss_best: |
|
|
129 |
saver.save(sess, self.checkpoints_folder + '/Net', step) |
|
|
130 |
print('Average loss_h at epoch {0} on validation set: {1}'.format(epoch, total_loss_h / n_batches)) |
|
|
131 |
print('Average loss_n at epoch {0} on validation set: {1}'.format(epoch, total_loss_n / n_batches)) |
|
|
132 |
print('Average loss at epoch {0} on validation set: {1}'.format(epoch, total_loss_mean)) |
|
|
133 |
print('Took: {0} seconds'.format(time.time() - start_time)) |
|
|
134 |
return total_loss_mean |
|
|
135 |
|
|
|
136 |
def test_once(self, sess, init): |
|
|
137 |
start_time = time.time() |
|
|
138 |
sess.run(init, feed_dict = {self.X : self.S_test, self.Y : self.params_test}) |
|
|
139 |
total_loss_h = 0 |
|
|
140 |
total_loss_n = 0 |
|
|
141 |
total_loss = 0 |
|
|
142 |
n_batches = 0 |
|
|
143 |
self.U_h = np.zeros(self.S_test.shape) |
|
|
144 |
print('------------ TESTING ------------') |
|
|
145 |
try: |
|
|
146 |
while True: |
|
|
147 |
l_h, l_n, l, u_h = sess.run([self.loss_h, self.loss_n, self.loss, self.u_h]) |
|
|
148 |
self.U_h[self.batch_size * n_batches : self.batch_size * (n_batches + 1)] = u_h |
|
|
149 |
total_loss_h += l_h |
|
|
150 |
total_loss_n += l_n |
|
|
151 |
total_loss += l |
|
|
152 |
n_batches += 1 |
|
|
153 |
except tf.errors.OutOfRangeError: |
|
|
154 |
pass |
|
|
155 |
print('Average loss_h on testing set: {0}'.format(total_loss_h / n_batches)) |
|
|
156 |
print('Average loss_N on testing set: {0}'.format(total_loss_n / n_batches)) |
|
|
157 |
print('Average loss on testing set: {0}'.format(total_loss / n_batches)) |
|
|
158 |
print('Took: {0} seconds'.format(time.time() - start_time)) |
|
|
159 |
|
|
|
160 |
#@profile (if memory profiling must be used) |
|
|
161 |
def train_all(self, n_epochs): |
|
|
162 |
if (not self.restart): |
|
|
163 |
utils.safe_mkdir(self.checkpoints_folder) |
|
|
164 |
saver = tf.train.Saver() |
|
|
165 |
train_writer = tf.summary.FileWriter('./' + self.graph_folder + '/train', tf.get_default_graph()) |
|
|
166 |
test_writer = tf.summary.FileWriter('./' + self.graph_folder + '/test', tf.get_default_graph()) |
|
|
167 |
|
|
|
168 |
print('Loading snapshot matrix...') |
|
|
169 |
if (self.large): |
|
|
170 |
S = utils.read_large_data(self.train_mat) |
|
|
171 |
else: |
|
|
172 |
S = utils.read_data(self.train_mat) |
|
|
173 |
|
|
|
174 |
idxs = np.random.permutation(S.shape[0]) |
|
|
175 |
S = S[idxs] |
|
|
176 |
S_max, S_min = utils.max_min(S, self.n_train) |
|
|
177 |
utils.scaling(S, S_max, S_min) |
|
|
178 |
|
|
|
179 |
if (self.zero_padding): |
|
|
180 |
S = utils.zero_pad(S, self.p) |
|
|
181 |
|
|
|
182 |
self.S_train, self.S_val = S[:self.n_train, :], S[self.n_train:, :] |
|
|
183 |
del S |
|
|
184 |
|
|
|
185 |
print('Loading parameters...') |
|
|
186 |
params = utils.read_params(self.train_params) |
|
|
187 |
|
|
|
188 |
params = params[idxs] |
|
|
189 |
|
|
|
190 |
self.params_train, self.params_val = params[:self.n_train], params[self.n_train:] |
|
|
191 |
del params |
|
|
192 |
|
|
|
193 |
self.loss_best = 1 |
|
|
194 |
count = 0 |
|
|
195 |
with tf.Session(config = tf.ConfigProto(gpu_options = tf.GPUOptions(allow_growth = True))) as sess: |
|
|
196 |
sess.run(tf.global_variables_initializer()) |
|
|
197 |
|
|
|
198 |
if (self.restart): |
|
|
199 |
ckpt = tf.train.get_checkpoint_state(os.path.dirname(self.checkpoints_folder + '/checkpoint')) |
|
|
200 |
if ckpt and ckpt.model_checkpoint_path: |
|
|
201 |
print(ckpt.model_checkpoint_path) |
|
|
202 |
saver.restore(sess, ckpt.model_checkpoint_path) |
|
|
203 |
|
|
|
204 |
step = self.g_step.eval() |
|
|
205 |
|
|
|
206 |
for epoch in range(n_epochs): |
|
|
207 |
step = self.train_one_epoch(sess, self.init, train_writer, epoch, step) |
|
|
208 |
total_loss_mean = self.eval_once(sess, saver, self.init, test_writer, epoch, step) |
|
|
209 |
if total_loss_mean < self.loss_best: |
|
|
210 |
self.loss_best = total_loss_mean |
|
|
211 |
count = 0 |
|
|
212 |
else: |
|
|
213 |
count += 1 |
|
|
214 |
# early - stopping |
|
|
215 |
if count == 500: |
|
|
216 |
print('Stopped training due to early-stopping cross-validation') |
|
|
217 |
break |
|
|
218 |
print('Best loss on validation set: {0}'.format(self.loss_best)) |
|
|
219 |
|
|
|
220 |
train_writer.close() |
|
|
221 |
test_writer.close() |
|
|
222 |
|
|
|
223 |
with tf.Session() as sess: |
|
|
224 |
sess.run(tf.global_variables_initializer()) |
|
|
225 |
|
|
|
226 |
ckpt = tf.train.get_checkpoint_state(os.path.dirname(self.checkpoints_folder + '/checkpoint')) |
|
|
227 |
if ckpt and ckpt.model_checkpoint_path: |
|
|
228 |
print(ckpt.model_checkpoint_path) |
|
|
229 |
saver.restore(sess, ckpt.model_checkpoint_path) |
|
|
230 |
|
|
|
231 |
print('Loading testing snapshot matrix...') |
|
|
232 |
if (self.large): |
|
|
233 |
self.S_test = utils.read_large_data(self.test_mat) |
|
|
234 |
else: |
|
|
235 |
self.S_test = utils.read_data(self.test_mat) |
|
|
236 |
|
|
|
237 |
utils.scaling(self.S_test, S_max, S_min) |
|
|
238 |
|
|
|
239 |
if (self.zero_padding): |
|
|
240 |
self.S_test = utils.zero_pad(self.S_test, self.n) |
|
|
241 |
|
|
|
242 |
print('Loading testing parameters...') |
|
|
243 |
self.params_test = utils.read_params(self.test_params) |
|
|
244 |
|
|
|
245 |
self.test_once(sess, self.init) |
|
|
246 |
|
|
|
247 |
utils.inverse_scaling(self.U_h, S_max, S_min) |
|
|
248 |
utils.inverse_scaling(self.S_test, S_max, S_min) |
|
|
249 |
n_test = self.S_test.shape[0] // self.N_t |
|
|
250 |
err = np.zeros((n_test, 1)) |
|
|
251 |
for i in range(n_test): |
|
|
252 |
num = np.sqrt(np.mean(np.linalg.norm(self.S_test[i * self.N_t : (i + 1) * self.N_t] - self.U_h[i * self.N_t : (i + 1) * self.N_t], 2, axis = 1) ** 2)) |
|
|
253 |
den = np.sqrt(np.mean(np.linalg.norm(self.S_test[i * self.N_t : (i + 1) * self.N_t], 2, axis = 1) ** 2)) |
|
|
254 |
err[i] = num / den |
|
|
255 |
print('Error indicator epsilon_rel: {0}'.format(np.mean(err))) |