|
a |
|
b/src/layers.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
class NBLoss(nn.Module): |
|
|
7 |
def __init__(self): |
|
|
8 |
super(NBLoss, self).__init__() |
|
|
9 |
|
|
|
10 |
def forward(self, x, mean, disp, scale_factor=1.0): |
|
|
11 |
eps = 1e-10 |
|
|
12 |
scale_factor = scale_factor[:, None] |
|
|
13 |
mean = mean * scale_factor |
|
|
14 |
|
|
|
15 |
t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps) |
|
|
16 |
t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps))) |
|
|
17 |
result = t1 + t2 |
|
|
18 |
|
|
|
19 |
result = torch.mean(result) |
|
|
20 |
return result |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
class ZINBLoss(nn.Module): |
|
|
24 |
def __init__(self): |
|
|
25 |
super(ZINBLoss, self).__init__() |
|
|
26 |
|
|
|
27 |
def forward(self, x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0): |
|
|
28 |
eps = 1e-10 |
|
|
29 |
scale_factor = scale_factor[:, None] |
|
|
30 |
mean = mean * scale_factor |
|
|
31 |
|
|
|
32 |
t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps) |
|
|
33 |
t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps))) |
|
|
34 |
nb_final = t1 + t2 |
|
|
35 |
|
|
|
36 |
nb_case = nb_final - torch.log(1.0-pi+eps) |
|
|
37 |
zero_nb = torch.pow(disp/(disp+mean+eps), disp) |
|
|
38 |
zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps) |
|
|
39 |
result = torch.where(torch.le(x, 1e-8), zero_case, nb_case) |
|
|
40 |
|
|
|
41 |
if ridge_lambda > 0: |
|
|
42 |
ridge = ridge_lambda*torch.square(pi) |
|
|
43 |
result += ridge |
|
|
44 |
|
|
|
45 |
result = torch.mean(result) |
|
|
46 |
return result |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
class GaussianNoise(nn.Module): |
|
|
50 |
def __init__(self, sigma=0): |
|
|
51 |
super(GaussianNoise, self).__init__() |
|
|
52 |
self.sigma = sigma |
|
|
53 |
|
|
|
54 |
def forward(self, x): |
|
|
55 |
if self.training: |
|
|
56 |
x = x + self.sigma * torch.randn_like(x) |
|
|
57 |
return x |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
class MeanAct(nn.Module): |
|
|
61 |
def __init__(self): |
|
|
62 |
super(MeanAct, self).__init__() |
|
|
63 |
|
|
|
64 |
def forward(self, x): |
|
|
65 |
return torch.clamp(torch.exp(x), min=1e-5, max=1e6) |
|
|
66 |
|
|
|
67 |
class DispAct(nn.Module): |
|
|
68 |
def __init__(self): |
|
|
69 |
super(DispAct, self).__init__() |
|
|
70 |
|
|
|
71 |
def forward(self, x): |
|
|
72 |
return torch.clamp(F.softplus(x), min=1e-4, max=1e4) |