# Copyright 2020 Google LLC.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Tests for deeptrio.make_examples."""
import copy
import errno
import json
import platform
import sys
from unittest import mock
from absl import flags
from absl import logging
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
from etils import epath
import numpy as np
from deeptrio import make_examples
from deeptrio import testdata
from deepvariant import dv_utils
from deepvariant import dv_utils_using_clif
from deepvariant import make_examples_core
from deepvariant.labeler import variant_labeler
from deepvariant.protos import deepvariant_pb2
from tensorflow.python.platform import gfile
from third_party.nucleus.io import fasta
from third_party.nucleus.io import sharded_file_utils
from third_party.nucleus.io import tfrecord
from third_party.nucleus.io import vcf
from third_party.nucleus.protos import reference_pb2
from third_party.nucleus.protos import variants_pb2
from third_party.nucleus.testing import test_utils
from third_party.nucleus.util import ranges
from third_party.nucleus.util import variant_utils
from third_party.nucleus.util import variantcall_utils
from third_party.nucleus.util import vcf_constants
FLAGS = flags.FLAGS
# Dictionary mapping keys to decoders for decode_example function.
_EXAMPLE_DECODERS = {
'locus': dv_utils.example_locus,
'alt_allele_indices/encoded': dv_utils.example_alt_alleles_indices,
'image/encoded': dv_utils.example_encoded_image,
'variant/encoded': dv_utils.example_variant,
'variant_type': dv_utils.example_variant_type,
'label': dv_utils.example_label,
'image/shape': dv_utils.example_image_shape,
'sequencing_type': dv_utils.example_sequencing_type,
'denovo_label': dv_utils.example_denovo_label,
}
def decode_example(example):
"""Decodes a tf.Example from DeepVariant into a dict of Pythonic structures.
Args:
example: tf.Example proto. The example to make into a dictionary.
Returns:
A python dictionary with key/value pairs for each of the fields of example,
with each value decoded as needed into Python structures like protos, list,
etc.
Raises:
KeyError: If example contains a feature without a known decoder.
"""
as_dict = {}
for key in example.features.feature:
if key not in _EXAMPLE_DECODERS:
raise KeyError('Unexpected example key', key)
as_dict[key] = _EXAMPLE_DECODERS[key](example)
return as_dict
def setUpModule():
logging.set_verbosity(logging.FATAL)
testdata.init()
def _make_contigs(specs):
"""Makes ContigInfo protos from specs.
Args:
specs: A list of 2- or 3-tuples. All tuples should be of the same length. If
2-element, these should be the name and length in basepairs of each
contig, and their pos_in_fasta will be set to their index in the list. If
the 3-element, the tuple should contain name, length, and pos_in_fasta.
Returns:
A list of ContigInfo protos, one for each spec in specs.
"""
if specs and len(specs[0]) == 3:
return [
reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i)
for name, length, i in specs
]
else:
return [
reference_pb2.ContigInfo(name=name, n_bases=length, pos_in_fasta=i)
for i, (name, length) in enumerate(specs)
]
def _from_literals_list(literals, contig_map=None):
"""Makes a list of Range objects from literals."""
return ranges.parse_literals(literals, contig_map)
def _from_literals(literals, contig_map=None):
"""Makes a RangeSet of intervals from literals."""
return ranges.RangeSet.from_regions(literals, contig_map)
def _sharded(basename, num_shards=None):
if num_shards:
return basename + '@' + str(num_shards)
else:
return basename
class MakeExamplesEnd2EndTest(parameterized.TestCase):
# Golden sets are created with
# learning/genomics/internal/create_golden_deep_trio.sh
@parameterized.parameters(
# All tests are run with fast_pass_aligner enabled. There are no
# golden sets version for ssw realigner.
dict(mode='calling', num_shards=0),
dict(mode='calling', num_shards=3),
dict(mode='candidate_sweep', num_shards=0),
dict(mode='candidate_sweep', num_shards=3),
dict(
mode='training', num_shards=0, labeler_algorithm='haplotype_labeler'
),
dict(
mode='training', num_shards=3, labeler_algorithm='haplotype_labeler'
),
dict(
mode='training', num_shards=0, labeler_algorithm='positional_labeler'
),
dict(
mode='training', num_shards=3, labeler_algorithm='positional_labeler'
),
)
@flagsaver.flagsaver
def test_make_examples_end2end(
self, mode, num_shards, labeler_algorithm=None, use_fast_pass_aligner=True
):
self.assertIn(mode, {'calling', 'training', 'candidate_sweep'})
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.write_run_info = True
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile(
_sharded('vsc.tfrecord', num_shards)
)
FLAGS.examples = test_utils.test_tmpfile(
_sharded('examples.tfrecord', num_shards)
)
child_examples = test_utils.test_tmpfile(
_sharded('examples_child.tfrecord', num_shards)
)
if mode == 'candidate_sweep':
FLAGS.candidate_positions = test_utils.test_tmpfile(
_sharded('candidate_positions', num_shards)
)
candidate_positions = test_utils.test_tmpfile(
_sharded('candidate_positions', num_shards)
)
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.partition_size = 1000
FLAGS.mode = mode
FLAGS.gvcf_gq_binsize = 5
FLAGS.use_fast_pass_aligner = use_fast_pass_aligner
if labeler_algorithm is not None:
FLAGS.labeler_algorithm = labeler_algorithm
if mode == 'calling':
FLAGS.gvcf = test_utils.test_tmpfile(
_sharded('gvcf.tfrecord', num_shards)
)
child_gvcf = test_utils.test_tmpfile(
_sharded('gvcf_child.tfrecord', num_shards)
)
child_candidates = test_utils.test_tmpfile(
_sharded('vsc_child.tfrecord', num_shards)
)
else:
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
child_candidates = test_utils.test_tmpfile(
_sharded('vsc.tfrecord', num_shards)
)
if mode == 'candidate_sweep':
golden_candidate_positions = _sharded(
testdata.GOLDEN_CANDIDATE_POSITIONS, num_shards
)
for task_id in range(max(num_shards, 1)):
FLAGS.task = task_id
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
# Check that our run_info proto contains the basic fields we'd expect:
# (a) our options are written to the run_info.options field.
run_info = make_examples_core.read_make_examples_run_info(
options.run_info_filename
)
self.assertEqual(run_info.options, options)
# (b) run_info.resource_metrics is present and contains our hostname.
self.assertTrue(run_info.HasField('resource_metrics'))
self.assertEqual(run_info.resource_metrics.host_name, platform.node())
# For candidate_sweep mode we verify that candidate positions match
# golden set exactly.
if mode == 'candidate_sweep':
_, candidates_path = sharded_file_utils.resolve_filespecs(
task_id, candidate_positions
)
_, gold_candidates_path = sharded_file_utils.resolve_filespecs(
task_id, golden_candidate_positions
)
self.verify_candidate_positions(candidates_path, gold_candidates_path)
# In candidate_sweep mode the test stops here.
if mode == 'candidate_sweep':
return
# Test that our candidates are reasonable, calling specific helper functions
# to check lots of properties of the output.
candidates = sorted(
tfrecord.read_tfrecords(
child_candidates, proto=deepvariant_pb2.DeepVariantCall
),
key=lambda c: variant_utils.variant_range_tuple(c.variant),
)
self.verify_deepvariant_calls(candidates, options)
self.verify_variants(
[call.variant for call in candidates], region, options, is_gvcf=False
)
# Verify that the variants in the examples are all good.
if mode == 'calling':
examples = self.verify_examples(
child_examples,
region,
options,
verify_labels=False,
examples_filename=FLAGS.examples,
)
if mode == 'training':
examples = self.verify_examples(
FLAGS.examples, region, options, verify_labels=True
)
example_variants = [dv_utils.example_variant(ex) for ex in examples]
self.verify_variants(example_variants, region, options, is_gvcf=False)
# Verify the integrity of the examples and then check that they match our
# golden labeled examples. Note we expect the order for both training and
# calling modes to produce deterministic order because we fix the random
# seed.
if mode == 'calling':
golden_file = _sharded(testdata.GOLDEN_CALLING_EXAMPLES, num_shards)
else:
golden_file = _sharded(testdata.GOLDEN_TRAINING_EXAMPLES, num_shards)
self.assertDeepVariantExamplesEqual(
examples, list(tfrecord.read_tfrecords(golden_file))
)
if mode == 'calling':
nist_reader = vcf.VcfReader(testdata.TRUTH_VARIANTS_VCF)
nist_variants = list(nist_reader.query(region))
self.verify_nist_concordance(example_variants, nist_variants)
# Check the quality of our generated gvcf file.
gvcfs = variant_utils.sorted_variants(
tfrecord.read_tfrecords(child_gvcf, proto=variants_pb2.Variant)
)
self.verify_variants(gvcfs, region, options, is_gvcf=True)
self.verify_contiguity(gvcfs, region)
gvcf_golden_file = _sharded(
testdata.GOLDEN_POSTPROCESS_GVCF_INPUT, num_shards
)
expected_gvcfs = list(
tfrecord.read_tfrecords(gvcf_golden_file, proto=variants_pb2.Variant)
)
# Despite its name, assertCountEqual checks that all items are equal.
self.assertCountEqual(gvcfs, expected_gvcfs)
if (
mode == 'training'
and num_shards == 0
and labeler_algorithm != 'positional_labeler'
):
# The positional labeler doesn't track metrics, so don't try to read them
# in when that's the mode.
self.assertEqual(
make_examples_core.read_make_examples_run_info(
testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO
).labeling_metrics,
run_info.labeling_metrics,
)
@parameterized.parameters(
dict(
denovo_test=False,
expected_denovo_variants=0,
),
dict(
denovo_test=True,
expected_denovo_variants=3,
),
)
@flagsaver.flagsaver
def test_make_examples_ont_end2end(
self,
denovo_test: bool,
expected_denovo_variants: int,
):
"""Test end to end for long ONT reads with phasing enabled.
Args:
denovo_test: If true, denovo parameters will be set.
expected_denovo_variants: Total number of denovo examples expected.
This test runs ONT end to end and compares the output with the golden
output. This test is introduced because previously in training mode the
non training sample would not be phased. So this now tests to make sure
all of the training examples are phased correctly.
"""
region = ranges.parse_literal('chr20:5050000-5075000')
FLAGS.write_run_info = True
FLAGS.ref = testdata.GRCH38_CHR0_FASTA
FLAGS.reads = testdata.ONT_HG002_BAM
FLAGS.reads_parent1 = testdata.ONT_HG003_BAM
FLAGS.reads_parent2 = testdata.ONT_HG004_BAM
FLAGS.confident_regions = testdata.HG002_HIGH_CONFIDENCE_BED
FLAGS.truth_variants = testdata.HG002_HIGH_CONFIDENCE_VCF
FLAGS.sample_name = 'HG002'
FLAGS.sample_name_to_train = 'HG002'
FLAGS.sample_name_parent1 = 'HG003'
FLAGS.sample_name_parent2 = 'HG004'
FLAGS.add_hp_channel = True
FLAGS.alt_aligned_pileup = 'diff_channels'
FLAGS.min_mapping_quality = 1
FLAGS.mode = 'training'
FLAGS.parse_sam_aux_fields = True
FLAGS.partition_size = 25000
FLAGS.phase_reads = True
FLAGS.pileup_image_height_child = 100
FLAGS.pileup_image_height_parent = 100
FLAGS.pileup_image_width = 199
FLAGS.realign_reads = False
FLAGS.skip_parent_calling = True
FLAGS.sort_by_haplotypes = True
FLAGS.track_ref_reads = True
FLAGS.vsc_min_fraction_indels = 0.12
FLAGS.vsc_min_fraction_snps = 0.1
num_shards = 0
FLAGS.examples = test_utils.test_tmpfile(
_sharded('examples.tfrecord', num_shards)
)
FLAGS.regions = [ranges.to_literal(region)]
golden_file = _sharded(testdata.GOLDEN_ONT_MAKE_EXAMPLES_OUTPUT, num_shards)
FLAGS.denovo_regions = None
if denovo_test:
# If denovo test is enabled, then set the parameters for denovo testing.
golden_file = _sharded(
testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT, num_shards
)
FLAGS.write_run_info = True
FLAGS.denovo_regions = testdata.HG002_DENOVO_BED
for task_id in range(max(num_shards, 1)):
FLAGS.task = task_id
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
examples = self.verify_examples(
FLAGS.examples, region, options, verify_labels=True
)
self.assertDeepVariantExamplesEqual(
examples, list(tfrecord.read_tfrecords(golden_file))
)
if denovo_test:
# Check total number of denovo examples.
total_denovo = sum(
[
1
for example in examples
if dv_utils.example_denovo_label(example)
]
)
self.assertEqual(
total_denovo,
expected_denovo_variants,
msg='ONT denovo golden test: denovo variants count.',
)
# Read the runinfo file
runinfo = make_examples_core.read_make_examples_run_info(
FLAGS.examples + '.run_info.pbtxt'
)
golden_runinfo = make_examples_core.read_make_examples_run_info(
testdata.GOLDEN_ONT_DENOVO_MAKE_EXAMPLES_OUTPUT + '.run_info.pbtxt'
)
self.assertEqual(
runinfo.stats.num_examples,
golden_runinfo.stats.num_examples,
msg='ONT denovo golden test: Run info comparison num_examples.',
)
self.assertEqual(
runinfo.stats.num_denovo,
golden_runinfo.stats.num_denovo,
msg='ONT denovo golden test: Run info comparison num_denovo.',
)
self.assertEqual(
runinfo.stats.num_nondenovo,
golden_runinfo.stats.num_nondenovo,
msg='ONT denovo golden test: Run info comparison num_nondenovo.',
)
# Golden sets are created with learning/genomics/internal/create_golden.sh
@flagsaver.flagsaver
def test_make_examples_training_end2end_with_customized_classes_labeler(self):
FLAGS.labeler_algorithm = 'customized_classes_labeler'
FLAGS.customized_classes_labeler_classes_list = 'ref,class1,class2'
FLAGS.customized_classes_labeler_info_field_name = 'type'
region = ranges.parse_literal('20:10,000,000-10,004,000')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
FLAGS.partition_size = 1000
FLAGS.mode = 'training'
FLAGS.gvcf_gq_binsize = 5
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF_WITH_TYPES
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
golden_file = _sharded(testdata.CUSTOMIZED_CLASSES_GOLDEN_TRAINING_EXAMPLES)
# Verify that the variants in the examples are all good.
examples = self.verify_examples(
FLAGS.examples, region, options, verify_labels=True
)
self.assertDeepVariantExamplesEqual(
examples, list(tfrecord.read_tfrecords(golden_file))
)
# Golden sets are created with learning/genomics/internal/create_golden.sh
@flagsaver.flagsaver
def test_make_examples_training_end2end_with_alt_aligned_pileup(self):
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
FLAGS.partition_size = 1000
FLAGS.mode = 'training'
FLAGS.gvcf_gq_binsize = 5
# The following 4 lines are added.
FLAGS.alt_aligned_pileup = 'diff_channels'
FLAGS.pileup_image_height_child = 60
FLAGS.pileup_image_height_parent = 40
FLAGS.pileup_image_width = 199
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
golden_file = _sharded(testdata.ALT_ALIGNED_PILEUP_GOLDEN_TRAINING_EXAMPLES)
# Verify that the variants in the examples are all good.
examples = self.verify_examples(
FLAGS.examples, region, options, verify_labels=True
)
self.assertDeepVariantExamplesEqual(
examples, list(tfrecord.read_tfrecords(golden_file))
)
# Pileup image should now have 8 channels.
# Height should be 60 + 40 * 2 = 140.
self.assertEqual(decode_example(examples[0])['image/shape'], [140, 199, 8])
@flagsaver.flagsaver
def test_make_examples_compare_realignment_modes(self):
def _run_with_realignment_mode(enable_joint_realignment, name):
FLAGS.enable_joint_realignment = enable_joint_realignment
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile(f'{name}.vsc.tfrecord')
FLAGS.examples = test_utils.test_tmpfile(f'{name}.examples.tfrecord')
child_examples = test_utils.test_tmpfile(
f'{name}_child.examples.tfrecord'
)
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.partition_size = 1000
FLAGS.mode = 'calling'
FLAGS.gvcf = test_utils.test_tmpfile(f'{name}.gvcf.tfrecord')
# child_gvcf = test_utils.test_tmpfile(f'{name}.gvcf_child.tfrecord')
# child_candidates = test_utils.test_tmpfile(f'{name}.vsc_child.tfrecord')
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
examples = self.verify_examples(
child_examples,
region,
options,
verify_labels=False,
examples_filename=FLAGS.examples,
)
return examples
examples1 = _run_with_realignment_mode(False, 'ex1')
examples2 = _run_with_realignment_mode(True, 'ex2')
self.assertNotEmpty(examples1)
self.assertNotEmpty(examples2)
# The assumption is just that these two lists of examples should be
# different. In this case, it happens to be that we got different numbers
# of examples:
self.assertNotEmpty(examples1)
self.assertDeepVariantExamplesNotEqual(examples1, examples2)
@parameterized.parameters(
dict(select_types=None, expected_count=79),
dict(select_types='all', expected_count=79),
dict(select_types='snps', expected_count=64),
dict(select_types='indels', expected_count=12),
dict(select_types='snps indels', expected_count=76),
dict(select_types='multi-allelics', expected_count=3),
dict(select_types=None, keep_legacy_behavior=True, expected_count=79),
dict(select_types='all', keep_legacy_behavior=True, expected_count=79),
dict(select_types='snps', keep_legacy_behavior=True, expected_count=64),
dict(select_types='indels', keep_legacy_behavior=True, expected_count=11),
dict(
select_types='snps indels',
keep_legacy_behavior=True,
expected_count=75,
),
dict(
select_types='multi-allelics',
keep_legacy_behavior=True,
expected_count=4,
),
)
@flagsaver.flagsaver
def test_make_examples_with_variant_selection(
self, select_types, expected_count, keep_legacy_behavior=False
):
if select_types is not None:
FLAGS.select_variant_types = select_types
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile(_sharded('vsc.tfrecord'))
child_candidates = test_utils.test_tmpfile(_sharded('vsc_child.tfrecord'))
FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
FLAGS.partition_size = 1000
FLAGS.mode = 'calling'
FLAGS.keep_legacy_allele_counter_behavior = keep_legacy_behavior
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
candidates = list(tfrecord.read_tfrecords(child_candidates))
self.assertLen(candidates, expected_count)
@parameterized.parameters(
dict(
mode='calling', which_parent='parent1', sample_name_to_train='child'
),
dict(
mode='calling', which_parent='parent2', sample_name_to_train='child'
),
dict(
mode='training', which_parent='parent1', sample_name_to_train='child'
),
dict(
mode='training', which_parent='parent2', sample_name_to_train='child'
),
dict(
mode='calling', which_parent='parent1', sample_name_to_train='parent1'
),
dict(
mode='training',
which_parent='parent1',
sample_name_to_train='parent1',
),
# Training on parent2 in a duo is not supported (with a clear error
# message).
)
@flagsaver.flagsaver
def test_make_examples_training_end2end_duos(
self, mode, which_parent, sample_name_to_train
):
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.examples = test_utils.test_tmpfile(_sharded('examples.tfrecord'))
FLAGS.partition_size = 1000
FLAGS.mode = mode
if mode == 'training':
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
if which_parent == 'parent1':
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.sample_name_parent1 = 'parent1'
elif which_parent == 'parent2':
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name_parent2 = 'parent2'
else:
raise ValueError('Invalid `which_parent` value in test case.')
FLAGS.sample_name_to_train = sample_name_to_train
# This is only a simple test that it runs without errors.
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
@parameterized.parameters(
dict(mode='calling'),
dict(mode='training'),
)
@flagsaver.flagsaver
def test_make_examples_end2end_vcf_candidate_importer(self, mode):
FLAGS.variant_caller = 'vcf_candidate_importer'
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.pileup_image_height_parent = 40
FLAGS.pileup_image_height_child = 60
FLAGS.candidates = test_utils.test_tmpfile(
_sharded('vcf_candidate_importer.candidates.{}.tfrecord'.format(mode))
)
FLAGS.examples = test_utils.test_tmpfile(
_sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode))
)
FLAGS.mode = mode
FLAGS.regions = '20:10,000,000-10,010,000'
if mode == 'calling':
golden_file = _sharded(
testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_CALLING_EXAMPLES_CHILD
)
path_to_output_examples = test_utils.test_tmpfile(
_sharded(
'vcf_candidate_importer_child.examples.{}.tfrecord'.format(mode)
)
)
FLAGS.proposed_variants_child = testdata.TRUTH_VARIANTS_VCF
FLAGS.proposed_variants_parent1 = testdata.TRUTH_VARIANTS_VCF
FLAGS.proposed_variants_parent2 = testdata.TRUTH_VARIANTS_VCF
else:
golden_file = _sharded(
testdata.GOLDEN_VCF_CANDIDATE_IMPORTER_TRAINING_EXAMPLES
)
path_to_output_examples = test_utils.test_tmpfile(
_sharded('vcf_candidate_importer.examples.{}.tfrecord'.format(mode))
)
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
# Verify that the variants in the examples are all good.
output_examples_to_compare = self.verify_examples(
path_to_output_examples,
None,
options,
verify_labels=mode == 'training',
examples_filename=FLAGS.examples,
)
self.assertDeepVariantExamplesEqual(
output_examples_to_compare, list(tfrecord.read_tfrecords(golden_file))
)
@parameterized.parameters(
dict(
max_reads_per_partition=1500,
expected_len_examples1=88,
expected_len_examples2=32,
),
dict(
max_reads_per_partition=8,
expected_len_examples1=34,
expected_len_examples2=30,
),
)
@flagsaver.flagsaver
def test_make_examples_with_max_reads_for_dynamic_bases_per_region(
self,
max_reads_per_partition,
expected_len_examples1,
expected_len_examples2,
):
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.examples = test_utils.test_tmpfile(_sharded('ex.tfrecord'))
child_examples = test_utils.test_tmpfile(_sharded('ex_child.tfrecord'))
FLAGS.partition_size = 1000
FLAGS.mode = 'calling'
FLAGS.max_reads_per_partition = max_reads_per_partition
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
examples1 = self.verify_examples(
child_examples,
region,
options,
verify_labels=False,
examples_filename=FLAGS.examples,
)
self.assertLen(examples1, expected_len_examples1)
# Now, this is the main part of the test. I want to test the behavior after
# I set max_reads_for_dynamic_bases_per_region.
FLAGS.max_reads_for_dynamic_bases_per_region = 1
options = make_examples.default_options(add_flags=True)
make_examples_core.make_examples_runner(options)
examples2 = self.verify_examples(
child_examples,
region,
options,
verify_labels=False,
examples_filename=FLAGS.examples,
)
self.assertLen(examples2, expected_len_examples2)
def verify_nist_concordance(self, candidates, nist_variants):
# Tests that we call almost all of the real variants (according to NIST's
# Genome in a Bottle callset for NA12878) in our candidate callset.
# Tests that we don't have an enormous number of FP calls. We should have
# no more than 5x (arbitrary) more candidate calls than real calls. If we
# have more it's likely due to some major pipeline problem.
self.assertLess(len(candidates), 5 * len(nist_variants))
tp_count = 0
for nist_variant in nist_variants:
if self.assertVariantIsPresent(nist_variant, candidates):
tp_count = tp_count + 1
self.assertGreater(
tp_count / len(nist_variants),
0.9705,
'Recall must be greater than 0.9705. TP={}, Truth variants={}'.format(
tp_count, len(nist_variants)
),
)
def assertDeepVariantExamplesEqual(self, actual, expected):
"""Asserts that actual and expected tf.Examples from DeepVariant are equal.
Args:
actual: iterable of tf.Examples from DeepVariant. DeepVariant examples
that we want to check.
expected: iterable of tf.Examples. Expected results for actual.
"""
self.assertEqual(len(actual), len(expected))
for i in range(len(actual)):
self.assertEqual(decode_example(actual[i]), decode_example(expected[i]))
def assertDeepVariantExamplesNotEqual(self, actual, expected):
"""Asserts that actual and expected tf.Examples are not equal.
Args:
actual: iterable of tf.Examples from DeepVariant. DeepVariant examples
that we want to check.
expected: iterable of tf.Examples. Expected results for actual.
"""
pass_not_equal_check = False
if len(actual) != len(expected):
logging.warning(
(
'In assertDeepVariantExamplesNotEqual: '
'actual(%d) and expected(%d) has different lengths'
),
len(actual),
len(expected),
)
pass_not_equal_check = True
min_size = min(len(actual), len(expected))
for i in range(min_size):
if decode_example(actual[i]) != decode_example(expected[i]):
logging.warning(
(
'assertDeepVariantExamplesNotEqual: '
'actual example[%d] and expected example[%d] '
'are different'
),
i,
i,
)
pass_not_equal_check = True
self.assertTrue(
pass_not_equal_check,
(
'assertDeepVariantExamplesNotEqual failed - '
'actual and expected examples are identical.'
),
)
def assertVariantIsPresent(self, to_find, variants):
def variant_key(v):
return (v.reference_bases, v.start, v.end)
# Finds a call in our actual call set for each NIST variant, asserting
# that we found exactly one.
matches = [
variant
for variant in variants
if variant_key(to_find) == variant_key(variant)
]
if not matches:
return False
# Verify that every alt allele appears in the call (but the call might)
# have more than just those calls.
for alt in to_find.alternate_bases:
if alt not in matches[0].alternate_bases:
return False
return True
def verify_candidate_positions(
self, candidate_positions_paths, candidate_positions_golden_set
):
with epath.Path(candidate_positions_golden_set).open('rb') as my_file:
positions_golden = np.frombuffer(my_file.read(), dtype=np.int32)
with epath.Path(candidate_positions_paths).open('rb') as my_file:
positions = np.frombuffer(my_file.read(), dtype=np.int32)
logging.warning(
'%d positions created, %d positions in golden file',
len(positions),
len(positions_golden),
)
self.assertCountEqual(positions, positions_golden)
def verify_variants(self, variants, region, options, is_gvcf):
# Verifies simple properties of the Variant protos in variants. For example,
# checks that the reference_name() is our expected chromosome. The flag
# is_gvcf determines how we check the VariantCall field of each variant,
# enforcing expectations for gVCF records if true or variant calls if false.
for variant in variants:
if region:
self.assertEqual(variant.reference_name, region.reference_name)
self.assertGreaterEqual(variant.start, region.start)
self.assertLessEqual(variant.start, region.end)
self.assertNotEqual(variant.reference_bases, '')
self.assertNotEmpty(variant.alternate_bases)
self.assertLen(variant.calls, 1)
call = variant_utils.only_call(variant)
self.assertEqual(
call.call_set_name,
options.sample_options[1].variant_caller_options.sample_name,
)
if is_gvcf:
# GVCF records should have 0/0 or ./. (un-called) genotypes as they are
# reference sites, have genotype likelihoods and a GQ value.
self.assertIn(list(call.genotype), [[0, 0], [-1, -1]])
self.assertLen(call.genotype_likelihood, 3)
self.assertGreaterEqual(variantcall_utils.get_gq(call), 0)
def verify_contiguity(self, contiguous_variants, region):
"""Verifies region is fully covered by gvcf records."""
# We expect that the intervals cover every base, so the first variant should
# be at our interval start and the last one should end at our interval end.
self.assertNotEmpty(contiguous_variants)
self.assertEqual(region.start, contiguous_variants[0].start)
self.assertEqual(region.end, contiguous_variants[-1].end)
# After this loop completes successfully we know that together the GVCF and
# Variants form a fully contiguous cover of our calling interval, as
# expected.
for v1, v2 in zip(contiguous_variants, contiguous_variants[1:]):
# Sequential variants should be contiguous, meaning that v2.start should
# be v1's end, as the end is exclusive and the start is inclusive.
if v1.start == v2.start and v1.end == v2.end:
# Skip duplicates here as we may have multi-allelic variants turning
# into multiple bi-allelic variants at the same site.
continue
# We expect to immediately follow the end of a gvcf record but to occur
# at the base immediately after a variant, since the variant's end can
# span over a larger interval when it's a deletion and we still produce
# gvcf records under the deletion.
expected_start = v1.end if v1.alternate_bases == ['<*>'] else v1.start + 1
self.assertEqual(v2.start, expected_start)
def verify_deepvariant_calls(self, dv_calls, options):
# Verifies simple structural properties of the DeepVariantCall objects
# emitted by the VerySensitiveCaller, such as that the AlleleCount and
# Variant both have the same position.
for call in dv_calls:
for alt_allele in call.variant.alternate_bases:
# Skip ref calls.
if alt_allele == vcf_constants.NO_ALT_ALLELE:
continue
# Make sure allele appears in our allele_support field and that at
# least our min number of reads to call an alt allele are present in
# the supporting reads list for that allele.
self.assertIn(alt_allele, list(call.allele_support))
self.assertGreaterEqual(
len(call.allele_support[alt_allele].read_names),
options.sample_options[1].variant_caller_options.min_count_snps,
)
def sanity_check_example_info_json(self, example, examples_filename, task_id):
"""Checks `example_info.json` w/ examples_filename has the right info."""
example_info_json = dv_utils.get_example_info_json_filename(
examples_filename, task_id
)
example_info = json.load(gfile.GFile(example_info_json, 'r'))
self.assertIn('shape', example_info)
self.assertEqual(
example_info['shape'], dv_utils.example_image_shape(example)
)
self.assertIn('channels', example_info)
self.assertLen(example_info['channels'], example_info['shape'][2])
def verify_examples(
self,
path_to_output_examples,
region,
options,
verify_labels,
examples_filename=None,
):
# Do some simple structural checks on the tf.Examples in the file.
expected_features = [
'variant/encoded',
'locus',
'image/encoded',
'alt_allele_indices/encoded',
]
if verify_labels:
expected_features += ['label']
examples = list(tfrecord.read_tfrecords(path_to_output_examples))
for example in examples:
for label_feature in expected_features:
self.assertIn(label_feature, example.features.feature)
# pylint: disable=g-explicit-length-test
self.assertNotEmpty(dv_utils.example_alt_alleles_indices(example))
# Check that the variants in the examples are good.
variants = [dv_utils.example_variant(x) for x in examples]
self.verify_variants(variants, region, options, is_gvcf=False)
# In DeepTrio, path_to_output_examples can be pointing to the ones with
# the suffixes (such as _child). In that case, we pass in the original
# examples path to the `examples_filename` arg.
# If `examples_filename` arg, directly use `path_to_output_examples`.
if examples:
if examples_filename is None:
examples_filename = path_to_output_examples
self.sanity_check_example_info_json(
examples[0], examples_filename, options.task_id
)
return examples
class MakeExamplesUnitTest(parameterized.TestCase):
def test_read_write_run_info(self):
def _read_lines(path):
with open(path) as fin:
return list(fin.readlines())
golden_actual = make_examples_core.read_make_examples_run_info(
testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO
)
# We don't really want to inject too much knowledge about the golden right
# here, so we only use a minimal test that (a) the run_info_filename is
# a non-empty string and (b) the number of candidates sites in the labeling
# metrics field is greater than 0. Any reasonable golden output will have at
# least one candidate variant, and the reader should have filled in the
# value.
self.assertNotEmpty(golden_actual.options.run_info_filename)
self.assertEqual(
golden_actual.labeling_metrics.n_candidate_variant_sites,
testdata.N_GOLDEN_TRAINING_EXAMPLES,
)
# Check that reading + writing the data produces the same lines:
tmp_output = test_utils.test_tmpfile('written_run_info.pbtxt')
make_examples_core.write_make_examples_run_info(golden_actual, tmp_output)
self.assertEqual(
_read_lines(testdata.GOLDEN_MAKE_EXAMPLES_RUN_INFO),
_read_lines(tmp_output),
)
@flagsaver.flagsaver
def test_keep_duplicates(self):
FLAGS.keep_duplicates = True
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
self.assertEqual(
options.pic_options.read_requirements.keep_duplicates, True
)
@flagsaver.flagsaver
def test_keep_supplementary_alignments(self):
FLAGS.keep_supplementary_alignments = True
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
self.assertEqual(
options.pic_options.read_requirements.keep_supplementary_alignments,
True,
)
@flagsaver.flagsaver
def test_keep_secondary_alignments(self):
FLAGS.keep_secondary_alignments = True
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
self.assertEqual(
options.pic_options.read_requirements.keep_secondary_alignments, True
)
@flagsaver.flagsaver
def test_min_base_quality(self):
FLAGS.min_base_quality = 5
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
self.assertEqual(options.pic_options.read_requirements.min_base_quality, 5)
@flagsaver.flagsaver
def test_min_mapping_quality(self):
FLAGS.min_mapping_quality = 15
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
self.assertEqual(
options.pic_options.read_requirements.min_mapping_quality, 15
)
@flagsaver.flagsaver
def test_default_options_with_training_random_emit_ref_sites(self):
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
FLAGS.training_random_emit_ref_sites = 0.3
options = make_examples.default_options(add_flags=True)
self.assertAlmostEqual(
options.sample_options[
1
].variant_caller_options.fraction_reference_sites_to_emit,
0.3,
)
@flagsaver.flagsaver
def test_default_options_without_training_random_emit_ref_sites(self):
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
# In proto3, there is no way to check presence of scalar field:
# redacted
# As an approximation, we directly check that the value should be exactly 0.
self.assertEqual(
options.sample_options[
1
].variant_caller_options.fraction_reference_sites_to_emit,
0.0,
)
@flagsaver.flagsaver
def test_confident_regions(self):
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
FLAGS.confident_regions = testdata.CONFIDENT_REGIONS_BED
FLAGS.mode = 'training'
FLAGS.examples = ''
options = make_examples.default_options(add_flags=True)
confident_regions = make_examples_core.read_confident_regions(options)
# Our expected intervals, inlined from CONFIDENT_REGIONS_BED.
expected = _from_literals_list([
'20:10000847-10002407',
'20:10002521-10004171',
'20:10004274-10004964',
'20:10004995-10006386',
'20:10006410-10007800',
'20:10007825-10008018',
'20:10008044-10008079',
'20:10008101-10008707',
'20:10008809-10008897',
'20:10009003-10009791',
'20:10009934-10010531',
])
# Our confident regions should be exactly those found in the BED file.
self.assertCountEqual(expected, list(confident_regions))
@parameterized.parameters(
({'examples': ('foo', 'foo')},),
({'examples': ('foo', 'foo'), 'gvcf': ('bar', 'bar')},),
({'examples': ('foo@10', 'foo-00000-of-00010')},),
({'task': (0, 0), 'examples': ('foo@10', 'foo-00000-of-00010')},),
({'task': (1, 1), 'examples': ('foo@10', 'foo-00001-of-00010')},),
(
{
'task': (1, 1),
'examples': ('foo@10', 'foo-00001-of-00010'),
'gvcf': ('bar@10', 'bar-00001-of-00010'),
},
),
(
{
'task': (1, 1),
'examples': ('foo@10', 'foo-00001-of-00010'),
'gvcf': ('bar@10', 'bar-00001-of-00010'),
'candidates': ('baz@10', 'baz-00001-of-00010'),
},
),
)
@flagsaver.flagsaver
def test_sharded_outputs1(self, settings):
# Set all of the requested flag values.
for name, (flag_val, _) in settings.items():
setattr(FLAGS, name, flag_val)
FLAGS.mode = 'training'
FLAGS.reads = ''
FLAGS.ref = ''
options = make_examples.default_options(add_flags=True)
# Check all of the flags.
for name, option_val in [
('examples', options.examples_filename),
('candidates', options.candidates_filename),
('gvcf', options.gvcf_filename),
]:
expected = settings[name][1] if name in settings else ''
self.assertEqual(expected, option_val)
def test_catches_bad_argv(self):
with (
mock.patch.object(logging, 'error') as mock_logging,
mock.patch.object(sys, 'exit') as mock_exit,
):
make_examples.main(['make_examples.py', 'extra_arg'])
mock_logging.assert_called_once_with(
'Command line parsing failure: make_examples does not accept '
'positional arguments but some are present on the command line: '
"\"['make_examples.py', 'extra_arg']\"."
)
mock_exit.assert_called_once_with(errno.ENOENT)
@flagsaver.flagsaver
def test_catches_bad_flags(self):
# Set all of the requested flag values.
region = ranges.parse_literal('20:10,000,000-10,010,000')
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.candidates = test_utils.test_tmpfile('vsc.tfrecord')
FLAGS.examples = test_utils.test_tmpfile('examples.tfrecord')
FLAGS.regions = [ranges.to_literal(region)]
FLAGS.partition_size = 1000
FLAGS.mode = 'training'
FLAGS.truth_variants = testdata.TRUTH_VARIANTS_VCF
# This is the bad flag.
FLAGS.confident_regions = ''
with (
mock.patch.object(logging, 'error') as mock_logging,
mock.patch.object(sys, 'exit') as mock_exit,
):
make_examples.main(['make_examples.py'])
mock_logging.assert_called_once_with(
'confident_regions is required when in training mode.'
)
mock_exit.assert_called_once_with(errno.ENOENT)
@flagsaver.flagsaver
def test_regions_and_exclude_regions_flags_with_trio_options(self):
FLAGS.mode = 'calling'
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.regions = '20:10,000,000-11,000,000'
FLAGS.examples = 'examples.tfrecord'
FLAGS.exclude_regions = '20:10,010,000-10,100,000'
options = make_examples.default_options(add_flags=True)
_, regions_from_options = (
make_examples_core.processing_regions_from_options(options)
)
self.assertCountEqual(
list(ranges.RangeSet(regions_from_options)),
_from_literals_list(
['20:10,000,000-10,009,999', '20:10,100,001-11,000,000']
),
)
@flagsaver.flagsaver
def test_incorrect_empty_regions_with_trio_options(self):
FLAGS.mode = 'calling'
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
# Deliberately incorrect contig name.
FLAGS.regions = 'xxx20:10,000,000-11,000,000'
FLAGS.examples = 'examples.tfrecord'
options = make_examples.default_options(add_flags=True)
with self.assertRaisesRegex(ValueError, 'The regions to call is empty.'):
make_examples_core.processing_regions_from_options(options)
class RegionProcessorTest(parameterized.TestCase):
def setUp(self):
super(RegionProcessorTest, self).setUp()
self.region = ranges.parse_literal('20:10,000,000-10,000,100')
FLAGS.reads = ''
self.options = make_examples.default_options(add_flags=False)
self.options.reference_filename = testdata.CHR20_FASTA
self.options.truth_variants_filename = testdata.TRUTH_VARIANTS_VCF
self.options.mode = deepvariant_pb2.MakeExamplesOptions.TRAINING
self.ref_reader = fasta.IndexedFastaReader(self.options.reference_filename)
self.default_shape = [5, 5, 7]
self.processor = make_examples_core.RegionProcessor(self.options)
self.mock_init = self.add_mock('_initialize')
for sample in self.processor.samples:
sample.in_memory_sam_reader = mock.Mock()
def add_mock(self, name, retval='dontadd', side_effect='dontadd'):
patcher = mock.patch.object(self.processor, name, autospec=True)
self.addCleanup(patcher.stop)
mocked = patcher.start()
if retval != 'dontadd':
mocked.return_value = retval
if side_effect != 'dontadd':
mocked.side_effect = side_effect
return mocked
@parameterized.parameters([
deepvariant_pb2.MakeExamplesOptions.TRAINING,
deepvariant_pb2.MakeExamplesOptions.CALLING,
])
def test_process_keeps_ordering_of_candidates_and_examples(self, mode):
self.processor.options.mode = mode
r1, r2 = mock.Mock(), mock.Mock()
c1, c2 = mock.Mock(), mock.Mock()
self.add_mock('region_reads_norealign', retval=[r1, r2])
self.add_mock('candidates_in_region', retval=({'child': [c1, c2]}, {}))
candidates_dict, gvcfs_dict, runtimes = self.processor.process(self.region)
self.assertEqual({'child': [c1, c2]}, candidates_dict)
self.assertEqual({}, gvcfs_dict)
self.assertIsInstance(runtimes, dict)
in_memory_sam_reader = self.processor.samples[1].in_memory_sam_reader
in_memory_sam_reader.replace_reads.assert_called_once_with([r1, r2])
def test_create_pileup_examples_handles_none(self):
self.processor.pic = mock.Mock()
dv_call = mock.Mock()
self.processor.pic.create_pileup_images.return_value = None
self.assertEqual(
[], self.processor.create_pileup_examples(dv_call, 'child')
)
self.processor.pic.create_pileup_images.assert_called_once()
def test_create_pileup_examples(self):
self.processor.pic = mock.Mock()
self.processor.pic.get_channels.return_value = None
self.add_mock(
'_encode_tensor',
side_effect=[
(b'tensor1', self.default_shape),
(b'tensor2', self.default_shape),
],
)
dv_call = mock.Mock()
dv_call.variant = test_utils.make_variant(start=10, alleles=['A', 'C', 'G'])
ex = mock.Mock()
alt1, alt2 = ['C'], ['G']
self.processor.pic.create_pileup_images.return_value = [
(alt1, b'tensor1'),
(alt2, b'tensor2'),
]
actual = self.processor.create_pileup_examples(dv_call, 'child')
self.processor.pic.create_pileup_images.assert_called_once()
self.assertLen(actual, 2)
for ex, (alt, img) in zip(actual, [(alt1, b'tensor1'), (alt2, b'tensor2')]):
self.assertEqual(dv_utils.example_alt_alleles(ex), alt)
self.assertEqual(dv_utils.example_variant(ex), dv_call.variant)
self.assertEqual(dv_utils.example_encoded_image(ex), img)
self.assertEqual(dv_utils.example_image_shape(ex), self.default_shape)
@parameterized.parameters(
# Test that a het variant gets a label value of 1 assigned to the example.
dict(
label=variant_labeler.VariantLabel(
is_confident=True,
variant=test_utils.make_variant(start=10, alleles=['A', 'C']),
genotype=(0, 1),
),
denovo_label=0,
expected_label_value=1,
denovo_enabled=True,
),
# Test that a reference variant gets a label value of 0 in the example.
dict(
label=variant_labeler.VariantLabel(
is_confident=True,
variant=test_utils.make_variant(start=10, alleles=['A', '.']),
genotype=(0, 0),
),
denovo_label=1,
expected_label_value=0,
denovo_enabled=True,
),
dict(
label=variant_labeler.VariantLabel(
is_confident=True,
variant=test_utils.make_variant(start=10, alleles=['A', '.']),
genotype=(0, 0),
),
denovo_label=None,
expected_label_value=0,
denovo_enabled=False,
),
)
def test_add_label_to_example(
self, label, denovo_label, expected_label_value, denovo_enabled
):
example = self._example_for_variant(label.variant)
labeled = copy.deepcopy(example)
actual = self.processor.add_label_to_example(
labeled, label, denovo_label, denovo_enabled
)
# The add_label_to_example command modifies labeled and returns it.
self.assertIs(actual, labeled)
# Check that all keys from example are present in labeled.
for key, value in example.features.feature.items():
if key != 'variant/encoded': # Special case tested below.
self.assertEqual(value, labeled.features.feature[key])
# The genotype of our example_variant should be set to the true genotype
# according to our label.
self.assertEqual(expected_label_value, dv_utils.example_label(labeled))
if denovo_enabled:
self.assertEqual(denovo_label, dv_utils.example_denovo_label(labeled))
else:
self.assertIsNone(dv_utils.example_denovo_label(labeled))
labeled_variant = dv_utils.example_variant(labeled)
call = variant_utils.only_call(labeled_variant)
self.assertEqual(tuple(call.genotype), label.genotype)
# The original variant and labeled_variant from out tf.Example should be
# equal except for the genotype field, since this is set by
# add_label_to_example.
label.variant.calls[0].genotype[:] = []
call.genotype[:] = []
self.assertEqual(label.variant, labeled_variant)
def test_label_variant_raises_for_non_confident_variant(self):
label = variant_labeler.VariantLabel(
is_confident=False,
variant=test_utils.make_variant(start=10, alleles=['A', 'C']),
genotype=(0, 1),
)
example = self._example_for_variant(label.variant)
with self.assertRaisesRegex(
ValueError, 'Cannot add a non-confident label to an example'
):
self.processor.add_label_to_example(
example, label, denovo_label=0, denovo_enabled=False
)
def _example_for_variant(self, variant):
return dv_utils_using_clif.make_example(
variant, list(variant.alternate_bases), b'foo', self.default_shape
)
def test_use_original_quality_scores_without_parse_sam_aux_fields(self):
FLAGS.mode = 'calling'
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.examples = 'examples.tfrecord'
FLAGS.use_original_quality_scores = True
FLAGS.parse_sam_aux_fields = False
with self.assertRaisesRegex(
Exception,
(
'If --use_original_quality_scores is set then '
'--parse_sam_aux_fields must be set too.'
),
):
make_examples.default_options(add_flags=True)
@parameterized.parameters(
dict(height_parent=10, height_child=9),
dict(height_parent=9, height_child=10),
dict(height_parent=150, height_child=101),
dict(height_parent=101, height_child=170),
)
@flagsaver.flagsaver
def test_image_heights(self, height_parent, height_child):
FLAGS.pileup_image_height_parent = height_parent
FLAGS.pileup_image_height_child = height_child
FLAGS.mode = 'calling'
FLAGS.ref = testdata.CHR20_FASTA
FLAGS.reads = testdata.HG001_CHR20_BAM
FLAGS.reads_parent1 = testdata.NA12891_CHR20_BAM
FLAGS.reads_parent2 = testdata.NA12892_CHR20_BAM
FLAGS.sample_name = 'child'
FLAGS.sample_name_to_train = 'child'
FLAGS.sample_name_parent1 = 'parent1'
FLAGS.sample_name_parent2 = 'parent2'
FLAGS.examples = 'examples.tfrecord'
options = make_examples.default_options(add_flags=True)
with self.assertRaisesRegex(
Exception, 'Total pileup image heights must be between 75-362.'
):
make_examples.check_options_are_valid(options)
@parameterized.parameters(
[
dict(window_width=221),
dict(window_width=1001),
],
)
def test_align_to_all_haplotypes(self, window_width):
# align_to_all_haplotypes() will pull from the reference, so choose a
# real variant.
region = ranges.parse_literal('20:10,046,000-10,046,400')
nist_reader = vcf.VcfReader(testdata.TRUTH_VARIANTS_VCF)
nist_variants = list(nist_reader.query(region))
# We picked this region to have exactly one known variant:
# reference_bases: "AAGAAAGAAAG"
# alternate_bases: "A", a deletion of 10 bp
# start: 10046177
# end: 10046188
# reference_name: "chr20"
variant = nist_variants[0]
self.processor.pic = mock.Mock()
self.processor.pic.width = window_width
self.processor.pic.half_width = int((window_width - 1) / 2)
self.processor.realigner = mock.Mock()
# Using a real ref_reader to test that the reference allele matches
# between the variant and the reference at the variant's coordinates.
self.processor.realigner.ref_reader = self.ref_reader
read = test_utils.make_read(
'A' * 101, start=10046100, cigar='101M', quals=[30] * 101
)
self.processor.realigner.align_to_haplotype = mock.Mock()
alt_info = self.processor.align_to_all_haplotypes(variant, [read])
hap_alignments = alt_info['alt_alignments']
hap_sequences = alt_info['alt_sequences']
# Both outputs are keyed by alt allele.
self.assertCountEqual(hap_alignments.keys(), ['A'])
self.assertCountEqual(hap_sequences.keys(), ['A'])
# Sequence must be the length of the window.
self.assertLen(hap_sequences['A'], self.processor.pic.width)
# align_to_haplotype should be called once for each alt (1 alt here).
self.processor.realigner.align_to_haplotype.assert_called_once()
# If variant reference_bases are wrong, it should raise a ValueError.
variant.reference_bases = 'G'
with self.assertRaisesRegex(
ValueError, 'does not match the bases in the reference'
):
self.processor.align_to_all_haplotypes(variant, [read])
if __name__ == '__main__':
absltest.main()