Switch to unified view

a b/bin/train_protein_predictor.py
1
"""
2
Script for training a protein predictor
3
"""
4
5
import os
6
import sys
7
import logging
8
import argparse
9
import copy
10
import functools
11
import itertools
12
import collections
13
from typing import *
14
import json
15
16
import numpy as np
17
import pandas as pd
18
from scipy import sparse
19
import scanpy as sc
20
import anndata as ad
21
22
import matplotlib.pyplot as plt
23
24
import torch
25
import torch.nn as nn
26
import torch.nn.functional as F
27
import skorch
28
import skorch.helper
29
30
torch.backends.cudnn.deterministic = True  # For reproducibility
31
torch.backends.cudnn.benchmark = False
32
33
SRC_DIR = os.path.join(
34
    os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "babel"
35
)
36
assert os.path.isdir(SRC_DIR)
37
sys.path.append(SRC_DIR)
38
MODELS_DIR = os.path.join(SRC_DIR, "models")
39
assert os.path.isdir(MODELS_DIR)
40
sys.path.append(MODELS_DIR)
41
42
import sc_data_loaders
43
import autoencoders
44
import loss_functions
45
import model_utils
46
from protein_utils import LOSS_DICT, OPTIM_DICT, ACT_DICT
47
import utils
48
49
from train_model import plot_loss_history
50
51
logging.basicConfig(level=logging.INFO)
52
53
54
def load_rna_files(
55
    rna_counts_fnames: List[str], model_dir: str, transpose: bool = True
56
) -> ad.AnnData:
57
    """Load the RNA files in, filling in unmeasured genes as necessary"""
58
    # Find the genes that the model understands
59
    rna_genes_list_fname = os.path.join(model_dir, "rna_genes.txt")
60
    assert os.path.isfile(
61
        rna_genes_list_fname
62
    ), f"Cannot find RNA genes file: {rna_genes_list_fname}"
63
    learned_rna_genes = utils.read_delimited_file(rna_genes_list_fname)
64
    assert isinstance(learned_rna_genes, list)
65
    assert utils.is_all_unique(
66
        learned_rna_genes
67
    ), "Learned genes list contains duplicates"
68
69
    temp_ad = utils.sc_read_multi_files(
70
        rna_counts_fnames,
71
        feature_type="Gene Expression",
72
        transpose=transpose,
73
        join="outer",
74
    )
75
    logging.info(f"Read input RNA files for {temp_ad.shape}")
76
    temp_ad.X = utils.ensure_arr(temp_ad.X)
77
78
    # Filter for mouse genes and remove human/mouse prefix
79
    temp_ad.var_names_make_unique()
80
    kept_var_names = [
81
        vname for vname in temp_ad.var_names if not vname.startswith("MOUSE_")
82
    ]
83
    if len(kept_var_names) != temp_ad.n_vars:
84
        temp_ad = temp_ad[:, kept_var_names]
85
    temp_ad.var = pd.DataFrame(index=[v.strip("HUMAN_") for v in kept_var_names])
86
87
    # Expand adata to span all genes
88
    # Initiating as a sparse matrix doesn't allow vectorized building
89
    intersected_genes = set(temp_ad.var_names).intersection(learned_rna_genes)
90
    assert intersected_genes, "No overlap between learned and input genes!"
91
    expanded_mat = np.zeros((temp_ad.n_obs, len(learned_rna_genes)))
92
    skip_count = 0
93
    for gene in intersected_genes:
94
        dest_idx = learned_rna_genes.index(gene)
95
        src_idx = temp_ad.var_names.get_loc(gene)
96
        if not isinstance(src_idx, int):
97
            logging.warn(f"Got multiple source matches for {gene}, skipping")
98
            skip_count += 1
99
            continue
100
        v = utils.ensure_arr(temp_ad.X[:, src_idx]).flatten()
101
        expanded_mat[:, dest_idx] = v
102
    if skip_count:
103
        logging.warning(
104
            f"Skipped {skip_count}/{len(intersected_genes)} genes due to multiple matches"
105
        )
106
    expanded_mat = sparse.csr_matrix(expanded_mat)  # Compress
107
    retval = ad.AnnData(
108
        expanded_mat, obs=temp_ad.obs, var=pd.DataFrame(index=learned_rna_genes)
109
    )
110
    return retval
111
112
113
def build_parser():
114
    """Build CLI parser"""
115
    parser = argparse.ArgumentParser(
116
        description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
117
    )
118
    parser.add_argument(
119
        "--rnaCounts",
120
        type=str,
121
        nargs="*",
122
        required=True,
123
        help="file containing raw RNA counts",
124
    )
125
    parser.add_argument(
126
        "--proteinCounts",
127
        type=str,
128
        nargs="*",
129
        required=True,
130
        help="file containing raw protein counts",
131
    )
132
    parser.add_argument(
133
        "--encoder", required=True, type=str, help="Model folder to find encoder"
134
    )
135
    parser.add_argument(
136
        "--outdir",
137
        type=str,
138
        default=os.getcwd(),
139
        help="Output directory for model, defaults to current dir",
140
    )
141
    parser.add_argument(
142
        "--clusterres",
143
        type=float,
144
        default=1.5,
145
        help="Cluster resolution for train/valid/test splits",
146
    )
147
    parser.add_argument(
148
        "--validcluster", type=int, default=0, help="Cluster ID to use as valid cluster"
149
    )
150
    parser.add_argument(
151
        "--testcluster", type=int, default=1, help="Cluster ID to use as test cluster"
152
    )
153
    parser.add_argument(
154
        "--preprocessonly",
155
        action="store_true",
156
        help="Preprocess data only, do not train model",
157
    )
158
    parser.add_argument(
159
        "--act",
160
        type=str,
161
        choices=ACT_DICT.keys(),
162
        default="prelu",
163
        help="Activation function",
164
    )
165
    parser.add_argument(
166
        "--loss", type=str, choices=LOSS_DICT.keys(), default="L1", help="Loss"
167
    )
168
    parser.add_argument(
169
        "--optim", type=str, choices=OPTIM_DICT.keys(), default="adam", help="Optimizer"
170
    )
171
    parser.add_argument("--interdim", type=int, default=64, help="Intermediate dim")
172
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
173
    parser.add_argument("--bs", type=int, default=512, help="Batch size")
174
    parser.add_argument(
175
        "--epochs", type=int, default=600, help="Maximum number of epochs to train"
176
    )
177
    parser.add_argument(
178
        "--notrans",
179
        action="store_true",
180
        help="Do not transpose (already in row obs form)",
181
    )
182
    parser.add_argument("--device", default=0, type=int, help="Device for training")
183
    return parser
184
185
186
def main():
187
    """Train a protein predictor"""
188
    parser = build_parser()
189
    args = parser.parse_args()
190
191
    # Create output directory
192
    if not os.path.isdir(args.outdir):
193
        os.makedirs(args.outdir)
194
195
    # Specify output log file
196
    logger = logging.getLogger()
197
    fh = logging.FileHandler(os.path.join(args.outdir, "training.log"))
198
    fh.setLevel(logging.INFO)
199
    logger.addHandler(fh)
200
201
    # Log parameters
202
    for arg in vars(args):
203
        logging.info(f"Parameter {arg}: {getattr(args, arg)}")
204
    with open(os.path.join(args.outdir, "params.json"), "w") as sink:
205
        json.dump(vars(args), sink, indent=4)
206
207
    # Load the model
208
    pretrained_net = model_utils.load_model(args.encoder, device=args.device)
209
210
    # Load in some files
211
    rna_genes = utils.read_delimited_file(os.path.join(args.encoder, "rna_genes.txt"))
212
    atac_bins = utils.read_delimited_file(os.path.join(args.encoder, "atac_bins.txt"))
213
214
    # Read in the RNA
215
    rna_data_kwargs = copy.copy(sc_data_loaders.TENX_PBMC_RNA_DATA_KWARGS)
216
    rna_data_kwargs["cluster_res"] = args.clusterres
217
    rna_data_kwargs["fname"] = args.rnaCounts
218
    rna_data_kwargs["reader"] = lambda x: load_rna_files(
219
        x, args.encoder, transpose=not args.notrans
220
    )
221
222
    # Construct data folds
223
    full_sc_rna_dataset = sc_data_loaders.SingleCellDataset(
224
        valid_cluster_id=args.validcluster,
225
        test_cluster_id=args.testcluster,
226
        **rna_data_kwargs,
227
    )
228
    full_sc_rna_dataset.data_raw.write_h5ad(os.path.join(args.outdir, "full_rna.h5ad"))
229
230
    train_valid_test_dsets = []
231
    for mode in ["all", "train", "valid", "test"]:
232
        logging.info(f"Constructing {mode} dataset")
233
        sc_rna_dataset = sc_data_loaders.SingleCellDatasetSplit(
234
            full_sc_rna_dataset, split=mode
235
        )
236
        sc_rna_dataset.data_raw.write_h5ad(
237
            os.path.join(args.outdir, f"{mode}_rna.h5ad")
238
        )  # Write RNA input
239
        sc_atac_dummy_dataset = sc_data_loaders.DummyDataset(
240
            shape=len(atac_bins), length=len(sc_rna_dataset)
241
        )
242
        # RNA and fake ATAC
243
        sc_dual_dataset = sc_data_loaders.PairedDataset(
244
            sc_rna_dataset,
245
            sc_atac_dummy_dataset,
246
            flat_mode=True,
247
        )
248
        # encoded(RNA) as "x" and RNA + fake ATAC as "y"
249
        sc_rna_encoded_dataset = sc_data_loaders.EncodedDataset(
250
            sc_dual_dataset, model=pretrained_net, input_mode="RNA"
251
        )
252
        sc_rna_encoded_dataset.encoded.write_h5ad(
253
            os.path.join(args.outdir, f"{mode}_encoded.h5ad")
254
        )
255
        sc_protein_dataset = sc_data_loaders.SingleCellProteinDataset(
256
            args.proteinCounts,
257
            obs_names=sc_rna_dataset.obs_names,
258
            transpose=not args.notrans,
259
        )
260
        sc_protein_dataset.data_raw.write_h5ad(
261
            os.path.join(args.outdir, f"{mode}_protein.h5ad")
262
        )  # Write protein
263
        # x = 16 dimensional encoded layer, y = 25 dimensional protein array
264
        sc_rna_protein_dataset = sc_data_loaders.SplicedDataset(
265
            sc_rna_encoded_dataset, sc_protein_dataset
266
        )
267
        _temp = sc_rna_protein_dataset[0]  # ensure calling works
268
        train_valid_test_dsets.append(sc_rna_protein_dataset)
269
270
    # Unpack and do sanity checks
271
    _, sc_rna_prot_train, sc_rna_prot_valid, sc_rna_prot_test = train_valid_test_dsets
272
    x, y, z = sc_rna_prot_train[0], sc_rna_prot_valid[0], sc_rna_prot_test[0]
273
    assert (
274
        x[0].shape == y[0].shape == z[0].shape
275
    ), f"Got mismatched shapes: {x[0].shape} {y[0].shape} {z[0].shape}"
276
    assert (
277
        x[1].shape == y[1].shape == z[1].shape
278
    ), f"Got mismatched shapes: {x[1].shape} {y[1].shape} {z[1].shape}"
279
280
    protein_markers = list(sc_protein_dataset.data_raw.var_names)
281
    with open(os.path.join(args.outdir, "protein_proteins.txt"), "w") as sink:
282
        sink.write("\n".join(protein_markers) + "\n")
283
    assert len(
284
        utils.read_delimited_file(os.path.join(args.outdir, "protein_proteins.txt"))
285
    ) == len(protein_markers)
286
    logging.info(f"Predicting on {len(protein_markers)} proteins")
287
288
    if args.preprocessonly:
289
        return
290
291
    protein_decoder_skorch = skorch.NeuralNet(
292
        module=autoencoders.Decoder,
293
        module__num_units=16,
294
        module__intermediate_dim=args.interdim,
295
        module__num_outputs=len(protein_markers),
296
        module__activation=ACT_DICT[args.act],
297
        module__final_activation=nn.Identity(),
298
        # module__final_activation=nn.Linear(
299
        #     len(protein_markers), len(protein_markers), bias=True
300
        # ),  # Paper uses identity activation instead
301
        lr=args.lr,
302
        criterion=LOSS_DICT[args.loss],  # Other works use L1 loss
303
        optimizer=OPTIM_DICT[args.optim],
304
        batch_size=args.bs,
305
        max_epochs=args.epochs,
306
        callbacks=[
307
            skorch.callbacks.EarlyStopping(patience=15),
308
            skorch.callbacks.LRScheduler(
309
                policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
310
                patience=5,
311
                factor=0.1,
312
                min_lr=1e-6,
313
                # **model_utils.REDUCE_LR_ON_PLATEAU_PARAMS,
314
            ),
315
            skorch.callbacks.GradientNormClipping(gradient_clip_value=5),
316
            skorch.callbacks.Checkpoint(
317
                dirname=args.outdir,
318
                fn_prefix="net_",
319
                monitor="valid_loss_best",
320
            ),
321
        ],
322
        train_split=skorch.helper.predefined_split(sc_rna_prot_valid),
323
        iterator_train__num_workers=8,
324
        iterator_valid__num_workers=8,
325
        device=utils.get_device(args.device),
326
    )
327
    protein_decoder_skorch.fit(sc_rna_prot_train, y=None)
328
329
    # Plot the loss history
330
    fig = plot_loss_history(
331
        protein_decoder_skorch.history, os.path.join(args.outdir, "loss.pdf")
332
    )
333
334
335
if __name__ == "__main__":
336
    main()