Diff of /src/layers.py [000000] .. [ac720d]

Switch to unified view

a b/src/layers.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
import numpy as np
5
6
class NBLoss(nn.Module):
7
    def __init__(self):
8
        super(NBLoss, self).__init__()
9
10
    def forward(self, x, mean, disp, scale_factor=1.0):
11
        eps = 1e-10
12
        scale_factor = scale_factor[:, None]
13
        mean = mean * scale_factor
14
        
15
        t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
16
        t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
17
        result = t1 + t2
18
19
        result = torch.mean(result)
20
        return result
21
22
23
class ZINBLoss(nn.Module):
24
    def __init__(self):
25
        super(ZINBLoss, self).__init__()
26
27
    def forward(self, x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0):
28
        eps = 1e-10
29
        scale_factor = scale_factor[:, None]
30
        mean = mean * scale_factor
31
        
32
        t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
33
        t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
34
        nb_final = t1 + t2
35
36
        nb_case = nb_final - torch.log(1.0-pi+eps)
37
        zero_nb = torch.pow(disp/(disp+mean+eps), disp)
38
        zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps)
39
        result = torch.where(torch.le(x, 1e-8), zero_case, nb_case)
40
        
41
        if ridge_lambda > 0:
42
            ridge = ridge_lambda*torch.square(pi)
43
            result += ridge
44
        
45
        result = torch.mean(result)
46
        return result
47
48
49
class GaussianNoise(nn.Module):
50
    def __init__(self, sigma=0):
51
        super(GaussianNoise, self).__init__()
52
        self.sigma = sigma
53
    
54
    def forward(self, x):
55
        if self.training:
56
            x = x + self.sigma * torch.randn_like(x)
57
        return x
58
59
60
class MeanAct(nn.Module):
61
    def __init__(self):
62
        super(MeanAct, self).__init__()
63
64
    def forward(self, x):
65
        return torch.clamp(torch.exp(x), min=1e-5, max=1e6)
66
67
class DispAct(nn.Module):
68
    def __init__(self):
69
        super(DispAct, self).__init__()
70
71
    def forward(self, x):
72
        return torch.clamp(F.softplus(x), min=1e-4, max=1e4)