|
a |
|
b/ECG/loadDatasets.swift |
|
|
1 |
// |
|
|
2 |
// loadDatasets.swift |
|
|
3 |
// ECG |
|
|
4 |
// |
|
|
5 |
// Created by Dave Fernandes on 2019-03-05. |
|
|
6 |
// Copyright © 2019 MintLeaf Software Inc. All rights reserved. |
|
|
7 |
// |
|
|
8 |
|
|
|
9 |
import Python |
|
|
10 |
import TensorFlow |
|
|
11 |
|
|
|
12 |
struct Example: TensorGroup { |
|
|
13 |
var labels: Tensor<Int32> |
|
|
14 |
var series: Tensor<Float> |
|
|
15 |
} |
|
|
16 |
|
|
|
17 |
func loadTimeSeries(from path: String) -> Example { |
|
|
18 |
let pickle = Python.import("pickle") |
|
|
19 |
let file = Python.open(path, "rb") |
|
|
20 |
let pyDict = pickle.load(file, encoding: "bytes") |
|
|
21 |
let dict = Dictionary<String, PythonObject>(pyDict) |
|
|
22 |
|
|
|
23 |
guard let series = dict?["x"], |
|
|
24 |
let labels = dict?["y"] else { |
|
|
25 |
fatalError() |
|
|
26 |
} |
|
|
27 |
|
|
|
28 |
let labelsTensor = Tensor<Int64>(numpy: labels)! |
|
|
29 |
let seriesTensor = Tensor<Float64>(numpy: series)! |
|
|
30 |
return Example(labels: Tensor<Int32>(labelsTensor), series: Tensor<Float32>(seriesTensor)) |
|
|
31 |
} |
|
|
32 |
|
|
|
33 |
func loadDatasets() -> (training: Dataset<Example>, test: Dataset<Example>) { |
|
|
34 |
let trainingDataset = Dataset<Example>(elements: loadTimeSeries(from: "train_set.pickle")) |
|
|
35 |
let testDataset = Dataset<Example>(elements: loadTimeSeries(from: "test_set.pickle")) |
|
|
36 |
return (training: trainingDataset, test: testDataset) |
|
|
37 |
} |