[e9500f]: / pathflowai / model_training.py

Download this file

392 lines (337 with data), 27.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
import torch, os, numpy as np, pandas as pd
from pathflowai.utils import *
#from large_data_utils import *
from pathflowai.datasets import *
from pathflowai.models import *
from pathflowai.schedulers import *
from pathflowai.visualize import *
import copy
from pathflowai.sampler import ImbalancedDatasetSampler
import argparse
import sqlite3
#from nonechucks import SafeDataLoader as DataLoader
from torch.utils.data import DataLoader
import click
import pysnooper
CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
@click.group(context_settings= CONTEXT_SETTINGS)
@click.version_option(version='0.1')
def train():
pass
def return_model(training_opts):
model=generate_model(pretrain=training_opts['pretrain'],architecture=training_opts['architecture'],num_classes=training_opts['num_targets'], add_sigmoid=False, n_hidden=training_opts['n_hidden'], segmentation=training_opts['segmentation'])
if os.path.exists(training_opts['pretrained_save_location']) and not training_opts['predict']:
model_dict = torch.load(training_opts['pretrained_save_location'])
keys=list(model_dict.keys())
if not training_opts['segmentation']:
model_dict.update(dict(list(model.state_dict().items())[-2:]))#={k:model_dict[k] for k in keys[:-2]}
model.load_state_dict(model_dict) # this will likely break after pretraining?
elif os.path.exists(training_opts['save_location']) and training_opts['predict']:
model_dict = torch.load(training_opts['save_location'])
model.load_state_dict(model_dict)
if training_opts['extract_embedding']:
assert training_opts['extract_embedding']==training_opts['predict']
architecture=training_opts['architecture']
if hasattr(model,"fc"):
model.fc = model.fc[0]
elif hasattr(model,"output"):
model.output = model.output[0]
elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
model.classifier[6]=model.classifier[6][0]
if torch.cuda.is_available():
model.cuda()
return model
def run_test(dataloader):
for i,(X,y) in enumerate(dataloader):
np.save('X_test_{}.npy'.format(i),X.detach().cpu().numpy())#np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
np.save('y_test_{}.npy'.format(i),y.detach().cpu().numpy())
if i==5:
exit()
def return_trainer_opts(model,training_opts,dataloaders,num_train_batches):
return dict(model=model,
n_epoch=training_opts['n_epoch'],
validation_dataloader=dataloaders['val'],
optimizer_opts=dict(name=training_opts['optimizer'],
lr=training_opts['lr'],
weight_decay=training_opts['wd']),
scheduler_opts=dict(scheduler=training_opts['scheduler_type'],
lr_scheduler_decay=0.5,
T_max=training_opts['T_max'],
eta_min=training_opts['eta_min'],
T_mult=training_opts['T_mult']),
loss_fn=training_opts['loss_fn'],
num_train_batches=num_train_batches,
seg_out_class=training_opts['seg_out_class'],
apex_opt_level=training_opts['apex_opt_level'],
checkpointing=training_opts['checkpointing'])
def return_transformer(training_opts):
dataset_df = pd.read_csv(training_opts['dataset_df']) if os.path.exists(training_opts['dataset_df']) else create_train_val_test(training_opts['train_val_test_splits'],training_opts['patch_info_file'],training_opts['patch_size'])
dataset_opts=dict(dataset_df=dataset_df, set='pass', patch_info_file=training_opts['patch_info_file'], input_dir=training_opts['input_dir'], target_names=training_opts['target_names'], pos_annotation_class=training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'])
norm_dict = get_normalizer(training_opts['normalization_file'], dataset_opts)
transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms'])
transformers = get_data_transforms(**transform_opts)
return dataset_df,dataset_opts,transformers
def return_datasets(training_opts,dataset_df,transformers):
datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']}
# nc.SafeDataset(
print(datasets['train'])
if len(training_opts['target_segmentation_class']) > 1:
from functools import reduce
for i in range(1,len(training_opts['target_segmentation_class'])):
#print(training_opts['classify_annotations'])
datasets['train'].concat(DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i],n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter']))
#datasets['train']=reduce(lambda x,y: x.concat(y),[DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i]) for i in range(len(training_opts['target_segmentation_class']))])
print(datasets['train'])
if training_opts['supplement']:
old_train_set = copy.deepcopy(datasets['train'])
datasets['train']=DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=-1, target_threshold=training_opts['target_threshold'], oversampling_factor=1,n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'])
datasets['train'].concat(old_train_set)
if training_opts['subsample_p']<1.0:
datasets['train'].subsample(training_opts['subsample_p'])
if training_opts['subsample_p_val']<1.0:
if training_opts['subsample_p_val']==-1.:
training_opts['subsample_p_val']=training_opts['subsample_p']
if training_opts['subsample_p_val']<1.0:
datasets['val'].subsample(training_opts['subsample_p_val'])
if training_opts['classify_annotations']:
binarizer=datasets['train'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
datasets['val'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
datasets['test'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
training_opts['num_targets']=len(datasets['train'].targets)
for Set in ['train','val','test']:
print(datasets[Set].patch_info.iloc[:,6:].sum(axis=0))
if training_opts['prediction_set']!='test':
datasets['test']=datasets[training_opts['prediction_set']]
if training_opts['external_test_db'] and training_opts['external_test_dir']:
datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename'])
return datasets,training_opts,transform_opts
#@pysnooper.snoop('train_model.log')
def train_model_(training_opts):
"""Function to train, predict on model.
Parameters
----------
training_opts : dict
Training options populated from command line.
"""
cuda_available=torch.cuda.is_available()
model = return_model(training_opts)
dataset_df,dataset_opts,transformers=return_transformer(training_opts)
if training_opts['extract_embedding'] and training_opts['npy_file']:
dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"])
dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir'])
exit()
datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers)
if training_opts['num_training_images_epoch']>0:
num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size']
else:
num_train_batches = None
dataloaders={set: DataLoader(datasets[set], batch_size=training_opts['batch_size'], shuffle=(set=='train') if not (training_opts['imbalanced_correction'] and not training_opts['segmentation']) else False, num_workers=10, sampler=ImbalancedDatasetSampler(datasets[set]) if (training_opts['imbalanced_correction'] and set=='train' and not training_opts['segmentation']) else None) for set in ['train', 'val', 'test']}
print(dataloaders['train'].sampler) # FIXME VAL SEEMS TO BE MISSING DURING PREDICTION
print(dataloaders['val'].sampler)
if training_opts['run_test']: run_test(dataloaders['train'])
model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches)
if not training_opts['predict']:
trainer = ModelTrainer(**model_trainer_opts)
if training_opts['imbalanced_correction2']:
trainer.add_class_balance_loss(datasets['train'])
elif training_opts['custom_weights']:
trainer.add_class_balance_loss(datasets['train'],custom_weights=training_opts['custom_weights'])
if training_opts['adopt_training_loss']:
trainer.val_loss_fn = trainer.loss_fn
trainer.fit(dataloaders['train'], verbose=True, print_every=1, plot_training_curves=True, plot_save_file=training_opts['training_curve'], print_val_confusion=training_opts['print_val_confusion'], save_val_predictions=training_opts['save_val_predictions'])
torch.save(trainer.model.state_dict(),training_opts['save_location'])
else:
if training_opts['extract_model']:
dataset_opts.update(dict(target_segmentation_class=-1, target_threshold=training_opts['target_threshold'][0] if len(training_opts['target_threshold']) else 0., set='test', binary_threshold=training_opts['binary_threshold'], num_targets=training_opts['num_targets'], oversampling_factor=1))
torch.save(dict(model=model,dataset_opts=dataset_opts, transform_opts=transform_opts),'{}.{}'.format(training_opts['save_location'],'extracted_model.pkl'))
exit()
trainer = ModelTrainer(**model_trainer_opts)
if training_opts['segmentation']:
for ID, dataset in (datasets['test'].split_by_ID() if not training_opts['prediction_basename'] else datasets['test'].select_IDs(training_opts['prediction_basename'])):
dataloader = DataLoader(dataset, batch_size=training_opts['batch_size'], shuffle=False, num_workers=10)
if training_opts['run_test']:
for X,y in dataloader:
np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
exit()
y_pred = trainer.predict(dataloader)
print(ID,y_pred.shape)
segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0))
else:
extract_embedding=training_opts['extract_embedding']
if extract_embedding:
trainer.bce=False
y_pred = trainer.predict(dataloaders['test'])
patch_info = dataloaders['test'].dataset.patch_info
if extract_embedding:
patch_info['name']=patch_info.astype(str).apply(lambda x: '\n'.join(['{}:{}'.format(k,v) for k,v in x.to_dict().items()]),axis=1)#.apply(','.join,axis=1)
embeddings=pd.DataFrame(y_pred,index=patch_info['name'])
embeddings['ID']=patch_info['ID'].values
torch.save(dict(embeddings=embeddings,patch_info=patch_info),join(training_opts['prediction_output_dir'],'embeddings.pkl'))
else:
if len(y_pred.shape)>1 and y_pred.shape[1]>1:
annotations = np.vectorize(lambda x: x+'_pred')(np.arange(y_pred.shape[1]).astype(str)).tolist() # [training_opts['pos_annotation_class']]+training_opts['other_annotations']] if training_opts['classify_annotations'] else
for i in range(y_pred.shape[1]):
patch_info.loc[:,annotations[i]]=y_pred[:,i]
patch_info['y_pred']=y_pred if (training_opts['num_targets']==1 or not (training_opts['classify_annotations'] or training_opts['mt_bce'])) else y_pred.argmax(axis=1)
conn = sqlite3.connect(training_opts['prediction_save_path'])
patch_info.to_sql(str(training_opts['patch_size']),con=conn, if_exists=('replace')) # if not training_opts['prediction_basename'] else 'append'))
conn.close()
@train.command()
@click.option('-s', '--segmentation', is_flag=True, help='Segmentation task.', show_default=True)
@click.option('-p', '--prediction', is_flag=True, help='Predict on model.', show_default=True)
@click.option('-pa', '--pos_annotation_class', default='', help='Annotation Class from which to apply positive labels.', type=click.Path(exists=False), show_default=True)
@click.option('-oa', '--other_annotations', default=[], multiple=True, help='Annotations in image.', type=click.Path(exists=False), show_default=True)
@click.option('-o', '--save_location', default='', help='Model Save Location, append with pickle .pkl.', type=click.Path(exists=False), show_default=True)
@click.option('-pt', '--pretrained_save_location', default='', help='Model Save Location, append with pickle .pkl, pretrained by previous analysis to be finetuned.', type=click.Path(exists=False), show_default=True)
@click.option('-i', '--input_dir', default='', help='Input directory containing slides and everything.', type=click.Path(exists=False), show_default=True)
@click.option('-ps', '--patch_size', default=224, help='Patch size.', show_default=True)
@click.option('-pr', '--patch_resize', default=224, help='Patch resized.', show_default=True)
@click.option('-tg', '--target_names', default=[], multiple=True, help='Targets.', type=click.Path(exists=False), show_default=True)
@click.option('-df', '--dataset_df', default='', help='CSV file with train/val/test and target info.', type=click.Path(exists=False), show_default=True)
@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True)
@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn',
'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True)
@click.option('-nt', '--num_targets', default=1, help='Number of targets.', show_default=True)
@click.option('-ss', '--subsample_p', default=1.0, help='Subsample training set.', show_default=True)
@click.option('-ssv', '--subsample_p_val', default=-1., help='Subsample val set. If not set, defaults to that of training set', show_default=True)
@click.option('-t', '--num_training_images_epoch', default=-1, help='Number of training images per epoch. -1 means use all training images each epoch.', show_default=True)
@click.option('-lr', '--learning_rate', default=1e-2, help='Learning rate.', show_default=True)
@click.option('-tp', '--transform_platform', default='torch', help='Transform platform for nonsegmentation tasks.', type=click.Choice(['torch','albumentations']))
@click.option('-ne', '--n_epoch', default=10, help='Number of epochs.', show_default=True)
@click.option('-pi', '--patch_info_file', default='patch_info.db', help='Patch info file.', type=click.Path(exists=False), show_default=True)
@click.option('-tc', '--target_segmentation_class', default=[-1], multiple=True, help='Segmentation Class to finetune on.', show_default=True)
@click.option('-tt', '--target_threshold', default=[0.], multiple=True, help='Threshold to include target for segmentation if saving one class.', show_default=True)
@click.option('-ov', '--oversampling_factor', default=[1.], multiple=True, help='How much to oversample training set.', show_default=True)
@click.option('-sup', '--supplement', is_flag=True, help='Use the thresholding to supplement the original training set.', show_default=True)
@click.option('-bs', '--batch_size', default=10, help='Batch size.', show_default=True)
@click.option('-rt', '--run_test', is_flag=True, help='Output predictions for a batch to "test_predictions.npy". Use for debugging.', show_default=True)
@click.option('-mtb', '--mt_bce', is_flag=True, help='Run multi-target bce predictions on the annotations.', show_default=True)
@click.option('-po', '--prediction_output_dir', default='predictions', help='Where to output segmentation predictions.', type=click.Path(exists=False), show_default=True)
@click.option('-ee', '--extract_embedding', is_flag=True, help='Extract embeddings.', show_default=True)
@click.option('-em', '--extract_model', is_flag=True, help='Save entire torch model.', show_default=True)
@click.option('-bt', '--binary_threshold', default=0., help='If running binary classification on annotations, dichotomize selected annotation as such.', show_default=True)
@click.option('-prt', '--pretrain', is_flag=True, help='Pretrain on ImageNet.', show_default=True)
@click.option('-olf', '--overwrite_loss_fn', default='', help='Overwrite the default training loss functions with loss of choice.', type=click.Choice(['','bce','mse','focal','dice','gdl','ce']), show_default=True)
@click.option('-atl', '--adopt_training_loss', is_flag=True, help='Adopt training loss function for validation calculation.', show_default=True)
@click.option('-tdb', '--external_test_db', default='', help='External database of samples to test on.', type=click.Path(exists=False), show_default=True)
@click.option('-tdir', '--external_test_dir', default='', help='External directory of samples to test on.', type=click.Path(exists=False), show_default=True)
@click.option('-pb', '--prediction_basename', default=[''], multiple=True, help='For segmentation tasks, if supplied, can predict on these basenames rather than the entire test set. Only works for segmentation tasks for now', show_default=True)
@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False), show_default=True)
@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True)
@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True)
@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.', show_default=True)
@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.', show_default=True)
@click.option('-aol', '--apex_opt_level', default='O2', help='YAML file to add transforms from.', type=click.Choice(['O0','O1','O2','O3']), show_default=True)
@click.option('-ckp', '--checkpointing', is_flag=True, help='Save intermediate models to ./checkpoints.', show_default=True)
@click.option('-npy', '--npy_file', default='', help='Specify one file to extract embeddings from. Embeddings are output into predictions directory', type=click.Path(exists=False), show_default=True)
@click.option('-gpu', '--gpu_id', default=-1, help='Set GPU if 0 and greater.', show_default=True)
def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class, apex_opt_level, checkpointing, npy_file, gpu_id):
"""Train and predict using model for regression and classification tasks."""
# add separate pretrain ability on separating cell types, then transfer learn
# add pretrain and efficient net, pretraining remove last layer while loading state dict
if gpu_id>=0:
torch.cuda.set_device(gpu_id)
target_segmentation_class=list(map(int,target_segmentation_class))
target_threshold=list(map(float,target_threshold))
oversampling_factor=[(int(x) if float(x)>=1 else float(x)) for x in oversampling_factor]
other_annotations=list(other_annotations)
prediction_basename=list(filter(None,prediction_basename))
command_opts = dict(segmentation=segmentation,
prediction=prediction,
pos_annotation_class=pos_annotation_class,
other_annotations=other_annotations,
save_location=save_location,
pretrained_save_location=pretrained_save_location,
input_dir=input_dir,
patch_size=patch_size,
target_names=target_names,
dataset_df=dataset_df,
fix_names=fix_names,
architecture=architecture,
patch_resize=patch_resize,
imbalanced_correction=imbalanced_correction,
imbalanced_correction2=imbalanced_correction2,
classify_annotations=classify_annotations,
num_targets=num_targets,
subsample_p=subsample_p,
num_training_images_epoch=num_training_images_epoch,
lr=learning_rate,
transform_platform=transform_platform,
n_epoch=n_epoch,
patch_info_file=patch_info_file,
target_segmentation_class=target_segmentation_class,
target_threshold=target_threshold,
oversampling_factor=oversampling_factor,
supplement=supplement,
predict=prediction,
batch_size=batch_size,
run_test=run_test,
mt_bce=mt_bce,
prediction_output_dir=prediction_output_dir,
extract_embedding=extract_embedding,
extract_model=extract_model,
binary_threshold=binary_threshold,
subsample_p_val=subsample_p_val,
wd=1e-3,
scheduler_type='warm_restarts',
T_max=10,
T_mult=2,
eta_min=5e-8,
optimizer='adam',
n_hidden=100,
pretrain=pretrain,
training_curve='training_curve.png',
adopt_training_loss=adopt_training_loss,
external_test_db=external_test_db,
external_test_dir=external_test_dir,
prediction_basename=prediction_basename,
save_val_predictions=save_val_predictions,
custom_weights=custom_weights,
prediction_set=prediction_set,
user_transforms=dict(),
dilation_jitter=dict(),
seg_out_class=seg_out_class,
apex_opt_level=apex_opt_level,
checkpointing=checkpointing,
npy_file=npy_file)
training_opts = dict(normalization_file="normalization_parameters.pkl",
loss_fn='bce',
print_val_confusion=True,
prediction_save_path = 'predictions.db',
train_val_test_splits='train_val_test.pkl'
)
segmentation_training_opts = copy.deepcopy(training_opts)
segmentation_training_opts.update(dict(loss_fn='dice',#gdl dice+ce
normalization_file='normalization_segmentation.pkl',
fix_names=False,
))
if segmentation:
training_opts = segmentation_training_opts
for k in command_opts:
training_opts[k] = command_opts[k]
if classify_annotations:
if training_opts['num_targets']==1:
training_opts['loss_fn']='bce'
else:
training_opts['loss_fn']='ce'
if mt_bce:
training_opts['loss_fn']='bce'
if overwrite_loss_fn:
training_opts['loss_fn']=overwrite_loss_fn
if user_transforms_file and os.path.exists(user_transforms_file):
from yaml import load as yml_load
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
with open(user_transforms_file) as f:
training_opts['user_transforms']=yml_load(f,Loader=Loader)
if 'dilationjitter' in list(training_opts['user_transforms'].keys()):
training_opts['dilation_jitter']=training_opts['user_transforms'].pop('dilationjitter')
train_model_(training_opts)
if __name__=='__main__':
train()