Diff of /training/trainer.py [000000] .. [249e74]

Switch to side-by-side view

--- a
+++ b/training/trainer.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+from models import CNNBase, CAM, ResNet50, BloodNet3D, CellMovementEncoder
+from weakly_supervised_localization import WSLoss
+from self_supervised_distillation import DistillKL
+from adversarial_robustness import FGSM
+
+# Assuming you have defined CNNBase, CAM, ResNet50, and other necessary components
+
+# 1. Weakly supervised localization
+cnn = CNNBase()
+cam = CAM(cnn)
+
+loss_fn = nn.BCEWithLogitsLoss()
+ws_loss_fn = WSLoss(loss_fn)
+
+optimizer = torch.optim.Adam(cnn.parameters())
+
+for images, labels in loader:
+    cnn.train()  # Set to train mode
+    optimizer.zero_grad()  # Clear gradients
+
+    preds = cnn(images)  
+    cam_maps = cam(images)
+
+    loss = ws_loss_fn(preds, cam_maps, labels)
+    loss.backward()
+
+    optimizer.step()
+
+# 2. Self-supervised data distillation
+teacher = ResNet50()  # Pretrained on 100x more data
+student = CNNBase()
+
+distill_loss = DistillKL(teacher, student)
+
+optimizer_student = torch.optim.Adam(student.parameters())  # Separate optimizer for the student
+
+for images in unsup_blood_cells:
+    student.train()  # Set to train mode
+    optimizer_student.zero_grad()  # Clear gradients
+
+    t_preds = teacher(images) 
+    s_preds = student(images)
+
+    loss = distill_loss(s_preds, t_preds)
+    loss.backward()
+
+    optimizer_student.step()
+
+# 3. Adversarial Training 
+adv_steps = 5
+epsilon = 0.1
+
+model = CNNBase()  # Assuming you need a model instance for FGSM
+loss_fn = nn.BCEWithLogitsLoss()
+attacker = FGSM(model, loss_fn, epsilon)
+
+optimizer_adv = torch.optim.Adam(model.parameters())  # Separate optimizer for the adversarial model
+
+for images, labels in loader:  # Assuming you have a loader for this as well
+    model.train()  # Set to train mode
+    optimizer_adv.zero_grad()  # Clear gradients
+
+    for i in range(adv_steps):
+        images = attacker.attack(images)
+
+    preds = model(images)
+
+    loss = loss_fn(preds, labels)
+    loss.backward()
+
+    optimizer_adv.step()
+
+# 4. 3D Convolutions
+bloodnet_3d = BloodNet3D()
+optimizer_3d = torch.optim.Adam(bloodnet_3d.parameters())  # Separate optimizer for 3D Convolutions
+
+bloodnet_3d.train()  # Set to train mode
+optimizer_3d.zero_grad()  # Clear gradients
+
+# Add code for 3D Convolutions if needed
+
+optimizer_3d.step()
+
+# 5. Video self-supervision
+frame_order_loss = nn.MSELoss()
+
+for blood_cell_video in videos:
+    bloodnet_video = CellMovementEncoder()  
+    optimizer_video = torch.optim.Adam(bloodnet_video.parameters())  # Separate optimizer for video
+
+    frames = shuffleFrames(blood_cell_video)
+
+    bloodnet_video.train()  # Set to train mode
+    optimizer_video.zero_grad()  # Clear gradients
+
+    pred_order = bloodnet_video(frames)  # Assuming frames are the input to CellMovementEncoder
+    true_order = torch.arange(len(frames))
+
+    loss = frame_order_loss(pred_order, true_order)
+    loss.backward()
+
+    optimizer_video.step()