[aedd99]: / dae_main_train.py

Download this file

48 lines (29 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from models import *
#from ImageExp import ImgExp
from ae_exp import AEExp
import numpy as np
def init_dae_exp(pre_load = None, regularizer_list = []):
'''
These are the training setting.
'''
batch_size = 16
epochs = 1
img_width, img_height = 64,64
hor_flip = False
initial_epoch = 0
dset = 'UR-Filled' #Choose data set here
autoencooder, model_name, model_type = DAE(img_width = img_width,
img_height = img_height, regularizer_list = regularizer_list)
DAE_exp = AEExp(model = autoencooder, img_width = img_width,\
img_height = img_height, model_name = model_name, model_type = model_type, \
pre_load = pre_load, initial_epoch = initial_epoch,\
epochs = epochs, batch_size = batch_size, dset = dset, hor_flip = hor_flip
)
return DAE_exp
if __name__ == "__main__":
regularizer_list_list = [['Dropout']] # Can use 'L1L2' aswell
for regularizer_list in regularizer_list_list:
DAE_exp = init_dae_exp(regularizer_list = regularizer_list)
DAE_exp.set_train_data(raw = False)
print(DAE_exp.train_data.shape)
DAE_exp.train()