Switch to unified view

a b/pathflowai/model_training.py
1
import torch, os, numpy as np, pandas as pd
2
from pathflowai.utils import *
3
#from large_data_utils import *
4
from pathflowai.datasets import *
5
from pathflowai.models import *
6
from pathflowai.schedulers import *
7
from pathflowai.visualize import *
8
import copy
9
from pathflowai.sampler import ImbalancedDatasetSampler
10
import argparse
11
import sqlite3
12
#from nonechucks import SafeDataLoader as DataLoader
13
from torch.utils.data import DataLoader
14
import click
15
import pysnooper
16
17
CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
18
19
@click.group(context_settings= CONTEXT_SETTINGS)
20
@click.version_option(version='0.1')
21
def train():
22
    pass
23
24
def return_model(training_opts):
25
    model=generate_model(pretrain=training_opts['pretrain'],architecture=training_opts['architecture'],num_classes=training_opts['num_targets'], add_sigmoid=False, n_hidden=training_opts['n_hidden'], segmentation=training_opts['segmentation'])
26
27
    if os.path.exists(training_opts['pretrained_save_location']) and not training_opts['predict']:
28
        model_dict = torch.load(training_opts['pretrained_save_location'])
29
        keys=list(model_dict.keys())
30
        if not training_opts['segmentation']:
31
            model_dict.update(dict(list(model.state_dict().items())[-2:]))#={k:model_dict[k] for k in keys[:-2]}
32
        model.load_state_dict(model_dict) # this will likely break after pretraining?
33
    elif os.path.exists(training_opts['save_location']) and training_opts['predict']:
34
        model_dict = torch.load(training_opts['save_location'])
35
        model.load_state_dict(model_dict)
36
    if training_opts['extract_embedding']:
37
        assert training_opts['extract_embedding']==training_opts['predict']
38
        architecture=training_opts['architecture']
39
        if hasattr(model,"fc"):
40
            model.fc = model.fc[0]
41
        elif hasattr(model,"output"):
42
            model.output = model.output[0]
43
        elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
44
            model.classifier[6]=model.classifier[6][0]
45
    if torch.cuda.is_available():
46
        model.cuda()
47
    return model
48
49
def run_test(dataloader):
50
    for i,(X,y) in enumerate(dataloader):
51
        np.save('X_test_{}.npy'.format(i),X.detach().cpu().numpy())#np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
52
        np.save('y_test_{}.npy'.format(i),y.detach().cpu().numpy())
53
        if i==5:
54
            exit()
55
56
def return_trainer_opts(model,training_opts,dataloaders,num_train_batches):
57
    return dict(model=model,
58
                n_epoch=training_opts['n_epoch'],
59
                validation_dataloader=dataloaders['val'],
60
                optimizer_opts=dict(name=training_opts['optimizer'],
61
                                    lr=training_opts['lr'],
62
                                    weight_decay=training_opts['wd']),
63
                scheduler_opts=dict(scheduler=training_opts['scheduler_type'],
64
                                    lr_scheduler_decay=0.5,
65
                                    T_max=training_opts['T_max'],
66
                                    eta_min=training_opts['eta_min'],
67
                                    T_mult=training_opts['T_mult']),
68
                loss_fn=training_opts['loss_fn'],
69
                num_train_batches=num_train_batches,
70
                seg_out_class=training_opts['seg_out_class'],
71
                apex_opt_level=training_opts['apex_opt_level'],
72
                checkpointing=training_opts['checkpointing'])
73
74
def return_transformer(training_opts):
75
    dataset_df = pd.read_csv(training_opts['dataset_df']) if os.path.exists(training_opts['dataset_df']) else create_train_val_test(training_opts['train_val_test_splits'],training_opts['patch_info_file'],training_opts['patch_size'])
76
77
    dataset_opts=dict(dataset_df=dataset_df, set='pass', patch_info_file=training_opts['patch_info_file'], input_dir=training_opts['input_dir'], target_names=training_opts['target_names'], pos_annotation_class=training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'])
78
79
    norm_dict = get_normalizer(training_opts['normalization_file'], dataset_opts)
80
81
    transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms'])
82
83
    transformers = get_data_transforms(**transform_opts)
84
    return dataset_df,dataset_opts,transformers
85
86
def return_datasets(training_opts,dataset_df,transformers):
87
    datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']}
88
    # nc.SafeDataset(
89
    print(datasets['train'])
90
91
    if len(training_opts['target_segmentation_class']) > 1:
92
        from functools import reduce
93
        for i in range(1,len(training_opts['target_segmentation_class'])):
94
            #print(training_opts['classify_annotations'])
95
            datasets['train'].concat(DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i],n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter']))
96
        #datasets['train']=reduce(lambda x,y: x.concat(y),[DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i]) for i in range(len(training_opts['target_segmentation_class']))])
97
        print(datasets['train'])
98
99
    if training_opts['supplement']:
100
        old_train_set = copy.deepcopy(datasets['train'])
101
        datasets['train']=DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=-1, target_threshold=training_opts['target_threshold'], oversampling_factor=1,n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'])
102
        datasets['train'].concat(old_train_set)
103
104
    if training_opts['subsample_p']<1.0:
105
        datasets['train'].subsample(training_opts['subsample_p'])
106
107
    if training_opts['subsample_p_val']<1.0:
108
        if training_opts['subsample_p_val']==-1.:
109
            training_opts['subsample_p_val']=training_opts['subsample_p']
110
        if training_opts['subsample_p_val']<1.0:
111
            datasets['val'].subsample(training_opts['subsample_p_val'])
112
113
    if training_opts['classify_annotations']:
114
        binarizer=datasets['train'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
115
        datasets['val'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
116
        datasets['test'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
117
        training_opts['num_targets']=len(datasets['train'].targets)
118
119
    for Set in ['train','val','test']:
120
        print(datasets[Set].patch_info.iloc[:,6:].sum(axis=0))
121
122
    if training_opts['prediction_set']!='test':
123
        datasets['test']=datasets[training_opts['prediction_set']]
124
125
    if training_opts['external_test_db'] and training_opts['external_test_dir']:
126
        datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename'])
127
    return datasets,training_opts,transform_opts
128
129
#@pysnooper.snoop('train_model.log')
130
def train_model_(training_opts):
131
    """Function to train, predict on model.
132
133
    Parameters
134
    ----------
135
    training_opts : dict
136
        Training options populated from command line.
137
138
    """
139
140
    cuda_available=torch.cuda.is_available()
141
142
    model = return_model(training_opts)
143
144
    dataset_df,dataset_opts,transformers=return_transformer(training_opts)
145
146
    if training_opts['extract_embedding'] and training_opts['npy_file']:
147
        dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"])
148
        dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir'])
149
        exit()
150
151
    datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers)
152
153
    if training_opts['num_training_images_epoch']>0:
154
        num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size']
155
    else:
156
        num_train_batches = None
157
158
    dataloaders={set: DataLoader(datasets[set], batch_size=training_opts['batch_size'], shuffle=(set=='train') if not (training_opts['imbalanced_correction'] and not training_opts['segmentation']) else False, num_workers=10, sampler=ImbalancedDatasetSampler(datasets[set]) if (training_opts['imbalanced_correction'] and set=='train' and not training_opts['segmentation']) else None) for set in ['train', 'val', 'test']}
159
160
    print(dataloaders['train'].sampler) # FIXME VAL SEEMS TO BE MISSING DURING PREDICTION
161
    print(dataloaders['val'].sampler)
162
163
    if training_opts['run_test']: run_test(dataloaders['train'])
164
165
    model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches)
166
167
    if not training_opts['predict']:
168
169
        trainer = ModelTrainer(**model_trainer_opts)
170
171
        if training_opts['imbalanced_correction2']:
172
            trainer.add_class_balance_loss(datasets['train'])
173
        elif training_opts['custom_weights']:
174
            trainer.add_class_balance_loss(datasets['train'],custom_weights=training_opts['custom_weights'])
175
176
        if training_opts['adopt_training_loss']:
177
            trainer.val_loss_fn = trainer.loss_fn
178
179
        trainer.fit(dataloaders['train'], verbose=True, print_every=1, plot_training_curves=True, plot_save_file=training_opts['training_curve'], print_val_confusion=training_opts['print_val_confusion'], save_val_predictions=training_opts['save_val_predictions'])
180
181
        torch.save(trainer.model.state_dict(),training_opts['save_location'])
182
183
    else:
184
185
        if training_opts['extract_model']:
186
            dataset_opts.update(dict(target_segmentation_class=-1, target_threshold=training_opts['target_threshold'][0] if len(training_opts['target_threshold']) else 0., set='test', binary_threshold=training_opts['binary_threshold'], num_targets=training_opts['num_targets'], oversampling_factor=1))
187
            torch.save(dict(model=model,dataset_opts=dataset_opts, transform_opts=transform_opts),'{}.{}'.format(training_opts['save_location'],'extracted_model.pkl'))
188
            exit()
189
190
        trainer = ModelTrainer(**model_trainer_opts)
191
192
        if training_opts['segmentation']:
193
            for ID, dataset in (datasets['test'].split_by_ID() if not training_opts['prediction_basename'] else datasets['test'].select_IDs(training_opts['prediction_basename'])):
194
                dataloader = DataLoader(dataset, batch_size=training_opts['batch_size'], shuffle=False, num_workers=10)
195
                if training_opts['run_test']:
196
                    for X,y in dataloader:
197
                        np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
198
                        exit()
199
                y_pred = trainer.predict(dataloader)
200
                print(ID,y_pred.shape)
201
                segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0))
202
        else:
203
            extract_embedding=training_opts['extract_embedding']
204
            if extract_embedding:
205
                trainer.bce=False
206
207
            y_pred = trainer.predict(dataloaders['test'])
208
209
            patch_info = dataloaders['test'].dataset.patch_info
210
211
            if extract_embedding:
212
                patch_info['name']=patch_info.astype(str).apply(lambda x: '\n'.join(['{}:{}'.format(k,v) for k,v in x.to_dict().items()]),axis=1)#.apply(','.join,axis=1)
213
                embeddings=pd.DataFrame(y_pred,index=patch_info['name'])
214
                embeddings['ID']=patch_info['ID'].values
215
                torch.save(dict(embeddings=embeddings,patch_info=patch_info),join(training_opts['prediction_output_dir'],'embeddings.pkl'))
216
217
            else:
218
                if len(y_pred.shape)>1 and y_pred.shape[1]>1:
219
                    annotations = np.vectorize(lambda x: x+'_pred')(np.arange(y_pred.shape[1]).astype(str)).tolist() # [training_opts['pos_annotation_class']]+training_opts['other_annotations']] if training_opts['classify_annotations'] else
220
                    for i in range(y_pred.shape[1]):
221
                        patch_info.loc[:,annotations[i]]=y_pred[:,i]
222
                patch_info['y_pred']=y_pred if (training_opts['num_targets']==1 or not (training_opts['classify_annotations'] or training_opts['mt_bce'])) else y_pred.argmax(axis=1)
223
224
                conn = sqlite3.connect(training_opts['prediction_save_path'])
225
                patch_info.to_sql(str(training_opts['patch_size']),con=conn, if_exists=('replace')) # if not training_opts['prediction_basename'] else 'append'))
226
                conn.close()
227
228
@train.command()
229
@click.option('-s', '--segmentation', is_flag=True, help='Segmentation task.', show_default=True)
230
@click.option('-p', '--prediction', is_flag=True, help='Predict on model.', show_default=True)
231
@click.option('-pa', '--pos_annotation_class', default='', help='Annotation Class from which to apply positive labels.', type=click.Path(exists=False), show_default=True)
232
@click.option('-oa', '--other_annotations', default=[], multiple=True, help='Annotations in image.', type=click.Path(exists=False), show_default=True)
233
@click.option('-o', '--save_location', default='', help='Model Save Location, append with pickle .pkl.', type=click.Path(exists=False), show_default=True)
234
@click.option('-pt', '--pretrained_save_location', default='', help='Model Save Location, append with pickle .pkl, pretrained by previous analysis to be finetuned.', type=click.Path(exists=False), show_default=True)
235
@click.option('-i', '--input_dir', default='', help='Input directory containing slides and everything.', type=click.Path(exists=False), show_default=True)
236
@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
237
@click.option('-pr', '--patch_resize', default=224, help='Patch resized.',  show_default=True)
238
@click.option('-tg', '--target_names', default=[], multiple=True, help='Targets.', type=click.Path(exists=False), show_default=True)
239
@click.option('-df', '--dataset_df', default='', help='CSV file with train/val/test and target info.', type=click.Path(exists=False), show_default=True)
240
@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True)
241
@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
242
                                            'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn',
243
                                            'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
244
@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
245
@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
246
@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True)
247
@click.option('-nt', '--num_targets', default=1, help='Number of targets.', show_default=True)
248
@click.option('-ss', '--subsample_p', default=1.0, help='Subsample training set.', show_default=True)
249
@click.option('-ssv', '--subsample_p_val', default=-1., help='Subsample val set. If not set, defaults to that of training set', show_default=True)
250
@click.option('-t', '--num_training_images_epoch', default=-1, help='Number of training images per epoch. -1 means use all training images each epoch.', show_default=True)
251
@click.option('-lr', '--learning_rate', default=1e-2, help='Learning rate.', show_default=True)
252
@click.option('-tp', '--transform_platform', default='torch', help='Transform platform for nonsegmentation tasks.', type=click.Choice(['torch','albumentations']))
253
@click.option('-ne', '--n_epoch', default=10, help='Number of epochs.', show_default=True)
254
@click.option('-pi', '--patch_info_file', default='patch_info.db', help='Patch info file.', type=click.Path(exists=False), show_default=True)
255
@click.option('-tc', '--target_segmentation_class', default=[-1], multiple=True, help='Segmentation Class to finetune on.',  show_default=True)
256
@click.option('-tt', '--target_threshold', default=[0.], multiple=True, help='Threshold to include target for segmentation if saving one class.',  show_default=True)
257
@click.option('-ov', '--oversampling_factor', default=[1.], multiple=True, help='How much to oversample training set.',  show_default=True)
258
@click.option('-sup', '--supplement', is_flag=True, help='Use the thresholding to supplement the original training set.', show_default=True)
259
@click.option('-bs', '--batch_size', default=10, help='Batch size.',  show_default=True)
260
@click.option('-rt', '--run_test', is_flag=True, help='Output predictions for a batch to "test_predictions.npy". Use for debugging.',  show_default=True)
261
@click.option('-mtb', '--mt_bce', is_flag=True, help='Run multi-target bce predictions on the annotations.',  show_default=True)
262
@click.option('-po', '--prediction_output_dir', default='predictions', help='Where to output segmentation predictions.', type=click.Path(exists=False), show_default=True)
263
@click.option('-ee', '--extract_embedding', is_flag=True, help='Extract embeddings.',  show_default=True)
264
@click.option('-em', '--extract_model', is_flag=True, help='Save entire torch model.',  show_default=True)
265
@click.option('-bt', '--binary_threshold', default=0., help='If running binary classification on annotations, dichotomize selected annotation as such.',  show_default=True)
266
@click.option('-prt', '--pretrain', is_flag=True, help='Pretrain on ImageNet.', show_default=True)
267
@click.option('-olf', '--overwrite_loss_fn', default='', help='Overwrite the default training loss functions with loss of choice.', type=click.Choice(['','bce','mse','focal','dice','gdl','ce']), show_default=True)
268
@click.option('-atl', '--adopt_training_loss', is_flag=True, help='Adopt training loss function for validation calculation.', show_default=True)
269
@click.option('-tdb', '--external_test_db', default='', help='External database of samples to test on.', type=click.Path(exists=False), show_default=True)
270
@click.option('-tdir', '--external_test_dir', default='', help='External directory of samples to test on.', type=click.Path(exists=False), show_default=True)
271
@click.option('-pb', '--prediction_basename', default=[''], multiple=True, help='For segmentation tasks, if supplied, can predict on these basenames rather than the entire test set. Only works for segmentation tasks for now',  show_default=True)
272
@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False),  show_default=True)
273
@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True)
274
@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True)
275
@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.',  show_default=True)
276
@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.',  show_default=True)
277
@click.option('-aol', '--apex_opt_level', default='O2', help='YAML file to add transforms from.', type=click.Choice(['O0','O1','O2','O3']), show_default=True)
278
@click.option('-ckp', '--checkpointing', is_flag=True, help='Save intermediate models to ./checkpoints.',  show_default=True)
279
@click.option('-npy', '--npy_file', default='', help='Specify one file to extract embeddings from. Embeddings are output into predictions directory', type=click.Path(exists=False), show_default=True)
280
@click.option('-gpu', '--gpu_id', default=-1, help='Set GPU if 0 and greater.',  show_default=True)
281
def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class, apex_opt_level, checkpointing, npy_file, gpu_id):
282
    """Train and predict using model for regression and classification tasks."""
283
    # add separate pretrain ability on separating cell types, then transfer learn
284
    # add pretrain and efficient net, pretraining remove last layer while loading state dict
285
    if gpu_id>=0:
286
        torch.cuda.set_device(gpu_id)
287
    target_segmentation_class=list(map(int,target_segmentation_class))
288
    target_threshold=list(map(float,target_threshold))
289
    oversampling_factor=[(int(x) if float(x)>=1 else float(x)) for x in oversampling_factor]
290
    other_annotations=list(other_annotations)
291
    prediction_basename=list(filter(None,prediction_basename))
292
    command_opts = dict(segmentation=segmentation,
293
                        prediction=prediction,
294
                        pos_annotation_class=pos_annotation_class,
295
                        other_annotations=other_annotations,
296
                        save_location=save_location,
297
                        pretrained_save_location=pretrained_save_location,
298
                        input_dir=input_dir,
299
                        patch_size=patch_size,
300
                        target_names=target_names,
301
                        dataset_df=dataset_df,
302
                        fix_names=fix_names,
303
                        architecture=architecture,
304
                        patch_resize=patch_resize,
305
                        imbalanced_correction=imbalanced_correction,
306
                        imbalanced_correction2=imbalanced_correction2,
307
                        classify_annotations=classify_annotations,
308
                        num_targets=num_targets,
309
                        subsample_p=subsample_p,
310
                        num_training_images_epoch=num_training_images_epoch,
311
                        lr=learning_rate,
312
                        transform_platform=transform_platform,
313
                        n_epoch=n_epoch,
314
                        patch_info_file=patch_info_file,
315
                        target_segmentation_class=target_segmentation_class,
316
                        target_threshold=target_threshold,
317
                        oversampling_factor=oversampling_factor,
318
                        supplement=supplement,
319
                        predict=prediction,
320
                        batch_size=batch_size,
321
                        run_test=run_test,
322
                        mt_bce=mt_bce,
323
                        prediction_output_dir=prediction_output_dir,
324
                        extract_embedding=extract_embedding,
325
                        extract_model=extract_model,
326
                        binary_threshold=binary_threshold,
327
                        subsample_p_val=subsample_p_val,
328
                        wd=1e-3,
329
                        scheduler_type='warm_restarts',
330
                        T_max=10,
331
                        T_mult=2,
332
                        eta_min=5e-8,
333
                        optimizer='adam',
334
                        n_hidden=100,
335
                        pretrain=pretrain,
336
                        training_curve='training_curve.png',
337
                        adopt_training_loss=adopt_training_loss,
338
                        external_test_db=external_test_db,
339
                        external_test_dir=external_test_dir,
340
                        prediction_basename=prediction_basename,
341
                        save_val_predictions=save_val_predictions,
342
                        custom_weights=custom_weights,
343
                        prediction_set=prediction_set,
344
                        user_transforms=dict(),
345
                        dilation_jitter=dict(),
346
                        seg_out_class=seg_out_class,
347
                        apex_opt_level=apex_opt_level,
348
                        checkpointing=checkpointing,
349
                        npy_file=npy_file)
350
351
    training_opts = dict(normalization_file="normalization_parameters.pkl",
352
                         loss_fn='bce',
353
                         print_val_confusion=True,
354
                         prediction_save_path = 'predictions.db',
355
                         train_val_test_splits='train_val_test.pkl'
356
                         )
357
    segmentation_training_opts = copy.deepcopy(training_opts)
358
    segmentation_training_opts.update(dict(loss_fn='dice',#gdl dice+ce
359
                                            normalization_file='normalization_segmentation.pkl',
360
                                            fix_names=False,
361
                                            ))
362
    if segmentation:
363
        training_opts = segmentation_training_opts
364
    for k in command_opts:
365
        training_opts[k] = command_opts[k]
366
    if classify_annotations:
367
        if training_opts['num_targets']==1:
368
            training_opts['loss_fn']='bce'
369
        else:
370
            training_opts['loss_fn']='ce'
371
    if mt_bce:
372
        training_opts['loss_fn']='bce'
373
    if overwrite_loss_fn:
374
        training_opts['loss_fn']=overwrite_loss_fn
375
376
    if user_transforms_file and os.path.exists(user_transforms_file):
377
        from yaml import load as yml_load
378
        try:
379
            from yaml import CLoader as Loader, CDumper as Dumper
380
        except ImportError:
381
            from yaml import Loader, Dumper
382
        with open(user_transforms_file) as f:
383
            training_opts['user_transforms']=yml_load(f,Loader=Loader)
384
            if 'dilationjitter' in list(training_opts['user_transforms'].keys()):
385
                training_opts['dilation_jitter']=training_opts['user_transforms'].pop('dilationjitter')
386
387
    train_model_(training_opts)
388
389
390
if __name__=='__main__':
391
    train()