|
a |
|
b/funcs.py |
|
|
1 |
from skimage.measure import label |
|
|
2 |
#Scientific computing |
|
|
3 |
import numpy as np |
|
|
4 |
import os |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from torch.nn.functional import one_hot |
|
|
8 |
import cv2 |
|
|
9 |
import torch |
|
|
10 |
import random |
|
|
11 |
#Pytorch packages |
|
|
12 |
|
|
|
13 |
def random_sum_to(n, num_terms = None): |
|
|
14 |
''' |
|
|
15 |
generate num_tersm with sum as n |
|
|
16 |
''' |
|
|
17 |
num_terms = (num_terms or r.randint(2, n)) - 1 |
|
|
18 |
a = random.sample(range(1, n), num_terms) + [0, n] |
|
|
19 |
list.sort(a) |
|
|
20 |
return [a[i+1] - a[i] for i in range(len(a) - 1)] |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def get_first_prompt(mask_cls,dist_thre_ratio=0.3,prompt_num=5,max_prompt_num=15,region_type='random'): |
|
|
25 |
''' |
|
|
26 |
if region_type = random, we random select one region and generate prompt |
|
|
27 |
if region_type = all, we generate prompt at each object region |
|
|
28 |
if region_type = largest_k, we generate prompt at largest k region, k <10 |
|
|
29 |
''' |
|
|
30 |
if prompt_num==-1: |
|
|
31 |
prompt_num = random.randint(1, max_prompt_num) |
|
|
32 |
# Find all disconnected regions |
|
|
33 |
label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True) |
|
|
34 |
#print('num of regions found', region_ids) |
|
|
35 |
ratio_list, regionid_list = [], [] |
|
|
36 |
for region_id in range(1, region_ids+1): |
|
|
37 |
#find coordinates of points in the region |
|
|
38 |
binary_msk = np.where(label_msk==region_id, 1, 0) |
|
|
39 |
|
|
|
40 |
# clean some region that is abnormally small |
|
|
41 |
r = np.sum(binary_msk) / np.sum(mask_cls) |
|
|
42 |
#print('curr mask over all mask ratio', r) |
|
|
43 |
ratio_list.append(r) |
|
|
44 |
regionid_list.append(region_id) |
|
|
45 |
if len(ratio_list)>0: |
|
|
46 |
ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list))) |
|
|
47 |
regionid_list = regionid_list[::-1] |
|
|
48 |
|
|
|
49 |
if region_type == 'random': |
|
|
50 |
prompt_num = 1 |
|
|
51 |
regionid_list = [random.choice(regionid_list)] # random choose 1 region |
|
|
52 |
prompt_num_each_region = [1] |
|
|
53 |
elif region_type[:7] == 'largest': |
|
|
54 |
region_max_num = int(region_type.split('_')[-1]) |
|
|
55 |
#print(region_max_num,prompt_num,len(regionid_list)) |
|
|
56 |
valid_region = min(region_max_num,len(regionid_list)) |
|
|
57 |
if valid_region<prompt_num: |
|
|
58 |
prompt_num_each_region = random_sum_to(prompt_num,valid_region) |
|
|
59 |
else: |
|
|
60 |
prompt_num_each_region = prompt_num*[1] |
|
|
61 |
regionid_list = regionid_list[:min(valid_region,prompt_num)] |
|
|
62 |
#print(prompt_num_each_region) |
|
|
63 |
else: |
|
|
64 |
prompt_num_each_region = len(regionid_list)*[1] |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
prompt = [] |
|
|
68 |
mask_curr = np.zeros_like(label_msk) |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
for reg_id in range(len(regionid_list)): |
|
|
72 |
binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0) |
|
|
73 |
mask_curr = np.logical_or(binary_msk,mask_curr) |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
padded_mask = np.uint8(np.pad(binary_msk, ((1, 1), (1, 1)), 'constant')) |
|
|
77 |
dist_img = cv2.distanceTransform(padded_mask, distanceType=cv2.DIST_L2, maskSize=5).astype(np.float32)[1:-1, 1:-1] |
|
|
78 |
|
|
|
79 |
# sort the distances |
|
|
80 |
dist_array=sorted(dist_img.copy().flatten())[::-1] |
|
|
81 |
dist_array = np.array(dist_array) |
|
|
82 |
# find the threshold: |
|
|
83 |
dis_thre = max(dist_array[int(dist_thre_ratio*np.sum(dist_array>0))],1) |
|
|
84 |
#print(np.max(dist_array)) |
|
|
85 |
#print(dis_thre) |
|
|
86 |
cY, cX = np.where(dist_img>=dis_thre) |
|
|
87 |
while prompt_num_each_region[reg_id]>0: |
|
|
88 |
# random select one prompt |
|
|
89 |
random_idx = np.random.randint(0, len(cX)) |
|
|
90 |
cx, cy = int(cX[random_idx]), int(cY[random_idx]) |
|
|
91 |
prompt.append((cx,cy,1)) |
|
|
92 |
prompt_num_each_region[reg_id] -=1 |
|
|
93 |
|
|
|
94 |
while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size |
|
|
95 |
prompt.append((cx,cy,1)) |
|
|
96 |
else: # if this image doesn't have target object |
|
|
97 |
prompt = [(0,0,-1)] |
|
|
98 |
mask_curr = np.zeros_like(label_msk) |
|
|
99 |
while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size |
|
|
100 |
prompt.append((0,0,-1)) |
|
|
101 |
prompt = np.array(prompt) |
|
|
102 |
mask_curr = np.array(mask_curr,dtype=int) |
|
|
103 |
return prompt,mask_curr |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
def get_top_boxes(mask_cls,dist_thre_ratio=0.10,prompt_num=15,region_type='largest_15'): |
|
|
107 |
# Find all disconnected regions |
|
|
108 |
label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True) |
|
|
109 |
#print('num of regions found', region_ids) |
|
|
110 |
ratio_list, regionid_list = [], [] |
|
|
111 |
for region_id in range(1, region_ids+1): |
|
|
112 |
#find coordinates of points in the region |
|
|
113 |
binary_msk = np.where(label_msk==region_id, 1, 0) |
|
|
114 |
|
|
|
115 |
# clean some region that is abnormally small |
|
|
116 |
r = np.sum(binary_msk) / np.sum(mask_cls) |
|
|
117 |
#print('curr mask over all mask ratio', r) |
|
|
118 |
ratio_list.append(r) |
|
|
119 |
regionid_list.append(region_id) |
|
|
120 |
if len(ratio_list)>0: |
|
|
121 |
# sort the region from largest to smallest |
|
|
122 |
ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list))) |
|
|
123 |
regionid_list = regionid_list[::-1] |
|
|
124 |
|
|
|
125 |
if region_type == 'random': |
|
|
126 |
prompt_num = 1 |
|
|
127 |
regionid_list = [random.choice(regionid_list)] # random choose 1 region |
|
|
128 |
elif region_type[:7] == 'largest': |
|
|
129 |
region_max_num = int(region_type.split('_')[-1]) |
|
|
130 |
regionid_list = regionid_list[:min(region_max_num,len(regionid_list))] |
|
|
131 |
|
|
|
132 |
prompt = [] |
|
|
133 |
mask_curr = np.zeros_like(label_msk) |
|
|
134 |
for reg_id in range(len(regionid_list)): |
|
|
135 |
binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0) |
|
|
136 |
mask_curr = np.logical_or(binary_msk,mask_curr) |
|
|
137 |
box = MaskToBoxSimple(binary_msk,dist_thre_ratio) |
|
|
138 |
prompt.append(box) |
|
|
139 |
|
|
|
140 |
while len(prompt)<prompt_num: # repeat prompt to ensure the same size |
|
|
141 |
prompt.append(box) |
|
|
142 |
prompt = np.array(prompt) |
|
|
143 |
mask_curr = np.array(mask_curr,dtype=int) |
|
|
144 |
else: |
|
|
145 |
prompt = [[0,0,0,0]] |
|
|
146 |
mask_curr = np.zeros_like(label_msk) |
|
|
147 |
while len(prompt)<prompt_num: |
|
|
148 |
prompt.append(prompt[0]) |
|
|
149 |
return prompt,mask_curr |
|
|
150 |
|
|
|
151 |
def MaskToBoxSimple(mask,random_thre=0.05): |
|
|
152 |
''' |
|
|
153 |
random_thre, the randomness at each side of box |
|
|
154 |
''' |
|
|
155 |
mask = mask.squeeze() |
|
|
156 |
|
|
|
157 |
y_max,x_max = mask.shape[0],mask.shape[1] |
|
|
158 |
|
|
|
159 |
#find coordinates of points in the region |
|
|
160 |
row, col = np.argwhere(mask).T |
|
|
161 |
# find the four corner coordinates |
|
|
162 |
y0,x0 = row.min(),col.min() |
|
|
163 |
y1,x1 = row.max(),col.max() |
|
|
164 |
|
|
|
165 |
y_thre = (y1-y0)*random_thre |
|
|
166 |
x_thre = (x1-x0)*random_thre |
|
|
167 |
|
|
|
168 |
x0 = max(0,x0-x_thre*random.random()) |
|
|
169 |
x1 = min(x_max,x1+x_thre*random.random()) |
|
|
170 |
|
|
|
171 |
y0 = max(0,y0-y_thre*random.random()) |
|
|
172 |
y1 = min(y_max,y1+y_thre*random.random()) |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
return [x0,y0,x1,y1] |
|
|
176 |
|
|
|
177 |
def min_max_normalize(tensor,p=0.01): |
|
|
178 |
p_min = torch.quantile(tensor,p) |
|
|
179 |
p_max = torch.quantile(tensor,1-p) |
|
|
180 |
tensor = torch.clamp(tensor,p_min,p_max) |
|
|
181 |
return tensor |
|
|
182 |
|