[249e74]: / training / adversary.py

Download this file

19 lines (15 with data), 507 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
import torch
class Adversary:
def __init__(self, model, loss_fn, epsilon):
self.model = model
self.loss_fn = loss_fn
self.epsilon = epsilon
def train_adversarial(self, images, optimizer, adv_steps):
self.model.train()
for i in range(adv_steps):
optimizer.zero_grad()
images = self.attacker.attack(images)
preds = self.model(images)
loss = self.loss_fn(preds, labels)
loss.backward()
optimizer.step()