Diff of /evaluate.py [000000] .. [dc40d0]

Switch to unified view

a b/evaluate.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import argparse
9
import random
10
import os
11
import numpy as np
12
import torch
13
import torch.backends.cudnn as cudnn
14
15
import lavis.tasks as tasks
16
from lavis.common.config import Config
17
from lavis.common.dist_utils import get_rank, init_distributed_mode
18
from lavis.common.logger import setup_logger
19
from lavis.common.optims import (
20
    LinearWarmupCosineLRScheduler,
21
    LinearWarmupStepLRScheduler,
22
)
23
from lavis.common.utils import now
24
25
# imports modules for registration
26
from lavis.datasets.builders import *
27
from lavis.models import *
28
from lavis.processors import *
29
from lavis.runners.runner_base import RunnerBase
30
from lavis.tasks import *
31
32
33
def parse_args():
34
    parser = argparse.ArgumentParser(description="Training")
35
36
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
37
    parser.add_argument(
38
        "--options",
39
        nargs="+",
40
        help="override some settings in the used config, the key-value pair "
41
        "in xxx=yyy format will be merged into config file (deprecate), "
42
        "change to --cfg-options instead.",
43
    )
44
45
    parser.add_argument("--local-rank", type=int)
46
    args = parser.parse_args()
47
    if 'LOCAL_RANK' not in os.environ:
48
        os.environ['LOCAL_RANK'] = str(args.local_rank)
49
50
    return args
51
52
53
def setup_seeds(config):
54
    seed = config.run_cfg.seed + get_rank()
55
56
    random.seed(seed)
57
    np.random.seed(seed)
58
    torch.manual_seed(seed)
59
60
    cudnn.benchmark = False
61
    cudnn.deterministic = True
62
63
64
def main():
65
    # allow auto-dl completes on main process without timeout when using NCCL backend.
66
    # os.environ["NCCL_BLOCKING_WAIT"] = "1"
67
68
    # set before init_distributed_mode() to ensure the same job_id shared across all ranks.
69
    job_id = now()
70
71
    cfg = Config(parse_args())
72
73
    init_distributed_mode(cfg.run_cfg)
74
75
    setup_seeds(cfg)
76
77
    # set after init_distributed_mode() to only log on master.
78
    setup_logger()
79
80
    cfg.pretty_print()
81
82
    task = tasks.setup_task(cfg)
83
    datasets = task.build_datasets(cfg)
84
    model = task.build_model(cfg)
85
86
    runner = RunnerBase(
87
        cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
88
    )
89
    runner.evaluate(skip_reload=True)
90
91
92
if __name__ == "__main__":
93
    main()
94