[a9a70b]: / utils / visualize.py

Download this file

110 lines (91 with data), 3.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from .misc import *
__all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
# functions to show an image
def make_image(img, mean=(0,0,0), std=(1,1,1)):
for i in range(0, 3):
img[i] = img[i] * std[i] + mean[i] # unnormalize
npimg = img.numpy()
return np.transpose(npimg, (1, 2, 0))
def gauss(x,a,b,c):
return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
def colorize(x):
''' Converts a one-channel grayscale image to a color heatmap image '''
if x.dim() == 2:
torch.unsqueeze(x, 0, out=x)
if x.dim() == 3:
cl = torch.zeros([3, x.size(1), x.size(2)])
cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
cl[1] = gauss(x,1,.5,.3)
cl[2] = gauss(x,1,.2,.3)
cl[cl.gt(1)] = 1
elif x.dim() == 4:
cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
cl[:,1,:,:] = gauss(x,1,.5,.3)
cl[:,2,:,:] = gauss(x,1,.2,.3)
return cl
def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.imshow(images)
plt.show()
def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
im_size = images.size(2)
# save for adding mask
im_data = images.clone()
for i in range(0, 3):
im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.subplot(2, 1, 1)
plt.imshow(images)
plt.axis('off')
# for b in range(mask.size(0)):
# mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
mask_size = mask.size(2)
# print('Max %f Min %f' % (mask.max(), mask.min()))
mask = (upsampling(mask, scale_factor=im_size/mask_size))
# mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
# for c in range(3):
# mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
# print(mask.size())
mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
# mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
plt.subplot(2, 1, 2)
plt.imshow(mask)
plt.axis('off')
def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
im_size = images.size(2)
# save for adding mask
im_data = images.clone()
for i in range(0, 3):
im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
images = make_image(torchvision.utils.make_grid(images), Mean, Std)
plt.subplot(1+len(masklist), 1, 1)
plt.imshow(images)
plt.axis('off')
for i in range(len(masklist)):
mask = masklist[i].data.cpu()
# for b in range(mask.size(0)):
# mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
mask_size = mask.size(2)
# print('Max %f Min %f' % (mask.max(), mask.min()))
mask = (upsampling(mask, scale_factor=im_size/mask_size))
# mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
# for c in range(3):
# mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
# print(mask.size())
mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
# mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
plt.subplot(1+len(masklist), 1, i+2)
plt.imshow(mask)
plt.axis('off')
# x = torch.zeros(1, 3, 3)
# out = colorize(x)
# out_im = make_image(out)
# plt.imshow(out_im)
# plt.show()