|
a |
|
b/Learning/predict.py |
|
|
1 |
import argparse |
|
|
2 |
import h5py |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
import torch.nn.functional as F |
|
|
7 |
from model import DeepDrug3D |
|
|
8 |
|
|
|
9 |
def load_data(path): |
|
|
10 |
with h5py.File(path, 'r') as f: |
|
|
11 |
X = f['X'][()] |
|
|
12 |
if len(X.shape) == 4: |
|
|
13 |
X = np.expand_dims(X, axis=0) |
|
|
14 |
X = torch.FloatTensor(X) |
|
|
15 |
return X |
|
|
16 |
|
|
|
17 |
def predict(path, model_path): |
|
|
18 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
19 |
print('Current device: ' + str(device)) |
|
|
20 |
|
|
|
21 |
net = DeepDrug3D(14) |
|
|
22 |
net.load_state_dict(torch.load(model_path)) |
|
|
23 |
net.to(device) |
|
|
24 |
data = load_data(path) |
|
|
25 |
data = data.to(device) |
|
|
26 |
output = net(data) |
|
|
27 |
proba = F.softmax(output, dim=1) |
|
|
28 |
score = proba.data.cpu().numpy()[0] |
|
|
29 |
print('The probability of pocket provided binds with ATP ligands: {:.4f}'.format(score[0])) |
|
|
30 |
print('The probability of pocket provided binds with Heme ligands: {:.4f}'.format(score[1])) |
|
|
31 |
print('The probability of pocket provided binds with other ligands: {:.4f}'.format(score[2])) |
|
|
32 |
|
|
|
33 |
if __name__ == "__main__": |
|
|
34 |
parser = argparse.ArgumentParser() |
|
|
35 |
parser.add_argument('--f', type=str, required=True, help='Input h5 file name') |
|
|
36 |
parser.add_argument('--m', type=str, required=True, help='Path to the trained model weights') |
|
|
37 |
opt = parser.parse_args() |
|
|
38 |
|
|
|
39 |
predict(opt.f, opt.m) |