--- a
+++ b/training/adversary.py
@@ -0,0 +1,18 @@
+import torch
+
+class Adversary:
+    def __init__(self, model, loss_fn, epsilon):
+        self.model = model
+        self.loss_fn = loss_fn
+        self.epsilon = epsilon
+
+    def train_adversarial(self, images, optimizer, adv_steps):
+        self.model.train()
+        for i in range(adv_steps):
+            optimizer.zero_grad()
+            images = self.attacker.attack(images)
+
+        preds = self.model(images)
+        loss = self.loss_fn(preds, labels)
+        loss.backward()
+        optimizer.step()