|
a |
|
b/tasks/moco-train.py |
|
|
1 |
""" Momentum Contrastive (MoCo) Learning |
|
|
2 |
""" |
|
|
3 |
import argparse |
|
|
4 |
import json |
|
|
5 |
import os |
|
|
6 |
import sys |
|
|
7 |
import time |
|
|
8 |
import numpy as np |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
import torchinfo |
|
|
14 |
|
|
|
15 |
sys.path.append(os.getcwd()) |
|
|
16 |
import utilities.runUtils as rutl |
|
|
17 |
import utilities.logUtils as lutl |
|
|
18 |
from algorithms.moco import MoCo |
|
|
19 |
from algorithms.loss.ssl_losses import NTXentLoss |
|
|
20 |
from datacode.natural_image_data import Cifar100Dataset |
|
|
21 |
from datacode.ultrasound_data import FetalUSFramesDataset |
|
|
22 |
from datacode.augmentations import SimCLRTransform |
|
|
23 |
|
|
|
24 |
print(f"Pytorch version: {torch.__version__}") |
|
|
25 |
print(f"cuda version: {torch.version.cuda}") |
|
|
26 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
27 |
print("Device Used:", device) |
|
|
28 |
|
|
|
29 |
###============================= Configure and Setup =========================== |
|
|
30 |
|
|
|
31 |
CFG = rutl.ObjDict( |
|
|
32 |
use_amp = True, #automatic Mixed precision |
|
|
33 |
|
|
|
34 |
datapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/train-all-frames.hdf5", |
|
|
35 |
valdatapath = "/home/mothilal.asokan/Downloads/HC701/Project/US-Fetal-Video-Frames_V1-1/valid-all-frames.hdf5", |
|
|
36 |
skip_count = 5, |
|
|
37 |
|
|
|
38 |
epochs = 20, |
|
|
39 |
batch_size = 288, |
|
|
40 |
workers = 24, |
|
|
41 |
image_size = 256, |
|
|
42 |
|
|
|
43 |
weight_decay = 1e-4, |
|
|
44 |
lr = 0.03, |
|
|
45 |
|
|
|
46 |
featx_arch = "resnet50", # "resnet34/50/101" |
|
|
47 |
featx_pretrain = "IMGNET-1K" , # "IMGNET-1K" or None |
|
|
48 |
|
|
|
49 |
print_freq_step = 10, #steps |
|
|
50 |
ckpt_freq_epoch = 5, #epochs |
|
|
51 |
valid_freq_epoch = 5, #epochs |
|
|
52 |
disable_tqdm = False, #True--> to disable |
|
|
53 |
|
|
|
54 |
checkpoint_dir= "hypotheses/-dummy/ssl-moco", |
|
|
55 |
resume_training = True, |
|
|
56 |
) |
|
|
57 |
|
|
|
58 |
## -------- |
|
|
59 |
parser = argparse.ArgumentParser(description='MoCo Training') |
|
|
60 |
parser.add_argument('--load-json', type=str, metavar='JSON', |
|
|
61 |
help='Load settings from file in json format. Command line options override values in python file.') |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
args = parser.parse_args() |
|
|
65 |
|
|
|
66 |
if args.load_json: |
|
|
67 |
with open(args.load_json, 'rt') as f: |
|
|
68 |
CFG.__dict__.update(json.load(f)) |
|
|
69 |
|
|
|
70 |
### ---------------------------------------------------------------------------- |
|
|
71 |
CFG.gLogPath = CFG.checkpoint_dir |
|
|
72 |
CFG.gWeightPath = CFG.checkpoint_dir + '/weights/' |
|
|
73 |
|
|
|
74 |
### ============================================================================ |
|
|
75 |
|
|
|
76 |
def getDataLoaders(): |
|
|
77 |
|
|
|
78 |
transform_obj = SimCLRTransform(image_size=CFG.image_size) |
|
|
79 |
|
|
|
80 |
traindataset = FetalUSFramesDataset( hdf5_file= CFG.datapath, |
|
|
81 |
transform = transform_obj, |
|
|
82 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
83 |
|
|
|
84 |
|
|
|
85 |
trainloader = torch.utils.data.DataLoader( traindataset, shuffle=True, |
|
|
86 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
87 |
pin_memory=True,drop_last=True ) |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
validdataset = FetalUSFramesDataset( hdf5_file= CFG.valdatapath, |
|
|
91 |
transform = transform_obj, |
|
|
92 |
load2ram = False, frame_skip=CFG.skip_count) |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
validloader = torch.utils.data.DataLoader( validdataset, shuffle=False, |
|
|
96 |
batch_size=CFG.batch_size, num_workers=CFG.workers, |
|
|
97 |
pin_memory=True, drop_last=True) |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
lutl.LOG2DICTXT({"TRAIN DatasetClass":traindataset.get_info(), |
|
|
101 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
102 |
}, CFG.gLogPath +'/misc.txt') |
|
|
103 |
lutl.LOG2DICTXT({"VALID DatasetClass":validdataset.get_info(), |
|
|
104 |
"TransformsClass": str(transform_obj.get_composition()), |
|
|
105 |
}, CFG.gLogPath +'/misc.txt') |
|
|
106 |
|
|
|
107 |
return trainloader, validloader |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def getModelnOptimizer(): |
|
|
111 |
model = MoCo(featx_arch=CFG.featx_arch, |
|
|
112 |
pretrained=CFG.featx_pretrain).to(device) |
|
|
113 |
|
|
|
114 |
optimizer = torch.optim.SGD(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, |
|
|
115 |
momentum=0.9) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
model_info = torchinfo.summary(model, [(CFG.batch_size, 3, CFG.image_size, CFG.image_size)], |
|
|
119 |
verbose=0) |
|
|
120 |
lutl.LOG2TXT(model_info, CFG.gLogPath +'/misc.txt', console= False) |
|
|
121 |
|
|
|
122 |
return model.to(device), optimizer |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float): |
|
|
126 |
"""Updates parameters of `model_ema` with Exponential Moving Average of `model` |
|
|
127 |
Momentum encoders are a crucial component fo models such as MoCo or BYOL. |
|
|
128 |
Examples: |
|
|
129 |
>>> backbone = resnet18() |
|
|
130 |
>>> projection_head = MoCoProjectionHead() |
|
|
131 |
>>> backbone_momentum = copy.deepcopy(moco) |
|
|
132 |
>>> projection_head_momentum = copy.deepcopy(projection_head) |
|
|
133 |
>>> |
|
|
134 |
>>> # update momentum |
|
|
135 |
>>> update_momentum(moco, moco_momentum, m=0.999) |
|
|
136 |
>>> update_momentum(projection_head, projection_head_momentum, m=0.999) |
|
|
137 |
""" |
|
|
138 |
for model_ema, model in zip(model_ema.parameters(), model.parameters()): |
|
|
139 |
model_ema.data = model_ema.data * m + model.data * (1.0 - m) |
|
|
140 |
|
|
|
141 |
|
|
|
142 |
|
|
|
143 |
def cosine_schedule( |
|
|
144 |
step: int, max_steps: int, start_value: float, end_value: float |
|
|
145 |
) -> float: |
|
|
146 |
""" |
|
|
147 |
Use cosine decay to gradually modify start_value to reach target end_value during iterations. |
|
|
148 |
Args: |
|
|
149 |
step: |
|
|
150 |
Current step number. |
|
|
151 |
max_steps: |
|
|
152 |
Total number of steps. |
|
|
153 |
start_value: |
|
|
154 |
Starting value. |
|
|
155 |
end_value: |
|
|
156 |
Target value. |
|
|
157 |
Returns: |
|
|
158 |
Cosine decay value. |
|
|
159 |
""" |
|
|
160 |
if step < 0: |
|
|
161 |
raise ValueError("Current step number can't be negative") |
|
|
162 |
if max_steps < 1: |
|
|
163 |
raise ValueError("Total step number must be >= 1") |
|
|
164 |
if step > max_steps: |
|
|
165 |
# Note: we allow step == max_steps even though step starts at 0 and should end |
|
|
166 |
# at max_steps - 1. This is because Pytorch Lightning updates the LR scheduler |
|
|
167 |
# always for the next epoch, even after the last training epoch. This results in |
|
|
168 |
# Pytorch Lightning calling the scheduler with step == max_steps. |
|
|
169 |
raise ValueError( |
|
|
170 |
f"The current step cannot be larger than max_steps but found step {step} and max_steps {max_steps}." |
|
|
171 |
) |
|
|
172 |
|
|
|
173 |
if max_steps == 1: |
|
|
174 |
# Avoid division by zero |
|
|
175 |
decay = end_value |
|
|
176 |
elif step == max_steps: |
|
|
177 |
# Special case for Pytorch Lightning which updates LR scheduler also for epoch |
|
|
178 |
# after last training epoch. |
|
|
179 |
decay = end_value |
|
|
180 |
else: |
|
|
181 |
decay = ( |
|
|
182 |
end_value |
|
|
183 |
- (end_value - start_value) |
|
|
184 |
* (np.cos(np.pi * step / (max_steps - 1)) + 1) |
|
|
185 |
/ 2 |
|
|
186 |
) |
|
|
187 |
return decay |
|
|
188 |
|
|
|
189 |
|
|
|
190 |
|
|
|
191 |
### ---------------------------------------------------------------------------- |
|
|
192 |
|
|
|
193 |
def simple_main(): |
|
|
194 |
### SETUP |
|
|
195 |
rutl.START_SEED() |
|
|
196 |
torch.cuda.device(device) |
|
|
197 |
torch.backends.cudnn.benchmark = True |
|
|
198 |
|
|
|
199 |
if os.path.exists(CFG.checkpoint_dir) and (not CFG.resume_training): |
|
|
200 |
raise Exception("CheckPoint folder already exists and restart_training not enabled; Somethings Wrong!") |
|
|
201 |
if not os.path.exists(CFG.gWeightPath): os.makedirs(CFG.gWeightPath) |
|
|
202 |
|
|
|
203 |
with open(CFG.gLogPath+"/exp_config.json", 'a') as f: |
|
|
204 |
json.dump(vars(CFG), f, indent=4) |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
### DATA ACCESS |
|
|
208 |
trainloader, validloader = getDataLoaders() |
|
|
209 |
|
|
|
210 |
### MODEL, OPTIM |
|
|
211 |
model, optimizer = getModelnOptimizer() |
|
|
212 |
|
|
|
213 |
criterion = NTXentLoss(memory_bank_size=4096) |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, |
|
|
217 |
len(trainloader), eta_min=0,last_epoch=-1) |
|
|
218 |
## Automatically resume from checkpoint if it exists and enabled |
|
|
219 |
ckpt = None |
|
|
220 |
if CFG.resume_training: |
|
|
221 |
try: ckpt = torch.load(CFG.gWeightPath+'/checkpoint-1.pth', map_location='cpu') |
|
|
222 |
except: |
|
|
223 |
try:ckpt = torch.load(CFG.gWeightPath+'/checkpoint-0.pth', map_location='cpu') |
|
|
224 |
except: print("Check points are not loadable. Starting fresh...") |
|
|
225 |
if ckpt: |
|
|
226 |
start_epoch = ckpt['epoch'] |
|
|
227 |
model.load_state_dict(ckpt['model']) |
|
|
228 |
optimizer.load_state_dict(ckpt['optimizer']) |
|
|
229 |
lutl.LOG2TXT(f"Restarting Training from EPOCH:{start_epoch} of {CFG.checkpoint_dir}", CFG.gLogPath +'/misc.txt') |
|
|
230 |
else: |
|
|
231 |
start_epoch = 0 |
|
|
232 |
|
|
|
233 |
|
|
|
234 |
### MODEL TRAINING |
|
|
235 |
start_time = time.time() |
|
|
236 |
best_loss = float('inf') |
|
|
237 |
wgt_suf = 0 # foolproof savetime crash |
|
|
238 |
if CFG.use_amp: scaler = torch.cuda.amp.GradScaler() # for mixed precision |
|
|
239 |
|
|
|
240 |
for epoch in range(start_epoch, CFG.epochs): |
|
|
241 |
|
|
|
242 |
## ---- Training Routine ---- |
|
|
243 |
t_running_loss_ = 0 |
|
|
244 |
momentum_val = cosine_schedule(epoch, CFG.epochs, 0.996, 1) |
|
|
245 |
|
|
|
246 |
model.train() |
|
|
247 |
for step, (x_query, x_key) in tqdm(enumerate(trainloader, |
|
|
248 |
start=epoch * len(trainloader)), |
|
|
249 |
disable=CFG.disable_tqdm): |
|
|
250 |
|
|
|
251 |
update_momentum(model.backbone, model.backbone_momentum, m=momentum_val) |
|
|
252 |
update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val) |
|
|
253 |
x_query = x_query.to(device, non_blocking=True) |
|
|
254 |
x_key = x_key.to(device, non_blocking=True) |
|
|
255 |
optimizer.zero_grad() |
|
|
256 |
|
|
|
257 |
if CFG.use_amp: ## with mixed precision |
|
|
258 |
with torch.cuda.amp.autocast(): |
|
|
259 |
query = model(x_query) |
|
|
260 |
key = model.forward_momentum(x_key) |
|
|
261 |
loss = criterion(query, key) |
|
|
262 |
|
|
|
263 |
scaler.scale(loss).backward() |
|
|
264 |
scaler.step(optimizer) |
|
|
265 |
scaler.update() |
|
|
266 |
else: |
|
|
267 |
query = model(x_query) |
|
|
268 |
key = model.forward_momentum(x_key) |
|
|
269 |
loss = criterion(query, key) |
|
|
270 |
loss.backward() |
|
|
271 |
optimizer.step() |
|
|
272 |
t_running_loss_+=loss.item() |
|
|
273 |
|
|
|
274 |
if step % CFG.print_freq_step == 0: |
|
|
275 |
stats = dict(epoch=epoch, step=step, |
|
|
276 |
lr_weights=optimizer.param_groups[0]['lr'], |
|
|
277 |
step_loss=loss.item(), |
|
|
278 |
time=int(time.time() - start_time)) |
|
|
279 |
lutl.LOG2DICTXT(stats, CFG.checkpoint_dir +'/train-stats.txt') |
|
|
280 |
train_epoch_loss = t_running_loss_/len(trainloader) |
|
|
281 |
|
|
|
282 |
scheduler.step() |
|
|
283 |
|
|
|
284 |
# save checkpoint |
|
|
285 |
if (epoch+1) % CFG.ckpt_freq_epoch == 0: |
|
|
286 |
wgt_suf = (wgt_suf+1) %2 |
|
|
287 |
state = dict(epoch=epoch, model=model.state_dict(), |
|
|
288 |
optimizer=optimizer.state_dict()) |
|
|
289 |
torch.save(state, CFG.gWeightPath +f'/checkpoint-{wgt_suf}.pth') |
|
|
290 |
|
|
|
291 |
|
|
|
292 |
## ---- Validation Routine ---- |
|
|
293 |
if (epoch+1) % CFG.valid_freq_epoch == 0: |
|
|
294 |
model.eval() |
|
|
295 |
v_running_loss_ = 0 |
|
|
296 |
with torch.no_grad(): |
|
|
297 |
for (x_query, x_key) in tqdm(validloader, total=len(validloader), |
|
|
298 |
disable=CFG.disable_tqdm): |
|
|
299 |
update_momentum(model.backbone, model.backbone_momentum, m=momentum_val) |
|
|
300 |
update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val) |
|
|
301 |
x_query = x_query.to(device, non_blocking=True) |
|
|
302 |
x_key = x_key.to(device, non_blocking=True) |
|
|
303 |
query = model(x_query) |
|
|
304 |
key = model.forward_momentum(x_key) |
|
|
305 |
loss = criterion(query, key) |
|
|
306 |
v_running_loss_ += loss.item() |
|
|
307 |
valid_epoch_loss = v_running_loss_/len(validloader) |
|
|
308 |
|
|
|
309 |
# just check |
|
|
310 |
best_flag = False |
|
|
311 |
if valid_epoch_loss < best_loss: |
|
|
312 |
best_flag = True |
|
|
313 |
best_loss = valid_epoch_loss |
|
|
314 |
|
|
|
315 |
v_stats = dict(epoch=epoch, best=best_flag, wgt_suf=wgt_suf, |
|
|
316 |
train_loss=train_epoch_loss, |
|
|
317 |
valid_loss=valid_epoch_loss) |
|
|
318 |
lutl.LOG2DICTXT(v_stats, CFG.gLogPath+'/valid-stats.txt') |
|
|
319 |
|
|
|
320 |
|
|
|
321 |
if __name__ == '__main__': |
|
|
322 |
simple_main() |