--- a
+++ b/U-Net/unet/unet_model.py
@@ -0,0 +1,48 @@
+""" Full assembly of the parts to form the complete network """
+
+from .unet_parts import *
+
+
+class UNet(nn.Module):
+    def __init__(self, n_channels, n_classes, bilinear=False):
+        super(UNet, self).__init__()
+        self.n_channels = n_channels
+        self.n_classes = n_classes
+        self.bilinear = bilinear
+
+        self.inc = (DoubleConv(n_channels, 64))
+        self.down1 = (Down(64, 128))
+        self.down2 = (Down(128, 256))
+        self.down3 = (Down(256, 512))
+        factor = 2 if bilinear else 1
+        self.down4 = (Down(512, 1024 // factor))
+        self.up1 = (Up(1024, 512 // factor, bilinear))
+        self.up2 = (Up(512, 256 // factor, bilinear))
+        self.up3 = (Up(256, 128 // factor, bilinear))
+        self.up4 = (Up(128, 64, bilinear))
+        self.outc = (OutConv(64, n_classes))
+
+    def forward(self, x):
+        x1 = self.inc(x)
+        x2 = self.down1(x1)
+        x3 = self.down2(x2)
+        x4 = self.down3(x3)
+        x5 = self.down4(x4)
+        x = self.up1(x5, x4)
+        x = self.up2(x, x3)
+        x = self.up3(x, x2)
+        x = self.up4(x, x1)
+        logits = self.outc(x)
+        return logits
+
+    def use_checkpointing(self):
+        self.inc = torch.utils.checkpoint(self.inc)
+        self.down1 = torch.utils.checkpoint(self.down1)
+        self.down2 = torch.utils.checkpoint(self.down2)
+        self.down3 = torch.utils.checkpoint(self.down3)
+        self.down4 = torch.utils.checkpoint(self.down4)
+        self.up1 = torch.utils.checkpoint(self.up1)
+        self.up2 = torch.utils.checkpoint(self.up2)
+        self.up3 = torch.utils.checkpoint(self.up3)
+        self.up4 = torch.utils.checkpoint(self.up4)
+        self.outc = torch.utils.checkpoint(self.outc)