[249e74]: / training / trainer.py

Download this file

105 lines (70 with data), 2.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
import torch
import torch.nn as nn
from models import CNNBase, CAM, ResNet50, BloodNet3D, CellMovementEncoder
from weakly_supervised_localization import WSLoss
from self_supervised_distillation import DistillKL
from adversarial_robustness import FGSM
# Assuming you have defined CNNBase, CAM, ResNet50, and other necessary components
# 1. Weakly supervised localization
cnn = CNNBase()
cam = CAM(cnn)
loss_fn = nn.BCEWithLogitsLoss()
ws_loss_fn = WSLoss(loss_fn)
optimizer = torch.optim.Adam(cnn.parameters())
for images, labels in loader:
cnn.train() # Set to train mode
optimizer.zero_grad() # Clear gradients
preds = cnn(images)
cam_maps = cam(images)
loss = ws_loss_fn(preds, cam_maps, labels)
loss.backward()
optimizer.step()
# 2. Self-supervised data distillation
teacher = ResNet50() # Pretrained on 100x more data
student = CNNBase()
distill_loss = DistillKL(teacher, student)
optimizer_student = torch.optim.Adam(student.parameters()) # Separate optimizer for the student
for images in unsup_blood_cells:
student.train() # Set to train mode
optimizer_student.zero_grad() # Clear gradients
t_preds = teacher(images)
s_preds = student(images)
loss = distill_loss(s_preds, t_preds)
loss.backward()
optimizer_student.step()
# 3. Adversarial Training
adv_steps = 5
epsilon = 0.1
model = CNNBase() # Assuming you need a model instance for FGSM
loss_fn = nn.BCEWithLogitsLoss()
attacker = FGSM(model, loss_fn, epsilon)
optimizer_adv = torch.optim.Adam(model.parameters()) # Separate optimizer for the adversarial model
for images, labels in loader: # Assuming you have a loader for this as well
model.train() # Set to train mode
optimizer_adv.zero_grad() # Clear gradients
for i in range(adv_steps):
images = attacker.attack(images)
preds = model(images)
loss = loss_fn(preds, labels)
loss.backward()
optimizer_adv.step()
# 4. 3D Convolutions
bloodnet_3d = BloodNet3D()
optimizer_3d = torch.optim.Adam(bloodnet_3d.parameters()) # Separate optimizer for 3D Convolutions
bloodnet_3d.train() # Set to train mode
optimizer_3d.zero_grad() # Clear gradients
# Add code for 3D Convolutions if needed
optimizer_3d.step()
# 5. Video self-supervision
frame_order_loss = nn.MSELoss()
for blood_cell_video in videos:
bloodnet_video = CellMovementEncoder()
optimizer_video = torch.optim.Adam(bloodnet_video.parameters()) # Separate optimizer for video
frames = shuffleFrames(blood_cell_video)
bloodnet_video.train() # Set to train mode
optimizer_video.zero_grad() # Clear gradients
pred_order = bloodnet_video(frames) # Assuming frames are the input to CellMovementEncoder
true_order = torch.arange(len(frames))
loss = frame_order_loss(pred_order, true_order)
loss.backward()
optimizer_video.step()