Diff of /train.py [000000] .. [dc40d0]

Switch to unified view

a b/train.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import argparse
9
import os
10
import random
11
12
import numpy as np
13
import torch
14
import torch.backends.cudnn as cudnn
15
16
import lavis.tasks as tasks
17
from lavis.common.config import Config
18
from lavis.common.dist_utils import get_rank, init_distributed_mode
19
from lavis.common.logger import setup_logger
20
from lavis.common.optims import (
21
    LinearWarmupCosineLRScheduler,
22
    LinearWarmupStepLRScheduler,
23
)
24
from lavis.common.registry import registry
25
from lavis.common.utils import now
26
27
# imports modules for registration
28
from lavis.datasets.builders import *
29
from lavis.models import *
30
from lavis.processors import *
31
from lavis.runners import *
32
from lavis.tasks import *
33
from PathBLIP.dataset import Quiltdataset
34
35
def parse_args():
36
    parser = argparse.ArgumentParser(description="Training")
37
38
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
39
    parser.add_argument(
40
        "--options",
41
        nargs="+",
42
        help="override some settings in the used config, the key-value pair "
43
        "in xxx=yyy format will be merged into config file (deprecate), "
44
        "change to --cfg-options instead.",
45
    )
46
    parser.add_argument("--local-rank", type=int)
47
    args = parser.parse_args()
48
    if 'LOCAL_RANK' not in os.environ:
49
        os.environ['LOCAL_RANK'] = str(args.local_rank)
50
51
    return args
52
53
54
def setup_seeds(config):
55
    seed = config.run_cfg.seed + get_rank()
56
57
    random.seed(seed)
58
    np.random.seed(seed)
59
    torch.manual_seed(seed)
60
61
    cudnn.benchmark = False
62
    cudnn.deterministic = True
63
64
65
def get_runner_class(cfg):
66
    """
67
    Get runner class from config. Default to epoch-based runner.
68
    """
69
    runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
70
71
    return runner_cls
72
73
74
def main():
75
    # allow auto-dl completes on main process without timeout when using NCCL backend.
76
    # os.environ["NCCL_BLOCKING_WAIT"] = "1"
77
78
    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
79
    job_id = now()
80
81
    cfg = Config(parse_args())
82
83
    init_distributed_mode(cfg.run_cfg)
84
85
    setup_seeds(cfg)
86
87
    # set after init_distributed_mode() to only log on master.
88
    setup_logger()
89
90
    cfg.pretty_print()
91
92
    task = tasks.setup_task(cfg)
93
    datasets = task.build_datasets(cfg)
94
    # datasets = Quiltdataset('../test_samples.csv')
95
    model = task.build_model(cfg)
96
97
    runner = get_runner_class(cfg)(
98
        cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
99
    )
100
    runner.train()
101
102
103
if __name__ == "__main__":
104
    main()