--- a
+++ b/pathflowai/model_training.py
@@ -0,0 +1,391 @@
+import torch, os, numpy as np, pandas as pd
+from pathflowai.utils import *
+#from large_data_utils import *
+from pathflowai.datasets import *
+from pathflowai.models import *
+from pathflowai.schedulers import *
+from pathflowai.visualize import *
+import copy
+from pathflowai.sampler import ImbalancedDatasetSampler
+import argparse
+import sqlite3
+#from nonechucks import SafeDataLoader as DataLoader
+from torch.utils.data import DataLoader
+import click
+import pysnooper
+
+CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
+
+@click.group(context_settings= CONTEXT_SETTINGS)
+@click.version_option(version='0.1')
+def train():
+	pass
+
+def return_model(training_opts):
+	model=generate_model(pretrain=training_opts['pretrain'],architecture=training_opts['architecture'],num_classes=training_opts['num_targets'], add_sigmoid=False, n_hidden=training_opts['n_hidden'], segmentation=training_opts['segmentation'])
+
+	if os.path.exists(training_opts['pretrained_save_location']) and not training_opts['predict']:
+		model_dict = torch.load(training_opts['pretrained_save_location'])
+		keys=list(model_dict.keys())
+		if not training_opts['segmentation']:
+			model_dict.update(dict(list(model.state_dict().items())[-2:]))#={k:model_dict[k] for k in keys[:-2]}
+		model.load_state_dict(model_dict) # this will likely break after pretraining?
+	elif os.path.exists(training_opts['save_location']) and training_opts['predict']:
+		model_dict = torch.load(training_opts['save_location'])
+		model.load_state_dict(model_dict)
+	if training_opts['extract_embedding']:
+		assert training_opts['extract_embedding']==training_opts['predict']
+		architecture=training_opts['architecture']
+		if hasattr(model,"fc"):
+			model.fc = model.fc[0]
+		elif hasattr(model,"output"):
+			model.output = model.output[0]
+		elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'):
+			model.classifier[6]=model.classifier[6][0]
+	if torch.cuda.is_available():
+		model.cuda()
+	return model
+
+def run_test(dataloader):
+	for i,(X,y) in enumerate(dataloader):
+		np.save('X_test_{}.npy'.format(i),X.detach().cpu().numpy())#np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
+		np.save('y_test_{}.npy'.format(i),y.detach().cpu().numpy())
+		if i==5:
+			exit()
+
+def return_trainer_opts(model,training_opts,dataloaders,num_train_batches):
+	return dict(model=model,
+				n_epoch=training_opts['n_epoch'],
+				validation_dataloader=dataloaders['val'],
+				optimizer_opts=dict(name=training_opts['optimizer'],
+									lr=training_opts['lr'],
+									weight_decay=training_opts['wd']),
+				scheduler_opts=dict(scheduler=training_opts['scheduler_type'],
+									lr_scheduler_decay=0.5,
+									T_max=training_opts['T_max'],
+									eta_min=training_opts['eta_min'],
+									T_mult=training_opts['T_mult']),
+				loss_fn=training_opts['loss_fn'],
+				num_train_batches=num_train_batches,
+				seg_out_class=training_opts['seg_out_class'],
+				apex_opt_level=training_opts['apex_opt_level'],
+				checkpointing=training_opts['checkpointing'])
+
+def return_transformer(training_opts):
+	dataset_df = pd.read_csv(training_opts['dataset_df']) if os.path.exists(training_opts['dataset_df']) else create_train_val_test(training_opts['train_val_test_splits'],training_opts['patch_info_file'],training_opts['patch_size'])
+
+	dataset_opts=dict(dataset_df=dataset_df, set='pass', patch_info_file=training_opts['patch_info_file'], input_dir=training_opts['input_dir'], target_names=training_opts['target_names'], pos_annotation_class=training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'])
+
+	norm_dict = get_normalizer(training_opts['normalization_file'], dataset_opts)
+
+	transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms'])
+
+	transformers = get_data_transforms(**transform_opts)
+	return dataset_df,dataset_opts,transformers
+
+def return_datasets(training_opts,dataset_df,transformers):
+	datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']}
+	# nc.SafeDataset(
+	print(datasets['train'])
+
+	if len(training_opts['target_segmentation_class']) > 1:
+		from functools import reduce
+		for i in range(1,len(training_opts['target_segmentation_class'])):
+			#print(training_opts['classify_annotations'])
+			datasets['train'].concat(DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i],n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter']))
+		#datasets['train']=reduce(lambda x,y: x.concat(y),[DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i]) for i in range(len(training_opts['target_segmentation_class']))])
+		print(datasets['train'])
+
+	if training_opts['supplement']:
+		old_train_set = copy.deepcopy(datasets['train'])
+		datasets['train']=DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=-1, target_threshold=training_opts['target_threshold'], oversampling_factor=1,n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'])
+		datasets['train'].concat(old_train_set)
+
+	if training_opts['subsample_p']<1.0:
+		datasets['train'].subsample(training_opts['subsample_p'])
+
+	if training_opts['subsample_p_val']<1.0:
+		if training_opts['subsample_p_val']==-1.:
+			training_opts['subsample_p_val']=training_opts['subsample_p']
+		if training_opts['subsample_p_val']<1.0:
+			datasets['val'].subsample(training_opts['subsample_p_val'])
+
+	if training_opts['classify_annotations']:
+		binarizer=datasets['train'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
+		datasets['val'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
+		datasets['test'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold'])
+		training_opts['num_targets']=len(datasets['train'].targets)
+
+	for Set in ['train','val','test']:
+		print(datasets[Set].patch_info.iloc[:,6:].sum(axis=0))
+
+	if training_opts['prediction_set']!='test':
+		datasets['test']=datasets[training_opts['prediction_set']]
+
+	if training_opts['external_test_db'] and training_opts['external_test_dir']:
+		datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename'])
+	return datasets,training_opts,transform_opts
+
+#@pysnooper.snoop('train_model.log')
+def train_model_(training_opts):
+	"""Function to train, predict on model.
+
+	Parameters
+	----------
+	training_opts : dict
+		Training options populated from command line.
+
+	"""
+
+	cuda_available=torch.cuda.is_available()
+
+	model = return_model(training_opts)
+
+	dataset_df,dataset_opts,transformers=return_transformer(training_opts)
+
+	if training_opts['extract_embedding'] and training_opts['npy_file']:
+		dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"])
+		dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir'])
+		exit()
+
+	datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers)
+
+	if training_opts['num_training_images_epoch']>0:
+		num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size']
+	else:
+		num_train_batches = None
+
+	dataloaders={set: DataLoader(datasets[set], batch_size=training_opts['batch_size'], shuffle=(set=='train') if not (training_opts['imbalanced_correction'] and not training_opts['segmentation']) else False, num_workers=10, sampler=ImbalancedDatasetSampler(datasets[set]) if (training_opts['imbalanced_correction'] and set=='train' and not training_opts['segmentation']) else None) for set in ['train', 'val', 'test']}
+
+	print(dataloaders['train'].sampler) # FIXME VAL SEEMS TO BE MISSING DURING PREDICTION
+	print(dataloaders['val'].sampler)
+
+	if training_opts['run_test']: run_test(dataloaders['train'])
+
+	model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches)
+
+	if not training_opts['predict']:
+
+		trainer = ModelTrainer(**model_trainer_opts)
+
+		if training_opts['imbalanced_correction2']:
+			trainer.add_class_balance_loss(datasets['train'])
+		elif training_opts['custom_weights']:
+			trainer.add_class_balance_loss(datasets['train'],custom_weights=training_opts['custom_weights'])
+
+		if training_opts['adopt_training_loss']:
+			trainer.val_loss_fn = trainer.loss_fn
+
+		trainer.fit(dataloaders['train'], verbose=True, print_every=1, plot_training_curves=True, plot_save_file=training_opts['training_curve'], print_val_confusion=training_opts['print_val_confusion'], save_val_predictions=training_opts['save_val_predictions'])
+
+		torch.save(trainer.model.state_dict(),training_opts['save_location'])
+
+	else:
+
+		if training_opts['extract_model']:
+			dataset_opts.update(dict(target_segmentation_class=-1, target_threshold=training_opts['target_threshold'][0] if len(training_opts['target_threshold']) else 0., set='test', binary_threshold=training_opts['binary_threshold'], num_targets=training_opts['num_targets'], oversampling_factor=1))
+			torch.save(dict(model=model,dataset_opts=dataset_opts, transform_opts=transform_opts),'{}.{}'.format(training_opts['save_location'],'extracted_model.pkl'))
+			exit()
+
+		trainer = ModelTrainer(**model_trainer_opts)
+
+		if training_opts['segmentation']:
+			for ID, dataset in (datasets['test'].split_by_ID() if not training_opts['prediction_basename'] else datasets['test'].select_IDs(training_opts['prediction_basename'])):
+				dataloader = DataLoader(dataset, batch_size=training_opts['batch_size'], shuffle=False, num_workers=10)
+				if training_opts['run_test']:
+					for X,y in dataloader:
+						np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy())
+						exit()
+				y_pred = trainer.predict(dataloader)
+				print(ID,y_pred.shape)
+				segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0))
+		else:
+			extract_embedding=training_opts['extract_embedding']
+			if extract_embedding:
+				trainer.bce=False
+
+			y_pred = trainer.predict(dataloaders['test'])
+
+			patch_info = dataloaders['test'].dataset.patch_info
+
+			if extract_embedding:
+				patch_info['name']=patch_info.astype(str).apply(lambda x: '\n'.join(['{}:{}'.format(k,v) for k,v in x.to_dict().items()]),axis=1)#.apply(','.join,axis=1)
+				embeddings=pd.DataFrame(y_pred,index=patch_info['name'])
+				embeddings['ID']=patch_info['ID'].values
+				torch.save(dict(embeddings=embeddings,patch_info=patch_info),join(training_opts['prediction_output_dir'],'embeddings.pkl'))
+
+			else:
+				if len(y_pred.shape)>1 and y_pred.shape[1]>1:
+					annotations = np.vectorize(lambda x: x+'_pred')(np.arange(y_pred.shape[1]).astype(str)).tolist() # [training_opts['pos_annotation_class']]+training_opts['other_annotations']] if training_opts['classify_annotations'] else
+					for i in range(y_pred.shape[1]):
+						patch_info.loc[:,annotations[i]]=y_pred[:,i]
+				patch_info['y_pred']=y_pred if (training_opts['num_targets']==1 or not (training_opts['classify_annotations'] or training_opts['mt_bce'])) else y_pred.argmax(axis=1)
+
+				conn = sqlite3.connect(training_opts['prediction_save_path'])
+				patch_info.to_sql(str(training_opts['patch_size']),con=conn, if_exists=('replace')) # if not training_opts['prediction_basename'] else 'append'))
+				conn.close()
+
+@train.command()
+@click.option('-s', '--segmentation', is_flag=True, help='Segmentation task.', show_default=True)
+@click.option('-p', '--prediction', is_flag=True, help='Predict on model.', show_default=True)
+@click.option('-pa', '--pos_annotation_class', default='', help='Annotation Class from which to apply positive labels.', type=click.Path(exists=False), show_default=True)
+@click.option('-oa', '--other_annotations', default=[], multiple=True, help='Annotations in image.', type=click.Path(exists=False), show_default=True)
+@click.option('-o', '--save_location', default='', help='Model Save Location, append with pickle .pkl.', type=click.Path(exists=False), show_default=True)
+@click.option('-pt', '--pretrained_save_location', default='', help='Model Save Location, append with pickle .pkl, pretrained by previous analysis to be finetuned.', type=click.Path(exists=False), show_default=True)
+@click.option('-i', '--input_dir', default='', help='Input directory containing slides and everything.', type=click.Path(exists=False), show_default=True)
+@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
+@click.option('-pr', '--patch_resize', default=224, help='Patch resized.',  show_default=True)
+@click.option('-tg', '--target_names', default=[], multiple=True, help='Targets.', type=click.Path(exists=False), show_default=True)
+@click.option('-df', '--dataset_df', default='', help='CSV file with train/val/test and target info.', type=click.Path(exists=False), show_default=True)
+@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True)
+@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201',
+											'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn',
+											'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True)
+@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
+@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True)
+@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True)
+@click.option('-nt', '--num_targets', default=1, help='Number of targets.', show_default=True)
+@click.option('-ss', '--subsample_p', default=1.0, help='Subsample training set.', show_default=True)
+@click.option('-ssv', '--subsample_p_val', default=-1., help='Subsample val set. If not set, defaults to that of training set', show_default=True)
+@click.option('-t', '--num_training_images_epoch', default=-1, help='Number of training images per epoch. -1 means use all training images each epoch.', show_default=True)
+@click.option('-lr', '--learning_rate', default=1e-2, help='Learning rate.', show_default=True)
+@click.option('-tp', '--transform_platform', default='torch', help='Transform platform for nonsegmentation tasks.', type=click.Choice(['torch','albumentations']))
+@click.option('-ne', '--n_epoch', default=10, help='Number of epochs.', show_default=True)
+@click.option('-pi', '--patch_info_file', default='patch_info.db', help='Patch info file.', type=click.Path(exists=False), show_default=True)
+@click.option('-tc', '--target_segmentation_class', default=[-1], multiple=True, help='Segmentation Class to finetune on.',  show_default=True)
+@click.option('-tt', '--target_threshold', default=[0.], multiple=True, help='Threshold to include target for segmentation if saving one class.',  show_default=True)
+@click.option('-ov', '--oversampling_factor', default=[1.], multiple=True, help='How much to oversample training set.',  show_default=True)
+@click.option('-sup', '--supplement', is_flag=True, help='Use the thresholding to supplement the original training set.', show_default=True)
+@click.option('-bs', '--batch_size', default=10, help='Batch size.',  show_default=True)
+@click.option('-rt', '--run_test', is_flag=True, help='Output predictions for a batch to "test_predictions.npy". Use for debugging.',  show_default=True)
+@click.option('-mtb', '--mt_bce', is_flag=True, help='Run multi-target bce predictions on the annotations.',  show_default=True)
+@click.option('-po', '--prediction_output_dir', default='predictions', help='Where to output segmentation predictions.', type=click.Path(exists=False), show_default=True)
+@click.option('-ee', '--extract_embedding', is_flag=True, help='Extract embeddings.',  show_default=True)
+@click.option('-em', '--extract_model', is_flag=True, help='Save entire torch model.',  show_default=True)
+@click.option('-bt', '--binary_threshold', default=0., help='If running binary classification on annotations, dichotomize selected annotation as such.',  show_default=True)
+@click.option('-prt', '--pretrain', is_flag=True, help='Pretrain on ImageNet.', show_default=True)
+@click.option('-olf', '--overwrite_loss_fn', default='', help='Overwrite the default training loss functions with loss of choice.', type=click.Choice(['','bce','mse','focal','dice','gdl','ce']), show_default=True)
+@click.option('-atl', '--adopt_training_loss', is_flag=True, help='Adopt training loss function for validation calculation.', show_default=True)
+@click.option('-tdb', '--external_test_db', default='', help='External database of samples to test on.', type=click.Path(exists=False), show_default=True)
+@click.option('-tdir', '--external_test_dir', default='', help='External directory of samples to test on.', type=click.Path(exists=False), show_default=True)
+@click.option('-pb', '--prediction_basename', default=[''], multiple=True, help='For segmentation tasks, if supplied, can predict on these basenames rather than the entire test set. Only works for segmentation tasks for now',  show_default=True)
+@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False),  show_default=True)
+@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True)
+@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True)
+@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.',  show_default=True)
+@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.',  show_default=True)
+@click.option('-aol', '--apex_opt_level', default='O2', help='YAML file to add transforms from.', type=click.Choice(['O0','O1','O2','O3']), show_default=True)
+@click.option('-ckp', '--checkpointing', is_flag=True, help='Save intermediate models to ./checkpoints.',  show_default=True)
+@click.option('-npy', '--npy_file', default='', help='Specify one file to extract embeddings from. Embeddings are output into predictions directory', type=click.Path(exists=False), show_default=True)
+@click.option('-gpu', '--gpu_id', default=-1, help='Set GPU if 0 and greater.',  show_default=True)
+def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class, apex_opt_level, checkpointing, npy_file, gpu_id):
+	"""Train and predict using model for regression and classification tasks."""
+	# add separate pretrain ability on separating cell types, then transfer learn
+	# add pretrain and efficient net, pretraining remove last layer while loading state dict
+	if gpu_id>=0:
+		torch.cuda.set_device(gpu_id)
+	target_segmentation_class=list(map(int,target_segmentation_class))
+	target_threshold=list(map(float,target_threshold))
+	oversampling_factor=[(int(x) if float(x)>=1 else float(x)) for x in oversampling_factor]
+	other_annotations=list(other_annotations)
+	prediction_basename=list(filter(None,prediction_basename))
+	command_opts = dict(segmentation=segmentation,
+						prediction=prediction,
+						pos_annotation_class=pos_annotation_class,
+						other_annotations=other_annotations,
+						save_location=save_location,
+						pretrained_save_location=pretrained_save_location,
+						input_dir=input_dir,
+						patch_size=patch_size,
+						target_names=target_names,
+						dataset_df=dataset_df,
+						fix_names=fix_names,
+						architecture=architecture,
+						patch_resize=patch_resize,
+						imbalanced_correction=imbalanced_correction,
+						imbalanced_correction2=imbalanced_correction2,
+						classify_annotations=classify_annotations,
+						num_targets=num_targets,
+						subsample_p=subsample_p,
+						num_training_images_epoch=num_training_images_epoch,
+						lr=learning_rate,
+						transform_platform=transform_platform,
+						n_epoch=n_epoch,
+						patch_info_file=patch_info_file,
+						target_segmentation_class=target_segmentation_class,
+						target_threshold=target_threshold,
+						oversampling_factor=oversampling_factor,
+						supplement=supplement,
+						predict=prediction,
+						batch_size=batch_size,
+						run_test=run_test,
+						mt_bce=mt_bce,
+						prediction_output_dir=prediction_output_dir,
+						extract_embedding=extract_embedding,
+						extract_model=extract_model,
+						binary_threshold=binary_threshold,
+						subsample_p_val=subsample_p_val,
+						wd=1e-3,
+						scheduler_type='warm_restarts',
+						T_max=10,
+						T_mult=2,
+						eta_min=5e-8,
+						optimizer='adam',
+						n_hidden=100,
+						pretrain=pretrain,
+						training_curve='training_curve.png',
+						adopt_training_loss=adopt_training_loss,
+						external_test_db=external_test_db,
+						external_test_dir=external_test_dir,
+						prediction_basename=prediction_basename,
+						save_val_predictions=save_val_predictions,
+						custom_weights=custom_weights,
+						prediction_set=prediction_set,
+						user_transforms=dict(),
+						dilation_jitter=dict(),
+						seg_out_class=seg_out_class,
+						apex_opt_level=apex_opt_level,
+						checkpointing=checkpointing,
+						npy_file=npy_file)
+
+	training_opts = dict(normalization_file="normalization_parameters.pkl",
+						 loss_fn='bce',
+						 print_val_confusion=True,
+						 prediction_save_path = 'predictions.db',
+						 train_val_test_splits='train_val_test.pkl'
+						 )
+	segmentation_training_opts = copy.deepcopy(training_opts)
+	segmentation_training_opts.update(dict(loss_fn='dice',#gdl dice+ce
+											normalization_file='normalization_segmentation.pkl',
+											fix_names=False,
+											))
+	if segmentation:
+		training_opts = segmentation_training_opts
+	for k in command_opts:
+		training_opts[k] = command_opts[k]
+	if classify_annotations:
+		if training_opts['num_targets']==1:
+			training_opts['loss_fn']='bce'
+		else:
+			training_opts['loss_fn']='ce'
+	if mt_bce:
+		training_opts['loss_fn']='bce'
+	if overwrite_loss_fn:
+		training_opts['loss_fn']=overwrite_loss_fn
+
+	if user_transforms_file and os.path.exists(user_transforms_file):
+		from yaml import load as yml_load
+		try:
+			from yaml import CLoader as Loader, CDumper as Dumper
+		except ImportError:
+			from yaml import Loader, Dumper
+		with open(user_transforms_file) as f:
+			training_opts['user_transforms']=yml_load(f,Loader=Loader)
+			if 'dilationjitter' in list(training_opts['user_transforms'].keys()):
+				training_opts['dilation_jitter']=training_opts['user_transforms'].pop('dilationjitter')
+
+	train_model_(training_opts)
+
+
+if __name__=='__main__':
+	train()