[51428b]: / src / main / scala / com / neilferguson / PopStrat.scala

Download this file

142 lines (122 with data), 7.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package com.neilferguson
import hex.FrameSplitter
import hex.deeplearning.DeepLearning
import hex.deeplearning.DeepLearningModel.DeepLearningParameters
import hex.deeplearning.DeepLearningModel.DeepLearningParameters.Activation
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.h2o.H2OContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.bdgenomics.adam.rdd.ADAMContext._
import org.bdgenomics.formats.avro.{Genotype, GenotypeAllele}
import water.Key
import scala.collection.JavaConverters._
import scala.collection.immutable.Range.inclusive
import scala.io.Source
object PopStrat {
def main(args: Array[String]): Unit = {
val genotypeFile = args(0)
val panelFile = args(1)
val master = if (args.length > 2) Some(args(2)) else None
val conf = new SparkConf().setAppName("PopStrat")
master.foreach(conf.setMaster)
val sc = new SparkContext(conf)
// Create a set of the populations that we want to predict
// Then create a map of sample ID -> population so that we can filter out the samples we're not interested in
val populations = Set("GBR", "ASW", "CHB")
def extract(file: String, filter: (String, String) => Boolean): Map[String,String] = {
Source.fromFile(file).getLines().map(line => {
val tokens = line.split("\t").toList
tokens(0) -> tokens(1)
}).toMap.filter(tuple => filter(tuple._1, tuple._2))
}
val panel: Map[String,String] = extract(panelFile, (sampleID: String, pop: String) => populations.contains(pop))
// Load the ADAM genotypes from the parquet file(s)
// Next, filter the genotypes so that we're left with only those in the populations we're interested in
val allGenotypes: RDD[Genotype] = sc.loadGenotypes(genotypeFile)
val genotypes: RDD[Genotype] = allGenotypes.filter(genotype => {panel.contains(genotype.getSampleId)})
// Convert the Genotype objects to our own SampleVariant objects to try and conserve memory
case class SampleVariant(sampleId: String, variantId: Int, alternateCount: Int)
def variantId(genotype: Genotype): String = {
val name = genotype.getVariant.getContig.getContigName
val start = genotype.getVariant.getStart
val end = genotype.getVariant.getEnd
s"$name:$start:$end"
}
def alternateCount(genotype: Genotype): Int = {
genotype.getAlleles.asScala.count(_ != GenotypeAllele.Ref)
}
def toVariant(genotype: Genotype): SampleVariant = {
// Intern sample IDs as they will be repeated a lot
new SampleVariant(genotype.getSampleId.intern(), variantId(genotype).hashCode(), alternateCount(genotype))
}
val variantsRDD: RDD[SampleVariant] = genotypes.map(toVariant)
// Group the variants by sample ID so we can process the variants sample-by-sample
// Then get the total number of samples. This will be used to find variants that are missing for some samples.
// Group the variants by variant ID and filter out those variants that are missing from some samples
val variantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsRDD.groupBy(_.sampleId)
val sampleCount: Long = variantsBySampleId.count()
println("Found " + sampleCount + " samples")
val variantsByVariantId: RDD[(Int, Iterable[SampleVariant])] = variantsRDD.groupBy(_.variantId).filter {
case (_, sampleVariants) => sampleVariants.size == sampleCount
}
// Make a map of variant ID -> count of samples with an alternate count of greater than zero
// then filter out those variants that are not in our desired frequency range. The objective here is simply to
// reduce the number of dimensions in the data set to make it easier to train the model.
// The specified range is fairly arbitrary and was chosen based on the fact that it includes a reasonable
// number of variants, but not too many.
val variantFrequencies: collection.Map[Int, Int] = variantsByVariantId.map {
case (variantId, sampleVariants) => (variantId, sampleVariants.count(_.alternateCount > 0))
}.collectAsMap()
val permittedRange = inclusive(11, 11)
val filteredVariantsBySampleId: RDD[(String, Iterable[SampleVariant])] = variantsBySampleId.map {
case (sampleId, sampleVariants) =>
val filteredSampleVariants = sampleVariants.filter(variant => permittedRange.contains(
variantFrequencies.getOrElse(variant.variantId, -1)))
(sampleId, filteredSampleVariants)
}
// Sort the variants for each sample ID. Each sample should now have the same number of sorted variants.
// All items in the RDD should now have the same variants in the same order so we can just use the first
// one to construct our header
// Next construct the rows of our SchemaRDD from the variants
val sortedVariantsBySampleId: RDD[(String, Array[SampleVariant])] = filteredVariantsBySampleId.map {
case (sampleId, variants) =>
(sampleId, variants.toArray.sortBy(_.variantId))
}
val header = StructType(Array(StructField("Region", StringType)) ++
sortedVariantsBySampleId.first()._2.map(variant => {StructField(variant.variantId.toString, IntegerType)}))
val rowRDD: RDD[Row] = sortedVariantsBySampleId.map {
case (sampleId, sortedVariants) =>
val region: Array[String] = Array(panel.getOrElse(sampleId, "Unknown"))
val alternateCounts: Array[Int] = sortedVariants.map(_.alternateCount)
Row.fromSeq(region ++ alternateCounts)
}
// Create the SchemaRDD from the header and rows and convert the SchemaRDD into a H2O dataframe
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val schemaRDD = sqlContext.applySchema(rowRDD, header)
val h2oContext = new H2OContext(sc).start()
import h2oContext._
val dataFrame = h2oContext.toDataFrame(schemaRDD)
// Split the dataframe into 50% training, 30% test, and 20% validation data
val frameSplitter = new FrameSplitter(dataFrame, Array(.5, .3), Array("training", "test", "validation").map(Key.make), null)
water.H2O.submitTask(frameSplitter)
val splits = frameSplitter.getResult
val training = splits(0)
val validation = splits(2)
// Set the parameters for our deep learning model.
val deepLearningParameters = new DeepLearningParameters()
deepLearningParameters._train = training
deepLearningParameters._valid = validation
deepLearningParameters._response_column = "Region"
deepLearningParameters._epochs = 10
deepLearningParameters._activation = Activation.RectifierWithDropout
deepLearningParameters._hidden = Array[Int](100,100)
// Train the deep learning model
val deepLearning = new DeepLearning(deepLearningParameters)
val deepLearningModel = deepLearning.trainModel.get
// Score the model against the entire dataset (training, test, and validation data)
// This causes the confusion matrix to be printed
deepLearningModel.score(dataFrame)('predict)
}
}