a b/genomelake/extractors.py
1
from __future__ import absolute_import, division, print_function
2
import numpy as np
3
import six
4
5
import bcolz
6
from pybedtools import BedTool
7
from pybedtools import Interval
8
from pysam import FastaFile
9
import pyBigWig
10
11
from . import backend
12
from .util import nan_to_zero
13
from .util import one_hot_encode_sequence
14
15
NUM_SEQ_CHARS = 4
16
17
18
class BaseExtractor(object):
19
    dtype = np.float32
20
21
    def __init__(self, datafile, **kwargs):
22
        self._datafile = datafile
23
24
    def __call__(self, intervals, out=None, **kwargs):
25
        data = self._check_or_create_output_array(intervals, out)
26
        self._extract(intervals, data, **kwargs)
27
        return data
28
29
    def _check_or_create_output_array(self, intervals, out):
30
        width = intervals[0].stop - intervals[0].start
31
        output_shape = self._get_output_shape(len(intervals), width)
32
33
        if out is None:
34
            out = np.zeros(output_shape, dtype=self.dtype)
35
        else:
36
            if out.shape != output_shape:
37
                raise ValueError('out array has incorrect shape: {} '
38
                                 '(need {})'.format(out.shape, output_shape))
39
            if out.dtype != self.dtype:
40
                raise ValueError('out array has incorrect dtype: {} '
41
                                 '(need {})'.format(out.dtype, self.dtype))
42
        return out
43
44
    def _extract(self, intervals, out, **kwargs):
45
        'Subclassses should implement this and return the data'
46
        raise NotImplementedError
47
48
    @staticmethod
49
    def _get_output_shape(num_intervals, width):
50
        'Subclasses should implement this and return the shape of output'
51
        raise NotImplementedError
52
53
54
class ArrayExtractor(BaseExtractor):
55
56
    def __init__(self, datafile, in_memory=False, **kwargs):
57
        super(ArrayExtractor, self).__init__(datafile, **kwargs)
58
        self._data = backend.load_directory(datafile, in_memory=in_memory)
59
        self.multiprocessing_safe = in_memory
60
61
        arr = next(iter(self._data.values()))
62
        def _mm_extract(self, intervals, out, **kwargs):
63
            mm_data = self._data
64
            for index, interval in enumerate(intervals):
65
                out[index] = mm_data[interval.chrom][interval.start:interval.stop]
66
67
        # output shape method
68
        shape = arr.shape
69
        if len(shape) == 1:
70
            def _get_output_shape(num_intervals, width):
71
                return (num_intervals, width)
72
        elif len(shape) == 2:
73
            def _get_output_shape(num_intervals, width):
74
                return (num_intervals, width, shape[1])
75
        else:
76
            raise ValueError('Can only extract from 1D/2D arrays')
77
78
        self._mm_extract = _mm_extract.__get__(self)
79
        self._extract = self._mm_extract
80
        self._get_output_shape = staticmethod(_get_output_shape).__get__(self)
81
82
83
class FastaExtractor(BaseExtractor):
84
85
    def __init__(self, datafile, use_strand=False, **kwargs):
86
        """Fasta file extractor
87
        
88
        NOTE: The extractor is not thread-save.
89
        If you with to use it with multiprocessing,
90
        create a new extractor object in each process.
91
        
92
        Args:
93
          datafile (str): path to the bigwig file
94
          use_strand (bool): if True, the extracted sequence
95
            is reverse complemented in case interval.strand == "-"
96
        """
97
        super(FastaExtractor, self).__init__(datafile, **kwargs)
98
        self.use_strand = use_strand
99
        self.fasta = FastaFile(self._datafile)
100
101
    def _extract(self, intervals, out, **kwargs):    
102
        for index, interval in enumerate(intervals):
103
            seq = self.fasta.fetch(str(interval.chrom), interval.start,
104
                                       interval.stop)
105
            one_hot_encode_sequence(seq, out[index, :, :])
106
107
            # reverse-complement seq the negative strand
108
            if self.use_strand and interval.strand == "-":
109
                out[index, :, :] = out[index, ::-1, ::-1]
110
111
        return out
112
113
    @staticmethod
114
    def _get_output_shape(num_intervals, width):
115
        return (num_intervals, width, NUM_SEQ_CHARS)
116
117
118
class BigwigExtractor(BaseExtractor):
119
120
    def __init__(self, datafile, **kwargs):
121
        """Big-wig file extractor
122
        
123
        NOTE: The extractor is not thread-save.
124
        If you with to use it with multiprocessing,
125
        create a new extractor object in each process.
126
        
127
        Args:
128
          datafile: path to the bigwig file
129
        """
130
        super(BigwigExtractor, self).__init__(datafile, **kwargs)
131
        self._verbose = kwargs.get('verbose', False)
132
        self.bw = pyBigWig.open(datafile)
133
        
134
    def _extract(self, intervals, out, **kwargs):
135
        out[:] = self._bigwig_extractor(self.bw, intervals,
136
                                        **kwargs)
137
138
        return out
139
140
    @staticmethod
141
    def _get_output_shape(num_intervals, width):
142
        return (num_intervals, width)
143
144
    @staticmethod
145
    def _bigwig_extractor(bw, intervals, out=None, **kwargs):
146
        nan_as_zero = kwargs.get('nan_as_zero', True)
147
        if out is None:
148
            width = intervals[0].stop - intervals[0].start
149
            out = np.zeros((len(intervals), width), dtype=np.float32)
150
            
151
        for index, interval in enumerate(intervals):
152
            out[index] = bw.values(
153
                interval.chrom, interval.start, interval.stop)
154
            if nan_as_zero:
155
                nan_to_zero(out[index])
156
        return out
157
    
158
    def __del__(self):
159
        return self.close()
160
    
161
    def close(self):
162
        return self.bw.close()