|
a |
|
b/dl/utils/utils.py |
|
|
1 |
import os |
|
|
2 |
import functools |
|
|
3 |
import itertools |
|
|
4 |
import collections |
|
|
5 |
import numpy as np |
|
|
6 |
import pandas |
|
|
7 |
from PIL import Image |
|
|
8 |
import sklearn.metrics |
|
|
9 |
|
|
|
10 |
import torch |
|
|
11 |
import torch.nn as nn |
|
|
12 |
import torch.nn.functional as F |
|
|
13 |
from torch.utils import data |
|
|
14 |
|
|
|
15 |
from .outlier import normalization |
|
|
16 |
from .train import get_label_prob |
|
|
17 |
|
|
|
18 |
def discrete_to_id(targets, start=0, sort=True, complex_object=False): |
|
|
19 |
"""Change discrete variable targets to numeric values |
|
|
20 |
|
|
|
21 |
Args: |
|
|
22 |
targets: 1-d torch.Tensor or np.array, or a list |
|
|
23 |
start: the starting index for the first elements |
|
|
24 |
sort: sort the unique value, so that the 'smaller' values have smaller indices |
|
|
25 |
complex_object: input is not numeric, but complex objects, e.g., tuple |
|
|
26 |
|
|
|
27 |
Returns: |
|
|
28 |
target_ids: torch.Tensor or np.array with integer elements starting from start(=0 default) |
|
|
29 |
cls_id_dict: a dictionary mapping variables to their numeric ids |
|
|
30 |
|
|
|
31 |
""" |
|
|
32 |
if complex_object: |
|
|
33 |
unique_targets = sorted(collections.Counter(targets)) |
|
|
34 |
else: |
|
|
35 |
if isinstance(targets, torch.Tensor): |
|
|
36 |
targets = targets.cpu().detach().numpy() |
|
|
37 |
else: |
|
|
38 |
targets = np.array(targets) # if targets is already an np.array, then it does nothing |
|
|
39 |
unique_targets = np.unique(targets) |
|
|
40 |
if sort: |
|
|
41 |
unique_targets = np.sort(unique_targets) |
|
|
42 |
cls_id_dict = {v: i+start for i, v in enumerate(unique_targets)} |
|
|
43 |
target_ids = np.array([cls_id_dict[v] for v in targets]) |
|
|
44 |
if isinstance(targets, torch.Tensor): |
|
|
45 |
target_ids = targets.new_tensor(target_ids) |
|
|
46 |
return target_ids, cls_id_dict |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def get_f1_score(m, average='weighted', verbose=False): |
|
|
50 |
"""Given a confusion matrix for binary classification, |
|
|
51 |
calculate accuracy, precision, recall, F1 measure |
|
|
52 |
|
|
|
53 |
Args: |
|
|
54 |
m: confusion mat for binary classification |
|
|
55 |
average: if 'weighted': calculate metrics for each label, then get weighted average (weights are supports) |
|
|
56 |
if 'average': calculate average metrics for each label |
|
|
57 |
see http://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html |
|
|
58 |
verbose: if True, print result |
|
|
59 |
""" |
|
|
60 |
def cal_f1(precision, recall): |
|
|
61 |
if precision + recall == 0: |
|
|
62 |
print('Both precision and recall are zero') |
|
|
63 |
return 0 |
|
|
64 |
return 2*precision*recall / (precision+recall) |
|
|
65 |
m = np.array(m) |
|
|
66 |
t0 = m[0,0] + m[0,1] |
|
|
67 |
t1 = m[1,0] + m[1,1] |
|
|
68 |
p0 = m[0,0] + m[1,0] |
|
|
69 |
p1 = m[0,1] + m[1,1] |
|
|
70 |
prec0 = m[0,0] / p0 |
|
|
71 |
prec1 = m[1,1] / p1 |
|
|
72 |
recall0 = m[0,0] / t0 |
|
|
73 |
recall1 = m[1,1] / t1 |
|
|
74 |
f1_0 = cal_f1(prec0, recall0) |
|
|
75 |
f1_1 = cal_f1(prec1, recall1) |
|
|
76 |
if average == 'macro': |
|
|
77 |
w0 = 0.5 |
|
|
78 |
w1 = 0.5 |
|
|
79 |
elif average == 'weighted': |
|
|
80 |
w0 = t0 / (t0+t1) |
|
|
81 |
w1 = t1 / (t0+t1) |
|
|
82 |
prec = prec0*w0 + prec1*w1 |
|
|
83 |
recall = recall0*w0 + recall1*w1 |
|
|
84 |
f1 = f1_0*w0 + f1_1*w1 |
|
|
85 |
acc = (m[0,0] + m[1,1]) / (t0+t1) |
|
|
86 |
if verbose: |
|
|
87 |
print(f'prec0={prec0}, recall0={recall0}, f1_0={f1_0}\n' |
|
|
88 |
f'prec1={prec1}, recall1={recall1}, f1_1={f1_1}') |
|
|
89 |
return acc, prec, recall, f1 |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
def dist(params1, params2=None, dist_fn=torch.norm): #pylint disable=no-member |
|
|
93 |
"""Calculate the norm of params1 or the distance between params1 and params2; |
|
|
94 |
Common usage calculate the distance between two model state_dicts. |
|
|
95 |
Args: |
|
|
96 |
params1: dictionary; with each item a torch.Tensor |
|
|
97 |
params2: if not None, should have the same structure (data types and dimensions) as params1 |
|
|
98 |
""" |
|
|
99 |
if params2 is None: |
|
|
100 |
return dist_fn(torch.Tensor([dist_fn(params1[k]) for k in params1])) |
|
|
101 |
d = torch.Tensor([dist_fn(params1[k] - params2[k]) for k in params1]) |
|
|
102 |
return dist_fn(d) |
|
|
103 |
|
|
|
104 |
class AverageMeter(object): |
|
|
105 |
def __init__(self): |
|
|
106 |
self._reset() |
|
|
107 |
|
|
|
108 |
def _reset(self): |
|
|
109 |
self.val = 0 |
|
|
110 |
self.sum = 0 |
|
|
111 |
self.cnt = 0 |
|
|
112 |
self.avg = 0 |
|
|
113 |
|
|
|
114 |
def update(self, val, n=1): |
|
|
115 |
self.val = val |
|
|
116 |
self.sum += val * n |
|
|
117 |
self.cnt += n |
|
|
118 |
self.avg = self.sum / self.cnt |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def pil_loader(path, format = 'RGB'): |
|
|
122 |
with open(path, 'rb') as f: |
|
|
123 |
with Image.open(f) as img: |
|
|
124 |
return img.convert(format) |
|
|
125 |
|
|
|
126 |
|
|
|
127 |
class ImageFolder(data.Dataset): |
|
|
128 |
def __init__(self, root, imgs, transform = None, target_transform = None, |
|
|
129 |
loader = pil_loader, is_test = False): |
|
|
130 |
self.root = root |
|
|
131 |
self.imgs = imgs |
|
|
132 |
self.transform = transform |
|
|
133 |
self.target_transform = target_transform |
|
|
134 |
self.loader = pil_loader |
|
|
135 |
self.is_test = is_test |
|
|
136 |
|
|
|
137 |
def __getitem__(self, idx): |
|
|
138 |
if self.is_test: |
|
|
139 |
img = self.imgs[idx] |
|
|
140 |
else: |
|
|
141 |
img, target = self.imgs[idx] |
|
|
142 |
img = self.loader(os.path.join(self.root, img)) |
|
|
143 |
if self.transform is not None: |
|
|
144 |
img = self.transform(img) |
|
|
145 |
if not self.is_test and self.target_transform is not None: |
|
|
146 |
target = self.target_transform(target) |
|
|
147 |
if self.is_test: |
|
|
148 |
return img |
|
|
149 |
else: |
|
|
150 |
return img, target |
|
|
151 |
|
|
|
152 |
def __len__(self): |
|
|
153 |
return len(self.imgs) |
|
|
154 |
|
|
|
155 |
|
|
|
156 |
def check_acc(output, target, topk=(1,)): |
|
|
157 |
if isinstance(output, tuple): |
|
|
158 |
output = output[0] |
|
|
159 |
maxk = max(topk) |
|
|
160 |
_, pred = output.topk(maxk, 1) |
|
|
161 |
res = [] |
|
|
162 |
for k in topk: |
|
|
163 |
acc = (pred.eq(target.contiguous().view(-1,1).expand(pred.size()))[:, :k] |
|
|
164 |
.float().contiguous().view(-1).sum(0)) |
|
|
165 |
acc.mul_(100 / target.size(0)) |
|
|
166 |
res.append(acc) |
|
|
167 |
return res |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
### Mainly developed for TCGA data analysis |
|
|
171 |
def select_samples(mat, aliquot_ids, feature_ids, patient_clinical=None, clinical_variable='PFI', |
|
|
172 |
sample_type='01', drop_duplicates=True, remove_na=True): |
|
|
173 |
"""Select samples with given sample_type ('01'); |
|
|
174 |
if drop_duplicates is True (by default), remove technical duplicates; |
|
|
175 |
and if remove_na is True (default), remove features that have NA; |
|
|
176 |
If patient_clinical is not None, further filter out samples with clinical_variable being NA |
|
|
177 |
""" |
|
|
178 |
mat = pandas.DataFrame(mat, columns=feature_ids) # Use pandas to drop NA |
|
|
179 |
# Select samples with sample_type(='01') |
|
|
180 |
idx = np.array([[i,s[:12]] for i, s in enumerate(aliquot_ids) if s[13:15]==sample_type]) |
|
|
181 |
# Remove technical duplicate |
|
|
182 |
if drop_duplicates: |
|
|
183 |
idx = pandas.DataFrame(idx).drop_duplicates(subset=[1]).values |
|
|
184 |
mat = mat.iloc[idx[:,0].astype(int)] |
|
|
185 |
aliquot_ids = aliquot_ids[idx[:,0].astype(int)] |
|
|
186 |
if remove_na: |
|
|
187 |
# Remove features that have NA values |
|
|
188 |
mat = mat.dropna(axis=1) |
|
|
189 |
feature_ids = mat.columns.values |
|
|
190 |
mat = mat.values |
|
|
191 |
if patient_clinical is not None: |
|
|
192 |
idx = [s[:12] in patient_clinical and not np.isnan(patient_clinical[s[:12]][clinical_variable]) |
|
|
193 |
for s in aliquot_ids] |
|
|
194 |
mat = mat[idx] |
|
|
195 |
aliquot_ids = aliquot_ids[idx] |
|
|
196 |
return mat, aliquot_ids, feature_ids |
|
|
197 |
|
|
|
198 |
|
|
|
199 |
def get_feature_feature_mat(feature_ids, gene_ids, feature_gene_adj, gene_gene_adj, |
|
|
200 |
max_score=1000): |
|
|
201 |
"""Calculate feature-feature interaction matrix based on their mapping to genes |
|
|
202 |
and gene-gene interactions: |
|
|
203 |
feature_feature = feature_gene * gene_gene * feature_gene^T (transpose) |
|
|
204 |
|
|
|
205 |
Args: |
|
|
206 |
feature_ids: np.array([feature_names]), dict {id: feature_name}, or {feature_name: id} |
|
|
207 |
gene_ids: np.array([gene_names]), dict {id: gene_name}, or {gene_name: id} |
|
|
208 |
feature_gene_adj: np.array([[feature_name, gene_name, score]]) |
|
|
209 |
with rows corresponding to features and columns genes; |
|
|
210 |
or (Deprecated) a list (gene) of lists of feature_ids. |
|
|
211 |
Note this is different from np.array input; len(feature_gene_adj) = len(gene_ids) |
|
|
212 |
gene_gene_adj: an np.array. Each row is (gene_name1, gene_name2, score) |
|
|
213 |
max_score: default 1000. Normalize confidence scores in gene_gene_adj to be in [0, 1] |
|
|
214 |
|
|
|
215 |
Returns: |
|
|
216 |
feature_feature_mat: np.array of shape (len(feature_ids), len(feature_ids)) |
|
|
217 |
|
|
|
218 |
""" |
|
|
219 |
def check_input_ids(ids): |
|
|
220 |
if isinstance(ids, np.ndarray) or isinstance(ids, list): |
|
|
221 |
ids = {v: i for i, v in enumerate(ids)} # Map feature names to indices starting from 0 |
|
|
222 |
elif isinstance(ids, dict): |
|
|
223 |
if sorted(ids) == list(range(len(ids))): |
|
|
224 |
# make sure it follows format {feature_name: id} |
|
|
225 |
ids = {v: k for k, v in ids.items()} |
|
|
226 |
else: |
|
|
227 |
raise ValueError(f'The input ids should be a list/np.ndarray/dictionary, ' |
|
|
228 |
'but is {type(feature_ids)}') |
|
|
229 |
return ids |
|
|
230 |
feature_ids = check_input_ids(feature_ids) |
|
|
231 |
gene_ids = check_input_ids(gene_ids) |
|
|
232 |
|
|
|
233 |
idx = [] |
|
|
234 |
if isinstance(feature_gene_adj, list): # Assume feature_gene_adj is a list; this is deprecated |
|
|
235 |
for i, v in enumerate(feature_gene_adj): |
|
|
236 |
for j in v: |
|
|
237 |
idx.append([j, i, 1]) |
|
|
238 |
elif isinstance(feature_gene_adj, np.ndarray) and feature_gene_adj.shape[1] == 3: |
|
|
239 |
for v in feature_gene_adj: |
|
|
240 |
if v[0] in feature_ids and v[1] in gene_ids: |
|
|
241 |
idx.append([feature_ids[v[0]], gene_ids[v[1]], float(v[2])]) |
|
|
242 |
else: |
|
|
243 |
raise ValueError('feature_gene_adj should be an np.ndarray of shape (N, 3) ' |
|
|
244 |
'or a list of lists (deprecated).') |
|
|
245 |
idx = np.array(idx).T |
|
|
246 |
feature_gene_mat = torch.sparse.FloatTensor(torch.tensor(idx[:2]).long(), |
|
|
247 |
torch.tensor(idx[2]).float(), |
|
|
248 |
(len(feature_ids), len(gene_ids))) |
|
|
249 |
# Extract a subnetwork from gene_gene_adj |
|
|
250 |
# Assume there is no self-loop in gene_gene_adj |
|
|
251 |
# and it contains two records for each undirected edge |
|
|
252 |
idx = [] |
|
|
253 |
for v in gene_gene_adj: |
|
|
254 |
if v[0] in gene_ids and v[1] in gene_ids: |
|
|
255 |
idx.append([gene_ids[v[0]], gene_ids[v[1]], v[2]/max_score]) |
|
|
256 |
# Add self-loops |
|
|
257 |
for i in range(len(gene_ids)): |
|
|
258 |
idx.append([i, i, 1.]) |
|
|
259 |
idx = np.array(idx).T |
|
|
260 |
gene_gene_mat = torch.sparse.FloatTensor(torch.tensor(idx[:2]).long(), |
|
|
261 |
torch.tensor(idx[2]).float(), |
|
|
262 |
(len(gene_ids), len(gene_ids))) |
|
|
263 |
feature_feature_mat = feature_gene_mat.mm(gene_gene_mat.mm(feature_gene_mat.to_dense().t())) |
|
|
264 |
return feature_feature_mat.numpy() |
|
|
265 |
|
|
|
266 |
|
|
|
267 |
def get_overlap_samples(sample_lists, common_list=None, start=0, end=12, return_common_list=False): |
|
|
268 |
"""Given a list of aliquot_id lists, find the common sample ids |
|
|
269 |
|
|
|
270 |
Args: |
|
|
271 |
sample_lists: a iterable of sample (aliquot) id lists |
|
|
272 |
common_list: if None (default), find the interaction of sample_lists; |
|
|
273 |
if provided, it should not be a set, because iterating over a set can be different from different runs |
|
|
274 |
start: default 0; assume sample ids are strings; |
|
|
275 |
when finding overlapping samples, only consider a specific range [start, end) |
|
|
276 |
end: default 12, for TCGA BCR barcode |
|
|
277 |
return_common_list: if True, return a set containing common list for backward compatiablity, |
|
|
278 |
returns a sorted common list is a better option |
|
|
279 |
|
|
|
280 |
Returns: |
|
|
281 |
np.array of shape (len(sample_lists), len(common_list)) |
|
|
282 |
""" |
|
|
283 |
sample_lists = [[s_id[start:end] for s_id in sample_list] for sample_list in sample_lists] |
|
|
284 |
if common_list is None: |
|
|
285 |
common_list = functools.reduce(lambda x,y: set(x).intersection(y), sample_lists) |
|
|
286 |
if return_common_list: |
|
|
287 |
return common_list |
|
|
288 |
common_list = sorted(common_list) # iterate over set can vary from different runs |
|
|
289 |
for s in sample_lists: # make sure every list in sample_lists contains all elements in common_list |
|
|
290 |
assert len(set(common_list).difference(s)) == 0 |
|
|
291 |
idx_lists = np.array([[sample_list.index(s_id) for s_id in common_list] |
|
|
292 |
for sample_list in sample_lists]) |
|
|
293 |
return idx_lists |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
# Select samples that have target variable(s) is in clinical file |
|
|
297 |
def filter_clinical_dict(target_variable, target_variable_type, target_variable_range, |
|
|
298 |
clinical_dict): |
|
|
299 |
"""Select patients with given target variable, its type and range in clinical data |
|
|
300 |
To save computation time, I assume all target variable(s) names are in clinical_dict without verification; |
|
|
301 |
|
|
|
302 |
Args: |
|
|
303 |
target_variable: str or a list of strings |
|
|
304 |
target_variable_type: 'discrete' or 'continuous' or a list of 'discrete' or 'continuous' |
|
|
305 |
target_variable_range: a list of values for 'continous' type, it is [lower_bound, upper_bound] |
|
|
306 |
or a list of list; target_variable, target_variable_type, target_variable_range must match |
|
|
307 |
clinical_dict: a dictionary of dictinaries; |
|
|
308 |
first-level keys: patient ids, second-level keys: variable names |
|
|
309 |
|
|
|
310 |
Returns: |
|
|
311 |
clinical_dict: newly constructed clinical_dict with all patients having target_variables |
|
|
312 |
|
|
|
313 |
Examples: |
|
|
314 |
target_variable = ['PFI', 'OS.time'] |
|
|
315 |
target_variable_type = ['discrete', 'continuous'] |
|
|
316 |
target_variable_range = [[0, 1], [0, float('Inf')]] |
|
|
317 |
clinical_dict = filter_clinical_dict(target_variable, target_variable_type, target_variable_range, |
|
|
318 |
patient_clinical) |
|
|
319 |
assert sorted([k for k, v in patient_clinical.items() if v['PFI'] in [0,1] and not np.isnan(v['OS.time'])]) == |
|
|
320 |
sorted(clinical_dict.keys()) |
|
|
321 |
|
|
|
322 |
""" |
|
|
323 |
if isinstance(target_variable, str): |
|
|
324 |
if target_variable_type == 'discrete': |
|
|
325 |
clinical_dict = {p:v for p, v in clinical_dict.items() |
|
|
326 |
if v[target_variable] in target_variable_range} |
|
|
327 |
elif target_variable_type == 'continuous': |
|
|
328 |
clinical_dict = {p:v for p, v in clinical_dict.items() |
|
|
329 |
if v[target_variable] >= target_variable_range[0] |
|
|
330 |
and v[target_variable] <= target_variable_range[1]} |
|
|
331 |
|
|
|
332 |
elif isinstance(target_variable, (list, tuple)): |
|
|
333 |
# Brilliant recursion |
|
|
334 |
for tar_var, tar_var_type, tar_var_range in zip(target_variable, target_variable_type, target_variable_range): |
|
|
335 |
clinical_dict = filter_clinical_dict(tar_var, tar_var_type, tar_var_range, clinical_dict) |
|
|
336 |
|
|
|
337 |
return clinical_dict |
|
|
338 |
|
|
|
339 |
|
|
|
340 |
def get_target_variable(target_variable, clinical_dict, sel_patient_ids): |
|
|
341 |
"""Extract target_variable from clinical_dict for sel_patient_ids |
|
|
342 |
If target_variable is a single str, it is only one line of code |
|
|
343 |
If target_variable is a list, recursively call itself and return a list of target variables |
|
|
344 |
|
|
|
345 |
Assume all sel_patient_ids have target_variable in clinical_dict |
|
|
346 |
|
|
|
347 |
""" |
|
|
348 |
if isinstance(target_variable, str): |
|
|
349 |
return [clinical_dict[s][target_variable] for s in sel_patient_ids] |
|
|
350 |
elif isinstance(target_variable, (list, str)): |
|
|
351 |
return [[clinical_dict[s][tar_var] for s in sel_patient_ids] for tar_var in target_variable] |
|
|
352 |
|
|
|
353 |
|
|
|
354 |
def normalize_continuous_variable(y_targets, target_variable_type, transform=True, forced=False, |
|
|
355 |
threshold=10, rm_outlier=True, whis=1.5, only_positive=True, max_val=1): |
|
|
356 |
"""Normalize continuous variable(s) |
|
|
357 |
If a variable is 'continuous', then call normalization() in outlier.py |
|
|
358 |
|
|
|
359 |
Args: |
|
|
360 |
y_targets: a np.array or a list of np.array |
|
|
361 |
target_variable_type: can be a string: 'continous' or 'discrete' (do nothing but return the input) |
|
|
362 |
or a list of strings |
|
|
363 |
transform, forced, threshold, rm_outlier, whis, only_positive, max_val are all passed to normalization |
|
|
364 |
|
|
|
365 |
""" |
|
|
366 |
if isinstance(target_variable_type, str): |
|
|
367 |
if target_variable_type=='continuous': |
|
|
368 |
y_targets = normalization(y_targets, transform=transform, forced=forced, threshold=threshold, |
|
|
369 |
rm_outlier=rm_outlier, whis=whis, only_positive=only_positive, |
|
|
370 |
max_val=max_val, diagonal=False, symmetric=False) |
|
|
371 |
return y_targets |
|
|
372 |
elif isinstance(target_variable_type, list): |
|
|
373 |
return [normalize_continuous_variable(y, var_type, transform=transform, forced=forced, |
|
|
374 |
threshold=threshold, rm_outlier=rm_outlier, whis=whis, only_positive=only_positive, |
|
|
375 |
max_val=max_val) for y, var_type in zip(y_targets, target_variable_type)] |
|
|
376 |
else: |
|
|
377 |
raise ValueError(f'target_variable_type should be a str or list of strs, but is {target_variable_type}') |
|
|
378 |
|
|
|
379 |
|
|
|
380 |
def get_label_distribution(ys, check_num_cls=True): |
|
|
381 |
"""Get label distributions for a list of labels |
|
|
382 |
|
|
|
383 |
Args: |
|
|
384 |
ys: an iterable (e.g., list) of labels (1-d numpy.array or torch.Tensor); |
|
|
385 |
the most common usage is get_label_distribution([y_train, y_val, y_test]) |
|
|
386 |
check_num_cls: only if it is True, ensure that each list of labels will have the same number of classes |
|
|
387 |
and also print out the message |
|
|
388 |
|
|
|
389 |
Returns: |
|
|
390 |
label_prob: a list of label distributions (multinomial); |
|
|
391 |
|
|
|
392 |
""" |
|
|
393 |
num_cls = 0 |
|
|
394 |
label_probs = [] |
|
|
395 |
for i, y in enumerate(ys): |
|
|
396 |
if len(y)>0: |
|
|
397 |
label_prob = get_label_prob(y, verbose=False) |
|
|
398 |
label_probs.append(label_prob) |
|
|
399 |
if check_num_cls: |
|
|
400 |
if num_cls > 0: |
|
|
401 |
assert num_cls == len(label_probs[-1]), f'{i}: {num_cls} != {len(label_probs[-1])}' |
|
|
402 |
else: |
|
|
403 |
num_cls = len(label_probs[-1]) |
|
|
404 |
else: |
|
|
405 |
label_probs.append([]) |
|
|
406 |
if check_num_cls: |
|
|
407 |
if isinstance(label_probs, torch.Tensor): |
|
|
408 |
print('label distribution:\n', torch.stack(label_probs, dim=1)) |
|
|
409 |
else: |
|
|
410 |
print('label distribution:\n', np.stack(label_probs, axis=1)) |
|
|
411 |
return label_probs |
|
|
412 |
|
|
|
413 |
|
|
|
414 |
def get_shuffled_data(sel_patient_ids, clinical_dict, cv_type, instance_portions, group_sizes, |
|
|
415 |
group_variable_name, seed=None, verbose=True): |
|
|
416 |
"""Shuffle sel_patient_ids and split them into multiple splits, |
|
|
417 |
in most cases, train, val and test sets; |
|
|
418 |
|
|
|
419 |
Args: |
|
|
420 |
sel_patient_ids: a list of object (patient) ids |
|
|
421 |
clinical_dict: a dictionary of dictionaries; |
|
|
422 |
first-level keys: object ids; second-level keys: attribute names; |
|
|
423 |
cv_type: either 'group-shuffle' or 'instance-shuffle'; in most cases: |
|
|
424 |
if 'group-shuffle', split groups into train, val and test set according to group_sizes or |
|
|
425 |
implicitly instance_portions; |
|
|
426 |
if 'instance-shuffle': split based on instance_portions |
|
|
427 |
instance_portions: a list of floats; the proportions of samples in each split; |
|
|
428 |
when cv_type=='group-shuffle' and group_sizes is given, then instance_portions is not used |
|
|
429 |
group_sizes: the number of groups in each split; only used when cv_type=='group-shuffle' |
|
|
430 |
group_variable_name: the attribute name for group information |
|
|
431 |
|
|
|
432 |
Returns: |
|
|
433 |
sel_patient_ids: shuffled object ids |
|
|
434 |
idx_splits: a list of indices, e.g., [train_idx, val_idx, test_idx] |
|
|
435 |
sel_patient_ids[train_idx] will get patient ids for training |
|
|
436 |
|
|
|
437 |
""" |
|
|
438 |
np.random.seed(seed) |
|
|
439 |
sel_patient_ids = np.random.permutation(sel_patient_ids) |
|
|
440 |
num_samples = len(sel_patient_ids) |
|
|
441 |
idx_splits = [] |
|
|
442 |
if cv_type == 'group-shuffle': |
|
|
443 |
# for my TCGA project, I used disease types as groups; thus the variable name is named 'disease_types' |
|
|
444 |
disease_types = sorted({clinical_dict[s][group_variable_name] for s in sel_patient_ids}) |
|
|
445 |
num_disease_types = len(disease_types) |
|
|
446 |
np.random.shuffle(disease_types) |
|
|
447 |
type_splits = [] |
|
|
448 |
cnt = 0 |
|
|
449 |
for i in range(len(group_sizes)-1): |
|
|
450 |
if group_sizes[i] < 0: |
|
|
451 |
# use instance_portion as group portions |
|
|
452 |
assert sum(instance_portions) == 1 |
|
|
453 |
group_sizes[i] = round(instance_portions[i] * num_disease_types) |
|
|
454 |
type_splits.append(disease_types[cnt:cnt+group_sizes[i]]) |
|
|
455 |
cnt = cnt+group_sizes[i] |
|
|
456 |
# do not use i to enumerate sel_patient_ids because i is used |
|
|
457 |
idx_splits.append([j for j, s in enumerate(sel_patient_ids) |
|
|
458 |
if clinical_dict[s][group_variable_name] in type_splits[i]]) |
|
|
459 |
# process the last split |
|
|
460 |
if group_sizes[-1] >=0: # for most of time, set group_sizes[-1] = num_test_types = -1 |
|
|
461 |
# almost never set group_sizes[-1] = 0, which will be useless |
|
|
462 |
assert group_sizes[-1] == num_disease_types - sum(group_sizes[:-1]) |
|
|
463 |
if cnt == len(disease_types): |
|
|
464 |
print('The last group is empty, thus not included') |
|
|
465 |
else: |
|
|
466 |
type_splits.append(disease_types[cnt:]) |
|
|
467 |
idx_splits.append([i for i, s in enumerate(sel_patient_ids) |
|
|
468 |
if clinical_dict[s][group_variable_name] in type_splits[-1]]) |
|
|
469 |
elif cv_type == 'instance-shuffle': |
|
|
470 |
# because sel_patient_ids has already been shuffled, we do not need to shuffle indices |
|
|
471 |
cnt = 0 |
|
|
472 |
assert sum(instance_portions) == 1 |
|
|
473 |
for i in range(len(instance_portions)-1): |
|
|
474 |
n = round(instance_portions[i]*num_samples) |
|
|
475 |
idx_splits.append(list(range(cnt, cnt+n))) |
|
|
476 |
cnt = cnt + n |
|
|
477 |
# process the last split |
|
|
478 |
if cnt == num_samples: |
|
|
479 |
# this can rarely happen |
|
|
480 |
print('The last split is empty, thus not included') |
|
|
481 |
else: |
|
|
482 |
idx_splits.append(list(range(cnt, num_samples))) |
|
|
483 |
|
|
|
484 |
def get_type_cnt_msg(p_ids): |
|
|
485 |
"""For a list p_ids, prepare group statistics for printing |
|
|
486 |
""" |
|
|
487 |
cnt_dict = dict(collections.Counter([clinical_dict[p_id][group_variable_name] |
|
|
488 |
for p_id in p_ids])) |
|
|
489 |
return f'{len(cnt_dict)} groups: {cnt_dict}' |
|
|
490 |
|
|
|
491 |
if verbose: |
|
|
492 |
msg = f'{cv_type}: \n' |
|
|
493 |
msg += '\n'.join([f'split {i}: {len(v)} samples ({len(v)/num_samples:.2f}), ' |
|
|
494 |
f'{get_type_cnt_msg(sel_patient_ids[v])}' |
|
|
495 |
for i, v in enumerate(idx_splits)]) |
|
|
496 |
print(msg) |
|
|
497 |
return sel_patient_ids, idx_splits |
|
|
498 |
|
|
|
499 |
|
|
|
500 |
def target_to_numpy(y_targets, target_variable_type, target_variable_range): |
|
|
501 |
"""y_targets is a list or a list of lists; transform it to numpy array |
|
|
502 |
For a discrete variable, generate numerical class labels from 0; |
|
|
503 |
for a continous variable, simply call np.array(y_targets); |
|
|
504 |
use recusion to handle a list of target variables |
|
|
505 |
|
|
|
506 |
Args: |
|
|
507 |
y_targets: a list of objects (strings/numbers, must be comparable) or lists |
|
|
508 |
target_variable_type: a string or a list of string ('discrete' or 'continous') |
|
|
509 |
target_variable_range: only used for sanity check for discrete variables |
|
|
510 |
|
|
|
511 |
Returns: |
|
|
512 |
y_true: a numpy array or a list of numpy arrays of type either float or int |
|
|
513 |
|
|
|
514 |
""" |
|
|
515 |
|
|
|
516 |
if isinstance(target_variable_type, str): |
|
|
517 |
y_true = np.array(y_targets) |
|
|
518 |
if target_variable_type == 'discrete': |
|
|
519 |
unique_cls = np.unique(y_true) |
|
|
520 |
num_cls = len(unique_cls) |
|
|
521 |
if sorted(unique_cls) != sorted(target_variable_range): |
|
|
522 |
print(f'unique_cls: {unique_cls} !=\ntarget_variable_range {target_variable_range}') |
|
|
523 |
cls_idx_dict = {p.item(): i for i, p in enumerate(sorted(unique_cls))} |
|
|
524 |
y_true = np.array([cls_idx_dict[i.item()] for i in y_true]) |
|
|
525 |
print(f'Changed class labels for the model: {cls_idx_dict}') |
|
|
526 |
elif isinstance(target_variable_type, (list, tuple)): |
|
|
527 |
y_true = [target_to_numpy(y, tar_var_type, tar_var_range) |
|
|
528 |
for y, tar_var_type, tar_var_range in |
|
|
529 |
zip(y_targets, target_variable_type, target_variable_range)] |
|
|
530 |
else: |
|
|
531 |
raise ValueError(f'target_variable_type must be str, list or tuple, ' |
|
|
532 |
f'but is {type(target_variable_type)}') |
|
|
533 |
return y_true |
|
|
534 |
|
|
|
535 |
|
|
|
536 |
def get_mi_acc(xs, y_true, var_names, var_name_length=35): |
|
|
537 |
"""Get mutual information (MI), adjusted MI, the maximal acc from Bayes classifier |
|
|
538 |
for a list of discrete predictors xs and target y_true |
|
|
539 |
For all combinations of xs calculate MI, Adj_MI, and Bayes_ACC |
|
|
540 |
|
|
|
541 |
Args: |
|
|
542 |
xs: a list of tensors or numpy arrays |
|
|
543 |
y_true: a tensor or numpy array |
|
|
544 |
|
|
|
545 |
Returns: |
|
|
546 |
a list of dictionaries with key being the variable name |
|
|
547 |
""" |
|
|
548 |
if isinstance(xs[0], torch.Tensor): |
|
|
549 |
xs = [x.cpu().detach().numpy() for x in xs] |
|
|
550 |
if isinstance(y_true, torch.Tensor): |
|
|
551 |
y_true = y_true.cpu().detach().numpy() |
|
|
552 |
result = [] |
|
|
553 |
print('{:^{var_name_length}}\t{:^5}\t{:^6}\t{:^9}'.format('Variable', 'MI', 'Adj_MI', 'Bayes_ACC', |
|
|
554 |
var_name_length=var_name_length)) |
|
|
555 |
for i, l in enumerate(itertools.chain.from_iterable(itertools.combinations(range(len(xs)), r) |
|
|
556 |
for r in range(1, 1+len(xs)))): |
|
|
557 |
if len(l) == 1: |
|
|
558 |
new_x = xs[l[0]] |
|
|
559 |
msg = f'{var_names[i]:^{var_name_length}}\t' |
|
|
560 |
else: # len(l) > 1 |
|
|
561 |
new_x = [tuple([v.item() for v in s]) for s in zip(*[xs[j] for j in l])] |
|
|
562 |
new_x = discrete_to_id(new_x, complex_object=True)[0] |
|
|
563 |
msg = f'{"-".join(map(str, l)):^{var_name_length}}\t' |
|
|
564 |
mi = sklearn.metrics.mutual_info_score(y_true, new_x) |
|
|
565 |
adj_mi = sklearn.metrics.adjusted_mutual_info_score(y_true, new_x) |
|
|
566 |
bayes_acc = (sklearn.metrics.confusion_matrix(y_true, new_x).max(axis=0).sum() / len(y_true)) |
|
|
567 |
result.append({msg: [mi, adj_mi, bayes_acc]}) |
|
|
568 |
msg += f'{mi:^5.3f}\t{adj_mi:^6.3f}\t{bayes_acc:^9.3f}' |
|
|
569 |
print(msg) |
|
|
570 |
return result |
|
|
571 |
# p1 = sklearn.metrics.confusion_matrix(y_true.numpy(), new_x)[:2].reshape(-1) |
|
|
572 |
# p2 = (np.bincount(y_true.numpy())[:,None] * np.bincount(new_x)).reshape(-1) |
|
|
573 |
# p = torch.distributions.categorical.Categorical(torch.tensor(p1, dtype=torch.float)) |
|
|
574 |
# q = torch.distributions.categorical.Categorical(torch.tensor(p2, dtype=torch.float)) |
|
|
575 |
# torch.distributions.kl.kl_divergence(p,q) |