--- a +++ b/utils.py @@ -0,0 +1,939 @@ +from numpy.random import seed +seed(1017) +from tensorflow import set_random_seed +set_random_seed(1017) + +import os +from glob import glob +from collections import OrderedDict + +import mne +from mne.io import RawArray +from mne import read_evokeds, read_source_spaces, compute_covariance +from mne import channels, find_events, concatenate_raws +from mne import pick_types, viz, io, Epochs, create_info +from mne import pick_channels, concatenate_epochs +from mne.datasets import sample +from mne.simulation import simulate_sparse_stc, simulate_raw +from mne.channels import read_montage +from mne.time_frequency import tfr_morlet + +import numpy as np +from numpy import genfromtxt + +import pandas as pd +pd.options.display.precision = 4 +pd.options.display.max_columns = None + +import matplotlib.pyplot as plt +plt.rcParams["figure.figsize"] = (12,12) + +import keras +from keras import regularizers +from keras.callbacks import TensorBoard +from keras.models import Sequential, Model +from keras.layers import Dense, Dropout, Activation, Input +from keras.layers import Flatten, Conv2D, MaxPooling2D, LSTM +from keras.layers import BatchNormalization, Conv3D, MaxPooling3D + +from sklearn.utils import class_weight +from sklearn.model_selection import train_test_split + + +class Feats: + def __init__(self, num_classes=2, class_weights=[1,1], input_shape=[16,], + new_times=1, model_type='1', + x_train=1, y_train=1, x_test=1, y_test=1, x_val=1, y_val=1): + self.num_classes = num_classes + self.class_weights = class_weights + self.input_shape = input_shape + self.new_times = new_times + self.model_type = model_type + self.x_train = x_train + self.y_train = y_train + self.x_test = x_test + self.y_test = y_test + self.x_val = x_val + self.y_val = y_val + +def LoadBVData(sub,session,data_dir,exp): + #for isub,sub in enumerate(subs): + print('Loading data for subject number: ' + sub) + fname = data_dir + exp + '/' + sub + '_' + exp + '_' + session + '.vhdr' + raw,sfreq = loadBV(fname,plot_sensors=False,plot_raw=False, + plot_raw_psd=False,stim_channel=True) + return raw + +def loadBV(filename, plot_sensors=True, plot_raw=True, + plot_raw_psd=True, stim_channel=False, ): + """Load in recorder data files.""" + + + #load .vhdr files from brain vision recorder + raw = io.read_raw_brainvision(filename, + montage='standard_1020', + eog=('HEOG', 'VEOG'), + preload=True,stim_channel=stim_channel) + + #set sampling rate + sfreq = raw.info['sfreq'] + print('Sampling Rate = ' + str(sfreq)) + + #load channel locations + print('Loading Channel Locations') + if plot_sensors: + raw.plot_sensors(show_names='True') + + ##Plot raw data + if plot_raw: + raw.plot(n_channels=16, block=True) + + #plot raw psd + if plot_raw_psd: + raw.plot_psd(fmin=.1, fmax=100 ) + + return raw, sfreq + + +def LoadMuseData(subs, nsesh, data_dir, load_verbose=False, sfreq=256.): + nsubs = len(subs) + raw = [] + print('Loading Data') + for isub,sub in enumerate(subs): + print('Subject number ' + str(isub+1) + '/' + str(nsubs)) + for isesh in range(nsesh): + print(' Session number ' + str(isesh+1) + '/' + str(nsesh)) + raw.append(muse_load_data(data_dir, sfreq=sfreq ,subject_nb=sub, + session_nb=isesh+1,verbose=load_verbose)) + raw = concatenate_raws(raw) + return raw + + +#from eeg-notebooks load_data +def muse_load_data(data_dir, subject_nb=1, session_nb=1, sfreq=256., + ch_ind=[0, 1, 2, 3], stim_ind=5, replace_ch_names=None, + verbose=1): + """Load CSV files from the /data directory into a Raw object. + + Args: + data_dir (str): directory inside /data that contains the + CSV files to load, e.g., 'auditory/P300' + + Keyword Args: + subject_nb (int or str): subject number. If 'all', load all + subjects. + session_nb (int or str): session number. If 'all', load all + sessions. + sfreq (float): EEG sampling frequency + ch_ind (list): indices of the EEG channels to keep + stim_ind (int): index of the stim channel + replace_ch_names (dict or None): dictionary containing a mapping to + rename channels. Useful when an external electrode was used. + + Returns: + (mne.io.array.array.RawArray): loaded EEG + """ + + + if subject_nb == 'all': + subject_nb = '*' + if session_nb == 'all': + session_nb = '*' + + data_path = os.path.join( + 'eeg-notebooks_v0.1/data', data_dir, + 'subject{}/session{}/*.csv'.format(subject_nb, session_nb)) + fnames = glob(data_path) + + return load_muse_csv_as_raw(fnames, + sfreq=sfreq, + ch_ind=ch_ind, + stim_ind=stim_ind, + replace_ch_names=replace_ch_names, + verbose=verbose) + + +#from eeg-notebooks +def load_muse_csv_as_raw(filename, sfreq=256., ch_ind=[0, 1, 2, 3], + stim_ind=5, replace_ch_names=None, verbose=1): + """Load CSV files into a Raw object. + + Args: + filename (str or list): path or paths to CSV files to load + + Keyword Args: + subject_nb (int or str): subject number. If 'all', load all + subjects. + session_nb (int or str): session number. If 'all', load all + sessions. + sfreq (float): EEG sampling frequency + ch_ind (list): indices of the EEG channels to keep + stim_ind (int): index of the stim channel + replace_ch_names (dict or None): dictionary containing a mapping to + rename channels. Useful when an external electrode was used. + + Returns: + (mne.io.array.array.RawArray): loaded EEG + """ + + n_channel = len(ch_ind) + + raw = [] + for fname in filename: + # read the file + data = pd.read_csv(fname, index_col=0) + + # name of each channels + ch_names = list(data.columns)[0:n_channel] + ['Stim'] + + if replace_ch_names is not None: + ch_names = [c if c not in replace_ch_names.keys() + else replace_ch_names[c] for c in ch_names] + + # type of each channels + ch_types = ['eeg'] * n_channel + ['stim'] + montage = read_montage('standard_1005') + + # get data and exclude Aux channel + data = data.values[:, ch_ind + [stim_ind]].T + + # convert in Volts (from uVolts) + data[:-1] *= 1e-6 + + # create MNE object + info = create_info(ch_names=ch_names, ch_types=ch_types, + sfreq=sfreq, montage=montage, verbose=verbose) + raw.append(RawArray(data=data, info=info, verbose=verbose)) + + # concatenate all raw objects + if len(raw) > 0: + raws = concatenate_raws(raw, verbose=verbose) + else: + print('No files for subject with filename ' + str(filename)) + raws = raw + + return raws + + +def SimulateRaw(amp1 = 50, amp2 = 100, freq = 1., batch=1): + + """Create simulated raw data and events of two kinds + + Keyword Args: + amp1 (float): amplitude of first condition effect + amp2 (float): ampltiude of second condition effect, + null hypothesis amp1=amp2 + freq (float): Frequency of simulated signal 1. for ERP 10. for alpha + batch (int): number of groups of 255 trials in each condition + Returns: + raw: simulated EEG MNE raw object with two event types + event_id: dict of the two events for input to PreProcess() + """ + + + data_path = sample.data_path() + raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif' + trans_fname = data_path + '/MEG/sample/sample_audvis_raw-trans.fif' + src_fname = data_path + '/subjects/sample/bem/sample-oct-6-src.fif' + bem_fname = (data_path + + '/subjects/sample/bem/sample-5120-5120-5120-bem-sol.fif') + + + raw_single = mne.io.read_raw_fif(raw_fname,preload=True) + raw_single.set_eeg_reference(projection=True) + raw_single = raw_single.crop(0., 255.) + raw_single = raw_single.copy().pick_types(meg=False, eeg=True, eog=True, stim=True) + + #concatenate 4 raws together to make 1000 trials + raw = [] + for i in range(batch): + raw.append(raw_single) + raw = concatenate_raws(raw) + + epoch_duration = 1. + + def data_fun(amp, freq): + """Create function to create fake signal""" + def data_fun_inner(times): + """Create fake signal with no noise""" + n_samp = len(times) + window = np.zeros(n_samp) + start, stop = [int(ii * float(n_samp) / 2) + for ii in (0, 1)] + window[start:stop] = np.hamming(stop - start) + data = amp * 1e-9 * np.sin(2. * np.pi * freq * times) + data *= window + return data + return data_fun_inner + + times = raw.times[:int(raw.info['sfreq'] * epoch_duration)] + src = read_source_spaces(src_fname) + + stc_zero = simulate_sparse_stc(src, n_dipoles=1, times=times, + data_fun=data_fun(amp1,freq), random_state=0) + stc_one = simulate_sparse_stc(src, n_dipoles=1, times=times, + data_fun=data_fun(amp2,freq), random_state=0) + + raw_sim_zero = simulate_raw(raw, stc_zero, trans_fname, src, bem_fname, + cov='simple', blink=True, n_jobs=1, verbose=True) + raw_sim_one = simulate_raw(raw, stc_one, trans_fname, src, bem_fname, + cov='simple', blink=True, n_jobs=1, verbose=True) + + stim_pick = raw_sim_one.info['ch_names'].index('STI 014') + raw_sim_one._data[stim_pick][np.where(raw_sim_one._data[stim_pick]==1)] = 2 + raw = concatenate_raws([raw_sim_zero, raw_sim_one]) + event_id = {'CondZero': 1,'CondOne': 2} + return raw, event_id + + +def mastoidReref(raw): + ref_idx = pick_channels(raw.info['ch_names'],['M2']) + eeg_idx = pick_types(raw.info,eeg=True) + raw._data[eeg_idx,:] = raw._data[eeg_idx,:] - raw._data[ref_idx,:] * .5 ; + return raw + +def GrattonEmcpRaw(raw): + raw_eeg = raw.copy().pick_types(eeg=True)[:][0] + raw_eog = raw.copy().pick_types(eog=True)[:][0] + b = np.linalg.solve(np.dot(raw_eog,raw_eog.T), np.dot(raw_eog,raw_eeg.T)) + eeg_corrected = (raw_eeg.T - np.dot(raw_eog.T,b)).T + raw_new = raw.copy() + raw_new._data[pick_types(raw.info,eeg=True),:] = eeg_corrected + return raw_new + + +def GrattonEmcpEpochs(epochs): + ''' + # Correct EEG data for EOG artifacts with regression + # INPUT - MNE epochs object (with eeg and eog channels) + # OUTPUT - MNE epochs object (with eeg corrected) + # After: Gratton,Coles,Donchin, 1983 + # -compute the ERP in each condition + # -subtract ERP from each trial + # -subtract baseline (mean over all epoch) + # -predict eye channel remainder from eeg remainder + # -use coefficients to subtract eog from eeg + ''' + + event_names = ['A_error','B_error'] + i = 0 + for key, value in sorted(epochs.event_id.items(), key=lambda x: (x[1], x[0])): + event_names[i] = key + i += 1 + + #select the correct channels and data + eeg_chans = pick_types(epochs.info, eeg=True, eog=False) + eog_chans = pick_types(epochs.info, eeg=False, eog=True) + original_data = epochs._data + + #subtract the average over trials from each trial + rem = {} + for event in event_names: + data = epochs[event]._data + avg = np.mean(epochs[event]._data,axis=0) + rem[event] = data-avg + + #concatenate trials together of different types + ## then put them all back together in X (regression on all at once) + allrem = np.concatenate([rem[event] for event in event_names]) + + #separate eog and eeg + X = allrem[:,eeg_chans,:] + Y = allrem[:,eog_chans,:] + + #subtract mean over time from every trial/channel + X = (X.T - np.mean(X,2).T).T + Y = (Y.T - np.mean(Y,2).T).T + + #move electrodes first + X = np.moveaxis(X,0,1) + Y = np.moveaxis(Y,0,1) + + #make 2d and compute regression + X = np.reshape(X,(X.shape[0],np.prod(X.shape[1:]))) + Y = np.reshape(Y,(Y.shape[0],np.prod(Y.shape[1:]))) + b = np.linalg.solve(np.dot(Y,Y.T), np.dot(Y,X.T)) + + #get original data and electrodes first for matrix math + raw_eeg = np.moveaxis(original_data[:,eeg_chans,:],0,1) + raw_eog = np.moveaxis(original_data[:,eog_chans,:],0,1) + + #subtract weighted eye channels from eeg channels + eeg_corrected = (raw_eeg.T - np.dot(raw_eog.T,b)).T + + #move back to match epochs + eeg_corrected = np.moveaxis(eeg_corrected,0,1) + + #copy original epochs and replace with corrected data + epochs_new = epochs.copy() + epochs_new._data[:,eeg_chans,:] = eeg_corrected + + return epochs_new + + +def PreProcess(raw, event_id, plot_psd=False, filter_data=True, + filter_range=(1,30), plot_events=False, epoch_time=(-.2,1), + baseline=(-.2,0), rej_thresh_uV=200, rereference=False, + emcp_raw=False, emcp_epochs=False, epoch_decim=1, plot_electrodes=False, + plot_erp=False): + + sfreq = raw.info['sfreq'] + #create new output freq for after epoch or wavelet decim + nsfreq = sfreq/epoch_decim + tmin=epoch_time[0] + tmax=epoch_time[1] + if filter_range[1] > nsfreq: + filter_range[1] = nsfreq/2.5 #lower than 2 to avoid aliasing from decim?? + + #pull event names in order of trigger number + event_names = ['A_error','B_error'] + i = 0 + for key, value in sorted(event_id.items(), key=lambda x: (x[1], x[0])): + event_names[i] = key + i += 1 + + #Filtering + if rereference: + print('Rerefering to average mastoid') + raw = mastoidReref(raw) + + if filter_data: + print('Filtering Data Between ' + str(filter_range[0]) + + ' and ' + str(filter_range[1]) + ' Hz.') + raw.filter(filter_range[0],filter_range[1], + method='iir', verbose='WARNING' ) + + if plot_psd: + raw.plot_psd(fmin=filter_range[0], fmax=nsfreq/2 ) + + #Eye Correction + if emcp_raw: + print('Raw Eye Movement Correction') + raw = GrattonEmcpRaw(raw) + + #Epoching + events = find_events(raw,shortest_event=1) + color = {1: 'red', 2: 'black'} + #artifact rejection + rej_thresh = rej_thresh_uV*1e-6 + + #plot event timing + if plot_events: + viz.plot_events(events, sfreq, raw.first_samp, color=color, + event_id=event_id) + + #Construct events - Main function from MNE + epochs = Epochs(raw, events=events, event_id=event_id, + tmin=tmin, tmax=tmax, baseline=baseline, + preload=True,reject={'eeg':rej_thresh}, + verbose=False, decim=epoch_decim) + print('Remaining Trials: ' + str(len(epochs))) + + #Gratton eye movement correction procedure on epochs + if emcp_epochs: + print('Epochs Eye Movement Correct') + epochs = GrattonEmcpEpochs(epochs) + + ## plot ERP at each electrode + evoked_dict = {event_names[0]:epochs[event_names[0]].average(), + event_names[1]:epochs[event_names[1]].average()} + + # butterfly plot + if plot_electrodes: + picks = pick_types(evoked_dict[event_names[0]].info, meg=False, eeg=True, eog=False) + fig_zero = evoked_dict[event_names[0]].plot(spatial_colors=True,picks=picks) + fig_zero = evoked_dict[event_names[1]].plot(spatial_colors=True,picks=picks) + + # plot ERP in each condition on same plot + if plot_erp: + #find the electrode most miximal on the head (highest in z) + picks = np.argmax([evoked_dict[event_names[0]].info['chs'][i]['loc'][2] + for i in range(len(evoked_dict[event_names[0]].info['chs']))]) + colors = {event_names[0]:"Red",event_names[1]:"Blue"} + viz.plot_compare_evokeds(evoked_dict,colors=colors, + picks=picks,split_legend=True) + + return epochs + + + +def FeatureEngineer(epochs, model_type='NN', + frequency_domain=False, + normalization=False, electrode_median=False, + wavelet_decim=1, flims=(3,30), include_phase=False, + f_bins=20, wave_cycles=3, + wavelet_electrodes = [11,12,13,14,15], + spect_baseline=[-1,-.5], + test_split = 0.2, val_split = 0.2, + random_seed=1017, watermark = False): + + """ + Takes epochs object as + + input and settings, + outputs feats(training, test and val data option to use frequency or time domain) + + TODO: take tfr? or autoencoder encoded object? + + FeatureEngineer(epochs, model_type='NN', + frequency_domain=False, + normalization=False, electrode_median=False, + wavelet_decim=1, flims=(3,30), include_phase=False, + f_bins=20, wave_cycles=3, + wavelet_electrodes = [11,12,13,14,15], + spect_baseline=[-1,-.5], + test_split = 0.2, val_split = 0.2, + random_seed=1017, watermark = False): + """ + np.random.seed(random_seed) + + #pull event names in order of trigger number + epochs.event_id = {'cond0':1, 'cond1':2} + event_names = ['cond0','cond1'] + i = 0 + for key, value in sorted(epochs.event_id.items(), + key=lambda item: (item[1],item[0])): + event_names[i] = key + i += 1 + + #Create feats object for output + feats = Feats() + feats.num_classes = len(epochs.event_id) + feats.model_type = model_type + + if frequency_domain: + print('Constructing Frequency Domain Features') + + #list of frequencies to output + f_low = flims[0] + f_high = flims[1] + frequencies = np.linspace(f_low, f_high, f_bins, endpoint=True) + + #option to select all electrodes for fft + if wavelet_electrodes == 'all': + wavelet_electrodes = pick_types(epochs.info,eeg=True,eog=False) + + #type of output from wavelet analysis + if include_phase: + tfr_output_type = 'complex' + else: + tfr_output_type = 'power' + + tfr_dict = {} + for event in event_names: + print('Computing Morlet Wavelets on ' + event) + tfr_temp = tfr_morlet(epochs[event], freqs=frequencies, + n_cycles=wave_cycles, return_itc=False, + picks=wavelet_electrodes, average=False, + decim=wavelet_decim, output=tfr_output_type) + + # Apply spectral baseline and find stim onset time + tfr_temp = tfr_temp.apply_baseline(spect_baseline,mode='mean') + stim_onset = np.argmax(tfr_temp.times>0) + + # Reshape power output and save to tfr dict + power_out_temp = np.moveaxis(tfr_temp.data[:,:,:,stim_onset:],1,3) + power_out_temp = np.moveaxis(power_out_temp,1,2) + print(event + ' trials: ' + str(len(power_out_temp))) + tfr_dict[event] = power_out_temp + + #reshape times (sloppy but just use the last temp tfr) + feats.new_times = tfr_temp.times[stim_onset:] + + for event in event_names: + print(event + ' Time Points: ' + str(len(feats.new_times))) + print(event + ' Frequencies: ' + str(len(tfr_temp.freqs))) + + #Construct X and Y + for ievent,event in enumerate(event_names): + if ievent == 0: + X = tfr_dict[event] + Y_class = np.zeros(len(tfr_dict[event])) + else: + X = np.append(X,tfr_dict[event],0) + Y_class = np.append(Y_class,np.ones(len(tfr_dict[event]))*ievent,0) + + #concatenate real and imaginary data + if include_phase: + print('Concatenating the real and imaginary components') + X = np.append(np.real(X),np.imag(X),2) + + #compute median over electrodes to decrease features + if electrode_median: + print('Computing Median over electrodes') + X = np.expand_dims(np.median(X,axis=len(X.shape)-1),2) + + #reshape for various models + if model_type == 'NN' or model_type == 'LSTM': + X = np.reshape(X, (X.shape[0], X.shape[1], np.prod(X.shape[2:]))) + + if model_type == 'CNN3D': + X = np.expand_dims(X,4) + + if model_type == 'AUTO' or model_type == 'AUTODeep': + print('Auto model reshape') + X = np.reshape(X, (X.shape[0],np.prod(X.shape[1:]))) + + + if not frequency_domain: + print('Constructing Time Domain Features') + + #if using muse aux port as eeg must label it as such + eeg_chans = pick_types(epochs.info,eeg=True,eog=False) + + #put channels last, remove eye and stim + X = np.moveaxis(epochs._data[:,eeg_chans,:],1,2); + + #take post baseline only + stim_onset = np.argmax(epochs.times>0) + feats.new_times = epochs.times[stim_onset:] + X = X[:,stim_onset:,:] + + #convert markers to class + #requires markers to be 1 and 2 in data file? + #This probably is not robust to other marker numbers + Y_class = epochs.events[:,2]-1 #subtract 1 to make 0 and 1 + + #median over electrodes to reduce features + if electrode_median: + print('Computing Median over electrodes') + X = np.expand_dims(np.median(X,axis=len(X.shape)-1),2) + + ## Model Reshapes: + # reshape for CNN + if model_type == 'CNN': + print('Size X before reshape for CNN: ' + str(X.shape)) + X = np.expand_dims(X,3 ) + print('Size X before reshape for CNN: ' + str(X.shape)) + + # reshape for CNN3D + if model_type == 'CNN3D': + print('Size X before reshape for CNN3D: ' + str(X.shape)) + X = np.expand_dims(np.expand_dims(X,3),4) + print('Size X before reshape for CNN3D: ' + str(X.shape)) + + #reshape for autoencoder + if model_type == 'AUTO' or model_type == 'AUTODeep': + print('Size X before reshape for Auto: ' + str(X.shape)) + X = np.reshape(X, (X.shape[0], np.prod(X.shape[1:]))) + print('Size X after reshape for Auto: ' + str(X.shape)) + + + #Normalize X - TODO: need to save mean and std for future test + val + if normalization: + print('Normalizing X') + X = (X - np.mean(X)) / np.std(X) + + # convert class vectors to one hot Y and recast X + Y = keras.utils.to_categorical(Y_class,feats.num_classes) + X = X.astype('float32') + + # add watermark for testing models + if watermark: + X[Y[:,0]==0,0:2,] = 0 + X[Y[:,0]==1,0:2,] = 1 + + # Compute model input shape + feats.input_shape = X.shape[1:] + + # Split training test and validation data + val_prop = val_split / (1-test_split) + (feats.x_train, + feats.x_test, + feats.y_train, + feats.y_test) = train_test_split(X, Y, + test_size=test_split, + random_state=random_seed) + (feats.x_train, + feats.x_val, + feats.y_train, + feats.y_val) = train_test_split(feats.x_train, feats.y_train, + test_size=val_prop, + random_state=random_seed) + + #compute class weights for uneven classes + y_ints = [y.argmax() for y in feats.y_train] + feats.class_weights = class_weight.compute_class_weight('balanced', + np.unique(y_ints), + y_ints) + + #Print some outputs + print('Combined X Shape: ' + str(X.shape)) + print('Combined Y Shape: ' + str(Y_class.shape)) + print('Y Example (should be 1s & 0s): ' + str(Y_class[0:10])) + print('X Range: ' + str(np.min(X)) + ':' + str(np.max(X))) + print('Input Shape: ' + str(feats.input_shape)) + print('x_train shape:', feats.x_train.shape) + print(feats.x_train.shape[0], 'train samples') + print(feats.x_test.shape[0], 'test samples') + print(feats.x_val.shape[0], 'validation samples') + print('Class Weights: ' + str(feats.class_weights)) + + return feats + + + + + +def CreateModel(feats,units=[16,8,4,8,16], dropout=.25, + batch_norm=True, filt_size=3, pool_size=2): + + print('Creating ' + feats.model_type + ' Model') + print('Input shape: ' + str(feats.input_shape)) + + + nunits = len(units) + + ##---LSTM - Many to two, sequence of time to classes + #Units must be at least two + if feats.model_type == 'LSTM': + if nunits < 2: + print('Warning: Need at least two layers for LSTM') + + model = Sequential() + model.add(LSTM(input_shape=(None, feats.input_shape[1]), + units=units[0], return_sequences=True)) + if batch_norm: + model.add(BatchNormalization()) + model.add(Activation('relu')) + if dropout: + model.add(Dropout(dropout)) + + if len(units) > 2: + for unit in units[1:-1]: + model.add(LSTM(units=unit,return_sequences=True)) + if batch_norm: + model.add(BatchNormalization()) + model.add(Activation('relu')) + if dropout: + model.add(Dropout(dropout)) + + model.add(LSTM(units=units[-1],return_sequences=False)) + if batch_norm: + model.add(BatchNormalization()) + model.add(Activation('relu')) + if dropout: + model.add(Dropout(dropout)) + + model.add(Dense(units=feats.num_classes)) + model.add(Activation("softmax")) + + + ##---DenseFeedforward Network + #Makes a hidden layer for each item in units + if feats.model_type == 'NN': + model = Sequential() + model.add(Flatten(input_shape=feats.input_shape)) + + for unit in units: + model.add(Dense(unit)) + if batch_norm: + model.add(BatchNormalization()) + model.add(Activation('relu')) + if dropout: + model.add(Dropout(dropout)) + + model.add(Dense(feats.num_classes, activation='softmax')) + + ##----Convolutional Network + if feats.model_type == 'CNN': + if nunits < 2: + print('Warning: Need at least two layers for CNN') + model = Sequential() + model.add(Conv2D(units[0], filt_size, + input_shape=feats.input_shape, padding='same')) + model.add(Activation('relu')) + model.add(MaxPooling2D(pool_size=pool_size, padding='same')) + + if nunits > 2: + for unit in units[1:-1]: + model.add(Conv2D(unit, filt_size, padding='same')) + model.add(Activation('relu')) + model.add(MaxPooling2D(pool_size=pool_size, padding='same')) + + + model.add(Flatten()) + model.add(Dense(units[-1])) + model.add(Activation('relu')) + model.add(Dense(feats.num_classes)) + model.add(Activation('softmax')) + + ##----Convolutional Network + if feats.model_type == 'CNN3D': + if nunits < 2: + print('Warning: Need at least two layers for CNN') + model = Sequential() + model.add(Conv3D(units[0], filt_size, + input_shape=feats.input_shape, padding='same')) + model.add(Activation('relu')) + model.add(MaxPooling3D(pool_size=pool_size, padding='same')) + + if nunits > 2: + for unit in units[1:-1]: + model.add(Conv3D(unit, filt_size, padding='same')) + model.add(Activation('relu')) + model.add(MaxPooling3D(pool_size=pool_size, padding='same')) + + + model.add(Flatten()) + model.add(Dense(units[-1])) + model.add(Activation('relu')) + model.add(Dense(feats.num_classes)) + model.add(Activation('softmax')) + + + ## Autoencoder + #takes the first item in units for hidden layer size + if feats.model_type == 'AUTO': + encoding_dim = units[0] + input_data = Input(shape=(feats.input_shape[0],)) + #,activity_regularizer=regularizers.l1(10e-5) + encoded = Dense(encoding_dim, activation='relu')(input_data) + decoded = Dense(feats.input_shape[0], activation='sigmoid')(encoded) + model = Model(input_data, decoded) + + encoder = Model(input_data,encoded) + encoded_input = Input(shape=(encoding_dim,)) + decoder_layer = model.layers[-1] + decoder = Model(encoded_input, decoder_layer(encoded_input)) + + + #takes an odd number of layers > 1 + #e.g. units = [64,32,16,32,64] + if feats.model_type == 'AUTODeep': + if nunits % 2 == 0: + print('Warning: Please enter odd number of layers into units') + + half = nunits/2 + midi = int(np.floor(half)) + + input_data = Input(shape=(feats.input_shape[0],)) + encoded = Dense(units[0], activation='relu')(input_data) + + #encoder decreases + if nunits >= 3: + for unit in units[1:midi]: + encoded = Dense(unit, activation='relu')(encoded) + + #latent space + decoded = Dense(units[midi], activation='relu')(encoded) + + #decoder increses + if nunits >= 3: + for unit in units[midi+1:-1]: + decoded = Dense(unit, activation='relu')(decoded) + + decoded = Dense(units[-1], activation='relu')(decoded) + + decoded = Dense(feats.input_shape[0], activation='sigmoid')(decoded) + model = Model(input_data, decoded) + + encoder = Model(input_data,encoded) + encoded_input = Input(shape=(units[midi],)) + + + + + + if feats.model_type == 'AUTO' or feats.model_type == 'AUTODeep': + opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, + epsilon=None, decay=0.0, amsgrad=False) + model.compile(optimizer=opt, loss='mean_squared_error') + + + + if ((feats.model_type == 'CNN') or + (feats.model_type == 'CNN3D') or + (feats.model_type == 'LSTM') or + (feats.model_type == 'NN')): + + # initiate adam optimizer + opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, + epsilon=None, decay=0.0, amsgrad=False) + # Let's train the model using RMSprop + model.compile(loss='binary_crossentropy', + optimizer=opt, + metrics=['accuracy']) + encoder = [] + + + model.summary() + + return model, encoder + + +def TrainTestVal(model, feats, batch_size=2, + train_epochs=20, show_plots=True): + + print('Training Model:') + # Train Model + if feats.model_type == 'AUTO' or feats.model_type == 'AUTODeep': + print('Training autoencoder:') + + history = model.fit(feats.x_train, feats.x_train, + batch_size = batch_size, + epochs=train_epochs, + validation_data=(feats.x_val,feats.x_val), + shuffle=True, + verbose=True, + class_weight=feats.class_weights + ) + + # list all data in history + print(history.history.keys()) + + if show_plots: + # summarize history for loss + plt.semilogy(history.history['loss']) + plt.semilogy(history.history['val_loss']) + plt.title('model loss') + plt.ylabel('loss') + plt.xlabel('epoch') + plt.legend(['train', 'val'], loc='upper left') + plt.show() + + else: + history = model.fit(feats.x_train, feats.y_train, + batch_size=batch_size, + epochs=train_epochs, + validation_data=(feats.x_val, feats.y_val), + shuffle=True, + verbose=True, + class_weight=feats.class_weights + ) + + # list all data in history + print(history.history.keys()) + + if show_plots: + # summarize history for accuracy + plt.plot(history.history['acc']) + plt.plot(history.history['val_acc']) + plt.title('model accuracy') + plt.ylabel('accuracy') + plt.xlabel('epoch') + plt.legend(['train', 'val'], loc='upper left') + plt.show() + # summarize history for loss + plt.semilogy(history.history['loss']) + plt.semilogy(history.history['val_loss']) + plt.title('model loss') + plt.ylabel('loss') + plt.xlabel('epoch') + plt.legend(['train', 'val'], loc='upper left') + plt.show() + + + # Test on left out Test data + score, acc = model.evaluate(feats.x_test, feats.y_test, + batch_size=batch_size) + print(model.metrics_names) + print('Test loss:', score) + print('Test accuracy:', acc) + + # Build a dictionary of data to return + data = {} + data['score'] = score + data['acc'] = acc + + return model, data