Diff of /train.py [000000] .. [cec8b4]

Switch to unified view

a b/train.py
1
'''Script to execute training. Assumes projects have already been created
2
and tiles have already been extracted from slides.
3
'''
4
5
import click
6
import multiprocessing
7
import re
8
from typing import List
9
10
from biscuit.experiment import Experiment, ALL_EXP
11
12
# -----------------------------------------------------------------------------
13
14
def num_range(s: str) -> List[int]:
15
    '''Accept either a comma separated list of numbers 'a,b,c' or a
16
    range 'a-c' and return as a list of ints.
17
    '''
18
    range_re = re.compile(r'^(\d+)-(\d+)$')
19
    m = range_re.match(s)
20
    if m:
21
        return list(range(int(m.group(1)), int(m.group(2))+1))
22
    vals = s.split(',')
23
    return [int(x) for x in vals]
24
25
# -----------------------------------------------------------------------------
26
27
@click.command()
28
@click.option('--train_project', default='projects/training', type=str, help='Override training project')
29
@click.option('--eval_project', default='projects/evaluation',type=str, help='Override eval project')
30
@click.option('--outcome', type=str, help='Outcome (annotation header) that assigns class labels.', default='cohort', show_default=True)
31
@click.option('--outcome1', type=str, help='First class label.', default='LUAD', show_default=True)
32
@click.option('--outcome2', type=str, help='Second class label.', default='LUSC', show_default=True)
33
@click.option('--steps', type=num_range, help='Training steps to perform')
34
@click.option('--reg', type=bool, help='Train regular models', default=True)
35
@click.option('--ratio', type=bool, help='Train ratio models', default=True)
36
@click.option('--gan', type=bool, help='Train gan models', default=False)
37
def train_models(
38
    train_project,
39
    eval_project,
40
    outcome,
41
    outcome1,
42
    outcome2,
43
    steps=None,
44
    reg=True,
45
    ratio=True,
46
    gan=False,
47
):
48
    # --- Configure experiments -----------------------------------------------
49
    experiment = Experiment(
50
        train_project,
51
        eval_projects=[eval_project],
52
        outcome=outcome,
53
        outcome1=outcome1,
54
        outcome2=outcome2,
55
        outdir='results')
56
57
    if steps is None:
58
        steps = range(7)
59
    to_run = []
60
61
    # Configure regular experiments
62
    if reg:
63
        reg1 = experiment.config('{}', ALL_EXP, 1, order='f')
64
        reg2 = experiment.config('{}2', ALL_EXP, 1, order='f', order_col='order2')
65
        rev1 = experiment.config('{}_R', ALL_EXP, 1, order='r')
66
        rev2 = experiment.config('{}_R2', ALL_EXP, 1, order='r', order_col='order2')
67
        to_run += [reg1, reg2, rev1, rev2]
68
69
    # Configure 3:1 and 10:1 ratio experiments
70
    if ratio:
71
        ratio_exp = list('AMDPGZ')
72
        ratio_3 = experiment.config('{}_3', ratio_exp, 3, order='f')
73
        ratio_3_rev = experiment.config('{}_R_3', ratio_exp, 3, order='r')
74
        ratio_10 = experiment.config('{}_10', ratio_exp, 10, order='f')
75
        ratio_10_rev = experiment.config('{}_R_10', ratio_exp, 10, order='r')
76
        to_run += [ratio_3, ratio_3_rev, ratio_10, ratio_10_rev]
77
78
    # GAN experiments
79
    if gan:
80
        _g = list('RALMNDOPQGWY') + ['ZA', 'ZC']
81
        gan_exp = {}
82
        gan_exp.update(experiment.config('{}_g10', _g, 1, gan=0.1, order='f'))
83
        gan_exp.update(experiment.config('{}_R_g10', _g, 1, gan=0.1, order='r'))
84
        gan_exp.update(experiment.config('{}_g20', _g, 1, gan=0.2, order='f'))
85
        gan_exp.update(experiment.config('{}_R_g20', _g, 1, gan=0.2, order='r'))
86
        gan_exp.update(experiment.config('{}_g30', _g, 1, gan=0.3, order='f'))
87
        gan_exp.update(experiment.config('{}_R_g30', _g, 1, gan=0.3, order='r'))
88
        gan_exp.update(experiment.config('{}_g40', _g, 1, gan=0.4, order='f'))
89
        gan_exp.update(experiment.config('{}_R_g40', _g, 1, gan=0.4, order='r'))
90
        gan_exp.update(experiment.config('{}_g50', _g, 1, gan=0.5, order='f'))
91
        gan_exp.update(experiment.config('{}_R_g50', _g, 1, gan=0.5, order='r'))
92
        to_run += [gan_exp]
93
94
    # --- Train experiments ---------------------------------------------------
95
    for exp in to_run:
96
        experiment.run(exp, steps=steps)
97
98
99
if __name__ == '__main__':
100
    multiprocessing.freeze_support()
101
    train_models()  # pylint: disable=no-value-for-parameter