Diff of /lungs/utils.py [000000] .. [eac570]

Switch to unified view

a b/lungs/utils.py
1
# pylint: disable=missing-docstring
2
import os
3
import time
4
import numpy
5
from six.moves import xrange  # pylint: disable=redefined-builtin
6
import tensorflow as tf
7
import math
8
import numpy as np
9
import gdown
10
from sklearn.metrics import roc_auc_score, roc_curve
11
import scikitplot as skplt
12
import matplotlib.pyplot as plt
13
from matplotlib.pyplot import figure
14
from os.path import join
15
import seaborn as sns
16
sns.set_style("darkgrid")
17
18
REMOTE_CKPTS = {
19
    'cancer_fine_tuned': {'url': '1Zc8KdEz9JUfkT1ZsG9ELYReUPbVapbQC', 'md5': 'cd5271617e090859f73a727da81cc2e3'},
20
    'i3d_imagenet': {'url': '1FMWHGFYPjuvpgzkGm-_gYKdXpmv5fOq2',  'md5': 'f1408b50e5871153516fe87932121745'}
21
}
22
23
def load_pretrained_ckpt(ckpt, data_dir):
24
    if ckpt in REMOTE_CKPTS:
25
        download_ckpt(data_dir, ckpt, REMOTE_CKPTS[ckpt])
26
27
    # Load a pre-defined ckpt or a ckpt from path
28
    predefined = join(data_dir, 'checkpoints', ckpt)
29
    ckpt_dir = predefined if os.path.exists(predefined) else ckpt
30
31
    pre_trained_ckpt = join(ckpt_dir, 'model.ckpt')
32
    print('\nINFO: Loading pre-trained model:', pre_trained_ckpt)
33
    return pre_trained_ckpt
34
35
def download_ckpt(data_dir, name, download_info):
36
    if os.path.exists(join(data_dir, 'checkpoints', name)):
37
        print('\nINFO: {} model already downloaded.'.format(name))
38
    else:
39
        print('\nINFO: Downloading {} model...'.format(name))
40
        url = 'https://drive.google.com/uc?id=' + download_info['url']
41
        zip_output = join(data_dir, 'checkpoints', name + '.zip')
42
        md5 = download_info['md5']
43
        gdown.cached_download(url, zip_output, md5=md5, postprocess=gdown.extractall, quiet=True)
44
        os.remove(zip_output)
45
46
def pretty_print_floats(lst):
47
    return ',  '.join(['{:.3f}'.format(_) for _ in lst])
48
49
def load_npz_as_list(base_dir, npz_file):
50
    return np.load(join(base_dir, npz_file))['arr_0'].tolist()
51
52
def plot_loss(val_loss, tr_loss, plots_dir):
53
    figure(num=None, figsize=(16, 8), dpi=100)
54
    title = 'Training and Validation Loss'
55
    epochs = range(1, len(val_loss) + 1)
56
    plt.plot(epochs, val_loss, label='Val. Loss')
57
    plt.plot(epochs, tr_loss, label='Train Loss')
58
    plt.title(title)
59
    plt.xlabel('Epoch')
60
    plt.ylabel('Loss')
61
    plt.legend()
62
    plt.show()
63
    plt.savefig(join(plots_dir, title + '.png'), bbox_inches='tight')
64
65
def plot_acc_auc(val_acc, tr_acc, val_auc, tr_auc, plots_dir):
66
    figure(num=None, figsize=(16, 8), dpi=100)
67
    title = 'Accuracy and AUC'
68
    epochs = range(1, len(val_acc) + 1)
69
    plt.plot(epochs, val_acc, label='Val. Accuracy')
70
    plt.plot(epochs, tr_acc, label='Train Accuracy')
71
    plt.plot(epochs, val_auc, label='Val. AUC')
72
    plt.plot(epochs, tr_auc, label='Train AUC')
73
    plt.title(title)
74
    plt.xlabel('Epoch')
75
    plt.ylabel('Score')
76
    plt.legend()
77
    plt.savefig(join(plots_dir, title + '.png'), bbox_inches='tight')
78
79
def calc_plot_epoch_auc_roc(y, y_probs, title, plots_dir, verbose=False):
80
    y_prob_2_classes = [(1 - p, p) for p in y_probs]
81
    fpr, tpr, th = roc_curve(y, y_probs)
82
    if verbose:
83
        print('TPR:', pretty_print_floats(tpr))
84
        print('FPR:', pretty_print_floats(fpr))
85
        print('TH: ', pretty_print_floats(th), '\n')
86
    auc = roc_auc_score(y, y_probs)
87
    title = title + ',  AUC={:.3f}'.format(auc)
88
    skplt.metrics.plot_roc(y, y_prob_2_classes, classes_to_plot=[], 
89
                           title= title,
90
                           figsize=(7, 7), plot_micro=False, plot_macro=True, 
91
                           title_fontsize=15, text_fontsize=13)
92
    plt.show()
93
    plt.savefig(join(plots_dir, title) + '.png', bbox_inches='tight')
94
95
def load_and_plot_epoch_auc(metrics_dir, epoch, val_true, tr_true, plots_dir):
96
    val_preds_epoch = load_npz_as_list(metrics_dir, 'val_preds/epoch_' + str(epoch) + '.npz')
97
    calc_plot_epoch_auc_roc(val_true, val_preds_epoch, 
98
                            'Val. ROC for Epoch {}'.format(epoch), plots_dir)
99
100
    tr_preds_epoch = load_npz_as_list(metrics_dir, 'tr_preds/epoch_' + str(epoch) + '.npz')
101
    calc_plot_epoch_auc_roc(tr_true, tr_preds_epoch, 
102
                            'Train ROC for Epoch {}'.format(epoch), plots_dir)
103
104
def plot_metrics(epoch, metrics_dir, plots_dir):
105
    val_loss = load_npz_as_list(metrics_dir, 'val_loss.npz')
106
    val_acc = load_npz_as_list(metrics_dir, 'val_acc.npz')
107
    val_auc = load_npz_as_list(metrics_dir, 'val_auc.npz')
108
    val_true = load_npz_as_list(metrics_dir, 'val_true.npz')
109
110
    tr_loss = load_npz_as_list(metrics_dir, 'tr_loss.npz')
111
    tr_acc = load_npz_as_list(metrics_dir, 'tr_acc.npz')
112
    tr_auc = load_npz_as_list(metrics_dir, 'tr_auc.npz')
113
    tr_true = load_npz_as_list(metrics_dir, 'tr_true.npz')
114
115
    plot_loss(val_loss, tr_loss, plots_dir)
116
    plot_acc_auc(val_acc, tr_acc, val_auc, tr_auc, plots_dir)
117
    load_and_plot_epoch_auc(metrics_dir, epoch, val_true, tr_true, plots_dir)
118
119
def write_metrics(metrics, tr_metrics, val_metrics, metrics_dir, epoch, verbose=False):
120
    for (loss, acc, auc, preds, _), ds in ((tr_metrics, 'tr'), (val_metrics, 'val')):
121
        for metric, key in [(loss, 'loss'), (acc, 'acc'), (auc, 'auc'), (preds, 'preds')]:
122
            name = ds + '_' + key
123
            metrics[name].append(metric)
124
            write_number_list(metrics[name], join(metrics_dir, name))
125
        write_number_list(preds, join(metrics_dir, ds + '_preds', 'epoch_{}'.format(epoch)), verbose=verbose)
126
127
def apply_window(volume, axis=4):
128
    # Windowing
129
    # Our values currently range from -1024 to around 2000. 
130
    # Anything above 400 is not interesting to us, as these are simply bones with different radiodensity.  
131
    # A commonly used set of thresholds in Lungs LDCT to normalize between are -1000 and 400. 
132
    min_bound = -1000.0
133
    max_bound = 400.0
134
    volume = (volume - min_bound) / (max_bound - min_bound)
135
    volume[volume>1] = 1.
136
    volume[volume<0] = 0.
137
138
    # Normalize rgb values to [-1, 1]
139
    volume = (volume * 2) - 1
140
    res =  np.stack((volume, volume, volume), axis=axis)
141
    return res.astype(np.float32)
142
143
def write_number_list(lst, f_name, verbose=False):
144
    if verbose:
145
        print('INFO: Saving ' + f_name + '.npz ...')
146
        print(lst)
147
    np.savez(f_name + '.npz', np.array(lst))       
148
149
def batcher(iterable, batch_size=1):
150
    iter_len = len(iterable)
151
    for i in range(0, iter_len, batch_size):
152
        yield iterable[i: min(i + batch_size, iter_len)]
153
154
def load_data_list(path):
155
    coupled_data = []
156
    with open(path) as file_list_fp:
157
        for line in file_list_fp:
158
            volume_path, label = line.split()
159
            coupled_data.append((volume_path, int(label)))
160
    return coupled_data
161
162
def get_list_labels(coupled_data):
163
        return np.array([l for _, l in coupled_data]).astype(np.int64)
164
165
def placeholder_inputs(num_slices, crop_size, rgb_channels=3):
166
    """Generate placeholder variables to represent the input tensors.
167
168
    These placeholders are used as inputs by the rest of the model building
169
    code and will be fed from the downloaded data in the .run() loop, below.
170
171
    Args:
172
    num_slices: The num of slices per volume.
173
    crop_size: The crop size of per volume.
174
    channels: The number of RGB input channels per volume.
175
176
    Returns:
177
    volumes_placeholder: volumes placeholder.
178
    labels_placeholder: Labels placeholder.
179
    """
180
    # Note that the shapes of the placeholders match the shapes of the full
181
    # volume and label tensors, except the first dimension is now batch_size
182
    # rather than the full size of the train or test data sets.
183
    volumes_placeholder = tf.placeholder(tf.float32, shape=(None,
184
                                                           num_slices,
185
                                                           crop_size,
186
                                                           crop_size,
187
                                                           rgb_channels))
188
    labels_placeholder = tf.placeholder(tf.int64, shape=(None))
189
    is_training = tf.placeholder(tf.bool)
190
    return volumes_placeholder, labels_placeholder, is_training
191
192
def focal_loss(logits, labels, alpha=0.75, gamma=2):
193
    """Compute focal loss for binary classification.
194
195
    Args:
196
      labels: A int32 tensor of shape [batch_size].
197
      logits: A float32 tensor of shape [batch_size].
198
      alpha: A scalar for focal loss alpha hyper-parameter. If positive samples number
199
      > negtive samples number, alpha < 0.5 and vice versa.
200
      gamma: A scalar for focal loss gamma hyper-parameter.
201
    Returns:
202
      A tensor of the same shape as `labels`.
203
    """
204
    y_pred = tf.nn.sigmoid(logits)
205
    labels = tf.to_float(labels)
206
    losses = -(labels * (1 - alpha) * ((1 - y_pred) * gamma) * tf.log(y_pred)) - \
207
        (1 - labels) * alpha * (y_pred ** gamma) * tf.log(1 - y_pred)
208
    return tf.reduce_mean(losses)
209
210
def cross_entropy_loss(logits, labels):
211
    # pylint: disable=no-value-for-parameter
212
    # pylint: disable=unexpected-keyword-arg
213
    cross_entropy_mean = tf.reduce_mean(
214
                  tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
215
                  )
216
    return cross_entropy_mean
217
218
def accuracy(logit, labels):
219
    correct_pred = tf.equal(tf.argmax(logit, 1), labels)
220
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
221
    return accuracy
222
223
def get_preds(preds):
224
    return preds[:, 1]
225
226
def get_logits(logits):
227
    return logits