[030aeb]: / dosma / tissues / tissue.py

Download this file

351 lines (269 with data), 12.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
import os
from abc import ABC, abstractmethod
from typing import Union
import numpy as np
import pandas as pd
import scipy.ndimage as sni
from dosma.core.io import format_io_utils as fio_utils
from dosma.core.io.format_io import ImageDataFormat
from dosma.core.med_volume import MedicalVolume
from dosma.core.orientation import SAGITTAL
from dosma.core.quant_vals import QuantitativeValue, QuantitativeValueType
from dosma.defaults import preferences
from dosma.utils import io_utils
WEIGHTS_FILE_EXT = "h5"
__all__ = ["Tissue"]
class Tissue(ABC):
"""Abstract class for tissues.
Tissues are defined loosely as any tissue structures (bones, soft tissue, etc.).
Args:
weights_dir (str): Directory to all segmentation weights.
medial_to_lateral (`bool`, optional): If `True`, anatomy is from medial_to_lateral.
Attributes:
FULL_NAME (str): Full name of tissue 'femoral cartilage' for femoral cartilage.
ID (int): Unique integer ID for tissue. Should be unique to all tissues,
and should not change.
STR_ID (str): Short hand string id such as 'fc' for femoral cartilage.
T1_EXPECTED (float): Expected T1 value (in milliseconds).
medial_to_lateral (bool): If ``True``, mask is in medial to lateral direction.
pid (str): Patient/subject ID. Should be anonymized.
quant_vals (dict[str, tuple[np.ndarray, pd.DataFrame]]): Mapping from quantitative value
name (t2, t1-rho, etc.) to tuple of unrolled map and DataFrame containing
measurement values.
weights_filepath (str): File path to weights directory for neural network segmentation.
"""
ID = -1
STR_ID = ""
FULL_NAME = ""
# Expected quantitative param values.
T1_EXPECTED = None
def __init__(self, weights_dir: str = None, medial_to_lateral: bool = None):
self.pid = None
self.__mask__ = None
self.quant_vals = {}
self.weights_file_path = None
if weights_dir is not None:
self.weights_file_path = self.find_weights(weights_dir)
self.medial_to_lateral = medial_to_lateral
# quantitative value list
self.quantitative_values = []
@abstractmethod
def split_regions(self, base_map: Union[np.ndarray, MedicalVolume]):
"""Split mask into anatomical regions.
Args:
base_map (np.ndarray): 3D numpy array typically corresponding to volume to split.
Returns:
np.ndarray: 4D numpy array (region, height, width, depth).
Saved in variable `self.regions`.
"""
pass
def calc_quant_vals(self):
"""Calculate quantitative values for pixels corresponding to the tissue.
Requires mask to be set for this tissue.
"""
for qv in self.quantitative_values:
self.__calc_quant_vals__(qv.volumetric_map, qv.qv_type)
@abstractmethod
def __calc_quant_vals__(self, quant_map: MedicalVolume, map_type: QuantitativeValueType):
"""Helper method to get quantitative values for tissue - implemented per tissue.
Different tissues should override this as they see fit.
Args:
quant_map (MedicalVolume): 3D map of pixel-wise quantitative measures
(T2, T2*, T1-rho, etc.). Volume should have ``np.nan`` values for
all pixels unable to be calculated.
map_type (QuantitativeValueType): Type of quantitative value to analyze.
Raises:
TypeError: If `quant_map` is not of type `MedicalVolume` or `map_type` is not of type
`QuantitativeValueType`.
ValueError: If no mask is found for tissue.
"""
if not isinstance(quant_map, MedicalVolume):
raise TypeError("`Expected type 'MedicalVolume' for `quant_map`")
if not isinstance(map_type, QuantitativeValueType):
raise TypeError("`Expected type 'QuantitativeValueType' for `map_type`")
if self.__mask__ is None:
raise ValueError("Please initialize mask for {}".format(self.FULL_NAME))
quant_map.reformat(self.__mask__.orientation, inplace=True)
pass
def __store_quant_vals__(
self, quant_map: MedicalVolume, quant_df: pd.DataFrame, map_type: QuantitativeValueType
):
"""Adds quantitative value in `self.quant_vals`.
Args:
quant_map (list[dict]): Dictionaries of different unrolled maps and
corresponding plotting data (title, xlabel, etc.).
quant_df (pd.DataFrame): Computed data for this quantitative value.
map_type (QuantitativeValueType): Type of quantitative value to analyze.
"""
self.quant_vals[map_type.name] = (quant_map, quant_df)
def find_weights(self, weights_dir: str):
"""Search for weights file in weights directory.
Args:
weights_dir (str): Directory where weights are stored.
Returns:
str: File path to weights corresponding to tissue.
Raises:
ValueError: If multiple weights files exists for the tissue
or no valid weights file found.
"""
# Find weights file with NAME in the filename, like 'fc_weights.h5'
files = os.listdir(weights_dir)
weights_file = None
for f in files:
file = os.path.join(weights_dir, f)
if os.path.isfile(file) and f.endswith(WEIGHTS_FILE_EXT) and self.STR_ID in f:
if weights_file is not None:
raise ValueError("There are multiple weights files, please remove duplicates")
weights_file = file
if weights_file is None:
raise ValueError(
"No file found that contains '{}' and ends in '{}'".format(
self.STR_ID, WEIGHTS_FILE_EXT
)
)
self.weights_file_path = weights_file
return weights_file
def save_data(
self, save_dirpath: str, data_format: ImageDataFormat = preferences.image_data_format
):
"""Save data for tissue.
Saves mask and quantitative values associated with this tissue.
Override in subclasses to save additional data. When overriding in subclasses, call
``super().save_data(save_dirpath)`` first to save mask and quantitative values by default.
See :mod:`dosma.tissues.femoral_cartilage` for details.
.. literalinclude:: femoral_cartilage.py
Args:
save_dirpath (str): Directory path where all data is stored.
data_format (`ImageDataFormat`, optional): Format to save data.
"""
save_dirpath = self.__save_dirpath__(save_dirpath)
if self.__mask__ is not None:
mask_file_path = os.path.join(save_dirpath, "{}.nii.gz".format(self.STR_ID))
mask_file_path = fio_utils.convert_image_data_format(mask_file_path, data_format)
self.__mask__.save_volume(mask_file_path, data_format=data_format)
for qv in self.quantitative_values:
qv.save_data(save_dirpath, data_format)
self.__save_quant_data__(save_dirpath)
@abstractmethod
def __save_quant_data__(self, dirpath: str):
"""Save quantitative data generated for this tissue.
Called by `save_data`.
Args:
dirpath (str): Directory path to tissue data.
"""
pass
def save_quant_data(self, dirpath: str):
"""Save quantitative data generated for this tissue.
Does not save mask or quantitative parameter map.
Args:
dirpath (str): Directory path to tissue data.
"""
return self.__save_quant_data__(dirpath)
def load_data(self, load_dir_path: str):
"""Load data for tissue.
All tissue information is based on the mask. If mask for tissue doesn't exist,
there is no information to load.
Args:
load_dir_path (str): Directory path where all data is stored.
"""
load_dir_path = self.__save_dirpath__(load_dir_path)
mask_file_path = os.path.join(load_dir_path, "{}.nii.gz".format(self.STR_ID))
# Try to load mask, if file exists.
try:
msk = fio_utils.generic_load(mask_file_path, expected_num_volumes=1)
self.set_mask(msk)
except FileNotFoundError:
# do nothing
pass
self.quantitative_values = QuantitativeValue.load_qvs(load_dir_path)
def __save_dirpath__(self, dirpath):
"""Tissue-specific subdirectory to store data.
Subdirectory will have path '`dirpath`/`self.STR_ID`/'.
If directory does not exist, it will be created.
Args:
dirpath (str): Directory path where all data is stored.
Returns:
str: Tissue-specific data directory.
"""
return io_utils.mkdirs(os.path.join(dirpath, self.STR_ID))
# TODO (arjundd): Refactor get/set methods of mask to property.
def set_mask(self, mask: MedicalVolume):
"""Set mask for tissue.
Args:
mask (MedicalVolume): Binary mask of segmented tissue.
"""
assert type(mask) is MedicalVolume, "mask for tissue must be of type MedicalVolume"
mask = mask.reformat(SAGITTAL)
self.__mask__ = mask
def get_mask(self):
"""
Returns:
MedicalVolume: Binary mask of segmented tissue.
"""
return self.__mask__
def add_quantitative_value(self, qv_new: QuantitativeValue):
"""Add quantitative value to the tissue.
Args:
qv_new (QuantitativeValue): Quantitative value to add to tissue.
"""
# for qv in self.quantitative_values:
# if qv_new.NAME == qv.NAME:
# raise ValueError('This quantitative value already exists. '
# 'Only one type of quantitative value can be added per tissue.\n'
# 'Manually delete %s folder' % qv_new.NAME)
self.quantitative_values.append(qv_new)
def __get_axis_bounds__(
self, im: np.ndarray, ignore_nan: bool = True, leave_buffer: bool = False
):
"""Get tightest bounds for data in the array.
When plotting data, we would like to avoid making our dynamic range too large such
that we cannot detect color changes in differences that matter.
To avoid this, we make our bounds as tight as possible.
Bounds are calculated with respect to non-zero elements. If unique values are [0, 8, 9],
the dyanmic range will be [8, 9].
Args:
im (np.ndarray): Array containing information for which bounds have to be computed.
ignore_nan (obj:`bool`, optional): Ignore `nan` values when computing the bounds.
leave_buffer (obj:`bool`, optional): Add buffer of +/-5 to dynamic range.
"""
im_temp = im
axs = []
if ignore_nan:
im_temp = np.nan_to_num(im)
non_zero_elems = np.nonzero(im_temp)
for i in range(len(non_zero_elems)):
v_min = np.min(non_zero_elems[i])
v_max = np.max(non_zero_elems[i])
if leave_buffer:
v_min -= 5
v_max += 5
axs.append((v_min, v_max))
return axs
def largest_cc(mask, num=1):
"""Return the largest `num` connected component(s) of a 3D mask array.
Args:
mask (np.ndarray): 3D mask array (`np.bool` or `np.[u]int`).
num (int, optional): Maximum number of connected components to keep.
Returns:
mask (np.ndarray): 3D mask array with `num` connected components.
Note:
Adapted from nipy (https://github.com/nipy/nipy/blob/master/nipy/labs/mask.py)
due to dependency issues.
"""
# We use asarray to be able to work with masked arrays.
mask = np.asarray(mask)
labels, label_nb = sni.label(mask)
if not label_nb:
raise ValueError("No non-zero values: no connected components")
if label_nb == 1:
return mask.astype(np.bool)
label_count = np.bincount(labels.ravel().astype(np.int))
# discard 0 the 0 label
label_count[0] = 0
# Split num=1 case for speed.
if num == 1:
return labels == label_count.argmax()
else:
# 1) discard 0 the 0 label and 2) descending order
order = np.argsort(label_count)[1:][::-1]
return np.isin(labels, order[:num])