|
a |
|
b/Generation/utils.py |
|
|
1 |
import numpy as np |
|
|
2 |
import math |
|
|
3 |
import torch |
|
|
4 |
import os |
|
|
5 |
import sys |
|
|
6 |
import time |
|
|
7 |
from torch import inf |
|
|
8 |
import wandb |
|
|
9 |
|
|
|
10 |
class NativeScaler: |
|
|
11 |
state_dict_key = "amp_scaler" |
|
|
12 |
|
|
|
13 |
def __init__(self): |
|
|
14 |
self._scaler = torch.cuda.amp.GradScaler() |
|
|
15 |
|
|
|
16 |
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): |
|
|
17 |
self._scaler.scale(loss).backward(create_graph=create_graph) |
|
|
18 |
if update_grad: |
|
|
19 |
if clip_grad is not None: |
|
|
20 |
assert parameters is not None |
|
|
21 |
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place |
|
|
22 |
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) |
|
|
23 |
else: |
|
|
24 |
self._scaler.unscale_(optimizer) |
|
|
25 |
norm = get_grad_norm_(parameters) |
|
|
26 |
self._scaler.step(optimizer) |
|
|
27 |
self._scaler.update() |
|
|
28 |
else: |
|
|
29 |
norm = None |
|
|
30 |
return norm |
|
|
31 |
|
|
|
32 |
def state_dict(self): |
|
|
33 |
return self._scaler.state_dict() |
|
|
34 |
|
|
|
35 |
def load_state_dict(self, state_dict): |
|
|
36 |
self._scaler.load_state_dict(state_dict) |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
|
|
|
40 |
def get_grad_norm_(parameters, norm_type: float = 2.0): |
|
|
41 |
if isinstance(parameters, torch.Tensor): |
|
|
42 |
parameters = [parameters] |
|
|
43 |
parameters = [p for p in parameters if p.grad is not None] |
|
|
44 |
norm_type = float(norm_type) |
|
|
45 |
if len(parameters) == 0: |
|
|
46 |
return torch.tensor(0.) |
|
|
47 |
device = parameters[0].grad.device |
|
|
48 |
if norm_type == inf: |
|
|
49 |
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) |
|
|
50 |
else: |
|
|
51 |
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) |
|
|
52 |
return total_norm |
|
|
53 |
|
|
|
54 |
def train_one_epoch(model, data_loader, optimizer, device, epoch, |
|
|
55 |
loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None, |
|
|
56 |
img_feature_extractor=None, preprocess=None): |
|
|
57 |
model.train(True) |
|
|
58 |
optimizer.zero_grad() |
|
|
59 |
total_loss = [] |
|
|
60 |
total_cor = [] |
|
|
61 |
accum_iter = config.accum_iter |
|
|
62 |
for data_iter_step, (data_dcit) in enumerate(data_loader): |
|
|
63 |
|
|
|
64 |
# we use a per iteration (instead of per epoch) lr scheduler |
|
|
65 |
# print(data_iter_step) |
|
|
66 |
# print(len(data_loader)) |
|
|
67 |
|
|
|
68 |
if data_iter_step % accum_iter == 0: |
|
|
69 |
adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config) |
|
|
70 |
samples = data_dcit['eeg'] |
|
|
71 |
|
|
|
72 |
img_features = None |
|
|
73 |
valid_idx = None |
|
|
74 |
if img_feature_extractor is not None: |
|
|
75 |
images = data_dcit['image'] |
|
|
76 |
valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1) |
|
|
77 |
img_feature_extractor.eval() |
|
|
78 |
with torch.no_grad(): |
|
|
79 |
img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2'] |
|
|
80 |
samples = samples.to(device) |
|
|
81 |
# img_features = img_features.to(device) |
|
|
82 |
|
|
|
83 |
optimizer.zero_grad() |
|
|
84 |
with torch.cuda.amp.autocast(enabled=True): |
|
|
85 |
loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio) |
|
|
86 |
# loss.backward() |
|
|
87 |
# norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad) |
|
|
88 |
# optimizer.step() |
|
|
89 |
|
|
|
90 |
loss_value = loss.item() |
|
|
91 |
|
|
|
92 |
if not math.isfinite(loss_value): |
|
|
93 |
print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}") |
|
|
94 |
sys.exit(1) |
|
|
95 |
|
|
|
96 |
# loss /= accum_iter |
|
|
97 |
loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad) |
|
|
98 |
|
|
|
99 |
# if (data_iter_step + 1) % accum_iter == 0: |
|
|
100 |
# cal the cor |
|
|
101 |
pred = pred.to('cpu').detach() |
|
|
102 |
samples = samples.to('cpu').detach() |
|
|
103 |
# pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred) |
|
|
104 |
pred = model_without_ddp.unpatchify(pred) |
|
|
105 |
|
|
|
106 |
cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item() |
|
|
107 |
optimizer.zero_grad() |
|
|
108 |
|
|
|
109 |
total_loss.append(loss_value) |
|
|
110 |
total_cor.append(cor) |
|
|
111 |
if device == torch.device('cuda:0'): |
|
|
112 |
lr = optimizer.param_groups[0]["lr"] |
|
|
113 |
print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor)) |
|
|
114 |
|
|
|
115 |
if log_writer is not None: |
|
|
116 |
lr = optimizer.param_groups[0]["lr"] |
|
|
117 |
log_writer.log('train_loss_step', np.mean(total_loss), step=epoch) |
|
|
118 |
log_writer.log('lr', lr, step=epoch) |
|
|
119 |
log_writer.log('cor', np.mean(total_cor), step=epoch) |
|
|
120 |
if start_time is not None: |
|
|
121 |
log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch) |
|
|
122 |
if config.local_rank == 0: |
|
|
123 |
print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}') |
|
|
124 |
|
|
|
125 |
return np.mean(total_cor) |
|
|
126 |
|
|
|
127 |
def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False): |
|
|
128 |
""" |
|
|
129 |
grid_size: int of the grid height and width |
|
|
130 |
return: |
|
|
131 |
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
|
|
132 |
""" |
|
|
133 |
grid_l = np.arange(length, dtype=float) |
|
|
134 |
|
|
|
135 |
grid_l = grid_l.reshape([1, length]) |
|
|
136 |
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l) |
|
|
137 |
if cls_token: |
|
|
138 |
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
|
|
139 |
return pos_embed |
|
|
140 |
|
|
|
141 |
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
|
|
142 |
""" |
|
|
143 |
embed_dim: output dimension for each position |
|
|
144 |
pos: a list of positions to be encoded: size (M,) |
|
|
145 |
out: (M, D) |
|
|
146 |
""" |
|
|
147 |
assert embed_dim % 2 == 0 |
|
|
148 |
omega = np.arange(embed_dim // 2, dtype=float) |
|
|
149 |
omega /= embed_dim / 2. |
|
|
150 |
omega = 1. / 10000**omega # (D/2,) |
|
|
151 |
|
|
|
152 |
pos = pos.reshape(-1) # (M,) |
|
|
153 |
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product |
|
|
154 |
|
|
|
155 |
emb_sin = np.sin(out) # (M, D/2) |
|
|
156 |
emb_cos = np.cos(out) # (M, D/2) |
|
|
157 |
|
|
|
158 |
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) |
|
|
159 |
return emb |
|
|
160 |
|
|
|
161 |
def interpolate_pos_embed(model, checkpoint_model): |
|
|
162 |
if 'pos_embed' in checkpoint_model: |
|
|
163 |
pos_embed_checkpoint = checkpoint_model['pos_embed'] |
|
|
164 |
embedding_size = pos_embed_checkpoint.shape[-1] |
|
|
165 |
num_patches = model.patch_embed.num_patches |
|
|
166 |
num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token |
|
|
167 |
# height (== width) for the checkpoint position embedding |
|
|
168 |
orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens) |
|
|
169 |
# height (== width) for the new position embedding |
|
|
170 |
new_size = int(num_patches) |
|
|
171 |
# class_token and dist_token are kept unchanged |
|
|
172 |
if orig_size != new_size: |
|
|
173 |
print("Position interpolate from %d to %d" % (orig_size, new_size)) |
|
|
174 |
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|
|
175 |
# only the position tokens are interpolated |
|
|
176 |
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|
|
177 |
pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1) |
|
|
178 |
pos_tokens = torch.nn.functional.interpolate( |
|
|
179 |
pos_tokens, size=(new_size)) |
|
|
180 |
pos_tokens = pos_tokens.permute(0, 2, 1) |
|
|
181 |
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|
|
182 |
checkpoint_model['pos_embed'] = new_pos_embed |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
def adjust_learning_rate(optimizer, epoch, config): |
|
|
186 |
"""Decay the learning rate with half-cycle cosine after warmup""" |
|
|
187 |
if epoch < config.warmup_epochs: |
|
|
188 |
lr = config.lr * epoch / config.warmup_epochs |
|
|
189 |
else: |
|
|
190 |
lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \ |
|
|
191 |
(1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs))) |
|
|
192 |
for param_group in optimizer.param_groups: |
|
|
193 |
if "lr_scale" in param_group: |
|
|
194 |
param_group["lr"] = lr * param_group["lr_scale"] |
|
|
195 |
else: |
|
|
196 |
param_group["lr"] = lr |
|
|
197 |
return lr |
|
|
198 |
|
|
|
199 |
|
|
|
200 |
def save_model(config, epoch, model, optimizer, loss_scaler): |
|
|
201 |
|
|
|
202 |
to_save = { |
|
|
203 |
'model': model.state_dict(), |
|
|
204 |
'optimizer': optimizer.state_dict(), |
|
|
205 |
'epoch': epoch, |
|
|
206 |
'scaler': loss_scaler.state_dict(), |
|
|
207 |
'config': config, |
|
|
208 |
} |
|
|
209 |
torch.save(to_save, '/home/weichen/projects/shiyin/DreamDiffusion.pth') |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
def load_model(config, model, checkpoint_path ): |
|
|
213 |
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
214 |
model.load_state_dict(checkpoint['model']) |
|
|
215 |
print(f'Model loaded with {checkpoint_path}') |
|
|
216 |
|
|
|
217 |
|
|
|
218 |
def patchify(imgs, patch_size): |
|
|
219 |
""" |
|
|
220 |
imgs: (N, 1, num_voxels) |
|
|
221 |
x: (N, L, patch_size) |
|
|
222 |
""" |
|
|
223 |
p = patch_size |
|
|
224 |
assert imgs.ndim == 3 and imgs.shape[2] % p == 0 |
|
|
225 |
|
|
|
226 |
h = imgs.shape[2] // p |
|
|
227 |
x = imgs.reshape(shape=(imgs.shape[0], h, p)) |
|
|
228 |
return x |
|
|
229 |
|
|
|
230 |
def unpatchify(x, patch_size): |
|
|
231 |
""" |
|
|
232 |
x: (N, L, patch_size) |
|
|
233 |
imgs: (N, 1, num_voxels) |
|
|
234 |
""" |
|
|
235 |
p = patch_size |
|
|
236 |
h = x.shape[1] |
|
|
237 |
|
|
|
238 |
imgs = x.reshape(shape=(x.shape[0], 1, h * p)) |
|
|
239 |
return imgs |
|
|
240 |
|
|
|
241 |
class wandb_logger: |
|
|
242 |
def __init__(self, config): |
|
|
243 |
wandb.init( |
|
|
244 |
# Set the project where this run will be logged |
|
|
245 |
project=config['project'], |
|
|
246 |
name=config['name'], |
|
|
247 |
config=config, |
|
|
248 |
entity=config['entity'], |
|
|
249 |
) |
|
|
250 |
|
|
|
251 |
|
|
|
252 |
self.config = config |
|
|
253 |
self.step = None |
|
|
254 |
|
|
|
255 |
def log(self, data, step=None): |
|
|
256 |
if step is None: |
|
|
257 |
wandb.log(data) |
|
|
258 |
else: |
|
|
259 |
wandb.log(data, step=step) |
|
|
260 |
self.step = step |
|
|
261 |
|
|
|
262 |
def watch_model(self, *args, **kwargs): |
|
|
263 |
wandb.watch(*args, **kwargs) |
|
|
264 |
|
|
|
265 |
def log_image(self, figs): |
|
|
266 |
if self.step is None: |
|
|
267 |
wandb.log(figs) |
|
|
268 |
else: |
|
|
269 |
wandb.log(figs, step=self.step) |
|
|
270 |
|
|
|
271 |
def finish(self): |
|
|
272 |
wandb.finish(quiet=True) |
|
|
273 |
|
|
|
274 |
def load(self, net): |
|
|
275 |
path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt']) |
|
|
276 |
net.load_state_dict(torch.load(path)) |
|
|
277 |
print(f'load {path}') |
|
|
278 |
|
|
|
279 |
def save(self, net, file_name=None): |
|
|
280 |
path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt']) |
|
|
281 |
if not os.path.exists(path_ckpt): |
|
|
282 |
os.makedirs(path_ckpt) |
|
|
283 |
print(f'{path_ckpt} created!') |
|
|
284 |
|
|
|
285 |
path = os.path.join(path_ckpt, file_name) |
|
|
286 |
torch.save(net.state_dict(), path) |
|
|
287 |
|
|
|
288 |
def watch(self, model, log): |
|
|
289 |
wandb.watch(model, log) |