Diff of /indrops.py [000000] .. [4d6235]

Switch to unified view

a b/indrops.py
1
import os, subprocess
2
import itertools
3
import operator
4
from collections import defaultdict, OrderedDict
5
import errno
6
7
# cPickle is a faster version of pickle that isn't installed in python3
8
# inserted try statement just in case
9
try:
10
   import cPickle as pickle
11
except:
12
   import pickle
13
14
from io import BytesIO
15
16
import numpy as np
17
import re
18
import shutil
19
import gzip
20
21
# product: product(A, B) returns the same as ((x,y) for x in A for y in B).
22
# combination: Return r length subsequences of elements from the input iterable.
23
from itertools import product, combinations
24
import time
25
26
import yaml
27
28
import tempfile
29
import string
30
from contextlib import contextmanager
31
32
# -----------------------
33
#
34
# Helper functions
35
#
36
# -----------------------
37
38
def string_hamming_distance(str1, str2):
39
    """
40
    Fast hamming distance over 2 strings known to be of same length.
41
    In information theory, the Hamming distance between two strings of equal 
42
    length is the number of positions at which the corresponding symbols 
43
    are different.
44
45
    eg "karolin" and "kathrin" is 3.
46
    """
47
    return sum(itertools.imap(operator.ne, str1, str2))
48
49
___tbl = {'A':'T', 'T':'A', 'C':'G', 'G':'C', 'N':'N'}
50
def rev_comp(seq):
51
    return ''.join(___tbl[s] for s in seq[::-1])
52
53
54
def to_fastq(name, seq, qual):
55
    """
56
    Return string that can be written to fastQ file
57
    """
58
    return '@'+name+'\n'+seq+'\n+\n'+qual+'\n'
59
60
def to_fastq_lines(bc, umi, seq, qual, read_name=''):
61
    """
62
    Return string that can be written to fastQ file
63
    """
64
    reformated_name = read_name.replace(':', '_')
65
    name = '%s:%s:%s' % (bc, umi, reformated_name)
66
    return to_fastq(name, seq, qual)
67
68
def from_fastq(handle):
69
    while True:
70
        name = next(handle).rstrip()[1:] #Read name
71
        seq = next(handle).rstrip() #Read seq
72
        next(handle) #+ line
73
        qual = next(handle).rstrip() #Read qual
74
        if not name or not seq or not qual:
75
            break
76
        yield name, seq, qual
77
78
def seq_neighborhood(seq, n_subs=1):
79
    """
80
    Given a sequence, yield all sequences within n_subs substitutions of 
81
    that sequence by looping through each combination of base pairs within
82
    each combination of positions.
83
    """
84
    for positions in combinations(range(len(seq)), n_subs):
85
    # yields all unique combinations of indices for n_subs mutations
86
        for subs in product(*("ATGCN",)*n_subs):
87
        # yields all combinations of possible nucleotides for strings of length
88
        # n_subs
89
            seq_copy = list(seq)
90
            for p, s in zip(positions, subs):
91
                seq_copy[p] = s
92
            yield ''.join(seq_copy)
93
94
def build_barcode_neighborhoods(barcode_file, expect_reverse_complement=True):
95
    """
96
    Given a set of barcodes, produce sequences which can unambiguously be
97
    mapped to these barcodes, within 2 substitutions. If a sequence maps to 
98
    multiple barcodes, get rid of it. However, if a sequences maps to a bc1 with 
99
    1change and another with 2changes, keep the 1change mapping.
100
    """
101
102
    # contains all mutants that map uniquely to a barcode
103
    clean_mapping = dict()
104
105
    # contain single or double mutants 
106
    mapping1 = defaultdict(set)
107
    mapping2 = defaultdict(set)
108
    
109
    #Build the full neighborhood and iterate through barcodes
110
    with open(barcode_file, 'rU') as f:
111
        # iterate through each barcode (rstrip cleans string of whitespace)
112
        for line in f:
113
            barcode = line.rstrip()
114
            if expect_reverse_complement:
115
                barcode = rev_comp(line.rstrip())
116
117
            # each barcode obviously maps to itself uniquely
118
            clean_mapping[barcode] = barcode
119
120
            # for each possible mutated form of a given barcode, either add
121
            # the origin barcode into the set corresponding to that mutant or 
122
            # create a new entry for a mutant not already in mapping1
123
            # eg: barcodes CATG and CCTG would be in the set for mutant CTTG
124
            # but only barcode CATG could generate mutant CANG
125
            for n in seq_neighborhood(barcode, 1):
126
                mapping1[n].add(barcode)
127
            
128
            # same as above but with double mutants
129
            for n in seq_neighborhood(barcode, 2):
130
                mapping2[n].add(barcode)   
131
    
132
    # take all single-mutants and find those that could only have come from one
133
    # specific barcode
134
    for k, v in mapping1.items():
135
        if k not in clean_mapping:
136
            if len(v) == 1:
137
                clean_mapping[k] = list(v)[0]
138
    
139
    for k, v in mapping2.items():
140
        if k not in clean_mapping:
141
            if len(v) == 1:
142
                clean_mapping[k] = list(v)[0]
143
    del mapping1
144
    del mapping2
145
    return clean_mapping
146
147
def check_dir(path):
148
    """
149
    Checks if directory already exists or not and creates it if it doesn't
150
    """
151
    try:
152
        os.makedirs(path)
153
    except OSError as exception:
154
        if exception.errno != errno.EEXIST:
155
            raise
156
157
def print_to_stderr(msg, newline=True):
158
    """
159
    Wrapper to eventually write to stderr
160
    """
161
    sys.stderr.write(str(msg))
162
    if newline:
163
        sys.stderr.write('\n')
164
165
def worker_filter(iterable, worker_index, total_workers):
166
    return (p for i,p in enumerate(iterable) if (i-worker_index)%total_workers==0)
167
168
class FIFO():
169
    """
170
    A context manager for a named pipe.
171
    """
172
    def __init__(self, filename="", suffix="", prefix="tmp_fifo_dir", dir=None):
173
        if filename:
174
            self.filename = filename
175
        else:
176
            self.tmpdir = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
177
            self.filename = os.path.join(self.tmpdir, 'fifo')
178
179
    def __enter__(self):
180
        if os.path.exists(self.filename):
181
            os.unlink(self.filename)
182
        os.mkfifo(self.filename)
183
        return self
184
185
    def __exit__(self, type, value, traceback):
186
        os.remove(self.filename)
187
        if hasattr(self, 'tmpdir'):
188
            shutil.rmtree(self.tmpdir)
189
190
# -----------------------
191
#
192
# Core objects
193
#
194
# -----------------------
195
196
class IndropsProject():
197
198
    def __init__(self, project_yaml_file_handle, read_only=False):
199
200
        self.yaml = yaml.load(project_yaml_file_handle)
201
202
        self.name = self.yaml['project_name']
203
        self.project_dir = self.yaml['project_dir']
204
205
        self.libraries = OrderedDict()
206
        self.runs = OrderedDict()
207
208
        self.read_only = read_only
209
210
        for run in self.yaml['sequencing_runs']:
211
            """
212
            After filtering, each sequencing run generates between 1 ... X files with filtered reads.
213
               X = (N x M)
214
             - N: The run is often split into several files (a typical NextSeq run is split into L001,
215
                  L002, L003, L004 which match different lanes, but this can also be done artificially.
216
             - M: The same run might contain several libraries. The demultiplexing can be handled by the script (or externally).
217
                  If demultiplexing is done externally, there will be a different .fastq file for each library.
218
            """
219
            version = run['version']
220
221
            filtered_filename = '{library_name}_{run_name}'
222
            if run['version'] == 'v3':
223
                filtered_filename += '_{library_index}'
224
            # Prepare to iterate over run split into several files
225
            if 'split_affixes' in run:
226
                filtered_filename += '_{split_affix}'
227
                split_affixes = run['split_affixes']
228
            else:
229
                split_affixes = ['']
230
231
            filtered_filename += '.fastq'
232
233
            # Prepare to iterate over libraries
234
            if 'libraries' in run:
235
                run_libraries = run['libraries']
236
            elif 'library_name' in run:
237
                run_libraries = [{'library_name' : run['library_name'], 'library_prefix':''}]
238
            else:
239
                raise Exception('No library name or libraries specified.')
240
241
            if run['version']=='v1' or run['version']=='v2':
242
                for affix in split_affixes:
243
                    for lib in run_libraries:
244
                        lib_name = lib['library_name']
245
                        if lib_name not in self.libraries:
246
                            self.libraries[lib_name] = IndropsLibrary(name=lib_name, project=self, version=run['version'])
247
                        else:
248
                            assert self.libraries[lib_name].version == run['version']
249
250
                        if version == 'v1':
251
                            metaread_filename = os.path.join(run['dir'],run['fastq_path'].format(split_affix=affix, read='R1', library_prefix=lib['library_prefix']))
252
                            bioread_filename = os.path.join(run['dir'],run['fastq_path'].format(split_affix=affix, read='R2', library_prefix=lib['library_prefix']))
253
                        elif version == 'v2':
254
                            metaread_filename  = os.path.join(run['dir'],run['fastq_path'].format(split_affix=affix, read='R2', library_prefix=lib['library_prefix']))
255
                            bioread_filename = os.path.join(run['dir'],run['fastq_path'].format(split_affix=affix, read='R1', library_prefix=lib['library_prefix']))
256
257
                        filtered_part_filename = filtered_filename.format(run_name=run['name'], split_affix=affix, library_name=lib_name)
258
                        filtered_part_path = os.path.join(self.project_dir, lib_name, 'filtered_parts', filtered_part_filename)
259
                        part = V1V2Filtering(filtered_fastq_filename=filtered_part_path,
260
                            project=self, 
261
                            bioread_filename=bioread_filename,
262
                            metaread_filename=metaread_filename,
263
                            run_name=run['name'],
264
                            library_name=lib_name,
265
                            part_name=affix)
266
267
                        if run['name'] not in self.runs:
268
                            self.runs[run['name']] = []
269
                        self.runs[run['name']].append(part)
270
                        self.libraries[lib_name].parts.append(part)
271
272
            elif run['version'] == 'v3' or run['version'] == 'v3-miseq':
273
                for affix in split_affixes:
274
                    filtered_part_filename = filtered_filename.format(run_name=run['name'], split_affix=affix,
275
                        library_name='{library_name}', library_index='{library_index}')
276
                    part_filename = os.path.join(self.project_dir, '{library_name}', 'filtered_parts', filtered_part_filename)
277
278
                    input_filename = os.path.join(run['dir'], run['fastq_path'].format(split_affix=affix, read='{read}'))
279
                    part = V3Demultiplexer(run['libraries'], project=self, part_filename=part_filename, input_filename=input_filename, run_name=run['name'], part_name=affix,
280
                        run_version_details=run['version'])
281
282
                    if run['name'] not in self.runs:
283
                        self.runs[run['name']] = []
284
                    self.runs[run['name']].append(part)
285
286
                    for lib in run_libraries:
287
                        lib_name = lib['library_name']
288
                        lib_index = lib['library_index']
289
                        if lib_name not in self.libraries:
290
                            self.libraries[lib_name] = IndropsLibrary(name=lib_name, project=self, version=run['version'])
291
                        self.libraries[lib_name].parts.append(part.libraries[lib_index])
292
293
294
    @property
295
    def paths(self):
296
        if not hasattr(self, '_paths'):
297
            script_dir = os.path.dirname(os.path.realpath(__file__))
298
            #Read defaults
299
            with open(os.path.join(script_dir, 'default_parameters.yaml'), 'r') as f:
300
                paths = yaml.load(f)['paths']
301
            # Update with user provided values
302
            paths.update(self.yaml['paths'])
303
304
            paths['python'] = os.path.join(paths['python_dir'], 'python')
305
            paths['java'] = os.path.join(paths['java_dir'], 'java')
306
            paths['bowtie'] = os.path.join(paths['bowtie_dir'], 'bowtie')
307
            paths['samtools'] = os.path.join(paths['samtools_dir'], 'samtools')
308
            paths['trimmomatic_jar'] = os.path.join(script_dir, 'bins', 'trimmomatic-0.33.jar')
309
            paths['rsem_tbam2gbam'] = os.path.join(paths['rsem_dir'], 'rsem-tbam2gbam')
310
            paths['rsem_prepare_reference'] = os.path.join(paths['rsem_dir'], 'rsem-prepare-reference')
311
312
            self._paths = type('Paths_anonymous_object',(object,),paths)()
313
            self._paths.trim_polyA_and_filter_low_complexity_reads_py = os.path.join(script_dir, 'trim_polyA_and_filter_low_complexity_reads.py')
314
            self._paths.quantify_umifm_from_alignments_py = os.path.join(script_dir, 'quantify_umifm_from_alignments.py')
315
            self._paths.count_barcode_distribution_py = os.path.join(script_dir, 'count_barcode_distribution.py')
316
            self._paths.gel_barcode1_list = os.path.join(script_dir, 'ref/barcode_lists/gel_barcode1_list.txt')
317
            self._paths.gel_barcode2_list = os.path.join(script_dir, 'ref/barcode_lists/gel_barcode2_list.txt')
318
        return self._paths
319
320
    @property
321
    def parameters(self):
322
        if not hasattr(self, '_parameters'):
323
            #Read defaults
324
            with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'default_parameters.yaml'), 'r') as f:
325
                self._parameters = yaml.load(f)['parameters']
326
            # Update with user provided values
327
            if 'parameters' in self.yaml:
328
                for k, d in self.yaml['parameters'].items():
329
                    self._parameters[k].update(d)
330
331
        return self._parameters
332
333
    @property
334
    def gel_barcode1_revcomp_list_neighborhood(self):
335
        if not hasattr(self, '_gel_barcode1_list_neighborhood'):
336
            self._gel_barcode1_revcomp_list_neighborhood = build_barcode_neighborhoods(self.paths.gel_barcode1_list, True)
337
        return self._gel_barcode1_revcomp_list_neighborhood
338
    
339
    @property
340
    def gel_barcode2_revcomp_list_neighborhood(self):
341
        if not hasattr(self, '_gel_barcode2_revcomp_list_neighborhood'):
342
            self._gel_barcode2_revcomp_list_neighborhood = build_barcode_neighborhoods(self.paths.gel_barcode2_list, True)
343
        return self._gel_barcode2_revcomp_list_neighborhood
344
345
    @property
346
    def gel_barcode2_list_neighborhood(self):
347
        if not hasattr(self, '_gel_barcode2_list_neighborhood'):
348
            self._gel_barcode2_list_neighborhood = build_barcode_neighborhoods(self.paths.gel_barcode2_list, False)
349
        return self._gel_barcode2_list_neighborhood
350
351
    @property
352
    def stable_barcode_names(self):
353
        if not hasattr(self, '_stable_barcode_names'):
354
            with open(self.paths.gel_barcode1_list) as f:
355
                rev_bc1s = [rev_comp(line.rstrip()) for line in f]
356
            with open(self.paths.gel_barcode2_list) as f:
357
                bc2s = [line.rstrip() for line in f]
358
                rev_bc2s = [rev_comp(bc2) for bc2 in bc2s]
359
360
            # V1, V2 names:
361
            v1v2_names = {}
362
            barcode_iter = product(rev_bc1s, rev_bc2s)
363
            name_iter = product(string.ascii_uppercase, repeat=4)
364
            for barcode, name in zip(barcode_iter, name_iter):
365
                v1v2_names['-'.join(barcode)] = 'bc' + ''.join(name)
366
367
            # V3 names:
368
            v3_names = {}
369
            barcode_iter = product(bc2s, rev_bc2s)
370
            name_iter = product(string.ascii_uppercase, repeat=4)
371
            for barcode, name in zip(barcode_iter, name_iter):
372
                v3_names['-'.join(barcode)] = 'bc' + ''.join(name)
373
374
375
            self._stable_barcode_names = {
376
                'v1' : v1v2_names,
377
                'v2' : v1v2_names,
378
                'v3': v3_names,
379
                'v3-miseq':v3_names,
380
            }
381
        return self._stable_barcode_names
382
383
    def project_check_dir(self, path):
384
        if not self.read_only:
385
            check_dir(path)
386
387
    def filter_gtf(self, gzipped_transcriptome_gtf, gtf_with_genenames_in_transcript_id):
388
389
        # A small number of gene are flagged as having two different biotypes.
390
        gene_biotype_dict = defaultdict(set)
391
392
        # Read through GTF file once to get all gene names
393
        for line in subprocess.Popen(["gzip", "--stdout", "-d", gzipped_transcriptome_gtf], stdout=subprocess.PIPE).stdout:
394
            # Skip non-gene feature lines.
395
            if '\tgene\t' not in line:
396
                continue
397
                
398
            gene_biotype_match = re.search(r'gene_biotype \"(.*?)\";', line)
399
            gene_name_match = re.search(r'gene_name \"(.*?)\";', line)
400
            if gene_name_match and gene_biotype_match:
401
                gene_name = gene_name_match.group(1)
402
                gene_biotype = gene_biotype_match.group(1)
403
                
404
                # Record biotype.
405
                gene_biotype_dict[gene_name].add(gene_biotype)
406
407
        # Detect read-through genes by name. Name must be a fusion of two other gene names 'G1-G2'.
408
        readthrough_genes = set()
409
        for gene in gene_biotype_dict.keys():
410
            if '-' in gene and len(gene.split('-')) == 2:
411
                g1, g2 = gene.split('-')
412
                if g1 in gene_biotype_dict and g2 in gene_biotype_dict:
413
                    readthrough_genes.add(gene)
414
415
416
        # Detect pseudogenes: genes where all associated biotypes have 'pseudogene' in name
417
        pseudogenes = set()
418
        for gene, biotypes in gene_biotype_dict.items():
419
            if all('pseudogene' in b for b in biotypes):
420
                pseudogenes.add(gene)
421
422
        all_genes = set(gene_biotype_dict.keys())
423
        valid_genes = all_genes.difference(pseudogenes).difference(readthrough_genes)
424
425
        transcripts_counter = defaultdict(int)
426
427
428
        # Go through GTF file again, annotating each transcript_id with the gene and outputting to a new GTF file.
429
        output_gtf = open(gtf_with_genenames_in_transcript_id, 'w')
430
        for line in subprocess.Popen(["gzip", "--stdout", "-d", gzipped_transcriptome_gtf], stdout=subprocess.PIPE).stdout:
431
            # Skip non-transcript feature lines.
432
            if 'transcript_id' not in line:
433
                continue
434
                
435
            gene_name_match = re.search(r'gene_name \"(.*?)\";', line)
436
            if gene_name_match:
437
                gene_name = gene_name_match.group(1)
438
                if gene_name in valid_genes:
439
                    
440
                    # An unusual edgecase in the GTF for Danio Rerio rel89
441
                    if ' ' in gene_name:
442
                        gene_name = gene_name.replace(' ', '_')
443
444
                    out_line = re.sub(r'(?<=transcript_id ")(.*?)(?=";)', r'\1|'+gene_name, line)
445
                    output_gtf.write(out_line)
446
                    if '\ttranscript\t' in line:
447
                        transcripts_counter['valid'] += 1
448
                elif gene_name in pseudogenes and '\ttranscript\t' in line:
449
                    transcripts_counter['pseudogenes'] += 1
450
                elif gene_name in readthrough_genes and '\ttranscript\t' in line:
451
                    transcripts_counter['readthrough_genes'] += 1
452
        output_gtf.close()
453
454
        print_to_stderr('Filtered GTF contains %d transcripts (%d genes)' % (transcripts_counter['valid'], len(valid_genes)))
455
        print_to_stderr('   - ignored %d transcripts from %d pseudogenes)' % (transcripts_counter['pseudogenes'], len(pseudogenes)))
456
        print_to_stderr('   - ignored %d read-through transcripts (%d genes)' % (transcripts_counter['readthrough_genes'], len(readthrough_genes)))
457
458
    def build_transcriptome(self, gzipped_genome_softmasked_fasta_filename, gzipped_transcriptome_gtf,
459
            mode='strict'):
460
        import pyfasta
461
        
462
        index_dir = os.path.dirname(self.paths.bowtie_index)
463
        self.project_check_dir(index_dir)
464
465
        genome_filename = os.path.join(index_dir, '.'.join(gzipped_genome_softmasked_fasta_filename.split('.')[:-1]))
466
467
        gtf_filename = os.path.join(index_dir, gzipped_transcriptome_gtf.split('/')[-1])
468
        gtf_prefix = '.'.join(gtf_filename.split('.')[:-2])
469
        # gtf_with_genenames_in_transcript_id = gtf_prefix + '.annotated.gtf'
470
        gtf_with_genenames_in_transcript_id = self.paths.bowtie_index + '.gtf'
471
472
        print_to_stderr('Filtering GTF')
473
        self.filter_gtf(gzipped_transcriptome_gtf, gtf_with_genenames_in_transcript_id)
474
        # accepted_gene_biotypes_for_NA_transcripts = set(["protein_coding","IG_V_gene","IG_J_gene","TR_J_gene","TR_D_gene","TR_V_gene","IG_C_gene","IG_D_gene","TR_C_gene"])
475
        # tsl1_or_tsl2_strings = ['transcript_support_level "1"', 'transcript_support_level "1 ', 'transcript_support_level "2"', 'transcript_support_level "2 ']
476
        # tsl_NA =  'transcript_support_level "NA'
477
478
        # def filter_ensembl_transcript(transcript_line):
479
480
        #     line_valid_for_output = False
481
        #     if mode == 'strict':
482
        #         for string in tsl1_or_tsl2_strings:
483
        #             if string in line:
484
        #                 line_valid_for_output = True
485
        #                 break
486
        #         if tsl_NA in line:
487
        #             gene_biotype = re.search(r'gene_biotype \"(.*?)\";', line)
488
        #             if gene_biotype and gene_biotype.group(1) in accepted_gene_biotypes_for_NA_transcripts:
489
        #                 line_valid_for_output = True
490
        #         return line_valid_for_output
491
492
        #     elif mode == 'all_ensembl':
493
        #         line_valid_for_output = True
494
        #         return line_valid_for_output
495
496
497
498
        # print_to_stderr('Filtering GTF')
499
        # output_gtf = open(gtf_with_genenames_in_transcript_id, 'w')
500
        # for line in subprocess.Popen(["gzip", "--stdout", "-d", gzipped_transcriptome_gtf], stdout=subprocess.PIPE).stdout:
501
        #     if 'transcript_id' not in line:
502
        #         continue
503
504
        #     if filter_ensembl_transcript(line):
505
        #         gene_name = re.search(r'gene_name \"(.*?)\";', line)
506
        #         if gene_name:
507
        #             gene_name = gene_name.group(1)
508
        #             out_line = re.sub(r'(?<=transcript_id ")(.*?)(?=";)', r'\1|'+gene_name, line)
509
        #             output_gtf.write(out_line)
510
        # output_gtf.close()
511
512
        print_to_stderr('Gunzipping Genome')
513
        p_gzip = subprocess.Popen(["gzip", "-dfc", gzipped_genome_softmasked_fasta_filename], stdout=open(genome_filename, 'wb'))
514
        if p_gzip.wait() != 0:
515
            raise Exception(" Error in rsem-prepare reference ")
516
517
        p_rsem = subprocess.Popen([self.paths.rsem_prepare_reference, '--bowtie', '--bowtie-path', self.paths.bowtie_dir,
518
                            '--gtf', gtf_with_genenames_in_transcript_id, 
519
                            '--polyA', '--polyA-length', '5', genome_filename, self.paths.bowtie_index])
520
521
        if p_rsem.wait() != 0:
522
            raise Exception(" Error in rsem-prepare reference ")
523
524
        print_to_stderr('Finding soft masked regions in transcriptome')
525
        
526
        transcripts_fasta = pyfasta.Fasta(self.paths.bowtie_index + '.transcripts.fa')
527
        soft_mask = {}
528
        for tx, seq in transcripts_fasta.items():
529
            seq = str(seq)
530
            soft_mask[tx] = set((m.start(), m.end()) for m in re.finditer(r'[atcgn]+', seq))
531
        with open(self.paths.bowtie_index + '.soft_masked_regions.pickle', 'w') as out:
532
            pickle.dump(soft_mask, out)
533
534
class IndropsLibrary():
535
536
    def __init__(self, name='', project=None, version=''):
537
        self.project = project
538
        self.name = name
539
        self.parts = []
540
        self.version = version
541
542
        self.paths = {}
543
        for lib_dir in ['filtered_parts', 'quant_dir']:
544
            dir_path = os.path.join(self.project.project_dir, self.name, lib_dir)
545
            self.project.project_check_dir(dir_path)
546
            self.paths[lib_dir] = dir_path
547
        self.paths = type('Paths_anonymous_object',(object,),self.paths)()
548
549
        self.paths.abundant_barcodes_names_filename = os.path.join(self.project.project_dir, self.name, 'abundant_barcodes.pickle')
550
        self.paths.filtering_statistics_filename = os.path.join(self.project.project_dir, self.name, self.name+'.filtering_stats.csv')
551
        self.paths.barcode_abundance_histogram_filename = os.path.join(self.project.project_dir, self.name, self.name+'.barcode_abundance.png')
552
        self.paths.barcode_abundance_by_barcode_filename = os.path.join(self.project.project_dir, self.name, self.name+'.barcode_abundance_by_barcode.png')
553
        self.paths.missing_quants_filename = os.path.join(self.project.project_dir, self.name, self.name+'.missing_barcodes.pickle')
554
555
    @property
556
    def barcode_counts(self):
557
        if not hasattr(self, '_barcode_counts'):
558
            self._barcode_counts = defaultdict(int)
559
            for part in self.parts:
560
                for k, v in part.part_barcode_counts.items():
561
                    self._barcode_counts[k] += v
562
563
        return self._barcode_counts
564
565
    @property
566
    def abundant_barcodes(self):
567
        if not hasattr(self, '_abundant_barcodes'):
568
            with open(self.paths.abundant_barcodes_names_filename) as f:
569
                self._abundant_barcodes = pickle.load(f)
570
        return self._abundant_barcodes
571
572
    def sorted_barcode_names(self, min_reads=0, max_reads=10**10):
573
        return [name for bc,(name,abun) in sorted(self.abundant_barcodes.items(), key=lambda i:-i[1][1]) if (abun>min_reads) & (abun<max_reads)]
574
575
    def identify_abundant_barcodes(self, make_histogram=True, absolute_min_reads=250):
576
        """
577
        Identify which barcodes are above the absolute minimal abundance, 
578
        and make a histogram summarizing the barcode distribution
579
        """
580
        keep_barcodes = []
581
        for k, v in self.barcode_counts.items():
582
            if v > absolute_min_reads:
583
                keep_barcodes.append(k)
584
585
        abundant_barcodes = {}
586
        print_to_stderr(" %d barcodes above absolute minimum threshold" % len(keep_barcodes))
587
        for bc in keep_barcodes:
588
            abundant_barcodes[bc] = (self.project.stable_barcode_names[self.version][bc], self.barcode_counts[bc])
589
590
        self._abundant_barcodes = abundant_barcodes
591
        with open(self.paths.abundant_barcodes_names_filename, 'w') as f:
592
            pickle.dump(abundant_barcodes, f)
593
594
        # Create table about the filtering process
595
        with open(self.paths.filtering_statistics_filename, 'w') as filtering_stats:
596
597
            header = ['Run', 'Part', 'Input Reads', 'Valid Structure', 'Surviving Trimmomatic', 'Surviving polyA trim and complexity filter']
598
599
            if self.version == 'v1' or self.version == 'v2':
600
                structure_parts = ['W1_in_R2', 'empty_read',  'No_W1', 'No_polyT', 'BC1', 'BC2', 'Umi_error']
601
                header += ['W1 in R2', 'empty read',  'No W1 in R1', 'No polyT', 'BC1', 'BC2', 'UMI_contains_N']
602
            elif self.version == 'v3' or self.version == 'v3-miseq':
603
                structure_parts = ['Invalid_BC1', 'Invalid_BC2', 'UMI_contains_N']
604
                header += ['Invalid BC1', 'Invalid BC2', 'UMI_contains_N']
605
606
            trimmomatic_parts = ['dropped']
607
            header += ['Dropped by Trimmomatic']
608
609
            complexity_filter_parts = ['rejected_because_too_short', 'rejected_because_complexity_too_low']
610
            header += ['Too short after polyA trim', 'Read complexity too low']
611
612
            filtering_stats.write(','.join(header)+'\n')
613
614
            for part in self.parts:
615
                with open(part.filtering_metrics_filename) as f:
616
                    part_stats = yaml.load(f)
617
                    line = [part.run_name, part.part_name, part_stats['read_structure']['Total'], part_stats['read_structure']['Valid'], part_stats['trimmomatic']['output'], part_stats['complexity_filter']['output']]
618
                    line += [part_stats['read_structure'][k] if k in part_stats['read_structure'] else 0 for k in structure_parts]
619
                    line += [part_stats['trimmomatic'][k] if k in part_stats['trimmomatic'] else 0 for k in trimmomatic_parts]
620
                    line += [part_stats['complexity_filter'][k] if k in part_stats['complexity_filter'] else 0 for k in complexity_filter_parts]
621
                    line = [str(l) for l in line]
622
                    filtering_stats.write(','.join(line)+'\n')
623
624
        print_to_stderr("Created Library filtering summary:")
625
        print_to_stderr("  " + self.paths.filtering_statistics_filename)
626
 
627
        # Make the histogram figure
628
        if not make_histogram:
629
            return
630
631
        count_freq = defaultdict(int)
632
        for bc, count in self.barcode_counts.items():
633
            count_freq[count] += 1
634
635
        x = np.array(count_freq.keys())
636
        y = np.array(count_freq.values())
637
        w = x*y
638
639
        # need to use non-intenactive Agg backend
640
        import matplotlib
641
        matplotlib.use('Agg')
642
        from matplotlib import pyplot as plt
643
        fig = plt.figure()
644
        ax = fig.add_subplot(111)
645
        ax.hist(x, bins=np.logspace(0, 6, 50), weights=w, color='green')
646
        ax.set_xscale('log')
647
        ax.set_xlabel('Reads per barcode')
648
        ax.set_ylabel('#reads coming from bin')
649
        fig.savefig(self.paths.barcode_abundance_histogram_filename)
650
651
        print_to_stderr("Created Barcode Abundance Histogram at:")
652
        print_to_stderr("  " + self.paths.barcode_abundance_histogram_filename)
653
654
655
        fig = plt.figure()
656
        ax = fig.add_subplot(111)
657
        ax.hist(list(self.barcode_counts.values()), bins=np.logspace(2, 6, 50), color='green')
658
        ax.set_xlim((1, 10**6))
659
        ax.set_xscale('log')
660
        ax.set_xlabel('Reads per barcode')
661
        ax.set_ylabel('# of barcodes')
662
        fig.savefig(self.paths.barcode_abundance_by_barcode_filename)
663
        print_to_stderr("Created Barcode Abundance Histogram by barcodes at:")
664
        print_to_stderr("  " + self.paths.barcode_abundance_by_barcode_filename)
665
666
    def sort_reads_by_barcode(self, index=0):
667
        self.parts[index].sort_reads_by_barcode(self.abundant_barcodes)
668
669
    def get_reads_for_barcode(self, barcode, run_filter=[]):
670
        for part in self.parts:
671
            if (not run_filter) or (part.run_name in run_filter):
672
                for line in part.get_reads_for_barcode(barcode):
673
                    yield line
674
675
    def output_barcode_fastq(self, analysis_prefix='', min_reads=750, max_reads=10**10, total_workers=1, worker_index=0, run_filter=[]):
676
        if analysis_prefix:
677
            analysis_prefix = analysis_prefix + '.'
678
679
        output_dir_path = os.path.join(self.project.project_dir, self.name, 'barcode_fastq')
680
        self.project.project_check_dir(output_dir_path)
681
682
        sorted_barcode_names = self.sorted_barcode_names(min_reads=min_reads, max_reads=max_reads)
683
684
        # Identify which barcodes belong to this worker
685
        barcodes_for_this_worker = []
686
        i = worker_index
687
        while i < len(sorted_barcode_names):
688
            barcodes_for_this_worker.append(sorted_barcode_names[i])
689
            i += total_workers
690
691
        print_to_stderr("""[%s] This worker assigned %d out of %d total barcodes.""" % (self.name, len(barcodes_for_this_worker), len(sorted_barcode_names)))        
692
693
        for barcode in barcodes_for_this_worker:
694
            barcode_fastq_filename = analysis_prefix+'%s.%s.fastq' % (self.name, barcode)
695
            print_to_stderr("  "+barcode_fastq_filename)
696
            with open(os.path.join(output_dir_path, barcode_fastq_filename), 'w') as f:
697
                for line in self.get_reads_for_barcode(barcode, run_filter):
698
                    f.write(line)
699
700
    def quantify_expression(self, analysis_prefix='', max_reads=10**10, min_reads=750, min_counts=0, total_workers=1, worker_index=0, no_bam=False, run_filter=[]):
701
        if analysis_prefix:
702
            analysis_prefix = analysis_prefix + '.'
703
704
        sorted_barcode_names = self.sorted_barcode_names(min_reads=min_reads, max_reads=max_reads)
705
        #print_to_stderr("   min_reads: %d sorted_barcode_names counts: %d" % (min_reads, len(sorted_barcode_names)))
706
707
        # Identify which barcodes belong to this worker
708
        barcodes_for_this_worker = []
709
        i = worker_index
710
        while i < len(sorted_barcode_names):
711
            barcodes_for_this_worker.append(sorted_barcode_names[i])
712
            i += total_workers
713
714
        counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.counts.tsv' % (analysis_prefix, worker_index, total_workers))
715
        ambig_counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.counts.tsv' % (analysis_prefix, worker_index, total_workers))
716
        ambig_partners_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.partners' % (analysis_prefix, worker_index, total_workers))
717
        metrics_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.metrics.tsv' % (analysis_prefix, worker_index, total_workers))
718
        ignored_for_output_filename = counts_output_filename+'.ignored'
719
720
        merged_bam_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.bam'% (analysis_prefix, worker_index, total_workers))
721
        merged_bam_index_filename = merged_bam_filename + '.bai'
722
723
        get_barcode_genomic_bam_filename = lambda bc: os.path.join(self.paths.quant_dir, '%s%s.genomic.sorted.bam' % (analysis_prefix, bc))
724
725
        # If we wanted BAM output, and the merge BAM and merged BAM index are present, then we are done
726
        if (not no_bam) and (os.path.isfile(merged_bam_filename) and os.path.isfile(merged_bam_index_filename)):
727
            print_to_stderr('Indexed, merged BAM file detected for this worker. Done.')
728
            return 
729
730
        # Otherwise, we have to check what we need to quantify
731
732
        
733
        """
734
        Function to determine which barcodes this quantification worker might have already quantified.
735
        This tries to handle interruption during any step of the process.
736
737
        The worker is assigned some list of barcodes L. For every barcode:
738
            - It could have been quantified
739
                - but have less than min_counts ---> so it got written to `ignored` file.
740
                - and quantification succeeded, meaning
741
                    1. there is a line (ending in \n) in the `metrics` file. 
742
                    2. there is a line (ending in \n) in the `quantification` file.
743
                    3. there (could) be a line (ending in \n) in the `ambiguous quantification` file.
744
                    4. there (could) be a line (ending in \n) in the `ambiguous quantification partners` file.
745
                        [If any line doesn't end in \n, then likely the output of that line was interrupted!]
746
                    5. (If BAM output is desired) There should be a sorted genomic BAM
747
                    6. (If BAM output is desired) There should be a sorted genomic BAM index
748
        """
749
        succesfully_previously_quantified = set()
750
        previously_ignored = set()
751
        header_written = False
752
753
        if os.path.isfile(counts_output_filename) and os.path.isfile(metrics_output_filename):
754
            # Load in list of ignored barcodes
755
            if os.path.isfile(ignored_for_output_filename):
756
                with open(ignored_for_output_filename, 'r') as f:
757
                    previously_ignored = set([line.rstrip().split('\t')[0] for line in f])
758
759
            # Load the metrics data into memory
760
            # (It should be fairly small, this is fast and safe)
761
            existing_metrics_data = {}
762
            with open(metrics_output_filename, 'r') as f:
763
                existing_metrics_data = dict((line.partition('\t')[0], line) for line in f if line[-1]=='\n')
764
765
766
            # Quantification data could be large, read it line by line and output it back for barcodes that have a matching metrics line.
767
            with open(counts_output_filename, 'r') as in_counts, \
768
                     open(counts_output_filename+'.tmp', 'w') as tmp_counts, \
769
                     open(metrics_output_filename+'.tmp', 'w') as tmp_metrics:
770
771
                for line in in_counts:
772
                    # The first worker is reponsible for written the header.
773
                    # Make sure we carry that over
774
                    if (not header_written) and (worker_index==0):
775
                        tmp_counts.write(line)
776
                        tmp_metrics.write(existing_metrics_data['Barcode'])
777
                        header_written = True
778
                        continue
779
780
                    # This line has incomplete output, skip it.
781
                    # (This can only happen with the last line)
782
                    if line[-1] != '\n':
783
                        continue
784
785
                    barcode = line.partition('\t')[0]
786
787
                    # Skip barcode if we don't have existing metrics data
788
                    if barcode not in existing_metrics_data:
789
                        continue
790
791
                    # Check if we BAM required BAM files exist
792
                    barcode_genomic_bam_filename = get_barcode_genomic_bam_filename(barcode)
793
                    bam_files_required_and_present = no_bam or (os.path.isfile(barcode_genomic_bam_filename) and os.path.isfile(barcode_genomic_bam_filename+'.bai'))
794
                    if not bam_files_required_and_present:
795
                        continue
796
797
                    # This passed all the required checks, write the line to the temporary output files
798
                    tmp_counts.write(line)
799
                    tmp_metrics.write(existing_metrics_data[barcode])
800
                    succesfully_previously_quantified.add(barcode)
801
802
            shutil.move(counts_output_filename+'.tmp', counts_output_filename)
803
            shutil.move(metrics_output_filename+'.tmp', metrics_output_filename)
804
805
            # For any 'already quantified' barcode, make sure we also copy over the ambiguity data
806
            with open(ambig_counts_output_filename, 'r') as in_f, \
807
                 open(ambig_counts_output_filename+'.tmp', 'w') as tmp_f:
808
                 f_first_line = (worker_index == 0)
809
                 for line in in_f:
810
                    if f_first_line:
811
                        tmp_f.write(line)
812
                        f_first_line = False
813
                        continue
814
                    if (line.partition('\t')[0] in succesfully_previously_quantified) and (line[-1]=='\n'):
815
                        tmp_f.write(line)
816
            shutil.move(ambig_counts_output_filename+'.tmp', ambig_counts_output_filename)
817
818
            with open(ambig_partners_output_filename, 'r') as in_f, \
819
                 open(ambig_partners_output_filename+'.tmp', 'w') as tmp_f:
820
                 for line in in_f:
821
                    if (line.partition('\t')[0] in succesfully_previously_quantified) and (line[-1]=='\n'):
822
                        tmp_f.write(line)
823
            shutil.move(ambig_partners_output_filename+'.tmp', ambig_partners_output_filename)
824
825
        barcodes_to_quantify = [bc for bc in barcodes_for_this_worker if (bc not in succesfully_previously_quantified and bc not in previously_ignored)]
826
827
828
        print_to_stderr("""[%s] This worker assigned %d out of %d total barcodes.""" % (self.name, len(barcodes_for_this_worker), len(sorted_barcode_names)))
829
        if len(barcodes_for_this_worker)-len(barcodes_to_quantify) > 0:
830
            print_to_stderr("""    %d previously quantified, %d previously ignored, %d left for this run.""" % (len(succesfully_previously_quantified), len(previously_ignored), len(barcodes_to_quantify)))
831
        
832
833
834
        print_to_stderr(('{0:<14.12}'.format('Prefix') if analysis_prefix else '') + '{0:<14.12}{1:<9}'.format("Library", "Barcode"), False)
835
        print_to_stderr("{0:<8s}{1:<8s}{2:<10s}".format("Reads", "Counts", "Ambigs"))
836
        for barcode in barcodes_to_quantify:
837
            self.quantify_expression_for_barcode(barcode,
838
                counts_output_filename, metrics_output_filename,
839
                ambig_counts_output_filename, ambig_partners_output_filename,
840
                no_bam=no_bam, write_header=(not header_written) and (worker_index==0), analysis_prefix=analysis_prefix,
841
                min_counts = min_counts, run_filter=run_filter)
842
            header_written = True
843
        print_to_stderr("Per barcode quantification completed.")
844
845
        if no_bam:
846
            return
847
848
        #Gather list of barcodes with output from the metrics file
849
        genomic_bams = []
850
        with open(metrics_output_filename, 'r') as f:
851
            for line in f:
852
                bc = line.partition('\t')[0]
853
                if bc == 'Barcode': #This is the line in the header
854
                    continue
855
                genomic_bams.append(get_barcode_genomic_bam_filename(bc))
856
857
        print_to_stderr("Merging BAM output.")
858
        try:
859
            subprocess.check_output([self.project.paths.samtools, 'merge', '-f', merged_bam_filename]+genomic_bams, stderr=subprocess.STDOUT)
860
        except subprocess.CalledProcessError, err:
861
            print_to_stderr("   CMD: %s" % str(err.cmd)[:400])
862
            print_to_stderr("   stdout/stderr:")
863
            print_to_stderr(err.output)
864
            raise Exception(" === Error in samtools merge === ")
865
866
        print_to_stderr("Indexing merged BAM output.")
867
        try:
868
            subprocess.check_output([self.project.paths.samtools, 'index', merged_bam_filename], stderr=subprocess.STDOUT)
869
        except subprocess.CalledProcessError, err:
870
            print_to_stderr("   CMD: %s" % str(err.cmd)[:400])
871
            print_to_stderr("   stdout/stderr:")
872
            print_to_stderr(err.output)
873
            raise Exception(" === Error in samtools index === ")
874
875
        print(genomic_bams)
876
        for filename in genomic_bams:
877
            os.remove(filename)
878
            os.remove(filename + '.bai')
879
880
    def quantify_expression_for_barcode(self, barcode, counts_output_filename, metrics_output_filename,
881
            ambig_counts_output_filename, ambig_partners_output_filename,
882
            min_counts=0, analysis_prefix='', no_bam=False, write_header=False, run_filter=[]):
883
        print_to_stderr(('{0:<14.12}'.format(analysis_prefix) if analysis_prefix else '') + '{0:<14.12}{1:<9}'.format(self.name, barcode), False)
884
885
        unaligned_reads_output = os.path.join(self.paths.quant_dir, '%s%s.unaligned.fastq' % (analysis_prefix,barcode))
886
        aligned_bam = os.path.join(self.paths.quant_dir, '%s%s.aligned.bam' % (analysis_prefix,barcode))
887
888
        # Bowtie command
889
        bowtie_cmd = [self.project.paths.bowtie, self.project.paths.bowtie_index, '-q', '-',
890
            '-p', '1', '-a', '--best', '--strata', '--chunkmbs', '1000', '--norc', '--sam',
891
            '-shmem', #should sometimes reduce memory usage...?
892
            '-m', str(self.project.parameters['bowtie_arguments']['m']),
893
            '-n', str(self.project.parameters['bowtie_arguments']['n']),
894
            '-l', str(self.project.parameters['bowtie_arguments']['l']),
895
            '-e', str(self.project.parameters['bowtie_arguments']['e']),
896
            ]
897
        if self.project.parameters['output_arguments']['output_unaligned_reads_to_other_fastq']:
898
            bowtie_cmd += ['--un', unaligned_reads_output]
899
900
        # Quantification command
901
        script_dir = os.path.dirname(os.path.realpath(__file__))
902
        quant_cmd = [self.project.paths.python, self.project.paths.quantify_umifm_from_alignments_py,
903
            '-m', str(self.project.parameters['umi_quantification_arguments']['m']),
904
            '-u', str(self.project.parameters['umi_quantification_arguments']['u']),
905
            '-d', str(self.project.parameters['umi_quantification_arguments']['d']),
906
            '--min_non_polyA', str(self.project.parameters['umi_quantification_arguments']['min_non_polyA']),
907
            '--library', str(self.name),
908
            '--barcode', str(barcode),
909
            '--counts', counts_output_filename,
910
            '--metrics', metrics_output_filename,
911
            '--ambigs', ambig_counts_output_filename,
912
            '--ambig-partners', ambig_partners_output_filename,
913
            '--min-counts', str(min_counts),
914
        ]
915
        if not no_bam:
916
            quant_cmd += ['--bam', aligned_bam]
917
        if write_header:
918
            quant_cmd += ['--write-header']
919
920
        if self.project.parameters['umi_quantification_arguments']['split-ambigs']:
921
            quant_cmd.append('--split-ambig')
922
        if self.project.parameters['output_arguments']['filter_alignments_to_softmasked_regions']:
923
            quant_cmd += ['--soft-masked-regions', self.project.paths.bowtie_index + '.soft_masked_regions.pickle']
924
925
        # Spawn processes
926
927
        p1 = subprocess.Popen(bowtie_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
928
        p2 = subprocess.Popen(quant_cmd, stdin=p1.stdout, stderr=subprocess.PIPE)
929
        
930
                
931
        for line in self.get_reads_for_barcode(barcode, run_filter=run_filter):
932
            try:
933
                p1.stdin.write(line)
934
            except IOError as e:
935
                print_to_stderr('\n')
936
                print_to_stderr(p1.stderr.read())
937
                raise Exception('\n === Error on piping data to bowtie ===')
938
939
940
        p1.stdin.close()
941
942
        if p1.wait() != 0:
943
            print_to_stderr('\n')
944
            print_to_stderr(p1.stderr.read())
945
            raise Exception('\n === Error on bowtie ===')
946
947
        if p2.wait() != 0:
948
            print_to_stderr(p2.stderr.read())
949
            raise Exception('\n === Error on Quantification Script ===')
950
        print_to_stderr(p2.stderr.read(), False)
951
952
        if no_bam:
953
            # We are done here
954
            return False
955
956
        if not os.path.isfile(aligned_bam):
957
            raise Exception("\n === No aligned bam was output for barcode %s ===" % barcode)
958
959
        genomic_bam = os.path.join(self.paths.quant_dir, '%s%s.genomic.bam' % (analysis_prefix,barcode))
960
        sorted_bam = os.path.join(self.paths.quant_dir, '%s%s.genomic.sorted.bam' % (analysis_prefix,barcode))
961
        try:
962
            subprocess.check_output([self.project.paths.rsem_tbam2gbam, self.project.paths.bowtie_index, aligned_bam, genomic_bam], stderr=subprocess.STDOUT)
963
        except subprocess.CalledProcessError, err:
964
            print_to_stderr("   CMD: %s" % str(err.cmd)[:100])
965
            print_to_stderr("   stdout/stderr:")
966
            print_to_stderr(err.output)
967
            raise Exception(" === Error in rsem-tbam2gbam === ")
968
969
        try:
970
            subprocess.check_output([self.project.paths.samtools, 'sort', '-o', sorted_bam, genomic_bam], stderr=subprocess.STDOUT)
971
        except subprocess.CalledProcessError, err:
972
            print_to_stderr("   CMD: %s" % str(err.cmd)[:100])
973
            print_to_stderr("   stdout/stderr:")
974
            print_to_stderr(err.output)
975
            raise Exception(" === Error in samtools sort === ")
976
977
        try:
978
            subprocess.check_output([self.project.paths.samtools, 'index', sorted_bam], stderr=subprocess.STDOUT)
979
        except subprocess.CalledProcessError, err:
980
            print_to_stderr("   CMD: %s" % str(err.cmd)[:100])
981
            print_to_stderr("   stdout/stderr:")
982
            print_to_stderr(err.output)
983
            raise Exception(" === Error in samtools index === ")
984
985
        os.remove(aligned_bam)
986
        os.remove(genomic_bam)
987
988
989
        return True
990
991
    def aggregate_counts(self, analysis_prefix='', process_ambiguity_data=False):
992
        if analysis_prefix:
993
            analysis_prefix = analysis_prefix + '.'
994
            quant_output_files = [fn[len(analysis_prefix):].split('.')[0] for fn in os.listdir(self.paths.quant_dir) if ('worker' in fn and fn[:len(analysis_prefix)]==analysis_prefix)]
995
        else:
996
            quant_output_files = [fn.split('.')[0] for fn in os.listdir(self.paths.quant_dir) if (fn[:6]=='worker')]
997
        
998
        worker_names = [w[6:] for w in quant_output_files]
999
        worker_indices = set(int(w.split('_')[0]) for w in worker_names)
1000
1001
        total_workers = set(int(w.split('_')[1]) for w in worker_names)
1002
        if len(total_workers) > 1:
1003
            raise Exception("""Quantification for library %s, prefix '%s' was run with different numbers of total_workers.""" % (self.name, analysis_prefix))
1004
        total_workers = list(total_workers)[0]
1005
1006
        missing_workers = []
1007
        for i in range(total_workers):
1008
            if i not in worker_indices:
1009
                missing_workers.append(i)
1010
        if missing_workers:
1011
            missing_workers = ','.join([str(i) for i in sorted(missing_workers)])
1012
            raise Exception("""Output from workers %s (total %d) is missing. """ % (missing_workers, total_workers))
1013
1014
        aggregated_counts_filename = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'counts.tsv')
1015
        aggregated_quant_metrics_filename = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'quant_metrics.tsv')
1016
        aggregated_ignored_filename = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'ignored_barcodes.txt')
1017
        aggregated_bam_output = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'bam')
1018
1019
        aggregated_ambig_counts_filename = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'ambig_counts.tsv')
1020
        aggregated_ambig_partners_filename = os.path.join(self.project.project_dir, self.name, self.name+'.'+analysis_prefix+'ambig_partners.tsv')
1021
1022
        agg_counts = open(aggregated_counts_filename, mode='w')
1023
        agg_metrics = open(aggregated_quant_metrics_filename, mode='w')
1024
        agg_ignored = open(aggregated_ignored_filename, mode='w')
1025
        if process_ambiguity_data:
1026
            agg_ambigs = open(aggregated_ambig_counts_filename, mode='w')
1027
            agg_ambig_partners = open(aggregated_ambig_partners_filename, mode='w')
1028
1029
        end_of_counts_header = 0
1030
        end_of_metrics_header = 0
1031
        end_of_ambigs_header = 0
1032
        print_to_stderr('  Concatenating output from all workers.')
1033
        for worker_index in range(total_workers):
1034
            counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.counts.tsv' % (analysis_prefix, worker_index, total_workers))
1035
            ambig_counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.counts.tsv' % (analysis_prefix, worker_index, total_workers))
1036
            ambig_partners_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.partners' % (analysis_prefix, worker_index, total_workers))
1037
            metrics_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.metrics.tsv' % (analysis_prefix, worker_index, total_workers))
1038
            ignored_for_output_filename = counts_output_filename+'.ignored'
1039
1040
            # Counts
1041
            with open(counts_output_filename, 'r') as f:
1042
                shutil.copyfileobj(f, agg_counts)
1043
1044
            # Metrics
1045
            with open(metrics_output_filename, 'r') as f:
1046
                shutil.copyfileobj(f, agg_metrics)
1047
1048
            # Ignored
1049
            if os.path.isfile(counts_output_filename+'.ignored'):
1050
                with open(counts_output_filename+'.ignored', 'r') as f:
1051
                    shutil.copyfileobj(f, agg_ignored)
1052
1053
            if process_ambiguity_data:
1054
                with open(ambig_counts_output_filename, 'r') as f:
1055
                    shutil.copyfileobj(f, agg_ambigs)
1056
1057
                with open(ambig_partners_output_filename, 'r') as f:
1058
                    shutil.copyfileobj(f, agg_ambig_partners)
1059
1060
        print_to_stderr('  GZIPping concatenated output.')
1061
        agg_counts.close()
1062
        subprocess.Popen(['gzip', '-f', aggregated_counts_filename]).wait()
1063
        agg_metrics.close()
1064
        subprocess.Popen(['gzip', '-f', aggregated_quant_metrics_filename]).wait()
1065
        print_to_stderr('Aggregation completed in %s.gz' % aggregated_counts_filename)
1066
1067
        if process_ambiguity_data:
1068
            agg_ambigs.close()
1069
            subprocess.Popen(['gzip', '-f', aggregated_ambig_counts_filename]).wait()
1070
            agg_ambig_partners.close()
1071
            subprocess.Popen(['gzip', '-f', aggregated_ambig_partners_filename]).wait()
1072
1073
        target_bams = [os.path.join(self.paths.quant_dir, '%sworker%d_%d.bam'% (analysis_prefix, worker_index, total_workers)) for worker_index in range(total_workers)]
1074
        target_bams = [t for t in target_bams if os.path.isfile(t)]
1075
        if target_bams:
1076
            print_to_stderr('  Merging BAM files.')
1077
            p1 = subprocess.Popen([self.project.paths.samtools, 'merge', '-f', aggregated_bam_output]+target_bams, stderr=subprocess.PIPE, stdout=subprocess.PIPE)
1078
            if p1.wait() == 0:
1079
                print_to_stderr('  Indexing merged BAM file.')
1080
                p2 = subprocess.Popen([self.project.paths.samtools, 'index', aggregated_bam_output], stderr=subprocess.PIPE, stdout=subprocess.PIPE)
1081
                if p2.wait() == 0:
1082
                    for filename in target_bams:
1083
                        os.remove(filename)
1084
                        os.remove(filename + '.bai')
1085
                else:
1086
                    print_to_stderr(" === Error in samtools index ===")
1087
                    print_to_stderr(p2.stderr.read())
1088
            else:
1089
                print_to_stderr(" === Error in samtools merge ===")
1090
            print_to_stderr(p1.stderr.read())     
1091
1092
        # print_to_stderr('Deleting per-worker counts files.')
1093
        # for worker_index in range(total_workers):
1094
        #     counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.counts.tsv' % (analysis_prefix, worker_index, total_workers))
1095
        #     os.remove(counts_output_filename)
1096
1097
        #     ambig_counts_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.counts.tsv' % (analysis_prefix, worker_index, total_workers))
1098
        #     os.remove(ambig_counts_output_filename)
1099
1100
        #     ambig_partners_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.ambig.partners' % (analysis_prefix, worker_index, total_workers))
1101
        #     os.remove(ambig_partners_output_filename)
1102
1103
        #     metrics_output_filename = os.path.join(self.paths.quant_dir, '%sworker%d_%d.metrics.tsv' % (analysis_prefix, worker_index, total_workers))
1104
        #     os.remove(metrics_output_filename)
1105
1106
        #     ignored_for_output_filename = counts_output_filename+'.ignored'
1107
        #     os.remove(ignored_for_output_filename)
1108
1109
1110
class LibrarySequencingPart():
1111
    def __init__(self, filtered_fastq_filename=None, project=None, run_name='', library_name='', part_name=''):
1112
        self.project = project
1113
        self.run_name = run_name
1114
        self.part_name = part_name
1115
        self.library_name = library_name
1116
        self.filtered_fastq_filename = filtered_fastq_filename
1117
        self.barcode_counts_pickle_filename = filtered_fastq_filename + '.counts.pickle'
1118
        self.filtering_metrics_filename = '.'.join(filtered_fastq_filename.split('.')[:-1]) + 'metrics.yaml'
1119
1120
        self.sorted_gzipped_fastq_filename = filtered_fastq_filename + '.sorted.fastq.gz'
1121
        self.sorted_gzipped_fastq_index_filename = filtered_fastq_filename + '.sorted.fastq.gz.index.pickle'
1122
1123
    @property
1124
    def is_filtered(self):
1125
        if not hasattr(self, '_is_filtered'):
1126
            self._is_filtered = os.path.exists(self.filtered_fastq_filename) and os.path.exists(self.barcode_counts_pickle_filename)
1127
        return self._is_filtered
1128
    
1129
    @property
1130
    def is_sorted(self):
1131
        if not hasattr(self, '_is_sorted'):
1132
            self._is_sorted = os.path.exists(self.sorted_gzipped_fastq_filename) and os.path.exists(self.sorted_gzipped_fastq_index_filename)
1133
        return self._is_sorted
1134
1135
    @property
1136
    def part_barcode_counts(self):
1137
        if not hasattr(self, '_part_barcode_counts'):
1138
            with open(self.barcode_counts_pickle_filename, 'r') as f:
1139
                self._part_barcode_counts = pickle.load(f)
1140
        return self._part_barcode_counts
1141
1142
    @property
1143
    def sorted_index(self):
1144
        if not hasattr(self, '_sorted_index'):
1145
            with open(self.sorted_gzipped_fastq_index_filename, 'r') as f:
1146
                self._sorted_index = pickle.load(f)
1147
        return self._sorted_index
1148
1149
    def contains_library_in_query(self, query_libraries):
1150
        return self.library_name in query_libraries
1151
1152
    def sort_reads_by_barcode(self, abundant_barcodes={}):
1153
        sorted_barcodes = [j for j,v in sorted(abundant_barcodes.items(), key=lambda i:-i[1][1])]
1154
        sorted_barcodes = [j for j in sorted_barcodes if j in self.part_barcode_counts]
1155
1156
        barcode_buffers = {}
1157
        barcode_gzippers = {}
1158
        for bc in sorted_barcodes + ['ignored']:
1159
            barcode_buffers[bc] = BytesIO()
1160
            barcode_gzippers[bc] = gzip.GzipFile(fileobj=barcode_buffers[bc], mode='wb')
1161
1162
        total_processed_reads = 0
1163
        total_ignored_reads = 0
1164
        bcs_with_data = set()
1165
        bcs_with_tmp_data = set()
1166
        barcode_tmp_filename = lambda bc: '%s.%s.tmp.gz' % (self.sorted_gzipped_fastq_filename, bc)
1167
1168
1169
        total_reads = sum(self.part_barcode_counts.values())
1170
        print_to_stderr('Sorting %d reads from %d barcodes above absolute minimum threshold.' % (total_reads, len(abundant_barcodes)))
1171
        with open(self.filtered_fastq_filename, 'r') as input_fastq:
1172
            for name, seq, qual in from_fastq(input_fastq):
1173
                total_processed_reads += 1
1174
                bc = name.split(':')[0]
1175
1176
                if total_processed_reads%1000000 == 0:
1177
                    print_to_stderr('Read in %.02f percent of all reads (%d)' % (100.*total_processed_reads/total_reads, total_processed_reads))
1178
                
1179
                if bc in abundant_barcodes:
1180
                    barcode_gzippers[bc].write(to_fastq(name, seq, qual))
1181
                    bcs_with_data.add(bc)
1182
                else:
1183
                    total_ignored_reads += 1
1184
                    barcode_gzippers['ignored'].write(to_fastq(name, seq, qual))
1185
                    bcs_with_data.add('ignored')
1186
1187
1188
        sorted_output_index = {}
1189
        with open(self.sorted_gzipped_fastq_filename, 'wb') as sorted_output:
1190
            for original_bc in sorted_barcodes + ['ignored']:
1191
                if original_bc != 'ignored':
1192
                    new_bc_name = abundant_barcodes[original_bc][0]
1193
                    barcode_reads_count = self.part_barcode_counts[original_bc]
1194
                else:
1195
                    new_bc_name = 'ignored'
1196
                    barcode_reads_count = total_ignored_reads
1197
1198
                start_pos = sorted_output.tell()
1199
                barcode_gzippers[original_bc].close()
1200
                if original_bc in bcs_with_data:
1201
                    barcode_buffers[original_bc].seek(0)
1202
                    shutil.copyfileobj(barcode_buffers[original_bc], sorted_output)
1203
                barcode_buffers[original_bc].close()
1204
                end_pos = sorted_output.tell()
1205
1206
                if end_pos > start_pos:
1207
                    sorted_output_index[new_bc_name] = (original_bc, start_pos, end_pos, end_pos-start_pos, barcode_reads_count)
1208
1209
        with open(self.sorted_gzipped_fastq_index_filename, 'w') as f:
1210
            pickle.dump(sorted_output_index, f)      
1211
1212
    def get_reads_for_barcode(self, barcode):
1213
        if barcode not in self.sorted_index:
1214
            raise StopIteration
1215
1216
        original_barcode, start_byte_offset, end_byte_offset, byte_length, barcode_reads = self.sorted_index[barcode]
1217
1218
        with open(self.sorted_gzipped_fastq_filename, 'rb') as sorted_output:
1219
            sorted_output.seek(start_byte_offset)
1220
            byte_buffer = BytesIO(sorted_output.read(byte_length))
1221
            ungzipper = gzip.GzipFile(fileobj=byte_buffer, mode='rb')
1222
            while True:
1223
                yield next(ungzipper)
1224
1225
    @contextmanager
1226
    def trimmomatic_and_low_complexity_filter_process(self):
1227
        """
1228
        We start 3 processes that are connected with Unix pipes.
1229
1230
        Process 1 - Trimmomatic. Doesn't support stdin/stdout, so we instead use named pipes (FIFOs). It reads from FIFO1, and writes to FIFO2. 
1231
        Process 2 - In line complexity filter, a python script. It reads from FIFO2 (Trimmomatic output) and writes to the ouput file. 
1232
        Process 3 - Indexer that counts the number of reads for every barcode. This reads from stdin, writes the reads to stdout and writes the index as a pickle to stderr.
1233
1234
        When these are done, we start another process to count the results on the FastQ file.
1235
        """
1236
        filtered_dir = os.path.dirname(self.filtered_fastq_filename) #We will use the same directory for creating temporary FIFOs, assuming we have write access.
1237
        
1238
        self.filtering_statistics_counter = defaultdict(int)
1239
        with FIFO(dir=filtered_dir) as fifo2, open(self.filtered_fastq_filename, 'w') as filtered_fastq_file, open(self.filtered_fastq_filename+'.counts.pickle', 'w') as filtered_index_file:
1240
            
1241
            low_complexity_filter_cmd = [self.project.paths.python, self.project.paths.trim_polyA_and_filter_low_complexity_reads_py,
1242
                '-input', fifo2.filename, 
1243
                '--min-post-trim-length', self.project.parameters['trimmomatic_arguments']['MINLEN'],
1244
                '--max-low-complexity-fraction', str(self.project.parameters['low_complexity_filter_arguments']['max_low_complexity_fraction']),
1245
                ]
1246
            counter_cmd = [self.project.paths.python,  self.project.paths.count_barcode_distribution_py]
1247
1248
            p2 = subprocess.Popen(low_complexity_filter_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
1249
            p3 = subprocess.Popen(counter_cmd, stdin=p2.stdout, stdout=filtered_fastq_file, stderr=filtered_index_file)
1250
1251
            with FIFO(dir=filtered_dir) as fifo1:
1252
1253
                trimmomatic_cmd = [self.project.paths.java, '-Xmx500m', '-jar', self.project.paths.trimmomatic_jar,
1254
                        'SE', '-threads', "1", '-phred33', fifo1.filename, fifo2.filename]
1255
                for arg in self.project.parameters['trimmomatic_arguments']['argument_order']:
1256
                    val = self.project.parameters['trimmomatic_arguments'][arg]
1257
                    trimmomatic_cmd.append('%s:%s' % (arg, val))
1258
1259
                p1 = subprocess.Popen(trimmomatic_cmd, stderr=subprocess.PIPE)
1260
1261
                fifo1_filehandle = open(fifo1.filename, 'w')
1262
                yield fifo1_filehandle
1263
                fifo1_filehandle.close()
1264
                trimmomatic_stderr = p1.stderr.read().splitlines()
1265
                if trimmomatic_stderr[2] != 'TrimmomaticSE: Completed successfully':
1266
                    raise Exception('Trimmomatic did not complete succesfully on %s' % filtered_filename)
1267
                trimmomatic_metrics = trimmomatic_stderr[1].split() 
1268
                # ['Input', 'Reads:', #READS, 'Surviving:', #SURVIVING, (%SURVIVING), 'Dropped:', #DROPPED, (%DROPPED)]
1269
                trimmomatic_metrics = {'input' : trimmomatic_metrics[2], 'output': trimmomatic_metrics[4], 'dropped': trimmomatic_metrics[7]}
1270
                p1.wait()
1271
1272
            complexity_filter_metrics = pickle.load(p2.stderr)
1273
            p2.wait()
1274
            p3.wait()
1275
1276
1277
        filtering_metrics = {
1278
            'read_structure' : dict(self.filtering_statistics_counter),
1279
            'trimmomatic' : trimmomatic_metrics,
1280
            'complexity_filter': complexity_filter_metrics,
1281
        }
1282
        with open(self.filtering_metrics_filename, 'w') as f:
1283
            yaml.dump(dict(filtering_metrics), f, default_flow_style=False)
1284
1285
1286
class V1V2Filtering(LibrarySequencingPart):
1287
1288
    def __init__(self, bioread_filename=None, metaread_filename=None, *args, **kwargs):
1289
1290
        self.bioread_filename = bioread_filename
1291
        self.metaread_filename = metaread_filename
1292
        LibrarySequencingPart.__init__(self, *args, **kwargs)
1293
1294
1295
    def filter_and_count_reads(self):
1296
        """
1297
        Input the two raw FastQ files
1298
        Output: 
1299
            - A single fastQ file that uses the read name to store the barcoding information
1300
            - A pickle of the number of reads originating from each barcode 
1301
        """
1302
        # Relevant paths
1303
        r1_filename, r2_filename = self.metaread_filename, self.bioread_filename
1304
1305
        #Get barcode neighborhoods
1306
        bc1s = self.project.gel_barcode1_revcomp_list_neighborhood
1307
        bc2s = self.project.gel_barcode2_revcomp_list_neighborhood 
1308
1309
1310
        # This starts a Trimmomatic process, a low complexity filter process, and will 
1311
        # upon closing, start the barcode distribution counting process.
1312
        last_ping = time.time()
1313
        ping_every_n_reads = 1000000
1314
        ping_header = "{0:>12}{1:>16}{2:>12}{3:>10}{4:>10}{5:>10}{6:>10}{7:>10}{8:>10}{9:>10}"
1315
        ping_header = ping_header.format("Total Reads", "", "Valid Reads", "W1 in R2", "Empty", "No W1", "No polyT", "No BC1", "No BC2", "No UMI")
1316
        ping_template = "{total:12d}    {rate:5.1f} sec/M {Valid:12.1%}{W1_in_R2:10.1%}{empty_read:10.1%}{No_W1:10.1%}{No_polyT:10.1%}{BC1:10.1%}{BC2:10.1%}{Umi_error:10.1%}"
1317
        def print_ping_to_log(last_ping):
1318
            sec_per_mil = (time.time()-last_ping)/(ping_every_n_reads/10**6) if last_ping else 0.0
1319
            total = self.filtering_statistics_counter['Total']
1320
            if total > 0:
1321
                ping_format_data = {k: float(self.filtering_statistics_counter[k])/total for k in ['Valid', 'W1_in_R2', 'empty_read',  'No_W1', 'No_polyT', 'BC1', 'BC2', 'Umi_error']}
1322
                print_to_stderr(ping_template.format(total=total, rate=sec_per_mil, **ping_format_data))
1323
1324
1325
        with self.trimmomatic_and_low_complexity_filter_process() as trim_process:
1326
            #Iterate over the weaved reads
1327
            for r_name, r1_seq, r1_qual, r2_seq, r2_qual in self._weave_fastqs(r1_filename, r2_filename):
1328
                    
1329
                # Check if they should be kept
1330
                keep, result = self._process_reads(r1_seq, r2_seq, valid_bc1s=bc1s, valid_bc2s=bc2s)
1331
1332
                # Write the the reads worth keeping
1333
                if keep:
1334
                    bc, umi = result
1335
                    trim_process.write(to_fastq_lines(bc, umi, r2_seq, r2_qual, r_name))
1336
                    self.filtering_statistics_counter['Valid'] += 1
1337
                else:
1338
                    self.filtering_statistics_counter[result] += 1
1339
1340
                # Track speed per M reads
1341
                self.filtering_statistics_counter['Total'] += 1
1342
                if self.filtering_statistics_counter['Total']%(10*ping_every_n_reads) == 1:
1343
                    print_to_stderr(ping_header)
1344
1345
                if self.filtering_statistics_counter['Total']%ping_every_n_reads == 0:
1346
                    print_ping_to_log(last_ping)
1347
                    last_ping = time.time()
1348
1349
            print_ping_to_log(False)
1350
1351
        print_to_stderr(self.filtering_statistics_counter)
1352
1353
    def _weave_fastqs(self, r1_fastq, r2_fastq):
1354
        """
1355
        Merge 2 FastQ files by returning paired reads for each.
1356
        Returns only R1_seq, R2_seq and R2_qual.
1357
        """
1358
1359
        is_gz_compressed = False
1360
        is_bz_compressed = False
1361
        if r1_fastq.split('.')[-1] == 'gz' and r2_fastq.split('.')[-1] == 'gz':
1362
            is_gz_compressed = True
1363
            
1364
        #Added bz2 support VS
1365
        if r1_fastq.split('.')[-1] == 'bz2' and r2_fastq.split('.')[-1] == 'bz2':
1366
            is_bz_compressed = True
1367
1368
        # Decompress Gzips using subprocesses because python gzip is incredibly slow.
1369
        if is_gz_compressed:    
1370
            r1_gunzip = subprocess.Popen("gzip --stdout -d %s" % (r1_fastq), shell=True, stdout=subprocess.PIPE)
1371
            r1_stream = r1_gunzip.stdout
1372
            r2_gunzip = subprocess.Popen("gzip --stdout -d %s" % (r2_fastq), shell=True, stdout=subprocess.PIPE)
1373
            r2_stream = r2_gunzip.stdout
1374
        elif is_bz_compressed:
1375
            r1_bunzip = subprocess.Popen("bzcat %s" % (r1_fastq), shell=True, stdout=subprocess.PIPE)
1376
            r1_stream = r1_bunzip.stdout
1377
            r2_bunzip = subprocess.Popen("bzcat %s" % (r2_fastq), shell=True, stdout=subprocess.PIPE)
1378
            r2_stream = r2_bunzip.stdout
1379
        else:
1380
            r1_stream = open(r1_fastq, 'r')
1381
            r2_stream = open(r2_fastq, 'r')
1382
1383
        while True:
1384
            #Read 4 lines from each FastQ
1385
            name = next(r1_stream).rstrip()[1:].split()[0] #Read name
1386
            r1_seq = next(r1_stream).rstrip() #Read seq
1387
            next(r1_stream) #+ line
1388
            r1_qual = next(r1_stream).rstrip() #Read qual
1389
            
1390
            next(r2_stream) #Read name
1391
            r2_seq = next(r2_stream).rstrip() #Read seq
1392
            next(r2_stream) #+ line
1393
            r2_qual = next(r2_stream).rstrip() #Read qual
1394
            
1395
            # changed to allow for empty reads (caused by adapter trimming)
1396
            if name:
1397
                yield name, r1_seq, r1_qual, r2_seq, r2_qual
1398
            else:
1399
            # if not r1_seq or not r2_seq:
1400
                break
1401
1402
        r1_stream.close()
1403
        r2_stream.close()
1404
1405
    def _process_reads(self, name, read, valid_bc1s={}, valid_bc2s={}):
1406
        """
1407
        Returns either:
1408
            True, (barcode, umi)
1409
                (if read passes filter)
1410
            False, name of filter that failed
1411
                (for stats collection)
1412
        
1413
        R1 anatomy: BBBBBBBB[BBB]WWWWWWWWWWWWWWWWWWWWWWCCCCCCCCUUUUUUTTTTTTTTTT______________
1414
            B = Barcode1, can be 8, 9, 10 or 11 bases long.
1415
            W = 'W1' sequence, specified below
1416
            C = Barcode2, always 8 bases
1417
            U = UMI, always 6 bases
1418
            T = Beginning of polyT tail.
1419
            _ = Either sequencing survives across the polyT tail, or signal starts dropping off
1420
                (and start being anything, likely with poor quality)
1421
        """
1422
1423
        minimal_polyT_len_on_R1 = 7
1424
        hamming_threshold_for_W1_matching = 3
1425
1426
        w1 = "GAGTGATTGCTTGTGACGCCTT"
1427
        rev_w1 = "AAGGCGTCACAAGCAATCACTC" #Hard-code so we don't recompute on every one of millions of calls
1428
        # If R2 contains rev_W1, this is almost certainly empty library
1429
        if rev_w1 in read:
1430
            return False, 'W1_in_R2'
1431
1432
        # # With reads sufficiently long, we will often see a PolyA sequence in R2. 
1433
        # if polyA in read:
1434
        #     return False, 'PolyA_in_R2'
1435
1436
        # Check for polyT signal at 3' end.
1437
        # 44 is the length of BC1+W1+BC2+UMI, given the longest PolyT
1438
        #BC1: 8-11 bases
1439
        #W1 : 22 bases
1440
        #BC2: 8 bases
1441
        #UMI: 6 bases
1442
1443
        # check for empty reads (due to adapter trimming)
1444
        if not read:
1445
            return False, 'empty_read'
1446
        
1447
        #Check for W1 adapter
1448
        #Allow for up to hamming_threshold errors
1449
        if w1 in name:
1450
            w1_pos = name.find(w1)
1451
            if not 7 < w1_pos < 12:
1452
                return False, 'No_W1'
1453
        else:
1454
            #Try to find W1 adapter at start positions 8-11
1455
            #by checking hamming distance to W1.
1456
            for w1_pos in range(8, 12):
1457
                if string_hamming_distance(w1, name[w1_pos:w1_pos+22]) <= hamming_threshold_for_W1_matching:
1458
                    break
1459
            else:
1460
                return False, 'No_W1'
1461
                
1462
        bc2_pos=w1_pos+22
1463
        umi_pos=bc2_pos+8
1464
        polyTpos=umi_pos+6
1465
        expected_poly_t = name[polyTpos:polyTpos+minimal_polyT_len_on_R1]
1466
        if string_hamming_distance(expected_poly_t, 'T'*minimal_polyT_len_on_R1) > 3:
1467
                 return False, 'No_polyT'
1468
            
1469
        bc1 = str(name[:w1_pos])
1470
        bc2 = str(name[bc2_pos:umi_pos])
1471
        umi = str(name[umi_pos:umi_pos+6])
1472
        
1473
        #Validate barcode (and try to correct when there is no ambiguity)
1474
        if valid_bc1s and valid_bc2s:
1475
            # Check if BC1 and BC2 can be mapped to expected barcodes
1476
            if bc1 in valid_bc1s:
1477
                # BC1 might be a neighboring BC, rather than a valid BC itself. 
1478
                bc1 = valid_bc1s[bc1]
1479
            else:
1480
                return False, 'BC1'
1481
            if bc2 in valid_bc2s:
1482
                bc2 = valid_bc2s[bc2]
1483
            else:
1484
                return False, 'BC2'
1485
            if 'N' in umi:
1486
                return False, 'UMI_error'
1487
        bc = '%s-%s'%(bc1, bc2)
1488
        return True, (bc, umi)
1489
1490
class V3Demultiplexer():
1491
1492
    def __init__(self, library_indices, project=None, part_filename="", input_filename="", run_name="", part_name="", run_version_details="v3"):
1493
1494
        self.run_version_details = run_version_details
1495
        self.input_filename = input_filename
1496
        self.project = project
1497
        self.run_name = run_name
1498
        self.part_name = part_name
1499
        self.libraries = {}
1500
        for lib in library_indices:
1501
            lib_index = lib['library_index']
1502
            lib_name = lib['library_name']
1503
            library_part_filename = part_filename.format(library_name=lib_name, library_index=lib_index)
1504
            self.libraries[lib_index] = LibrarySequencingPart(filtered_fastq_filename=library_part_filename, project=project, run_name=run_name, library_name=lib_name, part_name=part_name)
1505
1506
    def _weave_fastqs(self, fastqs):
1507
        last_extension = [fn.split('.')[-1] for fn in fastqs]
1508
        if all(ext == 'gz' for ext in last_extension):
1509
            processes = [subprocess.Popen("gzip --stdout -d %s" % (fn), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) for fn in fastqs]
1510
            streams = [r.stdout for r in processes]
1511
        elif all(ext == 'bz2' for ext in last_extension):
1512
            processes = [subprocess.Popen("bzcat %s" % (fn), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) for fn in fastqs]
1513
            streams = [r.stdout for r in processes]
1514
        elif all(ext == 'fastq' for ext in last_extension):
1515
            streams = [open(fn, 'r') for fn in fastqs]
1516
        else:
1517
            raise("ERROR: Different files are compressed differently. Check input.")
1518
1519
        while True:
1520
            names = [next(s)[:-1].split()[0] for s in streams]
1521
            seqs = [next(s)[:-1] for s in streams]
1522
            blanks = [next(s)[:-1]  for s in streams]
1523
            quals = [next(s)[:-1]  for s in streams]
1524
1525
            assert all(name==names[0] for name in names)
1526
            yield names[0], seqs, quals
1527
1528
        for s in streams:
1529
            s.close()
1530
1531
1532
    def _process_reads(self, name, seqs, quals, valid_bc1s={}, valid_bc2s={}, valid_libs={}):
1533
        """
1534
        Returns either:
1535
            True, (barcode, umi)
1536
                (if read passes filter)
1537
            False, name of filter that failed
1538
                (for stats collection)
1539
        """
1540
1541
        r1, r2, r3, r4 = seqs
1542
        if self.run_version_details=='v3-miseq':
1543
            r2 = rev_comp(r2)
1544
            r4 = rev_comp(r4)
1545
1546
        if r3 in valid_libs:
1547
            lib_index = valid_libs[r3]
1548
        else:
1549
            return False, r3, 'Invalid_library_index'
1550
1551
        if r2 in valid_bc1s:
1552
            bc1 = valid_bc1s[r2]
1553
        else:
1554
            return False, lib_index, 'Invalid_BC1'
1555
1556
        orig_bc2 = r4[:8]
1557
        umi = r4[8:8+6]
1558
        polyA = r4[8+6:]
1559
1560
        if orig_bc2 in valid_bc2s:
1561
            bc2 = valid_bc2s[orig_bc2]
1562
        else:
1563
            return False, lib_index, 'Invalid_BC2'
1564
1565
        if 'N' in umi:
1566
            return False, lib_index, 'UMI_contains_N'
1567
1568
        final_bc = '%s-%s' % (bc1, bc2)
1569
        return True, lib_index, (final_bc, umi)
1570
1571
1572
    def filter_and_count_reads(self):
1573
        # Prepare error corrected index sets
1574
        self.sequence_to_index_mapping = {}
1575
        libs = self.libraries.keys()
1576
        self.sequence_to_index_mapping = dict(zip(libs, libs))
1577
        index_neighborhoods = [set(seq_neighborhood(lib, 1)) for lib in libs]
1578
        for lib, clibs in zip(libs, index_neighborhoods):
1579
            # Quick check that error-correction maps to a single index
1580
            for clib in clibs:
1581
                if sum(clib in hood for hood in index_neighborhoods)==1:
1582
                    self.sequence_to_index_mapping[clib] = lib
1583
1584
        # Prepare error corrected barcode sets
1585
        error_corrected_barcodes = self.project.gel_barcode2_list_neighborhood
1586
        error_corrected_rev_compl_barcodes = self.project.gel_barcode2_revcomp_list_neighborhood
1587
1588
        # Open up our context managers
1589
        manager_order = [] #It's imperative to exit managers the opposite order than we open them!
1590
        trim_processes = {}
1591
        trim_processes_managers = {}
1592
1593
        for lib in self.libraries.keys():
1594
            manager_order.append(lib)
1595
            trim_processes_managers[lib] = self.libraries[lib].trimmomatic_and_low_complexity_filter_process()
1596
            trim_processes[lib] = trim_processes_managers[lib].__enter__()
1597
1598
        overall_filtering_statistics = defaultdict(int)
1599
1600
        # Paths for the 4 expected FastQs
1601
        input_fastqs = []
1602
        for r in ['R1', 'R2', 'R3', 'R4']:
1603
            input_fastqs.append(self.input_filename.format(read=r))
1604
1605
        last_ping = time.time()
1606
        ping_every_n_reads = 1000000
1607
        ping_header = "{0:>12}{1:>16}{2:>12}{3:>10}{4:>10}{5:>10}{6:>10}   |" + ''.join("{%d:>12.10}"%i for i in range(7,7+len(manager_order)))
1608
        ping_header = ping_header.format("Total Reads", "", "Valid Reads", "No index", "No BC1", "No BC2", "No UMI", *[self.libraries[k].library_name for k in manager_order])
1609
        ping_template = "{total:12d}    {rate:5.1f} sec/M {Valid:12.1%}{Invalid_library_index:10.1%}{Invalid_BC1:10.1%}{Invalid_BC2:10.1%}{UMI_contains_N:10.1%}   |{"+":>12.1%}{".join(manager_order)+":>12.1%}"
1610
        
1611
        def print_ping_to_log(last_ping):
1612
            sec_per_mil = (time.time() - last_ping)/(float(ping_every_n_reads)/10**6) if last_ping else 0
1613
            total = overall_filtering_statistics['Total']
1614
            ping_format_data = {k: float(overall_filtering_statistics[k])/total for k in ['Valid', 'Invalid_library_index', 'Invalid_BC1',  'Invalid_BC2', 'UMI_contains_N']}
1615
            if overall_filtering_statistics['Valid'] > 0:
1616
                ping_format_data.update({k: float(self.libraries[k].filtering_statistics_counter['Valid'])/overall_filtering_statistics['Valid'] for k in manager_order})
1617
            print_to_stderr(ping_template.format(total=total, rate=sec_per_mil, **ping_format_data))
1618
1619
        common__ = defaultdict(int)
1620
        print_to_stderr('Filtering %s, file %s' % (self.run_name, self.input_filename))
1621
        for r_name, seqs, quals in self._weave_fastqs(input_fastqs):
1622
1623
            # Python 3 compatibility in mind!
1624
            seqs = [s.decode('utf-8') for s in seqs]
1625
1626
            keep, lib_index, result = self._process_reads(r_name, seqs, quals,
1627
                                                    error_corrected_barcodes, error_corrected_rev_compl_barcodes, 
1628
                                                    self.sequence_to_index_mapping)
1629
            common__[seqs[1]] += 1
1630
            if keep:
1631
                bc, umi = result
1632
                bio_read = seqs[0]
1633
                bio_qual = quals[0]
1634
                trim_processes[lib_index].write(to_fastq_lines(bc, umi, bio_read, bio_qual, r_name[1:]))
1635
                self.libraries[lib_index].filtering_statistics_counter['Valid'] += 1
1636
                self.libraries[lib_index].filtering_statistics_counter['Total'] += 1
1637
                overall_filtering_statistics['Valid'] += 1
1638
1639
            else:
1640
                if result != 'Invalid_library_index':
1641
                    self.libraries[lib_index].filtering_statistics_counter[result] += 1
1642
                    self.libraries[lib_index].filtering_statistics_counter['Total'] += 1
1643
                overall_filtering_statistics[result] += 1
1644
1645
            # Track speed per M reads
1646
            overall_filtering_statistics['Total'] += 1
1647
1648
            if overall_filtering_statistics['Total']%(ping_every_n_reads*10)==1:
1649
                print_to_stderr(ping_header)
1650
            
1651
            if overall_filtering_statistics['Total']%ping_every_n_reads == 0:
1652
                print_ping_to_log(last_ping)
1653
                last_ping = time.time()
1654
                
1655
        print_ping_to_log(False)
1656
        # Close up the context managers
1657
        for lib in manager_order[::-1]:
1658
            trim_processes_managers[lib].__exit__(None, None, None)
1659
1660
    def contains_library_in_query(self, query_libraries):
1661
        for lib in self.libraries.values():
1662
            if lib.contains_library_in_query(query_libraries):
1663
                return True
1664
        return False
1665
1666
1667
1668
1669
1670
if __name__=="__main__":
1671
1672
    import sys, argparse
1673
    parser = argparse.ArgumentParser()
1674
1675
    parser.add_argument('project', type=argparse.FileType('r'), help='Project YAML File.')
1676
    parser.add_argument('-l', '--libraries', type=str, help='[all] Library name(s) to work on. If blank, will iterate over all libraries in project.', nargs='?', default='')
1677
    parser.add_argument('-r', '--runs', type=str, help='[all] Run name(s) to work on. If blank, will iterate over all runs in project.', nargs='?', default='')
1678
    parser.add_argument('command', type=str, choices=['info', 'filter', 'identify_abundant_barcodes', 'sort', 'quantify', 'aggregate', 'build_index', 'get_reads', 'output_barcode_fastq'])
1679
    parser.add_argument('--total-workers', type=int, help='[all] Total workers that are working together. This takes precedence over barcodes-per-worker.', default=1)
1680
    parser.add_argument('--worker-index', type=int, help='[all] Index of current worker (the first worker should have index 0).', default=0)
1681
    parser.add_argument('--min-reads', type=int, help='[quantify] Minimum number of reads for barcode to be processed', nargs='?', default=750)
1682
    parser.add_argument('--max-reads', type=int, help='[quantify] Maximum number of reads for barcode to be processed', nargs='?', default=100000000)
1683
    parser.add_argument('--min-counts', type=int, help='[aggregate] Minimun number of UMIFM counts for barcode to be aggregated', nargs='?', default=0)
1684
    parser.add_argument('--analysis-prefix', type=str, help='[quantify/aggregate/convert_bam/merge_bam] Prefix for analysis files.', nargs='?', default='')
1685
    parser.add_argument('--no-bam', help='[quantify] Do not output alignments to bam file.', action='store_true')
1686
    parser.add_argument('--genome-fasta-gz', help='[build_index] Path to gzipped soft-masked genomic FASTA file.')
1687
    parser.add_argument('--ensembl-gtf-gz', help='[build_index] Path to gzipped ENSEMBL GTF file. ')
1688
    parser.add_argument('--mode', help='[build_index] Stringency mode for transcriptome build. [strict|all_ensembl]', default='strict')
1689
    parser.add_argument('--override-yaml', help="[all] Dictionnary to update project YAML with.. [You don't need this.]", nargs='?', default='')
1690
1691
    args = parser.parse_args()
1692
    project = IndropsProject(args.project)
1693
    if args.override_yaml:
1694
        override = eval(args.override_yaml)
1695
        if 'paths' in override:
1696
            project.yaml['paths'].update(override['paths'])
1697
        if 'parameters' in override:
1698
            for k,v in override['parameters'].items():
1699
                project.yaml['parameters'][k].update(v)
1700
        if hasattr(project, '_paths'):
1701
            del project._paths
1702
        if hasattr(project, '_parameters'):
1703
            del project._parameters
1704
1705
    target_libraries = []
1706
    if args.libraries:
1707
        for lib in args.libraries.split(','):
1708
            assert lib in project.libraries
1709
            if lib not in target_libraries:
1710
                target_libraries.append(lib)
1711
    else:
1712
        target_libraries = project.libraries.keys()
1713
    lib_query = set(target_libraries)
1714
1715
    target_runs = []
1716
    if args.runs:
1717
        for run in args.runs.split(','):
1718
            assert run in project.runs
1719
            target_runs.append(run)
1720
    else:
1721
        target_runs = project.runs.keys()
1722
1723
    target_library_parts = []
1724
    for lib in target_libraries:
1725
        for pi, part in enumerate(project.libraries[lib].parts):
1726
            if part.run_name in target_runs:
1727
                target_library_parts.append((lib, pi))
1728
1729
    if args.command == 'info':
1730
        print_to_stderr('Project Name: ' + project.name)
1731
        target_run_parts = []
1732
        for run in target_runs:
1733
            target_run_parts += [part for part in project.runs[run] if part.contains_library_in_query(lib_query)]
1734
        print_to_stderr('Total library parts in search query: ' + str(len(target_run_parts)))
1735
1736
    elif args.command == 'filter':
1737
        target_run_parts = []
1738
        for run in target_runs:
1739
            target_run_parts += [part for part in project.runs[run] if part.contains_library_in_query(lib_query)]
1740
1741
        for part in worker_filter(target_run_parts, args.worker_index, args.total_workers):
1742
            print_to_stderr('Filtering run "%s", library "%s", part "%s"' % (part.run_name, part.library_name if hasattr(part, 'library_name') else 'N/A', part.part_name))
1743
            part.filter_and_count_reads()
1744
1745
    elif args.command == 'identify_abundant_barcodes':
1746
        for library in worker_filter(target_libraries, args.worker_index, args.total_workers):
1747
            project.libraries[library].identify_abundant_barcodes()
1748
1749
    elif args.command == 'sort':
1750
        for library, part_index in worker_filter(target_library_parts, args.worker_index, args.total_workers):
1751
            print_to_stderr('Sorting %s, part "%s"' % (library, project.libraries[library].parts[part_index].filtered_fastq_filename))
1752
            project.libraries[library].sort_reads_by_barcode(index=part_index)
1753
1754
    elif args.command == 'quantify':
1755
        for library in target_libraries:
1756
            project.libraries[library].quantify_expression(worker_index=args.worker_index, total_workers=args.total_workers,
1757
                    min_reads=args.min_reads, max_reads=args.max_reads, min_counts=args.min_counts,
1758
                    analysis_prefix=args.analysis_prefix,
1759
                    no_bam=args.no_bam, run_filter=target_runs)
1760
1761
            for part in project.libraries[library].parts:
1762
                if hasattr(part, '_sorted_index'):
1763
                    del part._sorted_index
1764
1765
    elif args.command == 'aggregate':
1766
        for library in target_libraries:
1767
            project.libraries[library].aggregate_counts(analysis_prefix=args.analysis_prefix)
1768
1769
    elif args.command == 'build_index':
1770
        project.build_transcriptome(args.genome_fasta_gz, args.ensembl_gtf_gz, mode=args.mode)
1771
1772
    elif args.command == 'get_reads':
1773
        for library in target_libraries:
1774
            sorted_barcode_names = project.libraries[library].sorted_barcode_names(min_reads=args.min_reads, max_reads=args.max_reads)
1775
            for bc in sorted_barcode_names:
1776
                for line in project.libraries[library].get_reads_for_barcode(bc, run_filter=target_runs):
1777
                    sys.stdout.write(line)
1778
1779
            for part in project.libraries[library].parts:
1780
                if hasattr(part, '_sorted_index'):
1781
                    del part._sorted_index
1782
1783
    elif args.command == 'output_barcode_fastq':
1784
        for library in target_libraries:
1785
            project.libraries[library].output_barcode_fastq(worker_index=args.worker_index, total_workers=args.total_workers,
1786
                    min_reads=args.min_reads, max_reads=args.max_reads, analysis_prefix=args.analysis_prefix, run_filter=target_runs)