|
a |
|
b/tasks/cls-train.py |
|
|
1 |
""" Classifier Network trainig |
|
|
2 |
""" |
|
|
3 |
|
|
|
4 |
import argparse |
|
|
5 |
import json |
|
|
6 |
import os |
|
|
7 |
import sys |
|
|
8 |
import time |
|
|
9 |
from tqdm.autonotebook import tqdm |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
from torch import nn, optim |
|
|
13 |
import torchinfo |
|
|
14 |
|
|
|
15 |
import numpy as np |
|
|
16 |
from sklearn.model_selection import train_test_split as sk_train_test_split |
|
|
17 |
|
|
|
18 |
sys.path.append(os.getcwd()) |
|
|
19 |
import utilities.runUtils as rutl |
|
|
20 |
import utilities.logUtils as lutl |
|
|
21 |
from utilities.metricUtils import MultiClassMetrics |
|
|
22 |
from algorithms.classifiers import ClassifierNet |
|
|
23 |
from datacode.ultrasound_data import ClassifyDataFromCSV, get_class_weights |
|
|
24 |
from datacode.augmentations import ClassifierTransform |
|
|
25 |
|
|
|
26 |
print(f"Pytorch version: {torch.__version__}") |
|
|
27 |
print(f"cuda version: {torch.version.cuda}") |
|
|
28 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
29 |
print("Device Used:", device) |
|
|
30 |
|
|
|
31 |
###============================= Configure and Setup =========================== |
|
|
32 |
|
|
|
33 |
CFG = rutl.ObjDict( |
|
|
34 |
data_folder = "/home/joseph.benjamin/WERK/fetal-ultrasound/data/Fetal-UltraSound/US-Planes-Heart-Views-V3", |
|
|
35 |
balance_data = False, #while loading in dataloader; removed |
|
|
36 |
seed = 1792, #previously 73 |
|
|
37 |
|
|
|
38 |
epochs = 100, |
|
|
39 |
image_size = 256, |
|
|
40 |
batch_size = 128, |
|
|
41 |
workers = 16, |
|
|
42 |
learning_rate = 1e-3, |
|
|
43 |
weight_decay = 1e-6, |
|
|
44 |
|
|
|
45 |
featx_arch = "resnet50", |
|
|
46 |
featx_pretrain = "IMAGENET-1K" , # "IMAGENET-1K" or None |
|
|
47 |
featx_freeze = False, |
|
|
48 |
featx_bnorm = False, |
|
|
49 |
featx_dropout = 0.5, |
|
|
50 |
clsfy_layers = [5], #First mlp inwill be set w.r.t FeatureExtractor |
|
|
51 |
clsfy_dropout = 0.0, |
|
|
52 |
|
|
|
53 |
checkpoint_dir = "hypotheses/#dummy/Classify/trail-002", |
|
|
54 |
disable_tqdm = False, #True--> to disable |
|
|
55 |
restart_training = True |
|
|
56 |
) |
|
|
57 |
|
|
|
58 |
### ---------------------------------------------------------------------------- |
|
|
59 |
# CLI TAKES PRECENCE OVER JSON CONFIG |
|
|
60 |
# e.g CLI overwrites the value set for featx-pretain in JSON while running |
|
|
61 |
# without CLI default values form dict will be used |
|
|
62 |
|
|
|
63 |
parser = argparse.ArgumentParser(description='Classification task') |
|
|
64 |
parser.add_argument('--load-json', type=str, metavar='JSON', |
|
|
65 |
help='Load settings from file in json format. Command line options override values in file.') |
|
|
66 |
|
|
|
67 |
parser.add_argument('--seed', type=int, metavar='INT', |
|
|
68 |
help='add batchnorm between feature extractor and classifier') |
|
|
69 |
|
|
|
70 |
parser.add_argument('--featx-freeze', type=bool, metavar='BOOL', |
|
|
71 |
help='freeze pretrain or not') |
|
|
72 |
|
|
|
73 |
parser.add_argument('--featx-bnorm', type=bool, metavar='BOOL', |
|
|
74 |
help='add batchnorm between feature extractor and classifier') |
|
|
75 |
|
|
|
76 |
parser.add_argument('--featx-pretrain', type=str, metavar='PATH', |
|
|
77 |
help='Set from where to load the prestrained weight from') |
|
|
78 |
|
|
|
79 |
parser.add_argument('--checkpoint-dir', type=str, metavar='PATH', |
|
|
80 |
help='Load settings from file in json format. Command line options override values in file.') |
|
|
81 |
|
|
|
82 |
|
|
|
83 |
args = parser.parse_args() |
|
|
84 |
|
|
|
85 |
if args.load_json: |
|
|
86 |
with open(args.load_json, 'rt') as f: |
|
|
87 |
CFG.__dict__.update(json.load(f)) |
|
|
88 |
|
|
|
89 |
for arg in vars(args): |
|
|
90 |
att = getattr(args, arg) |
|
|
91 |
if att: CFG.__dict__[arg] = att |
|
|
92 |
|
|
|
93 |
### ---------------------------------------------------------------------------- |
|
|
94 |
CFG.gLogPath = CFG.checkpoint_dir |
|
|
95 |
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' |
|
|
96 |
|
|
|
97 |
### ============================================================================ |
|
|
98 |
|
|
|
99 |
def getDataLoaders(data_percent=None): |
|
|
100 |
## Augumentations |
|
|
101 |
train_transforms =ClassifierTransform(image_size=CFG.image_size, mode="train") |
|
|
102 |
valid_transforms =ClassifierTransform(image_size=CFG.image_size, mode="infer") |
|
|
103 |
|
|
|
104 |
## Dataset Class |
|
|
105 |
traindataset = ClassifyDataFromCSV(CFG.data_folder, |
|
|
106 |
CFG.data_folder+"/trainV3.csv", |
|
|
107 |
transform = train_transforms,) |
|
|
108 |
validdataset = ClassifyDataFromCSV(CFG.data_folder, |
|
|
109 |
CFG.data_folder+"/validV3.csv", |
|
|
110 |
transform = valid_transforms,) |
|
|
111 |
class_weights, _ = get_class_weights(traindataset.targets, nclasses=5) |
|
|
112 |
|
|
|
113 |
### Choose P% of data from train data |
|
|
114 |
if data_percent and (data_percent < 100): |
|
|
115 |
_idx, used_idx = sk_train_test_split( np.arange(len(traindataset)), |
|
|
116 |
test_size=data_percent/100, random_state=CFG.seed, |
|
|
117 |
stratify=traindataset.targets) |
|
|
118 |
traindataset = torch.utils.data.Subset(traindataset, sorted(used_idx)) |
|
|
119 |
lutl.LOG2CSV(sorted(used_idx), CFG.gLogPath +'/train_indices_used.csv') |
|
|
120 |
|
|
|
121 |
torch.manual_seed(CFG.seed) |
|
|
122 |
## Loaders Class |
|
|
123 |
trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, |
|
|
124 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
125 |
pin_memory=True) |
|
|
126 |
|
|
|
127 |
validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, |
|
|
128 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
129 |
pin_memory=True) |
|
|
130 |
|
|
|
131 |
lutl.LOG2DICTXT({"Train->":len(traindataset), |
|
|
132 |
"class-weights":str(class_weights), |
|
|
133 |
"TransformsClass": str(train_transforms.get_composition()), |
|
|
134 |
},CFG.gLogPath +'/misc.txt') |
|
|
135 |
lutl.LOG2DICTXT({"Valid->":len(validdataset), |
|
|
136 |
"TransformsClass": str(valid_transforms.get_composition()), |
|
|
137 |
},CFG.gLogPath +'/misc.txt') |
|
|
138 |
|
|
|
139 |
return trainloader, validloader, class_weights |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
def getModelnOptimizer(): |
|
|
143 |
|
|
|
144 |
## pretrain setting |
|
|
145 |
m_state = 0; torch_pretrain_flag = None |
|
|
146 |
if os.path.isfile(CFG.featx_pretrain): |
|
|
147 |
m_state = torch.load(CFG.featx_pretrain, map_location='cpu') |
|
|
148 |
else: torch_pretrain_flag = CFG.featx_pretrain |
|
|
149 |
|
|
|
150 |
model = ClassifierNet(arch=CFG.featx_arch, |
|
|
151 |
fc_layer_sizes=CFG.clsfy_layers, |
|
|
152 |
feature_freeze=CFG.featx_freeze, |
|
|
153 |
feature_dropout=CFG.featx_dropout, |
|
|
154 |
feature_bnorm=CFG.featx_bnorm, |
|
|
155 |
classifier_dropout=CFG.clsfy_dropout, |
|
|
156 |
torch_pretrain=torch_pretrain_flag ) |
|
|
157 |
|
|
|
158 |
## load from checkpoints |
|
|
159 |
if m_state: |
|
|
160 |
m_state = m_state["model"] |
|
|
161 |
ret_msg = model.load_state_dict(m_state, strict=False) |
|
|
162 |
lutl.LOG2TXT(f"Manual Pretrain Loaded...{CFG.featx_pretrain},{str(ret_msg)}", |
|
|
163 |
CFG.gLogPath +'/misc.txt') |
|
|
164 |
|
|
|
165 |
model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size), |
|
|
166 |
verbose=0) |
|
|
167 |
lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) |
|
|
168 |
|
|
|
169 |
##-------------- |
|
|
170 |
|
|
|
171 |
optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate, |
|
|
172 |
weight_decay=CFG.weight_decay) |
|
|
173 |
scheduler = False |
|
|
174 |
|
|
|
175 |
return model.to(device), optimizer, scheduler |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
def getLossFunc(class_weights): |
|
|
179 |
lossfn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, |
|
|
180 |
dtype=torch.float32).to(device) ) |
|
|
181 |
return lossfn |
|
|
182 |
|
|
|
183 |
|
|
|
184 |
def simple_main(data_percent=None): |
|
|
185 |
|
|
|
186 |
### SETUP |
|
|
187 |
rutl.START_SEED(CFG.seed) |
|
|
188 |
gpu = 0 |
|
|
189 |
torch.cuda.set_device(gpu) |
|
|
190 |
torch.backends.cudnn.benchmark = True |
|
|
191 |
|
|
|
192 |
## paths and logs setup |
|
|
193 |
if data_percent: CFG.gLogPath = CFG.checkpoint_dir+f"/{data_percent}_percent/" |
|
|
194 |
CFG.gWeightPath = CFG.gLogPath+"/weights/" |
|
|
195 |
|
|
|
196 |
if os.path.exists(CFG.gLogPath) and (not CFG.restart_training): |
|
|
197 |
raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!", |
|
|
198 |
CFG.checkpoint_dir) |
|
|
199 |
if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) |
|
|
200 |
|
|
|
201 |
with open(CFG.gLogPath+"/exp_cfg.json", 'a') as f: |
|
|
202 |
json.dump(vars(CFG), f, indent=4) |
|
|
203 |
|
|
|
204 |
|
|
|
205 |
### DATA ACCESS |
|
|
206 |
trainloader, validloader, class_weights = getDataLoaders(data_percent) |
|
|
207 |
|
|
|
208 |
### MODEL, OPTIM |
|
|
209 |
model, optimizer, scheduler = getModelnOptimizer() |
|
|
210 |
lossfn = getLossFunc(class_weights) |
|
|
211 |
|
|
|
212 |
|
|
|
213 |
## Automatically resume from checkpoint if it exists and enabled |
|
|
214 |
if os.path.exists(CFG.gWeightPath +'/checkpoint.pth') and CFG.restart_training: |
|
|
215 |
ckpt = torch.load(CFG.gWeightPath +'/checkpoint.pth', |
|
|
216 |
map_location='cpu') |
|
|
217 |
start_epoch = ckpt['epoch'] |
|
|
218 |
model.load_state_dict(ckpt['model']) |
|
|
219 |
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
220 |
lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.gLogPath}", CFG.gLogPath +'/misc.txt') |
|
|
221 |
else: |
|
|
222 |
start_epoch = 0 |
|
|
223 |
|
|
|
224 |
### MODEL TRAINING |
|
|
225 |
start_time = time.time() |
|
|
226 |
best_acc = 0 ; best_loss = float('inf') |
|
|
227 |
trainMetric = MultiClassMetrics(CFG.gLogPath) |
|
|
228 |
validMetric = MultiClassMetrics(CFG.gLogPath) |
|
|
229 |
|
|
|
230 |
for epoch in range(start_epoch, CFG.epochs): |
|
|
231 |
|
|
|
232 |
## ---- Training Routine ---- |
|
|
233 |
model.train() |
|
|
234 |
for img, tgt in tqdm(trainloader, disable=CFG.disable_tqdm): |
|
|
235 |
img = img.to(device, non_blocking=True) |
|
|
236 |
tgt = tgt.to(device, non_blocking=True) |
|
|
237 |
optimizer.zero_grad() |
|
|
238 |
pred = model.forward(img) |
|
|
239 |
loss = lossfn(pred, tgt) |
|
|
240 |
loss.backward() |
|
|
241 |
# nn.utils.clip_grad_norm_(model.parameters(), |
|
|
242 |
# max_norm=2.0, norm_type=2) |
|
|
243 |
optimizer.step() |
|
|
244 |
trainMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss) |
|
|
245 |
if scheduler: scheduler.step() |
|
|
246 |
|
|
|
247 |
## save checkpoint states |
|
|
248 |
state = dict(epoch=epoch + 1, model=model.state_dict(), |
|
|
249 |
optimizer=optimizer.state_dict()) |
|
|
250 |
torch.save(state, CFG.gWeightPath +'/checkpoint.pth') |
|
|
251 |
|
|
|
252 |
|
|
|
253 |
## ---- Validation Routine ---- |
|
|
254 |
model.eval() |
|
|
255 |
with torch.no_grad(): |
|
|
256 |
for img, tgt in tqdm(validloader, disable=CFG.disable_tqdm): |
|
|
257 |
img = img.to(device, non_blocking=True) |
|
|
258 |
tgt = tgt.to(device, non_blocking=True) |
|
|
259 |
pred = model.forward(img) |
|
|
260 |
loss = lossfn(pred, tgt) |
|
|
261 |
validMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss) |
|
|
262 |
|
|
|
263 |
## Log Metrics TODO Add balanced and F1 |
|
|
264 |
stats = dict( |
|
|
265 |
epoch=epoch, time=int(time.time() - start_time), |
|
|
266 |
trainloss = trainMetric.get_loss(), |
|
|
267 |
trainacc = trainMetric.get_balanced_accuracy(), |
|
|
268 |
trainF1 = trainMetric.get_f1score(), |
|
|
269 |
validloss = validMetric.get_loss(), |
|
|
270 |
validacc = validMetric.get_balanced_accuracy(), |
|
|
271 |
validF1 = validMetric.get_f1score(), |
|
|
272 |
) |
|
|
273 |
lutl.LOG2DICTXT(stats, CFG.gLogPath+'/train-stats.txt') |
|
|
274 |
|
|
|
275 |
|
|
|
276 |
## save best model |
|
|
277 |
best_flag = False |
|
|
278 |
if stats['validacc'] > best_acc: |
|
|
279 |
torch.save(model.state_dict(), CFG.gWeightPath +'/bestmodel.pth') |
|
|
280 |
best_acc = stats['validacc'] |
|
|
281 |
best_loss = stats['validloss'] |
|
|
282 |
best_flag = True |
|
|
283 |
|
|
|
284 |
## Log detailed validation |
|
|
285 |
detail_stat = dict( |
|
|
286 |
epoch=epoch, time=int(time.time() - start_time), |
|
|
287 |
best = best_flag, |
|
|
288 |
validf1scr = validMetric.get_f1score(), |
|
|
289 |
validbalacc = validMetric.get_balanced_accuracy(), |
|
|
290 |
validacc = validMetric.get_accuracy(), |
|
|
291 |
validreport = validMetric.get_class_report(), |
|
|
292 |
validconfus = validMetric.get_confusion_matrix().tolist(), |
|
|
293 |
) |
|
|
294 |
lutl.LOG2DICTXT(detail_stat, CFG.gLogPath+'/validation-details.txt', console=False) |
|
|
295 |
|
|
|
296 |
trainMetric.reset() |
|
|
297 |
validMetric.reset(best_flag) |
|
|
298 |
|
|
|
299 |
return CFG.gLogPath |
|
|
300 |
|
|
|
301 |
|
|
|
302 |
|
|
|
303 |
def simple_test(saved_logpath): |
|
|
304 |
|
|
|
305 |
### SETUP |
|
|
306 |
rutl.START_SEED() |
|
|
307 |
gpu = 0 |
|
|
308 |
torch.cuda.set_device(gpu) |
|
|
309 |
torch.backends.cudnn.benchmark = True |
|
|
310 |
|
|
|
311 |
### DATA ACCESS |
|
|
312 |
test_transforms =ClassifierTransform(image_size=CFG.image_size, |
|
|
313 |
mode="infer") |
|
|
314 |
testdataset = ClassifyDataFromCSV( CFG.data_folder, |
|
|
315 |
CFG.data_folder+"/testV3.csv", |
|
|
316 |
transform = test_transforms,) |
|
|
317 |
testloader = torch.utils.data.DataLoader( testdataset, |
|
|
318 |
shuffle=False, |
|
|
319 |
batch_size=CFG.batch_size, |
|
|
320 |
num_workers=CFG.workers, |
|
|
321 |
pin_memory=True) |
|
|
322 |
lutl.LOG2DICTXT({"TEST->":len(testdataset), |
|
|
323 |
"TransformsClass": str(test_transforms.get_composition()), |
|
|
324 |
},saved_logpath +'/test-results.txt') |
|
|
325 |
|
|
|
326 |
### MODEL |
|
|
327 |
model = ClassifierNet(arch=CFG.featx_arch, |
|
|
328 |
fc_layer_sizes=CFG.clsfy_layers, |
|
|
329 |
feature_freeze=CFG.featx_freeze, |
|
|
330 |
feature_dropout=CFG.featx_dropout, |
|
|
331 |
feature_bnorm=CFG.featx_bnorm, |
|
|
332 |
classifier_dropout=CFG.clsfy_dropout) |
|
|
333 |
model = model.to(device) |
|
|
334 |
model.load_state_dict(torch.load(saved_logpath+"/weights/bestmodel.pth")) |
|
|
335 |
|
|
|
336 |
|
|
|
337 |
### MODEL TESTING |
|
|
338 |
testMetric = MultiClassMetrics(saved_logpath) |
|
|
339 |
model.eval() |
|
|
340 |
|
|
|
341 |
start_time = time.time() |
|
|
342 |
with torch.no_grad(): |
|
|
343 |
for img, tgt in tqdm(testloader, disable=CFG.disable_tqdm): |
|
|
344 |
img = img.to(device, non_blocking=True) |
|
|
345 |
tgt = tgt.to(device, non_blocking=True) |
|
|
346 |
pred = model.forward(img) |
|
|
347 |
testMetric.add_entry(torch.argmax(pred, dim=1), tgt) |
|
|
348 |
|
|
|
349 |
## Log detailed validation |
|
|
350 |
detail_stat = dict( |
|
|
351 |
timetaken = int(time.time() - start_time), |
|
|
352 |
testf1scr = testMetric.get_f1score(), |
|
|
353 |
testbalacc = testMetric.get_balanced_accuracy(), |
|
|
354 |
testacc = testMetric.get_accuracy(), |
|
|
355 |
testreport = testMetric.get_class_report(), |
|
|
356 |
testconfus = testMetric.get_confusion_matrix( |
|
|
357 |
save_png= True, title="test").tolist(), |
|
|
358 |
) |
|
|
359 |
lutl.LOG2DICTXT(detail_stat, saved_logpath+'/test-results.txt', |
|
|
360 |
console=True) |
|
|
361 |
|
|
|
362 |
testMetric._write_predictions(title="test") |
|
|
363 |
|
|
|
364 |
|
|
|
365 |
|
|
|
366 |
if __name__ == '__main__': |
|
|
367 |
|
|
|
368 |
# logpth = simple_main() |
|
|
369 |
# simple_test(logpth) |
|
|
370 |
|
|
|
371 |
for p in [100, 50, 25, 10, 5, 1]: |
|
|
372 |
logpth = simple_main(data_percent=p) |
|
|
373 |
simple_test(logpth) |