--- a
+++ b/ecgxai/network/causalcnn/modules.py
@@ -0,0 +1,176 @@
+import torch
+import torch.nn as nn
+
+
+class Softplus(nn.Module):
+    """
+    Applies Softplus to the output and adds a small number.
+
+    Attributes:
+        eps (int): Small number to add for stability.
+    """
+    def __init__(self, eps: float):
+        super(Softplus, self).__init__()
+        self.eps = eps
+        self.softplus = nn.Softplus()
+
+    def forward(self, x):
+        return self.softplus(x) + self.eps
+
+
+class Chomp1d(torch.nn.Module):
+    """
+    Removes the last elements of a time series.
+
+    Takes as input a three-dimensional tensor (`B`, `C`, `L`) where `B` is the
+    batch size, `C` is the number of input channels, and `L` is the length of
+    the input. Outputs a three-dimensional tensor (`B`, `C`, `L - s`) where `s`
+    is the number of elements to remove.
+
+    Attributes:
+        chomp_size (int): Number of elements to remove.
+    """
+    def __init__(self, chomp_size: int):
+        super(Chomp1d, self).__init__()
+        self.chomp_size = chomp_size
+
+    def forward(self, x):
+        return x[:, :, :-self.chomp_size]
+
+
+class SqueezeChannels(torch.nn.Module):
+    """
+    Squeezes, in a three-dimensional tensor, the third dimension.
+    """
+    def __init__(self):
+        super(SqueezeChannels, self).__init__()
+
+    def forward(self, x):
+        return x.squeeze(2)
+
+
+class CausalConvolutionBlock(torch.nn.Module):
+    """
+    Causal convolution block, composed sequentially of two causal convolutions
+    (with leaky ReLU activation functions), and a parallel residual connection.
+
+    Takes as input a three-dimensional tensor (`B`, `C`, `L`) where `B` is the
+    batch size, `C` is the number of input channels, and `L` is the length of
+    the input. Outputs a three-dimensional tensor (`B`, `C`, `L`).
+
+    Attributes:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        kernel_size (int): Kernel size of the applied non-residual convolutions.
+        padding (int): Zero-padding applied to the left of the input of the
+           non-residual convolutions.
+        final (bool): Disables, if True, the last activation function.
+        forward (bool): If True ordinary convolutions are used, and otherwise 
+            transposed convolutions will be used.
+    """
+    def __init__(self, in_channels: int, out_channels: int, kernel_size: int,
+                 dilation: int, final=False, forward=True):
+        super(CausalConvolutionBlock, self).__init__()
+
+        Conv1d = torch.nn.Conv1d if forward else torch.nn.ConvTranspose1d
+        
+        # Computes left padding so that the applied convolutions are causal
+        padding = (kernel_size - 1) * dilation
+
+        # First causal convolution
+        conv1 = Conv1d(
+            in_channels, out_channels, kernel_size,
+            padding=padding, dilation=dilation
+        )
+        # The truncation makes the convolution causal
+        chomp1 = Chomp1d(padding)
+        relu1 = torch.nn.LeakyReLU()
+
+        # Second causal convolution
+        conv2 = Conv1d(
+            out_channels, out_channels, kernel_size,
+            padding=padding, dilation=dilation
+        )
+        chomp2 = Chomp1d(padding)
+        relu2 = torch.nn.LeakyReLU()
+
+        # Causal network
+        self.causal = torch.nn.Sequential(
+            conv1, chomp1, relu1, conv2, chomp2, relu2
+        )
+
+        # Residual connection
+        self.upordownsample = Conv1d(
+            in_channels, out_channels, 1
+        ) if in_channels != out_channels else None
+
+        # Final activation function
+        self.relu = torch.nn.LeakyReLU() if final else None
+
+    def forward(self, x):
+        out_causal = self.causal(x)
+        res = x if self.upordownsample is None else self.upordownsample(x)
+        if self.relu is None:
+            return out_causal + res
+        else:
+            return self.relu(out_causal + res)
+
+
+class CausalCNN(torch.nn.Module):
+    """
+    Causal CNN, composed of a sequence of causal convolution blocks.
+
+    Takes as input a three-dimensional tensor (`B`, `C`, `L`) where `B` is the
+    batch size, `C` is the number of input channels, and `L` is the length of
+    the input. Outputs a three-dimensional tensor (`B`, `C_out`, `L`).
+
+        in_channels (int): Number of input channels.
+        channels (int): Number of channels processed in the network and of output
+           channels.
+        depth (int): Depth of the network.
+        out_channels (int): Number of output channels.
+        kernel_size (int): Kernel size of the applied non-residual convolutions.
+    """
+    def __init__(self, in_channels, channels, depth, out_channels,
+                 kernel_size, forward=True):
+        super(CausalCNN, self).__init__()
+
+        layers = []  # List of causal convolution blocks
+        # double the dilation size if forward, if backward
+        # we start at the final dilation and work backwards
+        dilation_size = 1 if forward else 2**depth
+        
+        for i in range(depth):
+            in_channels_block = in_channels if i == 0 else channels
+            layers += [CausalConvolutionBlock(
+                in_channels_block, channels, kernel_size, dilation_size,
+                forward=forward,
+            )]
+            # double the dilation at each step if forward, otherwise
+            # halve the dilation
+            dilation_size = dilation_size * 2 if forward else dilation_size // 2
+
+        # Last layer
+        layers += [CausalConvolutionBlock(
+            channels, out_channels, kernel_size, dilation_size
+        )]
+
+        self.network = torch.nn.Sequential(*layers)
+
+    def forward(self, x):
+        return self.network(x)
+
+
+class Spatial(nn.Module):
+    def __init__(self, channels, dropout, forward=True):
+        super(Spatial, self).__init__()
+        Conv1d = nn.Conv1d if forward else nn.ConvTranspose1d
+        self.network = nn.Sequential(
+            Conv1d(channels, channels, 1),
+            nn.BatchNorm1d(num_features=channels),
+            nn.ReLU(),
+            nn.Dropout(dropout),
+        )
+
+    def forward(self, x):
+        return self.network(x)
\ No newline at end of file