|
a |
|
b/neuroqwerty-mit-csxpd-dataset-1.0.0/nqDataLoader.py |
|
|
1 |
# -*- coding: utf-8 -*- |
|
|
2 |
|
|
|
3 |
# set modules dir |
|
|
4 |
import numpy as np |
|
|
5 |
import sys, os, re, datetime |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class NqDataLoader: |
|
|
9 |
FLT_NO_MOUSE = 1 << 0 |
|
|
10 |
FLT_NO_LETTERS = 1 << 1 |
|
|
11 |
FLT_NO_BACK = 1 << 2 |
|
|
12 |
FLT_NO_SHORT_META = 1 << 3 # space, enter, arrows, etc. |
|
|
13 |
FLT_NO_LONG_META = 1 << 4 # shift, control, alt, ect. |
|
|
14 |
FLT_NO_PUNCT = 1 << 5 |
|
|
15 |
|
|
|
16 |
def __init__(self): |
|
|
17 |
self.dataKeys = None |
|
|
18 |
self.dataHT = None |
|
|
19 |
self.dataTimeStart = None |
|
|
20 |
self.dataTimeEnd = None |
|
|
21 |
pass |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
def sanityCheck( self ): |
|
|
25 |
""" |
|
|
26 |
Filter out keystrokes variables in the member variables. |
|
|
27 |
Eliminate anything < 0. |
|
|
28 |
returns the number of elements removed |
|
|
29 |
""" |
|
|
30 |
assert( self.dataKeys is not None and len(self.dataKeys) > 0 ) |
|
|
31 |
assert( self.dataHT is not None and len(self.dataHT) > 0 ) |
|
|
32 |
assert( self.dataTimeStart is not None and len(self.dataTimeStart) > 0 ) |
|
|
33 |
assert( self.dataTimeEnd is not None and len(self.dataTimeEnd) > 0 ) |
|
|
34 |
|
|
|
35 |
badLbl = self.dataTimeStart <= 0 |
|
|
36 |
badLbl = np.bitwise_or( badLbl, self.dataTimeEnd <= 0) |
|
|
37 |
badLbl = np.bitwise_or( badLbl, self.dataHT < 0) |
|
|
38 |
badLbl = np.bitwise_or( badLbl, self.dataHT >= 5) |
|
|
39 |
#----- remove non consecutive start times |
|
|
40 |
nonConsTmpLbl = np.zeros( len(self.dataTimeStart) ) == 0 # start with all True labels |
|
|
41 |
nonConsLbl = np.zeros( len(self.dataTimeStart) ) > 0 # start with all False labels |
|
|
42 |
startTmpArr = self.dataTimeStart.copy() |
|
|
43 |
while ( np.sum( nonConsTmpLbl ) > 0 ): |
|
|
44 |
# find non consecutive labels |
|
|
45 |
nonConsTmpLbl = np.append([False], np.diff(startTmpArr)<0) |
|
|
46 |
# keep track of the indeces to remove |
|
|
47 |
nonConsLbl = np.bitwise_or( nonConsLbl, nonConsTmpLbl) |
|
|
48 |
# changes value in the temporary array |
|
|
49 |
indecesToChange = np.arange(len(nonConsTmpLbl))[nonConsTmpLbl] |
|
|
50 |
startTmpArr[indecesToChange] = startTmpArr[indecesToChange-1] |
|
|
51 |
|
|
|
52 |
badLbl = np.bitwise_or( badLbl, nonConsLbl) |
|
|
53 |
#----- |
|
|
54 |
|
|
|
55 |
# invert bad labels |
|
|
56 |
goodLbl = np.bitwise_not(badLbl) |
|
|
57 |
|
|
|
58 |
self.dataKeys = self.dataKeys[goodLbl] |
|
|
59 |
self.dataHT = self.dataHT[goodLbl] |
|
|
60 |
self.dataTimeStart = self.dataTimeStart[goodLbl] |
|
|
61 |
self.dataTimeEnd = self.dataTimeEnd[goodLbl] |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
return sum(badLbl) |
|
|
65 |
|
|
|
66 |
def loadDataFile(self, fileIn, autoFilt=True, impType=None, debug=False): |
|
|
67 |
""" |
|
|
68 |
Load raw data file |
|
|
69 |
""" |
|
|
70 |
errorStr = '' |
|
|
71 |
try: |
|
|
72 |
data = [] |
|
|
73 |
|
|
|
74 |
# if data.dtype == np.int64: # Sleep inertia format |
|
|
75 |
if impType =='si': |
|
|
76 |
data = np.genfromtxt(fileIn, dtype=long, delimiter=',', skip_header=0) |
|
|
77 |
data = data - data.min() |
|
|
78 |
data = data.astype(np.float64) / 1000 |
|
|
79 |
self.dataTimeStart = data[:,0] |
|
|
80 |
self.dataTimeEnd = data[:,1] |
|
|
81 |
self.dataHT = self.dataTimeEnd - self.dataTimeStart |
|
|
82 |
#TO REMOVE |
|
|
83 |
self.dataKeys = np.zeros(len(self.dataHT))#Just to make sanity work |
|
|
84 |
remNum = self.sanityCheck() |
|
|
85 |
#print remNum |
|
|
86 |
else: # PD format |
|
|
87 |
data = np.genfromtxt(fileIn, dtype=None, delimiter=',', skip_header=0) |
|
|
88 |
# load |
|
|
89 |
self.dataKeys = data['f0'] |
|
|
90 |
self.dataHT = data['f1'] |
|
|
91 |
self.dataTimeStart = data['f3'] #No CHANGED 2<->3 |
|
|
92 |
self.dataTimeEnd = data['f2'] |
|
|
93 |
remNum = self.sanityCheck() |
|
|
94 |
#print '{:}, {:} %'.format( remNum, 1.0*remNum/len(self.dataHT) ) |
|
|
95 |
|
|
|
96 |
if (debug): |
|
|
97 |
print 'removed ', str(remNum), ' elements' |
|
|
98 |
|
|
|
99 |
if( autoFilt ): |
|
|
100 |
self.filtData(self.FLT_NO_MOUSE | self.FLT_NO_LONG_META ) |
|
|
101 |
|
|
|
102 |
# load flight time |
|
|
103 |
self.dataFT = np.array([ self.dataTimeStart[i]-self.dataTimeStart[i-1] for i in range(1,self.dataTimeStart.size) ]) |
|
|
104 |
self.dataFT = np.append(self.dataFT, 0) |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
|
|
|
108 |
return True |
|
|
109 |
except IOError: |
|
|
110 |
errorStr = 'file {:s} not found'.format(fileIn) |
|
|
111 |
return errorStr |
|
|
112 |
def loadDataArr(self, lstArr): |
|
|
113 |
self.dataKeys = np.zeros((len(lstArr),1), dtype='S30') |
|
|
114 |
self.dataHT = np.zeros((len(lstArr),1)) |
|
|
115 |
self.dataTimeStart = np.zeros((len(lstArr),1)) |
|
|
116 |
self.dataTimeEnd =np.zeros((len(lstArr),1)) |
|
|
117 |
i = 0 |
|
|
118 |
for row in lstArr: |
|
|
119 |
tok = row.split(',') |
|
|
120 |
self.dataKeys[i] = str(tok[0]) |
|
|
121 |
self.dataHT[i] = str(tok[1]) |
|
|
122 |
self.dataTimeStart[i] = str(tok[2]) |
|
|
123 |
self.dataTimeEnd[i] = str(tok[3]) |
|
|
124 |
i += 1 |
|
|
125 |
|
|
|
126 |
#self.loadDataFile(lstArr.toString()) |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
def filtData(self, flags): |
|
|
130 |
""" |
|
|
131 |
Filter data |
|
|
132 |
return (fltKeys, fltHT, fltTimeStart, fltTimeEnd) |
|
|
133 |
""" |
|
|
134 |
#-- filters |
|
|
135 |
pMouse=re.compile('("mouse.+")') |
|
|
136 |
pChar=re.compile('(".{1}")') |
|
|
137 |
pBack=re.compile('("BackSpace")') |
|
|
138 |
pLongMeta=re.compile('("Shift.+")|("Alt.+")|("Control.+")') |
|
|
139 |
pShortMeta=re.compile('("space")|("Num_Lock")|("Return")|("P_Enter")|("Caps_Lock")|("Left")|("Right")|("Up")|("Down")') |
|
|
140 |
pPunct=re.compile('("more")|("less")|("exclamdown")|("comma")|("\[65027\]")|("\[65105\]")|("ntilde")|("minus")|("equal")|("bracketleft")|("bracketright")|("semicolon")|("backslash")|("apostrophe")|("comma")|("period")|("slash")|("grave")') |
|
|
141 |
#-- |
|
|
142 |
|
|
|
143 |
#-- create mask labels |
|
|
144 |
lbl = np.ones(len( self.dataKeys ))==1 |
|
|
145 |
if( flags & self.FLT_NO_MOUSE ): |
|
|
146 |
lblTmp = [ pMouse.match( k ) is None for k in self.dataKeys] |
|
|
147 |
lbl = lbl & lblTmp |
|
|
148 |
if( flags & self.FLT_NO_LETTERS ): |
|
|
149 |
lblTmp = [ pChar.match( k ) is None for k in self.dataKeys] |
|
|
150 |
lbl = lbl & lblTmp |
|
|
151 |
if( flags & self.FLT_NO_BACK ): |
|
|
152 |
lblTmp = [ pBack.match( k ) is None for k in self.dataKeys] |
|
|
153 |
lbl = lbl & lblTmp |
|
|
154 |
if( flags & self.FLT_NO_SHORT_META ): |
|
|
155 |
lblTmp = [ pShortMeta.match( k ) is None for k in self.dataKeys] |
|
|
156 |
lbl = lbl & lblTmp |
|
|
157 |
if( flags & self.FLT_NO_LONG_META ): |
|
|
158 |
lblTmp = [ pLongMeta.match( k ) is None for k in self.dataKeys] |
|
|
159 |
lbl = lbl & lblTmp |
|
|
160 |
if( flags & self.FLT_NO_PUNCT ): |
|
|
161 |
lblTmp = [ pPunct.match( k ) is None for k in self.dataKeys] |
|
|
162 |
lbl = lbl & lblTmp |
|
|
163 |
#-- |
|
|
164 |
|
|
|
165 |
self.lbl = lbl |
|
|
166 |
|
|
|
167 |
self.dataKeys = self.dataKeys[lbl] |
|
|
168 |
self.dataHT = self.dataHT[lbl] |
|
|
169 |
self.dataTimeStart = self.dataTimeStart[lbl] |
|
|
170 |
self.dataTimeEnd = self.dataTimeEnd[lbl] |
|
|
171 |
|
|
|
172 |
def getStdVariablesFilt( fileIn, impType=None ): |
|
|
173 |
""" |
|
|
174 |
Receives as parameter the location of the raw typing file |
|
|
175 |
Return filtered variables (i.e. no mouse clicks, no long meta buttons, no backspaces) |
|
|
176 |
format returned (array of keys, array of hold times, array of press events timestamps, array of release events timestamps ) |
|
|
177 |
""" |
|
|
178 |
nqObj = self |
|
|
179 |
res = nqObj.loadDataFile( fileIn, False, impType) |
|
|
180 |
# remove delete button |
|
|
181 |
nqObj.filtData(nqObj.FLT_NO_MOUSE | nqObj.FLT_NO_LONG_META | nqObj.FLT_NO_BACK ) |
|
|
182 |
assert(res==True) # make sure the file exists |
|
|
183 |
dataKeys = nqObj.dataKeys |
|
|
184 |
dataHT = nqObj.dataHT |
|
|
185 |
dataTimeStart = nqObj.dataTimeStart |
|
|
186 |
dataTimeEnd = nqObj.dataTimeEnd |
|
|
187 |
|
|
|
188 |
return dataKeys, dataHT, dataTimeStart, dataTimeEnd |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
def getDataFiltHelper( fileIn, impType=None ): |
|
|
192 |
""" |
|
|
193 |
Helper method to load filtered keypress data from given file |
|
|
194 |
:param fileIn: path to csv keypress file |
|
|
195 |
:param impType: format of the csv file ('si': for sleep inertia data, None for PD data) |
|
|
196 |
:return: list of array with dataKeys, dataHT, dataTimeStart, dataTimeEnd |
|
|
197 |
""" |
|
|
198 |
nqObj = NqDataLoader() |
|
|
199 |
res = nqObj.loadDataFile( fileIn, False, impType) |
|
|
200 |
# remove delete button |
|
|
201 |
nqObj.filtData(nqObj.FLT_NO_MOUSE | nqObj.FLT_NO_LONG_META | nqObj.FLT_NO_BACK ) |
|
|
202 |
assert(res==True) # make sure the file exists |
|
|
203 |
dataKeys = nqObj.dataKeys |
|
|
204 |
dataHT = nqObj.dataHT |
|
|
205 |
dataTimeStart = nqObj.dataTimeStart |
|
|
206 |
dataTimeEnd = nqObj.dataTimeEnd |
|
|
207 |
|
|
|
208 |
return dataKeys, dataHT, dataTimeStart, dataTimeEnd |
|
|
209 |
|
|
|
210 |
|
|
|
211 |
def genFileStruct( dataDir, maxRepNum=4 ): |
|
|
212 |
''' |
|
|
213 |
Generate a dictionary with the NQ file list and test date (legacy method) |
|
|
214 |
:param dataDir: base directory containing the CSV files |
|
|
215 |
:param maxRepNum: integer with the maximum repetition number |
|
|
216 |
:return: two dictionaries: fMap, dateMap = NQ file/date list[pID][repID][expID] |
|
|
217 |
''' |
|
|
218 |
fMap = {} # data container |
|
|
219 |
dateMap = {} |
|
|
220 |
files = os.listdir( dataDir ) |
|
|
221 |
p = re.compile( '([0-9]+)\.{1}([0-9]+)_([0-9]+)_([0-9]+)\.csv' ) |
|
|
222 |
for f in files: |
|
|
223 |
m = p.match( f ) |
|
|
224 |
|
|
|
225 |
if( m ): # file found |
|
|
226 |
timeStamp = m.group(1) |
|
|
227 |
pID = int(m.group(2)) |
|
|
228 |
repID = int(m.group(3)) |
|
|
229 |
expID = int(m.group(4)) |
|
|
230 |
# store new patient |
|
|
231 |
if( not fMap.has_key(pID) ): |
|
|
232 |
fMap[pID] = {} |
|
|
233 |
dateMap[pID] = {} |
|
|
234 |
for tmpRid in range(1, maxRepNum+1): |
|
|
235 |
fMap[pID][tmpRid] = {} |
|
|
236 |
dateMap[pID][tmpRid] = {} |
|
|
237 |
# fMap[pID] = {1: {}, 2: {}, 3: {}, 4:{}} |
|
|
238 |
# store data |
|
|
239 |
fMap[pID][repID][expID] = dataDir + f |
|
|
240 |
dateMap[pID][repID][expID] = datetime.datetime.fromtimestamp(int(timeStamp)) |
|
|
241 |
else: |
|
|
242 |
print f, ' no' |
|
|
243 |
|
|
|
244 |
return fMap, dateMap |