Switch to side-by-side view

--- a
+++ b/src/.ipynb_checkpoints/scot-checkpoint.py
@@ -0,0 +1,175 @@
+"""
+Author: Pinar Demetci, Rebecca Santorella
+12 February 2020
+Utils for SCOT
+"""
+import numpy as np
+import ot
+from ot.bregman import sinkhorn
+from ot.utils import dist, UndefinedParameter
+from ot.optim import cg
+from ot.gromov import init_matrix, gwggrad, gwloss
+import src.utils as ut
+import sys
+
+# We use the unbalanced Gromov-Wasserstein solver (utilising PyTorch) written by Thibault Sejourne 
+# https://github.com/thibsej/unbalanced_gromov_wasserstein
+# Clone the above repo and update below to the corresponding path
+sys.path.insert(0, "/home/zsteve/analysis/unbalanced_gromov_wasserstein/solver")
+from tlb_kl_sinkhorn_solver import TLBSinkhornSolver
+import torch
+
+def scot(X, y, k, e, rho = 1, mode="connectivity", metric="correlation", XontoY=True, returnCoupling=False, balanced = True):
+	"""
+	Given two datasets (X and y) and 
+	the hyperparameters (k: number of neighbors to be used in kNN graph construction; and e: eplison value in entropic regularization),
+	returns the resulting datasets after transport
+	For transport in the opposite direction, set XontoY to False
+	"""
+	## I think we should let users choose what type of normalization they want to use since that might matter in comparisons to other methods
+	## e.g. MMD-MA uses l-2 norm sometimes and they might want to use that for fair comparison. So commented out the part below -Pinar
+	# X=ut.zscore_standardize(np.asarray(X))
+	# y=ut.zscore_standardize(np.asarray(y))
+
+	# Construct the kNN graphs
+	Cx=ut.get_graph_distance_matrix(X, k, mode=mode, metric=metric) 
+	Cy=ut.get_graph_distance_matrix(y, k, mode=mode, metric=metric);
+
+	# Initialize marginal distributions over data:
+	X_sampleNo= Cx.shape[0]
+	y_sampleNo= Cy.shape[0]
+	p=ot.unif(X_sampleNo)
+	q=ot.unif(y_sampleNo)
+
+	# Perform optimization to get the coupling matrix between domains:
+	if balanced:
+		couplingM, log = ot.gromov.entropic_gromov_wasserstein(Cx, Cy, p, q, 'square_loss', epsilon=e, log=True, verbose=True)
+	else:
+		solver = TLBSinkhornSolver(nits=1000, nits_sinkhorn=2500, gradient=False, tol=1e-3, tol_sinkhorn=1e-3)
+		couplingM, _ = solver.tlb_sinkhorn(torch.Tensor(p).cuda(), torch.Tensor(Cx).cuda(), torch.Tensor(q).cuda(), torch.Tensor(Cy).cuda(), rho=rho*0.5*(Cx.mean() + Cy.mean()), eps=e, init=None)
+		couplingM = couplingM.cpu().numpy()
+		log = None
+
+	# check to make sure GW congerged, if not, warn the user with an error statement
+	converged=True #initialize the convergence flag
+	if (np.isnan(couplingM).any() or np.any(~couplingM.any(axis=1)) or np.any(~couplingM.any(axis=0))): # or sum(sum(couplingM)) < .95):
+		print("Did not converge. Try increasing the epsilon value. ")
+		converged=False
+
+	# If the user wants to get the coupling matrix and the optimization log at the end of this and investigate it or perform some projection themselves,
+	# allow them (useful for hyperparameter tuning with projections in both directions) :
+	if returnCoupling==True:
+		return couplingM, log
+
+	# Otherwise perform barycentric projection in the desired direction and return the aligned matrices
+	else:
+		if converged==False: #except if the convergence failed, just return None, None.
+			return None, None
+		if XontoY==True:
+			X_transported = ut.transport_data(X,y,couplingM,transposeCoupling=False)
+			return X_transported, y
+		else:
+			y_transported = ut.transport_data(X,y,couplingM,transposeCoupling=True)
+			return X, y_transported
+# runs scot for given values of k and epsilon
+# by default returns the parameters corresponding to the lowest gromov-wasserstein distance
+# (optional) returns all data points for plotting
+def search_scot(X,y, ks, es, plot_values = False): 
+
+    X_sampleNo= X.shape[0]
+    y_sampleNo= y.shape[0]
+    p=ot.unif(X_sampleNo)
+    q=ot.unif(y_sampleNo)
+
+    # store values of k, epsilon, and gw distance 
+    k_plot=[]
+    e_plot=[]
+    g_plot=[]
+
+    total=len(es)*len(ks)
+    counter=0
+    
+    # search in k first to reduce graph computation
+    for k in ks:
+        Cx=ut.get_graph_distance_matrix(X, k, mode="connectivity", metric="correlation") 
+        Cy=ut.get_graph_distance_matrix(y, k, mode="connectivity", metric="correlation")
+
+        for e in es:
+
+            counter+=1
+            if (counter % 10 == 0):
+                print(str(counter)+"/"+str(total))
+
+            # run scot
+            gw, log = ot.gromov.entropic_gromov_wasserstein(Cx, Cy, p, q, 'square_loss', epsilon = e, log=True, verbose=False, max_iter = 200)
+
+            if (np.isnan(gw).any() or np.any(~gw.any(axis=1)) or np.any(~gw.any(axis=0)) or sum(sum(gw)) < .95):
+                print("Did not converge")
+            else:   
+                g_plot.append(log["gw_dist"])
+                k_plot.append(k)
+                e_plot.append(e)          
+
+    # find the parameters corresponding to the lowest gromov-wasserstein distance
+    gmin=np.amin(g_plot)
+    gminI=np.argmin(g_plot)
+    e_best = e_plot[gminI]
+    k_best = k_plot[gminI]
+    print("Best result with GW distance is when e and k are:", e_best, k_best, " with lowest GW dist:", gmin)    
+
+    if plot_values:
+        return g_plot, k_plot, e_plot
+    else:
+        return k_best, e_best
+
+# find the best alignment by gromov-wasserstein distance
+def unsupervised_scot(X,y, XontoY=True):
+
+    # use k = 20% of # sample or k = 50 if dataset is large 
+    n = min(X.shape[0], y.shape[0]) 
+    k_best = min(n // 5, 50)
+
+    # first fix k and find the best epsilon (6 runs)
+    es = np.logspace(-2, -3, 6)
+    g1, k1, e1 = search_scot(X,y,[k_best], es, plot_values = True)
+
+    # save the best epsilon from that search
+    gmin = np.min(g1)
+    gminI=np.argmin(g1)
+    e_best = e1[gminI]
+
+    # fix that epsilon and vary k (4 runs)
+    if ( n > 250):
+        ks = np.linspace(20, 100, 4)
+    else:
+        ks = np.linspace(X.shape[0]//20, X.shape[0]//6, 4)
+    ks = ks.astype(int)
+    g2, k2, e2 = search_scot(X,y,ks, [e_best], plot_values = True)
+
+    # save the best k from that search 
+    gminI=np.argmin(g2)
+    if (g2[gminI] < gmin):
+        gmin = g2[gminI]
+        k_best = k2[gminI]
+
+    # now use that k and epsilon to do a more refined grid search (10 runs)
+    scale = np.log10(e_best)
+    eps_refined = np.logspace(scale + .25, scale - .25, 5)
+
+    ks_refined = np.linspace( max(5, k_best - 5), min(X.shape[0]//2, k_best + 5), 2)    
+    ks_refined = ks_refined.astype(int)
+    g3, k3, e3 = search_scot(X, y, ks_refined, eps_refined, plot_values = True)
+
+    # find the best parameter set from all runs
+    gminI=np.argmin(g3)
+    if (g3[gminI] < gmin):
+        gmin = g3[gminI]
+        k_best = k3[gminI]
+        e_best = e3[gminI]
+
+    print("Lowest GW distance is ", gmin, " for epsilon = ", e_best, " and k = ", k_best)
+
+    # run soct with these parameters
+    X_t, y_t = scot(X, y, k_best, e_best, XontoY = XontoY)
+    
+    return X_t, y_t, k_best, e_best