--- a +++ b/src/scpanel/select_cell.py @@ -0,0 +1,386 @@ +# import scanpy as sc + +import os + +# import numpy as np +# import pandas as pd +import time +from collections import Counter + +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.stats import shapiro +from sklearn.ensemble import RandomForestClassifier + +# from sklearn.utils import resample +from sklearn.metrics import balanced_accuracy_score, roc_auc_score +from sklearn.model_selection import train_test_split +from tqdm import tqdm + +from .utils_func import * +from anndata._core.anndata import AnnData +from matplotlib.axes._axes import Axes +from numpy import float64 +from pandas.core.frame import DataFrame +from typing import Dict, Tuple + + +def split_patients(adata: AnnData, test_pt_size: float, random_state: int, out_dir: str, verbose: bool) -> Tuple[AnnData, AnnData]: + ####################################################### + # utilize anndata (based on Scanpy) as input and output + ## - adata: preprocessed, X contains log-transformed data + ## - random_state + ## - out_dir: save splitted data and split information + ## - dataset: name of dataset + ####################################################### + + # retrieve patient-level and cell-level metadata + pat_meta_temp = adata.obs[["y", "patient_id"]].drop_duplicates().reset_index() + cell_meta_temp = adata.obs + + # split train and test set at patient-level + rest_, patient_test_id_list = train_test_split( + pat_meta_temp.patient_id, + test_size=test_pt_size, + stratify=pat_meta_temp.y, + random_state=random_state, + ) + if verbose: + print( + len(patient_test_id_list), + "patients in test set: ", + patient_test_id_list.values, + ) + print(len(rest_), "patients in train set: ", rest_.values) + + # retrieve cell-level index for train and test set + test_idx = cell_meta_temp[ + cell_meta_temp["patient_id"].isin(patient_test_id_list) + ].index.tolist() + train_idx = cell_meta_temp[cell_meta_temp["patient_id"].isin(rest_)].index.tolist() + + ## reload train-test index----------------- + ## is_train = np.genfromtxt(f'{out_dir}/train_test_idx.txt', dtype=bool) + + # retreive y (labels) + y = adata.obs["y"] + + # retreive X (data) + train_adata = adata[train_idx, :] + test_adata = adata[test_idx, :] + + # output split information + ## train set + cell_info_train = pd.DataFrame( + dict( + **Counter(y[train_idx]), + **{"total_cells": len(train_idx), "patient_ids": [rest_.values]}, + ), + index=["train"], + ) + n_patient_train = ( + train_adata.obs[["y", "patient_id"]] + .drop_duplicates() + .groupby(["y"])["patient_id"] + .count() + .to_frame() + .T + ) + n_patient_train.columns = ["N_" + x for x in n_patient_train.columns.tolist()] + split_info_train = pd.concat( + [cell_info_train, n_patient_train.set_index(cell_info_train.index)], axis=1 + ) + + ## test set + cell_info_test = pd.DataFrame( + dict( + **Counter(y[test_idx]), + **{ + "total_cells": len(test_idx), + "patient_ids": [patient_test_id_list.values], + }, + ), + index=["test"], + ) + n_patient_test = ( + test_adata.obs[["y", "patient_id"]] + .drop_duplicates() + .groupby(["y"])["patient_id"] + .count() + .to_frame() + .T + ) + n_patient_test.columns = ["N_" + x for x in n_patient_test.columns.tolist()] + split_info_test = pd.concat( + [cell_info_test, n_patient_test.set_index(cell_info_test.index)], axis=1 + ) + + split_info = pd.concat([split_info_train, split_info_test], axis=0) + # print(split_info) + + # output + ## split information + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + # output train and test index + np.savetxt( + f"{out_dir}/train_test_idx.txt", adata.obs_names.isin(train_idx), fmt="%s" + ) + split_info.to_csv(f"{out_dir}/split_patient_train_test_info.csv") + + ## train set & test set + # del train_adata.raw + # del test_adata.raw + train_adata.write_h5ad(f"{out_dir}/processed_rna_assay_train.h5ad") + test_adata.write_h5ad(f"{out_dir}/processed_rna_assay_test.h5ad") + + return train_adata, test_adata + + +def cal_bootstrap_score( + adata: AnnData, + out_dir: str, + sample_n_cell: int, + n_iterations: int=100, + n_threads: int=16, + show_progress: bool=True, + verbose: bool=False, +) -> DataFrame: + + celltype = adata.obs["ct"].unique()[0] + + # Initializing DataFrame, to hold bootstrapped statistics + bootstrapped_stats = pd.DataFrame() + + if show_progress: + print(celltype, "start calculating...") + # Each loop iteration is a single bootstrap resample and model fit + for i in range(n_iterations): + + if verbose: + print("Starting iteration #", i) + + adata_0 = adata[adata.obs.label == 0].copy() + adata_1 = adata[adata.obs.label == 1].copy() + + # make balanced data + adata_0_i_index = ( + adata_0.obs.groupby("patient_id") + .sample(n=sample_n_cell, replace=False, random_state=i) + .index + ) + adata_0_i = adata_0[adata_0_i_index] + adata_1_i_index = ( + adata_1.obs.groupby("patient_id") + .sample(n=sample_n_cell, replace=False, random_state=i) + .index + ) + adata_1_i = adata_1[adata_1_i_index] + + adata_i = anndata.concat([adata_0_i, adata_1_i]) + adata_i.obs_names_make_unique() + + adata_train_i, adata_test_i = split_patients( + adata_i, + random_state=i, + test_pt_size=0.4, + out_dir=f"{out_dir}/tmp/split_{i}", + verbose=verbose, + ) + + X_i_train, y_i_train = get_X_y_from_ann(adata_train_i) + X_i_test, y_i_test = get_X_y_from_ann(adata_test_i) + + # Fill NaN in numpy + X_i_train = np.nan_to_num(X_i_train) + y_i_train = np.nan_to_num(y_i_train) + X_i_test = np.nan_to_num(X_i_test) + y_i_test = np.nan_to_num(y_i_test) + + # Initializing estimator + rf = RandomForestClassifier( + n_jobs=n_threads, class_weight="balanced", random_state=i + ) + rf.fit(X_i_train, y_i_train) + + # Make prediction + y_i_pred = rf.predict(X_i_test) + y_i_pred_score = rf.predict_proba(X_i_test) + + # Storing stats in DataFrame, and concatenating with stats + bACC = balanced_accuracy_score(y_i_test, y_i_pred) + # AUC = roc_auc_score(y_i_test, y_i_pred_score[:, 1]) + AUC = roc_auc_score(y_i_test, y_i_pred_score[:, 1]) + + bootstrapped_stats_i = pd.DataFrame( + data=dict(bACC=bACC, AUC=AUC, celltype=celltype), index=[i] + ) + + bootstrapped_stats = pd.concat(objs=[bootstrapped_stats, bootstrapped_stats_i]) + + if show_progress & ((i + 1) % 10 == 0): + print("n_iterations", i + 1, " is done") + + return bootstrapped_stats + + +def custom_metrics(grouping: Tuple[str, DataFrame], metric: str) -> Dict[str, float64]: + (group_label, df) = grouping + + if shapiro(df[metric])[1] <= 0.05: + return {group_label: df[metric].median()} + else: + return {group_label: df[metric].mean()} + + +def cell_type_score( + adata_train_dict: Dict[str, AnnData], out_dir: str, ncpus: int, sample_n_cell: int, n_iterations: int=100, verbose: bool=False +) -> Tuple[DataFrame, DataFrame]: + + bootstrapped_stats_all = pd.DataFrame() + celltypes_all = adata_train_dict.keys() + + # timestart = time.time() + for celltype in tqdm(celltypes_all): + + adata = adata_train_dict[celltype] + + bootstrapped_stats_celltype = cal_bootstrap_score( + adata, + n_iterations=n_iterations, + sample_n_cell=sample_n_cell, + n_threads=ncpus, + show_progress=False, + out_dir=out_dir, + verbose=verbose, + ) + + bootstrapped_stats_all = pd.concat( + objs=[bootstrapped_stats_all, bootstrapped_stats_celltype] + ) + print(celltype, " DONE") + + timeend = time.time() + # print ("Cell type scores calculation took", time.strftime('%Hh%Mm%Ss',time.gmtime(timeend - timestart))) + + grouping = bootstrapped_stats_all.groupby("celltype") + + AUC_dict = dict() + for i in grouping: + AUC_dict.update(custom_metrics(i, "AUC")) + + result_AUC = pd.DataFrame.from_dict(AUC_dict, orient="index", columns=["AUC"]) + + # add number of cells in each cell type as one column + n_cell_df = pd.DataFrame.from_dict( + (dict((k, len(v)) for k, v in adata_train_dict.items())), + orient="index", + columns=["n_cell"], + ) + result_AUC = result_AUC.merge(n_cell_df, left_index=True, right_index=True) + result_AUC.index = result_AUC.index.set_names(["celltype"]) + result_AUC = result_AUC.reset_index().sort_values(by=["AUC"]) + + # check if output path exist + # if not, create one + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + result_AUC.to_csv(f"{out_dir}/celltype_AUC.csv") + + bootstrapped_stats_all.to_csv(f"{out_dir}/celltype_bootstrap_stats_all.csv") + + return result_AUC, bootstrapped_stats_all + + +def plot_cell_type_score(AUC: DataFrame, AUC_all: DataFrame, width: int=4, height: int=5) -> Axes: + + AUC = AUC.set_index("celltype") + + # long to wide table, sorted + pData = AUC_all.pivot(columns="celltype", values="AUC") + pData = pData[AUC.index[::-1]] + + # Initialize figure + fig, axes = plt.subplots(figsize=(width, height), dpi=200) + + axes = sns.boxplot(data=pData, orient="h", ax=axes) + axes.set_xlabel("AUROC") + axes.set_ylabel("Disease Responsive Cell Types") + + # customize ytick labels + axes.set_yticklabels( + [y + " (" + str(AUC.loc[y, "n_cell"]) + ")" for y in pData.columns] + ) + + # customize xtick labels + axes.set_xlim([0.5, 1.05]) + + axes.spines[["right", "top"]].set_visible(False) + + # Add median value on top of each box + yticks_dict = {k: v for k, v in zip(pData.columns, plt.yticks()[0])} + for y, x in pData.max().items(): + s = AUC.loc[y, "AUC"] + plt.text( + x + 0.1, + yticks_dict[y], + f"{s:.3f}", + horizontalalignment="center", + verticalalignment="center", + ) + return axes + + +####deprecated 2023.11.17########### +# def plot_cell_type_score(result_AUC, width=5, height=5): + +# plot_df = result_AUC +# plot_df['yname'] = plot_df.celltype + " (" + plot_df.n_cell.map(str) + ")" + +# # using subplots() to draw vertical lines +# fig, axes = plt.subplots(figsize=(width, height), dpi=200) + +# # providing list of colors + +# axes.hlines(plot_df['yname'], xmin=0, +# xmax=plot_df['AUC']) + +# # drawing the markers (circle) +# axes.plot(plot_df['AUC'], plot_df['yname'], "o") +# axes.set_xlim(0) + +# # formatting and details +# plt.xlabel('AUROC',fontsize=15) +# plt.ylabel('Disease Responsive Cell Types',fontsize=15) +# #plt.title('AUC') +# plt.yticks(plot_df['yname'],fontsize=12) +# plt.xlim([0, 1.2]) + +# # expand the xlim but hide the last xtick '1.2' +# x_ticks = axes.xaxis.get_major_ticks() +# x_ticks[-1].set_visible(False) + +# yticks_dict = {k: v for k, v in zip(plot_df['celltype'], plt.yticks()[0])} +# for x, y in zip(plot_df['AUC'], plot_df['celltype']): +# plt.text(x+0.1, yticks_dict[y], round(x, 3), horizontalalignment='center', verticalalignment='center', +# fontsize=12) + + +def select_celltype(adata_train_dict: Dict[str, AnnData], out_dir: str, celltype_selected: str) -> AnnData: + # output selected cell type + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + with open(f"{out_dir}/selected_celltype.txt", "w") as f: + for item in celltype_selected: + f.write("%s" % item) + + adata_train = adata_train_dict[celltype_selected] + + sc.pp.filter_genes(adata_train, min_cells=1) + + print("Selecting ", *celltype_selected, "...") + return adata_train.copy()