|
a |
|
b/lightning/detection.py |
|
|
1 |
import pytorch_lightning as pl |
|
|
2 |
import torch |
|
|
3 |
import torchvision.transforms as T |
|
|
4 |
from torch.utils.data import DataLoader |
|
|
5 |
from ignite.metrics import Accuracy |
|
|
6 |
|
|
|
7 |
from models import get_model |
|
|
8 |
from eval import DetectionEvaluator |
|
|
9 |
from data import ImageDetectionDemoDataset |
|
|
10 |
from util import constants as C |
|
|
11 |
from .logger import TFLogger |
|
|
12 |
|
|
|
13 |
import pdb |
|
|
14 |
|
|
|
15 |
class DetectionTask(pl.LightningModule, TFLogger): |
|
|
16 |
"""Standard interface for the trainer to interact with the model.""" |
|
|
17 |
|
|
|
18 |
def __init__(self, params): |
|
|
19 |
super().__init__() #Initialize parent classes (pl.LightningModule params) |
|
|
20 |
self.save_hyperparameters(params) #Save hyperparameters to experiment directory, pytorch lightning function |
|
|
21 |
self.model = get_model(params) #Instantiates model from model folder |
|
|
22 |
self.evaluator = DetectionEvaluator() |
|
|
23 |
|
|
|
24 |
def training_step(self, batch, batch_nb): #Batch of data from train dataloader passed here |
|
|
25 |
losses = self.model.forward(batch) |
|
|
26 |
loss = torch.stack(list(losses.values())).mean() |
|
|
27 |
return loss |
|
|
28 |
|
|
|
29 |
def validation_step(self, batch, batch_nb): #Called once for every batch |
|
|
30 |
|
|
|
31 |
losses = self.model.forward(batch) |
|
|
32 |
loss = torch.stack(list(losses.values())).mean() |
|
|
33 |
preds = self.model.infer(batch) |
|
|
34 |
self.evaluator.process(batch, preds) |
|
|
35 |
return loss |
|
|
36 |
|
|
|
37 |
def validation_epoch_end(self, outputs): #outputs are loss tensors from validation step |
|
|
38 |
avg_loss = torch.stack(outputs).mean() |
|
|
39 |
self.log("val_loss", avg_loss) |
|
|
40 |
metrics = self.evaluator.evaluate() |
|
|
41 |
self.evaluator.reset() |
|
|
42 |
self.log_dict(metrics, prog_bar=True) |
|
|
43 |
|
|
|
44 |
def test_step(self, batch, batch_nb): |
|
|
45 |
preds = self.model.infer(batch) |
|
|
46 |
self.evaluator.process(batch, preds) |
|
|
47 |
|
|
|
48 |
def test_epoch_end(self, outputs): |
|
|
49 |
metrics = self.evaluator.evaluate() |
|
|
50 |
self.log_dict(metrics) |
|
|
51 |
return metrics |
|
|
52 |
|
|
|
53 |
def configure_optimizers(self): |
|
|
54 |
return [torch.optim.Adam(self.parameters(), lr=0.02)] |
|
|
55 |
|
|
|
56 |
def train_dataloader(self): #Called during init |
|
|
57 |
dataset = ImageDetectionDemoDataset() #For specific examples |
|
|
58 |
return DataLoader(dataset, shuffle=True, #For entire batch |
|
|
59 |
batch_size=2, num_workers=8, |
|
|
60 |
collate_fn=lambda x: x) |
|
|
61 |
|
|
|
62 |
def val_dataloader(self): #Called during init |
|
|
63 |
dataset = ImageDetectionDemoDataset() |
|
|
64 |
return DataLoader(dataset, shuffle=False, |
|
|
65 |
batch_size=1, num_workers=8, collate_fn=lambda x: x) |
|
|
66 |
|
|
|
67 |
def test_dataloader(self): #Called during init |
|
|
68 |
dataset = ImageDetectionDemoDataset() |
|
|
69 |
return DataLoader(dataset, shuffle=False, |
|
|
70 |
batch_size=1, num_workers=8, collate_fn=lambda x: x) |
|
|
71 |
|
|
|
72 |
#Process |
|
|
73 |
#1. Call Trainer.fit |
|
|
74 |
#2. Run validation step twice |
|
|
75 |
#3. Run validation epoch end as a dummy run |
|
|
76 |
#4. Call training step for all training batches, till end of epoch |
|
|
77 |
#5. Call validation step for all validation batches till validation batches exhausted |
|
|
78 |
#6. Compute validation metric at validation epoch end, log it, save checkpoint if validation metrics improve |
|
|
79 |
#7. Return to training step 4 and continue |
|
|
80 |
|
|
|
81 |
#Test steps are only called when you call main.py |