a b/code/data_provider.py
1
import numpy as np
2
import os
3
# import tensorflow as tf
4
import h5py
5
6
7
def load_h5_all(file, is_training):
8
    hf = h5py.File(file, 'r+')
9
    label = hf['label'][:][:]
10
    num_samples = len(label)
11
    train_size = num_samples - test_size
12
    feat = hf['feature'][:][:, :]
13
    gene = hf['gene_name'][:]
14
    sample = hf['sample'][:]
15
    print('%s has data:', feat.shape)
16
    # train_dataset = tf.data.Dataset.from_tensor_slices((feat[:train_size, :], label[:train_size]))     #not using now
17
    # test_dataset = tf.data.Dataset.from_tensor_slices((feat[-test_size:], label[-test_size:]))     #not using now
18
    # train_dataset = tf.data.Dataset.from_generator((feat0[:train_size, :], label[:train_size]))
19
    # test_dataset = tf.data.Dataset.from_generator((feat0[-test_size:], label[-test_size:]))
20
21
    return feat, label, gene, sample
22
23
24
25
if __name__ == '__main__':
26
    m_rna, label, gene, sample_id = load_h5_all('../data_process/tcga.h5', True)
27