--- a +++ b/gpsa/plotting/callbacks.py @@ -0,0 +1,443 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from torch.utils.data import Dataset, DataLoader +import time +import pandas as pd +from scipy.stats import pearsonr + +from matplotlib.lines import Line2D + +import seaborn as sns + +SCATTER_POINT_SIZE = 50 + +device = "cuda" if torch.cuda.is_available() else "cpu" + +def callback_oned( + model, + X, + Y, + X_aligned, + data_expression_ax, + latent_expression_ax, + prediction_ax=None, + X_test=None, + Y_pred=None, + Y_test_true=None, + X_test_aligned=None, + F_samples=None, +): + model.eval() + markers = list(Line2D.markers.keys()) + colors = ["blue", "orange"] + + if model.fixed_view_idx is not None: + curr_idx = model.view_idx["expression"][model.fixed_view_idx] + X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) + + data_expression_ax.cla() + latent_expression_ax.cla() + + data_expression_ax.set_title("Observed data") + latent_expression_ax.set_title("Aligned data") + + data_expression_ax.set_xlabel("Spatial coordinate") + latent_expression_ax.set_xlabel("Spatial coordinate") + + data_expression_ax.set_ylabel("Outcome") + latent_expression_ax.set_ylabel("Outcome") + + data_expression_ax.set_xlim([X.min(), X.max()]) + latent_expression_ax.set_xlim([X.min(), X.max()]) + + for vv in range(model.n_views): + + view_idx = model.view_idx["expression"] + + data_expression_ax.scatter( + X[view_idx[vv], 0], + Y[view_idx[vv], 0], + label="View {}".format(vv + 1), + marker=markers[vv], + s=SCATTER_POINT_SIZE, + c="blue", + ) + if Y.shape[1] > 1: + data_expression_ax.scatter( + X[view_idx[vv], 0], + Y[view_idx[vv], 1], + label="View {}".format(vv + 1), + marker=markers[vv], + s=SCATTER_POINT_SIZE, + c="orange", + ) + latent_expression_ax.scatter( + # model.G_means["expression"].detach().cpu().numpy()[view_idx[vv], 0], + X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0], + Y[view_idx[vv], 0], + c="blue", + label="View {}".format(vv + 1), + marker=markers[vv], + s=SCATTER_POINT_SIZE, + ) + if Y.shape[1] > 1: + latent_expression_ax.scatter( + # model.G_means["expression"].detach().cpu().numpy()[view_idx[vv], 0], + X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0], + Y[view_idx[vv], 1], + c="orange", + label="View {}".format(vv + 1), + marker=markers[vv], + s=SCATTER_POINT_SIZE, + ) + # latent_expression_ax.scatter( + # model.Xtilde.detach().cpu().numpy()[vv, :, 0], + # model.delta_list.detach().cpu().numpy()[vv][:, 0], + # c="red", + # label="View {}".format(vv + 1), + # marker="^", + # s=100, + # ) + + if F_samples is not None: + latent_expression_ax.scatter( + X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0], + F_samples.detach().cpu().numpy()[view_idx[vv], 0], + c="red", + marker=markers[vv], + s=SCATTER_POINT_SIZE, + ) + if Y.shape[1] > 1: + latent_expression_ax.scatter( + X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0], + F_samples.detach().cpu().numpy()[view_idx[vv], 1], + c="green", + marker=markers[vv], + s=SCATTER_POINT_SIZE, + ) + + if prediction_ax is not None: + + prediction_ax.cla() + prediction_ax.set_title("Predictions") + prediction_ax.set_xlabel("True outcome") + prediction_ax.set_ylabel("Predicted outcome") + + ### Plots the warping function + # prediction_ax.scatter( + # X[view_idx[vv], 0], + # X_aligned["expression"].detach().cpu().numpy()[view_idx[vv], 0], + # label="View {}".format(vv + 1), + # marker=markers[vv], + # s=100, + # c="blue", + # ) + # prediction_ax.scatter( + # model.Xtilde.detach().cpu().numpy()[vv, :, 0], + # model.delta_list.detach().cpu().numpy()[vv][:, 0], + # c="red", + # label="View {}".format(vv + 1), + # marker="^", + # s=100, + # ) + latent_expression_ax.scatter( + X_test_aligned["expression"].detach().cpu().numpy()[:, 0], + Y_pred.detach().cpu().numpy()[:, 0], + c="blue", + label="Prediction", + marker="^", + s=SCATTER_POINT_SIZE, + ) + latent_expression_ax.scatter( + X_test_aligned["expression"].detach().cpu().numpy()[:, 0], + Y_pred.detach().cpu().numpy()[:, 1], + c="orange", + label="Prediction", + marker="^", + s=SCATTER_POINT_SIZE, + ) + prediction_ax.scatter( + Y_test_true[:, 0], + Y_pred.detach().cpu().numpy()[:, 0], + c="black", + s=SCATTER_POINT_SIZE, + ) + prediction_ax.scatter( + Y_test_true[:, 1], + Y_pred.detach().cpu().numpy()[:, 1], + c="black", + s=SCATTER_POINT_SIZE, + marker="^", + ) + + data_expression_ax.legend() + plt.draw() + plt.pause(1 / 60.0) + + +def callback_twod( + model, + X, + Y, + X_aligned, + data_expression_ax, + latent_expression_ax, + is_mle=False, + gene_idx=0, + s=200, + include_legend=False, +): + + if model.fixed_view_idx is not None: + if is_mle: + pass + else: + curr_idx = model.view_idx["expression"][model.fixed_view_idx] + X_aligned["expression"][curr_idx] = torch.tensor( + X[curr_idx].astype(np.float32), device=device + ) + + model.eval() + markers = [".", "+", "^"] + colors = ["blue", "orange"] + + data_expression_ax.cla() + latent_expression_ax.cla() + data_expression_ax.set_title("Observed data") + latent_expression_ax.set_title("Aligned data") + + curr_view_idx = model.view_idx["expression"] + + latent_Xs = [] + Xs = [] + Ys = [] + markers_list = [] + viewname_list = [] + + for vv in range(model.n_views): + + ## Data + Xs.append(X[curr_view_idx[vv]]) + + ## Latents + curr_latent_Xs = X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv]] + latent_Xs.append(curr_latent_Xs) + Ys.append(Y[curr_view_idx[vv], gene_idx]) + markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) + viewname_list.append( + ["Observation {}".format(vv + 1)] * curr_latent_Xs.shape[0] + ) + + Xs = np.concatenate(Xs, axis=0) + latent_Xs = np.concatenate(latent_Xs, axis=0) + Ys = np.concatenate(Ys) + markers_list = np.concatenate(markers_list) + viewname_list = np.concatenate(viewname_list) + + data_df = pd.DataFrame( + { + "X1": Xs[:, 0], + "X2": Xs[:, 1], + "Y": Ys, + "marker": markers_list, + "view": viewname_list, + } + ) + + latent_df = pd.DataFrame( + { + "X1": latent_Xs[:, 0], + "X2": latent_Xs[:, 1], + "Y": Ys, + "marker": markers_list, + "view": viewname_list, + } + ) + + plt.sca(data_expression_ax) + g = sns.scatterplot( + data=data_df, + x="X1", + y="X2", + hue="Y", + style="view", + ax=data_expression_ax, + s=s, + linewidth=1.8, + edgecolor="black", + palette="viridis", + ) + if not include_legend: + g.legend_.remove() + # plt.colorbar() + # plt.axis("off") + # plt.scatter(model.Xtilde.detach().cpu().numpy()[0, :, 0], model.Xtilde.detach().cpu().numpy()[0, :, 1], color="red") + # plt.scatter(model.Xtilde.detach().cpu().numpy()[1, :, 0], model.Xtilde.detach().cpu().numpy()[1, :, 1], color="red") + # plt.scatter(model.Gtilde.detach().cpu().numpy()[:, 0], model.Gtilde.detach().cpu().numpy()[:, 1], color="red") + # plt.axis("off") + + plt.sca(latent_expression_ax) + g = sns.scatterplot( + data=latent_df, + x="X1", + y="X2", + hue="Y", + style="view", + ax=latent_expression_ax, + s=s, + linewidth=1.8, + edgecolor="black", + palette="viridis", + ) + if not include_legend: + g.legend_.remove() + # plt.colorbar() + + # import ipdb; ipdb.set_trace() + + # for vv in range(model.n_views): + + # # import ipdb; ipdb.set_trace() + # data_expression_ax.scatter( + # X[curr_view_idx[vv], 0], + # X[curr_view_idx[vv], 1], + # c=Y[curr_view_idx[vv], 0], + # label="View {}".format(vv + 1), + # marker=markers[vv], + # s=400, + # ) + # latent_expression_ax.scatter( + # X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv], 0], + # X_aligned["expression"].detach().cpu().numpy()[curr_view_idx[vv], 1], + # c=Y[curr_view_idx[vv], 0], + # label="View {}".format(vv + 1), + # marker=markers[vv], + # s=400, + # ) + # plt.axis("off") + + +def callback_twod_aligned_only( + model, + X, + Y, + X_aligned, + latent_expression_ax1, + latent_expression_ax2, + is_mle=False, + gene_idx=0, +): + + if model.fixed_view_idx is not None: + if is_mle: + pass + else: + curr_idx = model.view_idx["expression"][model.fixed_view_idx] + X_aligned["expression"][curr_idx] = torch.tensor( + X[curr_idx].astype(np.float32) + ) + + model.eval() + markers = [".", "+", "^"] + colors = ["blue", "orange"] + + latent_expression_ax1.cla() + latent_expression_ax2.cla() + latent_expression_ax1.set_title("Observed data") + latent_expression_ax2.set_title("Aligned data") + + curr_view_idx = model.view_idx["expression"] + + latent_Xs = [] + Xs = [] + Ys = [] + markers_list = [] + viewname_list = [] + + aligned_coords = X_aligned["expression"].detach().cpu().numpy() + + for vv in range(model.n_views): + + ## Data + Xs.append(X[curr_view_idx[vv]]) + + ## Latents + curr_latent_Xs = aligned_coords[curr_view_idx[vv]] + latent_Xs.append(curr_latent_Xs) + Ys.append(Y[curr_view_idx[vv], gene_idx]) + markers_list.append([markers[vv]] * curr_latent_Xs.shape[0]) + viewname_list.append(["View {}".format(vv + 1)] * curr_latent_Xs.shape[0]) + + latent_expression_ax1.scatter( + aligned_coords[curr_view_idx[0]][:, 0], + aligned_coords[curr_view_idx[0]][:, 1], + c=Y[curr_view_idx[0]][:, gene_idx].squeeze(), + s=24, + marker="h", + ) + latent_expression_ax2.scatter( + aligned_coords[curr_view_idx[1]][:, 0], + aligned_coords[curr_view_idx[1]][:, 1], + c=Y[curr_view_idx[1]][:, gene_idx].squeeze(), + s=24, + marker="h", + ) + # latent_expression_ax1.scatter(model.Xtilde.detach().cpu().numpy()[0, :, 0], model.Xtilde.detach().cpu().numpy()[0, :, 1], color="red") + # latent_expression_ax2.scatter(model.Xtilde.detach().cpu().numpy()[1, :, 0], model.Xtilde.detach().cpu().numpy()[1, :, 1], color="red") + + plt.axis("off") + + +def callback_twod_multimodal( + model, data_dict, X_aligned, axes, rgb=False, scatterpoint_size=100 +): + + # if model.fixed_view_idx is not None: + # if is_mle: + # pass + # else: + # curr_idx = model.view_idx["expression"][model.fixed_view_idx] + # X_aligned["expression"][curr_idx] = torch.tensor(X[curr_idx].astype(np.float32)) + + model.eval() + markers = [".", "+", "^"] + colors = ["blue", "orange"] + + [ax.cla() for ax in axes] + + axes[0].set_title("Observed expression") + axes[1].set_title("Aligned expression") + axes[2].set_title("Observed histology") + axes[3].set_title("Aligned histology") + + axis_counter = 0 + n_mods = 2 + for mod in ["expression", "histology"]: + curr_view_idx = model.view_idx[mod] + for vv in range(model.n_views): + + # import ipdb; ipdb.set_trace() + curr_coords = data_dict[mod]["spatial_coords"] + + if mod == "histology" and rgb: + curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], :] + else: + curr_outputs = data_dict[mod]["outputs"][curr_view_idx[vv], 0] + axes[axis_counter].scatter( + curr_coords[curr_view_idx[vv], 0], + curr_coords[curr_view_idx[vv], 1], + c=curr_outputs, + label="View {}".format(vv + 1), + marker=markers[vv], + s=scatterpoint_size, + ) + axes[axis_counter + 1].scatter( + X_aligned[mod].detach().cpu().numpy()[curr_view_idx[vv], 0], + X_aligned[mod].detach().cpu().numpy()[curr_view_idx[vv], 1], + c=curr_outputs, + label="View {}".format(vv + 1), + marker=markers[vv], + s=scatterpoint_size, + ) + axis_counter += n_mods