a b/genInfos.py
1
# -*- coding: utf-8 -*-
2
"""
3
Created on Thu Aug 13 21:35:28 2015.
4
5
@author: fornax
6
"""
7
import numpy as np
8
import pandas as pd
9
from glob import glob
10
from mne import concatenate_raws
11
12
from preprocessing.aux import creat_mne_raw_object
13
14
# #### define lists #####
15
subjects = range(1, 13)
16
17
lbls_tot = []
18
subjects_val_tot = []
19
series_val_tot = []
20
21
ids_tot = []
22
subjects_test_tot = []
23
series_test_tot = []
24
25
# #### generate predictions #####
26
for subject in subjects:
27
    print 'Loading data for subject %d...' % subject
28
    # ############### READ DATA ###############################################
29
    fnames = glob('data/train/subj%d_series*_data.csv' % (subject))
30
    fnames.sort()
31
    fnames_val = fnames[-2:]
32
33
    fnames_test = glob('data/test/subj%d_series*_data.csv' % (subject))
34
    fnames_test.sort()
35
36
    raw_val = concatenate_raws([creat_mne_raw_object(fname, read_events=True)
37
                                for fname in fnames_val])
38
    raw_test = concatenate_raws([creat_mne_raw_object(fname, read_events=False)
39
                                for fname in fnames_test])
40
41
    # extract labels for series 7&8
42
    labels = raw_val._data[32:]
43
    lbls_tot.append(labels.transpose())
44
45
    # aggregate infos for validation (series 7&8)
46
    raw_series7 = creat_mne_raw_object(fnames_val[0])
47
    raw_series8 = creat_mne_raw_object(fnames_val[1])
48
    series = np.array([7] * raw_series7.n_times +
49
                      [8] * raw_series8.n_times)
50
    series_val_tot.append(series)
51
52
    subjs = np.array([subject]*labels.shape[1])
53
    subjects_val_tot.append(subjs)
54
55
    # aggregate infos for test (series 9&10)
56
    ids = np.concatenate([np.array(pd.read_csv(fname)['id'])
57
                         for fname in fnames_test])
58
    ids_tot.append(ids)
59
    raw_series9 = creat_mne_raw_object(fnames_test[1], read_events=False)
60
    raw_series10 = creat_mne_raw_object(fnames_test[0], read_events=False)
61
    series = np.array([10] * raw_series10.n_times +
62
                      [9] * raw_series9.n_times)
63
    series_test_tot.append(series)
64
65
    subjs = np.array([subject]*raw_test.n_times)
66
    subjects_test_tot.append(subjs)
67
68
69
# save validation infos
70
subjects_val_tot = np.concatenate(subjects_val_tot)
71
series_val_tot = np.concatenate(series_val_tot)
72
lbls_tot = np.concatenate(lbls_tot)
73
toSave = np.c_[lbls_tot, subjects_val_tot, series_val_tot]
74
np.save('infos_val.npy', toSave)
75
76
# save test infos
77
subjects_test_tot = np.concatenate(subjects_test_tot)
78
series_test_tot = np.concatenate(series_test_tot)
79
ids_tot = np.concatenate(ids_tot)
80
toSave = np.c_[ids_tot, subjects_test_tot, series_test_tot]
81
np.save('infos_test.npy', toSave)