--- a +++ b/main.py @@ -0,0 +1,127 @@ +import argparse +from networks import ( + branchy_linear_network, + deep_linear_network, + dual_input_model, + seq_model, + simple_branchy_linear_network +) +from data_generator import data_generator +from sklearn.metrics import cohen_kappa_score, f1_score +from tensorflow.keras import optimizers +import tensorflow as tf +import tensorflow_addons as tfa + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process the inputs') + parser.add_argument( + '--model', + type=str, + help='which model would you like to run', + default='simple_branchy_linear_network' + ) + parser.add_argument( + '--epochs', + type=int, + help='how many epochs', + default=12 + ) + parser.add_argument( + '--verbose', + type=int, + help='0,1,2', + default=1 + ) + + args = parser.parse_args() + model_ = args.model + epochs = args.epochs + verbose = args.verbose + + X_train, X_test, y_train, y_test, y_train_categorical, y_test_categorical, class_weight = data_generator( + 'data/challenge_1_gut_microbiome_data.csv' + ) + + # selecting your model + if model_ == 'simple_branchy_linear_network': + model = simple_branchy_linear_network.simple_branchy_linear_network(class_weight) + elif model_ == 'branchy_linear_network': + model = branchy_linear_network.branchy_linear_network(class_weight) + elif model_ == 'seq_model': + model = seq_model.seq_model() + elif model_ == 'deep_linear_network': + model = deep_linear_network.deep_linear_network(class_weight) + elif model_ == 'dual_input_model': + model = dual_input_model.dual_input_model() + + print(model.summary()) + + # compile the model + adam = optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0) + model.compile( + optimizer=adam, + loss=tf.keras.losses.CategoricalCrossentropy(), + metrics=[ + tfa.metrics.CohenKappa(num_classes=4, weightage='quadratic'), + tfa.metrics.F1Score(num_classes=4), + 'accuracy' + ] + ) + + # create call backs + reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( + monitor='val_loss', + factor=0.5, + patience=5, + verbose=verbose, + min_lr=1e-8, + ) + + checkpoint = tf.keras.callbacks.ModelCheckpoint( + 'models/checkpoint', + monitor='val_cohen_kappa', + verbose=verbose, + save_best_only=True, + mode='max', + save_weights_only=False + ) + + # fit the model + model.fit( + x=X_train, + y=y_train_categorical, + batch_size=16, + epochs=epochs, + verbose=verbose, + validation_data=(X_test, y_test_categorical), + shuffle=True, + callbacks=[ + reduce_lr, + checkpoint + ] + ) + + # evaluate the model + # load the best model + model.load_weights('models/checkpoint') + + y_prob = model.predict(X_test) + y_classes = y_prob.argmax(axis=-1) + + ck_score = cohen_kappa_score( + y_test, + y_classes, + weights='quadratic' + ) + + f1_score = f1_score( + y_test, + y_classes, + labels=[0,1,2,3], + average='weighted', + ) + + print('Results:') + print('cohen kappa score:', ck_score) + print('f1_score:', f1_score)