Diff of /utils.py [000000] .. [e6696a]

Switch to unified view

a b/utils.py
1
from numpy.random import seed
2
seed(1017)
3
from tensorflow import set_random_seed
4
set_random_seed(1017)
5
6
import os
7
from glob import glob
8
from collections import OrderedDict
9
10
import mne
11
from mne.io import RawArray
12
from mne import read_evokeds, read_source_spaces, compute_covariance
13
from mne import channels, find_events, concatenate_raws
14
from mne import pick_types, viz, io, Epochs, create_info
15
from mne import pick_channels, concatenate_epochs
16
from mne.datasets import sample
17
from mne.simulation import simulate_sparse_stc, simulate_raw
18
from mne.channels import read_montage
19
from mne.time_frequency import tfr_morlet
20
21
import numpy as np
22
from numpy import genfromtxt
23
24
import pandas as pd
25
pd.options.display.precision = 4
26
pd.options.display.max_columns = None
27
28
import matplotlib.pyplot as plt
29
plt.rcParams["figure.figsize"] = (12,12)
30
31
import keras
32
from keras import regularizers
33
from keras.callbacks import TensorBoard
34
from keras.models import Sequential, Model
35
from keras.layers import Dense, Dropout, Activation, Input
36
from keras.layers import Flatten, Conv2D, MaxPooling2D, LSTM
37
from keras.layers import BatchNormalization, Conv3D, MaxPooling3D
38
39
from sklearn.utils import class_weight
40
from sklearn.model_selection import train_test_split
41
42
43
class Feats:
44
  def __init__(self, num_classes=2, class_weights=[1,1], input_shape=[16,], 
45
               new_times=1, model_type='1', 
46
               x_train=1, y_train=1, x_test=1, y_test=1, x_val=1, y_val=1):
47
    self.num_classes = num_classes
48
    self.class_weights = class_weights
49
    self.input_shape = input_shape
50
    self.new_times = new_times
51
    self.model_type = model_type
52
    self.x_train = x_train
53
    self.y_train = y_train
54
    self.x_test = x_test
55
    self.y_test = y_test
56
    self.x_val = x_val
57
    self.y_val = y_val
58
59
def LoadBVData(sub,session,data_dir,exp):
60
  #for isub,sub in enumerate(subs):       
61
  print('Loading data for subject number: ' + sub)
62
  fname = data_dir + exp + '/' + sub + '_' + exp + '_' + session + '.vhdr'
63
  raw,sfreq = loadBV(fname,plot_sensors=False,plot_raw=False,
64
          plot_raw_psd=False,stim_channel=True)
65
  return raw
66
67
def loadBV(filename, plot_sensors=True, plot_raw=True,
68
  plot_raw_psd=True, stim_channel=False, ):
69
  """Load in recorder data files."""
70
71
72
  #load .vhdr files from brain vision recorder
73
  raw = io.read_raw_brainvision(filename,
74
                          montage='standard_1020',
75
                          eog=('HEOG', 'VEOG'),
76
                          preload=True,stim_channel=stim_channel)
77
78
  #set sampling rate
79
  sfreq = raw.info['sfreq']
80
  print('Sampling Rate = ' + str(sfreq))
81
82
  #load channel locations
83
  print('Loading Channel Locations')
84
  if plot_sensors:
85
    raw.plot_sensors(show_names='True')
86
87
  ##Plot raw data
88
  if plot_raw:
89
    raw.plot(n_channels=16, block=True)
90
91
   #plot raw psd
92
  if plot_raw_psd:
93
    raw.plot_psd(fmin=.1, fmax=100 )
94
95
  return raw, sfreq
96
97
98
def LoadMuseData(subs, nsesh, data_dir, load_verbose=False, sfreq=256.):
99
  nsubs = len(subs)
100
  raw = []
101
  print('Loading Data')
102
  for isub,sub in enumerate(subs):
103
    print('Subject number ' + str(isub+1) + '/' + str(nsubs))
104
    for isesh in range(nsesh):
105
      print(' Session number ' + str(isesh+1) + '/' + str(nsesh))
106
      raw.append(muse_load_data(data_dir, sfreq=sfreq ,subject_nb=sub,
107
                    session_nb=isesh+1,verbose=load_verbose))
108
  raw = concatenate_raws(raw)
109
  return raw
110
111
112
#from eeg-notebooks load_data
113
def muse_load_data(data_dir, subject_nb=1, session_nb=1, sfreq=256.,
114
                   ch_ind=[0, 1, 2, 3], stim_ind=5, replace_ch_names=None,
115
                   verbose=1):
116
    """Load CSV files from the /data directory into a Raw object.
117
118
    Args:
119
        data_dir (str): directory inside /data that contains the
120
            CSV files to load, e.g., 'auditory/P300'
121
122
    Keyword Args:
123
        subject_nb (int or str): subject number. If 'all', load all
124
            subjects.
125
        session_nb (int or str): session number. If 'all', load all
126
            sessions.
127
        sfreq (float): EEG sampling frequency
128
        ch_ind (list): indices of the EEG channels to keep
129
        stim_ind (int): index of the stim channel
130
        replace_ch_names (dict or None): dictionary containing a mapping to
131
            rename channels. Useful when an external electrode was used.
132
133
    Returns:
134
        (mne.io.array.array.RawArray): loaded EEG
135
    """
136
137
138
    if subject_nb == 'all':
139
        subject_nb = '*'
140
    if session_nb == 'all':
141
        session_nb = '*'
142
143
    data_path = os.path.join(
144
            'eeg-notebooks_v0.1/data', data_dir,
145
            'subject{}/session{}/*.csv'.format(subject_nb, session_nb))
146
    fnames = glob(data_path)
147
148
    return load_muse_csv_as_raw(fnames,
149
                                sfreq=sfreq,
150
                                ch_ind=ch_ind,
151
                                stim_ind=stim_ind,
152
                                replace_ch_names=replace_ch_names,
153
                                verbose=verbose)
154
155
156
#from eeg-notebooks
157
def load_muse_csv_as_raw(filename, sfreq=256., ch_ind=[0, 1, 2, 3],
158
                         stim_ind=5, replace_ch_names=None, verbose=1):
159
    """Load CSV files into a Raw object.
160
161
    Args:
162
        filename (str or list): path or paths to CSV files to load
163
164
    Keyword Args:
165
        subject_nb (int or str): subject number. If 'all', load all
166
            subjects.
167
        session_nb (int or str): session number. If 'all', load all
168
            sessions.
169
        sfreq (float): EEG sampling frequency
170
        ch_ind (list): indices of the EEG channels to keep
171
        stim_ind (int): index of the stim channel
172
        replace_ch_names (dict or None): dictionary containing a mapping to
173
            rename channels. Useful when an external electrode was used.
174
175
    Returns:
176
        (mne.io.array.array.RawArray): loaded EEG
177
    """
178
179
    n_channel = len(ch_ind)
180
181
    raw = []
182
    for fname in filename:
183
        # read the file
184
        data = pd.read_csv(fname, index_col=0)
185
186
        # name of each channels
187
        ch_names = list(data.columns)[0:n_channel] + ['Stim']
188
189
        if replace_ch_names is not None:
190
            ch_names = [c if c not in replace_ch_names.keys()
191
                        else replace_ch_names[c] for c in ch_names]
192
193
        # type of each channels
194
        ch_types = ['eeg'] * n_channel + ['stim']
195
        montage = read_montage('standard_1005')
196
197
        # get data and exclude Aux channel
198
        data = data.values[:, ch_ind + [stim_ind]].T
199
200
        # convert in Volts (from uVolts)
201
        data[:-1] *= 1e-6
202
203
        # create MNE object
204
        info = create_info(ch_names=ch_names, ch_types=ch_types,
205
                           sfreq=sfreq, montage=montage, verbose=verbose)
206
        raw.append(RawArray(data=data, info=info, verbose=verbose))
207
208
    # concatenate all raw objects
209
    if len(raw) > 0:
210
      raws = concatenate_raws(raw, verbose=verbose)
211
    else:
212
      print('No files for subject with filename ' + str(filename))
213
      raws = raw
214
      
215
    return raws
216
217
218
def SimulateRaw(amp1 = 50, amp2 = 100, freq = 1., batch=1):
219
220
  """Create simulated raw data and events of two kinds
221
  
222
  Keyword Args:
223
      amp1 (float): amplitude of first condition effect
224
      amp2 (float): ampltiude of second condition effect, 
225
          null hypothesis amp1=amp2
226
      freq (float): Frequency of simulated signal 1. for ERP 10. for alpha
227
      batch (int): number of groups of 255 trials in each condition
228
  Returns: 
229
      raw: simulated EEG MNE raw object with two event types
230
      event_id: dict of the two events for input to PreProcess()
231
  """
232
233
234
  data_path = sample.data_path()
235
  raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
236
  trans_fname = data_path + '/MEG/sample/sample_audvis_raw-trans.fif'
237
  src_fname = data_path + '/subjects/sample/bem/sample-oct-6-src.fif'
238
  bem_fname = (data_path + 
239
        '/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif')
240
241
  
242
  raw_single = mne.io.read_raw_fif(raw_fname,preload=True)
243
  raw_single.set_eeg_reference(projection=True)
244
  raw_single = raw_single.crop(0., 255.)
245
  raw_single = raw_single.copy().pick_types(meg=False, eeg=True, eog=True, stim=True)
246
247
  #concatenate 4 raws together to make 1000 trials
248
  raw = []
249
  for i in range(batch):
250
    raw.append(raw_single)
251
  raw = concatenate_raws(raw)
252
253
  epoch_duration = 1.
254
  
255
  def data_fun(amp, freq):
256
    """Create function to create fake signal"""
257
    def data_fun_inner(times):
258
      """Create fake signal with no noise"""
259
      n_samp = len(times)
260
      window = np.zeros(n_samp)
261
      start, stop = [int(ii * float(n_samp) / 2)
262
        for ii in (0, 1)]
263
      window[start:stop] = np.hamming(stop - start)
264
      data = amp * 1e-9 * np.sin(2. * np.pi * freq * times)
265
      data *= window
266
      return data
267
    return data_fun_inner
268
269
  times = raw.times[:int(raw.info['sfreq'] * epoch_duration)]
270
  src = read_source_spaces(src_fname)
271
272
  stc_zero = simulate_sparse_stc(src, n_dipoles=1, times=times,
273
              data_fun=data_fun(amp1,freq), random_state=0)
274
  stc_one = simulate_sparse_stc(src, n_dipoles=1, times=times,
275
              data_fun=data_fun(amp2,freq), random_state=0)
276
277
  raw_sim_zero = simulate_raw(raw, stc_zero, trans_fname, src, bem_fname, 
278
            cov='simple', blink=True, n_jobs=1, verbose=True)
279
  raw_sim_one = simulate_raw(raw, stc_one, trans_fname, src, bem_fname, 
280
            cov='simple', blink=True, n_jobs=1, verbose=True)
281
282
  stim_pick = raw_sim_one.info['ch_names'].index('STI 014')
283
  raw_sim_one._data[stim_pick][np.where(raw_sim_one._data[stim_pick]==1)] = 2
284
  raw = concatenate_raws([raw_sim_zero, raw_sim_one])
285
  event_id = {'CondZero': 1,'CondOne': 2}
286
  return raw, event_id
287
288
289
def mastoidReref(raw):
290
  ref_idx = pick_channels(raw.info['ch_names'],['M2'])
291
  eeg_idx = pick_types(raw.info,eeg=True)
292
  raw._data[eeg_idx,:] =  raw._data[eeg_idx,:]  -  raw._data[ref_idx,:] * .5 ;
293
  return raw
294
295
def GrattonEmcpRaw(raw):
296
  raw_eeg = raw.copy().pick_types(eeg=True)[:][0]
297
  raw_eog = raw.copy().pick_types(eog=True)[:][0]
298
  b = np.linalg.solve(np.dot(raw_eog,raw_eog.T), np.dot(raw_eog,raw_eeg.T))
299
  eeg_corrected = (raw_eeg.T - np.dot(raw_eog.T,b)).T
300
  raw_new = raw.copy()
301
  raw_new._data[pick_types(raw.info,eeg=True),:] = eeg_corrected
302
  return raw_new
303
304
305
def GrattonEmcpEpochs(epochs):
306
  '''
307
  # Correct EEG data for EOG artifacts with regression
308
  # INPUT - MNE epochs object (with eeg and eog channels)
309
  # OUTPUT - MNE epochs object (with eeg corrected)
310
  # After: Gratton,Coles,Donchin, 1983
311
  # -compute the ERP in each condition
312
  # -subtract ERP from each trial
313
  # -subtract baseline (mean over all epoch)
314
  # -predict eye channel remainder from eeg remainder
315
  # -use coefficients to subtract eog from eeg
316
  '''
317
318
  event_names = ['A_error','B_error']
319
  i = 0
320
  for key, value in sorted(epochs.event_id.items(), key=lambda x: (x[1], x[0])):
321
    event_names[i] = key
322
    i += 1
323
324
  #select the correct channels and data
325
  eeg_chans = pick_types(epochs.info, eeg=True, eog=False)
326
  eog_chans = pick_types(epochs.info, eeg=False, eog=True)
327
  original_data = epochs._data
328
329
  #subtract the average over trials from each trial
330
  rem = {}
331
  for event in event_names:
332
    data = epochs[event]._data
333
    avg = np.mean(epochs[event]._data,axis=0)
334
    rem[event] = data-avg
335
336
  #concatenate trials together of different types
337
  ## then put them all back together in X (regression on all at once)
338
  allrem = np.concatenate([rem[event] for event in event_names])
339
340
  #separate eog and eeg
341
  X = allrem[:,eeg_chans,:]
342
  Y = allrem[:,eog_chans,:]
343
344
  #subtract mean over time from every trial/channel
345
  X = (X.T - np.mean(X,2).T).T
346
  Y = (Y.T - np.mean(Y,2).T).T
347
348
  #move electrodes first
349
  X = np.moveaxis(X,0,1)
350
  Y = np.moveaxis(Y,0,1)
351
352
  #make 2d and compute regression
353
  X = np.reshape(X,(X.shape[0],np.prod(X.shape[1:])))
354
  Y = np.reshape(Y,(Y.shape[0],np.prod(Y.shape[1:])))
355
  b = np.linalg.solve(np.dot(Y,Y.T), np.dot(Y,X.T))
356
357
  #get original data and electrodes first for matrix math
358
  raw_eeg = np.moveaxis(original_data[:,eeg_chans,:],0,1)
359
  raw_eog = np.moveaxis(original_data[:,eog_chans,:],0,1)
360
361
  #subtract weighted eye channels from eeg channels
362
  eeg_corrected = (raw_eeg.T - np.dot(raw_eog.T,b)).T
363
364
  #move back to match epochs
365
  eeg_corrected = np.moveaxis(eeg_corrected,0,1)
366
367
  #copy original epochs and replace with corrected data
368
  epochs_new = epochs.copy()
369
  epochs_new._data[:,eeg_chans,:] = eeg_corrected
370
371
  return epochs_new
372
373
374
def PreProcess(raw, event_id, plot_psd=False, filter_data=True,
375
               filter_range=(1,30), plot_events=False, epoch_time=(-.2,1),
376
               baseline=(-.2,0), rej_thresh_uV=200, rereference=False, 
377
               emcp_raw=False, emcp_epochs=False, epoch_decim=1, plot_electrodes=False,
378
               plot_erp=False):
379
380
  sfreq = raw.info['sfreq']
381
  #create new output freq for after epoch or wavelet decim
382
  nsfreq = sfreq/epoch_decim
383
  tmin=epoch_time[0]
384
  tmax=epoch_time[1]
385
  if filter_range[1] > nsfreq:
386
    filter_range[1] = nsfreq/2.5  #lower than 2 to avoid aliasing from decim??
387
388
  #pull event names in order of trigger number
389
  event_names = ['A_error','B_error']
390
  i = 0
391
  for key, value in sorted(event_id.items(), key=lambda x: (x[1], x[0])):
392
    event_names[i] = key
393
    i += 1
394
395
  #Filtering
396
  if rereference:
397
    print('Rerefering to average mastoid')
398
    raw = mastoidReref(raw)
399
400
  if filter_data:
401
    print('Filtering Data Between ' + str(filter_range[0]) + 
402
            ' and ' + str(filter_range[1]) + ' Hz.')
403
    raw.filter(filter_range[0],filter_range[1],
404
               method='iir', verbose='WARNING' )
405
406
  if plot_psd:
407
    raw.plot_psd(fmin=filter_range[0], fmax=nsfreq/2 )
408
409
  #Eye Correction
410
  if emcp_raw:
411
    print('Raw Eye Movement Correction')
412
    raw = GrattonEmcpRaw(raw)
413
414
  #Epoching
415
  events = find_events(raw,shortest_event=1)
416
  color = {1: 'red', 2: 'black'}
417
  #artifact rejection
418
  rej_thresh = rej_thresh_uV*1e-6
419
420
  #plot event timing
421
  if plot_events:
422
    viz.plot_events(events, sfreq, raw.first_samp, color=color,
423
                        event_id=event_id)
424
425
  #Construct events - Main function from MNE
426
  epochs = Epochs(raw, events=events, event_id=event_id,
427
                  tmin=tmin, tmax=tmax, baseline=baseline,
428
                  preload=True,reject={'eeg':rej_thresh},
429
                  verbose=False, decim=epoch_decim)
430
  print('Remaining Trials: ' + str(len(epochs)))
431
432
  #Gratton eye movement correction procedure on epochs
433
  if emcp_epochs:
434
    print('Epochs Eye Movement Correct')
435
    epochs = GrattonEmcpEpochs(epochs)
436
437
  ## plot ERP at each electrode
438
  evoked_dict = {event_names[0]:epochs[event_names[0]].average(),
439
                              event_names[1]:epochs[event_names[1]].average()}
440
441
  # butterfly plot
442
  if plot_electrodes:
443
    picks = pick_types(evoked_dict[event_names[0]].info, meg=False, eeg=True, eog=False)
444
    fig_zero = evoked_dict[event_names[0]].plot(spatial_colors=True,picks=picks)
445
    fig_zero = evoked_dict[event_names[1]].plot(spatial_colors=True,picks=picks)
446
447
  # plot ERP in each condition on same plot
448
  if plot_erp:
449
    #find the electrode most miximal on the head (highest in z)
450
    picks = np.argmax([evoked_dict[event_names[0]].info['chs'][i]['loc'][2] 
451
              for i in range(len(evoked_dict[event_names[0]].info['chs']))])
452
    colors = {event_names[0]:"Red",event_names[1]:"Blue"}
453
    viz.plot_compare_evokeds(evoked_dict,colors=colors,
454
                            picks=picks,split_legend=True)
455
456
  return epochs
457
458
459
460
def FeatureEngineer(epochs, model_type='NN',
461
                    frequency_domain=False,
462
                    normalization=False, electrode_median=False,
463
                    wavelet_decim=1, flims=(3,30), include_phase=False,
464
                    f_bins=20, wave_cycles=3, 
465
                    wavelet_electrodes = [11,12,13,14,15],
466
                    spect_baseline=[-1,-.5],
467
                    test_split = 0.2, val_split = 0.2,
468
                    random_seed=1017, watermark = False):
469
470
  """
471
  Takes epochs object as 
472
473
  input and settings, 
474
  outputs  feats(training, test and val data option to use frequency or time domain)
475
  
476
  TODO: take tfr? or autoencoder encoded object?
477
478
  FeatureEngineer(epochs, model_type='NN',
479
                    frequency_domain=False,
480
                    normalization=False, electrode_median=False,
481
                    wavelet_decim=1, flims=(3,30), include_phase=False,
482
                    f_bins=20, wave_cycles=3, 
483
                    wavelet_electrodes = [11,12,13,14,15],
484
                    spect_baseline=[-1,-.5],
485
                    test_split = 0.2, val_split = 0.2,
486
                    random_seed=1017, watermark = False):
487
  """
488
  np.random.seed(random_seed)
489
490
  #pull event names in order of trigger number
491
  epochs.event_id = {'cond0':1, 'cond1':2}
492
  event_names = ['cond0','cond1']
493
  i = 0
494
  for key, value in sorted(epochs.event_id.items(),
495
                           key=lambda item: (item[1],item[0])):
496
    event_names[i] = key
497
    i += 1
498
499
  #Create feats object for output
500
  feats = Feats()
501
  feats.num_classes = len(epochs.event_id)
502
  feats.model_type = model_type
503
504
  if frequency_domain:
505
    print('Constructing Frequency Domain Features')
506
507
    #list of frequencies to output
508
    f_low = flims[0]
509
    f_high = flims[1]
510
    frequencies =  np.linspace(f_low, f_high, f_bins, endpoint=True)
511
512
    #option to select all electrodes for fft
513
    if wavelet_electrodes == 'all':
514
      wavelet_electrodes = pick_types(epochs.info,eeg=True,eog=False)
515
516
    #type of output from wavelet analysis
517
    if include_phase:
518
      tfr_output_type = 'complex'
519
    else:
520
      tfr_output_type = 'power'
521
522
    tfr_dict = {}
523
    for event in event_names:
524
      print('Computing Morlet Wavelets on ' + event)
525
      tfr_temp = tfr_morlet(epochs[event], freqs=frequencies,
526
                            n_cycles=wave_cycles, return_itc=False,
527
                            picks=wavelet_electrodes, average=False,
528
                            decim=wavelet_decim, output=tfr_output_type)
529
530
      # Apply spectral baseline and find stim onset time
531
      tfr_temp = tfr_temp.apply_baseline(spect_baseline,mode='mean')
532
      stim_onset = np.argmax(tfr_temp.times>0)
533
534
      # Reshape power output and save to tfr dict
535
      power_out_temp = np.moveaxis(tfr_temp.data[:,:,:,stim_onset:],1,3)
536
      power_out_temp = np.moveaxis(power_out_temp,1,2)
537
      print(event + ' trials: ' + str(len(power_out_temp)))
538
      tfr_dict[event] = power_out_temp
539
540
    #reshape times (sloppy but just use the last temp tfr)
541
    feats.new_times = tfr_temp.times[stim_onset:]
542
543
    for event in event_names:
544
      print(event + ' Time Points: ' + str(len(feats.new_times)))
545
      print(event + ' Frequencies: ' + str(len(tfr_temp.freqs)))
546
547
    #Construct X and Y
548
    for ievent,event in enumerate(event_names):
549
      if ievent == 0:
550
        X = tfr_dict[event]
551
        Y_class = np.zeros(len(tfr_dict[event]))
552
      else:
553
        X = np.append(X,tfr_dict[event],0)
554
        Y_class = np.append(Y_class,np.ones(len(tfr_dict[event]))*ievent,0)
555
556
    #concatenate real and imaginary data
557
    if include_phase:
558
      print('Concatenating the real and imaginary components')
559
      X = np.append(np.real(X),np.imag(X),2)
560
561
    #compute median over electrodes to decrease features
562
    if electrode_median:
563
      print('Computing Median over electrodes')
564
      X = np.expand_dims(np.median(X,axis=len(X.shape)-1),2)
565
566
    #reshape for various models
567
    if model_type == 'NN' or model_type == 'LSTM':
568
      X = np.reshape(X, (X.shape[0], X.shape[1], np.prod(X.shape[2:])))
569
570
    if model_type == 'CNN3D':
571
      X = np.expand_dims(X,4)
572
573
    if model_type == 'AUTO' or model_type == 'AUTODeep':
574
      print('Auto model reshape')
575
      X = np.reshape(X, (X.shape[0],np.prod(X.shape[1:])))
576
577
578
  if not frequency_domain:
579
    print('Constructing Time Domain Features')
580
581
    #if using muse aux port as eeg must label it as such
582
    eeg_chans = pick_types(epochs.info,eeg=True,eog=False)
583
584
    #put channels last, remove eye and stim
585
    X = np.moveaxis(epochs._data[:,eeg_chans,:],1,2);
586
587
    #take post baseline only
588
    stim_onset = np.argmax(epochs.times>0)
589
    feats.new_times = epochs.times[stim_onset:]
590
    X = X[:,stim_onset:,:]
591
592
    #convert markers to class
593
    #requires markers to be 1 and 2 in data file?
594
    #This probably is not robust to other marker numbers
595
    Y_class = epochs.events[:,2]-1  #subtract 1 to make 0 and 1
596
597
    #median over electrodes to reduce features
598
    if electrode_median:
599
      print('Computing Median over electrodes')
600
      X = np.expand_dims(np.median(X,axis=len(X.shape)-1),2)
601
602
    ## Model Reshapes:
603
    # reshape for CNN
604
    if model_type == 'CNN':
605
      print('Size X before reshape for CNN: ' + str(X.shape))
606
      X = np.expand_dims(X,3 )
607
      print('Size X before reshape for CNN: ' + str(X.shape))
608
609
    # reshape for CNN3D
610
    if model_type == 'CNN3D':
611
      print('Size X before reshape for CNN3D: ' + str(X.shape))
612
      X = np.expand_dims(np.expand_dims(X,3),4)
613
      print('Size X before reshape for CNN3D: ' + str(X.shape))
614
615
    #reshape for autoencoder
616
    if model_type == 'AUTO' or model_type == 'AUTODeep':
617
      print('Size X before reshape for Auto: ' + str(X.shape))
618
      X = np.reshape(X, (X.shape[0], np.prod(X.shape[1:])))
619
      print('Size X after reshape for Auto: ' + str(X.shape))
620
621
622
  #Normalize X - TODO: need to save mean and std for future test + val
623
  if normalization:
624
    print('Normalizing X')
625
    X = (X - np.mean(X)) / np.std(X)
626
627
  # convert class vectors to one hot Y and recast X
628
  Y = keras.utils.to_categorical(Y_class,feats.num_classes)
629
  X = X.astype('float32')
630
631
  # add watermark for testing models
632
  if watermark:
633
    X[Y[:,0]==0,0:2,] = 0
634
    X[Y[:,0]==1,0:2,] = 1
635
636
  # Compute model input shape
637
  feats.input_shape = X.shape[1:]
638
639
  # Split training test and validation data
640
  val_prop = val_split / (1-test_split)
641
  (feats.x_train,
642
    feats.x_test,
643
    feats.y_train,
644
    feats.y_test) = train_test_split(X, Y,
645
                                     test_size=test_split,
646
                                     random_state=random_seed)
647
  (feats.x_train,
648
   feats.x_val,
649
   feats.y_train,
650
   feats.y_val) = train_test_split(feats.x_train, feats.y_train,
651
                                   test_size=val_prop,
652
                                   random_state=random_seed)
653
654
  #compute class weights for uneven classes
655
  y_ints = [y.argmax() for y in feats.y_train]
656
  feats.class_weights = class_weight.compute_class_weight('balanced',
657
                                                 np.unique(y_ints),
658
                                                 y_ints)
659
660
  #Print some outputs
661
  print('Combined X Shape: ' + str(X.shape))
662
  print('Combined Y Shape: ' + str(Y_class.shape))
663
  print('Y Example (should be 1s & 0s): ' + str(Y_class[0:10]))
664
  print('X Range: ' + str(np.min(X)) + ':' + str(np.max(X)))
665
  print('Input Shape: ' + str(feats.input_shape))
666
  print('x_train shape:', feats.x_train.shape)
667
  print(feats.x_train.shape[0], 'train samples')
668
  print(feats.x_test.shape[0], 'test samples')
669
  print(feats.x_val.shape[0], 'validation samples')
670
  print('Class Weights: ' + str(feats.class_weights))
671
672
  return feats
673
674
675
676
677
678
def CreateModel(feats,units=[16,8,4,8,16], dropout=.25,
679
                batch_norm=True, filt_size=3, pool_size=2):
680
681
  print('Creating ' +  feats.model_type + ' Model')
682
  print('Input shape: ' + str(feats.input_shape))
683
684
685
  nunits = len(units)
686
687
  ##---LSTM - Many to two, sequence of time to classes
688
  #Units must be at least two
689
  if feats.model_type == 'LSTM':
690
    if nunits < 2:
691
      print('Warning: Need at least two layers for LSTM')
692
693
    model = Sequential()
694
    model.add(LSTM(input_shape=(None, feats.input_shape[1]),
695
                   units=units[0], return_sequences=True))
696
    if batch_norm:
697
      model.add(BatchNormalization())
698
    model.add(Activation('relu'))
699
    if dropout:
700
      model.add(Dropout(dropout))
701
702
    if len(units) > 2:
703
      for unit in units[1:-1]:
704
        model.add(LSTM(units=unit,return_sequences=True))
705
        if batch_norm:
706
          model.add(BatchNormalization())
707
        model.add(Activation('relu'))
708
        if dropout:
709
          model.add(Dropout(dropout))
710
711
    model.add(LSTM(units=units[-1],return_sequences=False))
712
    if batch_norm:
713
      model.add(BatchNormalization())
714
    model.add(Activation('relu'))
715
    if dropout:
716
      model.add(Dropout(dropout))
717
718
    model.add(Dense(units=feats.num_classes))
719
    model.add(Activation("softmax"))
720
721
722
  ##---DenseFeedforward Network
723
  #Makes a hidden layer for each item in units
724
  if feats.model_type == 'NN':
725
    model = Sequential()
726
    model.add(Flatten(input_shape=feats.input_shape))
727
728
    for unit in units:
729
      model.add(Dense(unit))
730
      if batch_norm:
731
        model.add(BatchNormalization())
732
      model.add(Activation('relu'))
733
      if dropout:
734
        model.add(Dropout(dropout))
735
736
    model.add(Dense(feats.num_classes, activation='softmax'))
737
738
  ##----Convolutional Network
739
  if feats.model_type == 'CNN':
740
    if nunits < 2:
741
      print('Warning: Need at least two layers for CNN')
742
    model = Sequential()
743
    model.add(Conv2D(units[0], filt_size,
744
              input_shape=feats.input_shape, padding='same'))
745
    model.add(Activation('relu'))
746
    model.add(MaxPooling2D(pool_size=pool_size, padding='same'))
747
748
    if nunits > 2:
749
      for unit in units[1:-1]:
750
        model.add(Conv2D(unit, filt_size, padding='same'))
751
        model.add(Activation('relu'))
752
        model.add(MaxPooling2D(pool_size=pool_size, padding='same'))
753
754
755
    model.add(Flatten())
756
    model.add(Dense(units[-1]))
757
    model.add(Activation('relu'))
758
    model.add(Dense(feats.num_classes))
759
    model.add(Activation('softmax'))
760
761
  ##----Convolutional Network
762
  if feats.model_type == 'CNN3D':
763
    if nunits < 2:
764
      print('Warning: Need at least two layers for CNN')
765
    model = Sequential()
766
    model.add(Conv3D(units[0], filt_size,
767
                     input_shape=feats.input_shape, padding='same'))
768
    model.add(Activation('relu'))
769
    model.add(MaxPooling3D(pool_size=pool_size, padding='same'))
770
771
    if nunits > 2:
772
      for unit in units[1:-1]:
773
        model.add(Conv3D(unit, filt_size, padding='same'))
774
        model.add(Activation('relu'))
775
        model.add(MaxPooling3D(pool_size=pool_size, padding='same'))
776
777
778
    model.add(Flatten())
779
    model.add(Dense(units[-1]))
780
    model.add(Activation('relu'))
781
    model.add(Dense(feats.num_classes))
782
    model.add(Activation('softmax'))
783
784
785
  ## Autoencoder
786
  #takes the first item in units for hidden layer size
787
  if feats.model_type == 'AUTO':
788
    encoding_dim = units[0]
789
    input_data = Input(shape=(feats.input_shape[0],))
790
    #,activity_regularizer=regularizers.l1(10e-5)
791
    encoded = Dense(encoding_dim, activation='relu')(input_data)
792
    decoded = Dense(feats.input_shape[0], activation='sigmoid')(encoded)
793
    model = Model(input_data, decoded)
794
795
    encoder = Model(input_data,encoded)
796
    encoded_input = Input(shape=(encoding_dim,))
797
    decoder_layer = model.layers[-1]
798
    decoder = Model(encoded_input, decoder_layer(encoded_input))
799
800
801
  #takes an odd number of layers > 1
802
  #e.g. units = [64,32,16,32,64]
803
  if feats.model_type == 'AUTODeep':
804
    if nunits % 2 == 0:
805
      print('Warning: Please enter odd number of layers into units')
806
807
    half = nunits/2
808
    midi = int(np.floor(half))
809
810
    input_data = Input(shape=(feats.input_shape[0],))
811
    encoded = Dense(units[0], activation='relu')(input_data)
812
813
    #encoder decreases
814
    if nunits >= 3:
815
        for unit in units[1:midi]:
816
          encoded = Dense(unit, activation='relu')(encoded)
817
818
    #latent space
819
    decoded = Dense(units[midi], activation='relu')(encoded)
820
821
    #decoder increses
822
    if nunits >= 3:
823
      for unit in units[midi+1:-1]:
824
        decoded = Dense(unit, activation='relu')(decoded)
825
826
    decoded = Dense(units[-1], activation='relu')(decoded)
827
828
    decoded = Dense(feats.input_shape[0], activation='sigmoid')(decoded)
829
    model = Model(input_data, decoded)
830
831
    encoder = Model(input_data,encoded)
832
    encoded_input = Input(shape=(units[midi],))
833
834
835
836
837
838
  if feats.model_type == 'AUTO' or feats.model_type == 'AUTODeep':
839
    opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999,
840
                                epsilon=None, decay=0.0, amsgrad=False)
841
    model.compile(optimizer=opt, loss='mean_squared_error')
842
843
844
845
  if ((feats.model_type == 'CNN') or
846
      (feats.model_type == 'CNN3D') or
847
      (feats.model_type == 'LSTM') or
848
      (feats.model_type == 'NN')):
849
850
    # initiate adam optimizer
851
    opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999,
852
                                epsilon=None, decay=0.0, amsgrad=False)
853
    # Let's train the model using RMSprop
854
    model.compile(loss='binary_crossentropy',
855
                  optimizer=opt,
856
                  metrics=['accuracy'])
857
    encoder = []
858
859
860
  model.summary()
861
862
  return model, encoder
863
864
865
def TrainTestVal(model, feats, batch_size=2, 
866
                train_epochs=20, show_plots=True):
867
868
  print('Training Model:')
869
  # Train Model
870
  if feats.model_type == 'AUTO' or feats.model_type == 'AUTODeep':
871
    print('Training autoencoder:')
872
873
    history = model.fit(feats.x_train, feats.x_train,
874
                        batch_size = batch_size,
875
                        epochs=train_epochs,
876
                        validation_data=(feats.x_val,feats.x_val),
877
                        shuffle=True,
878
                        verbose=True,
879
                        class_weight=feats.class_weights
880
                       )
881
882
    # list all data in history
883
    print(history.history.keys())
884
885
    if show_plots:
886
      # summarize history for loss
887
      plt.semilogy(history.history['loss'])
888
      plt.semilogy(history.history['val_loss'])
889
      plt.title('model loss')
890
      plt.ylabel('loss')
891
      plt.xlabel('epoch')
892
      plt.legend(['train', 'val'], loc='upper left')
893
      plt.show()
894
895
  else:
896
    history = model.fit(feats.x_train, feats.y_train,
897
              batch_size=batch_size,
898
              epochs=train_epochs,
899
              validation_data=(feats.x_val, feats.y_val),
900
              shuffle=True,
901
              verbose=True,
902
              class_weight=feats.class_weights
903
              )
904
905
    # list all data in history
906
    print(history.history.keys())
907
908
    if show_plots:
909
      # summarize history for accuracy
910
      plt.plot(history.history['acc'])
911
      plt.plot(history.history['val_acc'])
912
      plt.title('model accuracy')
913
      plt.ylabel('accuracy')
914
      plt.xlabel('epoch')
915
      plt.legend(['train', 'val'], loc='upper left')
916
      plt.show()
917
      # summarize history for loss
918
      plt.semilogy(history.history['loss'])
919
      plt.semilogy(history.history['val_loss'])
920
      plt.title('model loss')
921
      plt.ylabel('loss')
922
      plt.xlabel('epoch')
923
      plt.legend(['train', 'val'], loc='upper left')
924
      plt.show()
925
926
927
    # Test on left out Test data
928
    score, acc = model.evaluate(feats.x_test, feats.y_test,
929
                                batch_size=batch_size)
930
    print(model.metrics_names)
931
    print('Test loss:', score)
932
    print('Test accuracy:', acc)
933
934
    # Build a dictionary of data to return
935
    data = {}
936
    data['score'] = score
937
    data['acc'] = acc
938
939
    return model, data