Switch to side-by-side view

--- a
+++ b/slideflow/studio/widgets/segment.py
@@ -0,0 +1,843 @@
+import os
+import torch
+import slideflow as sf
+import imgui
+import glfw
+import segmentation_models_pytorch as smp
+from typing import Optional, List
+from os.path import join, dirname, abspath, exists
+from threading import Thread
+from tkinter.filedialog import askopenfilename, askdirectory
+from slideflow.segment import TileMaskDataset
+from slideflow.model.torch_utils import get_device
+from collections import defaultdict
+
+from ._utils import Widget
+from ..gui import imgui_utils
+from ..utils import LEFT_MOUSE_BUTTON, RIGHT_MOUSE_BUTTON
+from .slide import stride_capture
+
+from pytorch_lightning.callbacks import Callback
+
+class ProgressCallback(Callback):
+
+    def __init__(self, toast, max_epochs):
+        super().__init__()
+        self.toast = toast
+        self.max_epochs = max_epochs
+
+    def on_train_epoch_end(self, trainer, pl_module):
+        percent = (trainer.current_epoch + 1) / self.max_epochs
+        self.toast.set_progress(min(percent, 1.))
+
+# ----------------------------------------------------------------------------
+
+
+class TissueSegWidget(Widget):
+
+    tag = 'segment'
+    description = 'Tissue Segmentation'
+    icon = join(dirname(abspath(__file__)), '..', 'gui', 'buttons', 'button_segment.png')
+    icon_highlighted = join(dirname(abspath(__file__)), '..', 'gui', 'buttons', 'button_segment_highlighted.png')
+
+    def __init__(self, viz):
+        self.viz                    = viz
+        self._segment               = None
+        self._thread                = None
+        self._load_toast            = None
+        self._working_toast         = None
+        self._training_toast        = None
+        self._show_params           = False
+        self._rois_at_start         = 0
+        self._need_to_refresh_rois  = False
+        self._clicking              = False
+        self._show_popup            = False
+        self._load_slide_popup      = None
+        self._load_slide_popup_coords = None
+
+        # Parameters
+        self._supported_archs       = ['FPN', 'DeepLabV3', 'DeepLabV3Plus', 'Linknet', 'MAnet', 'PAN', 'PSPNet', 'Unet', 'UnetPlusPlus']
+        self._selected_arch         = 0
+        self._supported_encoders    = smp.encoders.get_encoder_names()
+        self._selected_encoder      = self._supported_encoders.index('resnet34')
+        self._filter_methods        = ['otsu', 'roi']
+        self._selected_filter_method = 0
+        self._training_modes        = ['binary', 'multiclass', 'multilabel']
+        self._selected_training_mode = 0
+        self.max_epochs             = 20
+        self.tile_px                = 1024
+        self.tile_um                = 2048
+        self.crop_margin            = 256
+        self.stride                 = 1
+        self._capturing_stride      = 1
+        self._selected_slides       = defaultdict(bool)
+        self._unique_training_classes = dict()
+        self._sq_mm_threshold       = 0.01
+
+
+    # --- Properties ---
+
+    @property
+    def cfg(self) -> sf.segment.SegmentConfig:
+        seg = self._segment
+        return None if seg is None else seg.cfg
+
+    @property
+    def arch(self) -> str:
+        return self._supported_archs[self._selected_arch]
+
+    @property
+    def encoder(self) -> str:
+        return self._supported_encoders[self._selected_encoder]
+
+    @property
+    def mpp(self) -> float:
+        return self.tile_um / self.tile_px
+
+    @property
+    def filter_method(self) -> str:
+        return self._filter_methods[self._selected_filter_method]
+
+    @property
+    def mode(self) -> str:
+        return self._training_modes[self._selected_training_mode]
+
+    # --- Internal ---
+
+    def get_training_slides(self) -> List[str]:
+        return [slide for slide in list(self._selected_slides.keys())
+                if self._selected_slides[slide]]
+
+    def get_training_classes(self) -> List[str]:
+        return [(k if k != '<No label>' else None)
+                for k, v in self._unique_training_classes.items() if v]
+
+    def close(self):
+        pass
+
+    def is_thread_running(self):
+        return self._thread is not None and self._thread.is_alive()
+
+    def is_training(self):
+        return self._training_toast is not None
+
+    def drag_and_drop_hook(self, path, ignore_errors=False) -> bool:
+        """Handle file paths provided via drag-and-drop."""
+        if (sf.util.path_to_ext(path).lower() == 'pth'):
+            if exists(join(dirname(path), 'segment_params.json')):
+                self.load(path, ignore_errors=ignore_errors)
+                return True
+        return False
+
+    # --- Model loading ---
+
+    def ask_load_model(self) -> str:
+        model_path = askopenfilename(
+            title="Load model...",
+            filetypes=[("pth", ".pth"), ("All files", ".*")]
+        )
+        if model_path:
+            self.load(model_path)
+
+    def ask_export_model(self) -> Optional[str]:
+        destination = askdirectory(
+            title="Export model (choose directory)..."
+        )
+        if destination:
+            model_path = sf.util.get_new_model_dir(destination, 'segment')
+            self.export(model_path)
+        return model_path
+
+    def export(self, path: str) -> None:
+        """Export a tissue segmentation model."""
+        if self._segment is None:
+            return
+        if not exists(path):
+            os.makedirs(path)
+        model_path = join(path, 'model.pth')
+        torch.save(self._segment.model.state_dict(), model_path)
+        self._segment.cfg.to_json(join(path, 'segment_params.json'))
+        self._segment.model_path = model_path
+        self.viz.create_toast(f"Model exported to {model_path}", icon="success")
+
+    def load(self, path, ignore_errors=False):
+        """Load a tissue segmentation model."""
+        if self.is_thread_running():
+            self._thread.join()
+        self._load_toast = self.viz.create_toast(
+            title=f"Loading segmentation model",
+            icon='info',
+            sticky=True,
+            spinner=True)
+        self._thread = Thread(target=self._load_model, args=(path, ignore_errors))
+        self._thread.start()
+
+    def _load_model(self, path, ignore_errors=False):
+        try:
+            self._segment = sf.slide.qc.StridedSegment(path)
+            self._segment.model.to(get_device())
+        except Exception as e:
+            if self._load_toast is not None:
+                self._load_toast.done()
+            sf.log.error(f"Error loading segment model: {e}")
+            self.viz.create_toast(f"Error loading segment model: {e}", icon="error")
+            self._segment = None
+        else:
+            if self._load_toast is not None:
+                self._load_toast.done()
+            self.viz.create_toast(
+                f"Loaded model at {path}.",
+                icon="success"
+            )
+
+    def close_model(self) -> None:
+        self._segment = None
+
+    def generate_rois(self):
+        """Generate ROIs from the loaded segmentation model."""
+        if self.is_thread_running():
+            self.viz.create_toast("Failed to start thread.", icon="error")
+            return
+        self._rois_at_start = len(self.viz.wsi.rois)
+        self._working_toast = self.viz.create_toast(
+            title="Generating ROIs",
+            message=f"Generating ROIs from segmentation model.",
+            icon='info',
+            sticky=True,
+            spinner=True)
+        self._thread = Thread(target=self._generate_rois)
+        self._thread.start()
+
+    def _generate_rois(self):
+        viz = self.viz
+        self._segment.generate_rois(
+            viz.wsi,
+            sq_mm_threshold=self._sq_mm_threshold,
+            simplify_tolerance=5
+        )
+        self._need_to_refresh_rois = True
+        if self._working_toast is not None:
+            self._working_toast.done()
+        viz.create_toast(
+            "Generated {} ROIs.".format(
+                len(self.viz.wsi.rois) - self._rois_at_start
+            ),
+            icon="success"
+        )
+
+    def train(self) -> None:
+        """Train a segmentation model."""
+        if self.is_thread_running():
+            self.viz.create_toast("Failed to start thread.", icon="error")
+            return
+
+        # Create a progress toast.
+        if self._training_toast is not None:
+            self._training_toast.done()
+        self._training_toast = self.viz.create_toast(
+            title="Training segmentation model",
+            icon='info',
+            sticky=True,
+            progress=True,
+            spinner=True
+        )
+        self._thread = Thread(target=self._train)
+        self._thread.start()
+
+    def finetune(self) -> None:
+        """Finetune a segmentation model."""
+        if self.is_thread_running():
+            self.viz.create_toast("Failed to start thread.", icon="error")
+            return
+
+        # Create a progress toast.
+        if self._training_toast is not None:
+            self._training_toast.done()
+        self._training_toast = self.viz.create_toast(
+            title="Finetuning segmentation model",
+            icon='info',
+            sticky=True,
+            progress=True,
+            spinner=True
+        )
+        self._thread = Thread(target=self._finetune)
+        self._thread.start()
+
+    def _train(self) -> None:
+        """Train a segmentation model."""
+        import pytorch_lightning as pl
+
+        viz = self.viz
+
+        # Prepare the slideflow dataset.
+        dataset = viz.P.dataset(filters={'slide': self.get_training_slides()})
+
+        # Determine the labels, if necessary.
+        all_roi_labels = self.get_training_classes()
+        if self.mode == 'binary':
+            out_classes = 1
+        elif self.mode == 'multiclass':
+            out_classes = len(all_roi_labels) + 1
+        else:
+            out_classes = len(all_roi_labels)
+
+        # Prepare the tile-mask dataset.
+        dts = TileMaskDataset(
+            dataset,
+            tile_px=self.tile_px,
+            tile_um=self.tile_um,
+            stride_div=self.stride,
+            crop_margin=self.crop_margin,
+            filter_method=self.filter_method,
+            roi_labels=all_roi_labels,
+            mode=self.mode
+        )
+
+        # Set the configuration.
+        config = sf.segment.SegmentConfig(
+            arch=self.arch,
+            encoder_name=self.encoder,
+            epochs=self.max_epochs,  # 100
+            mpp=self.mpp,
+            mode=self.mode,
+            out_classes=out_classes,
+            labels=(all_roi_labels if self.mode != 'binary' else None)
+        )
+
+        # Create dataloader.
+        train_dl = torch.utils.data.DataLoader(
+            dts,
+            batch_size=config.train_batch_size,
+            shuffle=True,
+            num_workers=4,
+            drop_last=True,
+            persistent_workers=True
+        )
+
+        # Build the model and trainer.
+        model = config.build_model()
+        trainer = pl.Trainer(
+            max_epochs=config.epochs,
+            devices=1,   # Distributed training not supported in a GUI.
+            num_nodes=1, # Distributed training not supported in a GUI.
+            callbacks=[ProgressCallback(self._training_toast, config.epochs)]
+        )
+
+        # Train the model.
+        trainer.fit(model, train_dataloaders=train_dl)
+
+        # Move model to eval & appropriate device.
+        model.eval()
+        model.to(get_device())
+
+        # Create the segment object.
+        self._segment = sf.slide.qc.StridedSegment.from_model(model, config)
+
+        # Cleanup.
+        self._training_toast.done()
+        self._training_toast = None
+        self.viz.create_toast("Training complete.", icon="success")
+
+    def _finetune(self) -> None:
+        """Finetune a segmentation model."""
+        import pytorch_lightning as pl
+
+        viz = self.viz
+        if not self._segment:
+            self.viz.create_toast("Cannot finetune; no model loaded.", icon="error")
+            return
+
+        # Prepare the dataset.
+        dataset = viz.P.dataset(filters={'slide': self.get_training_slides()})
+        dts = TileMaskDataset(
+            dataset,
+            tile_px=self.tile_px,
+            tile_um=self.tile_um,
+            stride_div=self.stride,
+            crop_margin=self.crop_margin,
+            filter_method=self.filter_method
+        )
+
+        # Set the configuration.
+        config = sf.segment.SegmentConfig(
+            arch=self.arch,
+            encoder_name=self.encoder,
+            epochs=self.max_epochs,  # 100
+            mpp=self.mpp,
+            mode=self.mode,
+        )
+
+        # Create dataloader.
+        train_dl = torch.utils.data.DataLoader(
+            dts,
+            batch_size=config.train_batch_size,
+            shuffle=True,
+            num_workers=4,
+            drop_last=True
+        )
+
+        # Build the model and trainer.
+        trainer = pl.Trainer(
+            max_epochs=config.epochs,
+            devices=1,   # Distributed training not supported in a GUI.
+            num_nodes=1, # Distributed training not supported in a GUI.
+            callbacks=[ProgressCallback(self._training_toast, config.epochs)]
+        )
+
+        # Train the model.
+        self._segment.model.train()
+        trainer.fit(self._segment.model, train_dataloaders=train_dl)
+
+        # Move model to eval & appropriate device.
+        self._segment.model.eval()
+        self._segment.model.to(get_device())
+
+        # Cleanup.
+        self._training_toast.done()
+        self._training_toast = None
+        self.viz.create_toast("Finetuning complete.", icon="success")
+
+    # --- Callbacks ---
+
+    def keyboard_callback(self, key: int, action: int) -> None:
+        """Handle keyboard events.
+
+        Args:
+            key (int): The key that was pressed. See ``glfw.KEY_*``.
+            action (int): The action that was performed (e.g. ``glfw.PRESS``,
+                ``glfw.RELEASE``, ``glfw.REPEAT``).
+
+        """
+        if (key == glfw.KEY_SPACE and action == glfw.PRESS and self.viz._control_down):
+            can_generate_rois = (
+                not self.is_thread_running()
+                and (self._segment is not None)
+                and (self.viz.wsi is not None)
+                and not self.is_training()
+            )
+            if can_generate_rois:
+                self.generate_rois()
+
+
+    # --- Drawing ---
+
+    def draw_info(self):
+        """Draw information about the loaded model."""
+        viz = self.viz
+
+        rows = [
+            ['Architecture', self.cfg.arch],
+            ['Encoder',      self.cfg.encoder_name],
+            ['Mode',         self.cfg.mode],
+            ['Classes',      self.cfg.out_classes],
+            ['MPP',          self.cfg.mpp, 'Microns per pixel (optical resolution)']
+        ]
+        imgui.text_colored('Model', *viz.theme.dim)
+        imgui.same_line(viz.font_size * 6)
+        model_path = self._segment.model_path or 'None'
+        with imgui_utils.clipped_with_tooltip(model_path, 22):
+            imgui.text(imgui_utils.ellipsis_clip(model_path, 22))
+        for y, cols in enumerate(rows):
+            for x, col in enumerate(cols):
+                if x != 0:
+                    imgui.same_line(viz.font_size * (6 + (x - 1) * 6))
+                if x == 0:
+                    imgui.text_colored(str(col), *viz.theme.dim)
+                    if len(cols) == 3 and imgui.is_item_hovered():
+                        imgui.set_tooltip(cols[2])
+                elif x == 1:
+                    imgui.text(str(col))
+
+        imgui.same_line(imgui.get_content_region_max()[0] - viz.font_size - viz.spacing * 2)
+        if imgui.button("HP"):
+            self._show_params = not self._show_params
+
+        imgui_utils.vertical_break()
+
+    def draw_train_data_source(self) -> None:
+        """Draw training data source options."""
+        viz = self.viz
+
+        # Slide sources
+        width = imgui.get_content_region_max()[0] - viz.spacing
+
+        changed = False
+        with imgui.begin_list_box("##segment_data_source", width, 150) as list_box:
+            if list_box.opened:
+                if self.viz.P is None:
+                    imgui.text("No project loaded.")
+                else:
+                    for slide_path in self.viz.project_widget.slide_paths:
+                        name = sf.util.path_to_name(slide_path)
+                        with self.viz.bold_font(self.viz.wsi is not None and slide_path == self.viz.wsi.path):
+                            _clicked, self._selected_slides[name] = imgui.selectable(name, self._selected_slides[name])
+                            if _clicked:
+                                changed = True
+                            if imgui.is_item_hovered():
+                                imgui.set_tooltip(slide_path)
+                                if imgui.is_mouse_down(RIGHT_MOUSE_BUTTON):
+                                    self._load_slide_popup = slide_path
+                                if imgui.is_mouse_double_clicked(LEFT_MOUSE_BUTTON):
+                                    self.viz.load_slide(slide_path)
+        if imgui_utils.button('Select All'):
+            changed = True
+            for name in self._selected_slides:
+                self._selected_slides[name] = True
+
+        imgui.same_line()
+        if imgui_utils.button('With ROIs'):
+            changed = True
+            _rois = [sf.util.path_to_name(r) for r in self.viz.P.dataset().rois()]
+            for name in self._selected_slides:
+                if name in _rois:
+                    self._selected_slides[name] = True
+                else:
+                    self._selected_slides[name] = False
+
+        imgui.same_line()
+        if imgui_utils.button('Select None'):
+            changed = True
+            for name in self._selected_slides:
+                self._selected_slides[name] = False
+
+        imgui.text("{} slides selected".format(sum(self._selected_slides.values())))
+
+        # Update the unique training classes.
+        if changed:
+            dataset = viz.P.dataset(filters={'slide': self.get_training_slides()}, verification=None)
+            _unique = dataset.get_unique_roi_labels(allow_empty=True)
+            _unique = [k if k is not None else '<No label>' for k in _unique]
+            self._unique_training_classes = {
+                k: (True if k not in self._unique_training_classes else self._unique_training_classes[k])
+                for k in _unique
+            }
+
+        imgui_utils.vertical_break()
+
+    def draw_class_selection(self) -> None:
+        """Draw class selection multi-select box."""
+        viz = self.viz
+        imgui.text_colored('Classes', *viz.theme.dim)
+        imgui.same_line(viz.label_w)
+
+        # Class selection
+        width = imgui.get_content_region_max()[0] - viz.spacing - viz.label_w
+        with imgui.begin_list_box("##segment_class_select", width, viz.font_size * 5) as list_box:
+            if list_box.opened:
+                for _class in self._unique_training_classes:
+                    _, self._unique_training_classes[_class] = imgui.selectable(_class, self._unique_training_classes[_class])
+
+        imgui.text('')
+        imgui.same_line(viz.label_w)
+        if imgui_utils.button('Select All##segment_class_select_all'):
+            for _class in self._unique_training_classes:
+                self._unique_training_classes[_class] = True
+
+        imgui.same_line()
+        if imgui_utils.button('Select None##segment_class_select_none'):
+            for _class in self._unique_training_classes:
+                self._unique_training_classes[_class] = False
+
+        imgui_utils.vertical_break()
+
+    def draw_train_data_processing(self) -> None:
+        """Draw training data processing options."""
+        viz = self.viz
+
+        # Tile size.
+        imgui.text_colored('Tile size', *viz.theme.dim)
+        imgui.same_line(viz.label_w)
+        with imgui_utils.item_width(viz.font_size * 3):
+            _, self.tile_px = imgui.input_int(
+                "##segment_tile_px",
+                self.tile_px,
+                step=0,
+            )
+        imgui.same_line()
+        imgui.text('px')
+        imgui.text('')
+        imgui.same_line(viz.label_w)
+        with imgui_utils.item_width(viz.font_size * 3):
+            _, self.tile_um = imgui.input_int(
+                "##segment_tile_um",
+                self.tile_um,
+                step=0,
+            )
+        imgui.same_line()
+        imgui.text('um')
+        imgui.same_line()
+        imgui.text('(MPP={:.2f})'.format(self.mpp))
+
+        # Crop margin.
+        imgui.text_colored('Margin', *viz.theme.dim)
+        if imgui.is_item_hovered():
+            imgui.set_tooltip("Margin for random cropping during training.")
+        imgui.same_line(viz.label_w)
+        with imgui_utils.item_width(viz.font_size * 6):
+            _, self.crop_margin = imgui.input_int(
+                "##segment_crop_margin",
+                self.crop_margin,
+                step=16,
+            )
+            self.crop_margin = max(0, self.crop_margin)
+        imgui.same_line()
+        imgui.text('px')
+
+        # Stride.
+        imgui.text_colored('Stride', *viz.theme.dim)
+        if imgui.is_item_hovered():
+            imgui.set_tooltip("Stride for tiling the slide.")
+        self.stride, self._capturing_stride, _ = stride_capture(
+            viz,
+            self.stride,
+            self._capturing_stride,
+            max_value=16,
+            label='Stride',
+            draw_label=False,
+            offset=viz.label_w,
+            width=imgui.get_content_region_max()[0] - viz.label_w - (viz.spacing)
+        )
+
+        # Filter method.
+        imgui.text_colored('Filter', *viz.theme.dim)
+        if imgui.is_item_hovered():
+            imgui.set_tooltip(
+                "Method for filtering tiles.\n"
+                "If 'otsu', tiles are filtered using Otsu's thresholding.\n"
+                "If 'roi', only tiles touching an ROI are used."
+            )
+        imgui.same_line(viz.label_w)
+        _, self._selected_filter_method = imgui.combo(
+            "##segment_filter_method",
+            self._selected_filter_method,
+            self._filter_methods
+        )
+
+        imgui_utils.vertical_break()
+
+    def draw_train_params(self) -> None:
+        """Draw training architecture & hyperparameter options."""
+        viz = self.viz
+
+        # === Architecture & training parameters ===
+        # Architecture.
+        imgui.text_colored('Arch', *viz.theme.dim)
+        if imgui.is_item_hovered():
+            imgui.set_tooltip("Model architecture")
+        imgui.same_line(viz.label_w)
+        _, self._selected_arch = imgui.combo(
+            "##segment_arch",
+            self._selected_arch,
+            self._supported_archs
+        )
+        # Encoder.
+        imgui.text_colored('Encoder', *viz.theme.dim)
+        imgui.same_line(viz.label_w)
+        _, self._selected_encoder = imgui.combo(
+            "##segment_encoder",
+            self._selected_encoder,
+            self._supported_encoders
+        )
+        # Training mode.
+        imgui.text_colored('Mode', *viz.theme.dim)
+        imgui.same_line(viz.label_w)
+        _, self._selected_training_mode = imgui.combo(
+            "##segment_training_mode",
+            self._selected_training_mode,
+            self._training_modes
+        )
+        # Max epochs.
+        imgui.text_colored('Epochs', *viz.theme.dim)
+        imgui.same_line(viz.label_w)
+        _, self.max_epochs = imgui.input_int(
+            "##segment_max_epochs",
+            self.max_epochs,
+            step=1,
+            step_fast=5
+        )
+        # Class selection (for multilabel and multiclass)
+        self.draw_class_selection()
+
+    def draw_training_button(self) -> None:
+        """Draw the training button."""
+        viz = self.viz
+        width = (self.viz.sidebar.content_width - (self.viz.spacing * 4)) / 3
+
+        # Train button.
+        _button_text = "Train" if not self.is_training() else "Training" + imgui_utils.spinner_text()
+        if viz.sidebar.full_button(_button_text, enabled=(sum(self._selected_slides.values()) and not self.is_training()), width=width):
+            self.train()
+        if imgui.is_item_hovered() and viz.P is None:
+            imgui.set_tooltip("No project loaded. Load a project to train a model.")
+
+        # Finetune button.
+        imgui.same_line()
+        if viz.sidebar.full_button2("Finetune", enabled=(sum(self._selected_slides.values()) and not self.is_training() and self._segment is not None), width=width):
+            self.finetune()
+        if imgui.is_item_hovered() and self._segment is None:
+            imgui.set_tooltip("No model loaded. Load a model to finetune.")
+        if imgui.is_item_hovered() and viz.P is None:
+            imgui.set_tooltip("No project loaded. Load a project to export a model.")
+
+        # Export button.
+        imgui.same_line()
+        if viz.sidebar.full_button2("Export", enabled=(self._segment is not None), width=width):
+            self.ask_export_model()
+        if imgui.is_item_hovered() and self._segment is None:
+            imgui.set_tooltip("No model loaded.")
+
+    def draw_apply(self) -> None:
+        """Show a button prompting the user to generate ROIs."""
+        viz = self.viz
+
+        # Label
+        imgui.text_colored('Min mm²', *viz.theme.dim)
+        if imgui.is_item_hovered():
+            imgui.set_tooltip("Filter out ROIs smaller than this area, in square millimeters.")
+
+        # Free input
+        imgui.same_line(viz.label_w)
+        with imgui_utils.item_width(viz.font_size * 3):
+            _changed, _val = imgui.input_float('##small_roi_filter_freetext', self._sq_mm_threshold, format='%.3f')
+            if _changed:
+                self._sq_mm_threshold = _val
+
+        # Slider
+        imgui.same_line(viz.label_w + viz.font_size * 3 + viz.spacing)
+        width = imgui.get_content_region_max()[0] - viz.label_w - viz.font_size * 3 - viz.spacing
+        with imgui_utils.item_width(width):
+            _changed, _val = imgui.slider_float(
+                '##small_roi_filter',
+                self._sq_mm_threshold,
+                min_value=0.0,
+                max_value=1.0,
+                format=''
+            )
+            if _changed:
+                self._sq_mm_threshold = _val
+
+        # Generate ROIs button
+        if viz.sidebar.full_button(
+            'Generate ROIs',
+            enabled=(
+                not self.is_thread_running()
+                and (self._segment is not None)
+                and (viz.wsi is not None)
+                and not self.is_training()
+            )
+        ):
+            self.generate_rois()
+
+    def draw_load_slide_popup(self):
+        viz = self.viz
+        if self._load_slide_popup:
+            if self._load_slide_popup_coords is None:
+                self._load_slide_popup_coords = self.viz.get_mouse_pos(scale=False)
+            cx, cy = self._load_slide_popup_coords
+            imgui.set_next_window_position(cx, cy)
+            imgui.begin(
+                '##segment_load_slide_popup',
+                flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)
+            )
+            if imgui.menu_item('Load')[0]:
+                viz.load_slide(self._load_slide_popup)
+                self._clicking = False
+                self._load_slide_popup = None
+                self._load_slide_popup_coords = None
+
+            # Hide menu if we click elsewhere
+            if imgui.is_mouse_down(LEFT_MOUSE_BUTTON) and not imgui.is_window_hovered():
+                self._clicking = True
+            if self._clicking and imgui.is_mouse_released(LEFT_MOUSE_BUTTON):
+                self._clicking = False
+                self._load_slide_popup = None
+                self._load_slide_popup_coords = None
+
+            imgui.end()
+
+
+    def draw_config_popup(self):
+        viz = self.viz
+        has_model = self._segment is not None
+
+        if self._show_popup:
+            cx, cy = imgui.get_cursor_pos()
+            imgui.set_next_window_position(viz.sidebar.full_width, cy)
+            imgui.begin(
+                '##segment_config_popup',
+                flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)
+            )
+            if imgui.menu_item('Load model', enabled=(not self.is_training()))[0]:
+                self.ask_load_model()
+                self._clicking = False
+                self._show_popup = False
+            if imgui.menu_item('Close model', enabled=has_model)[0]:
+                self.close_model()
+                self._clicking = False
+                self._show_popup = False
+
+            # Hide menu if we click elsewhere
+            if imgui.is_mouse_down(LEFT_MOUSE_BUTTON) and not imgui.is_window_hovered():
+                self._clicking = True
+            if self._clicking and imgui.is_mouse_released(LEFT_MOUSE_BUTTON):
+                self._clicking = False
+                self._show_popup = False
+
+            imgui.end()
+
+    @imgui_utils.scoped_by_object_id
+    def __call__(self, show=True):
+        viz = self.viz
+
+        if show:
+            with viz.header_with_buttons("Tissue Segmentation"):
+                imgui.same_line(imgui.get_content_region_max()[0] - viz.font_size*1.5)
+                cx, cy = imgui.get_cursor_pos()
+                imgui.set_cursor_position((cx, cy-int(viz.font_size*0.25)))
+                if viz.sidebar.small_button('gear'):
+                    self._clicking = False
+                    self._show_popup = not self._show_popup
+                self.draw_config_popup()
+
+        if show and self._segment is None:
+            imgui_utils.padded_text(
+                'Load or train a model.',
+                vpad=[int(viz.font_size/2),
+                      int(viz.font_size)]
+            )
+            if viz.sidebar.full_button("Load a Model", enabled=(not self.is_training())):
+                self.ask_load_model()
+            if imgui.is_item_hovered() and self.is_training():
+                imgui.set_tooltip("Cannot load model while training.")
+            imgui_utils.vertical_break()
+
+        elif show:
+            if viz.collapsing_header('Model Info', default=True):
+                self.draw_info()
+
+        if show:
+            if viz.collapsing_header('Training', default=False):
+
+                if viz.collapsing_header2('Data Source', default=False):
+                    self.draw_train_data_source()
+                    self.draw_load_slide_popup()
+
+                if viz.collapsing_header2('Data Processing', default=False):
+                    self.draw_train_data_processing()
+
+                if viz.collapsing_header2('Arch & Params', default=False):
+                    self.draw_train_params()
+
+                imgui_utils.vertical_break()
+                self.draw_training_button()
+
+                imgui_utils.vertical_break()
+
+            if viz.collapsing_header('Apply', default=True):
+                self.draw_apply()
+
+        # Refresh ROIs if necessary.
+        # Must be in the main loop.
+        if self._need_to_refresh_rois:
+            self._need_to_refresh_rois = False
+            viz.slide_widget.roi_widget.refresh_rois()
\ No newline at end of file