Switch to side-by-side view

--- a
+++ b/slideflow/studio/_render_manager.py
@@ -0,0 +1,260 @@
+import multiprocessing
+import slideflow as sf
+from typing import Optional
+
+from ._renderer import Renderer, CapturedException
+
+#----------------------------------------------------------------------------
+
+class AsyncRenderManager:
+
+    """Manager to assist with rendering tile-level model predictions."""
+
+    def __init__(self):
+        self._closed        = False
+        self._is_async      = False
+        self._cur_args      = None
+        self._cur_result    = None
+        self._cur_stamp     = 0
+        self._renderer_obj  = None
+        self._args_queue    = None
+        self._result_queue  = None
+        self._process       = None
+        self._model_path    = None
+        self._live_updates  = False
+        self.tile_px        = None
+        self.extract_px     = None
+        self._addl_render   = []
+        self._set_device()
+
+    def _set_device(self) -> None:
+        """Set the device for the renderer."""
+        if sf.util.torch_available:
+            from slideflow.model import torch_utils
+            self.device = torch_utils.get_device()
+        else:
+            self.device = None
+
+    def close(self) -> None:
+        """Close the renderer."""
+        self._closed = True
+        self._renderer_obj = None
+        if self._process is not None:
+            self._process.terminate()
+        self._process = None
+        self._args_queue = None
+        self._result_queue = None
+
+    @property
+    def is_async(self) -> bool:
+        """Return whether the renderer is in asynchronous mode.
+
+        Returns:
+            bool: Whether the renderer is in asynchronous mode.
+
+        """
+        return self._is_async
+
+    def set_renderer(self, renderer_class: type, **kwargs) -> None:
+        """Set the renderer class for the renderer.
+
+        Args:
+            renderer_class (type): Renderer class to use.
+
+        """
+        assert not self._closed
+        if self.is_async:
+            self._set_args_async(set_renderer=(renderer_class, kwargs))
+        else:
+            self._renderer_obj = renderer_class(device=self.device, **kwargs)
+            for _renderer in self._addl_render:
+                self._renderer_obj.add_renderer(_renderer)
+
+    def close_renderer(self) -> None:
+        if self.is_async:
+            self._set_args_async(close_renderer=True)
+        else:
+            self._renderer_obj = None
+
+    def add_to_render_pipeline(self, renderer: Renderer) -> None:
+        """Add a renderer to the rendering pipeline.
+
+        Args:
+            renderer (Renderer): Renderer to add to the pipeline.
+                This renderer will be triggered before the main renderer.
+
+        Raises:
+            ValueError: If the renderer is in asynchronous mode.
+
+        """
+        if self.is_async:
+            raise ValueError("Cannot add to rendering pipeline when in "
+                             "asynchronous mode.")
+        self._addl_render += [renderer]
+        if self._renderer_obj is not None:
+            self._renderer_obj.add_renderer(renderer)
+
+    def remove_from_render_pipeline(self, renderer: Renderer) -> None:
+        """Remove a renderer from the rendering pipeline.
+
+        Args:
+            renderer (Renderer): Renderer to remove from the pipeline.
+
+        Raises:
+            ValueError: If the renderer is in asynchronous mode.
+
+        """
+        if self.is_async:
+            raise ValueError("Cannot remove rendering pipeline when in "
+                             "asynchronous mode.")
+        idx = self._addl_render.index(renderer)
+        del self._addl_render[idx]
+        if self._renderer_obj is not None:
+            self._renderer_obj.remove_renderer(renderer)
+
+    def set_async(self, is_async):
+        """Set the renderer to synchronous or asynchronous mode.
+
+        Args:
+            is_async (bool): Whether to set the renderer to asynchronous mode.
+
+        """
+        self._is_async = is_async
+
+    def set_args(self, **args):
+        """Set the arguments for the renderer."""
+        assert not self._closed
+        if args != self._cur_args or self._live_updates:
+            if self._is_async:
+                self._set_args_async(**args)
+            else:
+                self._set_args_sync(**args)
+            if not self._live_updates:
+                self._cur_args = args
+
+    def _set_args_async(self, **args):
+        """Set the arguments for the renderer in asynchronous mode."""
+        if self._process is None:
+            ctx = multiprocessing.get_context('spawn')
+            self._args_queue = ctx.Queue()
+            self._result_queue = ctx.Queue()
+            self._process = ctx.Process(target=self._process_fn,
+                                        args=(self._args_queue,
+                                              self._result_queue,
+                                              self._model_path,
+                                              self._live_updates),
+                                        daemon=True)
+            self._process.start()
+        self._args_queue.put([args, self._cur_stamp])
+
+    def _set_args_sync(self, **args):
+        """Set the arguments for the renderer in synchronous mode."""
+        if self._renderer_obj is None:
+            self._renderer_obj = Renderer(device=self.device)
+            for _renderer in self._addl_render:
+                self._renderer_obj.add_renderer(_renderer)
+            self._renderer_obj._model = self._model
+            self._renderer_obj._saliency = self._saliency
+        self._cur_result = self._renderer_obj.render(**args)
+
+    def get_result(self):
+        """Get the result of the renderer.
+
+        Returns:
+            EasyDict: The result of the renderer.
+
+        """
+        assert not self._closed
+        if self._result_queue is not None:
+            while self._result_queue.qsize() > 0:
+                result, stamp = self._result_queue.get()
+                if stamp == self._cur_stamp:
+                    self._cur_result = result
+        return self._cur_result
+
+    def clear_result(self):
+        """Clear the result of the renderer."""
+        assert not self._closed
+        self._cur_args = None
+        self._cur_result = None
+        self._cur_stamp += 1
+
+    def load_model(self, model_path: str) -> None:
+        """Load a model for the renderer.
+
+        Args:
+            model_path (str): Path to the model.
+
+        """
+        if self._is_async:
+            self._set_args_async(load_model=model_path)
+        elif model_path != self._model_path:
+            self._model_path = model_path
+            if self._renderer_obj is None:
+                self._renderer_obj = Renderer(device=self.device)
+                for _renderer in self._addl_render:
+                    self._renderer_obj.add_renderer(_renderer)
+            self._renderer_obj.load_model(model_path, device=self.device)
+
+    def clear_model(self):
+        """Clear the model for the renderer."""
+        self._model_path = None
+        if self._renderer_obj is not None:
+            self._renderer_obj._umap_encoders = None
+            self._renderer_obj._model = None
+            self._renderer_obj._saliency = None
+
+    @property
+    def _model(self):
+        if self._renderer_obj is not None:
+            return self._renderer_obj._model
+        else:
+            return None
+
+    @property
+    def _saliency(self):
+        if self._renderer_obj is not None:
+            return self._renderer_obj._saliency
+        else:
+            return None
+
+    @property
+    def _umap_encoders(self):
+        if self._renderer_obj is not None:
+            return self._renderer_obj._umap_encoders
+        else:
+            return None
+
+    @staticmethod
+    def _process_fn(
+        args_queue: multiprocessing.Queue,
+        result_queue: multiprocessing.Queue,
+        model_path: Optional[str] = None,
+        live_updates: bool = False
+    ):
+        if sf.util.torch_available:
+            from slideflow.model import torch_utils
+            device = torch_utils.get_device()
+        else:
+            device = None
+        renderer_obj = Renderer(device=device)
+        if model_path:
+            renderer_obj.load_model(model_path, device=device)
+        while True:
+            while args_queue.qsize() > 0:
+                args, stamp = args_queue.get()
+                if 'close_renderer' in args:
+                    renderer_obj = Renderer(device=device)
+                if 'set_renderer' in args:
+                    renderer_class, kwargs = args['set_renderer']
+                    renderer_obj = renderer_class(**kwargs)
+                if 'load_model' in args:
+                    renderer_obj.load_model(args['load_model'], device=device)
+                if 'quit' in args:
+                    return
+            if (live_updates and not result_queue.qsize()):
+                result = renderer_obj.render(**args)
+                if 'error' in result:
+                    result.error = CapturedException(result.error)
+
+                result_queue.put([result, stamp])