|
a |
|
b/training/interpreter.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from weakly_supervised_localization import WSLoss |
|
|
4 |
|
|
|
5 |
class Interpreter: |
|
|
6 |
def __init__(self, cnn, cam, loader): |
|
|
7 |
self.cnn = cnn |
|
|
8 |
self.cam = cam |
|
|
9 |
self.loader = loader |
|
|
10 |
self.loss_fn = nn.BCEWithLogitsLoss() |
|
|
11 |
self.ws_loss_fn = WSLoss(self.loss_fn) |
|
|
12 |
self.optimizer = torch.optim.Adam(self.cnn.parameters()) |
|
|
13 |
|
|
|
14 |
def train_weakly_supervised(self): |
|
|
15 |
for images, labels in self.loader: |
|
|
16 |
self.optimizer.zero_grad() |
|
|
17 |
|
|
|
18 |
preds = self.cnn(images) |
|
|
19 |
cam_maps = self.cam(images) |
|
|
20 |
|
|
|
21 |
loss = self.ws_loss_fn(preds, cam_maps, labels) |
|
|
22 |
loss.backward() |
|
|
23 |
|
|
|
24 |
self.optimizer.step() |