Diff of /train_blip2qformer.py [000000] .. [dc40d0]

Switch to unified view

a b/train_blip2qformer.py
1
import pdb, os
2
import random
3
import argparse
4
5
import numpy as np
6
import torch
7
from torch.utils.data import DataLoader
8
from PathBLIP.dataset import Quiltdataset, ImageTextContrastiveCollator
9
from lavis.models import load_model
10
from trainer import Trainer
11
12
# set random seed
13
seed = 42
14
random.seed(seed)
15
np.random.seed(seed)
16
torch.manual_seed(seed)
17
torch.cuda.manual_seed(seed)
18
os.environ['PYTHONASHSEED'] = str(seed)
19
os.environ['TOKENIZERS_PARALLELISM']='false'
20
21
# set cuda devices
22
# os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3'
23
# device = "cuda:0,1,2,3" if torch.cuda.is_available() else "cpu"
24
25
26
train_config = {
27
    'num_epochs': 20,
28
    'warmup': 0.1,
29
    'lr': 2e-5,
30
    'weight_decay': 1e-4,
31
    'eval_batch_size': 8,
32
    'eval_steps': 1000,
33
    'save_steps': 1000,
34
}
35
36
train_dataset = Quiltdataset("../BLIP/LAVIS-main/quilt.csv")
37
train_collate_fn = ImageTextContrastiveCollator()
38
train_dataloader = DataLoader(train_dataset,
39
    batch_size=8,
40
    collate_fn=train_collate_fn,
41
    shuffle=True,
42
    pin_memory=True,
43
    num_workers=4,
44
    drop_last=True
45
    )
46
47
val_dataset = Quiltdataset("../test_samples.csv")
48
val_collate_fn = ImageTextContrastiveCollator()
49
50
val_dataloader = DataLoader(val_dataset,
51
    batch_size=4,
52
    collate_fn=val_collate_fn,
53
    shuffle=False,
54
    pin_memory=True,
55
    num_workers=4,
56
    )
57
58
# parser = argparse.ArgumentParser(description='training')
59
# parser.add_argument('--local_rank', type=int, help='local rank for dist')
60
# args = parser.parse_args()
61
torch.distributed.init_process_group(backend='nccl')
62
# print(args.local_rank)
63
# world_size = torch.cuda.device_count()
64
local_rank = torch.distributed.get_rank()
65
# print(local_rank)
66
torch.cuda.set_device(local_rank)
67
68
69
# parser.add_argument("--local-rank", type=int)
70
# args = parser.parse_args()
71
# if 'LOCAL_RANK' not in os.environ:
72
#     os.environ['LOCAL_RANK'] = str(args.local_rank)
73
74
model = load_model("blip2", "pretrain", checkpoint="../BLIP/blip2_pretrained.pth")
75
    # model.load_state_dict(torch.load('./checkpoints/vision_text_pretrain/t5/epoch10.pth',map_location='cpu'),strict=False)
76
model.cuda()
77
model_save_path = f'../BLIP/LAVIS-main/checkpoints/VL'
78
trainer = Trainer()
79
trainer.train(
80
    model,
81
    train_dataset,
82
    val_dataset,
83
    local_rank,
84
    warmup_ratio=train_config['warmup'],
85
    epochs=train_config['num_epochs'],
86
    optimizer_params={'lr':train_config['lr']},
87
    output_path=model_save_path,
88
    weight_decay=train_config['weight_decay'],
89
    use_amp=True,
90
    )
91
92
93