[139527]: / lightning / classification.py

Download this file

91 lines (75 with data), 3.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from models import get_model
from eval import get_loss_fn, BinaryClassificationEvaluator
from data import ImageClassificationDemoDataset
from util import constants as C
from .logger import TFLogger
class ClassificationTask(pl.LightningModule, TFLogger):
"""Standard interface for the trainer to interact with the model."""
def __init__(self, params):
super().__init__()
self.save_hyperparameters(params)
self.model = get_model(params)
self.loss = get_loss_fn(params)
self.evaluator = BinaryClassificationEvaluator(threshold=0.5)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
"""
Returns:
A dictionary of loss and metrics, with:
loss(required): loss used to calculate the gradient
log: metrics to be logged to the TensorBoard and metrics.csv
progress_bar: metrics to be logged to the progress bar
and metrics.csv
"""
x, y = batch
logits = self.forward(x)
loss = self.loss(logits.view(-1), y)
self.log("loss", loss)
return loss
def validation_step(self, batch, batch_nb):
x, y = batch
logits = self.forward(x)
loss = self.loss(logits.view(-1), y)
y_hat = (logits > 0).float()
self.evaluator.update((torch.sigmoid(logits), y))
return loss
def validation_epoch_end(self, outputs):
"""
Aggregate and return the validation metrics
Args:
outputs: A list of dictionaries of metrics from `validation_step()'
Returns: None
Returns:
A dictionary of loss and metrics, with:
val_loss (required): validation_loss
log: metrics to be logged to the TensorBoard and metrics.csv
progress_bar: metrics to be logged to the progress bar
and metrics.csv
"""
avg_loss = torch.stack(outputs).mean()
self.log("val_loss", avg_loss)
metrics = self.evaluator.evaluate()
self.evaluator.reset()
self.log_dict(metrics)
def test_step(self, batch, batch_nb):
return self.validation_step(batch, batch_nb)
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs)
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]
def train_dataloader(self):
dataset = ImageClassificationDemoDataset()
return DataLoader(dataset, shuffle=True,
batch_size=2, num_workers=8)
def val_dataloader(self):
dataset = ImageClassificationDemoDataset()
return DataLoader(dataset, shuffle=False,
batch_size=1, num_workers=8)
def test_dataloader(self):
dataset = ImageClassificationDemoDataset()
return DataLoader(dataset, shuffle=False,
batch_size=1, num_workers=8)