--- a +++ b/bin/DeepMod_scripts/myMultiBiRNN.py @@ -0,0 +1,477 @@ + + +import tensorflow as tf +from tensorflow.contrib import rnn +import numpy as np + +import math +import glob, os, sys, time; + +from collections import defaultdict + +batchsize = 2048; + +# different class weights for unbalanced data +class_weights = tf.constant([0.1,0.9]) + +# +# create a RNN with LSTM +# define performance evaluation operation +# +def mCreateSession(num_input, num_hidden, timesteps, moptions): + # two classes only + num_classes = 2; + # the number of layers + numlayers = 3; + # learning rate + learning_rate = 0.001 + + # define input and output + X = tf.placeholder("float", [None, timesteps, num_input]); + Y = tf.placeholder("float", [None, num_classes]); + + # for last layers + weights = {'out': tf.Variable(tf.truncated_normal([2*num_hidden, num_classes]))}; + biases = {'out': tf.Variable(tf.truncated_normal([num_classes]))} + + # define a bidirectional RNN + def BiRNN(x, weights, biases): + x = tf.unstack(x, timesteps, 1); + + # define the LSTM cells + lstm_fw_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(num_hidden, forget_bias=1.0) for _ in range(numlayers)]); + lstm_bw_cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(num_hidden, forget_bias=1.0) for _ in range(numlayers)]); + + # define bidirectional RNN + try: + outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32); + except Exception: + outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32); + + # define output layer + if moptions['outputlayer'] in ['sigmoid']: + return tf.contrib.layers.fully_connected(outputs[int(timesteps/2)], num_outputs=num_classes, activation_fn=tf.nn.sigmoid); + else: + return tf.matmul(outputs[int(timesteps/2)], weights['out']) + biases['out'] + + # get prediction + logits = BiRNN(X, weights, biases); + prediction = tf.nn.softmax(logits) + + mfpred=tf.argmax(prediction,1) + + ## with different class-weights or not + if 'unbalanced' in moptions and (not moptions['unbalanced']==None) and moptions['unbalanced']==1: # class_weights + loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=tf.multiply(logits, class_weights), labels=Y)) + else: + loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=Y)) + # + + # for optimizer + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate); + train_op = optimizer.minimize(loss_op); + + # get accuracy + correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1)); + accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)); + + # AUC + auc_op = tf.metrics.auc(Y, prediction) + # precision + mpre = tf.metrics.precision(tf.argmax(Y, 1), tf.argmax(prediction, 1)) + # recall + mspf = tf.metrics.recall(tf.argmax(Y, 1), tf.argmax(prediction, 1)) + + # initialization of variables + init = tf.global_variables_initializer(); + init_l = tf.local_variables_initializer() + + saver = tf.train.Saver(); + + return (init, init_l, loss_op, accuracy, train_op, X, Y, saver, auc_op, mpre, mspf, mfpred) + +# +# train a model and save it. +# +def train_save_model(filelists, num_input, mhidden, timesteps, moptions): + training_steps = 4 + #training_steps = 40 + + init, init_l, loss_op, accuracy, train_op, X, Y, saver, auc_op, mpre, mspf, mfpred = mCreateSession(num_input, mhidden, timesteps, moptions) + + # display step + desplay_files = len(filelists[0])/100 + if desplay_files<2: desplay_files = 2; + if desplay_files>10: desplay_files = int(desplay_files/10) * 10; #desplay_files=2 + if desplay_files>100: desplay_files = 100 + file_group_id = [0 for _ in range(len(filelists))]; + sumpsize = 25; + + # for configuration + config = tf.ConfigProto() + if (timesteps>61 and num_input>50): + config.gpu_options.per_process_gpu_memory_fraction = 0.5 + else: config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + # initialization + sess.run(init); + sess.run(init_l) + start_time = time.time(); start_c_time = time.time(); + io_time = 0; + + # for each epoch + for step in range(1, training_steps+1): + print('===%d=====================step========================%d/%d' % (desplay_files, step, training_steps)) + sys.stdout.flush() + last_desplay_files_num = -1; + file_group_id[0] = 0 + while file_group_id[0] < len(filelists[0]): + io_start_time = time.time(); + + # for each input groups. + # usually two groups: one positive group and one negative group + # might also one group containing both positive and negative labelling data + featurelist = [[[], []] for _ in range(len(filelists))]; + minsize = None; cur_batch_num = None; + # get data from all groups until 'minsize' data is loaded. + for ifl in range(len(filelists)): + if ifl==0: + minsize = batchsize * sumpsize + else: minsize = batchsize * cur_batch_num; + while len(featurelist[ifl][0])<minsize: + if not file_group_id[ifl] < len(filelists[ifl]): + if ifl==0: break; + else: file_group_id[ifl] = 0 + # get more data + batch_2_x, batch_2_y, _ = getDataFromFile_new(filelists[ifl][file_group_id[ifl]], moptions) + if len(batch_2_y)>0: + if len(featurelist[ifl][0])==0: + featurelist[ifl][0] = batch_2_x + featurelist[ifl][1] = batch_2_y + else: + # merge current loading data with previously loading data + featurelist[ifl][0] = np.concatenate((featurelist[ifl][0], batch_2_x), axis=0) + featurelist[ifl][1] = np.concatenate((featurelist[ifl][1], batch_2_y), axis=0) + file_group_id[ifl] += 1; + # split for small groups for training + if ifl==0: + featurelist[ifl][0] = np.array_split(featurelist[ifl][0], int(len(featurelist[ifl][0])/batchsize)) + featurelist[ifl][1] = np.array_split(featurelist[ifl][1], int(len(featurelist[ifl][1])/batchsize)) + cur_batch_num = len(featurelist[ifl][0]) + if len(featurelist[0][0])<sumpsize*0.8: + for ifl in range(1, len(filelists)): + if len(featurelist[0][0])*batchsize*1.2 < len(featurelist[ifl][0]): + featurelist[ifl][0] = featurelist[ifl][0][:int(len(featurelist[0][0])*batchsize*1.2)] + featurelist[ifl][1] = featurelist[ifl][1][:int(len(featurelist[0][0])*batchsize*1.2)] + if len(featurelist[0][0])<1: continue + # + if len(filelists)>1: + for ifl in range(1, len(filelists)): + #if (file_group_id[0]+1) - last_desplay_files_num >= desplay_files: msizeprint.append(str(len(featurelist[ifl][0]))) + featurelist[ifl][0] = np.array_split(featurelist[ifl][0], len(featurelist[0][0])) + featurelist[ifl][1] = np.array_split(featurelist[ifl][1], len(featurelist[0][0])) + io_time += (time.time() - io_start_time) + + ifl=3 if len(featurelist)>3 else len(featurelist)-1 + if (file_group_id[0]+1) - last_desplay_files_num >= desplay_files: + sess.run(init_l) + try: + # print some testing information as progress indicators + loss, aucm, acc, p, r = sess.run([loss_op, auc_op[1], accuracy, mpre[1], mspf[1]], feed_dict={X:featurelist[ifl][0][0], Y:featurelist[ifl][1][0]}) + print(">>>Tratin#files "+str(file_group_id[0]+1)+",loss="+"{:.3f}".format(loss)+",AUC="+"{:.3f}".format(aucm)+",acc="+"{:.3f}".format(acc)+",p="+"{:.3f}".format(p)+",r="+"{:.3f}".format(r)+(" Comsuming time: %d(current=%d) IO=%d(%.3f)" % (time.time()-start_time, time.time()-start_c_time, io_time, io_time/float(time.time()-start_time)))); + except: + print(">>>Tratin#filesError "+str(file_group_id[0]+1)+(" Comsuming time: %d(current=%d) IO=%d(%.3f)" % (time.time()-start_time, time.time()-start_c_time, io_time, io_time/float(time.time()-start_time)))); + sys.stdout.flush() + start_c_time = time.time(); + + # using each subgroup of data for training + for subi in range(len(featurelist[0][0])): + for ifl in range(len(filelists)): + to = sess.run([train_op, loss_op], feed_dict={X:featurelist[ifl][0][subi], Y:featurelist[ifl][1][subi]}) + if len(featurelist)==1: + # print some detail if nan issue happens + if math.isnan(to[1]): + for toj in range(len(featurelist[ifl][0][subi])): + print('{} vs {}'.format(featurelist[ifl][1][subi][toj][0], featurelist[ifl][1][subi][toj][1])) + for tok in featurelist[ifl][0][subi][toj]: + opstr = [] + for tol in tok: + opstr.append(str(round(tol, 2))) + print("\t\t\t"+','.join(opstr)) + sys.exit(1) + + + # adjust progress output information + ifl=3 if len(featurelist)>3 else len(featurelist)-1 + if (file_group_id[0]+1) - last_desplay_files_num >= desplay_files: + last_desplay_files_num = (file_group_id[0]+1) - ((file_group_id[0]+1) % desplay_files) + + # store more models + if 49.5<int(file_group_id[0]*100/float(len(filelists[0])))<50.5: + savp = '.50' + if (not os.path.isdir(moptions['outFolder']+str(step-1)+savp)): + os.system('mkdir -p '+moptions['outFolder']+str(step-1)+savp); + saver.save(sess, moptions['outFolder']+str(step-1)+savp+'/'+moptions['FileID']); + if len(featurelist)==1: + cur_per = int(file_group_id[0]*100/float(len(filelists[0]))) + if cur_per in [10, 20, 30, 40, 60, 70, 80, 90]: + savp = str(round(cur_per/100.0, 2)) + if (not os.path.isdir(moptions['outFolder']+str(step-1)+savp)): + os.system('mkdir -p '+moptions['outFolder']+str(step-1)+savp); + saver.save(sess, moptions['outFolder']+str(step-1)+savp+'/'+moptions['FileID']); + # for each epoch, store the trained model + if (not os.path.isdir(moptions['outFolder']+str(step))): + os.system('mkdir -p '+moptions['outFolder']+str(step)); + saver.save(sess, moptions['outFolder']+str(step)+'/'+moptions['FileID']); + print("Training Finished!") + + return (accuracy, X, Y, auc_op, mpre, mspf, init_l, mfpred) + +# +# get all data files in a folder +# +def getTFiles1(folder1, moptions): + t1files = glob.glob(os.path.join(folder1, "*.xy.gz")) + # get all data in a recursive way + if moptions['recursive']==1: + t1files.extend(glob.glob(os.path.join(folder1, "*/*.xy.gz"))) + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*.xy.gz"))); + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*/*.xy.gz"))); + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*/*/*.xy.gz"))); + print("Get folder1"); + # for read-based independent testing + print(t1files.__sizeof__(), len(t1files)) + if moptions['test'][0] == '0': + if moptions['test'][1]>0.5: + t1files = t1files[:int(len(t1files)*moptions['test'][1])] + else: t1files = t1files[-int(len(t1files)*moptions['test'][1]):] + print('Sizeinfo: %s sizeof=%d len=%d' % (folder1, t1files.__sizeof__(), len(t1files))) + sys.stdout.flush(); + + return t1files + +# +# get all data files in two seperate folders +# +def getTFiles(folder1, folder2, moptions): + t1files = glob.glob(os.path.join(folder1, "*.xy.gz")); #print(t1files.__sizeof__(), len(t1files)) + # get all data in a recursive way + if moptions['recursive']==1: + t1files.extend(glob.glob(os.path.join(folder1, "*/*.xy.gz"))); #print(t1files.__sizeof__(), len(t1files)) + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*.xy.gz"))); #print(t1files.__sizeof__(), len(t1files)) + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*/*.xy.gz"))); #print(t1files.__sizeof__(), len(t1files)) + t1files.extend(glob.glob(os.path.join(folder1, "*/*/*/*/*.xy.gz"))); #print(t1files.__sizeof__(), len(t1files)) + print("Get folder1"); + # for read-based independent testing + print(t1files.__sizeof__(), len(t1files)) + if moptions['test'][0] == '0': + if moptions['test'][1]>0.5: + t1files = t1files[:int(len(t1files)*moptions['test'][1])] + else: t1files = t1files[-int(len(t1files)*moptions['test'][1]):] + print(t1files.__sizeof__(), len(t1files)) + sys.stdout.flush(); + + if folder2==None: t2files = [] + else: + # get all data in a recursive way for the second folder + t2files = glob.glob(os.path.join(folder2, "*.xy.gz")) + if moptions['recursive']==1: + t2files.extend(glob.glob(os.path.join(folder2, "*/*.xy.gz"))) + t2files.extend(glob.glob(os.path.join(folder2, "*/*/*.xy.gz"))) + t2files.extend(glob.glob(os.path.join(folder2, "*/*/*/*.xy.gz"))) + t2files.extend(glob.glob(os.path.join(folder2, "*/*/*/*/*.xy.gz"))) + print("Get folder2"); + # for read-based independent testing + print(t2files.__sizeof__(), len(t2files)) + if moptions['test'][0] == '0': + if moptions['test'][1]>0.5: + t2files = t2files[:int(len(t2files)*moptions['test'][1])] + else: t2files = t2files[-int(len(t2files)*moptions['test'][1]):] + print(t2files.__sizeof__(), len(t2files)) + sys.stdout.flush(); + return t1files, t2files + +# +# get data from a data file +# +def getDataFromFile(fn, moptions): + mdata = np.loadtxt(fn, dtype=np.float32) + # get genomic position, label, and feature + t0, ty, tx = np.split(mdata, [0,2], axis=1); + return (tx, ty, None) + +# +# get X and Y from a data file +# +def getDataFromFile_new(fn, moptions, mfind0ld=None): + mdata = np.loadtxt(fn, dtype=np.float32) + # get genomic position, label, and feature + t0, ty, tx = np.split(mdata, [1,3], axis=1); + + if moptions['test'][0] in ['-', '+']: + t0 = t0.astype(int) + + nan_file = [] + m_data = []; m_y = []; + if not mfind0ld==None: + pos_to_file_dict = defaultdict(); preind = 0 + mfind0ldkeys = sorted(list(mfind0ld.keys())); + for mind in range(len(ty)): + if not mfind0ld==None: + if preind<len(mfind0ldkeys) and mind == mfind0ldkeys[preind]: + pos_to_file_dict[len(m_y)] = mfind0ld[ mfind0ldkeys[preind] ] + preind += 1 + + # for region-based independent testing + if (ty[mind][0]<0.01 and ty[mind][1]<0.01): continue; + if (moptions['test'][0]=='-' and moptions['test'][1]<t0[mind]<moptions['test'][2]) or \ + (moptions['test'][0]=='+' and (not moptions['test'][1]<t0[mind]<moptions['test'][2])): + continue; + ## check nan values in feature files + has_nan_value = False; + for cur_row in tx[(mind-int(moptions['windowsize']/2)):(mind+int(moptions['windowsize']/2)+1)]: + if np.isnan(cur_row).any(): + has_nan_value = True; + break; + if has_nan_value: + if fn in nan_file: pass + else: + print ("Warning-nan-value {}".format(fn)) + nan_file.append(fn); + else: + m_y.append(ty[mind]) + m_data.append(tx[(mind-int(moptions['windowsize']/2)):(mind+int(moptions['windowsize']/2)+1)]) + if not mfind0ld==None: + # since each feature file contains data from several fast5 files + # get the staring rows of fast5 files in feature files + file_to_pos_dict = defaultdict(); + ptofkeys = sorted(list(pos_to_file_dict.keys())) + for npk_ind in range(len(ptofkeys)): + if (npk_ind+1<len(ptofkeys) and ptofkeys[npk_ind+1]-ptofkeys[npk_ind]<500) or len(m_y)-ptofkeys[npk_ind]<500: continue; + + file_to_pos_dict[ pos_to_file_dict[ptofkeys[npk_ind]] ] = [ptofkeys[npk_ind], (ptofkeys[npk_ind+1] if npk_ind+1<len(ptofkeys) else len(m_y))] + + # reshape the data + if len(m_data)>0: + m_data = np.reshape(m_data, (len(m_data), len(m_data[0]), len(m_data[0][0]))) + m_y = np.reshape(m_y, (len(m_y), len(m_y[0]))).astype(int) + + if not mfind0ld==None: + return (m_data, m_y, file_to_pos_dict); + else: return (m_data, m_y, None) + +# +# get index information for each fast5 file in a data file +# +def getGZFilePos(gzfile): + # get the staring row of fast5 files in a feature file + mfind = defaultdict() + with open(gzfile[:-len('.gz')]+'.ind', 'r') as mr: + line = mr.readline() + while line: + line = line.strip(); + lsp = line.split(); + if len(lsp)>1: + mfind[int(lsp[0])] = lsp[1] + line = mr.readline() + return mfind + +# +# make a prediction for data from a data file +# +def mPred(mfbase, mffolder, accuracy, X, Y, test_gzfile2, pf, num_input, auc_op, mpre, mspf, init_l, mfpred, timesteps, moptions): + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + # restore the well-trained model + new_saver = tf.train.import_meta_graph(mfbase+'.meta') + new_saver.restore(sess,tf.train.latest_checkpoint(mffolder)) + + # write the prediction information + pfwriter = open(pf, 'w'); + for test_gzfile in test_gzfile2: + for test_fn_ind in range(len(test_gzfile)): + # get feature files for prediction + test_gzfeature, test_gzlabel, _ = getDataFromFile_new(test_gzfile[test_fn_ind], moptions) + if len(test_gzfeature)<1: continue; + + ftlist = np.array_split(test_gzfeature, int(len(test_gzfeature)/batchsize)+1) + lblist = np.array_split(test_gzlabel, int(len(test_gzlabel)/batchsize)+1) + for fti in range(len(ftlist)): + sess.run(init_l) + mfpred_output = sess.run([ mfpred], feed_dict={X:ftlist[fti], Y:lblist[fti]}) + # get ture positive, false positive, false negative and true negative + tp, fp, fn, tn = 0, 0, 0, 0 + for pi in range(len(mfpred_output[0])): + if mfpred_output[0][pi]==1: + if lblist[fti][pi][1]==1: tp += 1 + else: fp += 1 + else: + if lblist[fti][pi][1]==1: fn += 1 + else: tn += 1 + pfwriter.write('tp=%d fp=%d fn=%d tn=%d %s\n' % (tp, fp, fn, tn, test_gzfile[test_fn_ind])) + pfwriter.flush() + pfwriter.close(); + +# +# entry for prediction +# +def pred_prepare(moptions, test_file, accuracy, X, Y, auc_op, mpre, mspf, init_l, mfpred): + mPred(moptions['modfile'][0], moptions['modfile'][1], accuracy, X, Y, test_file, moptions['outFolder']+moptions['FileID']+'_mpred.txt', moptions['fnum'], auc_op, mpre, mspf, init_l, mfpred, moptions['windowsize'], moptions) + +# +# prepare training process +# +def mMult_RNN_LSTM_train(moptions): + # get folders of feature files + filegroups = moptions['wrkBase'].split(';') + for i in range(len(filegroups)): + filegroups[i] = filegroups[i].split(',') + + print(filegroups) + + # get feature files + filelists = [[] for _ in range(len(filegroups))] + for i in range(len(filegroups)): + for fgj in range(len(filegroups[i])): + if not len(filegroups[i][fgj])>0: continue + filelists[i].extend(getTFiles1(filegroups[i][fgj], moptions)) + mostnum, mostid = 0, -1; + # set seed so that random setting + np.random.seed(3) + for i in range(len(filelists)): + np.random.shuffle(filelists[i]) + if len(filelists[i])>mostnum: + mostnum = len(filelists[i]) + mostid = i; + + # set seed so that random setting + np.random.seed(7) + # path for model file + if 'modfile' in moptions and (not moptions['modfile']==None): + if moptions['modfile'].rfind('/')==-1: + moptions['modfile'] = [moptions['modfile'], './'] + else: + moptions['modfile'] = [moptions['modfile'], moptions['modfile'][:moptions['modfile'].rfind('/')+1]] + + if not mostid==0: + filelists[mostid], filelists[0] = filelists[0], filelists[mostid] + + accuracy, X, Y, auc_op, mpre, mspf, init_l, mfpred = train_save_model(filelists, moptions['fnum'], moptions['hidden'], moptions['windowsize'], moptions) + +# +# prepare prediction process +# +def pred_entry(moptions): + # get feature files + tfiles = [getTFiles1(moptions['wrkBase'], moptions)] + + init, init_l, loss_op, accuracy, train_op, X, Y, saver, auc_op, mpre, mspf, mfpred = mCreateSession(moptions['fnum'], moptions['hidden'], moptions['windowsize'], moptions) + + # path for model file + if moptions['modfile'].rfind('/')==-1: + moptions['modfile'] = [moptions['modfile'], './'] + else: + moptions['modfile'] = [moptions['modfile'], moptions['modfile'][:moptions['modfile'].rfind('/')+1]] + + pred_prepare(moptions, tfiles, accuracy, X, Y, auc_op, mpre, mspf, init_l, mfpred)