|
a |
|
b/test/data/test_split.py |
|
|
1 |
import unittest |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
|
|
|
5 |
from torchdrug import data, datasets |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class SplitTest(unittest.TestCase): |
|
|
9 |
|
|
|
10 |
def setUp(self): |
|
|
11 |
smiles = ["CC1CCC(C(C)C)C(O)C1", # scaffold: C1CCCCC1 |
|
|
12 |
"OC1CCCCC1", |
|
|
13 |
"CCSC(=O)N(CC)C1CCCCC1", |
|
|
14 |
"ClC1C(Cl)C(Cl)C(Cl)C(Cl)C1Cl", |
|
|
15 |
"CC1CCC(C)CC1", |
|
|
16 |
"CCN(CC)c1nc(Cl)nc(N(CC)CC)n1", # scaffold: c1ncncn1 |
|
|
17 |
"CCNc1nc(NC(C)C)nc(SC)n1", |
|
|
18 |
"CCNc1nc(NC(C)(C)C)nc(SC)n1", |
|
|
19 |
"CCNc1nc(NC(C)C)nc(OC)n1", |
|
|
20 |
"CCNc1nc(Cl)nc(NCC)n1"] |
|
|
21 |
self.dataset = data.MoleculeDataset() |
|
|
22 |
self.dataset.load_smiles(smiles, {}) |
|
|
23 |
self.lengths = [5, 5] |
|
|
24 |
|
|
|
25 |
def test_scaffold(self): |
|
|
26 |
train, test = data.scaffold_split(self.dataset, self.lengths) |
|
|
27 |
train_scaffolds = set(sample["graph"].to_scaffold() for sample in train) |
|
|
28 |
test_scaffolds = set(sample["graph"].to_scaffold() for sample in test) |
|
|
29 |
self.assertEqual(len(train_scaffolds), 1, "Incorrect scaffold split") |
|
|
30 |
self.assertEqual(len(test_scaffolds), 1, "Incorrect scaffold split") |
|
|
31 |
self.assertFalse(train_scaffolds.intersection(test_scaffolds), "Incorrect scaffold split") |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
if __name__ == "__main__": |
|
|
35 |
unittest.main() |