In [None]:
import torch
import torchvision
import os
import glob
import time
import pickle
import sys
sys.path.append('/content/drive/MyDrive/Batoul_Code/')
sys.path.append('/content/drive/MyDrive/Batoul_Code/src')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Patha
from PIL import Image
from sklearn.model_selection import train_test_split

from data import LungDataset, blend, Pad, Crop, Resize
from data2 import LungDataset2, blend, Pad, Crop, Resize

from  OurModel import CxlNet

from  metrics import jaccard, dice,get_accuracy, get_sensitivity, get_specificity

In [None]:
in_channels=1
out_channels=2
batch_norm=True
upscale_mode="bilinear"
image_size=512
def selectModel():
    return CxlNet(
            in_channels=in_channels,
            out_channels=out_channels,
            batch_norm=batch_norm,
            upscale_mode=upscale_mode,
            image_size=image_size)

In [None]:
dataset_name="dataset"
dataset_types={"dataset":"png","CT":"jpg"}
dataset_type=dataset_types[dataset_name]
print(dataset_type)
image_size=512
split_file = "/content/drive/MyDrive/Batoul_Code/splits.pk"
list_data_file = "/content/drive/MyDrive/Batoul_Code/list_data.pk"
version="UNet"
approach="contour"
model = selectModel()

base_path="/content/drive/MyDrive/Batoul_Code/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device




data_folder = Path(base_path+"input", base_path+"input/"+dataset_name)
origins_folder = data_folder / "images"
masks_folder = data_folder / "masks"
masks_contour_folder = data_folder / "masks_contour"
masks_folder =masks_contour_folder
models_folder = Path(base_path+"models")
images_folder = Path(base_path+"images")


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
models_folder = Path(base_path+"models")
model_name = "unet-6v.pt"
model_name="ournet_"+version+".pt"
print(model_name)
model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
model.to(device)
model.eval()


test_loss = 0.0
test_jaccard = 0.0
test_dice = 0.0
test_accuracy=0.0
test_sensitivity=0.0
test_specificity=0.0
batch_size = 4

if os.path.isfile(list_data_file):
  with open(list_data_file, "rb") as f:
    list_data = pickle.load(f)
    origins_list=list_data[0]
    masks_list=list_data[1]
else:
  origins_list = [f.stem  for f in origins_folder.glob(f"*.{dataset_type}")]
  masks_list = [f.stem  for f in masks_folder.glob(f"*.{dataset_type}")]
  with open(list_data_file, "wb") as f:
    pickle.dump([origins_list,masks_list], f)


#origins_list = [f.stem for f in origins_folder.glob("*.png")]
#masks_list = [f.stem for f in masks_folder.glob("*.png")]


origin_mask_list = [(mask_name.replace("_mask", ""), mask_name) for mask_name in masks_list]



if os.path.isfile(split_file):
    with open(split_file, "rb") as f:
        splits = pickle.load(f)
else:
    splits = {}
    splits["train"], splits["test"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)
    splits["train"], splits["val"] = train_test_split(splits["train"], test_size=0.1, random_state=42)
    with open(split_file, "wb") as f:
        pickle.dump(splits, f)

val_test_transforms = torchvision.transforms.Compose([
    Resize((image_size, image_size)),
])

if dataset_name!="dataset":
  train_transforms = torchvision.transforms.Compose([
  Pad(200),
  Crop(300),
  val_test_transforms,
  ])
  datasets = {x: LungDataset2(
  splits[x],
  origins_folder,
  masks_folder,
  train_transforms if x == "train" else val_test_transforms,
  dataset_type=dataset_type
  ) for x in ["train", "test", "val"]}
else:
  train_transforms = torchvision.transforms.Compose([
  Pad(200),
  Crop(300),
  val_test_transforms,])

  datasets = {x: LungDataset(
  splits[x],
  origins_folder,
  masks_folder,
  train_transforms if x == "train" else val_test_transforms,
  dataset_type=dataset_type
  ) for x in ["train", "test", "val"]}

num_samples = 9
phase = "test"
print(len(datasets[phase]))

dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=batch_size) for x in ["train", "test", "val"]}

for origins, masks in dataloaders["test"]:
    num = origins.size(0)

    origins = origins.to(device)
    masks = masks.to(device)

    with torch.no_grad():
        outs = model(origins)
        softmax = torch.nn.functional.log_softmax(outs, dim=1)
        test_loss += torch.nn.functional.nll_loss(softmax, masks).item() * num
        outs = torch.argmax(softmax, dim=1)
        outs = outs.float()
        masks = masks.float()
        test_jaccard += jaccard(masks, outs).item() * num
        test_dice += dice(masks, outs).item() * num
        test_accuracy += get_accuracy(masks, outs) * num
        test_sensitivity += get_sensitivity(masks, outs) * num
        test_specificity += get_specificity(masks, outs) * num
    print(".", end="")

test_loss = test_loss / len(datasets["test"])
test_jaccard = test_jaccard / len(datasets["test"])
test_dice = test_dice / len(datasets["test"])
test_accuracy = test_accuracy / len(datasets["test"])
print()
print(f"avg test loss: {test_loss}")
print(f"avg test jaccard: {test_jaccard}")
print(f"avg test dice: {test_dice}")
print(f"avg test accuracy: {test_accuracy}")
print(f"avg test sensitivity: {test_sensitivity}")
print(f"avg test specificity: {test_specificity}")



subset = torch.utils.data.Subset(
    datasets[phase],
    np.random.randint(0, len(datasets[phase]), num_samples)
)
random_samples_loader = torch.utils.data.DataLoader(subset, batch_size=2)
plt.figure(figsize=(20, 25))

for idx, (origin, mask) in enumerate(random_samples_loader):
    plt.subplot((num_samples // 3) + 1, 3, idx + 1)

    origin = origin.to(device)
    mask = mask.to(device)

    with torch.no_grad():
        out = model(origin)
        softmax = torch.nn.functional.log_softmax(out, dim=1)
        out = torch.argmax(softmax, dim=1)

        jaccard_score = jaccard(mask.float(), out.float()).item()
        dice_score = dice(mask.float(), out.float()).item()

        origin = origin[0].to("cpu")
        out = out[0].to("cpu")
        mask = mask[0].to("cpu")
        #plt.imshow(np.array(blend(origin, mask, out)))
        plt.imshow(np.array(blend(origin, out, out)))
        plt.title(f"jaccard: {jaccard_score:.4f}, dice: {dice_score:.4f}")
        print(".", end="")

plt.savefig(images_folder / "obtained-results.png", bbox_inches='tight')
plt.show()
print()
print("red area - predict")
print("green area - ground truth")
print("yellow area - intersection")


model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
model.to(device)
model.eval()

device

#%%

origin_filename = base_path+ f"input/{dataset_name}/images/ID00015637202177877247924_110.jpg"
#origin_filename=base_path + "external_samples/1.jpg"

origin = Image.open(origin_filename).convert("P")
origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))
origin = torchvision.transforms.functional.to_tensor(origin) - 0.5

with torch.no_grad():
    origin = torch.stack([origin])
    origin = origin.to(device)
    out = model(origin)
    softmax = torch.nn.functional.log_softmax(out, dim=1)
    out = torch.argmax(softmax, dim=1)

    origin = origin[0].to("cpu")
    out = out[0].to("cpu")


plt.figure(figsize=(20, 10))

pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB")

plt.subplot(1, 2, 1)
plt.title("origin image")
plt.imshow(np.array(pil_origin))
plt.show()
plt.subplot(1, 2, 2)
plt.title("blended origin + predict")
plt.imshow(np.array(blend(origin, out)))
plt.show()


In [None]:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
models_folder = Path(base_path+"models")



test_loss = 0.0
test_jaccard = 0.0
test_dice = 0.0

batch_size = 4

if os.path.isfile(list_data_file):
  with open(list_data_file, "rb") as f:
    list_data = pickle.load(f)
    origins_list=list_data[0]
    masks_list=list_data[1]
else:
  origins_list = [f.stem  for f in origins_folder.glob(f"*.{dataset_type}")]
  masks_list = [f.stem  for f in masks_folder.glob(f"*.{dataset_type}")]
  with open(list_data_file, "wb") as f:
    pickle.dump([origins_list,masks_list], f)


#origins_list = [f.stem for f in origins_folder.glob("*.png")]
#masks_list = [f.stem for f in masks_folder.glob("*.png")]


origin_mask_list = [(mask_name.replace("_mask", ""), mask_name) for mask_name in masks_list]



if os.path.isfile(split_file):
    with open(split_file, "rb") as f:
        splits = pickle.load(f)
else:
    splits = {}
    splits["train"], splits["test"] = train_test_split(origin_mask_list, test_size=0.2, random_state=42)
    splits["train"], splits["val"] = train_test_split(splits["train"], test_size=0.1, random_state=42)
    with open(split_file, "wb") as f:
        pickle.dump(splits, f)

val_test_transforms = torchvision.transforms.Compose([
    Resize((image_size, image_size)),
])

if dataset_name!="dataset":
  train_transforms = torchvision.transforms.Compose([
  #Pad(200),
  #Crop(300),
  #val_test_transforms,
  ])
  datasets = {x: LungDataset2(
  splits[x],
  origins_folder,
  masks_folder,
  train_transforms if x == "train" else val_test_transforms,
  dataset_type=dataset_type
  ) for x in ["train", "test", "val"]}
else:
  train_transforms = torchvision.transforms.Compose([
  Pad(200),
  Crop(300),
  val_test_transforms,])

  datasets = {x: LungDataset(
  splits[x],
  origins_folder,
  masks_folder,
  train_transforms if x == "train" else val_test_transforms,
  dataset_type=dataset_type
  ) for x in ["train", "test", "val"]}


def mask_to_class_rgb1(mask):
        #print('----mask->rgb----')
  mask = torch.from_numpy(np.array(mask))
  mask = torch.squeeze(mask)  # remove 1

  class_mask = mask

  class_mask = class_mask.permute(2, 0, 1).contiguous()
  h, w = class_mask.shape[1], class_mask.shape[2]
  mask_out = torch.zeros((h, w))

  threshold=200
  for i in range(0,3):
    class_mask[i][class_mask[i] < threshold] = 0

  for i in range(2, 3):
    mask_out[class_mask[i] >= threshold]=1
  return mask_out


def mask_to_class_rgb(mask):
        #print('----mask->rgb----')
  mask = torch.from_numpy(np.array(mask))
  mask = torch.squeeze(mask)  # remove 1

  class_mask = mask

  class_mask = class_mask.permute(2, 0, 1).contiguous()
  h, w = class_mask.shape[1], class_mask.shape[2]
  mask_out = torch.zeros((h, w))

  threshold=200
  for i in range(0,3):
    class_mask[i][class_mask[i] < threshold] = 0

  for i in range(2, 3):
    mask_out[class_mask[i] >= threshold]=1
  return mask_out

def getitem2(path):
  mask = Image.open(path)
  mask = mask_to_class_rgb(mask)
  mask=mask.long()
  #mask = (torch.tensor(mask) > 128).long()
  return mask

def getitem1(path):
  mask = Image.open(path)
  mask = mask.resize((image_size,image_size))
  mask = np.array(mask)
  mask = (torch.tensor(mask) > 128).long()
  return mask


idx=1
phase = "test"
fig = plt.figure(figsize=(20, 10))
input=0
if dataset_name!="dataset":
  samples=["ID00015637202177877247924_110.jpg",
           "ID00009637202177434476278_173.jpg",
           "ID00009637202177434476278_316.jpg",
           "ID00009637202177434476278_204.jpg",]
  masks = [mask_name.replace("_", "_mask_").replace("images", "masks") for mask_name in samples]

else:
  samples=["CHNCXR_0060_0.png",
           "CHNCXR_0074_0.png",
           "CHNCXR_0129_0.png",
           "CHNCXR_0167_0.png",]
  masks = [mask_name.replace("_0.png", "_0_mask.png").replace("images", "masks") for mask_name in samples]


samples=[base_path + f"input/{dataset_name}/images/"+ sample_name for sample_name in samples]
masks=[base_path + f"input/{dataset_name}/masks/"+ mask_name for mask_name in masks]
models=["ResNetDUCHDC","OueNetNew3","NestedUNet","ResNetDUC","FCN_GCN","SegNet2","UNet"]

for input in range(0,len(samples)) :
  for m in range(0,len(models)):
    origin_filename = samples[input]
    origin = Image.open(origin_filename).convert("P")
    origin = torchvision.transforms.functional.resize(origin, (image_size, image_size))
    origin = torchvision.transforms.functional.to_tensor(origin) - 0.5
    if dataset_name!="dataset":
      mask= getitem2(masks[input])
    else:
      mask= getitem1(masks[input])
    version=models[m]
    if dataset_name!="dataset":
      version=version+"_"+dataset_name
    model = selectModel(models[m])
    model_name="ournet_"+version+".pt"
    model.load_state_dict(torch.load(models_folder / model_name, map_location=torch.device("cpu")))
    model.to(device)
    with torch.no_grad():
        origin = torch.stack([origin])
        origin = origin.to(device)
        out = model(origin)
        softmax = torch.nn.functional.log_softmax(out, dim=1)
        out = torch.argmax(softmax, dim=1)

        origin = origin[0].to("cpu")
        out = out[0].to("cpu")

    pil_origin = torchvision.transforms.functional.to_pil_image(origin + 0.5).convert("RGB")
    plt.subplots_adjust(hspace=0)
    if m==0:
      ax=fig.add_subplot(len(samples), len(models)+2,idx)
      ax.set_axis_off()
      #plt.title("origin image")
      plt.imshow(np.array(pil_origin))
      idx=idx+1
      ax=fig.add_subplot(len(samples), len(models)+2,idx)
      ax.set_axis_off()
      plt.imshow(np.array(blend(origin, mask,amount=0.4)))
      idx=idx+1
    ax=fig.add_subplot(len(samples), len(models)+2,idx)
    ax.set_axis_off()
    #plt.title("blended origin + predict")
    plt.imshow(np.array(blend(origin, out,amount=0.5)))
    #plt.savefig(images_folder / f"results/{version} {input}", bbox_inches='tight')
    idx=idx+1

plt.show()
