Diff of /ECG/loadDatasets.swift [000000] .. [c49678]

Switch to unified view

a b/ECG/loadDatasets.swift
1
//
2
//  loadDatasets.swift
3
//  ECG
4
//
5
//  Created by Dave Fernandes on 2019-03-05.
6
//  Copyright © 2019 MintLeaf Software Inc. All rights reserved.
7
//
8
9
import Python
10
import TensorFlow
11
12
struct Example: TensorGroup {
13
    var labels: Tensor<Int32>
14
    var series: Tensor<Float>
15
}
16
17
func loadTimeSeries(from path: String) -> Example {
18
    let pickle = Python.import("pickle")
19
    let file = Python.open(path, "rb")
20
    let pyDict = pickle.load(file, encoding: "bytes")
21
    let dict = Dictionary<String, PythonObject>(pyDict)
22
23
    guard let series = dict?["x"],
24
        let labels = dict?["y"] else {
25
        fatalError()
26
    }
27
28
    let labelsTensor = Tensor<Int64>(numpy: labels)!
29
    let seriesTensor = Tensor<Float64>(numpy: series)!
30
    return Example(labels: Tensor<Int32>(labelsTensor), series: Tensor<Float32>(seriesTensor))
31
}
32
33
func loadDatasets() -> (training: Dataset<Example>, test: Dataset<Example>) {
34
    let trainingDataset = Dataset<Example>(elements: loadTimeSeries(from: "train_set.pickle"))
35
    let testDataset = Dataset<Example>(elements: loadTimeSeries(from: "test_set.pickle"))
36
    return (training: trainingDataset, test: testDataset)
37
}