Diff of /trainer.py [000000] .. [dc40d0]

Switch to unified view

a b/trainer.py
1
import os
2
from typing import List, Dict, Type
3
import math
4
import torch
5
from torch.optim import Optimizer
6
import transformers
7
import torch.nn as nn
8
from PathBLIP.dataset import ImageTextContrastiveCollator
9
from tqdm import tqdm
10
WEIGHTS_NAME = "pytorch_model.bin"
11
12
class Trainer:
13
    '''trainer for single-gpu training.
14
    '''
15
    def __init__(self, args=None):
16
        pass
17
18
    def train(self,
19
        model,
20
        train_dataset,
21
        eval_dataset,
22
        local_rank,
23
        epochs: int = 1,
24
        scheduler: str = 'WarmupCosine',
25
        warmup_steps: int = 10000,
26
        warmup_ratio: float = 0.01,
27
        output_path: str = './checkpoints/vision_text_pretrain',
28
        optimizer_class: Type[Optimizer] = torch.optim.AdamW,
29
        optimizer_params : Dict[str, object]= {'lr': 2e-5},
30
        weight_decay: float = 0.01,
31
        max_grad_norm: float = 1,
32
        use_amp: bool = False,
33
        accumulation_steps: int = 1,
34
        ):
35
        '''
36
        output_path: model save path
37
        checkpoint_path: model load and continue to learn path
38
        '''
39
        self.accumulation_steps = accumulation_steps
40
        if use_amp:
41
            from torch.cuda.amp import autocast
42
            scaler = torch.cuda.amp.GradScaler()
43
44
        train_collate_fn = ImageTextContrastiveCollator()        
45
        
46
        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank], output_device=local_rank,find_unused_parameters=True)
47
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
48
        dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False, pin_memory=True, num_workers=4,
49
                                                   batch_size=36, drop_last=True, collate_fn=train_collate_fn, sampler=train_sampler)
50
51
        steps_per_epoch = len(dataloader)
52
        num_train_steps = int((steps_per_epoch) * epochs)
53
        warmup_steps = math.ceil(num_train_steps * warmup_ratio) #10% of train data for warm-up
54
55
        # Prepare optimizers
56
        param_optimizer = list(model.named_parameters())
57
58
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
59
        optimizer_grouped_parameters = [
60
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
61
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
62
        ]
63
64
        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
65
        scheduler = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)
66
67
68
        skip_scheduler = False
69
        for epoch in range(epochs):
70
            data_iterator = iter(dataloader)
71
            with tqdm(total=steps_per_epoch, disable= local_rank != 0) as pbar:
72
                for train_iter in range(steps_per_epoch):
73
                    model.zero_grad()
74
                    model.train()              
75
                    data = next(data_iterator)
76
                    data['image'] = data['image'].cuda()
77
                    if use_amp:
78
                        with autocast():
79
                            loss = model(data)
80
                        loss_value = loss['loss']
81
                        scale_before_step = scaler.get_scale()
82
                        scaler.scale(loss_value).backward()
83
                        scaler.unscale_(optimizer)
84
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
85
                        scaler.step(optimizer)
86
                        scaler.update()
87
                        skip_scheduler = scaler.get_scale() != scale_before_step
88
                    else:
89
                        loss = model(data)
90
                        loss_value = loss['loss'] / self.accumulation_steps
91
                        loss_value.backward()
92
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
93
                        optimizer.step()
94
                    # pbar.set_postfix('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value))  # 输入一个字典,显示实验指标
95
                    # tqdm.write('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value), end="\r")
96
                    pbar.set_description(f"Epoch {epoch+1}/{epochs}, Batch {train_iter}/{steps_per_epoch} - Loss: {loss_value:.4f}")
97
                    pbar.update(1)  # 更新进度条
98
                    # pbar.update(1)
99
                    # print('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value))
100
                    
101
102
                    optimizer.zero_grad()
103
104
                    if not skip_scheduler:
105
                        scheduler.step()
106
            if local_rank == 0:        
107
                self._save_ckpt(model,epoch,output_path)
108
109
110
    @staticmethod
111
    def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
112
        """
113
        Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
114
        """
115
        scheduler = scheduler.lower()
116
        if scheduler == 'constantlr':
117
            return transformers.get_constant_schedule(optimizer)
118
        elif scheduler == 'warmupconstant':
119
            return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
120
        elif scheduler == 'warmuplinear':
121
            return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
122
        elif scheduler == 'warmupcosine':
123
            return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
124
        elif scheduler == 'warmupcosinewithhardrestarts':
125
            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
126
        else:
127
            raise ValueError("Unknown scheduler {}".format(scheduler))
128
129
    def _save_ckpt(self, model, epoch, save_dir):
130
        if not os.path.exists(save_dir): 
131
            os.makedirs(save_dir)
132
        state_dict = model.state_dict()
133
        torch.save(state_dict, os.path.join(save_dir, 'epoch{}.pth'.format(epoch)))
134
        print("Save: ", os.path.join(save_dir, 'epoch{}.pth'.format(epoch)))