Switch to unified view

a b/src/janggu/data/visualization.py
1
"""Genomic track visualization utilities."""
2
import warnings
3
from itertools import product
4
5
import matplotlib.pyplot as plt
6
import numpy as np
7
import seaborn as sns
8
9
from janggu.utils import NMAP
10
from janggu.utils import PMAP
11
from janggu.utils import _to_list
12
13
14
def plotGenomeTrack(tracks, chrom, start, end, figsize=(10, 5), plottypes=None):
15
16
    """plotGenomeTrack shows plots of a specific interval from cover objects data.
17
18
    It takes one or more cover objects as well as a genomic interval consisting
19
    of chromosome name, start and end and creates
20
    a genome browser-like plot.
21
22
    Parameters
23
    ----------
24
    tracks : janggu.data.Cover, list(Cover), janggu.data.Track or list(Track)
25
        One or more track objects.
26
    chrom : str
27
        chromosome name.
28
    start : int
29
        The start of the required interval.
30
    end : int
31
        The end of the required interval.
32
    figsize : tuple(int, int)
33
        Figure size passed on to matplotlib.
34
    plottype : None or list(str)
35
        Plot type indicates whether to plot coverage tracks as line plots,
36
        heatmap, or seqplot using 'line' or 'heatmap', respectively.
37
        By default, all coverage objects are depicted as line plots if plottype=None.
38
        Otherwise, a list of types must be supplied containing the plot types for each
39
        coverage object explicitly. For example, ['line', 'heatmap', 'seqplot'].
40
        While, 'line' and 'heatmap' can be used for any type of coverage data,
41
        'seqplot' is reserved to plot sequence influence on the output. It is
42
        intended to be used in conjunction with 'input_attribution' method which
43
        determines the importance of paricular sequence letters for the output prediction.
44
45
    Returns
46
    -------
47
    matplotlib Figure
48
        A matplotlib figure illustrating the genome browser-view of the coverage
49
        objects for the given interval.
50
        To depict and save the figure the native matplotlib functions show()
51
        and savefig() can be used.
52
    """
53
54
    tracks = _to_list(tracks)
55
56
    for track in tracks:
57
        if not isinstance(track, Track):
58
            warnings.warn('Convert the Dataset object to proper Track objects.'
59
                          ' In the future, only Track objects will be supported.',
60
                          FutureWarning)
61
            if plottypes is None:
62
                plottypes = ['line'] * len(tracks)
63
64
            assert len(plottypes) == len(tracks), \
65
                "The number of cover objects must be the same as the number of plottyes."
66
            break
67
68
    def _convert_to_track(cover, plottype):
69
        if plottype == 'heatmap':
70
            track = HeatTrack(cover)
71
        elif plottype == 'seqplot':
72
            track = SeqTrack(cover)
73
        else:
74
            track = LineTrack(cover)
75
        return track
76
77
    tracks_ = []
78
    for itrack, track in enumerate(tracks):
79
        if isinstance(track, Track):
80
            tracks_.append(track)
81
        else:
82
            warnings.warn('Convert the Dataset object to proper Track objects.'
83
                          ' In the future, only Track objects will be supported.',
84
                          FutureWarning)
85
            tracks_.append(_convert_to_track(track, plottypes[itrack]))
86
87
    tracks = tracks_
88
    headertrack = 2
89
    trackheights = 0
90
    for track in tracks:
91
        trackheights += track.height
92
    spacer = len(tracks) - 1
93
94
    grid = plt.GridSpec(headertrack + trackheights + spacer,
95
                        10, wspace=0.4, hspace=0.3)
96
    fig = plt.figure(figsize=figsize)
97
98
    # title and reference track
99
    title = fig.add_subplot(grid[0, 1:])
100
101
    title.set_title(chrom)
102
    plt.xlim([0, end - start])
103
    title.spines['right'].set_visible(False)
104
    title.spines['top'].set_visible(False)
105
    title.spines['left'].set_visible(False)
106
    plt.xticks([0, end-start], [start, end])
107
    plt.yticks(())
108
109
    y_offset = 1
110
    for track in tracks:
111
        y_offset += 1
112
113
        track.add_side_bar(fig, grid, y_offset)
114
        track.plot(fig, grid, y_offset, chrom, start, end)
115
        y_offset += track.height
116
117
    return (fig)
118
119
120
class Track(object):
121
    """General track
122
123
    Parameters
124
    ----------
125
126
    data : Cover object
127
        Coverage object
128
    height : int
129
        Track height.
130
    """
131
    def __init__(self, data, height):
132
        self.height = height
133
        self.data = data
134
135
    @property
136
    def name(self):
137
        """Track name"""
138
        return self.data.name
139
140
    def add_side_bar(self, fig, grid, offset):
141
        """Side-bar"""
142
        # side bar indicator for current cover
143
        ax = fig.add_subplot(grid[(offset): (offset + self.height), 0])
144
145
        ax.set_xticks(())
146
        ax.spines['right'].set_visible(False)
147
        ax.spines['top'].set_visible(False)
148
        ax.spines['bottom'].set_visible(False)
149
        ax.set_yticks([0.5])
150
        ax.set_yticklabels([self.name])
151
152
    def get_track_axis(self, fig, grid, offset, height):
153
        """Returns axis object"""
154
        return fig.add_subplot(grid[offset:(offset + height), 1:])
155
156
    def get_data(self, chrom, start, end):
157
        """Returns data to plot."""
158
        return self.data[chrom, start, end][0, :, :, :]
159
160
161
class LineTrack(Track):
162
    """Line track
163
164
    Visualizes genomic data as line plot.
165
166
    Parameters
167
    ----------
168
169
    data : Cover object
170
        Coverage object
171
    height : int
172
        Track height. Default=3
173
    linestyle : str
174
        Linestyle for plot
175
    marker : str
176
        Marker code for plot
177
    color : str
178
        Color code for plot
179
    linewidth : float
180
        Line width.
181
    """
182
    def __init__(self, data, height=3, linestyle='-', marker='o', color='b',
183
                 linewidth=2):
184
        super(LineTrack, self).__init__(data, height)
185
        self.height = height * len(data.conditions)
186
        self.linestyle = linestyle
187
        self.linewidth = linewidth
188
        self.marker = marker
189
        self.color = color
190
191
    def plot(self, fig, grid, offset, chrom, start, end):
192
        """Plot line track."""
193
        coverage = self.get_data(chrom, start, end)
194
        offset_ = offset
195
        trackheight = self.height//len(self.data.conditions)
196
197
        def _get_xy(cov):
198
            xvalue = np.where(np.isfinite(cov))[0]
199
            yvalue = cov[xvalue]
200
            return xvalue, yvalue
201
202
        for i, condition in enumerate(self.data.conditions):
203
            ax = self.get_track_axis(fig, grid, offset_, trackheight)
204
            offset_ += trackheight
205
            if coverage.shape[1] == 2:
206
                #both strands are covered separately
207
                xvalue, yvalue = _get_xy(coverage[:, 0, i])
208
                ax.plot(xvalue, yvalue,
209
                        linewidth=self.linewidth,
210
                        linestyle=self.linestyle,
211
                        color=self.color, label="+", marker='+')
212
                xvalue, yvalue = _get_xy(coverage[:, 1, i])
213
                ax.plot(xvalue, yvalue,
214
                        linewidth=self.linewidth,
215
                        linestyle=self.linestyle,
216
                        color=self.color, label="-", marker=1)
217
                ax.legend()
218
            else:
219
                xvalue, yvalue = _get_xy(coverage[:, 0, i])
220
                ax.plot(xvalue, yvalue, linewidth=self.linewidth,
221
                        color=self.color,
222
                        linestyle=self.linestyle,
223
                        marker=self.marker)
224
            ax.set_yticks(())
225
            ax.set_xticks(())
226
            ax.set_xlim([0, end-start])
227
            if len(self.data.conditions) > 1:
228
                ax.set_ylabel(condition, labelpad=12)
229
            ax.spines['right'].set_visible(False)
230
            ax.spines['top'].set_visible(False)
231
232
233
class SeqTrack(Track):
234
    """Sequence Track
235
236
    Visualizes sequence importance.
237
238
    Parameters
239
    ----------
240
241
    data : Cover object
242
        Coverage object
243
    height : int
244
        Track height. Default=3
245
    """
246
    def __init__(self, data, height=3):
247
        super(SeqTrack, self).__init__(data, height)
248
249
    def plot(self, fig, grid, offset, chrom, start, end):
250
        """Plot sequence track"""
251
252
        if len(self.data.conditions) % len(NMAP) == 0:
253
            alphabetsize = len(NMAP)
254
            MAP = NMAP
255
        elif len(self.data.conditions) % len(PMAP) == 0:  # pragma: no cover
256
            alphabetsize = len(PMAP)
257
            MAP = PMAP
258
        else:  # pragma: no cover
259
            raise ValueError(
260
                "Coverage tracks seems not represent biological sequences. "
261
                "The last dimension must be divisible by the alphabetsize.")
262
263
        for cond in self.data.conditions:
264
            if cond[0] not in MAP:
265
                raise ValueError(
266
                    "Coverage tracks seems not represent biological sequences. "
267
                    "Condition names must represent the alphabet letters.")
268
269
        coverage = self.get_data(chrom, start, end)
270
        # project higher-order sequence structure onto original sequence.
271
        coverage = coverage.reshape(coverage.shape[0], -1)
272
        coverage = coverage.reshape(coverage.shape[:-1] +
273
                                    (alphabetsize,
274
                                     int(coverage.shape[-1]/alphabetsize)))
275
        coverage = coverage.sum(-1)
276
277
        ax = self.get_track_axis(fig, grid, offset, self.height)
278
        x = np.arange(coverage.shape[0])
279
        y_figure_offset = np.zeros(coverage.shape[0])
280
        handles = []
281
        for letter in MAP:
282
            handles.append(ax.bar(x, coverage[:, MAP[letter]],
283
                                  bottom=y_figure_offset,
284
                                  color=sns.color_palette("hls",
285
                                                          len(MAP))[MAP[letter]],
286
                                  label=letter))
287
            y_figure_offset += coverage[:, MAP[letter]]
288
        ax.legend(handles=handles)
289
        ax.set_yticklabels(())
290
        ax.set_yticks(())
291
        ax.set_xticks(())
292
        ax.set_xlim([0, end-start])
293
294
295
class HeatTrack(Track):
296
    """Heatmap Track
297
298
    Visualizes genomic data as heatmap.
299
300
    Parameters
301
    ----------
302
303
    data : Cover object
304
        Coverage object
305
    height : int
306
        Track height. Default=3
307
    """
308
    def __init__(self, data, height=3):
309
        super(HeatTrack, self).__init__(data, height)
310
311
    def plot(self, fig, grid, offset, chrom, start, end):
312
        """Plot heatmap track."""
313
        ax = self.get_track_axis(fig, grid, offset, self.height)
314
        coverage = self.get_data(chrom, start, end)
315
316
        im = ax.pcolor(coverage.reshape(coverage.shape[0], -1).T)
317
318
        if coverage.shape[-2] == 2:
319
            ticks = [':'.join([x, y]) for y, x \
320
                     in product(['+', '-'], self.data.conditions)]
321
        else:
322
            ticks = self.data.conditions
323
324
        ax.set_yticklabels(ticks)
325
        ax.set_xticks(())
326
        ax.set_yticks(np.arange(0, len(ticks) + 1, 1.0))
327
        ax.set_xlim([0, end-start])