Diff of /funcs.py [000000] .. [dff9e0]

Switch to unified view

a b/funcs.py
1
from skimage.measure import label
2
#Scientific computing 
3
import numpy as np
4
import os
5
import matplotlib.pyplot as plt
6
import torch.nn.functional as F
7
from torch.nn.functional import one_hot
8
import cv2
9
import torch
10
import random
11
#Pytorch packages
12
13
def random_sum_to(n, num_terms = None):
14
    '''
15
    generate num_tersm with sum as n
16
    '''
17
    num_terms = (num_terms or r.randint(2, n)) - 1
18
    a = random.sample(range(1, n), num_terms) + [0, n]
19
    list.sort(a)
20
    return [a[i+1] - a[i] for i in range(len(a) - 1)]
21
22
23
24
def get_first_prompt(mask_cls,dist_thre_ratio=0.3,prompt_num=5,max_prompt_num=15,region_type='random'):
25
    '''
26
    if region_type = random, we random select one region and generate prompt
27
    if region_type = all, we generate prompt at each object region
28
    if region_type = largest_k, we generate prompt at largest k region, k <10
29
    '''
30
    if prompt_num==-1:
31
        prompt_num = random.randint(1, max_prompt_num)
32
    # Find all disconnected regions
33
    label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True)
34
    #print('num of regions found', region_ids)
35
    ratio_list, regionid_list = [], []
36
    for region_id in range(1, region_ids+1):
37
        #find coordinates of points in the region
38
        binary_msk = np.where(label_msk==region_id, 1, 0)
39
40
        # clean some region that is abnormally small
41
        r = np.sum(binary_msk) / np.sum(mask_cls)
42
        #print('curr mask over all mask ratio', r)
43
        ratio_list.append(r)
44
        regionid_list.append(region_id)
45
    if len(ratio_list)>0:
46
        ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list)))
47
        regionid_list = regionid_list[::-1]
48
    
49
        if region_type == 'random':
50
            prompt_num = 1
51
            regionid_list = [random.choice(regionid_list)] # random choose 1 region
52
            prompt_num_each_region = [1]
53
        elif region_type[:7] == 'largest':
54
            region_max_num = int(region_type.split('_')[-1])
55
            #print(region_max_num,prompt_num,len(regionid_list))
56
            valid_region = min(region_max_num,len(regionid_list))
57
            if valid_region<prompt_num:
58
                prompt_num_each_region = random_sum_to(prompt_num,valid_region)
59
            else:
60
                prompt_num_each_region = prompt_num*[1]
61
            regionid_list = regionid_list[:min(valid_region,prompt_num)]
62
            #print(prompt_num_each_region)
63
        else:
64
            prompt_num_each_region = len(regionid_list)*[1]
65
66
67
        prompt = []
68
        mask_curr = np.zeros_like(label_msk)
69
        
70
71
        for reg_id in range(len(regionid_list)):
72
            binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0)
73
            mask_curr = np.logical_or(binary_msk,mask_curr)
74
75
76
            padded_mask = np.uint8(np.pad(binary_msk, ((1, 1), (1, 1)), 'constant'))
77
            dist_img = cv2.distanceTransform(padded_mask, distanceType=cv2.DIST_L2, maskSize=5).astype(np.float32)[1:-1, 1:-1]
78
79
            # sort the distances 
80
            dist_array=sorted(dist_img.copy().flatten())[::-1]
81
            dist_array = np.array(dist_array)
82
            # find the threshold:
83
            dis_thre = max(dist_array[int(dist_thre_ratio*np.sum(dist_array>0))],1)
84
            #print(np.max(dist_array))
85
            #print(dis_thre)
86
            cY, cX = np.where(dist_img>=dis_thre)
87
            while prompt_num_each_region[reg_id]>0:
88
                # random select one prompt
89
                random_idx = np.random.randint(0, len(cX))
90
                cx, cy = int(cX[random_idx]), int(cY[random_idx])
91
                prompt.append((cx,cy,1))
92
                prompt_num_each_region[reg_id] -=1
93
94
        while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size
95
            prompt.append((cx,cy,1))
96
    else: # if this image doesn't have target object
97
        prompt = [(0,0,-1)]
98
        mask_curr = np.zeros_like(label_msk)
99
        while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size
100
            prompt.append((0,0,-1))
101
    prompt = np.array(prompt) 
102
    mask_curr = np.array(mask_curr,dtype=int)
103
    return prompt,mask_curr
104
105
106
def get_top_boxes(mask_cls,dist_thre_ratio=0.10,prompt_num=15,region_type='largest_15'):
107
    # Find all disconnected regions
108
    label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True)
109
    #print('num of regions found', region_ids)
110
    ratio_list, regionid_list = [], []
111
    for region_id in range(1, region_ids+1):
112
        #find coordinates of points in the region
113
        binary_msk = np.where(label_msk==region_id, 1, 0)
114
115
        # clean some region that is abnormally small
116
        r = np.sum(binary_msk) / np.sum(mask_cls)
117
        #print('curr mask over all mask ratio', r)
118
        ratio_list.append(r)
119
        regionid_list.append(region_id)
120
    if len(ratio_list)>0:
121
        # sort the region from largest to smallest
122
        ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list)))
123
        regionid_list = regionid_list[::-1]
124
125
        if region_type == 'random':
126
            prompt_num = 1
127
            regionid_list = [random.choice(regionid_list)] # random choose 1 region
128
        elif region_type[:7] == 'largest':
129
            region_max_num = int(region_type.split('_')[-1])
130
            regionid_list = regionid_list[:min(region_max_num,len(regionid_list))]
131
132
        prompt = []
133
        mask_curr = np.zeros_like(label_msk)
134
        for reg_id in range(len(regionid_list)):
135
            binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0)
136
            mask_curr = np.logical_or(binary_msk,mask_curr)
137
            box = MaskToBoxSimple(binary_msk,dist_thre_ratio)
138
            prompt.append(box)
139
140
        while len(prompt)<prompt_num: # repeat prompt to ensure the same size
141
            prompt.append(box)
142
        prompt = np.array(prompt) 
143
        mask_curr = np.array(mask_curr,dtype=int)
144
    else:
145
        prompt = [[0,0,0,0]]
146
        mask_curr = np.zeros_like(label_msk)
147
        while len(prompt)<prompt_num:
148
            prompt.append(prompt[0])
149
    return prompt,mask_curr
150
        
151
def MaskToBoxSimple(mask,random_thre=0.05):
152
    '''
153
    random_thre, the randomness at each side of box
154
    '''
155
    mask = mask.squeeze()
156
    
157
    y_max,x_max = mask.shape[0],mask.shape[1]
158
    
159
    #find coordinates of points in the region
160
    row, col = np.argwhere(mask).T
161
    # find the four corner coordinates
162
    y0,x0 = row.min(),col.min()
163
    y1,x1 = row.max(),col.max()
164
    
165
    y_thre = (y1-y0)*random_thre
166
    x_thre = (x1-x0)*random_thre
167
    
168
    x0 = max(0,x0-x_thre*random.random())
169
    x1 = min(x_max,x1+x_thre*random.random())
170
    
171
    y0 = max(0,y0-y_thre*random.random())
172
    y1 = min(y_max,y1+y_thre*random.random())
173
    
174
175
    return [x0,y0,x1,y1]
176
177
def min_max_normalize(tensor,p=0.01):
178
    p_min = torch.quantile(tensor,p)
179
    p_max = torch.quantile(tensor,1-p)
180
    tensor = torch.clamp(tensor,p_min,p_max)
181
    return tensor
182