a b/extract_features_fp.py
1
import torch
2
import torch.nn as nn
3
from math import floor
4
import os
5
import random
6
import numpy as np
7
import pdb
8
import time
9
from datasets.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag_FP
10
from torch.utils.data import DataLoader
11
from models.resnet_custom import resnet50_baseline
12
import argparse
13
from utils.utils import print_network, collate_features
14
from utils.file_utils import save_hdf5
15
from PIL import Image
16
import h5py
17
import openslide
18
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
19
20
def compute_w_loader(file_path, output_path, wsi, model,
21
    batch_size = 8, verbose = 0, print_every=20, pretrained=True, 
22
    custom_downsample=1, target_patch_size=-1):
23
    """
24
    args:
25
        file_path: directory of bag (.h5 file)
26
        output_path: directory to save computed features (.h5 file)
27
        model: pytorch model
28
        batch_size: batch_size for computing features in batches
29
        verbose: level of feedback
30
        pretrained: use weights pretrained on imagenet
31
        custom_downsample: custom defined downscale factor of image patches
32
        target_patch_size: custom defined, rescaled image size before embedding
33
    """
34
    dataset = Whole_Slide_Bag_FP(file_path=file_path, wsi=wsi, pretrained=pretrained, 
35
        custom_downsample=custom_downsample, target_patch_size=target_patch_size)
36
    x, y = dataset[0]
37
    kwargs = {'num_workers': 4, 'pin_memory': True} if device.type == "cuda" else {}
38
    loader = DataLoader(dataset=dataset, batch_size=batch_size, **kwargs, collate_fn=collate_features)
39
40
    if verbose > 0:
41
        print('processing {}: total of {} batches'.format(file_path,len(loader)))
42
43
    mode = 'w'
44
    for count, (batch, coords) in enumerate(loader):
45
        with torch.no_grad():   
46
            if count % print_every == 0:
47
                print('batch {}/{}, {} files processed'.format(count, len(loader), count * batch_size))
48
            batch = batch.to(device, non_blocking=True)
49
            
50
            features = model(batch)
51
            features = features.cpu().numpy()
52
53
            asset_dict = {'features': features, 'coords': coords}
54
            save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode)
55
            mode = 'a'
56
    
57
    return output_path
58
59
60
parser = argparse.ArgumentParser(description='Feature Extraction')
61
parser.add_argument('--data_h5_dir', type=str, default=None)
62
parser.add_argument('--data_slide_dir', type=str, default=None)
63
parser.add_argument('--slide_ext', type=str, default= '.svs')
64
parser.add_argument('--csv_path', type=str, default=None)
65
parser.add_argument('--feat_dir', type=str, default=None)
66
parser.add_argument('--batch_size', type=int, default=256)
67
parser.add_argument('--no_auto_skip', default=False, action='store_true')
68
parser.add_argument('--custom_downsample', type=int, default=1)
69
parser.add_argument('--target_patch_size', type=int, default=-1)
70
args = parser.parse_args()
71
72
73
if __name__ == '__main__':
74
75
    print('initializing dataset')
76
    csv_path = args.csv_path
77
    if csv_path is None:
78
        raise NotImplementedError
79
80
    bags_dataset = Dataset_All_Bags(csv_path)
81
    
82
    os.makedirs(args.feat_dir, exist_ok=True)
83
    os.makedirs(os.path.join(args.feat_dir, 'pt_files'), exist_ok=True)
84
    os.makedirs(os.path.join(args.feat_dir, 'h5_files'), exist_ok=True)
85
    dest_files = os.listdir(os.path.join(args.feat_dir, 'pt_files'))
86
87
    print('loading model checkpoint')
88
    model = resnet50_baseline(pretrained=True)
89
    model = model.to(device)
90
    
91
    # print_network(model)
92
    if torch.cuda.device_count() > 1:
93
        model = nn.DataParallel(model)
94
        
95
    model.eval()
96
    total = len(bags_dataset)
97
98
    for bag_candidate_idx in range(total):
99
        slide_id = bags_dataset[bag_candidate_idx].split(args.slide_ext)[0]
100
        bag_name = slide_id+'.h5'
101
        h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name)
102
        slide_file_path = os.path.join(args.data_slide_dir, slide_id+args.slide_ext)
103
        print('\nprogress: {}/{}'.format(bag_candidate_idx, total))
104
        print(slide_id)
105
106
        if not args.no_auto_skip and slide_id+'.pt' in dest_files:
107
            print('skipped {}'.format(slide_id))
108
            continue 
109
110
        output_path = os.path.join(args.feat_dir, 'h5_files', bag_name)
111
        time_start = time.time()
112
        wsi = openslide.open_slide(slide_file_path)
113
        output_file_path = compute_w_loader(h5_file_path, output_path, wsi, 
114
        model = model, batch_size = args.batch_size, verbose = 1, print_every = 20, 
115
        custom_downsample=args.custom_downsample, target_patch_size=args.target_patch_size)
116
        time_elapsed = time.time() - time_start
117
        print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed))
118
        file = h5py.File(output_file_path, "r")
119
120
        features = file['features'][:]
121
        print('features size: ', features.shape)
122
        print('coordinates size: ', file['coords'].shape)
123
        features = torch.from_numpy(features)
124
        bag_base, _ = os.path.splitext(bag_name)
125
        torch.save(features, os.path.join(args.feat_dir, 'pt_files', bag_base+'.pt'))
126
127
128