|
a |
|
b/predict.py |
|
|
1 |
import dataset |
|
|
2 |
import utils |
|
|
3 |
import os |
|
|
4 |
import argparse |
|
|
5 |
import torch.backends.cudnn as cudnn |
|
|
6 |
import torch |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
import time |
|
|
11 |
|
|
|
12 |
parser = argparse.ArgumentParser(description='PET lymphoma classification') |
|
|
13 |
|
|
|
14 |
#I/O PARAMS |
|
|
15 |
parser.add_argument('--output', type=str, default='.', help='name of output folder (default: ".")') |
|
|
16 |
parser.add_argument('--splits', type=int, default=None, nargs='+', help='data split range to predict for') |
|
|
17 |
|
|
|
18 |
#MODEL PARAMS |
|
|
19 |
parser.add_argument('--chptfolder', type=str, default='results', help='path to trained models (default: "results")') |
|
|
20 |
parser.add_argument('--normalize', action='store_true', default=False, help='normalize images') |
|
|
21 |
parser.add_argument('--nfeat', type=int, default=512, help='number of embedded features (default: 512)') |
|
|
22 |
|
|
|
23 |
#TRAINING PARAMS |
|
|
24 |
parser.add_argument('--batch_size', type=int, default=200, help='how many images to sample per slide (default: 200)') |
|
|
25 |
parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)') |
|
|
26 |
|
|
|
27 |
def main(): |
|
|
28 |
# Get user input |
|
|
29 |
global args |
|
|
30 |
args = parser.parse_args() |
|
|
31 |
# best run per split |
|
|
32 |
bestrun = pd.read_csv(os.path.join(args.chptfolder,'best_run.csv')) |
|
|
33 |
# set per split |
|
|
34 |
if args.splits is not None: |
|
|
35 |
splits = range(args.splits[0],args.splits[1]+1) |
|
|
36 |
else: |
|
|
37 |
splits = range(20) |
|
|
38 |
t0 = time.time() |
|
|
39 |
for split in splits: |
|
|
40 |
# Best checkpoint of this split |
|
|
41 |
run = bestrun[bestrun.split==split].run.values[0] |
|
|
42 |
chpnt = os.path.join(args.chptfolder,'checkpoint_split'+str(split)+'_run'+str(run)+'.pth') |
|
|
43 |
# Get model |
|
|
44 |
model = utils.model_prediction(chpnt) |
|
|
45 |
model.cuda() |
|
|
46 |
cudnn.benchmark = True |
|
|
47 |
# Test dataset |
|
|
48 |
_,_,_,dset,_ = dataset.get_datasets_singleview(None,args.normalize,False,split) |
|
|
49 |
|
|
|
50 |
print(' dataset size: {}'.format(len(dset.df))) |
|
|
51 |
out = predict(dset, model) |
|
|
52 |
pred_name = 'pred_split'+str(split)+'_run'+str(run)+'.csv' |
|
|
53 |
out.to_csv(os.path.join(args.output, pred_name), index=False) |
|
|
54 |
print(' loop time: {:.0f} min'.format((time.time()-t0)/60)) |
|
|
55 |
t0 = time.time() |
|
|
56 |
|
|
|
57 |
def predict(dset, model): |
|
|
58 |
# Set loaders |
|
|
59 |
loader = torch.utils.data.DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) |
|
|
60 |
# Evaluate on test data |
|
|
61 |
probs, feats = test(loader, model) |
|
|
62 |
out = dset.df.copy().reset_index(drop=True) |
|
|
63 |
out['probs'] = probs |
|
|
64 |
feats = pd.DataFrame(feats, columns=['F{}'.format(x) for x in range(1,args.nfeat+1)]) |
|
|
65 |
out = pd.concat([out, feats], axis=1) |
|
|
66 |
return out |
|
|
67 |
|
|
|
68 |
def test(loader, model): |
|
|
69 |
# Set model in test mode |
|
|
70 |
model.eval() |
|
|
71 |
# Initialize probability vector |
|
|
72 |
probs = torch.FloatTensor(len(loader.dataset)).cuda() |
|
|
73 |
# Initialize features |
|
|
74 |
feats = np.empty((len(loader.dataset),args.nfeat)) |
|
|
75 |
# Loop through batches |
|
|
76 |
with torch.no_grad(): |
|
|
77 |
for i, (input,_) in enumerate(loader): |
|
|
78 |
## Copy batch to GPU |
|
|
79 |
input = input.cuda() |
|
|
80 |
## Forward pass |
|
|
81 |
f, output = model.forward_feat(input) |
|
|
82 |
output = F.softmax(output, dim=1) |
|
|
83 |
## Features |
|
|
84 |
feats[i*args.batch_size:i*args.batch_size+input.size(0),:] = f.detach().cpu().numpy() |
|
|
85 |
## Clone output to output vector |
|
|
86 |
probs[i*args.batch_size:i*args.batch_size+input.size(0)] = output.detach()[:,1].clone() |
|
|
87 |
return probs.cpu().numpy(), feats |
|
|
88 |
|
|
|
89 |
if __name__ == '__main__': |
|
|
90 |
main() |