Diff of /main.py [000000] .. [11ca2d]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,127 @@
+import argparse
+from networks import (
+    branchy_linear_network,
+    deep_linear_network,
+    dual_input_model,
+    seq_model,
+    simple_branchy_linear_network
+)
+from data_generator import data_generator
+from sklearn.metrics import cohen_kappa_score, f1_score
+from tensorflow.keras import optimizers
+import tensorflow as tf
+import tensorflow_addons as tfa
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Process the inputs')
+    parser.add_argument(
+        '--model', 
+        type=str, 
+        help='which model would you like to run',
+        default='simple_branchy_linear_network'
+    )
+    parser.add_argument(
+        '--epochs', 
+        type=int, 
+        help='how many epochs',
+        default=12
+    )
+    parser.add_argument(
+        '--verbose', 
+        type=int, 
+        help='0,1,2',
+        default=1
+    )
+
+    args = parser.parse_args()
+    model_ = args.model
+    epochs = args.epochs
+    verbose = args.verbose
+
+    X_train, X_test, y_train, y_test, y_train_categorical, y_test_categorical, class_weight = data_generator(
+        'data/challenge_1_gut_microbiome_data.csv'
+    )
+
+    # selecting your model
+    if model_ == 'simple_branchy_linear_network':
+        model = simple_branchy_linear_network.simple_branchy_linear_network(class_weight)
+    elif model_ == 'branchy_linear_network':
+        model = branchy_linear_network.branchy_linear_network(class_weight)
+    elif model_ == 'seq_model':
+        model = seq_model.seq_model()
+    elif model_ == 'deep_linear_network':
+        model = deep_linear_network.deep_linear_network(class_weight)
+    elif model_ == 'dual_input_model':
+        model = dual_input_model.dual_input_model()
+
+    print(model.summary())
+
+    # compile the model
+    adam = optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
+    model.compile(
+        optimizer=adam, 
+        loss=tf.keras.losses.CategoricalCrossentropy(),
+        metrics=[
+            tfa.metrics.CohenKappa(num_classes=4, weightage='quadratic'),
+            tfa.metrics.F1Score(num_classes=4),
+            'accuracy'
+        ]
+    )
+
+    # create call backs
+    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
+        monitor='val_loss',
+        factor=0.5,
+        patience=5,
+        verbose=verbose,
+        min_lr=1e-8,
+    )
+
+    checkpoint = tf.keras.callbacks.ModelCheckpoint(
+        'models/checkpoint',
+        monitor='val_cohen_kappa',
+        verbose=verbose,
+        save_best_only=True,
+        mode='max',
+        save_weights_only=False
+    )
+
+    # fit the model
+    model.fit(
+        x=X_train,
+        y=y_train_categorical,
+        batch_size=16,
+        epochs=epochs,
+        verbose=verbose,
+        validation_data=(X_test, y_test_categorical),
+        shuffle=True,
+        callbacks=[
+            reduce_lr,
+            checkpoint
+        ]
+    )
+
+    # evaluate the model
+    # load the best model
+    model.load_weights('models/checkpoint')
+
+    y_prob = model.predict(X_test) 
+    y_classes = y_prob.argmax(axis=-1)
+
+    ck_score = cohen_kappa_score(
+        y_test,
+        y_classes,
+        weights='quadratic'
+    )
+
+    f1_score = f1_score(
+        y_test,
+        y_classes,
+        labels=[0,1,2,3],
+        average='weighted',
+    )
+
+    print('Results:')
+    print('cohen kappa score:', ck_score)
+    print('f1_score:', f1_score)