Diff of /training/trainer.py [000000] .. [249e74]

Switch to unified view

a b/training/trainer.py
1
import torch
2
import torch.nn as nn
3
from models import CNNBase, CAM, ResNet50, BloodNet3D, CellMovementEncoder
4
from weakly_supervised_localization import WSLoss
5
from self_supervised_distillation import DistillKL
6
from adversarial_robustness import FGSM
7
8
# Assuming you have defined CNNBase, CAM, ResNet50, and other necessary components
9
10
# 1. Weakly supervised localization
11
cnn = CNNBase()
12
cam = CAM(cnn)
13
14
loss_fn = nn.BCEWithLogitsLoss()
15
ws_loss_fn = WSLoss(loss_fn)
16
17
optimizer = torch.optim.Adam(cnn.parameters())
18
19
for images, labels in loader:
20
    cnn.train()  # Set to train mode
21
    optimizer.zero_grad()  # Clear gradients
22
23
    preds = cnn(images)  
24
    cam_maps = cam(images)
25
26
    loss = ws_loss_fn(preds, cam_maps, labels)
27
    loss.backward()
28
29
    optimizer.step()
30
31
# 2. Self-supervised data distillation
32
teacher = ResNet50()  # Pretrained on 100x more data
33
student = CNNBase()
34
35
distill_loss = DistillKL(teacher, student)
36
37
optimizer_student = torch.optim.Adam(student.parameters())  # Separate optimizer for the student
38
39
for images in unsup_blood_cells:
40
    student.train()  # Set to train mode
41
    optimizer_student.zero_grad()  # Clear gradients
42
43
    t_preds = teacher(images) 
44
    s_preds = student(images)
45
46
    loss = distill_loss(s_preds, t_preds)
47
    loss.backward()
48
49
    optimizer_student.step()
50
51
# 3. Adversarial Training 
52
adv_steps = 5
53
epsilon = 0.1
54
55
model = CNNBase()  # Assuming you need a model instance for FGSM
56
loss_fn = nn.BCEWithLogitsLoss()
57
attacker = FGSM(model, loss_fn, epsilon)
58
59
optimizer_adv = torch.optim.Adam(model.parameters())  # Separate optimizer for the adversarial model
60
61
for images, labels in loader:  # Assuming you have a loader for this as well
62
    model.train()  # Set to train mode
63
    optimizer_adv.zero_grad()  # Clear gradients
64
65
    for i in range(adv_steps):
66
        images = attacker.attack(images)
67
68
    preds = model(images)
69
70
    loss = loss_fn(preds, labels)
71
    loss.backward()
72
73
    optimizer_adv.step()
74
75
# 4. 3D Convolutions
76
bloodnet_3d = BloodNet3D()
77
optimizer_3d = torch.optim.Adam(bloodnet_3d.parameters())  # Separate optimizer for 3D Convolutions
78
79
bloodnet_3d.train()  # Set to train mode
80
optimizer_3d.zero_grad()  # Clear gradients
81
82
# Add code for 3D Convolutions if needed
83
84
optimizer_3d.step()
85
86
# 5. Video self-supervision
87
frame_order_loss = nn.MSELoss()
88
89
for blood_cell_video in videos:
90
    bloodnet_video = CellMovementEncoder()  
91
    optimizer_video = torch.optim.Adam(bloodnet_video.parameters())  # Separate optimizer for video
92
93
    frames = shuffleFrames(blood_cell_video)
94
95
    bloodnet_video.train()  # Set to train mode
96
    optimizer_video.zero_grad()  # Clear gradients
97
98
    pred_order = bloodnet_video(frames)  # Assuming frames are the input to CellMovementEncoder
99
    true_order = torch.arange(len(frames))
100
101
    loss = frame_order_loss(pred_order, true_order)
102
    loss.backward()
103
104
    optimizer_video.step()