[139527]: / main.py

Download this file

120 lines (107 with data), 4.6 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import fire
from pytorch_lightning import Trainer
import numpy as np
import pdb
from util import init_exp_folder, Args
from lightning import (get_task,
load_task,
get_ckpt_callback,
get_early_stop_callback,
get_logger)
def train(dataset_folder="./data_files",
save_dir="./sandbox",
exp_name="DemoExperiment",
model="ResNet18",
task='classification',
gpus=1,
pretrained=True,
num_classes=1,
accelerator=None,
logger_type='test_tube',
gradient_clip_val=0.5,
max_epochs=1,
patience=10,
stochastic_weight_avg=True,
limit_train_batches=1.0,
tb_path="./sandbox/tb",
loss_fn="BCE",
weights_summary=None,
augmentation = 'none',
num_workers=0,
auto_lr_find= True,
lr = 0.001,
batch_size = 2,
# pretraining = False,
# aux_task = None
):
"""
Run the training experiment.
Args:
save_dir: Path to save the checkpoints and logs
exp_name: Name of the experiment
model: Model name
gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
pretrained: Whether or not to use the pretrained model
num_classes: Number of classes
accelerator: Distributed computing mode
logger_type: 'wandb' or 'test_tube'
gradient_clip_val: Clip value of gradient norm
limit_train_batches: Proportion of training data to use
max_epochs: Max number of epochs
patience: number of epochs with no improvement after
which training will be stopped.
stochastic_weight_avg: Whether to use stochastic weight averaging.
tb_path: Path to global tb folder
loss_fn: Loss function to use
weights_summary: Prints a summary of the weights when training begins.
Returns: None
"""
args = Args(locals()) #Allows you to access stuff in the dictionary as args.exp_name
init_exp_folder(args) #Sets up experiment directory
task = get_task(args) #Have to define this pytorch lightning module, for implementation, where constructor in segmentation.py is called
#Then you instantiate trainer and start training
trainer = Trainer(gpus=gpus,
accelerator=accelerator,
logger=get_logger(logger_type, save_dir, exp_name), #Logging tool
callbacks=[get_early_stop_callback(patience), #Set number of epochs without improvement before stopping
get_ckpt_callback(save_dir, exp_name)], #Save model checkpoints to folder, defined by certain metrics
weights_save_path=os.path.join(save_dir, exp_name),
gradient_clip_val=None,
limit_train_batches=limit_train_batches, #When debugging, limit number of batches to run (percentage of data)
weights_summary=weights_summary,
stochastic_weight_avg=stochastic_weight_avg,
max_epochs=max_epochs,
auto_lr_find=True,
reload_dataloaders_every_n_epochs=1,
log_every_n_steps=1) #Handles functionality of training
trainer.fit(task)
def test(ckpt_path,
gpus=0,
**kwargs):
"""
Run the testing experiment.
Args:
ckpt_path: Path for the experiment to load
gpus: int. (ie: 2 gpus)
OR list to specify which GPUs [0, 1] OR '0,1'
OR '-1' / -1 to use all available gpus
Returns: None
"""
task = load_task(ckpt_path, **kwargs)
trainer = Trainer(gpus=gpus)
trainer.test(task)
def predict(ckpt_path, gpus=1, prediction_path="predictions.pt", **kwargs):
# couldn't figure out how to pass in a specific dataset as an argument
# by default, this makes predictions over the test dataset
# can change the prediction dataset in predict_dataloader() function in segmentation.py
task = load_task(ckpt_path, **kwargs)
trainer = Trainer(gpus=gpus)
trainer.predict(task)
preds_tensor = task.evaluator.preds
preds = preds_tensor.cpu().detach()
torch.save(prediction_path, predictions)
if __name__ == "__main__":
fire.Fire() #Allows you to run functions and supply arguments directly in command line