Diff of /Generation/utils.py [000000] .. [a4067e]

Switch to side-by-side view

--- a
+++ b/Generation/utils.py
@@ -0,0 +1,289 @@
+import numpy as np
+import math
+import torch
+import os
+import sys
+import time
+from torch import inf
+import wandb
+
+class NativeScaler:
+    state_dict_key = "amp_scaler"
+
+    def __init__(self):
+        self._scaler = torch.cuda.amp.GradScaler()
+
+    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+        self._scaler.scale(loss).backward(create_graph=create_graph)
+        if update_grad:
+            if clip_grad is not None:
+                assert parameters is not None
+                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
+                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+            else:
+                self._scaler.unscale_(optimizer)
+                norm = get_grad_norm_(parameters)
+            self._scaler.step(optimizer)
+            self._scaler.update()
+        else:
+            norm = None
+        return norm
+
+    def state_dict(self):
+        return self._scaler.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self._scaler.load_state_dict(state_dict)
+        
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return torch.tensor(0.)
+    device = parameters[0].grad.device
+    if norm_type == inf:
+        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+    else:
+        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+    return total_norm
+        
+def train_one_epoch(model, data_loader, optimizer, device, epoch, 
+                        loss_scaler, log_writer=None, config=None, start_time=None, model_without_ddp=None, 
+                        img_feature_extractor=None, preprocess=None):
+    model.train(True)
+    optimizer.zero_grad()
+    total_loss = []
+    total_cor = []
+    accum_iter = config.accum_iter
+    for data_iter_step, (data_dcit) in enumerate(data_loader):
+        
+        # we use a per iteration (instead of per epoch) lr scheduler
+        # print(data_iter_step)
+        # print(len(data_loader))
+        
+        if data_iter_step % accum_iter == 0:
+            adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config)
+        samples = data_dcit['eeg']
+        
+        img_features = None
+        valid_idx = None
+        if img_feature_extractor is not None:
+            images = data_dcit['image']
+            valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1)
+            img_feature_extractor.eval()
+            with torch.no_grad():
+                img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2']
+        samples = samples.to(device)
+        # img_features = img_features.to(device)
+
+        optimizer.zero_grad()
+        with torch.cuda.amp.autocast(enabled=True):
+            loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio)
+        # loss.backward()
+        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad)
+        # optimizer.step()
+
+        loss_value = loss.item()
+
+        if not math.isfinite(loss_value):
+            print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}")
+            sys.exit(1)
+
+        # loss /= accum_iter
+        loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad)
+
+        # if (data_iter_step + 1) % accum_iter == 0:
+        # cal the cor
+        pred = pred.to('cpu').detach()
+        samples = samples.to('cpu').detach()
+        # pred = pred.transpose(1,2) #model_without_ddp.unpatchify(pred)
+        pred = model_without_ddp.unpatchify(pred)
+            
+        cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p[0].unsqueeze(0), s[0].unsqueeze(0)],axis=0))[0,1] for p, s in zip(pred, samples)])).item()
+        optimizer.zero_grad()
+
+        total_loss.append(loss_value)
+        total_cor.append(cor)
+        if device == torch.device('cuda:0'):
+            lr = optimizer.param_groups[0]["lr"]
+            print('train_loss_step:', np.mean(total_loss), 'lr:', lr, 'cor', np.mean(total_cor))
+
+    if log_writer is not None:
+        lr = optimizer.param_groups[0]["lr"]
+        log_writer.log('train_loss_step', np.mean(total_loss), step=epoch)
+        log_writer.log('lr', lr, step=epoch)
+        log_writer.log('cor', np.mean(total_cor), step=epoch)
+        if start_time is not None:
+            log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch)
+    if config.local_rank == 0:        
+        print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}')
+
+    return np.mean(total_cor)
+
+def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    grid_l = np.arange(length, dtype=float)
+
+    grid_l = grid_l.reshape([1, length])
+    pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+    """
+    embed_dim: output dimension for each position
+    pos: a list of positions to be encoded: size (M,)
+    out: (M, D)
+    """
+    assert embed_dim % 2 == 0
+    omega = np.arange(embed_dim // 2, dtype=float)
+    omega /= embed_dim / 2.
+    omega = 1. / 10000**omega  # (D/2,)
+
+    pos = pos.reshape(-1)  # (M,)
+    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
+
+    emb_sin = np.sin(out) # (M, D/2)
+    emb_cos = np.cos(out) # (M, D/2)
+
+    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
+    return emb
+
+def interpolate_pos_embed(model, checkpoint_model):
+    if 'pos_embed' in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model['pos_embed']
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token
+        # height (== width) for the checkpoint position embedding
+        orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print("Position interpolate from %d to %d" % (orig_size, new_size))
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens, size=(new_size))
+            pos_tokens = pos_tokens.permute(0, 2, 1)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def adjust_learning_rate(optimizer, epoch, config):
+    """Decay the learning rate with half-cycle cosine after warmup"""
+    if epoch < config.warmup_epochs:
+        lr = config.lr * epoch / config.warmup_epochs 
+    else:
+        lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \
+            (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs)))
+    for param_group in optimizer.param_groups:
+        if "lr_scale" in param_group:
+            param_group["lr"] = lr * param_group["lr_scale"]
+        else:
+            param_group["lr"] = lr
+    return lr
+
+
+def save_model(config, epoch, model, optimizer, loss_scaler):
+    
+    to_save = {
+        'model': model.state_dict(),
+        'optimizer': optimizer.state_dict(),
+        'epoch': epoch,
+        'scaler': loss_scaler.state_dict(),
+        'config': config,
+    }
+    torch.save(to_save, '/home/weichen/projects/shiyin/DreamDiffusion.pth')
+    
+
+def load_model(config, model, checkpoint_path ):
+    checkpoint = torch.load(checkpoint_path, map_location='cpu')
+    model.load_state_dict(checkpoint['model'])
+    print(f'Model loaded with {checkpoint_path}')
+    
+
+def patchify(imgs, patch_size):
+    """
+    imgs: (N, 1, num_voxels)
+    x: (N, L, patch_size)
+    """
+    p = patch_size
+    assert imgs.ndim == 3 and imgs.shape[2] % p == 0
+
+    h = imgs.shape[2] // p
+    x = imgs.reshape(shape=(imgs.shape[0], h, p))
+    return x
+
+def unpatchify(x, patch_size):
+    """
+    x: (N, L, patch_size)
+    imgs: (N, 1, num_voxels)
+    """
+    p = patch_size
+    h = x.shape[1]
+    
+    imgs = x.reshape(shape=(x.shape[0], 1, h * p))
+    return imgs
+
+class wandb_logger:
+    def __init__(self, config):
+        wandb.init(
+            # Set the project where this run will be logged
+            project=config['project'],
+            name=config['name'],
+            config=config,
+            entity=config['entity'],            
+            )
+
+
+        self.config = config
+        self.step = None
+    
+    def log(self, data, step=None):
+        if step is None:
+            wandb.log(data)
+        else:
+            wandb.log(data, step=step)
+            self.step = step
+    
+    def watch_model(self, *args, **kwargs):
+        wandb.watch(*args, **kwargs)
+
+    def log_image(self, figs):
+        if self.step is None:
+            wandb.log(figs)
+        else:
+            wandb.log(figs, step=self.step)
+
+    def finish(self):
+        wandb.finish(quiet=True)
+
+    def load(self, net):
+        path = os.path.join(self.config['path_data'], self.config['path_ckpt'], self.config['file_ckpt'])
+        net.load_state_dict(torch.load(path))
+        print(f'load {path}')
+
+    def save(self, net, file_name=None):
+        path_ckpt = os.path.join(self.config['path_data'], self.config['path_ckpt'])
+        if not os.path.exists(path_ckpt):
+            os.makedirs(path_ckpt)
+            print(f'{path_ckpt} created!')
+
+        path = os.path.join(path_ckpt, file_name)
+        torch.save(net.state_dict(), path)
+
+    def watch(self, model, log):
+        wandb.watch(model, log)
\ No newline at end of file