|
a |
|
b/blog/bdgblog.md |
|
|
1 |
# Genomic Analysis Using ADAM, Spark and Deep Learning |
|
|
2 |
|
|
|
3 |
Can we use deep learning to predict which population group you belong to, based solely on your genome? |
|
|
4 |
|
|
|
5 |
Yes, we can - and in this post, we will show you exactly how to do this in a scalable way, using Apache Spark. We |
|
|
6 |
will explain how to apply [deep learning](https://en.wikipedia.org/wiki/Deep_learning) using [artifical neural networks] |
|
|
7 |
(https://en.wikipedia.org/wiki/Artificial_neural_network) to predict which population group an individual belongs to - |
|
|
8 |
based entirely on his or her genomic data. |
|
|
9 |
|
|
|
10 |
This is a follow-up to an earlier post: |
|
|
11 |
[Scalable Genomes Clustering With ADAM and Spark](http://bdgenomics.org/blog/2015/02/02/scalable-genomes-clustering-with-adam-and-spark/) |
|
|
12 |
and attempts to replicate the results of that post. However, we will use a different machine learning technique. |
|
|
13 |
Where the original post used [k-means clustering](https://en.wikipedia.org/wiki/K-means_clustering), we will use |
|
|
14 |
deep learning. |
|
|
15 |
|
|
|
16 |
We will use [ADAM](https://github.com/bigdatagenomics/adam) and [Apache Spark](https://spark.apache.org/) in |
|
|
17 |
combination with [H2O](http://0xdata.com/product/), an open source predictive analytics platform, and |
|
|
18 |
[Sparking Water](http://0xdata.com/product/sparkling-water/), which integrates H2O with Spark. |
|
|
19 |
|
|
|
20 |
## Code |
|
|
21 |
|
|
|
22 |
In this section, we'll dive straight into the code. If you'd rather get something working before looking at the code |
|
|
23 |
you can skip to the "Building and Running" section. |
|
|
24 |
|
|
|
25 |
The complete Scala code for this example can be found in |
|
|
26 |
[the PopStrat.scala class on GitHub](https://github.com/nfergu/popstrat/blob/master/src/main/scala/com/neilferguson/PopStrat.scala) |
|
|
27 |
and we'll refer to sections of the code here. Basic familiarity with Scala and |
|
|
28 |
[Apache Spark](https://spark.apache.org/) is assumed. |
|
|
29 |
|
|
|
30 |
### Setting-up |
|
|
31 |
|
|
|
32 |
The first thing we need to do is to read the names of the Genotype and Panel files that are passed into our program. |
|
|
33 |
The Genotype file contains data about a set of individuals (referred to here as "samples") and their genetic |
|
|
34 |
variation. The Panel file lists the population group (or "region") for each sample in the Genotype file; this is |
|
|
35 |
what we will try to predict. |
|
|
36 |
|
|
|
37 |
```scala |
|
|
38 |
val genotypeFile = args(0) |
|
|
39 |
val panelFile = args(1) |
|
|
40 |
``` |
|
|
41 |
|
|
|
42 |
Next, we set-up our Spark Context. Our program permits the Spark master to be specified as one of its arguments. |
|
|
43 |
This is useful when running from an IDE, but is omitted when running from the ```spark-submit``` script (see below). |
|
|
44 |
|
|
|
45 |
```scala |
|
|
46 |
val master = if (args.length > 2) Some(args(2)) else None |
|
|
47 |
val conf = new SparkConf().setAppName("PopStrat") |
|
|
48 |
master.foreach(conf.setMaster) |
|
|
49 |
val sc = new SparkContext(conf) |
|
|
50 |
``` |
|
|
51 |
|
|
|
52 |
Next, we declare a set called ```populations``` which contains all of the population groups that we're interested |
|
|
53 |
in predicting. We then read the Panel file into a Map, filtering it based on the population groups in the |
|
|
54 |
```populations``` set. The format of the panel file is described [here](http://www.1000genomes.org/faq/what-panel-file). |
|
|
55 |
Luckily it's very simple, containing the sample ID in the first column and the population group in the second. |
|
|
56 |
|
|
|
57 |
```scala |
|
|
58 |
val populations = Set("GBR", "ASW", "CHB") |
|
|
59 |
def extract(file: String, filter: (String, String) => Boolean): Map[String,String] = { |
|
|
60 |
Source.fromFile(file).getLines().map(line => { |
|
|
61 |
val tokens = line.split("\t").toList |
|
|
62 |
tokens(0) -> tokens(1) |
|
|
63 |
}).toMap.filter(tuple => filter(tuple._1, tuple._2)) |
|
|
64 |
} |
|
|
65 |
val panel: Map[String,String] = extract(panelFile, (sampleID: String, pop: String) => populations.contains(pop)) |
|
|
66 |
``` |
|
|
67 |
|
|
|
68 |
### Preparing the Genomics Data |
|
|
69 |
|
|
|
70 |
Next, we use [ADAM](https://github.com/bigdatagenomics/adam) to read our genotype data into a Spark RDD. Since we've |
|
|
71 |
imported ```ADAMContext._``` at the top of our class, this is simply a matter of calling `loadGenotypes` on the |
|
|
72 |
Spark Context. Then, we filter the genotype data to contain only samples that are in the population groups which we're |
|
|
73 |
interested in. |
|
|
74 |
|
|
|
75 |
```scala |
|
|
76 |
val allGenotypes: RDD[Genotype] = sc.loadGenotypes(genotypeFile) |
|
|
77 |
val genotypes: RDD[Genotype] = allGenotypes.filter(genotype => {panel.contains(genotype.getSampleId)}) |
|
|
78 |
``` |
|
|
79 |
|
|
|
80 |
Next, we convert the ADAM ```Genotype``` objects into our own ```SampleVariant``` objects. These objects contain just the data |
|
|
81 |
we need for further processing: the sample ID (which uniquely identifies a particular sample), a variant ID (which |
|
|
82 |
uniquely identifies a particular genetic variant) and a count of alternate |
|
|
83 |
[alleles](http://www.snpedia.com/index.php/Allele), where the sample differs from the |
|
|
84 |
reference genome. These variations will help us to classify individuals according to their population group. |
|
|
85 |
|
|
|
86 |
```scala |
|
|
87 |
case class SampleVariant(sampleId: String, variantId: Int, alternateCount: Int) |
|
|
88 |
def variantId(genotype: Genotype): String = { |
|
|
89 |
val name = genotype.getVariant.getContig.getContigName |
|
|
90 |
val start = genotype.getVariant.getStart |
|
|
91 |
val end = genotype.getVariant.getEnd |
|
|
92 |
s"$name:$start:$end" |
|
|
93 |
} |
|
|
94 |
def alternateCount(genotype: Genotype): Int = { |
|
|
95 |
genotype.getAlleles.asScala.count(_ != GenotypeAllele.Ref) |
|
|
96 |
} |
|
|
97 |
def toVariant(genotype: Genotype): SampleVariant = { |
|
|
98 |
// Intern sample IDs as they will be repeated a lot |
|
|
99 |
new SampleVariant(genotype.getSampleId.intern(), variantId(genotype).hashCode(), alternateCount(genotype)) |
|
|
100 |
} |
|
|
101 |
val variantsRDD: RDD[SampleVariant] = genotypes.map(toVariant) |
|
|
102 |
``` |
|
|
103 |
|
|
|
104 |
Next, we count the total number of samples (individuals) in the data. We then group the data by variant ID and filter |
|
|
105 |
out those variants which do not appear in all of the samples. The aim of this is to simplify the processing of the data and, since |
|
|
106 |
we have a very large number of variants in the data (up to 30 million, depending on the exact data set), filtering out |
|
|
107 |
a small number will not make a significant difference to the results. In fact, in the next step we'll reduce the |
|
|
108 |
number of variants even further. |
|
|
109 |
|
|
|
110 |
```scala |
|
|
111 |
val variantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsRDD.groupBy(_.sampleId) |
|
|
112 |
val sampleCount: Long = variantsBySampleId.count() |
|
|
113 |
println("Found " + sampleCount + " samples") |
|
|
114 |
val variantsByVariantId: RDD[(Int, Iterable[SampleVariant])] = variantsRDD.groupBy(_.variantId).filter { |
|
|
115 |
case (_, sampleVariants) => sampleVariants.size == sampleCount |
|
|
116 |
} |
|
|
117 |
``` |
|
|
118 |
|
|
|
119 |
When we train our machine learning model, each variant will be treated as a |
|
|
120 |
"[feature](https://en.wikipedia.org/wiki/Feature_(machine_learning))" that is used to train the model. |
|
|
121 |
Since it can be difficult to train machine learning models with very large numbers of features in the data |
|
|
122 |
(particularly if the number of samples is relatively small), we first need to try and reduce the number of variants |
|
|
123 |
in the data. |
|
|
124 |
|
|
|
125 |
To do this, we first compute the frequency with which alternate alleles have occurred for each variant. We then |
|
|
126 |
filter the variants down to just those that appear within a certain frequency range. In this case, we've chosen a |
|
|
127 |
fairly arbitrary frequency of 11. This was chosen through experimentation as a value that leaves around 3,000 variants |
|
|
128 |
in the data set we are using. |
|
|
129 |
|
|
|
130 |
There are more structured approaches to |
|
|
131 |
[dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction), which we perhaps could have |
|
|
132 |
employed, but this technique seems to work well enough for this example. |
|
|
133 |
|
|
|
134 |
```scala |
|
|
135 |
val variantFrequencies: collection.Map[Int, Int] = variantsByVariantId.map { |
|
|
136 |
case (variantId, sampleVariants) => (variantId, sampleVariants.count(_.alternateCount > 0)) |
|
|
137 |
}.collectAsMap() |
|
|
138 |
val permittedRange = inclusive(11, 11) |
|
|
139 |
val filteredVariantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsBySampleId.map { |
|
|
140 |
case (sampleId, sampleVariants) => |
|
|
141 |
val filteredSampleVariants = sampleVariants.filter(variant => permittedRange.contains( |
|
|
142 |
variantFrequencies.getOrElse(variant.variantId, -1))) |
|
|
143 |
(sampleId, filteredSampleVariants) |
|
|
144 |
} |
|
|
145 |
``` |
|
|
146 |
|
|
|
147 |
### Creating the Training Data |
|
|
148 |
|
|
|
149 |
To train our model, we need our data to be in tabular form where each row represents a single sample, and each |
|
|
150 |
column represents a specific variant. The table also contains a column for the population group or "Region", which is |
|
|
151 |
what we are trying to predict. |
|
|
152 |
|
|
|
153 |
Ultimately, in order for our data to be consumed by H2O we need it to end up in an H2O `DataFrame` object. Currently, |
|
|
154 |
the best way to do this in Spark seems to be to convert our data to an RDD of Spark SQL |
|
|
155 |
[Row](http://spark.apache.org/docs/1.4.0/api/scala/index.html#org.apache.spark.sql.Row) objects, and then this can |
|
|
156 |
automatically be converted to an H2O DataFrame. |
|
|
157 |
|
|
|
158 |
To achieve this, we first need to group the data by sample ID, and then sort the variants for each sample in a |
|
|
159 |
consistent manner (by variant ID). We can then create a header row for our table, containing the Region column, |
|
|
160 |
the sample ID and all of the variants. We then create an RDD of type `Row` for each sample. |
|
|
161 |
|
|
|
162 |
```scala |
|
|
163 |
val sortedVariantsBySampleId: RDD[(String, Array[SampleVariant])] = filteredVariantsBySampleId.map { |
|
|
164 |
case (sampleId, variants) => |
|
|
165 |
(sampleId, variants.toArray.sortBy(_.variantId)) |
|
|
166 |
} |
|
|
167 |
val header = StructType(Array(StructField("Region", StringType)) ++ |
|
|
168 |
sortedVariantsBySampleId.first()._2.map(variant => {StructField(variant.variantId.toString, IntegerType)})) |
|
|
169 |
val rowRDD: RDD[Row] = sortedVariantsBySampleId.map { |
|
|
170 |
case (sampleId, sortedVariants) => |
|
|
171 |
val region: Array[String] = Array(panel.getOrElse(sampleId, "Unknown")) |
|
|
172 |
val alternateCounts: Array[Int] = sortedVariants.map(_.alternateCount) |
|
|
173 |
Row.fromSeq(region ++ alternateCounts) |
|
|
174 |
} |
|
|
175 |
``` |
|
|
176 |
|
|
|
177 |
As mentioned above, once we have our RDD of `Row` objects we can then convert these automatically to an H2O |
|
|
178 |
DataFrame using Sparking Water (H2O's Spark integration). |
|
|
179 |
|
|
|
180 |
```scala |
|
|
181 |
val sqlContext = new org.apache.spark.sql.SQLContext(sc) |
|
|
182 |
val schemaRDD = sqlContext.applySchema(rowRDD, header) |
|
|
183 |
val h2oContext = new H2OContext(sc).start() |
|
|
184 |
import h2oContext._ |
|
|
185 |
val dataFrame = h2oContext.toDataFrame(schemaRDD) |
|
|
186 |
``` |
|
|
187 |
|
|
|
188 |
Now that we have a DataFrame, we want to split it into the training data (which we'll use to train our model), and a |
|
|
189 |
[test set](https://en.wikipedia.org/wiki/Test_set) (which we'll use to ensure that |
|
|
190 |
[overfitting](https://en.wikipedia.org/wiki/Overfitting) has not occurred). |
|
|
191 |
|
|
|
192 |
We will also create a "validation" set, which performs a similar purpose to the test set - in that it will be used to |
|
|
193 |
validate the strength of our model as it is being built, while avoiding overfitting. However, when training a neural |
|
|
194 |
network, we typically keep the validation set distinct from the test set, to enable us to learn |
|
|
195 |
[hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) for the model. |
|
|
196 |
See [chapter 3 of Michael Nielsen's "Neural Networks and Deep Learning"](http://neuralnetworksanddeeplearning.com/chap3.html) |
|
|
197 |
for more details on this. |
|
|
198 |
|
|
|
199 |
H2O comes with a class called `FrameSplitter`, so splitting the data is simply a matter of calling creating one |
|
|
200 |
of those and letting it split the data set. |
|
|
201 |
|
|
|
202 |
```scala |
|
|
203 |
val frameSplitter = new FrameSplitter(dataFrame, Array(.5, .3), Array("training", "test", "validation").map(Key.make), null) |
|
|
204 |
water.H2O.submitTask(frameSplitter) |
|
|
205 |
val splits = frameSplitter.getResult |
|
|
206 |
val training = splits(0) |
|
|
207 |
val validation = splits(2) |
|
|
208 |
``` |
|
|
209 |
|
|
|
210 |
### Training the Model |
|
|
211 |
|
|
|
212 |
Next, we need to set the parameters for our deep learning model. We specify the training and validation data sets, |
|
|
213 |
as well as the column in the data which contains the item we are trying to predict (in this case, the Region). |
|
|
214 |
We also set some [hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) which affect the way |
|
|
215 |
the model learns. We won't go into detail about these here, but you can read more in the |
|
|
216 |
[H2O documentation](http://docs.h2o.ai/h2oclassic/datascience/deeplearning.html). These parameters have been |
|
|
217 |
chosen through experimentation - however, H2O provides methods for |
|
|
218 |
[automatically tuning hyper-parameters](http://learn.h2o.ai/content/hands-on_training/deep_learning.html) so |
|
|
219 |
it may be possible to achieve better results by employing one of these methods. |
|
|
220 |
|
|
|
221 |
```scala |
|
|
222 |
val deepLearningParameters = new DeepLearningParameters() |
|
|
223 |
deepLearningParameters._train = training |
|
|
224 |
deepLearningParameters._valid = validation |
|
|
225 |
deepLearningParameters._response_column = "Region" |
|
|
226 |
deepLearningParameters._epochs = 10 |
|
|
227 |
deepLearningParameters._activation = Activation.RectifierWithDropout |
|
|
228 |
deepLearningParameters._hidden = Array[Int](100,100) |
|
|
229 |
``` |
|
|
230 |
|
|
|
231 |
Finally, we're ready to train our deep learning model! Now that we've set everything up this is easy: |
|
|
232 |
we simply create a H2O `DeepLearning` object and call `trainModel` on it. |
|
|
233 |
|
|
|
234 |
```scala |
|
|
235 |
val deepLearning = new DeepLearning(deepLearningParameters) |
|
|
236 |
val deepLearningModel = deepLearning.trainModel.get |
|
|
237 |
``` |
|
|
238 |
|
|
|
239 |
Having trained our model in the previous step, we now need to check how well it predicts the population |
|
|
240 |
groups in our data set. To do this we "score" our entire data set (including training, test, and validation data) |
|
|
241 |
against our model: |
|
|
242 |
|
|
|
243 |
```scala |
|
|
244 |
deepLearningModel.score(dataFrame)('predict) |
|
|
245 |
``` |
|
|
246 |
|
|
|
247 |
This final step will print a [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) which shows how |
|
|
248 |
well our model predicts our population groups. All being well, the confusion matrix should look something like this: |
|
|
249 |
|
|
|
250 |
``` |
|
|
251 |
Confusion Matrix (vertical: actual; across: predicted): |
|
|
252 |
ASW CHB GBR Error Rate |
|
|
253 |
ASW 60 1 0 0.0164 = 1 / 61 |
|
|
254 |
CHB 0 103 0 0.0000 = 0 / 103 |
|
|
255 |
GBR 0 1 90 0.0110 = 1 / 91 |
|
|
256 |
Totals 60 105 90 0.0078 = 2 / 255 |
|
|
257 |
``` |
|
|
258 |
|
|
|
259 |
This tells us that the model has correctly predicted 253 out of 255 population groups correctly (an accuracy of |
|
|
260 |
more than 99%). Nice! |
|
|
261 |
|
|
|
262 |
## Building and Running |
|
|
263 |
|
|
|
264 |
### Prerequisites |
|
|
265 |
|
|
|
266 |
Before building and running the example, please ensure you have version 7 or later of the |
|
|
267 |
[Java JDK](http://www.oracle.com/technetwork/java/javase/downloads/index.html) installed. |
|
|
268 |
|
|
|
269 |
### Building |
|
|
270 |
|
|
|
271 |
To build the example, first clone the GitHub repo at [https://github.com/nfergu/popstrat](https://github.com/nfergu/popstrat). |
|
|
272 |
|
|
|
273 |
Then [download and install Maven](http://maven.apache.org/download.cgi). Then, at the command line, type: |
|
|
274 |
|
|
|
275 |
``` |
|
|
276 |
mvn clean package |
|
|
277 |
``` |
|
|
278 |
|
|
|
279 |
This will build a JAR (target/uber-popstrat-0.1-SNAPSHOT.jar), containing the `PopStrat` class, |
|
|
280 |
as well as all of its dependencies. |
|
|
281 |
|
|
|
282 |
### Running |
|
|
283 |
|
|
|
284 |
First, [download Spark version 1.2.0](http://spark.apache.org/downloads.html) and unpack it on your machine. |
|
|
285 |
|
|
|
286 |
Next you'll need to get some genomics data. Go to your |
|
|
287 |
[nearest mirror of the 1000 genomes FTP site](http://www.1000genomes.org/data#DataAccess). |
|
|
288 |
From the `release/20130502/` directory download |
|
|
289 |
the `ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz` file and |
|
|
290 |
the `integrated_call_samples_v3.20130502.ALL.panel` file. The first file file is the genotype data for chromosome 22, |
|
|
291 |
and the second file is the panel file, which describes the population group for each sample in the genotype data. |
|
|
292 |
|
|
|
293 |
Unzip the genotype data before continuing. This will require around 10GB of disk space. |
|
|
294 |
|
|
|
295 |
To speed up execution and save disk space, you can convert the genotype VCF file to [ADAM](https://github.com/bigdatagenomics/adam) |
|
|
296 |
format (using the ADAM `transform` command) if you wish. However, |
|
|
297 |
this will take some time up-front. Both ADAM and VCF formats are supported. |
|
|
298 |
|
|
|
299 |
Next, run the following command: |
|
|
300 |
|
|
|
301 |
``` |
|
|
302 |
YOUR_SPARK_HOME/bin/spark-submit --class "com.neilferguson.PopStrat" --master local[6] --driver-memory 6G target/uber-popstrat-0.1-SNAPSHOT.jar <genotypesfile> <panelfile> |
|
|
303 |
``` |
|
|
304 |
|
|
|
305 |
Replacing <genotypesfile> with the path to your genotype data file (ADAM or VCF), and <panelfile> with the panel file |
|
|
306 |
from 1000 genomes. |
|
|
307 |
|
|
|
308 |
This runs the example using a local (in-process) Spark master with 6 cores and 6GB of RAM. You can run against a different |
|
|
309 |
Spark cluster by modifying the options in the above command line. See the |
|
|
310 |
[Spark documentation](https://spark.apache.org/docs/1.2.0/submitting-applications.html) for further details. |
|
|
311 |
|
|
|
312 |
Using the above data, the example may take up to 2-3 hours to run, depending on hardware. When it is finished, you should |
|
|
313 |
see a [confusion matrix](http://en.wikipedia.org/wiki/Confusion_matrix) which shows the predicted versus the actual |
|
|
314 |
populations. If all has gone well, this should show an accuracy of more than 99%. |
|
|
315 |
See the "Code" section above for more details on what exactly you should expect to see. |
|
|
316 |
|
|
|
317 |
## Conclusion |
|
|
318 |
|
|
|
319 |
In this post, we have shown how to combine ADAM and Apache Spark with H2O's deep learning capabilities to predict |
|
|
320 |
an individual's population group based on his or her genomic data. Our results demonstrate that we can predict these |
|
|
321 |
very well, with more than 99% accuracy. Our choice of technologies makes for a relatively straightforward implementation, |
|
|
322 |
and we expect it to be very scalable. |
|
|
323 |
|
|
|
324 |
Future work could involve validating the scalability of our solution on more hardware, trying to predict a wider |
|
|
325 |
range of population groups (currently we only predict 3 groups), and tuning the deep learning hyper-parameters to |
|
|
326 |
achieve even better accuracy. |