--- a +++ b/Classifier/Classes/visualizer.py @@ -0,0 +1,43 @@ + +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import torch +import torchvision + +mean_nums = [0.485, 0.456, 0.406] +std_nums = [0.229, 0.224, 0.225] + + +def load_image(img_path, resize=True): + img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) + + if resize: + img = cv2.resize(img, (64, 64), interpolation = cv2.INTER_AREA) + + return img + +def show_image(img_path): + img = load_image(img_path) + plt.imshow(img) + plt.axis('off') + +def show_grid_image(image_paths): + images = [load_image(img) for img in image_paths] + images = torch.as_tensor(images) + images = images.permute(0, 3, 1, 2) + grid_img = torchvision.utils.make_grid(images, nrow=11) + plt.figure(figsize=(24, 12)) + plt.imshow(grid_img.permute(1, 2, 0)) + plt.axis('off'); + +def image_show(inp, title=None): + inp = inp.numpy().transpose((1, 2, 0)) + mean = np.array([mean_nums]) + std = np.array([std_nums]) + inp = std * inp + mean + inp = np.clip(inp, 0, 1) + plt.imshow(inp) + if title is not None: + plt.title(title) + plt.axis('off') \ No newline at end of file