a b/utils/normalizations.py
1
import numpy as np
2
import sys
3
import os
4
#implementation of normalization functions
5
    #In-subject Variance normalization, 
6
    #Global variance normalization
7
    #Histogram normalization (in use as default)
8
9
#------------------------------------------------------------------------------------------------------
10
#------------------------------------------------------------------------------------------------------
11
#Histogram normalization functions
12
13
############################################################
14
#TRAINING
15
def applyHistNormalize(img, postfix, main_folder_path):
16
    file_path = os.path.join(main_folder_path, 'extra_data/', 'hist.txt')
17
    pc, s, m_p, mean_m = readHistInfo(file_path)
18
19
    return getTransform(img, pc, s, m_p, mean_m)
20
21
#returns landmark scores of the image
22
def getLandmarks(img, pc = (1,99), m_p=tuple(range(10, 100, 10))):
23
    img = img[img != 0] #and np.isfinite(img)
24
    threshold = np.mean(img)
25
    img = img[img > threshold]
26
27
    p = tuple(np.percentile(img, pc))
28
29
    m = tuple(np.percentile(img, m_p))
30
31
    return p, m
32
33
#extract linear map from p to s and map m's
34
#p is [p_1, p_2]
35
#s is [s_1, s_2]
36
#m is the landmark value
37
def mapLandmarksVec(p, s, m):
38
    p_1, p_2 = p[0], p[1]
39
    s_1, s_2 = s[0], s[1]
40
41
    new_val = np.zeros_like(p_1)
42
    same_inds = (p_1 == p_2)
43
    if np.sum(same_inds):
44
        print('Fix this')
45
        sys.exit()
46
        #Change with this if I encounter bug
47
        #new_val[same_inds] = s_1[same_inds].reshape(-1)
48
        #new_val[np.inverse(same_inds)] = (((m - p_1) * ((s_2 - s_1) / (p_2 - p_1))) + s_1).reshape(-1)
49
50
    #sys.exit()
51
    #new_val = ((m - p_1) * ((s_2 - s_1) / (p_2 - p_1))) + s_1
52
53
    return ((m-p_1) / (p_2-p_1) * (s_2 - s_1)) + s_1
54
55
def mapLandmarks(p, s, m):
56
    p_1, p_2 = p[0], p[1]
57
    s_1, s_2 = s[0], s[1]
58
59
    if p_1 == p_2:
60
        return s_1
61
    m_slope = (m-p_1) / (p_2-p_1)
62
63
    return (m_slope * (s_2 - s_1)) + s_1
64
65
##################################################################
66
def getTransform(img, pc, s, m_p, mean_m):
67
    z = np.copy(img)
68
    p, m = getLandmarks(img, pc, m_p)
69
70
    #using img, p, m, s, mean_m get the normalized image
71
    p_1, p_2 = p[0], p[1]
72
    s_1, s_2 = s[0], s[1]
73
74
    #histogram values at locations (pc + landmarks)
75
    m = [p_1] + list(m) + [p_2]
76
    #map scale corresponding to these values
77
    mean_m = [s_1] + list(mean_m) + [s_2]
78
    new_img = np.zeros_like(img, dtype=np.int64)
79
    hist_indices = np.zeros_like(img, dtype=np.int64)
80
81
    hist_indices = np.copy(new_img)
82
83
    for m_ in m:
84
        hist_indices += (img > m_).astype(int)
85
86
    hist_indices = np.clip(hist_indices, 1, len(m) - 1, out=hist_indices)
87
88
    indexer_m = lambda v: m[v]
89
    indexer_mm = lambda v: mean_m[v]
90
    f_m = np.vectorize(indexer_m)
91
    f_mm = np.vectorize(indexer_mm)
92
    
93
    new_p_1 = f_m(hist_indices - 1)
94
    new_p_2 = f_m(hist_indices)
95
    new_s_1 = f_mm(hist_indices - 1)
96
    new_s_2 = f_mm(hist_indices)
97
    
98
    new_img = mapLandmarksVec([new_p_1, new_p_2], [new_s_1, new_s_2], img)
99
    
100
    new_img = np.clip(new_img, s_1-1, s_2+1, out=new_img)
101
    
102
    return new_img
103
104
##################################################################
105
106
def iterOut(i):
107
    return ' '.join(str(x) for x in i) + '\n'
108
109
#READ AND WRITE
110
def writeHistInfo(filepath, pc, s, m_p, mean_m):
111
    with open(filepath, 'w+') as f:
112
        #parameters
113
        f.write(iterOut(pc))
114
        f.write(iterOut(s))
115
        f.write(iterOut(m_p))
116
        #result
117
        f.write(iterOut(mean_m))
118
119
def readHistInfo(filepath):
120
    lines = [line.rstrip() for line in open(filepath)]
121
    info = []
122
    info.append(tuple(float(x) for x in lines[0].split()))
123
    info.append(tuple(float(x) for x in lines[1].split()))
124
    info.append(tuple(float(x) for x in lines[2].split()))
125
    info.append(tuple(int(x) for x in lines[3].split()))
126
    
127
    #return     pc,     s,      m_p,    mean_m
128
    return info[0], info[1], info[2], info[3]
129
################################################################
130
131
#------------------------------------------------------------------------------------------------------
132
#------------------------------------------------------------------------------------------------------
133
134
#insubjectvar, globalvar supported so far
135
def applyNormalize(img, postfix, 
136
                    norm_method = 'hist', 
137
                    main_folder_path = '../../Data/MS2017b/'):
138
    if norm_method == 'insubjectvar':
139
        return applyInSubjectNormalize(img)
140
    elif norm_method == 'globalvar':
141
        return applyGlobalNormalize(img, postfix, main_folder_path)
142
    elif norm_method == 'hist':
143
        return applyHistNormalize(img, postfix, main_folder_path)
144
    else:
145
        print('Apply normalize doesn\'t support other functions currently')
146
        sys.exit()
147
148
def applyGlobalNormalize(img, postfix, main_folder_path = '../../Data/MS2017b/'):
149
    #subtract by dataset mean and divide by pixel standard deviation
150
    means = np.load(main_folder_path + 'extra_data/mean' + postfix + '.npy')
151
    stds = np.load(main_folder_path + 'extra_data/std' + postfix + '.npy')
152
    img = (img - means) / (stds + 0.000001)
153
    return img
154
155
def applyInSubjectNormalize(img):
156
    m = np.mean(img[img != 0])
157
    s = np.std(img[img != 0])
158
    img = (img - m) / s
159
    return img