|
a |
|
b/training/trainer.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from models import CNNBase, CAM, ResNet50, BloodNet3D, CellMovementEncoder |
|
|
4 |
from weakly_supervised_localization import WSLoss |
|
|
5 |
from self_supervised_distillation import DistillKL |
|
|
6 |
from adversarial_robustness import FGSM |
|
|
7 |
|
|
|
8 |
# Assuming you have defined CNNBase, CAM, ResNet50, and other necessary components |
|
|
9 |
|
|
|
10 |
# 1. Weakly supervised localization |
|
|
11 |
cnn = CNNBase() |
|
|
12 |
cam = CAM(cnn) |
|
|
13 |
|
|
|
14 |
loss_fn = nn.BCEWithLogitsLoss() |
|
|
15 |
ws_loss_fn = WSLoss(loss_fn) |
|
|
16 |
|
|
|
17 |
optimizer = torch.optim.Adam(cnn.parameters()) |
|
|
18 |
|
|
|
19 |
for images, labels in loader: |
|
|
20 |
cnn.train() # Set to train mode |
|
|
21 |
optimizer.zero_grad() # Clear gradients |
|
|
22 |
|
|
|
23 |
preds = cnn(images) |
|
|
24 |
cam_maps = cam(images) |
|
|
25 |
|
|
|
26 |
loss = ws_loss_fn(preds, cam_maps, labels) |
|
|
27 |
loss.backward() |
|
|
28 |
|
|
|
29 |
optimizer.step() |
|
|
30 |
|
|
|
31 |
# 2. Self-supervised data distillation |
|
|
32 |
teacher = ResNet50() # Pretrained on 100x more data |
|
|
33 |
student = CNNBase() |
|
|
34 |
|
|
|
35 |
distill_loss = DistillKL(teacher, student) |
|
|
36 |
|
|
|
37 |
optimizer_student = torch.optim.Adam(student.parameters()) # Separate optimizer for the student |
|
|
38 |
|
|
|
39 |
for images in unsup_blood_cells: |
|
|
40 |
student.train() # Set to train mode |
|
|
41 |
optimizer_student.zero_grad() # Clear gradients |
|
|
42 |
|
|
|
43 |
t_preds = teacher(images) |
|
|
44 |
s_preds = student(images) |
|
|
45 |
|
|
|
46 |
loss = distill_loss(s_preds, t_preds) |
|
|
47 |
loss.backward() |
|
|
48 |
|
|
|
49 |
optimizer_student.step() |
|
|
50 |
|
|
|
51 |
# 3. Adversarial Training |
|
|
52 |
adv_steps = 5 |
|
|
53 |
epsilon = 0.1 |
|
|
54 |
|
|
|
55 |
model = CNNBase() # Assuming you need a model instance for FGSM |
|
|
56 |
loss_fn = nn.BCEWithLogitsLoss() |
|
|
57 |
attacker = FGSM(model, loss_fn, epsilon) |
|
|
58 |
|
|
|
59 |
optimizer_adv = torch.optim.Adam(model.parameters()) # Separate optimizer for the adversarial model |
|
|
60 |
|
|
|
61 |
for images, labels in loader: # Assuming you have a loader for this as well |
|
|
62 |
model.train() # Set to train mode |
|
|
63 |
optimizer_adv.zero_grad() # Clear gradients |
|
|
64 |
|
|
|
65 |
for i in range(adv_steps): |
|
|
66 |
images = attacker.attack(images) |
|
|
67 |
|
|
|
68 |
preds = model(images) |
|
|
69 |
|
|
|
70 |
loss = loss_fn(preds, labels) |
|
|
71 |
loss.backward() |
|
|
72 |
|
|
|
73 |
optimizer_adv.step() |
|
|
74 |
|
|
|
75 |
# 4. 3D Convolutions |
|
|
76 |
bloodnet_3d = BloodNet3D() |
|
|
77 |
optimizer_3d = torch.optim.Adam(bloodnet_3d.parameters()) # Separate optimizer for 3D Convolutions |
|
|
78 |
|
|
|
79 |
bloodnet_3d.train() # Set to train mode |
|
|
80 |
optimizer_3d.zero_grad() # Clear gradients |
|
|
81 |
|
|
|
82 |
# Add code for 3D Convolutions if needed |
|
|
83 |
|
|
|
84 |
optimizer_3d.step() |
|
|
85 |
|
|
|
86 |
# 5. Video self-supervision |
|
|
87 |
frame_order_loss = nn.MSELoss() |
|
|
88 |
|
|
|
89 |
for blood_cell_video in videos: |
|
|
90 |
bloodnet_video = CellMovementEncoder() |
|
|
91 |
optimizer_video = torch.optim.Adam(bloodnet_video.parameters()) # Separate optimizer for video |
|
|
92 |
|
|
|
93 |
frames = shuffleFrames(blood_cell_video) |
|
|
94 |
|
|
|
95 |
bloodnet_video.train() # Set to train mode |
|
|
96 |
optimizer_video.zero_grad() # Clear gradients |
|
|
97 |
|
|
|
98 |
pred_order = bloodnet_video(frames) # Assuming frames are the input to CellMovementEncoder |
|
|
99 |
true_order = torch.arange(len(frames)) |
|
|
100 |
|
|
|
101 |
loss = frame_order_loss(pred_order, true_order) |
|
|
102 |
loss.backward() |
|
|
103 |
|
|
|
104 |
optimizer_video.step() |