--- a +++ b/semseg/semseg.py @@ -0,0 +1,132 @@ +import time +import glob +import os +import cv2 +import numpy as np +import torch +from torchvision.transforms import ToTensor, Normalize, Compose +from torch.autograd import Variable + +from models import UNet16 + +#小腸カプセル内視鏡画像(png)に対して学習済み(重み:ptファイル)セマンティックセグメンテーションモデルへ入力する +# ※predict_main()を実行するとできる + +def cuda(x): + ''' + GPUが使用可能なら入力画像データをGPUにのせる + ''' + #Python3.7以降予約語にasyncが指定されたため以下であると"SyntaxError: invalid syntax"、代わりにnon_blocking + #return x.cuda(async=True) if torch.cuda.is_available() else x + return x.cuda(non_blocking=True) if torch.cuda.is_available() else x + +def variable(x, volatile=False): + ''' + PyTorchで使用可能な型に変換する + ''' + #xの型がlistまたはtupleに等しいときTrue + if isinstance(x, (list, tuple)): + return [variable(y, volatile=volatile) for y in x] + #以下一行UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead. + #return cuda(Variable(x, volatile=volatile)) + with torch.no_grad(): + return cuda(Variable(x)) + +def get_model(model_path, model_type): + ''' + 学習済みのUNet16の重みを読み込み、UNet16の重みに入れる + :param model_path: + :param model_type: 'UNet16' + :return: + ''' + model = UNet16(num_classes=1) + state = torch.load(str(model_path)) + state = {key.replace('module.', ''): value for key, value in state['model'].items()} + model.load_state_dict(state) + model.eval() + if torch.cuda.is_available(): + return model.cuda() + return model + +def mask_overlay(image, mask, color=(0, 255, 0)): + ''' + 出力の二値画像(白が病変、黒が正常)と入力のCE画像を重ね合わせる + ''' + mask = np.dstack((mask, mask, mask)) * np.array(color) + mask = mask.astype(np.uint8) + weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.) + img = image.copy() + ind = mask[:, :, 1] > 0 + img[ind] = weighted_sum[ind] + return img + +def predict_main(in_path='./input_image', out_path='./result', model_path='weight/model_1_20epoch_8.pt', stat_area=300, r_display=0, g_display=255, b_display=0, source_save=True): + ''' + 学習済みモデルによる推論の実行 + + in_path:入力画像フォルダ(FCNであるため入力サイズは可変、本研究では512×512px) + out_path:出力画像保存フォルダ + model_path:学習済みモデルの保存されているフォルダ + stat_area:出力において、病変と予測された領域の面積がstat_area未満の場合削除される + r_display, g_display, b_display:セグメンテーション領域を何色にするか + source_save:True ⇒ 元画像を保存 + ''' + file_names_ = [ + in_path + ] + save_file_names = [ + out_path + ] + img_transform = Compose([ + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + model = get_model(f'{model_path}', model_type='UNet16')#モデルの読み込み + model_name_only = os.path.splitext(os.path.basename(model_path))[0] + for j, file_names in enumerate(file_names_):#フォルダ数(1回)ループ + read_file_name = glob.glob(f'{file_names}/*.png') + read_file_name_only = [os.path.splitext(os.path.basename(p))[0] for p in read_file_name] + os.makedirs(f'{save_file_names[j]}/{model_name_only}', exist_ok=True) + if source_save: + os.makedirs(f'{save_file_names[j]}/{model_name_only}/source', exist_ok=True) + print(file_names) + start_ = time.time() + for i in range(len(read_file_name)):#画像枚数分ループ + img = cv2.imread(str(read_file_name[i]))#入力画像の読み込み + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + input_img = torch.unsqueeze(variable(img_transform(img), volatile=True), dim=0) + mask = model(input_img)#推論の実行(maskは二値画像で白⇒病変、黒⇒それ以外) + mask_array = mask.data[0].cpu().numpy()[0] + mask_array = np.where(mask_array > 0, np.uint8(255), np.uint8(0)) + retval, labels, stats, _ = cv2.connectedComponentsWithStats(mask_array)#推論後の領域毎の情報を取得 + if mask_array.sum() >= stat_area: + count_small_blob_num = 0 #stat_area[px]以下の領域の数 + for i2, row in enumerate(stats): + #print(f"label {i}") + #print(f"* topleft: ({row[cv2.CC_STAT_LEFT]}, {row[cv2.CC_STAT_TOP]})") + #print(f"* size: ({row[cv2.CC_STAT_WIDTH]}, {row[cv2.CC_STAT_HEIGHT]})") + #print(f"* area: {row[cv2.CC_STAT_AREA]}") + if (row[cv2.CC_STAT_AREA]<stat_area) & (i2!=0):#領域ごとで面積がstat_area[px]未満の場合に黒くする + count_small_blob_num += 1 + mask_array = np.where(labels[:] == i2, 0, mask_array) + count_blob_num = (retval - 1) - count_small_blob_num #領域の個数 + if count_blob_num>0:#領域の個数が1個以上で画像を保存する + if source_save: + cv2.imwrite(f'{save_file_names[j]}/{model_name_only}/source/{read_file_name_only[i]}.png', cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + g_display = 1 if g_display==0 else g_display #g_displayの値が0だとなぜか色が表示されないので0なら1にする + saveimg = mask_overlay(img, (mask_array > 0).astype(np.uint8), color=(r_display, g_display, b_display)) + cv2.imwrite(f'{save_file_names[j]}/{model_name_only}/{read_file_name_only[i]}.png', cv2.cvtColor(saveimg, cv2.COLOR_BGR2RGB)) + print(time.time() - start_) + +if __name__=="__main__": + predict_main(in_path='./input_image', out_path='./result', model_path='weight/model_1_20epoch_8.pt', stat_area=300, r_display=0, g_display=255, b_display=0, source_save=True) + ''' + 引数 + in_path:入力画像フォルダ(FCNであるため入力サイズは可変、本研究では512×512px) + out_path:出力画像保存フォルダ + model_path:学習済みモデルの保存されているフォルダ + stat_area:出力において、病変と予測された領域の面積がstat_area未満の場合削除される + r_display, g_display, b_display:セグメンテーション領域を何色にするか + source_save:True ⇒ 元画像を保存 + ''' \ No newline at end of file