|
a |
|
b/ECG/main.swift |
|
|
1 |
// |
|
|
2 |
// main.swift |
|
|
3 |
// ECG |
|
|
4 |
// |
|
|
5 |
// Created by Dave Fernandes on 2019-03-04. |
|
|
6 |
// Copyright © 2019 MintLeaf Software Inc. All rights reserved. |
|
|
7 |
// |
|
|
8 |
|
|
|
9 |
import TensorFlow |
|
|
10 |
import Python |
|
|
11 |
import Foundation |
|
|
12 |
|
|
|
13 |
let batchSize: Int = 200 |
|
|
14 |
let maxEpochs: Int = 4 |
|
|
15 |
|
|
|
16 |
let (trainDataset, testDataset) = loadDatasets() |
|
|
17 |
let testBatches = testDataset.batched(1000) |
|
|
18 |
|
|
|
19 |
var model = ECGModel() |
|
|
20 |
let optimizer = Adam(for: model, learningRate: 0.001, decay: 0) |
|
|
21 |
|
|
|
22 |
// Training loop |
|
|
23 |
for epoch in 1...maxEpochs { |
|
|
24 |
print("Epoch \(epoch), training...") |
|
|
25 |
|
|
|
26 |
var trainingLossSum: Float = 0 |
|
|
27 |
var trainingBatchCount = 0 |
|
|
28 |
let trainingShuffled = trainDataset.shuffled(sampleCount: 500000, randomSeed: Int64(epoch)) |
|
|
29 |
let t0 = Date() |
|
|
30 |
|
|
|
31 |
// Loop over mini-batches in training set |
|
|
32 |
for batch in trainingShuffled.batched(batchSize) { |
|
|
33 |
let gradients = gradient(at: model) { |
|
|
34 |
(model: ECGModel) -> Tensor<Float> in |
|
|
35 |
|
|
|
36 |
let thisLoss = loss(model: model, examples: batch) |
|
|
37 |
trainingLossSum += thisLoss.scalarized() |
|
|
38 |
trainingBatchCount += 1 |
|
|
39 |
return thisLoss |
|
|
40 |
} |
|
|
41 |
optimizer.update(&model.allDifferentiableVariables, along: gradients) |
|
|
42 |
} |
|
|
43 |
|
|
|
44 |
let t1 = Date() |
|
|
45 |
print(" training loss: \(trainingLossSum / Float(trainingBatchCount)) step: \(trainingBatchCount * epoch) (\(t1.timeIntervalSince(t0)) sec)") |
|
|
46 |
|
|
|
47 |
var testLossSum: Float = 0 |
|
|
48 |
var testBatchCount = 0 |
|
|
49 |
|
|
|
50 |
// Loop over test set |
|
|
51 |
for batch in testBatches { |
|
|
52 |
testLossSum += loss(model: model, examples: batch).scalarized() |
|
|
53 |
testBatchCount += 1 |
|
|
54 |
} |
|
|
55 |
print(" test loss: \(testLossSum / Float(testBatchCount))") |
|
|
56 |
} |
|
|
57 |
|
|
|
58 |
// Print metrics and confusion matrix |
|
|
59 |
var yActual = [Int32]() |
|
|
60 |
var yPredicted = [Int32]() |
|
|
61 |
|
|
|
62 |
for batch in testBatches { |
|
|
63 |
let labelValues = batch.labels.scalars |
|
|
64 |
let predictedValues = model.predictedClasses(for: batch.series).scalars |
|
|
65 |
yActual.append(contentsOf: labelValues) |
|
|
66 |
yPredicted.append(contentsOf: predictedValues) |
|
|
67 |
} |
|
|
68 |
|
|
|
69 |
let skm = Python.import("sklearn.metrics") |
|
|
70 |
let report = skm.classification_report(yActual, yPredicted) |
|
|
71 |
print(report) |
|
|
72 |
let confusionMatrix = skm.confusion_matrix(yActual, yPredicted) |
|
|
73 |
print(confusionMatrix) |