Diff of /src/scpanel/utils_func.py [000000] .. [d90ecf]

Switch to side-by-side view

--- a
+++ b/src/scpanel/utils_func.py
@@ -0,0 +1,262 @@
+import warnings
+
+# from sklearn.preprocessing import LabelEncoder
+from typing import Tuple, Dict, Optional, Union
+
+import anndata
+import numpy as np
+import pandas as pd
+import scanpy as sc
+from scipy import sparse
+from sklearn.utils import resample
+from anndata._core.anndata import AnnData
+from numpy import ndarray
+from scipy.sparse._base import spmatrix
+from pandas.core.arrays.categorical import Categorical
+from pandas.core.frame import DataFrame
+from scipy.sparse._csr import csr_matrix
+
+warnings.filterwarnings("ignore")
+
+
+def preprocess(
+    adata: object,
+    integrated: bool=False,
+    ct_col: Optional[str]=None,
+    y_col: Optional[str]=None,
+    pt_col: Optional[str]=None,
+    class_map: Optional[Dict[str, int]]=None,
+) -> AnnData:
+    """standardize input data
+
+    Parameters
+    ----------
+    adata : object
+    integrated : bool=False
+    ct_col : Optional[str]=None
+    y_col : Optional[str]=None
+    pt_col : Optional[str]=None
+    class_map : Optional[Dict[str, int]]=None
+
+    Returns
+    -------
+    AnnData
+
+    """
+    try:
+        adata.__dict__["_raw"].__dict__["_var"] = (
+            adata.__dict__["_raw"]
+            .__dict__["_var"]
+            .rename(columns={"_index": "features"})
+        )
+    except AttributeError as e:
+        # print(f"An AttributeError occurred: {e}")
+        # Handle the error or do nothing if you just want to continue
+        pass
+
+    # adata.raw = adata
+    # can be recoverd by: adata_raw = adata.raw.to_adata()
+
+    if not integrated:
+        # Standardize X matrix-----------------------------------------
+        ## 1. Check if X is raw count matrix
+        ### If True, make X into logcount matrix
+        if check_nonnegative_integers(adata.X):
+            adata.layers["count"] = adata.X
+            print("X is raw count matrix, adding to `count` layer......")
+            sc.pp.normalize_total(adata, target_sum=1e4)
+            sc.pp.log1p(adata)
+        elif check_nonnegative_float(adata.X):
+            adata.layers["logcount"] = adata.X
+            print("X is logcount matrix, adding to `logcount` layer......")
+        else:
+            raise TypeError("X should be either raw counts or log-normalized counts")
+
+    # Filter low-express genes-----------------------------------------------
+    sc.pp.filter_genes(adata, min_cells=1)
+
+    # Calculate HVGs-------------------------------------------------------
+    # sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0)
+
+    # Standardize Metadata-----------------------------------------
+    ## cell type as ct
+    ## coondition as y
+
+    if ct_col is not None:
+        adata.obs["ct"] = adata.obs[ct_col]
+        adata.obs["ct"] = adata.obs["ct"].astype("category")
+    else:
+        adata.obs["ct"] = adata.obs["ct"].astype("category")
+
+    if y_col is not None:
+        adata.obs["y"] = adata.obs[y_col]
+        adata.obs["y"] = adata.obs["y"].astype("category")
+    else:
+        adata.obs["y"] = adata.obs["y"].astype("category")
+
+    if pt_col is not None:
+        adata.obs["patient_id"] = adata.obs[pt_col]
+        adata.obs["patient_id"] = adata.obs["patient_id"].astype("category")
+    else:
+        adata.obs["patient_id"] = adata.obs["patient_id"].astype("category")
+
+    # print(class_map)
+    adata.obs["label"] = adata.obs["y"].map(class_map)
+
+    #     # add HVGs in each cell type
+    #     # split each cell type into new objects.
+    #     celltypes_all = adata.obs['ct'].cat.categories.tolist()
+
+    #     alldata = {}
+    #     for celltype in celltypes_all:
+    #         #print(celltype)
+    #         adata_tmp = adata[adata.obs['ct'] == celltype,]
+    #         if adata_tmp.shape[0] > 20:
+    #             sc.pp.highly_variable_genes(adata_tmp, min_mean=0.0125, max_mean=3, min_disp=1.5)
+    #             adata_tmp = adata_tmp[:, adata_tmp.var.highly_variable]
+    #             alldata[celltype] = adata_tmp
+
+    #     celltype_hvgs = anndata.concat(alldata, join="outer").var_names.tolist()
+    #     hvgs = adata.var.highly_variable.index[adata.var.highly_variable].tolist()
+    #     hvgs_idx = adata.var_names.isin(list(set(celltype_hvgs + hvgs)))
+
+    return adata
+
+
+def get_X_y_from_ann(adata: AnnData, return_adj: bool=False, n_neigh: int=10) -> Union[Tuple[ndarray, Categorical], Tuple[ndarray, Categorical, csr_matrix]]:
+    """
+
+    Parameters
+    ----------
+    adata
+    return_adj
+    n_neigh
+
+    Returns
+    -------
+
+    """
+    # scale and store results in layer
+    # adata.layers['scaled'] = sc.pp.scale(adata, max_value=10, copy=True).X
+
+    # This X matrix should be scaled value
+    X = adata.to_df().values
+
+    y = adata.obs["label"].values
+
+    if return_adj:
+        if adata.n_vars >= 50:
+            sc.pp.pca(adata, n_comps=30, use_highly_variable=False)
+            sc.pl.pca_variance_ratio(adata, log=True, n_pcs=30)
+            knn = sc.Neighbors(adata)
+            knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=30)
+        else:
+            print(
+                "Number of input genes < 50, use original space to get adjacency matrix..."
+            )
+            knn = sc.Neighbors(adata)
+            knn.compute_neighbors(n_neighbors=n_neigh, n_pcs=0)
+
+        adj = knn.connectivities + sparse.diags([1] * adata.shape[0]).tocsr()
+
+        return X, y, adj
+
+    else:
+        return X, y
+
+
+def check_nonnegative_integers(X: Union[np.ndarray, sparse.spmatrix]) -> bool:
+    """Checks values of X to ensure it is count data"""
+    from numbers import Integral
+
+    data = X[np.random.choice(X.shape[0], 10, replace=False), :]
+    data = data.todense() if sparse.issparse(data) else data
+
+    # Check no negatives
+    if np.signbit(data).any():
+        return False
+    # Check all are integers
+    elif issubclass(data.dtype.type, Integral):
+        return True
+    elif np.any(~np.equal(np.mod(data, 1), 0)):
+        return False
+    else:
+        return True
+
+
+def check_nonnegative_float(X: Union[np.ndarray, sparse.spmatrix]) -> bool:
+    """Checks values of X to ensure it is logcount data"""
+    from numbers import Integral
+
+    data = X[np.random.choice(X.shape[0], 10, replace=False), :]
+    data = data.todense() if sparse.issparse(data) else data
+
+    # Check no negatives
+    if np.signbit(data).any():  # one negative(if True), return False
+        return False  # non negative(if False), go to next condition
+    # Check all are integers
+    elif np.issubdtype(data.dtype, np.floating):
+        return True
+
+
+def compute_cell_weight(data: Union[AnnData, DataFrame]) -> ndarray:
+    ## data: anndata object | DataFrame from anndata.obs
+    if isinstance(data, anndata.AnnData):
+        data = data.obs
+
+    from sklearn.utils.class_weight import compute_sample_weight
+
+    ## class weight
+    w_c = compute_sample_weight(class_weight="balanced", y=data["label"])
+    w_c_df = pd.DataFrame({"patient_id": data["patient_id"], "w_c": w_c})
+
+    ## patient weight
+    gped = data.groupby(["label"])
+    w_pt_df = pd.DataFrame([])
+    for name, group in gped:
+        w_pt = compute_sample_weight(class_weight="balanced", y=group["patient_id"])
+        w_pt_df = pd.concat(
+            [
+                w_pt_df,
+                pd.DataFrame(
+                    {
+                        "label": group["label"],
+                        "patient_id": group["patient_id"],
+                        "w_pt": w_pt,
+                    }
+                ),
+            ],
+            axis=0,
+        )
+
+    ## cell_weight = class weight*patient_weight
+    w_df = w_c_df.merge(w_pt_df, left_index=True, right_index=True)
+    w_df["w"] = w_df["w_c"] * w_df["w_pt"]
+
+    ## Add results to adata metadata or save csv???
+    # adata_train.obs = adata_train.obs.merge(w_df[['w_pt','w_c','w']], left_index=True, right_index=True)
+
+    return w_df["w"].values
+
+
+def downsample_adata(adata, downsample_size=4000, random_state=1):
+    #######################
+    ## stratified downsampling by patient_id, y, and ct (patient, disease status, and cell type)
+    # metadata for stratified downsampling
+    adata.obs["downsample_stratify"] = adata.obs[["patient_id", "y", "ct"]].apply(
+        lambda x: "_".join(x.dropna().astype(str)), axis=1
+    )
+    adata.obs["downsample_stratify"] = adata.obs["downsample_stratify"].astype(
+        "category"
+    )
+
+    down_index = resample(
+        adata.obs_names,
+        replace=False,
+        n_samples=downsample_size,
+        stratify=adata.obs["downsample_stratify"],
+        random_state=random_state,
+    )
+
+    adata = adata[down_index,]
+    return adata