|
a |
|
b/src/scpanel/utils_func.py |
|
|
1 |
import warnings |
|
|
2 |
|
|
|
3 |
# from sklearn.preprocessing import LabelEncoder |
|
|
4 |
from typing import Tuple, Dict, Optional, Union |
|
|
5 |
|
|
|
6 |
import anndata |
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
import scanpy as sc |
|
|
10 |
from scipy import sparse |
|
|
11 |
from sklearn.utils import resample |
|
|
12 |
from anndata._core.anndata import AnnData |
|
|
13 |
from numpy import ndarray |
|
|
14 |
from scipy.sparse._base import spmatrix |
|
|
15 |
from pandas.core.arrays.categorical import Categorical |
|
|
16 |
from pandas.core.frame import DataFrame |
|
|
17 |
from scipy.sparse._csr import csr_matrix |
|
|
18 |
|
|
|
19 |
warnings.filterwarnings("ignore") |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def preprocess( |
|
|
23 |
adata: object, |
|
|
24 |
integrated: bool=False, |
|
|
25 |
ct_col: Optional[str]=None, |
|
|
26 |
y_col: Optional[str]=None, |
|
|
27 |
pt_col: Optional[str]=None, |
|
|
28 |
class_map: Optional[Dict[str, int]]=None, |
|
|
29 |
) -> AnnData: |
|
|
30 |
"""standardize input data |
|
|
31 |
|
|
|
32 |
Parameters |
|
|
33 |
---------- |
|
|
34 |
adata : object |
|
|
35 |
integrated : bool=False |
|
|
36 |
ct_col : Optional[str]=None |
|
|
37 |
y_col : Optional[str]=None |
|
|
38 |
pt_col : Optional[str]=None |
|
|
39 |
class_map : Optional[Dict[str, int]]=None |
|
|
40 |
|
|
|
41 |
Returns |
|
|
42 |
------- |
|
|
43 |
AnnData |
|
|
44 |
|
|
|
45 |
""" |
|
|
46 |
try: |
|
|
47 |
adata.__dict__["_raw"].__dict__["_var"] = ( |
|
|
48 |
adata.__dict__["_raw"] |
|
|
49 |
.__dict__["_var"] |
|
|
50 |
.rename(columns={"_index": "features"}) |
|
|
51 |
) |
|
|
52 |
except AttributeError as e: |
|
|
53 |
# print(f"An AttributeError occurred: {e}") |
|
|
54 |
# Handle the error or do nothing if you just want to continue |
|
|
55 |
pass |
|
|
56 |
|
|
|
57 |
# adata.raw = adata |
|
|
58 |
# can be recoverd by: adata_raw = adata.raw.to_adata() |
|
|
59 |
|
|
|
60 |
if not integrated: |
|
|
61 |
# Standardize X matrix----------------------------------------- |
|
|
62 |
## 1. Check if X is raw count matrix |
|
|
63 |
### If True, make X into logcount matrix |
|
|
64 |
if check_nonnegative_integers(adata.X): |
|
|
65 |
adata.layers["count"] = adata.X |
|
|
66 |
print("X is raw count matrix, adding to `count` layer......") |
|
|
67 |
sc.pp.normalize_total(adata, target_sum=1e4) |
|
|
68 |
sc.pp.log1p(adata) |
|
|
69 |
elif check_nonnegative_float(adata.X): |
|
|
70 |
adata.layers["logcount"] = adata.X |
|
|
71 |
print("X is logcount matrix, adding to `logcount` layer......") |
|
|
72 |
else: |
|
|
73 |
raise TypeError("X should be either raw counts or log-normalized counts") |
|
|
74 |
|
|
|
75 |
# Filter low-express genes----------------------------------------------- |
|
|
76 |
sc.pp.filter_genes(adata, min_cells=1) |
|
|
77 |
|
|
|
78 |
# Calculate HVGs------------------------------------------------------- |
|
|
79 |
# sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0) |
|
|
80 |
|
|
|
81 |
# Standardize Metadata----------------------------------------- |
|
|
82 |
## cell type as ct |
|
|
83 |
## coondition as y |
|
|
84 |
|
|
|
85 |
if ct_col is not None: |
|
|
86 |
adata.obs["ct"] = adata.obs[ct_col] |
|
|
87 |
adata.obs["ct"] = adata.obs["ct"].astype("category") |
|
|
88 |
else: |
|
|
89 |
adata.obs["ct"] = adata.obs["ct"].astype("category") |
|
|
90 |
|
|
|
91 |
if y_col is not None: |
|
|
92 |
adata.obs["y"] = adata.obs[y_col] |
|
|
93 |
adata.obs["y"] = adata.obs["y"].astype("category") |
|
|
94 |
else: |
|
|
95 |
adata.obs["y"] = adata.obs["y"].astype("category") |
|
|
96 |
|
|
|
97 |
if pt_col is not None: |
|
|
98 |
adata.obs["patient_id"] = adata.obs[pt_col] |
|
|
99 |
adata.obs["patient_id"] = adata.obs["patient_id"].astype("category") |
|
|
100 |
else: |
|
|
101 |
adata.obs["patient_id"] = adata.obs["patient_id"].astype("category") |
|
|
102 |
|
|
|
103 |
# print(class_map) |
|
|
104 |
adata.obs["label"] = adata.obs["y"].map(class_map) |
|
|
105 |
|
|
|
106 |
# # add HVGs in each cell type |
|
|
107 |
# # split each cell type into new objects. |
|
|
108 |
# celltypes_all = adata.obs['ct'].cat.categories.tolist() |
|
|
109 |
|
|
|
110 |
# alldata = {} |
|
|
111 |
# for celltype in celltypes_all: |
|
|
112 |
# #print(celltype) |
|
|
113 |
# adata_tmp = adata[adata.obs['ct'] == celltype,] |
|
|
114 |
# if adata_tmp.shape[0] > 20: |
|
|
115 |
# sc.pp.highly_variable_genes(adata_tmp, min_mean=0.0125, max_mean=3, min_disp=1.5) |
|
|
116 |
# adata_tmp = adata_tmp[:, adata_tmp.var.highly_variable] |
|
|
117 |
# alldata[celltype] = adata_tmp |
|
|
118 |
|
|
|
119 |
# celltype_hvgs = anndata.concat(alldata, join="outer").var_names.tolist() |
|
|
120 |
# hvgs = adata.var.highly_variable.index[adata.var.highly_variable].tolist() |
|
|
121 |
# hvgs_idx = adata.var_names.isin(list(set(celltype_hvgs + hvgs))) |
|
|
122 |
|
|
|
123 |
return adata |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def get_X_y_from_ann(adata: AnnData, return_adj: bool=False, n_neigh: int=10) -> Union[Tuple[ndarray, Categorical], Tuple[ndarray, Categorical, csr_matrix]]: |
|
|
127 |
""" |
|
|
128 |
|
|
|
129 |
Parameters |
|
|
130 |
---------- |
|
|
131 |
adata |
|
|
132 |
return_adj |
|
|
133 |
n_neigh |
|
|
134 |
|
|
|
135 |
Returns |
|
|
136 |
------- |
|
|
137 |
|
|
|
138 |
""" |
|
|
139 |
# scale and store results in layer |
|
|
140 |
# adata.layers['scaled'] = sc.pp.scale(adata, max_value=10, copy=True).X |
|
|
141 |
|
|
|
142 |
# This X matrix should be scaled value |
|
|
143 |
X = adata.to_df().values |
|
|
144 |
|
|
|
145 |
y = adata.obs["label"].values |
|
|
146 |
|
|
|
147 |
if return_adj: |
|
|
148 |
if adata.n_vars >= 50: |
|
|
149 |
sc.pp.pca(adata, n_comps=30, use_highly_variable=False) |
|
|
150 |
sc.pl.pca_variance_ratio(adata, log=True, n_pcs=30) |
|
|
151 |
knn = sc.Neighbors(adata) |
|
|
152 |
knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=30) |
|
|
153 |
else: |
|
|
154 |
print( |
|
|
155 |
"Number of input genes < 50, use original space to get adjacency matrix..." |
|
|
156 |
) |
|
|
157 |
knn = sc.Neighbors(adata) |
|
|
158 |
knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=0) |
|
|
159 |
|
|
|
160 |
adj = knn.connectivities + sparse.diags([1] * adata.shape[0]).tocsr() |
|
|
161 |
|
|
|
162 |
return X, y, adj |
|
|
163 |
|
|
|
164 |
else: |
|
|
165 |
return X, y |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
def check_nonnegative_integers(X: Union[np.ndarray, sparse.spmatrix]) -> bool: |
|
|
169 |
"""Checks values of X to ensure it is count data""" |
|
|
170 |
from numbers import Integral |
|
|
171 |
|
|
|
172 |
data = X[np.random.choice(X.shape[0], 10, replace=False), :] |
|
|
173 |
data = data.todense() if sparse.issparse(data) else data |
|
|
174 |
|
|
|
175 |
# Check no negatives |
|
|
176 |
if np.signbit(data).any(): |
|
|
177 |
return False |
|
|
178 |
# Check all are integers |
|
|
179 |
elif issubclass(data.dtype.type, Integral): |
|
|
180 |
return True |
|
|
181 |
elif np.any(~np.equal(np.mod(data, 1), 0)): |
|
|
182 |
return False |
|
|
183 |
else: |
|
|
184 |
return True |
|
|
185 |
|
|
|
186 |
|
|
|
187 |
def check_nonnegative_float(X: Union[np.ndarray, sparse.spmatrix]) -> bool: |
|
|
188 |
"""Checks values of X to ensure it is logcount data""" |
|
|
189 |
from numbers import Integral |
|
|
190 |
|
|
|
191 |
data = X[np.random.choice(X.shape[0], 10, replace=False), :] |
|
|
192 |
data = data.todense() if sparse.issparse(data) else data |
|
|
193 |
|
|
|
194 |
# Check no negatives |
|
|
195 |
if np.signbit(data).any(): # one negative(if True), return False |
|
|
196 |
return False # non negative(if False), go to next condition |
|
|
197 |
# Check all are integers |
|
|
198 |
elif np.issubdtype(data.dtype, np.floating): |
|
|
199 |
return True |
|
|
200 |
|
|
|
201 |
|
|
|
202 |
def compute_cell_weight(data: Union[AnnData, DataFrame]) -> ndarray: |
|
|
203 |
## data: anndata object | DataFrame from anndata.obs |
|
|
204 |
if isinstance(data, anndata.AnnData): |
|
|
205 |
data = data.obs |
|
|
206 |
|
|
|
207 |
from sklearn.utils.class_weight import compute_sample_weight |
|
|
208 |
|
|
|
209 |
## class weight |
|
|
210 |
w_c = compute_sample_weight(class_weight="balanced", y=data["label"]) |
|
|
211 |
w_c_df = pd.DataFrame({"patient_id": data["patient_id"], "w_c": w_c}) |
|
|
212 |
|
|
|
213 |
## patient weight |
|
|
214 |
gped = data.groupby(["label"]) |
|
|
215 |
w_pt_df = pd.DataFrame([]) |
|
|
216 |
for name, group in gped: |
|
|
217 |
w_pt = compute_sample_weight(class_weight="balanced", y=group["patient_id"]) |
|
|
218 |
w_pt_df = pd.concat( |
|
|
219 |
[ |
|
|
220 |
w_pt_df, |
|
|
221 |
pd.DataFrame( |
|
|
222 |
{ |
|
|
223 |
"label": group["label"], |
|
|
224 |
"patient_id": group["patient_id"], |
|
|
225 |
"w_pt": w_pt, |
|
|
226 |
} |
|
|
227 |
), |
|
|
228 |
], |
|
|
229 |
axis=0, |
|
|
230 |
) |
|
|
231 |
|
|
|
232 |
## cell_weight = class weight*patient_weight |
|
|
233 |
w_df = w_c_df.merge(w_pt_df, left_index=True, right_index=True) |
|
|
234 |
w_df["w"] = w_df["w_c"] * w_df["w_pt"] |
|
|
235 |
|
|
|
236 |
## Add results to adata metadata or save csv??? |
|
|
237 |
# adata_train.obs = adata_train.obs.merge(w_df[['w_pt','w_c','w']], left_index=True, right_index=True) |
|
|
238 |
|
|
|
239 |
return w_df["w"].values |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
def downsample_adata(adata, downsample_size=4000, random_state=1): |
|
|
243 |
####################### |
|
|
244 |
## stratified downsampling by patient_id, y, and ct (patient, disease status, and cell type) |
|
|
245 |
# metadata for stratified downsampling |
|
|
246 |
adata.obs["downsample_stratify"] = adata.obs[["patient_id", "y", "ct"]].apply( |
|
|
247 |
lambda x: "_".join(x.dropna().astype(str)), axis=1 |
|
|
248 |
) |
|
|
249 |
adata.obs["downsample_stratify"] = adata.obs["downsample_stratify"].astype( |
|
|
250 |
"category" |
|
|
251 |
) |
|
|
252 |
|
|
|
253 |
down_index = resample( |
|
|
254 |
adata.obs_names, |
|
|
255 |
replace=False, |
|
|
256 |
n_samples=downsample_size, |
|
|
257 |
stratify=adata.obs["downsample_stratify"], |
|
|
258 |
random_state=random_state, |
|
|
259 |
) |
|
|
260 |
|
|
|
261 |
adata = adata[down_index,] |
|
|
262 |
return adata |