Diff of /Retrieval/util.py [000000] .. [a4067e]

Switch to unified view

a b/Retrieval/util.py
1
import numpy as np
2
import math
3
import torch
4
import os
5
import sys
6
import time
7
from torch import inf
8
import wandb
9
10
class NativeScaler:
11
    state_dict_key = "amp_scaler"
12
13
    def __init__(self):
14
        self._scaler = torch.cuda.amp.GradScaler()
15
16
    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
17
        self._scaler.scale(loss).backward(create_graph=create_graph)
18
        if update_grad:
19
            if clip_grad is not None:
20
                assert parameters is not None
21
                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
22
                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
23
            else:
24
                self._scaler.unscale_(optimizer)
25
                norm = get_grad_norm_(parameters)
26
            self._scaler.step(optimizer)
27
            self._scaler.update()
28
        else:
29
            norm = None
30
        return norm
31
32
    def state_dict(self):
33
        return self._scaler.state_dict()
34
35
    def load_state_dict(self, state_dict):
36
        self._scaler.load_state_dict(state_dict)
37
        
38
39
40
def get_grad_norm_(parameters, norm_type: float = 2.0):
41
    if isinstance(parameters, torch.Tensor):
42
        parameters = [parameters]
43
    parameters = [p for p in parameters if p.grad is not None]
44
    norm_type = float(norm_type)
45
    if len(parameters) == 0:
46
        return torch.tensor(0.)
47
    device = parameters[0].grad.device
48
    if norm_type == inf:
49
        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
50
    else:
51
        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
52
    return total_norm
53
        
54
def train_one_epoch(model, data_loader, optimizer, device, epoch, 
55
                        loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None, 
56
                        img_feature_extractor=None, preprocess=None):
57
    model.train(True)
58
    optimizer.zero_grad()
59
    total_loss = []
60
    total_cor = []
61
    accum_iter = config.accum_iter
62
    for data_iter_step, (data_dcit) in enumerate(data_loader):
63
        
64
        # we use a per iteration (instead of per epoch) lr scheduler
65
        # print(data_iter_step)
66
        # print(len(data_loader))
67
        
68
        if data_iter_step % accum_iter == 0:
69
            adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config)
70
        samples = data_dcit['eeg']
71
        
72
        img_features = None
73
        valid_idx = None
74
        if img_feature_extractor is not None:
75
            images = data_dcit['image']
76
            valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1)
77
            img_feature_extractor.eval()
78
            with torch.no_grad():
79
                img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2']
80
        samples = samples.to(device)
81
        # img_features = img_features.to(device)
82
83
        optimizer.zero_grad()
84
        with torch.cuda.amp.autocast(enabled=True):
85
            loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio)
86
        # loss.backward()
87
        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
88
        # optimizer.step()
89
90
        loss_value = loss.item()
91
92
        if not math.isfinite(loss_value):
93
            print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}")
94
            sys.exit(1)
95
96
        # loss /= accum_iter
97
        loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad)
98
99
        # if (data_iter_step + 1) % accum_iter == 0:
100
        # cal the cor
101
        pred = pred.to('cpu').detach()
102
        samples = samples.to('cpu').detach()
103
        # pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred)
104
        pred = model_without_ddp.unpatchify(pred)
105
            
106
        cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item()
107
        optimizer.zero_grad()
108
109
        total_loss.append(loss_value)
110
        total_cor.append(cor)
111
        if device == torch.device('cuda:0'):
112
            lr = optimizer.param_groups[0]["lr"]
113
            print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor))
114
115
    if log_writer is not None:
116
        lr = optimizer.param_groups[0]["lr"]
117
        log_writer.log('train_loss_step', np.mean(total_loss), step=epoch)
118
        log_writer.log('lr', lr, step=epoch)
119
        log_writer.log('cor', np.mean(total_cor), step=epoch)
120
        if start_time is not None:
121
            log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch)
122
    if config.local_rank == 0:        
123
        print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}')
124
125
    return np.mean(total_cor)
126
127
def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
128
    """
129
    grid_size: int of the grid height and width
130
    return:
131
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
132
    """
133
    grid_l = np.arange(length, dtype=float)
134
135
    grid_l = grid_l.reshape([1, length])
136
    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l)
137
    if cls_token:
138
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
139
    return pos_embed
140
141
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
142
    """
143
    embed_dim: output dimension for each position
144
    pos: a list of positions to be encoded: size (M,)
145
    out: (M, D)
146
    """
147
    assert embed_dim % 2 == 0
148
    omega = np.arange(embed_dim // 2, dtype=float)
149
    omega /= embed_dim / 2.
150
    omega = 1. / 10000**omega  # (D/2,)
151
152
    pos = pos.reshape(-1)  # (M,)
153
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
154
155
    emb_sin = np.sin(out) # (M, D/2)
156
    emb_cos = np.cos(out) # (M, D/2)
157
158
    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
159
    return emb
160
161
def interpolate_pos_embed(model, checkpoint_model):
162
    if 'pos_embed' in checkpoint_model:
163
        pos_embed_checkpoint = checkpoint_model['pos_embed']
164
        embedding_size = pos_embed_checkpoint.shape[-1]
165
        num_patches = model.patch_embed.num_patches
166
        num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token
167
        # height (== width) for the checkpoint position embedding
168
        orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens)
169
        # height (== width) for the new position embedding
170
        new_size = int(num_patches)
171
        # class_token and dist_token are kept unchanged
172
        if orig_size != new_size:
173
            print("Position interpolate from %d to %d" % (orig_size, new_size))
174
            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
175
            # only the position tokens are interpolated
176
            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
177
            pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1)
178
            pos_tokens = torch.nn.functional.interpolate(
179
                pos_tokens, size=(new_size))
180
            pos_tokens = pos_tokens.permute(0, 2, 1)
181
            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
182
            checkpoint_model['pos_embed'] = new_pos_embed
183
184
185
def adjust_learning_rate(optimizer, epoch, config):
186
    """Decay the learning rate with half-cycle cosine after warmup"""
187
    if epoch < config.warmup_epochs:
188
        lr = config.lr * epoch / config.warmup_epochs 
189
    else:
190
        lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \
191
            (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs)))
192
    for param_group in optimizer.param_groups:
193
        if "lr_scale" in param_group:
194
            param_group["lr"] = lr * param_group["lr_scale"]
195
        else:
196
            param_group["lr"] = lr
197
    return lr
198
199
200
    
201
202
def load_model(config, model, checkpoint_path ):
203
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
204
    model.load_state_dict(checkpoint['model'])
205
    print(f'Model loaded with {checkpoint_path}')
206
    
207
208
def patchify(imgs, patch_size):
209
    """
210
    imgs: (N, 1, num_voxels)
211
    x: (N, L, patch_size)
212
    """
213
    p = patch_size
214
    assert imgs.ndim == 3 and imgs.shape[2] % p == 0
215
216
    h = imgs.shape[2] // p
217
    x = imgs.reshape(shape=(imgs.shape[0], h, p))
218
    return x
219
220
def unpatchify(x, patch_size):
221
    """
222
    x: (N, L, patch_size)
223
    imgs: (N, 1, num_voxels)
224
    """
225
    p = patch_size
226
    h = x.shape[1]
227
    
228
    imgs = x.reshape(shape=(x.shape[0], 1, h * p))
229
    return imgs
230
231
class wandb_logger:
232
    def __init__(self, config):
233
        try:
234
            wandb.init(
235
                # Set the project where this run will be logged
236
                project=config['project'],
237
                name=config['name'],
238
                config=config,
239
                entity=config['entity'],            
240
                )
241
        except:
242
                wandb.init(
243
                # Set the project where this run will be logged
244
                project=config.project,
245
                name=config.name,
246
                config=config,
247
                entity=config.entity,            
248
                )
249
250
        self.config = config
251
        self.step = None
252
    
253
    def log(self, data, step=None):
254
        if step is None:
255
            wandb.log(data)
256
        else:
257
            wandb.log(data, step=step)
258
            self.step = step
259
    
260
    def watch_model(self, *args, **kwargs):
261
        wandb.watch(*args, **kwargs)
262
263
    def log_image(self, figs):
264
        if self.step is None:
265
            wandb.log(figs)
266
        else:
267
            wandb.log(figs, step=self.step)
268
269
    def finish(self):
270
        wandb.finish(quiet=True)
271
272
    def load(self, net):
273
        path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt'])
274
        net.load_state_dict(torch.load(path))
275
        print(f'load {path}')
276
277
    def save(self, net, file_name=None):
278
        path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt'])
279
        if not os.path.exists(path_ckpt):
280
            os.makedirs(path_ckpt)
281
            print(f'{path_ckpt} created!')
282
283
        path = os.path.join(path_ckpt, file_name)
284
        torch.save(net.state_dict(), path)
285
286
    def watch(self, model, log):
287
        wandb.watch(model, log)