Diff of /train.py [000000] .. [cec8b4]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,101 @@
+'''Script to execute training. Assumes projects have already been created
+and tiles have already been extracted from slides.
+'''
+
+import click
+import multiprocessing
+import re
+from typing import List
+
+from biscuit.experiment import Experiment, ALL_EXP
+
+# -----------------------------------------------------------------------------
+
+def num_range(s: str) -> List[int]:
+    '''Accept either a comma separated list of numbers 'a,b,c' or a
+    range 'a-c' and return as a list of ints.
+    '''
+    range_re = re.compile(r'^(\d+)-(\d+)$')
+    m = range_re.match(s)
+    if m:
+        return list(range(int(m.group(1)), int(m.group(2))+1))
+    vals = s.split(',')
+    return [int(x) for x in vals]
+
+# -----------------------------------------------------------------------------
+
+@click.command()
+@click.option('--train_project', default='projects/training', type=str, help='Override training project')
+@click.option('--eval_project', default='projects/evaluation',type=str, help='Override eval project')
+@click.option('--outcome', type=str, help='Outcome (annotation header) that assigns class labels.', default='cohort', show_default=True)
+@click.option('--outcome1', type=str, help='First class label.', default='LUAD', show_default=True)
+@click.option('--outcome2', type=str, help='Second class label.', default='LUSC', show_default=True)
+@click.option('--steps', type=num_range, help='Training steps to perform')
+@click.option('--reg', type=bool, help='Train regular models', default=True)
+@click.option('--ratio', type=bool, help='Train ratio models', default=True)
+@click.option('--gan', type=bool, help='Train gan models', default=False)
+def train_models(
+    train_project,
+    eval_project,
+    outcome,
+    outcome1,
+    outcome2,
+    steps=None,
+    reg=True,
+    ratio=True,
+    gan=False,
+):
+    # --- Configure experiments -----------------------------------------------
+    experiment = Experiment(
+        train_project,
+        eval_projects=[eval_project],
+        outcome=outcome,
+        outcome1=outcome1,
+        outcome2=outcome2,
+        outdir='results')
+
+    if steps is None:
+        steps = range(7)
+    to_run = []
+
+    # Configure regular experiments
+    if reg:
+        reg1 = experiment.config('{}', ALL_EXP, 1, order='f')
+        reg2 = experiment.config('{}2', ALL_EXP, 1, order='f', order_col='order2')
+        rev1 = experiment.config('{}_R', ALL_EXP, 1, order='r')
+        rev2 = experiment.config('{}_R2', ALL_EXP, 1, order='r', order_col='order2')
+        to_run += [reg1, reg2, rev1, rev2]
+
+    # Configure 3:1 and 10:1 ratio experiments
+    if ratio:
+        ratio_exp = list('AMDPGZ')
+        ratio_3 = experiment.config('{}_3', ratio_exp, 3, order='f')
+        ratio_3_rev = experiment.config('{}_R_3', ratio_exp, 3, order='r')
+        ratio_10 = experiment.config('{}_10', ratio_exp, 10, order='f')
+        ratio_10_rev = experiment.config('{}_R_10', ratio_exp, 10, order='r')
+        to_run += [ratio_3, ratio_3_rev, ratio_10, ratio_10_rev]
+
+    # GAN experiments
+    if gan:
+        _g = list('RALMNDOPQGWY') + ['ZA', 'ZC']
+        gan_exp = {}
+        gan_exp.update(experiment.config('{}_g10', _g, 1, gan=0.1, order='f'))
+        gan_exp.update(experiment.config('{}_R_g10', _g, 1, gan=0.1, order='r'))
+        gan_exp.update(experiment.config('{}_g20', _g, 1, gan=0.2, order='f'))
+        gan_exp.update(experiment.config('{}_R_g20', _g, 1, gan=0.2, order='r'))
+        gan_exp.update(experiment.config('{}_g30', _g, 1, gan=0.3, order='f'))
+        gan_exp.update(experiment.config('{}_R_g30', _g, 1, gan=0.3, order='r'))
+        gan_exp.update(experiment.config('{}_g40', _g, 1, gan=0.4, order='f'))
+        gan_exp.update(experiment.config('{}_R_g40', _g, 1, gan=0.4, order='r'))
+        gan_exp.update(experiment.config('{}_g50', _g, 1, gan=0.5, order='f'))
+        gan_exp.update(experiment.config('{}_R_g50', _g, 1, gan=0.5, order='r'))
+        to_run += [gan_exp]
+
+    # --- Train experiments ---------------------------------------------------
+    for exp in to_run:
+        experiment.run(exp, steps=steps)
+
+
+if __name__ == '__main__':
+    multiprocessing.freeze_support()
+    train_models()  # pylint: disable=no-value-for-parameter