Diff of /Visualization.py [000000] .. [271336]

Switch to unified view

a b/Visualization.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Thu Apr 06 11:35:51 2017
4
5
@author: Shreyas_V
6
"""
7
8
'''
9
Visualizes the weights of a keras model's layers, keeping the dimensions intact'''
10
11
import theano
12
import numpy as np
13
import numpy.ma as ma
14
import pylab
15
16
def make_mosaic(imgs, nrows, ncols, border=1):
17
    """
18
    Given a set of images with all the same shape, makes a
19
    mosaic with nrows and ncols
20
    """
21
    nimgs = imgs.shape[0]
22
    imshape = imgs.shape[1:]
23
    
24
    mosaic = ma.masked_all((nrows * imshape[0] + (nrows - 1) * border,
25
                            ncols * imshape[1] + (ncols - 1) * border),
26
                            dtype=np.float32)
27
    
28
    paddedh = imshape[0] + border
29
    paddedw = imshape[1] + border
30
    for i in xrange(nimgs):
31
        row = int(np.floor(i / ncols))
32
        col = i % ncols
33
        
34
        mosaic[row * paddedh:row * paddedh + imshape[0],
35
               col * paddedw:col * paddedw + imshape[1]] = imgs[i]
36
    return mosaic
37
38
39
def visualize_layer_output(model, layer_num, inp, nr=6, nc=6):
40
    '''
41
    Given a keras model, a layer number, a particular input for the model,
42
    shows a mosaic image of all filters in that layer with the image dimension of nr x nc,
43
    activated by the model's input given to the first layer
44
    '''
45
    out = theano.function([model.get_input(train=False)], model.layers[layer_num].get_output(train=False))
46
    op = out(inp)
47
    op = np.squeeze(op)
48
    pylab.imshow(make_mosaic(op, nr, nc), 'gray')
49
    
50
def visualize_layer(model, layer_num, nr=6, nc=6):
51
    '''
52
    Given a model and its layer number,
53
    shows a mosaic image of all the filter weights, 
54
    with the image dimension being nr x nc
55
    '''
56
    W = model.layers[layer_num].W.get_value(borrow=True)
57
    W = np.squeeze(W)
58
    pylab.imshow(make_mosaic(W, nr, nc))