--- a +++ b/training/trainer.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +from models import CNNBase, CAM, ResNet50, BloodNet3D, CellMovementEncoder +from weakly_supervised_localization import WSLoss +from self_supervised_distillation import DistillKL +from adversarial_robustness import FGSM + +# Assuming you have defined CNNBase, CAM, ResNet50, and other necessary components + +# 1. Weakly supervised localization +cnn = CNNBase() +cam = CAM(cnn) + +loss_fn = nn.BCEWithLogitsLoss() +ws_loss_fn = WSLoss(loss_fn) + +optimizer = torch.optim.Adam(cnn.parameters()) + +for images, labels in loader: + cnn.train() # Set to train mode + optimizer.zero_grad() # Clear gradients + + preds = cnn(images) + cam_maps = cam(images) + + loss = ws_loss_fn(preds, cam_maps, labels) + loss.backward() + + optimizer.step() + +# 2. Self-supervised data distillation +teacher = ResNet50() # Pretrained on 100x more data +student = CNNBase() + +distill_loss = DistillKL(teacher, student) + +optimizer_student = torch.optim.Adam(student.parameters()) # Separate optimizer for the student + +for images in unsup_blood_cells: + student.train() # Set to train mode + optimizer_student.zero_grad() # Clear gradients + + t_preds = teacher(images) + s_preds = student(images) + + loss = distill_loss(s_preds, t_preds) + loss.backward() + + optimizer_student.step() + +# 3. Adversarial Training +adv_steps = 5 +epsilon = 0.1 + +model = CNNBase() # Assuming you need a model instance for FGSM +loss_fn = nn.BCEWithLogitsLoss() +attacker = FGSM(model, loss_fn, epsilon) + +optimizer_adv = torch.optim.Adam(model.parameters()) # Separate optimizer for the adversarial model + +for images, labels in loader: # Assuming you have a loader for this as well + model.train() # Set to train mode + optimizer_adv.zero_grad() # Clear gradients + + for i in range(adv_steps): + images = attacker.attack(images) + + preds = model(images) + + loss = loss_fn(preds, labels) + loss.backward() + + optimizer_adv.step() + +# 4. 3D Convolutions +bloodnet_3d = BloodNet3D() +optimizer_3d = torch.optim.Adam(bloodnet_3d.parameters()) # Separate optimizer for 3D Convolutions + +bloodnet_3d.train() # Set to train mode +optimizer_3d.zero_grad() # Clear gradients + +# Add code for 3D Convolutions if needed + +optimizer_3d.step() + +# 5. Video self-supervision +frame_order_loss = nn.MSELoss() + +for blood_cell_video in videos: + bloodnet_video = CellMovementEncoder() + optimizer_video = torch.optim.Adam(bloodnet_video.parameters()) # Separate optimizer for video + + frames = shuffleFrames(blood_cell_video) + + bloodnet_video.train() # Set to train mode + optimizer_video.zero_grad() # Clear gradients + + pred_order = bloodnet_video(frames) # Assuming frames are the input to CellMovementEncoder + true_order = torch.arange(len(frames)) + + loss = frame_order_loss(pred_order, true_order) + loss.backward() + + optimizer_video.step()