|
a |
|
b/test/data/test_molecule.py |
|
|
1 |
import unittest |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
from torchdrug import data |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class MoleculeTest(unittest.TestCase): |
|
|
9 |
|
|
|
10 |
def setUp(self): |
|
|
11 |
self.smiles = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)" \ |
|
|
12 |
"([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1" |
|
|
13 |
|
|
|
14 |
def test_smiles(self): |
|
|
15 |
mol = data.Molecule.from_smiles(self.smiles) |
|
|
16 |
smiles = mol.to_smiles().upper() |
|
|
17 |
carbon_result = (mol.atom_type == 6).sum().item() |
|
|
18 |
carbon_truth = self.smiles.count("C") |
|
|
19 |
atom_result = mol.num_atom |
|
|
20 |
atom_truth = self.smiles.count("C") + self.smiles.count("O") + self.smiles.count("N") + self.smiles.count("P") |
|
|
21 |
self.assertEqual(carbon_result, carbon_truth, "Incorrect SMILES construction") |
|
|
22 |
self.assertEqual(atom_result, atom_truth, "Incorrect SMILES construction") |
|
|
23 |
carbon_result = smiles.count("C") |
|
|
24 |
atom_result = smiles.count("C") + smiles.count("O") + smiles.count("N") + smiles.count("P") |
|
|
25 |
self.assertEqual(carbon_result, carbon_truth, "Incorrect SMILES construction") |
|
|
26 |
self.assertEqual(atom_result, atom_truth, "Incorrect SMILES construction") |
|
|
27 |
|
|
|
28 |
mol = data.Molecule.from_smiles("") |
|
|
29 |
self.assertEqual(mol.num_node, 0, "Incorrect SMILES side case") |
|
|
30 |
self.assertEqual(mol.num_edge, 0, "Incorrect SMILES side case") |
|
|
31 |
mols = data.PackedMolecule.from_smiles([""]) |
|
|
32 |
self.assertTrue((mols.num_nodes == 0).all(), "Incorrect SMILES side case") |
|
|
33 |
self.assertTrue((mols.num_edges == 0).all(), "Incorrect SMILES side case") |
|
|
34 |
|
|
|
35 |
def test_feature(self): |
|
|
36 |
mol = data.Molecule.from_smiles(self.smiles, mol_feature="ecfp") |
|
|
37 |
self.assertTrue((mol.graph_feature > 0).any(), "Incorrect ECFP feature") |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
if __name__ == "__main__": |
|
|
41 |
unittest.main() |