Diff of /test.py [000000] .. [f2ca4d]

Switch to side-by-side view

--- a
+++ b/test.py
@@ -0,0 +1,142 @@
+import numpy as np
+import sys
+import os
+import glob
+import nibabel as nib
+import torch
+#docker
+#fpx = '/wmhseg_code/'
+#inputDir = '/input'
+#outputDir = '/output'
+
+#local
+main_folder_path = '../Data/MS2017b/'
+fpx = './'
+inputDir = 'input/'
+outputDir = 'output/'
+
+#PARAMS
+useGPU = 0
+gpu0 = 0
+patch_size = 60
+extra_patch = 5
+model_paths = [fpx + 'analysis/models/EXP3D_1x1x1x1_0_0_dice_1_best.pth']
+weights = [1]
+#EXP3D_1x1x1x1_0_0_dice_1_best.pth
+#EXP3D_2x2x2x2_0_0_dice_1_best.pth
+#EXP3D_1x1x1x1_0_1_dice_1_best.pth
+#EXP3D_1x1x1x1_1_1_dice_1_best.pth
+##EXP3D_1x1x1x1_0_1_dice_1_best.pth (with ep = 16)
+sys.path.append(fpx + 'utils/')
+sys.path.append(fpx + 'architectures/deeplab_3D/')
+sys.path.append(fpx + 'architectures/unet_3D/')
+sys.path.append(fpx + 'architectures/hrnet_3D/')
+sys.path.append(fpx + 'architectures/experiment_nets_3D/')
+sys.path.append('utils/')
+
+import deeplab_resnet_3D
+import unet_3D
+import highresnet_3D
+import exp_net_3D
+
+import augmentations as AUG
+import normalizations as NORM
+import resizeScans as RS
+import evalF as EF
+import evalFP as EFP
+import PP
+import torch
+
+#step 1: read image from input folder
+#step 2: resize image to 200x200x100 + apply normalizations
+#step 3: make prediction by patches (with augmentations)
+#step 4: save prediction to output folder
+#step 5: resize prediction back to original size of image
+
+
+img_path = os.path.join(inputDir, 'FLAIR.nii.gz')
+img_path_rs = os.path.join(outputDir, 'FLAIR_rs.nii.gz')
+
+wmh_path_rs = os.path.join(outputDir, 'wmh_rs.nii.gz')
+wmh_path = os.path.join(outputDir, 'result.nii.gz')
+
+old_size = PP.numpyFromScan(img_path).shape
+
+new_size = [200,200,100]
+num_labels = 2
+
+#convert scan to 200x200x100
+RS.convertSize2(img_path, img_path_rs, new_size)
+#get the affine value
+affine_rs = nib.load(img_path_rs).get_affine()
+
+#normalize using histogram and variance normalization
+RS.normalizeScan(img_path_rs, img_path_rs, main_folder_path=main_folder_path)
+
+#read preprocessed img
+img, affine = PP.numpyFromScan(img_path_rs, get_affine = True)
+img = img.transpose((3,0,1,2))
+img = img[np.newaxis, :]
+
+print('Image ready')
+print('Loading model')
+
+out = None
+for i, model_path in enumerate(model_paths): 
+	f_name = model_path.split('/')[-1]
+	isPriv = False
+
+	#load model
+	if 'EXP3D' in f_name:
+		experiment = f_name.replace('EXP3D_', '').replace('.pth', '').split('_')
+		experiment = '_'.join(experiment[0:3])
+		dilation_arr, isPriv, withASPP = PP.getExperimentInfo(experiment)
+		model = exp_net_3D.getExpNet(num_labels, dilation_arr, isPriv, NoLabels2 = 209, withASPP = withASPP)
+	elif 'HR3D' in f_name:
+		model = highresnet_3D.getHRNet(num_labels)
+	elif 'DL3D' in f_name:
+		model = deeplab_resnet_3D.Res_Deeplab(num_labels)
+	elif 'UNET3D' in  f_name:
+		model = unet_3D.UNet3D(1, num_labels)
+	else:
+		print('No model available for this .pth')
+		sys.exit()
+
+	if useGPU:
+	    saved_state_dict = torch.load(model_path)
+	else:
+	    saved_state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
+	model.load_state_dict(saved_state_dict)
+	model.float()
+	model.eval()
+	print('Model ready')
+	print('Predicting...')
+	if not isinstance(out, np.ndarray):
+		if isPriv:
+			out = EFP.testPredict(img, model, num_labels, 209, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
+		else:
+			out = EF.testPredict(img, model, num_labels, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
+	else:
+		if isPriv:
+			out += EFP.testPredict(img, model, num_labels, 209, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
+		else:
+			out += EF.testPredict(img, model, num_labels, 1, gpu0, useGPU, stride = 50, patch_size = 60, test_augm = False, extra_patch = extra_patch, get_soft = True)
+
+out /= float(len(model_paths))
+out = np.argmax(out, axis = 0)
+#remove batch and label dimension
+out = out.squeeze()
+print('Prediction complete')
+print('Saving...')
+#save output
+PP.saveScan(out.astype(np.float64), affine_rs, wmh_path_rs)
+
+#resize output to original input size and save (this is our final result)
+d = RS.convertSize2(wmh_path_rs, wmh_path, old_size, interpolation = 'nearest')
+
+#read the image and save it with same affine and header as original FLAIR image
+print('Saving final wmh file')
+orig_flair = nib.load(img_path)
+wmh_final = nib.load(wmh_path).get_data()
+PP.saveScan(wmh_final, orig_flair.get_affine(), wmh_path, header =orig_flair.header)
+print('Done')
\ No newline at end of file