a b/deeptrio/make_examples_test.py
1
# Copyright 2020 Google LLC.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# 1. Redistributions of source code must retain the above copyright notice,
8
#    this list of conditions and the following disclaimer.
9
#
10
# 2. Redistributions in binary form must reproduce the above copyright
11
#    notice, this list of conditions and the following disclaimer in the
12
#    documentation and/or other materials provided with the distribution.
13
#
14
# 3. Neither the name of the copyright holder nor the names of its
15
#    contributors may be used to endorse or promote products derived from this
16
#    software without specific prior written permission.
17
#
18
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
# POSSIBILITY OF SUCH DAMAGE.
29
"""Tests for deeptrio.make_examples."""
30
31
import errno
32
import json
33
import platform
34
import sys
35
from unittest import mock
36
37
from absl import flags
38
from absl import logging
39
from absl.testing import absltest
40
from absl.testing import flagsaver
41
from absl.testing import parameterized
42
from etils import epath
43
import numpy as np
44
45
from deeptrio import make_examples
46
from deeptrio import testdata
47
from deepvariant import dv_constants
48
from deepvariant import dv_utils
49
from deepvariant import make_examples_core
50
from deepvariant.protos import deepvariant_pb2
51
from tensorflow.python.platform import gfile
52
from third_party.nucleus.io import fasta
53
from third_party.nucleus.io import sharded_file_utils
54
from third_party.nucleus.io import tfrecord
55
from third_party.nucleus.io import vcf
56
from third_party.nucleus.protos import reference_pb2
57
from third_party.nucleus.protos import variants_pb2
58
from third_party.nucleus.testing import test_utils
59
from third_party.nucleus.util import ranges
60
from third_party.nucleus.util import variant_utils
61
from third_party.nucleus.util import variantcall_utils
62
from third_party.nucleus.util import vcf_constants
63
64
FLAGS = flags.FLAGS
65
66
# Dictionary mapping keys to decoders for decode_example function.
67
_EXAMPLE_DECODERS = {
68
    'locus': dv_utils.example_locus,
69
    'alt_allele_indices/encoded': dv_utils.example_alt_alleles_indices,
70
    'image/encoded': dv_utils.example_encoded_image,
71
    'variant/encoded': dv_utils.example_variant,
72
    'variant_type': dv_utils.example_variant_type,
73
    'label': dv_utils.example_label,
74
    'image/shape': dv_utils.example_image_shape,
75
    'sequencing_type': dv_utils.example_sequencing_type,
76
    'denovo_label': dv_utils.example_denovo_label,
77
}
78
79
80
def decode_example(example):
81
  """Decodes a tf.Example from DeepVariant into a dict of Pythonic structures.
82
83
  Args:
84
    example: tf.Example proto. The example to make into a dictionary.
85
86
  Returns:
87
    A python dictionary with key/value pairs for each of the fields of example,
88
    with each value decoded as needed into Python structures like protos, list,
89
    etc.
90
91
  Raises:
92
    KeyError: If example contains a feature without a known decoder.
93
  """
94
  as_dict = {}
95
  for key in example.features.feature:
96
    if key not in _EXAMPLE_DECODERS:
97
      raise KeyError('Unexpected example key', key)
98
    as_dict[key] = _EXAMPLE_DECODERS[key](example)
99
  return as_dict
100
101
102
def setUpModule():
103
  logging.set_verbosity(logging.FATAL)
104
  testdata.init()
105
106
107
def _make_contigs(specs):
108
  """Makes ContigInfo protos from specs.
109
110
  Args:
111
    specs: A list of 2- or 3-tuples. All tuples should be of the same length. If
112
      2-element, these should be the name and length in basepairs of each
113
      contig, and their pos_in_fasta will be set to their index in the list. If
114
      the 3-element, the tuple should contain name, length, and pos_in_fasta.
115
116
  Returns:
117
    A list of ContigInfo protos, one for each spec in specs.
118
  """
119
  if specs and len(specs[0]) == 3:
120
    return [
121
        reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i)
122
        for name, length, i in specs
123
    ]
124
  else:
125
    return [
126
        reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i)
127
        for i, (name, length) in enumerate(specs)
128
    ]
129
130
131
def _from_literals_list(literals, contig_map=None):
132
  """Makes a list of Range objects from literals."""
133
  return ranges.parse_literals(literals, contig_map)
134
135
136
def _from_literals(literals, contig_map=None):
137
  """Makes a RangeSet of intervals from literals."""
138
  return ranges.RangeSet.from_regions(literals, contig_map)
139
140
141
def _sharded(basename, num_shards=None):
142
  if num_shards:
143
    return basename + '@' + str(num_shards)
144
  else:
145
    return basename
146
147
148
class MakeExamplesEnd2EndTest(parameterized.TestCase):
149
150
  # Golden sets are created with
151
  # learning/genomics/internal/create_golden_deep_trio.sh
152
  @parameterized.parameters(
153
      # All tests are run with fast_pass_aligner enabled. There are no
154
      # golden sets version for ssw realigner.
155
      dict(mode='calling', num_shards=0),
156
      dict(mode='calling', num_shards=3),
157
      dict(mode='candidate_sweep', num_shards=0),
158
      dict(mode='candidate_sweep', num_shards=3),
159
      dict(
160
          mode='training', num_shards=0, labeler_algorithm='haplotype_labeler'
161
      ),
162
      dict(
163
          mode='training', num_shards=3, labeler_algorithm='haplotype_labeler'
164
      ),
165
      dict(
166
          mode='training', num_shards=0, labeler_algorithm='positional_labeler'
167
      ),
168
      dict(
169
          mode='training', num_shards=3, labeler_algorithm='positional_labeler'
170
      ),
171
  )
172
  @flagsaver.flagsaver
173
  def test_make_examples_end2end(
174
      self, mode, num_shards, labeler_algorithm=None, use_fast_pass_aligner=True
175
  ):
176
    self.assertIn(mode, {'calling', 'training', 'candidate_sweep'})
177
    region = ranges.parse_literal('20:10,000,000-10,010,000')
178
    FLAGS.write_run_info = True
179
    FLAGS.ref = testdata.CHR20_FASTA
180
    FLAGS.reads = testdata.HG001_CHR20_BAM
181
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
182
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
183
    FLAGS.sample_name = 'child'
184
    FLAGS.sample_name_to_train = 'child'
185
    FLAGS.sample_name_parent1 = 'parent1'
186
    FLAGS.sample_name_parent2 = 'parent2'
187
    FLAGS.candidates = test_utils.test_tmpfile(
188
        _sharded('vsc.tfrecord', num_shards)
189
    )
190
    FLAGS.examples = test_utils.test_tmpfile(
191
        _sharded('examples.tfrecord', num_shards)
192
    )
193
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
194
    child_examples = test_utils.test_tmpfile(
195
        _sharded('examples_child.tfrecord', num_shards)
196
    )
197
    if mode == 'candidate_sweep':
198
      FLAGS.candidate_positions = test_utils.test_tmpfile(
199
          _sharded('candidate_positions', num_shards)
200
      )
201
      candidate_positions = test_utils.test_tmpfile(
202
          _sharded('candidate_positions', num_shards)
203
      )
204
    FLAGS.regions = [ranges.to_literal(region)]
205
    FLAGS.partition_size = 1000
206
    FLAGS.mode = mode
207
    FLAGS.gvcf_gq_binsize = 5
208
    FLAGS.use_fast_pass_aligner = use_fast_pass_aligner
209
    if labeler_algorithm is not None:
210
      FLAGS.labeler_algorithm = labeler_algorithm
211
212
    if mode == 'calling':
213
      FLAGS.gvcf = test_utils.test_tmpfile(
214
          _sharded('gvcf.tfrecord', num_shards)
215
      )
216
      child_gvcf = test_utils.test_tmpfile(
217
          _sharded('gvcf_child.tfrecord', num_shards)
218
      )
219
      child_candidates = test_utils.test_tmpfile(
220
          _sharded('vsc_child.tfrecord', num_shards)
221
      )
222
    else:
223
      FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
224
      FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
225
      child_candidates = test_utils.test_tmpfile(
226
          _sharded('vsc.tfrecord', num_shards)
227
      )
228
229
    if mode == 'candidate_sweep':
230
      golden_candidate_positions = _sharded(
231
          testdata.GOLDEN_CANDIDATE_POSITIONS, num_shards
232
      )
233
    for task_id in range(max(num_shards, 1)):
234
      FLAGS.task = task_id
235
      options = make_examples.default_options(add_flags=True)
236
      make_examples_core.make_examples_runner(options)
237
238
      # Check that our run_info proto contains the basic fields we'd expect:
239
      # (a) our options are written to the run_info.options field.
240
      run_info = make_examples_core.read_make_examples_run_info(
241
          options.run_info_filename
242
      )
243
      self.assertEqual(run_info.options, options)
244
      # (b) run_info.resource_metrics is present and contains our hostname.
245
      self.assertTrue(run_info.HasField('resource_metrics'))
246
      self.assertEqual(run_info.resource_metrics.host_name, platform.node())
247
248
      # For candidate_sweep mode we verify that candidate positions match
249
      # golden set exactly.
250
      if mode == 'candidate_sweep':
251
        _, candidates_path = sharded_file_utils.resolve_filespecs(
252
            task_id, candidate_positions
253
        )
254
        _, gold_candidates_path = sharded_file_utils.resolve_filespecs(
255
            task_id, golden_candidate_positions
256
        )
257
        self.verify_candidate_positions(candidates_path, gold_candidates_path)
258
259
    # In candidate_sweep mode the test stops here.
260
    if mode == 'candidate_sweep':
261
      return
262
263
    # Test that our candidates are reasonable, calling specific helper functions
264
    # to check lots of properties of the output.
265
    candidates = sorted(
266
        tfrecord.read_tfrecords(
267
            child_candidates,
268
            proto=deepvariant_pb2.DeepVariantCall,
269
            compression_type='GZIP',
270
        ),
271
        key=lambda c: variant_utils.variant_range_tuple(c.variant),
272
    )
273
    self.verify_deepvariant_calls(candidates, options)
274
    self.verify_variants(
275
        [call.variant for call in candidates], region, options, is_gvcf=False
276
    )
277
278
    # Verify that the variants in the examples are all good.
279
    if mode == 'calling':
280
      examples = self.verify_examples(
281
          child_examples,
282
          region,
283
          options,
284
          verify_labels=False,
285
          examples_filename=FLAGS.examples,
286
      )
287
    if mode == 'training':
288
      examples = self.verify_examples(
289
          FLAGS.examples, region, options, verify_labels=True
290
      )
291
    example_variants = [dv_utils.example_variant(ex) for ex in examples]
292
    self.verify_variants(example_variants, region, options, is_gvcf=False)
293
294
    # Verify the integrity of the examples and then check that they match our
295
    # golden labeled examples. Note we expect the order for both training and
296
    # calling modes to produce deterministic order because we fix the random
297
    # seed.
298
    if mode == 'calling':
299
      golden_file = _sharded(testdata.GOLDEN_CALLING_EXAMPLES, num_shards)
300
    else:
301
      golden_file = _sharded(testdata.GOLDEN_TRAINING_EXAMPLES, num_shards)
302
    self.assertDeepVariantExamplesEqual(
303
        examples,
304
        list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')),
305
    )
306
307
    if mode == 'calling':
308
      nist_reader = vcf.VcfReader(testdata.TRUTH_VARIANTS_VCF)
309
      nist_variants = list(nist_reader.query(region))
310
      self.verify_nist_concordance(example_variants, nist_variants)
311
312
      # Check the quality of our generated gvcf file.
313
      gvcfs = variant_utils.sorted_variants(
314
          tfrecord.read_tfrecords(
315
              child_gvcf, proto=variants_pb2.Variant, compression_type='GZIP'
316
          )
317
      )
318
      self.verify_variants(gvcfs, region, options, is_gvcf=True)
319
      self.verify_contiguity(gvcfs, region)
320
      gvcf_golden_file = _sharded(
321
          testdata.GOLDEN_POSTPROCESS_GVCF_INPUT, num_shards
322
      )
323
      expected_gvcfs = list(
324
          tfrecord.read_tfrecords(
325
              gvcf_golden_file,
326
              proto=variants_pb2.Variant,
327
              compression_type='GZIP',
328
          )
329
      )
330
331
332
      self.assertCountEqual(gvcfs, expected_gvcfs)
333
334
    if (
335
        mode == 'training'
336
        and num_shards == 0
337
        and labeler_algorithm != 'positional_labeler'
338
    ):
339
      # The positional labeler doesn't track metrics, so don't try to read them
340
      # in when that's the mode.
341
      self.assertEqual(
342
          make_examples_core.read_make_examples_run_info(
343
              testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO
344
          ).labeling_metrics,
345
          run_info.labeling_metrics,
346
      )
347
348
  @parameterized.parameters(
349
      dict(
350
          denovo_test=False,
351
          expected_denovo_variants=0,
352
      ),
353
      dict(
354
          denovo_test=True,
355
          expected_denovo_variants=3,
356
      ),
357
  )
358
  @flagsaver.flagsaver
359
  def test_make_examples_ont_end2end(
360
      self,
361
      denovo_test: bool,
362
      expected_denovo_variants: int,
363
  ):
364
    """Test end to end for long ONT reads with phasing enabled.
365
366
    Args:
367
      denovo_test: If true, denovo parameters will be set.
368
      expected_denovo_variants: Total number of denovo examples expected.
369
370
    This test runs ONT end to end and compares the output with the golden
371
    output. This test is introduced because previously in training mode the
372
    non training sample would not be phased. So this now tests to make sure
373
    all of the training examples are phased correctly.
374
    """
375
    region = ranges.parse_literal('chr20:5050000-5075000')
376
    FLAGS.write_run_info = True
377
    FLAGS.ref = testdata.GRCH38_CHR0_FASTA
378
    FLAGS.reads = testdata.ONT_HG002_BAM
379
    FLAGS.reads_parent1 = testdata.ONT_HG003_BAM
380
    FLAGS.reads_parent2 = testdata.ONT_HG004_BAM
381
    FLAGS.confident_regions = testdata.HG002_HIGH_CONFIDENCE_BED
382
    FLAGS.truth_variants = testdata.HG002_HIGH_CONFIDENCE_VCF
383
    FLAGS.sample_name = 'HG002'
384
    FLAGS.sample_name_to_train = 'HG002'
385
    FLAGS.sample_name_parent1 = 'HG003'
386
    FLAGS.sample_name_parent2 = 'HG004'
387
    FLAGS.alt_aligned_pileup = 'diff_channels'
388
    FLAGS.min_mapping_quality = 1
389
    FLAGS.mode = 'training'
390
    FLAGS.parse_sam_aux_fields = True
391
    FLAGS.partition_size = 25000
392
    FLAGS.phase_reads = True
393
    FLAGS.pileup_image_height_child = 100
394
    FLAGS.pileup_image_height_parent = 100
395
    FLAGS.pileup_image_width = 199
396
    FLAGS.realign_reads = False
397
    FLAGS.skip_parent_calling = True
398
    FLAGS.sort_by_haplotypes = True
399
    FLAGS.track_ref_reads = True
400
    FLAGS.vsc_min_fraction_indels = 0.12
401
    FLAGS.vsc_min_fraction_snps = 0.1
402
    num_shards = 0
403
    FLAGS.examples = test_utils.test_tmpfile(
404
        _sharded('examples.tfrecord', num_shards)
405
    )
406
    FLAGS.channel_list = ','.join(
407
        dv_constants.PILEUP_DEFAULT_CHANNELS + ['haplotype']
408
    )
409
    FLAGS.regions = [ranges.to_literal(region)]
410
    golden_file = _sharded(testdata.GOLDEN_ONT_MAKE_EXAMPLES_OUTPUT, num_shards)
411
    FLAGS.denovo_regions = None
412
    if denovo_test:
413
      # If denovo test is enabled, then set the parameters for denovo testing.
414
      golden_file = _sharded(
415
          testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT, num_shards
416
      )
417
      FLAGS.write_run_info = True
418
      FLAGS.denovo_regions = testdata.HG002_DENOVO_BED
419
420
    for task_id in range(max(num_shards, 1)):
421
      FLAGS.task = task_id
422
      options = make_examples.default_options(add_flags=True)
423
      make_examples_core.make_examples_runner(options)
424
425
      examples = self.verify_examples(
426
          FLAGS.examples, region, options, verify_labels=True
427
      )
428
429
      self.assertDeepVariantExamplesEqual(
430
          examples,
431
          list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')),
432
      )
433
      if denovo_test:
434
        # Check total number of denovo examples.
435
        total_denovo = sum(
436
            [
437
                1
438
                for example in examples
439
                if dv_utils.example_denovo_label(example)
440
            ]
441
        )
442
        self.assertEqual(
443
            total_denovo,
444
            expected_denovo_variants,
445
            msg='ONT denovo golden test: denovo variants count.',
446
        )
447
        # Read the runinfo file
448
        runinfo = make_examples_core.read_make_examples_run_info(
449
            FLAGS.examples + '.run_info.pbtxt'
450
        )
451
        golden_runinfo = make_examples_core.read_make_examples_run_info(
452
            testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT + '.run_info.pbtxt'
453
        )
454
        self.assertEqual(
455
            runinfo.stats.num_examples,
456
            golden_runinfo.stats.num_examples,
457
            msg='ONT denovo golden test: Run info comparison num_examples.',
458
        )
459
        self.assertEqual(
460
            runinfo.stats.num_denovo,
461
            golden_runinfo.stats.num_denovo,
462
            msg='ONT denovo golden test: Run info comparison num_denovo.',
463
        )
464
        self.assertEqual(
465
            runinfo.stats.num_nondenovo,
466
            golden_runinfo.stats.num_nondenovo,
467
            msg='ONT denovo golden test: Run info comparison num_nondenovo.',
468
        )
469
470
  # Golden sets are created with learning/genomics/internal/create_golden.sh
471
  @flagsaver.flagsaver
472
  def test_make_examples_training_end2end_with_customized_classes_labeler(self):
473
    FLAGS.labeler_algorithm = 'customized_classes_labeler'
474
    FLAGS.customized_classes_labeler_classes_list = 'ref,class1,class2'
475
    FLAGS.customized_classes_labeler_info_field_name = 'type'
476
    region = ranges.parse_literal('20:10,000,000-10,004,000')
477
    FLAGS.regions = [ranges.to_literal(region)]
478
    FLAGS.ref = testdata.CHR20_FASTA
479
    FLAGS.reads = testdata.HG001_CHR20_BAM
480
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
481
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
482
    FLAGS.sample_name = 'child'
483
    FLAGS.sample_name_to_train = 'child'
484
    FLAGS.sample_name_parent1 = 'parent1'
485
    FLAGS.sample_name_parent2 = 'parent2'
486
    FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
487
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
488
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
489
    FLAGS.partition_size = 1000
490
    FLAGS.mode = 'training'
491
    FLAGS.gvcf_gq_binsize = 5
492
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF_WITH_TYPES
493
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
494
    options = make_examples.default_options(add_flags=True)
495
    make_examples_core.make_examples_runner(options)
496
    golden_file = _sharded(testdata.CUSTOMIZED_CLASSES_GOLDEN_TRAINING_EXAMPLES)
497
    # Verify that the variants in the examples are all good.
498
    examples = self.verify_examples(
499
        FLAGS.examples, region, options, verify_labels=True
500
    )
501
    self.assertDeepVariantExamplesEqual(
502
        examples,
503
        list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')),
504
    )
505
506
  # Golden sets are created with learning/genomics/internal/create_golden.sh
507
  @flagsaver.flagsaver
508
  def test_make_examples_training_end2end_with_alt_aligned_pileup(self):
509
    region = ranges.parse_literal('20:10,000,000-10,010,000')
510
    FLAGS.regions = [ranges.to_literal(region)]
511
    FLAGS.ref = testdata.CHR20_FASTA
512
    FLAGS.reads = testdata.HG001_CHR20_BAM
513
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
514
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
515
    FLAGS.sample_name = 'child'
516
    FLAGS.sample_name_to_train = 'child'
517
    FLAGS.sample_name_parent1 = 'parent1'
518
    FLAGS.sample_name_parent2 = 'parent2'
519
    FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
520
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
521
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_DEFAULT_CHANNELS)
522
    FLAGS.partition_size = 1000
523
    FLAGS.mode = 'training'
524
    FLAGS.gvcf_gq_binsize = 5
525
526
    # The following 4 lines are added.
527
    FLAGS.alt_aligned_pileup = 'diff_channels'
528
    FLAGS.pileup_image_height_child = 60
529
    FLAGS.pileup_image_height_parent = 40
530
    FLAGS.pileup_image_width = 199
531
532
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
533
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
534
    options = make_examples.default_options(add_flags=True)
535
    make_examples_core.make_examples_runner(options)
536
    golden_file = _sharded(testdata.ALT_ALIGNED_PILEUP_GOLDEN_TRAINING_EXAMPLES)
537
    # Verify that the variants in the examples are all good.
538
    examples = self.verify_examples(
539
        FLAGS.examples, region, options, verify_labels=True
540
    )
541
    self.assertDeepVariantExamplesEqual(
542
        examples,
543
        list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')),
544
    )
545
    # Pileup image should now have 8 channels.
546
    # Height should be 60 + 40 * 2 = 140.
547
    self.assertEqual(decode_example(examples[0])['image/shape'], [140, 199, 8])
548
549
  @flagsaver.flagsaver
550
  def test_make_examples_compare_realignment_modes(self):
551
    def _run_with_realignment_mode(enable_joint_realignment, name):
552
      FLAGS.enable_joint_realignment = enable_joint_realignment
553
      region = ranges.parse_literal('20:10,000,000-10,010,000')
554
      FLAGS.ref = testdata.CHR20_FASTA
555
      FLAGS.reads = testdata.HG001_CHR20_BAM
556
      FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
557
      FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
558
      FLAGS.sample_name = 'child'
559
      FLAGS.sample_name_to_train = 'child'
560
      FLAGS.sample_name_parent1 = 'parent1'
561
      FLAGS.sample_name_parent2 = 'parent2'
562
      FLAGS.candidates = test_utils.test_tmpfile(f'{name}.vsc.tfrecord')
563
      FLAGS.examples = test_utils.test_tmpfile(f'{name}.examples.tfrecord')
564
      FLAGS.channel_list = ','.join(
565
          dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE
566
      )
567
      child_examples = test_utils.test_tmpfile(
568
          f'{name}_child.examples.tfrecord'
569
      )
570
      FLAGS.regions = [ranges.to_literal(region)]
571
      FLAGS.partition_size = 1000
572
      FLAGS.mode = 'calling'
573
      FLAGS.gvcf = test_utils.test_tmpfile(f'{name}.gvcf.tfrecord')
574
      # child_gvcf = test_utils.test_tmpfile(f'{name}.gvcf_child.tfrecord')
575
      # child_candidates = test_utils.test_tmpfile(f'{name}.vsc_child.tfrecord')
576
      options = make_examples.default_options(add_flags=True)
577
      make_examples_core.make_examples_runner(options)
578
579
      examples = self.verify_examples(
580
          child_examples,
581
          region,
582
          options,
583
          verify_labels=False,
584
          examples_filename=FLAGS.examples,
585
      )
586
      return examples
587
588
    examples1 = _run_with_realignment_mode(False, 'ex1')
589
    examples2 = _run_with_realignment_mode(True, 'ex2')
590
    self.assertNotEmpty(examples1)
591
    self.assertNotEmpty(examples2)
592
    # The assumption is just that these two lists of examples should be
593
    # different. In this case, it happens to be that we got different numbers
594
    # of examples:
595
    self.assertNotEmpty(examples1)
596
    self.assertDeepVariantExamplesNotEqual(examples1, examples2)
597
598
  @parameterized.parameters(
599
      dict(select_types=None, expected_count=79),
600
      dict(select_types='all', expected_count=79),
601
      dict(select_types='snps', expected_count=64),
602
      dict(select_types='indels', expected_count=12),
603
      dict(select_types='snps indels', expected_count=76),
604
      dict(select_types='multi-allelics', expected_count=3),
605
      dict(select_types=None, keep_legacy_behavior=True, expected_count=79),
606
      dict(select_types='all', keep_legacy_behavior=True, expected_count=79),
607
      dict(select_types='snps', keep_legacy_behavior=True, expected_count=64),
608
      dict(select_types='indels', keep_legacy_behavior=True, expected_count=11),
609
      dict(
610
          select_types='snps indels',
611
          keep_legacy_behavior=True,
612
          expected_count=75,
613
      ),
614
      dict(
615
          select_types='multi-allelics',
616
          keep_legacy_behavior=True,
617
          expected_count=4,
618
      ),
619
  )
620
  @flagsaver.flagsaver
621
  def test_make_examples_with_variant_selection(
622
      self, select_types, expected_count, keep_legacy_behavior=False
623
  ):
624
    if select_types is not None:
625
      FLAGS.select_variant_types = select_types
626
    region = ranges.parse_literal('20:10,000,000-10,010,000')
627
    FLAGS.regions = [ranges.to_literal(region)]
628
    FLAGS.ref = testdata.CHR20_FASTA
629
    FLAGS.reads = testdata.HG001_CHR20_BAM
630
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
631
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
632
    FLAGS.sample_name = 'child'
633
    FLAGS.sample_name_to_train = 'child'
634
    FLAGS.sample_name_parent1 = 'parent1'
635
    FLAGS.sample_name_parent2 = 'parent2'
636
    FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
637
    child_candidates = test_utils.test_tmpfile(_sharded('vsc_child.tfrecord'))
638
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
639
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
640
    FLAGS.partition_size = 1000
641
    FLAGS.mode = 'calling'
642
    FLAGS.keep_legacy_allele_counter_behavior = keep_legacy_behavior
643
644
    options = make_examples.default_options(add_flags=True)
645
    make_examples_core.make_examples_runner(options)
646
647
    candidates = list(
648
        tfrecord.read_tfrecords(child_candidates, compression_type='GZIP')
649
    )
650
    self.assertLen(candidates, expected_count)
651
652
  @parameterized.parameters(
653
      dict(
654
          mode='calling', which_parent='parent1', sample_name_to_train='child'
655
      ),
656
      dict(
657
          mode='calling', which_parent='parent2', sample_name_to_train='child'
658
      ),
659
      dict(
660
          mode='training', which_parent='parent1', sample_name_to_train='child'
661
      ),
662
      dict(
663
          mode='training', which_parent='parent2', sample_name_to_train='child'
664
      ),
665
      dict(
666
          mode='calling', which_parent='parent1', sample_name_to_train='parent1'
667
      ),
668
      dict(
669
          mode='training',
670
          which_parent='parent1',
671
          sample_name_to_train='parent1',
672
      ),
673
      # Training on parent2 in a duo is not supported (with a clear error
674
      # message).
675
  )
676
  @flagsaver.flagsaver
677
  def test_make_examples_training_end2end_duos(
678
      self, mode, which_parent, sample_name_to_train
679
  ):
680
    region = ranges.parse_literal('20:10,000,000-10,010,000')
681
    FLAGS.regions = [ranges.to_literal(region)]
682
    FLAGS.ref = testdata.CHR20_FASTA
683
    FLAGS.reads = testdata.HG001_CHR20_BAM
684
    FLAGS.sample_name = 'child'
685
    FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
686
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
687
    FLAGS.partition_size = 1000
688
689
    FLAGS.mode = mode
690
    if mode == 'training':
691
      FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
692
      FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
693
694
    if which_parent == 'parent1':
695
      FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
696
      FLAGS.sample_name_parent1 = 'parent1'
697
    elif which_parent == 'parent2':
698
      FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
699
      FLAGS.sample_name_parent2 = 'parent2'
700
    else:
701
      raise ValueError('Invalid `which_parent` value in test case.')
702
    FLAGS.sample_name_to_train = sample_name_to_train
703
704
    # This is only a simple test that it runs without errors.
705
    options = make_examples.default_options(add_flags=True)
706
    make_examples_core.make_examples_runner(options)
707
708
  @parameterized.parameters(
709
      dict(mode='calling'),
710
      dict(mode='training'),
711
  )
712
  @flagsaver.flagsaver
713
  def test_make_examples_end2end_vcf_candidate_importer(self, mode):
714
    FLAGS.variant_caller = 'vcf_candidate_importer'
715
    FLAGS.ref = testdata.CHR20_FASTA
716
    FLAGS.reads = testdata.HG001_CHR20_BAM
717
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
718
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
719
    FLAGS.sample_name = 'child'
720
    FLAGS.sample_name_parent1 = 'parent1'
721
    FLAGS.sample_name_parent2 = 'parent2'
722
    FLAGS.pileup_image_height_parent = 40
723
    FLAGS.pileup_image_height_child = 60
724
    FLAGS.candidates = test_utils.test_tmpfile(
725
        _sharded('vcf_candidate_importer.candidates.{}.tfrecord'.format(mode))
726
    )
727
    FLAGS.examples = test_utils.test_tmpfile(
728
        _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode))
729
    )
730
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
731
    FLAGS.mode = mode
732
    FLAGS.regions = '20:10,000,000-10,010,000'
733
734
    if mode == 'calling':
735
      golden_file = _sharded(
736
          testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_CALLING_EXAMPLES_CHILD
737
      )
738
      path_to_output_examples = test_utils.test_tmpfile(
739
          _sharded(
740
              'vcf_candidate_importer_child.examples.{}.tfrecord'.format(mode)
741
          )
742
      )
743
      FLAGS.proposed_variants_child = testdata.TRUTH_VARIANTS_VCF
744
      FLAGS.proposed_variants_parent1 = testdata.TRUTH_VARIANTS_VCF
745
      FLAGS.proposed_variants_parent2 = testdata.TRUTH_VARIANTS_VCF
746
    else:
747
      golden_file = _sharded(
748
          testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_TRAINING_EXAMPLES
749
      )
750
      path_to_output_examples = test_utils.test_tmpfile(
751
          _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode))
752
      )
753
      FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
754
      FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
755
756
    options = make_examples.default_options(add_flags=True)
757
    make_examples_core.make_examples_runner(options)
758
    # Verify that the variants in the examples are all good.
759
    output_examples_to_compare = self.verify_examples(
760
        path_to_output_examples,
761
        None,
762
        options,
763
        verify_labels=mode == 'training',
764
        examples_filename=FLAGS.examples,
765
    )
766
    self.assertDeepVariantExamplesEqual(
767
        output_examples_to_compare,
768
        list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')),
769
    )
770
771
  @parameterized.parameters(
772
      dict(
773
          max_reads_per_partition=1500,
774
          expected_len_examples1=88,
775
          expected_len_examples2=32,
776
      ),
777
      dict(
778
          max_reads_per_partition=8,
779
          expected_len_examples1=34,
780
          expected_len_examples2=30,
781
      ),
782
  )
783
  @flagsaver.flagsaver
784
  def test_make_examples_with_max_reads_for_dynamic_bases_per_region(
785
      self,
786
      max_reads_per_partition,
787
      expected_len_examples1,
788
      expected_len_examples2,
789
  ):
790
    region = ranges.parse_literal('20:10,000,000-10,010,000')
791
    FLAGS.regions = [ranges.to_literal(region)]
792
    FLAGS.ref = testdata.CHR20_FASTA
793
    FLAGS.reads = testdata.HG001_CHR20_BAM
794
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
795
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
796
    FLAGS.sample_name = 'child'
797
    FLAGS.sample_name_to_train = 'child'
798
    FLAGS.sample_name_parent1 = 'parent1'
799
    FLAGS.sample_name_parent2 = 'parent2'
800
    FLAGS.examples = test_utils.test_tmpfile(_sharded('ex.tfrecord'))
801
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
802
    child_examples = test_utils.test_tmpfile(_sharded('ex_child.tfrecord'))
803
    FLAGS.partition_size = 1000
804
    FLAGS.mode = 'calling'
805
    FLAGS.max_reads_per_partition = max_reads_per_partition
806
807
    options = make_examples.default_options(add_flags=True)
808
    make_examples_core.make_examples_runner(options)
809
    examples1 = self.verify_examples(
810
        child_examples,
811
        region,
812
        options,
813
        verify_labels=False,
814
        examples_filename=FLAGS.examples,
815
    )
816
    self.assertLen(examples1, expected_len_examples1)
817
    # Now, this is the main part of the test. I want to test the behavior after
818
    # I set max_reads_for_dynamic_bases_per_region.
819
    FLAGS.max_reads_for_dynamic_bases_per_region = 1
820
    options = make_examples.default_options(add_flags=True)
821
    make_examples_core.make_examples_runner(options)
822
    examples2 = self.verify_examples(
823
        child_examples,
824
        region,
825
        options,
826
        verify_labels=False,
827
        examples_filename=FLAGS.examples,
828
    )
829
    self.assertLen(examples2, expected_len_examples2)
830
831
  def verify_nist_concordance(self, candidates, nist_variants):
832
    # Tests that we call almost all of the real variants (according to NIST's
833
    # Genome in a Bottle callset for NA12878) in our candidate callset.
834
    # Tests that we don't have an enormous number of FP calls. We should have
835
    # no more than 5x (arbitrary) more candidate calls than real calls. If we
836
    # have more it's likely due to some major pipeline problem.
837
    self.assertLess(len(candidates), 5 * len(nist_variants))
838
    tp_count = 0
839
    for nist_variant in nist_variants:
840
      if self.assertVariantIsPresent(nist_variant, candidates):
841
        tp_count = tp_count + 1
842
843
    self.assertGreater(
844
        tp_count / len(nist_variants),
845
        0.9705,
846
        'Recall must be greater than 0.9705. TP={}, Truth variants={}'.format(
847
            tp_count, len(nist_variants)
848
        ),
849
    )
850
851
  def assertDeepVariantExamplesEqual(self, actual, expected):
852
    """Asserts that actual and expected tf.Examples from DeepVariant are equal.
853
854
    Args:
855
      actual: iterable of tf.Examples from DeepVariant. DeepVariant examples
856
        that we want to check.
857
      expected: iterable of tf.Examples. Expected results for actual.
858
    """
859
    self.assertEqual(len(actual), len(expected))
860
    for i in range(len(actual)):
861
      self.assertEqual(decode_example(actual[i]), decode_example(expected[i]))
862
863
  def assertDeepVariantExamplesNotEqual(self, actual, expected):
864
    """Asserts that actual and expected tf.Examples are not equal.
865
866
    Args:
867
      actual: iterable of tf.Examples from DeepVariant. DeepVariant examples
868
        that we want to check.
869
      expected: iterable of tf.Examples. Expected results for actual.
870
    """
871
    pass_not_equal_check = False
872
    if len(actual) != len(expected):
873
      logging.warning(
874
          (
875
              'In assertDeepVariantExamplesNotEqual: '
876
              'actual(%d) and expected(%d) has different lengths'
877
          ),
878
          len(actual),
879
          len(expected),
880
      )
881
      pass_not_equal_check = True
882
    min_size = min(len(actual), len(expected))
883
    for i in range(min_size):
884
      if decode_example(actual[i]) != decode_example(expected[i]):
885
        logging.warning(
886
            (
887
                'assertDeepVariantExamplesNotEqual: '
888
                'actual example[%d] and expected example[%d] '
889
                'are different'
890
            ),
891
            i,
892
            i,
893
        )
894
        pass_not_equal_check = True
895
    self.assertTrue(
896
        pass_not_equal_check,
897
        (
898
            'assertDeepVariantExamplesNotEqual failed - '
899
            'actual and expected examples are identical.'
900
        ),
901
    )
902
903
  def assertVariantIsPresent(self, to_find, variants):
904
    def variant_key(v):
905
      return (v.reference_bases, v.start, v.end)
906
907
    # Finds a call in our actual call set for each NIST variant, asserting
908
    # that we found exactly one.
909
    matches = [
910
        variant
911
        for variant in variants
912
        if variant_key(to_find) == variant_key(variant)
913
    ]
914
    if not matches:
915
      return False
916
917
    # Verify that every alt allele appears in the call (but the call might)
918
    # have more than just those calls.
919
    for alt in to_find.alternate_bases:
920
      if alt not in matches[0].alternate_bases:
921
        return False
922
923
    return True
924
925
  def verify_candidate_positions(
926
      self, candidate_positions_paths, candidate_positions_golden_set
927
  ):
928
    with epath.Path(candidate_positions_golden_set).open('rb') as my_file:
929
      positions_golden = np.frombuffer(my_file.read(), dtype=np.int32)
930
    with epath.Path(candidate_positions_paths).open('rb') as my_file:
931
      positions = np.frombuffer(my_file.read(), dtype=np.int32)
932
    logging.warning(
933
        '%d positions created, %d positions in golden file',
934
        len(positions),
935
        len(positions_golden),
936
    )
937
    self.assertCountEqual(positions, positions_golden)
938
939
  def verify_variants(self, variants, region, options, is_gvcf):
940
    # Verifies simple properties of the Variant protos in variants. For example,
941
    # checks that the reference_name() is our expected chromosome. The flag
942
    # is_gvcf determines how we check the VariantCall field of each variant,
943
    # enforcing expectations for gVCF records if true or variant calls if false.
944
    for variant in variants:
945
      if region:
946
        self.assertEqual(variant.reference_name, region.reference_name)
947
        self.assertGreaterEqual(variant.start, region.start)
948
        self.assertLessEqual(variant.start, region.end)
949
      self.assertNotEqual(variant.reference_bases, '')
950
      self.assertNotEmpty(variant.alternate_bases)
951
      self.assertLen(variant.calls, 1)
952
953
      call = variant_utils.only_call(variant)
954
      self.assertEqual(
955
          call.call_set_name,
956
          options.sample_options[1].variant_caller_options.sample_name,
957
      )
958
      if is_gvcf:
959
        # GVCF records should have 0/0 or ./. (un-called) genotypes as they are
960
        # reference sites, have genotype likelihoods and a GQ value.
961
        self.assertIn(list(call.genotype), [[0, 0], [-1, -1]])
962
        self.assertLen(call.genotype_likelihood, 3)
963
        self.assertGreaterEqual(variantcall_utils.get_gq(call), 0)
964
965
  def verify_contiguity(self, contiguous_variants, region):
966
    """Verifies region is fully covered by gvcf records."""
967
    # We expect that the intervals cover every base, so the first variant should
968
    # be at our interval start and the last one should end at our interval end.
969
    self.assertNotEmpty(contiguous_variants)
970
    self.assertEqual(region.start, contiguous_variants[0].start)
971
    self.assertEqual(region.end, contiguous_variants[-1].end)
972
973
    # After this loop completes successfully we know that together the GVCF and
974
    # Variants form a fully contiguous cover of our calling interval, as
975
    # expected.
976
    for v1, v2 in zip(contiguous_variants, contiguous_variants[1:]):
977
      # Sequential variants should be contiguous, meaning that v2.start should
978
      # be v1's end, as the end is exclusive and the start is inclusive.
979
      if v1.start == v2.start and v1.end == v2.end:
980
        # Skip duplicates here as we may have multi-allelic variants turning
981
        # into multiple bi-allelic variants at the same site.
982
        continue
983
      # We expect to immediately follow the end of a gvcf record but to occur
984
      # at the base immediately after a variant, since the variant's end can
985
      # span over a larger interval when it's a deletion and we still produce
986
      # gvcf records under the deletion.
987
      expected_start = v1.end if v1.alternate_bases == ['<*>'] else v1.start + 1
988
      self.assertEqual(v2.start, expected_start)
989
990
  def verify_deepvariant_calls(self, dv_calls, options):
991
    # Verifies simple structural properties of the DeepVariantCall objects
992
    # emitted by the VerySensitiveCaller, such as that the AlleleCount and
993
    # Variant both have the same position.
994
    for call in dv_calls:
995
      for alt_allele in call.variant.alternate_bases:
996
        # Skip ref calls.
997
        if alt_allele == vcf_constants.NO_ALT_ALLELE:
998
          continue
999
        # Make sure allele appears in our allele_support field and that at
1000
        # least our min number of reads to call an alt allele are present in
1001
        # the supporting reads list for that allele.
1002
        self.assertIn(alt_allele, list(call.allele_support))
1003
        self.assertGreaterEqual(
1004
            len(call.allele_support[alt_allele].read_names),
1005
            options.sample_options[1].variant_caller_options.min_count_snps,
1006
        )
1007
1008
  def sanity_check_example_info_json(self, example, examples_filename, task_id):
1009
    """Checks `example_info.json` w/ examples_filename has the right info."""
1010
    example_info_json = dv_utils.get_example_info_json_filename(
1011
        examples_filename, task_id
1012
    )
1013
    example_info = json.load(gfile.GFile(example_info_json, 'r'))
1014
    self.assertIn('shape', example_info)
1015
    self.assertEqual(
1016
        example_info['shape'], dv_utils.example_image_shape(example)
1017
    )
1018
    self.assertIn('channels', example_info)
1019
    self.assertLen(example_info['channels'], example_info['shape'][2])
1020
1021
  def verify_examples(
1022
      self,
1023
      path_to_output_examples,
1024
      region,
1025
      options,
1026
      verify_labels,
1027
      examples_filename=None,
1028
  ):
1029
    # Do some simple structural checks on the tf.Examples in the file.
1030
    expected_features = [
1031
        'variant/encoded',
1032
        'locus',
1033
        'image/encoded',
1034
        'alt_allele_indices/encoded',
1035
    ]
1036
    if verify_labels:
1037
      expected_features += ['label']
1038
1039
    examples = list(
1040
        tfrecord.read_tfrecords(
1041
            path_to_output_examples, compression_type='GZIP'
1042
        )
1043
    )
1044
    for example in examples:
1045
      for label_feature in expected_features:
1046
        self.assertIn(label_feature, example.features.feature)
1047
      # pylint: disable=g-explicit-length-test
1048
      self.assertNotEmpty(dv_utils.example_alt_alleles_indices(example))
1049
1050
    # Check that the variants in the examples are good.
1051
    variants = [dv_utils.example_variant(x) for x in examples]
1052
    self.verify_variants(variants, region, options, is_gvcf=False)
1053
1054
    # In DeepTrio, path_to_output_examples can be pointing to the ones with
1055
    # the suffixes (such as _child). In that case, we pass in the original
1056
    # examples path to the `examples_filename` arg.
1057
    # If `examples_filename` arg, directly use `path_to_output_examples`.
1058
    if examples:
1059
      if examples_filename is None:
1060
        examples_filename = path_to_output_examples
1061
      self.sanity_check_example_info_json(
1062
          examples[0], examples_filename, options.task_id
1063
      )
1064
    return examples
1065
1066
1067
class MakeExamplesUnitTest(parameterized.TestCase):
1068
1069
  def test_read_write_run_info(self):
1070
    def _read_lines(path):
1071
      with open(path) as fin:
1072
        return list(fin.readlines())
1073
1074
    golden_actual = make_examples_core.read_make_examples_run_info(
1075
        testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO
1076
    )
1077
    # We don't really want to inject too much knowledge about the golden right
1078
    # here, so we only use a minimal test that (a) the run_info_filename is
1079
    # a non-empty string and (b) the number of candidates sites in the labeling
1080
    # metrics field is greater than 0. Any reasonable golden output will have at
1081
    # least one candidate variant, and the reader should have filled in the
1082
    # value.
1083
    self.assertNotEmpty(golden_actual.options.run_info_filename)
1084
    self.assertEqual(
1085
        golden_actual.labeling_metrics.n_candidate_variant_sites,
1086
        testdata.N_GOLDEN_TRAINING_EXAMPLES,
1087
    )
1088
1089
    # Check that reading + writing the data produces the same lines:
1090
    tmp_output = test_utils.test_tmpfile('written_run_info.pbtxt')
1091
    make_examples_core.write_make_examples_run_info(golden_actual, tmp_output)
1092
    print('*' * 100)
1093
    print(_read_lines(tmp_output))
1094
    print('*' * 100)
1095
    self.assertEqual(
1096
        _read_lines(testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO),
1097
        _read_lines(tmp_output),
1098
    )
1099
1100
  @flagsaver.flagsaver
1101
  def test_keep_duplicates(self):
1102
    FLAGS.keep_duplicates = True
1103
    FLAGS.ref = testdata.CHR20_FASTA
1104
    FLAGS.reads = testdata.HG001_CHR20_BAM
1105
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1106
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1107
    FLAGS.sample_name = 'child'
1108
    FLAGS.sample_name_to_train = 'child'
1109
    FLAGS.sample_name_parent1 = 'parent1'
1110
    FLAGS.sample_name_parent2 = 'parent2'
1111
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1112
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1113
    FLAGS.mode = 'training'
1114
    FLAGS.examples = ''
1115
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1116
    options = make_examples.default_options(add_flags=True)
1117
    self.assertEqual(
1118
        options.pic_options.read_requirements.keep_duplicates, True
1119
    )
1120
1121
  @flagsaver.flagsaver
1122
  def test_keep_supplementary_alignments(self):
1123
    FLAGS.keep_supplementary_alignments = True
1124
    FLAGS.ref = testdata.CHR20_FASTA
1125
    FLAGS.reads = testdata.HG001_CHR20_BAM
1126
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1127
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1128
    FLAGS.sample_name = 'child'
1129
    FLAGS.sample_name_to_train = 'child'
1130
    FLAGS.sample_name_parent1 = 'parent1'
1131
    FLAGS.sample_name_parent2 = 'parent2'
1132
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1133
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1134
    FLAGS.mode = 'training'
1135
    FLAGS.examples = ''
1136
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1137
    options = make_examples.default_options(add_flags=True)
1138
    self.assertEqual(
1139
        options.pic_options.read_requirements.keep_supplementary_alignments,
1140
        True,
1141
    )
1142
1143
  @flagsaver.flagsaver
1144
  def test_keep_secondary_alignments(self):
1145
    FLAGS.keep_secondary_alignments = True
1146
    FLAGS.ref = testdata.CHR20_FASTA
1147
    FLAGS.reads = testdata.HG001_CHR20_BAM
1148
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1149
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1150
    FLAGS.sample_name = 'child'
1151
    FLAGS.sample_name_to_train = 'child'
1152
    FLAGS.sample_name_parent1 = 'parent1'
1153
    FLAGS.sample_name_parent2 = 'parent2'
1154
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1155
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1156
    FLAGS.mode = 'training'
1157
    FLAGS.examples = ''
1158
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1159
    options = make_examples.default_options(add_flags=True)
1160
    self.assertEqual(
1161
        options.pic_options.read_requirements.keep_secondary_alignments, True
1162
    )
1163
1164
  @flagsaver.flagsaver
1165
  def test_min_base_quality(self):
1166
    FLAGS.min_base_quality = 5
1167
    FLAGS.ref = testdata.CHR20_FASTA
1168
    FLAGS.reads = testdata.HG001_CHR20_BAM
1169
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1170
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1171
    FLAGS.sample_name = 'child'
1172
    FLAGS.sample_name_to_train = 'child'
1173
    FLAGS.sample_name_parent1 = 'parent1'
1174
    FLAGS.sample_name_parent2 = 'parent2'
1175
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1176
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1177
    FLAGS.mode = 'training'
1178
    FLAGS.examples = ''
1179
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1180
    options = make_examples.default_options(add_flags=True)
1181
    self.assertEqual(options.pic_options.read_requirements.min_base_quality, 5)
1182
1183
  @flagsaver.flagsaver
1184
  def test_min_mapping_quality(self):
1185
    FLAGS.min_mapping_quality = 15
1186
    FLAGS.ref = testdata.CHR20_FASTA
1187
    FLAGS.reads = testdata.HG001_CHR20_BAM
1188
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1189
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1190
    FLAGS.sample_name = 'child'
1191
    FLAGS.sample_name_to_train = 'child'
1192
    FLAGS.sample_name_parent1 = 'parent1'
1193
    FLAGS.sample_name_parent2 = 'parent2'
1194
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1195
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1196
    FLAGS.mode = 'training'
1197
    FLAGS.examples = ''
1198
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1199
    options = make_examples.default_options(add_flags=True)
1200
    self.assertEqual(
1201
        options.pic_options.read_requirements.min_mapping_quality, 15
1202
    )
1203
1204
  @flagsaver.flagsaver
1205
  def test_default_options_with_training_random_emit_ref_sites(self):
1206
    FLAGS.ref = testdata.CHR20_FASTA
1207
    FLAGS.reads = testdata.HG001_CHR20_BAM
1208
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1209
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1210
    FLAGS.sample_name = 'child'
1211
    FLAGS.sample_name_to_train = 'child'
1212
    FLAGS.sample_name_parent1 = 'parent1'
1213
    FLAGS.sample_name_parent2 = 'parent2'
1214
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1215
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1216
    FLAGS.mode = 'training'
1217
    FLAGS.examples = ''
1218
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1219
1220
    FLAGS.training_random_emit_ref_sites = 0.3
1221
    options = make_examples.default_options(add_flags=True)
1222
    self.assertAlmostEqual(
1223
        options.sample_options[
1224
            1
1225
        ].variant_caller_options.fraction_reference_sites_to_emit,
1226
        0.3,
1227
    )
1228
1229
  @flagsaver.flagsaver
1230
  def test_default_options_without_training_random_emit_ref_sites(self):
1231
    FLAGS.ref = testdata.CHR20_FASTA
1232
    FLAGS.reads = testdata.HG001_CHR20_BAM
1233
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1234
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1235
    FLAGS.sample_name = 'child'
1236
    FLAGS.sample_name_to_train = 'child'
1237
    FLAGS.sample_name_parent1 = 'parent1'
1238
    FLAGS.sample_name_parent2 = 'parent2'
1239
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1240
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1241
    FLAGS.mode = 'training'
1242
    FLAGS.examples = ''
1243
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1244
1245
    options = make_examples.default_options(add_flags=True)
1246
    # In proto3, there is no way to check presence of scalar field:
1247
    # redacted
1248
    # As an approximation, we directly check that the value should be exactly 0.
1249
    self.assertEqual(
1250
        options.sample_options[
1251
            1
1252
        ].variant_caller_options.fraction_reference_sites_to_emit,
1253
        0.0,
1254
    )
1255
1256
  @flagsaver.flagsaver
1257
  def test_confident_regions(self):
1258
    FLAGS.ref = testdata.CHR20_FASTA
1259
    FLAGS.reads = testdata.HG001_CHR20_BAM
1260
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1261
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1262
    FLAGS.sample_name = 'child'
1263
    FLAGS.sample_name_to_train = 'child'
1264
    FLAGS.sample_name_parent1 = 'parent1'
1265
    FLAGS.sample_name_parent2 = 'parent2'
1266
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1267
    FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
1268
    FLAGS.mode = 'training'
1269
    FLAGS.examples = ''
1270
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1271
1272
    options = make_examples.default_options(add_flags=True)
1273
    confident_regions = make_examples_core.read_confident_regions(options)
1274
1275
    # Our expected intervals, inlined from CONFIDENT_REGIONS_BED.
1276
    expected = _from_literals_list([
1277
        '20:10000847-10002407',
1278
        '20:10002521-10004171',
1279
        '20:10004274-10004964',
1280
        '20:10004995-10006386',
1281
        '20:10006410-10007800',
1282
        '20:10007825-10008018',
1283
        '20:10008044-10008079',
1284
        '20:10008101-10008707',
1285
        '20:10008809-10008897',
1286
        '20:10009003-10009791',
1287
        '20:10009934-10010531',
1288
    ])
1289
    # Our confident regions should be exactly those found in the BED file.
1290
    self.assertCountEqual(expected, list(confident_regions))
1291
1292
  @parameterized.parameters(
1293
      ({'examples': ('foo', 'foo')},),
1294
      ({'examples': ('foo', 'foo'), 'gvcf': ('bar', 'bar')},),
1295
      ({'examples': ('foo@10', 'foo-00000-of-00010')},),
1296
      ({'task': (0, 0), 'examples': ('foo@10', 'foo-00000-of-00010')},),
1297
      ({'task': (1, 1), 'examples': ('foo@10', 'foo-00001-of-00010')},),
1298
      (
1299
          {
1300
              'task': (1, 1),
1301
              'examples': ('foo@10', 'foo-00001-of-00010'),
1302
              'gvcf': ('bar@10', 'bar-00001-of-00010'),
1303
          },
1304
      ),
1305
      (
1306
          {
1307
              'task': (1, 1),
1308
              'examples': ('foo@10', 'foo-00001-of-00010'),
1309
              'gvcf': ('bar@10', 'bar-00001-of-00010'),
1310
              'candidates': ('baz@10', 'baz-00001-of-00010'),
1311
          },
1312
      ),
1313
  )
1314
  @flagsaver.flagsaver
1315
  def test_sharded_outputs1(self, settings):
1316
    # Set all of the requested flag values.
1317
    for name, (flag_val, _) in settings.items():
1318
      setattr(FLAGS, name, flag_val)
1319
1320
    FLAGS.mode = 'training'
1321
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1322
    FLAGS.reads = ''
1323
    FLAGS.ref = ''
1324
    options = make_examples.default_options(add_flags=True)
1325
1326
    # Check all of the flags.
1327
    for name, option_val in [
1328
        ('examples', options.examples_filename),
1329
        ('candidates', options.candidates_filename),
1330
        ('gvcf', options.gvcf_filename),
1331
    ]:
1332
      expected = settings[name][1] if name in settings else ''
1333
      self.assertEqual(expected, option_val)
1334
1335
  def test_catches_bad_argv(self):
1336
    with (
1337
        mock.patch.object(logging, 'error') as mock_logging,
1338
        mock.patch.object(sys, 'exit') as mock_exit,
1339
    ):
1340
      make_examples.main(['make_examples.py', 'extra_arg'])
1341
    mock_logging.assert_called_once_with(
1342
        'Command line parsing failure: make_examples does not accept '
1343
        'positional arguments but some are present on the command line: '
1344
        "\"['make_examples.py', 'extra_arg']\"."
1345
    )
1346
    mock_exit.assert_called_once_with(errno.ENOENT)
1347
1348
  @flagsaver.flagsaver
1349
  def test_catches_bad_flags(self):
1350
    # Set all of the requested flag values.
1351
    region = ranges.parse_literal('20:10,000,000-10,010,000')
1352
    FLAGS.ref = testdata.CHR20_FASTA
1353
    FLAGS.reads = testdata.HG001_CHR20_BAM
1354
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1355
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1356
    FLAGS.sample_name = 'child'
1357
    FLAGS.sample_name_to_train = 'child'
1358
    FLAGS.sample_name_parent1 = 'parent1'
1359
    FLAGS.sample_name_parent2 = 'parent2'
1360
    FLAGS.candidates = test_utils.test_tmpfile('vsc.tfrecord')
1361
    FLAGS.examples = test_utils.test_tmpfile('examples.tfrecord')
1362
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1363
    FLAGS.regions = [ranges.to_literal(region)]
1364
    FLAGS.partition_size = 1000
1365
    FLAGS.mode = 'training'
1366
    FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
1367
    # This is the bad flag.
1368
    FLAGS.confident_regions = ''
1369
1370
    with (
1371
        mock.patch.object(logging, 'error') as mock_logging,
1372
        mock.patch.object(sys, 'exit') as mock_exit,
1373
    ):
1374
      make_examples.main(['make_examples.py'])
1375
    mock_logging.assert_called_once_with(
1376
        'confident_regions is required when in training mode.'
1377
    )
1378
    mock_exit.assert_called_once_with(errno.ENOENT)
1379
1380
  @flagsaver.flagsaver
1381
  def test_regions_and_exclude_regions_flags_with_trio_options(self):
1382
    FLAGS.mode = 'calling'
1383
    FLAGS.ref = testdata.CHR20_FASTA
1384
    FLAGS.reads = testdata.HG001_CHR20_BAM
1385
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1386
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1387
    FLAGS.sample_name = 'child'
1388
    FLAGS.sample_name_to_train = 'child'
1389
    FLAGS.sample_name_parent1 = 'parent1'
1390
    FLAGS.sample_name_parent2 = 'parent2'
1391
    FLAGS.regions = '20:10,000,000-11,000,000'
1392
    FLAGS.examples = 'examples.tfrecord'
1393
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1394
    FLAGS.exclude_regions = '20:10,010,000-10,100,000'
1395
1396
    options = make_examples.default_options(add_flags=True)
1397
    _, regions_from_options = (
1398
        make_examples_core.processing_regions_from_options(options)
1399
    )
1400
    self.assertCountEqual(
1401
        list(ranges.RangeSet(regions_from_options)),
1402
        _from_literals_list(
1403
            ['20:10,000,000-10,009,999', '20:10,100,001-11,000,000']
1404
        ),
1405
    )
1406
1407
  @flagsaver.flagsaver
1408
  def test_incorrect_empty_regions_with_trio_options(self):
1409
    FLAGS.mode = 'calling'
1410
    FLAGS.ref = testdata.CHR20_FASTA
1411
    FLAGS.reads = testdata.HG001_CHR20_BAM
1412
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1413
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1414
    FLAGS.sample_name = 'child'
1415
    FLAGS.sample_name_to_train = 'child'
1416
    FLAGS.sample_name_parent1 = 'parent1'
1417
    FLAGS.sample_name_parent2 = 'parent2'
1418
    # Deliberately incorrect contig name.
1419
    FLAGS.regions = 'xxx20:10,000,000-11,000,000'
1420
    FLAGS.examples = 'examples.tfrecord'
1421
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1422
1423
    options = make_examples.default_options(add_flags=True)
1424
    with self.assertRaisesRegex(ValueError, 'The regions to call is empty.'):
1425
      make_examples_core.processing_regions_from_options(options)
1426
1427
1428
class RegionProcessorTest(parameterized.TestCase):
1429
1430
  def setUp(self):
1431
    super(RegionProcessorTest, self).setUp()
1432
    self.region = ranges.parse_literal('20:10,000,000-10,000,100')
1433
1434
    FLAGS.reads = ''
1435
    self.options = make_examples.default_options(add_flags=False)
1436
    self.options.reference_filename = testdata.CHR20_FASTA
1437
    self.options.truth_variants_filename = testdata.TRUTH_VARIANTS_VCF
1438
    self.options.mode = deepvariant_pb2.MakeExamplesOptions.TRAINING
1439
1440
    self.ref_reader = fasta.IndexedFastaReader(self.options.reference_filename)
1441
    self.default_shape = [5, 5, 7]
1442
    self.processor = make_examples_core.RegionProcessor(self.options)
1443
    self.mock_init = self.add_mock('_initialize')
1444
    for sample in self.processor.samples:
1445
      sample.in_memory_sam_reader = mock.Mock()
1446
1447
  def add_mock(self, name, retval='dontadd', side_effect='dontadd'):
1448
    patcher = mock.patch.object(self.processor, name, autospec=True)
1449
    self.addCleanup(patcher.stop)
1450
    mocked = patcher.start()
1451
    if retval != 'dontadd':
1452
      mocked.return_value = retval
1453
    if side_effect != 'dontadd':
1454
      mocked.side_effect = side_effect
1455
    return mocked
1456
1457
  @parameterized.parameters([
1458
      deepvariant_pb2.MakeExamplesOptions.TRAINING,
1459
      deepvariant_pb2.MakeExamplesOptions.CALLING,
1460
  ])
1461
  def test_process_keeps_ordering_of_candidates_and_examples(self, mode):
1462
    self.processor.options.mode = mode
1463
1464
    r1, r2 = mock.Mock(), mock.Mock()
1465
    c1, c2 = mock.Mock(), mock.Mock()
1466
    self.add_mock('region_reads_norealign', retval=[r1, r2])
1467
    self.add_mock('candidates_in_region', retval=({'child': [c1, c2]}, {}, {}))
1468
    candidates_dict, gvcfs_dict, runtimes, read_phases = self.processor.process(
1469
        self.region
1470
    )
1471
    self.assertEqual({'child': [c1, c2]}, candidates_dict)
1472
    self.assertEqual({}, gvcfs_dict)
1473
    self.assertEqual({}, read_phases)
1474
    self.assertIsInstance(runtimes, dict)
1475
1476
    in_memory_sam_reader = self.processor.samples[1].in_memory_sam_reader
1477
    in_memory_sam_reader.replace_reads.assert_called_once_with([r1, r2])
1478
1479
  @flagsaver.flagsaver
1480
  def test_use_original_quality_scores_without_parse_sam_aux_fields(self):
1481
    FLAGS.mode = 'calling'
1482
    FLAGS.ref = testdata.CHR20_FASTA
1483
    FLAGS.reads = testdata.HG001_CHR20_BAM
1484
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1485
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1486
    FLAGS.sample_name = 'child'
1487
    FLAGS.sample_name_to_train = 'child'
1488
    FLAGS.sample_name_parent1 = 'parent1'
1489
    FLAGS.sample_name_parent2 = 'parent2'
1490
    FLAGS.examples = 'examples.tfrecord'
1491
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1492
    FLAGS.use_original_quality_scores = True
1493
    FLAGS.parse_sam_aux_fields = False
1494
1495
    with self.assertRaisesRegex(
1496
        Exception,
1497
        (
1498
            'If --use_original_quality_scores is set then '
1499
            '--parse_sam_aux_fields must be set too.'
1500
        ),
1501
    ):
1502
      make_examples.default_options(add_flags=True)
1503
1504
  @parameterized.parameters(
1505
      dict(height_parent=10, height_child=9),
1506
      dict(height_parent=9, height_child=10),
1507
      dict(height_parent=150, height_child=101),
1508
      dict(height_parent=101, height_child=170),
1509
  )
1510
  @flagsaver.flagsaver
1511
  def test_image_heights(self, height_parent, height_child):
1512
    FLAGS.pileup_image_height_parent = height_parent
1513
    FLAGS.pileup_image_height_child = height_child
1514
    FLAGS.mode = 'calling'
1515
    FLAGS.ref = testdata.CHR20_FASTA
1516
    FLAGS.reads = testdata.HG001_CHR20_BAM
1517
    FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
1518
    FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
1519
    FLAGS.sample_name = 'child'
1520
    FLAGS.sample_name_to_train = 'child'
1521
    FLAGS.sample_name_parent1 = 'parent1'
1522
    FLAGS.sample_name_parent2 = 'parent2'
1523
    FLAGS.examples = 'examples.tfrecord'
1524
    FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE)
1525
1526
    options = make_examples.default_options(add_flags=True)
1527
    with self.assertRaisesRegex(
1528
        Exception, 'Total pileup image heights must be between 75-362.'
1529
    ):
1530
      make_examples.check_options_are_valid(options)
1531
1532
1533
if __name__ == '__main__':
1534
  absltest.main()