Diff of /train.py [000000] .. [dce3d9]

Switch to unified view

a b/train.py
1
import os
2
import cv2
3
import numpy as np
4
from glob import glob
5
from sklearn.utils import shuffle
6
7
import tensorflow as tf
8
from tensorflow.keras.optimizers import Adam
9
from tensorflow.keras.metrics import Recall, Precision
10
11
from models.unet import get_unet_model
12
from metrics import dice_loss, dice_coef, iou
13
14
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
15
16
IMG_HEIGHT = 512
17
IMG_WIDTH = 512
18
AUTO = tf.data.AUTOTUNE
19
20
21
def create_dir(path):
22
    """Create a directory."""
23
    if not os.path.exists(path):
24
        os.makedirs(path)
25
26
27
def shuffling(x, y):
28
    x, y = shuffle(x, y, random_state=42)
29
    return x, y
30
31
32
def load_data(path):
33
    x = sorted(glob(os.path.join(path, "image", "*.jpg")))
34
    y = sorted(glob(os.path.join(path, "mask", "*.jpg")))
35
    return x, y
36
37
38
def preprocess_image(path):
39
    path = path.decode()
40
    x = cv2.imread(path, cv2.IMREAD_COLOR)
41
    x = x / 255.0
42
    x = x.astype(np.float32)
43
    return x
44
45
46
def preprocess_mask(path):
47
    path = path.decode()
48
    x = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
49
    x = x / 255.0
50
    x = x > 0.5
51
    x = x.astype(np.float32)
52
    x = np.expand_dims(x, axis=-1)
53
    return x
54
55
56
def preprocess_data(x, y):
57
    def _parse(x, y):
58
        x = preprocess_image(x)
59
        y = preprocess_mask(y)
60
        return x, y
61
62
    x, y = tf.numpy_function(_parse, [x, y], [tf.float32, tf.float32])
63
    x.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
64
    y.set_shape([IMG_HEIGHT, IMG_WIDTH, 1])
65
    return x, y
66
67
68
def load_dataset(x, y, batch_size=8):
69
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
70
    dataset = (
71
        dataset.map(preprocess_data, num_parallel_calls=AUTO)
72
        .batch(batch_size)
73
        .prefetch(AUTO)
74
    )
75
    return dataset
76
77
78
if __name__ == "__main__":
79
    """Seeding"""
80
    SEEDS = 42
81
    np.random.seed(SEEDS)
82
    tf.random.set_seed(SEEDS)
83
84
    # Create a MirroredStrategy.
85
    strategy = tf.distribute.MirroredStrategy()
86
    print("Number of devices: {}".format(strategy.num_replicas_in_sync))
87
88
    """ Directory for storing files """
89
    create_dir("files")
90
91
    """ Hyperparameters """
92
    batch_size = 16
93
    lr = 1e-4
94
    num_epochs = 200
95
    model_path = os.path.join("files", "model.h5")
96
    csv_path = os.path.join("files", "data.csv")
97
98
    """ Dataset """
99
    dataset_path = os.path.join("new_data")
100
    train_path = os.path.join(dataset_path, "train")
101
    valid_path = os.path.join(dataset_path, "valid")
102
103
    train_x, train_y = load_data(train_path)
104
    train_x, train_y = shuffling(train_x, train_y)
105
    valid_x, valid_y = load_data(valid_path)
106
107
    print(f"Train: {len(train_x)} - {len(train_y)}")
108
    print(f"Valid: {len(valid_x)} - {len(valid_y)}")
109
110
    train_dataset = load_dataset(train_x, train_y, batch_size=batch_size)
111
    valid_dataset = load_dataset(valid_x, valid_y, batch_size=batch_size)
112
113
    """ Model """
114
    # Open a strategy scope.
115
    with strategy.scope():
116
        model = get_unet_model((IMG_HEIGHT, IMG_WIDTH, 3))
117
        metrics = [dice_coef, iou, Recall(), Precision()]
118
        model.compile(loss=dice_loss, optimizer=Adam(lr), metrics=metrics)
119
120
    """Setting up Training Callbacks"""
121
    train_callbacks = [
122
        tf.keras.callbacks.ModelCheckpoint(model_path, verbose=1, save_best_only=True),
123
        tf.keras.callbacks.ReduceLROnPlateau(
124
            monitor="val_loss", factor=0.1, patience=10, min_lr=1e-7, verbose=1
125
        ),
126
        tf.keras.callbacks.CSVLogger(csv_path),
127
        tf.keras.callbacks.TensorBoard(),
128
        tf.keras.callbacks.EarlyStopping(
129
            monitor="val_loss", patience=10, restore_best_weights=False
130
        ),
131
    ]
132
133
    history = model.fit(
134
        train_dataset,
135
        epochs=num_epochs,
136
        validation_data=valid_dataset,
137
        callbacks=train_callbacks,
138
        shuffle=False,
139
    )