Switch to unified view

a b/hands-on-session-2/utils.py
1
import matplotlib.pyplot as plt
2
import numpy as np
3
4
5
def show_inline(image, title=''):
6
    f, ax = plt.subplots(1, 1, figsize=(10,10))
7
    ax.grid(False)
8
    ax.set_xticks([])
9
    ax.set_yticks([])
10
    ax.imshow(image)
11
    ax.set_title(title)
12
    plt.show()
13
14
15
def get_patches(out, k, patch_size=36, random=False):
16
    image = out['image']
17
    graph = out['graph']
18
19
    if random:
20
        importance_scores = np.random.uniform(size=out['importance_scores'].size)
21
    else:
22
        importance_scores = out['importance_scores']
23
24
    image = np.pad(image,
25
                   ((patch_size, patch_size), (patch_size, patch_size), (0, 0)),
26
                   mode="constant",
27
                   constant_values=255)
28
    important_indices = (-importance_scores).argsort()[:k]
29
    important_centroids = graph.ndata['centroid'][important_indices, :].cpu().numpy().astype(int)
30
31
    patches = []
32
    for i in range(k):
33
        x, y = important_centroids[i] + patch_size
34
        patch = image[y - int(patch_size / 2): y + int(patch_size / 2),
35
                x - int(patch_size / 2): x + int(patch_size / 2),
36
                :]
37
        patches.append(patch)
38
    return patches
39
40
41
def plot_patches(patches, ncol=5, patch_size=36):
42
    nrow = len(patches) // ncol
43
    patches = np.stack(patches, axis=0)
44
    patches = np.reshape(patches, newshape=(nrow, ncol, patch_size, patch_size, 3))
45
46
    for i in range(nrow):
47
        for j in range(ncol):
48
            if j == 0:
49
                grid_ = patches[i, j]
50
            else:
51
                grid_ = np.hstack((grid_, patches[i, j]))
52
        if i == 0:
53
            grid = grid_
54
        else:
55
            grid = np.vstack((grid, grid_))
56
57
    show_inline(grid)