Diff of /blog/bdgblog.md [000000] .. [51428b]

Switch to unified view

a b/blog/bdgblog.md
1
# Genomic Analysis Using ADAM, Spark and Deep Learning
2
3
Can we use deep learning to predict which population group you belong to, based solely on your genome?
4
5
Yes, we can - and in this post, we will show you exactly how to do this in a scalable way, using Apache Spark. We
6
will explain how to apply [deep learning](https://en.wikipedia.org/wiki/Deep_learning) using [artifical neural networks]
7
(https://en.wikipedia.org/wiki/Artificial_neural_network) to predict which population group an individual belongs to -
8
based entirely on his or her genomic data.
9
10
This is a follow-up to an earlier post:
11
[Scalable Genomes Clustering With ADAM and Spark](http://bdgenomics.org/blog/2015/02/02/scalable-genomes-clustering-with-adam-and-spark/)
12
and attempts to replicate the results of that post. However, we will use a different machine learning technique.
13
Where the original post used [k-means clustering](https://en.wikipedia.org/wiki/K-means_clustering), we will use
14
deep learning.
15
16
We will use [ADAM](https://github.com/bigdatagenomics/adam) and [Apache Spark](https://spark.apache.org/) in
17
combination with [H2O](http://0xdata.com/product/), an open source predictive analytics platform, and
18
[Sparking Water](http://0xdata.com/product/sparkling-water/), which integrates H2O with Spark.
19
20
## Code
21
22
In this section, we'll dive straight into the code. If you'd rather get something working before looking at the code
23
you can skip to the "Building and Running" section.
24
25
The complete Scala code for this example can be found in
26
[the PopStrat.scala class on GitHub](https://github.com/nfergu/popstrat/blob/master/src/main/scala/com/neilferguson/PopStrat.scala)
27
and we'll refer to sections of the code here. Basic familiarity with Scala and
28
[Apache Spark](https://spark.apache.org/) is assumed.
29
30
### Setting-up
31
32
The first thing we need to do is to read the names of the Genotype and Panel files that are passed into our program.
33
The Genotype file contains data about a set of individuals (referred to here as "samples") and their genetic
34
variation. The Panel file lists the population group (or "region") for each sample in the Genotype file; this is
35
what we will try to predict.
36
37
```scala
38
val genotypeFile = args(0)
39
val panelFile = args(1)
40
```
41
42
Next, we set-up our Spark Context. Our program permits the Spark master to be specified as one of its arguments.
43
This is useful when running from an IDE, but is omitted when running from the ```spark-submit``` script (see below).
44
45
```scala
46
val master = if (args.length > 2) Some(args(2)) else None
47
val conf = new SparkConf().setAppName("PopStrat")
48
master.foreach(conf.setMaster)
49
val sc = new SparkContext(conf)
50
```
51
52
Next, we declare a set called ```populations``` which contains all of the population groups that we're interested
53
in predicting. We then read the Panel file into a Map, filtering it based on the population groups in the
54
```populations``` set. The format of the panel file is described [here](http://www.1000genomes.org/faq/what-panel-file).
55
Luckily it's very simple, containing the sample ID in the first column and the population group in the second.
56
57
```scala
58
val populations = Set("GBR", "ASW", "CHB")
59
def extract(file: String, filter: (String, String) => Boolean): Map[String,String] = {
60
  Source.fromFile(file).getLines().map(line => {
61
    val tokens = line.split("\t").toList
62
    tokens(0) -> tokens(1)
63
  }).toMap.filter(tuple => filter(tuple._1, tuple._2))
64
}
65
val panel: Map[String,String] = extract(panelFile, (sampleID: String, pop: String) => populations.contains(pop))
66
```
67
68
### Preparing the Genomics Data
69
70
Next, we use [ADAM](https://github.com/bigdatagenomics/adam) to read our genotype data into a Spark RDD. Since we've
71
imported ```ADAMContext._``` at the top of our class, this is simply a matter of calling `loadGenotypes` on the
72
Spark Context. Then, we filter the genotype data to contain only samples that are in the population groups which we're
73
interested in.
74
75
```scala
76
val allGenotypes: RDD[Genotype] = sc.loadGenotypes(genotypeFile)
77
val genotypes: RDD[Genotype] = allGenotypes.filter(genotype => {panel.contains(genotype.getSampleId)})
78
```
79
80
Next, we convert the ADAM ```Genotype``` objects into our own ```SampleVariant``` objects. These objects contain just the data
81
we need for further processing: the sample ID (which uniquely identifies a particular sample), a variant ID (which
82
uniquely identifies a particular genetic variant) and a count of alternate
83
[alleles](http://www.snpedia.com/index.php/Allele), where the sample differs from the
84
reference genome. These variations will help us to classify individuals according to their population group.
85
86
```scala
87
case class SampleVariant(sampleId: String, variantId: Int, alternateCount: Int)
88
def variantId(genotype: Genotype): String = {
89
  val name = genotype.getVariant.getContig.getContigName
90
  val start = genotype.getVariant.getStart
91
  val end = genotype.getVariant.getEnd
92
  s"$name:$start:$end"
93
}
94
def alternateCount(genotype: Genotype): Int = {
95
  genotype.getAlleles.asScala.count(_ != GenotypeAllele.Ref)
96
}
97
def toVariant(genotype: Genotype): SampleVariant = {
98
  // Intern sample IDs as they will be repeated a lot
99
  new SampleVariant(genotype.getSampleId.intern(), variantId(genotype).hashCode(), alternateCount(genotype))
100
}
101
val variantsRDD: RDD[SampleVariant] = genotypes.map(toVariant)
102
```
103
104
Next, we count the total number of samples (individuals) in the data. We then group the data by variant ID and filter
105
out those variants which do not appear in all of the samples. The aim of this is to simplify the processing of the data and, since
106
we have a very large number of variants in the data (up to 30 million, depending on the exact data set), filtering out
107
a small number will not make a significant difference to the results. In fact, in the next step we'll reduce the
108
number of variants even further.
109
110
```scala
111
val variantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsRDD.groupBy(_.sampleId)
112
val sampleCount: Long = variantsBySampleId.count()
113
println("Found " + sampleCount + " samples")
114
val variantsByVariantId: RDD[(Int, Iterable[SampleVariant])] = variantsRDD.groupBy(_.variantId).filter {
115
  case (_, sampleVariants) => sampleVariants.size == sampleCount
116
}
117
```
118
119
When we train our machine learning model, each variant will be treated as a
120
"[feature](https://en.wikipedia.org/wiki/Feature_(machine_learning))" that is used to train the model.
121
Since it can be difficult to train machine learning models with very large numbers of features in the data
122
(particularly if the number of samples is relatively small), we first need to try and reduce the number of variants
123
in the data.
124
125
To do this, we first compute the frequency with which alternate alleles have occurred for each variant. We then
126
filter the variants down to just those that appear within a certain frequency range. In this case, we've chosen a
127
fairly arbitrary frequency of 11. This was chosen through experimentation as a value that leaves around 3,000 variants
128
in the data set we are using.
129
130
There are more structured approaches to
131
[dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction), which we perhaps could have
132
employed, but this technique seems to work well enough for this example.
133
134
```scala
135
val variantFrequencies: collection.Map[Int, Int] = variantsByVariantId.map {
136
  case (variantId, sampleVariants) => (variantId, sampleVariants.count(_.alternateCount > 0))
137
}.collectAsMap()
138
val permittedRange = inclusive(11, 11)
139
val filteredVariantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsBySampleId.map {
140
  case (sampleId, sampleVariants) =>
141
    val filteredSampleVariants = sampleVariants.filter(variant => permittedRange.contains(
142
      variantFrequencies.getOrElse(variant.variantId, -1)))
143
    (sampleId, filteredSampleVariants)
144
}
145
```
146
147
### Creating the Training Data
148
149
To train our model, we need our data to be in tabular form where each row represents a single sample, and each
150
column represents a specific variant. The table also contains a column for the population group or "Region", which is
151
what we are trying to predict.
152
153
Ultimately, in order for our data to be consumed by H2O we need it to end up in an H2O `DataFrame` object. Currently,
154
the best way to do this in Spark seems to be to convert our data to an RDD of Spark SQL
155
[Row](http://spark.apache.org/docs/1.4.0/api/scala/index.html#org.apache.spark.sql.Row) objects, and then this can
156
automatically be converted to an H2O DataFrame.
157
158
To achieve this, we first need to group the data by sample ID, and then sort the variants for each sample in a
159
consistent manner (by variant ID). We can then create a header row for our table, containing the Region column,
160
the sample ID and all of the variants. We then create an RDD of type `Row` for each sample.
161
162
```scala
163
val sortedVariantsBySampleId: RDD[(String, Array[SampleVariant])] = filteredVariantsBySampleId.map {
164
  case (sampleId, variants) =>
165
    (sampleId, variants.toArray.sortBy(_.variantId))
166
}
167
val header = StructType(Array(StructField("Region", StringType)) ++
168
  sortedVariantsBySampleId.first()._2.map(variant => {StructField(variant.variantId.toString, IntegerType)}))
169
val rowRDD: RDD[Row] = sortedVariantsBySampleId.map {
170
  case (sampleId, sortedVariants) =>
171
    val region: Array[String] = Array(panel.getOrElse(sampleId, "Unknown"))
172
    val alternateCounts: Array[Int] = sortedVariants.map(_.alternateCount)
173
    Row.fromSeq(region ++ alternateCounts)
174
}
175
```
176
177
As mentioned above, once we have our RDD of `Row` objects we can then convert these automatically to an H2O
178
DataFrame using Sparking Water (H2O's Spark integration).
179
180
```scala
181
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
182
val schemaRDD = sqlContext.applySchema(rowRDD, header)
183
val h2oContext = new H2OContext(sc).start()
184
import h2oContext._
185
val dataFrame = h2oContext.toDataFrame(schemaRDD)
186
```
187
188
Now that we have a DataFrame, we want to split it into the training data (which we'll use to train our model), and a
189
[test set](https://en.wikipedia.org/wiki/Test_set) (which we'll use to ensure that
190
[overfitting](https://en.wikipedia.org/wiki/Overfitting) has not occurred).
191
192
We will also create a "validation" set, which performs a similar purpose to the test set - in that it will be used to
193
validate the strength of our model as it is being built, while avoiding overfitting. However, when training a neural
194
network, we typically keep the validation set distinct from the test set, to enable us to learn
195
[hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) for the model.
196
See [chapter 3 of Michael Nielsen's "Neural Networks and Deep Learning"](http://neuralnetworksanddeeplearning.com/chap3.html)
197
for more details on this.
198
199
H2O comes with a class called `FrameSplitter`, so splitting the data is simply a matter of calling creating one
200
of those and letting it split the data set.
201
202
```scala
203
val frameSplitter = new FrameSplitter(dataFrame, Array(.5, .3), Array("training", "test", "validation").map(Key.make), null)
204
water.H2O.submitTask(frameSplitter)
205
val splits = frameSplitter.getResult
206
val training = splits(0)
207
val validation = splits(2)
208
```
209
210
### Training the Model
211
212
Next, we need to set the parameters for our deep learning model. We specify the training and validation data sets,
213
as well as the column in the data which contains the item we are trying to predict (in this case, the Region).
214
We also set some [hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) which affect the way
215
the model learns. We won't go into detail about these here, but you can read more in the
216
[H2O documentation](http://docs.h2o.ai/h2oclassic/datascience/deeplearning.html). These parameters have been
217
chosen through experimentation - however, H2O provides methods for
218
[automatically tuning hyper-parameters](http://learn.h2o.ai/content/hands-on_training/deep_learning.html) so
219
it may be possible to achieve better results by employing one of these methods.
220
221
```scala
222
val deepLearningParameters = new DeepLearningParameters()
223
deepLearningParameters._train = training
224
deepLearningParameters._valid = validation
225
deepLearningParameters._response_column = "Region"
226
deepLearningParameters._epochs = 10
227
deepLearningParameters._activation = Activation.RectifierWithDropout
228
deepLearningParameters._hidden = Array[Int](100,100)
229
```
230
231
Finally, we're ready to train our deep learning model! Now that we've set everything up this is easy:
232
we simply create a H2O `DeepLearning` object and call `trainModel` on it.
233
234
```scala
235
val deepLearning = new DeepLearning(deepLearningParameters)
236
val deepLearningModel = deepLearning.trainModel.get
237
```
238
239
Having trained our model in the previous step, we now need to check how well it predicts the population
240
groups in our data set. To do this we "score" our entire data set (including training, test, and validation data)
241
against our model:
242
243
```scala
244
deepLearningModel.score(dataFrame)('predict)
245
```
246
247
This final step will print a [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) which shows how
248
well our model predicts our population groups. All being well, the confusion matrix should look something like this:
249
250
```
251
Confusion Matrix (vertical: actual; across: predicted):
252
       ASW CHB GBR  Error      Rate
253
   ASW  60   1   0 0.0164 =  1 / 61
254
   CHB   0 103   0 0.0000 = 0 / 103
255
   GBR   0   1  90 0.0110 =  1 / 91
256
Totals  60 105  90 0.0078 = 2 / 255
257
```
258
259
This tells us that the model has correctly predicted 253 out of 255 population groups correctly (an accuracy of
260
more than 99%). Nice!
261
262
## Building and Running
263
264
### Prerequisites
265
266
Before building and running the example, please ensure you have version 7 or later of the
267
[Java JDK](http://www.oracle.com/technetwork/java/javase/downloads/index.html) installed.
268
269
### Building
270
271
To build the example, first clone the GitHub repo at [https://github.com/nfergu/popstrat](https://github.com/nfergu/popstrat).
272
273
Then [download and install Maven](http://maven.apache.org/download.cgi). Then, at the command line, type:
274
275
```
276
mvn clean package
277
```
278
279
This will build a JAR (target/uber-popstrat-0.1-SNAPSHOT.jar), containing the `PopStrat` class,
280
as well as all of its dependencies.
281
282
### Running
283
284
First, [download Spark version 1.2.0](http://spark.apache.org/downloads.html) and unpack it on your machine.
285
286
Next you'll need to get some genomics data. Go to your
287
[nearest mirror of the 1000 genomes FTP site](http://www.1000genomes.org/data#DataAccess).
288
From the `release/20130502/` directory download
289
the `ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz` file and
290
the `integrated_call_samples_v3.20130502.ALL.panel` file. The first file file is the genotype data for chromosome 22,
291
and the second file is the panel file, which describes the population group for each sample in the genotype data.
292
293
Unzip the genotype data before continuing. This will require around 10GB of disk space.
294
295
To speed up execution and save disk space, you can convert the genotype VCF file to [ADAM](https://github.com/bigdatagenomics/adam)
296
format (using the ADAM `transform` command) if you wish. However,
297
this will take some time up-front. Both ADAM and VCF formats are supported.
298
299
Next, run the following command:
300
301
```
302
YOUR_SPARK_HOME/bin/spark-submit --class "com.neilferguson.PopStrat" --master local[6] --driver-memory 6G target/uber-popstrat-0.1-SNAPSHOT.jar <genotypesfile> <panelfile>
303
```
304
305
Replacing &lt;genotypesfile&gt; with the path to your genotype data file (ADAM or VCF), and &lt;panelfile&gt; with the panel file
306
from 1000 genomes.
307
308
This runs the example using a local (in-process) Spark master with 6 cores and 6GB of RAM. You can run against a different
309
Spark cluster by modifying the options in the above command line. See the
310
[Spark documentation](https://spark.apache.org/docs/1.2.0/submitting-applications.html) for further details.
311
312
Using the above data, the example may take up to 2-3 hours to run, depending on hardware. When it is finished, you should
313
see a [confusion matrix](http://en.wikipedia.org/wiki/Confusion_matrix) which shows the predicted versus the actual
314
populations. If all has gone well, this should show an accuracy of more than 99%.
315
See the "Code" section above for more details on what exactly you should expect to see.
316
317
## Conclusion
318
319
In this post, we have shown how to combine ADAM and Apache Spark with H2O's deep learning capabilities to predict
320
an individual's population group based on his or her genomic data. Our results demonstrate that we can predict these
321
very well, with more than 99% accuracy. Our choice of technologies makes for a relatively straightforward implementation,
322
and we expect it to be very scalable.
323
324
Future work could involve validating the scalability of our solution on more hardware, trying to predict a wider
325
range of population groups (currently we only predict 3 groups), and tuning the deep learning hyper-parameters to
326
achieve even better accuracy.