--- a +++ b/bin/train_model.py @@ -0,0 +1,537 @@ +""" +Code to train a model +""" + +import os +import sys +import logging +import argparse +import copy +import functools +import itertools + +import numpy as np +import pandas as pd +import scipy.spatial +import scanpy as sc + +import matplotlib.pyplot as plt +from skorch.helper import predefined_split + +import torch +import torch.nn as nn +import torch.nn.functional as F +import skorch +import skorch.helper + +torch.backends.cudnn.deterministic = True # For reproducibility +torch.backends.cudnn.benchmark = False + +SRC_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel" +) +assert os.path.isdir(SRC_DIR) +sys.path.append(SRC_DIR) + +MODELS_DIR = os.path.join(SRC_DIR, "models") +assert os.path.isdir(MODELS_DIR) +sys.path.append(MODELS_DIR) + +import sc_data_loaders +import adata_utils +import model_utils +import autoencoders +import loss_functions +import layers +import activations +import plot_utils +import utils +import metrics +import interpretation + +logging.basicConfig(level=logging.INFO) + +OPTIMIZER_DICT = { + "adam": torch.optim.Adam, + "rmsprop": torch.optim.RMSprop, +} + + +def build_parser(): + """Build argument parser""" + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--data", "-d", type=str, nargs="*", help="Data files to train on", + ) + input_group.add_argument( + "--snareseq", + action="store_true", + help="Data in SNAREseq format, use custom data loading logic for separated RNA ATC files", + ) + input_group.add_argument( + "--shareseq", + nargs="+", + type=str, + choices=["lung", "skin", "brain"], + help="Load in the given SHAREseq datasets", + ) + parser.add_argument( + "--nofilter", + action="store_true", + help="Whether or not to perform filtering (only applies with --data argument)", + ) + parser.add_argument( + "--linear", + action="store_true", + help="Do clustering data splitting in linear instead of log space", + ) + parser.add_argument( + "--clustermethod", + type=str, + choices=["leiden", "louvain"], + default="leiden", + help="Clustering method to determine data splits", + ) + parser.add_argument( + "--validcluster", type=int, default=0, help="Cluster ID to use as valid cluster" + ) + parser.add_argument( + "--testcluster", type=int, default=1, help="Cluster ID to use as test cluster" + ) + parser.add_argument( + "--outdir", "-o", required=True, type=str, help="Directory to output to" + ) + parser.add_argument( + "--naive", + "-n", + action="store_true", + help="Use a naive model instead of lego model", + ) + parser.add_argument( + "--hidden", type=int, nargs="*", default=[16], help="Hidden dimensions" + ) + parser.add_argument( + "--pretrain", + type=str, + default="", + help="params.pt file to use to warm initialize the model (instead of starting from scratch)", + ) + parser.add_argument( + "--lossweight", + type=float, + nargs="*", + default=[1.33], + help="Relative loss weight", + ) + parser.add_argument( + "--optim", + type=str, + default="adam", + choices=OPTIMIZER_DICT.keys(), + help="Optimizer to use", + ) + parser.add_argument( + "--lr", "-l", type=float, default=[0.01], nargs="*", help="Learning rate" + ) + parser.add_argument( + "--batchsize", "-b", type=int, nargs="*", default=[512], help="Batch size" + ) + parser.add_argument( + "--earlystop", type=int, default=25, help="Early stopping after N epochs" + ) + parser.add_argument( + "--seed", type=int, nargs="*", default=[182822], help="Random seed to use" + ) + parser.add_argument("--device", default=0, type=int, help="Device to train on") + parser.add_argument( + "--ext", + type=str, + choices=["png", "pdf", "jpg"], + default="pdf", + help="Output format for plots", + ) + return parser + + +def plot_loss_history(history, fname: str): + """Create a plot of train valid loss""" + fig, ax = plt.subplots(dpi=300) + ax.plot( + np.arange(len(history)), history[:, "train_loss"], label="Train", + ) + ax.plot( + np.arange(len(history)), history[:, "valid_loss"], label="Valid", + ) + ax.legend() + ax.set( + xlabel="Epoch", ylabel="Loss", + ) + fig.savefig(fname) + return fig + + +def main(): + """Run the script""" + parser = build_parser() + args = parser.parse_args() + args.outdir = os.path.abspath(args.outdir) + + if not os.path.isdir(os.path.dirname(args.outdir)): + os.makedirs(os.path.dirname(args.outdir)) + + # Specify output log file + logger = logging.getLogger() + fh = logging.FileHandler(f"{args.outdir}_training.log", "w") + fh.setLevel(logging.INFO) + logger.addHandler(fh) + + # Log parameters and pytorch version + if torch.cuda.is_available(): + logging.info(f"PyTorch CUDA version: {torch.version.cuda}") + for arg in vars(args): + logging.info(f"Parameter {arg}: {getattr(args, arg)}") + + # Borrow parameters + logging.info("Reading RNA data") + if args.snareseq: + rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS) + elif args.shareseq: + logging.info(f"Loading in SHAREseq RNA data for: {args.shareseq}") + rna_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_RNA_DATA_KWARGS) + rna_data_kwargs["fname"] = None + rna_data_kwargs["reader"] = None + rna_data_kwargs["cell_info"] = None + rna_data_kwargs["gene_info"] = None + rna_data_kwargs["transpose"] = False + # Load in the datasets + shareseq_rna_adatas = [] + for tissuetype in args.shareseq: + shareseq_rna_adatas.append( + adata_utils.load_shareseq_data( + tissuetype, + dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq", + mode="RNA", + ) + ) + shareseq_rna_adata = shareseq_rna_adatas[0] + if len(shareseq_rna_adatas) > 1: + shareseq_rna_adata = shareseq_rna_adata.concatenate( + *shareseq_rna_adatas[1:], + join="inner", + batch_key="tissue", + batch_categories=args.shareseq, + ) + rna_data_kwargs["raw_adata"] = shareseq_rna_adata + else: + rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS) + rna_data_kwargs["fname"] = args.data + if args.nofilter: + rna_data_kwargs = { + k: v for k, v in rna_data_kwargs.items() if not k.startswith("filt_") + } + rna_data_kwargs["data_split_by_cluster_log"] = not args.linear + rna_data_kwargs["data_split_by_cluster"] = args.clustermethod + + sc_rna_dataset = sc_data_loaders.SingleCellDataset( + valid_cluster_id=args.validcluster, + test_cluster_id=args.testcluster, + **rna_data_kwargs, + ) + + sc_rna_train_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_rna_dataset, split="train", + ) + sc_rna_valid_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_rna_dataset, split="valid", + ) + sc_rna_test_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_rna_dataset, split="test", + ) + + # ATAC + logging.info("Aggregating ATAC clusters") + if args.snareseq: + atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS) + elif args.shareseq: + logging.info(f"Loading in SHAREseq ATAC data for {args.shareseq}") + atac_data_kwargs = copy.copy(sc_data_loaders.SNARESEQ_ATAC_DATA_KWARGS) + atac_data_kwargs["reader"] = None + atac_data_kwargs["fname"] = None + atac_data_kwargs["cell_info"] = None + atac_data_kwargs["gene_info"] = None + atac_data_kwargs["transpose"] = False + atac_adatas = [] + for tissuetype in args.shareseq: + atac_adatas.append( + adata_utils.load_shareseq_data( + tissuetype, + dirname="/data/wukevin/commonspace_data/GSE140203_SHAREseq", + mode="ATAC", + ) + ) + atac_bins = [a.var_names for a in atac_adatas] + if len(atac_adatas) > 1: + atac_bins_harmonized = sc_data_loaders.harmonize_atac_intervals(*atac_bins) + atac_adatas = [ + sc_data_loaders.repool_atac_bins(a, atac_bins_harmonized) + for a in atac_adatas + ] + shareseq_atac_adata = atac_adatas[0] + if len(atac_adatas) > 1: + shareseq_atac_adata = shareseq_atac_adata.concatenate( + *atac_adatas[1:], + join="inner", + batch_key="tissue", + batch_categories=args.shareseq, + ) + atac_data_kwargs["raw_adata"] = shareseq_atac_adata + else: + atac_parsed = [ + utils.sc_read_10x_h5_ft_type(fname, "Peaks") for fname in args.data + ] + if len(atac_parsed) > 1: + atac_bins = sc_data_loaders.harmonize_atac_intervals( + atac_parsed[0].var_names, atac_parsed[1].var_names + ) + for bins in atac_parsed[2:]: + atac_bins = sc_data_loaders.harmonize_atac_intervals( + atac_bins, bins.var_names + ) + logging.info(f"Aggregated {len(atac_bins)} bins") + else: + atac_bins = list(atac_parsed[0].var_names) + + atac_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_ATAC_DATA_KWARGS) + atac_data_kwargs["fname"] = rna_data_kwargs["fname"] + atac_data_kwargs["pool_genomic_interval"] = 0 # Do not pool + atac_data_kwargs["reader"] = functools.partial( + utils.sc_read_multi_files, + reader=lambda x: sc_data_loaders.repool_atac_bins( + utils.sc_read_10x_h5_ft_type(x, "Peaks"), atac_bins, + ), + ) + atac_data_kwargs["cluster_res"] = 0 # Do not bother clustering ATAC data + + sc_atac_dataset = sc_data_loaders.SingleCellDataset( + predefined_split=sc_rna_dataset, **atac_data_kwargs + ) + sc_atac_train_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_atac_dataset, split="train", + ) + sc_atac_valid_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_atac_dataset, split="valid", + ) + sc_atac_test_dataset = sc_data_loaders.SingleCellDatasetSplit( + sc_atac_dataset, split="test", + ) + + sc_dual_train_dataset = sc_data_loaders.PairedDataset( + sc_rna_train_dataset, sc_atac_train_dataset, flat_mode=True, + ) + sc_dual_valid_dataset = sc_data_loaders.PairedDataset( + sc_rna_valid_dataset, sc_atac_valid_dataset, flat_mode=True, + ) + sc_dual_test_dataset = sc_data_loaders.PairedDataset( + sc_rna_test_dataset, sc_atac_test_dataset, flat_mode=True, + ) + sc_dual_full_dataset = sc_data_loaders.PairedDataset( + sc_rna_dataset, sc_atac_dataset, flat_mode=True, + ) + + # Model + param_combos = list( + itertools.product( + args.hidden, args.lossweight, args.lr, args.batchsize, args.seed + ) + ) + for h_dim, lw, lr, bs, rand_seed in param_combos: + outdir_name = ( + f"{args.outdir}_hidden_{h_dim}_lossweight_{lw}_lr_{lr}_batchsize_{bs}_seed_{rand_seed}" + if len(param_combos) > 1 + else args.outdir + ) + if not os.path.isdir(outdir_name): + assert not os.path.exists(outdir_name) + os.makedirs(outdir_name) + assert os.path.isdir(outdir_name) + with open(os.path.join(outdir_name, "rna_genes.txt"), "w") as sink: + for gene in sc_rna_dataset.data_raw.var_names: + sink.write(gene + "\n") + with open(os.path.join(outdir_name, "atac_bins.txt"), "w") as sink: + for atac_bin in sc_atac_dataset.data_raw.var_names: + sink.write(atac_bin + "\n") + + # Write dataset + ### Full + sc_rna_dataset.size_norm_counts.write_h5ad( + os.path.join(outdir_name, "full_rna.h5ad") + ) + sc_rna_dataset.size_norm_log_counts.write_h5ad( + os.path.join(outdir_name, "full_rna_log.h5ad") + ) + sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad")) + ### Train + sc_rna_train_dataset.size_norm_counts.write_h5ad( + os.path.join(outdir_name, "train_rna.h5ad") + ) + sc_atac_train_dataset.data_raw.write_h5ad( + os.path.join(outdir_name, "train_atac.h5ad") + ) + ### Valid + sc_rna_valid_dataset.size_norm_counts.write_h5ad( + os.path.join(outdir_name, "valid_rna.h5ad") + ) + sc_atac_valid_dataset.data_raw.write_h5ad( + os.path.join(outdir_name, "valid_atac.h5ad") + ) + ### Test + sc_rna_test_dataset.size_norm_counts.write_h5ad( + os.path.join(outdir_name, "truth_rna.h5ad") + ) + sc_atac_dataset.data_raw.write_h5ad(os.path.join(outdir_name, "full_atac.h5ad")) + sc_atac_test_dataset.data_raw.write_h5ad( + os.path.join(outdir_name, "truth_atac.h5ad") + ) + + # Instantiate and train model + model_class = ( + autoencoders.NaiveSplicedAutoEncoder + if args.naive + else autoencoders.AssymSplicedAutoEncoder + ) + spliced_net = autoencoders.SplicedAutoEncoderSkorchNet( + module=model_class, + module__hidden_dim=h_dim, # Based on hyperparam tuning + module__input_dim1=sc_rna_dataset.data_raw.shape[1], + module__input_dim2=sc_atac_dataset.get_per_chrom_feature_count(), + module__final_activations1=[ + activations.Exp(), + activations.ClippedSoftplus(), + ], + module__final_activations2=nn.Sigmoid(), + module__flat_mode=True, + module__seed=rand_seed, + lr=lr, # Based on hyperparam tuning + criterion=loss_functions.QuadLoss, + criterion__loss2=loss_functions.BCELoss, # handle output of encoded layer + criterion__loss2_weight=lw, # numerically balance the two losses with different magnitudes + criterion__record_history=True, + optimizer=OPTIMIZER_DICT[args.optim], + iterator_train__shuffle=True, + device=utils.get_device(args.device), + batch_size=bs, # Based on hyperparam tuning + max_epochs=500, + callbacks=[ + skorch.callbacks.EarlyStopping(patience=args.earlystop), + skorch.callbacks.LRScheduler( + policy=torch.optim.lr_scheduler.ReduceLROnPlateau, + **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS, + ), + skorch.callbacks.GradientNormClipping(gradient_clip_value=5), + skorch.callbacks.Checkpoint( + dirname=outdir_name, fn_prefix="net_", monitor="valid_loss_best", + ), + ], + train_split=skorch.helper.predefined_split(sc_dual_valid_dataset), + iterator_train__num_workers=8, + iterator_valid__num_workers=8, + ) + if args.pretrain: + # Load in the warm start parameters + spliced_net.load_params(f_params=args.pretrain) + spliced_net.partial_fit(sc_dual_train_dataset, y=None) + else: + spliced_net.fit(sc_dual_train_dataset, y=None) + + fig = plot_loss_history( + spliced_net.history, os.path.join(outdir_name, f"loss.{args.ext}") + ) + plt.close(fig) + + logging.info("Evaluating on test set") + logging.info("Evaluating RNA > RNA") + sc_rna_test_preds = spliced_net.translate_1_to_1(sc_dual_test_dataset) + sc_rna_test_preds_anndata = sc.AnnData( + sc_rna_test_preds, + var=sc_rna_test_dataset.data_raw.var, + obs=sc_rna_test_dataset.data_raw.obs, + ) + sc_rna_test_preds_anndata.write_h5ad( + os.path.join(outdir_name, "rna_rna_test_preds.h5ad") + ) + fig = plot_utils.plot_scatter_with_r( + sc_rna_test_dataset.size_norm_counts.X, + sc_rna_test_preds, + one_to_one=True, + logscale=True, + density_heatmap=True, + title="RNA > RNA (test set)", + fname=os.path.join(outdir_name, f"rna_rna_scatter_log.{args.ext}"), + ) + plt.close(fig) + + logging.info("Evaluating ATAC > ATAC") + sc_atac_test_preds = spliced_net.translate_2_to_2(sc_dual_test_dataset) + sc_atac_test_preds_anndata = sc.AnnData( + sc_atac_test_preds, + var=sc_atac_test_dataset.data_raw.var, + obs=sc_atac_test_dataset.data_raw.obs, + ) + sc_atac_test_preds_anndata.write_h5ad( + os.path.join(outdir_name, "atac_atac_test_preds.h5ad") + ) + fig = plot_utils.plot_auroc( + sc_atac_test_dataset.data_raw.X, + sc_atac_test_preds, + title_prefix="ATAC > ATAC", + fname=os.path.join(outdir_name, f"atac_atac_auroc.{args.ext}"), + ) + plt.close(fig) + + logging.info("Evaluating ATAC > RNA") + sc_atac_rna_test_preds = spliced_net.translate_2_to_1(sc_dual_test_dataset) + sc_atac_rna_test_preds_anndata = sc.AnnData( + sc_atac_rna_test_preds, + var=sc_rna_test_dataset.data_raw.var, + obs=sc_rna_test_dataset.data_raw.obs, + ) + sc_atac_rna_test_preds_anndata.write_h5ad( + os.path.join(outdir_name, "atac_rna_test_preds.h5ad") + ) + fig = plot_utils.plot_scatter_with_r( + sc_rna_test_dataset.size_norm_counts.X, + sc_atac_rna_test_preds, + one_to_one=True, + logscale=True, + density_heatmap=True, + title="ATAC > RNA (test set)", + fname=os.path.join(outdir_name, f"atac_rna_scatter_log.{args.ext}"), + ) + plt.close(fig) + + logging.info("Evaluating RNA > ATAC") + sc_rna_atac_test_preds = spliced_net.translate_1_to_2(sc_dual_test_dataset) + sc_rna_atac_test_preds_anndata = sc.AnnData( + sc_rna_atac_test_preds, + var=sc_atac_test_dataset.data_raw.var, + obs=sc_atac_test_dataset.data_raw.obs, + ) + sc_rna_atac_test_preds_anndata.write_h5ad( + os.path.join(outdir_name, "rna_atac_test_preds.h5ad") + ) + fig = plot_utils.plot_auroc( + sc_atac_test_dataset.data_raw.X, + sc_rna_atac_test_preds, + title_prefix="RNA > ATAC", + fname=os.path.join(outdir_name, f"rna_atac_auroc.{args.ext}"), + ) + plt.close(fig) + + del spliced_net + + +if __name__ == "__main__": + main()