Diff of /training/distributed.py [000000] .. [249e74]

Switch to unified view

a b/training/distributed.py
1
import torch
2
import torch.nn as nn
3
from weakly_supervised_localization import WSLoss
4
5
class Interpreter:
6
    def __init__(self, cnn, cam, loader):
7
        self.cnn = cnn
8
        self.cam = cam
9
        self.loader = loader
10
        self.loss_fn = nn.BCEWithLogitsLoss()
11
        self.ws_loss_fn = WSLoss(self.loss_fn)
12
        self.optimizer = torch.optim.Adam(self.cnn.parameters())
13
14
    def train_weakly_supervised(self):
15
        for images, labels in self.loader:
16
            self.optimizer.zero_grad()
17
18
            preds = self.cnn(images)
19
            cam_maps = self.cam(images)
20
21
            loss = self.ws_loss_fn(preds, cam_maps, labels)
22
            loss.backward()
23
24
            self.optimizer.step()