--- a
+++ b/lightning/detection.py
@@ -0,0 +1,81 @@
+import pytorch_lightning as pl
+import torch
+import torchvision.transforms as T
+from torch.utils.data import DataLoader
+from ignite.metrics import Accuracy
+
+from models import get_model
+from eval import DetectionEvaluator
+from data import ImageDetectionDemoDataset
+from util import constants as C
+from .logger import TFLogger
+
+import pdb
+
+class DetectionTask(pl.LightningModule, TFLogger):
+    """Standard interface for the trainer to interact with the model."""
+
+    def __init__(self, params):
+        super().__init__() #Initialize parent classes (pl.LightningModule params)
+        self.save_hyperparameters(params) #Save hyperparameters to experiment directory, pytorch lightning function
+        self.model = get_model(params) #Instantiates model from model folder
+        self.evaluator = DetectionEvaluator()
+
+    def training_step(self, batch, batch_nb): #Batch of data from train dataloader passed here
+        losses = self.model.forward(batch)
+        loss = torch.stack(list(losses.values())).mean()
+        return loss
+
+    def validation_step(self, batch, batch_nb): #Called once for every batch
+
+        losses = self.model.forward(batch)
+        loss = torch.stack(list(losses.values())).mean()
+        preds = self.model.infer(batch)
+        self.evaluator.process(batch, preds)
+        return loss
+
+    def validation_epoch_end(self, outputs): #outputs are loss tensors from validation step
+        avg_loss = torch.stack(outputs).mean()
+        self.log("val_loss", avg_loss)
+        metrics = self.evaluator.evaluate()
+        self.evaluator.reset()
+        self.log_dict(metrics, prog_bar=True)
+
+    def test_step(self, batch, batch_nb):
+        preds = self.model.infer(batch)
+        self.evaluator.process(batch, preds)
+
+    def test_epoch_end(self, outputs):
+        metrics = self.evaluator.evaluate()
+        self.log_dict(metrics)
+        return metrics
+
+    def configure_optimizers(self):
+        return [torch.optim.Adam(self.parameters(), lr=0.02)]
+
+    def train_dataloader(self): #Called during init
+        dataset = ImageDetectionDemoDataset() #For specific examples
+        return DataLoader(dataset, shuffle=True, #For entire batch
+                          batch_size=2, num_workers=8,
+                          collate_fn=lambda x: x)
+
+    def val_dataloader(self): #Called during init
+        dataset = ImageDetectionDemoDataset() 
+        return DataLoader(dataset, shuffle=False,
+                batch_size=1, num_workers=8, collate_fn=lambda x: x)
+
+    def test_dataloader(self): #Called during init
+        dataset = ImageDetectionDemoDataset() 
+        return DataLoader(dataset, shuffle=False,
+                batch_size=1, num_workers=8, collate_fn=lambda x: x)
+
+    #Process
+    #1. Call Trainer.fit
+    #2. Run validation step twice
+    #3. Run validation epoch end as a dummy run
+    #4. Call training step for all training batches, till end of epoch
+    #5. Call validation step for all validation batches till validation batches exhausted
+    #6. Compute  validation metric at validation epoch end, log it, save checkpoint if validation metrics improve
+    #7. Return to training step 4 and continue
+
+    #Test steps are only called when you call main.py
\ No newline at end of file