a b/GUI.py
1
import sys
2
if len(sys.argv) == 2:
3
    if sys.argv[1] == '-h':
4
        print("python "+sys.argv[0]+" <path to dir of scans>")
5
        print("python "+sys.argv[0]+" <path to dir of scans> <path to save dir>")
6
        exit(1)
7
8
import matplotlib.pyplot as plt
9
from matplotlib.widgets import Button
10
from procedures.attack_pipeline import *
11
from utils.equalizer import *
12
import matplotlib.animation as animation
13
import time
14
15
class GUI(object):
16
    # If load_path is to a *.dcm or *.mhd file then only this scan is loaded
17
    # If load_path is to a directory, then all scans are loaded. It is assumed that each scan is in its own subdirectory.
18
    # save_path is the directory to save the tampered scans (as dicom)
19
    def __init__(self, load_path, save_path=None):
20
        # init manipulator
21
        self.savepath = save_path
22
        self.filepaths = self._load_paths(load_path)  # load all scans filepaths in path
23
        self.fileindex = 0
24
        self.manipulator = scan_manipulator()
25
        self.manipulator.load_target_scan(self.filepaths[self.fileindex])
26
        self.hist_state = True
27
        self.inject_coords = []
28
        self.remove_coords = []
29
30
        # init plot
31
        self.eq = histEq(self.manipulator.scan)
32
        self.slices, self.cols, self.rows = self.manipulator.scan.shape
33
        self.ind = self.slices // 2
34
        self.pause_start = 0
35
        self.fig, self.ax = plt.subplots(1, 1, dpi=100)
36
        self.fig.suptitle('CT-GAN: Malicious Tampering of 3D Medical Imagery using Deep Learning\nTool by Yisroel Mirsky', fontsize=14, fontweight='bold')
37
        plt.subplots_adjust(bottom=0.2)
38
        self.ani_direction = 'down'
39
        self.animation = None
40
        self.animation_state = True
41
        self.plot()
42
        self.ax.set_title(os.path.split(self.filepaths[self.fileindex])[-1]) #filename
43
44
45
        # register click/scroll events
46
        self.action_state = 'inject' #default state
47
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)
48
        self.fig.canvas.mpl_connect('scroll_event', self.onscroll)
49
50
        # register buttons
51
        axanim = plt.axes([0.1, 0.21, 0.2, 0.075])
52
        self.banim = Button(axanim, 'Toggle Animation')
53
        self.banim.on_clicked(self.toggle_animation)
54
55
        axinj = plt.axes([0.1, 0.05, 0.1, 0.075])
56
        axrem = plt.axes([0.21, 0.05, 0.1, 0.075])
57
        self.binj = Button(axinj, 'Inject')
58
        self.binj.on_clicked(self.inj_on)
59
        self.brem = Button(axrem, 'Remove')
60
        self.brem.on_clicked(self.rem_on)
61
62
        axhist = plt.axes([0.35, 0.05, 0.2, 0.075])
63
        self.bhist = Button(axhist, 'Toggle HistEQ')
64
        self.bhist.on_clicked(self.hist)
65
66
        axprev = plt.axes([0.59, 0.05, 0.1, 0.075])
67
        axsave = plt.axes([0.7, 0.05, 0.1, 0.075])
68
        axnext = plt.axes([0.81, 0.05, 0.1, 0.075])
69
        self.bnext = Button(axnext, 'Next')
70
        self.bnext.on_clicked(self.next)
71
        self.bprev = Button(axprev, 'Previous')
72
        self.bprev.on_clicked(self.prev)
73
        self.bsave = Button(axsave, 'Save')
74
        self.bsave.on_clicked(self.save)
75
        self.maximize_window()
76
        self.update()
77
        plt.show()
78
79
    def _load_paths(self,path):
80
        filepaths = []
81
        # load single scan?
82
        if (path.split('.')[-1] == "dcm") or (path.split('.')[-1] == "mhd"):
83
            filepaths.append(path)
84
            return filepaths
85
        # try load directory of scans...
86
        files = os.listdir(path)
87
        for file in files:
88
            if os.path.isdir(file):
89
                subdir = os.path.join(path,file)
90
                subdir_files = os.listdir(subdir)
91
                if subdir_files[0].split('.')[-1] == "dcm": #folder contains dicom
92
                    filepaths.append(os.path.join(path,subdir))
93
                elif (subdir_files[0].split('.')[-1] == "mhd") or (subdir_files[0].split('.')[-1] == "raw"): # MHD
94
                    filepaths.append(os.path.join(path,subdir,subdir_files[0]))
95
            elif file.split('.')[-1] == "mhd":
96
                filepaths.append(os.path.join(path,file))
97
        return filepaths
98
99
    def onclick(self, event):
100
        # print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' %
101
        #       ('double' if event.dblclick else 'single', event.button,
102
        #        event.x, event.y, event.xdata, event.ydata))
103
        if event.xdata is not None:
104
            coord = np.array([self.ind,event.ydata,event.xdata],dtype=int)
105
            if coord[1] > 0 and coord[2] > 0:
106
                self.pause_start = np.Inf #pause while working
107
                if self.action_state == 'inject':
108
                    self.ax.set_title("Injecting...")
109
                    self.im.axes.figure.canvas.draw()
110
                    self.manipulator.tamper(coord, action='inject', isVox=True)
111
                    self.inject_coords.append(coord)
112
                else:
113
                    self.ax.set_title("Removing...")
114
                    self.im.axes.figure.canvas.draw()
115
                    self.manipulator.tamper(coord, action='remove', isVox=True)
116
                    self.remove_coords.append(coord)
117
                self.pause_start = time.time() #pause few secs to see result before continue
118
                self.ax.set_title(os.path.split(self.filepaths[self.fileindex])[-1])  # filename
119
                self.update()
120
121
    def onscroll(self, event):
122
        if event.button == 'up':
123
            self.ind = (self.ind + 1) % self.slices
124
        else:
125
            self.ind = (self.ind - 1) % self.slices
126
        self.update()
127
128
    def toggle_animation(self, event):
129
        self.animation_state = not self.animation_state
130
        if self.animation_state:
131
            self.pause_start = 0
132
        else:
133
            self.pause_start = np.Inf
134
135
    def inj_on(self, event):
136
        self.action_state = 'inject'
137
138
    def rem_on(self, event):
139
        self.action_state = 'remove'
140
141
    def hist(self, event):
142
        self.hist_state = not self.hist_state
143
        self.plot()
144
        self.update()
145
146
    def next(self, event):
147
        self.fileindex = (self.fileindex + 1) % len(self.filepaths)
148
        self.loadscan(self.fileindex)
149
150
    def prev(self, event):
151
        self.fileindex = (self.fileindex - 1) % len(self.filepaths)
152
        self.loadscan(self.fileindex)
153
154
    def save(self, event):
155
        if self.savepath is not None:
156
            self.ax.set_title("Saving...")
157
            self.im.axes.figure.canvas.draw()
158
            uuid = os.path.split(self.filepaths[self.fileindex])[-1][:-4]
159
            #save scan
160
            self.manipulator.save_tampered_scan(os.path.join(self.savepath,uuid),output_type='dicom')
161
            #save coords
162
            file_exists = False
163
            if os.path.exists(os.path.join(self.savepath,"tamper_coordinates.csv")):
164
                file_exists = True
165
            f = open(os.path.join(self.savepath,"tamper_coordinates.csv"),"a+")
166
            load_filename = os.path.split(self.filepaths[self.fileindex])[-1]
167
            if not file_exists:
168
                f.write("filename, x, y, z, tamper_type\n") #header
169
            for coord in self.inject_coords:
170
                f.write(load_filename+", "+str(coord[2])+", "+str(coord[1])+", "+str(coord[0])+", "+"inject\n")
171
            for coord in self.remove_coords:
172
                f.write(load_filename+", "+str(coord[2])+", "+str(coord[1])+", "+str(coord[0])+", "+"remove\n")
173
            f.close()
174
            self.ax.set_title(load_filename)  # filename
175
            self.im.axes.figure.canvas.draw()
176
177
    def update(self):
178
        if self.hist_state:
179
            self.im.set_data(self.eq.equalize(self.manipulator.scan[self.ind,:,:]))
180
        else:
181
            self.im.set_data(self.manipulator.scan[self.ind,:,:])
182
        self.ax.set_ylabel('slice %s' % self.ind)
183
        self.im.axes.figure.canvas.draw()
184
185
    def loadscan(self,fileindex):
186
        #load screen
187
        self.im.set_data(np.ones((self.cols,self.rows))*-1000)
188
        self.ax.set_title("Loading...")
189
        self.im.axes.figure.canvas.draw()
190
        self.remove_coords.clear()
191
        self.inject_coords.clear()
192
        #load scan
193
        self.manipulator.load_target_scan(self.filepaths[fileindex])
194
        self.slices, self.cols, self.rows = self.manipulator.scan.shape
195
        self.ind = self.slices//2
196
        self.ax.clear()
197
        self.eq = histEq(self.manipulator.scan)
198
        self.plot()
199
        self.ax.set_title(os.path.split(self.filepaths[fileindex])[-1]) #filename
200
        self.ax.set_ylabel('slice %s' % self.ind)
201
        self.im.axes.figure.canvas.draw()
202
203
    def plot(self):
204
        self.ax.clear()
205
        if self.hist_state:
206
            self.im = self.ax.imshow(self.eq.equalize(self.manipulator.scan[self.ind,:,:]),cmap="bone")#, cmap="bone", vmin=-1000, vmax=1750)
207
        else:
208
            self.im = self.ax.imshow(self.manipulator.scan[self.ind,:,:], cmap="bone", vmin=-1000, vmax=1750)
209
        self.animation = animation.FuncAnimation(self.fig, self.animate, interval=100)
210
211
    def animate(self,i):
212
        if self.animation_state:
213
            if time.time() - self.pause_start > 1:
214
                if self.ind == self.slices-1:
215
                    self.ani_direction = 'up'
216
                elif self.ind == 0:
217
                    self.ani_direction = 'down'
218
                if self.ani_direction == 'up':
219
                    self.ind-=1
220
                else:
221
                    self.ind+=1
222
                self.update()
223
224
    def maximize_window(self):
225
        try: #'QT4Agg'
226
            figManager = plt.get_current_fig_manager()
227
            figManager.window.showMaximized()
228
        except:
229
            try: #'TkAgg'
230
                mng = plt.get_current_fig_manager()
231
                mng.window.state('zoomed')
232
            except:
233
                try: #'wxAgg'
234
                    mng = plt.get_current_fig_manager()
235
                    mng.frame.Maximize(True)
236
                except:
237
                    print("Could not maximize window")
238
239
if (len(sys.argv) == 1) or (len(sys.argv) > 3):
240
    loadpath = "data\\healthy_scans"
241
    savepath = "data\\tampered_scans"
242
if len(sys.argv) == 2:
243
    loadpath = sys.argv[1]
244
    savepath = "data\\tampered_scans"
245
if len(sys.argv) == 3:
246
    loadpath = sys.argv[1]
247
    savepath = sys.argv[2]
248
249
gui = GUI(load_path=loadpath,save_path=savepath)