Switch to unified view

a b/experimental/object_detection.py
1
import lightnet as ln
2
import torch
3
import numpy as np, pandas as pd
4
import matplotlib.pyplot as plt
5
import brambox as bb
6
import dask as da
7
from datasets import BramboxPathFlowDataset
8
import argparse, pickle
9
from sklearn.model_selection import train_test_split
10
11
# Settings
12
ln.logger.setConsoleLevel('ERROR')             # Only show error log messages
13
bb.logger.setConsoleLevel('ERROR')
14
# https://eavise.gitlab.io/lightnet/notes/02-B-engine.html
15
16
p=argparse.ArgumentParser()
17
p.add_argument('--num_classes',default=4,type=int)
18
p.add_argument('--patch_size',default=512,type=int)
19
p.add_argument('--patch_info_file',default='cell_info.db',type=str)
20
p.add_argument('--input_dir',default='inputs',type=str)
21
p.add_argument('--sample_p',default=1.,type=float)
22
p.add_argument('--conf_thresh',default=0.01,type=float)
23
p.add_argument('--nms_thresh',default=0.5,type=float)
24
25
26
args=p.parse_args()
27
np.random.seed(42)
28
num_classes=args.num_classes+1
29
patch_size=args.patch_size
30
batch_size=64
31
patch_info_file=args.patch_info_file
32
input_dir=args.input_dir
33
sample_p=args.sample_p
34
conf_thresh=args.conf_thresh
35
nms_thresh=args.nms_thresh
36
anchors=pickle.load(open('anchors.pkl','rb'))
37
38
annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size)
39
annotations=bb.io.load('pandas',annotation_file)
40
41
if sample_p < 1.:
42
    annotations=annotations.sample(frac=sample_p)
43
44
annotations_dict={}
45
annotations_dict['train'],annotations_dict['test']=train_test_split(annotations)
46
annotations_dict['train'],annotations_dict['val']=train_test_split(annotations_dict['train'])
47
48
model=ln.models.Yolo(num_classes=num_classes,anchors=anchors.tolist())
49
50
loss = ln.network.loss.RegionLoss(
51
    num_classes=model.num_classes,
52
    anchors=model.anchors,
53
    stride=model.stride
54
)
55
56
transforms = ln.data.transform.Compose([ln.data.transform.RandomHSV(
57
    hue=1,
58
    saturation=2,
59
    value=2
60
)])
61
62
# Create HyperParameters
63
params = ln.engine.HyperParameters(
64
    network=model,
65
    input_dimension = (patch_size,patch_size),
66
    mini_batch_size=16,
67
    batch_size=batch_size,
68
    max_batches=80000
69
)
70
71
post = ln.data.transform.Compose([
72
    ln.data.transform.GetBoundingBoxes(
73
        num_classes=params.network.num_classes,
74
        anchors=params.network.anchors,
75
        conf_thresh=conf_thresh,
76
    ),
77
78
    ln.data.transform.NonMaxSuppression(
79
        nms_thresh=nms_thresh
80
    ),
81
82
    ln.data.transform.TensorToBrambox(
83
        network_size=(patch_size,patch_size),
84
        # class_label_map=class_label_map,
85
    )
86
])
87
88
datasets={k:BramboxPathFlowDataset(input_dir,patch_info_file, patch_size, annotations_dict[k], input_dimension=(patch_size,patch_size), class_label_map=None, identify=None, img_transform=None, anno_transform=None) for k in ['train','val','test']}
89
# transforms
90
91
params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors)
92
params.optim = torch.optim.SGD(params.network.parameters(), lr=1e-4)
93
params.scheduler = ln.engine.SchedulerCompositor(
94
    #   batch   scheduler
95
        (0,     torch.optim.lr_scheduler.CosineAnnealingLR(params.optim,T_max=200))
96
    )
97
98
dls = {k:ln.data.DataLoader(
99
    datasets[k],
100
    batch_size = batch_size,
101
    collate_fn = ln.data.brambox_collate   # We want the data to be grouped as a list
102
    ) for k in ['train','val','test']}
103
104
params.val_loader=dls['val']
105
106
class CustomEngine(ln.engine.Engine):
107
    def start(self):
108
        """ Do whatever needs to be done before starting """
109
        self.params.to(self.device)  # Casting parameters to a certain device
110
        self.optim.zero_grad()       # Make sure to start with no gradients
111
        self.loss_acc = []           # Loss accumulator
112
113
    def process_batch(self, data):
114
        """ Forward and backward pass """
115
        data, target = data  # Unpack
116
        #print(target)
117
        data=data.permute(0,3,1,2).float()
118
        if torch.cuda.is_available():
119
            data=data.cuda()
120
121
        #print(data)
122
123
        output = self.network(data)
124
        #print(output)
125
126
        loss = self.loss(output, target)
127
128
        #print(loss)
129
        loss.backward()
130
        bbox=post(output)
131
        print(bbox)
132
133
        self.loss_acc.append(loss.item())
134
135
    @ln.engine.Engine.batch_end(100) # how to pass in validation dataloader
136
    def val_loop(self):
137
        with torch.no_grad():
138
            for i,data in enumerate(self.val_loader):
139
                if i > 100:
140
                    break
141
                data, target = data
142
                data=data.permute(0,3,1,2).float()
143
                if torch.cuda.is_available():
144
                    data=data.cuda()
145
                output = self.network(data)
146
                #print(output)
147
                loss = self.loss(output, target)
148
                print(loss)
149
                bbox=post(output)
150
                print(bbox)
151
                if not i:
152
                    bbox_final=[bbox]
153
                else:
154
                    bbox_final.append(bbox)
155
156
            detections=pd.concat(bbox_final)
157
            print(detections)
158
            print(annotations_dict['val'])
159
            pr=bb.stat.pr(detections, annotations_dict['val'], threshold=0.5)
160
            auc=bb.stat.auc(pr)
161
            print('VAL AUC={}'.format(auc))
162
163
    @ln.engine.Engine.batch_end(300)
164
    def save_model(self):
165
        self.params.save(f'backup-{self.batch}.state.pt')
166
167
    def train_batch(self):
168
        """ Weight update and logging """
169
        self.optim.step()
170
        self.optim.zero_grad()
171
172
        batch_loss = sum(self.loss_acc) / len(self.loss_acc)
173
        self.loss_acc = []
174
        self.log(f'Loss: {batch_loss}')
175
176
    def quit(self):
177
        if self.batch >= self.max_batches:  # Should probably save weights here
178
            print('Reached end of training')
179
            return True
180
        return False
181
182
183
184
# Create engine
185
engine = CustomEngine(
186
    params, dls['train'],              # Dataloader (None) is not valid
187
    device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
188
)
189
190
for i in range(10):
191
    engine()