--- a +++ b/tests/graphs/test_object_api.py @@ -0,0 +1,124 @@ +from pathaia.graphs import Graph, UGraph, Tree +from pathaia.graphs.errors import InvalidTree +from scipy.sparse import csr_matrix +from pytest import raises, warns +import itertools +from ordered_set import OrderedSet + + +def test_graph_init(): + nodes_init = (None, (1, 2, 3), "123") + edges_init = (None, ((1, 3), (1, 2), (2, 1))) + A_init = (None, csr_matrix(((1, 1), ((0, 1), (1, 2))), shape=(3, 3), dtype=bool)) + + expected = ( + (OrderedSet(), set(), csr_matrix((0, 0), dtype=bool)), + ( + OrderedSet((0, 1, 2)), + {(0, 1), (1, 2)}, + csr_matrix(((1, 1), ((0, 1), (1, 2))), shape=(3, 3), dtype=bool), + ), + ( + OrderedSet((1, 3, 2)), + {(1, 3), (1, 2), (2, 1)}, + csr_matrix(((1, 1, 1), ((0, 0, 2), (1, 2, 0))), shape=(3, 3), dtype=bool), + ), + ( + OrderedSet((1, 3, 2)), + {(1, 3), (1, 2), (2, 1)}, + csr_matrix(((1, 1, 1), ((0, 0, 2), (1, 2, 0))), shape=(3, 3), dtype=bool), + ), + (OrderedSet((1, 2, 3)), set(), csr_matrix((3, 3), dtype=bool)), + ( + OrderedSet((1, 2, 3)), + {(1, 2), (2, 3)}, + csr_matrix(((1, 1), ((0, 1), (1, 2))), shape=(3, 3), dtype=bool), + ), + ( + OrderedSet((1, 2, 3)), + {(1, 3), (1, 2), (2, 1)}, + csr_matrix(((1, 1, 1), ((0, 0, 1), (2, 1, 0))), shape=(3, 3), dtype=bool), + ), + ( + OrderedSet((1, 2, 3)), + {(1, 3), (1, 2), (2, 1)}, + csr_matrix(((1, 1, 1), ((0, 0, 1), (2, 1, 0))), shape=(3, 3), dtype=bool), + ), + (OrderedSet(("1", "2", "3")), set(), csr_matrix((3, 3), dtype=bool)), + ( + OrderedSet(("1", "2", "3")), + {("1", "2"), ("2", "3")}, + csr_matrix(((1, 1), ((0, 1), (1, 2))), shape=(3, 3), dtype=bool), + ), + KeyError, + KeyError, + ) + + for exp, (nodes, edges, A) in zip( + expected, itertools.product(nodes_init, edges_init, A_init) + ): + if isinstance(exp, tuple): + exp_nodes, exp_edges, exp_A = exp + G = Graph(nodes, edges, A) + assert G.nodes == exp_nodes + assert G.edges == exp_edges + assert (G.A[G.A > 0].A == exp_A[exp_A > 0].A).all() + else: + assert raises(exp, Graph, nodes, edges, A) + + +def test_ugraph_init(): + nodes = (1, 2, 3) + edges = ((1, 3), (1, 2), (2, 1)) + A = None + + exp_nodes, exp_edges, exp_A = ( + OrderedSet((1, 2, 3)), + {(1, 2), (1, 3)}, + csr_matrix( + ((1, 1, 1, 1), ((0, 1, 0, 2), (1, 0, 2, 0))), shape=(3, 3), dtype=bool + ), + ) + + G = UGraph(nodes, edges, A) + assert G.nodes == exp_nodes + assert G.edges == exp_edges + assert (G.A[G.A > 0].A == exp_A[exp_A > 0].A).all() + + +def test_tree_init(): + parents_init = (None, {2: 1, 3: 2, 4: 2}) + children_init = (None, {1: [2], 2: [3, 4]}, {1: [3]}) + edges_init = (None, ((1, 4), (2, 1))) + + expected = ( + (dict(), dict(), set()), + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}), + ({2: 1, 3: 2, 4: 2}, {1: {2}, 2: {3, 4}}, {(1, 2), (2, 3), (2, 4)}), + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}, UserWarning), + ({3: 1}, {1: {3}}, {(1, 3)}), + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}, UserWarning), + ({2: 1, 3: 2, 4: 2}, {1: {2}, 2: {3, 4}}, {(1, 2), (2, 3), (2, 4)}), + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}, UserWarning), + ({2: 1, 3: 2, 4: 2}, {1: {2}, 2: {3, 4}}, {(1, 2), (2, 3), (2, 4)}), + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}, UserWarning), + InvalidTree, + ({4: 1, 1: 2}, {1: {4}, 2: {1}}, {(1, 4), (2, 1)}, UserWarning), + ) + + for exp, (parents, children, edges) in zip( + expected, itertools.product(parents_init, children_init, edges_init) + ): + if isinstance(exp, tuple): + if len(exp) == 3: + exp_parents, exp_children, exp_edges = exp + T = Tree(parents=parents, children=children, edges=edges) + else: + exp_parents, exp_children, exp_edges, warning = exp + with warns(warning): + T = Tree(parents=parents, children=children, edges=edges) + assert T.parents == exp_parents + assert T.children == exp_children + assert T.edges == exp_edges + else: + assert raises(exp, Tree, parents=parents, children=children, edges=edges)