Switch to side-by-side view

--- a
+++ b/dl/utils/visualization/visualization.py
@@ -0,0 +1,127 @@
+import os
+
+import numpy as np
+import matplotlib
+import matplotlib.pyplot as plt
+from sklearn.decomposition import PCA, IncrementalPCA
+
+import torch
+
+def plot_cdf(y, title=''):
+  if isinstance(y, torch.Tensor):
+    y_sorted = y.contiguous().view(-1).sort(dim=0)[0].detach().cpu().numpy()
+  else:
+    y_sorted = np.array(y).reshape(-1)
+    y_sorted.sort()
+  plt.plot(y_sorted, np.linspace(0,1,len(y_sorted)), 'ro', markersize=0.5)
+  if title!='':
+    plt.title(title)
+  plt.show()
+
+def pca(x, n_components=2, verbose=False):
+    r"""PCA for 2-D visualization
+    """
+    if len(x)>10000:
+        pca = IncrementalPCA(n_components=n_components)
+    else:
+        pca = PCA(n_components=n_components)
+    if isinstance(x, torch.Tensor):
+        x = x.detach().cpu().numpy().copy()
+    pca.fit(x)
+    if verbose:
+        print(pca.explained_variance_, pca.noise_variance_)
+        plt.title('explained_variance')
+        plt.plot(pca.explained_variance_.tolist() + [pca.noise_variance_], 'ro')
+        plt.show()
+    return pca.fit_transform(x)
+
+def plot_scatter(y_=None, model_=None, x_=None, title='', labels=None, colors=None, size=15, 
+                 marker_size=20, folder='.', save_fig=False):
+    r"""2D scatter plot
+    """
+    if y_ is None:
+        assert model_ is not None and x_ is not None
+        y_ = model_(x_.contiguous())
+    if colors is not None:
+        assert len(colors) == len(y_)
+    else:
+        if labels is not None:
+            assert len(y_) == len(labels)
+            # here is a bug: BASE_COLORS and CSS4_COLORS have overlap
+            color = sorted(matplotlib.colors.BASE_COLORS) + sorted(matplotlib.colors.CSS4_COLORS)
+            colors = [color[i] for i in labels]
+    if isinstance(y_, torch.Tensor):
+        y_ = y_.data.cpu().numpy()
+    if y_.shape[1] > 2:
+        y_ = pca(y_)
+    plt.figure(figsize=(size, size))
+    plt.scatter(y_[:,0],y_[:,1], c=colors, s=marker_size)
+    if save_fig:
+        if not os.path.exists(folder):
+            os.makedirs(folder)
+        plt.savefig(folder+'/'+title+'.png', bbox_inches='tight', dpi=200)
+    else:
+        plt.title(title)
+        plt.show()
+    plt.close()
+
+
+def plot_history(history, title='', indices=None, colors='rgbkmc', markers='ov+*,<',
+                 labels=['']*6, linestyles=['']*6, markersize=4):
+  """Plot curves such as loss history and accuracy history during training
+
+  Args:
+    history: N * m numpy array; N is the number of steps, and m is the number of histories.
+    indices: a list of selected indices to plot; usually less than four curves are plotted in a single figure.
+  """
+  if indices is None:
+    indices = range(history.shape[1])
+  fig = plt.figure(figsize=(10,10))
+  ax = fig.add_subplot(1,1,1)
+  for i in indices:
+    ax.plot(range(len(history)), history[:,i], color=colors[i], linestyle=linestyles[i], 
+            marker=markers[i], markersize=markersize, label=labels[i])
+  ax.legend()
+  plt.title(title)
+  plt.show()
+
+
+def plot_history_multi_splits(histories, title='Loss', idx=0, labels=['Train', 'Validation', 'Test'],
+  colors='rgbk', markers='ov+*', linestyles=['', '', '', ''], markersize=4):
+  """Given a list of histories, e.g., [loss_train_his, loss_val_his, loss_test_his], plot them in one figure
+
+  Args:
+    Most arguments are passed to plot_history
+    histories: a list of list (or np.array), e.g., [loss_train_his, loss_val_his, loss_test_his]; 
+      len(loss_train_his) = num_points; len(loss_train_his[0]) = num_losses (for multiple losses)
+    title: plot title
+    idx: if there are multiple losses, idx specifies which one should be ploted
+    labels: the names for different histories; default: ['Train', 'Validation', 'Test']
+
+  """
+  if len(histories) != len(labels): # labels must match with histories
+    labels = [f'History {i}' for i in range(len(histories))] 
+  arrays = []
+  labels_included = []
+  for i, history in enumerate(histories):
+    if len(history)>0: # only include non-empty array
+      arrays.append(history)
+      labels_included.append(labels[i])
+      if len(arrays) > 1: # make sure all arrays have the same shape
+        assert np.array(history).shape == prev_shape
+      prev_shape = np.array(history).shape
+  history = np.array(arrays)
+  if history.ndim == 3: # For multiple losses/accuracies, use idx to select one
+    history = history[:,:,idx]
+  history = history.T  # To use plot_history, make sure history is of shape N * m; N=num_points, m = num of curves
+  plot_history(history, title=title, labels=labels_included, colors=colors, markers=markers,
+               linestyles=linestyles, markersize=markersize)
+
+
+def plot_acc_history(acc_his, title='', color='r', marker='v', linestyle='', markersize=2):
+  plt.figure(figsize=(10,10))
+  plt.title(title)
+  plt.plot(range(len(acc_his)), acc_his, color=color, marker=marker, linestyle=linestyle,
+          markersize=markersize)
+  plt.show()
+