--- a +++ b/ECG/loadDatasets.swift @@ -0,0 +1,37 @@ +// +// loadDatasets.swift +// ECG +// +// Created by Dave Fernandes on 2019-03-05. +// Copyright © 2019 MintLeaf Software Inc. All rights reserved. +// + +import Python +import TensorFlow + +struct Example: TensorGroup { + var labels: Tensor<Int32> + var series: Tensor<Float> +} + +func loadTimeSeries(from path: String) -> Example { + let pickle = Python.import("pickle") + let file = Python.open(path, "rb") + let pyDict = pickle.load(file, encoding: "bytes") + let dict = Dictionary<String, PythonObject>(pyDict) + + guard let series = dict?["x"], + let labels = dict?["y"] else { + fatalError() + } + + let labelsTensor = Tensor<Int64>(numpy: labels)! + let seriesTensor = Tensor<Float64>(numpy: series)! + return Example(labels: Tensor<Int32>(labelsTensor), series: Tensor<Float32>(seriesTensor)) +} + +func loadDatasets() -> (training: Dataset<Example>, test: Dataset<Example>) { + let trainingDataset = Dataset<Example>(elements: loadTimeSeries(from: "train_set.pickle")) + let testDataset = Dataset<Example>(elements: loadTimeSeries(from: "test_set.pickle")) + return (training: trainingDataset, test: testDataset) +}