--- a +++ b/deeptrio/make_examples_test.py @@ -0,0 +1,1534 @@ +# Copyright 2020 Google LLC. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +"""Tests for deeptrio.make_examples.""" + +import errno +import json +import platform +import sys +from unittest import mock + +from absl import flags +from absl import logging +from absl.testing import absltest +from absl.testing import flagsaver +from absl.testing import parameterized +from etils import epath +import numpy as np + +from deeptrio import make_examples +from deeptrio import testdata +from deepvariant import dv_constants +from deepvariant import dv_utils +from deepvariant import make_examples_core +from deepvariant.protos import deepvariant_pb2 +from tensorflow.python.platform import gfile +from third_party.nucleus.io import fasta +from third_party.nucleus.io import sharded_file_utils +from third_party.nucleus.io import tfrecord +from third_party.nucleus.io import vcf +from third_party.nucleus.protos import reference_pb2 +from third_party.nucleus.protos import variants_pb2 +from third_party.nucleus.testing import test_utils +from third_party.nucleus.util import ranges +from third_party.nucleus.util import variant_utils +from third_party.nucleus.util import variantcall_utils +from third_party.nucleus.util import vcf_constants + +FLAGS = flags.FLAGS + +# Dictionary mapping keys to decoders for decode_example function. +_EXAMPLE_DECODERS = { + 'locus': dv_utils.example_locus, + 'alt_allele_indices/encoded': dv_utils.example_alt_alleles_indices, + 'image/encoded': dv_utils.example_encoded_image, + 'variant/encoded': dv_utils.example_variant, + 'variant_type': dv_utils.example_variant_type, + 'label': dv_utils.example_label, + 'image/shape': dv_utils.example_image_shape, + 'sequencing_type': dv_utils.example_sequencing_type, + 'denovo_label': dv_utils.example_denovo_label, +} + + +def decode_example(example): + """Decodes a tf.Example from DeepVariant into a dict of Pythonic structures. + + Args: + example: tf.Example proto. The example to make into a dictionary. + + Returns: + A python dictionary with key/value pairs for each of the fields of example, + with each value decoded as needed into Python structures like protos, list, + etc. + + Raises: + KeyError: If example contains a feature without a known decoder. + """ + as_dict = {} + for key in example.features.feature: + if key not in _EXAMPLE_DECODERS: + raise KeyError('Unexpected example key', key) + as_dict[key] = _EXAMPLE_DECODERS[key](example) + return as_dict + + +def setUpModule(): + logging.set_verbosity(logging.FATAL) + testdata.init() + + +def _make_contigs(specs): + """Makes ContigInfo protos from specs. + + Args: + specs: A list of 2- or 3-tuples. All tuples should be of the same length. If + 2-element, these should be the name and length in basepairs of each + contig, and their pos_in_fasta will be set to their index in the list. If + the 3-element, the tuple should contain name, length, and pos_in_fasta. + + Returns: + A list of ContigInfo protos, one for each spec in specs. + """ + if specs and len(specs[0]) == 3: + return [ + reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i) + for name, length, i in specs + ] + else: + return [ + reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i) + for i, (name, length) in enumerate(specs) + ] + + +def _from_literals_list(literals, contig_map=None): + """Makes a list of Range objects from literals.""" + return ranges.parse_literals(literals, contig_map) + + +def _from_literals(literals, contig_map=None): + """Makes a RangeSet of intervals from literals.""" + return ranges.RangeSet.from_regions(literals, contig_map) + + +def _sharded(basename, num_shards=None): + if num_shards: + return basename + '@' + str(num_shards) + else: + return basename + + +class MakeExamplesEnd2EndTest(parameterized.TestCase): + + # Golden sets are created with + # learning/genomics/internal/create_golden_deep_trio.sh + @parameterized.parameters( + # All tests are run with fast_pass_aligner enabled. There are no + # golden sets version for ssw realigner. + dict(mode='calling', num_shards=0), + dict(mode='calling', num_shards=3), + dict(mode='candidate_sweep', num_shards=0), + dict(mode='candidate_sweep', num_shards=3), + dict( + mode='training', num_shards=0, labeler_algorithm='haplotype_labeler' + ), + dict( + mode='training', num_shards=3, labeler_algorithm='haplotype_labeler' + ), + dict( + mode='training', num_shards=0, labeler_algorithm='positional_labeler' + ), + dict( + mode='training', num_shards=3, labeler_algorithm='positional_labeler' + ), + ) + @flagsaver.flagsaver + def test_make_examples_end2end( + self, mode, num_shards, labeler_algorithm=None, use_fast_pass_aligner=True + ): + self.assertIn(mode, {'calling', 'training', 'candidate_sweep'}) + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.write_run_info = True + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile( + _sharded('vsc.tfrecord', num_shards) + ) + FLAGS.examples = test_utils.test_tmpfile( + _sharded('examples.tfrecord', num_shards) + ) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + child_examples = test_utils.test_tmpfile( + _sharded('examples_child.tfrecord', num_shards) + ) + if mode == 'candidate_sweep': + FLAGS.candidate_positions = test_utils.test_tmpfile( + _sharded('candidate_positions', num_shards) + ) + candidate_positions = test_utils.test_tmpfile( + _sharded('candidate_positions', num_shards) + ) + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.partition_size = 1000 + FLAGS.mode = mode + FLAGS.gvcf_gq_binsize = 5 + FLAGS.use_fast_pass_aligner = use_fast_pass_aligner + if labeler_algorithm is not None: + FLAGS.labeler_algorithm = labeler_algorithm + + if mode == 'calling': + FLAGS.gvcf = test_utils.test_tmpfile( + _sharded('gvcf.tfrecord', num_shards) + ) + child_gvcf = test_utils.test_tmpfile( + _sharded('gvcf_child.tfrecord', num_shards) + ) + child_candidates = test_utils.test_tmpfile( + _sharded('vsc_child.tfrecord', num_shards) + ) + else: + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + child_candidates = test_utils.test_tmpfile( + _sharded('vsc.tfrecord', num_shards) + ) + + if mode == 'candidate_sweep': + golden_candidate_positions = _sharded( + testdata.GOLDEN_CANDIDATE_POSITIONS, num_shards + ) + for task_id in range(max(num_shards, 1)): + FLAGS.task = task_id + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + + # Check that our run_info proto contains the basic fields we'd expect: + # (a) our options are written to the run_info.options field. + run_info = make_examples_core.read_make_examples_run_info( + options.run_info_filename + ) + self.assertEqual(run_info.options, options) + # (b) run_info.resource_metrics is present and contains our hostname. + self.assertTrue(run_info.HasField('resource_metrics')) + self.assertEqual(run_info.resource_metrics.host_name, platform.node()) + + # For candidate_sweep mode we verify that candidate positions match + # golden set exactly. + if mode == 'candidate_sweep': + _, candidates_path = sharded_file_utils.resolve_filespecs( + task_id, candidate_positions + ) + _, gold_candidates_path = sharded_file_utils.resolve_filespecs( + task_id, golden_candidate_positions + ) + self.verify_candidate_positions(candidates_path, gold_candidates_path) + + # In candidate_sweep mode the test stops here. + if mode == 'candidate_sweep': + return + + # Test that our candidates are reasonable, calling specific helper functions + # to check lots of properties of the output. + candidates = sorted( + tfrecord.read_tfrecords( + child_candidates, + proto=deepvariant_pb2.DeepVariantCall, + compression_type='GZIP', + ), + key=lambda c: variant_utils.variant_range_tuple(c.variant), + ) + self.verify_deepvariant_calls(candidates, options) + self.verify_variants( + [call.variant for call in candidates], region, options, is_gvcf=False + ) + + # Verify that the variants in the examples are all good. + if mode == 'calling': + examples = self.verify_examples( + child_examples, + region, + options, + verify_labels=False, + examples_filename=FLAGS.examples, + ) + if mode == 'training': + examples = self.verify_examples( + FLAGS.examples, region, options, verify_labels=True + ) + example_variants = [dv_utils.example_variant(ex) for ex in examples] + self.verify_variants(example_variants, region, options, is_gvcf=False) + + # Verify the integrity of the examples and then check that they match our + # golden labeled examples. Note we expect the order for both training and + # calling modes to produce deterministic order because we fix the random + # seed. + if mode == 'calling': + golden_file = _sharded(testdata.GOLDEN_CALLING_EXAMPLES, num_shards) + else: + golden_file = _sharded(testdata.GOLDEN_TRAINING_EXAMPLES, num_shards) + self.assertDeepVariantExamplesEqual( + examples, + list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')), + ) + + if mode == 'calling': + nist_reader = vcf.VcfReader(testdata.TRUTH_VARIANTS_VCF) + nist_variants = list(nist_reader.query(region)) + self.verify_nist_concordance(example_variants, nist_variants) + + # Check the quality of our generated gvcf file. + gvcfs = variant_utils.sorted_variants( + tfrecord.read_tfrecords( + child_gvcf, proto=variants_pb2.Variant, compression_type='GZIP' + ) + ) + self.verify_variants(gvcfs, region, options, is_gvcf=True) + self.verify_contiguity(gvcfs, region) + gvcf_golden_file = _sharded( + testdata.GOLDEN_POSTPROCESS_GVCF_INPUT, num_shards + ) + expected_gvcfs = list( + tfrecord.read_tfrecords( + gvcf_golden_file, + proto=variants_pb2.Variant, + compression_type='GZIP', + ) + ) + + + self.assertCountEqual(gvcfs, expected_gvcfs) + + if ( + mode == 'training' + and num_shards == 0 + and labeler_algorithm != 'positional_labeler' + ): + # The positional labeler doesn't track metrics, so don't try to read them + # in when that's the mode. + self.assertEqual( + make_examples_core.read_make_examples_run_info( + testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO + ).labeling_metrics, + run_info.labeling_metrics, + ) + + @parameterized.parameters( + dict( + denovo_test=False, + expected_denovo_variants=0, + ), + dict( + denovo_test=True, + expected_denovo_variants=3, + ), + ) + @flagsaver.flagsaver + def test_make_examples_ont_end2end( + self, + denovo_test: bool, + expected_denovo_variants: int, + ): + """Test end to end for long ONT reads with phasing enabled. + + Args: + denovo_test: If true, denovo parameters will be set. + expected_denovo_variants: Total number of denovo examples expected. + + This test runs ONT end to end and compares the output with the golden + output. This test is introduced because previously in training mode the + non training sample would not be phased. So this now tests to make sure + all of the training examples are phased correctly. + """ + region = ranges.parse_literal('chr20:5050000-5075000') + FLAGS.write_run_info = True + FLAGS.ref = testdata.GRCH38_CHR0_FASTA + FLAGS.reads = testdata.ONT_HG002_BAM + FLAGS.reads_parent1 = testdata.ONT_HG003_BAM + FLAGS.reads_parent2 = testdata.ONT_HG004_BAM + FLAGS.confident_regions = testdata.HG002_HIGH_CONFIDENCE_BED + FLAGS.truth_variants = testdata.HG002_HIGH_CONFIDENCE_VCF + FLAGS.sample_name = 'HG002' + FLAGS.sample_name_to_train = 'HG002' + FLAGS.sample_name_parent1 = 'HG003' + FLAGS.sample_name_parent2 = 'HG004' + FLAGS.alt_aligned_pileup = 'diff_channels' + FLAGS.min_mapping_quality = 1 + FLAGS.mode = 'training' + FLAGS.parse_sam_aux_fields = True + FLAGS.partition_size = 25000 + FLAGS.phase_reads = True + FLAGS.pileup_image_height_child = 100 + FLAGS.pileup_image_height_parent = 100 + FLAGS.pileup_image_width = 199 + FLAGS.realign_reads = False + FLAGS.skip_parent_calling = True + FLAGS.sort_by_haplotypes = True + FLAGS.track_ref_reads = True + FLAGS.vsc_min_fraction_indels = 0.12 + FLAGS.vsc_min_fraction_snps = 0.1 + num_shards = 0 + FLAGS.examples = test_utils.test_tmpfile( + _sharded('examples.tfrecord', num_shards) + ) + FLAGS.channel_list = ','.join( + dv_constants.PILEUP_DEFAULT_CHANNELS + ['haplotype'] + ) + FLAGS.regions = [ranges.to_literal(region)] + golden_file = _sharded(testdata.GOLDEN_ONT_MAKE_EXAMPLES_OUTPUT, num_shards) + FLAGS.denovo_regions = None + if denovo_test: + # If denovo test is enabled, then set the parameters for denovo testing. + golden_file = _sharded( + testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT, num_shards + ) + FLAGS.write_run_info = True + FLAGS.denovo_regions = testdata.HG002_DENOVO_BED + + for task_id in range(max(num_shards, 1)): + FLAGS.task = task_id + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + + examples = self.verify_examples( + FLAGS.examples, region, options, verify_labels=True + ) + + self.assertDeepVariantExamplesEqual( + examples, + list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')), + ) + if denovo_test: + # Check total number of denovo examples. + total_denovo = sum( + [ + 1 + for example in examples + if dv_utils.example_denovo_label(example) + ] + ) + self.assertEqual( + total_denovo, + expected_denovo_variants, + msg='ONT denovo golden test: denovo variants count.', + ) + # Read the runinfo file + runinfo = make_examples_core.read_make_examples_run_info( + FLAGS.examples + '.run_info.pbtxt' + ) + golden_runinfo = make_examples_core.read_make_examples_run_info( + testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT + '.run_info.pbtxt' + ) + self.assertEqual( + runinfo.stats.num_examples, + golden_runinfo.stats.num_examples, + msg='ONT denovo golden test: Run info comparison num_examples.', + ) + self.assertEqual( + runinfo.stats.num_denovo, + golden_runinfo.stats.num_denovo, + msg='ONT denovo golden test: Run info comparison num_denovo.', + ) + self.assertEqual( + runinfo.stats.num_nondenovo, + golden_runinfo.stats.num_nondenovo, + msg='ONT denovo golden test: Run info comparison num_nondenovo.', + ) + + # Golden sets are created with learning/genomics/internal/create_golden.sh + @flagsaver.flagsaver + def test_make_examples_training_end2end_with_customized_classes_labeler(self): + FLAGS.labeler_algorithm = 'customized_classes_labeler' + FLAGS.customized_classes_labeler_classes_list = 'ref,class1,class2' + FLAGS.customized_classes_labeler_info_field_name = 'type' + region = ranges.parse_literal('20:10,000,000-10,004,000') + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) + FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.partition_size = 1000 + FLAGS.mode = 'training' + FLAGS.gvcf_gq_binsize = 5 + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF_WITH_TYPES + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + golden_file = _sharded(testdata.CUSTOMIZED_CLASSES_GOLDEN_TRAINING_EXAMPLES) + # Verify that the variants in the examples are all good. + examples = self.verify_examples( + FLAGS.examples, region, options, verify_labels=True + ) + self.assertDeepVariantExamplesEqual( + examples, + list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')), + ) + + # Golden sets are created with learning/genomics/internal/create_golden.sh + @flagsaver.flagsaver + def test_make_examples_training_end2end_with_alt_aligned_pileup(self): + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) + FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_DEFAULT_CHANNELS) + FLAGS.partition_size = 1000 + FLAGS.mode = 'training' + FLAGS.gvcf_gq_binsize = 5 + + # The following 4 lines are added. + FLAGS.alt_aligned_pileup = 'diff_channels' + FLAGS.pileup_image_height_child = 60 + FLAGS.pileup_image_height_parent = 40 + FLAGS.pileup_image_width = 199 + + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + golden_file = _sharded(testdata.ALT_ALIGNED_PILEUP_GOLDEN_TRAINING_EXAMPLES) + # Verify that the variants in the examples are all good. + examples = self.verify_examples( + FLAGS.examples, region, options, verify_labels=True + ) + self.assertDeepVariantExamplesEqual( + examples, + list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')), + ) + # Pileup image should now have 8 channels. + # Height should be 60 + 40 * 2 = 140. + self.assertEqual(decode_example(examples[0])['image/shape'], [140, 199, 8]) + + @flagsaver.flagsaver + def test_make_examples_compare_realignment_modes(self): + def _run_with_realignment_mode(enable_joint_realignment, name): + FLAGS.enable_joint_realignment = enable_joint_realignment + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile(f'{name}.vsc.tfrecord') + FLAGS.examples = test_utils.test_tmpfile(f'{name}.examples.tfrecord') + FLAGS.channel_list = ','.join( + dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE + ) + child_examples = test_utils.test_tmpfile( + f'{name}_child.examples.tfrecord' + ) + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.partition_size = 1000 + FLAGS.mode = 'calling' + FLAGS.gvcf = test_utils.test_tmpfile(f'{name}.gvcf.tfrecord') + # child_gvcf = test_utils.test_tmpfile(f'{name}.gvcf_child.tfrecord') + # child_candidates = test_utils.test_tmpfile(f'{name}.vsc_child.tfrecord') + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + + examples = self.verify_examples( + child_examples, + region, + options, + verify_labels=False, + examples_filename=FLAGS.examples, + ) + return examples + + examples1 = _run_with_realignment_mode(False, 'ex1') + examples2 = _run_with_realignment_mode(True, 'ex2') + self.assertNotEmpty(examples1) + self.assertNotEmpty(examples2) + # The assumption is just that these two lists of examples should be + # different. In this case, it happens to be that we got different numbers + # of examples: + self.assertNotEmpty(examples1) + self.assertDeepVariantExamplesNotEqual(examples1, examples2) + + @parameterized.parameters( + dict(select_types=None, expected_count=79), + dict(select_types='all', expected_count=79), + dict(select_types='snps', expected_count=64), + dict(select_types='indels', expected_count=12), + dict(select_types='snps indels', expected_count=76), + dict(select_types='multi-allelics', expected_count=3), + dict(select_types=None, keep_legacy_behavior=True, expected_count=79), + dict(select_types='all', keep_legacy_behavior=True, expected_count=79), + dict(select_types='snps', keep_legacy_behavior=True, expected_count=64), + dict(select_types='indels', keep_legacy_behavior=True, expected_count=11), + dict( + select_types='snps indels', + keep_legacy_behavior=True, + expected_count=75, + ), + dict( + select_types='multi-allelics', + keep_legacy_behavior=True, + expected_count=4, + ), + ) + @flagsaver.flagsaver + def test_make_examples_with_variant_selection( + self, select_types, expected_count, keep_legacy_behavior=False + ): + if select_types is not None: + FLAGS.select_variant_types = select_types + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord')) + child_candidates = test_utils.test_tmpfile(_sharded('vsc_child.tfrecord')) + FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.partition_size = 1000 + FLAGS.mode = 'calling' + FLAGS.keep_legacy_allele_counter_behavior = keep_legacy_behavior + + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + + candidates = list( + tfrecord.read_tfrecords(child_candidates, compression_type='GZIP') + ) + self.assertLen(candidates, expected_count) + + @parameterized.parameters( + dict( + mode='calling', which_parent='parent1', sample_name_to_train='child' + ), + dict( + mode='calling', which_parent='parent2', sample_name_to_train='child' + ), + dict( + mode='training', which_parent='parent1', sample_name_to_train='child' + ), + dict( + mode='training', which_parent='parent2', sample_name_to_train='child' + ), + dict( + mode='calling', which_parent='parent1', sample_name_to_train='parent1' + ), + dict( + mode='training', + which_parent='parent1', + sample_name_to_train='parent1', + ), + # Training on parent2 in a duo is not supported (with a clear error + # message). + ) + @flagsaver.flagsaver + def test_make_examples_training_end2end_duos( + self, mode, which_parent, sample_name_to_train + ): + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord')) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.partition_size = 1000 + + FLAGS.mode = mode + if mode == 'training': + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + + if which_parent == 'parent1': + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.sample_name_parent1 = 'parent1' + elif which_parent == 'parent2': + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name_parent2 = 'parent2' + else: + raise ValueError('Invalid `which_parent` value in test case.') + FLAGS.sample_name_to_train = sample_name_to_train + + # This is only a simple test that it runs without errors. + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + + @parameterized.parameters( + dict(mode='calling'), + dict(mode='training'), + ) + @flagsaver.flagsaver + def test_make_examples_end2end_vcf_candidate_importer(self, mode): + FLAGS.variant_caller = 'vcf_candidate_importer' + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.pileup_image_height_parent = 40 + FLAGS.pileup_image_height_child = 60 + FLAGS.candidates = test_utils.test_tmpfile( + _sharded('vcf_candidate_importer.candidates.{}.tfrecord'.format(mode)) + ) + FLAGS.examples = test_utils.test_tmpfile( + _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode)) + ) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.mode = mode + FLAGS.regions = '20:10,000,000-10,010,000' + + if mode == 'calling': + golden_file = _sharded( + testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_CALLING_EXAMPLES_CHILD + ) + path_to_output_examples = test_utils.test_tmpfile( + _sharded( + 'vcf_candidate_importer_child.examples.{}.tfrecord'.format(mode) + ) + ) + FLAGS.proposed_variants_child = testdata.TRUTH_VARIANTS_VCF + FLAGS.proposed_variants_parent1 = testdata.TRUTH_VARIANTS_VCF + FLAGS.proposed_variants_parent2 = testdata.TRUTH_VARIANTS_VCF + else: + golden_file = _sharded( + testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_TRAINING_EXAMPLES + ) + path_to_output_examples = test_utils.test_tmpfile( + _sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode)) + ) + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + # Verify that the variants in the examples are all good. + output_examples_to_compare = self.verify_examples( + path_to_output_examples, + None, + options, + verify_labels=mode == 'training', + examples_filename=FLAGS.examples, + ) + self.assertDeepVariantExamplesEqual( + output_examples_to_compare, + list(tfrecord.read_tfrecords(golden_file, compression_type='GZIP')), + ) + + @parameterized.parameters( + dict( + max_reads_per_partition=1500, + expected_len_examples1=88, + expected_len_examples2=32, + ), + dict( + max_reads_per_partition=8, + expected_len_examples1=34, + expected_len_examples2=30, + ), + ) + @flagsaver.flagsaver + def test_make_examples_with_max_reads_for_dynamic_bases_per_region( + self, + max_reads_per_partition, + expected_len_examples1, + expected_len_examples2, + ): + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.examples = test_utils.test_tmpfile(_sharded('ex.tfrecord')) + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + child_examples = test_utils.test_tmpfile(_sharded('ex_child.tfrecord')) + FLAGS.partition_size = 1000 + FLAGS.mode = 'calling' + FLAGS.max_reads_per_partition = max_reads_per_partition + + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + examples1 = self.verify_examples( + child_examples, + region, + options, + verify_labels=False, + examples_filename=FLAGS.examples, + ) + self.assertLen(examples1, expected_len_examples1) + # Now, this is the main part of the test. I want to test the behavior after + # I set max_reads_for_dynamic_bases_per_region. + FLAGS.max_reads_for_dynamic_bases_per_region = 1 + options = make_examples.default_options(add_flags=True) + make_examples_core.make_examples_runner(options) + examples2 = self.verify_examples( + child_examples, + region, + options, + verify_labels=False, + examples_filename=FLAGS.examples, + ) + self.assertLen(examples2, expected_len_examples2) + + def verify_nist_concordance(self, candidates, nist_variants): + # Tests that we call almost all of the real variants (according to NIST's + # Genome in a Bottle callset for NA12878) in our candidate callset. + # Tests that we don't have an enormous number of FP calls. We should have + # no more than 5x (arbitrary) more candidate calls than real calls. If we + # have more it's likely due to some major pipeline problem. + self.assertLess(len(candidates), 5 * len(nist_variants)) + tp_count = 0 + for nist_variant in nist_variants: + if self.assertVariantIsPresent(nist_variant, candidates): + tp_count = tp_count + 1 + + self.assertGreater( + tp_count / len(nist_variants), + 0.9705, + 'Recall must be greater than 0.9705. TP={}, Truth variants={}'.format( + tp_count, len(nist_variants) + ), + ) + + def assertDeepVariantExamplesEqual(self, actual, expected): + """Asserts that actual and expected tf.Examples from DeepVariant are equal. + + Args: + actual: iterable of tf.Examples from DeepVariant. DeepVariant examples + that we want to check. + expected: iterable of tf.Examples. Expected results for actual. + """ + self.assertEqual(len(actual), len(expected)) + for i in range(len(actual)): + self.assertEqual(decode_example(actual[i]), decode_example(expected[i])) + + def assertDeepVariantExamplesNotEqual(self, actual, expected): + """Asserts that actual and expected tf.Examples are not equal. + + Args: + actual: iterable of tf.Examples from DeepVariant. DeepVariant examples + that we want to check. + expected: iterable of tf.Examples. Expected results for actual. + """ + pass_not_equal_check = False + if len(actual) != len(expected): + logging.warning( + ( + 'In assertDeepVariantExamplesNotEqual: ' + 'actual(%d) and expected(%d) has different lengths' + ), + len(actual), + len(expected), + ) + pass_not_equal_check = True + min_size = min(len(actual), len(expected)) + for i in range(min_size): + if decode_example(actual[i]) != decode_example(expected[i]): + logging.warning( + ( + 'assertDeepVariantExamplesNotEqual: ' + 'actual example[%d] and expected example[%d] ' + 'are different' + ), + i, + i, + ) + pass_not_equal_check = True + self.assertTrue( + pass_not_equal_check, + ( + 'assertDeepVariantExamplesNotEqual failed - ' + 'actual and expected examples are identical.' + ), + ) + + def assertVariantIsPresent(self, to_find, variants): + def variant_key(v): + return (v.reference_bases, v.start, v.end) + + # Finds a call in our actual call set for each NIST variant, asserting + # that we found exactly one. + matches = [ + variant + for variant in variants + if variant_key(to_find) == variant_key(variant) + ] + if not matches: + return False + + # Verify that every alt allele appears in the call (but the call might) + # have more than just those calls. + for alt in to_find.alternate_bases: + if alt not in matches[0].alternate_bases: + return False + + return True + + def verify_candidate_positions( + self, candidate_positions_paths, candidate_positions_golden_set + ): + with epath.Path(candidate_positions_golden_set).open('rb') as my_file: + positions_golden = np.frombuffer(my_file.read(), dtype=np.int32) + with epath.Path(candidate_positions_paths).open('rb') as my_file: + positions = np.frombuffer(my_file.read(), dtype=np.int32) + logging.warning( + '%d positions created, %d positions in golden file', + len(positions), + len(positions_golden), + ) + self.assertCountEqual(positions, positions_golden) + + def verify_variants(self, variants, region, options, is_gvcf): + # Verifies simple properties of the Variant protos in variants. For example, + # checks that the reference_name() is our expected chromosome. The flag + # is_gvcf determines how we check the VariantCall field of each variant, + # enforcing expectations for gVCF records if true or variant calls if false. + for variant in variants: + if region: + self.assertEqual(variant.reference_name, region.reference_name) + self.assertGreaterEqual(variant.start, region.start) + self.assertLessEqual(variant.start, region.end) + self.assertNotEqual(variant.reference_bases, '') + self.assertNotEmpty(variant.alternate_bases) + self.assertLen(variant.calls, 1) + + call = variant_utils.only_call(variant) + self.assertEqual( + call.call_set_name, + options.sample_options[1].variant_caller_options.sample_name, + ) + if is_gvcf: + # GVCF records should have 0/0 or ./. (un-called) genotypes as they are + # reference sites, have genotype likelihoods and a GQ value. + self.assertIn(list(call.genotype), [[0, 0], [-1, -1]]) + self.assertLen(call.genotype_likelihood, 3) + self.assertGreaterEqual(variantcall_utils.get_gq(call), 0) + + def verify_contiguity(self, contiguous_variants, region): + """Verifies region is fully covered by gvcf records.""" + # We expect that the intervals cover every base, so the first variant should + # be at our interval start and the last one should end at our interval end. + self.assertNotEmpty(contiguous_variants) + self.assertEqual(region.start, contiguous_variants[0].start) + self.assertEqual(region.end, contiguous_variants[-1].end) + + # After this loop completes successfully we know that together the GVCF and + # Variants form a fully contiguous cover of our calling interval, as + # expected. + for v1, v2 in zip(contiguous_variants, contiguous_variants[1:]): + # Sequential variants should be contiguous, meaning that v2.start should + # be v1's end, as the end is exclusive and the start is inclusive. + if v1.start == v2.start and v1.end == v2.end: + # Skip duplicates here as we may have multi-allelic variants turning + # into multiple bi-allelic variants at the same site. + continue + # We expect to immediately follow the end of a gvcf record but to occur + # at the base immediately after a variant, since the variant's end can + # span over a larger interval when it's a deletion and we still produce + # gvcf records under the deletion. + expected_start = v1.end if v1.alternate_bases == ['<*>'] else v1.start + 1 + self.assertEqual(v2.start, expected_start) + + def verify_deepvariant_calls(self, dv_calls, options): + # Verifies simple structural properties of the DeepVariantCall objects + # emitted by the VerySensitiveCaller, such as that the AlleleCount and + # Variant both have the same position. + for call in dv_calls: + for alt_allele in call.variant.alternate_bases: + # Skip ref calls. + if alt_allele == vcf_constants.NO_ALT_ALLELE: + continue + # Make sure allele appears in our allele_support field and that at + # least our min number of reads to call an alt allele are present in + # the supporting reads list for that allele. + self.assertIn(alt_allele, list(call.allele_support)) + self.assertGreaterEqual( + len(call.allele_support[alt_allele].read_names), + options.sample_options[1].variant_caller_options.min_count_snps, + ) + + def sanity_check_example_info_json(self, example, examples_filename, task_id): + """Checks `example_info.json` w/ examples_filename has the right info.""" + example_info_json = dv_utils.get_example_info_json_filename( + examples_filename, task_id + ) + example_info = json.load(gfile.GFile(example_info_json, 'r')) + self.assertIn('shape', example_info) + self.assertEqual( + example_info['shape'], dv_utils.example_image_shape(example) + ) + self.assertIn('channels', example_info) + self.assertLen(example_info['channels'], example_info['shape'][2]) + + def verify_examples( + self, + path_to_output_examples, + region, + options, + verify_labels, + examples_filename=None, + ): + # Do some simple structural checks on the tf.Examples in the file. + expected_features = [ + 'variant/encoded', + 'locus', + 'image/encoded', + 'alt_allele_indices/encoded', + ] + if verify_labels: + expected_features += ['label'] + + examples = list( + tfrecord.read_tfrecords( + path_to_output_examples, compression_type='GZIP' + ) + ) + for example in examples: + for label_feature in expected_features: + self.assertIn(label_feature, example.features.feature) + # pylint: disable=g-explicit-length-test + self.assertNotEmpty(dv_utils.example_alt_alleles_indices(example)) + + # Check that the variants in the examples are good. + variants = [dv_utils.example_variant(x) for x in examples] + self.verify_variants(variants, region, options, is_gvcf=False) + + # In DeepTrio, path_to_output_examples can be pointing to the ones with + # the suffixes (such as _child). In that case, we pass in the original + # examples path to the `examples_filename` arg. + # If `examples_filename` arg, directly use `path_to_output_examples`. + if examples: + if examples_filename is None: + examples_filename = path_to_output_examples + self.sanity_check_example_info_json( + examples[0], examples_filename, options.task_id + ) + return examples + + +class MakeExamplesUnitTest(parameterized.TestCase): + + def test_read_write_run_info(self): + def _read_lines(path): + with open(path) as fin: + return list(fin.readlines()) + + golden_actual = make_examples_core.read_make_examples_run_info( + testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO + ) + # We don't really want to inject too much knowledge about the golden right + # here, so we only use a minimal test that (a) the run_info_filename is + # a non-empty string and (b) the number of candidates sites in the labeling + # metrics field is greater than 0. Any reasonable golden output will have at + # least one candidate variant, and the reader should have filled in the + # value. + self.assertNotEmpty(golden_actual.options.run_info_filename) + self.assertEqual( + golden_actual.labeling_metrics.n_candidate_variant_sites, + testdata.N_GOLDEN_TRAINING_EXAMPLES, + ) + + # Check that reading + writing the data produces the same lines: + tmp_output = test_utils.test_tmpfile('written_run_info.pbtxt') + make_examples_core.write_make_examples_run_info(golden_actual, tmp_output) + print('*' * 100) + print(_read_lines(tmp_output)) + print('*' * 100) + self.assertEqual( + _read_lines(testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO), + _read_lines(tmp_output), + ) + + @flagsaver.flagsaver + def test_keep_duplicates(self): + FLAGS.keep_duplicates = True + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + options = make_examples.default_options(add_flags=True) + self.assertEqual( + options.pic_options.read_requirements.keep_duplicates, True + ) + + @flagsaver.flagsaver + def test_keep_supplementary_alignments(self): + FLAGS.keep_supplementary_alignments = True + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + options = make_examples.default_options(add_flags=True) + self.assertEqual( + options.pic_options.read_requirements.keep_supplementary_alignments, + True, + ) + + @flagsaver.flagsaver + def test_keep_secondary_alignments(self): + FLAGS.keep_secondary_alignments = True + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + options = make_examples.default_options(add_flags=True) + self.assertEqual( + options.pic_options.read_requirements.keep_secondary_alignments, True + ) + + @flagsaver.flagsaver + def test_min_base_quality(self): + FLAGS.min_base_quality = 5 + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + options = make_examples.default_options(add_flags=True) + self.assertEqual(options.pic_options.read_requirements.min_base_quality, 5) + + @flagsaver.flagsaver + def test_min_mapping_quality(self): + FLAGS.min_mapping_quality = 15 + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + options = make_examples.default_options(add_flags=True) + self.assertEqual( + options.pic_options.read_requirements.min_mapping_quality, 15 + ) + + @flagsaver.flagsaver + def test_default_options_with_training_random_emit_ref_sites(self): + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + + FLAGS.training_random_emit_ref_sites = 0.3 + options = make_examples.default_options(add_flags=True) + self.assertAlmostEqual( + options.sample_options[ + 1 + ].variant_caller_options.fraction_reference_sites_to_emit, + 0.3, + ) + + @flagsaver.flagsaver + def test_default_options_without_training_random_emit_ref_sites(self): + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + + options = make_examples.default_options(add_flags=True) + # In proto3, there is no way to check presence of scalar field: + # redacted + # As an approximation, we directly check that the value should be exactly 0. + self.assertEqual( + options.sample_options[ + 1 + ].variant_caller_options.fraction_reference_sites_to_emit, + 0.0, + ) + + @flagsaver.flagsaver + def test_confident_regions(self): + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED + FLAGS.mode = 'training' + FLAGS.examples = '' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + + options = make_examples.default_options(add_flags=True) + confident_regions = make_examples_core.read_confident_regions(options) + + # Our expected intervals, inlined from CONFIDENT_REGIONS_BED. + expected = _from_literals_list([ + '20:10000847-10002407', + '20:10002521-10004171', + '20:10004274-10004964', + '20:10004995-10006386', + '20:10006410-10007800', + '20:10007825-10008018', + '20:10008044-10008079', + '20:10008101-10008707', + '20:10008809-10008897', + '20:10009003-10009791', + '20:10009934-10010531', + ]) + # Our confident regions should be exactly those found in the BED file. + self.assertCountEqual(expected, list(confident_regions)) + + @parameterized.parameters( + ({'examples': ('foo', 'foo')},), + ({'examples': ('foo', 'foo'), 'gvcf': ('bar', 'bar')},), + ({'examples': ('foo@10', 'foo-00000-of-00010')},), + ({'task': (0, 0), 'examples': ('foo@10', 'foo-00000-of-00010')},), + ({'task': (1, 1), 'examples': ('foo@10', 'foo-00001-of-00010')},), + ( + { + 'task': (1, 1), + 'examples': ('foo@10', 'foo-00001-of-00010'), + 'gvcf': ('bar@10', 'bar-00001-of-00010'), + }, + ), + ( + { + 'task': (1, 1), + 'examples': ('foo@10', 'foo-00001-of-00010'), + 'gvcf': ('bar@10', 'bar-00001-of-00010'), + 'candidates': ('baz@10', 'baz-00001-of-00010'), + }, + ), + ) + @flagsaver.flagsaver + def test_sharded_outputs1(self, settings): + # Set all of the requested flag values. + for name, (flag_val, _) in settings.items(): + setattr(FLAGS, name, flag_val) + + FLAGS.mode = 'training' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.reads = '' + FLAGS.ref = '' + options = make_examples.default_options(add_flags=True) + + # Check all of the flags. + for name, option_val in [ + ('examples', options.examples_filename), + ('candidates', options.candidates_filename), + ('gvcf', options.gvcf_filename), + ]: + expected = settings[name][1] if name in settings else '' + self.assertEqual(expected, option_val) + + def test_catches_bad_argv(self): + with ( + mock.patch.object(logging, 'error') as mock_logging, + mock.patch.object(sys, 'exit') as mock_exit, + ): + make_examples.main(['make_examples.py', 'extra_arg']) + mock_logging.assert_called_once_with( + 'Command line parsing failure: make_examples does not accept ' + 'positional arguments but some are present on the command line: ' + "\"['make_examples.py', 'extra_arg']\"." + ) + mock_exit.assert_called_once_with(errno.ENOENT) + + @flagsaver.flagsaver + def test_catches_bad_flags(self): + # Set all of the requested flag values. + region = ranges.parse_literal('20:10,000,000-10,010,000') + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.candidates = test_utils.test_tmpfile('vsc.tfrecord') + FLAGS.examples = test_utils.test_tmpfile('examples.tfrecord') + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.regions = [ranges.to_literal(region)] + FLAGS.partition_size = 1000 + FLAGS.mode = 'training' + FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF + # This is the bad flag. + FLAGS.confident_regions = '' + + with ( + mock.patch.object(logging, 'error') as mock_logging, + mock.patch.object(sys, 'exit') as mock_exit, + ): + make_examples.main(['make_examples.py']) + mock_logging.assert_called_once_with( + 'confident_regions is required when in training mode.' + ) + mock_exit.assert_called_once_with(errno.ENOENT) + + @flagsaver.flagsaver + def test_regions_and_exclude_regions_flags_with_trio_options(self): + FLAGS.mode = 'calling' + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.regions = '20:10,000,000-11,000,000' + FLAGS.examples = 'examples.tfrecord' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.exclude_regions = '20:10,010,000-10,100,000' + + options = make_examples.default_options(add_flags=True) + _, regions_from_options = ( + make_examples_core.processing_regions_from_options(options) + ) + self.assertCountEqual( + list(ranges.RangeSet(regions_from_options)), + _from_literals_list( + ['20:10,000,000-10,009,999', '20:10,100,001-11,000,000'] + ), + ) + + @flagsaver.flagsaver + def test_incorrect_empty_regions_with_trio_options(self): + FLAGS.mode = 'calling' + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + # Deliberately incorrect contig name. + FLAGS.regions = 'xxx20:10,000,000-11,000,000' + FLAGS.examples = 'examples.tfrecord' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + + options = make_examples.default_options(add_flags=True) + with self.assertRaisesRegex(ValueError, 'The regions to call is empty.'): + make_examples_core.processing_regions_from_options(options) + + +class RegionProcessorTest(parameterized.TestCase): + + def setUp(self): + super(RegionProcessorTest, self).setUp() + self.region = ranges.parse_literal('20:10,000,000-10,000,100') + + FLAGS.reads = '' + self.options = make_examples.default_options(add_flags=False) + self.options.reference_filename = testdata.CHR20_FASTA + self.options.truth_variants_filename = testdata.TRUTH_VARIANTS_VCF + self.options.mode = deepvariant_pb2.MakeExamplesOptions.TRAINING + + self.ref_reader = fasta.IndexedFastaReader(self.options.reference_filename) + self.default_shape = [5, 5, 7] + self.processor = make_examples_core.RegionProcessor(self.options) + self.mock_init = self.add_mock('_initialize') + for sample in self.processor.samples: + sample.in_memory_sam_reader = mock.Mock() + + def add_mock(self, name, retval='dontadd', side_effect='dontadd'): + patcher = mock.patch.object(self.processor, name, autospec=True) + self.addCleanup(patcher.stop) + mocked = patcher.start() + if retval != 'dontadd': + mocked.return_value = retval + if side_effect != 'dontadd': + mocked.side_effect = side_effect + return mocked + + @parameterized.parameters([ + deepvariant_pb2.MakeExamplesOptions.TRAINING, + deepvariant_pb2.MakeExamplesOptions.CALLING, + ]) + def test_process_keeps_ordering_of_candidates_and_examples(self, mode): + self.processor.options.mode = mode + + r1, r2 = mock.Mock(), mock.Mock() + c1, c2 = mock.Mock(), mock.Mock() + self.add_mock('region_reads_norealign', retval=[r1, r2]) + self.add_mock('candidates_in_region', retval=({'child': [c1, c2]}, {}, {})) + candidates_dict, gvcfs_dict, runtimes, read_phases = self.processor.process( + self.region + ) + self.assertEqual({'child': [c1, c2]}, candidates_dict) + self.assertEqual({}, gvcfs_dict) + self.assertEqual({}, read_phases) + self.assertIsInstance(runtimes, dict) + + in_memory_sam_reader = self.processor.samples[1].in_memory_sam_reader + in_memory_sam_reader.replace_reads.assert_called_once_with([r1, r2]) + + @flagsaver.flagsaver + def test_use_original_quality_scores_without_parse_sam_aux_fields(self): + FLAGS.mode = 'calling' + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.examples = 'examples.tfrecord' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + FLAGS.use_original_quality_scores = True + FLAGS.parse_sam_aux_fields = False + + with self.assertRaisesRegex( + Exception, + ( + 'If --use_original_quality_scores is set then ' + '--parse_sam_aux_fields must be set too.' + ), + ): + make_examples.default_options(add_flags=True) + + @parameterized.parameters( + dict(height_parent=10, height_child=9), + dict(height_parent=9, height_child=10), + dict(height_parent=150, height_child=101), + dict(height_parent=101, height_child=170), + ) + @flagsaver.flagsaver + def test_image_heights(self, height_parent, height_child): + FLAGS.pileup_image_height_parent = height_parent + FLAGS.pileup_image_height_child = height_child + FLAGS.mode = 'calling' + FLAGS.ref = testdata.CHR20_FASTA + FLAGS.reads = testdata.HG001_CHR20_BAM + FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM + FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM + FLAGS.sample_name = 'child' + FLAGS.sample_name_to_train = 'child' + FLAGS.sample_name_parent1 = 'parent1' + FLAGS.sample_name_parent2 = 'parent2' + FLAGS.examples = 'examples.tfrecord' + FLAGS.channel_list = ','.join(dv_constants.PILEUP_CHANNELS_WITH_INSERT_SIZE) + + options = make_examples.default_options(add_flags=True) + with self.assertRaisesRegex( + Exception, 'Total pileup image heights must be between 75-362.' + ): + make_examples.check_options_are_valid(options) + + +if __name__ == '__main__': + absltest.main()