Diff of /deepvariant/dv_utils.py [000000] .. [5a4941]

Switch to unified view

a b/deepvariant/dv_utils.py
1
# Copyright 2017 Google LLC.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# 1. Redistributions of source code must retain the above copyright notice,
8
#    this list of conditions and the following disclaimer.
9
#
10
# 2. Redistributions in binary form must reproduce the above copyright
11
#    notice, this list of conditions and the following disclaimer in the
12
#    documentation and/or other materials provided with the distribution.
13
#
14
# 3. Neither the name of the copyright holder nor the names of its
15
#    contributors may be used to endorse or promote products derived from this
16
#    software without specific prior written permission.
17
#
18
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
# POSSIBILITY OF SUCH DAMAGE.
29
"""Utility functions for DeepVariant.
30
31
Started with a collection of utilities for working with the TF models. Now this
32
file includes broader utilities we use in DeepVariant.
33
"""
34
35
import json
36
from typing import Optional, Tuple
37
38
from absl import logging
39
import numpy as np
40
import tensorflow as tf
41
42
from deepvariant.protos import deepvariant_pb2
43
from third_party.nucleus.io import sharded_file_utils
44
# TODO: this dep still uses CLIF.
45
from third_party.nucleus.io import tfrecord
46
from third_party.nucleus.protos import variants_pb2
47
from tensorflow.core.example import example_pb2
48
49
50
# Convert strings up to this length, then clip.  We picked a number that
51
# was less than 1K, with a bit of extra space for the length element,
52
# to give enough space without overflowing to a larger multiple of 128.
53
STRING_TO_INT_MAX_CONTENTS_LEN = 1020
54
# This is the length of the resulting buffer (including the length entry).
55
STRING_TO_INT_BUFFER_LENGTH = STRING_TO_INT_MAX_CONTENTS_LEN + 1
56
57
58
def example_variant_type(example):
59
  """Gets the locus field from example as a string."""
60
  return example.features.feature['variant_type'].int64_list.value[0]
61
62
63
def example_locus(example):
64
  """Gets the locus field from example as a string."""
65
  return example.features.feature['locus'].bytes_list.value[0]
66
67
68
def example_alt_alleles_indices(example):
69
  """Gets an iterable of the alt allele indices in example."""
70
  return deepvariant_pb2.CallVariantsOutput.AltAlleleIndices.FromString(
71
      example.features.feature['alt_allele_indices/encoded'].bytes_list.value[0]
72
  ).indices
73
74
75
def example_alt_alleles(example, variant=None):
76
  """Gets a list of the alt alleles in example."""
77
  variant = variant if variant else example_variant(example)
78
  return [
79
      variant.alternate_bases[i] for i in example_alt_alleles_indices(example)
80
  ]
81
82
83
def example_encoded_image(example):
84
  """Gets image field from example as a string."""
85
  return example.features.feature['image/encoded'].bytes_list.value[0]
86
87
88
def example_variant(example):
89
  """Gets and decodes the variant field from example as a Variant."""
90
  encoded = example.features.feature['variant/encoded'].bytes_list.value[0]
91
  return variants_pb2.Variant.FromString(encoded)
92
93
94
def example_label(example: example_pb2.Example) -> Optional[int]:
95
  """Gets the label field from example as a string."""
96
  if 'label' not in example.features.feature:
97
    return None
98
  return int(example.features.feature['label'].int64_list.value[0])
99
100
101
def example_denovo_label(example: example_pb2.Example) -> Optional[int]:
102
  """Gets the label field from example as a string.
103
104
  Args:
105
    example: A tf.Example containing DeepVariant example.
106
107
  Returns:
108
    De novo label for the example.
109
  """
110
  if 'denovo_label' not in example.features.feature:
111
    return None
112
  return int(example.features.feature['denovo_label'].int64_list.value[0])
113
114
115
def example_image_shape(example):
116
  """Gets the image shape field from example as a list of int64."""
117
  if len(example.features.feature['image/shape'].int64_list.value) != 3:
118
    raise ValueError(
119
        'Invalid image/shape: we expect to find an image/shape '
120
        'field with length 3.'
121
    )
122
  return example.features.feature['image/shape'].int64_list.value[0:3]
123
124
125
def example_key(example):
126
  """Constructs a key for example based on its position and alleles."""
127
  variant = example_variant(example)
128
  alts = example_alt_alleles(example)
129
  return '{}:{}:{}->{}'.format(
130
      variant.reference_name,
131
      variant.start + 1,
132
      variant.reference_bases,
133
      '/'.join(alts),
134
  )
135
136
137
def example_set_label(example, numeric_label):
138
  """Sets the label features of example.
139
140
  Sets the label feature of example to numeric_label.
141
142
  Args:
143
    example: A tf.Example proto.
144
    numeric_label: A numeric (int64 compatible) label for example.
145
  """
146
  example.features.feature['label'].int64_list.value[:] = [numeric_label]
147
148
149
def example_set_denovo_label(
150
    example: example_pb2.Example, numeric_label: int
151
) -> None:
152
  """Sets the denovo label features of example.
153
154
  Sets the label feature of example to numeric_label.
155
156
  Args:
157
    example: a tf.Example proto.
158
    numeric_label: A numeric (int64 compatible) label for example.
159
  """
160
  example.features.feature['denovo_label'].int64_list.value[:] = [numeric_label]
161
162
163
def example_set_variant(example, variant, deterministic=False):
164
  """Sets the variant/encoded feature of example to variant.SerializeToString().
165
166
  Args:
167
    example: a tf.Example proto.
168
    variant: third_party.nucleus.protos.Variant protobuf containing information
169
      about a candidate variant call.
170
    deterministic: Used to set SerializeToString.
171
  """
172
  example.features.feature['variant/encoded'].bytes_list.value[:] = [
173
      variant.SerializeToString(deterministic=deterministic)
174
  ]
175
176
177
def example_sequencing_type(example):
178
  return example.features.feature['sequencing_type'].int64_list.value[0]
179
180
181
def get_one_example_from_examples_path(source, proto=None):
182
  """Get the first record from `source`.
183
184
  Args:
185
    source: str. A pattern or a comma-separated list of patterns that represent
186
      file names.
187
    proto: A proto class. proto.FromString() will be called on each serialized
188
      record in path to parse it.
189
190
  Returns:
191
    The first record, or None.
192
  """
193
  files = sharded_file_utils.glob_list_sharded_file_patterns(source)
194
  if not files:
195
    raise ValueError(
196
        'Cannot find matching files with the pattern "{}"'.format(source)
197
    )
198
  for f in files:
199
    try:
200
      return next(
201
          tfrecord.read_tfrecords(f, proto=proto, compression_type='GZIP')
202
      )
203
    except StopIteration:
204
      # Getting a StopIteration from one next() means source_path is empty.
205
      # Move on to the next one to try to get one example.
206
      pass
207
  return None
208
209
210
def get_shape_from_examples_path(source):
211
  """Reads one record from source to determine the tensor shape for all."""
212
  one_example = get_one_example_from_examples_path(source)
213
  if one_example:
214
    return example_image_shape(one_example)
215
  return None
216
217
218
def _simplify_variant(variant):
219
  """Returns a new Variant with only the basic fields of variant."""
220
221
  def _simplify_variant_call(call):
222
    """Returns a new VariantCall with the basic fields of call."""
223
    return variants_pb2.VariantCall(
224
        call_set_name=call.call_set_name,
225
        genotype=call.genotype,
226
        info=dict(call.info),
227
    )  # dict() is necessary to actually set info.
228
229
  return variants_pb2.Variant(
230
      reference_name=variant.reference_name,
231
      start=variant.start,
232
      end=variant.end,
233
      reference_bases=variant.reference_bases,
234
      alternate_bases=variant.alternate_bases,
235
      filter=variant.filter,
236
      quality=variant.quality,
237
      calls=[_simplify_variant_call(call) for call in variant.calls],
238
  )
239
240
241
def string_to_int_tensor(x):
242
  """Graph operations decode a string into a fixed-size tensor of ints."""
243
  decoded = tf.compat.v1.decode_raw(x, tf.uint8)
244
  clipped = decoded[:STRING_TO_INT_MAX_CONTENTS_LEN]  # clip to allowed max_len
245
  shape = tf.shape(input=clipped)
246
  slen = shape[0]
247
  # pad to desired max_len
248
  padded = tf.pad(
249
      tensor=clipped, paddings=[[0, STRING_TO_INT_MAX_CONTENTS_LEN - slen]]
250
  )
251
  casted = tf.cast(padded, tf.int32)
252
  casted.set_shape([STRING_TO_INT_MAX_CONTENTS_LEN])
253
  return tf.concat([[slen], casted], 0)
254
255
256
def int_tensor_to_string(x):
257
  """Python operations to encode a tensor of ints into string of bytes."""
258
  slen = x[0]
259
  v = x[1 : slen + 1]
260
  return np.array(v, dtype=np.uint8).tostring()
261
262
263
def tpu_available(sess=None):
264
  """Return true if a TPU device is available to the default session."""
265
  if sess is None:
266
    init_op = tf.group(
267
        tf.compat.v1.global_variables_initializer(),
268
        tf.compat.v1.local_variables_initializer(),
269
    )
270
    with tf.compat.v1.Session() as sess:
271
      sess.run(init_op)
272
      devices = sess.list_devices()
273
  else:
274
    devices = sess.list_devices()
275
  return any(dev.device_type == 'TPU' for dev in devices)
276
277
278
def resolve_master(master, tpu_name, tpu_zone, gcp_project):
279
  """Resolve the master's URL given standard flags."""
280
  if master is not None:
281
    return master
282
  elif tpu_name is not None:
283
    return tf.distribute.cluster_resolver.TPUClusterResolver(
284
        tpu=[tpu_name], zone=tpu_zone, project=gcp_project
285
    ).get_master()
286
  else:
287
    # For k8s TPU we do not have/need tpu_name. See
288
    # https://cloud.google.com/tpu/docs/kubernetes-engine-setup#tensorflow-code
289
    return tf.distribute.cluster_resolver.TPUClusterResolver().get_master()
290
291
292
def get_example_info_json_filename(
293
    examples_filename: str, task_id: Optional[int]
294
) -> str:
295
  """Returns corresponding example_info.json filename for examples_filename."""
296
  if sharded_file_utils.is_sharded_file_spec(examples_filename):
297
    assert task_id is not None
298
    # If examples_filename has the @shards representation, resolve it into
299
    # the first shard. We only write .example_info.json to the first shard.
300
    example_info_prefix = sharded_file_utils.sharded_filename(
301
        examples_filename, task_id
302
    )
303
  else:
304
    # In all other cases, including non-sharded files,
305
    # or sharded filenames with -ddddd-of-ddddd, just append.
306
    example_info_prefix = examples_filename
307
  return example_info_prefix + '.example_info.json'
308
309
310
def get_shape_and_channels_from_json(example_info_json):
311
  """Returns the shape and channels list from the input json."""
312
  if not tf.io.gfile.exists(example_info_json):
313
    logging.warning(
314
        (
315
            'Starting from v1.4.0, we expect %s to '
316
            'include information for shape and channels.'
317
        ),
318
        example_info_json,
319
    )
320
    return None, None
321
  with tf.io.gfile.GFile(example_info_json) as f:
322
    example_info = json.load(f)
323
  example_shape = example_info['shape']
324
  example_channels_enum = example_info['channels']
325
  logging.info(
326
      'From %s: Shape of input examples: %s, Channels of input examples: %s.',
327
      example_info_json,
328
      str(example_shape),
329
      str(example_channels_enum),
330
  )
331
  return example_shape, example_channels_enum
332
333
334
def get_tf_record_writer(output_filename: str) -> tf.io.TFRecordWriter:
335
  tf_options = tf.io.TFRecordOptions(compression_type='GZIP')
336
  return tf.io.TFRecordWriter(output_filename, options=tf_options)
337
338
339
def preprocess_images(images):
340
  """Applies preprocessing operations for Inception images.
341
342
  Because this will run in model_fn, on the accelerator, we use operations
343
  that efficiently execute there.
344
345
  Args:
346
    images: A Tensor of with uint8 values.
347
348
  Returns:
349
    A tensor of images the same shape, containing floating point values, with
350
    all points rescaled between -1 and 1 and possibly resized.
351
  """
352
  images = tf.cast(images, dtype=tf.float32)
353
  images = tf.subtract(images, 128.0)
354
  images = tf.math.divide(images, 128.0)
355
  return images
356
357
358
def unpreprocess_images(images: np.ndarray) -> np.ndarray:
359
  """Reverses preprocess_images in numpy format.
360
361
  Args:
362
    images: A numpy array with floating point values.
363
364
  Returns:
365
    A numpy array of images the same shape.
366
  """
367
  images *= 128.0
368
  images += 128.0
369
  # We can optionally convert it to uint8 by .astype(np.uint8).
370
  # But for now we'll just return it as floating points.
371
  return images
372
373
374
def call_variant_to_tfexample(
375
    cvo: deepvariant_pb2.CallVariantsOutput,
376
    image_shape: Tuple[int, int, int] = (100, 221, 7),
377
) -> tf.train.Example:
378
  """Converts CallVariantsOutput to tf.train.Example if possible.
379
380
  This function is for specific debugging purpose, so it will only
381
  work on CallVariantsOutput with debug_info.image_encoded.
382
383
  Note: Not all values are transferred as there isn't a 1:1 mapping. No mapping
384
  exists for 'variant/encoded' or 'sequencing_type' for example.
385
386
  Args:
387
    cvo: A CallVariantsOutput to convert to a TF.Example.
388
    image_shape: The shape of the image contained within cvo.
389
390
  Returns:
391
    A Tf.Example created from the given CallVariantsOutput.
392
393
  Raises:
394
    ValueError if the input data lacks the needed fields.
395
  """
396
  tfexample = tf.train.Example()
397
  features = tfexample.features.feature
398
  features['image/shape'].int64_list.value[:] = list(image_shape)
399
  if cvo.debug_info and cvo.debug_info.image_encoded:
400
    features['image/encoded'].bytes_list.value[:] = [
401
        cvo.debug_info.image_encoded
402
    ]
403
  else:
404
    raise ValueError('CallVariantsOutput does not contain an image.')
405
406
  features['label'].int64_list.value[:] = [cvo.debug_info.true_label]
407
408
  if cvo.alt_allele_indices:
409
    features['alt_allele_indices'].int64_list.value[
410
        :
411
    ] = cvo.alt_allele_indices.indices
412
413
  # Create and assign locus
414
  features['locus'].bytes_list.value[:] = [
415
      bytes(
416
          f'{cvo.variant.reference_name}:{cvo.variant.start}-{cvo.variant.end}',
417
          'utf-8',
418
      )
419
  ]
420
  return tfexample