|
a |
|
b/main.py |
|
|
1 |
import os |
|
|
2 |
import fire |
|
|
3 |
from pytorch_lightning import Trainer |
|
|
4 |
import numpy as np |
|
|
5 |
import pdb |
|
|
6 |
|
|
|
7 |
from util import init_exp_folder, Args |
|
|
8 |
from lightning import (get_task, |
|
|
9 |
load_task, |
|
|
10 |
get_ckpt_callback, |
|
|
11 |
get_early_stop_callback, |
|
|
12 |
get_logger) |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def train(dataset_folder="./data_files", |
|
|
16 |
save_dir="./sandbox", |
|
|
17 |
exp_name="DemoExperiment", |
|
|
18 |
model="ResNet18", |
|
|
19 |
task='classification', |
|
|
20 |
gpus=1, |
|
|
21 |
pretrained=True, |
|
|
22 |
num_classes=1, |
|
|
23 |
accelerator=None, |
|
|
24 |
logger_type='test_tube', |
|
|
25 |
gradient_clip_val=0.5, |
|
|
26 |
max_epochs=1, |
|
|
27 |
patience=10, |
|
|
28 |
stochastic_weight_avg=True, |
|
|
29 |
limit_train_batches=1.0, |
|
|
30 |
tb_path="./sandbox/tb", |
|
|
31 |
loss_fn="BCE", |
|
|
32 |
weights_summary=None, |
|
|
33 |
augmentation = 'none', |
|
|
34 |
num_workers=0, |
|
|
35 |
auto_lr_find= True, |
|
|
36 |
lr = 0.001, |
|
|
37 |
batch_size = 2, |
|
|
38 |
# pretraining = False, |
|
|
39 |
# aux_task = None |
|
|
40 |
): |
|
|
41 |
""" |
|
|
42 |
Run the training experiment. |
|
|
43 |
|
|
|
44 |
Args: |
|
|
45 |
save_dir: Path to save the checkpoints and logs |
|
|
46 |
exp_name: Name of the experiment |
|
|
47 |
model: Model name |
|
|
48 |
gpus: int. (ie: 2 gpus) |
|
|
49 |
OR list to specify which GPUs [0, 1] OR '0,1' |
|
|
50 |
OR '-1' / -1 to use all available gpus |
|
|
51 |
pretrained: Whether or not to use the pretrained model |
|
|
52 |
num_classes: Number of classes |
|
|
53 |
accelerator: Distributed computing mode |
|
|
54 |
logger_type: 'wandb' or 'test_tube' |
|
|
55 |
gradient_clip_val: Clip value of gradient norm |
|
|
56 |
limit_train_batches: Proportion of training data to use |
|
|
57 |
max_epochs: Max number of epochs |
|
|
58 |
patience: number of epochs with no improvement after |
|
|
59 |
which training will be stopped. |
|
|
60 |
stochastic_weight_avg: Whether to use stochastic weight averaging. |
|
|
61 |
tb_path: Path to global tb folder |
|
|
62 |
loss_fn: Loss function to use |
|
|
63 |
weights_summary: Prints a summary of the weights when training begins. |
|
|
64 |
|
|
|
65 |
Returns: None |
|
|
66 |
|
|
|
67 |
""" |
|
|
68 |
args = Args(locals()) #Allows you to access stuff in the dictionary as args.exp_name |
|
|
69 |
init_exp_folder(args) #Sets up experiment directory |
|
|
70 |
task = get_task(args) #Have to define this pytorch lightning module, for implementation, where constructor in segmentation.py is called |
|
|
71 |
#Then you instantiate trainer and start training |
|
|
72 |
trainer = Trainer(gpus=gpus, |
|
|
73 |
accelerator=accelerator, |
|
|
74 |
logger=get_logger(logger_type, save_dir, exp_name), #Logging tool |
|
|
75 |
callbacks=[get_early_stop_callback(patience), #Set number of epochs without improvement before stopping |
|
|
76 |
get_ckpt_callback(save_dir, exp_name)], #Save model checkpoints to folder, defined by certain metrics |
|
|
77 |
weights_save_path=os.path.join(save_dir, exp_name), |
|
|
78 |
gradient_clip_val=None, |
|
|
79 |
limit_train_batches=limit_train_batches, #When debugging, limit number of batches to run (percentage of data) |
|
|
80 |
weights_summary=weights_summary, |
|
|
81 |
stochastic_weight_avg=stochastic_weight_avg, |
|
|
82 |
max_epochs=max_epochs, |
|
|
83 |
auto_lr_find=True, |
|
|
84 |
reload_dataloaders_every_n_epochs=1, |
|
|
85 |
log_every_n_steps=1) #Handles functionality of training |
|
|
86 |
trainer.fit(task) |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def test(ckpt_path, |
|
|
90 |
gpus=0, |
|
|
91 |
**kwargs): |
|
|
92 |
""" |
|
|
93 |
Run the testing experiment. |
|
|
94 |
|
|
|
95 |
Args: |
|
|
96 |
ckpt_path: Path for the experiment to load |
|
|
97 |
gpus: int. (ie: 2 gpus) |
|
|
98 |
OR list to specify which GPUs [0, 1] OR '0,1' |
|
|
99 |
OR '-1' / -1 to use all available gpus |
|
|
100 |
Returns: None |
|
|
101 |
|
|
|
102 |
""" |
|
|
103 |
task = load_task(ckpt_path, **kwargs) |
|
|
104 |
trainer = Trainer(gpus=gpus) |
|
|
105 |
trainer.test(task) |
|
|
106 |
|
|
|
107 |
def predict(ckpt_path, gpus=1, prediction_path="predictions.pt", **kwargs): |
|
|
108 |
# couldn't figure out how to pass in a specific dataset as an argument |
|
|
109 |
# by default, this makes predictions over the test dataset |
|
|
110 |
# can change the prediction dataset in predict_dataloader() function in segmentation.py |
|
|
111 |
task = load_task(ckpt_path, **kwargs) |
|
|
112 |
trainer = Trainer(gpus=gpus) |
|
|
113 |
trainer.predict(task) |
|
|
114 |
preds_tensor = task.evaluator.preds |
|
|
115 |
preds = preds_tensor.cpu().detach() |
|
|
116 |
torch.save(prediction_path, predictions) |
|
|
117 |
|
|
|
118 |
if __name__ == "__main__": |
|
|
119 |
fire.Fire() #Allows you to run functions and supply arguments directly in command line |