Diff of /pretraining/train.py [000000] .. [4abb48]

Switch to unified view

a b/pretraining/train.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import argparse
9
import os
10
import random
11
12
import numpy as np
13
import pickle
14
import torch
15
import torch.backends.cudnn as cudnn
16
import wandb
17
18
from torch.utils.data import DataLoader
19
from torchinfo import summary
20
from tqdm import tqdm
21
22
23
import model.lavis.tasks as tasks
24
from model.lavis.common.config import Config
25
from model.lavis.common.dist_utils import get_rank
26
from model.lavis.common.logger import setup_logger
27
28
from local_config import WANDB_ENTITY
29
from model.lavis.common.registry import registry
30
from model.lavis.common.utils import now
31
32
# imports modules for registration
33
from model.lavis.common.optims import (
34
   LinearWarmupCosineLRScheduler,
35
   LinearWarmupStepLRScheduler,
36
)
37
from model.lavis.datasets.builders import *
38
from model.lavis.models import *
39
from model.lavis.processors import *
40
from model.lavis.runners import *
41
from model.lavis.tasks import *
42
from model.lavis.data.ReportDataset import MIMIC_CXR_Dataset
43
from local_config import PATH_TO_MIMIC_CXR
44
45
46
def parse_args():
47
    parser = argparse.ArgumentParser(description="Training")
48
49
    parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
50
    parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training.")
51
    parser.add_argument(
52
        "--options",
53
        nargs="+",
54
        help="override some settings in the used config, the key-value pair "
55
             "in xxx=yyy format will be merged into config file (deprecate), "
56
             "change to --cfg-options instead.",
57
    )
58
59
    args = parser.parse_args()
60
    # if 'LOCAL_RANK' not in os.environ:
61
    #     os.environ['LOCAL_RANK'] = str(args.local_rank)
62
63
    return args
64
65
66
def setup_seeds(config):
67
    seed = config.run_cfg.seed + get_rank()
68
69
    random.seed(seed)
70
    np.random.seed(seed)
71
    torch.manual_seed(seed)
72
73
    cudnn.benchmark = False
74
    cudnn.deterministic = True
75
76
77
def get_runner_class(cfg):
78
    """
79
    Get runner class from config. Default to epoch-based runner.
80
    """
81
    runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
82
83
    return runner_cls
84
85
86
def main():
87
    registry.mapping['paths']['cache_root'] = '.'
88
    cfg = Config(parse_args())
89
90
    job_id = now()
91
92
    # init_distributed_mode(cfg)
93
    setup_seeds(cfg)
94
95
    # set after init_distributed_mode() to only log on master.
96
    setup_logger()
97
98
    wandb_run = wandb.init(
99
        project=cfg.run_cfg.project_name,
100
        entity=WANDB_ENTITY,
101
        name=cfg.run_cfg.run_name
102
    )
103
104
    cfg.pretty_print()
105
106
    task = tasks.setup_task(cfg)
107
108
    # my report dataset
109
    datasets = {}
110
    datasets['mimic_cxr'] = {}
111
    datasets['mimic_cxr']['train'] = MIMIC_CXR_Dataset(vis_processor=None, text_processor=None, vis_root=f"{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0",
112
                                                       split="train", cfg=cfg, truncate=None)
113
    datasets['mimic_cxr']['train_val'] = MIMIC_CXR_Dataset(vis_processor=None, text_processor=None,
114
                                                           vis_root=f"{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0", split="train", cfg=cfg,
115
                                                           truncate=1000)  # 1000
116
    datasets['mimic_cxr']['val'] = MIMIC_CXR_Dataset(vis_processor=None, text_processor=None, vis_root=f"{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0",
117
                                                     split="validate", cfg=cfg, truncate=None)
118
    datasets['mimic_cxr']['test'] = MIMIC_CXR_Dataset(vis_processor=None, text_processor=None, vis_root=f"{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0",
119
                                                      split="test", cfg=cfg, truncate=None)
120
121
    model = task.build_model(cfg)
122
    print(summary(model, input_size=None, device='cpu'))
123
124
125
    if not cfg.run_cfg.evaluate:
126
        ''' training code '''
127
        runner = RunnerBase(
128
            cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
129
        )
130
131
        runner.train(wandb_run)
132
133
134
    else:
135
        ''' precompute Q-Former output embeddings for all images '''
136
        model.cuda()
137
        model.eval()
138
139
        dataloader = DataLoader(datasets['mimic_cxr']['test'], batch_size=256, shuffle=False, num_workers=cfg.run_cfg.num_workers)
140
        embeddings = {}
141
        for i, batch in enumerate(tqdm(dataloader)):
142
            qformer_embs, _ = model.forward_image(batch['image'].cuda())
143
            for j, id in enumerate(batch['image_id']):
144
                dicom = datasets['mimic_cxr']['test'].id_to_dicom[id.item()]
145
                embeddings[dicom] = qformer_embs[j].cpu().detach().numpy()
146
147
        # save embeddings
148
        with open(f"pretraining/embs/{cfg.run_cfg.run_name}_embeddings_test.pkl", "wb") as f:
149
            pickle.dump(embeddings, f)
150
151
        dataloader = DataLoader(datasets['mimic_cxr']['val'], batch_size=256, shuffle=False, num_workers=cfg.run_cfg.num_workers)
152
        embeddings = {}
153
        for i, batch in enumerate(tqdm(dataloader)):
154
            qformer_embs, _ = model.forward_image(batch['image'].cuda())
155
            for j, id in enumerate(batch['image_id']):
156
                dicom = datasets['mimic_cxr']['val'].id_to_dicom[id.item()]
157
                embeddings[dicom] = qformer_embs[j].cpu().detach().numpy()
158
159
        # save embeddings
160
        with open(f"pretraining/embs/{cfg.run_cfg.run_name}_embeddings_val.pkl", "wb") as f:
161
            pickle.dump(embeddings, f)
162
163
        dataloader = DataLoader(datasets['mimic_cxr']['train'], batch_size=256, shuffle=False, num_workers=cfg.run_cfg.num_workers)
164
        embeddings = {}
165
        for i, batch in enumerate(tqdm(dataloader)):
166
            qformer_embs, _ = model.forward_image(batch['image'].cuda())
167
            for j, id in enumerate(batch['image_id']):
168
                dicom = datasets['mimic_cxr']['train'].id_to_dicom[id.item()]
169
                embeddings[dicom] = qformer_embs[j].cpu().detach().numpy()
170
171
        # save embeddings
172
        with open(f"pretraining/embs/{cfg.run_cfg.run_name}_embeddings_train.pkl", "wb") as f:
173
            pickle.dump(embeddings, f)
174
175
176
if __name__ == "__main__":
177
    main()