|
a |
|
b/pathflowai/model_training.py |
|
|
1 |
import torch, os, numpy as np, pandas as pd |
|
|
2 |
from pathflowai.utils import * |
|
|
3 |
#from large_data_utils import * |
|
|
4 |
from pathflowai.datasets import * |
|
|
5 |
from pathflowai.models import * |
|
|
6 |
from pathflowai.schedulers import * |
|
|
7 |
from pathflowai.visualize import * |
|
|
8 |
import copy |
|
|
9 |
from pathflowai.sampler import ImbalancedDatasetSampler |
|
|
10 |
import argparse |
|
|
11 |
import sqlite3 |
|
|
12 |
#from nonechucks import SafeDataLoader as DataLoader |
|
|
13 |
from torch.utils.data import DataLoader |
|
|
14 |
import click |
|
|
15 |
import pysnooper |
|
|
16 |
|
|
|
17 |
CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90) |
|
|
18 |
|
|
|
19 |
@click.group(context_settings= CONTEXT_SETTINGS) |
|
|
20 |
@click.version_option(version='0.1') |
|
|
21 |
def train(): |
|
|
22 |
pass |
|
|
23 |
|
|
|
24 |
def return_model(training_opts): |
|
|
25 |
model=generate_model(pretrain=training_opts['pretrain'],architecture=training_opts['architecture'],num_classes=training_opts['num_targets'], add_sigmoid=False, n_hidden=training_opts['n_hidden'], segmentation=training_opts['segmentation']) |
|
|
26 |
|
|
|
27 |
if os.path.exists(training_opts['pretrained_save_location']) and not training_opts['predict']: |
|
|
28 |
model_dict = torch.load(training_opts['pretrained_save_location']) |
|
|
29 |
keys=list(model_dict.keys()) |
|
|
30 |
if not training_opts['segmentation']: |
|
|
31 |
model_dict.update(dict(list(model.state_dict().items())[-2:]))#={k:model_dict[k] for k in keys[:-2]} |
|
|
32 |
model.load_state_dict(model_dict) # this will likely break after pretraining? |
|
|
33 |
elif os.path.exists(training_opts['save_location']) and training_opts['predict']: |
|
|
34 |
model_dict = torch.load(training_opts['save_location']) |
|
|
35 |
model.load_state_dict(model_dict) |
|
|
36 |
if training_opts['extract_embedding']: |
|
|
37 |
assert training_opts['extract_embedding']==training_opts['predict'] |
|
|
38 |
architecture=training_opts['architecture'] |
|
|
39 |
if hasattr(model,"fc"): |
|
|
40 |
model.fc = model.fc[0] |
|
|
41 |
elif hasattr(model,"output"): |
|
|
42 |
model.output = model.output[0] |
|
|
43 |
elif architecture.startswith('alexnet') or architecture.startswith('vgg') or architecture.startswith('densenet'): |
|
|
44 |
model.classifier[6]=model.classifier[6][0] |
|
|
45 |
if torch.cuda.is_available(): |
|
|
46 |
model.cuda() |
|
|
47 |
return model |
|
|
48 |
|
|
|
49 |
def run_test(dataloader): |
|
|
50 |
for i,(X,y) in enumerate(dataloader): |
|
|
51 |
np.save('X_test_{}.npy'.format(i),X.detach().cpu().numpy())#np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy()) |
|
|
52 |
np.save('y_test_{}.npy'.format(i),y.detach().cpu().numpy()) |
|
|
53 |
if i==5: |
|
|
54 |
exit() |
|
|
55 |
|
|
|
56 |
def return_trainer_opts(model,training_opts,dataloaders,num_train_batches): |
|
|
57 |
return dict(model=model, |
|
|
58 |
n_epoch=training_opts['n_epoch'], |
|
|
59 |
validation_dataloader=dataloaders['val'], |
|
|
60 |
optimizer_opts=dict(name=training_opts['optimizer'], |
|
|
61 |
lr=training_opts['lr'], |
|
|
62 |
weight_decay=training_opts['wd']), |
|
|
63 |
scheduler_opts=dict(scheduler=training_opts['scheduler_type'], |
|
|
64 |
lr_scheduler_decay=0.5, |
|
|
65 |
T_max=training_opts['T_max'], |
|
|
66 |
eta_min=training_opts['eta_min'], |
|
|
67 |
T_mult=training_opts['T_mult']), |
|
|
68 |
loss_fn=training_opts['loss_fn'], |
|
|
69 |
num_train_batches=num_train_batches, |
|
|
70 |
seg_out_class=training_opts['seg_out_class'], |
|
|
71 |
apex_opt_level=training_opts['apex_opt_level'], |
|
|
72 |
checkpointing=training_opts['checkpointing']) |
|
|
73 |
|
|
|
74 |
def return_transformer(training_opts): |
|
|
75 |
dataset_df = pd.read_csv(training_opts['dataset_df']) if os.path.exists(training_opts['dataset_df']) else create_train_val_test(training_opts['train_val_test_splits'],training_opts['patch_info_file'],training_opts['patch_size']) |
|
|
76 |
|
|
|
77 |
dataset_opts=dict(dataset_df=dataset_df, set='pass', patch_info_file=training_opts['patch_info_file'], input_dir=training_opts['input_dir'], target_names=training_opts['target_names'], pos_annotation_class=training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations']) |
|
|
78 |
|
|
|
79 |
norm_dict = get_normalizer(training_opts['normalization_file'], dataset_opts) |
|
|
80 |
|
|
|
81 |
transform_opts=dict(patch_size = training_opts['patch_resize'], mean=norm_dict['mean'], std=norm_dict['std'], resize=True, transform_platform=training_opts['transform_platform'] if not training_opts['segmentation'] else 'albumentations', user_transforms=training_opts['user_transforms']) |
|
|
82 |
|
|
|
83 |
transformers = get_data_transforms(**transform_opts) |
|
|
84 |
return dataset_df,dataset_opts,transformers |
|
|
85 |
|
|
|
86 |
def return_datasets(training_opts,dataset_df,transformers): |
|
|
87 |
datasets= {set: DynamicImageDataset(dataset_df, set, training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][0] if set=='train' else -1, target_threshold=training_opts['target_threshold'][0], oversampling_factor=training_opts['oversampling_factor'][0] if set=='train' else 1, n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'], classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'] if set == 'train' else {}) for set in ['train','val','test']} |
|
|
88 |
# nc.SafeDataset( |
|
|
89 |
print(datasets['train']) |
|
|
90 |
|
|
|
91 |
if len(training_opts['target_segmentation_class']) > 1: |
|
|
92 |
from functools import reduce |
|
|
93 |
for i in range(1,len(training_opts['target_segmentation_class'])): |
|
|
94 |
#print(training_opts['classify_annotations']) |
|
|
95 |
datasets['train'].concat(DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i],n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter'])) |
|
|
96 |
#datasets['train']=reduce(lambda x,y: x.concat(y),[DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=training_opts['target_segmentation_class'][i], target_threshold=training_opts['target_threshold'][i], oversampling_factor=training_opts['oversampling_factor'][i]) for i in range(len(training_opts['target_segmentation_class']))]) |
|
|
97 |
print(datasets['train']) |
|
|
98 |
|
|
|
99 |
if training_opts['supplement']: |
|
|
100 |
old_train_set = copy.deepcopy(datasets['train']) |
|
|
101 |
datasets['train']=DynamicImageDataset(dataset_df, 'train', training_opts['patch_info_file'], transformers, training_opts['input_dir'], training_opts['target_names'], training_opts['pos_annotation_class'], segmentation=training_opts['segmentation'], patch_size=training_opts['patch_size'], fix_names=training_opts['fix_names'], other_annotations=training_opts['other_annotations'], target_segmentation_class=-1, target_threshold=training_opts['target_threshold'], oversampling_factor=1,n_segmentation_classes=training_opts['num_targets'],gdl=training_opts['loss_fn']=='gdl',mt_bce=training_opts['mt_bce'],classify_annotations=training_opts['classify_annotations'],dilation_jitter=training_opts['dilation_jitter']) |
|
|
102 |
datasets['train'].concat(old_train_set) |
|
|
103 |
|
|
|
104 |
if training_opts['subsample_p']<1.0: |
|
|
105 |
datasets['train'].subsample(training_opts['subsample_p']) |
|
|
106 |
|
|
|
107 |
if training_opts['subsample_p_val']<1.0: |
|
|
108 |
if training_opts['subsample_p_val']==-1.: |
|
|
109 |
training_opts['subsample_p_val']=training_opts['subsample_p'] |
|
|
110 |
if training_opts['subsample_p_val']<1.0: |
|
|
111 |
datasets['val'].subsample(training_opts['subsample_p_val']) |
|
|
112 |
|
|
|
113 |
if training_opts['classify_annotations']: |
|
|
114 |
binarizer=datasets['train'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold']) |
|
|
115 |
datasets['val'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold']) |
|
|
116 |
datasets['test'].binarize_annotations(num_targets=training_opts['num_targets'],binary_threshold=training_opts['binary_threshold']) |
|
|
117 |
training_opts['num_targets']=len(datasets['train'].targets) |
|
|
118 |
|
|
|
119 |
for Set in ['train','val','test']: |
|
|
120 |
print(datasets[Set].patch_info.iloc[:,6:].sum(axis=0)) |
|
|
121 |
|
|
|
122 |
if training_opts['prediction_set']!='test': |
|
|
123 |
datasets['test']=datasets[training_opts['prediction_set']] |
|
|
124 |
|
|
|
125 |
if training_opts['external_test_db'] and training_opts['external_test_dir']: |
|
|
126 |
datasets['test'].update_dataset(input_dir=training_opts['external_test_dir'],new_db=training_opts['external_test_db'],prediction_basename=training_opts['prediction_basename']) |
|
|
127 |
return datasets,training_opts,transform_opts |
|
|
128 |
|
|
|
129 |
#@pysnooper.snoop('train_model.log') |
|
|
130 |
def train_model_(training_opts): |
|
|
131 |
"""Function to train, predict on model. |
|
|
132 |
|
|
|
133 |
Parameters |
|
|
134 |
---------- |
|
|
135 |
training_opts : dict |
|
|
136 |
Training options populated from command line. |
|
|
137 |
|
|
|
138 |
""" |
|
|
139 |
|
|
|
140 |
cuda_available=torch.cuda.is_available() |
|
|
141 |
|
|
|
142 |
model = return_model(training_opts) |
|
|
143 |
|
|
|
144 |
dataset_df,dataset_opts,transformers=return_transformer(training_opts) |
|
|
145 |
|
|
|
146 |
if training_opts['extract_embedding'] and training_opts['npy_file']: |
|
|
147 |
dataset=NPYDataset(training_opts['patch_info_file'],training_opts['patch_size'],training_opts['npy_file'],transformers["test"]) |
|
|
148 |
dataset.embed(model,training_opts['batch_size'],training_opts['prediction_output_dir']) |
|
|
149 |
exit() |
|
|
150 |
|
|
|
151 |
datasets,training_opts,transform_opts=return_datasets(training_opts,dataset_df,transformers) |
|
|
152 |
|
|
|
153 |
if training_opts['num_training_images_epoch']>0: |
|
|
154 |
num_train_batches = min(training_opts['num_training_images_epoch'],len(datasets['train']))//training_opts['batch_size'] |
|
|
155 |
else: |
|
|
156 |
num_train_batches = None |
|
|
157 |
|
|
|
158 |
dataloaders={set: DataLoader(datasets[set], batch_size=training_opts['batch_size'], shuffle=(set=='train') if not (training_opts['imbalanced_correction'] and not training_opts['segmentation']) else False, num_workers=10, sampler=ImbalancedDatasetSampler(datasets[set]) if (training_opts['imbalanced_correction'] and set=='train' and not training_opts['segmentation']) else None) for set in ['train', 'val', 'test']} |
|
|
159 |
|
|
|
160 |
print(dataloaders['train'].sampler) # FIXME VAL SEEMS TO BE MISSING DURING PREDICTION |
|
|
161 |
print(dataloaders['val'].sampler) |
|
|
162 |
|
|
|
163 |
if training_opts['run_test']: run_test(dataloaders['train']) |
|
|
164 |
|
|
|
165 |
model_trainer_opts=return_trainer_opts(model,training_opts,dataloders,num_train_batches) |
|
|
166 |
|
|
|
167 |
if not training_opts['predict']: |
|
|
168 |
|
|
|
169 |
trainer = ModelTrainer(**model_trainer_opts) |
|
|
170 |
|
|
|
171 |
if training_opts['imbalanced_correction2']: |
|
|
172 |
trainer.add_class_balance_loss(datasets['train']) |
|
|
173 |
elif training_opts['custom_weights']: |
|
|
174 |
trainer.add_class_balance_loss(datasets['train'],custom_weights=training_opts['custom_weights']) |
|
|
175 |
|
|
|
176 |
if training_opts['adopt_training_loss']: |
|
|
177 |
trainer.val_loss_fn = trainer.loss_fn |
|
|
178 |
|
|
|
179 |
trainer.fit(dataloaders['train'], verbose=True, print_every=1, plot_training_curves=True, plot_save_file=training_opts['training_curve'], print_val_confusion=training_opts['print_val_confusion'], save_val_predictions=training_opts['save_val_predictions']) |
|
|
180 |
|
|
|
181 |
torch.save(trainer.model.state_dict(),training_opts['save_location']) |
|
|
182 |
|
|
|
183 |
else: |
|
|
184 |
|
|
|
185 |
if training_opts['extract_model']: |
|
|
186 |
dataset_opts.update(dict(target_segmentation_class=-1, target_threshold=training_opts['target_threshold'][0] if len(training_opts['target_threshold']) else 0., set='test', binary_threshold=training_opts['binary_threshold'], num_targets=training_opts['num_targets'], oversampling_factor=1)) |
|
|
187 |
torch.save(dict(model=model,dataset_opts=dataset_opts, transform_opts=transform_opts),'{}.{}'.format(training_opts['save_location'],'extracted_model.pkl')) |
|
|
188 |
exit() |
|
|
189 |
|
|
|
190 |
trainer = ModelTrainer(**model_trainer_opts) |
|
|
191 |
|
|
|
192 |
if training_opts['segmentation']: |
|
|
193 |
for ID, dataset in (datasets['test'].split_by_ID() if not training_opts['prediction_basename'] else datasets['test'].select_IDs(training_opts['prediction_basename'])): |
|
|
194 |
dataloader = DataLoader(dataset, batch_size=training_opts['batch_size'], shuffle=False, num_workers=10) |
|
|
195 |
if training_opts['run_test']: |
|
|
196 |
for X,y in dataloader: |
|
|
197 |
np.save('test_predictions.npy',model(X.cuda() if torch.cuda.is_available() else X).detach().cpu().numpy()) |
|
|
198 |
exit() |
|
|
199 |
y_pred = trainer.predict(dataloader) |
|
|
200 |
print(ID,y_pred.shape) |
|
|
201 |
segmentation_predictions2npy(y_pred, dataset.patch_info, dataset.segmentation_maps[ID], npy_output='{}/{}_predict.npy'.format(training_opts['prediction_output_dir'],ID), original_patch_size=training_opts['patch_size'], resized_patch_size=training_opts['patch_resize'], output_probs=(training_opts['seg_out_class']>=0)) |
|
|
202 |
else: |
|
|
203 |
extract_embedding=training_opts['extract_embedding'] |
|
|
204 |
if extract_embedding: |
|
|
205 |
trainer.bce=False |
|
|
206 |
|
|
|
207 |
y_pred = trainer.predict(dataloaders['test']) |
|
|
208 |
|
|
|
209 |
patch_info = dataloaders['test'].dataset.patch_info |
|
|
210 |
|
|
|
211 |
if extract_embedding: |
|
|
212 |
patch_info['name']=patch_info.astype(str).apply(lambda x: '\n'.join(['{}:{}'.format(k,v) for k,v in x.to_dict().items()]),axis=1)#.apply(','.join,axis=1) |
|
|
213 |
embeddings=pd.DataFrame(y_pred,index=patch_info['name']) |
|
|
214 |
embeddings['ID']=patch_info['ID'].values |
|
|
215 |
torch.save(dict(embeddings=embeddings,patch_info=patch_info),join(training_opts['prediction_output_dir'],'embeddings.pkl')) |
|
|
216 |
|
|
|
217 |
else: |
|
|
218 |
if len(y_pred.shape)>1 and y_pred.shape[1]>1: |
|
|
219 |
annotations = np.vectorize(lambda x: x+'_pred')(np.arange(y_pred.shape[1]).astype(str)).tolist() # [training_opts['pos_annotation_class']]+training_opts['other_annotations']] if training_opts['classify_annotations'] else |
|
|
220 |
for i in range(y_pred.shape[1]): |
|
|
221 |
patch_info.loc[:,annotations[i]]=y_pred[:,i] |
|
|
222 |
patch_info['y_pred']=y_pred if (training_opts['num_targets']==1 or not (training_opts['classify_annotations'] or training_opts['mt_bce'])) else y_pred.argmax(axis=1) |
|
|
223 |
|
|
|
224 |
conn = sqlite3.connect(training_opts['prediction_save_path']) |
|
|
225 |
patch_info.to_sql(str(training_opts['patch_size']),con=conn, if_exists=('replace')) # if not training_opts['prediction_basename'] else 'append')) |
|
|
226 |
conn.close() |
|
|
227 |
|
|
|
228 |
@train.command() |
|
|
229 |
@click.option('-s', '--segmentation', is_flag=True, help='Segmentation task.', show_default=True) |
|
|
230 |
@click.option('-p', '--prediction', is_flag=True, help='Predict on model.', show_default=True) |
|
|
231 |
@click.option('-pa', '--pos_annotation_class', default='', help='Annotation Class from which to apply positive labels.', type=click.Path(exists=False), show_default=True) |
|
|
232 |
@click.option('-oa', '--other_annotations', default=[], multiple=True, help='Annotations in image.', type=click.Path(exists=False), show_default=True) |
|
|
233 |
@click.option('-o', '--save_location', default='', help='Model Save Location, append with pickle .pkl.', type=click.Path(exists=False), show_default=True) |
|
|
234 |
@click.option('-pt', '--pretrained_save_location', default='', help='Model Save Location, append with pickle .pkl, pretrained by previous analysis to be finetuned.', type=click.Path(exists=False), show_default=True) |
|
|
235 |
@click.option('-i', '--input_dir', default='', help='Input directory containing slides and everything.', type=click.Path(exists=False), show_default=True) |
|
|
236 |
@click.option('-ps', '--patch_size', default=224, help='Patch size.', show_default=True) |
|
|
237 |
@click.option('-pr', '--patch_resize', default=224, help='Patch resized.', show_default=True) |
|
|
238 |
@click.option('-tg', '--target_names', default=[], multiple=True, help='Targets.', type=click.Path(exists=False), show_default=True) |
|
|
239 |
@click.option('-df', '--dataset_df', default='', help='CSV file with train/val/test and target info.', type=click.Path(exists=False), show_default=True) |
|
|
240 |
@click.option('-fn', '--fix_names', is_flag=True, help='Whether to fix names in dataset_df.', show_default=True) |
|
|
241 |
@click.option('-a', '--architecture', default='alexnet', help='Neural Network Architecture.', type=click.Choice(['alexnet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', |
|
|
242 |
'inception_v3', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'vgg11', 'vgg11_bn','unet','unet2','nested_unet','fast_scnn', |
|
|
243 |
'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'deeplabv3_resnet101','deeplabv3_resnet50','fcn_resnet101', 'fcn_resnet50',"sqnxt23_w3d2", "sqnxt23_w2", "sqnxt23v5_w1", "sqnxt23v5_w3d2", "sqnxt23v5_w2"]+['efficientnet-b{}'.format(i) for i in range(8)]), show_default=True) |
|
|
244 |
@click.option('-imb', '--imbalanced_correction', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True) |
|
|
245 |
@click.option('-imb2', '--imbalanced_correction2', is_flag=True, help='Attempt to correct for imbalanced data.', show_default=True) |
|
|
246 |
@click.option('-ca', '--classify_annotations', is_flag=True, help='Classify annotations.', show_default=True) |
|
|
247 |
@click.option('-nt', '--num_targets', default=1, help='Number of targets.', show_default=True) |
|
|
248 |
@click.option('-ss', '--subsample_p', default=1.0, help='Subsample training set.', show_default=True) |
|
|
249 |
@click.option('-ssv', '--subsample_p_val', default=-1., help='Subsample val set. If not set, defaults to that of training set', show_default=True) |
|
|
250 |
@click.option('-t', '--num_training_images_epoch', default=-1, help='Number of training images per epoch. -1 means use all training images each epoch.', show_default=True) |
|
|
251 |
@click.option('-lr', '--learning_rate', default=1e-2, help='Learning rate.', show_default=True) |
|
|
252 |
@click.option('-tp', '--transform_platform', default='torch', help='Transform platform for nonsegmentation tasks.', type=click.Choice(['torch','albumentations'])) |
|
|
253 |
@click.option('-ne', '--n_epoch', default=10, help='Number of epochs.', show_default=True) |
|
|
254 |
@click.option('-pi', '--patch_info_file', default='patch_info.db', help='Patch info file.', type=click.Path(exists=False), show_default=True) |
|
|
255 |
@click.option('-tc', '--target_segmentation_class', default=[-1], multiple=True, help='Segmentation Class to finetune on.', show_default=True) |
|
|
256 |
@click.option('-tt', '--target_threshold', default=[0.], multiple=True, help='Threshold to include target for segmentation if saving one class.', show_default=True) |
|
|
257 |
@click.option('-ov', '--oversampling_factor', default=[1.], multiple=True, help='How much to oversample training set.', show_default=True) |
|
|
258 |
@click.option('-sup', '--supplement', is_flag=True, help='Use the thresholding to supplement the original training set.', show_default=True) |
|
|
259 |
@click.option('-bs', '--batch_size', default=10, help='Batch size.', show_default=True) |
|
|
260 |
@click.option('-rt', '--run_test', is_flag=True, help='Output predictions for a batch to "test_predictions.npy". Use for debugging.', show_default=True) |
|
|
261 |
@click.option('-mtb', '--mt_bce', is_flag=True, help='Run multi-target bce predictions on the annotations.', show_default=True) |
|
|
262 |
@click.option('-po', '--prediction_output_dir', default='predictions', help='Where to output segmentation predictions.', type=click.Path(exists=False), show_default=True) |
|
|
263 |
@click.option('-ee', '--extract_embedding', is_flag=True, help='Extract embeddings.', show_default=True) |
|
|
264 |
@click.option('-em', '--extract_model', is_flag=True, help='Save entire torch model.', show_default=True) |
|
|
265 |
@click.option('-bt', '--binary_threshold', default=0., help='If running binary classification on annotations, dichotomize selected annotation as such.', show_default=True) |
|
|
266 |
@click.option('-prt', '--pretrain', is_flag=True, help='Pretrain on ImageNet.', show_default=True) |
|
|
267 |
@click.option('-olf', '--overwrite_loss_fn', default='', help='Overwrite the default training loss functions with loss of choice.', type=click.Choice(['','bce','mse','focal','dice','gdl','ce']), show_default=True) |
|
|
268 |
@click.option('-atl', '--adopt_training_loss', is_flag=True, help='Adopt training loss function for validation calculation.', show_default=True) |
|
|
269 |
@click.option('-tdb', '--external_test_db', default='', help='External database of samples to test on.', type=click.Path(exists=False), show_default=True) |
|
|
270 |
@click.option('-tdir', '--external_test_dir', default='', help='External directory of samples to test on.', type=click.Path(exists=False), show_default=True) |
|
|
271 |
@click.option('-pb', '--prediction_basename', default=[''], multiple=True, help='For segmentation tasks, if supplied, can predict on these basenames rather than the entire test set. Only works for segmentation tasks for now', show_default=True) |
|
|
272 |
@click.option('-cw', '--custom_weights', default='', help='Comma delimited custom weights', type=click.Path(exists=False), show_default=True) |
|
|
273 |
@click.option('-pset', '--prediction_set', default='test', help='Dataset to predict on.', type=click.Choice(['train','val','test']), show_default=True) |
|
|
274 |
@click.option('-ut', '--user_transforms_file', default='', help='YAML file to add transforms from.', type=click.Path(exists=False), show_default=True) |
|
|
275 |
@click.option('-svp', '--save_val_predictions', is_flag=True, help='Whether to save the validation predictions.', show_default=True) |
|
|
276 |
@click.option('-soc', '--seg_out_class', default=-1, help='Output a particular segmentation class probabilities.', show_default=True) |
|
|
277 |
@click.option('-aol', '--apex_opt_level', default='O2', help='YAML file to add transforms from.', type=click.Choice(['O0','O1','O2','O3']), show_default=True) |
|
|
278 |
@click.option('-ckp', '--checkpointing', is_flag=True, help='Save intermediate models to ./checkpoints.', show_default=True) |
|
|
279 |
@click.option('-npy', '--npy_file', default='', help='Specify one file to extract embeddings from. Embeddings are output into predictions directory', type=click.Path(exists=False), show_default=True) |
|
|
280 |
@click.option('-gpu', '--gpu_id', default=-1, help='Set GPU if 0 and greater.', show_default=True) |
|
|
281 |
def train_model(segmentation,prediction,pos_annotation_class,other_annotations,save_location,pretrained_save_location,input_dir,patch_size,patch_resize,target_names,dataset_df,fix_names, architecture, imbalanced_correction, imbalanced_correction2, classify_annotations, num_targets, subsample_p,subsample_p_val,num_training_images_epoch, learning_rate, transform_platform, n_epoch, patch_info_file, target_segmentation_class, target_threshold, oversampling_factor, supplement, batch_size, run_test, mt_bce, prediction_output_dir, extract_embedding, extract_model, binary_threshold, pretrain, overwrite_loss_fn, adopt_training_loss, external_test_db,external_test_dir, prediction_basename, custom_weights, prediction_set, user_transforms_file, save_val_predictions, seg_out_class, apex_opt_level, checkpointing, npy_file, gpu_id): |
|
|
282 |
"""Train and predict using model for regression and classification tasks.""" |
|
|
283 |
# add separate pretrain ability on separating cell types, then transfer learn |
|
|
284 |
# add pretrain and efficient net, pretraining remove last layer while loading state dict |
|
|
285 |
if gpu_id>=0: |
|
|
286 |
torch.cuda.set_device(gpu_id) |
|
|
287 |
target_segmentation_class=list(map(int,target_segmentation_class)) |
|
|
288 |
target_threshold=list(map(float,target_threshold)) |
|
|
289 |
oversampling_factor=[(int(x) if float(x)>=1 else float(x)) for x in oversampling_factor] |
|
|
290 |
other_annotations=list(other_annotations) |
|
|
291 |
prediction_basename=list(filter(None,prediction_basename)) |
|
|
292 |
command_opts = dict(segmentation=segmentation, |
|
|
293 |
prediction=prediction, |
|
|
294 |
pos_annotation_class=pos_annotation_class, |
|
|
295 |
other_annotations=other_annotations, |
|
|
296 |
save_location=save_location, |
|
|
297 |
pretrained_save_location=pretrained_save_location, |
|
|
298 |
input_dir=input_dir, |
|
|
299 |
patch_size=patch_size, |
|
|
300 |
target_names=target_names, |
|
|
301 |
dataset_df=dataset_df, |
|
|
302 |
fix_names=fix_names, |
|
|
303 |
architecture=architecture, |
|
|
304 |
patch_resize=patch_resize, |
|
|
305 |
imbalanced_correction=imbalanced_correction, |
|
|
306 |
imbalanced_correction2=imbalanced_correction2, |
|
|
307 |
classify_annotations=classify_annotations, |
|
|
308 |
num_targets=num_targets, |
|
|
309 |
subsample_p=subsample_p, |
|
|
310 |
num_training_images_epoch=num_training_images_epoch, |
|
|
311 |
lr=learning_rate, |
|
|
312 |
transform_platform=transform_platform, |
|
|
313 |
n_epoch=n_epoch, |
|
|
314 |
patch_info_file=patch_info_file, |
|
|
315 |
target_segmentation_class=target_segmentation_class, |
|
|
316 |
target_threshold=target_threshold, |
|
|
317 |
oversampling_factor=oversampling_factor, |
|
|
318 |
supplement=supplement, |
|
|
319 |
predict=prediction, |
|
|
320 |
batch_size=batch_size, |
|
|
321 |
run_test=run_test, |
|
|
322 |
mt_bce=mt_bce, |
|
|
323 |
prediction_output_dir=prediction_output_dir, |
|
|
324 |
extract_embedding=extract_embedding, |
|
|
325 |
extract_model=extract_model, |
|
|
326 |
binary_threshold=binary_threshold, |
|
|
327 |
subsample_p_val=subsample_p_val, |
|
|
328 |
wd=1e-3, |
|
|
329 |
scheduler_type='warm_restarts', |
|
|
330 |
T_max=10, |
|
|
331 |
T_mult=2, |
|
|
332 |
eta_min=5e-8, |
|
|
333 |
optimizer='adam', |
|
|
334 |
n_hidden=100, |
|
|
335 |
pretrain=pretrain, |
|
|
336 |
training_curve='training_curve.png', |
|
|
337 |
adopt_training_loss=adopt_training_loss, |
|
|
338 |
external_test_db=external_test_db, |
|
|
339 |
external_test_dir=external_test_dir, |
|
|
340 |
prediction_basename=prediction_basename, |
|
|
341 |
save_val_predictions=save_val_predictions, |
|
|
342 |
custom_weights=custom_weights, |
|
|
343 |
prediction_set=prediction_set, |
|
|
344 |
user_transforms=dict(), |
|
|
345 |
dilation_jitter=dict(), |
|
|
346 |
seg_out_class=seg_out_class, |
|
|
347 |
apex_opt_level=apex_opt_level, |
|
|
348 |
checkpointing=checkpointing, |
|
|
349 |
npy_file=npy_file) |
|
|
350 |
|
|
|
351 |
training_opts = dict(normalization_file="normalization_parameters.pkl", |
|
|
352 |
loss_fn='bce', |
|
|
353 |
print_val_confusion=True, |
|
|
354 |
prediction_save_path = 'predictions.db', |
|
|
355 |
train_val_test_splits='train_val_test.pkl' |
|
|
356 |
) |
|
|
357 |
segmentation_training_opts = copy.deepcopy(training_opts) |
|
|
358 |
segmentation_training_opts.update(dict(loss_fn='dice',#gdl dice+ce |
|
|
359 |
normalization_file='normalization_segmentation.pkl', |
|
|
360 |
fix_names=False, |
|
|
361 |
)) |
|
|
362 |
if segmentation: |
|
|
363 |
training_opts = segmentation_training_opts |
|
|
364 |
for k in command_opts: |
|
|
365 |
training_opts[k] = command_opts[k] |
|
|
366 |
if classify_annotations: |
|
|
367 |
if training_opts['num_targets']==1: |
|
|
368 |
training_opts['loss_fn']='bce' |
|
|
369 |
else: |
|
|
370 |
training_opts['loss_fn']='ce' |
|
|
371 |
if mt_bce: |
|
|
372 |
training_opts['loss_fn']='bce' |
|
|
373 |
if overwrite_loss_fn: |
|
|
374 |
training_opts['loss_fn']=overwrite_loss_fn |
|
|
375 |
|
|
|
376 |
if user_transforms_file and os.path.exists(user_transforms_file): |
|
|
377 |
from yaml import load as yml_load |
|
|
378 |
try: |
|
|
379 |
from yaml import CLoader as Loader, CDumper as Dumper |
|
|
380 |
except ImportError: |
|
|
381 |
from yaml import Loader, Dumper |
|
|
382 |
with open(user_transforms_file) as f: |
|
|
383 |
training_opts['user_transforms']=yml_load(f,Loader=Loader) |
|
|
384 |
if 'dilationjitter' in list(training_opts['user_transforms'].keys()): |
|
|
385 |
training_opts['dilation_jitter']=training_opts['user_transforms'].pop('dilationjitter') |
|
|
386 |
|
|
|
387 |
train_model_(training_opts) |
|
|
388 |
|
|
|
389 |
|
|
|
390 |
if __name__=='__main__': |
|
|
391 |
train() |