Diff of /predict.py [000000] .. [1928b6]

Switch to unified view

a b/predict.py
1
import dataset
2
import utils
3
import os
4
import argparse
5
import torch.backends.cudnn as cudnn
6
import torch
7
import torch.nn.functional as F
8
import numpy as np
9
import pandas as pd
10
import time
11
12
parser = argparse.ArgumentParser(description='PET lymphoma classification')
13
14
#I/O PARAMS
15
parser.add_argument('--output', type=str, default='.', help='name of output folder (default: ".")')
16
parser.add_argument('--splits', type=int, default=None, nargs='+', help='data split range to predict for')
17
18
#MODEL PARAMS
19
parser.add_argument('--chptfolder', type=str, default='results', help='path to trained models (default: "results")')
20
parser.add_argument('--normalize', action='store_true', default=False, help='normalize images')
21
parser.add_argument('--nfeat', type=int, default=512, help='number of embedded features (default: 512)')
22
   
23
#TRAINING PARAMS
24
parser.add_argument('--batch_size', type=int, default=200, help='how many images to sample per slide (default: 200)')
25
parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)')
26
27
def main():
28
    # Get user input
29
    global args
30
    args = parser.parse_args()               
31
    # best run per split
32
    bestrun = pd.read_csv(os.path.join(args.chptfolder,'best_run.csv'))
33
    # set per split
34
    if args.splits is not None:
35
        splits = range(args.splits[0],args.splits[1]+1)
36
    else:
37
        splits = range(20)
38
    t0 = time.time()
39
    for split in splits:
40
        # Best checkpoint of this split
41
        run = bestrun[bestrun.split==split].run.values[0]
42
        chpnt = os.path.join(args.chptfolder,'checkpoint_split'+str(split)+'_run'+str(run)+'.pth')
43
        # Get model
44
        model = utils.model_prediction(chpnt)
45
        model.cuda()
46
        cudnn.benchmark = True
47
        # Test dataset
48
        _,_,_,dset,_ = dataset.get_datasets_singleview(None,args.normalize,False,split)
49
50
        print('  dataset size: {}'.format(len(dset.df)))
51
        out = predict(dset, model)
52
        pred_name = 'pred_split'+str(split)+'_run'+str(run)+'.csv'
53
        out.to_csv(os.path.join(args.output, pred_name), index=False)
54
        print('  loop time: {:.0f} min'.format((time.time()-t0)/60))
55
        t0 = time.time()
56
        
57
def predict(dset, model):
58
    # Set loaders
59
    loader = torch.utils.data.DataLoader(dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)    
60
    # Evaluate on test data
61
    probs, feats = test(loader, model)
62
    out = dset.df.copy().reset_index(drop=True)
63
    out['probs'] = probs
64
    feats = pd.DataFrame(feats, columns=['F{}'.format(x) for x in range(1,args.nfeat+1)])
65
    out = pd.concat([out, feats], axis=1)
66
    return out
67
68
def test(loader, model):
69
    # Set model in test mode
70
    model.eval()
71
    # Initialize probability vector
72
    probs = torch.FloatTensor(len(loader.dataset)).cuda()
73
    # Initialize features
74
    feats = np.empty((len(loader.dataset),args.nfeat))
75
    # Loop through batches
76
    with torch.no_grad():
77
        for i, (input,_) in enumerate(loader):
78
            ## Copy batch to GPU
79
            input = input.cuda()
80
            ## Forward pass
81
            f, output = model.forward_feat(input)
82
            output = F.softmax(output, dim=1)
83
            ## Features
84
            feats[i*args.batch_size:i*args.batch_size+input.size(0),:] = f.detach().cpu().numpy()
85
            ## Clone output to output vector
86
            probs[i*args.batch_size:i*args.batch_size+input.size(0)] = output.detach()[:,1].clone()
87
    return probs.cpu().numpy(), feats
88
89
if __name__ == '__main__':
90
    main()