Switch to unified view

a b/third_party/nucleus/util/vis.py
1
# Copyright 2019 Google LLC.
2
#
3
# Redistribution and use in source and binary forms, with or without
4
# modification, are permitted provided that the following conditions
5
# are met:
6
#
7
# 1. Redistributions of source code must retain the above copyright notice,
8
#    this list of conditions and the following disclaimer.
9
#
10
# 2. Redistributions in binary form must reproduce the above copyright
11
#    notice, this list of conditions and the following disclaimer in the
12
#    documentation and/or other materials provided with the distribution.
13
#
14
# 3. Neither the name of the copyright holder nor the names of its
15
#    contributors may be used to endorse or promote products derived from this
16
#    software without specific prior written permission.
17
#
18
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28
# POSSIBILITY OF SUCH DAMAGE.
29
30
"""Utility functions for visualization and inspection of pileup examples.
31
32
Visualization and inspection utility functions enable showing image-like array
33
data including those used in DeepVariant.
34
"""
35
36
from __future__ import absolute_import
37
from __future__ import division
38
from __future__ import print_function
39
40
import enum
41
import math
42
from typing import List, NamedTuple, Tuple
43
44
from etils import epath
45
from IPython import display
46
import numpy as np
47
from PIL import Image
48
from PIL import ImageDraw
49
50
from third_party.nucleus.protos import variants_pb2
51
52
53
DEEPVARIANT_CHANNEL_NAMES = [
54
    'read base', 'base quality', 'mapping quality', 'strand',
55
    'read supports variant', 'base differs from ref', 'haplotype tag',
56
    'alternate allele 1', 'alternate allele 2'
57
]
58
59
60
class Diff(enum.Enum):
61
  FEW_DIFFS = 1
62
  MANY_DIFFS = 2
63
  NEARBY_VARIANTS = 3
64
65
66
class BaseQuality(enum.Enum):
67
  GOOD = 1
68
  BAD = 2
69
70
71
class MappingQuality(enum.Enum):
72
  GOOD = 1
73
  BAD = 2
74
75
76
class StrandBias(enum.Enum):
77
  GOOD = 1
78
  BIASED = 2
79
80
81
class ReadSupport(enum.Enum):
82
  ALL = 1
83
  HALF = 2
84
  LOW = 3
85
86
87
PileupCuration = NamedTuple('PileupCuration',
88
                            [('base_quality', BaseQuality),
89
                             ('mapping_quality', MappingQuality),
90
                             ('strand_bias', StrandBias),
91
                             ('diff_category', Diff),
92
                             ('read_support', ReadSupport)])
93
94
95
def get_image_array_from_example(example):
96
  """Decode image/encoded and image/shape of an Example into a numpy array.
97
98
  Parse image/encoded and image/shape features from a tensorflow Example and
99
  decode the image into that shape.
100
101
  Args:
102
    example: a tensorflow Example containing features that include
103
      "image/encoded" and "image/shape"
104
105
  Returns:
106
    numpy array of dtype np.uint8.
107
  """
108
  features = example.features.feature
109
  img = features['image/encoded'].bytes_list.value[0]
110
  shape = features['image/shape'].int64_list.value[0:3]
111
  return np.frombuffer(img, np.uint8).reshape(shape)
112
113
114
def split_3d_array_into_channels(arr):
115
  """Split 3D array into a list of 2D arrays.
116
117
  e.g. given a numpy array of shape (100, 200, 6), return a list of 6 channels,
118
  each with shape (100, 200).
119
120
  Args:
121
    arr: a 3D numpy array.
122
123
  Returns:
124
    list of 2D numpy arrays.
125
  """
126
  return [arr[:, :, i] for i in range(arr.shape[-1])]
127
128
129
def channels_from_example(example):
130
  """Extract image from an Example and return the list of channels.
131
132
  Args:
133
    example: a tensorflow Example containing features that include
134
      "image/encoded" and "image/shape"
135
136
  Returns:
137
    list of 2D numpy arrays, one for each channel.
138
  """
139
  image = get_image_array_from_example(example)
140
  return split_3d_array_into_channels(image)
141
142
143
def convert_6_channels_to_rgb(channels):
144
  """Convert 6-channel image from DeepVariant to RGB for quick visualization.
145
146
  The 6 channels are: "read base", "base quality", "mapping quality", "strand",
147
  "supports variant", "base != reference".
148
149
  Args:
150
    channels: a list of 6 numpy arrays.
151
152
  Returns:
153
    3D numpy array of 3 colors (Red, green, blue).
154
  """
155
  base = channels[0]
156
  # qual is the minimum of base quality and mapping quality at each position
157
  # 254 is the max value for quality scores because the SAM specification has
158
  # 255 reserved for unavailable values.
159
  qual = np.minimum(channels[1], channels[2])
160
  strand = channels[3]
161
  # alpha is <supports variant> * <base != reference>
162
  alpha = np.multiply(channels[4] / 254.0, channels[5] / 254.0)
163
  return np.multiply(np.stack([base, qual, strand]),
164
                     alpha).astype(np.uint8).transpose([1, 2, 0])
165
166
167
def scale_colors_for_png(arr, vmin=0, vmax=255):
168
  """Scale an array to integers between 0 and 255 to prep it for a PNG image.
169
170
  Args:
171
    arr: numpy array. Input array made up of integers or floats.
172
    vmin: number. Minimum data value to map to 0. Values below this will be
173
      clamped to this value and therefore become 0.
174
    vmax: number. Maximum data value to map to 255. Values above this will be
175
      clamped to this value and therefore become 255.
176
177
  Returns:
178
    numpy array of dtype np.uint8 (integers between 0 and 255).
179
  """
180
  if vmax == 0 or vmax <= vmin:
181
    raise ValueError('vmin must be non-zero and higher than vmin.')
182
183
  # Careful not to modify the original array
184
  scaled = np.copy(arr)
185
186
  # Snap numbers in the array falling outside the range into the range,
187
  # otherwise they will produce artifacts due to byte overflow
188
  scaled[scaled > vmax] = vmax
189
  scaled[scaled < vmin] = vmin
190
191
  # Scale the input into the range of vmin to vmax
192
  if vmin != 0 or vmax != 255:
193
    scaled = ((scaled - vmin) / (vmax - vmin)) * 255
194
  return scaled.astype(np.uint8)
195
196
197
def _get_image_type_from_array(arr):
198
  """Find image type based on array dimensions.
199
200
  Raises error on invalid image dimensions.
201
  Args:
202
    arr: numpy array. Input array.
203
204
  Returns:
205
    str. "RGB" or "L", meant for PIL.Image.fromarray.
206
  """
207
  if len(arr.shape) == 3 and arr.shape[2] == 3:
208
    # 8-bit x 3 colors
209
    return 'RGB'
210
  elif len(arr.shape) == 2:
211
    # 8-bit, gray-scale
212
    return 'L'
213
  else:
214
    raise ValueError(
215
        'Input array must have either 2 dimensions or 3 dimensions where the '
216
        'third dimension has 3 channels. i.e. arr.shape is (x,y) or (x,y,3). '
217
        'Found shape {}.'.format(arr.shape))
218
219
220
def autoscale_colors_for_png(arr, vmin=None, vmax=None):
221
  """Adjust an array to prepare it for saving to an image.
222
223
  Re-scale numbers in the input array to go from 0 to 255 to adapt them for a
224
  PNG image.
225
226
  Args:
227
    arr: numpy array. Should be 2-dimensional or 3-dimensional where the third
228
      dimension has 3 channels.
229
    vmin: number (float or int). Minimum data value, which will correspond to
230
      black in greyscale or lack of each color in RGB images. Default None takes
231
      the minimum of the data from arr.
232
    vmax: number (float or int). Maximum data value, which will correspond to
233
      white in greyscale or full presence of each color in RGB images. Default
234
      None takes the max of the data from arr.
235
236
  Returns:
237
    (modified numpy array, image_mode)
238
  """
239
  image_mode = _get_image_type_from_array(arr)
240
241
  if vmin is None:
242
    vmin = np.min(arr)
243
  if vmax is None:
244
    vmax = np.max(arr)
245
246
  # In cases where all elements are the same, fix the vmax so that even though
247
  # the whole image will be black, the user can at least see the shape
248
  if vmin == vmax:
249
    vmax = vmin + 1
250
251
  scaled = scale_colors_for_png(arr, vmin=vmin, vmax=vmax)
252
  return scaled, image_mode
253
254
255
def add_header(img, labels, mark_midpoints=True, header_height=20):
256
  """Adds labels to the image, evenly distributed across the top.
257
258
  This is primarily useful for showing the names of channels.
259
260
  Args:
261
    img: A PIL Image.
262
    labels: list of strs. Labels for segments to write across the top.
263
    mark_midpoints: bool. Whether to add a small vertical line marking the
264
      center of each segment of the image.
265
    header_height: int. Height of the header in pixels.
266
267
  Returns:
268
    A new PIL Image, taller than the original img and annotated.
269
  """
270
271
  # Create a taller image to make space for a header at the top.
272
  new_height = header_height + img.size[1]
273
  new_width = img.size[0]
274
275
  if img.mode == 'RGB':
276
    placeholder_size = (new_height, new_width, 3)
277
  else:
278
    placeholder_size = (new_height, new_width)
279
  placeholder = np.ones(placeholder_size, dtype=np.uint8) * 255
280
281
  # Divide the image width into segments.
282
  segment_width = img.size[0] / len(labels)
283
284
  # Calculate midpoints for all segments.
285
  midpoints = [int(segment_width * (i + 0.5)) for i in range(len(labels))]
286
287
  if mark_midpoints:
288
    # For each label, add a small line to mark the middle.
289
    for x_position in midpoints:
290
      placeholder[header_height - 5:header_height, x_position] = 0
291
      # If image has an even width, it will need 2 pixels marked as the middle.
292
      if segment_width % 2 == 0:
293
        placeholder[header_height - 5:header_height, x_position + 1] = 0
294
295
  bigger_img = Image.fromarray(placeholder, mode=img.mode)
296
  # Place the original image inside the taller placeholder image.
297
  bigger_img.paste(img, (0, header_height))
298
299
  # Add a label for each segment.
300
  draw = ImageDraw.Draw(bigger_img)
301
  for i in range(len(labels)):
302
    text = labels[i]
303
    text_width = draw.textbbox((0, 0), text, anchor='lt')[2]
304
    # xy refers to the left top corner of the text, so to center the text on
305
    # the midpoint, subtract half the text width from the midpoint position.
306
    x_position = int(midpoints[i] - text_width / 2)
307
    draw.text(xy=(x_position, 0), text=text, fill='black')
308
  return bigger_img
309
310
311
def save_to_png(arr,
312
                path=None,
313
                image_mode=None,
314
                show=True,
315
                labels=None,
316
                scale=None):
317
  """Make a PNG and show it from a numpy array of dtype=np.uint8.
318
319
  Args:
320
    arr: numpy array. Input array to save.
321
    path: str. File path at which to save the image. A .png prefix is added if
322
      the path does not already have one. Leave empty to save at /tmp/tmp.png,
323
      which is useful when only temporarily showing the image in a Colab
324
      notebook.
325
    image_mode: "RGB" or "L". Leave as default=None to choose based on image
326
      dimensions.
327
    show: bool. Whether to display the image using IPython (for notebooks).
328
    labels: list of str. Labels to show across the top of the image.
329
    scale: integer. Number of pixels wide and tall to show each cell in the
330
      array. This sizes up the image while keeping exactly the same number of
331
      pixels for every cell in the array, preserving resolution and preventing
332
      any interpolation or overlapping of pixels. Default None adapts to the
333
      size of the image to multiply it up until a limit of 500 pixels, a
334
      convenient size for use in notebooks. If saving to a file for automated
335
      processing, scale=1 is recommended to keep output files small and simple
336
      while still retaining all the information content.
337
338
  Returns:
339
    None. Saves an image at path and optionally shows it with IPython.display.
340
  """
341
  if image_mode is None:
342
    image_mode = _get_image_type_from_array(arr)
343
344
  img = Image.fromarray(arr, mode=image_mode)
345
346
  if labels is not None:
347
    img = add_header(img, labels)
348
349
  if scale is None:
350
    scale = max(1, int(500 / max(arr.shape)))
351
352
  if scale != 1:
353
    img = img.resize((img.size[0] * scale, img.size[1] * scale))
354
355
  # Saving to a temporary file is needed even when showing in a notebook
356
  if path is None:
357
    path = '/tmp/tmp.png'
358
  elif not path.endswith('.png'):
359
    # Only PNG is supported because JPEG files are unnecessarily 3 times larger.
360
    path = '{}.png'.format(path)
361
  img.save(epath.Path(path).open('wb'), format=path.split('.')[-1])
362
363
  # Show image (great for notebooks)
364
  if show:
365
    display.display(display.Image(path))
366
367
368
def array_to_png(arr,
369
                 path=None,
370
                 show=True,
371
                 vmin=None,
372
                 vmax=None,
373
                 scale=None,
374
                 labels=None):
375
  """Save an array as a PNG image with PIL and show it.
376
377
  Args:
378
    arr: numpy array. Should be 2-dimensional or 3-dimensional where the third
379
      dimension has 3 channels.
380
    path: str. Path for the image output. Default is /tmp/tmp.png for quickly
381
      showing the image in a notebook.
382
    show: bool. Whether to show the image using IPython utilities, only works in
383
      notebooks.
384
    vmin: number. Minimum data value, which will correspond to black in
385
      greyscale or lack of each color in RGB images. Default None takes the
386
      minimum of the data from arr.
387
    vmax: number. Maximum data value, which will correspond to white in
388
      greyscale or full presence of each color in RGB images. Default None takes
389
      the max of the data from arr.
390
    scale: integer. Number of pixels wide and tall to show each cell in the
391
      array. This sizes up the image while keeping exactly the same number of
392
      pixels for every cell in the array, preserving resolution and preventing
393
      any interpolation or overlapping of pixels. Default None adapts to the
394
      size of the image to multiply it up until a limit of 500 pixels, a
395
      convenient size for use in notebooks. If saving to a file for automated
396
      processing, scale=1 is recommended to keep output files small and simple
397
      while still retaining all the information content.
398
    labels: list of str. Labels to show across the top of the image.
399
400
  Returns:
401
    None. Saves an image at path and optionally shows it with IPython.display.
402
  """
403
  scaled, image_mode = autoscale_colors_for_png(arr, vmin=vmin, vmax=vmax)
404
  save_to_png(
405
      scaled,
406
      path=path,
407
      show=show,
408
      image_mode=image_mode,
409
      labels=labels,
410
      scale=scale)
411
412
413
def _deepvariant_channel_names(num_channels):
414
  """Get DeepVariant channel names for the given number of channels."""
415
  # Add additional empty labels if there are more channels than expected.
416
  filler_labels = [
417
      'channel {}'.format(i + 1)
418
      for i in range(len(DEEPVARIANT_CHANNEL_NAMES), num_channels)
419
  ]
420
  labels = DEEPVARIANT_CHANNEL_NAMES + filler_labels
421
  # Trim off any extra labels.
422
  return labels[0:num_channels]
423
424
425
def draw_deepvariant_pileup(example=None,
426
                            channels=None,
427
                            composite_type=None,
428
                            annotated=True,
429
                            labels=None,
430
                            path=None,
431
                            show=True,
432
                            scale=None):
433
  """Quick utility for showing a pileup example as channels or RGB.
434
435
  Args:
436
    example: A tensorflow Example containing image/encoded and image/shape
437
      features. Will be parsed through channels_from_example. Ignored if
438
      channels are provided directly. Either example OR channels is required.
439
    channels: list of 2D arrays containing the data to draw. Either example OR
440
      channels is required.
441
    composite_type: str or None. Method for combining channels. One of
442
      [None,"RGB"].
443
    annotated: bool. Whether to add channel labels and mark midpoints.
444
    labels: list of str. Which labels to add to the image. If annotated=True,
445
      use default channels labels for DeepVariant.
446
    path: str. Output file path for saving as an image. If None, just show plot.
447
    show: bool. Whether to display the image for ipython notebooks. Set to False
448
      to prevent extra output when running in bulk.
449
    scale: integer. Multiplier to enlarge the image. Default: None, which will
450
      set it automatically for a human-readable size. Set to 1 for no scaling.
451
452
  Returns:
453
    None. Saves an image at path and optionally shows it with IPython.display.
454
  """
455
  if example and not channels:
456
    channels = channels_from_example(example)
457
  elif not channels:
458
    raise ValueError('Either example OR channels must be specified.')
459
460
  if composite_type is None:
461
    img_array = np.concatenate(channels, axis=1)
462
    if annotated and labels is None:
463
      labels = _deepvariant_channel_names(len(channels))
464
  elif composite_type == 'RGB':
465
    img_array = convert_6_channels_to_rgb(channels)
466
    if annotated and labels is None:
467
      labels = ['']  # Creates one midpoint with no label.
468
  else:
469
    raise ValueError(
470
        "Unrecognized composite_type: {}. Must be None or 'RGB'".format(
471
            composite_type))
472
473
  array_to_png(
474
      img_array,
475
      path=path,
476
      show=show,
477
      scale=scale,
478
      labels=labels,
479
      vmin=0,
480
      vmax=254)
481
482
483
def variant_from_example(example):
484
  """Extract Variant object from the 'variant/encoded' feature of an Example.
485
486
  Args:
487
    example: a DeepVariant-style make_examples output example.
488
489
  Returns:
490
    A Nucleus Variant.
491
  """
492
  features = example.features.feature
493
  var_string = features['variant/encoded'].bytes_list.value[0]
494
  return variants_pb2.Variant.FromString(var_string)
495
496
497
def locus_id_from_variant(variant):
498
  """Create a locus ID of form "chr:pos_ref" from a Variant object.
499
500
  Args:
501
    variant: a nucleus variant.
502
503
  Returns:
504
    str.
505
  """
506
  return '{}:{}_{}'.format(variant.reference_name, variant.start,
507
                           variant.reference_bases)
508
509
510
def alt_allele_indices_from_example(example):
511
  """Extract indices of the particular alt allele(s) the example represents.
512
513
  Args:
514
    example: a DeepVariant make_examples output example.
515
516
  Returns:
517
    list of indices.
518
  """
519
  features = example.features.feature
520
  val = features['alt_allele_indices/encoded'].bytes_list.value[0]
521
  # Extract the encoded proto into unsigned integers and convert to regular ints
522
  mapped = [int(x) for x in np.frombuffer(val, dtype=np.uint8)]
523
  # Format is [<field id + type>, <number of elements in array>, ...<array>].
524
  # Extract the array only, leaving out the metadata.
525
  return mapped[2:]
526
527
528
def alt_bases_from_indices(alt_allele_indices, alternate_bases):
529
  """Get alt allele bases based on their indices.
530
531
  e.g. one alt allele: [0], ["C"] => "C"
532
  or with two alt alleles: [0,2], ["C", "TT", "A"] => "C-A"
533
534
  Args:
535
    alt_allele_indices: list of integers. Indices of the alt alleles for a
536
      particular example.
537
    alternate_bases: list of strings. All alternate alleles for the variant.
538
539
  Returns:
540
    str. Alt allele(s) at the indices, joined by '-' if more than 1.
541
  """
542
  alleles = [alternate_bases[i] for i in alt_allele_indices]
543
  # Avoiding '/' to support use in file paths.
544
  return '-'.join(alleles)
545
546
547
def alt_from_example(example):
548
  """Get alt allele(s) from a DeepVariant example.
549
550
  Args:
551
    example: a DeepVariant make_examples output example.
552
553
  Returns:
554
    str. The bases of the alt alleles, joined by a -.
555
  """
556
  variant = variant_from_example(example)
557
  indices = alt_allele_indices_from_example(example)
558
  return alt_bases_from_indices(indices, variant.alternate_bases)
559
560
561
def locus_id_with_alt(example):
562
  """Get complete locus ID from a DeepVariant example.
563
564
  Args:
565
    example: a DeepVariant make_examples output example.
566
567
  Returns:
568
    str in the form "chr:pos_ref_alt.
569
  """
570
  variant = variant_from_example(example)
571
  locus_id = locus_id_from_variant(variant)
572
  alt = alt_from_example(example)
573
  return '{}_{}'.format(locus_id, alt)
574
575
576
def label_from_example(example):
577
  """Get the "label" from an example.
578
579
  Args:
580
    example: a DeepVariant make_examples output example.
581
582
  Returns:
583
    integer (0, 1, or 2 for regular DeepVariant examples) or None if the
584
        example has no label.
585
  """
586
  val = example.features.feature['label'].int64_list.value
587
  if val:
588
    return int(val[0])
589
  else:
590
    return None
591
592
593
def remove_ref_band(arr: np.ndarray,
594
                    num_top_rows_to_skip: int = 5) -> np.ndarray:
595
  """Removes the reference rows at the top of a pileup image array."""
596
  assert len(arr.shape) == 2
597
  assert arr.shape[0] > num_top_rows_to_skip
598
  return arr[num_top_rows_to_skip:, :]
599
600
601
def fraction_low_base_quality(channels: List[np.ndarray],
602
                              threshold: int = 127) -> float:
603
  """Gets fraction of bases that have low base quality scores in a pileup.
604
605
  Args:
606
      channels: A list of channels of a DeepVariant pileup image. This only uses
607
        channels[1], the base quality channel.
608
      threshold: Bases qualities below this will be considered low quality. The
609
        default is 127 because this is half of the max (254).
610
611
  Returns:
612
      The fraction of bases with base quality below the threshold.
613
  """
614
  basequal_channel = remove_ref_band(channels[1])
615
  non_zero_values = basequal_channel[basequal_channel > 0]
616
617
  num_non_zero = non_zero_values.shape[0]
618
  if num_non_zero == 0:
619
    return 0.0
620
  return sum((non_zero_values < threshold) * 1.0) / num_non_zero
621
622
623
def fraction_reads_with_low_mapq(channels: List[np.ndarray],
624
                                 threshold: int = 127) -> float:
625
  """Gets fraction of reads that have low mapping quality scores in pileup.
626
627
  Args:
628
      channels: A list of channels of a DeepVariant pileup image. This only uses
629
        channels[2], the mapping quality channel.
630
      threshold: int. Default is 127 because this is half of the max (254).
631
632
  Returns:
633
      The fraction of bases with mapping quality below the threshold.
634
  """
635
  mapq_channel = remove_ref_band(channels[2])
636
  # Get max value of each row, aka each read.
637
  max_row_values = np.amax(mapq_channel, axis=1)
638
639
  non_zero_values = max_row_values[max_row_values > 0]
640
  num_non_zero = non_zero_values.shape[0]
641
  if num_non_zero == 0:
642
    return 0.0
643
  return sum((non_zero_values < threshold) * 1.0) / num_non_zero
644
645
646
def fraction_read_support(channels: List[np.ndarray]) -> float:
647
  """Gets fraction of reads that support the variant.
648
649
  Args:
650
      channels: A list of channels of a DeepVariant pileup image. This only uses
651
        channels[4], the 'read supports variant' channel.
652
653
  Returns:
654
      Fraction of reads supporting the alternate allele(s), ranging from [0, 1].
655
  """
656
  support_channel = remove_ref_band(channels[4])
657
  max_row_values = np.amax(support_channel, axis=1)
658
659
  non_zero_values = max_row_values[max_row_values > 0]
660
  num_non_zero = non_zero_values.shape[0]
661
  if num_non_zero == 0:
662
    return 0.0
663
  return sum(non_zero_values == 254) * 1.0 / num_non_zero
664
665
666
def describe_read_support(channels: List[np.ndarray]) -> ReadSupport:
667
  """Calculates read support and describes it categorically.
668
669
  Computes read support as a fraction and returns a convenient descriptive term
670
  according to the following thresholds: LOW is [0, 0.3], HALF is (0.3, 0.8],
671
  and ALL is (0.8, 1].
672
673
  Args:
674
      channels: A list of channels of a DeepVariant pileup image. This only uses
675
        channels[4], the 'read supports variant' channel.
676
677
  Returns:
678
      A ReadSupport value.
679
  """
680
  fraction_support = fraction_read_support(channels)
681
  if fraction_support > 0.8:
682
    return ReadSupport.ALL
683
  elif fraction_support > 0.3:
684
    return ReadSupport.HALF
685
  else:
686
    return ReadSupport.LOW
687
688
689
def binomial_test(k: int, n: int) -> float:
690
  """Calculates a two-tailed binomial test with p=0.5, without scipy.
691
692
  Since the expected probability is 0.5, it simplifies a few things:
693
  1) (0.5**x)*(0.5**(n-x)) = (0.5**n)
694
  2) A two-tailed test is simply doubling when p = 0.5.
695
  Scipy is much larger than Nucleus, so this avoids adding it as a dependency.
696
697
  Args:
698
    k: Number of "successes", in this case, the number of supporting reads.
699
    n: Number of "trials", in this case, the total number of reads.
700
701
  Returns:
702
    The p-value for the binomial test.
703
  """
704
  if not k <= n:
705
    raise ValueError('k must be <= n')
706
  if k == n / 2:
707
    return 1.0
708
  sum_of_ps = 0
709
710
  # With p=0.5, the distribution is symmetric, allowing this simplification:
711
  k = min(k, n - k)
712
  # Add up all the exact probabilities for each scenario more extreme than k.
713
  for x in range(0, k + 1):
714
    # After python 3.8, the following line can be replaced using math.comb.
715
    n_choose_x = math.factorial(n) / math.factorial(x) / math.factorial(n - x)
716
    p_for_i = n_choose_x * (0.5**n)
717
    sum_of_ps += p_for_i
718
  return sum_of_ps * 2  # Doubling because it's a two-tailed test.
719
720
721
def pvalue_for_strand_bias(channels: List[np.ndarray]) -> float:
722
  """Calculates a rough p-value for strand bias in pileup.
723
724
  Using the strand and read-supports-variant channels, compares the numbers of
725
  forward and reverse reads among the supporting reads and returns a p-value
726
  using a two-tailed binomial test.
727
728
  Args:
729
      channels: List of DeepVariant channels. Uses channels[3] (strand) and
730
        channels[4] (read support).
731
732
  Returns:
733
      P-value for whether the supporting reads show strand bias.
734
  """
735
  strand = remove_ref_band(channels[3])
736
  forward_strand = strand == 240
737
  reverse_strand = strand == 70
738
739
  read_support = remove_ref_band(channels[4])
740
  read_support = (read_support == 254) * 1.0
741
  forward_support = read_support * forward_strand
742
  reverse_support = read_support * reverse_strand
743
744
  forward_supporting = int(sum(np.amax(forward_support, axis=1)))
745
  reverse_supporting = int(sum(np.amax(reverse_support, axis=1)))
746
747
  return binomial_test(
748
      k=forward_supporting, n=forward_supporting + reverse_supporting)
749
750
751
def analyze_diff_and_nearby_variants(
752
    channels: List[np.ndarray]) -> Tuple[float, int]:
753
  """Analyzes which differences belong to nearby variants and which do not.
754
755
  This attempts to identify putative nearby variants from the pileup image
756
  alone, and then excludes these columns of the pileup to calculate the
757
  remaining fraction of differences that may be attributed to sequencing errors.
758
759
  Args:
760
      channels: A list of channels of a DeepVariant pileup image. This only uses
761
        channels[5], the 'differs from ref' channel.
762
763
  Returns:
764
      Two outputs: diff fraction, number of likely nearby variants.
765
  """
766
  diff_channel = remove_ref_band(channels[5])
767
768
  # Count the number of diff pixels per column.
769
  column_diffs = np.sum(diff_channel == 254, axis=0)
770
  # Count number of differences per base position.
771
  column_read_count = np.sum(diff_channel != 0, axis=0)
772
  # Divide to get the fraction of reads showing a diff at each base (column).
773
  # Adding 1 here avoids dividing by zero (exact fraction here is not vital).
774
  fraction = column_diffs * 1.0 / (column_read_count + 1)
775
776
  # Columns with more differences could be variants.
777
  nearby_variant_columns = (fraction > 0.1) * (column_diffs > 4) * 1
778
  num_potential_nearby_variants = sum(nearby_variant_columns)
779
780
  # Exclude potential variants when calculating error fraction.
781
  nearby_variant_mask = np.array([nearby_variant_columns] *
782
                                 diff_channel.shape[0])
783
  mask_to_remove_nearby_variants = 1 - nearby_variant_mask
784
  non_variant_diffs = (diff_channel == 254) * mask_to_remove_nearby_variants
785
786
  # Calculate differences as fraction of the total number of read bases.
787
  total_read_area = np.sum((diff_channel != 0))
788
  diff_fraction = 0 if total_read_area == 0 else np.sum(
789
      non_variant_diffs) / total_read_area
790
  return diff_fraction, num_potential_nearby_variants
791
792
793
def describe_diff(channels: List[np.ndarray],
794
                  diff_fraction_threshold: float = 0.01) -> Diff:
795
  """Describes a pileup image by its diff channel, including nearby variants.
796
797
  Returns Diff.MANY_DIFFS if the fraction of differences outside potential
798
  nearby variants is above the diff_fraction_threshold, which is usually
799
  indicative of sequencing errors. Otherwise return Diff.NEARBY_VARIANTS if
800
  there are five or more of these, or Diff.FEW_DIFFS if neither of these
801
  special cases apply.
802
803
  Args:
804
      channels: A list of channels of a DeepVariant pileup image. This only uses
805
        channels[5], the 'differs from ref' channel.
806
      diff_fraction_threshold: Fraction of total bases of all reads that can
807
        differ, above which the pileup will be designated as 'many_diffs'.
808
        Differences that appear due to nearby variants (neater columns) do not
809
        count towards this threshold. The default is set by visual curation of
810
        Illumina reads, so it may be necessary to increase this for higher-error
811
        sequencing types.
812
813
  Returns:
814
      One Diff value.
815
  """
816
  diff_fraction, nearby_variants = analyze_diff_and_nearby_variants(channels)
817
  # Thresholds were chosen by visual experimentation, i.e. human curation.
818
  if diff_fraction > diff_fraction_threshold:
819
    return Diff.MANY_DIFFS
820
  elif nearby_variants >= 5:
821
    return Diff.NEARBY_VARIANTS
822
  else:
823
    return Diff.FEW_DIFFS
824
825
826
def curate_pileup(channels: List[np.ndarray]) -> PileupCuration:
827
  """Runs all automated curation functions and outputs categorical tags.
828
829
  The following values are possible for each descriptor:
830
  * base_quality: GOOD (>5% low quality) or BAD.
831
  * mapping_quality: GOOD (<5% low quality) or BAD.
832
  * strand_biased: BIASED (p-value < 0.05) or GOOD.
833
  * diff_category: MANY_DIFFS (>1% differences), NEARBY_VARIANTS (5+ variants),
834
  or FEW_DIFFS otherwise.
835
  * read_support: LOW (<=30%), HALF (30-80%), ALL (>80%).
836
837
  The thresholds were all set by trying to match human curation.
838
839
  Args:
840
      channels: A list of DeepVariant channels.
841
842
  Returns:
843
      A PileupCuration NamedTuple.
844
  """
845
846
  return PileupCuration(
847
      base_quality=BaseQuality.GOOD
848
      if fraction_low_base_quality(channels) < 0.05 else BaseQuality.BAD,
849
      mapping_quality=MappingQuality.GOOD
850
      if fraction_reads_with_low_mapq(channels) < 0.05 else MappingQuality.BAD,
851
      strand_bias=StrandBias.BIASED
852
      if pvalue_for_strand_bias(channels) < 0.05 else StrandBias.GOOD,
853
      diff_category=describe_diff(channels),
854
      read_support=describe_read_support(channels))