Switch to unified view

a b/utils/direct_field/df_cardia.py
1
import numpy as np
2
from scipy import ndimage
3
import math
4
import cv2
5
from PIL import Image
6
7
def direct_field(a, norm=True):
8
    """ a: np.ndarray, (h, w)
9
    """
10
    if a.ndim == 3:
11
        a = np.squeeze(a)
12
13
    h, w = a.shape
14
15
    a_Image = Image.fromarray(a)
16
    a = a_Image.resize((w, h), Image.NEAREST)
17
    a = np.array(a)
18
    
19
    accumulation = np.zeros((2, h, w), dtype=np.float32)
20
    for i in np.unique(a)[1:]:
21
        # b, ind = ndimage.distance_transform_edt(a==i, return_indices=True)
22
        # c = np.indices((h, w))
23
        # diff = c - ind
24
        # dr = np.sqrt(np.sum(diff ** 2, axis=0))
25
26
        img = (a == i).astype(np.uint8)
27
        dst, labels = cv2.distanceTransformWithLabels(img, cv2.DIST_L2, cv2.DIST_MASK_PRECISE, labelType=cv2.DIST_LABEL_PIXEL)
28
        index = np.copy(labels)
29
        index[img > 0] = 0
30
        place = np.argwhere(index > 0)
31
        nearCord = place[labels-1,:]
32
        x = nearCord[:, :, 0]
33
        y = nearCord[:, :, 1]
34
        nearPixel = np.zeros((2, h, w))
35
        nearPixel[0,:,:] = x
36
        nearPixel[1,:,:] = y
37
        grid = np.indices(img.shape)
38
        grid = grid.astype(float)
39
        diff = grid - nearPixel
40
        if norm:
41
            dr = np.sqrt(np.sum(diff**2, axis = 0))
42
        else:
43
            dr = np.ones_like(img)
44
45
        # direction = np.zeros((2, h, w), dtype=np.float32)
46
        # direction[0, b>0] = np.divide(diff[0, b>0], dr[b>0])
47
        # direction[1, b>0] = np.divide(diff[1, b>0], dr[b>0])
48
49
        direction = np.zeros((2, h, w), dtype=np.float32)
50
        direction[0, img>0] = np.divide(diff[0, img>0], dr[img>0])
51
        direction[1, img>0] = np.divide(diff[1, img>0], dr[img>0])
52
53
        accumulation[:, img>0] = 0
54
        accumulation = accumulation + direction
55
    
56
    # mag, angle = cv2.cartToPolar(accumulation[0, ...], accumulation[1, ...])
57
    # for l in np.unique(a)[1:]:
58
    #     mag_i = mag[a==l].astype(float)
59
    #     t = 1 / mag_i * mag_i.max()
60
    #     mag[a==l] = t
61
    # x, y = cv2.polarToCart(mag, angle)
62
    # accumulation = np.stack([x, y], axis=0)
63
64
    return accumulation
65
66
67
if __name__ == "__main__":
68
    import matplotlib.pyplot as plt
69
    # gt_p = "/home/ffbian/chencheng/XieheCardiac/npydata/dianfen/16100000/gts/16100000_CINE_segmented_SAX_b3.npy"
70
    # gt = np.load(gt_p)[..., 9]  # uint8
71
    # print(gt.shape)
72
73
    # a_Image = Image.fromarray(gt)
74
    # a = a_Image.resize((224, 224), Image.NEAREST)
75
    # a = np.array(a)  # uint8
76
    # print(a.shape, np.unique(a))
77
78
    # # plt.imshow(a)
79
    # # plt.show()
80
81
    # ############################################################
82
    # direction = direct_field(gt)
83
    
84
    # theta = np.arctan2(direction[1,...], direction[0,...])
85
    # degree = theta * 180 / math.pi
86
    # degree = (degree + 180) / 360
87
88
    # plt.imshow(degree)
89
    # plt.show()
90
91
    ########################################################
92
    import json, time, pdb, h5py
93
    data_list = "/home/ffbian/chencheng/XieheCardiac/2DUNet/UNet/libs/datasets/train_new.json"
94
    data_list = "/root/chengfeng/Cardiac/source_code/libs/datasets/jsonLists/acdcList/Dense_TestList.json"
95
    with open(data_list, 'r') as f:
96
        data_infos = json.load(f)
97
    
98
    mag_stat = []
99
    st = time.time()
100
    for i, di in enumerate(data_infos):
101
        # img_p, times_idx = di
102
        # gt_p = img_p.replace("/imgs/", "/gts/")
103
        # gt = np.load(gt_p)[..., times_idx]
104
        
105
        img = h5py.File(di,'r')['image']
106
        gt = h5py.File(di,'r')['label']
107
        gt = np.array(gt).astype(np.float32)
108
109
        print(gt.shape)
110
        direction = direct_field(gt, False)
111
        # theta = np.arctan2(direction[1,...], direction[0,...])
112
        mag, angle = cv2.cartToPolar(direction[0, ...], direction[1, ...])
113
        # degree = theta * 180 / math.pi
114
        # degree = (degree + 180) / 360
115
        degree = angle / (2 * math.pi) * 255
116
        # degree = (theta - theta.min()) / (theta.max() - theta.min()) * 255
117
        # mag = np.sqrt(np.sum(direction ** 2, axis=0, keepdims=False))
118
        
119
120
        # 归一化
121
        # for l in np.unique(gt)[1:]:
122
        #     mag_i = mag[gt==l].astype(float)
123
        #     # mag[gt==l] = 1. - mag[gt==l] / np.max(mag[gt==l])
124
        #     t = (mag_i - mag_i.min()) / (mag_i.max() - mag_i.min())
125
        #     mag[gt==l] = np.exp(-10*t)
126
        #     print(mag_i.max(), mag_i.min())
127
128
        # for l in np.unique(gt)[1:]:
129
        #     mag_i = mag[gt==l].astype(float)
130
        #     t = 1 / (mag_i) * mag_i.max()
131
        #     # t = np.exp(-0.8*mag_i) * mag_i.max()
132
        #     # t = 1 / np.sqrt(mag_i+1) * mag_i.max()
133
        #     mag[gt==l] = t
134
        #     # print(mag_i.max(), mag_i.min())
135
136
        # mag[mag>0] = 2 * np.exp(-0.8*(mag[mag>0]-1))
137
        # mag[mag>0] = 2 * np.exp(0.8*(mag[mag>0]-1))
138
139
140
        mag_stat.append(mag.max())
141
        # pdb.set_trace()
142
143
        # plt.imshow(degree)
144
        # plt.show()
145
146
        ######################
147
        fig, axs = plt.subplots(1, 3)
148
        axs[0].imshow(degree)
149
        axs[1].imshow(gt)
150
        axs[2].imshow(mag)
151
        plt.show()
152
153
        ######################
154
        if i % 100 == 0:
155
            print("\r\r{}/{}  {:.4}s".format(i+1, len(data_infos), time.time()-st))
156
    print()
157
158
    print("total time: ", time.time()-st)
159
    print("Average time: ", (time.time()-st) / len(data_infos))
160
    # total time:  865.811030626297
161
    # Average time:  0.012969593759126428
162
163
    plt.hist(mag_stat)
164
    plt.show()
165
    print(mag_stat)