|
a |
|
b/slideflow/studio/utils.py |
|
|
1 |
"""Utilities for Slideflow Studio.""" |
|
|
2 |
|
|
|
3 |
from typing import Any, List |
|
|
4 |
|
|
|
5 |
import imgui |
|
|
6 |
import os |
|
|
7 |
import slideflow as sf |
|
|
8 |
import numpy as np |
|
|
9 |
from os.path import join, exists |
|
|
10 |
from slideflow import log |
|
|
11 |
from typing import Tuple, Optional |
|
|
12 |
|
|
|
13 |
if sf.util.tf_available: |
|
|
14 |
import tensorflow as tf |
|
|
15 |
sf.util.allow_gpu_memory_growth() |
|
|
16 |
if sf.util.torch_available: |
|
|
17 |
import slideflow.model.torch |
|
|
18 |
|
|
|
19 |
#---------------------------------------------------------------------------- |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
LEFT_MOUSE_BUTTON = 0 |
|
|
23 |
RIGHT_MOUSE_BUTTON = 1 |
|
|
24 |
|
|
|
25 |
#---------------------------------------------------------------------------- |
|
|
26 |
|
|
|
27 |
class EasyDict(dict): |
|
|
28 |
"""Convenience class that behaves like a dict but allows access with the attribute syntax.""" |
|
|
29 |
|
|
|
30 |
def __getattr__(self, name: str) -> Any: |
|
|
31 |
try: |
|
|
32 |
return self[name] |
|
|
33 |
except KeyError: |
|
|
34 |
raise AttributeError(name) |
|
|
35 |
|
|
|
36 |
def __setattr__(self, name: str, value: Any) -> None: |
|
|
37 |
self[name] = value |
|
|
38 |
|
|
|
39 |
def __delattr__(self, name: str) -> None: |
|
|
40 |
del self[name] |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
#---------------------------------------------------------------------------- |
|
|
44 |
|
|
|
45 |
def prediction_to_string( |
|
|
46 |
predictions: np.ndarray, |
|
|
47 |
outcomes: List[str], |
|
|
48 |
is_classification: bool |
|
|
49 |
) -> str: |
|
|
50 |
"""Convert a prediction array to a human-readable string.""" |
|
|
51 |
#TODO: support multi-outcome models |
|
|
52 |
if is_classification: |
|
|
53 |
return f'{outcomes[str(np.argmax(predictions))]} ({np.max(predictions)*100:.1f}%)' |
|
|
54 |
else: |
|
|
55 |
return f'{predictions[0]:.2f}' |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def _load_umap_encoders(path, model) -> EasyDict: |
|
|
59 |
import tensorflow as tf |
|
|
60 |
|
|
|
61 |
layers = [d for d in os.listdir(path) if os.path.isdir(join(path, d))] |
|
|
62 |
log.debug("Layers found at path {} in _load_umap_encoders: {}".format(path, layers)) |
|
|
63 |
features = sf.model.Features.from_model( |
|
|
64 |
model, |
|
|
65 |
include_preds=True, |
|
|
66 |
layers=layers, |
|
|
67 |
pooling='avg' |
|
|
68 |
) |
|
|
69 |
|
|
|
70 |
outputs = [] |
|
|
71 |
for i, layer in enumerate(layers): |
|
|
72 |
# Add outputs for each UMAP encoder |
|
|
73 |
encoder = tf.keras.models.load_model(join(path, layer, 'encoder')) |
|
|
74 |
encoder._name = f'{layer}_encoder' |
|
|
75 |
outputs += [encoder(features.model.outputs[i])] |
|
|
76 |
|
|
|
77 |
# Add the predictions output |
|
|
78 |
outputs += [features.model.outputs[-1]] |
|
|
79 |
|
|
|
80 |
# Build the encoder model for all layers |
|
|
81 |
encoder_model = tf.keras.models.Model( |
|
|
82 |
inputs=features.model.input, |
|
|
83 |
outputs=outputs |
|
|
84 |
) |
|
|
85 |
return EasyDict( |
|
|
86 |
encoder=encoder_model, |
|
|
87 |
layers=layers, |
|
|
88 |
range={ |
|
|
89 |
layer: np.load(join(path, layer, 'range_clip.npz'))['range'] |
|
|
90 |
for layer in layers |
|
|
91 |
}, |
|
|
92 |
clip={ |
|
|
93 |
layer: np.load(join(path, layer, 'range_clip.npz'))['clip'] |
|
|
94 |
for layer in layers |
|
|
95 |
} |
|
|
96 |
) |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
def _load_model_and_saliency(model_path, device=None): |
|
|
100 |
log.debug("Loading model at {}...".format(model_path)) |
|
|
101 |
_umap_encoders = None |
|
|
102 |
_saliency = None |
|
|
103 |
|
|
|
104 |
# Load a PyTorch model |
|
|
105 |
if sf.util.torch_available and sf.util.path_to_ext(model_path) == 'zip': |
|
|
106 |
import slideflow.model.torch |
|
|
107 |
_device = sf.model.torch.torch_utils.get_device() |
|
|
108 |
_model = sf.model.torch.load(model_path) |
|
|
109 |
_model.to(_device) |
|
|
110 |
_model.eval() |
|
|
111 |
if device is not None: |
|
|
112 |
_model = _model.to(device) |
|
|
113 |
_saliency = sf.grad.SaliencyMap(_model, class_idx=0) #TODO: auto-update from heatmaps logit |
|
|
114 |
|
|
|
115 |
# Load a TFLite model |
|
|
116 |
elif sf.util.tf_available and sf.util.path_to_ext(model_path) == 'tflite': |
|
|
117 |
interpreter = tf.lite.Interpreter(model_path) |
|
|
118 |
_model = interpreter.get_signature_runner() |
|
|
119 |
|
|
|
120 |
# Load a Tensorflow model |
|
|
121 |
elif sf.util.tf_available: |
|
|
122 |
import slideflow.model.tensorflow |
|
|
123 |
_model = sf.model.tensorflow.load(model_path, method='weights') |
|
|
124 |
_saliency = sf.grad.SaliencyMap(_model, class_idx=0) #TODO: auto-update from heatmaps logit |
|
|
125 |
if exists(join(model_path, 'umap_encoders')): |
|
|
126 |
_umap_encoders = _load_umap_encoders(join(model_path, 'umap_encoders'), _model) |
|
|
127 |
else: |
|
|
128 |
raise ValueError(f"Unable to interpret model {model_path}") |
|
|
129 |
return _model, _saliency, _umap_encoders |
|
|
130 |
|
|
|
131 |
#---------------------------------------------------------------------------- |
|
|
132 |
|
|
|
133 |
class StatusMessage: |
|
|
134 |
"""A class to manage status messages.""" |
|
|
135 |
def __init__( |
|
|
136 |
self, |
|
|
137 |
viz: Any, |
|
|
138 |
message: str, |
|
|
139 |
description: Optional[str] = None, |
|
|
140 |
*, |
|
|
141 |
color: Tuple[float, float, float, float] = (0.7, 0, 0, 1), |
|
|
142 |
text_color: Tuple[float, float, float, float] = (1, 1, 1, 1), |
|
|
143 |
rounding: int = 0, |
|
|
144 |
) -> None: |
|
|
145 |
self.viz = viz |
|
|
146 |
self.message = message |
|
|
147 |
self.description = description |
|
|
148 |
self.color = color |
|
|
149 |
self.text_color = text_color |
|
|
150 |
self.rounding = rounding |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def render(self): |
|
|
154 |
"""Render the status message.""" |
|
|
155 |
# Calculations. |
|
|
156 |
h = self.viz.status_bar_height |
|
|
157 |
r = self.viz.pixel_ratio |
|
|
158 |
y_pos = int((self.viz.content_frame_height - (h * r)) / r) |
|
|
159 |
size = imgui.calc_text_size(self.message) |
|
|
160 |
|
|
|
161 |
# Center the text. |
|
|
162 |
x_start = self.viz.content_width/2 - size.x/2 |
|
|
163 |
imgui.same_line() |
|
|
164 |
imgui.set_cursor_pos_x(x_start) |
|
|
165 |
|
|
|
166 |
# Draw the background. |
|
|
167 |
draw_list = imgui.get_window_draw_list() |
|
|
168 |
pad = self.viz.spacing * 2 |
|
|
169 |
draw_list.add_rect_filled( |
|
|
170 |
x_start - pad - 4, |
|
|
171 |
y_pos, |
|
|
172 |
x_start + size.x + pad, |
|
|
173 |
y_pos + h, |
|
|
174 |
imgui.get_color_u32_rgba(*self.color), |
|
|
175 |
rounding=self.rounding |
|
|
176 |
) |
|
|
177 |
|
|
|
178 |
# Draw the text. |
|
|
179 |
imgui.push_style_color(imgui.COLOR_TEXT, *self.text_color) |
|
|
180 |
imgui.text(self.message) |
|
|
181 |
imgui.pop_style_color(1) |
|
|
182 |
|
|
|
183 |
# Set the tooltip. |
|
|
184 |
if self.description is not None: |
|
|
185 |
if imgui.is_item_hovered(): |
|
|
186 |
imgui.set_tooltip(self.description) |