--- a
+++ b/src/layers.py
@@ -0,0 +1,72 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+class NBLoss(nn.Module):
+    def __init__(self):
+        super(NBLoss, self).__init__()
+
+    def forward(self, x, mean, disp, scale_factor=1.0):
+        eps = 1e-10
+        scale_factor = scale_factor[:, None]
+        mean = mean * scale_factor
+        
+        t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
+        t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
+        result = t1 + t2
+
+        result = torch.mean(result)
+        return result
+
+
+class ZINBLoss(nn.Module):
+    def __init__(self):
+        super(ZINBLoss, self).__init__()
+
+    def forward(self, x, mean, disp, pi, scale_factor=1.0, ridge_lambda=0.0):
+        eps = 1e-10
+        scale_factor = scale_factor[:, None]
+        mean = mean * scale_factor
+        
+        t1 = torch.lgamma(disp+eps) + torch.lgamma(x+1.0) - torch.lgamma(x+disp+eps)
+        t2 = (disp+x) * torch.log(1.0 + (mean/(disp+eps))) + (x * (torch.log(disp+eps) - torch.log(mean+eps)))
+        nb_final = t1 + t2
+
+        nb_case = nb_final - torch.log(1.0-pi+eps)
+        zero_nb = torch.pow(disp/(disp+mean+eps), disp)
+        zero_case = -torch.log(pi + ((1.0-pi)*zero_nb)+eps)
+        result = torch.where(torch.le(x, 1e-8), zero_case, nb_case)
+        
+        if ridge_lambda > 0:
+            ridge = ridge_lambda*torch.square(pi)
+            result += ridge
+        
+        result = torch.mean(result)
+        return result
+
+
+class GaussianNoise(nn.Module):
+    def __init__(self, sigma=0):
+        super(GaussianNoise, self).__init__()
+        self.sigma = sigma
+    
+    def forward(self, x):
+        if self.training:
+            x = x + self.sigma * torch.randn_like(x)
+        return x
+
+
+class MeanAct(nn.Module):
+    def __init__(self):
+        super(MeanAct, self).__init__()
+
+    def forward(self, x):
+        return torch.clamp(torch.exp(x), min=1e-5, max=1e6)
+
+class DispAct(nn.Module):
+    def __init__(self):
+        super(DispAct, self).__init__()
+
+    def forward(self, x):
+        return torch.clamp(F.softplus(x), min=1e-4, max=1e4)