[e66fb7]: / src / mil_models / OT / ckn / ops.py

Download this file

40 lines (33 with data), 1.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
class MatrixInverseSqrt(torch.autograd.Function):
"""Matrix inverse square root for a symmetric definite positive matrix
"""
@staticmethod
def forward(ctx, input, eps=1e-2):
use_cuda = input.is_cuda
if input.shape[0] < 300:
input = input.cpu()
e, v = torch.linalg.eigh(input, UPLO='U')
if use_cuda and not e.is_cuda:
e = e.cuda()
v = v.cuda()
e = e.clamp(min=0)
e_sqrt = e.sqrt_().add_(eps)
ctx.e_sqrt = e_sqrt
ctx.v = v
e_rsqrt = e_sqrt.reciprocal()
output = v.mm(torch.diag(e_rsqrt).mm(v.t()))
return output
@staticmethod
def backward(ctx, grad_output):
e_sqrt, v = Variable(ctx.e_sqrt), Variable(ctx.v)
ei = e_sqrt.expand_as(v)
ej = e_sqrt.view([-1, 1]).expand_as(v)
f = torch.reciprocal((ei + ej) * ei * ej)
grad_input = -v.mm((f*(v.t().mm(grad_output.mm(v)))).mm(v.t()))
return grad_input, None
def matrix_inverse_sqrt(input, eps=1e-2):
"""Wrapper for MatrixInverseSqrt"""
return MatrixInverseSqrt.apply(input, eps)