--- a +++ b/src/.ipynb_checkpoints/scot-checkpoint.py @@ -0,0 +1,175 @@ +""" +Author: Pinar Demetci, Rebecca Santorella +12 February 2020 +Utils for SCOT +""" +import numpy as np +import ot +from ot.bregman import sinkhorn +from ot.utils import dist, UndefinedParameter +from ot.optim import cg +from ot.gromov import init_matrix, gwggrad, gwloss +import src.utils as ut +import sys + +# We use the unbalanced Gromov-Wasserstein solver (utilising PyTorch) written by Thibault Sejourne +# https://github.com/thibsej/unbalanced_gromov_wasserstein +# Clone the above repo and update below to the corresponding path +sys.path.insert(0, "/home/zsteve/analysis/unbalanced_gromov_wasserstein/solver") +from tlb_kl_sinkhorn_solver import TLBSinkhornSolver +import torch + +def scot(X, y, k, e, rho = 1, mode="connectivity", metric="correlation", XontoY=True, returnCoupling=False, balanced = True): + """ + Given two datasets (X and y) and + the hyperparameters (k: number of neighbors to be used in kNN graph construction; and e: eplison value in entropic regularization), + returns the resulting datasets after transport + For transport in the opposite direction, set XontoY to False + """ + ## I think we should let users choose what type of normalization they want to use since that might matter in comparisons to other methods + ## e.g. MMD-MA uses l-2 norm sometimes and they might want to use that for fair comparison. So commented out the part below -Pinar + # X=ut.zscore_standardize(np.asarray(X)) + # y=ut.zscore_standardize(np.asarray(y)) + + # Construct the kNN graphs + Cx=ut.get_graph_distance_matrix(X, k, mode=mode, metric=metric) + Cy=ut.get_graph_distance_matrix(y, k, mode=mode, metric=metric); + + # Initialize marginal distributions over data: + X_sampleNo= Cx.shape[0] + y_sampleNo= Cy.shape[0] + p=ot.unif(X_sampleNo) + q=ot.unif(y_sampleNo) + + # Perform optimization to get the coupling matrix between domains: + if balanced: + couplingM, log = ot.gromov.entropic_gromov_wasserstein(Cx, Cy, p, q, 'square_loss', epsilon=e, log=True, verbose=True) + else: + solver = TLBSinkhornSolver(nits=1000, nits_sinkhorn=2500, gradient=False, tol=1e-3, tol_sinkhorn=1e-3) + couplingM, _ = solver.tlb_sinkhorn(torch.Tensor(p).cuda(), torch.Tensor(Cx).cuda(), torch.Tensor(q).cuda(), torch.Tensor(Cy).cuda(), rho=rho*0.5*(Cx.mean() + Cy.mean()), eps=e, init=None) + couplingM = couplingM.cpu().numpy() + log = None + + # check to make sure GW congerged, if not, warn the user with an error statement + converged=True #initialize the convergence flag + if (np.isnan(couplingM).any() or np.any(~couplingM.any(axis=1)) or np.any(~couplingM.any(axis=0))): # or sum(sum(couplingM)) < .95): + print("Did not converge. Try increasing the epsilon value. ") + converged=False + + # If the user wants to get the coupling matrix and the optimization log at the end of this and investigate it or perform some projection themselves, + # allow them (useful for hyperparameter tuning with projections in both directions) : + if returnCoupling==True: + return couplingM, log + + # Otherwise perform barycentric projection in the desired direction and return the aligned matrices + else: + if converged==False: #except if the convergence failed, just return None, None. + return None, None + if XontoY==True: + X_transported = ut.transport_data(X,y,couplingM,transposeCoupling=False) + return X_transported, y + else: + y_transported = ut.transport_data(X,y,couplingM,transposeCoupling=True) + return X, y_transported +# runs scot for given values of k and epsilon +# by default returns the parameters corresponding to the lowest gromov-wasserstein distance +# (optional) returns all data points for plotting +def search_scot(X,y, ks, es, plot_values = False): + + X_sampleNo= X.shape[0] + y_sampleNo= y.shape[0] + p=ot.unif(X_sampleNo) + q=ot.unif(y_sampleNo) + + # store values of k, epsilon, and gw distance + k_plot=[] + e_plot=[] + g_plot=[] + + total=len(es)*len(ks) + counter=0 + + # search in k first to reduce graph computation + for k in ks: + Cx=ut.get_graph_distance_matrix(X, k, mode="connectivity", metric="correlation") + Cy=ut.get_graph_distance_matrix(y, k, mode="connectivity", metric="correlation") + + for e in es: + + counter+=1 + if (counter % 10 == 0): + print(str(counter)+"/"+str(total)) + + # run scot + gw, log = ot.gromov.entropic_gromov_wasserstein(Cx, Cy, p, q, 'square_loss', epsilon = e, log=True, verbose=False, max_iter = 200) + + if (np.isnan(gw).any() or np.any(~gw.any(axis=1)) or np.any(~gw.any(axis=0)) or sum(sum(gw)) < .95): + print("Did not converge") + else: + g_plot.append(log["gw_dist"]) + k_plot.append(k) + e_plot.append(e) + + # find the parameters corresponding to the lowest gromov-wasserstein distance + gmin=np.amin(g_plot) + gminI=np.argmin(g_plot) + e_best = e_plot[gminI] + k_best = k_plot[gminI] + print("Best result with GW distance is when e and k are:", e_best, k_best, " with lowest GW dist:", gmin) + + if plot_values: + return g_plot, k_plot, e_plot + else: + return k_best, e_best + +# find the best alignment by gromov-wasserstein distance +def unsupervised_scot(X,y, XontoY=True): + + # use k = 20% of # sample or k = 50 if dataset is large + n = min(X.shape[0], y.shape[0]) + k_best = min(n // 5, 50) + + # first fix k and find the best epsilon (6 runs) + es = np.logspace(-2, -3, 6) + g1, k1, e1 = search_scot(X,y,[k_best], es, plot_values = True) + + # save the best epsilon from that search + gmin = np.min(g1) + gminI=np.argmin(g1) + e_best = e1[gminI] + + # fix that epsilon and vary k (4 runs) + if ( n > 250): + ks = np.linspace(20, 100, 4) + else: + ks = np.linspace(X.shape[0]//20, X.shape[0]//6, 4) + ks = ks.astype(int) + g2, k2, e2 = search_scot(X,y,ks, [e_best], plot_values = True) + + # save the best k from that search + gminI=np.argmin(g2) + if (g2[gminI] < gmin): + gmin = g2[gminI] + k_best = k2[gminI] + + # now use that k and epsilon to do a more refined grid search (10 runs) + scale = np.log10(e_best) + eps_refined = np.logspace(scale + .25, scale - .25, 5) + + ks_refined = np.linspace( max(5, k_best - 5), min(X.shape[0]//2, k_best + 5), 2) + ks_refined = ks_refined.astype(int) + g3, k3, e3 = search_scot(X, y, ks_refined, eps_refined, plot_values = True) + + # find the best parameter set from all runs + gminI=np.argmin(g3) + if (g3[gminI] < gmin): + gmin = g3[gminI] + k_best = k3[gminI] + e_best = e3[gminI] + + print("Lowest GW distance is ", gmin, " for epsilon = ", e_best, " and k = ", k_best) + + # run soct with these parameters + X_t, y_t = scot(X, y, k_best, e_best, XontoY = XontoY) + + return X_t, y_t, k_best, e_best