a b/aiagents4pharma/talk2knowledgegraphs/utils/kg_utils.py
1
#!/usr/bin/env python3
2
3
'''A utility module for knowledge graph operations'''
4
5
from typing import Tuple
6
import networkx as nx
7
import pandas as pd
8
9
def kg_to_df_pandas(kg: nx.DiGraph) -> Tuple[pd.DataFrame, pd.DataFrame]:
10
    """
11
    Convert a directed knowledge graph to a pandas DataFrame.
12
13
    Args:
14
        kg: The directed knowledge graph in networkX format.
15
16
    Returns:
17
        df_nodes: A pandas DataFrame of the nodes in the knowledge graph.
18
        df_edges: A pandas DataFrame of the edges in the knowledge graph.
19
    """
20
21
    # Create a pandas DataFrame of the nodes
22
    df_nodes = pd.DataFrame.from_dict(kg.nodes, orient='index')
23
24
    # Create a pandas DataFrame of the edges
25
    df_edges = nx.to_pandas_edgelist(kg,
26
                                    source='node_source',
27
                                    target='node_target')
28
29
    return df_nodes, df_edges
30
31
def df_pandas_to_kg(df: pd.DataFrame,
32
                    df_nodes_attrs: pd.DataFrame,
33
                    node_source: str,
34
                    node_target: str
35
                    ) -> nx.DiGraph:
36
    """
37
    Convert a pandas DataFrame to a directed knowledge graph.
38
39
    Args:
40
        df: A pandas DataFrame of the edges in the knowledge graph.
41
        df_nodes_attrs: A pandas DataFrame of the nodes in the knowledge graph.
42
        node_source: The column name of the source node in the df.
43
        node_target: The column name of the target node in the df.
44
45
    Returns:
46
        kg: The directed knowledge graph in networkX format.
47
    """
48
49
    # Assert if the columns node_source and node_target are in the df
50
    assert node_source in df.columns, f'{node_source} not in df'
51
    assert node_target in df.columns, f'{node_target} not in df'
52
53
    # Assert that the nodes in the index of the df_nodes_attrs
54
    # are present in the source and target columns of the df
55
    assert set(df_nodes_attrs.index).issubset(set(df[node_source]).\
56
                                        union(set(df[node_target]))), \
57
                                        'Nodes in index of df_nodes not found in df_edges'
58
59
    # Create a knowledge graph from the dataframes
60
    # Add edges and nodes to the knowledge graph
61
    kg = nx.from_pandas_edgelist(df,
62
                                source=node_source,
63
                                target=node_target,
64
                                create_using=nx.DiGraph,
65
                                edge_attr=True)
66
    kg.add_nodes_from(df_nodes_attrs.to_dict('index').items())
67
68
    return kg