[249e74]: / training / interpreter.py

Download this file

25 lines (19 with data), 704 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
from weakly_supervised_localization import WSLoss
class Interpreter:
def __init__(self, cnn, cam, loader):
self.cnn = cnn
self.cam = cam
self.loader = loader
self.loss_fn = nn.BCEWithLogitsLoss()
self.ws_loss_fn = WSLoss(self.loss_fn)
self.optimizer = torch.optim.Adam(self.cnn.parameters())
def train_weakly_supervised(self):
for images, labels in self.loader:
self.optimizer.zero_grad()
preds = self.cnn(images)
cam_maps = self.cam(images)
loss = self.ws_loss_fn(preds, cam_maps, labels)
loss.backward()
self.optimizer.step()