# Copyright 2017 Google LLC.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Code for calling variants with a trained DeepVariant model."""
import os
import time
from absl import flags
from absl import logging
import numpy as np
import tensorflow as tf
from tensorflow import estimator as tf_estimator
from deepvariant import data_providers
from deepvariant import dv_utils
from deepvariant import logging_level
from deepvariant import modeling
from deepvariant.protos import deepvariant_pb2
from google.protobuf import text_format
from third_party.nucleus.io import sharded_file_utils
from third_party.nucleus.io import tfrecord
from third_party.nucleus.protos import variants_pb2
from third_party.nucleus.util import errors
from third_party.nucleus.util import proto_utils
from third_party.nucleus.util import variant_utils
tf.compat.v1.disable_eager_execution()
_ALLOW_EXECUTION_HARDWARE = [
'auto', # Default, no validation.
'cpu', # Don't use accelerators, even if available.
'accelerator', # Must be hardware acceleration or an error will be raised.
]
# The number of digits past the decimal point that genotype likelihoods are
# rounded to, for numerical stability.
_GL_PRECISION = 10
# This number is estimated by the following logic:
# For a sample with 10,000,000 examples, if we log every 50,000 examples,
# there will be 200 lines per sample.
_LOG_EVERY_N = 50000
FLAGS = flags.FLAGS
flags.DEFINE_string(
'examples',
None,
(
'Required. tf.Example protos containing DeepVariant candidate variants'
' in TFRecord format, as emitted by make_examples. Can be a'
' comma-separated list of files, and the file names can contain'
' wildcard characters.'
),
)
flags.DEFINE_string(
'outfile',
None,
(
'Required. Destination path where we will write output candidate'
' variants with additional likelihood information in TFRecord format of'
' CallVariantsOutput protos.'
),
)
flags.DEFINE_string(
'checkpoint',
None,
(
'Required. Path to the TensorFlow model checkpoint to use to evaluate '
'candidate variant calls.'
),
)
flags.DEFINE_integer(
'batch_size',
512,
(
'Number of candidate variant tensors to batch together during'
' inference. Larger batches use more memory but are more computational'
' efficient.'
),
)
flags.DEFINE_integer(
'max_batches', None, 'Max. batches to evaluate. Defaults to all.'
)
flags.DEFINE_integer(
'num_readers', 8, 'Number of parallel readers to create for examples.'
)
flags.DEFINE_string(
'model_name',
'inception_v3',
'The name of the model architecture of --checkpoint.',
)
flags.DEFINE_boolean(
'include_debug_info',
False,
'If true, include extra debug info in the output.',
)
flags.DEFINE_boolean(
'debugging_true_label_mode',
False,
(
'If true, read the true labels from examples and add to '
'output. Note that the program will crash if the input '
'examples do not have the label field. '
'When true, this will also fill everything when '
'--include_debug_info is set to true.'
),
)
flags.DEFINE_string(
'execution_hardware',
'auto',
(
'When in cpu mode, call_variants will not place any ops on the GPU,'
' even if one is available. In accelerator mode call_variants validates'
' that at least some hardware accelerator (GPU/TPU) was available for'
' us. This option is primarily for QA purposes to allow users to'
' validate their accelerator environment is correctly configured. In'
' auto mode, the default, op placement is entirely left up to'
' TensorFlow. In tpu mode, use and require TPU.'
),
)
flags.DEFINE_string(
'config_string',
None,
(
'String representation of a tf.ConfigProto message, with'
' comma-separated key: value pairs, such as "allow_soft_placement:'
' True". The value can itself be another message, such as "gpu_options:'
' {per_process_gpu_memory_fraction: 0.5}".'
),
)
# Cloud TPU Cluster Resolvers
flags.DEFINE_string(
'gcp_project',
None,
(
'Project name for the Cloud TPU-enabled project. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.'
),
)
flags.DEFINE_string(
'tpu_zone',
None,
(
'GCE zone where the Cloud TPU is located in. If not specified, we '
'will attempt to automatically detect the GCE project from metadata.'
),
)
flags.DEFINE_string(
'tpu_name',
None,
( # pylint: disable=line-too-long
'Name of the Cloud TPU for Cluster Resolvers. You must specify either '
'this flag or --primary. An empty value corresponds to no Cloud TPU.'
' See '
'https://www.tensorflow.org/api_docs/python/tf/distribute/cluster_resolver/TPUClusterResolver'
),
)
flags.DEFINE_string(
'primary',
None,
(
'GRPC URL of the primary (e.g. grpc://ip.address.of.tpu:8470). You '
'must specify either this flag or --tpu_name.'
),
)
flags.DEFINE_boolean('use_tpu', False, 'Use tpu if available.')
flags.DEFINE_boolean('use_openvino', False, 'Use Intel OpenVINO as backend.')
flags.DEFINE_string(
'openvino_model_dir',
'',
'If set, use this directory to save the temporary model file for OpenVINO.',
)
flags.DEFINE_string(
'kmp_blocktime',
'0',
(
'Value to set the KMP_BLOCKTIME environment variable to for efficient'
' MKL inference. See'
' https://www.tensorflow.org/performance/performance_guide for more'
' information. The default value is 0, which provides the best'
' performance in our tests. Set this flag to "" to not set the'
' variable.'
),
)
class ExecutionHardwareError(Exception):
pass
def prepare_inputs(source_path, use_tpu=False, num_readers=None):
"""Return a tf.data input_fn from the source_path.
Args:
source_path: Path to a TFRecord file containing deepvariant tf.Example
protos.
use_tpu: boolean. Use the tpu code path.
num_readers: int > 0 or None. Number of parallel readers to use to read
examples from source_path. If None, uses FLAGS.num_readers instead.
Returns:
A tf input_fn yielding batches of image, encoded_variant,
encoded_alt_allele_indices.
The image is a [batch_size, height, width, channel] tensor. The
encoded_variants is a tf.string or tpu-encoded tensor containing a
serialized Variant proto describing the variant call associated with
image. The encoded_alt_allele_indices is a tf.string or tpu-encoded
tensor containing a serialized CallVariantsOutput.AltAlleleIndices proto
containing the alternate alleles indices used as "alt" when constructing
the image.
"""
if not num_readers:
num_readers = FLAGS.num_readers
return data_providers.get_input_fn_from_filespec(
input_file_spec=source_path,
mode=tf_estimator.ModeKeys.PREDICT,
use_tpu=use_tpu,
input_read_threads=num_readers,
debugging_true_label_mode=FLAGS.debugging_true_label_mode,
)
def round_gls(gls, precision=None):
"""Returns genotype likelihoods rounded to the desired precision level.
Args:
gls: A list of floats. The input genotype likelihoods at any precision.
precision: Positive int. The number of places past the decimal point to
round to. If None, no rounding is performed.
Returns:
A list of floats rounded to the desired precision.
Raises:
ValueError: The input gls do not sum to nearly 1.
"""
if abs(sum(gls) - 1) > 1e-6:
raise ValueError(
'Invalid genotype likelihoods do not sum to one: sum({}) = {}'.format(
gls, sum(gls)
)
)
if precision is None:
return gls
min_ix = 0
min_gl = gls[0]
for ix, gl in enumerate(gls):
if gl < min_gl:
min_gl = gl
min_ix = ix
rounded_gls = [round(gl, precision) for gl in gls]
rounded_gls[min_ix] = max(
0.0,
round(
1 - sum(rounded_gls[:min_ix] + rounded_gls[min_ix + 1 :]), precision
),
)
return rounded_gls
def write_variant_call(writer, prediction, use_tpu):
"""Write the variant call based on prediction.
Args:
writer: A object with a write() function that will be called for each
encoded_variant and genotype likelihoods.
prediction: A [3] tensor of floats. These are the predicted genotype
likelihoods (p00, p0x, pxx) for some alt allele x, in the same order as
encoded_variants.
use_tpu: bool. Decode the tpu specific encoding of prediction.
Returns:
The return status from writer.
"""
encoded_variant = prediction['variant']
if use_tpu:
encoded_variant = dv_utils.int_tensor_to_string(encoded_variant)
encoded_alt_allele_indices = prediction['alt_allele_indices']
if use_tpu:
encoded_alt_allele_indices = dv_utils.int_tensor_to_string(
encoded_alt_allele_indices
)
rounded_gls = round_gls(prediction['probabilities'], precision=_GL_PRECISION)
# Write it out.
true_labels = prediction['label'] if FLAGS.debugging_true_label_mode else None
cvo = _create_cvo_proto(
encoded_variant,
rounded_gls,
encoded_alt_allele_indices,
true_labels,
logits=prediction.get('logits'),
prelogits=prediction.get('prelogits'),
)
return writer.write(cvo)
def _create_cvo_proto(
encoded_variant,
gls,
encoded_alt_allele_indices,
true_labels=None,
logits=None,
prelogits=None,
):
"""Returns a CallVariantsOutput proto from the relevant input information."""
variant = variants_pb2.Variant.FromString(encoded_variant)
alt_allele_indices = (
deepvariant_pb2.CallVariantsOutput.AltAlleleIndices.FromString(
encoded_alt_allele_indices
)
)
debug_info = None
if FLAGS.include_debug_info or FLAGS.debugging_true_label_mode:
if prelogits is not None:
assert prelogits.shape == (1, 1, 2048)
prelogits = prelogits[0][0]
debug_info = deepvariant_pb2.CallVariantsOutput.DebugInfo(
has_insertion=variant_utils.has_insertion(variant),
has_deletion=variant_utils.has_deletion(variant),
is_snp=variant_utils.is_snp(variant),
predicted_label=np.argmax(gls),
true_label=true_labels,
logits=logits,
prelogits=prelogits,
)
call_variants_output = deepvariant_pb2.CallVariantsOutput(
variant=variant,
alt_allele_indices=alt_allele_indices,
genotype_probabilities=gls,
debug_info=debug_info,
)
return call_variants_output
def call_variants(
examples_filename,
checkpoint_path,
model,
output_file,
execution_hardware='auto',
batch_size=16,
max_batches=None,
use_tpu=False,
primary='',
):
"""Main driver of call_variants."""
if FLAGS.kmp_blocktime:
os.environ['KMP_BLOCKTIME'] = FLAGS.kmp_blocktime
logging.vlog(
3, 'Set KMP_BLOCKTIME to {}'.format(os.environ['KMP_BLOCKTIME'])
)
# Read a single TFExample to make sure we're not loading an older version.
first_example = dv_utils.get_one_example_from_examples_path(examples_filename)
if first_example is None:
logging.warning(
'Unable to read any records from %s. Output will contain zero records.',
examples_filename,
)
tfrecord.write_tfrecords([], output_file)
return
example_info_json = dv_utils.get_example_info_json_filename(
examples_filename, 0
)
example_shape, example_channels_enum = (
dv_utils.get_shape_and_channels_from_json(example_info_json)
)
# Check if the checkpoint_path has the same shape.
if checkpoint_path is not None and example_shape is not None:
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
shape_map_for_layers = reader.get_variable_to_shape_map()
first_layer = 'InceptionV3/Conv2d_1a_3x3/weights'
# For a shape map of [3, 3, 6, 32] for the Conv2d_1a_3x3 layer, the 6
# is the number of channels.
num_channels_in_checkpoint_model = shape_map_for_layers[first_layer][2]
if num_channels_in_checkpoint_model != example_shape[2]:
raise ValueError(
'The number of channels in examples and checkpoint '
'should match, but the checkpoint has {} channels while '
'the examples have {}.'.format(
num_channels_in_checkpoint_model, example_shape[2]
)
)
input_info_file = os.path.join(
os.path.dirname(checkpoint_path), 'model.ckpt.example_info.json'
)
ckpt_shape, ckpt_channels_enum = dv_utils.get_shape_and_channels_from_json(
input_info_file
)
if ckpt_shape is not None and ckpt_channels_enum is not None:
if example_shape != ckpt_shape:
raise ValueError(
f'Shape mismatch in {example_info_json} and {input_info_file}.'
)
if example_channels_enum != ckpt_channels_enum:
raise ValueError(
f'Channels mismatch in {example_info_json} and {input_info_file}.'
)
else:
# We can consider more strictly enforcing this.
logging.warning(
'Starting from v1.4.0, we recommend having a '
'model.ckpt.example_info.json file with your model.'
)
# Check accelerator status.
if execution_hardware not in _ALLOW_EXECUTION_HARDWARE:
raise ValueError(
'Unexpected execution_hardware={} value. Allowed values are {}'.format(
execution_hardware, ','.join(_ALLOW_EXECUTION_HARDWARE)
)
)
init_op = tf.group(
tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer(),
)
config = tf.compat.v1.ConfigProto()
if FLAGS.config_string is not None:
text_format.Parse(FLAGS.config_string, config)
if execution_hardware == 'cpu':
# Don't overwrite entire dictionary.
config.device_count['GPU'] = 0
config.device_count['TPU'] = 0
# Perform sanity check.
with tf.compat.v1.Session(config=config) as sess:
sess.run(init_op)
if execution_hardware == 'accelerator':
if not any(dev.device_type != 'CPU' for dev in sess.list_devices()):
raise ExecutionHardwareError(
'execution_hardware is set to accelerator, but no accelerator '
'was found'
)
# TODO. Sort out auto-detection of TPU. Just calling
# sess.list_devices here doesn't return the correct answer. That can only
# work later, after the device (on the other VM) has been initialized,
# which is generally not yet.
# Prepare input stream and estimator.
tf_dataset = prepare_inputs(source_path=examples_filename, use_tpu=use_tpu)
estimator = model.make_estimator(
batch_size=batch_size,
master=primary,
use_tpu=use_tpu,
session_config=config,
include_debug_info=FLAGS.include_debug_info,
)
# Instantiate the prediction "stream", and select the EMA values from
# the model.
if checkpoint_path is None:
# Unit tests use this branch.
predict_hooks = []
else:
predict_hooks = [
h(checkpoint_path) for h in model.session_predict_hooks()
]
predictions = iter(
estimator.predict(
input_fn=tf_dataset,
checkpoint_path=checkpoint_path,
hooks=predict_hooks,
)
)
# The following code is introduced to be in sync with call_variants
# where we use multiple writers to write outpts.
# If output file is already sharded then don't dynamically shard.
if sharded_file_utils.is_sharded_filename(output_file):
logging.info('Output is already sharded, so dynamic sharding is disabled.')
else:
# For call_variants, we always use one writer process.
total_writer_process = 1
# Convert output filename to sharded output filename.
filename_pattern = output_file.replace(
'.tfrecord.gz', '@' + str(total_writer_process) + '.tfrecord.gz'
)
output_file = sharded_file_utils.maybe_generate_sharded_filenames(
filename_pattern
)[0]
# Consume predictions one at a time and write them to output_file.
logging.info('Writing calls to %s', output_file)
writer = tfrecord.Writer(output_file)
with writer:
start_time = time.time()
n_examples, n_batches = 0, 0
while max_batches is None or n_batches <= max_batches:
try:
prediction = next(predictions)
except (StopIteration, tf.errors.OutOfRangeError):
break
write_variant_call(writer, prediction, use_tpu)
n_examples += 1
n_batches = n_examples // batch_size + 1
duration = time.time() - start_time
logging.log_every_n(
logging.INFO,
'Processed %s examples in %s batches [%.3f sec per 100]',
_LOG_EVERY_N,
n_examples,
n_batches,
(100 * duration) / n_examples,
)
# One last log to capture the extra examples.
duration = time.time() - start_time
logging.info(
'Processed %s examples in %s batches [%.3f sec per 100]',
n_examples,
n_batches,
(100 * duration) / n_examples,
)
logging.info(
'Done calling variants from a total of %d examples.', n_examples
)
def main(argv=()):
with errors.clean_commandline_error_exit():
if len(argv) > 1:
errors.log_and_raise(
'Command line parsing failure: call_variants does not accept '
'positional arguments but some are present on the command line: '
'"{}".'.format(str(argv)),
errors.CommandLineError,
)
del argv # Unused.
proto_utils.uses_fast_cpp_protos_or_die()
logging_level.set_from_flag()
if FLAGS.use_tpu:
primary = dv_utils.resolve_master(
FLAGS.primary, FLAGS.tpu_name, FLAGS.tpu_zone, FLAGS.gcp_project
)
else:
primary = ''
model = modeling.get_model(FLAGS.model_name)
call_variants(
examples_filename=FLAGS.examples,
checkpoint_path=FLAGS.checkpoint,
model=model,
execution_hardware=FLAGS.execution_hardware,
output_file=FLAGS.outfile,
max_batches=FLAGS.max_batches,
batch_size=FLAGS.batch_size,
primary=primary,
use_tpu=FLAGS.use_tpu,
)
if __name__ == '__main__':
flags.mark_flags_as_required([
'examples',
'outfile',
'checkpoint',
])
tf.compat.v1.app.run()