|
a |
|
b/experimental/object_detection.py |
|
|
1 |
import lightnet as ln |
|
|
2 |
import torch |
|
|
3 |
import numpy as np, pandas as pd |
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
import brambox as bb |
|
|
6 |
import dask as da |
|
|
7 |
from datasets import BramboxPathFlowDataset |
|
|
8 |
import argparse, pickle |
|
|
9 |
from sklearn.model_selection import train_test_split |
|
|
10 |
|
|
|
11 |
# Settings |
|
|
12 |
ln.logger.setConsoleLevel('ERROR') # Only show error log messages |
|
|
13 |
bb.logger.setConsoleLevel('ERROR') |
|
|
14 |
# https://eavise.gitlab.io/lightnet/notes/02-B-engine.html |
|
|
15 |
|
|
|
16 |
p=argparse.ArgumentParser() |
|
|
17 |
p.add_argument('--num_classes',default=4,type=int) |
|
|
18 |
p.add_argument('--patch_size',default=512,type=int) |
|
|
19 |
p.add_argument('--patch_info_file',default='cell_info.db',type=str) |
|
|
20 |
p.add_argument('--input_dir',default='inputs',type=str) |
|
|
21 |
p.add_argument('--sample_p',default=1.,type=float) |
|
|
22 |
p.add_argument('--conf_thresh',default=0.01,type=float) |
|
|
23 |
p.add_argument('--nms_thresh',default=0.5,type=float) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
args=p.parse_args() |
|
|
27 |
np.random.seed(42) |
|
|
28 |
num_classes=args.num_classes+1 |
|
|
29 |
patch_size=args.patch_size |
|
|
30 |
batch_size=64 |
|
|
31 |
patch_info_file=args.patch_info_file |
|
|
32 |
input_dir=args.input_dir |
|
|
33 |
sample_p=args.sample_p |
|
|
34 |
conf_thresh=args.conf_thresh |
|
|
35 |
nms_thresh=args.nms_thresh |
|
|
36 |
anchors=pickle.load(open('anchors.pkl','rb')) |
|
|
37 |
|
|
|
38 |
annotation_file = 'annotations_bbox_{}.pkl'.format(patch_size) |
|
|
39 |
annotations=bb.io.load('pandas',annotation_file) |
|
|
40 |
|
|
|
41 |
if sample_p < 1.: |
|
|
42 |
annotations=annotations.sample(frac=sample_p) |
|
|
43 |
|
|
|
44 |
annotations_dict={} |
|
|
45 |
annotations_dict['train'],annotations_dict['test']=train_test_split(annotations) |
|
|
46 |
annotations_dict['train'],annotations_dict['val']=train_test_split(annotations_dict['train']) |
|
|
47 |
|
|
|
48 |
model=ln.models.Yolo(num_classes=num_classes,anchors=anchors.tolist()) |
|
|
49 |
|
|
|
50 |
loss = ln.network.loss.RegionLoss( |
|
|
51 |
num_classes=model.num_classes, |
|
|
52 |
anchors=model.anchors, |
|
|
53 |
stride=model.stride |
|
|
54 |
) |
|
|
55 |
|
|
|
56 |
transforms = ln.data.transform.Compose([ln.data.transform.RandomHSV( |
|
|
57 |
hue=1, |
|
|
58 |
saturation=2, |
|
|
59 |
value=2 |
|
|
60 |
)]) |
|
|
61 |
|
|
|
62 |
# Create HyperParameters |
|
|
63 |
params = ln.engine.HyperParameters( |
|
|
64 |
network=model, |
|
|
65 |
input_dimension = (patch_size,patch_size), |
|
|
66 |
mini_batch_size=16, |
|
|
67 |
batch_size=batch_size, |
|
|
68 |
max_batches=80000 |
|
|
69 |
) |
|
|
70 |
|
|
|
71 |
post = ln.data.transform.Compose([ |
|
|
72 |
ln.data.transform.GetBoundingBoxes( |
|
|
73 |
num_classes=params.network.num_classes, |
|
|
74 |
anchors=params.network.anchors, |
|
|
75 |
conf_thresh=conf_thresh, |
|
|
76 |
), |
|
|
77 |
|
|
|
78 |
ln.data.transform.NonMaxSuppression( |
|
|
79 |
nms_thresh=nms_thresh |
|
|
80 |
), |
|
|
81 |
|
|
|
82 |
ln.data.transform.TensorToBrambox( |
|
|
83 |
network_size=(patch_size,patch_size), |
|
|
84 |
# class_label_map=class_label_map, |
|
|
85 |
) |
|
|
86 |
]) |
|
|
87 |
|
|
|
88 |
datasets={k:BramboxPathFlowDataset(input_dir,patch_info_file, patch_size, annotations_dict[k], input_dimension=(patch_size,patch_size), class_label_map=None, identify=None, img_transform=None, anno_transform=None) for k in ['train','val','test']} |
|
|
89 |
# transforms |
|
|
90 |
|
|
|
91 |
params.loss = ln.network.loss.RegionLoss(params.network.num_classes, params.network.anchors) |
|
|
92 |
params.optim = torch.optim.SGD(params.network.parameters(), lr=1e-4) |
|
|
93 |
params.scheduler = ln.engine.SchedulerCompositor( |
|
|
94 |
# batch scheduler |
|
|
95 |
(0, torch.optim.lr_scheduler.CosineAnnealingLR(params.optim,T_max=200)) |
|
|
96 |
) |
|
|
97 |
|
|
|
98 |
dls = {k:ln.data.DataLoader( |
|
|
99 |
datasets[k], |
|
|
100 |
batch_size = batch_size, |
|
|
101 |
collate_fn = ln.data.brambox_collate # We want the data to be grouped as a list |
|
|
102 |
) for k in ['train','val','test']} |
|
|
103 |
|
|
|
104 |
params.val_loader=dls['val'] |
|
|
105 |
|
|
|
106 |
class CustomEngine(ln.engine.Engine): |
|
|
107 |
def start(self): |
|
|
108 |
""" Do whatever needs to be done before starting """ |
|
|
109 |
self.params.to(self.device) # Casting parameters to a certain device |
|
|
110 |
self.optim.zero_grad() # Make sure to start with no gradients |
|
|
111 |
self.loss_acc = [] # Loss accumulator |
|
|
112 |
|
|
|
113 |
def process_batch(self, data): |
|
|
114 |
""" Forward and backward pass """ |
|
|
115 |
data, target = data # Unpack |
|
|
116 |
#print(target) |
|
|
117 |
data=data.permute(0,3,1,2).float() |
|
|
118 |
if torch.cuda.is_available(): |
|
|
119 |
data=data.cuda() |
|
|
120 |
|
|
|
121 |
#print(data) |
|
|
122 |
|
|
|
123 |
output = self.network(data) |
|
|
124 |
#print(output) |
|
|
125 |
|
|
|
126 |
loss = self.loss(output, target) |
|
|
127 |
|
|
|
128 |
#print(loss) |
|
|
129 |
loss.backward() |
|
|
130 |
bbox=post(output) |
|
|
131 |
print(bbox) |
|
|
132 |
|
|
|
133 |
self.loss_acc.append(loss.item()) |
|
|
134 |
|
|
|
135 |
@ln.engine.Engine.batch_end(100) # how to pass in validation dataloader |
|
|
136 |
def val_loop(self): |
|
|
137 |
with torch.no_grad(): |
|
|
138 |
for i,data in enumerate(self.val_loader): |
|
|
139 |
if i > 100: |
|
|
140 |
break |
|
|
141 |
data, target = data |
|
|
142 |
data=data.permute(0,3,1,2).float() |
|
|
143 |
if torch.cuda.is_available(): |
|
|
144 |
data=data.cuda() |
|
|
145 |
output = self.network(data) |
|
|
146 |
#print(output) |
|
|
147 |
loss = self.loss(output, target) |
|
|
148 |
print(loss) |
|
|
149 |
bbox=post(output) |
|
|
150 |
print(bbox) |
|
|
151 |
if not i: |
|
|
152 |
bbox_final=[bbox] |
|
|
153 |
else: |
|
|
154 |
bbox_final.append(bbox) |
|
|
155 |
|
|
|
156 |
detections=pd.concat(bbox_final) |
|
|
157 |
print(detections) |
|
|
158 |
print(annotations_dict['val']) |
|
|
159 |
pr=bb.stat.pr(detections, annotations_dict['val'], threshold=0.5) |
|
|
160 |
auc=bb.stat.auc(pr) |
|
|
161 |
print('VAL AUC={}'.format(auc)) |
|
|
162 |
|
|
|
163 |
@ln.engine.Engine.batch_end(300) |
|
|
164 |
def save_model(self): |
|
|
165 |
self.params.save(f'backup-{self.batch}.state.pt') |
|
|
166 |
|
|
|
167 |
def train_batch(self): |
|
|
168 |
""" Weight update and logging """ |
|
|
169 |
self.optim.step() |
|
|
170 |
self.optim.zero_grad() |
|
|
171 |
|
|
|
172 |
batch_loss = sum(self.loss_acc) / len(self.loss_acc) |
|
|
173 |
self.loss_acc = [] |
|
|
174 |
self.log(f'Loss: {batch_loss}') |
|
|
175 |
|
|
|
176 |
def quit(self): |
|
|
177 |
if self.batch >= self.max_batches: # Should probably save weights here |
|
|
178 |
print('Reached end of training') |
|
|
179 |
return True |
|
|
180 |
return False |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
|
|
|
184 |
# Create engine |
|
|
185 |
engine = CustomEngine( |
|
|
186 |
params, dls['train'], # Dataloader (None) is not valid |
|
|
187 |
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') |
|
|
188 |
) |
|
|
189 |
|
|
|
190 |
for i in range(10): |
|
|
191 |
engine() |