Diff of /dl/utils/utils.py [000000] .. [4807fa]

Switch to unified view

a b/dl/utils/utils.py
1
import os
2
import functools
3
import itertools
4
import collections
5
import numpy as np
6
import pandas
7
from PIL import Image
8
import sklearn.metrics
9
10
import torch
11
import torch.nn as nn
12
import torch.nn.functional as F
13
from torch.utils import data
14
15
from .outlier import normalization
16
from .train import get_label_prob
17
18
def discrete_to_id(targets, start=0, sort=True, complex_object=False):
19
  """Change discrete variable targets to numeric values
20
21
  Args:
22
    targets: 1-d torch.Tensor or np.array, or a list
23
    start: the starting index for the first elements
24
    sort: sort the unique value, so that the 'smaller' values have smaller indices
25
    complex_object: input is not numeric, but complex objects, e.g., tuple
26
27
  Returns:
28
    target_ids: torch.Tensor or np.array with integer elements starting from start(=0 default)
29
    cls_id_dict: a dictionary mapping variables to their numeric ids
30
31
  """
32
  if complex_object:
33
    unique_targets = sorted(collections.Counter(targets))
34
  else:
35
    if isinstance(targets, torch.Tensor):
36
      targets = targets.cpu().detach().numpy()
37
    else:
38
      targets = np.array(targets) # if targets is already an np.array, then it does nothing
39
    unique_targets = np.unique(targets)
40
    if sort:
41
      unique_targets = np.sort(unique_targets)
42
  cls_id_dict = {v: i+start for i, v in enumerate(unique_targets)}
43
  target_ids = np.array([cls_id_dict[v] for v in targets])
44
  if isinstance(targets, torch.Tensor):
45
    target_ids = targets.new_tensor(target_ids)
46
  return target_ids, cls_id_dict
47
  
48
49
def get_f1_score(m, average='weighted', verbose=False):
50
  """Given a confusion matrix for binary classification, 
51
    calculate accuracy, precision, recall, F1 measure
52
    
53
  Args:
54
    m: confusion mat for binary classification
55
    average: if 'weighted': calculate metrics for each label, then get weighted average (weights are supports)
56
      if 'average': calculate average metrics for each label
57
      see http://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
58
    verbose: if True, print result
59
  """
60
  def cal_f1(precision, recall):
61
    if precision + recall == 0:
62
      print('Both precision and recall are zero')
63
      return 0
64
    return 2*precision*recall / (precision+recall)
65
  m = np.array(m)
66
  t0 = m[0,0] + m[0,1]
67
  t1 = m[1,0] + m[1,1]
68
  p0 = m[0,0] + m[1,0]
69
  p1 = m[0,1] + m[1,1]
70
  prec0 = m[0,0] / p0
71
  prec1 = m[1,1] / p1
72
  recall0 = m[0,0] / t0
73
  recall1 = m[1,1] / t1
74
  f1_0 = cal_f1(prec0, recall0)
75
  f1_1 = cal_f1(prec1, recall1)
76
  if average == 'macro':
77
    w0 = 0.5
78
    w1 = 0.5
79
  elif average == 'weighted':
80
    w0 = t0 / (t0+t1)
81
    w1 = t1 / (t0+t1)
82
  prec = prec0*w0 + prec1*w1
83
  recall = recall0*w0 + recall1*w1
84
  f1 = f1_0*w0 + f1_1*w1
85
  acc = (m[0,0] + m[1,1]) / (t0+t1)
86
  if verbose:
87
    print(f'prec0={prec0}, recall0={recall0}, f1_0={f1_0}\n'
88
         f'prec1={prec1}, recall1={recall1}, f1_1={f1_1}')
89
  return acc, prec, recall, f1
90
91
92
def dist(params1, params2=None, dist_fn=torch.norm): #pylint disable=no-member
93
    """Calculate the norm of params1 or the distance between params1 and params2; 
94
        Common usage calculate the distance between two model state_dicts.
95
    Args:
96
        params1: dictionary; with each item a torch.Tensor
97
        params2: if not None, should have the same structure (data types and dimensions) as params1
98
    """
99
    if params2 is None:
100
        return dist_fn(torch.Tensor([dist_fn(params1[k]) for k in params1]))
101
    d = torch.Tensor([dist_fn(params1[k] - params2[k]) for k in params1])
102
    return dist_fn(d)
103
    
104
class AverageMeter(object):
105
    def __init__(self):
106
        self._reset()
107
    
108
    def _reset(self):
109
        self.val = 0
110
        self.sum = 0
111
        self.cnt = 0
112
        self.avg = 0
113
    
114
    def update(self, val, n=1):
115
        self.val = val
116
        self.sum += val * n
117
        self.cnt += n
118
        self.avg = self.sum / self.cnt
119
120
121
def pil_loader(path, format = 'RGB'):
122
    with open(path, 'rb') as f:
123
        with Image.open(f) as img:
124
            return img.convert(format)
125
126
127
class ImageFolder(data.Dataset):
128
    def __init__(self, root, imgs, transform = None, target_transform = None, 
129
                 loader = pil_loader, is_test = False):
130
        self.root = root
131
        self.imgs = imgs
132
        self.transform = transform
133
        self.target_transform = target_transform
134
        self.loader = pil_loader
135
        self.is_test = is_test
136
    
137
    def __getitem__(self, idx):
138
        if self.is_test:
139
            img = self.imgs[idx]
140
        else:
141
            img, target = self.imgs[idx]
142
        img = self.loader(os.path.join(self.root, img))
143
        if self.transform is not None:
144
            img = self.transform(img)
145
        if not self.is_test and self.target_transform is not None:
146
            target = self.target_transform(target)
147
        if self.is_test:
148
            return img
149
        else:
150
            return img, target
151
    
152
    def __len__(self):
153
        return len(self.imgs)
154
155
        
156
def check_acc(output, target, topk=(1,)):
157
    if isinstance(output, tuple):
158
        output = output[0]
159
    maxk = max(topk)
160
    _, pred = output.topk(maxk, 1)
161
    res = []
162
    for k in topk:
163
        acc = (pred.eq(target.contiguous().view(-1,1).expand(pred.size()))[:, :k]
164
               .float().contiguous().view(-1).sum(0))
165
        acc.mul_(100 / target.size(0))
166
        res.append(acc)
167
    return res
168
169
170
### Mainly developed for TCGA data analysis
171
def select_samples(mat, aliquot_ids, feature_ids, patient_clinical=None, clinical_variable='PFI', 
172
                   sample_type='01', drop_duplicates=True, remove_na=True):
173
  """Select samples with given sample_type ('01');
174
     if drop_duplicates is True (by default), remove technical duplicates; 
175
     and if remove_na is True (default), remove features that have NA;
176
     If patient_clinical is not None, further filter out samples with clinical_variable being NA
177
  """
178
  mat = pandas.DataFrame(mat, columns=feature_ids) # Use pandas to drop NA
179
  # Select samples with sample_type(='01')
180
  idx = np.array([[i,s[:12]] for i, s in enumerate(aliquot_ids) if s[13:15]==sample_type])
181
  # Remove technical duplicate
182
  if drop_duplicates:
183
    idx = pandas.DataFrame(idx).drop_duplicates(subset=[1]).values
184
    mat = mat.iloc[idx[:,0].astype(int)]
185
  aliquot_ids = aliquot_ids[idx[:,0].astype(int)]
186
  if remove_na:
187
  # Remove features that have NA values
188
    mat = mat.dropna(axis=1)
189
    feature_ids = mat.columns.values
190
  mat = mat.values
191
  if patient_clinical is not None:
192
    idx = [s[:12] in patient_clinical and not np.isnan(patient_clinical[s[:12]][clinical_variable]) 
193
           for s in aliquot_ids]
194
    mat = mat[idx]
195
    aliquot_ids = aliquot_ids[idx]
196
  return mat, aliquot_ids, feature_ids
197
198
199
def get_feature_feature_mat(feature_ids, gene_ids, feature_gene_adj, gene_gene_adj, 
200
                            max_score=1000):
201
  """Calculate feature-feature interaction matrix based on their mapping to genes 
202
    and gene-gene interactions:
203
    feature_feature = feature_gene * gene_gene * feature_gene^T (transpose)
204
  
205
  Args:
206
    feature_ids: np.array([feature_names]), dict {id: feature_name}, or {feature_name: id}
207
    gene_ids: np.array([gene_names]), dict {id: gene_name}, or {gene_name: id}
208
    feature_gene_adj: np.array([[feature_name, gene_name, score]]) 
209
      with rows corresponding to features and columns genes; 
210
      or (Deprecated) a list (gene) of lists of feature_ids. 
211
        Note this is different from np.array input; len(feature_gene_adj) = len(gene_ids)
212
    gene_gene_adj: an np.array. Each row is (gene_name1, gene_name2, score)
213
    max_score: default 1000. Normalize confidence scores in gene_gene_adj to be in [0, 1]
214
    
215
  Returns:
216
    feature_feature_mat: np.array of shape (len(feature_ids), len(feature_ids))
217
    
218
  """
219
  def check_input_ids(ids):
220
    if isinstance(ids, np.ndarray) or isinstance(ids, list):
221
      ids = {v: i for i, v in enumerate(ids)} # Map feature names to indices starting from 0
222
    elif isinstance(ids, dict):
223
      if sorted(ids) == list(range(len(ids))):
224
        # make sure it follows format {feature_name: id}
225
        ids = {v: k for k, v in ids.items()}
226
    else:
227
      raise ValueError(f'The input ids should be a list/np.ndarray/dictionary, '
228
                       'but is {type(feature_ids)}')
229
    return ids
230
  feature_ids = check_input_ids(feature_ids)
231
  gene_ids = check_input_ids(gene_ids)
232
  
233
  idx = []
234
  if isinstance(feature_gene_adj, list): # Assume feature_gene_adj is a list; this is deprecated
235
    for i, v in enumerate(feature_gene_adj):
236
      for j in v:
237
        idx.append([j, i, 1])
238
  elif isinstance(feature_gene_adj, np.ndarray) and feature_gene_adj.shape[1] == 3:
239
    for v in feature_gene_adj: 
240
      if v[0] in feature_ids and v[1] in gene_ids:
241
        idx.append([feature_ids[v[0]], gene_ids[v[1]], float(v[2])])
242
  else:
243
    raise ValueError('feature_gene_adj should be an np.ndarray of shape (N, 3) '
244
                     'or a list of lists (deprecated).')
245
  idx = np.array(idx).T
246
  feature_gene_mat = torch.sparse.FloatTensor(torch.tensor(idx[:2]).long(), 
247
                                              torch.tensor(idx[2]).float(), 
248
                                              (len(feature_ids), len(gene_ids)))
249
  # Extract a subnetwork from gene_gene_adj
250
  # Assume there is no self-loop in gene_gene_adj 
251
  # and it contains two records for each undirected edge
252
  idx = []
253
  for v in gene_gene_adj: 
254
    if v[0] in gene_ids and v[1] in gene_ids:
255
      idx.append([gene_ids[v[0]], gene_ids[v[1]], v[2]/max_score])
256
  # Add self-loops
257
  for i in range(len(gene_ids)):
258
    idx.append([i, i, 1.])
259
  idx = np.array(idx).T
260
  gene_gene_mat = torch.sparse.FloatTensor(torch.tensor(idx[:2]).long(),
261
                                          torch.tensor(idx[2]).float(),
262
                                          (len(gene_ids), len(gene_ids)))
263
  feature_feature_mat = feature_gene_mat.mm(gene_gene_mat.mm(feature_gene_mat.to_dense().t()))
264
  return feature_feature_mat.numpy()
265
266
267
def get_overlap_samples(sample_lists, common_list=None, start=0, end=12, return_common_list=False):
268
  """Given a list of aliquot_id lists, find the common sample ids
269
  
270
  Args:
271
    sample_lists: a iterable of sample (aliquot) id lists
272
    common_list: if None (default), find the interaction of sample_lists; 
273
      if provided, it should not be a set, because iterating over a set can be different from different runs
274
    start: default 0; assume sample ids are strings; 
275
      when finding overlapping samples, only consider a specific range [start, end)
276
    end: default 12, for TCGA BCR barcode
277
    return_common_list: if True, return a set containing common list for backward compatiablity,
278
      returns a sorted common list is a better option
279
  
280
  Returns:
281
    np.array of shape (len(sample_lists), len(common_list))
282
  """ 
283
  sample_lists = [[s_id[start:end] for s_id in sample_list] for sample_list in sample_lists]
284
  if common_list is None:
285
    common_list = functools.reduce(lambda x,y: set(x).intersection(y), sample_lists)
286
    if return_common_list:
287
      return common_list
288
    common_list = sorted(common_list) # iterate over set can vary from different runs
289
  for s in sample_lists: # make sure every list in sample_lists contains all elements in common_list
290
    assert len(set(common_list).difference(s)) == 0 
291
  idx_lists = np.array([[sample_list.index(s_id) for s_id in common_list] 
292
                        for sample_list in sample_lists])
293
  return idx_lists
294
295
296
# Select samples that have target variable(s) is in clinical file
297
def filter_clinical_dict(target_variable, target_variable_type, target_variable_range, 
298
                         clinical_dict):
299
  """Select patients with given target variable, its type and range in clinical data
300
  To save computation time, I assume all target variable(s) names are in clinical_dict without verification;
301
  
302
  Args:
303
    target_variable: str or a list of strings
304
    target_variable_type: 'discrete' or 'continuous' or a list of 'discrete' or 'continuous'
305
    target_variable_range: a list of values for 'continous' type, it is [lower_bound, upper_bound]
306
      or a list of list; target_variable, target_variable_type, target_variable_range must match
307
    clinical_dict: a dictionary of dictinaries; 
308
      first-level keys: patient ids, second-level keys: variable names
309
  
310
  Returns:
311
    clinical_dict: newly constructed clinical_dict with all patients having target_variables
312
    
313
  Examples:
314
    target_variable = ['PFI', 'OS.time'] 
315
    target_variable_type = ['discrete', 'continuous']
316
    target_variable_range = [[0, 1], [0, float('Inf')]]
317
    clinical_dict = filter_clinical_dict(target_variable, target_variable_type, target_variable_range, 
318
                            patient_clinical)
319
    assert sorted([k for k, v in patient_clinical.items() if v['PFI'] in [0,1] and not np.isnan(v['OS.time'])]) == 
320
      sorted(clinical_dict.keys())
321
322
  """
323
  if isinstance(target_variable, str):
324
    if target_variable_type == 'discrete':
325
      clinical_dict = {p:v for p, v in clinical_dict.items() 
326
                       if v[target_variable] in target_variable_range}
327
    elif target_variable_type == 'continuous':
328
      clinical_dict = {p:v for p, v in clinical_dict.items() 
329
                       if v[target_variable] >= target_variable_range[0] 
330
                       and v[target_variable] <= target_variable_range[1]}
331
  
332
  elif isinstance(target_variable, (list, tuple)):
333
    # Brilliant recursion
334
    for tar_var, tar_var_type, tar_var_range in zip(target_variable, target_variable_type, target_variable_range):
335
      clinical_dict = filter_clinical_dict(tar_var, tar_var_type, tar_var_range, clinical_dict)
336
      
337
  return clinical_dict
338
339
340
def get_target_variable(target_variable, clinical_dict, sel_patient_ids):
341
  """Extract target_variable from clinical_dict for sel_patient_ids
342
  If target_variable is a single str, it is only one line of code
343
  If target_variable is a list, recursively call itself and return a list of target variables
344
  
345
  Assume all sel_patient_ids have target_variable in clinical_dict
346
  
347
  """
348
  if isinstance(target_variable, str):
349
    return [clinical_dict[s][target_variable] for s in sel_patient_ids]
350
  elif isinstance(target_variable, (list, str)):
351
    return [[clinical_dict[s][tar_var] for s in sel_patient_ids] for tar_var in target_variable]
352
353
354
def normalize_continuous_variable(y_targets, target_variable_type, transform=True, forced=False, 
355
                        threshold=10, rm_outlier=True, whis=1.5, only_positive=True, max_val=1):
356
  """Normalize continuous variable(s)
357
    If a variable is 'continuous', then call normalization() in outlier.py
358
  
359
  Args:
360
    y_targets: a np.array or a list of np.array
361
    target_variable_type: can be a string: 'continous' or 'discrete' (do nothing but return the input)
362
      or a list of strings
363
    transform, forced, threshold, rm_outlier, whis, only_positive, max_val are all passed to normalization
364
365
  """
366
  if isinstance(target_variable_type, str):
367
    if target_variable_type=='continuous':
368
      y_targets = normalization(y_targets, transform=transform, forced=forced, threshold=threshold, 
369
                                rm_outlier=rm_outlier, whis=whis, only_positive=only_positive, 
370
                                max_val=max_val, diagonal=False, symmetric=False)
371
    return y_targets
372
  elif isinstance(target_variable_type, list):
373
    return [normalize_continuous_variable(y, var_type, transform=transform, forced=forced, 
374
            threshold=threshold, rm_outlier=rm_outlier, whis=whis, only_positive=only_positive, 
375
            max_val=max_val) for y, var_type in zip(y_targets, target_variable_type)]
376
  else:
377
    raise ValueError(f'target_variable_type should be a str or list of strs, but is {target_variable_type}')
378
379
380
def get_label_distribution(ys, check_num_cls=True):
381
  """Get label distributions for a list of labels
382
  
383
  Args:
384
    ys: an iterable (e.g., list) of labels (1-d numpy.array or torch.Tensor);
385
      the most common usage is get_label_distribution([y_train, y_val, y_test])
386
    check_num_cls: only if it is True, ensure that each list of labels will have the same number of classes 
387
      and also print out the message
388
    
389
  Returns:
390
    label_prob: a list of label distributions (multinomial);
391
    
392
  """
393
  num_cls = 0
394
  label_probs = []
395
  for i, y in enumerate(ys):
396
    if len(y)>0:
397
      label_prob = get_label_prob(y, verbose=False)
398
      label_probs.append(label_prob)
399
      if check_num_cls:
400
        if num_cls > 0:
401
          assert num_cls == len(label_probs[-1]), f'{i}: {num_cls} != {len(label_probs[-1])}'
402
        else:
403
          num_cls = len(label_probs[-1])
404
    else:
405
      label_probs.append([])
406
  if check_num_cls:
407
    if isinstance(label_probs, torch.Tensor):
408
      print('label distribution:\n', torch.stack(label_probs, dim=1))
409
    else:
410
      print('label distribution:\n', np.stack(label_probs, axis=1))
411
  return label_probs
412
413
414
def get_shuffled_data(sel_patient_ids, clinical_dict, cv_type, instance_portions, group_sizes,
415
                     group_variable_name, seed=None, verbose=True):
416
  """Shuffle sel_patient_ids and split them into multiple splits, 
417
    in most cases, train, val and test sets; 
418
  
419
  Args:
420
    sel_patient_ids: a list of object (patient) ids
421
    clinical_dict: a dictionary of dictionaries; 
422
      first-level keys: object ids; second-level keys: attribute names;
423
    cv_type: either 'group-shuffle' or 'instance-shuffle'; in most cases:
424
      if 'group-shuffle', split groups into train, val and test set according to group_sizes or
425
      implicitly instance_portions;
426
      if 'instance-shuffle': split based on instance_portions
427
    instance_portions: a list of floats; the proportions of samples in each split; 
428
      when cv_type=='group-shuffle' and group_sizes is given, then instance_portions is not used
429
    group_sizes: the number of groups in each split; only used when cv_type=='group-shuffle'
430
    group_variable_name: the attribute name for group information
431
    
432
  Returns:
433
    sel_patient_ids: shuffled object ids
434
    idx_splits: a list of indices, e.g., [train_idx, val_idx, test_idx]
435
      sel_patient_ids[train_idx] will get patient ids for training
436
      
437
  """
438
  np.random.seed(seed)
439
  sel_patient_ids = np.random.permutation(sel_patient_ids)
440
  num_samples = len(sel_patient_ids)
441
  idx_splits = []
442
  if cv_type == 'group-shuffle':
443
    # for my TCGA project, I used disease types as groups; thus the variable name is named 'disease_types'
444
    disease_types = sorted({clinical_dict[s][group_variable_name] for s in sel_patient_ids})
445
    num_disease_types = len(disease_types)
446
    np.random.shuffle(disease_types)
447
    type_splits = []
448
    cnt = 0
449
    for i in range(len(group_sizes)-1):
450
      if group_sizes[i] < 0: 
451
        # use instance_portion as group portions
452
        assert sum(instance_portions) == 1
453
        group_sizes[i] = round(instance_portions[i] * num_disease_types)
454
      type_splits.append(disease_types[cnt:cnt+group_sizes[i]])
455
      cnt = cnt+group_sizes[i]
456
      # do not use i to enumerate sel_patient_ids because i is used
457
      idx_splits.append([j for j, s in enumerate(sel_patient_ids) 
458
                         if clinical_dict[s][group_variable_name] in type_splits[i]])
459
    # process the last split
460
    if group_sizes[-1] >=0: # for most of time, set group_sizes[-1] = num_test_types = -1
461
      # almost never set group_sizes[-1] = 0, which will be useless
462
      assert group_sizes[-1] == num_disease_types - sum(group_sizes[:-1])
463
    if cnt == len(disease_types):
464
      print('The last group is empty, thus not included')
465
    else:
466
      type_splits.append(disease_types[cnt:]) 
467
      idx_splits.append([i for i, s in enumerate(sel_patient_ids) 
468
                          if clinical_dict[s][group_variable_name] in type_splits[-1]])
469
  elif cv_type == 'instance-shuffle':
470
    # because sel_patient_ids has already been shuffled, we do not need to shuffle indices
471
    cnt = 0
472
    assert sum(instance_portions) == 1
473
    for i in range(len(instance_portions)-1):
474
      n = round(instance_portions[i]*num_samples)
475
      idx_splits.append(list(range(cnt, cnt+n)))
476
      cnt = cnt + n
477
    # process the last split
478
    if cnt == num_samples:
479
      # this can rarely happen
480
      print('The last split is empty, thus not included')
481
    else:
482
      idx_splits.append(list(range(cnt, num_samples)))
483
  
484
  def get_type_cnt_msg(p_ids):
485
    """For a list p_ids, prepare group statistics for printing
486
    """
487
    cnt_dict = dict(collections.Counter([clinical_dict[p_id][group_variable_name] 
488
                                       for p_id in p_ids]))
489
    return f'{len(cnt_dict)} groups: {cnt_dict}'
490
491
  if verbose:
492
    msg = f'{cv_type}: \n'
493
    msg += '\n'.join([f'split {i}: {len(v)} samples ({len(v)/num_samples:.2f}), '
494
                      f'{get_type_cnt_msg(sel_patient_ids[v])}'
495
                      for i, v in enumerate(idx_splits)])
496
    print(msg)
497
  return sel_patient_ids, idx_splits
498
499
500
def target_to_numpy(y_targets, target_variable_type, target_variable_range):
501
  """y_targets is a list or a list of lists; transform it to numpy array
502
  For a discrete variable, generate numerical class labels from 0;
503
  for a continous variable, simply call np.array(y_targets);
504
  use recusion to handle a list of target variables
505
  
506
  Args:
507
    y_targets: a list of objects (strings/numbers, must be comparable) or lists
508
    target_variable_type: a string or a list of string ('discrete' or 'continous')
509
    target_variable_range: only used for sanity check for discrete variables
510
    
511
  Returns:
512
    y_true: a numpy array or a list of numpy arrays of type either float or int
513
    
514
  """
515
  
516
  if isinstance(target_variable_type, str):
517
    y_true = np.array(y_targets)
518
    if target_variable_type == 'discrete':
519
      unique_cls = np.unique(y_true)
520
      num_cls = len(unique_cls)
521
      if sorted(unique_cls) != sorted(target_variable_range):
522
        print(f'unique_cls: {unique_cls} !=\ntarget_variable_range {target_variable_range}')
523
      cls_idx_dict = {p.item(): i for i, p in enumerate(sorted(unique_cls))}
524
      y_true = np.array([cls_idx_dict[i.item()] for i in y_true])
525
      print(f'Changed class labels for the model: {cls_idx_dict}')
526
  elif isinstance(target_variable_type, (list, tuple)):
527
    y_true = [target_to_numpy(y, tar_var_type, tar_var_range) 
528
              for y, tar_var_type, tar_var_range in 
529
              zip(y_targets, target_variable_type, target_variable_range)]
530
  else:
531
    raise ValueError(f'target_variable_type must be str, list or tuple, '
532
                     f'but is {type(target_variable_type)}')
533
  return y_true
534
535
536
def get_mi_acc(xs, y_true, var_names, var_name_length=35):
537
  """Get mutual information (MI), adjusted MI, the maximal acc from Bayes classifier 
538
  for a list of discrete predictors xs and target y_true
539
  For all combinations of xs calculate MI, Adj_MI, and Bayes_ACC
540
541
  Args:
542
    xs: a list of tensors or numpy arrays
543
    y_true: a tensor or numpy array
544
545
  Returns:
546
    a list of dictionaries with key being the variable name
547
  """
548
  if isinstance(xs[0], torch.Tensor):
549
    xs = [x.cpu().detach().numpy() for x in xs]
550
  if isinstance(y_true, torch.Tensor):
551
    y_true = y_true.cpu().detach().numpy()
552
  result = []
553
  print('{:^{var_name_length}}\t{:^5}\t{:^6}\t{:^9}'.format('Variable', 'MI', 'Adj_MI', 'Bayes_ACC', 
554
    var_name_length=var_name_length))
555
  for i, l in enumerate(itertools.chain.from_iterable(itertools.combinations(range(len(xs)), r) 
556
                                     for r in range(1, 1+len(xs)))):
557
    if len(l) == 1:
558
      new_x = xs[l[0]]
559
      msg = f'{var_names[i]:^{var_name_length}}\t'
560
    else: # len(l) > 1
561
      new_x = [tuple([v.item() for v in s]) for s in zip(*[xs[j] for j in l])]
562
      new_x = discrete_to_id(new_x, complex_object=True)[0]
563
      msg = f'{"-".join(map(str, l)):^{var_name_length}}\t'
564
    mi = sklearn.metrics.mutual_info_score(y_true, new_x)
565
    adj_mi = sklearn.metrics.adjusted_mutual_info_score(y_true, new_x)
566
    bayes_acc = (sklearn.metrics.confusion_matrix(y_true, new_x).max(axis=0).sum() / len(y_true))
567
    result.append({msg: [mi, adj_mi, bayes_acc]})
568
    msg += f'{mi:^5.3f}\t{adj_mi:^6.3f}\t{bayes_acc:^9.3f}'
569
    print(msg)
570
  return result
571
  # p1 = sklearn.metrics.confusion_matrix(y_true.numpy(), new_x)[:2].reshape(-1)
572
  # p2 = (np.bincount(y_true.numpy())[:,None] * np.bincount(new_x)).reshape(-1)
573
  # p = torch.distributions.categorical.Categorical(torch.tensor(p1, dtype=torch.float))
574
  # q = torch.distributions.categorical.Categorical(torch.tensor(p2, dtype=torch.float))
575
  # torch.distributions.kl.kl_divergence(p,q)