--- a +++ b/bin/combine_plaintext_atac_files.py @@ -0,0 +1,149 @@ +""" +Code to combine a bunch of plaintext RNA files + +Takes as input a series of filename *prefixes* + +Example usage: +python combine_plaintext_atac_files.py GSM4119513 GSM4119514 GSM4119515 GSM4119516 GSM4119517 GSM4119518 GSM4119519 -o output.h5ad +""" + +import os +import sys +import argparse +from typing import * +import multiprocessing +import functools +import logging +import glob +import gzip + +logging.basicConfig(level=logging.INFO) + +import pandas as pd +import scipy +import scanpy as sc +import anndata + +SRC_DIR = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel", +) +assert os.path.isdir(SRC_DIR) +sys.path.append(SRC_DIR) +import sc_data_loaders + + +def separate_trio_files(trio: Tuple[str, str, str]) -> Tuple[str, str, str]: + """ + Organize the trio of files such that they are in the order: + barcodes, peaks, matrix + """ + barcodes_files = [f for f in trio if "barcodes" in f] + assert len(barcodes_files) == 1 + peaks_files = [f for f in trio if "peaks" in f] + assert len(peaks_files) == 1 + mat_files = [f for f in trio if "matrix" in f] + assert len(mat_files) == 1 + return barcodes_files.pop(), peaks_files.pop(), mat_files.pop() + + +def read_barcodes(fname: str) -> List[str]: + """Read the barcodes file""" + opener = gzip.open if fname.endswith(".gz") else open + with opener(fname) as source: + retval = [l.strip() for l in source] + retval = [l.decode() if isinstance(l, bytes) else l for l in retval] + return retval + + +def read_peaks(fname: str) -> List[str]: + """Read the peaks file""" + opener = gzip.open if fname.endswith(".gz") else open + with opener(fname) as source: + tokens = [l.strip().decode() for l in source if l] + tokens = [l.decode() if isinstance(l, bytes) else l for l in tokens] + return tokens + + +def read_prefix(prefix: str) -> sc.AnnData: + """ + Helper function for reading in a prefix + """ + matches = glob.glob(prefix + "*") + assert len(matches) == 3, f"Got unexpected matches with prefix {prefix}" + + barcodes_file, peaks_file, mat_file = separate_trio_files(matches) + barcodes = read_barcodes(barcodes_file) + logging.info(f"Read {len(barcodes)} barcodes from {barcodes_file}") + + peaks = read_peaks(peaks_file) + logging.info(f"Read {len(peaks)} peaks from {peaks_file}") + + adata = sc.AnnData( + scipy.sparse.csr_matrix(scipy.io.mmread(mat_file).T), + obs=pd.DataFrame(index=barcodes), + var=pd.DataFrame(index=peaks), + ) + return adata + + +def build_parser(): + """Build commandline argument parser""" + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "prefix", + nargs="*", + type=str, + help="File prefixes denoting the files to combine", + ) + parser.add_argument( + "--output", "-o", type=str, required=True, help="Output file to write to" + ) + parser.add_argument( + "--threads", "-t", type=int, default=int(multiprocessing.cpu_count() / 2) + ) + return parser + + +def main(): + """Run main body of the script""" + parser = build_parser() + args = parser.parse_args() + assert args.output.endswith(".h5ad"), "Output file must be in .h5ad format" + threads = min(args.threads, len(args.prefix)) + + # Read in all the prefixes + pool = multiprocessing.Pool(threads) + adatas = list(pool.map(read_prefix, args.prefix)) + pool.close() + pool.join() + + # After having read in all the files, aggregate them + common_bins = adatas[0].var_names + for adata in adatas[1:]: + common_bins = sc_data_loaders.harmonize_atac_intervals( + common_bins, adata.var_names + ) + + logging.info(f"Aggregated {len(args.prefix)} prefixes into {len(common_bins)} bins") + + pfunc = functools.partial(sc_data_loaders.repool_atac_bins, target_bins=common_bins) + pool = multiprocessing.Pool(threads) + adatas = list(pool.map(pfunc, adatas)) + pool.close() + pool.join() + + retval = adatas[0] + if len(adatas) > 1: + retval = retval.concatenate(adatas[1:]) + logging.info( + f"Concatenated {len(args.prefix)} prefixes into a single adata of {retval.shape}" + ) + + logging.info(f"Writing to {args.output}") + retval.write(args.output) + + +if __name__ == "__main__": + main()