--- a +++ b/slideflow/studio/_render_manager.py @@ -0,0 +1,260 @@ +import multiprocessing +import slideflow as sf +from typing import Optional + +from ._renderer import Renderer, CapturedException + +#---------------------------------------------------------------------------- + +class AsyncRenderManager: + + """Manager to assist with rendering tile-level model predictions.""" + + def __init__(self): + self._closed = False + self._is_async = False + self._cur_args = None + self._cur_result = None + self._cur_stamp = 0 + self._renderer_obj = None + self._args_queue = None + self._result_queue = None + self._process = None + self._model_path = None + self._live_updates = False + self.tile_px = None + self.extract_px = None + self._addl_render = [] + self._set_device() + + def _set_device(self) -> None: + """Set the device for the renderer.""" + if sf.util.torch_available: + from slideflow.model import torch_utils + self.device = torch_utils.get_device() + else: + self.device = None + + def close(self) -> None: + """Close the renderer.""" + self._closed = True + self._renderer_obj = None + if self._process is not None: + self._process.terminate() + self._process = None + self._args_queue = None + self._result_queue = None + + @property + def is_async(self) -> bool: + """Return whether the renderer is in asynchronous mode. + + Returns: + bool: Whether the renderer is in asynchronous mode. + + """ + return self._is_async + + def set_renderer(self, renderer_class: type, **kwargs) -> None: + """Set the renderer class for the renderer. + + Args: + renderer_class (type): Renderer class to use. + + """ + assert not self._closed + if self.is_async: + self._set_args_async(set_renderer=(renderer_class, kwargs)) + else: + self._renderer_obj = renderer_class(device=self.device, **kwargs) + for _renderer in self._addl_render: + self._renderer_obj.add_renderer(_renderer) + + def close_renderer(self) -> None: + if self.is_async: + self._set_args_async(close_renderer=True) + else: + self._renderer_obj = None + + def add_to_render_pipeline(self, renderer: Renderer) -> None: + """Add a renderer to the rendering pipeline. + + Args: + renderer (Renderer): Renderer to add to the pipeline. + This renderer will be triggered before the main renderer. + + Raises: + ValueError: If the renderer is in asynchronous mode. + + """ + if self.is_async: + raise ValueError("Cannot add to rendering pipeline when in " + "asynchronous mode.") + self._addl_render += [renderer] + if self._renderer_obj is not None: + self._renderer_obj.add_renderer(renderer) + + def remove_from_render_pipeline(self, renderer: Renderer) -> None: + """Remove a renderer from the rendering pipeline. + + Args: + renderer (Renderer): Renderer to remove from the pipeline. + + Raises: + ValueError: If the renderer is in asynchronous mode. + + """ + if self.is_async: + raise ValueError("Cannot remove rendering pipeline when in " + "asynchronous mode.") + idx = self._addl_render.index(renderer) + del self._addl_render[idx] + if self._renderer_obj is not None: + self._renderer_obj.remove_renderer(renderer) + + def set_async(self, is_async): + """Set the renderer to synchronous or asynchronous mode. + + Args: + is_async (bool): Whether to set the renderer to asynchronous mode. + + """ + self._is_async = is_async + + def set_args(self, **args): + """Set the arguments for the renderer.""" + assert not self._closed + if args != self._cur_args or self._live_updates: + if self._is_async: + self._set_args_async(**args) + else: + self._set_args_sync(**args) + if not self._live_updates: + self._cur_args = args + + def _set_args_async(self, **args): + """Set the arguments for the renderer in asynchronous mode.""" + if self._process is None: + ctx = multiprocessing.get_context('spawn') + self._args_queue = ctx.Queue() + self._result_queue = ctx.Queue() + self._process = ctx.Process(target=self._process_fn, + args=(self._args_queue, + self._result_queue, + self._model_path, + self._live_updates), + daemon=True) + self._process.start() + self._args_queue.put([args, self._cur_stamp]) + + def _set_args_sync(self, **args): + """Set the arguments for the renderer in synchronous mode.""" + if self._renderer_obj is None: + self._renderer_obj = Renderer(device=self.device) + for _renderer in self._addl_render: + self._renderer_obj.add_renderer(_renderer) + self._renderer_obj._model = self._model + self._renderer_obj._saliency = self._saliency + self._cur_result = self._renderer_obj.render(**args) + + def get_result(self): + """Get the result of the renderer. + + Returns: + EasyDict: The result of the renderer. + + """ + assert not self._closed + if self._result_queue is not None: + while self._result_queue.qsize() > 0: + result, stamp = self._result_queue.get() + if stamp == self._cur_stamp: + self._cur_result = result + return self._cur_result + + def clear_result(self): + """Clear the result of the renderer.""" + assert not self._closed + self._cur_args = None + self._cur_result = None + self._cur_stamp += 1 + + def load_model(self, model_path: str) -> None: + """Load a model for the renderer. + + Args: + model_path (str): Path to the model. + + """ + if self._is_async: + self._set_args_async(load_model=model_path) + elif model_path != self._model_path: + self._model_path = model_path + if self._renderer_obj is None: + self._renderer_obj = Renderer(device=self.device) + for _renderer in self._addl_render: + self._renderer_obj.add_renderer(_renderer) + self._renderer_obj.load_model(model_path, device=self.device) + + def clear_model(self): + """Clear the model for the renderer.""" + self._model_path = None + if self._renderer_obj is not None: + self._renderer_obj._umap_encoders = None + self._renderer_obj._model = None + self._renderer_obj._saliency = None + + @property + def _model(self): + if self._renderer_obj is not None: + return self._renderer_obj._model + else: + return None + + @property + def _saliency(self): + if self._renderer_obj is not None: + return self._renderer_obj._saliency + else: + return None + + @property + def _umap_encoders(self): + if self._renderer_obj is not None: + return self._renderer_obj._umap_encoders + else: + return None + + @staticmethod + def _process_fn( + args_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + model_path: Optional[str] = None, + live_updates: bool = False + ): + if sf.util.torch_available: + from slideflow.model import torch_utils + device = torch_utils.get_device() + else: + device = None + renderer_obj = Renderer(device=device) + if model_path: + renderer_obj.load_model(model_path, device=device) + while True: + while args_queue.qsize() > 0: + args, stamp = args_queue.get() + if 'close_renderer' in args: + renderer_obj = Renderer(device=device) + if 'set_renderer' in args: + renderer_class, kwargs = args['set_renderer'] + renderer_obj = renderer_class(**kwargs) + if 'load_model' in args: + renderer_obj.load_model(args['load_model'], device=device) + if 'quit' in args: + return + if (live_updates and not result_queue.qsize()): + result = renderer_obj.render(**args) + if 'error' in result: + result.error = CapturedException(result.error) + + result_queue.put([result, stamp])