|
a |
|
b/train.py |
|
|
1 |
'''Script to execute training. Assumes projects have already been created |
|
|
2 |
and tiles have already been extracted from slides. |
|
|
3 |
''' |
|
|
4 |
|
|
|
5 |
import click |
|
|
6 |
import multiprocessing |
|
|
7 |
import re |
|
|
8 |
from typing import List |
|
|
9 |
|
|
|
10 |
from biscuit.experiment import Experiment, ALL_EXP |
|
|
11 |
|
|
|
12 |
# ----------------------------------------------------------------------------- |
|
|
13 |
|
|
|
14 |
def num_range(s: str) -> List[int]: |
|
|
15 |
'''Accept either a comma separated list of numbers 'a,b,c' or a |
|
|
16 |
range 'a-c' and return as a list of ints. |
|
|
17 |
''' |
|
|
18 |
range_re = re.compile(r'^(\d+)-(\d+)$') |
|
|
19 |
m = range_re.match(s) |
|
|
20 |
if m: |
|
|
21 |
return list(range(int(m.group(1)), int(m.group(2))+1)) |
|
|
22 |
vals = s.split(',') |
|
|
23 |
return [int(x) for x in vals] |
|
|
24 |
|
|
|
25 |
# ----------------------------------------------------------------------------- |
|
|
26 |
|
|
|
27 |
@click.command() |
|
|
28 |
@click.option('--train_project', default='projects/training', type=str, help='Override training project') |
|
|
29 |
@click.option('--eval_project', default='projects/evaluation',type=str, help='Override eval project') |
|
|
30 |
@click.option('--outcome', type=str, help='Outcome (annotation header) that assigns class labels.', default='cohort', show_default=True) |
|
|
31 |
@click.option('--outcome1', type=str, help='First class label.', default='LUAD', show_default=True) |
|
|
32 |
@click.option('--outcome2', type=str, help='Second class label.', default='LUSC', show_default=True) |
|
|
33 |
@click.option('--steps', type=num_range, help='Training steps to perform') |
|
|
34 |
@click.option('--reg', type=bool, help='Train regular models', default=True) |
|
|
35 |
@click.option('--ratio', type=bool, help='Train ratio models', default=True) |
|
|
36 |
@click.option('--gan', type=bool, help='Train gan models', default=False) |
|
|
37 |
def train_models( |
|
|
38 |
train_project, |
|
|
39 |
eval_project, |
|
|
40 |
outcome, |
|
|
41 |
outcome1, |
|
|
42 |
outcome2, |
|
|
43 |
steps=None, |
|
|
44 |
reg=True, |
|
|
45 |
ratio=True, |
|
|
46 |
gan=False, |
|
|
47 |
): |
|
|
48 |
# --- Configure experiments ----------------------------------------------- |
|
|
49 |
experiment = Experiment( |
|
|
50 |
train_project, |
|
|
51 |
eval_projects=[eval_project], |
|
|
52 |
outcome=outcome, |
|
|
53 |
outcome1=outcome1, |
|
|
54 |
outcome2=outcome2, |
|
|
55 |
outdir='results') |
|
|
56 |
|
|
|
57 |
if steps is None: |
|
|
58 |
steps = range(7) |
|
|
59 |
to_run = [] |
|
|
60 |
|
|
|
61 |
# Configure regular experiments |
|
|
62 |
if reg: |
|
|
63 |
reg1 = experiment.config('{}', ALL_EXP, 1, order='f') |
|
|
64 |
reg2 = experiment.config('{}2', ALL_EXP, 1, order='f', order_col='order2') |
|
|
65 |
rev1 = experiment.config('{}_R', ALL_EXP, 1, order='r') |
|
|
66 |
rev2 = experiment.config('{}_R2', ALL_EXP, 1, order='r', order_col='order2') |
|
|
67 |
to_run += [reg1, reg2, rev1, rev2] |
|
|
68 |
|
|
|
69 |
# Configure 3:1 and 10:1 ratio experiments |
|
|
70 |
if ratio: |
|
|
71 |
ratio_exp = list('AMDPGZ') |
|
|
72 |
ratio_3 = experiment.config('{}_3', ratio_exp, 3, order='f') |
|
|
73 |
ratio_3_rev = experiment.config('{}_R_3', ratio_exp, 3, order='r') |
|
|
74 |
ratio_10 = experiment.config('{}_10', ratio_exp, 10, order='f') |
|
|
75 |
ratio_10_rev = experiment.config('{}_R_10', ratio_exp, 10, order='r') |
|
|
76 |
to_run += [ratio_3, ratio_3_rev, ratio_10, ratio_10_rev] |
|
|
77 |
|
|
|
78 |
# GAN experiments |
|
|
79 |
if gan: |
|
|
80 |
_g = list('RALMNDOPQGWY') + ['ZA', 'ZC'] |
|
|
81 |
gan_exp = {} |
|
|
82 |
gan_exp.update(experiment.config('{}_g10', _g, 1, gan=0.1, order='f')) |
|
|
83 |
gan_exp.update(experiment.config('{}_R_g10', _g, 1, gan=0.1, order='r')) |
|
|
84 |
gan_exp.update(experiment.config('{}_g20', _g, 1, gan=0.2, order='f')) |
|
|
85 |
gan_exp.update(experiment.config('{}_R_g20', _g, 1, gan=0.2, order='r')) |
|
|
86 |
gan_exp.update(experiment.config('{}_g30', _g, 1, gan=0.3, order='f')) |
|
|
87 |
gan_exp.update(experiment.config('{}_R_g30', _g, 1, gan=0.3, order='r')) |
|
|
88 |
gan_exp.update(experiment.config('{}_g40', _g, 1, gan=0.4, order='f')) |
|
|
89 |
gan_exp.update(experiment.config('{}_R_g40', _g, 1, gan=0.4, order='r')) |
|
|
90 |
gan_exp.update(experiment.config('{}_g50', _g, 1, gan=0.5, order='f')) |
|
|
91 |
gan_exp.update(experiment.config('{}_R_g50', _g, 1, gan=0.5, order='r')) |
|
|
92 |
to_run += [gan_exp] |
|
|
93 |
|
|
|
94 |
# --- Train experiments --------------------------------------------------- |
|
|
95 |
for exp in to_run: |
|
|
96 |
experiment.run(exp, steps=steps) |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
if __name__ == '__main__': |
|
|
100 |
multiprocessing.freeze_support() |
|
|
101 |
train_models() # pylint: disable=no-value-for-parameter |