Switch to unified view

a b/test/layers/test_common.py
1
import unittest
2
3
import torch
4
from torch import nn
5
6
from torchdrug import layers
7
8
9
class CommonTest(unittest.TestCase):
10
11
    def setUp(self):
12
        self.a = torch.randn(10)
13
        self.b = torch.randn(10)
14
        self.g = torch.randn(10)
15
16
    def test_sequential(self):
17
        layer1 = nn.Module()
18
        layer2 = nn.Module()
19
        layer3 = nn.Module()
20
21
        layer1.forward = lambda a, b: (a + 1, b + 2)
22
        layer2.forward = lambda a, b: a * b
23
        layer = layers.Sequential(layer1, layer2)
24
        result = layer(self.a, self.b)
25
        truth = layer2(*layer1(self.a, self.b))
26
        self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer")
27
28
        layer1.forward = lambda g, a: g + a
29
        layer2.forward = lambda b: b * 2
30
        layer3.forward = lambda g, c: g * c
31
        layer = layers.Sequential(layer1, layer2, layer3, global_args=("g",))
32
        result = layer(self.g, self.a)
33
        truth = layer3(self.g, layer2(layer1(self.g, self.a)))
34
        self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer")
35
36
        layer1.forward = lambda a: {"b": a + 1, "c": a + 2}
37
        layer2.forward = lambda b: b * 2
38
        layer = layers.Sequential(layer1, layer2, allow_unused=True)
39
        result = layer(self.a)
40
        truth = layer2(layer1(self.a)["b"])
41
        self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer")
42
43
        layer1.forward = lambda g, a: {"g": g + 1, "b": a + 2}
44
        layer2.forward = lambda g, b: g * b
45
        layer = layers.Sequential(layer1, layer2, global_args=("g",))
46
        result = layer(self.g, self.a)
47
        truth = layer2(**layer1(self.g, self.a))
48
        self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer")
49
50
51
if __name__ == "__main__":
52
    unittest.main()