Switch to side-by-side view

--- a
+++ b/src/scpanel/split_patient.py
@@ -0,0 +1,118 @@
+import pandas as pd
+
+# from collections import Counter
+# import os
+import scanpy as sc
+from sklearn.model_selection import train_test_split
+from anndata._core.anndata import AnnData
+from typing import Dict, Tuple
+
+# import numpy as np
+
+
+def split_train_test(
+    adata: AnnData,
+    out_dir: str,
+    min_cells: int=20,
+    min_samples: int=3,
+    test_pt_size: float=0.2,
+    random_state: int=3467,
+    verbose: int=0,
+) -> Tuple[Dict[str, AnnData], Dict[str, AnnData]]:
+    # split dataset by cell type-----------------------------
+    # for each cell type,
+    ## filter patients with # of cells < 20
+    ## if # of patients in each group >= 3:
+    ## Dowsample and upsample(SMOTE) # of cells in each patient to Median value.
+
+    adata_train_dict = {}
+    adata_test_dict = {}
+    adata_train_meta = pd.DataFrame(columns=adata.obs.columns)
+    adata_test_meta = pd.DataFrame(columns=adata.obs.columns)
+
+    celltypes_all = adata.obs["ct"].astype("category").cat.categories.tolist()
+    celltypes_exclude = []
+
+    for ct in celltypes_all:
+        if verbose > 0:
+            print("Start working on", ct)
+
+        adata_ct = adata[adata.obs["ct"] == ct,]
+        # Number of cells in each patient
+        n_cell_pt = adata_ct.obs.groupby(
+            ["y", "patient_id"], observed=True, as_index=False
+        ).size()
+        # Remove paients with cells less than min_cells
+        pt_keep = n_cell_pt.patient_id[n_cell_pt["size"] >= min_cells].tolist()
+
+        # Cell types with 0 patient has cells >= min_cells
+        if len(pt_keep) == 0:
+            # print(ct, f': all patients has < {min_cells} cells')
+            celltypes_exclude.append(ct)
+            continue
+
+        adata_ct = adata_ct[adata_ct.obs["patient_id"].isin(pt_keep),]
+
+        n_cell_pt = adata_ct.obs.groupby(
+            ["y", "patient_id"], observed=True, as_index=False
+        ).size()
+
+        # Skip cell types with less than 3 patients in at least one condition
+        if (n_cell_pt.y.nunique() < 2) | (
+            (n_cell_pt.y.value_counts() < min_samples).any()
+        ):
+            # print(ct, 'has less than 3 patients in condition')
+            celltypes_exclude.append(ct)
+            continue
+
+        # Remove lowly-expressed genes
+        sc.pp.filter_genes(adata_ct, min_cells=(adata_ct.n_obs) * 0.1)
+        sc.pp.highly_variable_genes(adata_ct, n_top_genes=2000)
+        adata_ct = adata_ct[:, adata_ct.var.highly_variable]
+
+        # Split data into train and test set--------------------------
+        pat_meta_temp = n_cell_pt
+        cell_meta_temp = adata_ct.obs
+
+        # split data at patient-level
+        rest_, patient_test_id_list = train_test_split(
+            pat_meta_temp.patient_id,
+            test_size=test_pt_size,
+            stratify=pat_meta_temp.y,
+            random_state=random_state,
+        )
+
+        # print(len(patient_test_id_list),'patients in test set')
+        # print(len(rest_), 'patients remaining...')
+
+        # retrieve cell-level index for train and test set
+        test_idx = cell_meta_temp[
+            cell_meta_temp["patient_id"].isin(patient_test_id_list)
+        ].index.tolist()
+        train_idx = cell_meta_temp[
+            cell_meta_temp["patient_id"].isin(rest_)
+        ].index.tolist()
+
+        # get adata object for train and test
+        train_adata = adata_ct[train_idx, :]
+        test_adata = adata_ct[test_idx, :]
+
+        # Scale training data
+        sc.pp.scale(train_adata, max_value=10)
+        scaler_stat = train_adata.var[["mean", "std"]]
+        train_adata.var = scaler_stat
+
+        adata_train_dict[ct] = train_adata
+        adata_test_dict[ct] = test_adata
+
+        # Summarize metadata
+        adata_train_meta = pd.concat([adata_train_meta, train_adata.obs])
+        adata_test_meta = pd.concat([adata_test_meta, test_adata.obs])
+
+        adata_train_meta.to_csv(f"{out_dir}/Meta_all_train.csv", index=True)
+        adata_test_meta.to_csv(f"{out_dir}/Meta_all_test.csv", index=True)
+
+    print("Following cell types are excluded since no enough cells/patients:")
+    print(*celltypes_exclude, sep=", ")
+
+    return adata_train_dict, adata_test_dict