Switch to unified view

a b/openomics/transforms/agg.py
1
from collections import defaultdict
2
from typing import Union, List, Dict, Callable, Mapping, Iterable
3
4
import dask.dataframe as dd
5
import numpy as np
6
import pandas as pd
7
from pandas.core.groupby import SeriesGroupBy
8
9
10
def get_agg_func(keyword: Union[str, Callable], use_dask=False) -> Union[str, Callable, dd.Aggregation]:
11
    """
12
13
    Args:
14
        keyword (str):
15
        use_dask (bool): Whether to create a dd.Aggregation
16
17
    Returns:
18
        func (callable): a callable function, pandas aggregator func name, or a Dask Aggregation.
19
    """
20
    if callable(keyword) or isinstance(keyword, dd.Aggregation):
21
        return keyword
22
23
    elif keyword == "unique" and use_dask:
24
        # get unique values (in a list-like np.array) from each groupby key
25
        func = concat_unique_dask_agg()
26
    elif keyword == "unique" and not use_dask:
27
        func = concat_uniques
28
29
    elif keyword == "concat":
30
        # Concatenate values into list
31
        func = concat
32
33
    else:
34
        # Any other aggregation keywords or callable function
35
        func = keyword
36
37
    return func
38
39
40
def get_multi_aggregators(agg: str, agg_for: Dict[str, Union[str, Callable, dd.Aggregation]] = None, use_dask=False) \
41
    -> Mapping[str, Union[str, dd.Aggregation]]:
42
    """
43
44
    Args:
45
        agg ():
46
        agg_for ():
47
        use_dask ():
48
49
    Returns:
50
51
    """
52
    if agg_for is None:
53
        agg_for = {}
54
55
    col2func = {col: get_agg_func(keyword, use_dask=use_dask) for col, keyword in agg_for.items()}
56
    col_aggregators = defaultdict(lambda: get_agg_func(agg, use_dask=use_dask), col2func)
57
58
    return col_aggregators
59
60
61
def concat_unique_dask_agg() -> dd.Aggregation:
62
    def chunk(s: pd.Series) -> pd.Series:
63
        '''
64
        The function applied to the individual partition (map)
65
        '''
66
67
        def to_list(x: Union[str, List, np.ndarray]) -> List:
68
            if isinstance(x, str):
69
                return [x]
70
            elif isinstance(x, np.ndarray):
71
                return x
72
            elif isinstance(x, Iterable):
73
                if any(isinstance(a, Iterable) for a in x):
74
                    return list(set(np.hstack(x)))
75
                else:
76
                    return list(set(x))
77
            else:
78
                return [x]
79
80
        return s.apply(to_list)
81
82
    def agg(s: SeriesGroupBy) -> pd.Series:
83
        '''
84
        The function which will aggregate the result from all the partitions(reduce)
85
        '''
86
        s = s._selected_obj
87
        return s.groupby(level=list(range(s.index.nlevels)), group_keys=True).apply(
88
            lambda li: np.hstack(li) if isinstance(li, Iterable) and len(li) else None)
89
90
    def finalize(s) -> pd.Series:
91
        '''
92
        The optional functional that will be applied to the result of the agg_tu functions
93
        '''
94
        return s.apply(lambda arr: np.unique(arr[~pd.isna(arr)]))
95
96
    func = dd.Aggregation(
97
        name='unique',
98
        chunk=chunk,
99
        agg=agg,
100
        finalize=finalize
101
    )
102
    return func
103
104
105
def merge_concat(a: Union[str, None, Iterable], b: Union[str, None, Iterable]) -> Union[np.ndarray, str, None]:
106
    """
107
    Used as function in pd.combine() or dd.combine()
108
    Args:
109
        a (Union[str,None,Iterable]): cell value in a pd.Series
110
        b (Union[str,None,Iterable]): cell value in a pd.Series
111
112
    Returns:
113
        combined_value (Union[np.ndarray, str, None])
114
    """
115
    a_isna = pd.isna(a)
116
    b_isna = pd.isna(b)
117
    if a_isna is True or (isinstance(a_isna, Iterable) and all(a_isna)):
118
        return b
119
    elif b_isna is True or (isinstance(b_isna, Iterable) and all(b_isna)):
120
        return a
121
    elif isinstance(a, str) and isinstance(b, str):
122
        if a == b:
123
            return a
124
        return np.array([a, b])
125
    elif not isinstance(a, Iterable) and isinstance(b, Iterable):
126
        return np.hstack([[a], b])
127
    elif isinstance(a, Iterable) and not isinstance(b, Iterable):
128
        return np.hstack([a, [b]])
129
    elif isinstance(a, Iterable) and isinstance(b, Iterable):
130
        return np.hstack([a, b])
131
    else:
132
        return b
133
134
135
def concat_uniques(series: pd.Series) -> Union[str, List, np.ndarray, None]:
136
    """ An aggregation custom function to be applied to each column of a groupby
137
    Args:
138
        series (pd.Series): Entries can be either a string or a list of strings.
139
    Returns:
140
        unique_values
141
    """
142
    series = series.dropna()
143
    if series.empty:
144
        return None
145
146
    is_str_idx = series.map(type) == str
147
148
    if series.map(lambda x: isinstance(x, Iterable)).any():
149
        if (is_str_idx).any():
150
            # Convert mixed dtypes to lists
151
            series.loc[is_str_idx] = series.loc[is_str_idx].map(lambda s: [s] if len(s) else None)
152
        return np.unique(np.hstack(series))
153
154
    elif is_str_idx.any():
155
        concat_str = series.astype(str).unique()
156
        if len(concat_str):  # Avoid empty string
157
            return concat_str
158
159
    else:
160
        return series.tolist()
161
162
def concat(series: pd.Series) -> Union[str, List, np.ndarray, None]:
163
    """
164
    Args:
165
        series (pd.Series): Entries can be either a string or a list of strings.
166
    """
167
    series = series.dropna()
168
    if series.empty:
169
        return
170
171
    is_str_idx = series.map(type) == str
172
    if series.map(lambda x: isinstance(x, Iterable)).any():
173
        if (is_str_idx).any():
174
            # Convert mixed dtypes to lists
175
            series.loc[is_str_idx] = series.loc[is_str_idx].map(lambda s: [s] if len(s) else None)
176
        return np.hstack(series)
177
178
    elif is_str_idx.any():
179
        concat_str = series.astype(str).tolist()
180
        if len(concat_str):  # Avoid empty string
181
            return concat_str
182
183
    else:
184
        return series.tolist()
185
186