Diff of /training/adversary.py [000000] .. [249e74]

Switch to unified view

a b/training/adversary.py
1
import torch
2
3
class Adversary:
4
    def __init__(self, model, loss_fn, epsilon):
5
        self.model = model
6
        self.loss_fn = loss_fn
7
        self.epsilon = epsilon
8
9
    def train_adversarial(self, images, optimizer, adv_steps):
10
        self.model.train()
11
        for i in range(adv_steps):
12
            optimizer.zero_grad()
13
            images = self.attacker.attack(images)
14
15
        preds = self.model(images)
16
        loss = self.loss_fn(preds, labels)
17
        loss.backward()
18
        optimizer.step()