Diff of /Learning/predict.py [000000] .. [b623ff]

Switch to unified view

a b/Learning/predict.py
1
import argparse
2
import h5py
3
4
import numpy as np
5
import torch
6
import torch.nn.functional as F
7
from model import DeepDrug3D
8
9
def load_data(path):
10
    with h5py.File(path, 'r') as f:
11
        X = f['X'][()]
12
    if len(X.shape) == 4:
13
        X = np.expand_dims(X, axis=0)
14
    X = torch.FloatTensor(X)
15
    return X
16
17
def predict(path, model_path):
18
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
    print('Current device: ' + str(device))
20
    
21
    net = DeepDrug3D(14)
22
    net.load_state_dict(torch.load(model_path))
23
    net.to(device)
24
    data = load_data(path)
25
    data = data.to(device)
26
    output = net(data)
27
    proba = F.softmax(output, dim=1)
28
    score = proba.data.cpu().numpy()[0]
29
    print('The probability of pocket provided binds with ATP ligands: {:.4f}'.format(score[0]))
30
    print('The probability of pocket provided binds with Heme ligands: {:.4f}'.format(score[1]))
31
    print('The probability of pocket provided binds with other ligands: {:.4f}'.format(score[2]))
32
33
if __name__ == "__main__":
34
    parser = argparse.ArgumentParser()
35
    parser.add_argument('--f', type=str, required=True, help='Input h5 file name')
36
    parser.add_argument('--m', type=str, required=True, help='Path to the trained model weights')
37
    opt = parser.parse_args()
38
39
    predict(opt.f, opt.m)