a b/Generation/utils.py
1
import numpy as np
2
import math
3
import torch
4
import os
5
import sys
6
import time
7
from torch import inf
8
import wandb
9
10
class NativeScaler:
11
    state_dict_key = "amp_scaler"
12
13
    def __init__(self):
14
        self._scaler = torch.cuda.amp.GradScaler()
15
16
    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
17
        self._scaler.scale(loss).backward(create_graph=create_graph)
18
        if update_grad:
19
            if clip_grad is not None:
20
                assert parameters is not None
21
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
22
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
23
            else:
24
                self._scaler.unscale_(optimizer)
25
                norm = get_grad_norm_(parameters)
26
            self._scaler.step(optimizer)
27
            self._scaler.update()
28
        else:
29
            norm = None
30
        return norm
31
32
    def state_dict(self):
33
        return self._scaler.state_dict()
34
35
    def load_state_dict(self, state_dict):
36
        self._scaler.load_state_dict(state_dict)
37
        
38
39
40
def get_grad_norm_(parameters, norm_type: float = 2.0):
41
    if isinstance(parameters, torch.Tensor):
42
        parameters = [parameters]
43
    parameters = [p for p in parameters if p.grad is not None]
44
    norm_type = float(norm_type)
45
    if len(parameters) == 0:
46
        return torch.tensor(0.)
47
    device = parameters[0].grad.device
48
    if norm_type == inf:
49
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
50
    else:
51
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
52
    return total_norm
53
        
54
def train_one_epoch(model, data_loader, optimizer, device, epoch, 
55
                        loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None, 
56
                        img_feature_extractor=None, preprocess=None):
57
    model.train(True)
58
    optimizer.zero_grad()
59
    total_loss = []
60
    total_cor = []
61
    accum_iter = config.accum_iter
62
    for data_iter_step, (data_dcit) in enumerate(data_loader):
63
        
64
        # we use a per iteration (instead of per epoch) lr scheduler
65
        # print(data_iter_step)
66
        # print(len(data_loader))
67
        
68
        if data_iter_step % accum_iter == 0:
69
            adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config)
70
        samples = data_dcit['eeg']
71
        
72
        img_features = None
73
        valid_idx = None
74
        if img_feature_extractor is not None:
75
            images = data_dcit['image']
76
            valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1)
77
            img_feature_extractor.eval()
78
            with torch.no_grad():
79
                img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2']
80
        samples = samples.to(device)
81
        # img_features = img_features.to(device)
82
83
        optimizer.zero_grad()
84
        with torch.cuda.amp.autocast(enabled=True):
85
            loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio)
86
        # loss.backward()
87
        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
88
        # optimizer.step()
89
90
        loss_value = loss.item()
91
92
        if not math.isfinite(loss_value):
93
            print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}")
94
            sys.exit(1)
95
96
        # loss /= accum_iter
97
        loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad)
98
99
        # if (data_iter_step + 1) % accum_iter == 0:
100
        # cal the cor
101
        pred = pred.to('cpu').detach()
102
        samples = samples.to('cpu').detach()
103
        # pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred)
104
        pred = model_without_ddp.unpatchify(pred)
105
            
106
        cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item()
107
        optimizer.zero_grad()
108
109
        total_loss.append(loss_value)
110
        total_cor.append(cor)
111
        if device == torch.device('cuda:0'):
112
            lr = optimizer.param_groups[0]["lr"]
113
            print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor))
114
115
    if log_writer is not None:
116
        lr = optimizer.param_groups[0]["lr"]
117
        log_writer.log('train_loss_step', np.mean(total_loss), step=epoch)
118
        log_writer.log('lr', lr, step=epoch)
119
        log_writer.log('cor', np.mean(total_cor), step=epoch)
120
        if start_time is not None:
121
            log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch)
122
    if config.local_rank == 0:        
123
        print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}')
124
125
    return np.mean(total_cor)
126
127
def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
128
    """
129
    grid_size: int of the grid height and width
130
    return:
131
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
132
    """
133
    grid_l = np.arange(length, dtype=float)
134
135
    grid_l = grid_l.reshape([1, length])
136
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l)
137
    if cls_token:
138
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
139
    return pos_embed
140
141
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
142
    """
143
    embed_dim: output dimension for each position
144
    pos: a list of positions to be encoded: size (M,)
145
    out: (M, D)
146
    """
147
    assert embed_dim % 2 == 0
148
    omega = np.arange(embed_dim // 2, dtype=float)
149
    omega /= embed_dim / 2.
150
    omega = 1. / 10000**omega  # (D/2,)
151
152
    pos = pos.reshape(-1)  # (M,)
153
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
154
155
    emb_sin = np.sin(out) # (M, D/2)
156
    emb_cos = np.cos(out) # (M, D/2)
157
158
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
159
    return emb
160
161
def interpolate_pos_embed(model, checkpoint_model):
162
    if 'pos_embed' in checkpoint_model:
163
        pos_embed_checkpoint = checkpoint_model['pos_embed']
164
        embedding_size = pos_embed_checkpoint.shape[-1]
165
        num_patches = model.patch_embed.num_patches
166
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token
167
        # height (== width) for the checkpoint position embedding
168
        orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens)
169
        # height (== width) for the new position embedding
170
        new_size = int(num_patches)
171
        # class_token and dist_token are kept unchanged
172
        if orig_size != new_size:
173
            print("Position interpolate from %d to %d" % (orig_size, new_size))
174
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
175
            # only the position tokens are interpolated
176
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
177
            pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1)
178
            pos_tokens = torch.nn.functional.interpolate(
179
                pos_tokens, size=(new_size))
180
            pos_tokens = pos_tokens.permute(0, 2, 1)
181
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182
            checkpoint_model['pos_embed'] = new_pos_embed
183
184
185
def adjust_learning_rate(optimizer, epoch, config):
186
    """Decay the learning rate with half-cycle cosine after warmup"""
187
    if epoch < config.warmup_epochs:
188
        lr = config.lr * epoch / config.warmup_epochs 
189
    else:
190
        lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \
191
            (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs)))
192
    for param_group in optimizer.param_groups:
193
        if "lr_scale" in param_group:
194
            param_group["lr"] = lr * param_group["lr_scale"]
195
        else:
196
            param_group["lr"] = lr
197
    return lr
198
199
200
def save_model(config, epoch, model, optimizer, loss_scaler):
201
    
202
    to_save = {
203
        'model': model.state_dict(),
204
        'optimizer': optimizer.state_dict(),
205
        'epoch': epoch,
206
        'scaler': loss_scaler.state_dict(),
207
        'config': config,
208
    }
209
    torch.save(to_save, '/home/weichen/projects/shiyin/DreamDiffusion.pth')
210
    
211
212
def load_model(config, model, checkpoint_path ):
213
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
214
    model.load_state_dict(checkpoint['model'])
215
    print(f'Model loaded with {checkpoint_path}')
216
    
217
218
def patchify(imgs, patch_size):
219
    """
220
    imgs: (N, 1, num_voxels)
221
    x: (N, L, patch_size)
222
    """
223
    p = patch_size
224
    assert imgs.ndim == 3 and imgs.shape[2] % p == 0
225
226
    h = imgs.shape[2] // p
227
    x = imgs.reshape(shape=(imgs.shape[0], h, p))
228
    return x
229
230
def unpatchify(x, patch_size):
231
    """
232
    x: (N, L, patch_size)
233
    imgs: (N, 1, num_voxels)
234
    """
235
    p = patch_size
236
    h = x.shape[1]
237
    
238
    imgs = x.reshape(shape=(x.shape[0], 1, h * p))
239
    return imgs
240
241
class wandb_logger:
242
    def __init__(self, config):
243
        wandb.init(
244
            # Set the project where this run will be logged
245
            project=config['project'],
246
            name=config['name'],
247
            config=config,
248
            entity=config['entity'],            
249
            )
250
251
252
        self.config = config
253
        self.step = None
254
    
255
    def log(self, data, step=None):
256
        if step is None:
257
            wandb.log(data)
258
        else:
259
            wandb.log(data, step=step)
260
            self.step = step
261
    
262
    def watch_model(self, *args, **kwargs):
263
        wandb.watch(*args, **kwargs)
264
265
    def log_image(self, figs):
266
        if self.step is None:
267
            wandb.log(figs)
268
        else:
269
            wandb.log(figs, step=self.step)
270
271
    def finish(self):
272
        wandb.finish(quiet=True)
273
274
    def load(self, net):
275
        path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt'])
276
        net.load_state_dict(torch.load(path))
277
        print(f'load {path}')
278
279
    def save(self, net, file_name=None):
280
        path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt'])
281
        if not os.path.exists(path_ckpt):
282
            os.makedirs(path_ckpt)
283
            print(f'{path_ckpt} created!')
284
285
        path = os.path.join(path_ckpt, file_name)
286
        torch.save(net.state_dict(), path)
287
288
    def watch(self, model, log):
289
        wandb.watch(model, log)