Diff of /funcs.py [000000] .. [dff9e0]

Switch to side-by-side view

--- a
+++ b/funcs.py
@@ -0,0 +1,182 @@
+from skimage.measure import label
+#Scientific computing 
+import numpy as np
+import os
+import matplotlib.pyplot as plt
+import torch.nn.functional as F
+from torch.nn.functional import one_hot
+import cv2
+import torch
+import random
+#Pytorch packages
+
+def random_sum_to(n, num_terms = None):
+    '''
+    generate num_tersm with sum as n
+    '''
+    num_terms = (num_terms or r.randint(2, n)) - 1
+    a = random.sample(range(1, n), num_terms) + [0, n]
+    list.sort(a)
+    return [a[i+1] - a[i] for i in range(len(a) - 1)]
+
+
+
+def get_first_prompt(mask_cls,dist_thre_ratio=0.3,prompt_num=5,max_prompt_num=15,region_type='random'):
+    '''
+    if region_type = random, we random select one region and generate prompt
+    if region_type = all, we generate prompt at each object region
+    if region_type = largest_k, we generate prompt at largest k region, k <10
+    '''
+    if prompt_num==-1:
+        prompt_num = random.randint(1, max_prompt_num)
+    # Find all disconnected regions
+    label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True)
+    #print('num of regions found', region_ids)
+    ratio_list, regionid_list = [], []
+    for region_id in range(1, region_ids+1):
+        #find coordinates of points in the region
+        binary_msk = np.where(label_msk==region_id, 1, 0)
+
+        # clean some region that is abnormally small
+        r = np.sum(binary_msk) / np.sum(mask_cls)
+        #print('curr mask over all mask ratio', r)
+        ratio_list.append(r)
+        regionid_list.append(region_id)
+    if len(ratio_list)>0:
+        ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list)))
+        regionid_list = regionid_list[::-1]
+    
+        if region_type == 'random':
+            prompt_num = 1
+            regionid_list = [random.choice(regionid_list)] # random choose 1 region
+            prompt_num_each_region = [1]
+        elif region_type[:7] == 'largest':
+            region_max_num = int(region_type.split('_')[-1])
+            #print(region_max_num,prompt_num,len(regionid_list))
+            valid_region = min(region_max_num,len(regionid_list))
+            if valid_region<prompt_num:
+                prompt_num_each_region = random_sum_to(prompt_num,valid_region)
+            else:
+                prompt_num_each_region = prompt_num*[1]
+            regionid_list = regionid_list[:min(valid_region,prompt_num)]
+            #print(prompt_num_each_region)
+        else:
+            prompt_num_each_region = len(regionid_list)*[1]
+
+
+        prompt = []
+        mask_curr = np.zeros_like(label_msk)
+        
+
+        for reg_id in range(len(regionid_list)):
+            binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0)
+            mask_curr = np.logical_or(binary_msk,mask_curr)
+
+
+            padded_mask = np.uint8(np.pad(binary_msk, ((1, 1), (1, 1)), 'constant'))
+            dist_img = cv2.distanceTransform(padded_mask, distanceType=cv2.DIST_L2, maskSize=5).astype(np.float32)[1:-1, 1:-1]
+
+            # sort the distances 
+            dist_array=sorted(dist_img.copy().flatten())[::-1]
+            dist_array = np.array(dist_array)
+            # find the threshold:
+            dis_thre = max(dist_array[int(dist_thre_ratio*np.sum(dist_array>0))],1)
+            #print(np.max(dist_array))
+            #print(dis_thre)
+            cY, cX = np.where(dist_img>=dis_thre)
+            while prompt_num_each_region[reg_id]>0:
+                # random select one prompt
+                random_idx = np.random.randint(0, len(cX))
+                cx, cy = int(cX[random_idx]), int(cY[random_idx])
+                prompt.append((cx,cy,1))
+                prompt_num_each_region[reg_id] -=1
+
+        while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size
+            prompt.append((cx,cy,1))
+    else: # if this image doesn't have target object
+        prompt = [(0,0,-1)]
+        mask_curr = np.zeros_like(label_msk)
+        while len(prompt)<max_prompt_num: # repeat prompt to ensure the same size
+            prompt.append((0,0,-1))
+    prompt = np.array(prompt) 
+    mask_curr = np.array(mask_curr,dtype=int)
+    return prompt,mask_curr
+
+
+def get_top_boxes(mask_cls,dist_thre_ratio=0.10,prompt_num=15,region_type='largest_15'):
+    # Find all disconnected regions
+    label_msk, region_ids = label(mask_cls, connectivity=2, return_num=True)
+    #print('num of regions found', region_ids)
+    ratio_list, regionid_list = [], []
+    for region_id in range(1, region_ids+1):
+        #find coordinates of points in the region
+        binary_msk = np.where(label_msk==region_id, 1, 0)
+
+        # clean some region that is abnormally small
+        r = np.sum(binary_msk) / np.sum(mask_cls)
+        #print('curr mask over all mask ratio', r)
+        ratio_list.append(r)
+        regionid_list.append(region_id)
+    if len(ratio_list)>0:
+        # sort the region from largest to smallest
+        ratio_list, regionid_list = zip(*sorted(zip(ratio_list, regionid_list)))
+        regionid_list = regionid_list[::-1]
+
+        if region_type == 'random':
+            prompt_num = 1
+            regionid_list = [random.choice(regionid_list)] # random choose 1 region
+        elif region_type[:7] == 'largest':
+            region_max_num = int(region_type.split('_')[-1])
+            regionid_list = regionid_list[:min(region_max_num,len(regionid_list))]
+
+        prompt = []
+        mask_curr = np.zeros_like(label_msk)
+        for reg_id in range(len(regionid_list)):
+            binary_msk = np.where(label_msk==regionid_list[reg_id], 1, 0)
+            mask_curr = np.logical_or(binary_msk,mask_curr)
+            box = MaskToBoxSimple(binary_msk,dist_thre_ratio)
+            prompt.append(box)
+
+        while len(prompt)<prompt_num: # repeat prompt to ensure the same size
+            prompt.append(box)
+        prompt = np.array(prompt) 
+        mask_curr = np.array(mask_curr,dtype=int)
+    else:
+        prompt = [[0,0,0,0]]
+        mask_curr = np.zeros_like(label_msk)
+        while len(prompt)<prompt_num:
+            prompt.append(prompt[0])
+    return prompt,mask_curr
+        
+def MaskToBoxSimple(mask,random_thre=0.05):
+    '''
+    random_thre, the randomness at each side of box
+    '''
+    mask = mask.squeeze()
+    
+    y_max,x_max = mask.shape[0],mask.shape[1]
+    
+    #find coordinates of points in the region
+    row, col = np.argwhere(mask).T
+    # find the four corner coordinates
+    y0,x0 = row.min(),col.min()
+    y1,x1 = row.max(),col.max()
+    
+    y_thre = (y1-y0)*random_thre
+    x_thre = (x1-x0)*random_thre
+    
+    x0 = max(0,x0-x_thre*random.random())
+    x1 = min(x_max,x1+x_thre*random.random())
+    
+    y0 = max(0,y0-y_thre*random.random())
+    y1 = min(y_max,y1+y_thre*random.random())
+    
+
+    return [x0,y0,x1,y1]
+
+def min_max_normalize(tensor,p=0.01):
+    p_min = torch.quantile(tensor,p)
+    p_max = torch.quantile(tensor,1-p)
+    tensor = torch.clamp(tensor,p_min,p_max)
+    return tensor
+