Diff of /blog/bdgblog.md [000000] .. [51428b]

Switch to side-by-side view

--- a
+++ b/blog/bdgblog.md
@@ -0,0 +1,326 @@
+# Genomic Analysis Using ADAM, Spark and Deep Learning
+
+Can we use deep learning to predict which population group you belong to, based solely on your genome?
+
+Yes, we can - and in this post, we will show you exactly how to do this in a scalable way, using Apache Spark. We
+will explain how to apply [deep learning](https://en.wikipedia.org/wiki/Deep_learning) using [artifical neural networks]
+(https://en.wikipedia.org/wiki/Artificial_neural_network) to predict which population group an individual belongs to -
+based entirely on his or her genomic data.
+
+This is a follow-up to an earlier post:
+[Scalable Genomes Clustering With ADAM and Spark](http://bdgenomics.org/blog/2015/02/02/scalable-genomes-clustering-with-adam-and-spark/)
+and attempts to replicate the results of that post. However, we will use a different machine learning technique.
+Where the original post used [k-means clustering](https://en.wikipedia.org/wiki/K-means_clustering), we will use
+deep learning.
+
+We will use [ADAM](https://github.com/bigdatagenomics/adam) and [Apache Spark](https://spark.apache.org/) in
+combination with [H2O](http://0xdata.com/product/), an open source predictive analytics platform, and
+[Sparking Water](http://0xdata.com/product/sparkling-water/), which integrates H2O with Spark.
+
+## Code
+
+In this section, we'll dive straight into the code. If you'd rather get something working before looking at the code
+you can skip to the "Building and Running" section.
+
+The complete Scala code for this example can be found in
+[the PopStrat.scala class on GitHub](https://github.com/nfergu/popstrat/blob/master/src/main/scala/com/neilferguson/PopStrat.scala)
+and we'll refer to sections of the code here. Basic familiarity with Scala and
+[Apache Spark](https://spark.apache.org/) is assumed.
+
+### Setting-up
+
+The first thing we need to do is to read the names of the Genotype and Panel files that are passed into our program.
+The Genotype file contains data about a set of individuals (referred to here as "samples") and their genetic
+variation. The Panel file lists the population group (or "region") for each sample in the Genotype file; this is
+what we will try to predict.
+
+```scala
+val genotypeFile = args(0)
+val panelFile = args(1)
+```
+
+Next, we set-up our Spark Context. Our program permits the Spark master to be specified as one of its arguments.
+This is useful when running from an IDE, but is omitted when running from the ```spark-submit``` script (see below).
+
+```scala
+val master = if (args.length > 2) Some(args(2)) else None
+val conf = new SparkConf().setAppName("PopStrat")
+master.foreach(conf.setMaster)
+val sc = new SparkContext(conf)
+```
+
+Next, we declare a set called ```populations``` which contains all of the population groups that we're interested
+in predicting. We then read the Panel file into a Map, filtering it based on the population groups in the
+```populations``` set. The format of the panel file is described [here](http://www.1000genomes.org/faq/what-panel-file).
+Luckily it's very simple, containing the sample ID in the first column and the population group in the second.
+
+```scala
+val populations = Set("GBR", "ASW", "CHB")
+def extract(file: String, filter: (String, String) => Boolean): Map[String,String] = {
+  Source.fromFile(file).getLines().map(line => {
+    val tokens = line.split("\t").toList
+    tokens(0) -> tokens(1)
+  }).toMap.filter(tuple => filter(tuple._1, tuple._2))
+}
+val panel: Map[String,String] = extract(panelFile, (sampleID: String, pop: String) => populations.contains(pop))
+```
+
+### Preparing the Genomics Data
+
+Next, we use [ADAM](https://github.com/bigdatagenomics/adam) to read our genotype data into a Spark RDD. Since we've
+imported ```ADAMContext._``` at the top of our class, this is simply a matter of calling `loadGenotypes` on the
+Spark Context. Then, we filter the genotype data to contain only samples that are in the population groups which we're
+interested in.
+
+```scala
+val allGenotypes: RDD[Genotype] = sc.loadGenotypes(genotypeFile)
+val genotypes: RDD[Genotype] = allGenotypes.filter(genotype => {panel.contains(genotype.getSampleId)})
+```
+
+Next, we convert the ADAM ```Genotype``` objects into our own ```SampleVariant``` objects. These objects contain just the data
+we need for further processing: the sample ID (which uniquely identifies a particular sample), a variant ID (which
+uniquely identifies a particular genetic variant) and a count of alternate
+[alleles](http://www.snpedia.com/index.php/Allele), where the sample differs from the
+reference genome. These variations will help us to classify individuals according to their population group.
+
+```scala
+case class SampleVariant(sampleId: String, variantId: Int, alternateCount: Int)
+def variantId(genotype: Genotype): String = {
+  val name = genotype.getVariant.getContig.getContigName
+  val start = genotype.getVariant.getStart
+  val end = genotype.getVariant.getEnd
+  s"$name:$start:$end"
+}
+def alternateCount(genotype: Genotype): Int = {
+  genotype.getAlleles.asScala.count(_ != GenotypeAllele.Ref)
+}
+def toVariant(genotype: Genotype): SampleVariant = {
+  // Intern sample IDs as they will be repeated a lot
+  new SampleVariant(genotype.getSampleId.intern(), variantId(genotype).hashCode(), alternateCount(genotype))
+}
+val variantsRDD: RDD[SampleVariant] = genotypes.map(toVariant)
+```
+
+Next, we count the total number of samples (individuals) in the data. We then group the data by variant ID and filter
+out those variants which do not appear in all of the samples. The aim of this is to simplify the processing of the data and, since
+we have a very large number of variants in the data (up to 30 million, depending on the exact data set), filtering out
+a small number will not make a significant difference to the results. In fact, in the next step we'll reduce the
+number of variants even further.
+
+```scala
+val variantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsRDD.groupBy(_.sampleId)
+val sampleCount: Long = variantsBySampleId.count()
+println("Found " + sampleCount + " samples")
+val variantsByVariantId: RDD[(Int, Iterable[SampleVariant])] = variantsRDD.groupBy(_.variantId).filter {
+  case (_, sampleVariants) => sampleVariants.size == sampleCount
+}
+```
+
+When we train our machine learning model, each variant will be treated as a
+"[feature](https://en.wikipedia.org/wiki/Feature_(machine_learning))" that is used to train the model.
+Since it can be difficult to train machine learning models with very large numbers of features in the data
+(particularly if the number of samples is relatively small), we first need to try and reduce the number of variants
+in the data.
+
+To do this, we first compute the frequency with which alternate alleles have occurred for each variant. We then
+filter the variants down to just those that appear within a certain frequency range. In this case, we've chosen a
+fairly arbitrary frequency of 11. This was chosen through experimentation as a value that leaves around 3,000 variants
+in the data set we are using.
+
+There are more structured approaches to
+[dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction), which we perhaps could have
+employed, but this technique seems to work well enough for this example.
+
+```scala
+val variantFrequencies: collection.Map[Int, Int] = variantsByVariantId.map {
+  case (variantId, sampleVariants) => (variantId, sampleVariants.count(_.alternateCount > 0))
+}.collectAsMap()
+val permittedRange = inclusive(11, 11)
+val filteredVariantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsBySampleId.map {
+  case (sampleId, sampleVariants) =>
+    val filteredSampleVariants = sampleVariants.filter(variant => permittedRange.contains(
+      variantFrequencies.getOrElse(variant.variantId, -1)))
+    (sampleId, filteredSampleVariants)
+}
+```
+
+### Creating the Training Data
+
+To train our model, we need our data to be in tabular form where each row represents a single sample, and each
+column represents a specific variant. The table also contains a column for the population group or "Region", which is
+what we are trying to predict.
+
+Ultimately, in order for our data to be consumed by H2O we need it to end up in an H2O `DataFrame` object. Currently,
+the best way to do this in Spark seems to be to convert our data to an RDD of Spark SQL
+[Row](http://spark.apache.org/docs/1.4.0/api/scala/index.html#org.apache.spark.sql.Row) objects, and then this can
+automatically be converted to an H2O DataFrame.
+
+To achieve this, we first need to group the data by sample ID, and then sort the variants for each sample in a
+consistent manner (by variant ID). We can then create a header row for our table, containing the Region column,
+the sample ID and all of the variants. We then create an RDD of type `Row` for each sample.
+
+```scala
+val sortedVariantsBySampleId: RDD[(String, Array[SampleVariant])] = filteredVariantsBySampleId.map {
+  case (sampleId, variants) =>
+    (sampleId, variants.toArray.sortBy(_.variantId))
+}
+val header = StructType(Array(StructField("Region", StringType)) ++
+  sortedVariantsBySampleId.first()._2.map(variant => {StructField(variant.variantId.toString, IntegerType)}))
+val rowRDD: RDD[Row] = sortedVariantsBySampleId.map {
+  case (sampleId, sortedVariants) =>
+    val region: Array[String] = Array(panel.getOrElse(sampleId, "Unknown"))
+    val alternateCounts: Array[Int] = sortedVariants.map(_.alternateCount)
+    Row.fromSeq(region ++ alternateCounts)
+}
+```
+
+As mentioned above, once we have our RDD of `Row` objects we can then convert these automatically to an H2O
+DataFrame using Sparking Water (H2O's Spark integration).
+
+```scala
+val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+val schemaRDD = sqlContext.applySchema(rowRDD, header)
+val h2oContext = new H2OContext(sc).start()
+import h2oContext._
+val dataFrame = h2oContext.toDataFrame(schemaRDD)
+```
+
+Now that we have a DataFrame, we want to split it into the training data (which we'll use to train our model), and a
+[test set](https://en.wikipedia.org/wiki/Test_set) (which we'll use to ensure that
+[overfitting](https://en.wikipedia.org/wiki/Overfitting) has not occurred).
+
+We will also create a "validation" set, which performs a similar purpose to the test set - in that it will be used to
+validate the strength of our model as it is being built, while avoiding overfitting. However, when training a neural
+network, we typically keep the validation set distinct from the test set, to enable us to learn
+[hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) for the model.
+See [chapter 3 of Michael Nielsen's "Neural Networks and Deep Learning"](http://neuralnetworksanddeeplearning.com/chap3.html)
+for more details on this.
+
+H2O comes with a class called `FrameSplitter`, so splitting the data is simply a matter of calling creating one
+of those and letting it split the data set.
+
+```scala
+val frameSplitter = new FrameSplitter(dataFrame, Array(.5, .3), Array("training", "test", "validation").map(Key.make), null)
+water.H2O.submitTask(frameSplitter)
+val splits = frameSplitter.getResult
+val training = splits(0)
+val validation = splits(2)
+```
+
+### Training the Model
+
+Next, we need to set the parameters for our deep learning model. We specify the training and validation data sets,
+as well as the column in the data which contains the item we are trying to predict (in this case, the Region).
+We also set some [hyper-parameters](http://colinraffel.com/wiki/neural_network_hyperparameters) which affect the way
+the model learns. We won't go into detail about these here, but you can read more in the
+[H2O documentation](http://docs.h2o.ai/h2oclassic/datascience/deeplearning.html). These parameters have been
+chosen through experimentation - however, H2O provides methods for
+[automatically tuning hyper-parameters](http://learn.h2o.ai/content/hands-on_training/deep_learning.html) so
+it may be possible to achieve better results by employing one of these methods.
+
+```scala
+val deepLearningParameters = new DeepLearningParameters()
+deepLearningParameters._train = training
+deepLearningParameters._valid = validation
+deepLearningParameters._response_column = "Region"
+deepLearningParameters._epochs = 10
+deepLearningParameters._activation = Activation.RectifierWithDropout
+deepLearningParameters._hidden = Array[Int](100,100)
+```
+
+Finally, we're ready to train our deep learning model! Now that we've set everything up this is easy:
+we simply create a H2O `DeepLearning` object and call `trainModel` on it.
+
+```scala
+val deepLearning = new DeepLearning(deepLearningParameters)
+val deepLearningModel = deepLearning.trainModel.get
+```
+
+Having trained our model in the previous step, we now need to check how well it predicts the population
+groups in our data set. To do this we "score" our entire data set (including training, test, and validation data)
+against our model:
+
+```scala
+deepLearningModel.score(dataFrame)('predict)
+```
+
+This final step will print a [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) which shows how
+well our model predicts our population groups. All being well, the confusion matrix should look something like this:
+
+```
+Confusion Matrix (vertical: actual; across: predicted):
+       ASW CHB GBR  Error      Rate
+   ASW  60   1   0 0.0164 =  1 / 61
+   CHB   0 103   0 0.0000 = 0 / 103
+   GBR   0   1  90 0.0110 =  1 / 91
+Totals  60 105  90 0.0078 = 2 / 255
+```
+
+This tells us that the model has correctly predicted 253 out of 255 population groups correctly (an accuracy of
+more than 99%). Nice!
+
+## Building and Running
+
+### Prerequisites
+
+Before building and running the example, please ensure you have version 7 or later of the
+[Java JDK](http://www.oracle.com/technetwork/java/javase/downloads/index.html) installed.
+
+### Building
+
+To build the example, first clone the GitHub repo at [https://github.com/nfergu/popstrat](https://github.com/nfergu/popstrat).
+
+Then [download and install Maven](http://maven.apache.org/download.cgi). Then, at the command line, type:
+
+```
+mvn clean package
+```
+
+This will build a JAR (target/uber-popstrat-0.1-SNAPSHOT.jar), containing the `PopStrat` class,
+as well as all of its dependencies.
+
+### Running
+
+First, [download Spark version 1.2.0](http://spark.apache.org/downloads.html) and unpack it on your machine.
+
+Next you'll need to get some genomics data. Go to your
+[nearest mirror of the 1000 genomes FTP site](http://www.1000genomes.org/data#DataAccess).
+From the `release/20130502/` directory download
+the `ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz` file and
+the `integrated_call_samples_v3.20130502.ALL.panel` file. The first file file is the genotype data for chromosome 22,
+and the second file is the panel file, which describes the population group for each sample in the genotype data.
+
+Unzip the genotype data before continuing. This will require around 10GB of disk space.
+
+To speed up execution and save disk space, you can convert the genotype VCF file to [ADAM](https://github.com/bigdatagenomics/adam)
+format (using the ADAM `transform` command) if you wish. However,
+this will take some time up-front. Both ADAM and VCF formats are supported.
+
+Next, run the following command:
+
+```
+YOUR_SPARK_HOME/bin/spark-submit --class "com.neilferguson.PopStrat" --master local[6] --driver-memory 6G target/uber-popstrat-0.1-SNAPSHOT.jar <genotypesfile> <panelfile>
+```
+
+Replacing &lt;genotypesfile&gt; with the path to your genotype data file (ADAM or VCF), and &lt;panelfile&gt; with the panel file
+from 1000 genomes.
+
+This runs the example using a local (in-process) Spark master with 6 cores and 6GB of RAM. You can run against a different
+Spark cluster by modifying the options in the above command line. See the
+[Spark documentation](https://spark.apache.org/docs/1.2.0/submitting-applications.html) for further details.
+
+Using the above data, the example may take up to 2-3 hours to run, depending on hardware. When it is finished, you should
+see a [confusion matrix](http://en.wikipedia.org/wiki/Confusion_matrix) which shows the predicted versus the actual
+populations. If all has gone well, this should show an accuracy of more than 99%.
+See the "Code" section above for more details on what exactly you should expect to see.
+
+## Conclusion
+
+In this post, we have shown how to combine ADAM and Apache Spark with H2O's deep learning capabilities to predict
+an individual's population group based on his or her genomic data. Our results demonstrate that we can predict these
+very well, with more than 99% accuracy. Our choice of technologies makes for a relatively straightforward implementation,
+and we expect it to be very scalable.
+
+Future work could involve validating the scalability of our solution on more hardware, trying to predict a wider
+range of population groups (currently we only predict 3 groups), and tuning the deep learning hyper-parameters to
+achieve even better accuracy.