[61c0d0]: / ECG / main.swift

Download this file

74 lines (59 with data), 2.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
//
// main.swift
// ECG
//
// Created by Dave Fernandes on 2019-03-04.
// Copyright © 2019 MintLeaf Software Inc. All rights reserved.
//
import TensorFlow
import Python
import Foundation
let batchSize: Int = 200
let maxEpochs: Int = 4
let (trainDataset, testDataset) = loadDatasets()
let testBatches = testDataset.batched(1000)
var model = ECGModel()
let optimizer = Adam(for: model, learningRate: 0.001, decay: 0)
// Training loop
for epoch in 1...maxEpochs {
print("Epoch \(epoch), training...")
var trainingLossSum: Float = 0
var trainingBatchCount = 0
let trainingShuffled = trainDataset.shuffled(sampleCount: 500000, randomSeed: Int64(epoch))
let t0 = Date()
// Loop over mini-batches in training set
for batch in trainingShuffled.batched(batchSize) {
let gradients = gradient(at: model) {
(model: ECGModel) -> Tensor<Float> in
let thisLoss = loss(model: model, examples: batch)
trainingLossSum += thisLoss.scalarized()
trainingBatchCount += 1
return thisLoss
}
optimizer.update(&model.allDifferentiableVariables, along: gradients)
}
let t1 = Date()
print(" training loss: \(trainingLossSum / Float(trainingBatchCount)) step: \(trainingBatchCount * epoch) (\(t1.timeIntervalSince(t0)) sec)")
var testLossSum: Float = 0
var testBatchCount = 0
// Loop over test set
for batch in testBatches {
testLossSum += loss(model: model, examples: batch).scalarized()
testBatchCount += 1
}
print(" test loss: \(testLossSum / Float(testBatchCount))")
}
// Print metrics and confusion matrix
var yActual = [Int32]()
var yPredicted = [Int32]()
for batch in testBatches {
let labelValues = batch.labels.scalars
let predictedValues = model.predictedClasses(for: batch.series).scalars
yActual.append(contentsOf: labelValues)
yPredicted.append(contentsOf: predictedValues)
}
let skm = Python.import("sklearn.metrics")
let report = skm.classification_report(yActual, yPredicted)
print(report)
let confusionMatrix = skm.confusion_matrix(yActual, yPredicted)
print(confusionMatrix)