|
a |
|
b/tasks/simclr-train.py |
|
|
1 |
""" SimCLR self-supervision training |
|
|
2 |
""" |
|
|
3 |
import argparse |
|
|
4 |
import json |
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import time |
|
|
8 |
from tqdm import tqdm |
|
|
9 |
|
|
|
10 |
import torch |
|
|
11 |
import torchinfo |
|
|
12 |
|
|
|
13 |
sys.path.append(os.getcwd()) |
|
|
14 |
import utilities.runUtils as rutl |
|
|
15 |
import utilities.logUtils as lutl |
|
|
16 |
from algorithms.simclr import SimCLR, LARS |
|
|
17 |
from datacode.natural_image_data import Cifar100Dataset |
|
|
18 |
from datacode.ultrasound_data import FetalUSFramesDataset |
|
|
19 |
from datacode.augmentations import SimCLRTransform |
|
|
20 |
|
|
|
21 |
print(f"Pytorch version: {torch.__version__}") |
|
|
22 |
print(f"cuda version: {torch.version.cuda}") |
|
|
23 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
24 |
print("Device Used:", device) |
|
|
25 |
|
|
|
26 |
###============================= Configure and Setup =========================== |
|
|
27 |
|
|
|
28 |
CFG = rutl.ObjDict( |
|
|
29 |
use_amp = True, #automatic Mixed precision |
|
|
30 |
|
|
|
31 |
datapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/train-all-frames.hdf5", |
|
|
32 |
valdatapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/valid-all-frames.hdf5", |
|
|
33 |
skip_count = 5, |
|
|
34 |
|
|
|
35 |
epochs = 10, |
|
|
36 |
batch_size = 160, |
|
|
37 |
workers = 24, |
|
|
38 |
image_size = 256, |
|
|
39 |
|
|
|
40 |
weight_decay = 1e-4, |
|
|
41 |
temp = 0.5, |
|
|
42 |
lr = 0.3, |
|
|
43 |
|
|
|
44 |
featx_arch = "resnet50", # "resnet34/50/101" |
|
|
45 |
featx_pretrain = None, # "IMGNET-1K" or None |
|
|
46 |
projector = [512, 128], |
|
|
47 |
|
|
|
48 |
print_freq_step = 10, #steps |
|
|
49 |
ckpt_freq_epoch = 5, #epochs |
|
|
50 |
valid_freq_epoch = 5, #epochs |
|
|
51 |
disable_tqdm = False, #True--> to disable |
|
|
52 |
|
|
|
53 |
checkpoint_dir= "hypotheses/-dummy/ssl-simclr/", |
|
|
54 |
resume_training = True, |
|
|
55 |
) |
|
|
56 |
|
|
|
57 |
## -------- |
|
|
58 |
parser = argparse.ArgumentParser(description='SimCLR Training') |
|
|
59 |
parser.add_argument('--load-json', type=str, metavar='JSON', |
|
|
60 |
help='Load settings from file in json format. Command line options override values in python file.') |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
|
|
|
64 |
args = parser.parse_args() |
|
|
65 |
|
|
|
66 |
if args.load_json: |
|
|
67 |
with open(args.load_json, 'rt') as f: |
|
|
68 |
CFG.__dict__.update(json.load(f)) |
|
|
69 |
|
|
|
70 |
### ---------------------------------------------------------------------------- |
|
|
71 |
CFG.gLogPath = CFG.checkpoint_dir |
|
|
72 |
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' |
|
|
73 |
|
|
|
74 |
### ============================================================================ |
|
|
75 |
|
|
|
76 |
def getDataLoaders(): |
|
|
77 |
|
|
|
78 |
transform_obj = SimCLRTransform(image_size=CFG.image_size) |
|
|
79 |
|
|
|
80 |
traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath, |
|
|
81 |
transform = transform_obj, |
|
|
82 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, |
|
|
86 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
87 |
pin_memory=True,drop_last=True ) |
|
|
88 |
|
|
|
89 |
# val_transform_obj = SimCLREvalDataTransform(image_size=CFG.image_size) |
|
|
90 |
|
|
|
91 |
validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath, |
|
|
92 |
transform = transform_obj, |
|
|
93 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, |
|
|
97 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
98 |
pin_memory=True, drop_last=True) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(), |
|
|
102 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
103 |
}, CFG.gLogPath +'/misc.txt') |
|
|
104 |
lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(), |
|
|
105 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
106 |
}, CFG.gLogPath +'/misc.txt') |
|
|
107 |
|
|
|
108 |
return trainloader, validloader |
|
|
109 |
|
|
|
110 |
def getModelnOptimizer(): |
|
|
111 |
model = SimCLR(featx_arch=CFG.featx_arch, |
|
|
112 |
projector_sizes=CFG.projector, |
|
|
113 |
batch_size=CFG.batch_size, |
|
|
114 |
temp=CFG.temp, |
|
|
115 |
pretrained=CFG.featx_pretrain).to(device) |
|
|
116 |
|
|
|
117 |
optimizer = LARS(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, |
|
|
118 |
momentum=0.9) |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
model_info = torchinfo.summary(model, 2*[(CFG.batch_size, 3, CFG.image_size, CFG.image_size)], |
|
|
122 |
verbose=0) |
|
|
123 |
lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) |
|
|
124 |
|
|
|
125 |
return model.to(device), optimizer |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
### ---------------------------------------------------------------------------- |
|
|
129 |
|
|
|
130 |
def simple_main(): |
|
|
131 |
### SETUP |
|
|
132 |
rutl.START_SEED() |
|
|
133 |
torch.cuda.device(device) |
|
|
134 |
torch.backends.cudnn.benchmark = True |
|
|
135 |
|
|
|
136 |
if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training): |
|
|
137 |
raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!") |
|
|
138 |
if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) |
|
|
139 |
|
|
|
140 |
with open(CFG.gLogPath+"/exp_config.json", 'a') as f: |
|
|
141 |
json.dump(vars(CFG), f, indent=4) |
|
|
142 |
|
|
|
143 |
|
|
|
144 |
### DATA ACCESS |
|
|
145 |
trainloader, validloader = getDataLoaders() |
|
|
146 |
|
|
|
147 |
### MODEL, OPTIM |
|
|
148 |
model, optimizer = getModelnOptimizer() |
|
|
149 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, |
|
|
150 |
len(trainloader), eta_min=0,last_epoch=-1) |
|
|
151 |
## Automatically resume from checkpoint if it exists and enabled |
|
|
152 |
ckpt = None |
|
|
153 |
if CFG.resume_training: |
|
|
154 |
try: ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu') |
|
|
155 |
except: |
|
|
156 |
try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu') |
|
|
157 |
except: print("Check points are not loadable. Starting fresh...") |
|
|
158 |
if ckpt: |
|
|
159 |
start_epoch = ckpt['epoch'] |
|
|
160 |
model.load_state_dict(ckpt['model']) |
|
|
161 |
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
162 |
lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}", CFG.gLogPath +'/misc.txt') |
|
|
163 |
else: |
|
|
164 |
start_epoch = 0 |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
### MODEL TRAINING |
|
|
168 |
start_time = time.time() |
|
|
169 |
best_loss = float('inf') |
|
|
170 |
wgt_suf = 0 # foolproof savetime crash |
|
|
171 |
if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision |
|
|
172 |
|
|
|
173 |
for epoch in range(start_epoch, CFG.epochs): |
|
|
174 |
|
|
|
175 |
## ---- Training Routine ---- |
|
|
176 |
t_running_loss_ = 0 |
|
|
177 |
model.train() |
|
|
178 |
for step, (y1, y2) in tqdm(enumerate(trainloader, |
|
|
179 |
start=epoch * len(trainloader)), |
|
|
180 |
disable=CFG.disable_tqdm): |
|
|
181 |
|
|
|
182 |
y1 = y1.to(device, non_blocking=True) |
|
|
183 |
y2 = y2.to(device, non_blocking=True) |
|
|
184 |
optimizer.zero_grad() |
|
|
185 |
|
|
|
186 |
if CFG.use_amp: ## with mixed precision |
|
|
187 |
with torch.cuda.amp.autocast(): |
|
|
188 |
loss = model.forward(y1, y2) |
|
|
189 |
scaler.scale(loss).backward() |
|
|
190 |
scaler.step(optimizer) |
|
|
191 |
scaler.update() |
|
|
192 |
else: |
|
|
193 |
loss = model.forward(y1, y2) |
|
|
194 |
loss.backward() |
|
|
195 |
optimizer.step() |
|
|
196 |
t_running_loss_+=loss.item() |
|
|
197 |
|
|
|
198 |
if step % CFG.print_freq_step == 0: |
|
|
199 |
stats = dict(epoch=epoch, step=step, |
|
|
200 |
lr_weights=optimizer.param_groups[0]['lr'], |
|
|
201 |
step_loss=loss.item(), |
|
|
202 |
time=int(time.time() - start_time)) |
|
|
203 |
lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt') |
|
|
204 |
train_epoch_loss = t_running_loss_/len(trainloader) |
|
|
205 |
|
|
|
206 |
scheduler.step() |
|
|
207 |
|
|
|
208 |
# save checkpoint |
|
|
209 |
if (epoch+1) % CFG.ckpt_freq_epoch == 0: |
|
|
210 |
wgt_suf = (wgt_suf+1) %2 |
|
|
211 |
state = dict(epoch=epoch, model=model.state_dict(), |
|
|
212 |
optimizer=optimizer.state_dict()) |
|
|
213 |
torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth') |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
## ---- Validation Routine ---- |
|
|
217 |
if (epoch+1) % CFG.valid_freq_epoch == 0: |
|
|
218 |
model.eval() |
|
|
219 |
v_running_loss_ = 0 |
|
|
220 |
with torch.no_grad(): |
|
|
221 |
for (y1, y2) in tqdm(validloader, total=len(validloader), |
|
|
222 |
disable=CFG.disable_tqdm): |
|
|
223 |
y1 = y1.to(device, non_blocking=True) |
|
|
224 |
y2 = y2.to(device, non_blocking=True) |
|
|
225 |
loss = model.forward(y1, y2) |
|
|
226 |
v_running_loss_ += loss.item() |
|
|
227 |
valid_epoch_loss = v_running_loss_/len(validloader) |
|
|
228 |
|
|
|
229 |
# just check |
|
|
230 |
best_flag = False |
|
|
231 |
if valid_epoch_loss < best_loss: |
|
|
232 |
best_flag = True |
|
|
233 |
best_loss = valid_epoch_loss |
|
|
234 |
|
|
|
235 |
v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf, |
|
|
236 |
train_loss=train_epoch_loss, |
|
|
237 |
valid_loss=valid_epoch_loss) |
|
|
238 |
lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt') |
|
|
239 |
|
|
|
240 |
|
|
|
241 |
if __name__ == '__main__': |
|
|
242 |
simple_main() |