--- a
+++ b/train_blip2qformer.py
@@ -0,0 +1,93 @@
+import pdb, os
+import random
+import argparse
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from PathBLIP.dataset import Quiltdataset, ImageTextContrastiveCollator
+from lavis.models import load_model
+from trainer import Trainer
+
+# set random seed
+seed = 42
+random.seed(seed)
+np.random.seed(seed)
+torch.manual_seed(seed)
+torch.cuda.manual_seed(seed)
+os.environ['PYTHONASHSEED'] = str(seed)
+os.environ['TOKENIZERS_PARALLELISM']='false'
+
+# set cuda devices
+# os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
+# device = "cuda:0,1,2,3" if torch.cuda.is_available() else "cpu"
+
+
+train_config = {
+    'num_epochs': 20,
+    'warmup': 0.1,
+    'lr': 2e-5,
+    'weight_decay': 1e-4,
+    'eval_batch_size': 8,
+    'eval_steps': 1000,
+    'save_steps': 1000,
+}
+
+train_dataset = Quiltdataset("../BLIP/LAVIS-main/quilt.csv")
+train_collate_fn = ImageTextContrastiveCollator()
+train_dataloader = DataLoader(train_dataset,
+    batch_size=8,
+    collate_fn=train_collate_fn,
+    shuffle=True,
+    pin_memory=True,
+    num_workers=4,
+    drop_last=True
+    )
+
+val_dataset = Quiltdataset("../test_samples.csv")
+val_collate_fn = ImageTextContrastiveCollator()
+
+val_dataloader = DataLoader(val_dataset,
+    batch_size=4,
+    collate_fn=val_collate_fn,
+    shuffle=False,
+    pin_memory=True,
+    num_workers=4,
+    )
+
+# parser = argparse.ArgumentParser(description='training')
+# parser.add_argument('--local_rank', type=int, help='local rank for dist')
+# args = parser.parse_args()
+torch.distributed.init_process_group(backend='nccl')
+# print(args.local_rank)
+# world_size = torch.cuda.device_count()
+local_rank = torch.distributed.get_rank()
+# print(local_rank)
+torch.cuda.set_device(local_rank)
+
+
+# parser.add_argument("--local-rank", type=int)
+# args = parser.parse_args()
+# if 'LOCAL_RANK' not in os.environ:
+#     os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+model = load_model("blip2", "pretrain", checkpoint="../BLIP/blip2_pretrained.pth")
+    # model.load_state_dict(torch.load('./checkpoints/vision_text_pretrain/t5/epoch10.pth',map_location='cpu'),strict=False)
+model.cuda()
+model_save_path = f'../BLIP/LAVIS-main/checkpoints/VL'
+trainer = Trainer()
+trainer.train(
+    model,
+    train_dataset,
+    val_dataset,
+    local_rank,
+    warmup_ratio=train_config['warmup'],
+    epochs=train_config['num_epochs'],
+    optimizer_params={'lr':train_config['lr']},
+    output_path=model_save_path,
+    weight_decay=train_config['weight_decay'],
+    use_amp=True,
+    )
+
+
+    
\ No newline at end of file