# -*- coding: utf-8 -*-
import math
import torch
from .utils import spherical_kmeans
import pdb
import ot
# from src.utils.losses import compute_distance_euclidean
def compute_distance_euclidean(inp, out):
"""
Compute Euclidean distance between prototypes
Args:
inp: (n_input_samples, n_proto_input, feature_dim)
out: (n_output_samples, n_proto_output, feature_dim)
Returns:
dist_mat: (n_input_samples, n_output_samples, n_proto_input, n_proto_output)
Euclidean distance between prototypes
"""
n_inp_samples, p_inp, _ = inp.shape
n_out_samples, p_out, _ = out.shape
dist_mat = torch.zeros((n_inp_samples, n_out_samples, p_inp, p_out)).to(inp.device)
for i in range(n_inp_samples):
for j in range(n_out_samples):
dist = ot.dist(inp[i], out[j], metric='euclidean')
dist_mat[i, j] = dist
return dist_mat
def sinkhorn(dot, mask=None, eps=1e-03, return_kernel=False, max_iter=100):
"""
dot: n x in_size x out_size
mask: n x in_size
output: n x in_size x out_size
"""
n, in_size, out_size = dot.shape
if return_kernel:
K = torch.exp(dot / eps)
else:
K = dot
# K: n x in_size x out_size
u = K.new_ones((n, in_size))
v = K.new_ones((n, out_size))
a = float(out_size / in_size)
if mask is not None:
mask = mask.float()
a = out_size / mask.sum(1, keepdim=True)
for _ in range(max_iter):
u = a / torch.bmm(K, v.view(n, out_size, 1)).view(n, in_size)
if mask is not None:
u = u * mask
v = 1. / torch.bmm(u.view(n, 1, in_size), K).view(n, out_size)
K = u.view(n, in_size, 1) * (K * v.view(n, 1, out_size))
if return_kernel:
K = K / out_size
return (K * dot).sum(dim=[1, 2])
return K
def log_sinkhorn(K, mask=None, eps=1.0, return_kernel=False, max_iter=100):
"""
dot: n x in_size x out_size
mask: n x in_size
output: n x in_size x out_size
"""
batch_size, in_size, out_size = K.shape
def min_eps(u, v, dim):
Z = (K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps
return -torch.logsumexp(Z, dim=dim)
# K: batch_size x in_size x out_size
u = K.new_zeros((batch_size, in_size))
v = K.new_zeros((batch_size, out_size))
a = torch.ones_like(u).fill_(out_size / in_size)
if mask is not None:
a = out_size / mask.float().sum(1, keepdim=True)
a = torch.log(a)
for _ in range(max_iter):
u = eps * (a + min_eps(u, v, dim=-1)) + u
if mask is not None:
u = u.masked_fill(~mask, -1e8)
v = eps * min_eps(u, v, dim=1) + v
if return_kernel:
output = torch.exp(
(K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps)
output = output / out_size
return (output * K).sum(dim=[1, 2])
K = torch.exp(
(K + u.view(batch_size, in_size, 1) + v.view(batch_size, 1, out_size)) / eps)
return K
def multihead_attn(input, weight, mask=None, eps=1.0, return_kernel=False, distance='euclidean',
max_iter=100, log_domain=False, position_filter=None):
"""Compute the attention weight using Sinkhorn OT
input: n x in_size x in_dim
mask: n x in_size
weight: m x out_size x in_dim (m: number of heads/ref)
output: n x out_size x m x in_size
"""
n, in_size, in_dim = input.shape
m, out_size = weight.shape[:-1]
# Inner product tends to be unstable. Default to Euclidean
if distance == 'euclidean':
K = compute_distance_euclidean(input, weight)
elif distance == 'inner':
K = torch.tensordot(input, weight, dims=[[-1], [-1]])
K = K.permute(0, 2, 1, 3)
else:
raise NotImplementedError(f"Not implemented for {distance}")
if position_filter is not None:
K = position_filter * K
# K: n x m x in_size x out_size
K = K.reshape(-1, in_size, out_size)
# K: nm x in_size x out_size
if mask is not None:
mask = mask.repeat_interleave(m, dim=0)
if log_domain:
K = log_sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter)
else:
if not return_kernel:
K = torch.exp(K / eps)
K = sinkhorn(K, mask, eps, return_kernel=return_kernel, max_iter=max_iter)
# K: nm x in_size x out_size
if return_kernel:
return K.reshape(n, m)
K = K.reshape(n, m, in_size, out_size)
if position_filter is not None:
K = position_filter * K
K = K.permute(0, 3, 1, 2).contiguous()
return K
def wasserstein_barycenter(x, c, eps=1.0, max_iter=100, sinkhorn_iter=50, log_domain=False):
"""
x: n x in_size x in_dim
c: out_size x in_dim
"""
prev_c = c
for i in range(max_iter):
T = attn(x, c, eps=eps, log_domain=log_domain, max_iter=sinkhorn_iter)
# T: n x out_size x in_size
c = 0.5*c + 0.5*torch.bmm(T, x).mean(dim=0) / math.sqrt(c.shape[0])
c /= c.norm(dim=-1, keepdim=True).clamp(min=1e-06)
if ((c - prev_c) ** 2).sum() < 1e-06:
break
prev_c = c
return c
def wasserstein_kmeans(x, n_clusters, out_size, eps=1.0, block_size=None, max_iter=100,
sinkhorn_iter=50, wb=False, verbose=True, log_domain=False, use_cuda=False):
"""
x: n x in_size x in_dim
output: n_clusters x out_size x in_dim
out_size <= in_size
"""
n, in_size, in_dim = x.shape
if n_clusters == 1:
if use_cuda:
x = x.cuda()
clusters = spherical_kmeans(x.view(-1, in_dim), out_size, block_size=block_size)
if wb:
clusters = wasserstein_barycenter(x, clusters, eps=0.1, log_domain=False)
clusters = clusters.unsqueeze_(0)
return clusters
## intialization
indices = torch.randperm(n)[:n_clusters]
clusters = x[indices, :out_size, :].clone()
if use_cuda:
clusters = clusters.cuda()
wass_sim = x.new_empty(n)
assign = x.new_empty(n, dtype=torch.long)
if block_size is None or block_size == 0:
block_size = n
prev_sim = float('inf')
for n_iter in range(max_iter):
for i in range(0, n, block_size):
end_i = min(i + block_size, n)
x_batch = x[i: end_i]
if use_cuda:
x_batch = x_batch.cuda()
tmp_sim = multihead_attn(x_batch, clusters, eps=eps, return_kernel=True, max_iter=sinkhorn_iter, log_domain=log_domain)
tmp_sim = tmp_sim.cpu()
wass_sim[i : end_i], assign[i: end_i] = tmp_sim.max(dim=-1)
del x_batch
sim = wass_sim.mean()
if verbose and (n_iter + 1) % 10 == 0:
print("Wasserstein spherical kmeans iter {}, objective value {}".format(
n_iter + 1, sim))
for j in range(n_clusters):
index = assign == j
if index.sum() == 0:
idx = wass_sim.argmin()
clusters[j].copy_(x[idx, :out_size, :])
wass_sim[idx] = 1
else:
xj = x[index]
if use_cuda:
xj = xj.cuda()
c = spherical_kmeans(xj.view(-1, in_dim), out_size, block_size=block_size, verbose=False)
if wb:
c = wasserstein_barycenter(xj, c, eps=0.001, log_domain=True, sinkhorn_iter=50)
clusters[j] = c
if torch.abs(prev_sim - sim) / sim.clamp(min=1e-10) < 1e-6:
break
prev_sim = sim
return clusters