Diff of /ECG/ConvModel.swift [000000] .. [c49678]

Switch to unified view

a b/ECG/ConvModel.swift
1
//
2
//  ConvModel.swift
3
//  ECG
4
//
5
//  Created by Dave Fernandes on 2019-03-04.
6
//  Copyright © 2019 MintLeaf Software Inc. All rights reserved.
7
//
8
9
// Model is from: https://arxiv.org/pdf/1805.00794.pdf
10
11
import TensorFlow
12
13
public struct ConvUnit<Scalar: TensorFlowFloatingPoint> : Layer {
14
    var conv1: Conv1D<Scalar>
15
    var conv2: Conv1D<Scalar>
16
    var pool: MaxPool1D<Scalar>
17
    
18
    public init(kernelSize: Int, channels: Int) {
19
        conv1 = Conv1D<Scalar>(filterShape: (kernelSize, channels, channels), padding: .same, activation: relu)
20
        conv2 = Conv1D<Scalar>(filterShape: (kernelSize, channels, channels), padding: .same)
21
        pool = MaxPool1D<Scalar>(poolSize: kernelSize, stride: 2, padding: .valid)
22
    }
23
    
24
    @differentiable
25
    public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
26
        var tmp = input.sequenced(through: conv1, conv2)
27
        tmp = pool(relu(tmp + input))
28
        return tmp
29
    }
30
}
31
32
public struct ConvModel : Layer {
33
    var conv1: Conv1D<Float>
34
    var convUnit = [ConvUnit<Float>]()
35
    var dense1: Dense<Float>
36
    var dense2: Dense<Float>
37
    
38
    @noDerivative let convUnitCount = 5
39
    
40
    public init() {
41
        conv1 = Conv1D<Float>(filterShape: (5, 1, 32), stride: 1, padding: .same)
42
        for _ in 0..<convUnitCount {
43
            convUnit.append(ConvUnit<Float>(kernelSize: 5, channels: 32))
44
        }
45
        dense1 = Dense<Float>(inputSize: 64, outputSize: 32, activation: relu)
46
        dense2 = Dense<Float>(inputSize: 32, outputSize: 5)
47
    }
48
    
49
    @differentiable
50
    public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
51
        var tmp = conv1(input.expandingShape(at: 2))
52
        
53
        for i in 0..<convUnitCount {
54
            let unit = convUnit[i]
55
            tmp = unit(tmp)
56
        }
57
        
58
        tmp = tmp.reshaped(to: [-1, 64])
59
        tmp = tmp.sequenced(through: dense1, dense2)
60
        return tmp
61
    }
62
    
63
    public func predictedClasses(for input: Tensor<Float>) -> Tensor<Int32> {
64
        return model.inferring(from: input).argmax(squeezingAxis: 1)
65
    }
66
}
67
68
typealias ECGModel = ConvModel
69
70
@differentiable(wrt: model)
71
func loss(model: ECGModel, examples: Example) -> Tensor<Float> {
72
    let logits = model(examples.series)
73
    return softmaxCrossEntropy(logits: logits, labels: examples.labels)
74
}