--- a
+++ b/inpainting/train.py
@@ -0,0 +1,143 @@
+import os
+from torch.utils.data import DataLoader
+from torchvision import transforms
+import torchvision.utils as vutils
+import torch
+import torch.nn as nn
+from data.data import InpaintingDataset, ToTensor
+from model.net import InpaintingModel_GMCNN
+from options.train_options import TrainOptions
+from util.utils import getLatest
+from data.hyperkvasir import KvasirSegmentationDataset
+import torch.nn.functional as F
+from inpainting.model.basemodel import BaseModel
+from inpainting.model.basenet import BaseNet
+from inpainting.model.loss import WGANLoss, IDMRFLoss
+from inpainting.model.layer import init_weights, PureUpsampling, ConfidenceDrivenMaskLayer, SpectralNorm
+from inpainting.model.net import *
+
+config = {"model": "DeepLab",
+          "device": "cuda",
+          "lr": 0.00001,
+          "batch_size": 8,
+          "epochs": 250}
+DATASET_PATH = ""
+
+
+class gmcnn_inpainter_trainer:
+    def __init__(self, config):
+        self.model = InpaintingModel_GMCNN(in_channels=3, opt=config)
+        self.dataset = KvasirSegmentationDataset(DATASET_PATH, augment=True)  # todo make inpainting dataset
+        self.dataloader = DataLoader(self.dataset, batch_size=config["batch_size"], shuffle=True)
+        self.device = config["device"]
+        self.recloss = nn.L1Loss()
+        self.aeloss = nn.L1Loss()
+        self.confidence_mask_layer = ConfidenceDrivenMaskLayer()
+
+        self.netGM = GMCNN(3, out_channels=3, cnum=32, act=F.elu, norm=F.instance_norm).cuda()
+        init_weights(self.netGM)
+        self.model_names = ['GM']
+
+        self.netD = None
+
+        self.optimizer_G = torch.optim.Adam(self.netGM.parameters(), lr=opt.lr, betas=(0.5, 0.9))
+        self.optimizer_D = None
+
+        self.wganloss = None
+        self.recloss = nn.L1Loss()
+        self.aeloss = nn.L1Loss()
+        self.mrfloss = None
+        # self.lambda_adv = opt.lambda_adv
+        # self.lambda_rec = opt.lambda_rec
+        # self.lambda_ae = opt.lambda_ae
+        # self.lambda_gp = opt.lambda_gp
+        # self.lambda_mrf = opt.lambda_mrf
+        self.G_loss = None
+        self.G_loss_reconstruction = None
+        self.G_loss_mrf = None
+        self.G_loss_adv, self.G_loss_adv_local = None, None
+        self.G_loss_ae = None
+        self.D_loss, self.D_loss_local = None, None
+        self.GAN_loss = None
+
+        self.gt, self.gt_local = None, None
+        self.mask, self.mask_01 = None, None
+        self.rect = None
+        self.im_in, self.gin = None, None
+
+        self.completed, self.completed_local = None, None
+        self.completed_logit, self.completed_local_logit = None, None
+        self.gt_logit, self.gt_local_logit = None, None
+
+        self.pred = None
+        self.netD = GlobalLocalDiscriminator(3, cnum=64, act=F.elu,
+                                             spectral_norm=True,
+                                             g_fc_channels=512 // 16 * 512 // 16 * 64 * 4,
+                                             l_fc_channels=512 // 16 * 512 // 16 * 64 * 4).to(self.device)
+        init_weights(self.netD)
+        self.optimizer_D = torch.optim.Adam(filter(lambda x: x.requires_grad, self.netD.parameters()), lr=config["lr"],
+                                            betas=(0.5, 0.9))
+        self.wganloss = WGANLoss()
+        self.mrfloss = IDMRFLoss()
+
+    def train(self):
+        for epoch in range(config["epochs"]):
+            self.train_epoch(epoch)
+            ret_loss = self.model.get_current_losses()
+            self.model.save_networks(epoch + 1)
+
+    def train_epoch(self, epoch):
+        for img, mask, fname in self.dataloader:
+            img, mask = img.to(self.device), mask.to(self.device)
+            img_in = img * (1 - mask)
+            self.gen_in = torch.cat((img_in, mask), 1)
+            self.model.setInput(img, mask)
+            self.model.optimize_parameters()
+
+            self.pred = self.netGM(self.gin)
+            self.completed = self.pred * self.mask_01 + self.gt * (1 - self.mask_01)
+            self.completed_local = self.completed
+
+            for i in range(5):  # train discriminator 5 times interleaved
+                self.optimizer_D.zero_grad()
+                self.optimizer_G.zero_grad()
+                self.forward_D()
+                self.backward_D()
+                self.optimizer_D.step()
+
+            self.optimizer_G.zero_grad()
+            self.forward_G()
+            self.backward_G()
+            self.optimizer_G.step()
+            # TODO come back here to finish gmcnn
+
+    def forward_G(self):
+        self.G_loss_reconstruction = self.recloss(self.completed * self.mask, self.gt.detach() * self.mask)
+        self.G_loss_reconstruction = self.G_loss_reconstruction / torch.mean(self.mask_01)
+        self.G_loss_ae = self.aeloss(self.pred * (1 - self.mask_01), self.gt.detach() * (1 - self.mask_01))
+        self.G_loss_ae = self.G_loss_ae / torch.mean(1 - self.mask_01)
+        self.G_loss = self.lambda_rec * self.G_loss_reconstruction + self.lambda_ae * self.G_loss_ae
+
+        self.completed_logit, self.completed_local_logit = self.netD(self.completed, self.completed_local)
+        self.G_loss_mrf = self.mrfloss((self.completed_local + 1) / 2.0, (self.gt_local.detach() + 1) / 2.0)
+        self.G_loss = self.G_loss + self.lambda_mrf * self.G_loss_mrf
+
+        self.G_loss_adv = -self.completed_logit.mean()
+        self.G_loss_adv_local = -self.completed_local_logit.mean()
+        self.G_loss = self.G_loss + self.lambda_adv * (self.G_loss_adv + self.G_loss_adv_local)
+
+    def forward_D(self):
+        self.completed_logit, self.completed_local_logit = self.netD(self.completed.detach(),
+                                                                     self.completed_local.detach())
+        self.gt_logit, self.gt_local_logit = self.netD(self.gt, self.gt_local)
+        # hinge loss
+        self.D_loss_local = nn.ReLU()(1.0 - self.gt_local_logit).mean() + nn.ReLU()(
+            1.0 + self.completed_local_logit).mean()
+        self.D_loss = nn.ReLU()(1.0 - self.gt_logit).mean() + nn.ReLU()(1.0 + self.completed_logit).mean()
+        self.D_loss = self.D_loss + self.D_loss_local
+
+    def backward_G(self):
+        self.G_loss.backward()
+
+    def backward_D(self):
+        self.D_loss.backward(retain_graph=True)