--- a +++ b/training/adversary.py @@ -0,0 +1,18 @@ +import torch + +class Adversary: + def __init__(self, model, loss_fn, epsilon): + self.model = model + self.loss_fn = loss_fn + self.epsilon = epsilon + + def train_adversarial(self, images, optimizer, adv_steps): + self.model.train() + for i in range(adv_steps): + optimizer.zero_grad() + images = self.attacker.attack(images) + + preds = self.model(images) + loss = self.loss_fn(preds, labels) + loss.backward() + optimizer.step()