Diff of /slideflow/studio/utils.py [000000] .. [78ef36]

Switch to unified view

a b/slideflow/studio/utils.py
1
"""Utilities for Slideflow Studio."""
2
3
from typing import Any, List
4
5
import imgui
6
import os
7
import slideflow as sf
8
import numpy as np
9
from os.path import join, exists
10
from slideflow import log
11
from typing import Tuple, Optional
12
13
if sf.util.tf_available:
14
    import tensorflow as tf
15
    sf.util.allow_gpu_memory_growth()
16
if sf.util.torch_available:
17
    import slideflow.model.torch
18
19
#----------------------------------------------------------------------------
20
21
22
LEFT_MOUSE_BUTTON = 0
23
RIGHT_MOUSE_BUTTON = 1
24
25
#----------------------------------------------------------------------------
26
27
class EasyDict(dict):
28
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""
29
30
    def __getattr__(self, name: str) -> Any:
31
        try:
32
            return self[name]
33
        except KeyError:
34
            raise AttributeError(name)
35
36
    def __setattr__(self, name: str, value: Any) -> None:
37
        self[name] = value
38
39
    def __delattr__(self, name: str) -> None:
40
        del self[name]
41
42
43
#----------------------------------------------------------------------------
44
45
def prediction_to_string(
46
    predictions: np.ndarray,
47
    outcomes: List[str],
48
    is_classification: bool
49
) -> str:
50
    """Convert a prediction array to a human-readable string."""
51
    #TODO: support multi-outcome models
52
    if is_classification:
53
        return f'{outcomes[str(np.argmax(predictions))]} ({np.max(predictions)*100:.1f}%)'
54
    else:
55
        return f'{predictions[0]:.2f}'
56
57
58
def _load_umap_encoders(path, model) -> EasyDict:
59
    import tensorflow as tf
60
61
    layers = [d for d in os.listdir(path) if os.path.isdir(join(path, d))]
62
    log.debug("Layers found at path {} in _load_umap_encoders: {}".format(path, layers))
63
    features = sf.model.Features.from_model(
64
        model,
65
        include_preds=True,
66
        layers=layers,
67
        pooling='avg'
68
    )
69
70
    outputs = []
71
    for i, layer in enumerate(layers):
72
        # Add outputs for each UMAP encoder
73
        encoder = tf.keras.models.load_model(join(path, layer, 'encoder'))
74
        encoder._name = f'{layer}_encoder'
75
        outputs += [encoder(features.model.outputs[i])]
76
77
    # Add the predictions output
78
    outputs += [features.model.outputs[-1]]
79
80
    # Build the encoder model for all layers
81
    encoder_model = tf.keras.models.Model(
82
        inputs=features.model.input,
83
        outputs=outputs
84
    )
85
    return EasyDict(
86
        encoder=encoder_model,
87
        layers=layers,
88
        range={
89
            layer: np.load(join(path, layer, 'range_clip.npz'))['range']
90
            for layer in layers
91
        },
92
        clip={
93
            layer: np.load(join(path, layer, 'range_clip.npz'))['clip']
94
            for layer in layers
95
        }
96
    )
97
98
99
def _load_model_and_saliency(model_path, device=None):
100
    log.debug("Loading model at {}...".format(model_path))
101
    _umap_encoders = None
102
    _saliency = None
103
104
    # Load a PyTorch model
105
    if sf.util.torch_available and sf.util.path_to_ext(model_path) == 'zip':
106
        import slideflow.model.torch
107
        _device = sf.model.torch.torch_utils.get_device()
108
        _model = sf.model.torch.load(model_path)
109
        _model.to(_device)
110
        _model.eval()
111
        if device is not None:
112
            _model = _model.to(device)
113
        _saliency = sf.grad.SaliencyMap(_model, class_idx=0)  #TODO: auto-update from heatmaps logit
114
115
    # Load a TFLite model
116
    elif sf.util.tf_available and sf.util.path_to_ext(model_path) == 'tflite':
117
        interpreter = tf.lite.Interpreter(model_path)
118
        _model = interpreter.get_signature_runner()
119
120
    # Load a Tensorflow model
121
    elif sf.util.tf_available:
122
        import slideflow.model.tensorflow
123
        _model = sf.model.tensorflow.load(model_path, method='weights')
124
        _saliency = sf.grad.SaliencyMap(_model, class_idx=0)  #TODO: auto-update from heatmaps logit
125
        if exists(join(model_path, 'umap_encoders')):
126
            _umap_encoders = _load_umap_encoders(join(model_path, 'umap_encoders'), _model)
127
    else:
128
        raise ValueError(f"Unable to interpret model {model_path}")
129
    return _model, _saliency, _umap_encoders
130
131
#----------------------------------------------------------------------------
132
133
class StatusMessage:
134
    """A class to manage status messages."""
135
    def __init__(
136
        self,
137
        viz: Any,
138
        message: str,
139
        description: Optional[str] = None,
140
        *,
141
        color: Tuple[float, float, float, float] = (0.7, 0, 0, 1),
142
        text_color: Tuple[float, float, float, float] = (1, 1, 1, 1),
143
        rounding: int = 0,
144
    ) -> None:
145
        self.viz = viz
146
        self.message = message
147
        self.description = description
148
        self.color = color
149
        self.text_color = text_color
150
        self.rounding = rounding
151
152
153
    def render(self):
154
        """Render the status message."""
155
        # Calculations.
156
        h = self.viz.status_bar_height
157
        r = self.viz.pixel_ratio
158
        y_pos = int((self.viz.content_frame_height - (h * r)) / r)
159
        size = imgui.calc_text_size(self.message)
160
161
        # Center the text.
162
        x_start = self.viz.content_width/2 - size.x/2
163
        imgui.same_line()
164
        imgui.set_cursor_pos_x(x_start)
165
166
        # Draw the background.
167
        draw_list = imgui.get_window_draw_list()
168
        pad = self.viz.spacing * 2
169
        draw_list.add_rect_filled(
170
            x_start - pad - 4,
171
            y_pos,
172
            x_start + size.x + pad,
173
            y_pos + h,
174
            imgui.get_color_u32_rgba(*self.color),
175
            rounding=self.rounding
176
        )
177
178
        # Draw the text.
179
        imgui.push_style_color(imgui.COLOR_TEXT, *self.text_color)
180
        imgui.text(self.message)
181
        imgui.pop_style_color(1)
182
183
        # Set the tooltip.
184
        if self.description is not None:
185
            if imgui.is_item_hovered():
186
                imgui.set_tooltip(self.description)