Diff of /src/inference.py [000000] .. [95f789]

Switch to unified view

a b/src/inference.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as Ftorch
4
from torch.utils.data import DataLoader
5
from tqdm import *
6
7
from models import *
8
from augmentation import *
9
from dataset import *
10
import glob
11
12
13
device = torch.device('cuda')
14
15
16
def predict(model, loader):
17
    model.eval()
18
    preds = []
19
    with torch.no_grad():
20
        for dct in tqdm(loader, total=len(loader)):
21
            images = dct['images'].to(device)
22
            # meta = dct["meta"].to(device)
23
            pred = model(images)
24
            pred = Ftorch.sigmoid(pred)
25
            pred = pred.detach().cpu().numpy()
26
            preds.append(pred)
27
28
    preds = np.concatenate(preds, axis=0)
29
    return preds
30
31
32
def get_best_checkpoints(checkpoint_dir, n_best=3, minimize_metric=True):
33
    files = glob.glob(f"{checkpoint_dir}/checkpoints/best*.pth")
34
    files = [file for file in files if not 'full' in file]
35
36
    top_best_metrics = []
37
    for file in files:
38
        ckp = torch.load(file)
39
        valid_metric = ckp['valid_metrics']['loss']
40
        top_best_metrics.append((file, valid_metric))
41
42
    top_best_metrics = sorted(
43
        top_best_metrics,
44
        key=lambda x: x[1],
45
        reverse=not minimize_metric
46
    )
47
    top_best_metrics = top_best_metrics[:n_best]
48
    return top_best_metrics
49
50
51
def predict_test_tta_ckp():
52
    test_csv = "./csv/patient2_kfold/test.csv"
53
    # test_root = "/data/stage_1_test_3w/"
54
    # test_root = "/data/png/test_stage_1/adjacent-brain-cropped/"
55
    # test_root = "/data/stage_1_test_3w/"
56
    test_root = "/data/stage_1_test_images_jpg_preprocessing/"
57
    image_type = 'jpg'
58
59
    image_size = [512, 512]
60
    backbone = "densenet169"
61
    normalization = True
62
    # fold = 2
63
    for fold in [0, 1, 2, 3, 4]:
64
        # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth
65
        scheme = f"{backbone}-mw-512-resume-{fold}"
66
67
        log_dir = f"/logs/rsna/test/{scheme}/"
68
69
        with_any = True
70
71
        if with_any:
72
            num_classes = 6
73
            target_cols = LABEL_COLS
74
        else:
75
            num_classes = 5
76
            target_cols = LABEL_COLS_WITHOUT_ANY
77
78
        # test_preds = 0
79
80
        top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True)
81
82
        test_preds = 0
83
        for best_metric in top_best_metrics:
84
85
            checkpoint_path, checkpoint_metric = best_metric
86
            print("*" * 50)
87
            print(f"checkpoint: {checkpoint_path}")
88
            print(f"Metric: {checkpoint_metric}")
89
90
            model = CNNFinetuneModels(
91
                model_name=backbone,
92
                num_classes=num_classes,
93
                pretrained=False
94
            )
95
96
            ckp = os.path.join(log_dir, f"checkpoints/best.pth")
97
            checkpoint = torch.load(ckp)
98
            model.load_state_dict(checkpoint['model_state_dict'])
99
            model = nn.DataParallel(model)
100
            model = model.to(device)
101
102
            augs = test_tta(image_size, normalization)
103
104
            for name, aug in augs.items():
105
                print("Augmentation: {}".format(name))
106
107
                test_dataset = RSNADataset(
108
                    csv_file=test_csv,
109
                    root=test_root,
110
                    with_any=with_any,
111
                    transform=aug,
112
                    mode="test",
113
                    image_type=image_type
114
                )
115
116
                test_loader = DataLoader(
117
                    dataset=test_dataset,
118
                    batch_size=64,
119
                    shuffle=False,
120
                    num_workers=8,
121
                )
122
123
                test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics))
124
125
        os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True)
126
        np.save(f"/logs/prediction/{scheme}/test_{fold}_ckp_tta.npy", test_preds)
127
128
        test_df = pd.read_csv(test_csv)
129
        test_ids = test_df['sop_instance_uid'].values
130
131
        ids = []
132
        labels = []
133
        for i, id in enumerate(test_ids):
134
            if not "ID" in id:
135
                id = "ID_" + id
136
            pred = test_preds[i]
137
            for j, target in enumerate(target_cols):
138
                id_target = id + "_" + target
139
                ids.append(id_target)
140
                labels.append(pred[j])
141
            if not with_any:
142
                id_target = id + "_" + "any"
143
                ids.append(id_target)
144
                labels.append(pred.max())
145
146
        submission_df = pd.DataFrame({
147
            'ID': ids,
148
            'Label': labels
149
        })
150
151
        submission_df.to_csv(f"/logs/prediction/{scheme}/{scheme}_ckp_tta.csv", index=False)
152
153
154
def predict_valid_tta_ckp():
155
156
    # test_root = "/data/png/train/adjacent-brain-cropped/"
157
    # test_root = "/data/stage_1_train_3w/"
158
    test_root = "/data/stage_1_test_images_jpg_preprocessing/"
159
    image_type = 'jpg'
160
161
    image_size = [512, 512]
162
    backbone = "densenet169"
163
    normalization = True
164
    # fold = 2
165
    for fold in [0, 1, 2, 3, 4]:
166
        test_csv = f"./csv/patient2_kfold/valid_{fold}.csv"
167
        # /logs/rsna/test/resnet50-anju-512-resume-0/checkpoints//train512.13.pth
168
        scheme = f"{backbone}-mw-512-resume-{fold}"
169
170
        log_dir = f"/logs/rsna/test/{scheme}/"
171
172
        with_any = True
173
174
        if with_any:
175
            num_classes = 6
176
            target_cols = LABEL_COLS
177
        else:
178
            num_classes = 5
179
            target_cols = LABEL_COLS_WITHOUT_ANY
180
181
        # test_preds = 0
182
183
        top_best_metrics = get_best_checkpoints(log_dir, n_best=1, minimize_metric=True)
184
185
        test_preds = 0
186
        for best_metric in top_best_metrics:
187
188
            checkpoint_path, checkpoint_metric = best_metric
189
            print("*" * 50)
190
            print(f"checkpoint: {checkpoint_path}")
191
            print(f"Metric: {checkpoint_metric}")
192
193
            model = CNNFinetuneModels(
194
                model_name=backbone,
195
                num_classes=num_classes,
196
            )
197
198
            ckp = os.path.join(log_dir, f"checkpoints/best.pth")
199
            checkpoint = torch.load(ckp)
200
            model.load_state_dict(checkpoint['model_state_dict'])
201
            model = nn.DataParallel(model)
202
            model = model.to(device)
203
204
            augs = test_tta(image_size, normalization)
205
206
            for name, aug in augs.items():
207
                print("Augmentation: {}".format(name))
208
209
                test_dataset = RSNADataset(
210
                    csv_file=test_csv,
211
                    root=test_root,
212
                    with_any=with_any,
213
                    transform=aug,
214
                    mode="valid",
215
                    image_type=image_type
216
                )
217
218
                test_loader = DataLoader(
219
                    dataset=test_dataset,
220
                    batch_size=64,
221
                    shuffle=False,
222
                    num_workers=8,
223
                )
224
225
                test_preds += predict(model, test_loader) / (len(augs) * len(top_best_metrics))
226
227
        os.makedirs(f"/logs/prediction/{scheme}", exist_ok=True)
228
        np.save(f"/logs/prediction/{scheme}/valid_{scheme}.npy", test_preds)
229
230
        test_df = pd.read_csv(test_csv)
231
        test_ids = test_df['sop_instance_uid'].values
232
233
        ids = []
234
        labels = []
235
        for i, id in enumerate(test_ids):
236
            if not "ID" in id:
237
                id = "ID_" + id
238
            pred = test_preds[i]
239
            for j, target in enumerate(target_cols):
240
                id_target = id + "_" + target
241
                ids.append(id_target)
242
                labels.append(pred[j])
243
            if not with_any:
244
                id_target = id + "_" + "any"
245
                ids.append(id_target)
246
                labels.append(pred.max())
247
248
        submission_df = pd.DataFrame({
249
            'ID': ids,
250
            'Label': labels
251
        })
252
253
        submission_df.to_csv(f"/logs/prediction/{scheme}/valid_{scheme}.csv", index=False)
254
255
256
if __name__ == '__main__':
257
    # predict_test()
258
    predict_test_tta_ckp()
259
    # predict_valid_tta_ckp()