import torch, os, numpy as np, pandas as pd
from pathflowai.utils import *
#from large_data_utils import *
from pathflowai.datasets import *
from pathflowai.models import *
from pathflowai.schedulers import *
from pathflowai.visualize import *
import copy
from pathflowai.sampler import ImbalancedDatasetSampler
import argparse
import sqlite3
#from nonechucks import SafeDataLoader as DataLoader
from torch.utils.data import DataLoader
import click
import pysnooper
CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
@click.group(context_settings= CONTEXT_SETTINGS)
@click.version_option(version='0.1')
def train():
pass
def return_model(training_opts):
model=generate_model(pretrain=training_opts['pretrain'],architecture=training_opts['architecture'],num_classes=training_opts['num_targets'], add_sigmoid=False, n_hidden=training_opts['n_hidden'], segmentation=training_opts['segmentation'])
if os.path.exists(training_opts['pretrained_save_location']) and not training_opts['predict']:
model_dict = torch.load(training_opts['pretrained_save_location'])
keys=list(model_dict.keys())
if not training_opts['segmentation']:
model_dict.update(dict(list(model.state_dict().items())[-2:]))#={k:model_dict[k] for k in keys[:-2]}
model.load_state_dict(model_dict) # this will likely break after pretraining?
elif os.path.exists(training_opts['save_location']) and training_opts['predict']:
model_dict = torch.load(training_opts['save_location'])
model.load_state_dict(model_dict)
if training_opts['extract_embedding']:
assert training_opts['extract_embedding']==training_opts['predict']
architecture=training_opts['architecture']
if hasattr(model,"fc"):
model.fc = model.fc[0]
elif hasattr(model,"output"):
model.output = model.output[0]
elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
model.classifier[6]=model.classifier[6][0]
if torch.cuda.is_available():
model.cuda()
return model
def run_test(dataloader):
for i,(X,y) in enumerate(dataloader):
np.save('X_test_{}.npy'.format(i),X.detach().cpu().numpy())#np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
np.save('y_test_{}.npy'.format(i),y.detach().cpu().numpy())
if i==5:
exit()
def return_trainer_opts(model,training_opts,dataloaders,num_train_batches):
return dict(model=model,
n_epoch=training_opts['n_epoch'],
validation_dataloader=dataloaders['val'],
optimizer_opts=dict(name=training_opts['optimizer'],
lr=training_opts['lr'],
weight_decay=training_opts['wd']),
scheduler_opts=dict(scheduler=training_opts['scheduler_type'],
lr_scheduler_decay=0.5,
T_max=training_opts['T_max'],
eta_min=training_opts['eta_min'],
T_mult=training_opts['T_mult']),
loss_fn=training_opts['loss_fn'],
num_train_batches=num_train_batches,
seg_out_class=training_opts['seg_out_class'],
apex_opt_level=training_opts['apex_opt_level'],
checkpointing=training_opts['checkpointing'])
def return_transformer(training_opts):
dataset_df = pd.read_csv(training_opts['dataset_df']) if os.path.exists(training_opts['dataset_df']) else create_train_val_test(training_opts['train_val_test_splits'],training_opts['patch_info_file'],training_opts['patch_size'])
dataset_opts=dict(dataset_df=dataset_df, set='pass', patch_info_file=training_opts['patch_info_file'], input_dir=training_opts['input_dir'], target_names=training_opts['target_names'], pos_annotation_class=training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'])
norm_dict = get_normalizer(training_opts['normalization_file'], dataset_opts)
transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms'])
transformers = get_data_transforms(**transform_opts)
return dataset_df,dataset_opts,transformers
def return_datasets(training_opts,dataset_df,transformers):
datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']}
# nc.SafeDataset(
print(datasets['train'])
if len(training_opts['target_segmentation_class']) > 1:
from functools import reduce
for i in range(1,len(training_opts['target_segmentation_class'])):
#print(training_opts['classify_annotations'])
datasets['train'].concat(DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i],n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter']))
#datasets['train']=reduce(lambda x,y: x.concat(y),[DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i]) for i in range(len(training_opts['target_segmentation_class']))])
print(datasets['train'])
if training_opts['supplement']:
old_train_set = copy.deepcopy(datasets['train'])
datasets['train']=DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=-1, target_threshold=training_opts['target_threshold'], oversampling_factor=1,n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'])
datasets['train'].concat(old_train_set)
if training_opts['subsample_p']<1.0:
datasets['train'].subsample(training_opts['subsample_p'])
if training_opts['subsample_p_val']<1.0:
if training_opts['subsample_p_val']==-1.:
training_opts['subsample_p_val']=training_opts['subsample_p']
if training_opts['subsample_p_val']<1.0:
datasets['val'].subsample(training_opts['subsample_p_val'])
if training_opts['classify_annotations']:
binarizer=datasets['train'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
datasets['val'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
datasets['test'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
training_opts['num_targets']=len(datasets['train'].targets)
for Set in ['train','val','test']:
print(datasets[Set].patch_info.iloc[:,6:].sum(axis=0))
if training_opts['prediction_set']!='test':
datasets['test']=datasets[training_opts['prediction_set']]
if training_opts['external_test_db'] and training_opts['external_test_dir']:
datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename'])
return datasets,training_opts,transform_opts
#@pysnooper.snoop('train_model.log')
def train_model_(training_opts):
"""Function to train, predict on model.
Parameters
----------
training_opts : dict
Training options populated from command line.
"""
cuda_available=torch.cuda.is_available()
model = return_model(training_opts)
dataset_df,dataset_opts,transformers=return_transformer(training_opts)
if training_opts['extract_embedding'] and training_opts['npy_file']:
dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"])
dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir'])
exit()
datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers)
if training_opts['num_training_images_epoch']>0:
num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size']
else:
num_train_batches = None
dataloaders={set: DataLoader(datasets[set], batch_size=training_opts['batch_size'], shuffle=(set=='train') if not (training_opts['imbalanced_correction'] and not training_opts['segmentation']) else False, num_workers=10, sampler=ImbalancedDatasetSampler(datasets[set]) if (training_opts['imbalanced_correction'] and set=='train' and not training_opts['segmentation']) else None) for set in ['train', 'val', 'test']}
print(dataloaders['train'].sampler) # FIXME VAL SEEMS TO BE MISSING DURING PREDICTION
print(dataloaders['val'].sampler)
if training_opts['run_test']: run_test(dataloaders['train'])
model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches)
if not training_opts['predict']:
trainer = ModelTrainer(**model_trainer_opts)
if training_opts['imbalanced_correction2']:
trainer.add_class_balance_loss(datasets['train'])
elif training_opts['custom_weights']:
trainer.add_class_balance_loss(datasets['train'],custom_weights=training_opts['custom_weights'])
if training_opts['adopt_training_loss']:
trainer.val_loss_fn = trainer.loss_fn
trainer.fit(dataloaders['train'], verbose=True, print_every=1, plot_training_curves=True, plot_save_file=training_opts['training_curve'], print_val_confusion=training_opts['print_val_confusion'], save_val_predictions=training_opts['save_val_predictions'])
torch.save(trainer.model.state_dict(),training_opts['save_location'])
else:
if training_opts['extract_model']:
dataset_opts.update(dict(target_segmentation_class=-1, target_threshold=training_opts['target_threshold'][0] if len(training_opts['target_threshold']) else 0., set='test', binary_threshold=training_opts['binary_threshold'], num_targets=training_opts['num_targets'], oversampling_factor=1))
torch.save(dict(model=model,dataset_opts=dataset_opts, transform_opts=transform_opts),'{}.{}'.format(training_opts['save_location'],'extracted_model.pkl'))
exit()
trainer = ModelTrainer(**model_trainer_opts)
if training_opts['segmentation']:
for ID, dataset in (datasets['test'].split_by_ID() if not training_opts['prediction_basename'] else datasets['test'].select_IDs(training_opts['prediction_basename'])):
dataloader = DataLoader(dataset, batch_size=training_opts['batch_size'], shuffle=False, num_workers=10)
if training_opts['run_test']:
for X,y in dataloader:
np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
exit()
y_pred = trainer.predict(dataloader)
print(ID,y_pred.shape)
segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0))
else:
extract_embedding=training_opts['extract_embedding']
if extract_embedding:
trainer.bce=False
y_pred = trainer.predict(dataloaders['test'])
patch_info = dataloaders['test'].dataset.patch_info
if extract_embedding:
patch_info['name']=patch_info.astype(str).apply(lambda x: '\n'.join(['{}:{}'.format(k,v) for k,v in x.to_dict().items()]),axis=1)#.apply(','.join,axis=1)
embeddings=pd.DataFrame(y_pred,index=patch_info['name'])
embeddings['ID']=patch_info['ID'].values
torch.save(dict(embeddings=embeddings,patch_info=patch_info),join(training_opts['prediction_output_dir'],'embeddings.pkl'))
else:
if len(y_pred.shape)>1 and y_pred.shape[1]>1:
annotations = np.vectorize(lambda x: x+'_pred')(np.arange(y_pred.shape[1]).astype(str)).tolist() # [training_opts['pos_annotation_class']]+training_opts['other_annotations']] if training_opts['classify_annotations'] else
for i in range(y_pred.shape[1]):
patch_info.loc[:,annotations[i]]=y_pred[:,i]
patch_info['y_pred']=y_pred if (training_opts['num_targets']==1 or not (training_opts['classify_annotations'] or training_opts['mt_bce'])) else y_pred.argmax(axis=1)
conn = sqlite3.connect(training_opts['prediction_save_path'])
patch_info.to_sql(str(training_opts['patch_size']),con=conn, if_exists=('replace')) # if not training_opts['prediction_basename'] else 'append'))
conn.close()
@train.command()
@click.option('-s', '--segmentation', is_flag=True, help='Segmentation task.', show_default=True)
@click.option('-p', '--prediction', is_flag=True, help='Predict on model.', show_default=True)
@click.option('-pa', '--pos_annotation_class', default='', help='Annotation Class from which to apply positive labels.', type=click.Path(exists=False), show_default=True)
@click.option('-oa', '--other_annotations', default=[], multiple=True, help='Annotations in image.', type=click.Path(exists=False), show_default=True)
@click.option('-o', '--save_location', default='', help='Model Save Location, append with pickle .pkl.', type=click.Path(exists=False), show_default=True)
@click.option('-pt', '--pretrained_save_location', default='', help='Model Save Location, append with pickle .pkl, pretrained by previous analysis to be finetuned.', type=click.Path(exists=False), show_default=True)
@click.option('-i', '--input_dir', default='', help='Input directory containing slides and everything.', type=click.Path(exists=False), show_default=True)
@click.option('-ps', '--patch_size', default=224, help='Patch size.', show_default=True)
@click.option('-pr', '--patch_resize', default=224, help='Patch resized.', show_default=True)
@click.option('-tg', '--target_names', default=[], multiple=True, help='Targets.', type=click.Path(exists=False), show_default=True)
@click.option('-df', '--dataset_df', default='', help='CSV file with train/val/test and target info.', type=click.Path(exists=False), show_default=True)
@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True)
@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn',
'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True)
@click.option('-nt', '--num_targets', default=1, help='Number of targets.', show_default=True)
@click.option('-ss', '--subsample_p', default=1.0, help='Subsample training set.', show_default=True)
@click.option('-ssv', '--subsample_p_val', default=-1., help='Subsample val set. If not set, defaults to that of training set', show_default=True)
@click.option('-t', '--num_training_images_epoch', default=-1, help='Number of training images per epoch. -1 means use all training images each epoch.', show_default=True)
@click.option('-lr', '--learning_rate', default=1e-2, help='Learning rate.', show_default=True)
@click.option('-tp', '--transform_platform', default='torch', help='Transform platform for nonsegmentation tasks.', type=click.Choice(['torch','albumentations']))
@click.option('-ne', '--n_epoch', default=10, help='Number of epochs.', show_default=True)
@click.option('-pi', '--patch_info_file', default='patch_info.db', help='Patch info file.', type=click.Path(exists=False), show_default=True)
@click.option('-tc', '--target_segmentation_class', default=[-1], multiple=True, help='Segmentation Class to finetune on.', show_default=True)
@click.option('-tt', '--target_threshold', default=[0.], multiple=True, help='Threshold to include target for segmentation if saving one class.', show_default=True)
@click.option('-ov', '--oversampling_factor', default=[1.], multiple=True, help='How much to oversample training set.', show_default=True)
@click.option('-sup', '--supplement', is_flag=True, help='Use the thresholding to supplement the original training set.', show_default=True)
@click.option('-bs', '--batch_size', default=10, help='Batch size.', show_default=True)
@click.option('-rt', '--run_test', is_flag=True, help='Output predictions for a batch to "test_predictions.npy". Use for debugging.', show_default=True)
@click.option('-mtb', '--mt_bce', is_flag=True, help='Run multi-target bce predictions on the annotations.', show_default=True)
@click.option('-po', '--prediction_output_dir', default='predictions', help='Where to output segmentation predictions.', type=click.Path(exists=False), show_default=True)
@click.option('-ee', '--extract_embedding', is_flag=True, help='Extract embeddings.', show_default=True)
@click.option('-em', '--extract_model', is_flag=True, help='Save entire torch model.', show_default=True)
@click.option('-bt', '--binary_threshold', default=0., help='If running binary classification on annotations, dichotomize selected annotation as such.', show_default=True)
@click.option('-prt', '--pretrain', is_flag=True, help='Pretrain on ImageNet.', show_default=True)
@click.option('-olf', '--overwrite_loss_fn', default='', help='Overwrite the default training loss functions with loss of choice.', type=click.Choice(['','bce','mse','focal','dice','gdl','ce']), show_default=True)
@click.option('-atl', '--adopt_training_loss', is_flag=True, help='Adopt training loss function for validation calculation.', show_default=True)
@click.option('-tdb', '--external_test_db', default='', help='External database of samples to test on.', type=click.Path(exists=False), show_default=True)
@click.option('-tdir', '--external_test_dir', default='', help='External directory of samples to test on.', type=click.Path(exists=False), show_default=True)
@click.option('-pb', '--prediction_basename', default=[''], multiple=True, help='For segmentation tasks, if supplied, can predict on these basenames rather than the entire test set. Only works for segmentation tasks for now', show_default=True)
@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False), show_default=True)
@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True)
@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True)
@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.', show_default=True)
@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.', show_default=True)
@click.option('-aol', '--apex_opt_level', default='O2', help='YAML file to add transforms from.', type=click.Choice(['O0','O1','O2','O3']), show_default=True)
@click.option('-ckp', '--checkpointing', is_flag=True, help='Save intermediate models to ./checkpoints.', show_default=True)
@click.option('-npy', '--npy_file', default='', help='Specify one file to extract embeddings from. Embeddings are output into predictions directory', type=click.Path(exists=False), show_default=True)
@click.option('-gpu', '--gpu_id', default=-1, help='Set GPU if 0 and greater.', show_default=True)
def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class, apex_opt_level, checkpointing, npy_file, gpu_id):
"""Train and predict using model for regression and classification tasks."""
# add separate pretrain ability on separating cell types, then transfer learn
# add pretrain and efficient net, pretraining remove last layer while loading state dict
if gpu_id>=0:
torch.cuda.set_device(gpu_id)
target_segmentation_class=list(map(int,target_segmentation_class))
target_threshold=list(map(float,target_threshold))
oversampling_factor=[(int(x) if float(x)>=1 else float(x)) for x in oversampling_factor]
other_annotations=list(other_annotations)
prediction_basename=list(filter(None,prediction_basename))
command_opts = dict(segmentation=segmentation,
prediction=prediction,
pos_annotation_class=pos_annotation_class,
other_annotations=other_annotations,
save_location=save_location,
pretrained_save_location=pretrained_save_location,
input_dir=input_dir,
patch_size=patch_size,
target_names=target_names,
dataset_df=dataset_df,
fix_names=fix_names,
architecture=architecture,
patch_resize=patch_resize,
imbalanced_correction=imbalanced_correction,
imbalanced_correction2=imbalanced_correction2,
classify_annotations=classify_annotations,
num_targets=num_targets,
subsample_p=subsample_p,
num_training_images_epoch=num_training_images_epoch,
lr=learning_rate,
transform_platform=transform_platform,
n_epoch=n_epoch,
patch_info_file=patch_info_file,
target_segmentation_class=target_segmentation_class,
target_threshold=target_threshold,
oversampling_factor=oversampling_factor,
supplement=supplement,
predict=prediction,
batch_size=batch_size,
run_test=run_test,
mt_bce=mt_bce,
prediction_output_dir=prediction_output_dir,
extract_embedding=extract_embedding,
extract_model=extract_model,
binary_threshold=binary_threshold,
subsample_p_val=subsample_p_val,
wd=1e-3,
scheduler_type='warm_restarts',
T_max=10,
T_mult=2,
eta_min=5e-8,
optimizer='adam',
n_hidden=100,
pretrain=pretrain,
training_curve='training_curve.png',
adopt_training_loss=adopt_training_loss,
external_test_db=external_test_db,
external_test_dir=external_test_dir,
prediction_basename=prediction_basename,
save_val_predictions=save_val_predictions,
custom_weights=custom_weights,
prediction_set=prediction_set,
user_transforms=dict(),
dilation_jitter=dict(),
seg_out_class=seg_out_class,
apex_opt_level=apex_opt_level,
checkpointing=checkpointing,
npy_file=npy_file)
training_opts = dict(normalization_file="normalization_parameters.pkl",
loss_fn='bce',
print_val_confusion=True,
prediction_save_path = 'predictions.db',
train_val_test_splits='train_val_test.pkl'
)
segmentation_training_opts = copy.deepcopy(training_opts)
segmentation_training_opts.update(dict(loss_fn='dice',#gdl dice+ce
normalization_file='normalization_segmentation.pkl',
fix_names=False,
))
if segmentation:
training_opts = segmentation_training_opts
for k in command_opts:
training_opts[k] = command_opts[k]
if classify_annotations:
if training_opts['num_targets']==1:
training_opts['loss_fn']='bce'
else:
training_opts['loss_fn']='ce'
if mt_bce:
training_opts['loss_fn']='bce'
if overwrite_loss_fn:
training_opts['loss_fn']=overwrite_loss_fn
if user_transforms_file and os.path.exists(user_transforms_file):
from yaml import load as yml_load
try:
from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
from yaml import Loader, Dumper
with open(user_transforms_file) as f:
training_opts['user_transforms']=yml_load(f,Loader=Loader)
if 'dilationjitter' in list(training_opts['user_transforms'].keys()):
training_opts['dilation_jitter']=training_opts['user_transforms'].pop('dilationjitter')
train_model_(training_opts)
if __name__=='__main__':
train()