--- a +++ b/slideflow/studio/widgets/segment.py @@ -0,0 +1,843 @@ +import os +import torch +import slideflow as sf +import imgui +import glfw +import segmentation_models_pytorch as smp +from typing import Optional, List +from os.path import join, dirname, abspath, exists +from threading import Thread +from tkinter.filedialog import askopenfilename, askdirectory +from slideflow.segment import TileMaskDataset +from slideflow.model.torch_utils import get_device +from collections import defaultdict + +from ._utils import Widget +from ..gui import imgui_utils +from ..utils import LEFT_MOUSE_BUTTON, RIGHT_MOUSE_BUTTON +from .slide import stride_capture + +from pytorch_lightning.callbacks import Callback + +class ProgressCallback(Callback): + + def __init__(self, toast, max_epochs): + super().__init__() + self.toast = toast + self.max_epochs = max_epochs + + def on_train_epoch_end(self, trainer, pl_module): + percent = (trainer.current_epoch + 1) / self.max_epochs + self.toast.set_progress(min(percent, 1.)) + +# ---------------------------------------------------------------------------- + + +class TissueSegWidget(Widget): + + tag = 'segment' + description = 'Tissue Segmentation' + icon = join(dirname(abspath(__file__)), '..', 'gui', 'buttons', 'button_segment.png') + icon_highlighted = join(dirname(abspath(__file__)), '..', 'gui', 'buttons', 'button_segment_highlighted.png') + + def __init__(self, viz): + self.viz = viz + self._segment = None + self._thread = None + self._load_toast = None + self._working_toast = None + self._training_toast = None + self._show_params = False + self._rois_at_start = 0 + self._need_to_refresh_rois = False + self._clicking = False + self._show_popup = False + self._load_slide_popup = None + self._load_slide_popup_coords = None + + # Parameters + self._supported_archs = ['FPN', 'DeepLabV3', 'DeepLabV3Plus', 'Linknet', 'MAnet', 'PAN', 'PSPNet', 'Unet', 'UnetPlusPlus'] + self._selected_arch = 0 + self._supported_encoders = smp.encoders.get_encoder_names() + self._selected_encoder = self._supported_encoders.index('resnet34') + self._filter_methods = ['otsu', 'roi'] + self._selected_filter_method = 0 + self._training_modes = ['binary', 'multiclass', 'multilabel'] + self._selected_training_mode = 0 + self.max_epochs = 20 + self.tile_px = 1024 + self.tile_um = 2048 + self.crop_margin = 256 + self.stride = 1 + self._capturing_stride = 1 + self._selected_slides = defaultdict(bool) + self._unique_training_classes = dict() + self._sq_mm_threshold = 0.01 + + + # --- Properties --- + + @property + def cfg(self) -> sf.segment.SegmentConfig: + seg = self._segment + return None if seg is None else seg.cfg + + @property + def arch(self) -> str: + return self._supported_archs[self._selected_arch] + + @property + def encoder(self) -> str: + return self._supported_encoders[self._selected_encoder] + + @property + def mpp(self) -> float: + return self.tile_um / self.tile_px + + @property + def filter_method(self) -> str: + return self._filter_methods[self._selected_filter_method] + + @property + def mode(self) -> str: + return self._training_modes[self._selected_training_mode] + + # --- Internal --- + + def get_training_slides(self) -> List[str]: + return [slide for slide in list(self._selected_slides.keys()) + if self._selected_slides[slide]] + + def get_training_classes(self) -> List[str]: + return [(k if k != '<No label>' else None) + for k, v in self._unique_training_classes.items() if v] + + def close(self): + pass + + def is_thread_running(self): + return self._thread is not None and self._thread.is_alive() + + def is_training(self): + return self._training_toast is not None + + def drag_and_drop_hook(self, path, ignore_errors=False) -> bool: + """Handle file paths provided via drag-and-drop.""" + if (sf.util.path_to_ext(path).lower() == 'pth'): + if exists(join(dirname(path), 'segment_params.json')): + self.load(path, ignore_errors=ignore_errors) + return True + return False + + # --- Model loading --- + + def ask_load_model(self) -> str: + model_path = askopenfilename( + title="Load model...", + filetypes=[("pth", ".pth"), ("All files", ".*")] + ) + if model_path: + self.load(model_path) + + def ask_export_model(self) -> Optional[str]: + destination = askdirectory( + title="Export model (choose directory)..." + ) + if destination: + model_path = sf.util.get_new_model_dir(destination, 'segment') + self.export(model_path) + return model_path + + def export(self, path: str) -> None: + """Export a tissue segmentation model.""" + if self._segment is None: + return + if not exists(path): + os.makedirs(path) + model_path = join(path, 'model.pth') + torch.save(self._segment.model.state_dict(), model_path) + self._segment.cfg.to_json(join(path, 'segment_params.json')) + self._segment.model_path = model_path + self.viz.create_toast(f"Model exported to {model_path}", icon="success") + + def load(self, path, ignore_errors=False): + """Load a tissue segmentation model.""" + if self.is_thread_running(): + self._thread.join() + self._load_toast = self.viz.create_toast( + title=f"Loading segmentation model", + icon='info', + sticky=True, + spinner=True) + self._thread = Thread(target=self._load_model, args=(path, ignore_errors)) + self._thread.start() + + def _load_model(self, path, ignore_errors=False): + try: + self._segment = sf.slide.qc.StridedSegment(path) + self._segment.model.to(get_device()) + except Exception as e: + if self._load_toast is not None: + self._load_toast.done() + sf.log.error(f"Error loading segment model: {e}") + self.viz.create_toast(f"Error loading segment model: {e}", icon="error") + self._segment = None + else: + if self._load_toast is not None: + self._load_toast.done() + self.viz.create_toast( + f"Loaded model at {path}.", + icon="success" + ) + + def close_model(self) -> None: + self._segment = None + + def generate_rois(self): + """Generate ROIs from the loaded segmentation model.""" + if self.is_thread_running(): + self.viz.create_toast("Failed to start thread.", icon="error") + return + self._rois_at_start = len(self.viz.wsi.rois) + self._working_toast = self.viz.create_toast( + title="Generating ROIs", + message=f"Generating ROIs from segmentation model.", + icon='info', + sticky=True, + spinner=True) + self._thread = Thread(target=self._generate_rois) + self._thread.start() + + def _generate_rois(self): + viz = self.viz + self._segment.generate_rois( + viz.wsi, + sq_mm_threshold=self._sq_mm_threshold, + simplify_tolerance=5 + ) + self._need_to_refresh_rois = True + if self._working_toast is not None: + self._working_toast.done() + viz.create_toast( + "Generated {} ROIs.".format( + len(self.viz.wsi.rois) - self._rois_at_start + ), + icon="success" + ) + + def train(self) -> None: + """Train a segmentation model.""" + if self.is_thread_running(): + self.viz.create_toast("Failed to start thread.", icon="error") + return + + # Create a progress toast. + if self._training_toast is not None: + self._training_toast.done() + self._training_toast = self.viz.create_toast( + title="Training segmentation model", + icon='info', + sticky=True, + progress=True, + spinner=True + ) + self._thread = Thread(target=self._train) + self._thread.start() + + def finetune(self) -> None: + """Finetune a segmentation model.""" + if self.is_thread_running(): + self.viz.create_toast("Failed to start thread.", icon="error") + return + + # Create a progress toast. + if self._training_toast is not None: + self._training_toast.done() + self._training_toast = self.viz.create_toast( + title="Finetuning segmentation model", + icon='info', + sticky=True, + progress=True, + spinner=True + ) + self._thread = Thread(target=self._finetune) + self._thread.start() + + def _train(self) -> None: + """Train a segmentation model.""" + import pytorch_lightning as pl + + viz = self.viz + + # Prepare the slideflow dataset. + dataset = viz.P.dataset(filters={'slide': self.get_training_slides()}) + + # Determine the labels, if necessary. + all_roi_labels = self.get_training_classes() + if self.mode == 'binary': + out_classes = 1 + elif self.mode == 'multiclass': + out_classes = len(all_roi_labels) + 1 + else: + out_classes = len(all_roi_labels) + + # Prepare the tile-mask dataset. + dts = TileMaskDataset( + dataset, + tile_px=self.tile_px, + tile_um=self.tile_um, + stride_div=self.stride, + crop_margin=self.crop_margin, + filter_method=self.filter_method, + roi_labels=all_roi_labels, + mode=self.mode + ) + + # Set the configuration. + config = sf.segment.SegmentConfig( + arch=self.arch, + encoder_name=self.encoder, + epochs=self.max_epochs, # 100 + mpp=self.mpp, + mode=self.mode, + out_classes=out_classes, + labels=(all_roi_labels if self.mode != 'binary' else None) + ) + + # Create dataloader. + train_dl = torch.utils.data.DataLoader( + dts, + batch_size=config.train_batch_size, + shuffle=True, + num_workers=4, + drop_last=True, + persistent_workers=True + ) + + # Build the model and trainer. + model = config.build_model() + trainer = pl.Trainer( + max_epochs=config.epochs, + devices=1, # Distributed training not supported in a GUI. + num_nodes=1, # Distributed training not supported in a GUI. + callbacks=[ProgressCallback(self._training_toast, config.epochs)] + ) + + # Train the model. + trainer.fit(model, train_dataloaders=train_dl) + + # Move model to eval & appropriate device. + model.eval() + model.to(get_device()) + + # Create the segment object. + self._segment = sf.slide.qc.StridedSegment.from_model(model, config) + + # Cleanup. + self._training_toast.done() + self._training_toast = None + self.viz.create_toast("Training complete.", icon="success") + + def _finetune(self) -> None: + """Finetune a segmentation model.""" + import pytorch_lightning as pl + + viz = self.viz + if not self._segment: + self.viz.create_toast("Cannot finetune; no model loaded.", icon="error") + return + + # Prepare the dataset. + dataset = viz.P.dataset(filters={'slide': self.get_training_slides()}) + dts = TileMaskDataset( + dataset, + tile_px=self.tile_px, + tile_um=self.tile_um, + stride_div=self.stride, + crop_margin=self.crop_margin, + filter_method=self.filter_method + ) + + # Set the configuration. + config = sf.segment.SegmentConfig( + arch=self.arch, + encoder_name=self.encoder, + epochs=self.max_epochs, # 100 + mpp=self.mpp, + mode=self.mode, + ) + + # Create dataloader. + train_dl = torch.utils.data.DataLoader( + dts, + batch_size=config.train_batch_size, + shuffle=True, + num_workers=4, + drop_last=True + ) + + # Build the model and trainer. + trainer = pl.Trainer( + max_epochs=config.epochs, + devices=1, # Distributed training not supported in a GUI. + num_nodes=1, # Distributed training not supported in a GUI. + callbacks=[ProgressCallback(self._training_toast, config.epochs)] + ) + + # Train the model. + self._segment.model.train() + trainer.fit(self._segment.model, train_dataloaders=train_dl) + + # Move model to eval & appropriate device. + self._segment.model.eval() + self._segment.model.to(get_device()) + + # Cleanup. + self._training_toast.done() + self._training_toast = None + self.viz.create_toast("Finetuning complete.", icon="success") + + # --- Callbacks --- + + def keyboard_callback(self, key: int, action: int) -> None: + """Handle keyboard events. + + Args: + key (int): The key that was pressed. See ``glfw.KEY_*``. + action (int): The action that was performed (e.g. ``glfw.PRESS``, + ``glfw.RELEASE``, ``glfw.REPEAT``). + + """ + if (key == glfw.KEY_SPACE and action == glfw.PRESS and self.viz._control_down): + can_generate_rois = ( + not self.is_thread_running() + and (self._segment is not None) + and (self.viz.wsi is not None) + and not self.is_training() + ) + if can_generate_rois: + self.generate_rois() + + + # --- Drawing --- + + def draw_info(self): + """Draw information about the loaded model.""" + viz = self.viz + + rows = [ + ['Architecture', self.cfg.arch], + ['Encoder', self.cfg.encoder_name], + ['Mode', self.cfg.mode], + ['Classes', self.cfg.out_classes], + ['MPP', self.cfg.mpp, 'Microns per pixel (optical resolution)'] + ] + imgui.text_colored('Model', *viz.theme.dim) + imgui.same_line(viz.font_size * 6) + model_path = self._segment.model_path or 'None' + with imgui_utils.clipped_with_tooltip(model_path, 22): + imgui.text(imgui_utils.ellipsis_clip(model_path, 22)) + for y, cols in enumerate(rows): + for x, col in enumerate(cols): + if x != 0: + imgui.same_line(viz.font_size * (6 + (x - 1) * 6)) + if x == 0: + imgui.text_colored(str(col), *viz.theme.dim) + if len(cols) == 3 and imgui.is_item_hovered(): + imgui.set_tooltip(cols[2]) + elif x == 1: + imgui.text(str(col)) + + imgui.same_line(imgui.get_content_region_max()[0] - viz.font_size - viz.spacing * 2) + if imgui.button("HP"): + self._show_params = not self._show_params + + imgui_utils.vertical_break() + + def draw_train_data_source(self) -> None: + """Draw training data source options.""" + viz = self.viz + + # Slide sources + width = imgui.get_content_region_max()[0] - viz.spacing + + changed = False + with imgui.begin_list_box("##segment_data_source", width, 150) as list_box: + if list_box.opened: + if self.viz.P is None: + imgui.text("No project loaded.") + else: + for slide_path in self.viz.project_widget.slide_paths: + name = sf.util.path_to_name(slide_path) + with self.viz.bold_font(self.viz.wsi is not None and slide_path == self.viz.wsi.path): + _clicked, self._selected_slides[name] = imgui.selectable(name, self._selected_slides[name]) + if _clicked: + changed = True + if imgui.is_item_hovered(): + imgui.set_tooltip(slide_path) + if imgui.is_mouse_down(RIGHT_MOUSE_BUTTON): + self._load_slide_popup = slide_path + if imgui.is_mouse_double_clicked(LEFT_MOUSE_BUTTON): + self.viz.load_slide(slide_path) + if imgui_utils.button('Select All'): + changed = True + for name in self._selected_slides: + self._selected_slides[name] = True + + imgui.same_line() + if imgui_utils.button('With ROIs'): + changed = True + _rois = [sf.util.path_to_name(r) for r in self.viz.P.dataset().rois()] + for name in self._selected_slides: + if name in _rois: + self._selected_slides[name] = True + else: + self._selected_slides[name] = False + + imgui.same_line() + if imgui_utils.button('Select None'): + changed = True + for name in self._selected_slides: + self._selected_slides[name] = False + + imgui.text("{} slides selected".format(sum(self._selected_slides.values()))) + + # Update the unique training classes. + if changed: + dataset = viz.P.dataset(filters={'slide': self.get_training_slides()}, verification=None) + _unique = dataset.get_unique_roi_labels(allow_empty=True) + _unique = [k if k is not None else '<No label>' for k in _unique] + self._unique_training_classes = { + k: (True if k not in self._unique_training_classes else self._unique_training_classes[k]) + for k in _unique + } + + imgui_utils.vertical_break() + + def draw_class_selection(self) -> None: + """Draw class selection multi-select box.""" + viz = self.viz + imgui.text_colored('Classes', *viz.theme.dim) + imgui.same_line(viz.label_w) + + # Class selection + width = imgui.get_content_region_max()[0] - viz.spacing - viz.label_w + with imgui.begin_list_box("##segment_class_select", width, viz.font_size * 5) as list_box: + if list_box.opened: + for _class in self._unique_training_classes: + _, self._unique_training_classes[_class] = imgui.selectable(_class, self._unique_training_classes[_class]) + + imgui.text('') + imgui.same_line(viz.label_w) + if imgui_utils.button('Select All##segment_class_select_all'): + for _class in self._unique_training_classes: + self._unique_training_classes[_class] = True + + imgui.same_line() + if imgui_utils.button('Select None##segment_class_select_none'): + for _class in self._unique_training_classes: + self._unique_training_classes[_class] = False + + imgui_utils.vertical_break() + + def draw_train_data_processing(self) -> None: + """Draw training data processing options.""" + viz = self.viz + + # Tile size. + imgui.text_colored('Tile size', *viz.theme.dim) + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 3): + _, self.tile_px = imgui.input_int( + "##segment_tile_px", + self.tile_px, + step=0, + ) + imgui.same_line() + imgui.text('px') + imgui.text('') + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 3): + _, self.tile_um = imgui.input_int( + "##segment_tile_um", + self.tile_um, + step=0, + ) + imgui.same_line() + imgui.text('um') + imgui.same_line() + imgui.text('(MPP={:.2f})'.format(self.mpp)) + + # Crop margin. + imgui.text_colored('Margin', *viz.theme.dim) + if imgui.is_item_hovered(): + imgui.set_tooltip("Margin for random cropping during training.") + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 6): + _, self.crop_margin = imgui.input_int( + "##segment_crop_margin", + self.crop_margin, + step=16, + ) + self.crop_margin = max(0, self.crop_margin) + imgui.same_line() + imgui.text('px') + + # Stride. + imgui.text_colored('Stride', *viz.theme.dim) + if imgui.is_item_hovered(): + imgui.set_tooltip("Stride for tiling the slide.") + self.stride, self._capturing_stride, _ = stride_capture( + viz, + self.stride, + self._capturing_stride, + max_value=16, + label='Stride', + draw_label=False, + offset=viz.label_w, + width=imgui.get_content_region_max()[0] - viz.label_w - (viz.spacing) + ) + + # Filter method. + imgui.text_colored('Filter', *viz.theme.dim) + if imgui.is_item_hovered(): + imgui.set_tooltip( + "Method for filtering tiles.\n" + "If 'otsu', tiles are filtered using Otsu's thresholding.\n" + "If 'roi', only tiles touching an ROI are used." + ) + imgui.same_line(viz.label_w) + _, self._selected_filter_method = imgui.combo( + "##segment_filter_method", + self._selected_filter_method, + self._filter_methods + ) + + imgui_utils.vertical_break() + + def draw_train_params(self) -> None: + """Draw training architecture & hyperparameter options.""" + viz = self.viz + + # === Architecture & training parameters === + # Architecture. + imgui.text_colored('Arch', *viz.theme.dim) + if imgui.is_item_hovered(): + imgui.set_tooltip("Model architecture") + imgui.same_line(viz.label_w) + _, self._selected_arch = imgui.combo( + "##segment_arch", + self._selected_arch, + self._supported_archs + ) + # Encoder. + imgui.text_colored('Encoder', *viz.theme.dim) + imgui.same_line(viz.label_w) + _, self._selected_encoder = imgui.combo( + "##segment_encoder", + self._selected_encoder, + self._supported_encoders + ) + # Training mode. + imgui.text_colored('Mode', *viz.theme.dim) + imgui.same_line(viz.label_w) + _, self._selected_training_mode = imgui.combo( + "##segment_training_mode", + self._selected_training_mode, + self._training_modes + ) + # Max epochs. + imgui.text_colored('Epochs', *viz.theme.dim) + imgui.same_line(viz.label_w) + _, self.max_epochs = imgui.input_int( + "##segment_max_epochs", + self.max_epochs, + step=1, + step_fast=5 + ) + # Class selection (for multilabel and multiclass) + self.draw_class_selection() + + def draw_training_button(self) -> None: + """Draw the training button.""" + viz = self.viz + width = (self.viz.sidebar.content_width - (self.viz.spacing * 4)) / 3 + + # Train button. + _button_text = "Train" if not self.is_training() else "Training" + imgui_utils.spinner_text() + if viz.sidebar.full_button(_button_text, enabled=(sum(self._selected_slides.values()) and not self.is_training()), width=width): + self.train() + if imgui.is_item_hovered() and viz.P is None: + imgui.set_tooltip("No project loaded. Load a project to train a model.") + + # Finetune button. + imgui.same_line() + if viz.sidebar.full_button2("Finetune", enabled=(sum(self._selected_slides.values()) and not self.is_training() and self._segment is not None), width=width): + self.finetune() + if imgui.is_item_hovered() and self._segment is None: + imgui.set_tooltip("No model loaded. Load a model to finetune.") + if imgui.is_item_hovered() and viz.P is None: + imgui.set_tooltip("No project loaded. Load a project to export a model.") + + # Export button. + imgui.same_line() + if viz.sidebar.full_button2("Export", enabled=(self._segment is not None), width=width): + self.ask_export_model() + if imgui.is_item_hovered() and self._segment is None: + imgui.set_tooltip("No model loaded.") + + def draw_apply(self) -> None: + """Show a button prompting the user to generate ROIs.""" + viz = self.viz + + # Label + imgui.text_colored('Min mm²', *viz.theme.dim) + if imgui.is_item_hovered(): + imgui.set_tooltip("Filter out ROIs smaller than this area, in square millimeters.") + + # Free input + imgui.same_line(viz.label_w) + with imgui_utils.item_width(viz.font_size * 3): + _changed, _val = imgui.input_float('##small_roi_filter_freetext', self._sq_mm_threshold, format='%.3f') + if _changed: + self._sq_mm_threshold = _val + + # Slider + imgui.same_line(viz.label_w + viz.font_size * 3 + viz.spacing) + width = imgui.get_content_region_max()[0] - viz.label_w - viz.font_size * 3 - viz.spacing + with imgui_utils.item_width(width): + _changed, _val = imgui.slider_float( + '##small_roi_filter', + self._sq_mm_threshold, + min_value=0.0, + max_value=1.0, + format='' + ) + if _changed: + self._sq_mm_threshold = _val + + # Generate ROIs button + if viz.sidebar.full_button( + 'Generate ROIs', + enabled=( + not self.is_thread_running() + and (self._segment is not None) + and (viz.wsi is not None) + and not self.is_training() + ) + ): + self.generate_rois() + + def draw_load_slide_popup(self): + viz = self.viz + if self._load_slide_popup: + if self._load_slide_popup_coords is None: + self._load_slide_popup_coords = self.viz.get_mouse_pos(scale=False) + cx, cy = self._load_slide_popup_coords + imgui.set_next_window_position(cx, cy) + imgui.begin( + '##segment_load_slide_popup', + flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE) + ) + if imgui.menu_item('Load')[0]: + viz.load_slide(self._load_slide_popup) + self._clicking = False + self._load_slide_popup = None + self._load_slide_popup_coords = None + + # Hide menu if we click elsewhere + if imgui.is_mouse_down(LEFT_MOUSE_BUTTON) and not imgui.is_window_hovered(): + self._clicking = True + if self._clicking and imgui.is_mouse_released(LEFT_MOUSE_BUTTON): + self._clicking = False + self._load_slide_popup = None + self._load_slide_popup_coords = None + + imgui.end() + + + def draw_config_popup(self): + viz = self.viz + has_model = self._segment is not None + + if self._show_popup: + cx, cy = imgui.get_cursor_pos() + imgui.set_next_window_position(viz.sidebar.full_width, cy) + imgui.begin( + '##segment_config_popup', + flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE) + ) + if imgui.menu_item('Load model', enabled=(not self.is_training()))[0]: + self.ask_load_model() + self._clicking = False + self._show_popup = False + if imgui.menu_item('Close model', enabled=has_model)[0]: + self.close_model() + self._clicking = False + self._show_popup = False + + # Hide menu if we click elsewhere + if imgui.is_mouse_down(LEFT_MOUSE_BUTTON) and not imgui.is_window_hovered(): + self._clicking = True + if self._clicking and imgui.is_mouse_released(LEFT_MOUSE_BUTTON): + self._clicking = False + self._show_popup = False + + imgui.end() + + @imgui_utils.scoped_by_object_id + def __call__(self, show=True): + viz = self.viz + + if show: + with viz.header_with_buttons("Tissue Segmentation"): + imgui.same_line(imgui.get_content_region_max()[0] - viz.font_size*1.5) + cx, cy = imgui.get_cursor_pos() + imgui.set_cursor_position((cx, cy-int(viz.font_size*0.25))) + if viz.sidebar.small_button('gear'): + self._clicking = False + self._show_popup = not self._show_popup + self.draw_config_popup() + + if show and self._segment is None: + imgui_utils.padded_text( + 'Load or train a model.', + vpad=[int(viz.font_size/2), + int(viz.font_size)] + ) + if viz.sidebar.full_button("Load a Model", enabled=(not self.is_training())): + self.ask_load_model() + if imgui.is_item_hovered() and self.is_training(): + imgui.set_tooltip("Cannot load model while training.") + imgui_utils.vertical_break() + + elif show: + if viz.collapsing_header('Model Info', default=True): + self.draw_info() + + if show: + if viz.collapsing_header('Training', default=False): + + if viz.collapsing_header2('Data Source', default=False): + self.draw_train_data_source() + self.draw_load_slide_popup() + + if viz.collapsing_header2('Data Processing', default=False): + self.draw_train_data_processing() + + if viz.collapsing_header2('Arch & Params', default=False): + self.draw_train_params() + + imgui_utils.vertical_break() + self.draw_training_button() + + imgui_utils.vertical_break() + + if viz.collapsing_header('Apply', default=True): + self.draw_apply() + + # Refresh ROIs if necessary. + # Must be in the main loop. + if self._need_to_refresh_rois: + self._need_to_refresh_rois = False + viz.slide_widget.roi_widget.refresh_rois() \ No newline at end of file