|
a |
|
b/models/mrcnn.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
# ============================================================================== |
|
|
16 |
|
|
|
17 |
""" |
|
|
18 |
Parts are based on https://github.com/multimodallearning/pytorch-mask-rcnn |
|
|
19 |
published under MIT license. |
|
|
20 |
""" |
|
|
21 |
import sys |
|
|
22 |
|
|
|
23 |
import numpy as np |
|
|
24 |
import torch |
|
|
25 |
import torch.nn as nn |
|
|
26 |
import torch.nn.functional as F |
|
|
27 |
import torch.utils |
|
|
28 |
|
|
|
29 |
sys.path.append("..") |
|
|
30 |
import utils.model_utils as mutils |
|
|
31 |
import utils.exp_utils as utils |
|
|
32 |
from custom_extensions.nms import nms |
|
|
33 |
from custom_extensions.roi_align import roi_align |
|
|
34 |
|
|
|
35 |
############################################################ |
|
|
36 |
# Networks on top of backbone |
|
|
37 |
############################################################ |
|
|
38 |
|
|
|
39 |
class RPN(nn.Module): |
|
|
40 |
""" |
|
|
41 |
Region Proposal Network. |
|
|
42 |
""" |
|
|
43 |
|
|
|
44 |
def __init__(self, cf, conv): |
|
|
45 |
|
|
|
46 |
super(RPN, self).__init__() |
|
|
47 |
self.dim = conv.dim |
|
|
48 |
|
|
|
49 |
self.conv_shared = conv(cf.end_filts, cf.n_rpn_features, ks=3, stride=cf.rpn_anchor_stride, pad=1, relu=cf.relu) |
|
|
50 |
self.conv_class = conv(cf.n_rpn_features, 2 * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) |
|
|
51 |
self.conv_bbox = conv(cf.n_rpn_features, 2 * self.dim * len(cf.rpn_anchor_ratios), ks=1, stride=1, relu=None) |
|
|
52 |
|
|
|
53 |
|
|
|
54 |
def forward(self, x): |
|
|
55 |
""" |
|
|
56 |
:param x: input feature maps (b, in_channels, y, x, (z)) |
|
|
57 |
:return: rpn_class_logits (b, 2, n_anchors) |
|
|
58 |
:return: rpn_probs_logits (b, 2, n_anchors) |
|
|
59 |
:return: rpn_bbox (b, 2 * dim, n_anchors) |
|
|
60 |
""" |
|
|
61 |
|
|
|
62 |
# Shared convolutional base of the RPN. |
|
|
63 |
x = self.conv_shared(x) |
|
|
64 |
|
|
|
65 |
# Anchor Score. (batch, anchors per location * 2, y, x, (z)). |
|
|
66 |
rpn_class_logits = self.conv_class(x) |
|
|
67 |
# Reshape to (batch, 2, anchors) |
|
|
68 |
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1) |
|
|
69 |
rpn_class_logits = rpn_class_logits.permute(*axes) |
|
|
70 |
rpn_class_logits = rpn_class_logits.contiguous() |
|
|
71 |
rpn_class_logits = rpn_class_logits.view(x.size()[0], -1, 2) |
|
|
72 |
|
|
|
73 |
# Softmax on last dimension (fg vs. bg). |
|
|
74 |
rpn_probs = F.softmax(rpn_class_logits, dim=2) |
|
|
75 |
|
|
|
76 |
# Bounding box refinement. (batch, anchors_per_location * (y, x, (z), log(h), log(w), (log(d)), y, x, (z)) |
|
|
77 |
rpn_bbox = self.conv_bbox(x) |
|
|
78 |
|
|
|
79 |
# Reshape to (batch, 2*dim, anchors) |
|
|
80 |
rpn_bbox = rpn_bbox.permute(*axes) |
|
|
81 |
rpn_bbox = rpn_bbox.contiguous() |
|
|
82 |
rpn_bbox = rpn_bbox.view(x.size()[0], -1, self.dim * 2) |
|
|
83 |
|
|
|
84 |
return [rpn_class_logits, rpn_probs, rpn_bbox] |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
|
|
|
88 |
class Classifier(nn.Module): |
|
|
89 |
""" |
|
|
90 |
Head network for classification and bounding box refinement. Performs RoiAlign, processes resulting features through a |
|
|
91 |
shared convolutional base and finally branches off the classifier- and regression head. |
|
|
92 |
""" |
|
|
93 |
def __init__(self, cf, conv): |
|
|
94 |
super(Classifier, self).__init__() |
|
|
95 |
|
|
|
96 |
self.dim = conv.dim |
|
|
97 |
self.in_channels = cf.end_filts |
|
|
98 |
self.pool_size = cf.pool_size |
|
|
99 |
self.pyramid_levels = cf.pyramid_levels |
|
|
100 |
# instance_norm does not work with spatial dims (1, 1, (1)) |
|
|
101 |
norm = cf.norm if cf.norm != 'instance_norm' else None |
|
|
102 |
|
|
|
103 |
self.conv1 = conv(cf.end_filts, cf.end_filts * 4, ks=self.pool_size, stride=1, norm=norm, relu=cf.relu) |
|
|
104 |
self.conv2 = conv(cf.end_filts * 4, cf.end_filts * 4, ks=1, stride=1, norm=norm, relu=cf.relu) |
|
|
105 |
self.linear_class = nn.Linear(cf.end_filts * 4, cf.head_classes) |
|
|
106 |
self.linear_bbox = nn.Linear(cf.end_filts * 4, cf.head_classes * 2 * self.dim) |
|
|
107 |
|
|
|
108 |
def forward(self, x, rois): |
|
|
109 |
""" |
|
|
110 |
:param x: input feature maps (b, in_channels, y, x, (z)) |
|
|
111 |
:param rois: normalized box coordinates as proposed by the RPN to be forwarded through |
|
|
112 |
the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements |
|
|
113 |
have been merged to one vector, while the origin info has been stored for re-allocation. |
|
|
114 |
:return: mrcnn_class_logits (n_proposals, n_head_classes) |
|
|
115 |
:return: mrcnn_bbox (n_proposals, n_head_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. |
|
|
116 |
""" |
|
|
117 |
x = pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) |
|
|
118 |
x = self.conv1(x) |
|
|
119 |
x = self.conv2(x) |
|
|
120 |
x = x.view(-1, self.in_channels * 4) |
|
|
121 |
mrcnn_class_logits = self.linear_class(x) |
|
|
122 |
mrcnn_bbox = self.linear_bbox(x) |
|
|
123 |
mrcnn_bbox = mrcnn_bbox.view(mrcnn_bbox.size()[0], -1, self.dim * 2) |
|
|
124 |
|
|
|
125 |
return [mrcnn_class_logits, mrcnn_bbox] |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
|
|
|
129 |
class Mask(nn.Module): |
|
|
130 |
""" |
|
|
131 |
Head network for proposal-based mask segmentation. Performs RoiAlign, some convolutions and applies sigmoid on the |
|
|
132 |
output logits to allow for overlapping classes. |
|
|
133 |
""" |
|
|
134 |
def __init__(self, cf, conv): |
|
|
135 |
super(Mask, self).__init__() |
|
|
136 |
self.pool_size = cf.mask_pool_size |
|
|
137 |
self.pyramid_levels = cf.pyramid_levels |
|
|
138 |
self.dim = conv.dim |
|
|
139 |
self.conv1 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) |
|
|
140 |
self.conv2 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) |
|
|
141 |
self.conv3 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) |
|
|
142 |
self.conv4 = conv(cf.end_filts, cf.end_filts, ks=3, stride=1, pad=1, norm=cf.norm, relu=cf.relu) |
|
|
143 |
if conv.dim == 2: |
|
|
144 |
self.deconv = nn.ConvTranspose2d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) |
|
|
145 |
else: |
|
|
146 |
self.deconv = nn.ConvTranspose3d(cf.end_filts, cf.end_filts, kernel_size=2, stride=2) |
|
|
147 |
|
|
|
148 |
self.relu = nn.ReLU(inplace=True) if cf.relu == 'relu' else nn.LeakyReLU(inplace=True) |
|
|
149 |
self.conv5 = conv(cf.end_filts, cf.head_classes, ks=1, stride=1, relu=None) |
|
|
150 |
self.sigmoid = nn.Sigmoid() |
|
|
151 |
|
|
|
152 |
def forward(self, x, rois): |
|
|
153 |
""" |
|
|
154 |
:param x: input feature maps (b, in_channels, y, x, (z)) |
|
|
155 |
:param rois: normalized box coordinates as proposed by the RPN to be forwarded through |
|
|
156 |
the second stage (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix). Proposals of all batch elements |
|
|
157 |
have been merged to one vector, while the origin info has been stored for re-allocation. |
|
|
158 |
:return: x: masks (n_sampled_proposals (n_detections in inference), n_classes, y, x, (z)) |
|
|
159 |
""" |
|
|
160 |
x = pyramid_roi_align(x, rois, self.pool_size, self.pyramid_levels, self.dim) |
|
|
161 |
x = self.conv1(x) |
|
|
162 |
x = self.conv2(x) |
|
|
163 |
x = self.conv3(x) |
|
|
164 |
x = self.conv4(x) |
|
|
165 |
x = self.relu(self.deconv(x)) |
|
|
166 |
x = self.conv5(x) |
|
|
167 |
x = self.sigmoid(x) |
|
|
168 |
return x |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
############################################################ |
|
|
172 |
# Loss Functions |
|
|
173 |
############################################################ |
|
|
174 |
|
|
|
175 |
def compute_rpn_class_loss(rpn_match, rpn_class_logits, shem_poolsize): |
|
|
176 |
""" |
|
|
177 |
:param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. |
|
|
178 |
:param rpn_class_logits: (n_anchors, 2). logits from RPN classifier. |
|
|
179 |
:param shem_poolsize: int. factor of top-k candidates to draw from per negative sample |
|
|
180 |
(stochastic-hard-example-mining). |
|
|
181 |
:return: loss: torch tensor |
|
|
182 |
:return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training. |
|
|
183 |
""" |
|
|
184 |
|
|
|
185 |
# filter out neutral anchors. |
|
|
186 |
pos_indices = torch.nonzero(rpn_match == 1) |
|
|
187 |
neg_indices = torch.nonzero(rpn_match == -1) |
|
|
188 |
|
|
|
189 |
# loss for positive samples |
|
|
190 |
if 0 not in pos_indices.size(): |
|
|
191 |
pos_indices = pos_indices.squeeze(1) |
|
|
192 |
roi_logits_pos = rpn_class_logits[pos_indices] |
|
|
193 |
pos_loss = F.cross_entropy(roi_logits_pos, torch.LongTensor([1] * pos_indices.shape[0]).cuda()) |
|
|
194 |
else: |
|
|
195 |
pos_loss = torch.FloatTensor([0]).cuda() |
|
|
196 |
|
|
|
197 |
# loss for negative samples: draw hard negative examples (SHEM) |
|
|
198 |
# that match the number of positive samples, but at least 1. |
|
|
199 |
if 0 not in neg_indices.size(): |
|
|
200 |
neg_indices = neg_indices.squeeze(1) |
|
|
201 |
roi_logits_neg = rpn_class_logits[neg_indices] |
|
|
202 |
negative_count = np.max((1, pos_indices.cpu().data.numpy().size)) |
|
|
203 |
roi_probs_neg = F.softmax(roi_logits_neg, dim=1) |
|
|
204 |
neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize) |
|
|
205 |
neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda()) |
|
|
206 |
np_neg_ix = neg_ix.cpu().data.numpy() |
|
|
207 |
else: |
|
|
208 |
neg_loss = torch.FloatTensor([0]).cuda() |
|
|
209 |
np_neg_ix = np.array([]).astype('int32') |
|
|
210 |
|
|
|
211 |
loss = (pos_loss + neg_loss) / 2 |
|
|
212 |
return loss, np_neg_ix |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
def compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas, rpn_match): |
|
|
216 |
""" |
|
|
217 |
:param rpn_target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))). |
|
|
218 |
Uses 0 padding to fill in unsed bbox deltas. |
|
|
219 |
:param rpn_pred_deltas: predicted deltas from RPN. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))) |
|
|
220 |
:param rpn_match: (n_anchors). [-1, 0, 1] for negative, neutral, and positive matched anchors. |
|
|
221 |
:return: loss: torch 1D tensor. |
|
|
222 |
""" |
|
|
223 |
if 0 not in torch.nonzero(rpn_match == 1).size(): |
|
|
224 |
|
|
|
225 |
indices = torch.nonzero(rpn_match == 1).squeeze(1) |
|
|
226 |
# Pick bbox deltas that contribute to the loss |
|
|
227 |
rpn_pred_deltas = rpn_pred_deltas[indices] |
|
|
228 |
# Trim target bounding box deltas to the same length as rpn_bbox. |
|
|
229 |
target_deltas = rpn_target_deltas[:rpn_pred_deltas.size()[0], :] |
|
|
230 |
# Smooth L1 loss |
|
|
231 |
loss = F.smooth_l1_loss(rpn_pred_deltas, target_deltas) |
|
|
232 |
else: |
|
|
233 |
loss = torch.FloatTensor([0]).cuda() |
|
|
234 |
|
|
|
235 |
return loss |
|
|
236 |
|
|
|
237 |
|
|
|
238 |
def compute_mrcnn_class_loss(target_class_ids, pred_class_logits): |
|
|
239 |
""" |
|
|
240 |
:param target_class_ids: (n_sampled_rois) batch dimension was merged into roi dimension. |
|
|
241 |
:param pred_class_logits: (n_sampled_rois, n_classes) |
|
|
242 |
:return: loss: torch 1D tensor. |
|
|
243 |
""" |
|
|
244 |
if 0 not in target_class_ids.size(): |
|
|
245 |
loss = F.cross_entropy(pred_class_logits, target_class_ids.long()) |
|
|
246 |
else: |
|
|
247 |
loss = torch.FloatTensor([0.]).cuda() |
|
|
248 |
|
|
|
249 |
return loss |
|
|
250 |
|
|
|
251 |
|
|
|
252 |
def compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids): |
|
|
253 |
""" |
|
|
254 |
:param mrcnn_target_deltas: (n_sampled_rois, (dy, dx, (dz), log(dh), log(dw), (log(dh))) |
|
|
255 |
:param mrcnn_pred_deltas: (n_sampled_rois, n_classes, (dy, dx, (dz), log(dh), log(dw), (log(dh))) |
|
|
256 |
:param target_class_ids: (n_sampled_rois) |
|
|
257 |
:return: loss: torch 1D tensor. |
|
|
258 |
""" |
|
|
259 |
if 0 not in torch.nonzero(target_class_ids > 0).size(): |
|
|
260 |
positive_roi_ix = torch.nonzero(target_class_ids > 0)[:, 0] |
|
|
261 |
positive_roi_class_ids = target_class_ids[positive_roi_ix].long() |
|
|
262 |
target_bbox = mrcnn_target_deltas[positive_roi_ix, :].detach() |
|
|
263 |
pred_bbox = mrcnn_pred_deltas[positive_roi_ix, positive_roi_class_ids, :] |
|
|
264 |
loss = F.smooth_l1_loss(pred_bbox, target_bbox) |
|
|
265 |
else: |
|
|
266 |
loss = torch.FloatTensor([0]).cuda() |
|
|
267 |
|
|
|
268 |
return loss |
|
|
269 |
|
|
|
270 |
|
|
|
271 |
def compute_mrcnn_mask_loss(target_masks, pred_masks, target_class_ids): |
|
|
272 |
""" |
|
|
273 |
:param target_masks: (n_sampled_rois, y, x, (z)) A float32 tensor of values 0 or 1. Uses zero padding to fill array. |
|
|
274 |
:param pred_masks: (n_sampled_rois, n_classes, y, x, (z)) float32 tensor with values between [0, 1]. |
|
|
275 |
:param target_class_ids: (n_sampled_rois) |
|
|
276 |
:return: loss: torch 1D tensor. |
|
|
277 |
""" |
|
|
278 |
if 0 not in torch.nonzero(target_class_ids > 0).size(): |
|
|
279 |
# Only positive ROIs contribute to the loss. And only |
|
|
280 |
# the class specific mask of each ROI. |
|
|
281 |
positive_ix = torch.nonzero(target_class_ids > 0)[:, 0] |
|
|
282 |
positive_class_ids = target_class_ids[positive_ix].long() |
|
|
283 |
y_true = target_masks[positive_ix, :, :].detach() |
|
|
284 |
y_pred = pred_masks[positive_ix, positive_class_ids, :, :] |
|
|
285 |
loss = F.binary_cross_entropy(y_pred, y_true) |
|
|
286 |
else: |
|
|
287 |
loss = torch.FloatTensor([0]).cuda() |
|
|
288 |
|
|
|
289 |
return loss |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
############################################################ |
|
|
293 |
# Helper Layers |
|
|
294 |
############################################################ |
|
|
295 |
|
|
|
296 |
def refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, batch_anchors, cf): |
|
|
297 |
""" |
|
|
298 |
Receives anchor scores and selects a subset to pass as proposals |
|
|
299 |
to the second stage. Filtering is done based on anchor scores and |
|
|
300 |
non-max suppression to remove overlaps. It also applies bounding |
|
|
301 |
box refinment details to anchors. |
|
|
302 |
:param rpn_pred_probs: (b, n_anchors, 2) |
|
|
303 |
:param rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) |
|
|
304 |
:return: batch_normalized_props: Proposals in normalized coordinates (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) |
|
|
305 |
:return: batch_out_proposals: Box coords + RPN foreground scores |
|
|
306 |
for monitoring/plotting (b, proposal_count, (y1, x1, y2, x2, (z1), (z2), score)) |
|
|
307 |
""" |
|
|
308 |
std_dev = torch.from_numpy(cf.rpn_bbox_std_dev[None]).float().cuda() |
|
|
309 |
norm = torch.from_numpy(cf.scale).float().cuda() |
|
|
310 |
anchors = batch_anchors.clone() |
|
|
311 |
|
|
|
312 |
|
|
|
313 |
|
|
|
314 |
batch_scores = rpn_pred_probs[:, :, 1] |
|
|
315 |
# norm deltas |
|
|
316 |
batch_deltas = rpn_pred_deltas * std_dev |
|
|
317 |
batch_normalized_props = [] |
|
|
318 |
batch_out_proposals = [] |
|
|
319 |
|
|
|
320 |
# loop over batch dimension. |
|
|
321 |
for ix in range(batch_scores.shape[0]): |
|
|
322 |
|
|
|
323 |
scores = batch_scores[ix] |
|
|
324 |
deltas = batch_deltas[ix] |
|
|
325 |
|
|
|
326 |
# improve performance by trimming to top anchors by score |
|
|
327 |
# and doing the rest on the smaller subset. |
|
|
328 |
pre_nms_limit = min(cf.pre_nms_limit, anchors.size()[0]) |
|
|
329 |
scores, order = scores.sort(descending=True) |
|
|
330 |
order = order[:pre_nms_limit] |
|
|
331 |
scores = scores[:pre_nms_limit] |
|
|
332 |
deltas = deltas[order, :] |
|
|
333 |
|
|
|
334 |
# apply deltas to anchors to get refined anchors and filter with non-maximum suppression. |
|
|
335 |
if batch_deltas.shape[-1] == 4: |
|
|
336 |
boxes = mutils.apply_box_deltas_2D(anchors[order, :], deltas) |
|
|
337 |
boxes = mutils.clip_boxes_2D(boxes, cf.window) |
|
|
338 |
else: |
|
|
339 |
boxes = mutils.apply_box_deltas_3D(anchors[order, :], deltas) |
|
|
340 |
boxes = mutils.clip_boxes_3D(boxes, cf.window) |
|
|
341 |
# boxes are y1,x1,y2,x2, torchvision-nms requires x1,y1,x2,y2, but consistent swap x<->y is irrelevant. |
|
|
342 |
keep = nms.nms(boxes, scores, cf.rpn_nms_threshold) |
|
|
343 |
|
|
|
344 |
|
|
|
345 |
keep = keep[:proposal_count] |
|
|
346 |
boxes = boxes[keep, :] |
|
|
347 |
rpn_scores = scores[keep][:, None] |
|
|
348 |
|
|
|
349 |
# pad missing boxes with 0. |
|
|
350 |
if boxes.shape[0] < proposal_count: |
|
|
351 |
n_pad_boxes = proposal_count - boxes.shape[0] |
|
|
352 |
zeros = torch.zeros([n_pad_boxes, boxes.shape[1]]).cuda() |
|
|
353 |
boxes = torch.cat([boxes, zeros], dim=0) |
|
|
354 |
zeros = torch.zeros([n_pad_boxes, rpn_scores.shape[1]]).cuda() |
|
|
355 |
rpn_scores = torch.cat([rpn_scores, zeros], dim=0) |
|
|
356 |
|
|
|
357 |
# concat box and score info for monitoring/plotting. |
|
|
358 |
batch_out_proposals.append(torch.cat((boxes, rpn_scores), 1).cpu().data.numpy()) |
|
|
359 |
# normalize dimensions to range of 0 to 1. |
|
|
360 |
normalized_boxes = boxes / norm |
|
|
361 |
assert torch.all(normalized_boxes <= 1), "normalized box coords >1 found" |
|
|
362 |
|
|
|
363 |
# add again batch dimension |
|
|
364 |
batch_normalized_props.append(normalized_boxes.unsqueeze(0)) |
|
|
365 |
|
|
|
366 |
batch_normalized_props = torch.cat(batch_normalized_props) |
|
|
367 |
batch_out_proposals = np.array(batch_out_proposals) |
|
|
368 |
|
|
|
369 |
return batch_normalized_props, batch_out_proposals |
|
|
370 |
|
|
|
371 |
|
|
|
372 |
def pyramid_roi_align(feature_maps, rois, pool_size, pyramid_levels, dim): |
|
|
373 |
""" |
|
|
374 |
Implements ROI Pooling on multiple levels of the feature pyramid. |
|
|
375 |
:param feature_maps: list of feature maps, each of shape (b, c, y, x , (z)) |
|
|
376 |
:param rois: proposals (normalized coords.) as returned by RPN. contain info about original batch element allocation. |
|
|
377 |
(n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs) |
|
|
378 |
:param pool_size: list of poolsizes in dims: [x, y, (z)] |
|
|
379 |
:param pyramid_levels: list. [0, 1, 2, ...] |
|
|
380 |
:return: pooled: pooled feature map rois (n_proposals, c, poolsize_y, poolsize_x, (poolsize_z)) |
|
|
381 |
Output: |
|
|
382 |
Pooled regions in the shape: [num_boxes, height, width, channels]. |
|
|
383 |
The width and height are those specific in the pool_shape in the layer |
|
|
384 |
constructor. |
|
|
385 |
""" |
|
|
386 |
boxes = rois[:, :dim*2] |
|
|
387 |
batch_ixs = rois[:, dim*2] |
|
|
388 |
|
|
|
389 |
# Assign each ROI to a level in the pyramid based on the ROI area. |
|
|
390 |
if dim == 2: |
|
|
391 |
y1, x1, y2, x2 = boxes.chunk(4, dim=1) |
|
|
392 |
else: |
|
|
393 |
y1, x1, y2, x2, z1, z2 = boxes.chunk(6, dim=1) |
|
|
394 |
|
|
|
395 |
h = y2 - y1 |
|
|
396 |
w = x2 - x1 |
|
|
397 |
|
|
|
398 |
# Equation 1 in https://arxiv.org/abs/1612.03144. Account for |
|
|
399 |
# the fact that our coordinates are normalized here. |
|
|
400 |
# divide sqrt(h*w) by 1 instead image_area. |
|
|
401 |
roi_level = (4 + torch.log2(torch.sqrt(h*w))).round().int().clamp(pyramid_levels[0], pyramid_levels[-1]) |
|
|
402 |
# if Pyramid contains additional level P6, adapt the roi_level assignment accordingly. |
|
|
403 |
if len(pyramid_levels) == 5: |
|
|
404 |
roi_level[h*w > 0.65] = 5 |
|
|
405 |
|
|
|
406 |
# Loop through levels and apply ROI pooling to each. |
|
|
407 |
pooled = [] |
|
|
408 |
box_to_level = [] |
|
|
409 |
fmap_shapes = [f.shape for f in feature_maps] |
|
|
410 |
for level_ix, level in enumerate(pyramid_levels): |
|
|
411 |
ix = roi_level == level |
|
|
412 |
if not ix.any(): |
|
|
413 |
continue |
|
|
414 |
ix = torch.nonzero(ix)[:, 0] |
|
|
415 |
level_boxes = boxes[ix, :] |
|
|
416 |
# re-assign rois to feature map of original batch element. |
|
|
417 |
ind = batch_ixs[ix].int() |
|
|
418 |
|
|
|
419 |
# Keep track of which box is mapped to which level |
|
|
420 |
box_to_level.append(ix) |
|
|
421 |
|
|
|
422 |
# Stop gradient propogation to ROI proposals |
|
|
423 |
level_boxes = level_boxes.detach() |
|
|
424 |
if len(pool_size) == 2: |
|
|
425 |
# remap to feature map coordinate system |
|
|
426 |
y_exp, x_exp = fmap_shapes[level_ix][2:] # exp = expansion |
|
|
427 |
level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda()) |
|
|
428 |
pooled_features = roi_align.roi_align_2d(feature_maps[level_ix], |
|
|
429 |
torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), |
|
|
430 |
pool_size) |
|
|
431 |
else: |
|
|
432 |
y_exp, x_exp, z_exp = fmap_shapes[level_ix][2:] |
|
|
433 |
level_boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda()) |
|
|
434 |
pooled_features = roi_align.roi_align_3d(feature_maps[level_ix], |
|
|
435 |
torch.cat((ind.unsqueeze(1).float(), level_boxes), dim=1), |
|
|
436 |
pool_size) |
|
|
437 |
pooled.append(pooled_features) |
|
|
438 |
|
|
|
439 |
|
|
|
440 |
# Pack pooled features into one tensor |
|
|
441 |
pooled = torch.cat(pooled, dim=0) |
|
|
442 |
|
|
|
443 |
# Pack box_to_level mapping into one array and add another |
|
|
444 |
# column representing the order of pooled boxes |
|
|
445 |
box_to_level = torch.cat(box_to_level, dim=0) |
|
|
446 |
|
|
|
447 |
# Rearrange pooled features to match the order of the original boxes |
|
|
448 |
_, box_to_level = torch.sort(box_to_level) |
|
|
449 |
pooled = pooled[box_to_level, :, :] |
|
|
450 |
|
|
|
451 |
return pooled |
|
|
452 |
|
|
|
453 |
|
|
|
454 |
def detection_target_layer(batch_proposals, batch_mrcnn_class_scores, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks, cf): |
|
|
455 |
""" |
|
|
456 |
Subsamples proposals for mrcnn losses and generates targets. Sampling is done per batch element, seems to have positive |
|
|
457 |
effects on training, as opposed to sampling over entire batch. Negatives are sampled via stochastic-hard-example-mining |
|
|
458 |
(SHEM), where a number of negative proposals are drawn from larger pool of highest scoring proposals for stochasticity. |
|
|
459 |
Scoring is obtained here as the max over all foreground probabilities as returned by mrcnn_classifier (worked better than |
|
|
460 |
loss-based class balancing methods like "online-hard-example-mining" or "focal loss".) |
|
|
461 |
:param batch_proposals: (n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ixs). |
|
|
462 |
boxes as proposed by RPN. n_proposals here is determined by batch_size * POST_NMS_ROIS. |
|
|
463 |
:param batch_mrcnn_class_scores: (n_proposals, n_classes) |
|
|
464 |
:param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. |
|
|
465 |
:param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. |
|
|
466 |
:param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) |
|
|
467 |
:return: sample_indices: (n_sampled_rois) indices of sampled proposals to be used for loss functions. |
|
|
468 |
:return: target_class_ids: (n_sampled_rois)containing target class labels of sampled proposals. |
|
|
469 |
:return: target_deltas: (n_sampled_rois, 2 * dim) containing target deltas of sampled proposals for box refinement. |
|
|
470 |
:return: target_masks: (n_sampled_rois, y, x, (z)) containing target masks of sampled proposals. |
|
|
471 |
""" |
|
|
472 |
# normalization of target coordinates |
|
|
473 |
if cf.dim == 2: |
|
|
474 |
h, w = cf.patch_size |
|
|
475 |
scale = torch.from_numpy(np.array([h, w, h, w])).float().cuda() |
|
|
476 |
else: |
|
|
477 |
h, w, z = cf.patch_size |
|
|
478 |
scale = torch.from_numpy(np.array([h, w, h, w, z, z])).float().cuda() |
|
|
479 |
|
|
|
480 |
positive_count = 0 |
|
|
481 |
negative_count = 0 |
|
|
482 |
sample_positive_indices = [] |
|
|
483 |
sample_negative_indices = [] |
|
|
484 |
sample_deltas = [] |
|
|
485 |
sample_masks = [] |
|
|
486 |
sample_class_ids = [] |
|
|
487 |
|
|
|
488 |
std_dev = torch.from_numpy(cf.bbox_std_dev).float().cuda() |
|
|
489 |
|
|
|
490 |
# loop over batch and get positive and negative sample rois. |
|
|
491 |
for b in range(len(batch_gt_class_ids)): |
|
|
492 |
|
|
|
493 |
gt_class_ids = torch.from_numpy(batch_gt_class_ids[b]).int().cuda() |
|
|
494 |
gt_masks = torch.from_numpy(batch_gt_masks[b]).float().cuda() |
|
|
495 |
if np.any(batch_gt_class_ids[b] > 0): # skip roi selection for no gt images. |
|
|
496 |
gt_boxes = torch.from_numpy(batch_gt_boxes[b]).float().cuda() / scale |
|
|
497 |
else: |
|
|
498 |
gt_boxes = torch.FloatTensor().cuda() |
|
|
499 |
|
|
|
500 |
# get proposals and indices of current batch element. |
|
|
501 |
proposals = batch_proposals[batch_proposals[:, -1] == b][:, :-1] |
|
|
502 |
batch_element_indices = torch.nonzero(batch_proposals[:, -1] == b).squeeze(1) |
|
|
503 |
|
|
|
504 |
# Compute overlaps matrix [proposals, gt_boxes] |
|
|
505 |
if 0 not in gt_boxes.size(): |
|
|
506 |
if gt_boxes.shape[1] == 4: |
|
|
507 |
assert cf.dim == 2, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) |
|
|
508 |
overlaps = mutils.bbox_overlaps_2D(proposals, gt_boxes) |
|
|
509 |
else: |
|
|
510 |
assert cf.dim == 3, "gt_boxes shape {} doesnt match cf.dim{}".format(gt_boxes.shape, cf.dim) |
|
|
511 |
overlaps = mutils.bbox_overlaps_3D(proposals, gt_boxes) |
|
|
512 |
|
|
|
513 |
# Determine postive and negative ROIs |
|
|
514 |
roi_iou_max = torch.max(overlaps, dim=1)[0] |
|
|
515 |
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box |
|
|
516 |
positive_roi_bool = roi_iou_max >= (0.5 if cf.dim == 2 else 0.3) |
|
|
517 |
# 2. Negative ROIs are those with < 0.1 with every GT box. |
|
|
518 |
negative_roi_bool = roi_iou_max < (0.1 if cf.dim == 2 else 0.01) |
|
|
519 |
else: |
|
|
520 |
positive_roi_bool = torch.FloatTensor().cuda() |
|
|
521 |
negative_roi_bool = torch.from_numpy(np.array([1]*proposals.shape[0])).cuda() |
|
|
522 |
|
|
|
523 |
# Sample Positive ROIs |
|
|
524 |
if 0 not in torch.nonzero(positive_roi_bool).size(): |
|
|
525 |
positive_indices = torch.nonzero(positive_roi_bool).squeeze(1) |
|
|
526 |
positive_samples = int(cf.train_rois_per_image * cf.roi_positive_ratio) |
|
|
527 |
rand_idx = torch.randperm(positive_indices.size()[0]) |
|
|
528 |
rand_idx = rand_idx[:positive_samples].cuda() |
|
|
529 |
positive_indices = positive_indices[rand_idx] |
|
|
530 |
positive_samples = positive_indices.size()[0] |
|
|
531 |
positive_rois = proposals[positive_indices, :] |
|
|
532 |
# Assign positive ROIs to GT boxes. |
|
|
533 |
positive_overlaps = overlaps[positive_indices, :] |
|
|
534 |
roi_gt_box_assignment = torch.max(positive_overlaps, dim=1)[1] |
|
|
535 |
roi_gt_boxes = gt_boxes[roi_gt_box_assignment, :] |
|
|
536 |
roi_gt_class_ids = gt_class_ids[roi_gt_box_assignment] |
|
|
537 |
|
|
|
538 |
# Compute bbox refinement targets for positive ROIs |
|
|
539 |
deltas = mutils.box_refinement(positive_rois, roi_gt_boxes) |
|
|
540 |
deltas /= std_dev |
|
|
541 |
|
|
|
542 |
# Assign positive ROIs to GT masks |
|
|
543 |
roi_masks = gt_masks[roi_gt_box_assignment] |
|
|
544 |
assert roi_masks.shape[1] == 1, "desired to have more than one channel in gt masks?" |
|
|
545 |
|
|
|
546 |
# Compute mask targets |
|
|
547 |
boxes = positive_rois |
|
|
548 |
box_ids = torch.arange(roi_masks.shape[0]).cuda().unsqueeze(1).float() |
|
|
549 |
if len(cf.mask_shape) == 2: |
|
|
550 |
# need to remap normalized box coordinates to unnormalized mask coordinates. |
|
|
551 |
y_exp, x_exp = roi_masks.shape[2:] # exp = expansion |
|
|
552 |
boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp], dtype=torch.float32).cuda()) |
|
|
553 |
masks = roi_align.roi_align_2d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) |
|
|
554 |
else: |
|
|
555 |
y_exp, x_exp, z_exp = roi_masks.shape[2:] # exp = expansion |
|
|
556 |
boxes.mul_(torch.tensor([y_exp, x_exp, y_exp, x_exp, z_exp, z_exp], dtype=torch.float32).cuda()) |
|
|
557 |
masks = roi_align.roi_align_3d(roi_masks, torch.cat((box_ids, boxes), dim=1), cf.mask_shape) |
|
|
558 |
masks = masks.squeeze(1) |
|
|
559 |
# Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with |
|
|
560 |
# binary cross entropy loss. |
|
|
561 |
masks = torch.round(masks) |
|
|
562 |
|
|
|
563 |
sample_positive_indices.append(batch_element_indices[positive_indices]) |
|
|
564 |
sample_deltas.append(deltas) |
|
|
565 |
sample_masks.append(masks) |
|
|
566 |
sample_class_ids.append(roi_gt_class_ids) |
|
|
567 |
positive_count += positive_samples |
|
|
568 |
else: |
|
|
569 |
positive_samples = 0 |
|
|
570 |
|
|
|
571 |
# Negative ROIs. Add enough to maintain positive:negative ratio, but at least 1. Sample via SHEM. |
|
|
572 |
if 0 not in torch.nonzero(negative_roi_bool).size(): |
|
|
573 |
negative_indices = torch.nonzero(negative_roi_bool).squeeze(1) |
|
|
574 |
r = 1.0 / cf.roi_positive_ratio |
|
|
575 |
b_neg_count = np.max((int(r * positive_samples - positive_samples), 1)) |
|
|
576 |
roi_probs_neg = batch_mrcnn_class_scores[batch_element_indices[negative_indices]] |
|
|
577 |
raw_sampled_indices = mutils.shem(roi_probs_neg, b_neg_count, cf.shem_poolsize) |
|
|
578 |
sample_negative_indices.append(batch_element_indices[negative_indices[raw_sampled_indices]]) |
|
|
579 |
negative_count += raw_sampled_indices.size()[0] |
|
|
580 |
|
|
|
581 |
if len(sample_positive_indices) > 0: |
|
|
582 |
target_deltas = torch.cat(sample_deltas) |
|
|
583 |
target_masks = torch.cat(sample_masks) |
|
|
584 |
target_class_ids = torch.cat(sample_class_ids) |
|
|
585 |
|
|
|
586 |
# Pad target information with zeros for negative ROIs. |
|
|
587 |
if positive_count > 0 and negative_count > 0: |
|
|
588 |
sample_indices = torch.cat((torch.cat(sample_positive_indices), torch.cat(sample_negative_indices)), dim=0) |
|
|
589 |
zeros = torch.zeros(negative_count).int().cuda() |
|
|
590 |
target_class_ids = torch.cat([target_class_ids, zeros], dim=0) |
|
|
591 |
zeros = torch.zeros(negative_count, cf.dim * 2).cuda() |
|
|
592 |
target_deltas = torch.cat([target_deltas, zeros], dim=0) |
|
|
593 |
zeros = torch.zeros(negative_count, *cf.mask_shape).cuda() |
|
|
594 |
target_masks = torch.cat([target_masks, zeros], dim=0) |
|
|
595 |
elif positive_count > 0: |
|
|
596 |
sample_indices = torch.cat(sample_positive_indices) |
|
|
597 |
elif negative_count > 0: |
|
|
598 |
sample_indices = torch.cat(sample_negative_indices) |
|
|
599 |
zeros = torch.zeros(negative_count).int().cuda() |
|
|
600 |
target_class_ids = zeros |
|
|
601 |
zeros = torch.zeros(negative_count, cf.dim * 2).cuda() |
|
|
602 |
target_deltas = zeros |
|
|
603 |
zeros = torch.zeros(negative_count, *cf.mask_shape).cuda() |
|
|
604 |
target_masks = zeros |
|
|
605 |
else: |
|
|
606 |
sample_indices = torch.LongTensor().cuda() |
|
|
607 |
target_class_ids = torch.IntTensor().cuda() |
|
|
608 |
target_deltas = torch.FloatTensor().cuda() |
|
|
609 |
target_masks = torch.FloatTensor().cuda() |
|
|
610 |
|
|
|
611 |
return sample_indices, target_class_ids, target_deltas, target_masks |
|
|
612 |
|
|
|
613 |
|
|
|
614 |
############################################################ |
|
|
615 |
# Output Handler |
|
|
616 |
############################################################ |
|
|
617 |
|
|
|
618 |
# def refine_detections(rois, probs, deltas, batch_ixs, cf): |
|
|
619 |
# """ |
|
|
620 |
# Refine classified proposals, filter overlaps and return final detections. |
|
|
621 |
# |
|
|
622 |
# :param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS |
|
|
623 |
# :param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by mrcnn classifier. |
|
|
624 |
# :param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor. |
|
|
625 |
# :param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation. |
|
|
626 |
# :return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)) |
|
|
627 |
# """ |
|
|
628 |
# # class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point. |
|
|
629 |
# class_ids = [] |
|
|
630 |
# fg_classes = cf.head_classes - 1 |
|
|
631 |
# # repeat vectors to fill in predictions for all foreground classes. |
|
|
632 |
# for ii in range(1, fg_classes + 1): |
|
|
633 |
# class_ids += [ii] * rois.shape[0] |
|
|
634 |
# class_ids = torch.from_numpy(np.array(class_ids)).cuda() |
|
|
635 |
# |
|
|
636 |
# rois = rois.repeat(fg_classes, 1) |
|
|
637 |
# probs = probs.repeat(fg_classes, 1) |
|
|
638 |
# deltas = deltas.repeat(fg_classes, 1, 1) |
|
|
639 |
# batch_ixs = batch_ixs.repeat(fg_classes) |
|
|
640 |
# |
|
|
641 |
# # get class-specific scores and bounding box deltas |
|
|
642 |
# idx = torch.arange(class_ids.size()[0]).long().cuda() |
|
|
643 |
# class_scores = probs[idx, class_ids] |
|
|
644 |
# deltas_specific = deltas[idx, class_ids] |
|
|
645 |
# batch_ixs = batch_ixs[idx] |
|
|
646 |
# |
|
|
647 |
# # apply bounding box deltas. re-scale to image coordinates. |
|
|
648 |
# std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() |
|
|
649 |
# scale = torch.from_numpy(cf.scale).float().cuda() |
|
|
650 |
# refined_rois = mutils.apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \ |
|
|
651 |
# mutils.apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale |
|
|
652 |
# |
|
|
653 |
# # round and cast to int since we're deadling with pixels now |
|
|
654 |
# refined_rois = mutils.clip_to_window(cf.window, refined_rois) |
|
|
655 |
# refined_rois = torch.round(refined_rois) |
|
|
656 |
# |
|
|
657 |
# # filter out low confidence boxes |
|
|
658 |
# keep = idx |
|
|
659 |
# keep_bool = (class_scores >= cf.model_min_confidence) |
|
|
660 |
# if 0 not in torch.nonzero(keep_bool).size(): |
|
|
661 |
# |
|
|
662 |
# score_keep = torch.nonzero(keep_bool)[:, 0] |
|
|
663 |
# pre_nms_class_ids = class_ids[score_keep] |
|
|
664 |
# pre_nms_rois = refined_rois[score_keep] |
|
|
665 |
# pre_nms_scores = class_scores[score_keep] |
|
|
666 |
# pre_nms_batch_ixs = batch_ixs[score_keep] |
|
|
667 |
# |
|
|
668 |
# for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): |
|
|
669 |
# |
|
|
670 |
# bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] |
|
|
671 |
# bix_class_ids = pre_nms_class_ids[bixs] |
|
|
672 |
# bix_rois = pre_nms_rois[bixs] |
|
|
673 |
# bix_scores = pre_nms_scores[bixs] |
|
|
674 |
# |
|
|
675 |
# for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): |
|
|
676 |
# |
|
|
677 |
# ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] |
|
|
678 |
# # nms expects boxes sorted by score. |
|
|
679 |
# ix_rois = bix_rois[ixs] |
|
|
680 |
# ix_scores = bix_scores[ixs] |
|
|
681 |
# ix_scores, order = ix_scores.sort(descending=True) |
|
|
682 |
# ix_rois = ix_rois[order, :] |
|
|
683 |
# |
|
|
684 |
# if cf.dim == 2: |
|
|
685 |
# class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) |
|
|
686 |
# else: |
|
|
687 |
# class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold) |
|
|
688 |
# |
|
|
689 |
# # map indices back. |
|
|
690 |
# class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]] |
|
|
691 |
# # merge indices over classes for current batch element |
|
|
692 |
# b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) |
|
|
693 |
# |
|
|
694 |
# # only keep top-k boxes of current batch-element |
|
|
695 |
# top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] |
|
|
696 |
# b_keep = b_keep[top_ids] |
|
|
697 |
# |
|
|
698 |
# # merge indices over batch elements. |
|
|
699 |
# batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) |
|
|
700 |
# |
|
|
701 |
# keep = batch_keep |
|
|
702 |
# |
|
|
703 |
# else: |
|
|
704 |
# keep = torch.tensor([0]).long().cuda() |
|
|
705 |
# |
|
|
706 |
# # arrange output |
|
|
707 |
# result = torch.cat((refined_rois[keep], |
|
|
708 |
# batch_ixs[keep].unsqueeze(1), |
|
|
709 |
# class_ids[keep].unsqueeze(1).float(), |
|
|
710 |
# class_scores[keep].unsqueeze(1)), dim=1) |
|
|
711 |
# |
|
|
712 |
# return result |
|
|
713 |
|
|
|
714 |
def refine_detections(cf, batch_ixs, rois, deltas, scores): |
|
|
715 |
""" |
|
|
716 |
Refine classified proposals (apply deltas to rpn rois), filter overlaps (nms) and return final detections. |
|
|
717 |
:param rois: (n_proposals, 2 * dim) normalized boxes as proposed by RPN. n_proposals = batch_size * POST_NMS_ROIS |
|
|
718 |
:param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by mrcnn bbox regressor. |
|
|
719 |
:param batch_ixs: (n_proposals) batch element assignment info for re-allocation. |
|
|
720 |
:param scores: (n_proposals, n_classes) probabilities for all classes per roi as predicted by mrcnn classifier. |
|
|
721 |
:return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score, *regression vector features)) |
|
|
722 |
""" |
|
|
723 |
# class IDs per ROI. Since scores of all classes are of interest (not just max class), all are kept at this point. |
|
|
724 |
class_ids = [] |
|
|
725 |
fg_classes = cf.head_classes - 1 |
|
|
726 |
# repeat vectors to fill in predictions for all foreground classes. |
|
|
727 |
for ii in range(1, fg_classes + 1): |
|
|
728 |
class_ids += [ii] * rois.shape[0] |
|
|
729 |
class_ids = torch.from_numpy(np.array(class_ids)).cuda() |
|
|
730 |
|
|
|
731 |
batch_ixs = batch_ixs.repeat(fg_classes) |
|
|
732 |
rois = rois.repeat(fg_classes, 1) |
|
|
733 |
deltas = deltas.repeat(fg_classes, 1, 1) |
|
|
734 |
scores = scores.repeat(fg_classes, 1) |
|
|
735 |
|
|
|
736 |
# get class-specific scores and bounding box deltas |
|
|
737 |
idx = torch.arange(class_ids.size()[0]).long().cuda() |
|
|
738 |
# using idx instead of slice [:,] squashes first dimension. |
|
|
739 |
#len(class_ids)>scores.shape[1] --> probs is broadcasted by expansion from fg_classes-->len(class_ids) |
|
|
740 |
batch_ixs = batch_ixs[idx] |
|
|
741 |
deltas_specific = deltas[idx, class_ids] |
|
|
742 |
class_scores = scores[idx, class_ids] |
|
|
743 |
|
|
|
744 |
# apply bounding box deltas. re-scale to image coordinates. |
|
|
745 |
std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda() |
|
|
746 |
scale = torch.from_numpy(cf.scale).float().cuda() |
|
|
747 |
refined_rois = mutils.apply_box_deltas_2D(rois, deltas_specific * std_dev) * scale if cf.dim == 2 else \ |
|
|
748 |
mutils.apply_box_deltas_3D(rois, deltas_specific * std_dev) * scale |
|
|
749 |
|
|
|
750 |
# round and cast to int since we're dealing with pixels now |
|
|
751 |
refined_rois = mutils.clip_to_window(cf.window, refined_rois) |
|
|
752 |
refined_rois = torch.round(refined_rois) |
|
|
753 |
|
|
|
754 |
# filter out low confidence boxes |
|
|
755 |
keep = idx |
|
|
756 |
keep_bool = (class_scores >= cf.model_min_confidence) |
|
|
757 |
if not 0 in torch.nonzero(keep_bool).size(): |
|
|
758 |
|
|
|
759 |
score_keep = torch.nonzero(keep_bool)[:, 0] |
|
|
760 |
pre_nms_class_ids = class_ids[score_keep] |
|
|
761 |
pre_nms_rois = refined_rois[score_keep] |
|
|
762 |
pre_nms_scores = class_scores[score_keep] |
|
|
763 |
pre_nms_batch_ixs = batch_ixs[score_keep] |
|
|
764 |
|
|
|
765 |
for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)): |
|
|
766 |
|
|
|
767 |
bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0] |
|
|
768 |
bix_class_ids = pre_nms_class_ids[bixs] |
|
|
769 |
bix_rois = pre_nms_rois[bixs] |
|
|
770 |
bix_scores = pre_nms_scores[bixs] |
|
|
771 |
|
|
|
772 |
for i, class_id in enumerate(mutils.unique1d(bix_class_ids)): |
|
|
773 |
|
|
|
774 |
ixs = torch.nonzero(bix_class_ids == class_id)[:, 0] |
|
|
775 |
# nms expects boxes sorted by score. |
|
|
776 |
ix_rois = bix_rois[ixs] |
|
|
777 |
ix_scores = bix_scores[ixs] |
|
|
778 |
ix_scores, order = ix_scores.sort(descending=True) |
|
|
779 |
ix_rois = ix_rois[order, :] |
|
|
780 |
|
|
|
781 |
class_keep = nms.nms(ix_rois, ix_scores, cf.detection_nms_threshold) |
|
|
782 |
|
|
|
783 |
# map indices back. |
|
|
784 |
class_keep = keep[score_keep[bixs[ixs[order[class_keep]]]]] |
|
|
785 |
# merge indices over classes for current batch element |
|
|
786 |
b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep))) |
|
|
787 |
|
|
|
788 |
# only keep top-k boxes of current batch-element |
|
|
789 |
top_ids = class_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element] |
|
|
790 |
b_keep = b_keep[top_ids] |
|
|
791 |
|
|
|
792 |
# merge indices over batch elements. |
|
|
793 |
batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep))) |
|
|
794 |
|
|
|
795 |
keep = batch_keep |
|
|
796 |
|
|
|
797 |
else: |
|
|
798 |
keep = torch.tensor([0]).long().cuda() |
|
|
799 |
|
|
|
800 |
# arrange output |
|
|
801 |
output = [refined_rois[keep], batch_ixs[keep].unsqueeze(1)] |
|
|
802 |
output += [class_ids[keep].unsqueeze(1).float(), class_scores[keep].unsqueeze(1)] |
|
|
803 |
|
|
|
804 |
result = torch.cat(output, dim=1) |
|
|
805 |
# shape: (n_keeps, catted feats), catted feats: [0:dim*2] are box_coords, [dim*2] are batch_ics, |
|
|
806 |
# [dim*2+1] are class_ids, [dim*2+2] are scores, [dim*2+3:] are regression vector features (incl uncertainty) |
|
|
807 |
return result |
|
|
808 |
|
|
|
809 |
|
|
|
810 |
def get_results(cf, img_shape, detections, detection_masks, box_results_list=None, return_masks=True): |
|
|
811 |
""" |
|
|
812 |
Restores batch dimension of merged detections, unmolds detections, creates and fills results dict. |
|
|
813 |
:param img_shape: |
|
|
814 |
:param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) |
|
|
815 |
:param detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. |
|
|
816 |
:param box_results_list: None or list of output boxes for monitoring/plotting. |
|
|
817 |
each element is a list of boxes per batch element. |
|
|
818 |
:param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). |
|
|
819 |
:return: results_dict: dictionary with keys: |
|
|
820 |
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: |
|
|
821 |
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] |
|
|
822 |
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, 1] only fg. vs. bg for now. |
|
|
823 |
class-specific return of masks will come with implementation of instance segmentation evaluation. |
|
|
824 |
""" |
|
|
825 |
detections = detections.cpu().data.numpy() |
|
|
826 |
if cf.dim == 2: |
|
|
827 |
detection_masks = detection_masks.permute(0, 2, 3, 1).cpu().data.numpy() |
|
|
828 |
else: |
|
|
829 |
detection_masks = detection_masks.permute(0, 2, 3, 4, 1).cpu().data.numpy() |
|
|
830 |
|
|
|
831 |
# restore batch dimension of merged detections using the batch_ix info. |
|
|
832 |
batch_ixs = detections[:, cf.dim*2] |
|
|
833 |
detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])] |
|
|
834 |
mrcnn_mask = [detection_masks[batch_ixs == ix] for ix in range(img_shape[0])] |
|
|
835 |
|
|
|
836 |
# for test_forward, where no previous list exists. |
|
|
837 |
if box_results_list is None: |
|
|
838 |
box_results_list = [[] for _ in range(img_shape[0])] |
|
|
839 |
|
|
|
840 |
seg_preds = [] |
|
|
841 |
# loop over batch and unmold detections. |
|
|
842 |
for ix in range(img_shape[0]): |
|
|
843 |
|
|
|
844 |
if 0 not in detections[ix].shape: |
|
|
845 |
boxes = detections[ix][:, :2 * cf.dim].astype(np.int32) |
|
|
846 |
class_ids = detections[ix][:, 2 * cf.dim + 1].astype(np.int32) |
|
|
847 |
scores = detections[ix][:, 2 * cf.dim + 2] |
|
|
848 |
masks = mrcnn_mask[ix][np.arange(boxes.shape[0]), ..., class_ids] |
|
|
849 |
|
|
|
850 |
# Filter out detections with zero area. Often only happens in early |
|
|
851 |
# stages of training when the network weights are still a bit random. |
|
|
852 |
if cf.dim == 2: |
|
|
853 |
exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0] |
|
|
854 |
else: |
|
|
855 |
exclude_ix = np.where( |
|
|
856 |
(boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0] |
|
|
857 |
|
|
|
858 |
if exclude_ix.shape[0] > 0: |
|
|
859 |
boxes = np.delete(boxes, exclude_ix, axis=0) |
|
|
860 |
class_ids = np.delete(class_ids, exclude_ix, axis=0) |
|
|
861 |
scores = np.delete(scores, exclude_ix, axis=0) |
|
|
862 |
masks = np.delete(masks, exclude_ix, axis=0) |
|
|
863 |
|
|
|
864 |
# Resize masks to original image size and set boundary threshold. |
|
|
865 |
full_masks = [] |
|
|
866 |
permuted_image_shape = list(img_shape[2:]) + [img_shape[1]] |
|
|
867 |
if return_masks: |
|
|
868 |
for i in range(masks.shape[0]): |
|
|
869 |
# Convert neural network mask to full size mask. |
|
|
870 |
full_masks.append(mutils.unmold_mask_2D(masks[i], boxes[i], permuted_image_shape) |
|
|
871 |
if cf.dim == 2 else mutils.unmold_mask_3D(masks[i], boxes[i], permuted_image_shape)) |
|
|
872 |
# if masks are returned, take max over binary full masks of all predictions in this image. |
|
|
873 |
# right now only binary masks for plotting/monitoring. for instance segmentation return all proposal masks. |
|
|
874 |
final_masks = np.max(np.array(full_masks), 0) if len(full_masks) > 0 else np.zeros( |
|
|
875 |
(*permuted_image_shape[:-1],)) |
|
|
876 |
|
|
|
877 |
# add final predictions to results. |
|
|
878 |
if 0 not in boxes.shape: |
|
|
879 |
for ix2, score in enumerate(scores): |
|
|
880 |
box_results_list[ix].append({'box_coords': boxes[ix2], 'box_score': score, |
|
|
881 |
'box_type': 'det', 'box_pred_class_id': class_ids[ix2]}) |
|
|
882 |
else: |
|
|
883 |
# pad with zero dummy masks. |
|
|
884 |
final_masks = np.zeros(img_shape[2:]) |
|
|
885 |
|
|
|
886 |
seg_preds.append(final_masks) |
|
|
887 |
|
|
|
888 |
# create and fill results dictionary. |
|
|
889 |
results_dict = {'boxes': box_results_list, |
|
|
890 |
'seg_preds': np.round(np.array(seg_preds))[:, np.newaxis].astype('uint8')} |
|
|
891 |
|
|
|
892 |
return results_dict |
|
|
893 |
|
|
|
894 |
|
|
|
895 |
############################################################ |
|
|
896 |
# Mask R-CNN Class |
|
|
897 |
############################################################ |
|
|
898 |
|
|
|
899 |
class net(nn.Module): |
|
|
900 |
|
|
|
901 |
|
|
|
902 |
def __init__(self, cf, logger): |
|
|
903 |
|
|
|
904 |
super(net, self).__init__() |
|
|
905 |
self.cf = cf |
|
|
906 |
self.logger = logger |
|
|
907 |
self.build() |
|
|
908 |
|
|
|
909 |
if self.cf.weight_init is not None: |
|
|
910 |
logger.info("using pytorch weight init of type {}".format(self.cf.weight_init)) |
|
|
911 |
mutils.initialize_weights(self) |
|
|
912 |
else: |
|
|
913 |
logger.info("using default pytorch weight init") |
|
|
914 |
|
|
|
915 |
|
|
|
916 |
def build(self): |
|
|
917 |
"""Build Mask R-CNN architecture.""" |
|
|
918 |
|
|
|
919 |
# Image size must be dividable by 2 multiple times. |
|
|
920 |
h, w = self.cf.patch_size[:2] |
|
|
921 |
if h / 2**5 != int(h / 2**5) or w / 2**5 != int(w / 2**5): |
|
|
922 |
raise Exception("Image size must be dividable by 2 at least 5 times " |
|
|
923 |
"to avoid fractions when downscaling and upscaling." |
|
|
924 |
"For example, use 256, 320, 384, 448, 512, ... etc. ") |
|
|
925 |
if len(self.cf.patch_size) == 3: |
|
|
926 |
d = self.cf.patch_size[2] |
|
|
927 |
if d / 2**3 != int(d / 2**3): |
|
|
928 |
raise Exception("Image z dimension must be dividable by 2 at least 3 times " |
|
|
929 |
"to avoid fractions when downscaling and upscaling.") |
|
|
930 |
|
|
|
931 |
|
|
|
932 |
|
|
|
933 |
# instanciate abstract multi dimensional conv class and backbone class. |
|
|
934 |
conv = mutils.NDConvGenerator(self.cf.dim) |
|
|
935 |
backbone = utils.import_module('bbone', self.cf.backbone_path) |
|
|
936 |
|
|
|
937 |
# build Anchors, FPN, RPN, Classifier / Bbox-Regressor -head, Mask-head |
|
|
938 |
self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf) |
|
|
939 |
self.anchors = torch.from_numpy(self.np_anchors).float().cuda() |
|
|
940 |
self.fpn = backbone.FPN(self.cf, conv) |
|
|
941 |
self.rpn = RPN(self.cf, conv) |
|
|
942 |
self.classifier = Classifier(self.cf, conv) |
|
|
943 |
self.mask = Mask(self.cf, conv) |
|
|
944 |
|
|
|
945 |
|
|
|
946 |
def train_forward(self, batch, is_validation=False): |
|
|
947 |
""" |
|
|
948 |
train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data |
|
|
949 |
for processing, computes losses, and stores outputs in a dictionary. |
|
|
950 |
:param batch: dictionary containing 'data', 'seg', etc. |
|
|
951 |
data_dict['roi_masks']: (b, n(b), 1, h(n), w(n) (z(n))) list like batch['class_target'] but with |
|
|
952 |
arrays (masks) inplace of integers. n == number of rois per this batch element. |
|
|
953 |
:return: results_dict: dictionary with keys: |
|
|
954 |
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: |
|
|
955 |
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] |
|
|
956 |
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. |
|
|
957 |
'monitor_values': dict of values to be monitored. |
|
|
958 |
""" |
|
|
959 |
img = batch['data'] |
|
|
960 |
if "roi_labels" in batch.keys(): |
|
|
961 |
raise Exception("Key for roi-wise class targets changed in v0.1.0 from 'roi_labels' to 'class_target'.\n" |
|
|
962 |
"If you use DKFZ's batchgenerators, please make sure you run version >= 0.20.1.") |
|
|
963 |
gt_class_ids = batch['class_target'] |
|
|
964 |
gt_boxes = batch['bb_target'] |
|
|
965 |
#axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) |
|
|
966 |
#gt_masks = [np.transpose(batch['roi_masks'][ii], axes=axes) for ii in range(len(batch['roi_masks']))] |
|
|
967 |
# --> now GT masks has c==channels in last dimension. |
|
|
968 |
gt_masks = batch['roi_masks'] |
|
|
969 |
img = torch.from_numpy(img).float().cuda() |
|
|
970 |
batch_rpn_class_loss = torch.FloatTensor([0]).cuda() |
|
|
971 |
batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() |
|
|
972 |
|
|
|
973 |
# list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. |
|
|
974 |
box_results_list = [[] for _ in range(img.shape[0])] |
|
|
975 |
|
|
|
976 |
#forward passes. 1. general forward pass, where no activations are saved in second stage (for performance |
|
|
977 |
# monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. |
|
|
978 |
rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, detection_masks = self.forward(img) |
|
|
979 |
mrcnn_class_logits, mrcnn_pred_deltas, mrcnn_pred_mask, target_class_ids, mrcnn_target_deltas, target_mask, \ |
|
|
980 |
sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes, gt_masks) |
|
|
981 |
|
|
|
982 |
# loop over batch |
|
|
983 |
for b in range(img.shape[0]): |
|
|
984 |
if len(gt_boxes[b]) > 0: |
|
|
985 |
|
|
|
986 |
# add gt boxes to output list for monitoring. |
|
|
987 |
for ix in range(len(gt_boxes[b])): |
|
|
988 |
box_results_list[b].append({'box_coords': batch['bb_target'][b][ix], |
|
|
989 |
'box_label': batch['class_target'][b][ix], 'box_type': 'gt'}) |
|
|
990 |
|
|
|
991 |
# match gt boxes with anchors to generate targets for RPN losses. |
|
|
992 |
rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) |
|
|
993 |
|
|
|
994 |
# add positive anchors used for loss to output list for monitoring. |
|
|
995 |
pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) |
|
|
996 |
for p in pos_anchors: |
|
|
997 |
box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) |
|
|
998 |
|
|
|
999 |
else: |
|
|
1000 |
rpn_match = np.array([-1]*self.np_anchors.shape[0]) |
|
|
1001 |
rpn_target_deltas = np.array([0]) |
|
|
1002 |
|
|
|
1003 |
rpn_match_gpu = torch.from_numpy(rpn_match).cuda() |
|
|
1004 |
rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() |
|
|
1005 |
|
|
|
1006 |
# compute RPN losses. |
|
|
1007 |
rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match_gpu, rpn_class_logits[b], self.cf.shem_poolsize) |
|
|
1008 |
rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match_gpu) |
|
|
1009 |
batch_rpn_class_loss += rpn_class_loss / img.shape[0] |
|
|
1010 |
batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0] |
|
|
1011 |
|
|
|
1012 |
# add negative anchors used for loss to output list for monitoring. |
|
|
1013 |
neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[rpn_match == -1][neg_anchor_ix], img.shape[2:]) |
|
|
1014 |
for n in neg_anchors: |
|
|
1015 |
box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) |
|
|
1016 |
|
|
|
1017 |
# add highest scoring proposals to output list for monitoring. |
|
|
1018 |
rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] |
|
|
1019 |
for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: |
|
|
1020 |
box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) |
|
|
1021 |
|
|
|
1022 |
# add positive and negative roi samples used for mrcnn losses to output list for monitoring. |
|
|
1023 |
if 0 not in sample_proposals.shape: |
|
|
1024 |
rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() |
|
|
1025 |
for ix, r in enumerate(rois): |
|
|
1026 |
box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, |
|
|
1027 |
'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) |
|
|
1028 |
|
|
|
1029 |
batch_rpn_class_loss = batch_rpn_class_loss |
|
|
1030 |
batch_rpn_bbox_loss = batch_rpn_bbox_loss |
|
|
1031 |
|
|
|
1032 |
# compute mrcnn losses. |
|
|
1033 |
mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits) |
|
|
1034 |
mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids) |
|
|
1035 |
|
|
|
1036 |
# mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). |
|
|
1037 |
# In this case, the mask_loss is taken out of training. |
|
|
1038 |
if not self.cf.frcnn_mode: |
|
|
1039 |
mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids) |
|
|
1040 |
else: |
|
|
1041 |
mrcnn_mask_loss = torch.FloatTensor([0]).cuda() |
|
|
1042 |
|
|
|
1043 |
loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + mrcnn_mask_loss |
|
|
1044 |
|
|
|
1045 |
# monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. |
|
|
1046 |
dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] |
|
|
1047 |
|
|
|
1048 |
|
|
|
1049 |
|
|
|
1050 |
# run unmolding of predictions for monitoring and merge all results to one dictionary. |
|
|
1051 |
return_masks = True#self.cf.return_masks_in_val if is_validation else False |
|
|
1052 |
results_dict = get_results(self.cf, img.shape, detections, detection_masks, |
|
|
1053 |
box_results_list, return_masks=return_masks) |
|
|
1054 |
|
|
|
1055 |
results_dict['torch_loss'] = loss |
|
|
1056 |
results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()} |
|
|
1057 |
|
|
|
1058 |
results_dict['logger_string'] = \ |
|
|
1059 |
"loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, mrcnn_bbox: {4:.2f}, " \ |
|
|
1060 |
"mrcnn_mask: {5:.2f}, dcount {6}".format(loss.item(), batch_rpn_class_loss.item(), |
|
|
1061 |
batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), |
|
|
1062 |
mrcnn_bbox_loss.item(), mrcnn_mask_loss.item(), dcount) |
|
|
1063 |
|
|
|
1064 |
return results_dict |
|
|
1065 |
|
|
|
1066 |
|
|
|
1067 |
def test_forward(self, batch, return_masks=True): |
|
|
1068 |
""" |
|
|
1069 |
test method. wrapper around forward pass of network without usage of any ground truth information. |
|
|
1070 |
prepares input data for processing and stores outputs in a dictionary. |
|
|
1071 |
:param batch: dictionary containing 'data' |
|
|
1072 |
:param return_masks: boolean. If True, full resolution masks are returned for all proposals (speed trade-off). |
|
|
1073 |
:return: results_dict: dictionary with keys: |
|
|
1074 |
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: |
|
|
1075 |
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] |
|
|
1076 |
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] |
|
|
1077 |
""" |
|
|
1078 |
img = batch['data'] |
|
|
1079 |
img = torch.from_numpy(img).float().cuda() |
|
|
1080 |
_, _, _, detections, detection_masks = self.forward(img) |
|
|
1081 |
results_dict = get_results(self.cf, img.shape, detections, detection_masks, return_masks=return_masks) |
|
|
1082 |
return results_dict |
|
|
1083 |
|
|
|
1084 |
|
|
|
1085 |
def forward(self, img, is_training=True): |
|
|
1086 |
""" |
|
|
1087 |
:param img: input images (b, c, y, x, (z)). |
|
|
1088 |
:return: rpn_pred_logits: (b, n_anchors, 2) |
|
|
1089 |
:return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d)))) |
|
|
1090 |
:return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting. |
|
|
1091 |
:return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score) |
|
|
1092 |
:return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head. |
|
|
1093 |
""" |
|
|
1094 |
# extract features. |
|
|
1095 |
fpn_outs = self.fpn(img) |
|
|
1096 |
rpn_feature_maps = [fpn_outs[i] for i in self.cf.pyramid_levels] |
|
|
1097 |
self.mrcnn_feature_maps = rpn_feature_maps |
|
|
1098 |
|
|
|
1099 |
# loop through pyramid layers and apply RPN. |
|
|
1100 |
layer_outputs = [] # list of lists |
|
|
1101 |
for p in rpn_feature_maps: |
|
|
1102 |
layer_outputs.append(self.rpn(p)) |
|
|
1103 |
|
|
|
1104 |
# concatenate layer outputs. |
|
|
1105 |
# convert from list of lists of level outputs to list of lists of outputs across levels. |
|
|
1106 |
# e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]] |
|
|
1107 |
outputs = list(zip(*layer_outputs)) |
|
|
1108 |
outputs = [torch.cat(list(o), dim=1) for o in outputs] |
|
|
1109 |
rpn_pred_logits, rpn_pred_probs, rpn_pred_deltas = outputs |
|
|
1110 |
|
|
|
1111 |
# generate proposals: apply predicted deltas to anchors and filter by foreground scores from RPN classifier. |
|
|
1112 |
proposal_count = self.cf.post_nms_rois_training if is_training else self.cf.post_nms_rois_inference |
|
|
1113 |
batch_rpn_rois, batch_proposal_boxes = refine_proposals(rpn_pred_probs, rpn_pred_deltas, proposal_count, self.anchors, self.cf) |
|
|
1114 |
|
|
|
1115 |
# merge batch dimension of proposals while storing allocation info in coordinate dimension. |
|
|
1116 |
batch_ixs = torch.from_numpy(np.repeat(np.arange(batch_rpn_rois.shape[0]), batch_rpn_rois.shape[1])).float().cuda() |
|
|
1117 |
rpn_rois = batch_rpn_rois.view(-1, batch_rpn_rois.shape[2]) |
|
|
1118 |
self.rpn_rois_batch_info = torch.cat((rpn_rois, batch_ixs.unsqueeze(1)), dim=1) |
|
|
1119 |
|
|
|
1120 |
# this is the first of two forward passes in the second stage, where no activations are stored for backprop. |
|
|
1121 |
# here, all proposals are forwarded (with virtual_batch_size = batch_size * post_nms_rois.) |
|
|
1122 |
# for inference/monitoring as well as sampling of rois for the loss functions. |
|
|
1123 |
# processed in chunks of roi_chunk_size to re-adjust to gpu-memory. |
|
|
1124 |
chunked_rpn_rois = self.rpn_rois_batch_info.split(self.cf.roi_chunk_size) |
|
|
1125 |
class_logits_list, bboxes_list = [], [] |
|
|
1126 |
with torch.no_grad(): |
|
|
1127 |
for chunk in chunked_rpn_rois: |
|
|
1128 |
chunk_class_logits, chunk_bboxes = self.classifier(self.mrcnn_feature_maps, chunk) |
|
|
1129 |
class_logits_list.append(chunk_class_logits) |
|
|
1130 |
bboxes_list.append(chunk_bboxes) |
|
|
1131 |
batch_mrcnn_class_logits = torch.cat(class_logits_list, 0) |
|
|
1132 |
batch_mrcnn_bbox = torch.cat(bboxes_list, 0) |
|
|
1133 |
self.batch_mrcnn_class_scores = F.softmax(batch_mrcnn_class_logits, dim=1) |
|
|
1134 |
|
|
|
1135 |
# refine classified proposals, filter and return final detections. |
|
|
1136 |
detections = refine_detections(self.cf, batch_ixs, rpn_rois, batch_mrcnn_bbox, self.batch_mrcnn_class_scores) |
|
|
1137 |
|
|
|
1138 |
# forward remaining detections through mask-head to generate corresponding masks. |
|
|
1139 |
scale = [img.shape[2]] * 4 + [img.shape[-1]] * 2 |
|
|
1140 |
scale = torch.from_numpy(np.array(scale[:self.cf.dim * 2] + [1])[None]).float().cuda() |
|
|
1141 |
|
|
|
1142 |
|
|
|
1143 |
detection_boxes = detections[:, :self.cf.dim * 2 + 1] / scale |
|
|
1144 |
with torch.no_grad(): |
|
|
1145 |
detection_masks = self.mask(self.mrcnn_feature_maps, detection_boxes) |
|
|
1146 |
|
|
|
1147 |
return [rpn_pred_logits, rpn_pred_deltas, batch_proposal_boxes, detections, detection_masks] |
|
|
1148 |
|
|
|
1149 |
|
|
|
1150 |
def loss_samples_forward(self, batch_gt_class_ids, batch_gt_boxes, batch_gt_masks): |
|
|
1151 |
""" |
|
|
1152 |
this is the second forward pass through the second stage (features from stage one are re-used). |
|
|
1153 |
samples few rois in detection_target_layer and forwards only those for loss computation. |
|
|
1154 |
:param batch_gt_class_ids: list over batch elements. Each element is a list over the corresponding roi target labels. |
|
|
1155 |
:param batch_gt_boxes: list over batch elements. Each element is a list over the corresponding roi target coordinates. |
|
|
1156 |
:param batch_gt_masks: list over batch elements. Each element is binary mask of shape (n_gt_rois, y, x, (z), c) |
|
|
1157 |
:return: sample_logits: (n_sampled_rois, n_classes) predicted class scores. |
|
|
1158 |
:return: sample_boxes: (n_sampled_rois, n_classes, 2 * dim) predicted corrections to be applied to proposals for refinement. |
|
|
1159 |
:return: sample_mask: (n_sampled_rois, n_classes, y, x, (z)) predicted masks per class and proposal. |
|
|
1160 |
:return: sample_target_class_ids: (n_sampled_rois) target class labels of sampled proposals. |
|
|
1161 |
:return: sample_target_deltas: (n_sampled_rois, 2 * dim) target deltas of sampled proposals for box refinement. |
|
|
1162 |
:return: sample_target_masks: (n_sampled_rois, y, x, (z)) target masks of sampled proposals. |
|
|
1163 |
:return: sample_proposals: (n_sampled_rois, 2 * dim) RPN output for sampled proposals. only for monitoring/plotting. |
|
|
1164 |
""" |
|
|
1165 |
# sample rois for loss and get corresponding targets for all Mask R-CNN head network losses. |
|
|
1166 |
sample_ix, sample_target_class_ids, sample_target_deltas, sample_target_mask = \ |
|
|
1167 |
detection_target_layer(self.rpn_rois_batch_info, self.batch_mrcnn_class_scores, |
|
|
1168 |
batch_gt_class_ids, batch_gt_boxes, batch_gt_masks, self.cf) |
|
|
1169 |
|
|
|
1170 |
# re-use feature maps and RPN output from first forward pass. |
|
|
1171 |
sample_proposals = self.rpn_rois_batch_info[sample_ix] |
|
|
1172 |
if 0 not in sample_proposals.size(): |
|
|
1173 |
sample_logits, sample_boxes = self.classifier(self.mrcnn_feature_maps, sample_proposals) |
|
|
1174 |
sample_mask = self.mask(self.mrcnn_feature_maps, sample_proposals) |
|
|
1175 |
else: |
|
|
1176 |
sample_logits = torch.FloatTensor().cuda() |
|
|
1177 |
sample_boxes = torch.FloatTensor().cuda() |
|
|
1178 |
sample_mask = torch.FloatTensor().cuda() |
|
|
1179 |
|
|
|
1180 |
return [sample_logits, sample_boxes, sample_mask, sample_target_class_ids, sample_target_deltas, |
|
|
1181 |
sample_target_mask, sample_proposals] |