Diff of /test/utils/test_comm.py [000000] .. [36b44b]

Switch to side-by-side view

--- a
+++ b/test/utils/test_comm.py
@@ -0,0 +1,78 @@
+import os
+import unittest
+
+import torch
+from torch import multiprocessing as mp
+
+from torchdrug.utils import comm
+
+
+def worker(rank, reduce_fn, objs, queue, event):
+    comm.init_process_group("nccl", init_method="env://", rank=rank)
+    result = reduce_fn(objs[rank])
+    queue.put((rank, result))
+    event.wait()
+
+
+class ReduceTest(unittest.TestCase):
+
+    def setUp(self):
+        self.num_worker = 4
+        self.asymmetric_objs = []
+        self.objs = []
+        for i in range(self.num_worker):
+            obj = {"a": torch.randint(5, (3,)).cuda(), "b": torch.rand(5).cuda()}
+            asymmetric_obj = {"a": torch.randint(5, (i + 1,)).cuda(), "b": torch.rand(i * 3 + 1).cuda()}
+            self.objs.append(obj)
+            self.asymmetric_objs.append(asymmetric_obj)
+
+        self.ctx = mp.get_context("spawn")
+        os.environ["WORLD_SIZE"] = str(self.num_worker)
+        os.environ["MASTER_ADDR"] = "localhost"
+        os.environ["MASTER_PORT"] = "1024"
+
+    def test_reduce(self):
+        queue = self.ctx.Queue()
+        event = self.ctx.Event()
+
+        spawn_ctx = mp.spawn(worker, (comm.reduce, self.objs, queue, event), nprocs=self.num_worker, join=False)
+        truth = {}
+        truth["a"] = torch.stack([obj["a"] for obj in self.objs]).sum(dim=0)
+        truth["b"] = torch.stack([obj["b"] for obj in self.objs]).sum(dim=0)
+        for i in range(self.num_worker):
+            rank, result = queue.get()
+            self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect reduce operator")
+            self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect reduce operator")
+            del result
+        event.set()
+        spawn_ctx.join()
+
+        event.clear()
+        spawn_ctx = mp.spawn(worker, (comm.stack, self.objs, queue, event), nprocs=self.num_worker, join=False)
+        truth = {}
+        truth["a"] = torch.stack([obj["a"] for obj in self.objs])
+        truth["b"] = torch.stack([obj["b"] for obj in self.objs])
+        for i in range(self.num_worker):
+            rank, result = queue.get()
+            self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect stack operator")
+            self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect stack operator")
+            del result
+        event.set()
+        spawn_ctx.join()
+
+        event.clear()
+        spawn_ctx = mp.spawn(worker, (comm.cat, self.asymmetric_objs, queue, event), nprocs=self.num_worker, join=False)
+        truth = {}
+        truth["a"] = torch.cat([obj["a"] for obj in self.asymmetric_objs])
+        truth["b"] = torch.cat([obj["b"] for obj in self.asymmetric_objs])
+        for i in range(self.num_worker):
+            rank, result = queue.get()
+            self.assertTrue(torch.allclose(result["a"], truth["a"]), "Incorrect cat operator")
+            self.assertTrue(torch.allclose(result["b"], truth["b"]), "Incorrect cat operator")
+            del result
+        event.set()
+        spawn_ctx.join()
+
+
+if __name__ == "__main__":
+    unittest.main()
\ No newline at end of file