a b/src/scpanel/utils_func.py
1
import warnings
2
3
# from sklearn.preprocessing import LabelEncoder
4
from typing import Tuple, Dict, Optional, Union
5
6
import anndata
7
import numpy as np
8
import pandas as pd
9
import scanpy as sc
10
from scipy import sparse
11
from sklearn.utils import resample
12
from anndata._core.anndata import AnnData
13
from numpy import ndarray
14
from scipy.sparse._base import spmatrix
15
from pandas.core.arrays.categorical import Categorical
16
from pandas.core.frame import DataFrame
17
from scipy.sparse._csr import csr_matrix
18
19
warnings.filterwarnings("ignore")
20
21
22
def preprocess(
23
    adata: object,
24
    integrated: bool=False,
25
    ct_col: Optional[str]=None,
26
    y_col: Optional[str]=None,
27
    pt_col: Optional[str]=None,
28
    class_map: Optional[Dict[str, int]]=None,
29
) -> AnnData:
30
    """standardize input data
31
32
    Parameters
33
    ----------
34
    adata : object
35
    integrated : bool=False
36
    ct_col : Optional[str]=None
37
    y_col : Optional[str]=None
38
    pt_col : Optional[str]=None
39
    class_map : Optional[Dict[str, int]]=None
40
41
    Returns
42
    -------
43
    AnnData
44
45
    """
46
    try:
47
        adata.__dict__["_raw"].__dict__["_var"] = (
48
            adata.__dict__["_raw"]
49
            .__dict__["_var"]
50
            .rename(columns={"_index": "features"})
51
        )
52
    except AttributeError as e:
53
        # print(f"An AttributeError occurred: {e}")
54
        # Handle the error or do nothing if you just want to continue
55
        pass
56
57
    # adata.raw = adata
58
    # can be recoverd by: adata_raw = adata.raw.to_adata()
59
60
    if not integrated:
61
        # Standardize X matrix-----------------------------------------
62
        ## 1. Check if X is raw count matrix
63
        ### If True, make X into logcount matrix
64
        if check_nonnegative_integers(adata.X):
65
            adata.layers["count"] = adata.X
66
            print("X is raw count matrix, adding to `count` layer......")
67
            sc.pp.normalize_total(adata, target_sum=1e4)
68
            sc.pp.log1p(adata)
69
        elif check_nonnegative_float(adata.X):
70
            adata.layers["logcount"] = adata.X
71
            print("X is logcount matrix, adding to `logcount` layer......")
72
        else:
73
            raise TypeError("X should be either raw counts or log-normalized counts")
74
75
    # Filter low-express genes-----------------------------------------------
76
    sc.pp.filter_genes(adata, min_cells=1)
77
78
    # Calculate HVGs-------------------------------------------------------
79
    # sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0)
80
81
    # Standardize Metadata-----------------------------------------
82
    ## cell type as ct
83
    ## coondition as y
84
85
    if ct_col is not None:
86
        adata.obs["ct"] = adata.obs[ct_col]
87
        adata.obs["ct"] = adata.obs["ct"].astype("category")
88
    else:
89
        adata.obs["ct"] = adata.obs["ct"].astype("category")
90
91
    if y_col is not None:
92
        adata.obs["y"] = adata.obs[y_col]
93
        adata.obs["y"] = adata.obs["y"].astype("category")
94
    else:
95
        adata.obs["y"] = adata.obs["y"].astype("category")
96
97
    if pt_col is not None:
98
        adata.obs["patient_id"] = adata.obs[pt_col]
99
        adata.obs["patient_id"] = adata.obs["patient_id"].astype("category")
100
    else:
101
        adata.obs["patient_id"] = adata.obs["patient_id"].astype("category")
102
103
    # print(class_map)
104
    adata.obs["label"] = adata.obs["y"].map(class_map)
105
106
    #     # add HVGs in each cell type
107
    #     # split each cell type into new objects.
108
    #     celltypes_all = adata.obs['ct'].cat.categories.tolist()
109
110
    #     alldata = {}
111
    #     for celltype in celltypes_all:
112
    #         #print(celltype)
113
    #         adata_tmp = adata[adata.obs['ct'] == celltype,]
114
    #         if adata_tmp.shape[0] > 20:
115
    #             sc.pp.highly_variable_genes(adata_tmp, min_mean=0.0125, max_mean=3, min_disp=1.5)
116
    #             adata_tmp = adata_tmp[:, adata_tmp.var.highly_variable]
117
    #             alldata[celltype] = adata_tmp
118
119
    #     celltype_hvgs = anndata.concat(alldata, join="outer").var_names.tolist()
120
    #     hvgs = adata.var.highly_variable.index[adata.var.highly_variable].tolist()
121
    #     hvgs_idx = adata.var_names.isin(list(set(celltype_hvgs + hvgs)))
122
123
    return adata
124
125
126
def get_X_y_from_ann(adata: AnnData, return_adj: bool=False, n_neigh: int=10) -> Union[Tuple[ndarray, Categorical], Tuple[ndarray, Categorical, csr_matrix]]:
127
    """
128
129
    Parameters
130
    ----------
131
    adata
132
    return_adj
133
    n_neigh
134
135
    Returns
136
    -------
137
138
    """
139
    # scale and store results in layer
140
    # adata.layers['scaled'] = sc.pp.scale(adata, max_value=10, copy=True).X
141
142
    # This X matrix should be scaled value
143
    X = adata.to_df().values
144
145
    y = adata.obs["label"].values
146
147
    if return_adj:
148
        if adata.n_vars >= 50:
149
            sc.pp.pca(adata, n_comps=30, use_highly_variable=False)
150
            sc.pl.pca_variance_ratio(adata, log=True, n_pcs=30)
151
            knn = sc.Neighbors(adata)
152
            knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=30)
153
        else:
154
            print(
155
                "Number of input genes < 50, use original space to get adjacency matrix..."
156
            )
157
            knn = sc.Neighbors(adata)
158
            knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=0)
159
160
        adj = knn.connectivities + sparse.diags([1] * adata.shape[0]).tocsr()
161
162
        return X, y, adj
163
164
    else:
165
        return X, y
166
167
168
def check_nonnegative_integers(X: Union[np.ndarray, sparse.spmatrix]) -> bool:
169
    """Checks values of X to ensure it is count data"""
170
    from numbers import Integral
171
172
    data = X[np.random.choice(X.shape[0], 10, replace=False), :]
173
    data = data.todense() if sparse.issparse(data) else data
174
175
    # Check no negatives
176
    if np.signbit(data).any():
177
        return False
178
    # Check all are integers
179
    elif issubclass(data.dtype.type, Integral):
180
        return True
181
    elif np.any(~np.equal(np.mod(data, 1), 0)):
182
        return False
183
    else:
184
        return True
185
186
187
def check_nonnegative_float(X: Union[np.ndarray, sparse.spmatrix]) -> bool:
188
    """Checks values of X to ensure it is logcount data"""
189
    from numbers import Integral
190
191
    data = X[np.random.choice(X.shape[0], 10, replace=False), :]
192
    data = data.todense() if sparse.issparse(data) else data
193
194
    # Check no negatives
195
    if np.signbit(data).any():  # one negative(if True), return False
196
        return False  # non negative(if False), go to next condition
197
    # Check all are integers
198
    elif np.issubdtype(data.dtype, np.floating):
199
        return True
200
201
202
def compute_cell_weight(data: Union[AnnData, DataFrame]) -> ndarray:
203
    ## data: anndata object | DataFrame from anndata.obs
204
    if isinstance(data, anndata.AnnData):
205
        data = data.obs
206
207
    from sklearn.utils.class_weight import compute_sample_weight
208
209
    ## class weight
210
    w_c = compute_sample_weight(class_weight="balanced", y=data["label"])
211
    w_c_df = pd.DataFrame({"patient_id": data["patient_id"], "w_c": w_c})
212
213
    ## patient weight
214
    gped = data.groupby(["label"])
215
    w_pt_df = pd.DataFrame([])
216
    for name, group in gped:
217
        w_pt = compute_sample_weight(class_weight="balanced", y=group["patient_id"])
218
        w_pt_df = pd.concat(
219
            [
220
                w_pt_df,
221
                pd.DataFrame(
222
                    {
223
                        "label": group["label"],
224
                        "patient_id": group["patient_id"],
225
                        "w_pt": w_pt,
226
                    }
227
                ),
228
            ],
229
            axis=0,
230
        )
231
232
    ## cell_weight = class weight*patient_weight
233
    w_df = w_c_df.merge(w_pt_df, left_index=True, right_index=True)
234
    w_df["w"] = w_df["w_c"] * w_df["w_pt"]
235
236
    ## Add results to adata metadata or save csv???
237
    # adata_train.obs = adata_train.obs.merge(w_df[['w_pt','w_c','w']], left_index=True, right_index=True)
238
239
    return w_df["w"].values
240
241
242
def downsample_adata(adata, downsample_size=4000, random_state=1):
243
    #######################
244
    ## stratified downsampling by patient_id, y, and ct (patient, disease status, and cell type)
245
    # metadata for stratified downsampling
246
    adata.obs["downsample_stratify"] = adata.obs[["patient_id", "y", "ct"]].apply(
247
        lambda x: "_".join(x.dropna().astype(str)), axis=1
248
    )
249
    adata.obs["downsample_stratify"] = adata.obs["downsample_stratify"].astype(
250
        "category"
251
    )
252
253
    down_index = resample(
254
        adata.obs_names,
255
        replace=False,
256
        n_samples=downsample_size,
257
        stratify=adata.obs["downsample_stratify"],
258
        random_state=random_state,
259
    )
260
261
    adata = adata[down_index,]
262
    return adata