[4807fa]: / dl / utils / visualization / visualization.py

Download this file

128 lines (112 with data), 4.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA, IncrementalPCA
import torch
def plot_cdf(y, title=''):
if isinstance(y, torch.Tensor):
y_sorted = y.contiguous().view(-1).sort(dim=0)[0].detach().cpu().numpy()
else:
y_sorted = np.array(y).reshape(-1)
y_sorted.sort()
plt.plot(y_sorted, np.linspace(0,1,len(y_sorted)), 'ro', markersize=0.5)
if title!='':
plt.title(title)
plt.show()
def pca(x, n_components=2, verbose=False):
r"""PCA for 2-D visualization
"""
if len(x)>10000:
pca = IncrementalPCA(n_components=n_components)
else:
pca = PCA(n_components=n_components)
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy().copy()
pca.fit(x)
if verbose:
print(pca.explained_variance_, pca.noise_variance_)
plt.title('explained_variance')
plt.plot(pca.explained_variance_.tolist() + [pca.noise_variance_], 'ro')
plt.show()
return pca.fit_transform(x)
def plot_scatter(y_=None, model_=None, x_=None, title='', labels=None, colors=None, size=15,
marker_size=20, folder='.', save_fig=False):
r"""2D scatter plot
"""
if y_ is None:
assert model_ is not None and x_ is not None
y_ = model_(x_.contiguous())
if colors is not None:
assert len(colors) == len(y_)
else:
if labels is not None:
assert len(y_) == len(labels)
# here is a bug: BASE_COLORS and CSS4_COLORS have overlap
color = sorted(matplotlib.colors.BASE_COLORS) + sorted(matplotlib.colors.CSS4_COLORS)
colors = [color[i] for i in labels]
if isinstance(y_, torch.Tensor):
y_ = y_.data.cpu().numpy()
if y_.shape[1] > 2:
y_ = pca(y_)
plt.figure(figsize=(size, size))
plt.scatter(y_[:,0],y_[:,1], c=colors, s=marker_size)
if save_fig:
if not os.path.exists(folder):
os.makedirs(folder)
plt.savefig(folder+'/'+title+'.png', bbox_inches='tight', dpi=200)
else:
plt.title(title)
plt.show()
plt.close()
def plot_history(history, title='', indices=None, colors='rgbkmc', markers='ov+*,<',
labels=['']*6, linestyles=['']*6, markersize=4):
"""Plot curves such as loss history and accuracy history during training
Args:
history: N * m numpy array; N is the number of steps, and m is the number of histories.
indices: a list of selected indices to plot; usually less than four curves are plotted in a single figure.
"""
if indices is None:
indices = range(history.shape[1])
fig = plt.figure(figsize=(10,10))
ax = fig.add_subplot(1,1,1)
for i in indices:
ax.plot(range(len(history)), history[:,i], color=colors[i], linestyle=linestyles[i],
marker=markers[i], markersize=markersize, label=labels[i])
ax.legend()
plt.title(title)
plt.show()
def plot_history_multi_splits(histories, title='Loss', idx=0, labels=['Train', 'Validation', 'Test'],
colors='rgbk', markers='ov+*', linestyles=['', '', '', ''], markersize=4):
"""Given a list of histories, e.g., [loss_train_his, loss_val_his, loss_test_his], plot them in one figure
Args:
Most arguments are passed to plot_history
histories: a list of list (or np.array), e.g., [loss_train_his, loss_val_his, loss_test_his];
len(loss_train_his) = num_points; len(loss_train_his[0]) = num_losses (for multiple losses)
title: plot title
idx: if there are multiple losses, idx specifies which one should be ploted
labels: the names for different histories; default: ['Train', 'Validation', 'Test']
"""
if len(histories) != len(labels): # labels must match with histories
labels = [f'History {i}' for i in range(len(histories))]
arrays = []
labels_included = []
for i, history in enumerate(histories):
if len(history)>0: # only include non-empty array
arrays.append(history)
labels_included.append(labels[i])
if len(arrays) > 1: # make sure all arrays have the same shape
assert np.array(history).shape == prev_shape
prev_shape = np.array(history).shape
history = np.array(arrays)
if history.ndim == 3: # For multiple losses/accuracies, use idx to select one
history = history[:,:,idx]
history = history.T # To use plot_history, make sure history is of shape N * m; N=num_points, m = num of curves
plot_history(history, title=title, labels=labels_included, colors=colors, markers=markers,
linestyles=linestyles, markersize=markersize)
def plot_acc_history(acc_his, title='', color='r', marker='v', linestyle='', markersize=2):
plt.figure(figsize=(10,10))
plt.title(title)
plt.plot(range(len(acc_his)), acc_his, color=color, marker=marker, linestyle=linestyle,
markersize=markersize)
plt.show()