--- a +++ b/models/inpainters.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +from segmentation_models_pytorch.deeplabv3 import DeepLabV3Plus + + +class SegGenerator(nn.Module): + def __init__(self): + super(SegGenerator, self).__init__() + self.model = DeepLabV3Plus(in_channels=3, classes=3, activation=None, encoder_weights=None) + + def forward(self, mask): + return self.model(mask) + + +class SegDiscriminator(nn.Module): + def __init__(self): + super(SegDiscriminator, self).__init__() + self.model = DeepLabV3Plus(in_channels=3, classes=1, activation="sigmoid", encoder_weights=None) + + def forward(self, mask): + # adds nosie to inputs, see https://arxiv.org/pdf/1701.04862.pdf + return self.model(mask + torch.normal(torch.zeros_like(mask), torch.ones_like(mask) / 10)) + # return self.model(mask)