Diff of /train.py [000000] .. [4b8af8]

Switch to unified view

a b/train.py
1
import os
2
3
import numpy as np
4
import pandas as pd
5
import cv2
6
7
import matplotlib.pyplot as plt
8
9
from sklearn.model_selection import train_test_split
10
11
import os
12
import time
13
14
import torch
15
import torch.nn as nn
16
from torch.optim import Adam
17
from torch.optim.lr_scheduler import ReduceLROnPlateau
18
19
from dataset_dataloader import LungsDataset, get_augmentations, get_dataloader
20
from loss_metric import Meter, BCEDiceLoss
21
22
from segmentation_models_pytorch.unet import Unet
23
import matplotlib.pyplot as plt
24
from IPython.display import clear_output
25
26
27
class Trainer:
28
    """
29
    Factory for training proccess.
30
    Args:
31
        display_plot: if True - plot train history after each epoch.
32
        net: neural network for mask prediction.
33
        criterion: factory for calculating objective loss.
34
        optimizer: optimizer for weights updating.
35
        phases: list with train and validation phases.
36
        dataloaders: dict with data loaders for train and val phases.
37
        imgs_dir: path to folder with images.
38
        masks_dir: path to folder with imasks.
39
        path_to_csv: path to csv file.
40
        meter: factory for storing and updating metrics.
41
        batch_size: data batch size for one step weights updating.
42
        num_epochs: num weights updation for all data.
43
        accumulation_steps: the number of steps after which the optimization step can be taken
44
                    (https://www.kaggle.com/c/understanding_cloud_organization/discussion/105614).
45
        lr: learning rate for optimizer.
46
        scheduler: scheduler for control learning rate.
47
        losses: dict for storing lists with losses for each phase.
48
        jaccard_scores: dict for storing lists with jaccard scores for each phase.
49
        dice_scores: dict for storing lists with dice scores for each phase.
50
    """
51
    def __init__(self,
52
                 net: nn.Module,
53
                 criterion: nn.Module,
54
                 lr: float,
55
                 accumulation_steps: int,
56
                 batch_size: int,
57
                 num_epochs: int,
58
                 imgs_dir: str,
59
                 masks_dir: str,
60
                 path_to_csv: str,
61
                 display_plot: bool = True
62
                ):
63
64
        """Initialization."""
65
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
66
        print("device:", self.device)
67
        self.display_plot = display_plot
68
        self.net = net
69
        self.net = self.net.to(self.device)
70
        self.criterion = criterion
71
        self.optimizer = Adam(self.net.parameters(), lr=lr)
72
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode="min",
73
                                           patience=3, verbose=True)
74
        self.accumulation_steps = accumulation_steps // batch_size
75
        self.phases = ["train", "val"]
76
        self.num_epochs = num_epochs
77
78
        self.dataloaders = {
79
            phase: get_dataloader(
80
                imgs_dir = imgs_dir,
81
                masks_dir = masks_dir,
82
                path_to_csv = path_to_csv,
83
                phase = phase,
84
                batch_size = 8,
85
                num_workers = 6
86
            )
87
            for phase in self.phases
88
        }
89
        self.best_loss = float("inf")
90
        self.losses = {phase: [] for phase in self.phases}
91
        self.dice_scores = {phase: [] for phase in self.phases}
92
        self.jaccard_scores = {phase: [] for phase in self.phases}
93
         
94
    def _compute_loss_and_outputs(self,
95
                                  images: torch.Tensor,
96
                                  targets: torch.Tensor):
97
        images = images.to(self.device)
98
        targets = targets.to(self.device)
99
        logits = self.net(images)
100
        loss = self.criterion(logits, targets)
101
        return loss, logits
102
        
103
    def _do_epoch(self, epoch: int, phase: str):
104
        print(f"{phase} epoch: {epoch} | time: {time.strftime('%H:%M:%S')}")
105
106
        self.net.train() if phase == "train" else self.net.eval()
107
        meter = Meter()
108
        dataloader = self.dataloaders[phase]
109
        total_batches = len(dataloader)
110
        running_loss = 0.0
111
        self.optimizer.zero_grad()
112
        for itr, (images, targets) in enumerate(dataloader):
113
            loss, logits = self._compute_loss_and_outputs(images, targets)
114
            loss = loss / self.accumulation_steps
115
            if phase == "train":
116
                loss.backward()
117
                if (itr + 1) % self.accumulation_steps == 0:
118
                    self.optimizer.step()
119
                    self.optimizer.zero_grad()
120
            running_loss += loss.item()
121
            meter.update(logits.detach().cpu(),
122
                         targets.detach().cpu()
123
                        )
124
            
125
        epoch_loss = (running_loss * self.accumulation_steps) / total_batches
126
        epoch_dice, epoch_iou = meter.get_metrics()
127
        
128
        self.losses[phase].append(epoch_loss)
129
        self.dice_scores[phase].append(epoch_dice)
130
        self.jaccard_scores[phase].append(epoch_iou)
131
132
        return epoch_loss
133
        
134
    def train(self):
135
        for epoch in range(self.num_epochs):
136
            self._do_epoch(epoch, "train")
137
            with torch.no_grad():
138
                val_loss = self._do_epoch(epoch, "val")
139
                self.scheduler.step(val_loss)
140
            if self.display_plot:
141
                self._plot_train_history()
142
                
143
            if val_loss < self.best_loss:
144
                print(f"\n{'#'*20}\nSaved new checkpoint\n{'#'*20}\n")
145
                self.best_loss = val_loss
146
                torch.save(self.net.state_dict(), "best_model.pth")
147
            print()
148
        self._save_train_history()
149
            
150
    def _plot_train_history(self):
151
        data = [self.losses, self.dice_scores, self.jaccard_scores]
152
        colors = ['deepskyblue', "crimson"]
153
        labels = [
154
            f"""
155
            train loss {self.losses['train'][-1]}
156
            val loss {self.losses['val'][-1]}
157
            """,
158
            
159
            f"""
160
            train dice score {self.dice_scores['train'][-1]}
161
            val dice score {self.dice_scores['val'][-1]} 
162
            """, 
163
                  
164
            f"""
165
            train jaccard score {self.jaccard_scores['train'][-1]}
166
            val jaccard score {self.jaccard_scores['val'][-1]}
167
            """,
168
        ]
169
        
170
        clear_output(True)
171
        with plt.style.context("seaborn-dark-palette"):
172
            fig, axes = plt.subplots(3, 1, figsize=(8, 10))
173
            for i, ax in enumerate(axes):
174
                ax.plot(data[i]['val'], c=colors[0], label="val")
175
                ax.plot(data[i]['train'], c=colors[-1], label="train")
176
                ax.set_title(labels[i])
177
                ax.legend(loc="upper right")
178
                
179
            plt.tight_layout()
180
            plt.show()
181
            
182
    def load_predtrain_model(self,
183
                             state_path: str):
184
        self.net.load_state_dict(torch.load(state_path))
185
        print("Predtrain model loaded")
186
        
187
    def _save_train_history(self):
188
        """writing model weights and training logs to files."""
189
        torch.save(self.net.state_dict(),
190
                   f"last_epoch_model.pth")
191
192
        logs_ = [self.losses, self.dice_scores, self.jaccard_scores]
193
        log_names_ = ["_loss", "_dice", "_jaccard"]
194
        logs = [logs_[i][key] for i in list(range(len(logs_)))
195
                         for key in logs_[i]]
196
        log_names = [key+log_names_[i] 
197
                     for i in list(range(len(logs_))) 
198
                     for key in logs_[i]
199
                    ]
200
        pd.DataFrame(
201
            dict(zip(log_names, logs))
202
        ).to_csv("train_log.csv", index=False)