--- a
+++ b/trainer.py
@@ -0,0 +1,134 @@
+import os
+from typing import List, Dict, Type
+import math
+import torch
+from torch.optim import Optimizer
+import transformers
+import torch.nn as nn
+from PathBLIP.dataset import ImageTextContrastiveCollator
+from tqdm import tqdm
+WEIGHTS_NAME = "pytorch_model.bin"
+
+class Trainer:
+    '''trainer for single-gpu training.
+    '''
+    def __init__(self, args=None):
+        pass
+
+    def train(self,
+        model,
+        train_dataset,
+        eval_dataset,
+        local_rank,
+        epochs: int = 1,
+        scheduler: str = 'WarmupCosine',
+        warmup_steps: int = 10000,
+        warmup_ratio: float = 0.01,
+        output_path: str = './checkpoints/vision_text_pretrain',
+        optimizer_class: Type[Optimizer] = torch.optim.AdamW,
+        optimizer_params : Dict[str, object]= {'lr': 2e-5},
+        weight_decay: float = 0.01,
+        max_grad_norm: float = 1,
+        use_amp: bool = False,
+        accumulation_steps: int = 1,
+        ):
+        '''
+        output_path: model save path
+        checkpoint_path: model load and continue to learn path
+        '''
+        self.accumulation_steps = accumulation_steps
+        if use_amp:
+            from torch.cuda.amp import autocast
+            scaler = torch.cuda.amp.GradScaler()
+
+        train_collate_fn = ImageTextContrastiveCollator()        
+        
+        model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[local_rank], output_device=local_rank,find_unused_parameters=True)
+        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+        dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False, pin_memory=True, num_workers=4,
+                                                   batch_size=36, drop_last=True, collate_fn=train_collate_fn, sampler=train_sampler)
+
+        steps_per_epoch = len(dataloader)
+        num_train_steps = int((steps_per_epoch) * epochs)
+        warmup_steps = math.ceil(num_train_steps * warmup_ratio) #10% of train data for warm-up
+
+        # Prepare optimizers
+        param_optimizer = list(model.named_parameters())
+
+        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+        optimizer_grouped_parameters = [
+            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
+            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
+        ]
+
+        optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
+        scheduler = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps, t_total=num_train_steps)
+
+
+        skip_scheduler = False
+        for epoch in range(epochs):
+            data_iterator = iter(dataloader)
+            with tqdm(total=steps_per_epoch, disable= local_rank != 0) as pbar:
+                for train_iter in range(steps_per_epoch):
+                    model.zero_grad()
+                    model.train()              
+                    data = next(data_iterator)
+                    data['image'] = data['image'].cuda()
+                    if use_amp:
+                        with autocast():
+                            loss = model(data)
+                        loss_value = loss['loss']
+                        scale_before_step = scaler.get_scale()
+                        scaler.scale(loss_value).backward()
+                        scaler.unscale_(optimizer)
+                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
+                        scaler.step(optimizer)
+                        scaler.update()
+                        skip_scheduler = scaler.get_scale() != scale_before_step
+                    else:
+                        loss = model(data)
+                        loss_value = loss['loss'] / self.accumulation_steps
+                        loss_value.backward()
+                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
+                        optimizer.step()
+                    # pbar.set_postfix('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value))  # 输入一个字典,显示实验指标
+                    # tqdm.write('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value), end="\r")
+                    pbar.set_description(f"Epoch {epoch+1}/{epochs}, Batch {train_iter}/{steps_per_epoch} - Loss: {loss_value:.4f}")
+                    pbar.update(1)  # 更新进度条
+                    # pbar.update(1)
+                    # print('Epoch[{}/{}]/Iter[{}/{}]: loss: {:.4f}'.format(epoch,epochs,train_iter,steps_per_epoch,loss_value))
+                    
+
+                    optimizer.zero_grad()
+
+                    if not skip_scheduler:
+                        scheduler.step()
+            if local_rank == 0:        
+                self._save_ckpt(model,epoch,output_path)
+
+
+    @staticmethod
+    def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
+        """
+        Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
+        """
+        scheduler = scheduler.lower()
+        if scheduler == 'constantlr':
+            return transformers.get_constant_schedule(optimizer)
+        elif scheduler == 'warmupconstant':
+            return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
+        elif scheduler == 'warmuplinear':
+            return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
+        elif scheduler == 'warmupcosine':
+            return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
+        elif scheduler == 'warmupcosinewithhardrestarts':
+            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
+        else:
+            raise ValueError("Unknown scheduler {}".format(scheduler))
+
+    def _save_ckpt(self, model, epoch, save_dir):
+        if not os.path.exists(save_dir): 
+            os.makedirs(save_dir)
+        state_dict = model.state_dict()
+        torch.save(state_dict, os.path.join(save_dir, 'epoch{}.pth'.format(epoch)))
+        print("Save: ", os.path.join(save_dir, 'epoch{}.pth'.format(epoch)))
\ No newline at end of file