a b/test/data/test_molecule.py
1
import unittest
2
3
import torch
4
5
from torchdrug import data
6
7
8
class MoleculeTest(unittest.TestCase):
9
10
    def setUp(self):
11
        self.smiles = "CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@](C#N)" \
12
                      "([C@H](O)[C@@H]1O)C1=CC=C2N1N=CN=C2N)OC1=CC=CC=C1"
13
14
    def test_smiles(self):
15
        mol = data.Molecule.from_smiles(self.smiles)
16
        smiles = mol.to_smiles().upper()
17
        carbon_result = (mol.atom_type == 6).sum().item()
18
        carbon_truth = self.smiles.count("C")
19
        atom_result = mol.num_atom
20
        atom_truth = self.smiles.count("C") + self.smiles.count("O") + self.smiles.count("N") + self.smiles.count("P")
21
        self.assertEqual(carbon_result, carbon_truth, "Incorrect SMILES construction")
22
        self.assertEqual(atom_result, atom_truth, "Incorrect SMILES construction")
23
        carbon_result = smiles.count("C")
24
        atom_result = smiles.count("C") + smiles.count("O") + smiles.count("N") + smiles.count("P")
25
        self.assertEqual(carbon_result, carbon_truth, "Incorrect SMILES construction")
26
        self.assertEqual(atom_result, atom_truth, "Incorrect SMILES construction")
27
28
        mol = data.Molecule.from_smiles("")
29
        self.assertEqual(mol.num_node, 0, "Incorrect SMILES side case")
30
        self.assertEqual(mol.num_edge, 0, "Incorrect SMILES side case")
31
        mols = data.PackedMolecule.from_smiles([""])
32
        self.assertTrue((mols.num_nodes == 0).all(), "Incorrect SMILES side case")
33
        self.assertTrue((mols.num_edges == 0).all(), "Incorrect SMILES side case")
34
35
    def test_feature(self):
36
        mol = data.Molecule.from_smiles(self.smiles, mol_feature="ecfp")
37
        self.assertTrue((mol.graph_feature > 0).any(), "Incorrect ECFP feature")
38
39
40
if __name__ == "__main__":
41
    unittest.main()