Switch to unified view

a b/keras_CNN/keras_describe.py
1
"""
2
Describe a keras model (load it and create graph and description output)
3
4
keras_describe.py  model_file
5
6
"""
7
8
9
from __future__ import print_function
10
from keras.preprocessing.image import ImageDataGenerator
11
from keras.models import Sequential
12
from keras.layers import Dense, Dropout, Activation, Flatten
13
from keras.layers import Convolution2D, MaxPooling2D
14
from keras.optimizers import SGD, Adagrad, Adadelta, RMSprop
15
from keras.utils import np_utils
16
try:
17
    from keras.utils.visualize_util import plot
18
except ImportError:
19
    from keras.utils.vis_utils import plot_model as plot
20
from keras.models import load_model
21
from keras.models import model_from_json
22
from keras.callbacks import EarlyStopping, TensorBoard
23
import sys, random, pickle, numpy as np
24
import matplotlib.pyplot as plt
25
from mpl_toolkits.axes_grid1 import AxesGrid
26
from datetime import date
27
28
29
filename = sys.argv[1]
30
31
model = None
32
if filename[-3:3] in ['hd5', 'hdf', 'df5']:
33
    model = load_model(filename)
34
else:
35
    with open(filename, 'rU') as json_file:
36
        model = model_from_json(json_file.read())
37
38
print('Network Layout:')
39
model.summary()
40
41
plot_name = '.'.join(filename.split('.')[0:-1]) + ".png"
42
43
plot(model, to_file=plot_name, show_shapes=True)
44