Diff of /LoadBatches1D.py [000000] .. [eaa663]

Switch to unified view

a b/LoadBatches1D.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Sun Apr 21 13:46:44 2019
4
5
@author: Winham
6
7
LoadBatches1D.py: 迭代生成训练时的batch
8
9
实现参考:https://github.com/divamgupta/image-segmentation-keras/blob/master/LoadBatches.py
10
11
"""
12
13
import os
14
import itertools
15
import numpy as np
16
from sklearn import preprocessing as prep
17
18
19
def getSigArr(path, sigNorm='scale'):
20
    sig = np.load(path)
21
    if sigNorm == 'scale':
22
        sig = prep.scale(sig)
23
    elif sigNorm == 'minmax':
24
        min_max_scaler = prep.MinMaxScaler()
25
        sig = min_max_scaler.fit_transform(sig)
26
    return np.expand_dims(sig, axis=1)
27
28
29
def getSegmentationArr(path, nClasses=3, output_length=1800, class_value=[0, 0.5, 1]):
30
    # class_value是在generate_labels.py中定义的,背景0,正常0.5,早搏1,必须要保持一致
31
    seg_labels = np.zeros([output_length, nClasses])
32
    seg = np.load(path)
33
    for i in range(nClasses):
34
        seg_labels[:, i] = (seg == class_value[i]).astype(float)
35
    return seg_labels
36
37
38
def SigSegmentationGenerator(sigs_path, segs_path, batch_size, n_classes, output_length=1800):
39
    sigs = os.listdir(sigs_path)
40
    segmentations = os.listdir(segs_path)
41
    sigs.sort()
42
    segmentations.sort()
43
    for i in range(len(sigs)):
44
        sigs[i] = sigs_path + sigs[i]
45
        segmentations[i] = segs_path + segmentations[i]
46
    assert len(sigs) == len(segmentations)
47
    for sig, seg in zip(sigs, segmentations):
48
        assert (sig.split('/')[-1].split(".")[0] == seg.split('/')[-1].split(".")[0])
49
    zipped = itertools.cycle(zip(sigs, segmentations))
50
    while True:
51
        X = []
52
        Y = []
53
        for _ in range(batch_size):
54
            sig, seg = next(zipped)
55
            X.append(getSigArr(sig))
56
            Y.append(getSegmentationArr(seg, n_classes, output_length))
57
        yield np.array(X), np.array(Y)