[286bfb]: / src / mil_models / OT / ckn / utils.py

Download this file

138 lines (113 with data), 4.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
import math
import random
import numpy as np
import torch
EPS = 1e-4
def gaussian_filter_1d(size, sigma=None):
"""Create 1D Gaussian filter
"""
if size == 1:
return torch.ones(1)
if sigma is None:
sigma = (size - 1.)/(2.*math.sqrt(2))
m = float((size - 1) // 2)
filt = torch.arange(-m, m+1)
filt = torch.exp(-filt.pow(2)/(2.*sigma*sigma))
return filt/torch.sum(filt)
def init_kmeans(x, n_clusters, n_local_trials=None, use_cuda=False):
n_samples, n_features = x.size()
clusters = torch.Tensor(n_clusters, n_features)
if use_cuda:
clusters = clusters.cuda()
if n_local_trials is None:
n_local_trials = 2 + int(np.log(n_clusters))
clusters[0] = x[np.random.randint(n_samples)]
closest_dist_sq = 1 - clusters[[0]].mm(x.t())
closest_dist_sq = closest_dist_sq.view(-1)
current_pot = closest_dist_sq.sum()
for c in range(1, n_clusters):
rand_vals = np.random.random_sample(n_local_trials) * current_pot
candidate_ids = np.searchsorted(closest_dist_sq.cumsum(-1), rand_vals)
distance_to_candidates = 1 - x[candidate_ids].mm(x.t())
best_candidate = None
best_pot = None
best_dist_sq = None
for trial in range(n_local_trials):
# Compute potential when including center candidate
new_dist_sq = torch.min(closest_dist_sq,
distance_to_candidates[trial])
new_pot = new_dist_sq.sum()
# Store result if it is the best local trial so far
if (best_candidate is None) or (new_pot < best_pot):
best_candidate = candidate_ids[trial]
best_pot = new_pot
best_dist_sq = new_dist_sq
clusters[c] = x[best_candidate]
current_pot = best_pot
closest_dist_sq = best_dist_sq
return clusters
def spherical_kmeans(x, n_clusters, max_iters=100, verbose=True,
init=None, eps=1e-4):
"""Spherical kmeans
Args:
x (Tensor n_samples x n_features): data points
n_clusters (int): number of clusters
"""
use_cuda = x.is_cuda
n_samples, n_features = x.size()
if init == "kmeans++":
print(init)
clusters = init_kmeans(x, n_clusters, use_cuda=use_cuda)
else:
indices = torch.randperm(n_samples)[:n_clusters]
if use_cuda:
indices = indices.cuda()
clusters = x[indices]
prev_sim = np.inf
for n_iter in range(max_iters):
# assign data points to clusters
cos_sim = x.mm(clusters.t())
tmp, assign = cos_sim.max(dim=-1)
sim = tmp.mean()
if (n_iter + 1) % 10 == 0 and verbose:
print("Spherical kmeans iter {}, objective value {}".format(
n_iter + 1, sim))
# update clusters
for j in range(n_clusters):
index = assign == j
if index.sum() == 0:
# clusters[j] = x[random.randrange(n_samples)]
idx = tmp.argmin()
clusters[j] = x[idx]
tmp[idx] = 1
else:
xj = x[index]
c = xj.mean(0)
clusters[j] = c / c.norm()
if torch.abs(prev_sim - sim)/(torch.abs(sim)+1e-20) < 1e-6:
break
prev_sim = sim
return clusters
def normalize_(x, p=2, dim=-1):
norm = x.norm(p=p, dim=dim, keepdim=True)
x.div_(norm.clamp(min=EPS))
return x
def flip(x, dim=-1):
"""Reverse a tensor along given axis
can be removed later when Pytorch updated
"""
reverse_indices = torch.arange(x.size(dim) - 1, -1, -1)
reverse_indices = reverse_indices.type_as(x.data).long()
return x.index_select(dim=dim, index=reverse_indices)
def proj_on_simplex(x, axis=0, r=1., inplace=True):
d = x.size(axis)
mu, indices = torch.sort(x, dim=axis, descending=True)
diag = torch.cumsum(mu, dim=axis) - r
theta = diag / torch.arange(1., d+1).view(-1, 1).expand_as(diag)
indices = torch.sum((mu > theta).long(), dim=axis, keepdim=True) - 1
theta = torch.gather(theta, dim=axis, index=indices)
if inplace:
x.add_(-theta).clamp_(min=0)
return x
return torch.clamp(x - theta, min=0)