--- a
+++ b/data_process.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+'''
+@time: 2019/9/8 18:44
+数据预处理:
+    1.构建label2index和index2label
+    2.划分数据集
+@ author: javis
+'''
+import os, torch
+import numpy as np
+from config import config
+
+# 保证每次划分数据一致
+np.random.seed(41)
+
+
+def name2index(path):
+    '''
+    把类别名称转换为index索引
+    :param path: 文件路径
+    :return: 字典
+    '''
+    list_name = []
+    for line in open(path, encoding='utf-8'):
+        list_name.append(line.strip())
+    name2indx = {name: i for i, name in enumerate(list_name)}
+    return name2indx
+
+
+def split_data(file2idx, val_ratio=0.1):
+    '''
+    划分数据集,val需保证每类至少有1个样本
+    :param file2idx:
+    :param val_ratio:验证集占总数据的比例
+    :return:训练集,验证集路径
+    '''
+    data = set(os.listdir(config.train_dir))
+    val = set()
+    idx2file = [[] for _ in range(config.num_classes)]
+    for file, list_idx in file2idx.items():
+        for idx in list_idx:
+            idx2file[idx].append(file)
+    for item in idx2file:
+        # print(len(item), item)
+        num = int(len(item) * val_ratio)
+        val = val.union(item[:num])
+    train = data.difference(val)
+    return list(train), list(val)
+
+
+def file2index(path, name2idx):
+    '''
+    获取文件id对应的标签类别
+    :param path:文件路径
+    :return:文件id对应label列表的字段
+    '''
+    file2index = dict()
+    for line in open(path, encoding='utf-8'):
+        arr = line.strip().split('\t')
+        id = arr[0]
+        labels = [name2idx[name] for name in arr[3:]]
+        # print(id, labels)
+        file2index[id] = labels
+    return file2index
+
+
+def count_labels(data, file2idx):
+    '''
+    统计每个类别的样本数
+    :param data:
+    :param file2idx:
+    :return:
+    '''
+    cc = [0] * config.num_classes
+    for fp in data:
+        for i in file2idx[fp]:
+            cc[i] += 1
+    return np.array(cc)
+
+
+def train(name2idx, idx2name):
+    file2idx = file2index(config.train_label, name2idx)
+    train, val = split_data(file2idx)
+    wc=count_labels(train,file2idx)
+    print(wc)
+    dd = {'train': train, 'val': val, "idx2name": idx2name, 'file2idx': file2idx,'wc':wc}
+    torch.save(dd, config.train_data)
+
+
+if __name__ == '__main__':
+    pass
+    name2idx = name2index(config.arrythmia)
+    idx2name = {idx: name for name, idx in name2idx.items()}
+    train(name2idx, idx2name)