[e50482]: / predict_2d.py

Download this file

156 lines (137 with data), 5.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from data_2d import *
from train_2d import *
torch.manual_seed(42)
"""-----------------------Arguments-----------------------"""
parser = argparse.ArgumentParser(description="Prediction")
parser.add_argument("--batch_size_test", type=str, default=1)
args = parser.parse_args()
test_batch_size = args.batch_size_test
# Softmax
soft = torch.nn.Softmax(dim=1).cuda()
# IoU
IOU_metric = IoU(num_classes=4, absent_score=-1., reduction="none").cuda()
# F1 score
f1_metric = F1(num_classes=4, mdmc_average="samplewise", average='none').cuda()
"""---------Post Processing---------"""
keep_largest = monai.transforms.KeepLargestConnectedComponent(applied_labels=[0, 1, 2, 3], independent=True)
fill_holes = monai.transforms.FillHoles(applied_labels=[0])
"""---------Test Data---------"""
test_transform = mt.Compose([
mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False)
])
# Test dataset
test_data = DataLoader(test_loader_ACDC(transform=test_transform, test_index=None), batch_size=1, shuffle=False)
# pad the images so that they are divisible by 16
def Pad_images(image):
orig_shape = list(image.size())
original_x = orig_shape[2]
original_y = orig_shape[3]
new_x = (16 - (original_x % 16)) + original_x
new_y = (16 - (original_y % 16)) + original_y
new_shape = [new_x, new_y]
b, c, h, w = image.shape
m = image.min()
x_max = new_shape[0]
y_max = new_shape[1]
result = torch.Tensor(b, c, x_max, y_max).fill_(m)
xx = (x_max - h) // 2
yy = (y_max - w) // 2
result[:, :, xx:xx + h, yy:yy + w] = image
return result, tuple([xx, yy]) # result is a torch tensor in CPU --> have to move to GPU
# pass the padded image, the indices and the original shape
def UnPad_imges(image, indices, org_shape):
b, c, h, w = org_shape
xx = indices[0]
yy = indices[1]
return image[:, :, xx:xx + h, yy:yy + w] # image is a torch tensor --> have to move to GPU
# save the predictions and ground truths
def save_pred(img, mask, pred, outpath, name_model, idx):
# Folder to save the results
if not os.path.exists(os.path.join(outpath, name_model)):
os.makedirs(os.path.join(outpath, name_model))
out_save_path_image = os.path.join(outpath, name_model, f"{idx}_image" + "." + 'png')
out_save_path_pred = os.path.join(outpath, name_model, f"{idx}_pred" + "." + 'png')
out_save_path_mask = os.path.join(outpath, name_model, f"{idx}_gt" + "." + 'png')
# Save images
img = img.squeeze()
img = np.array(img.cpu())
image_file_name = str(idx) + "_image"
plt.title = image_file_name
plt.imsave(out_save_path_image, img, format='png', cmap='gray')
# Save ground truths
mask = mask.squeeze()
mask = np.array(mask.cpu())
image_file_name = str(idx) + "_gt"
plt.title = image_file_name
plt.imsave(out_save_path_mask, mask, format='png', cmap='gray')
# Save predictions
# Post Processing
final_prediction = torch.argmax(pred, dim=1)
final_prediction = keep_largest(final_prediction)
final_prediction = fill_holes(final_prediction)
# final_prediction = torch.argmax(final_prediction, dim=1)
final_pred = np.array(final_prediction.cpu().squeeze())
image_file_name = str(idx) + "_pred"
plt.title = image_file_name
plt.imsave(out_save_path_pred, final_pred, format='png', cmap='gray')
plt.close()
def save_results(iou, dice, out_path, model_name):
# Folder to store the plots
if not os.path.exists(os.path.join(out_path, model_name)):
os.makedirs(os.path.join(out_path, model_name))
out_save_path = os.path.join(os.path.join(out_path, model_name))
# IoU
sorted_iou = sorted(iou)
print("Top IoU:", sorted_iou[-1])
# Dice Scores
sorted_dice = sorted(dice)
print("Top Dice Score:", sorted_dice[-1])
# save results
result_dict = {"IoU": sorted_iou[-1],
"Dice Score": sorted_dice[-1]}
file_name = 'results.txt'
completeName = os.path.join(out_save_path, file_name)
with open(completeName, 'w') as file:
file.write(str(result_dict))
def test_results(model, out_path, model_name):
all_iou = []
all_dice = []
indices = 0
for items in test_data:
image = items["image"].cuda()
image_shape = image.shape
mask = items["mask"].long().cuda().squeeze(dim=1)
# pad the image
image, ind = Pad_images(image)
pred = model(image.float().cuda())
# unpad the images
pred = UnPad_imges(pred, ind, image_shape).cuda()
pred = soft(pred)
# pred = torch.argmax(pred, dim=1)
# Save results
save_pred(image, mask, pred, out_path, model_name, indices)
# calculate iou
iou_all_class = IOU_metric(pred, mask)
iou_all_class = iou_all_class.cpu().numpy()
iou_all_class = iou_all_class[iou_all_class != -1.]
iou = iou_all_class.mean()
all_iou.append(iou)
# calculate dice score
dice_all_class = f1_metric(pred, mask)
dice_all_class = dice_all_class.cpu().numpy()
dice_all_class = dice_all_class[~np.isnan(dice_all_class)]
dice = dice_all_class.mean()
all_dice.append(dice)
indices = indices + 1
# Save plots
save_results(all_iou, all_dice, out_path, model_name)
if __name__ == "__main__":
name = str("UNet2D_Attention_Best_0.3_Fold_1")
model_path = str(Path(r"../unet/cluster_results/best_models", name + ".pt"))
if not os.path.exists(r'../unet/test_results_1/'):
os.makedirs(r'../unet/test_results_1/')
result_path = r'../unet/test_results_1/'
model = torch.load(model_path)
with torch.no_grad():
model.eval().cuda()
test_results(model, result_path, name)