|
a |
|
b/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
def train_test_split(adata, test_ratio = 0.1,seed=1): |
|
|
5 |
"""Splits the adata into a training set and a test set. |
|
|
6 |
Args: |
|
|
7 |
adata: the dataset to be splitted. |
|
|
8 |
test_ratio: ratio of the test data in adata. |
|
|
9 |
seed: random seed. |
|
|
10 |
Returns: |
|
|
11 |
the training set and the test set, both in AnnData format. |
|
|
12 |
""" |
|
|
13 |
|
|
|
14 |
rng = np.random.default_rng(seed=seed) |
|
|
15 |
test_indices = rng.choice(adata.n_obs, size=int(test_ratio * adata.n_obs), replace=False) |
|
|
16 |
train_indices = list(set(range(adata.n_obs)).difference(test_indices)) |
|
|
17 |
train_adata = adata[adata.obs_names[train_indices], :] |
|
|
18 |
test_adata = adata[adata.obs_names[test_indices], :] |
|
|
19 |
|
|
|
20 |
return train_adata, test_adata |
|
|
21 |
|
|
|
22 |
def calc_weight( |
|
|
23 |
epoch: int, |
|
|
24 |
n_epochs: int, |
|
|
25 |
cutoff_ratio: float = 0., |
|
|
26 |
warmup_ratio: float = 1 / 3, |
|
|
27 |
min_weight: float = 0., |
|
|
28 |
max_weight: float = 1e-7 |
|
|
29 |
) -> float: |
|
|
30 |
"""Calculates weights. |
|
|
31 |
Args: |
|
|
32 |
epoch: current epoch. |
|
|
33 |
n_epochs: the total number of epochs to train the model. |
|
|
34 |
cutoff_ratio: ratio of cutoff epochs (set weight to zero) and |
|
|
35 |
n_epochs. |
|
|
36 |
warmup_ratio: ratio of warmup epochs and n_epochs. |
|
|
37 |
min_weight: minimum weight. |
|
|
38 |
max_weight: maximum weight. |
|
|
39 |
Returns: |
|
|
40 |
The current weight of the KL term. |
|
|
41 |
""" |
|
|
42 |
|
|
|
43 |
fully_warmup_epoch = n_epochs * warmup_ratio |
|
|
44 |
|
|
|
45 |
if epoch < n_epochs * cutoff_ratio: |
|
|
46 |
return 0. |
|
|
47 |
if warmup_ratio: |
|
|
48 |
return max(min(1., epoch / fully_warmup_epoch) * max_weight, min_weight) |
|
|
49 |
else: |
|
|
50 |
return max_weight |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
import math |
|
|
54 |
import os |
|
|
55 |
import shutil |
|
|
56 |
import sys |
|
|
57 |
import time |
|
|
58 |
|
|
|
59 |
import torch |
|
|
60 |
import torch.distributions as dist |
|
|
61 |
import torch.nn.functional as F |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
# Classes |
|
|
65 |
class Constants(object): |
|
|
66 |
eta = 1e-6 |
|
|
67 |
eps = 1e-8 |
|
|
68 |
log2 = math.log(2) |
|
|
69 |
log2pi = math.log(2 * math.pi) |
|
|
70 |
logceilc = 88 # largest cuda v s.t. exp(v) < inf |
|
|
71 |
logfloorc = -104 # smallest cuda v s.t. exp(v) > 0 |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting |
|
|
75 |
class Logger(object): |
|
|
76 |
def __init__(self, filename, mode="a"): |
|
|
77 |
self.terminal = sys.stdout |
|
|
78 |
self.log = open(filename, mode) |
|
|
79 |
|
|
|
80 |
def write(self, message): |
|
|
81 |
self.terminal.write(message) |
|
|
82 |
self.log.write(message) |
|
|
83 |
|
|
|
84 |
def flush(self): |
|
|
85 |
# this flush method is needed for python 3 compatibility. |
|
|
86 |
# this handles the flush command by doing nothing. |
|
|
87 |
# you might want to specify some extra behavior here. |
|
|
88 |
pass |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
class Timer: |
|
|
92 |
def __init__(self, name): |
|
|
93 |
self.name = name |
|
|
94 |
|
|
|
95 |
def __enter__(self): |
|
|
96 |
self.begin = time.time() |
|
|
97 |
return self |
|
|
98 |
|
|
|
99 |
def __exit__(self, *args): |
|
|
100 |
self.end = time.time() |
|
|
101 |
self.elapsed = self.end - self.begin |
|
|
102 |
self.elapsedH = time.gmtime(self.elapsed) |
|
|
103 |
print('====> [{}] Time: {:7.3f}s or {}' |
|
|
104 |
.format(self.name, |
|
|
105 |
self.elapsed, |
|
|
106 |
time.strftime("%H:%M:%S", self.elapsedH))) |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
# Functions |
|
|
110 |
def save_vars(vs, filepath): |
|
|
111 |
""" |
|
|
112 |
Saves variables to the given filepath in a safe manner. |
|
|
113 |
""" |
|
|
114 |
if os.path.exists(filepath): |
|
|
115 |
shutil.copyfile(filepath, '{}.old'.format(filepath)) |
|
|
116 |
torch.save(vs, filepath) |
|
|
117 |
|
|
|
118 |
|
|
|
119 |
def save_model(model, filepath): |
|
|
120 |
""" |
|
|
121 |
To load a saved model, simply use |
|
|
122 |
`model.load_state_dict(torch.load('path-to-saved-model'))`. |
|
|
123 |
""" |
|
|
124 |
save_vars(model.state_dict(), filepath) |
|
|
125 |
# if hasattr(model, 'vaes'): |
|
|
126 |
# for vae in model.vaes: |
|
|
127 |
# fdir, fext = os.path.splitext(filepath) |
|
|
128 |
# save_vars(vae.state_dict(), fdir + '_' + vae.modelName + fext) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
def is_multidata(dataB): |
|
|
132 |
return isinstance(dataB, list) or isinstance(dataB, tuple) |
|
|
133 |
|
|
|
134 |
|
|
|
135 |
def unpack_data(dataB, device='cuda'): |
|
|
136 |
# dataB :: (Tensor, Idx) | [(Tensor, Idx)] |
|
|
137 |
""" Unpacks the data batch object in an appropriate manner to extract data """ |
|
|
138 |
if is_multidata(dataB): |
|
|
139 |
if torch.is_tensor(dataB[0]): |
|
|
140 |
if torch.is_tensor(dataB[1]): |
|
|
141 |
return dataB[0].to(device) # mnist, svhn, cubI |
|
|
142 |
elif is_multidata(dataB[1]): |
|
|
143 |
return dataB[0].to(device), dataB[1][0].to(device) # cubISft |
|
|
144 |
else: |
|
|
145 |
raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[1]))) |
|
|
146 |
|
|
|
147 |
elif is_multidata(dataB[0]): |
|
|
148 |
return [d.to(device) for d in list(zip(*dataB))[0]] # mnist-svhn, cubIS |
|
|
149 |
else: |
|
|
150 |
raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB[0]))) |
|
|
151 |
elif torch.is_tensor(dataB): |
|
|
152 |
return dataB.to(device) |
|
|
153 |
else: |
|
|
154 |
raise RuntimeError('Invalid data format {} -- check your dataloader!'.format(type(dataB))) |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
def get_mean(d, K=100): |
|
|
158 |
""" |
|
|
159 |
Extract the `mean` parameter for given distribution. |
|
|
160 |
If attribute not available, estimate from samples. |
|
|
161 |
""" |
|
|
162 |
try: |
|
|
163 |
mean = d.mean |
|
|
164 |
except NotImplementedError: |
|
|
165 |
samples = d.rsample(torch.Size([K])) |
|
|
166 |
mean = samples.mean(0) |
|
|
167 |
return mean |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
def log_mean_exp(value, dim=0, keepdim=False): |
|
|
171 |
return torch.logsumexp(value, dim, keepdim=keepdim) - math.log(value.size(dim)) |
|
|
172 |
|
|
|
173 |
|
|
|
174 |
def kl_divergence(d1, d2, K=100): |
|
|
175 |
"""Computes closed-form KL if available, else computes a MC estimate.""" |
|
|
176 |
if (type(d1), type(d2)) in torch.distributions.kl._KL_REGISTRY: |
|
|
177 |
return torch.distributions.kl_divergence(d1, d2) |
|
|
178 |
else: |
|
|
179 |
samples = d1.rsample(torch.Size([K])) |
|
|
180 |
return (d1.log_prob(samples) - d2.log_prob(samples)).mean(0) |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
def vade_kld_uni(model, zs): |
|
|
184 |
n_centroids = model.params.n_centroids |
|
|
185 |
gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs) # pi, var_cは get_gammaでConstants.eta足してる |
|
|
186 |
|
|
|
187 |
mu, var = model._qz_x_params |
|
|
188 |
mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids) |
|
|
189 |
var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids) |
|
|
190 |
lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \ |
|
|
191 |
torch.log(var_c) + \ |
|
|
192 |
var_expand / var_c + \ |
|
|
193 |
(mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1) # log p(z|c) |
|
|
194 |
lpc = torch.sum(gamma * torch.log(pi), dim=1) # log p(c) #log(pi)が-inf怪しい |
|
|
195 |
lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1) # see VaDE paper # log q(z|x) |
|
|
196 |
lqc_x = torch.sum(gamma * (lgamma), dim=1) # log q(c|x) |
|
|
197 |
|
|
|
198 |
kld = -lpz_c - lpc + lqz_x + lqc_x |
|
|
199 |
|
|
|
200 |
return kld |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
def vade_kld(model, zs, r): |
|
|
204 |
n_centroids = model.params.n_centroids |
|
|
205 |
gamma, lgamma, mu_c, var_c, pi = model.get_gamma(zs) # pi, var_cは get_gammaでConstants.eta足してる |
|
|
206 |
|
|
|
207 |
mu, var = model.vaes[r]._qz_x_params |
|
|
208 |
mu_expand = mu.unsqueeze(2).expand(mu.size(0), mu.size(1), n_centroids) |
|
|
209 |
var_expand = var.unsqueeze(2).expand(var.size(0), var.size(1), n_centroids) |
|
|
210 |
lpz_c = -0.5 * torch.sum(gamma * torch.sum(math.log(2 * math.pi) + \ |
|
|
211 |
torch.log(var_c) + \ |
|
|
212 |
var_expand / var_c + \ |
|
|
213 |
(mu_expand - mu_c) ** 2 / var_c, dim=1), dim=1) # log p(z|c) |
|
|
214 |
lpc = torch.sum(gamma * torch.log(pi), dim=1) # log p(c) #log(pi)が-inf怪しい |
|
|
215 |
lqz_x = -0.5 * torch.sum(1 + torch.log(var) + math.log(2 * math.pi), dim=1) # see VaDE paper # log q(z|x) |
|
|
216 |
lqc_x = torch.sum(gamma * (lgamma), dim=1) # log q(c|x) |
|
|
217 |
|
|
|
218 |
kld = -lpz_c - lpc + lqz_x + lqc_x |
|
|
219 |
|
|
|
220 |
return kld |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
def pdist(sample_1, sample_2, eps=1e-5): |
|
|
224 |
"""Compute the matrix of all squared pairwise distances. Code |
|
|
225 |
adapted from the torch-two-sample library (added batching). |
|
|
226 |
You can find the original implementation of this function here: |
|
|
227 |
https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/util.py |
|
|
228 |
|
|
|
229 |
Arguments |
|
|
230 |
--------- |
|
|
231 |
sample_1 : torch.Tensor or Variable |
|
|
232 |
The first sample, should be of shape ``(batch_size, n_1, d)``. |
|
|
233 |
sample_2 : torch.Tensor or Variable |
|
|
234 |
The second sample, should be of shape ``(batch_size, n_2, d)``. |
|
|
235 |
norm : float |
|
|
236 |
The l_p norm to be used. |
|
|
237 |
batched : bool |
|
|
238 |
whether data is batched |
|
|
239 |
|
|
|
240 |
Returns |
|
|
241 |
------- |
|
|
242 |
torch.Tensor or Variable |
|
|
243 |
Matrix of shape (batch_size, n_1, n_2). The [i, j]-th entry is equal to |
|
|
244 |
``|| sample_1[i, :] - sample_2[j, :] ||_p``.""" |
|
|
245 |
if len(sample_1.shape) == 2: |
|
|
246 |
sample_1, sample_2 = sample_1.unsqueeze(0), sample_2.unsqueeze(0) |
|
|
247 |
B, n_1, n_2 = sample_1.size(0), sample_1.size(1), sample_2.size(1) |
|
|
248 |
norms_1 = torch.sum(sample_1 ** 2, dim=-1, keepdim=True) |
|
|
249 |
norms_2 = torch.sum(sample_2 ** 2, dim=-1, keepdim=True) |
|
|
250 |
norms = (norms_1.expand(B, n_1, n_2) |
|
|
251 |
+ norms_2.transpose(1, 2).expand(B, n_1, n_2)) |
|
|
252 |
distances_squared = norms - 2 * sample_1.matmul(sample_2.transpose(1, 2)) |
|
|
253 |
return torch.sqrt(eps + torch.abs(distances_squared)).squeeze() # batch x K x latent |
|
|
254 |
|
|
|
255 |
|
|
|
256 |
def NN_lookup(emb_h, emb, data): |
|
|
257 |
indices = pdist(emb.to(emb_h.device), emb_h).argmin(dim=0) |
|
|
258 |
# indices = torch.tensor(cosine_similarity(emb, emb_h.cpu().numpy()).argmax(0)).to(emb_h.device).squeeze() |
|
|
259 |
return data[indices] |
|
|
260 |
|
|
|
261 |
|
|
|
262 |
class FakeCategorical(dist.Distribution): |
|
|
263 |
support = dist.constraints.real |
|
|
264 |
has_rsample = True |
|
|
265 |
|
|
|
266 |
def __init__(self, locs): |
|
|
267 |
self.logits = locs |
|
|
268 |
self._batch_shape = self.logits.shape |
|
|
269 |
|
|
|
270 |
@property |
|
|
271 |
def mean(self): |
|
|
272 |
return self.logits |
|
|
273 |
|
|
|
274 |
def sample(self, sample_shape=torch.Size()): |
|
|
275 |
with torch.no_grad(): |
|
|
276 |
return self.rsample(sample_shape) |
|
|
277 |
|
|
|
278 |
def rsample(self, sample_shape=torch.Size()): |
|
|
279 |
return self.logits.expand([*sample_shape, *self.logits.shape]).contiguous() |
|
|
280 |
|
|
|
281 |
def log_prob(self, value): |
|
|
282 |
# value of shape (K, B, D) |
|
|
283 |
lpx_z = -F.cross_entropy(input=self.logits.view(-1, self.logits.size(-1)), |
|
|
284 |
target=value.expand(self.logits.size()[:-1]).long().view(-1), |
|
|
285 |
reduction='none', |
|
|
286 |
ignore_index=0) |
|
|
287 |
|
|
|
288 |
return lpx_z.view(*self.logits.shape[:-1]) |
|
|
289 |
# it is inevitable to have the word embedding dimension summed up in |
|
|
290 |
# cross-entropy loss ($\sum -gt_i \log(p_i)$ with most gt_i = 0, We adopt the |
|
|
291 |
# operationally equivalence here, which is summing up the sentence dimension |
|
|
292 |
# in objective. |
|
|
293 |
|
|
|
294 |
|
|
|
295 |
# from github Bjarten/early-stopping-pytorch |
|
|
296 |
import numpy as np |
|
|
297 |
import torch |
|
|
298 |
|
|
|
299 |
|
|
|
300 |
class EarlyStopping: |
|
|
301 |
"""Early stops the training if validation loss doesn't improve after a given patience.""" |
|
|
302 |
|
|
|
303 |
def __init__(self, patience=7, verbose=False, delta=0): |
|
|
304 |
""" |
|
|
305 |
Args: |
|
|
306 |
patience (int): How long to wait after last time validation loss improved. |
|
|
307 |
Default: 7 |
|
|
308 |
verbose (bool): If True, prints a message for each validation loss improvement. |
|
|
309 |
Default: False |
|
|
310 |
delta (float): Minimum change in the monitored quantity to qualify as an improvement. |
|
|
311 |
Default: 0 |
|
|
312 |
""" |
|
|
313 |
self.patience = patience |
|
|
314 |
self.verbose = verbose |
|
|
315 |
self.counter = 0 |
|
|
316 |
self.best_score = None |
|
|
317 |
self.early_stop = False |
|
|
318 |
self.val_loss_min = np.Inf |
|
|
319 |
self.delta = delta |
|
|
320 |
|
|
|
321 |
def __call__(self, val_loss, model, runPath): |
|
|
322 |
|
|
|
323 |
score = -val_loss |
|
|
324 |
|
|
|
325 |
if self.best_score is None: |
|
|
326 |
self.best_score = score |
|
|
327 |
self.save_checkpoint(val_loss, model, runPath) |
|
|
328 |
elif score < self.best_score + self.delta: |
|
|
329 |
self.counter += 1 |
|
|
330 |
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') |
|
|
331 |
if self.counter >= self.patience: |
|
|
332 |
self.early_stop = True |
|
|
333 |
else: |
|
|
334 |
self.best_score = score |
|
|
335 |
self.save_checkpoint(val_loss, model, runPath) # runPath追加 |
|
|
336 |
self.counter = 0 |
|
|
337 |
|
|
|
338 |
def save_checkpoint(self, val_loss, model, runPath): |
|
|
339 |
'''Saves model when validation loss decrease.''' |
|
|
340 |
if self.verbose: |
|
|
341 |
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') |
|
|
342 |
# torch.save(model.state_dict(), 'checkpoint.pt') |
|
|
343 |
save_model(model, runPath + '/model.rar') # mmvaeより移植 |
|
|
344 |
self.val_loss_min = val_loss |
|
|
345 |
|
|
|
346 |
|
|
|
347 |
class EarlyStopping_nosave: |
|
|
348 |
"""Early stops the training if validation loss doesn't improve after a given patience.""" |
|
|
349 |
|
|
|
350 |
def __init__(self, patience=7, verbose=False, delta=0): |
|
|
351 |
""" |
|
|
352 |
Args: |
|
|
353 |
patience (int): How long to wait after last time validation loss improved. |
|
|
354 |
Default: 7 |
|
|
355 |
verbose (bool): If True, prints a message for each validation loss improvement. |
|
|
356 |
Default: False |
|
|
357 |
delta (float): Minimum change in the monitored quantity to qualify as an improvement. |
|
|
358 |
Default: 0 |
|
|
359 |
""" |
|
|
360 |
self.patience = patience |
|
|
361 |
self.verbose = verbose |
|
|
362 |
self.counter = 0 |
|
|
363 |
self.best_score = -1e9 |
|
|
364 |
self.early_stop = False |
|
|
365 |
self.val_loss_min = np.Inf |
|
|
366 |
self.delta = delta |
|
|
367 |
|
|
|
368 |
def __call__(self, val_loss, model, runPath): |
|
|
369 |
|
|
|
370 |
score = -val_loss |
|
|
371 |
|
|
|
372 |
if self.best_score is None: |
|
|
373 |
self.best_score = score |
|
|
374 |
elif score < self.best_score + self.delta: |
|
|
375 |
self.counter += 1 |
|
|
376 |
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') |
|
|
377 |
if self.counter >= self.patience: |
|
|
378 |
self.early_stop = True |
|
|
379 |
else: |
|
|
380 |
self.best_score = score |
|
|
381 |
self.counter = 0 |