[92cc18]: / models / inpainters.py

Download this file

24 lines (17 with data), 814 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
from segmentation_models_pytorch.deeplabv3 import DeepLabV3Plus
class SegGenerator(nn.Module):
def __init__(self):
super(SegGenerator, self).__init__()
self.model = DeepLabV3Plus(in_channels=3, classes=3, activation=None, encoder_weights=None)
def forward(self, mask):
return self.model(mask)
class SegDiscriminator(nn.Module):
def __init__(self):
super(SegDiscriminator, self).__init__()
self.model = DeepLabV3Plus(in_channels=3, classes=1, activation="sigmoid", encoder_weights=None)
def forward(self, mask):
# adds nosie to inputs, see https://arxiv.org/pdf/1701.04862.pdf
return self.model(mask + torch.normal(torch.zeros_like(mask), torch.ones_like(mask) / 10))
# return self.model(mask)