--- a +++ b/code/utils.py @@ -0,0 +1,44 @@ +import h5py +import numpy as np +import random + +WINDOW_SIZE = 100 + +def rescale_array(X): + X = X / 20 + X = np.clip(X, -5, 5) + return X + + +def aug_X(X): + scale = 1 + np.random.uniform(-0.1, 0.1) + offset = np.random.uniform(-0.1, 0.1) + noise = np.random.normal(scale=0.05, size=X.shape) + X = scale * X + offset + noise + return X + +def gen(dict_files, aug=False): + while True: + record_name = random.choice(list(dict_files.keys())) + batch_data = dict_files[record_name] + all_rows = batch_data['x'] + + for i in range(10): + start_index = random.choice(range(all_rows.shape[0]-WINDOW_SIZE)) + + X = all_rows[start_index:start_index+WINDOW_SIZE, ...] + Y = batch_data['y'][start_index:start_index+WINDOW_SIZE] + + X = np.expand_dims(X, 0) + Y = np.expand_dims(Y, -1) + Y = np.expand_dims(Y, 0) + + if aug: + X = aug_X(X) + X = rescale_array(X) + + yield X, Y + + +def chunker(seq, size=WINDOW_SIZE): + return (seq[pos:pos + size] for pos in range(0, len(seq), size)) \ No newline at end of file