Diff of /training/distributed.py [000000] .. [249e74]

Switch to side-by-side view

--- a
+++ b/training/distributed.py
@@ -0,0 +1,24 @@
+import torch
+import torch.nn as nn
+from weakly_supervised_localization import WSLoss
+
+class Interpreter:
+    def __init__(self, cnn, cam, loader):
+        self.cnn = cnn
+        self.cam = cam
+        self.loader = loader
+        self.loss_fn = nn.BCEWithLogitsLoss()
+        self.ws_loss_fn = WSLoss(self.loss_fn)
+        self.optimizer = torch.optim.Adam(self.cnn.parameters())
+
+    def train_weakly_supervised(self):
+        for images, labels in self.loader:
+            self.optimizer.zero_grad()
+
+            preds = self.cnn(images)
+            cam_maps = self.cam(images)
+
+            loss = self.ws_loss_fn(preds, cam_maps, labels)
+            loss.backward()
+
+            self.optimizer.step()