# -*- coding: utf-8 -*-
import torch
from torch.autograd import Variable
class MatrixInverseSqrt(torch.autograd.Function):
"""Matrix inverse square root for a symmetric definite positive matrix
"""
@staticmethod
def forward(ctx, input, eps=1e-2):
use_cuda = input.is_cuda
if input.shape[0] < 300:
input = input.cpu()
e, v = torch.linalg.eigh(input, UPLO='U')
if use_cuda and not e.is_cuda:
e = e.cuda()
v = v.cuda()
e = e.clamp(min=0)
e_sqrt = e.sqrt_().add_(eps)
ctx.e_sqrt = e_sqrt
ctx.v = v
e_rsqrt = e_sqrt.reciprocal()
output = v.mm(torch.diag(e_rsqrt).mm(v.t()))
return output
@staticmethod
def backward(ctx, grad_output):
e_sqrt, v = Variable(ctx.e_sqrt), Variable(ctx.v)
ei = e_sqrt.expand_as(v)
ej = e_sqrt.view([-1, 1]).expand_as(v)
f = torch.reciprocal((ei + ej) * ei * ej)
grad_input = -v.mm((f*(v.t().mm(grad_output.mm(v)))).mm(v.t()))
return grad_input, None
def matrix_inverse_sqrt(input, eps=1e-2):
"""Wrapper for MatrixInverseSqrt"""
return MatrixInverseSqrt.apply(input, eps)