In [None]:
from google.colab import drive
drive.mount('/content/drive')
#!unzip '/content/drive/MyDrive/Batoul_Code/input/CT/images.zip' -d '/content/drive/MyDrive/Batoul_Code/input/CT'

In [None]:
import torch
import torchvision
import os
import glob
import time
import pickle
import sys
sys.path.append('/content/drive/MyDrive/Batoul_Code/')
sys.path.append('/content/drive/MyDrive/Batoul_Code/src')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split


from data import LungDataset, blend, Pad, Crop, Resize
from  OurModel import  CxlNet
from  metrics import jaccard, dice

In [None]:
in_channels=1
out_channels=2
batch_norm=True
upscale_mode="bilinear"
image_size=512
def selectModel():
    return CxlNet(
            in_channels=in_channels,
            out_channels=out_channels,
            batch_norm=batch_norm,
            upscale_mode=upscale_mode,
            image_size=image_size)

In [None]:
dataset_name="dataset"
dataset_type="png"
split_file = "/content/drive/MyDrive/Batoul_Code/splits.pk"
version="CxlNet"
approach="contour"
model = selectModel()

base_path="/content/drive/MyDrive/Batoul_Code/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

data_folder = Path(base_path+"input", base_path+"input/"+dataset_name)
origins_folder = data_folder / "images"
masks_folder = data_folder / "masks"
masks_contour_folder = data_folder / "masks_contour"
models_folder = Path(base_path+"models")
images_folder = Path(base_path+"images")




In [None]:
#@title
batch_size = 4
torch.cuda.empty_cache()
origins_list = [f.stem  for f in origins_folder.glob(f"*.{dataset_type}")]
#masks_list = [f.stem  for f in masks_folder.glob(f"*.{dataset_type}")]
masks_list = [f.stem  for f in masks_contour_folder.glob(f"*.{dataset_type}")]




origin_mask_list = [(mask_name.replace("_mask", ""), mask_name) for mask_name in masks_list]


if os.path.isfile(split_file):
    with open(split_file, "rb") as f:
        splits = pickle.load(f)
else:
    splits = {}
    splits["train"], splits["test"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)
    splits["train"], splits["val"] = train_test_split(splits["train"], test_size=0.1, random_state=42)
    with open(split_file, "wb") as f:
        pickle.dump(splits, f)

val_test_transforms = torchvision.transforms.Compose([
    Resize((image_size, image_size)),
])

train_transforms = torchvision.transforms.Compose([
    Pad(200),
    Crop(300),
    val_test_transforms,
])

datasets = {x: LungDataset(
    splits[x],
    origins_folder,
    #masks_folder,
    masks_contour_folder,
    train_transforms if x == "train" else val_test_transforms,
    dataset_type=dataset_type
) for x in ["train", "test", "val"]}

dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size)
               for x in ["train", "test", "val"]}

print(len(dataloaders['train']))

idx = 0
phase = "train"

plt.figure(figsize=(20, 20))
origin, mask = datasets[phase][idx]

pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB")
print(origin.size())
print(mask.size())
pil_origin.save("1.png")


print(mask.size())
pil_mask = torchvision.transforms.functional.to_pil_image(mask.float())
pil_mask.save("2.png")
plt.subplot(1, 3, 1)
plt.title("origin image")
plt.imshow(np.array(pil_origin))

plt.subplot(1, 3, 2)
plt.title("manually labeled mask")
plt.imshow(np.array(pil_mask))

plt.subplot(1, 3, 3)
plt.title("blended origin + mask")
plt.imshow(np.array(blend(origin, mask)));

plt.savefig(images_folder / "data-example.png", bbox_inches='tight')
plt.show()
train=True
model_name = "ournet_"+version+".pt"
if train==True:

    if os.path.isfile(models_folder / model_name):
      model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
      print("load_state_dict")

    model = model.to(device)
    # optimizer = torch.optim.SGD(unet.parameters(), lr=0.0005, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

    train_log_filename = base_path + "train-log-"+version+".txt"
    epochs = 50
    best_val_loss = np.inf


    hist = []

    for e in range(epochs):
        start_t = time.time()

        print("Epoch "+str(e))
        model.train()

        train_loss = 0.0

        for origins, masks in dataloaders["train"]:
            num = origins.size(0)

            origins = origins.to(device)
            #print(masks.size())
            #if dataset_name!="dataset":
              #masks = masks.permute((0,3,1, 2))
              #masks=masks[:,0,:,:]
              #print(masks.size())

            masks = masks.to(device)
            optimizer.zero_grad()
            outs = model(origins)
            softmax = torch.nn.functional.log_softmax(outs, dim=1)
            loss = torch.nn.functional.nll_loss(softmax, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * num
            print(".", end="")

        train_loss = train_loss / len(datasets['train'])
        print()

        print("validation phase")
        model.eval()
        val_loss = 0.0
        val_jaccard = 0.0
        val_dice = 0.0

        for origins, masks in dataloaders["val"]:
            num = origins.size(0)
            origins = origins.to(device)
            masks = masks.to(device)

            with torch.no_grad():
                outs = model(origins)
                softmax = torch.nn.functional.log_softmax(outs, dim=1)
                val_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num

                outs = torch.argmax(softmax, dim=1)
                outs = outs.float()
                masks = masks.float()
                val_jaccard += jaccard(masks, outs.float()).item() * num
                val_dice += dice(masks, outs).item() * num

            print(".", end="")
        val_loss = val_loss / len(datasets["val"])
        val_jaccard = val_jaccard / len(datasets["val"])
        val_dice = val_dice / len(datasets["val"])
        print()

        end_t = time.time()
        spended_t = end_t - start_t

        with open(train_log_filename, "a") as train_log_file:
            report = f"epoch: {e + 1}/{epochs}, time: {spended_t}, train loss: {train_loss}, \n" \
                     + f"val loss: {val_loss}, val jaccard: {val_jaccard}, val dice: {val_dice}"

            hist.append({
                "time": spended_t,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_jaccard": val_jaccard,
                "val_dice": val_dice,
            })

            print(report)
            train_log_file.write(report + "\n")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), models_folder / model_name)
                print("model saved")
                train_log_file.write("model saved\n")
            print()

        #if val_jaccard >=0.9179:
            #break
    plt.figure(figsize=(15, 7))
    train_loss_hist = [h["train_loss"] for h in hist]
    plt.plot(range(len(hist)), train_loss_hist, "b", label="train loss")

    val_loss_hist = [h["val_loss"] for h in hist]
    plt.plot(range(len(hist)), val_loss_hist, "r", label="val loss")

    val_dice_hist = [h["val_dice"] for h in hist]
    plt.plot(range(len(hist)), val_dice_hist, "g", label="val dice")

    val_jaccard_hist = [h["val_jaccard"] for h in hist]
    plt.plot(range(len(hist)), val_jaccard_hist, "y", label="val jaccard")

    plt.legend()
    plt.xlabel("epoch")
    plt.savefig(images_folder / model_name.replace(".pt", "-train-hist.png"))

    time_hist = [h["time"] for h in hist]
    overall_time = sum(time_hist) // 60
    mean_epoch_time = sum(time_hist) / len(hist)
    print(f"epochs: {len(hist)}, overall time: {overall_time}m, mean epoch time: {mean_epoch_time}s")

    torch.cuda.empty_cache()
else:

    model_name = "ournet_"+version+".pt"
    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
    model.to(device)
    model.eval()

    test_loss = 0.0
    test_jaccard = 0.0
    test_dice = 0.0

    for origins, masks in dataloaders["test"]:
        num = origins.size(0)
        origins = origins.to(device)
        masks = masks.to(device)

        with torch.no_grad():
            outs = model(origins)
            softmax = torch.nn.functional.log_softmax(outs, dim=1)
            test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num

            outs = torch.argmax(softmax, dim=1)
            outs = outs.float()
            masks = masks.float()
            test_jaccard += jaccard(masks, outs).item() * num
            test_dice += dice(masks, outs).item() * num
        print(".", end="")

    test_loss = test_loss / len(datasets["test"])
    test_jaccard = test_jaccard / len(datasets["test"])
    test_dice = test_dice / len(datasets["test"])

    print()
    print(f"avg test loss: {test_loss}")
    print(f"avg test jaccard: {test_jaccard}")
    print(f"avg test dice: {test_dice}")

    num_samples = 9
    phase = "test"

    subset = torch.utils.data.Subset(
        datasets[phase],
        np.random.randint(0, len(datasets[phase]), num_samples)
    )
    random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=1)
    plt.figure(figsize=(20, 25))

    for idx, (origin, mask) in enumerate(random_samples_loader):
        plt.subplot((num_samples // 3) + 1, 3, idx + 1)

        origin = origin.to(device)
        mask = mask.to(device)

        with torch.no_grad():
            out = model(origin)
            softmax = torch.nn.functional.log_softmax(out, dim=1)
            out = torch.argmax(softmax, dim=1)

            jaccard_score = jaccard(mask.float(), out.float()).item()
            dice_score = dice(mask.float(), out.float()).item()

            origin = origin[0].to("cpu")
            out = out[0].to("cpu")
            mask = mask[0].to("cpu")

            plt.imshow(np.array(blend(origin, mask, out)))
            plt.title(f"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}")
            print(".", end="")
            plt.show()
    plt.savefig(images_folder / "obtained-results.png", bbox_inches='tight')
    print()
    print("red area - predict")
    print("green area - ground truth")
    print("yellow area - intersection")



    model_name = "ournet_"+version+".pt"
    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
    model.to(device)
    model.eval()

    device

    # %%

    origin_filename = "input/dataset/images/CHNCXR_0042_0.png"

    origin = Image.open(origin_filename).convert("P")
    origin = torchvision.transforms.functional.resize(origin, (200, 200))
    origin = torchvision.transforms.functional.to_tensor(origin) - 0.5

    with torch.no_grad():
        origin = torch.stack([origin])
        origin = origin.to(device)
        out = model(origin)
        softmax = torch.nn.functional.log_softmax(out, dim=1)
        out = torch.argmax(softmax, dim=1)

        origin = origin[0].to("cpu")
        out = out[0].to("cpu")

    plt.figure(figsize=(20, 10))

    pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB")

    plt.subplot(1, 2, 1)
    plt.title("origin image")
    plt.imshow(np.array(pil_origin))

    plt.subplot(1, 2, 2)
    plt.title("blended origin + predict")
    plt.imshow(np.array(blend(origin, out)))
    plt.show()
