--- a +++ b/predict_2d.py @@ -0,0 +1,155 @@ +from data_2d import * +from train_2d import * +torch.manual_seed(42) + +"""-----------------------Arguments-----------------------""" +parser = argparse.ArgumentParser(description="Prediction") +parser.add_argument("--batch_size_test", type=str, default=1) + +args = parser.parse_args() +test_batch_size = args.batch_size_test + +# Softmax +soft = torch.nn.Softmax(dim=1).cuda() +# IoU +IOU_metric = IoU(num_classes=4, absent_score=-1., reduction="none").cuda() +# F1 score +f1_metric = F1(num_classes=4, mdmc_average="samplewise", average='none').cuda() + +"""---------Post Processing---------""" +keep_largest = monai.transforms.KeepLargestConnectedComponent(applied_labels=[0, 1, 2, 3], independent=True) +fill_holes = monai.transforms.FillHoles(applied_labels=[0]) + +"""---------Test Data---------""" +test_transform = mt.Compose([ + mt.ToTensorD(keys=["image", "mask"], allow_missing_keys=False) +]) +# Test dataset +test_data = DataLoader(test_loader_ACDC(transform=test_transform, test_index=None), batch_size=1, shuffle=False) + + +# pad the images so that they are divisible by 16 +def Pad_images(image): + orig_shape = list(image.size()) + original_x = orig_shape[2] + original_y = orig_shape[3] + new_x = (16 - (original_x % 16)) + original_x + new_y = (16 - (original_y % 16)) + original_y + new_shape = [new_x, new_y] + b, c, h, w = image.shape + m = image.min() + x_max = new_shape[0] + y_max = new_shape[1] + result = torch.Tensor(b, c, x_max, y_max).fill_(m) + xx = (x_max - h) // 2 + yy = (y_max - w) // 2 + result[:, :, xx:xx + h, yy:yy + w] = image + return result, tuple([xx, yy]) # result is a torch tensor in CPU --> have to move to GPU + + +# pass the padded image, the indices and the original shape +def UnPad_imges(image, indices, org_shape): + b, c, h, w = org_shape + xx = indices[0] + yy = indices[1] + return image[:, :, xx:xx + h, yy:yy + w] # image is a torch tensor --> have to move to GPU + + +# save the predictions and ground truths +def save_pred(img, mask, pred, outpath, name_model, idx): + # Folder to save the results + if not os.path.exists(os.path.join(outpath, name_model)): + os.makedirs(os.path.join(outpath, name_model)) + out_save_path_image = os.path.join(outpath, name_model, f"{idx}_image" + "." + 'png') + out_save_path_pred = os.path.join(outpath, name_model, f"{idx}_pred" + "." + 'png') + out_save_path_mask = os.path.join(outpath, name_model, f"{idx}_gt" + "." + 'png') + # Save images + img = img.squeeze() + img = np.array(img.cpu()) + image_file_name = str(idx) + "_image" + plt.title = image_file_name + plt.imsave(out_save_path_image, img, format='png', cmap='gray') + # Save ground truths + mask = mask.squeeze() + mask = np.array(mask.cpu()) + image_file_name = str(idx) + "_gt" + plt.title = image_file_name + plt.imsave(out_save_path_mask, mask, format='png', cmap='gray') + # Save predictions + # Post Processing + final_prediction = torch.argmax(pred, dim=1) + final_prediction = keep_largest(final_prediction) + final_prediction = fill_holes(final_prediction) + # final_prediction = torch.argmax(final_prediction, dim=1) + final_pred = np.array(final_prediction.cpu().squeeze()) + image_file_name = str(idx) + "_pred" + plt.title = image_file_name + plt.imsave(out_save_path_pred, final_pred, format='png', cmap='gray') + plt.close() + + +def save_results(iou, dice, out_path, model_name): + # Folder to store the plots + if not os.path.exists(os.path.join(out_path, model_name)): + os.makedirs(os.path.join(out_path, model_name)) + out_save_path = os.path.join(os.path.join(out_path, model_name)) + # IoU + sorted_iou = sorted(iou) + print("Top IoU:", sorted_iou[-1]) + # Dice Scores + sorted_dice = sorted(dice) + print("Top Dice Score:", sorted_dice[-1]) + # save results + result_dict = {"IoU": sorted_iou[-1], + "Dice Score": sorted_dice[-1]} + + file_name = 'results.txt' + completeName = os.path.join(out_save_path, file_name) + with open(completeName, 'w') as file: + file.write(str(result_dict)) + + +def test_results(model, out_path, model_name): + all_iou = [] + all_dice = [] + indices = 0 + for items in test_data: + image = items["image"].cuda() + image_shape = image.shape + mask = items["mask"].long().cuda().squeeze(dim=1) + # pad the image + image, ind = Pad_images(image) + pred = model(image.float().cuda()) + # unpad the images + pred = UnPad_imges(pred, ind, image_shape).cuda() + pred = soft(pred) + # pred = torch.argmax(pred, dim=1) + # Save results + save_pred(image, mask, pred, out_path, model_name, indices) + # calculate iou + iou_all_class = IOU_metric(pred, mask) + iou_all_class = iou_all_class.cpu().numpy() + iou_all_class = iou_all_class[iou_all_class != -1.] + iou = iou_all_class.mean() + all_iou.append(iou) + # calculate dice score + dice_all_class = f1_metric(pred, mask) + dice_all_class = dice_all_class.cpu().numpy() + dice_all_class = dice_all_class[~np.isnan(dice_all_class)] + dice = dice_all_class.mean() + all_dice.append(dice) + indices = indices + 1 + # Save plots + save_results(all_iou, all_dice, out_path, model_name) + + +if __name__ == "__main__": + name = str("UNet2D_Attention_Best_0.3_Fold_1") + model_path = str(Path(r"../unet/cluster_results/best_models", name + ".pt")) + if not os.path.exists(r'../unet/test_results_1/'): + os.makedirs(r'../unet/test_results_1/') + result_path = r'../unet/test_results_1/' + model = torch.load(model_path) + with torch.no_grad(): + model.eval().cuda() + test_results(model, result_path, name)