Diff of /main.py [000000] .. [11ca2d]

Switch to unified view

a b/main.py
1
import argparse
2
from networks import (
3
    branchy_linear_network,
4
    deep_linear_network,
5
    dual_input_model,
6
    seq_model,
7
    simple_branchy_linear_network
8
)
9
from data_generator import data_generator
10
from sklearn.metrics import cohen_kappa_score, f1_score
11
from tensorflow.keras import optimizers
12
import tensorflow as tf
13
import tensorflow_addons as tfa
14
15
16
if __name__ == '__main__':
17
    parser = argparse.ArgumentParser(description='Process the inputs')
18
    parser.add_argument(
19
        '--model', 
20
        type=str, 
21
        help='which model would you like to run',
22
        default='simple_branchy_linear_network'
23
    )
24
    parser.add_argument(
25
        '--epochs', 
26
        type=int, 
27
        help='how many epochs',
28
        default=12
29
    )
30
    parser.add_argument(
31
        '--verbose', 
32
        type=int, 
33
        help='0,1,2',
34
        default=1
35
    )
36
37
    args = parser.parse_args()
38
    model_ = args.model
39
    epochs = args.epochs
40
    verbose = args.verbose
41
42
    X_train, X_test, y_train, y_test, y_train_categorical, y_test_categorical, class_weight = data_generator(
43
        'data/challenge_1_gut_microbiome_data.csv'
44
    )
45
46
    # selecting your model
47
    if model_ == 'simple_branchy_linear_network':
48
        model = simple_branchy_linear_network.simple_branchy_linear_network(class_weight)
49
    elif model_ == 'branchy_linear_network':
50
        model = branchy_linear_network.branchy_linear_network(class_weight)
51
    elif model_ == 'seq_model':
52
        model = seq_model.seq_model()
53
    elif model_ == 'deep_linear_network':
54
        model = deep_linear_network.deep_linear_network(class_weight)
55
    elif model_ == 'dual_input_model':
56
        model = dual_input_model.dual_input_model()
57
58
    print(model.summary())
59
60
    # compile the model
61
    adam = optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
62
    model.compile(
63
        optimizer=adam, 
64
        loss=tf.keras.losses.CategoricalCrossentropy(),
65
        metrics=[
66
            tfa.metrics.CohenKappa(num_classes=4, weightage='quadratic'),
67
            tfa.metrics.F1Score(num_classes=4),
68
            'accuracy'
69
        ]
70
    )
71
72
    # create call backs
73
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
74
        monitor='val_loss',
75
        factor=0.5,
76
        patience=5,
77
        verbose=verbose,
78
        min_lr=1e-8,
79
    )
80
81
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
82
        'models/checkpoint',
83
        monitor='val_cohen_kappa',
84
        verbose=verbose,
85
        save_best_only=True,
86
        mode='max',
87
        save_weights_only=False
88
    )
89
90
    # fit the model
91
    model.fit(
92
        x=X_train,
93
        y=y_train_categorical,
94
        batch_size=16,
95
        epochs=epochs,
96
        verbose=verbose,
97
        validation_data=(X_test, y_test_categorical),
98
        shuffle=True,
99
        callbacks=[
100
            reduce_lr,
101
            checkpoint
102
        ]
103
    )
104
105
    # evaluate the model
106
    # load the best model
107
    model.load_weights('models/checkpoint')
108
109
    y_prob = model.predict(X_test) 
110
    y_classes = y_prob.argmax(axis=-1)
111
112
    ck_score = cohen_kappa_score(
113
        y_test,
114
        y_classes,
115
        weights='quadratic'
116
    )
117
118
    f1_score = f1_score(
119
        y_test,
120
        y_classes,
121
        labels=[0,1,2,3],
122
        average='weighted',
123
    )
124
125
    print('Results:')
126
    print('cohen kappa score:', ck_score)
127
    print('f1_score:', f1_score)