a b/slideflow/studio/_render_manager.py
1
import multiprocessing
2
import slideflow as sf
3
from typing import Optional
4
5
from ._renderer import Renderer, CapturedException
6
7
#----------------------------------------------------------------------------
8
9
class AsyncRenderManager:
10
11
    """Manager to assist with rendering tile-level model predictions."""
12
13
    def __init__(self):
14
        self._closed        = False
15
        self._is_async      = False
16
        self._cur_args      = None
17
        self._cur_result    = None
18
        self._cur_stamp     = 0
19
        self._renderer_obj  = None
20
        self._args_queue    = None
21
        self._result_queue  = None
22
        self._process       = None
23
        self._model_path    = None
24
        self._live_updates  = False
25
        self.tile_px        = None
26
        self.extract_px     = None
27
        self._addl_render   = []
28
        self._set_device()
29
30
    def _set_device(self) -> None:
31
        """Set the device for the renderer."""
32
        if sf.util.torch_available:
33
            from slideflow.model import torch_utils
34
            self.device = torch_utils.get_device()
35
        else:
36
            self.device = None
37
38
    def close(self) -> None:
39
        """Close the renderer."""
40
        self._closed = True
41
        self._renderer_obj = None
42
        if self._process is not None:
43
            self._process.terminate()
44
        self._process = None
45
        self._args_queue = None
46
        self._result_queue = None
47
48
    @property
49
    def is_async(self) -> bool:
50
        """Return whether the renderer is in asynchronous mode.
51
52
        Returns:
53
            bool: Whether the renderer is in asynchronous mode.
54
55
        """
56
        return self._is_async
57
58
    def set_renderer(self, renderer_class: type, **kwargs) -> None:
59
        """Set the renderer class for the renderer.
60
61
        Args:
62
            renderer_class (type): Renderer class to use.
63
64
        """
65
        assert not self._closed
66
        if self.is_async:
67
            self._set_args_async(set_renderer=(renderer_class, kwargs))
68
        else:
69
            self._renderer_obj = renderer_class(device=self.device, **kwargs)
70
            for _renderer in self._addl_render:
71
                self._renderer_obj.add_renderer(_renderer)
72
73
    def close_renderer(self) -> None:
74
        if self.is_async:
75
            self._set_args_async(close_renderer=True)
76
        else:
77
            self._renderer_obj = None
78
79
    def add_to_render_pipeline(self, renderer: Renderer) -> None:
80
        """Add a renderer to the rendering pipeline.
81
82
        Args:
83
            renderer (Renderer): Renderer to add to the pipeline.
84
                This renderer will be triggered before the main renderer.
85
86
        Raises:
87
            ValueError: If the renderer is in asynchronous mode.
88
89
        """
90
        if self.is_async:
91
            raise ValueError("Cannot add to rendering pipeline when in "
92
                             "asynchronous mode.")
93
        self._addl_render += [renderer]
94
        if self._renderer_obj is not None:
95
            self._renderer_obj.add_renderer(renderer)
96
97
    def remove_from_render_pipeline(self, renderer: Renderer) -> None:
98
        """Remove a renderer from the rendering pipeline.
99
100
        Args:
101
            renderer (Renderer): Renderer to remove from the pipeline.
102
103
        Raises:
104
            ValueError: If the renderer is in asynchronous mode.
105
106
        """
107
        if self.is_async:
108
            raise ValueError("Cannot remove rendering pipeline when in "
109
                             "asynchronous mode.")
110
        idx = self._addl_render.index(renderer)
111
        del self._addl_render[idx]
112
        if self._renderer_obj is not None:
113
            self._renderer_obj.remove_renderer(renderer)
114
115
    def set_async(self, is_async):
116
        """Set the renderer to synchronous or asynchronous mode.
117
118
        Args:
119
            is_async (bool): Whether to set the renderer to asynchronous mode.
120
121
        """
122
        self._is_async = is_async
123
124
    def set_args(self, **args):
125
        """Set the arguments for the renderer."""
126
        assert not self._closed
127
        if args != self._cur_args or self._live_updates:
128
            if self._is_async:
129
                self._set_args_async(**args)
130
            else:
131
                self._set_args_sync(**args)
132
            if not self._live_updates:
133
                self._cur_args = args
134
135
    def _set_args_async(self, **args):
136
        """Set the arguments for the renderer in asynchronous mode."""
137
        if self._process is None:
138
            ctx = multiprocessing.get_context('spawn')
139
            self._args_queue = ctx.Queue()
140
            self._result_queue = ctx.Queue()
141
            self._process = ctx.Process(target=self._process_fn,
142
                                        args=(self._args_queue,
143
                                              self._result_queue,
144
                                              self._model_path,
145
                                              self._live_updates),
146
                                        daemon=True)
147
            self._process.start()
148
        self._args_queue.put([args, self._cur_stamp])
149
150
    def _set_args_sync(self, **args):
151
        """Set the arguments for the renderer in synchronous mode."""
152
        if self._renderer_obj is None:
153
            self._renderer_obj = Renderer(device=self.device)
154
            for _renderer in self._addl_render:
155
                self._renderer_obj.add_renderer(_renderer)
156
            self._renderer_obj._model = self._model
157
            self._renderer_obj._saliency = self._saliency
158
        self._cur_result = self._renderer_obj.render(**args)
159
160
    def get_result(self):
161
        """Get the result of the renderer.
162
163
        Returns:
164
            EasyDict: The result of the renderer.
165
166
        """
167
        assert not self._closed
168
        if self._result_queue is not None:
169
            while self._result_queue.qsize() > 0:
170
                result, stamp = self._result_queue.get()
171
                if stamp == self._cur_stamp:
172
                    self._cur_result = result
173
        return self._cur_result
174
175
    def clear_result(self):
176
        """Clear the result of the renderer."""
177
        assert not self._closed
178
        self._cur_args = None
179
        self._cur_result = None
180
        self._cur_stamp += 1
181
182
    def load_model(self, model_path: str) -> None:
183
        """Load a model for the renderer.
184
185
        Args:
186
            model_path (str): Path to the model.
187
188
        """
189
        if self._is_async:
190
            self._set_args_async(load_model=model_path)
191
        elif model_path != self._model_path:
192
            self._model_path = model_path
193
            if self._renderer_obj is None:
194
                self._renderer_obj = Renderer(device=self.device)
195
                for _renderer in self._addl_render:
196
                    self._renderer_obj.add_renderer(_renderer)
197
            self._renderer_obj.load_model(model_path, device=self.device)
198
199
    def clear_model(self):
200
        """Clear the model for the renderer."""
201
        self._model_path = None
202
        if self._renderer_obj is not None:
203
            self._renderer_obj._umap_encoders = None
204
            self._renderer_obj._model = None
205
            self._renderer_obj._saliency = None
206
207
    @property
208
    def _model(self):
209
        if self._renderer_obj is not None:
210
            return self._renderer_obj._model
211
        else:
212
            return None
213
214
    @property
215
    def _saliency(self):
216
        if self._renderer_obj is not None:
217
            return self._renderer_obj._saliency
218
        else:
219
            return None
220
221
    @property
222
    def _umap_encoders(self):
223
        if self._renderer_obj is not None:
224
            return self._renderer_obj._umap_encoders
225
        else:
226
            return None
227
228
    @staticmethod
229
    def _process_fn(
230
        args_queue: multiprocessing.Queue,
231
        result_queue: multiprocessing.Queue,
232
        model_path: Optional[str] = None,
233
        live_updates: bool = False
234
    ):
235
        if sf.util.torch_available:
236
            from slideflow.model import torch_utils
237
            device = torch_utils.get_device()
238
        else:
239
            device = None
240
        renderer_obj = Renderer(device=device)
241
        if model_path:
242
            renderer_obj.load_model(model_path, device=device)
243
        while True:
244
            while args_queue.qsize() > 0:
245
                args, stamp = args_queue.get()
246
                if 'close_renderer' in args:
247
                    renderer_obj = Renderer(device=device)
248
                if 'set_renderer' in args:
249
                    renderer_class, kwargs = args['set_renderer']
250
                    renderer_obj = renderer_class(**kwargs)
251
                if 'load_model' in args:
252
                    renderer_obj.load_model(args['load_model'], device=device)
253
                if 'quit' in args:
254
                    return
255
            if (live_updates and not result_queue.qsize()):
256
                result = renderer_obj.render(**args)
257
                if 'error' in result:
258
                    result.error = CapturedException(result.error)
259
260
                result_queue.put([result, stamp])