|
a |
|
b/test/layers/test_common.py |
|
|
1 |
import unittest |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
from torch import nn |
|
|
5 |
|
|
|
6 |
from torchdrug import layers |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class CommonTest(unittest.TestCase): |
|
|
10 |
|
|
|
11 |
def setUp(self): |
|
|
12 |
self.a = torch.randn(10) |
|
|
13 |
self.b = torch.randn(10) |
|
|
14 |
self.g = torch.randn(10) |
|
|
15 |
|
|
|
16 |
def test_sequential(self): |
|
|
17 |
layer1 = nn.Module() |
|
|
18 |
layer2 = nn.Module() |
|
|
19 |
layer3 = nn.Module() |
|
|
20 |
|
|
|
21 |
layer1.forward = lambda a, b: (a + 1, b + 2) |
|
|
22 |
layer2.forward = lambda a, b: a * b |
|
|
23 |
layer = layers.Sequential(layer1, layer2) |
|
|
24 |
result = layer(self.a, self.b) |
|
|
25 |
truth = layer2(*layer1(self.a, self.b)) |
|
|
26 |
self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer") |
|
|
27 |
|
|
|
28 |
layer1.forward = lambda g, a: g + a |
|
|
29 |
layer2.forward = lambda b: b * 2 |
|
|
30 |
layer3.forward = lambda g, c: g * c |
|
|
31 |
layer = layers.Sequential(layer1, layer2, layer3, global_args=("g",)) |
|
|
32 |
result = layer(self.g, self.a) |
|
|
33 |
truth = layer3(self.g, layer2(layer1(self.g, self.a))) |
|
|
34 |
self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer") |
|
|
35 |
|
|
|
36 |
layer1.forward = lambda a: {"b": a + 1, "c": a + 2} |
|
|
37 |
layer2.forward = lambda b: b * 2 |
|
|
38 |
layer = layers.Sequential(layer1, layer2, allow_unused=True) |
|
|
39 |
result = layer(self.a) |
|
|
40 |
truth = layer2(layer1(self.a)["b"]) |
|
|
41 |
self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer") |
|
|
42 |
|
|
|
43 |
layer1.forward = lambda g, a: {"g": g + 1, "b": a + 2} |
|
|
44 |
layer2.forward = lambda g, b: g * b |
|
|
45 |
layer = layers.Sequential(layer1, layer2, global_args=("g",)) |
|
|
46 |
result = layer(self.g, self.a) |
|
|
47 |
truth = layer2(**layer1(self.g, self.a)) |
|
|
48 |
self.assertTrue(torch.allclose(result, truth), "Incorrect sequential layer") |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
if __name__ == "__main__": |
|
|
52 |
unittest.main() |