|
a |
|
b/utils/load_plot.py |
|
|
1 |
import matplotlib.pyplot as plt |
|
|
2 |
from PIL import Image |
|
|
3 |
from subprocess import check_output |
|
|
4 |
from random import sample |
|
|
5 |
from os.path import join |
|
|
6 |
from matplotlib.pyplot import imsave |
|
|
7 |
from torchvision import transforms |
|
|
8 |
|
|
|
9 |
# Imagenes de prueba y transformaciones |
|
|
10 |
img_path = './data/img/' |
|
|
11 |
all_transforms = image = transforms.Compose([ |
|
|
12 |
transforms.Resize((224, 224)), # las imagenes originales son de tamaño 512x512 |
|
|
13 |
transforms.ToTensor(), # convertir a torch.Tensor |
|
|
14 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalización |
|
|
15 |
]) |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def load_random_samples(n): |
|
|
19 |
""" |
|
|
20 |
Arguments |
|
|
21 |
--------- |
|
|
22 |
n: numero de ejemplos |
|
|
23 |
|
|
|
24 |
Returns |
|
|
25 |
------- |
|
|
26 |
imgs: lista de torch.Tensor con las imagenes |
|
|
27 |
""" |
|
|
28 |
|
|
|
29 |
img_names = check_output(['ls', img_path]).decode('utf8').splitlines() |
|
|
30 |
# si n > nmax, devolver n_max |
|
|
31 |
selected_images = sample(img_names, min(n, len(img_names))) |
|
|
32 |
samples_path = [join(img_path, img) for img in selected_images] |
|
|
33 |
|
|
|
34 |
imgs = [] |
|
|
35 |
for sample_path in samples_path: |
|
|
36 |
x = Image.open(sample_path).convert("RGB") # leerlas con 3 canales |
|
|
37 |
x = all_transforms(x) # aplicar las transformaciones |
|
|
38 |
imgs.append(x) |
|
|
39 |
|
|
|
40 |
return imgs |
|
|
41 |
|
|
|
42 |
def plot_images(rows, cols, images): |
|
|
43 |
""" |
|
|
44 |
Arguments: |
|
|
45 |
---------- |
|
|
46 |
rows: número de filas |
|
|
47 |
cols: número de columnas |
|
|
48 |
images: lista de imágenes ( de tipo torch.Tensor) |
|
|
49 |
|
|
|
50 |
Returns: |
|
|
51 |
-------- |
|
|
52 |
""" |
|
|
53 |
|
|
|
54 |
fig, axs = plt.subplots(rows, cols, sharex='col', sharey='row', |
|
|
55 |
gridspec_kw={'hspace': 0, 'wspace': 0}) |
|
|
56 |
|
|
|
57 |
for i in range(rows): |
|
|
58 |
for j in range(cols): |
|
|
59 |
try: |
|
|
60 |
axs[i, j].imshow(images[i*cols + j][0, ...], cmap='gray') |
|
|
61 |
except IndexError: |
|
|
62 |
pass |
|
|
63 |
|
|
|
64 |
fig.show() |