Diff of /models/inpainters.py [000000] .. [92cc18]

Switch to unified view

a b/models/inpainters.py
1
import torch
2
import torch.nn as nn
3
from segmentation_models_pytorch.deeplabv3 import DeepLabV3Plus
4
5
6
class SegGenerator(nn.Module):
7
    def __init__(self):
8
        super(SegGenerator, self).__init__()
9
        self.model = DeepLabV3Plus(in_channels=3, classes=3, activation=None, encoder_weights=None)
10
11
    def forward(self, mask):
12
        return self.model(mask)
13
14
15
class SegDiscriminator(nn.Module):
16
    def __init__(self):
17
        super(SegDiscriminator, self).__init__()
18
        self.model = DeepLabV3Plus(in_channels=3, classes=1, activation="sigmoid", encoder_weights=None)
19
20
    def forward(self, mask):
21
        # adds nosie to inputs, see https://arxiv.org/pdf/1701.04862.pdf
22
        return self.model(mask + torch.normal(torch.zeros_like(mask), torch.ones_like(mask) / 10))
23
        # return self.model(mask)