|
a |
|
b/hands-on-session-2/utils.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
import numpy as np |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
def show_inline(image, title=''): |
|
|
6 |
f, ax = plt.subplots(1, 1, figsize=(10,10)) |
|
|
7 |
ax.grid(False) |
|
|
8 |
ax.set_xticks([]) |
|
|
9 |
ax.set_yticks([]) |
|
|
10 |
ax.imshow(image) |
|
|
11 |
ax.set_title(title) |
|
|
12 |
plt.show() |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def get_patches(out, k, patch_size=36, random=False): |
|
|
16 |
image = out['image'] |
|
|
17 |
graph = out['graph'] |
|
|
18 |
|
|
|
19 |
if random: |
|
|
20 |
importance_scores = np.random.uniform(size=out['importance_scores'].size) |
|
|
21 |
else: |
|
|
22 |
importance_scores = out['importance_scores'] |
|
|
23 |
|
|
|
24 |
image = np.pad(image, |
|
|
25 |
((patch_size, patch_size), (patch_size, patch_size), (0, 0)), |
|
|
26 |
mode="constant", |
|
|
27 |
constant_values=255) |
|
|
28 |
important_indices = (-importance_scores).argsort()[:k] |
|
|
29 |
important_centroids = graph.ndata['centroid'][important_indices, :].cpu().numpy().astype(int) |
|
|
30 |
|
|
|
31 |
patches = [] |
|
|
32 |
for i in range(k): |
|
|
33 |
x, y = important_centroids[i] + patch_size |
|
|
34 |
patch = image[y - int(patch_size / 2): y + int(patch_size / 2), |
|
|
35 |
x - int(patch_size / 2): x + int(patch_size / 2), |
|
|
36 |
:] |
|
|
37 |
patches.append(patch) |
|
|
38 |
return patches |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
def plot_patches(patches, ncol=5, patch_size=36): |
|
|
42 |
nrow = len(patches) // ncol |
|
|
43 |
patches = np.stack(patches, axis=0) |
|
|
44 |
patches = np.reshape(patches, newshape=(nrow, ncol, patch_size, patch_size, 3)) |
|
|
45 |
|
|
|
46 |
for i in range(nrow): |
|
|
47 |
for j in range(ncol): |
|
|
48 |
if j == 0: |
|
|
49 |
grid_ = patches[i, j] |
|
|
50 |
else: |
|
|
51 |
grid_ = np.hstack((grid_, patches[i, j])) |
|
|
52 |
if i == 0: |
|
|
53 |
grid = grid_ |
|
|
54 |
else: |
|
|
55 |
grid = np.vstack((grid, grid_)) |
|
|
56 |
|
|
|
57 |
show_inline(grid) |