Diff of /test/data/test_split.py [000000] .. [36b44b]

Switch to unified view

a b/test/data/test_split.py
1
import unittest
2
3
import torch
4
5
from torchdrug import data, datasets
6
7
8
class SplitTest(unittest.TestCase):
9
10
    def setUp(self):
11
        smiles = ["CC1CCC(C(C)C)C(O)C1", # scaffold: C1CCCCC1
12
                  "OC1CCCCC1",
13
                  "CCSC(=O)N(CC)C1CCCCC1",
14
                  "ClC1C(Cl)C(Cl)C(Cl)C(Cl)C1Cl",
15
                  "CC1CCC(C)CC1",
16
                  "CCN(CC)c1nc(Cl)nc(N(CC)CC)n1", # scaffold: c1ncncn1
17
                  "CCNc1nc(NC(C)C)nc(SC)n1",
18
                  "CCNc1nc(NC(C)(C)C)nc(SC)n1",
19
                  "CCNc1nc(NC(C)C)nc(OC)n1",
20
                  "CCNc1nc(Cl)nc(NCC)n1"]
21
        self.dataset = data.MoleculeDataset()
22
        self.dataset.load_smiles(smiles, {})
23
        self.lengths = [5, 5]
24
25
    def test_scaffold(self):
26
        train, test = data.scaffold_split(self.dataset, self.lengths)
27
        train_scaffolds = set(sample["graph"].to_scaffold() for sample in train)
28
        test_scaffolds = set(sample["graph"].to_scaffold() for sample in test)
29
        self.assertEqual(len(train_scaffolds), 1, "Incorrect scaffold split")
30
        self.assertEqual(len(test_scaffolds), 1, "Incorrect scaffold split")
31
        self.assertFalse(train_scaffolds.intersection(test_scaffolds), "Incorrect scaffold split")
32
33
34
if __name__ == "__main__":
35
    unittest.main()