Diff of /models/inpainters.py [000000] .. [92cc18]

Switch to side-by-side view

--- a
+++ b/models/inpainters.py
@@ -0,0 +1,23 @@
+import torch
+import torch.nn as nn
+from segmentation_models_pytorch.deeplabv3 import DeepLabV3Plus
+
+
+class SegGenerator(nn.Module):
+    def __init__(self):
+        super(SegGenerator, self).__init__()
+        self.model = DeepLabV3Plus(in_channels=3, classes=3, activation=None, encoder_weights=None)
+
+    def forward(self, mask):
+        return self.model(mask)
+
+
+class SegDiscriminator(nn.Module):
+    def __init__(self):
+        super(SegDiscriminator, self).__init__()
+        self.model = DeepLabV3Plus(in_channels=3, classes=1, activation="sigmoid", encoder_weights=None)
+
+    def forward(self, mask):
+        # adds nosie to inputs, see https://arxiv.org/pdf/1701.04862.pdf
+        return self.model(mask + torch.normal(torch.zeros_like(mask), torch.ones_like(mask) / 10))
+        # return self.model(mask)