Diff of /main.py [000000] .. [139527]

Switch to unified view

a b/main.py
1
import os
2
import fire
3
from pytorch_lightning import Trainer
4
import numpy as np
5
import pdb
6
7
from util import init_exp_folder, Args
8
from lightning import (get_task,
9
                       load_task,
10
                       get_ckpt_callback, 
11
                       get_early_stop_callback,
12
                       get_logger)
13
14
15
def train(dataset_folder="./data_files",
16
          save_dir="./sandbox",
17
          exp_name="DemoExperiment",
18
          model="ResNet18",
19
          task='classification',
20
          gpus=1,
21
          pretrained=True,
22
          num_classes=1,
23
          accelerator=None,
24
          logger_type='test_tube',
25
          gradient_clip_val=0.5,
26
          max_epochs=1,
27
          patience=10,
28
          stochastic_weight_avg=True,
29
          limit_train_batches=1.0,
30
          tb_path="./sandbox/tb",
31
          loss_fn="BCE",
32
          weights_summary=None,
33
          augmentation = 'none',
34
          num_workers=0,
35
          auto_lr_find= True,
36
          lr = 0.001,
37
          batch_size = 2,
38
 #         pretraining = False,
39
 #         aux_task = None
40
          ):
41
    """
42
    Run the training experiment.
43
44
    Args:
45
        save_dir: Path to save the checkpoints and logs
46
        exp_name: Name of the experiment
47
        model: Model name
48
        gpus: int. (ie: 2 gpus)
49
             OR list to specify which GPUs [0, 1] OR '0,1'
50
             OR '-1' / -1 to use all available gpus
51
        pretrained: Whether or not to use the pretrained model
52
        num_classes: Number of classes
53
        accelerator: Distributed computing mode
54
        logger_type: 'wandb' or 'test_tube'
55
        gradient_clip_val:  Clip value of gradient norm
56
        limit_train_batches: Proportion of training data to use
57
        max_epochs: Max number of epochs
58
        patience: number of epochs with no improvement after
59
                  which training will be stopped.
60
        stochastic_weight_avg: Whether to use stochastic weight averaging.
61
        tb_path: Path to global tb folder
62
        loss_fn: Loss function to use
63
        weights_summary: Prints a summary of the weights when training begins.
64
65
    Returns: None
66
67
    """
68
    args = Args(locals()) #Allows you to access stuff in the dictionary as args.exp_name
69
    init_exp_folder(args) #Sets up experiment directory 
70
    task = get_task(args) #Have to define this pytorch lightning module, for implementation, where constructor in segmentation.py is called
71
    #Then you instantiate trainer and start training
72
    trainer = Trainer(gpus=gpus, 
73
                      accelerator=accelerator,
74
                      logger=get_logger(logger_type, save_dir, exp_name), #Logging tool
75
                      callbacks=[get_early_stop_callback(patience), #Set number of epochs without improvement before stopping
76
                                 get_ckpt_callback(save_dir, exp_name)], #Save model checkpoints to folder, defined by certain metrics
77
                      weights_save_path=os.path.join(save_dir, exp_name), 
78
                      gradient_clip_val=None,
79
                      limit_train_batches=limit_train_batches, #When debugging, limit number of batches to run (percentage of data)
80
                      weights_summary=weights_summary,
81
                      stochastic_weight_avg=stochastic_weight_avg,
82
                      max_epochs=max_epochs,
83
                      auto_lr_find=True,
84
                      reload_dataloaders_every_n_epochs=1,
85
                      log_every_n_steps=1) #Handles functionality of training
86
    trainer.fit(task)
87
88
89
def test(ckpt_path,
90
         gpus=0,
91
         **kwargs):
92
    """
93
    Run the testing experiment.
94
95
    Args:
96
        ckpt_path: Path for the experiment to load
97
        gpus: int. (ie: 2 gpus)
98
             OR list to specify which GPUs [0, 1] OR '0,1'
99
             OR '-1' / -1 to use all available gpus
100
    Returns: None
101
102
    """
103
    task = load_task(ckpt_path, **kwargs)
104
    trainer = Trainer(gpus=gpus)
105
    trainer.test(task)
106
107
def predict(ckpt_path, gpus=1, prediction_path="predictions.pt", **kwargs):
108
    # couldn't figure out how to pass in a specific dataset as an argument
109
    # by default, this makes predictions over the test dataset
110
    # can change the prediction dataset in predict_dataloader() function in segmentation.py
111
    task = load_task(ckpt_path, **kwargs)
112
    trainer = Trainer(gpus=gpus)
113
    trainer.predict(task)
114
    preds_tensor = task.evaluator.preds
115
    preds = preds_tensor.cpu().detach()
116
    torch.save(prediction_path, predictions)
117
118
if __name__ == "__main__":
119
    fire.Fire() #Allows you to run functions and supply arguments directly in command line