Download this file

15 lines (13 with data), 427 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
import numpy as np
import matplotlib.pyplot as plt
def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated