[76022b]: / semseg / semseg.py

Download this file

132 lines (122 with data), 6.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import time
import glob
import os
import cv2
import numpy as np
import torch
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.autograd import Variable
from models import UNet16
#小腸カプセル内視鏡画像(png)に対して学習済み(重み:ptファイル)セマンティックセグメンテーションモデルへ入力する
# ※predict_main()を実行するとできる
def cuda(x):
'''
GPUが使用可能なら入力画像データをGPUにのせる
'''
#Python3.7以降予約語にasyncが指定されたため以下であると"SyntaxError: invalid syntax"、代わりにnon_blocking
#return x.cuda(async=True) if torch.cuda.is_available() else x
return x.cuda(non_blocking=True) if torch.cuda.is_available() else x
def variable(x, volatile=False):
'''
PyTorchで使用可能な型に変換する
'''
#xの型がlistまたはtupleに等しいときTrue
if isinstance(x, (list, tuple)):
return [variable(y, volatile=volatile) for y in x]
#以下一行UserWarning: volatile was removed and now has no effect. Use `with torch.no_grad():` instead.
#return cuda(Variable(x, volatile=volatile))
with torch.no_grad():
return cuda(Variable(x))
def get_model(model_path, model_type):
'''
学習済みのUNet16の重みを読み込み、UNet16の重みに入れる
:param model_path:
:param model_type: 'UNet16'
:return:
'''
model = UNet16(num_classes=1)
state = torch.load(str(model_path))
state = {key.replace('module.', ''): value for key, value in state['model'].items()}
model.load_state_dict(state)
model.eval()
if torch.cuda.is_available():
return model.cuda()
return model
def mask_overlay(image, mask, color=(0, 255, 0)):
'''
出力の二値画像(白が病変、黒が正常)と入力のCE画像を重ね合わせる
'''
mask = np.dstack((mask, mask, mask)) * np.array(color)
mask = mask.astype(np.uint8)
weighted_sum = cv2.addWeighted(mask, 0.5, image, 0.5, 0.)
img = image.copy()
ind = mask[:, :, 1] > 0
img[ind] = weighted_sum[ind]
return img
def predict_main(in_path='./input_image', out_path='./result', model_path='weight/model_1_20epoch_8.pt', stat_area=300, r_display=0, g_display=255, b_display=0, source_save=True):
'''
学習済みモデルによる推論の実行
in_path:入力画像フォルダ(FCNであるため入力サイズは可変、本研究では512×512px)
out_path:出力画像保存フォルダ
model_path:学習済みモデルの保存されているフォルダ
stat_area:出力において、病変と予測された領域の面積がstat_area未満の場合削除される
r_display, g_display, b_display:セグメンテーション領域を何色にするか
source_save:True ⇒ 元画像を保存
'''
file_names_ = [
in_path
]
save_file_names = [
out_path
]
img_transform = Compose([
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = get_model(f'{model_path}', model_type='UNet16')#モデルの読み込み
model_name_only = os.path.splitext(os.path.basename(model_path))[0]
for j, file_names in enumerate(file_names_):#フォルダ数(1回)ループ
read_file_name = glob.glob(f'{file_names}/*.png')
read_file_name_only = [os.path.splitext(os.path.basename(p))[0] for p in read_file_name]
os.makedirs(f'{save_file_names[j]}/{model_name_only}', exist_ok=True)
if source_save:
os.makedirs(f'{save_file_names[j]}/{model_name_only}/source', exist_ok=True)
print(file_names)
start_ = time.time()
for i in range(len(read_file_name)):#画像枚数分ループ
img = cv2.imread(str(read_file_name[i]))#入力画像の読み込み
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_img = torch.unsqueeze(variable(img_transform(img), volatile=True), dim=0)
mask = model(input_img)#推論の実行(maskは二値画像で白⇒病変、黒⇒それ以外)
mask_array = mask.data[0].cpu().numpy()[0]
mask_array = np.where(mask_array > 0, np.uint8(255), np.uint8(0))
retval, labels, stats, _ = cv2.connectedComponentsWithStats(mask_array)#推論後の領域毎の情報を取得
if mask_array.sum() >= stat_area:
count_small_blob_num = 0 #stat_area[px]以下の領域の数
for i2, row in enumerate(stats):
#print(f"label {i}")
#print(f"* topleft: ({row[cv2.CC_STAT_LEFT]}, {row[cv2.CC_STAT_TOP]})")
#print(f"* size: ({row[cv2.CC_STAT_WIDTH]}, {row[cv2.CC_STAT_HEIGHT]})")
#print(f"* area: {row[cv2.CC_STAT_AREA]}")
if (row[cv2.CC_STAT_AREA]<stat_area) & (i2!=0):#領域ごとで面積がstat_area[px]未満の場合に黒くする
count_small_blob_num += 1
mask_array = np.where(labels[:] == i2, 0, mask_array)
count_blob_num = (retval - 1) - count_small_blob_num #領域の個数
if count_blob_num>0:#領域の個数が1個以上で画像を保存する
if source_save:
cv2.imwrite(f'{save_file_names[j]}/{model_name_only}/source/{read_file_name_only[i]}.png', cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
g_display = 1 if g_display==0 else g_display #g_displayの値が0だとなぜか色が表示されないので0なら1にする
saveimg = mask_overlay(img, (mask_array > 0).astype(np.uint8), color=(r_display, g_display, b_display))
cv2.imwrite(f'{save_file_names[j]}/{model_name_only}/{read_file_name_only[i]}.png', cv2.cvtColor(saveimg, cv2.COLOR_BGR2RGB))
print(time.time() - start_)
if __name__=="__main__":
predict_main(in_path='./input_image', out_path='./result', model_path='weight/model_1_20epoch_8.pt', stat_area=300, r_display=0, g_display=255, b_display=0, source_save=True)
'''
引数
in_path:入力画像フォルダ(FCNであるため入力サイズは可変、本研究では512×512px)
out_path:出力画像保存フォルダ
model_path:学習済みモデルの保存されているフォルダ
stat_area:出力において、病変と予測された領域の面積がstat_area未満の場合削除される
r_display, g_display, b_display:セグメンテーション領域を何色にするか
source_save:True ⇒ 元画像を保存
'''