Diff of /dosma/gui/ims.py [000000] .. [030aeb]

Switch to side-by-side view

--- a
+++ b/dosma/gui/ims.py
@@ -0,0 +1,601 @@
+import logging
+import os
+import sys
+import tkinter as tk
+from tkinter import IntVar, Radiobutton, filedialog, messagebox, ttk
+from typing import Dict
+
+import numpy as np
+import Pmw
+from skimage.color import label2rgb
+from skimage.measure import label
+
+from dosma.cli import SUPPORTED_QUANTITATIVE_VALUES, SUPPORTED_SCAN_TYPES, parse_args
+from dosma.core.io import format_io_utils as fio_utils
+from dosma.core.orientation import AXIAL, CORONAL, SAGITTAL
+from dosma.gui.dosma_gui import ScanReader
+from dosma.gui.gui_utils.filedialog_reader import FileDialogReader
+from dosma.gui.im_viewer import IndexTracker
+from dosma.gui.preferences_viewer import PreferencesManager
+from dosma.msk import knee
+
+import matplotlib
+import matplotlib.pyplot as plt
+from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
+
+matplotlib.use("TkAgg")
+LARGE_FONT = ("Verdana", 12)
+
+_logger = logging.getLogger(__name__)
+
+
+class DosmaViewer(tk.Tk):
+    def __init__(self, *args, **kwargs):
+        tk.Tk.__init__(self, *args, **kwargs)
+
+        container = tk.Frame(self)
+        container.pack(side="top", fill="both", expand=True)
+        container.grid_rowconfigure(0, weight=1)
+        container.grid_columnconfigure(0, weight=1)
+
+        self.frames = {}
+        self.protocol("WM_DELETE_WINDOW", self.on_closing)
+        for F in (StartPage, DosmaFrame, PageThree, AnalysisFrame):
+            frame = F(container, self)
+
+            self.frames[F] = frame
+
+            frame.grid(row=0, column=0, sticky="nsew")
+
+        self.show_frame(StartPage)
+
+        self.pref = PreferencesManager()
+
+    def on_closing(self):
+        if messagebox.askokcancel("Quit", "Do you want to quit?"):
+            sys.exit()
+
+    def show_frame(self, cont):
+        frame = self.frames[cont]
+        frame.tkraise()
+
+    def show_preferences(self):
+        self.pref.show_window(self)
+
+
+class StartPage(tk.Frame):
+    def __init__(self, parent, controller):
+        tk.Frame.__init__(self, parent)
+        # photo = tk.PhotoImage(file="./defaults/skel-rotate.gif")
+        # label1 = tk.Label(image=photo)
+        # label1.pack()
+
+        label = tk.Label(self, text="Start Page", font=LARGE_FONT)
+        label.pack(pady=10, padx=10)
+
+        button2 = ttk.Button(self, text="Scan", command=lambda: controller.show_frame(DosmaFrame))
+        button2.pack()
+
+        button3 = ttk.Button(
+            self, text="Knee Analysis", command=lambda: controller.show_frame(AnalysisFrame)
+        )
+        button3.pack()
+
+        button3 = ttk.Button(
+            self, text="Image Viewer", command=lambda: controller.show_frame(PageThree)
+        )
+        button3.pack()
+
+        button3 = ttk.Button(
+            self, text="Preferences", command=lambda: controller.show_preferences()
+        )
+        button3.pack()
+
+
+class AnalysisFrame(tk.Frame):
+    __TISSUES_KEY = "Tissues"
+    __QUANTITATIVE_VALUES_KEY = "Quantitative values"
+    __LOAD_PATH_KEY = "Load data"
+
+    __PID_KEY = "pid"
+    __MEDIAL_TO_LATERAL_ORIENTATION_KEY = "ml"
+
+    def __init__(self, parent, controller):
+        tk.Frame.__init__(self, parent)
+
+        self.manager: Dict = {}
+        self.gui_manager: Dict = {}
+        self.balloon = Pmw.Balloon()
+
+        self.__init_manager()
+
+        self.__base_gui()
+        self.preferences = PreferencesManager()
+        self.file_dialog_reader = FileDialogReader()
+        self.scan_reader = ScanReader(self)
+
+        button1 = ttk.Button(self, text="Home", command=lambda: controller.show_frame(StartPage))
+        button1.pack(anchor="se", side="right")
+
+        button1 = ttk.Button(self, text="Run", command=lambda: self.execute())
+        button1.pack(anchor="sw", side="left")
+
+    def execute(self):
+        try:
+            load_path = self.manager[self.__LOAD_PATH_KEY].get()
+            if not load_path:
+                raise ValueError("Load path not defined")
+
+            preferences_str = self.preferences.get_cmd_line_str().strip()
+
+            tissue_str = ""
+            for c, t in enumerate(self.manager[self.__TISSUES_KEY]):
+                if t.get():
+                    tissue_str += "--%s " % knee.SUPPORTED_TISSUES[c].STR_ID
+            tissue_str = tissue_str.strip()
+
+            if not tissue_str:
+                raise ValueError("No tissues selected")
+
+            qv_str = ""
+            for c, qv in enumerate(self.manager[self.__QUANTITATIVE_VALUES_KEY]):
+                if qv.get():
+                    qv_str += "--%s " % SUPPORTED_QUANTITATIVE_VALUES[c].name.lower()
+            qv_str = qv_str.strip()
+
+            if not qv_str:
+                raise ValueError("No quantitative values selected")
+
+            pid = self.manager[self.__PID_KEY].get()
+            medial_to_lateral = self.manager[self.__MEDIAL_TO_LATERAL_ORIENTATION_KEY].get()
+
+            if not pid:
+                raise ValueError("No PID was provided")
+
+            # analysis string
+            str_f = "--l %s %s knee %s --pid %s %s %s" % (
+                load_path,
+                preferences_str,
+                tissue_str,
+                pid,
+                "--ml" if medial_to_lateral else "",
+                qv_str,
+            )
+            str_f = str_f.strip()
+            parse_args(str_f.split())
+        except Exception as e:
+            tk.messagebox.showerror(str(type(e)), e.__str__())
+
+    def __init_manager(self):
+        self.manager[self.__LOAD_PATH_KEY] = tk.StringVar()
+        self.manager[self.__TISSUES_KEY] = [
+            tk.BooleanVar() for i in range(len(knee.SUPPORTED_TISSUES))
+        ]
+        self.manager[self.__QUANTITATIVE_VALUES_KEY] = [
+            tk.BooleanVar() for i in range(len(SUPPORTED_QUANTITATIVE_VALUES))
+        ]
+
+        self.manager[self.__PID_KEY] = tk.StringVar()
+        self.manager[self.__MEDIAL_TO_LATERAL_ORIENTATION_KEY] = tk.BooleanVar()
+
+    def __display_pid_info(self):
+        hb = tk.Frame(self)
+        hb.pack(side="top", anchor="nw")
+        _label = tk.Label(hb, text=self.__PID_KEY.upper())
+        _label.pack(side="left", anchor="w", pady=10)
+        t = tk.Entry(hb, textvariable=self.manager[self.__PID_KEY])
+        t.pack(side="left", anchor="w", pady=10)
+        self.balloon.bind(_label, "Patient id")
+
+    def __display_data_loader(self):
+        hb = tk.Frame(self)
+
+        filedialog = FileDialogReader(self.manager[self.__LOAD_PATH_KEY])
+        b = tk.Button(
+            hb,
+            text=self.__LOAD_PATH_KEY,
+            command=lambda fd=filedialog: self.manager[self.__LOAD_PATH_KEY].set(
+                fd.get_save_dirpath()
+            ),
+        )
+        b.pack(side="left", anchor="nw", pady=10)
+
+        _label = tk.Label(hb, textvariable=self.manager[self.__LOAD_PATH_KEY])
+        _label.pack(side="left", anchor="nw", pady=10)
+
+        hb.pack(side="top", anchor="nw")
+
+    def __display_multi_option(self, label, options_list, boolvar_list):
+        hb = tk.Frame(self)
+        _label = tk.Label(hb, text="%s:" % label)
+        _label.pack(side="left", anchor="w")
+        hb.pack(side="top", anchor="nw")
+        frames = [tk.Frame(hb)] * (len(options_list) // 3 + 1)
+        for ind, option in enumerate(options_list):
+            f = frames[ind // 3]
+            b = tk.Checkbutton(f, text=option, variable=boolvar_list[ind])
+            b.pack(side="top", anchor="nw", pady=5)
+
+        for f in frames:
+            f.pack(side="left", anchor="nw")
+
+        return hb
+
+    def __display_tissues(self):
+        tissue_names = [x.FULL_NAME for x in knee.SUPPORTED_TISSUES]
+        _label = self.__display_multi_option(
+            self.__TISSUES_KEY, tissue_names, self.manager[self.__TISSUES_KEY]
+        )
+        self.balloon.bind(_label, "Tissues to analyze")
+
+    def __display_quant_vals(self):
+        quantitative_value_names = [x.name for x in SUPPORTED_QUANTITATIVE_VALUES]
+        _label = self.__display_multi_option(
+            self.__QUANTITATIVE_VALUES_KEY,
+            quantitative_value_names,
+            self.manager[self.__QUANTITATIVE_VALUES_KEY],
+        )
+        self.balloon.bind(_label, "Quantitative values to analyze")
+
+    def __display_knee_info(self):
+        hb = tk.Frame(self)
+        hb.pack(side="top", anchor="nw")
+        _label = tk.Label(hb, text="Medial -> Lateral: ")
+        _label.pack(side="left", anchor="w", pady=10)
+        t = tk.Checkbutton(hb, variable=self.manager[self.__MEDIAL_TO_LATERAL_ORIENTATION_KEY])
+        t.pack(side="left", anchor="w", pady=10)
+
+        self.balloon.bind(_label, "Select if Dicoms proceed in medial->lateral direction")
+
+    def __base_gui(self):
+        self.__display_data_loader()
+        self.__display_pid_info()
+        self.__display_tissues()
+        self.__display_knee_info()
+        self.__display_quant_vals()
+
+
+class DosmaFrame(tk.Frame):
+    __SCAN_KEY = "Scan"
+    __TISSUES_KEY = "Tissues"
+
+    __DICOM_PATH_KEY = "Read dicoms"
+    __LOAD_PATH_KEY = "Load data"
+
+    __SAVE_PATH_KEY = "Save path"
+
+    __DATA_KEY = "data"  # Track option menu for dicom/load path
+    __DATA_PATH_KEY = "datapath"  # Track filepath associated with option menu
+
+    __IGNORE_EXTENSION_KEY = "Ignore extension"
+
+    def __init__(self, parent, controller):
+        tk.Frame.__init__(self, parent)
+
+        self.file_dialog_reader = FileDialogReader()
+
+        self.manager: Dict = {}
+        self.gui_manager: Dict = {}
+        self.balloon = Pmw.Balloon()
+
+        self.__init_manager()
+
+        self.__base_gui()
+        self.preferences = PreferencesManager()
+        self.scan_reader = ScanReader(self)
+
+        button1 = ttk.Button(self, text="Home", command=lambda: controller.show_frame(StartPage))
+        button1.pack(anchor="se", side="right")
+
+        button1 = ttk.Button(self, text="Run", command=lambda: self.execute())
+        button1.pack(anchor="sw", side="left")
+
+        self.InitUI()
+
+    def execute(self):
+        try:
+            save_path = self.manager[self.__SAVE_PATH_KEY].get()
+            if not save_path:
+                raise ValueError("Save path not defined")
+
+            action_str = self.scan_reader.get_cmd_line_str().strip()
+
+            if not action_str:
+                raise ValueError("No action selected")
+
+            preferences_str = self.preferences.get_cmd_line_str().strip()
+
+            source = "d"
+            if self.manager[self.__DATA_KEY].get() == self.__LOAD_PATH_KEY:
+                source = "l"
+
+            tissue_str = ""
+            for c, t in enumerate(self.manager[self.__TISSUES_KEY]):
+                if t.get():
+                    tissue_str += "--%s " % knee.SUPPORTED_TISSUES[c].STR_ID
+            tissue_str = tissue_str.strip()
+
+            if not tissue_str:
+                raise ValueError("No tissues selected")
+
+            ignore_ext = self.manager[self.__IGNORE_EXTENSION_KEY].get()
+
+            str_f = "--%s %s --s %s %s %s %s %s %s" % (
+                source,
+                self.manager[self.__DATA_PATH_KEY].get(),
+                save_path,
+                preferences_str,
+                "--ignore_ext" if ignore_ext else "",
+                self.manager[self.__SCAN_KEY].get(),
+                tissue_str,
+                action_str,
+            )
+
+            _logger.info("CMD LINE INPUT: %s" % str_f)
+
+            parse_args(str_f.split())
+        except Exception as e:
+            tk.messagebox.showerror(str(type(e)), e.__str__())
+
+    def __init_manager(self):
+        self.manager[self.__SCAN_KEY] = tk.StringVar()
+        self.manager[self.__TISSUES_KEY] = [
+            tk.BooleanVar() for i in range(len(knee.SUPPORTED_TISSUES))
+        ]
+        self.manager[self.__DATA_KEY] = tk.StringVar()
+        self.manager[self.__DATA_PATH_KEY] = tk.StringVar()
+
+        self.manager[self.__SCAN_KEY].trace_add("write", self.__on_scan_change)
+        self.manager[self.__SAVE_PATH_KEY] = tk.StringVar()
+        self.manager[self.__IGNORE_EXTENSION_KEY] = tk.BooleanVar()
+
+    def __on_scan_change(self, *args):
+        scan_id = self.manager[self.__SCAN_KEY].get()
+        scan = None
+        for x in SUPPORTED_SCAN_TYPES:
+            if x.NAME == scan_id:
+                scan = x
+
+        self.scan_reader.load_scan(scan)
+
+        assert scan is not None, "No scan selected"
+
+    def __update_svar(self, *args):
+        svar = self.manager[self.__DATA_PATH_KEY]
+        selected_option = self.manager[self.__DATA_KEY].get()
+        if selected_option == self.__DICOM_PATH_KEY:
+            fp = self.file_dialog_reader.get_volume_filepath(
+                selected_option, im_type=fio_utils.ImageDataFormat.dicom
+            )
+        elif selected_option == self.__LOAD_PATH_KEY:
+            fp = self.file_dialog_reader.get_dirpath(selected_option)
+        else:
+            raise ValueError("%s key not found" % self.__DATA_KEY)
+
+        if not fp:
+            svar.set("")
+            return
+
+        svar.set(fp)
+
+        if selected_option == self.__LOAD_PATH_KEY:
+            self.manager[self.__SAVE_PATH_KEY].set(fp)
+
+    def __display_data_loader(self):
+        s_var = self.manager[self.__DATA_PATH_KEY]
+
+        hb = tk.Frame(self)
+
+        label = tk.Label(hb, text="Data source: ")
+        label.pack(side="left", anchor="nw", pady=10)
+
+        options = [self.__DICOM_PATH_KEY, self.__LOAD_PATH_KEY]
+        menu = tk.OptionMenu(
+            hb, self.manager[self.__DATA_KEY], *options, command=self.__update_svar
+        )
+        menu.pack(side="left", anchor="nw", pady=10)
+
+        label = tk.Label(hb, textvariable=s_var)
+        label.pack(side="left", anchor="nw", pady=10)
+
+        hb.pack(side="top", anchor="nw")
+        self.balloon.bind(hb, "Read dicoms or load data")
+
+        hb = tk.Frame(self)
+
+        # filedialog = FileDialogReader(self.manager[self.__SAVE_PATH_KEY])
+        b = tk.Button(
+            hb,
+            text=self.__SAVE_PATH_KEY,
+            command=lambda fd=self.file_dialog_reader: self.manager[self.__SAVE_PATH_KEY].set(
+                fd.get_save_dirpath()
+            ),
+        )
+        b.pack(side="left", anchor="nw", pady=10)
+
+        label = tk.Label(hb, textvariable=self.manager[self.__SAVE_PATH_KEY])
+        label.pack(side="left", anchor="nw", pady=10)
+
+        hb.pack(side="top", anchor="nw")
+
+        hb = tk.Frame(self)
+
+        b = tk.Checkbutton(
+            hb, text=self.__IGNORE_EXTENSION_KEY, variable=self.manager[self.__IGNORE_EXTENSION_KEY]
+        )
+        b.pack(side="left", anchor="nw", pady=10)
+        self.balloon.bind(b, "Ignore '.dcm' extension when loading dicoms")
+
+        hb.pack(side="top", anchor="nw")
+
+    def __display_tissues(self):
+        hb = tk.Frame(self)
+        _label = tk.Label(hb, text="Tissues:")
+        _label.pack(side="left", anchor="w")
+        hb.pack(side="top", anchor="nw")
+        frames = [tk.Frame(hb)] * (len(knee.SUPPORTED_TISSUES) // 3 + 1)
+        for ind, tissue in enumerate(knee.SUPPORTED_TISSUES):
+            f = frames[ind // 3]
+            b = tk.Checkbutton(
+                f, text=tissue.FULL_NAME, variable=self.manager[self.__TISSUES_KEY][ind]
+            )
+            b.pack(side="top", anchor="nw", pady=5)
+
+        for f in frames:
+            f.pack(side="left", anchor="nw")
+
+        self.balloon.bind(_label, "Tissues to analyze")
+
+    def __base_gui(self):
+        self.__display_data_loader()
+        self.__display_tissues()
+
+        hb = tk.Frame(self)
+        scan_label = tk.Label(hb, text="Scan:")
+        scan_label.pack(side="left", anchor="nw", pady=10)
+        options = [x.NAME for x in SUPPORTED_SCAN_TYPES]
+        scan_dropdown = tk.OptionMenu(hb, self.manager[self.__SCAN_KEY], *options)
+        scan_dropdown.pack(side="left", anchor="nw", pady=10)
+        hb.pack(side="top", anchor="nw")
+
+    def InitUI(self):
+        self.text_box = tk.Text(self, wrap="word", height=11, width=50)
+        self.text_box.pack(anchor="s", side="bottom")
+
+
+class PageThree(tk.Frame):
+    SUPPORTED_FORMATS = (("nifti files", "*.nii\.gz"), ("dicom files", "*.dcm"))  # noqa: W605
+    __base_filepath = "../"
+
+    _ORIENTATIONS = [("sagittal", SAGITTAL), ("coronal", CORONAL), ("axial", AXIAL)]
+
+    def __init__(self, parent, controller):
+        tk.Frame.__init__(self, parent)
+        self._im_display = None
+        self.binding_vars: Dict = {}
+        fig, ax = plt.subplots(1, 1)
+        X = np.random.rand(20, 20, 40)
+
+        self.tracker = IndexTracker(ax, X)
+
+        canvas = FigureCanvasTkAgg(fig, self)
+        canvas.draw()
+        canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
+        canvas.mpl_connect("scroll_event", self.tracker.onscroll)
+
+        toolbar = NavigationToolbar2Tk(canvas, self)
+        toolbar.update()
+        canvas._tkcanvas.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
+
+        self.im = None
+        self.mask = None
+        self._im_display = None
+
+        button1 = ttk.Button(
+            self, text="Back to Home", command=lambda: controller.show_frame(StartPage)
+        )
+        button1.pack(side=tk.BOTTOM, anchor="sw")
+
+        button2 = ttk.Button(self, text="Load main image", command=self.load_volume_callback)
+        button2.pack()
+
+        button3 = ttk.Button(self, text="Load mask", command=self.load_mask_callback)
+        button3.pack()
+
+        self.init_reformat_display()
+
+    def __reformat_callback(self, *args):
+        self.im_update()
+
+    def init_reformat_display(self):
+        orientation_var = IntVar(0)
+        orientation_var.trace_add("write", self.__reformat_callback)
+        count = 0
+        for text, _value in self._ORIENTATIONS:
+            b = Radiobutton(self, text=text, variable=orientation_var, value=count)
+            b.pack(side=tk.TOP, anchor="w")
+            count += 1
+        self._orientation = orientation_var
+
+    def load_volume_callback(self):
+        im = self.load_volume()
+        if not im:
+            return
+        self.im = im
+        self.mask = None
+
+        self.im_update()
+
+    def load_mask_callback(self):
+        if not self.im:
+            messagebox.showerror("Loading mask failed", "Main image must be loaded prior to mask")
+            return
+
+        mask = self.load_volume("Load mask")
+        mask.reformat(self.im.orientation, inplace=True)
+        try:
+            self.__verify_mask_size(self.im.volume, mask.volume)
+        except Exception as e:
+            messagebox.showerror("Loading mask failed", str(e))
+            return
+
+        self.mask = mask
+        self.im_update()
+
+    def __verify_mask_size(self, im: np.ndarray, mask: np.ndarray):
+        if mask.ndim != 3:
+            raise ValueError("Dimension mismatch. Mask must be 3D")
+        if im.shape != mask.shape:
+            raise ValueError(
+                "Dimension mismatch. Image of shape %s, but mask of shape %s"
+                % (str(im.shape), str(mask.shape))
+            )
+
+    def im_update(self):
+        orientation = self.orientation
+        self.im.reformat(orientation, inplace=True)
+        im = self.im.volume
+        im = im / np.max(im)
+        if self.mask:
+            self.mask.reformat(orientation, inplace=True)
+            label_image = label(self.mask.volume)
+            im = self.__labeltorgb_3d__(im, label_image, 0.3)
+
+        self.im_display = im
+
+    def __labeltorgb_3d__(self, im: np.ndarray, labels: np.ndarray, alpha: float = 0.3):
+        im_rgb = np.zeros(im.shape + (3,))  # rgb channel
+        for s in range(im.shape[2]):
+            im_slice = im[..., s]
+            labels_slice = labels[..., s]
+            im_rgb[..., s, :] = label2rgb(labels_slice, image=im_slice, bg_label=0, alpha=alpha)
+        return im_rgb
+
+    def load_volume(self, title="Select volume file(s)"):
+        files = filedialog.askopenfilenames(initialdir=self.__base_filepath, title=title)
+        if len(files) == 0:
+            return
+
+        filepath = files[0]
+        self.__base_filepath = os.path.dirname(filepath)
+
+        if filepath.endswith(".dcm"):
+            filepath = os.path.dirname(filepath)
+
+        im = fio_utils.generic_load(filepath, 1)
+
+        return im
+
+    @property
+    def orientation(self):
+        ind = self._orientation.get()
+        return self._ORIENTATIONS[ind][1]
+
+    @property
+    def im_display(self):
+        return self._im_display
+
+    @im_display.setter
+    def im_display(self, value):
+        self._im_display = value
+        self.tracker.x = self._im_display