Diff of /lightning/detection.py [000000] .. [139527]

Switch to unified view

a b/lightning/detection.py
1
import pytorch_lightning as pl
2
import torch
3
import torchvision.transforms as T
4
from torch.utils.data import DataLoader
5
from ignite.metrics import Accuracy
6
7
from models import get_model
8
from eval import DetectionEvaluator
9
from data import ImageDetectionDemoDataset
10
from util import constants as C
11
from .logger import TFLogger
12
13
import pdb
14
15
class DetectionTask(pl.LightningModule, TFLogger):
16
    """Standard interface for the trainer to interact with the model."""
17
18
    def __init__(self, params):
19
        super().__init__() #Initialize parent classes (pl.LightningModule params)
20
        self.save_hyperparameters(params) #Save hyperparameters to experiment directory, pytorch lightning function
21
        self.model = get_model(params) #Instantiates model from model folder
22
        self.evaluator = DetectionEvaluator()
23
24
    def training_step(self, batch, batch_nb): #Batch of data from train dataloader passed here
25
        losses = self.model.forward(batch)
26
        loss = torch.stack(list(losses.values())).mean()
27
        return loss
28
29
    def validation_step(self, batch, batch_nb): #Called once for every batch
30
31
        losses = self.model.forward(batch)
32
        loss = torch.stack(list(losses.values())).mean()
33
        preds = self.model.infer(batch)
34
        self.evaluator.process(batch, preds)
35
        return loss
36
37
    def validation_epoch_end(self, outputs): #outputs are loss tensors from validation step
38
        avg_loss = torch.stack(outputs).mean()
39
        self.log("val_loss", avg_loss)
40
        metrics = self.evaluator.evaluate()
41
        self.evaluator.reset()
42
        self.log_dict(metrics, prog_bar=True)
43
44
    def test_step(self, batch, batch_nb):
45
        preds = self.model.infer(batch)
46
        self.evaluator.process(batch, preds)
47
48
    def test_epoch_end(self, outputs):
49
        metrics = self.evaluator.evaluate()
50
        self.log_dict(metrics)
51
        return metrics
52
53
    def configure_optimizers(self):
54
        return [torch.optim.Adam(self.parameters(), lr=0.02)]
55
56
    def train_dataloader(self): #Called during init
57
        dataset = ImageDetectionDemoDataset() #For specific examples
58
        return DataLoader(dataset, shuffle=True, #For entire batch
59
                          batch_size=2, num_workers=8,
60
                          collate_fn=lambda x: x)
61
62
    def val_dataloader(self): #Called during init
63
        dataset = ImageDetectionDemoDataset() 
64
        return DataLoader(dataset, shuffle=False,
65
                batch_size=1, num_workers=8, collate_fn=lambda x: x)
66
67
    def test_dataloader(self): #Called during init
68
        dataset = ImageDetectionDemoDataset() 
69
        return DataLoader(dataset, shuffle=False,
70
                batch_size=1, num_workers=8, collate_fn=lambda x: x)
71
72
    #Process
73
    #1. Call Trainer.fit
74
    #2. Run validation step twice
75
    #3. Run validation epoch end as a dummy run
76
    #4. Call training step for all training batches, till end of epoch
77
    #5. Call validation step for all validation batches till validation batches exhausted
78
    #6. Compute  validation metric at validation epoch end, log it, save checkpoint if validation metrics improve
79
    #7. Return to training step 4 and continue
80
81
    #Test steps are only called when you call main.py