Diff of /tasks/cls-train.py [000000] .. [a18f15]

Switch to unified view

a b/tasks/cls-train.py
1
""" Classifier Network trainig
2
"""
3
4
import argparse
5
import json
6
import os
7
import sys
8
import time
9
from tqdm.autonotebook import tqdm
10
11
import torch
12
from torch import nn, optim
13
import torchinfo
14
15
import numpy as np
16
from sklearn.model_selection import train_test_split as sk_train_test_split
17
18
sys.path.append(os.getcwd())
19
import utilities.runUtils as rutl
20
import utilities.logUtils as lutl
21
from utilities.metricUtils import MultiClassMetrics
22
from algorithms.classifiers import ClassifierNet
23
from datacode.ultrasound_data import ClassifyDataFromCSV, get_class_weights
24
from datacode.augmentations import ClassifierTransform
25
26
print(f"Pytorch version: {torch.__version__}")
27
print(f"cuda version: {torch.version.cuda}")
28
device = 'cuda' if torch.cuda.is_available() else 'cpu'
29
print("Device Used:", device)
30
31
###============================= Configure and Setup ===========================
32
33
CFG = rutl.ObjDict(
34
data_folder  = "/home/joseph.benjamin/WERK/fetal-ultrasound/data/Fetal-UltraSound/US-Planes-Heart-Views-V3",
35
balance_data = False, #while loading in dataloader; removed
36
seed = 1792,  #previously 73
37
38
epochs        = 100,
39
image_size    = 256,
40
batch_size    = 128,
41
workers       = 16,
42
learning_rate = 1e-3,
43
weight_decay  = 1e-6,
44
45
featx_arch     = "resnet50",
46
featx_pretrain =  "IMAGENET-1K" , # "IMAGENET-1K" or None
47
featx_freeze   = False,
48
featx_bnorm    = False,
49
featx_dropout  = 0.5,
50
clsfy_layers   = [5], #First mlp inwill be set w.r.t FeatureExtractor
51
clsfy_dropout  = 0.0,
52
53
checkpoint_dir   = "hypotheses/#dummy/Classify/trail-002",
54
disable_tqdm     = False, #True--> to disable
55
restart_training = True
56
)
57
58
### ----------------------------------------------------------------------------
59
# CLI TAKES PRECENCE OVER JSON CONFIG
60
# e.g CLI overwrites the value set for featx-pretain in JSON while running
61
# without CLI default values form dict will be used
62
63
parser = argparse.ArgumentParser(description='Classification task')
64
parser.add_argument('--load-json', type=str, metavar='JSON',
65
    help='Load settings from file in json format. Command line options override values in file.')
66
67
parser.add_argument('--seed', type=int, metavar='INT',
68
    help='add batchnorm between feature extractor and classifier')
69
70
parser.add_argument('--featx-freeze', type=bool, metavar='BOOL',
71
    help='freeze pretrain or not')
72
73
parser.add_argument('--featx-bnorm', type=bool, metavar='BOOL',
74
    help='add batchnorm between feature extractor and classifier')
75
76
parser.add_argument('--featx-pretrain', type=str, metavar='PATH',
77
    help='Set from where to load the prestrained weight from')
78
79
parser.add_argument('--checkpoint-dir', type=str, metavar='PATH',
80
    help='Load settings from file in json format. Command line options override values in file.')
81
82
83
args = parser.parse_args()
84
85
if args.load_json:
86
    with open(args.load_json, 'rt') as f:
87
        CFG.__dict__.update(json.load(f))
88
89
for arg in vars(args):
90
    att = getattr(args, arg)
91
    if att: CFG.__dict__[arg] = att
92
93
### ----------------------------------------------------------------------------
94
CFG.gLogPath = CFG.checkpoint_dir
95
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/'
96
97
### ============================================================================
98
99
def getDataLoaders(data_percent=None):
100
    ## Augumentations
101
    train_transforms =ClassifierTransform(image_size=CFG.image_size, mode="train")
102
    valid_transforms =ClassifierTransform(image_size=CFG.image_size, mode="infer")
103
104
    ## Dataset Class
105
    traindataset = ClassifyDataFromCSV(CFG.data_folder,
106
                                       CFG.data_folder+"/trainV3.csv",
107
                                       transform = train_transforms,)
108
    validdataset = ClassifyDataFromCSV(CFG.data_folder,
109
                                       CFG.data_folder+"/validV3.csv",
110
                                       transform = valid_transforms,)
111
    class_weights, _ = get_class_weights(traindataset.targets, nclasses=5)
112
113
    ### Choose P% of data from train data
114
    if data_percent and (data_percent < 100):
115
        _idx, used_idx = sk_train_test_split( np.arange(len(traindataset)),
116
                                test_size=data_percent/100, random_state=CFG.seed,
117
                                stratify=traindataset.targets)
118
        traindataset = torch.utils.data.Subset(traindataset, sorted(used_idx))
119
        lutl.LOG2CSV(sorted(used_idx), CFG.gLogPath +'/train_indices_used.csv')
120
121
    torch.manual_seed(CFG.seed)
122
    ## Loaders Class
123
    trainloader  = torch.utils.data.DataLoader( traindataset, shuffle=True,
124
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
125
                        pin_memory=True)
126
127
    validloader  = torch.utils.data.DataLoader( validdataset, shuffle=False,
128
                        batch_size=CFG.batch_size, num_workers=CFG.workers,
129
                        pin_memory=True)
130
131
    lutl.LOG2DICTXT({"Train->":len(traindataset),
132
                    "class-weights":str(class_weights),
133
                    "TransformsClass": str(train_transforms.get_composition()),
134
                    },CFG.gLogPath +'/misc.txt')
135
    lutl.LOG2DICTXT({"Valid->":len(validdataset),
136
                    "TransformsClass": str(valid_transforms.get_composition()),
137
                    },CFG.gLogPath +'/misc.txt')
138
139
    return trainloader, validloader, class_weights
140
141
142
def getModelnOptimizer():
143
144
    ## pretrain setting
145
    m_state = 0; torch_pretrain_flag = None
146
    if os.path.isfile(CFG.featx_pretrain):
147
        m_state = torch.load(CFG.featx_pretrain, map_location='cpu')
148
    else: torch_pretrain_flag = CFG.featx_pretrain
149
150
    model = ClassifierNet(arch=CFG.featx_arch,
151
                    fc_layer_sizes=CFG.clsfy_layers,
152
                    feature_freeze=CFG.featx_freeze,
153
                    feature_dropout=CFG.featx_dropout,
154
                    feature_bnorm=CFG.featx_bnorm,
155
                    classifier_dropout=CFG.clsfy_dropout,
156
                    torch_pretrain=torch_pretrain_flag )
157
158
    ## load from checkpoints
159
    if m_state:
160
        m_state = m_state["model"]
161
        ret_msg = model.load_state_dict(m_state, strict=False)
162
        lutl.LOG2TXT(f"Manual Pretrain Loaded...{CFG.featx_pretrain},{str(ret_msg)}",
163
                     CFG.gLogPath +'/misc.txt')
164
165
    model_info = torchinfo.summary(model, (1, 3, CFG.image_size, CFG.image_size),
166
                                verbose=0)
167
    lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False)
168
169
    ##--------------
170
171
    optimizer = optim.AdamW(model.parameters(), lr=CFG.learning_rate,
172
                        weight_decay=CFG.weight_decay)
173
    scheduler = False
174
175
    return model.to(device), optimizer, scheduler
176
177
178
def getLossFunc(class_weights):
179
    lossfn = nn.CrossEntropyLoss(weight=torch.tensor(class_weights,
180
                                        dtype=torch.float32).to(device) )
181
    return lossfn
182
183
184
def simple_main(data_percent=None):
185
186
   ### SETUP
187
    rutl.START_SEED(CFG.seed)
188
    gpu = 0
189
    torch.cuda.set_device(gpu)
190
    torch.backends.cudnn.benchmark = True
191
192
    ## paths and logs setup
193
    if data_percent: CFG.gLogPath = CFG.checkpoint_dir+f"/{data_percent}_percent/"
194
    CFG.gWeightPath = CFG.gLogPath+"/weights/"
195
196
    if os.path.exists(CFG.gLogPath) and (not CFG.restart_training):
197
        raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!",
198
                        CFG.checkpoint_dir)
199
    if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath)
200
201
    with open(CFG.gLogPath+"/exp_cfg.json", 'a') as f:
202
        json.dump(vars(CFG), f, indent=4)
203
204
205
    ### DATA ACCESS
206
    trainloader, validloader, class_weights  = getDataLoaders(data_percent)
207
208
    ### MODEL, OPTIM
209
    model, optimizer, scheduler = getModelnOptimizer()
210
    lossfn = getLossFunc(class_weights)
211
212
213
    ## Automatically resume from checkpoint if it exists and enabled
214
    if os.path.exists(CFG.gWeightPath +'/checkpoint.pth') and CFG.restart_training:
215
        ckpt = torch.load(CFG.gWeightPath  +'/checkpoint.pth',
216
                            map_location='cpu')
217
        start_epoch = ckpt['epoch']
218
        model.load_state_dict(ckpt['model'])
219
        optimizer.load_state_dict(ckpt['optimizer'])
220
        lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.gLogPath}",  CFG.gLogPath +'/misc.txt')
221
    else:
222
        start_epoch = 0
223
224
    ### MODEL TRAINING
225
    start_time = time.time()
226
    best_acc = 0 ; best_loss = float('inf')
227
    trainMetric = MultiClassMetrics(CFG.gLogPath)
228
    validMetric = MultiClassMetrics(CFG.gLogPath)
229
230
    for epoch in range(start_epoch, CFG.epochs):
231
232
        ## ---- Training Routine ----
233
        model.train()
234
        for img, tgt in tqdm(trainloader, disable=CFG.disable_tqdm):
235
            img = img.to(device, non_blocking=True)
236
            tgt = tgt.to(device, non_blocking=True)
237
            optimizer.zero_grad()
238
            pred = model.forward(img)
239
            loss = lossfn(pred, tgt)
240
            loss.backward()
241
            # nn.utils.clip_grad_norm_(model.parameters(),
242
            #                          max_norm=2.0, norm_type=2)
243
            optimizer.step()
244
            trainMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss)
245
        if scheduler: scheduler.step()
246
247
        ## save checkpoint states
248
        state = dict(epoch=epoch + 1, model=model.state_dict(),
249
                        optimizer=optimizer.state_dict())
250
        torch.save(state, CFG.gWeightPath +'/checkpoint.pth')
251
252
253
        ## ---- Validation Routine ----
254
        model.eval()
255
        with torch.no_grad():
256
            for img, tgt in tqdm(validloader, disable=CFG.disable_tqdm):
257
                img = img.to(device, non_blocking=True)
258
                tgt = tgt.to(device, non_blocking=True)
259
                pred = model.forward(img)
260
                loss = lossfn(pred, tgt)
261
                validMetric.add_entry(torch.argmax(pred, dim=1), tgt, loss)
262
263
        ## Log Metrics TODO Add balanced and F1
264
        stats = dict(
265
                epoch=epoch, time=int(time.time() - start_time),
266
                trainloss = trainMetric.get_loss(),
267
                trainacc  = trainMetric.get_balanced_accuracy(),
268
                trainF1   = trainMetric.get_f1score(),
269
                validloss = validMetric.get_loss(),
270
                validacc  = validMetric.get_balanced_accuracy(),
271
                validF1   = validMetric.get_f1score(),
272
                )
273
        lutl.LOG2DICTXT(stats, CFG.gLogPath+'/train-stats.txt')
274
275
276
        ## save best model
277
        best_flag = False
278
        if stats['validacc'] > best_acc:
279
            torch.save(model.state_dict(), CFG.gWeightPath +'/bestmodel.pth')
280
            best_acc = stats['validacc']
281
            best_loss = stats['validloss']
282
            best_flag = True
283
284
        ## Log detailed validation
285
        detail_stat = dict(
286
                epoch=epoch, time=int(time.time() - start_time),
287
                best = best_flag,
288
                validf1scr  = validMetric.get_f1score(),
289
                validbalacc = validMetric.get_balanced_accuracy(),
290
                validacc    = validMetric.get_accuracy(),
291
                validreport = validMetric.get_class_report(),
292
                validconfus = validMetric.get_confusion_matrix().tolist(),
293
            )
294
        lutl.LOG2DICTXT(detail_stat, CFG.gLogPath+'/validation-details.txt', console=False)
295
296
        trainMetric.reset()
297
        validMetric.reset(best_flag)
298
299
    return CFG.gLogPath
300
301
302
303
def simple_test(saved_logpath):
304
305
    ### SETUP
306
    rutl.START_SEED()
307
    gpu = 0
308
    torch.cuda.set_device(gpu)
309
    torch.backends.cudnn.benchmark = True
310
311
    ### DATA ACCESS
312
    test_transforms =ClassifierTransform(image_size=CFG.image_size,
313
                                        mode="infer")
314
    testdataset = ClassifyDataFromCSV(  CFG.data_folder,
315
                                        CFG.data_folder+"/testV3.csv",
316
                                        transform = test_transforms,)
317
    testloader  = torch.utils.data.DataLoader( testdataset,
318
                                        shuffle=False,
319
                                        batch_size=CFG.batch_size,
320
                                        num_workers=CFG.workers,
321
                                        pin_memory=True)
322
    lutl.LOG2DICTXT({"TEST->":len(testdataset),
323
                     "TransformsClass": str(test_transforms.get_composition()),
324
                    },saved_logpath +'/test-results.txt')
325
326
    ### MODEL
327
    model = ClassifierNet(arch=CFG.featx_arch,
328
                    fc_layer_sizes=CFG.clsfy_layers,
329
                    feature_freeze=CFG.featx_freeze,
330
                    feature_dropout=CFG.featx_dropout,
331
                    feature_bnorm=CFG.featx_bnorm,
332
                    classifier_dropout=CFG.clsfy_dropout)
333
    model = model.to(device)
334
    model.load_state_dict(torch.load(saved_logpath+"/weights/bestmodel.pth"))
335
336
337
    ### MODEL TESTING
338
    testMetric = MultiClassMetrics(saved_logpath)
339
    model.eval()
340
341
    start_time = time.time()
342
    with torch.no_grad():
343
        for img, tgt in tqdm(testloader, disable=CFG.disable_tqdm):
344
            img = img.to(device, non_blocking=True)
345
            tgt = tgt.to(device, non_blocking=True)
346
            pred = model.forward(img)
347
            testMetric.add_entry(torch.argmax(pred, dim=1), tgt)
348
349
        ## Log detailed validation
350
        detail_stat = dict(
351
                timetaken   = int(time.time() - start_time),
352
                testf1scr  = testMetric.get_f1score(),
353
                testbalacc = testMetric.get_balanced_accuracy(),
354
                testacc    = testMetric.get_accuracy(),
355
                testreport = testMetric.get_class_report(),
356
                testconfus = testMetric.get_confusion_matrix(
357
                                        save_png= True, title="test").tolist(),
358
            )
359
        lutl.LOG2DICTXT(detail_stat, saved_logpath+'/test-results.txt',
360
                        console=True)
361
362
        testMetric._write_predictions(title="test")
363
364
365
366
if __name__ == '__main__':
367
368
    # logpth = simple_main()
369
    # simple_test(logpth)
370
371
    for p in [100, 50, 25, 10, 5, 1]:
372
        logpth = simple_main(data_percent=p)
373
        simple_test(logpth)