[3eb847]: / test / utils / test_comm.py

Download this file

78 lines (64 with data), 3.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import unittest
import torch
from torch import multiprocessing as mp
from torchdrug.utils import comm
def worker(rank, reduce_fn, objs, queue, event):
comm.init_process_group("nccl", init_method="env://", rank=rank)
result = reduce_fn(objs[rank])
queue.put((rank, result))
event.wait()
class ReduceTest(unittest.TestCase):
def setUp(self):
self.num_worker = 4
self.asymmetric_objs = []
self.objs = []
for i in range(self.num_worker):
obj = {"a": torch.randint(5, (3,)).cuda(), "b": torch.rand(5).cuda()}
asymmetric_obj = {"a": torch.randint(5, (i + 1,)).cuda(), "b": torch.rand(i * 3 + 1).cuda()}
self.objs.append(obj)
self.asymmetric_objs.append(asymmetric_obj)
self.ctx = mp.get_context("spawn")
os.environ["WORLD_SIZE"] = str(self.num_worker)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "1024"
def test_reduce(self):
queue = self.ctx.Queue()
event = self.ctx.Event()
spawn_ctx = mp.spawn(worker, (comm.reduce, self.objs, queue, event), nprocs=self.num_worker, join=False)
truth = {}
truth["a"] = torch.stack([obj["a"] for obj in self.objs]).sum(dim=0)
truth["b"] = torch.stack([obj["b"] for obj in self.objs]).sum(dim=0)
for i in range(self.num_worker):
rank, result = queue.get()
self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect reduce operator")
self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect reduce operator")
del result
event.set()
spawn_ctx.join()
event.clear()
spawn_ctx = mp.spawn(worker, (comm.stack, self.objs, queue, event), nprocs=self.num_worker, join=False)
truth = {}
truth["a"] = torch.stack([obj["a"] for obj in self.objs])
truth["b"] = torch.stack([obj["b"] for obj in self.objs])
for i in range(self.num_worker):
rank, result = queue.get()
self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect stack operator")
self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect stack operator")
del result
event.set()
spawn_ctx.join()
event.clear()
spawn_ctx = mp.spawn(worker, (comm.cat, self.asymmetric_objs, queue, event), nprocs=self.num_worker, join=False)
truth = {}
truth["a"] = torch.cat([obj["a"] for obj in self.asymmetric_objs])
truth["b"] = torch.cat([obj["b"] for obj in self.asymmetric_objs])
for i in range(self.num_worker):
rank, result = queue.get()
self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect cat operator")
self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect cat operator")
del result
event.set()
spawn_ctx.join()
if __name__ == "__main__":
unittest.main()