|
a |
|
b/GUI.py |
|
|
1 |
import sys |
|
|
2 |
if len(sys.argv) == 2: |
|
|
3 |
if sys.argv[1] == '-h': |
|
|
4 |
print("python "+sys.argv[0]+" <path to dir of scans>") |
|
|
5 |
print("python "+sys.argv[0]+" <path to dir of scans> <path to save dir>") |
|
|
6 |
exit(1) |
|
|
7 |
|
|
|
8 |
import matplotlib.pyplot as plt |
|
|
9 |
from matplotlib.widgets import Button |
|
|
10 |
from procedures.attack_pipeline import * |
|
|
11 |
from utils.equalizer import * |
|
|
12 |
import matplotlib.animation as animation |
|
|
13 |
import time |
|
|
14 |
|
|
|
15 |
class GUI(object): |
|
|
16 |
# If load_path is to a *.dcm or *.mhd file then only this scan is loaded |
|
|
17 |
# If load_path is to a directory, then all scans are loaded. It is assumed that each scan is in its own subdirectory. |
|
|
18 |
# save_path is the directory to save the tampered scans (as dicom) |
|
|
19 |
def __init__(self, load_path, save_path=None): |
|
|
20 |
# init manipulator |
|
|
21 |
self.savepath = save_path |
|
|
22 |
self.filepaths = self._load_paths(load_path) # load all scans filepaths in path |
|
|
23 |
self.fileindex = 0 |
|
|
24 |
self.manipulator = scan_manipulator() |
|
|
25 |
self.manipulator.load_target_scan(self.filepaths[self.fileindex]) |
|
|
26 |
self.hist_state = True |
|
|
27 |
self.inject_coords = [] |
|
|
28 |
self.remove_coords = [] |
|
|
29 |
|
|
|
30 |
# init plot |
|
|
31 |
self.eq = histEq(self.manipulator.scan) |
|
|
32 |
self.slices, self.cols, self.rows = self.manipulator.scan.shape |
|
|
33 |
self.ind = self.slices // 2 |
|
|
34 |
self.pause_start = 0 |
|
|
35 |
self.fig, self.ax = plt.subplots(1, 1, dpi=100) |
|
|
36 |
self.fig.suptitle('CT-GAN: Malicious Tampering of 3D Medical Imagery using Deep Learning\nTool by Yisroel Mirsky', fontsize=14, fontweight='bold') |
|
|
37 |
plt.subplots_adjust(bottom=0.2) |
|
|
38 |
self.ani_direction = 'down' |
|
|
39 |
self.animation = None |
|
|
40 |
self.animation_state = True |
|
|
41 |
self.plot() |
|
|
42 |
self.ax.set_title(os.path.split(self.filepaths[self.fileindex])[-1]) #filename |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
# register click/scroll events |
|
|
46 |
self.action_state = 'inject' #default state |
|
|
47 |
self.fig.canvas.mpl_connect('button_press_event', self.onclick) |
|
|
48 |
self.fig.canvas.mpl_connect('scroll_event', self.onscroll) |
|
|
49 |
|
|
|
50 |
# register buttons |
|
|
51 |
axanim = plt.axes([0.1, 0.21, 0.2, 0.075]) |
|
|
52 |
self.banim = Button(axanim, 'Toggle Animation') |
|
|
53 |
self.banim.on_clicked(self.toggle_animation) |
|
|
54 |
|
|
|
55 |
axinj = plt.axes([0.1, 0.05, 0.1, 0.075]) |
|
|
56 |
axrem = plt.axes([0.21, 0.05, 0.1, 0.075]) |
|
|
57 |
self.binj = Button(axinj, 'Inject') |
|
|
58 |
self.binj.on_clicked(self.inj_on) |
|
|
59 |
self.brem = Button(axrem, 'Remove') |
|
|
60 |
self.brem.on_clicked(self.rem_on) |
|
|
61 |
|
|
|
62 |
axhist = plt.axes([0.35, 0.05, 0.2, 0.075]) |
|
|
63 |
self.bhist = Button(axhist, 'Toggle HistEQ') |
|
|
64 |
self.bhist.on_clicked(self.hist) |
|
|
65 |
|
|
|
66 |
axprev = plt.axes([0.59, 0.05, 0.1, 0.075]) |
|
|
67 |
axsave = plt.axes([0.7, 0.05, 0.1, 0.075]) |
|
|
68 |
axnext = plt.axes([0.81, 0.05, 0.1, 0.075]) |
|
|
69 |
self.bnext = Button(axnext, 'Next') |
|
|
70 |
self.bnext.on_clicked(self.next) |
|
|
71 |
self.bprev = Button(axprev, 'Previous') |
|
|
72 |
self.bprev.on_clicked(self.prev) |
|
|
73 |
self.bsave = Button(axsave, 'Save') |
|
|
74 |
self.bsave.on_clicked(self.save) |
|
|
75 |
self.maximize_window() |
|
|
76 |
self.update() |
|
|
77 |
plt.show() |
|
|
78 |
|
|
|
79 |
def _load_paths(self,path): |
|
|
80 |
filepaths = [] |
|
|
81 |
# load single scan? |
|
|
82 |
if (path.split('.')[-1] == "dcm") or (path.split('.')[-1] == "mhd"): |
|
|
83 |
filepaths.append(path) |
|
|
84 |
return filepaths |
|
|
85 |
# try load directory of scans... |
|
|
86 |
files = os.listdir(path) |
|
|
87 |
for file in files: |
|
|
88 |
if os.path.isdir(file): |
|
|
89 |
subdir = os.path.join(path,file) |
|
|
90 |
subdir_files = os.listdir(subdir) |
|
|
91 |
if subdir_files[0].split('.')[-1] == "dcm": #folder contains dicom |
|
|
92 |
filepaths.append(os.path.join(path,subdir)) |
|
|
93 |
elif (subdir_files[0].split('.')[-1] == "mhd") or (subdir_files[0].split('.')[-1] == "raw"): # MHD |
|
|
94 |
filepaths.append(os.path.join(path,subdir,subdir_files[0])) |
|
|
95 |
elif file.split('.')[-1] == "mhd": |
|
|
96 |
filepaths.append(os.path.join(path,file)) |
|
|
97 |
return filepaths |
|
|
98 |
|
|
|
99 |
def onclick(self, event): |
|
|
100 |
# print('%s click: button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % |
|
|
101 |
# ('double' if event.dblclick else 'single', event.button, |
|
|
102 |
# event.x, event.y, event.xdata, event.ydata)) |
|
|
103 |
if event.xdata is not None: |
|
|
104 |
coord = np.array([self.ind,event.ydata,event.xdata],dtype=int) |
|
|
105 |
if coord[1] > 0 and coord[2] > 0: |
|
|
106 |
self.pause_start = np.Inf #pause while working |
|
|
107 |
if self.action_state == 'inject': |
|
|
108 |
self.ax.set_title("Injecting...") |
|
|
109 |
self.im.axes.figure.canvas.draw() |
|
|
110 |
self.manipulator.tamper(coord, action='inject', isVox=True) |
|
|
111 |
self.inject_coords.append(coord) |
|
|
112 |
else: |
|
|
113 |
self.ax.set_title("Removing...") |
|
|
114 |
self.im.axes.figure.canvas.draw() |
|
|
115 |
self.manipulator.tamper(coord, action='remove', isVox=True) |
|
|
116 |
self.remove_coords.append(coord) |
|
|
117 |
self.pause_start = time.time() #pause few secs to see result before continue |
|
|
118 |
self.ax.set_title(os.path.split(self.filepaths[self.fileindex])[-1]) # filename |
|
|
119 |
self.update() |
|
|
120 |
|
|
|
121 |
def onscroll(self, event): |
|
|
122 |
if event.button == 'up': |
|
|
123 |
self.ind = (self.ind + 1) % self.slices |
|
|
124 |
else: |
|
|
125 |
self.ind = (self.ind - 1) % self.slices |
|
|
126 |
self.update() |
|
|
127 |
|
|
|
128 |
def toggle_animation(self, event): |
|
|
129 |
self.animation_state = not self.animation_state |
|
|
130 |
if self.animation_state: |
|
|
131 |
self.pause_start = 0 |
|
|
132 |
else: |
|
|
133 |
self.pause_start = np.Inf |
|
|
134 |
|
|
|
135 |
def inj_on(self, event): |
|
|
136 |
self.action_state = 'inject' |
|
|
137 |
|
|
|
138 |
def rem_on(self, event): |
|
|
139 |
self.action_state = 'remove' |
|
|
140 |
|
|
|
141 |
def hist(self, event): |
|
|
142 |
self.hist_state = not self.hist_state |
|
|
143 |
self.plot() |
|
|
144 |
self.update() |
|
|
145 |
|
|
|
146 |
def next(self, event): |
|
|
147 |
self.fileindex = (self.fileindex + 1) % len(self.filepaths) |
|
|
148 |
self.loadscan(self.fileindex) |
|
|
149 |
|
|
|
150 |
def prev(self, event): |
|
|
151 |
self.fileindex = (self.fileindex - 1) % len(self.filepaths) |
|
|
152 |
self.loadscan(self.fileindex) |
|
|
153 |
|
|
|
154 |
def save(self, event): |
|
|
155 |
if self.savepath is not None: |
|
|
156 |
self.ax.set_title("Saving...") |
|
|
157 |
self.im.axes.figure.canvas.draw() |
|
|
158 |
uuid = os.path.split(self.filepaths[self.fileindex])[-1][:-4] |
|
|
159 |
#save scan |
|
|
160 |
self.manipulator.save_tampered_scan(os.path.join(self.savepath,uuid),output_type='dicom') |
|
|
161 |
#save coords |
|
|
162 |
file_exists = False |
|
|
163 |
if os.path.exists(os.path.join(self.savepath,"tamper_coordinates.csv")): |
|
|
164 |
file_exists = True |
|
|
165 |
f = open(os.path.join(self.savepath,"tamper_coordinates.csv"),"a+") |
|
|
166 |
load_filename = os.path.split(self.filepaths[self.fileindex])[-1] |
|
|
167 |
if not file_exists: |
|
|
168 |
f.write("filename, x, y, z, tamper_type\n") #header |
|
|
169 |
for coord in self.inject_coords: |
|
|
170 |
f.write(load_filename+", "+str(coord[2])+", "+str(coord[1])+", "+str(coord[0])+", "+"inject\n") |
|
|
171 |
for coord in self.remove_coords: |
|
|
172 |
f.write(load_filename+", "+str(coord[2])+", "+str(coord[1])+", "+str(coord[0])+", "+"remove\n") |
|
|
173 |
f.close() |
|
|
174 |
self.ax.set_title(load_filename) # filename |
|
|
175 |
self.im.axes.figure.canvas.draw() |
|
|
176 |
|
|
|
177 |
def update(self): |
|
|
178 |
if self.hist_state: |
|
|
179 |
self.im.set_data(self.eq.equalize(self.manipulator.scan[self.ind,:,:])) |
|
|
180 |
else: |
|
|
181 |
self.im.set_data(self.manipulator.scan[self.ind,:,:]) |
|
|
182 |
self.ax.set_ylabel('slice %s' % self.ind) |
|
|
183 |
self.im.axes.figure.canvas.draw() |
|
|
184 |
|
|
|
185 |
def loadscan(self,fileindex): |
|
|
186 |
#load screen |
|
|
187 |
self.im.set_data(np.ones((self.cols,self.rows))*-1000) |
|
|
188 |
self.ax.set_title("Loading...") |
|
|
189 |
self.im.axes.figure.canvas.draw() |
|
|
190 |
self.remove_coords.clear() |
|
|
191 |
self.inject_coords.clear() |
|
|
192 |
#load scan |
|
|
193 |
self.manipulator.load_target_scan(self.filepaths[fileindex]) |
|
|
194 |
self.slices, self.cols, self.rows = self.manipulator.scan.shape |
|
|
195 |
self.ind = self.slices//2 |
|
|
196 |
self.ax.clear() |
|
|
197 |
self.eq = histEq(self.manipulator.scan) |
|
|
198 |
self.plot() |
|
|
199 |
self.ax.set_title(os.path.split(self.filepaths[fileindex])[-1]) #filename |
|
|
200 |
self.ax.set_ylabel('slice %s' % self.ind) |
|
|
201 |
self.im.axes.figure.canvas.draw() |
|
|
202 |
|
|
|
203 |
def plot(self): |
|
|
204 |
self.ax.clear() |
|
|
205 |
if self.hist_state: |
|
|
206 |
self.im = self.ax.imshow(self.eq.equalize(self.manipulator.scan[self.ind,:,:]),cmap="bone")#, cmap="bone", vmin=-1000, vmax=1750) |
|
|
207 |
else: |
|
|
208 |
self.im = self.ax.imshow(self.manipulator.scan[self.ind,:,:], cmap="bone", vmin=-1000, vmax=1750) |
|
|
209 |
self.animation = animation.FuncAnimation(self.fig, self.animate, interval=100) |
|
|
210 |
|
|
|
211 |
def animate(self,i): |
|
|
212 |
if self.animation_state: |
|
|
213 |
if time.time() - self.pause_start > 1: |
|
|
214 |
if self.ind == self.slices-1: |
|
|
215 |
self.ani_direction = 'up' |
|
|
216 |
elif self.ind == 0: |
|
|
217 |
self.ani_direction = 'down' |
|
|
218 |
if self.ani_direction == 'up': |
|
|
219 |
self.ind-=1 |
|
|
220 |
else: |
|
|
221 |
self.ind+=1 |
|
|
222 |
self.update() |
|
|
223 |
|
|
|
224 |
def maximize_window(self): |
|
|
225 |
try: #'QT4Agg' |
|
|
226 |
figManager = plt.get_current_fig_manager() |
|
|
227 |
figManager.window.showMaximized() |
|
|
228 |
except: |
|
|
229 |
try: #'TkAgg' |
|
|
230 |
mng = plt.get_current_fig_manager() |
|
|
231 |
mng.window.state('zoomed') |
|
|
232 |
except: |
|
|
233 |
try: #'wxAgg' |
|
|
234 |
mng = plt.get_current_fig_manager() |
|
|
235 |
mng.frame.Maximize(True) |
|
|
236 |
except: |
|
|
237 |
print("Could not maximize window") |
|
|
238 |
|
|
|
239 |
if (len(sys.argv) == 1) or (len(sys.argv) > 3): |
|
|
240 |
loadpath = "data\\healthy_scans" |
|
|
241 |
savepath = "data\\tampered_scans" |
|
|
242 |
if len(sys.argv) == 2: |
|
|
243 |
loadpath = sys.argv[1] |
|
|
244 |
savepath = "data\\tampered_scans" |
|
|
245 |
if len(sys.argv) == 3: |
|
|
246 |
loadpath = sys.argv[1] |
|
|
247 |
savepath = sys.argv[2] |
|
|
248 |
|
|
|
249 |
gui = GUI(load_path=loadpath,save_path=savepath) |