|
a |
|
b/keras_CNN/keras_describe.py |
|
|
1 |
""" |
|
|
2 |
Describe a keras model (load it and create graph and description output) |
|
|
3 |
|
|
|
4 |
keras_describe.py model_file |
|
|
5 |
|
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
from __future__ import print_function |
|
|
10 |
from keras.preprocessing.image import ImageDataGenerator |
|
|
11 |
from keras.models import Sequential |
|
|
12 |
from keras.layers import Dense, Dropout, Activation, Flatten |
|
|
13 |
from keras.layers import Convolution2D, MaxPooling2D |
|
|
14 |
from keras.optimizers import SGD, Adagrad, Adadelta, RMSprop |
|
|
15 |
from keras.utils import np_utils |
|
|
16 |
try: |
|
|
17 |
from keras.utils.visualize_util import plot |
|
|
18 |
except ImportError: |
|
|
19 |
from keras.utils.vis_utils import plot_model as plot |
|
|
20 |
from keras.models import load_model |
|
|
21 |
from keras.models import model_from_json |
|
|
22 |
from keras.callbacks import EarlyStopping, TensorBoard |
|
|
23 |
import sys, random, pickle, numpy as np |
|
|
24 |
import matplotlib.pyplot as plt |
|
|
25 |
from mpl_toolkits.axes_grid1 import AxesGrid |
|
|
26 |
from datetime import date |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
filename = sys.argv[1] |
|
|
30 |
|
|
|
31 |
model = None |
|
|
32 |
if filename[-3:3] in ['hd5', 'hdf', 'df5']: |
|
|
33 |
model = load_model(filename) |
|
|
34 |
else: |
|
|
35 |
with open(filename, 'rU') as json_file: |
|
|
36 |
model = model_from_json(json_file.read()) |
|
|
37 |
|
|
|
38 |
print('Network Layout:') |
|
|
39 |
model.summary() |
|
|
40 |
|
|
|
41 |
plot_name = '.'.join(filename.split('.')[0:-1]) + ".png" |
|
|
42 |
|
|
|
43 |
plot(model, to_file=plot_name, show_shapes=True) |
|
|
44 |
|