|
a |
|
b/models/inpainters.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from segmentation_models_pytorch.deeplabv3 import DeepLabV3Plus |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
class SegGenerator(nn.Module): |
|
|
7 |
def __init__(self): |
|
|
8 |
super(SegGenerator, self).__init__() |
|
|
9 |
self.model = DeepLabV3Plus(in_channels=3, classes=3, activation=None, encoder_weights=None) |
|
|
10 |
|
|
|
11 |
def forward(self, mask): |
|
|
12 |
return self.model(mask) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class SegDiscriminator(nn.Module): |
|
|
16 |
def __init__(self): |
|
|
17 |
super(SegDiscriminator, self).__init__() |
|
|
18 |
self.model = DeepLabV3Plus(in_channels=3, classes=1, activation="sigmoid", encoder_weights=None) |
|
|
19 |
|
|
|
20 |
def forward(self, mask): |
|
|
21 |
# adds nosie to inputs, see https://arxiv.org/pdf/1701.04862.pdf |
|
|
22 |
return self.model(mask + torch.normal(torch.zeros_like(mask), torch.ones_like(mask) / 10)) |
|
|
23 |
# return self.model(mask) |