Diff of /helper_functions.py [000000] .. [e056f2]

Switch to unified view

a b/helper_functions.py
1
# functions for future use
2
3
import tensorflow as tf
4
import numpy as np
5
import matplotlib.pyplot as plt
6
7
# Function to import and resize images
8
def load_prepare_images(filename, img_shape=224, scale=True):
9
    """[summary]
10
11
    Args:
12
        filename ([type]): [description]
13
        img_shape (int, optional): [description]. Defaults to 224.
14
        scale (bool, optional): [description]. Defaults to True.
15
    """
16
    img = tf.io.read_file(filename)
17
    img = tf.image.decode_jpeg(image)
18
    img = tf.image.resize(img, [img_shape, img_shape])
19
    if scale:
20
        return img/255.
21
    else:
22
        return img
23
24
25
# Function to plot a confusion matrix
26
import numpy as np 
27
import matplotlib.pyplot as plt
28
import itertools
29
from sklearn.metrics import confusion_matrix
30
def create_confusion_matrix(y_true, y_pred, classes=None, figsize=(10, 10), text_size=15, norm=False, savefig=False):
31
    """[summary]
32
33
    Args:
34
        y_true ([type]): [description]
35
        y_pred ([type]): [description]
36
        classes ([type], optional): [description]. Defaults to None.
37
        figsize ([type], optional): [description]. Defaults to (10, 10).
38
        text_size (int, optional): [description]. Defaults to 15.
39
        norm (bool, optional): [description]. Defaults to False.
40
        savefig (bool, optional): [description]. Defaults to False.
41
42
    Returns:
43
        [type]: [description]
44
    """
45
    # Compute confusion matrix
46
    cm = confusion_matrix(y_true, y_pred)
47
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
48
    n_classes = cm.shape[0]
49
50
    # Plot confusion matrix 
51
    fig, ax = plt.subplots(figsize=figsize)
52
    cax = ax.matshow(cm, cmap=plt.cm.Blues)
53
    fig.colorbar(cax)
54
    if classes: 
55
        labels = classes
56
    else:
57
        labels = np.arange(cm.shap[0])
58
    ax.set(title="Confusion Matrix",
59
           ylabel="True label",
60
           xlabel="Predicted label",
61
           xticks = np.arange(n_classes),
62
           yticks = np.arange(n_classes),
63
           xticklabels = labels,
64
           yticklabels = labels)   
65
66
    ax.xaxis.set_label_position('bottom')
67
    ax.xaxis.tick_bottom()
68
    threshold = (cm.max() + cm.min()) / 2.
69
70
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
71
        if norm:
72
            plt.text(j, i, f"{cm[i, k]*100:.1f}%",
73
                     horizontalalignment="center",
74
                     color="white" if cm[i, j] > threshold else "black",
75
                     size=text_size)
76
77
        else:
78
            plt.text(j, i, f"{cm[i, j]}",
79
                    horizontalalignment="center",
80
                    color="white" if cm[i, j] > threshold else "black",
81
                    size=text_size)
82
83
    if savefig:
84
        fig.savefig("confusion_matrix.png")
85
86
87
# Function to predict images and plot
88
def pred_plot(model, filename, class_names):
89
    img = load_prepare_images(filename)
90
    pred = model.predict(tf.expand_dims(img, axis=0))
91
92
    if len(pred[0]) > 1:
93
        pred_class = class_names[pred.argmax()]
94
    else:
95
        pred_class = class_names[int(tf.round(pred)[0][0])]
96
        
97
    plt.imshow(img)
98
    plt.title(f"Prediction: {pred_class}")
99
    plt.axis(False);
100
101
import datetime
102
103
def create_tensorboard_callback(dir_name, experiment_name):
104
    log_dir = dir_name + "/" + experiment_name + "/" +datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
105
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
106
        log_dir=log_dir
107
    )
108
    print(f"Saved Tensorboard logs to: {log_dir}")  
109
    return tensorboard_callback
110
111
# Function to plot validation and training separately
112
import matplotlib.pyplot as plt
113
114
def plot_loss_curves(history):
115
    loss = history.history["loss"]
116
    val_loss = history.history["val_loss"]
117
    accuracy = history.history["accuracy"]
118
    val_accuracy = history.history["val_accuracy"]
119
    epochs = range(len(history.history["loss]"]))
120
121
    plt.plot(epochs, loss, label="Training_loss")
122
    plt.plot(epochs, val_loss, label="Validation_loss")
123
    plt.title("Loss")
124
    plt.xlabel("Epochs")
125
    plt.legend()
126
127
    plt.figure()
128
    plt.plot(epochs, accuracy, label="Training_accuracy")
129
    plt.plot(epochs, val_accuracy, label="Validation_accuracy")
130
    plt.title("Accuracy")
131
    plt.xlabel("Epochs")
132
    plt.legend();
133
134
def compare_history(original_history, new_history, inital_epochs=5):
135
    acc = original_history.history["accuracy"]
136
    loss = original_history.history["loss"]
137
    val_acc = original_history.history["val_accuracy"]
138
    val_loss = original_history.history["val_loss"]
139
    total_acc = acc + new_history.history["accuracy"]
140
    total_loss = loss + new_history.history["loss"]
141
    total_val_acc = val_acc + new_history.history["val_accuracy"]
142
    total_val_loss = val_loss + new_history.history["val_loss"]
143
    
144
    plt.figure(figsize=(10, 10))
145
    plt.subplot(2, 1, 1)
146
    plt.plot(total_acc, label="Training_accuracy")
147
    plt.plot(total_val_acc, label="Validation_accuracy")
148
    plt.plot([initial_epochs-1, initial_epoch-1],
149
              plt.ylim(), label="Start Fine Tuning")
150
    plt.legend(loc="lower right")
151
    plt.title("Training and Validation Accuracy")
152
    plt.subplot(2, 1, 2)
153
    plt.plot(total_loss, label="Training_loss")
154
    plt.plot(total_val_loss, label="Validation_loss")
155
    plt.plot([initial_epochs-1, initial_epoch-1],
156
              plt.ylim(), label="Start Fine Tuning")
157
    plt.legend(loc="upper right")
158
    plt.title("Training and Validation Loss")
159
    plt.xlabel("epoch")
160
    plt.show()
161
162
# Function to unzip a file
163
import zipfile
164
165
def unzip_data(filename):
166
    zip_ref = zipfile.ZipFile(filename, 'r')
167
    zip_ref.extractall()
168
    zip_ref.close()
169
170
171
# Function to walkthrough a directory and return a list of files
172
import os
173
174
def walk_dir(dir_path):
175
    for dirpath, dirnames, filenames in os.walk(dir_path):
176
        print(f"Directories {len(dirnames)} and images {len(filenames)} in '{dirpath}'. ")
177
178
179
# Function for evaluation, accuraccy, precision_recall_f1_score
180
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, f1_score
181
182
def calculate_results(y_true, y_pred):
183
    model_accuracy = accuracy_score(y_true, y_pred) * 100
184
    model_precision, model_recall, model_f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted")
185
    model_results = {"accuracy": model_accuracy,
186
                     "precision": model_precision,
187
                     "recall": model_recall,
188
                     "f1": model_f1}
189
    return model_results                 
190