Switch to unified view

a b/neuroqwerty-mit-csxpd-dataset-1.0.0/nqDataLoader.py
1
# -*- coding: utf-8 -*-
2
3
# set modules  dir
4
import numpy as np
5
import sys, os, re, datetime
6
7
8
class NqDataLoader:
9
    FLT_NO_MOUSE = 1 << 0
10
    FLT_NO_LETTERS = 1 << 1
11
    FLT_NO_BACK = 1 << 2
12
    FLT_NO_SHORT_META = 1 << 3    # space, enter, arrows, etc.
13
    FLT_NO_LONG_META = 1 << 4 # shift, control, alt, ect.
14
    FLT_NO_PUNCT = 1 << 5
15
    
16
    def __init__(self):
17
        self.dataKeys = None
18
        self.dataHT = None
19
        self.dataTimeStart = None
20
        self.dataTimeEnd = None
21
        pass
22
    
23
24
    def sanityCheck( self ):
25
        """
26
        Filter out keystrokes variables in the member variables. 
27
        Eliminate anything < 0.
28
        returns the number of elements removed
29
        """
30
        assert( self.dataKeys is not None and len(self.dataKeys) > 0 )
31
        assert( self.dataHT is not None and len(self.dataHT) > 0 )
32
        assert( self.dataTimeStart is not None and len(self.dataTimeStart) > 0 )
33
        assert( self.dataTimeEnd is not None and len(self.dataTimeEnd) > 0 )
34
        
35
        badLbl = self.dataTimeStart <= 0
36
        badLbl = np.bitwise_or( badLbl,  self.dataTimeEnd <= 0)
37
        badLbl = np.bitwise_or( badLbl,  self.dataHT < 0)
38
        badLbl = np.bitwise_or( badLbl,  self.dataHT >= 5)
39
        #----- remove non consecutive start times
40
        nonConsTmpLbl = np.zeros( len(self.dataTimeStart) ) == 0 # start with all True labels
41
        nonConsLbl = np.zeros( len(self.dataTimeStart) ) > 0 # start with all False labels
42
        startTmpArr = self.dataTimeStart.copy()
43
        while ( np.sum( nonConsTmpLbl ) > 0 ):
44
            # find non consecutive labels
45
            nonConsTmpLbl = np.append([False], np.diff(startTmpArr)<0)                
46
            # keep track of the indeces to remove
47
            nonConsLbl = np.bitwise_or( nonConsLbl,  nonConsTmpLbl)
48
               # changes value in the temporary array
49
            indecesToChange = np.arange(len(nonConsTmpLbl))[nonConsTmpLbl]
50
            startTmpArr[indecesToChange] = startTmpArr[indecesToChange-1]
51
52
        badLbl = np.bitwise_or( badLbl,  nonConsLbl)
53
        #-----
54
        
55
        # invert bad labels
56
        goodLbl = np.bitwise_not(badLbl)
57
        
58
        self.dataKeys = self.dataKeys[goodLbl]
59
        self.dataHT = self.dataHT[goodLbl]
60
        self.dataTimeStart = self.dataTimeStart[goodLbl]
61
        self.dataTimeEnd = self.dataTimeEnd[goodLbl]
62
             
63
        
64
        return sum(badLbl)
65
    
66
    def loadDataFile(self, fileIn, autoFilt=True, impType=None, debug=False):  
67
        """
68
        Load raw data file
69
        """      
70
        errorStr = ''
71
        try:
72
            data = []
73
            
74
#            if data.dtype == np.int64: # Sleep inertia format
75
            if impType =='si':
76
                data = np.genfromtxt(fileIn, dtype=long, delimiter=',', skip_header=0)
77
                data = data - data.min()
78
                data = data.astype(np.float64) / 1000
79
                self.dataTimeStart = data[:,0]  
80
                self.dataTimeEnd = data[:,1]
81
                self.dataHT = self.dataTimeEnd - self.dataTimeStart
82
                #TO REMOVE
83
                self.dataKeys = np.zeros(len(self.dataHT))#Just to make sanity work
84
                remNum = self.sanityCheck()
85
                #print remNum
86
            else: # PD format
87
                data = np.genfromtxt(fileIn, dtype=None, delimiter=',', skip_header=0)
88
                # load
89
                self.dataKeys = data['f0']
90
                self.dataHT = data['f1']  
91
                self.dataTimeStart = data['f3']  #No CHANGED 2<->3
92
                self.dataTimeEnd = data['f2']
93
                remNum = self.sanityCheck()
94
                #print '{:}, {:} %'.format( remNum, 1.0*remNum/len(self.dataHT) )
95
                
96
                if (debug):
97
                    print 'removed ', str(remNum), ' elements'
98
99
                if( autoFilt ):
100
                    self.filtData(self.FLT_NO_MOUSE  | self.FLT_NO_LONG_META )
101
            
102
            # load flight time
103
            self.dataFT = np.array([ self.dataTimeStart[i]-self.dataTimeStart[i-1]  for i in range(1,self.dataTimeStart.size) ])
104
            self.dataFT = np.append(self.dataFT, 0)
105
            
106
            
107
            
108
            return True
109
        except IOError:
110
            errorStr = 'file {:s} not found'.format(fileIn)
111
            return errorStr
112
    def loadDataArr(self, lstArr):
113
        self.dataKeys = np.zeros((len(lstArr),1), dtype='S30')
114
        self.dataHT = np.zeros((len(lstArr),1))
115
        self.dataTimeStart = np.zeros((len(lstArr),1))  
116
        self.dataTimeEnd =np.zeros((len(lstArr),1))
117
        i = 0
118
        for row in lstArr:
119
            tok = row.split(',')
120
            self.dataKeys[i] = str(tok[0])
121
            self.dataHT[i] = str(tok[1])
122
            self.dataTimeStart[i] = str(tok[2])
123
            self.dataTimeEnd[i] = str(tok[3]) 
124
            i += 1
125
            
126
        #self.loadDataFile(lstArr.toString())
127
    
128
129
    def filtData(self, flags):
130
        """
131
        Filter data
132
        return (fltKeys, fltHT, fltTimeStart, fltTimeEnd)
133
        """
134
        #-- filters
135
        pMouse=re.compile('("mouse.+")')
136
        pChar=re.compile('(".{1}")')
137
        pBack=re.compile('("BackSpace")')
138
        pLongMeta=re.compile('("Shift.+")|("Alt.+")|("Control.+")')
139
        pShortMeta=re.compile('("space")|("Num_Lock")|("Return")|("P_Enter")|("Caps_Lock")|("Left")|("Right")|("Up")|("Down")')
140
        pPunct=re.compile('("more")|("less")|("exclamdown")|("comma")|("\[65027\]")|("\[65105\]")|("ntilde")|("minus")|("equal")|("bracketleft")|("bracketright")|("semicolon")|("backslash")|("apostrophe")|("comma")|("period")|("slash")|("grave")')
141
        #--
142
143
        #-- create mask labels        
144
        lbl = np.ones(len( self.dataKeys ))==1
145
        if( flags & self.FLT_NO_MOUSE ):
146
            lblTmp = [ pMouse.match( k ) is None for k in self.dataKeys]
147
            lbl = lbl & lblTmp
148
        if( flags & self.FLT_NO_LETTERS ):
149
            lblTmp = [ pChar.match( k ) is None for k in self.dataKeys]
150
            lbl = lbl & lblTmp
151
        if( flags & self.FLT_NO_BACK ):
152
            lblTmp = [ pBack.match( k ) is None for k in self.dataKeys]
153
            lbl = lbl & lblTmp
154
        if( flags & self.FLT_NO_SHORT_META ):
155
            lblTmp = [ pShortMeta.match( k ) is None for k in self.dataKeys]
156
            lbl = lbl & lblTmp
157
        if( flags & self.FLT_NO_LONG_META ):
158
            lblTmp = [ pLongMeta.match( k ) is None for k in self.dataKeys]
159
            lbl = lbl & lblTmp
160
        if( flags & self.FLT_NO_PUNCT ):
161
            lblTmp = [ pPunct.match( k ) is None for k in self.dataKeys]
162
            lbl = lbl & lblTmp
163
        #--
164
        
165
        self.lbl = lbl        
166
        
167
        self.dataKeys = self.dataKeys[lbl]
168
        self.dataHT = self.dataHT[lbl]
169
        self.dataTimeStart = self.dataTimeStart[lbl]
170
        self.dataTimeEnd = self.dataTimeEnd[lbl]        
171
        
172
    def getStdVariablesFilt( fileIn, impType=None ):
173
        """
174
        Receives as parameter the location of the raw typing file
175
        Return filtered variables (i.e. no mouse clicks, no long meta buttons, no backspaces) 
176
        format returned (array of keys, array of hold times, array of press events timestamps, array of release events timestamps )
177
        """
178
        nqObj = self
179
        res = nqObj.loadDataFile( fileIn, False, impType)
180
        # remove delete button
181
        nqObj.filtData(nqObj.FLT_NO_MOUSE  | nqObj.FLT_NO_LONG_META | nqObj.FLT_NO_BACK )
182
        assert(res==True) # make sure the file exists
183
        dataKeys = nqObj.dataKeys
184
        dataHT = nqObj.dataHT
185
        dataTimeStart = nqObj.dataTimeStart
186
        dataTimeEnd = nqObj.dataTimeEnd
187
        
188
        return dataKeys, dataHT, dataTimeStart, dataTimeEnd
189
190
191
def getDataFiltHelper( fileIn, impType=None ):
192
    """
193
    Helper method to load filtered keypress data from given file
194
    :param fileIn: path to csv keypress file 
195
    :param impType: format of the csv file ('si': for sleep inertia data, None for PD data)
196
    :return: list of array with dataKeys, dataHT, dataTimeStart, dataTimeEnd
197
    """
198
    nqObj = NqDataLoader()
199
    res = nqObj.loadDataFile( fileIn, False, impType)
200
    # remove delete button
201
    nqObj.filtData(nqObj.FLT_NO_MOUSE  | nqObj.FLT_NO_LONG_META | nqObj.FLT_NO_BACK )
202
    assert(res==True) # make sure the file exists
203
    dataKeys = nqObj.dataKeys
204
    dataHT = nqObj.dataHT
205
    dataTimeStart = nqObj.dataTimeStart
206
    dataTimeEnd = nqObj.dataTimeEnd
207
    
208
    return dataKeys, dataHT, dataTimeStart, dataTimeEnd
209
    
210
    
211
def genFileStruct( dataDir, maxRepNum=4 ):
212
    '''
213
    Generate a dictionary with the NQ file list and test date (legacy method)
214
    :param dataDir: base directory containing the CSV files
215
    :param maxRepNum: integer with the maximum repetition number
216
    :return: two dictionaries: fMap, dateMap = NQ file/date list[pID][repID][expID]
217
    '''
218
    fMap = {} # data container
219
    dateMap = {}
220
    files = os.listdir( dataDir )    
221
    p = re.compile( '([0-9]+)\.{1}([0-9]+)_([0-9]+)_([0-9]+)\.csv' )
222
    for f in files:
223
        m = p.match( f )
224
        
225
        if( m ): # file found
226
            timeStamp = m.group(1)
227
            pID = int(m.group(2))
228
            repID = int(m.group(3))
229
            expID = int(m.group(4))
230
            # store new patient
231
            if( not fMap.has_key(pID) ):
232
                fMap[pID] = {}
233
                dateMap[pID] = {}
234
                for tmpRid in range(1, maxRepNum+1):
235
                    fMap[pID][tmpRid] = {}
236
                    dateMap[pID][tmpRid] = {}
237
                # fMap[pID] = {1: {}, 2: {}, 3: {}, 4:{}}
238
            # store data
239
            fMap[pID][repID][expID] = dataDir + f
240
            dateMap[pID][repID][expID] = datetime.datetime.fromtimestamp(int(timeStamp))
241
        else:
242
            print f, ' no'
243
            
244
    return fMap, dateMap