a | b/training/adversary.py | ||
---|---|---|---|
1 | import torch |
||
2 | |||
3 | class Adversary: |
||
4 | def __init__(self, model, loss_fn, epsilon): |
||
5 | self.model = model |
||
6 | self.loss_fn = loss_fn |
||
7 | self.epsilon = epsilon |
||
8 | |||
9 | def train_adversarial(self, images, optimizer, adv_steps): |
||
10 | self.model.train() |
||
11 | for i in range(adv_steps): |
||
12 | optimizer.zero_grad() |
||
13 | images = self.attacker.attack(images) |
||
14 | |||
15 | preds = self.model(images) |
||
16 | loss = self.loss_fn(preds, labels) |
||
17 | loss.backward() |
||
18 | optimizer.step() |