[554282]: / Learning / predict.py

Download this file

39 lines (33 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from model import DeepDrug3D
def load_data(path):
with h5py.File(path, 'r') as f:
X = f['X'][()]
if len(X.shape) == 4:
X = np.expand_dims(X, axis=0)
X = torch.FloatTensor(X)
return X
def predict(path, model_path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Current device: ' + str(device))
net = DeepDrug3D(14)
net.load_state_dict(torch.load(model_path))
net.to(device)
data = load_data(path)
data = data.to(device)
output = net(data)
proba = F.softmax(output, dim=1)
score = proba.data.cpu().numpy()[0]
print('The probability of pocket provided binds with ATP ligands: {:.4f}'.format(score[0]))
print('The probability of pocket provided binds with Heme ligands: {:.4f}'.format(score[1]))
print('The probability of pocket provided binds with other ligands: {:.4f}'.format(score[2]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--f', type=str, required=True, help='Input h5 file name')
parser.add_argument('--m', type=str, required=True, help='Path to the trained model weights')
opt = parser.parse_args()
predict(opt.f, opt.m)