|
a |
|
b/tasks/bt-train.py |
|
|
1 |
""" Barlow Twin self-supervision training |
|
|
2 |
""" |
|
|
3 |
import argparse |
|
|
4 |
import json |
|
|
5 |
import math |
|
|
6 |
import os |
|
|
7 |
import random |
|
|
8 |
import signal |
|
|
9 |
import subprocess |
|
|
10 |
import sys |
|
|
11 |
import time |
|
|
12 |
from tqdm import tqdm |
|
|
13 |
|
|
|
14 |
from torch import nn, optim |
|
|
15 |
import torch |
|
|
16 |
import torchvision |
|
|
17 |
import torchinfo |
|
|
18 |
import numpy as np |
|
|
19 |
from sklearn.model_selection import train_test_split as sk_train_test_split |
|
|
20 |
|
|
|
21 |
sys.path.append(os.getcwd()) |
|
|
22 |
import utilities.runUtils as rutl |
|
|
23 |
import utilities.logUtils as lutl |
|
|
24 |
from algorithms.barlowtwins import BarlowTwins, LARS, adjust_learning_rate |
|
|
25 |
from datacode.natural_image_data import Cifar100Dataset |
|
|
26 |
from datacode.ultrasound_data import FetalUSFramesDataset, ClassifyDataFromCSV |
|
|
27 |
from datacode.augmentations import BarlowTwinsTransformOrig, ClassifierTransform |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
print(f"Pytorch version: {torch.__version__}") |
|
|
31 |
print(f"cuda version: {torch.version.cuda}") |
|
|
32 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
33 |
print("Device Used:", device) |
|
|
34 |
|
|
|
35 |
###============================= Configure and Setup =========================== |
|
|
36 |
|
|
|
37 |
CFG = rutl.ObjDict( |
|
|
38 |
use_amp = True, #automatic Mixed precision |
|
|
39 |
|
|
|
40 |
datapath = "/home/USR/WERK/data/a.hdf5", |
|
|
41 |
valdatapath = "/home/USR/WERK/valdata/b.hdf5", |
|
|
42 |
skip_count = 5, |
|
|
43 |
|
|
|
44 |
epochs = 1000, |
|
|
45 |
batch_size = 2048, |
|
|
46 |
workers = 24, |
|
|
47 |
image_size = 256, |
|
|
48 |
|
|
|
49 |
learning_rate_weights = 0.2, |
|
|
50 |
learning_rate_biases = 0.0048, |
|
|
51 |
weight_decay = 1e-6, |
|
|
52 |
lmbd = 0.0051, |
|
|
53 |
|
|
|
54 |
featx_arch = "resnet50", # "resnet34/50/101" |
|
|
55 |
featx_pretrain = None, # "IMGNET-1K" or None |
|
|
56 |
projector = [8192,8192,8192], |
|
|
57 |
|
|
|
58 |
print_freq_step = 10 , #steps |
|
|
59 |
ckpt_freq_epoch = 5, #epochs |
|
|
60 |
valid_freq_epoch = 5, #epochs |
|
|
61 |
disable_tqdm = False, #True--> to disable |
|
|
62 |
|
|
|
63 |
checkpoint_dir = "hypotheses/-dummy/ssl-barlow/", |
|
|
64 |
resume_training = False, |
|
|
65 |
) |
|
|
66 |
|
|
|
67 |
## -------- |
|
|
68 |
parser = argparse.ArgumentParser(description='Barlow Twins Training') |
|
|
69 |
parser.add_argument('--load-json', type=str, metavar='JSON', |
|
|
70 |
help='Load settings from file in json format. Command line options override values in file.') |
|
|
71 |
|
|
|
72 |
args = parser.parse_args() |
|
|
73 |
|
|
|
74 |
if args.load_json: |
|
|
75 |
with open(args.load_json, 'rt') as f: |
|
|
76 |
CFG.__dict__.update(json.load(f)) |
|
|
77 |
|
|
|
78 |
### ---------------------------------------------------------------------------- |
|
|
79 |
CFG.gLogPath = CFG.checkpoint_dir |
|
|
80 |
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' |
|
|
81 |
|
|
|
82 |
### ============================================================================ |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
def getDataLoaders(): |
|
|
86 |
""" Unlabelled SSL Dataset |
|
|
87 |
""" |
|
|
88 |
|
|
|
89 |
transform_obj = BarlowTwinsTransformOrig(image_size=CFG.image_size) |
|
|
90 |
|
|
|
91 |
traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath, |
|
|
92 |
transform = transform_obj, |
|
|
93 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, |
|
|
97 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
98 |
pin_memory=True) |
|
|
99 |
|
|
|
100 |
validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath, |
|
|
101 |
transform = transform_obj, |
|
|
102 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, |
|
|
106 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
107 |
pin_memory=True) |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(), |
|
|
111 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
112 |
}, CFG.gLogPath +'/misc.txt') |
|
|
113 |
lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(), |
|
|
114 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
115 |
}, CFG.gLogPath +'/misc.txt') |
|
|
116 |
|
|
|
117 |
return trainloader, validloader |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def getModelnOptimizer(): |
|
|
121 |
model = BarlowTwins(featx_arch=CFG.featx_arch, |
|
|
122 |
projector_sizes=CFG.projector, |
|
|
123 |
batch_size=CFG.batch_size, |
|
|
124 |
lmbd=CFG.lmbd, |
|
|
125 |
pretrained=CFG.featx_pretrain).to(device) |
|
|
126 |
|
|
|
127 |
optimizer = LARS(model.parameters(), lr=0, weight_decay=CFG.weight_decay, |
|
|
128 |
weight_decay_filter=True, lars_adaptation_filter=True) |
|
|
129 |
|
|
|
130 |
model_info = torchinfo.summary(model, 2*[(1, 3, CFG.image_size, CFG.image_size)], |
|
|
131 |
verbose=0) |
|
|
132 |
lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) |
|
|
133 |
|
|
|
134 |
return model.to(device), optimizer |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
### ---------------------------------------------------------------------------- |
|
|
138 |
|
|
|
139 |
def simple_main(): |
|
|
140 |
### SETUP |
|
|
141 |
rutl.START_SEED() |
|
|
142 |
torch.cuda.device(device) |
|
|
143 |
torch.backends.cudnn.benchmark = True |
|
|
144 |
|
|
|
145 |
if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training): |
|
|
146 |
raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!") |
|
|
147 |
if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) |
|
|
148 |
|
|
|
149 |
with open(CFG.gLogPath+"/exp_config.json", 'a') as f: |
|
|
150 |
json.dump(vars(CFG), f, indent=4) |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
### DATA ACCESS |
|
|
154 |
trainloader, validloader = getDataLoaders() |
|
|
155 |
|
|
|
156 |
### MODEL, OPTIM |
|
|
157 |
model, optimizer = getModelnOptimizer() |
|
|
158 |
|
|
|
159 |
## Automatically resume from checkpoint if it exists and enabled |
|
|
160 |
ckpt = None |
|
|
161 |
if CFG.resume_training: |
|
|
162 |
try: ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu') |
|
|
163 |
except: |
|
|
164 |
try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu') |
|
|
165 |
except: print("Check points are not loadable. Starting fresh...") |
|
|
166 |
if ckpt: |
|
|
167 |
start_epoch = ckpt['epoch'] |
|
|
168 |
model.load_state_dict(ckpt['model']) |
|
|
169 |
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
170 |
lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}", CFG.gLogPath +'/misc.txt') |
|
|
171 |
else: |
|
|
172 |
start_epoch = 0 |
|
|
173 |
|
|
|
174 |
|
|
|
175 |
### MODEL TRAINING |
|
|
176 |
start_time = time.time() |
|
|
177 |
best_loss = float('inf') |
|
|
178 |
wgt_suf = 0 # foolproof savetime crash |
|
|
179 |
if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision |
|
|
180 |
|
|
|
181 |
for epoch in range(start_epoch, CFG.epochs): |
|
|
182 |
|
|
|
183 |
## ---- Training Routine ---- |
|
|
184 |
t_running_loss_ = 0 |
|
|
185 |
model.train() |
|
|
186 |
for step, (y1, y2) in tqdm(enumerate(trainloader, |
|
|
187 |
start=epoch * len(trainloader)), |
|
|
188 |
disable=CFG.disable_tqdm): |
|
|
189 |
y1 = y1.to(device, non_blocking=True) |
|
|
190 |
y2 = y2.to(device, non_blocking=True) |
|
|
191 |
|
|
|
192 |
adjust_learning_rate(CFG, optimizer, trainloader, step) |
|
|
193 |
optimizer.zero_grad() |
|
|
194 |
|
|
|
195 |
if CFG.use_amp: ## with mixed precision |
|
|
196 |
with torch.cuda.amp.autocast(): |
|
|
197 |
loss = model.forward(y1, y2) |
|
|
198 |
scaler.scale(loss).backward() |
|
|
199 |
scaler.step(optimizer) |
|
|
200 |
scaler.update() |
|
|
201 |
else: |
|
|
202 |
loss = model.forward(y1, y2) |
|
|
203 |
loss.backward() |
|
|
204 |
optimizer.step() |
|
|
205 |
t_running_loss_+=loss.item() |
|
|
206 |
|
|
|
207 |
if step % CFG.print_freq_step == 0: |
|
|
208 |
stats = dict(epoch=epoch, step=step, |
|
|
209 |
time=int(time.time() - start_time), |
|
|
210 |
step_loss = loss.item(), |
|
|
211 |
lr_weights = optimizer.param_groups[0]['lr'], |
|
|
212 |
lr_biases = optimizer.param_groups[1]['lr'],) |
|
|
213 |
lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt') |
|
|
214 |
train_epoch_loss = t_running_loss_/len(trainloader) |
|
|
215 |
|
|
|
216 |
# save checkpoint |
|
|
217 |
if (epoch+1) % CFG.ckpt_freq_epoch == 0: |
|
|
218 |
wgt_suf = (wgt_suf+1) %2 |
|
|
219 |
state = dict(epoch=epoch, model=model.state_dict(), |
|
|
220 |
optimizer=optimizer.state_dict()) |
|
|
221 |
torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth') |
|
|
222 |
|
|
|
223 |
|
|
|
224 |
## ---- Validation Routine ---- |
|
|
225 |
if (epoch+1) % CFG.valid_freq_epoch == 0: |
|
|
226 |
model.eval() |
|
|
227 |
v_running_loss_ = 0 |
|
|
228 |
with torch.no_grad(): |
|
|
229 |
for (y1, y2) in tqdm(validloader, total=len(validloader), |
|
|
230 |
disable=CFG.disable_tqdm): |
|
|
231 |
y1 = y1.to(device, non_blocking=True) |
|
|
232 |
y2 = y2.to(device, non_blocking=True) |
|
|
233 |
loss = model.forward(y1, y2) |
|
|
234 |
v_running_loss_ += loss.item() |
|
|
235 |
valid_epoch_loss = v_running_loss_/len(validloader) |
|
|
236 |
|
|
|
237 |
# just check |
|
|
238 |
best_flag = False |
|
|
239 |
if valid_epoch_loss < best_loss: |
|
|
240 |
best_flag = True |
|
|
241 |
best_loss = valid_epoch_loss |
|
|
242 |
|
|
|
243 |
v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf, |
|
|
244 |
train_loss=train_epoch_loss, |
|
|
245 |
valid_loss=valid_epoch_loss) |
|
|
246 |
lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt') |
|
|
247 |
|
|
|
248 |
|
|
|
249 |
if __name__ == '__main__': |
|
|
250 |
simple_main() |