Switch to side-by-side view

--- a
+++ b/experimental/object_detection.py
@@ -0,0 +1,191 @@
+import lightnet as ln
+import torch
+import numpy as np, pandas as pd
+import matplotlib.pyplot as plt
+import brambox as bb
+import dask as da
+from datasets import BramboxPathFlowDataset
+import argparse, pickle
+from sklearn.model_selection import train_test_split
+
+# Settings
+ln.logger.setConsoleLevel('ERROR')             # Only show error log messages
+bb.logger.setConsoleLevel('ERROR')
+# https://eavise.gitlab.io/lightnet/notes/02-B-engine.html
+
+p=argparse.ArgumentParser()
+p.add_argument('--num_classes',default=4,type=int)
+p.add_argument('--patch_size',default=512,type=int)
+p.add_argument('--patch_info_file',default='cell_info.db',type=str)
+p.add_argument('--input_dir',default='inputs',type=str)
+p.add_argument('--sample_p',default=1.,type=float)
+p.add_argument('--conf_thresh',default=0.01,type=float)
+p.add_argument('--nms_thresh',default=0.5,type=float)
+
+
+args=p.parse_args()
+np.random.seed(42)
+num_classes=args.num_classes+1
+patch_size=args.patch_size
+batch_size=64
+patch_info_file=args.patch_info_file
+input_dir=args.input_dir
+sample_p=args.sample_p
+conf_thresh=args.conf_thresh
+nms_thresh=args.nms_thresh
+anchors=pickle.load(open('anchors.pkl','rb'))
+
+annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size)
+annotations=bb.io.load('pandas',annotation_file)
+
+if sample_p < 1.:
+    annotations=annotations.sample(frac=sample_p)
+
+annotations_dict={}
+annotations_dict['train'],annotations_dict['test']=train_test_split(annotations)
+annotations_dict['train'],annotations_dict['val']=train_test_split(annotations_dict['train'])
+
+model=ln.models.Yolo(num_classes=num_classes,anchors=anchors.tolist())
+
+loss = ln.network.loss.RegionLoss(
+    num_classes=model.num_classes,
+    anchors=model.anchors,
+    stride=model.stride
+)
+
+transforms = ln.data.transform.Compose([ln.data.transform.RandomHSV(
+    hue=1,
+    saturation=2,
+    value=2
+)])
+
+# Create HyperParameters
+params = ln.engine.HyperParameters(
+    network=model,
+    input_dimension = (patch_size,patch_size),
+    mini_batch_size=16,
+    batch_size=batch_size,
+    max_batches=80000
+)
+
+post = ln.data.transform.Compose([
+    ln.data.transform.GetBoundingBoxes(
+        num_classes=params.network.num_classes,
+        anchors=params.network.anchors,
+        conf_thresh=conf_thresh,
+    ),
+
+    ln.data.transform.NonMaxSuppression(
+        nms_thresh=nms_thresh
+    ),
+
+    ln.data.transform.TensorToBrambox(
+        network_size=(patch_size,patch_size),
+        # class_label_map=class_label_map,
+    )
+])
+
+datasets={k:BramboxPathFlowDataset(input_dir,patch_info_file, patch_size, annotations_dict[k], input_dimension=(patch_size,patch_size), class_label_map=None, identify=None, img_transform=None, anno_transform=None) for k in ['train','val','test']}
+# transforms
+
+params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors)
+params.optim = torch.optim.SGD(params.network.parameters(), lr=1e-4)
+params.scheduler = ln.engine.SchedulerCompositor(
+    #   batch   scheduler
+        (0,     torch.optim.lr_scheduler.CosineAnnealingLR(params.optim,T_max=200))
+    )
+
+dls = {k:ln.data.DataLoader(
+    datasets[k],
+    batch_size = batch_size,
+    collate_fn = ln.data.brambox_collate   # We want the data to be grouped as a list
+    ) for k in ['train','val','test']}
+
+params.val_loader=dls['val']
+
+class CustomEngine(ln.engine.Engine):
+    def start(self):
+        """ Do whatever needs to be done before starting """
+        self.params.to(self.device)  # Casting parameters to a certain device
+        self.optim.zero_grad()       # Make sure to start with no gradients
+        self.loss_acc = []           # Loss accumulator
+
+    def process_batch(self, data):
+        """ Forward and backward pass """
+        data, target = data  # Unpack
+        #print(target)
+        data=data.permute(0,3,1,2).float()
+        if torch.cuda.is_available():
+            data=data.cuda()
+
+        #print(data)
+
+        output = self.network(data)
+        #print(output)
+
+        loss = self.loss(output, target)
+
+        #print(loss)
+        loss.backward()
+        bbox=post(output)
+        print(bbox)
+
+        self.loss_acc.append(loss.item())
+
+    @ln.engine.Engine.batch_end(100) # how to pass in validation dataloader
+    def val_loop(self):
+        with torch.no_grad():
+            for i,data in enumerate(self.val_loader):
+                if i > 100:
+                    break
+                data, target = data
+                data=data.permute(0,3,1,2).float()
+                if torch.cuda.is_available():
+                    data=data.cuda()
+                output = self.network(data)
+                #print(output)
+                loss = self.loss(output, target)
+                print(loss)
+                bbox=post(output)
+                print(bbox)
+                if not i:
+                    bbox_final=[bbox]
+                else:
+                    bbox_final.append(bbox)
+
+            detections=pd.concat(bbox_final)
+            print(detections)
+            print(annotations_dict['val'])
+            pr=bb.stat.pr(detections, annotations_dict['val'], threshold=0.5)
+            auc=bb.stat.auc(pr)
+            print('VAL AUC={}'.format(auc))
+
+    @ln.engine.Engine.batch_end(300)
+    def save_model(self):
+        self.params.save(f'backup-{self.batch}.state.pt')
+
+    def train_batch(self):
+        """ Weight update and logging """
+        self.optim.step()
+        self.optim.zero_grad()
+
+        batch_loss = sum(self.loss_acc) / len(self.loss_acc)
+        self.loss_acc = []
+        self.log(f'Loss: {batch_loss}')
+
+    def quit(self):
+        if self.batch >= self.max_batches:  # Should probably save weights here
+            print('Reached end of training')
+            return True
+        return False
+
+
+
+# Create engine
+engine = CustomEngine(
+    params, dls['train'],              # Dataloader (None) is not valid
+    device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+)
+
+for i in range(10):
+    engine()