Switch to side-by-side view

--- a
+++ b/openomics/transforms/adj.py
@@ -0,0 +1,73 @@
+from typing import Union, List, Dict, Tuple
+
+import networkx as nx
+
+
+def to_scipy_adjacency(g: nx.DiGraph, nodes: Union[List[str], Dict[str, List[str]]],
+                       edge_types: Union[List[str], Tuple[str, str, str]] = None,
+                       reverse=False,
+                       format="coo", d_ntype="_N"):
+    if reverse:
+        g = g.reverse(copy=True)
+
+    if not isinstance(g, nx.MultiGraph):
+        raise NotImplementedError
+
+    if not isinstance(edge_types, (list, tuple, set)):
+        edge_types = ["_E"]
+
+    edge_index_dict = {}
+    for etype in edge_types:
+        if isinstance(g, nx.MultiGraph) and isinstance(etype, str):
+            edge_subgraph = g.edge_subgraph([(u, v, e) for u, v, e in g.edges if e == etype])
+            nodes_A = nodes
+            nodes_B = nodes
+            metapath = (d_ntype, etype, d_ntype)
+
+        elif isinstance(g, nx.MultiGraph) and isinstance(etype, tuple) and isinstance(nodes, dict):
+            metapath: Tuple[str, str, str] = etype
+            head, etype, tail = metapath
+            edge_subgraph = g.edge_subgraph([(u, v, e) for u, v, e in g.edges if e == etype])
+
+            nodes_A = nodes[head]
+            nodes_B = nodes[tail]
+
+        elif etype == "_E":
+            edge_subgraph = g.edges
+            nodes_A = nodes
+            nodes_B = nodes
+            metapath = (d_ntype, etype, d_ntype)
+        else:
+            raise Exception(f"Edge types `{edge_types}` is ill formed.")
+
+        biadj = nx.bipartite.biadjacency_matrix(edge_subgraph, row_order=nodes_A, column_order=nodes_B,
+                                                format="coo")
+
+        if format == "coo":
+            edge_index_dict[metapath] = (biadj.row, biadj.col)
+        elif format == "pyg":
+            import torch
+            edge_index_dict[metapath] = torch.stack(
+                [torch.tensor(biadj.row, dtype=torch.long),
+                 torch.tensor(biadj.col, dtype=torch.long)])
+        else:
+            edge_index_dict[metapath] = biadj
+
+    return edge_index_dict
+
+
+def slice_adj(adj, node_list: list, nodes_A, nodes_B=None):
+    """
+    Args:
+        adj:
+        node_list (list):
+        nodes_A:
+        nodes_B:
+    """
+    if nodes_B is None:
+        idx = [node_list.index(node) for node in nodes_A]
+        return adj[idx, :][:, idx]
+    else:
+        idx_A = [node_list.index(node) for node in nodes_A]
+        idx_B = [node_list.index(node) for node in nodes_B]
+        return adj[idx_A, :][:, idx_B]