--- a +++ b/bin/DeepMod_scripts/myDetect.py @@ -0,0 +1,1286 @@ + +import os; +import sys; +import string; +import glob; +import time +import copy + +import h5py +import numpy as np +import multiprocessing + +from collections import defaultdict +from distutils.version import LooseVersion + +import tempfile +import subprocess + +import re; + +from . import myCom +#import myCom + +import tensorflow as tf +from tensorflow.contrib import rnn +from . import myMultiBiRNN +from . import EventTable +from . import MoveTable + +rnn_pred_batch_size = 512 + +# default path for data in fast5 files +fast5_channel_id= 'UniqueGlobalKey/channel_id' +fast5_analysis = ''.join(['/', myCom.analyses_base]) # +fast5_events = myCom.basecall_events_base # +fast5_rawReads = ''.join(['/', myCom.raw_base, '/', myCom.reads_base]) # +fast5_basecall_fq = myCom.basecall_fastq_base # +fast5_signal = myCom.signal_base # + +pre_base_str = 'rnn.pred.ind' + +# +# get digitisation, offset, range, sampling_rate from fast5 files +# +def get_channel_info(moptions, sp_param): + if not sp_param['f5status']=="": return; + try: + channel_info = sp_param['f5reader'][fast5_channel_id].attrs + sp_param["channel_info"] = {'digitisation':channel_info['digitisation'], 'offset':channel_info['offset'], 'range':channel_info['range'], 'sampling_rate':channel_info['sampling_rate'], 'channel_number':channel_info['channel_number']} + except: + raiseError("No Channel Info", sp_param, "No Channel Info") + +# print error message +def raiseError(sp_info, sp_param, errk): + sp_param['f5status'] = errk + print ('Error!!! %s in %s' % (sp_info, sp_param['mfile_path'])) + sys.stdout.flush() + +# +# get Albacore version used +# only support v1+ and v2+ +# +def getAlbacoreVersion(moptions, sp_param): + if not sp_param['f5status']=="": return; + try: + ver_path = ''.join([fast5_analysis,'/', moptions['basecall_1d'] ]) + #add .decode("utf-8") to make it compatible to py3 + try: + used_version = LooseVersion(sp_param['f5reader'][ver_path].attrs['version'].decode("utf-8") if 'version' in sp_param['f5reader'][ver_path].attrs else "0.0") + except: + used_version = LooseVersion(sp_param['f5reader'][ver_path].attrs['version'] if 'version' in sp_param['f5reader'][ver_path].attrs else "0.0") + sp_param['get_albacore_version'] = used_version + if used_version < LooseVersion("1.0"): # + sp_param['used_albacore_version'] = 1; + elif used_version < LooseVersion("2.0"): sp_param['used_albacore_version'] = 1; + elif used_version >= LooseVersion("2.0"): sp_param['used_albacore_version'] = 2; + except: # default verion is 1 now + sp_param['used_albacore_version'] = 1; + +# not used now. +def get_kmer_corrected_info(moptions): + if ('kmer_model_file' not in moptions) or moptions['kmer_model_file']==None or (not os.path.isfile(moptions['kmer_model_file'])): return; + + fr = open(moptions['kmer_model_file'], 'r') + moptions['kmer_model_dict'] = defaultdict() + line = fr.readline(); + while line: + line = string.strip(line); + if len(line)>0 and (not line[0]=='#'): + try: + c_kmer, c_level_mean, c_level_stdv = line.split()[:3] + c_level_mean, c_level_stdv = float(c_level_mean), float(c_level_stdv) + moptions['kmer_model_dict'][c_kmer] = (c_level_mean, 1/(c_level_stdv*c_level_stdv)) + except: + pass; + line = fr.readline(); + fr.close(); + +# not used now +# get shift and scale values for normalization +# +def get_cur_shift_scale(moptions, sp_param): + if not sp_param['f5status']=="": return; + if "kmer_model_dict" not in moptions: return; + + event_key = 'm_event' + + try: + cur_model = np.array([moptions['kmer_model_dict'][c_model_state] for c_model_state in sp_param[event_key]['model_state']], dtype=[('level_mean', np.float), ('level_stdv', np.float)]); + c_mean_stdv = cur_model['level_mean']*cur_model['level_stdv'] + c_mean_stdv_sum = c_mean_stdv.sum() + model_coef_matrix = np.array(( (cur_model['level_stdv'].sum(), c_mean_stdv_sum), \ + (c_mean_stdv_sum, (c_mean_stdv*cur_model['level_mean']).sum()) \ + )) + c_event_stdv = sp_param[event_key]['mean'] * cur_model['level_stdv'] + c_event_stdv_mean = c_event_stdv * cur_model['level_mean'] + dependent_array = np.array((c_event_stdv.sum(), c_event_stdv_mean.sum())); + + sp_param['shift_scale'] = {} + sp_param['shift_scale']['cal_shift'], sp_param['shift_scale']['cal_scale'] = np.linalg.solve(model_coef_matrix, dependent_array) + sp_param['shift_scale']['chn_shift'], sp_param['shift_scale']['chn_scale'] = -sp_param["channel_info"]['offset'], sp_param["channel_info"]['digitisation']/sp_param["channel_info"]['range'] + + sp_param['shift_scale']['shift']=sp_param['shift_scale']['chn_shift']+sp_param['shift_scale']['chn_scale']*sp_param['shift_scale']['cal_shift'] + sp_param['shift_scale']['scale']=sp_param['shift_scale']['chn_scale']*sp_param['shift_scale']['cal_scale'] + + sp_param['raw_signals'] = np.round(sp_param['raw_signals']/sp_param['shift_scale']['cal_scale'] - sp_param['shift_scale']['cal_shift']/sp_param['shift_scale']['cal_scale'], 6) + except: + raiseError('Cannot nanopore correction', sp_param, "Cannot nanopore correction") + +# +# get events from a fast5 file +# +def getEvent(moptions, sp_param): + if not sp_param['f5status']=="": return; + + # If use move tables intead of event tables + if moptions['move']: + try: # get events from a fast5 file' + mv_str = '/'.join(['', 'Analyses', moptions['basecall_1d'], moptions['basecall_2strand'], 'Move']) + move_data = sp_param['f5reader'][mv_str][()] + sp_param['events_data'] = move_data + except: + raiseError('No move data', sp_param, "No move data") + return; + m_event = MoveTable.getMove_Info(moptions, sp_param, move_data) + sp_param['m_event'] = m_event + # get sequence from events + sp_param['m_event_basecall'] = sp_param['fq_seq'] + sp_param['left_right_skip'] = (0, 0) + + + return + # End the part of getting move tables + + try: # get events from a fast5 file + event_path = ''.join([fast5_analysis, '/', moptions['basecall_1d'], '/', moptions['basecall_2strand'], '/', fast5_events]) + events_data = sp_param['f5reader'][event_path].value + except: + raiseError('No events data', sp_param, "No events data") + return; + + convertError = False; + + if sp_param['f5status'] == "": + sp_param['events_data'] = events_data + if sp_param['used_albacore_version']==1: + move0_left = 0; move0_right = len(events_data)-1; + while move0_left<move0_right: # get the first non-stay event at the left tail + if events_data['move'][move0_left]==0: move0_left += 1; + else: break; + if move0_left>move0_right-20: + raiseError(("Too many move0 at 3'(l%d, r%d)" % (move0_left, move0_right)), sp_param, "Remove too many bases on left") + return; + while move0_right>move0_left: # get the last non-stay event at the right tail + if events_data['move'][move0_right]==0: move0_right -= 1 + else: break; + if move0_right<move0_left+20: + raiseError(("Too many move0 at 5'(l%d, r%d)" % (move0_left, move0_right)), sp_param, 'Remove too many bases on right') + return + + # get the starting time + based_ind = events_data['start'][move0_left].astype(np.float64)*sp_param["channel_info"]["sampling_rate"] - sp_param['raw_attributes']['start_time'] + first_base_index_in_raw_signal = np.round(events_data['start'][move0_left].astype(np.float64)*sp_param["channel_info"]["sampling_rate"]).astype(np.int64) - sp_param['raw_attributes']['start_time'] + # get the potential error of the starting time + if first_base_index_in_raw_signal<-2: + raiseError(('The index of the first base is less than -2(%d=%.6f*%d-%d)' % (first_base_index_in_raw_signal, events_data['start'][move0_left].astype(np.float64), sp_param["channel_info"]["sampling_rate"], sp_param['raw_attributes']['start_time'])), sp_param, "The index of the first base is less than -2") + return; + elif first_base_index_in_raw_signal<0: + first_base_index_in_raw_signal = 0 + if moptions['outLevel']<=myCom.OUTPUT_INFO: print ('Warning!!! first_base_index_in_raw_signal less than 0 ' + sp_param['mfile_path']) + + first_base_index_in_raw_signal = np.uint64(first_base_index_in_raw_signal) + + m_event = []; pre_i = move0_left; + cur_length=(events_data['length'][pre_i]*sp_param["channel_info"]["sampling_rate"]).astype('uint64'); + for i in range(move0_left+1, move0_right+1): + if events_data['move'][i]>0: # for non-stay event + if pre_i==move0_left: + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), first_base_index_in_raw_signal, cur_length, events_data['model_state'][pre_i].upper())) + else: # calculate starting index in raw signal + # calculated position + cal_st = (events_data['start'][pre_i]-events_data['start'][move0_left])*sp_param["channel_info"]["sampling_rate"]+based_ind + if cal_st<0: print("Wanging Less than 0") + if cal_st>0 and cal_st - (m_event[-1][2]+ m_event[-1][3])>0 and (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')>0: + if (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')>2: # + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64'), events_data['model_state'][pre_i].upper())) + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), cal_st.astype('uint64'), cur_length, events_data['model_state'][pre_i].upper())) + else: # for a normal event + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')+ cur_length, events_data['model_state'][pre_i].upper())) + else: + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], cur_length, events_data['model_state'][pre_i].upper())) + if m_event[-1][2]>np.iinfo(np.int64).max-2 or m_event[-1][2]<0: + if not convertError: + print ('ex: %.7f*%d=%.0f' % (events_data['start'][move0_left].astype(np.float64), sp_param["channel_info"]["sampling_rate"], events_data['start'][move0_left].astype(np.float64)*sp_param["channel_info"]["sampling_rate"])), sp_param['raw_attributes']['start_time'], sp_param['mfile_path'], m_event[-1][2], m_event[-1][3] + convertError = True; + pre_i = i; + cur_length=(events_data['length'][i]*sp_param["channel_info"]["sampling_rate"]).astype('uint64'); + else: # for stay event + cur_length += (events_data['length'][i]*sp_param["channel_info"]["sampling_rate"]).astype('uint64') + if sp_param['f5status'] == "": # for the last event + # calculated position + cal_st = (events_data['start'][pre_i]-events_data['start'][move0_left])*sp_param["channel_info"]["sampling_rate"]+based_ind + if cal_st<0: print("Wanging Less than 0") + if cal_st>0 and cal_st - (m_event[-1][2]+ m_event[-1][3])>0 and (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')>0: + if (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')>2: + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64'), events_data['model_state'][pre_i].upper())) + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), cal_st.astype('uint64'), cur_length, events_data['model_state'][pre_i].upper())) + else: + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], (cal_st - (m_event[-1][2]+ m_event[-1][3])).astype('uint64')+ cur_length, events_data['model_state'][pre_i].upper())) + else: + m_event.append((round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), m_event[-1][2]+ m_event[-1][3], cur_length, events_data['model_state'][pre_i].upper())) + + # decode + m_event = np.array(m_event, dtype=[('mean', '<f4'), ('stdv', '<f4'), ('start', np.uint64), ('length', np.uint64), ('model_state', 'U5')]) #'S5')] + sp_param['m_event'] = m_event + # get sequence from events + sp_param['m_event_basecall'] = ''.join([event_model_state[2] for event_model_state in m_event['model_state']]); + sp_param['left_right_skip'] = (move0_left, len(events_data)-move0_right-1) + elif sp_param['used_albacore_version']==2: + if moptions['SignalGroup']=='simple': + m_event = []; + pre_i = 0; pre_length = events_data['length'][pre_i].astype('uint64'); + for cur_i in range(1, len(events_data)): + if events_data['move'][cur_i]>0: # non-stay vents + m_event.append( (round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), events_data['start'][pre_i], pre_length, events_data['model_state'][pre_i]) ) + + pre_i = cur_i; pre_length = events_data['length'][pre_i].astype('uint64'); + else: # for stay events + pre_length += events_data['length'][cur_i].astype('uint64'); + m_event.append( (round(events_data['mean'][pre_i],3), round(events_data['stdv'][pre_i],3), events_data['start'][pre_i], pre_length, events_data['model_state'][pre_i]) ) + # format events + m_event = np.array(m_event, dtype=[('mean', '<f4'), ('stdv', '<f4'), ('start', np.uint64), ('length', np.uint64), ('model_state', 'U5')]) + else: + m_event = EventTable.getEvent_Info(moptions, sp_param, events_data) + + sp_param['m_event'] = m_event + # get sequence from events + sp_param['m_event_basecall'] = ''.join([event_model_state[2] for event_model_state in m_event['model_state']]); + sp_param['left_right_skip'] = (0, 0) + else: + raise RuntimeError ("This version of Albacore is not supported. Please use the version of Albacore 1.x or 2.x") + +# +# normalize raw signals +# +def mnormalized(moptions, sp_param): + + if not sp_param['m_event']['start'][0] < (sp_param['m_event']['start'][-1]+sp_param['m_event']['length'][-1]): + print ('Fatal error signal start position is less than the end position', sp_param['mfile_path'], sp_param['m_event']['start'][0], sp_param['m_event']['start'][-1], sp_param['m_event']['length'][-1]) + + # get shift and scale + mshift = np.median(sp_param['raw_signals'][sp_param['m_event']['start'][0]:(sp_param['m_event']['start'][-1]+sp_param['m_event']['length'][-1])]) + mscale = np.median(np.abs(sp_param['raw_signals'][sp_param['m_event']['start'][0]:(sp_param['m_event']['start'][-1]+sp_param['m_event']['length'][-1])]-mshift)); + # standardize + sp_param['raw_signals'] = (sp_param['raw_signals'] - mshift)/mscale + # get meand + read_med = np.median(sp_param['raw_signals'][sp_param['m_event']['start'][0]:(sp_param['m_event']['start'][-1]+sp_param['m_event']['length'][-1])]) + read_mad = np.median(np.abs(sp_param['raw_signals'][sp_param['m_event']['start'][0]:(sp_param['m_event']['start'][-1]+sp_param['m_event']['length'][-1])] - read_med)) + lower_lim = read_med - (read_mad * 5) + upper_lim = read_med + (read_mad * 5) + # normalize as nanoraw did. + sp_param['raw_signals'] = np.round(np.array([upper_lim if sp_param['raw_signals'][i]>upper_lim else (lower_lim if sp_param['raw_signals'][i]<lower_lim else sp_param['raw_signals'][i]) for i in range(np.size(sp_param['raw_signals']))]), 3) + +# +# get Signal from a fast5 file +# +def getRawInfo(moptions, sp_param): + if not sp_param['f5status']=="": return; + + try: + # get attribute of raw signals + for raw_data in sp_param['f5reader'][fast5_rawReads].values(): pass; + sp_param['raw_attributes'] = dict(raw_data.attrs.items()) + + sp_param['raw_signals'] = raw_data['Signal'][()] + except: + raiseError(("No Raw_reads/Signal data %s" % (fast5_rawReads)), sp_param, "No Raw_reads/Signal") + +# +# get channel_info, AlbacoreVersion, read_id, Raw Signals, Event from a fast5 file +# +def getFast5Info(moptions, sp_param): + # get channel info + get_channel_info(moptions, sp_param) + if "channel_info" not in sp_param: + raiseError(("Channel information could not be found in %s " % fast5_channel_id), sp_param, "Channel information could not be found") + return; + # get albacore version + getAlbacoreVersion(moptions, sp_param) + if 'used_albacore_version' not in sp_param: + return + + try: + # get fastq attribute + fq_path = ''.join([fast5_analysis,'/',moptions['basecall_1d'],'/',moptions['basecall_2strand'],'/',fast5_basecall_fq]) + fq_data = sp_param['f5reader'][fq_path][()] + except: + raiseError('No Fastq data', sp_param, "No Fastq data") + if sp_param['f5status']=="": + fq_data = (fq_data.decode(encoding="utf-8")).split('\n') + sp_param['read_id'] = (fq_data[0][1:] if fq_data[0][0]=='@' else fq_data[0]).replace(" ", ":::").replace("\t", "|||") + sp_param['fq_seq'] = fq_data[1]; + # get raw signals + getRawInfo(moptions, sp_param) + # get events + if sp_param['f5status']=="": + getEvent(moptions, sp_param) + # normalize signals. + if sp_param['f5status']=="": + mnormalized(moptions, sp_param) + + if sp_param['f5status']=="": + # get mean, std for each event + for i in range(len(sp_param['m_event'])): + if (len(sp_param['raw_signals'][sp_param['m_event']['start'][i]:sp_param['m_event']['start'][i]+sp_param['m_event']['length'][i]])==0): + print ('Signal out of range {}: {}-{} {};{} for {}'.format(i, sp_param['m_event']['start'][i], sp_param['m_event']['length'][i], len(sp_param['m_event']), len(sp_param['raw_signals']), sp_param['mfile_path'])) + if i>500: + sp_param['m_event'] = sp_param['m_event'][:i-1] + else: + sp_param['f5status']=="Less event" + break; + sp_param['m_event']['mean'][i] = round(np.mean(sp_param['raw_signals'][sp_param['m_event']['start'][i]:sp_param['m_event']['start'][i]+sp_param['m_event']['length'][i]]), 3) + sp_param['m_event']['stdv'][i] = round(np.std(sp_param['raw_signals'][sp_param['m_event']['start'][i]:sp_param['m_event']['start'][i]+sp_param['m_event']['length'][i]]), 3) + +# +# associate signals for each event in a fast5 file +# +def get_Event_Signals(moptions, sp_options, f5files): + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + start_time = time.time(); runnum = 0; + + f5data = {} + sp_options["Error"] = defaultdict(list) + sp_options["get_albacore_version"] = defaultdict(int) + # for each fast5 file + for f5f in f5files: + try: + with h5py.File(f5f, 'r') as mf5: + sp_param = {} + sp_param['mfile_path'] = f5f + sp_param['f5reader'] = mf5 + sp_param['f5status'] = ""; + getFast5Info(moptions, sp_param) + if 'get_albacore_version' in sp_param: + sp_options["get_albacore_version"][str(sp_param['get_albacore_version'])] += 1 + if sp_param['f5status'] == "": + if sp_param['read_id'] in f5data: + print ('Duplicate id', sp_param['read_id'], f5f) + f5data[sp_param['read_id']] = (sp_param['m_event_basecall'], sp_param['m_event'], sp_param['raw_signals'], f5f, sp_param['left_right_skip']) + else: + sp_options["Error"][sp_param['f5status']].append(f5f) + # for outputing progress + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + runnum += 1; + + if runnum%500==0: + end_time = time.time(); + print ("%d consuming time %d" % (runnum, end_time-start_time)) + except: + sp_options["Error"]["Cannot open fast5 or other errors"].append(f5f) + print("Cannot open fast5 or other errors: {}".format(f5f)) + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + end_time = time.time(); + print ("All consuming time %d" % (end_time-start_time)) + + return f5data; + +# +# get signals of events +# map bases from events to a reference genome +# +def mDetect1(moptions, sp_options, f5files): + f5data = get_Event_Signals(moptions, sp_options, f5files) + + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: start_time = time.time(); + # for fa files of base sequences from events + temp_fa = tempfile.NamedTemporaryFile(suffix='.fa', mode='w') + f5keys = sorted(f5data.keys()); #f5keys.sort() + for f5k in f5keys: + temp_fa.write(''.join(['>', f5k, '\n', f5data[f5k][0], '\n'])) + temp_fa.flush(); + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + end_time = time.time(); + print ("Write consuming time %d" % (end_time-start_time)) + + # run alignmen tools of bwa-mem or minimap2 + temp_sam = tempfile.NamedTemporaryFile() + if moptions['alignStr']=='bwa': + cmd_opt = ['mem', '-x', 'ont2d', '-v', '1', '-t', '1', moptions['Ref'], temp_fa.name] + else: + cmd_opt = ['-ax', 'map-ont', moptions['Ref'], temp_fa.name] + returncode = subprocess.call([moptions['alignStr'],]+cmd_opt, stdout=temp_sam) + if not returncode==0: + print ('Fatal Error!!! returncode is non-zero(%d) for "%s"' % (returncode, curcmd)) + errkey = "Cannot running aligment" + for f5k in f5keys: + sp_options["Error"][errkey].append(f5data[f5k][3]) + return; + + temp_fa.close(); + temp_sam.seek(0); + # get content from aligned results + align_info = temp_sam.readlines() + align_info = [str(align_info[i], 'utf-8').strip() for i in range(len(align_info))] + temp_sam.close(); + + sp_param = defaultdict(); + sp_param['f5data'] = f5data + + f5align = defaultdict() + f5keydict = defaultdict(); + sp_param['ref_info'] = defaultdict() + + if moptions['outLevel']<=myCom.OUTPUT_DEBUG:start_time = time.time(); + ilid = 0; + # get alignment records + while ilid < len(align_info): + if len(align_info[ilid])==0 or align_info[ilid][0]=='@': + ilid += 1 + continue; + + sp_param['f5status'] = ""; + sp_param['line'] = align_info[ilid] + qname = handle_line(moptions, sp_param, f5align) + if sp_param['f5status'] == "": + f5keydict[qname] = True; + ilid += 1 + + # get unmapped reads + for f5k in f5keys: + if f5k not in f5keydict: + sp_options["Error"]["Not in alignment sam"].append(f5data[f5k][3]) + + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + end_time = time.time(); + print ("Get BAM consuming time %d" % (end_time-start_time)) + + # get features, prediction for each fast5 files + sp_param['f5status']= "" + sp_param['line'] = "" + if moptions['outLevel']<=myCom.OUTPUT_DEBUG:start_time = time.time(); + handle_record(moptions, sp_options, sp_param, f5align, f5data) + if moptions['outLevel']<=myCom.OUTPUT_DEBUG: + end_time = time.time(); + print ("Analyze & annotate & save consuming time %d" % (end_time-start_time)) + +# +# get reference sequenceng from a reference genome +# +def getRefSeq(moptions, sp_param, rname): + temp_seq = tempfile.NamedTemporaryFile() + # command to be run + cmd_opt = ['faidx', moptions['Ref'], rname] + returncode = subprocess.call(['samtools',]+cmd_opt, stdout=temp_seq) + if not returncode==0: + print ('Fatal Error!!! cannot find the chrosome sequence %s' % rname) + else: + temp_seq.seek(0); + seqinfo = temp_seq.readlines() + seqinfo = [str(seqinfo[i], 'utf-8').strip() for i in range(len(seqinfo))] + temp_seq.close(); + + sp_param['ref_info'][rname] = ''.join(seqinfo[1:]).strip().upper() + +# +# get mapping information and associate it with events/signals. +# +def handle_record(moptions, sp_options, sp_param, f5align, f5data): + alignkeys = list(f5align.keys()); + numreg = re.compile('\d+') + mdireg = re.compile('[MIDNSHPX=]{1}') + + # for each alignment record + for readk_ind in range(len(alignkeys)): + sp_param['f5status']= "" + readk = alignkeys[readk_ind] + sp_param['mfile_path'] = f5data[readk][3] + + mapq, flag, rname, pos, cigar, readseq = f5align[readk] + + # output chromosome of interest + if (not moptions['ConUnk']) and ((not rname.find('_')==-1) or (not rname.find('-')==-1) or (not rname.find('/')==-1) or (not rname.find(':')==-1)): + continue; + isinreg = False; + # check the region of interest + for cur_mr in moptions['region']: + if (cur_mr[0] in ['', None, rname]): + isinreg = True; + break; + if not isinreg: + continue; + + # get reference information + if rname not in sp_param['ref_info']: + getRefSeq(moptions, sp_param, rname) + refseq = sp_param['ref_info'][rname] + + # mapped position and strand + pos = pos - 1 + forward_reverse = '-' if flag&0x10 else '+' + + numinfo = numreg.findall(cigar); + mdiinfo = mdireg.findall(cigar) + numinfo = [int(numinfo[i]) for i in range(len(numinfo))] #map(int, numinfo) + + leftclip = 0; rightclip = 0; + # remove left clip + while mdiinfo[0] in ['I', 'D', 'N', 'S', 'H', 'P', 'X']: + if mdiinfo[0] in ['I', 'S', 'X']: + leftclip += numinfo[0]; readseq = readseq[numinfo[0]:] + if mdiinfo[0] in ['H']: leftclip += numinfo[0] + if mdiinfo[0] in ['D', 'N', 'X']: + pos += numinfo[0] + numinfo = numinfo[1:]; mdiinfo = mdiinfo[1:] + # remove right clip + while mdiinfo[-1] in ['I', 'D', 'N', 'S', 'H', 'P', 'X']: + if mdiinfo[-1] in ['I', 'S', 'X']: + rightclip += numinfo[-1]; readseq = readseq[:-numinfo[-1]] + if mdiinfo[-1] in ['H']: rightclip += numinfo[-1] + numinfo = numinfo[:-1]; mdiinfo = mdiinfo[:-1] + if forward_reverse=='+': # remove clipped events + if rightclip>0: m_event = f5data[readk][1][leftclip:-rightclip] + else: m_event = f5data[readk][1][leftclip:] + else: + if leftclip>0: m_event = f5data[readk][1][rightclip:-leftclip] + else: m_event = f5data[readk][1][rightclip:] + + isinreg = False; + # for specify regions + for cur_mr in moptions['region']: + if (cur_mr[0] in ['', None, rname]) and \ + (cur_mr[1] in ['', None] or pos>cur_mr[1]) and \ + (cur_mr[2] in ['', None] or pos+len(m_event)<cur_mr[2]): + isinreg = True; + break; + ### for check consistency + if not isinreg: + continue; + + lastmatch = None; firstmatch = None; + first_match_pos = None; last_match_pos = None + last_al_match = None; first_al_match = None + lasmtind = 0; + base_map_info = []; #indel_groups = defaultdict() + nummismatch = 0; numinsert = 0; numdel = 0; + read_ind = 0; + # get map detail: reference base, read base, ref position, read-position + for n1ind in range(len(numinfo)): + mdi = mdiinfo[n1ind]; + for n1i in range(numinfo[n1ind]): + # for each cigar type + if mdi=='M': + base_map_info.append((refseq[pos], readseq[read_ind], pos, read_ind, 0)) + if refseq[pos]==readseq[read_ind]: + if firstmatch==None: firstmatch = read_ind + if lastmatch==None or lastmatch<read_ind: lastmatch = read_ind; lasmtind=n1ind + if first_al_match==None: first_al_match=len(base_map_info)-1 + if last_al_match==None or last_al_match<len(base_map_info): last_al_match=len(base_map_info)-1 + if first_match_pos==None: first_match_pos = pos + if last_match_pos==None or last_match_pos<pos: last_match_pos = pos + else: nummismatch += 1 + pos += 1; read_ind += 1; + elif mdi =='I': + base_map_info.append(('-', readseq[read_ind], pos, read_ind, 0)) + read_ind += 1; + numinsert += 1 + elif mdi == 'D': + base_map_info.append((refseq[pos], '-', pos, read_ind, 0)) + pos += 1; + numdel += 1 + elif mdi == 'N': + base_map_info.append((refseq[pos], '-', pos, read_ind, 0)) + pos += 1; + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ('CIGAR-Error N exist', f5data[readk][3]) + elif mdi == 'S': + read_ind += 1; + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ('CIGAR-Error!!! S in the middle of the sequence', f5data[readk][3]) + elif mdi == 'H': + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ('CIGAR-Error!!! H in the middle of the sequence', f5data[readk][3]) + elif mdi == 'P': + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ('CIGAR-Error!!! P exist', f5data[readk][3]) + elif mdi == '=': + base_map_info.append((refseq[pos], readseq[read_ind], pos, read_ind, 0)) + if first_match_pos==None: first_match_pos = pos + if last_match_pos==None or last_match_pos<pos: last_match_pos = pos + pos += 1; read_ind += 1; + if firstmatch==None: firstmatch = read_ind - 1 + if lastmatch==None or lastmatch<read_ind-1: lastmatch = read_ind - 1; lasmtind=n1ind + if last_al_match==None or last_al_match<len(base_map_info): last_al_match=len(base_map_info)-1 + if first_al_match==None: first_al_match=len(base_map_info)-1 + elif mdi == 'X': + base_map_info.append((refseq[pos], readseq[read_ind], pos, read_ind, 0)) + pos += 1; read_ind += 1; + nummismatch += 1 + else: + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ('CIGAR-Error!!!', 'Warning unknow CIGAR element ' + str(numinfo[n1ind]) + ' ' + mdi, f5data[readk][3]) + if firstmatch==None or lastmatch==None or firstmatch<0 or lastmatch<0: + if moptions['outLevel']<=myCom.OUTPUT_WARNING: + print ("Errorfast5 "+f5data[readk][3]) + print('match-Error!!! no first and/or last match',f5data[readk][3],('firstmatch=%d' % firstmatch) if not (firstmatch==None) else "N", ('lastmatch%d' % lastmatch) if not (lastmatch==None) else "N", str(flag), rname, str(pos)); + print('\tf=%d, chr=%s, p=%d, c=%s, s=%s' % (flag, rname, pos, cigar, readseq)) + continue; + + # remove more un-matched events + if forward_reverse=='+': + if not firstmatch==None: leftclip += firstmatch + if (not lastmatch==None) and len(m_event)-lastmatch>1: rightclip += len(m_event)-lastmatch-1 + else: + if not firstmatch==None: rightclip += firstmatch + if (not lastmatch==None) and len(m_event)-lastmatch>1: leftclip += len(m_event)-lastmatch-1 + + if forward_reverse=='+': + if len(m_event)-lastmatch>1: + m_event = m_event[firstmatch:(lastmatch+1-len(m_event))] + elif firstmatch>0: m_event = m_event[firstmatch:] + else: + if firstmatch>0: m_event = m_event[(len(m_event)-1-lastmatch):-firstmatch] + elif len(m_event)-lastmatch>1: m_event = m_event[(len(m_event)-1-lastmatch):] + # check potential error + if firstmatch>0 or len(base_map_info)-last_al_match>1: + if moptions['outLevel']<=myCom.OUTPUT_WARNING and ((firstmatch>0) or (len(base_map_info)-last_al_match>1 and refseq[last_match_pos+1] not in ['N'])): + print ("Errorfast5"+f5data[readk][3]) + print ('Warning!!! first not match', firstmatch, lastmatch, first_al_match, last_al_match, len(base_map_info), numinfo[lasmtind-2:(lasmtind+5)], mdiinfo[lasmtind-2:(lasmtind+5)], lasmtind, len(numinfo)) + print('\tref='+refseq[last_match_pos:last_match_pos+20]+"\n\tred="+readseq[lastmatch:lastmatch+20]) + if firstmatch>0: + print('\tref='+refseq[(first_match_pos-20 if first_match_pos-20>0 else 0):first_match_pos]+"\n\tred="+readseq[(firstmatch-20 if firstmatch-20>0 else 0):firstmatch]) + print('\tf=%d, chr=%s, p=%d, c=%s, s=%s' % (flag, rname, pos, cigar, readseq)) # flag, rname, pos, cigar, readseq + + if len(base_map_info)-last_al_match>1: + base_map_info = base_map_info[first_al_match:(last_al_match+1-len(base_map_info))] + elif first_al_match>0: + base_map_info = base_map_info[first_al_match:] + + # format base + base_map_info = np.array(base_map_info, dtype=[('refbase', 'U1'), ('readbase', 'U1'), ('refbasei', np.uint64), ('readbasei', np.uint64), ('mod_pred', np.int)]) + if forward_reverse=='-': + base_map_info = np.flipud(base_map_info) + for bmii in range(len(base_map_info)): + base_map_info['refbase'][bmii] = get_complement(base_map_info['refbase'][bmii]) + base_map_info['readbase'][bmii] = get_complement(base_map_info['readbase'][bmii]) + leftclip, rightclip = rightclip, leftclip + if False: #True: # for test base_map_info ### for check consistency + ref_align_key = '/Analyses/NanomoCorrected_000/BaseCalled_template/Alignment/genome_alignment' + read_align_key = '/Analyses/NanomoCorrected_000/BaseCalled_template/Alignment/read_alignment' + with h5py.File(f5data[readk][3], 'r') as mf5: + read_align_list = [bt.decode(encoding="utf-8") for bt in mf5[read_align_key]] + ref_align_list = [bt.decode(encoding="utf-8") for bt in mf5[ref_align_key]] + for rali in range(len(read_align_list)): + if not read_align_list[rali]==base_map_info['readbase'][rali]: + print ("Error not equal1! %s %s %d %s" % (read_align_list[rali], base_map_info['readbase'][rali], rali, f5data[readk][3])) + if not ref_align_list[rali]==base_map_info['refbase'][rali]: + print ("Error not equal2! %s %s %d %s" % (ref_align_list[rali], base_map_info['refbase'][rali], rali, f5data[readk][3])) + + # ## for check consistency + if True: + # for special alignment + # CCG or CGG + # C-G C-G + for ali in range(len(base_map_info)): + if base_map_info['refbase'][ali]=='C' and base_map_info['readbase'][ali]=='C': + if ali+1<len(base_map_info) and base_map_info['readbase'][ali+1]=='-' and base_map_info['refbase'][ali+1]=='G': + addali = 2; + while ali + addali < len(base_map_info): + if base_map_info['readbase'][ali+addali]=='-' and base_map_info['refbase'][ali+addali]=='G': addali += 1; + else: break; + if ali + addali < len(base_map_info) and base_map_info['readbase'][ali+addali]=='G' and base_map_info['refbase'][ali+addali]=='G': + base_map_info['readbase'][ali+1], base_map_info['readbase'][ali+addali] = base_map_info['readbase'][ali+addali], base_map_info['readbase'][ali+1] + if base_map_info['refbase'][ali]=='G' and base_map_info['readbase'][ali]=='G': + if ali-1>-1 and base_map_info['readbase'][ali-1]=='-' and base_map_info['refbase'][ali-1]=='C': + addali = 2; + while ali - addali >-1: + if base_map_info['readbase'][ali-addali]=='-' and base_map_info['refbase'][ali-addali]=='C': addali += 1; + else: break; + if ali - addali>-1 and base_map_info['readbase'][ali-addali]=='C' and base_map_info['refbase'][ali-addali]=='C': + base_map_info['readbase'][ali-1], base_map_info['readbase'][ali-addali] = base_map_info['readbase'][ali-addali], base_map_info['readbase'][ali-1] + + if len(m_event)<50: + raiseError("Less Event", sp_param, "Less Event"); + sp_options["Error"]["Less Event"].append(f5data[readk][3]) + continue; + + # get feature + mfeatures,isdif = get_Feature(moptions, sp_options, sp_param, f5align, f5data, readk, leftclip, rightclip, base_map_info, forward_reverse, rname, first_match_pos, numinsert, numdel) + if isdif and moptions['outLevel']<=myCom.OUTPUT_WARNING: + print("Dif is true") + print([lastmatch, firstmatch, first_match_pos, last_match_pos, first_al_match, last_al_match, lasmtind, len(base_map_info), nummismatch, numinsert, numdel, len(base_map_info)-nummismatch-numinsert-numdel]) + if not sp_param['f5status']=="": continue + + # generate/save prediction information + pred_mod_num = mPredict1(moptions, sp_options, sp_param, mfeatures, base_map_info, readk, leftclip, rightclip) + predfile = (sp_options['ctfolder'] if sp_options['ctfolder'][-1] not in ['/', '\\'] else sp_options['ctfolder'][:-1])+'/rnn.pred.detail.fast5'+'.'+str(sp_options['batchid']) + pred_f5_key = 'pred_'+str(readk_ind) + sp_options['Mod'].append([rname, forward_reverse, f5align[readk][3]-1, pred_f5_key, f5data[readk][3][len(moptions['wrkBase'])+1:], predfile[len(moptions['outFolder']+moptions['FileID'])+1:]]) + # save the prediction + with h5py.File(predfile, 'a') as save_data: + if "pred" in save_data: base_group = save_data["pred"] + else: base_group = save_data.create_group("pred") + + if pred_f5_key in base_group: + del base_group[pred_f5_key] + save_data.flush() + pred_group = base_group.create_group(pred_f5_key) + + # save mapped chr, strand, positions + pred_group.attrs['mapped_chr'] = rname + pred_group.attrs['mapped_strand'] = forward_reverse + pred_group.attrs['mapped_start'] = base_map_info['refbasei'][0] if forward_reverse=='+' else base_map_info['refbasei'][-1] + pred_group.attrs['mapped_end'] = base_map_info['refbasei'][-1] if forward_reverse=='+' else base_map_info['refbasei'][0] + + if forward_reverse=='+': + pred_group.attrs['clipped_bases_start'] = leftclip + pred_group.attrs['clipped_bases_end'] = rightclip + else: + pred_group.attrs['clipped_bases_start'] = rightclip + pred_group.attrs['clipped_bases_end'] = leftclip + + # save indel, matches/mismatches + pred_group.attrs['num_insertions'] = numinsert + pred_group.attrs['num_deletions'] = numdel + pred_group.attrs['num_matches'] = len(base_map_info)-nummismatch-numinsert-numdel + pred_group.attrs['num_mismatches'] = nummismatch + + # save the fast5 file and prediction + pred_group.attrs['pred_mod_num'] = pred_mod_num + pred_group.attrs['f5file'] = f5data[readk][3] + pred_group.attrs['readk'] = readk + base_map_info = np.array(base_map_info, dtype=[('refbase', 'S1'), ('readbase', 'S1'), ('refbasei', np.uint64), ('readbasei', np.uint64), ('mod_pred', np.int)]) + pred_group.create_dataset('predetail', data=base_map_info, compression="gzip") + + try: + save_data.flush(); + save_data.close(); + except: + sp_options["Error"]['Cannot save data'].append(f5data[readk][3]) + print ('Error!!! %s in %s' % ("Cannot save data", f5data[readk][3])) + + # save index information + sp_options['Mod'] = sorted(sp_options['Mod']) + # index file + pred_ind_file = (sp_options['ctfolder'] if sp_options['ctfolder'][-1] not in ['/', '\\'] else sp_options['ctfolder'][:-1])+'/%s.' + pre_base_str + '.' + str(sp_options['batchid']) + if len(sp_options['Mod'])>0: + cur_chr = None; cur_writer = None; + for mfi in sp_options['Mod']: + if cur_chr==None or (not cur_chr == mfi[0]): + if not cur_chr==None: + cur_writer.flush(); + cur_writer.close() + cur_chr = mfi[0] + cur_writer = open((pred_ind_file % cur_chr), 'w') + cur_m_f = [] + for mfidetail in mfi: + cur_m_f.append(str(mfidetail)) + cur_m_f.append('\n') + cur_writer.write(' '.join(cur_m_f)) + if not cur_writer==None: + cur_writer.flush(); + cur_writer.close() + +# +# make modificatoin prediction for a long read +# +def mPredict1(moptions, sp_options, sp_param, mfeatures, base_map_info, readk, start_clip, end_clip): + # + modevents = sp_param['f5data'][readk][1] + # get features. labels might be all zero + t0, ty, tx = np.split(mfeatures, [1,3], axis=1); + t0 = t0.astype(int) + m_data = []; m_y = []; + for ie in range(start_clip-100, len(modevents)-end_clip+100): + mind = ie - (start_clip-100) + if ie>=start_clip and ie<len(modevents)-end_clip: + m_y.append(ty[mind]) + # format to input with windoe size + m_data.append(tx[(mind-int(moptions['windowsize']/2)):(mind+int(moptions['windowsize']/2)+1)]) + + # for input + test_feature = np.reshape(m_data, (len(m_data), len(m_data[0]), len(m_data[0][0]))) + test_label = np.reshape(m_y, (len(m_y), len(m_y[0]))).astype(int) + + sp_options['rnn'][0].run(sp_options['rnn'][3]) + + # split into small group + if len(test_feature) > rnn_pred_batch_size*1.2: + x_sub_group = np.array_split(test_feature, int(len(test_feature)/rnn_pred_batch_size)) + y_sub_group = np.array_split(test_label, int(len(test_feature)/rnn_pred_batch_size)) + else: + x_sub_group = [test_feature]; y_sub_group = [test_label] + # make prediction on each small groups + for subi in range(len(x_sub_group)): + if subi==0: + mfpred_output = (sp_options['rnn'][0].run([sp_options['rnn'][4]], \ + feed_dict={sp_options['rnn'][1]:x_sub_group[subi], sp_options['rnn'][2]:y_sub_group[subi]}))[0]; + else: + mfpred_output = np.concatenate((mfpred_output, (sp_options['rnn'][0].run([sp_options['rnn'][4]], \ + feed_dict={sp_options['rnn'][1]:x_sub_group[subi], sp_options['rnn'][2]:y_sub_group[subi]}))[0]), axis=0); + + # associate the prediction with reference positions and read positions + modevents = sp_param['f5data'][readk][1] + aligni = 0; pred_mod_num = 0; + for ie in range(start_clip, len(modevents)-end_clip): + while base_map_info['readbase'][aligni]=='-': aligni += 1 + if not base_map_info['readbase'][aligni] == modevents['model_state'][ie][2]: + print ('Error Does not match', base_map_info['readbase'][aligni], modevents['model_state'][ie][2], aligni, ie) + if mfpred_output[ie-start_clip]==1: + base_map_info['mod_pred'][aligni] = 1; + pred_mod_num += 1; + + aligni += 1 + return pred_mod_num + +# +# get feature for a long read +# +def get_Feature(moptions, sp_options, sp_param, f5align, f5data, readk, start_clip, end_clip, base_map_info, forward_reverse, rname, mapped_start_pos, num_insertions, num_deletions): + modevents = sp_param['f5data'][readk][1] + # class num, bin num, and bin size + clnum = 2; binnum = 50; binlen = 0.2; + if forward_reverse=='+': + align_ref_pos = mapped_start_pos + else: + align_ref_pos = mapped_start_pos + len(base_map_info) - num_insertions - 1 + + # initialize feature matrix + if moptions['fnum']==57: + mfeatures = np.zeros((len(modevents)-end_clip+100-(start_clip-100), (binnum+3+3+4))); + else: mfeatures = np.zeros((len(modevents)-end_clip+100-(start_clip-100), (3+3+4))); + + aligni = 0; isdif = False; + # get feature for each event; each event is a row + for ie in range(start_clip-100, len(modevents)-end_clip+100): + cur_row_num = ie - (start_clip-100); cur_base = '' + if ie>=start_clip and ie<len(modevents)-end_clip: + if align_ref_pos<mapped_start_pos: + print ('ERRRR align_ref_pos(%d)<mapped_start_pos(%d)' % (align_ref_pos, mapped_start_pos)) + # get non-indel events + while base_map_info['readbase'][aligni]=='-': + if not align_ref_pos==base_map_info['refbasei'][aligni]: + print ('ERRRR align_ref_pos(%d) not equal to %d' % (align_ref_pos, base_map_info['refbasei'][aligni] )) + if not base_map_info['refbase'][aligni]=='-': + if forward_reverse=='+': align_ref_pos += 1 + else: align_ref_pos -= 1 + aligni += 1 + if not base_map_info['readbase'][aligni] == modevents['model_state'][ie][2]: + print ('Error Does not match', base_map_info['readbase'][aligni], modevents['model_state'][ie][2], aligni, ie) + sp_param['f5status']= "Error Does not match" + if f5data[readk][3] not in sp_options["Error"]['Error Does not match']: + sp_options["Error"]['Error Does not match'].append(f5data[readk][3]) + if aligni>50: break; + isdif = True; + mfeatures[cur_row_num][0] = align_ref_pos + cur_base = base_map_info['refbase'][aligni] + # both positive/negative labels is zero + mfeatures[cur_row_num][1] = 0; mfeatures[cur_row_num][2] = 0 + if not base_map_info['refbase'][aligni]=='-': + if forward_reverse=='+': align_ref_pos += 1 + else: align_ref_pos -= 1 + aligni += 1 + + # for bin features + if ie>=0 and ie<len(modevents) and moptions['fnum']==57: + for currs in sp_param['f5data'][readk][2][modevents['start'][ie]:int(modevents['start'][ie]+int(modevents['length'][ie]+0.5))]: + if currs>10 or currs<-10: print ('Error raw signal', currs, ie, modevents['start'][ie], modevents['length'][ie]) + curbin = int((currs+5)/binlen) + if curbin<0: curbin = 0 + elif not curbin<binnum: curbin = binnum-1 + mfeatures[cur_row_num][curbin+3] += 1 + if ie>=0 and ie<len(modevents): + # for reference base type + if cur_base in myCom.g_ACGT: + mfeatures[cur_row_num][moptions['fnum']-3+3-4+myCom.g_ACGT.index(cur_base)] = 1 + cur_index_add = moptions['fnum'] - 3 + 3 + # for mean, std and length of signals + mfeatures[cur_row_num][cur_index_add + 0] = modevents["mean"][ie] + mfeatures[cur_row_num][cur_index_add + 1] = modevents["stdv"][ie] + mfeatures[cur_row_num][cur_index_add + 2] = modevents["length"][ie] + + + return (mfeatures, isdif) + + +# +# get complementary base of a given base +# +def get_complement(na): + if na in myCom.acgt: return myCom.na_bp[na] + else: return na; + +# +# get mean/std of signals +# +def calculate_mean_std(m_event, event_ind, forward_reverse, raw_pv, moptions, sp_param): + if forward_reverse=='-': + pvsignal = raw_pv[m_event[-event_ind-1][2]:(m_event[-event_ind-1][2]+m_event[-event_ind-1][3])] + else: + pvsignal = raw_pv[m_event[event_ind][2]:(m_event[event_ind][2]+m_event[event_ind][3])] + # get mean/std for an event + c_mean = round(np.mean(pvsignal), 3) + c_std = round(np.std(pvsignal), 3) + return (c_mean, c_std) + +# +# get required information from a mapped record +# +def handle_line(moptions, sp_param, f5align): + lsp = sp_param['line'].split('\t') + qname, flag, rname, pos, mapq, cigar, _, _, _, seq, _ = lsp[:11] + # check query name, map quality, reference position, cigar and reference name + if qname=='*': sp_param['f5status'] = "qname is *" + elif int(mapq)==255: sp_param['f5status'] = "mapq is 255" + elif int(pos)==0: sp_param['f5status'] = "pos is 0" + elif cigar=='*': sp_param['f5status'] = "cigar is *" + elif rname=='*': sp_param['f5status'] = "rname is *" + if not sp_param['f5status']=="": return qname + + if (qname not in f5align) or f5align[qname][0]<int(mapq): + f5align[qname] = (int(mapq), int(flag), rname, int(pos), cigar, seq) + + return qname + +# +# the worker of the detection in a multiprocessing way +# +def detect_handler(moptions, h5files_Q, failed_Q, file_map_info_q): + + _, init_l, _, _, _, X, Y, _, _, _, _, mfpred = myMultiBiRNN.mCreateSession(moptions['fnum'], moptions['hidden'], moptions['windowsize'], moptions) + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + sess = tf.Session(config=config) + # load module + new_saver = tf.train.import_meta_graph(moptions['modfile'][0]+'.meta') + new_saver.restore(sess,tf.train.latest_checkpoint(moptions['modfile'][1])) + + while not h5files_Q.empty(): + cur_start_time = time.time() + try: + # get fast5 file + f5files, ctfolderid, batchid = h5files_Q.get(block=False) + except: + break; + + sp_options = defaultdict(); + sp_options['ctfolderid'] = ctfolderid + sp_options['ctfolder'] = moptions['outFolder']+moptions['FileID']+'/'+str(ctfolderid) + if not os.path.isdir(sp_options['ctfolder']): + os.system('mkdir '+sp_options['ctfolder']) + #if moptions['testrnn']: + sp_options['rnn'] = (sess, X, Y, init_l, mfpred) + sp_options['batchid'] = batchid + + sp_options['Mod'] = []; + # make modification prediction for each fast5 + mDetect1(moptions, sp_options, f5files) + # outputing errors + for errtype, errfiles in sp_options["Error"].items(): + failed_Q.put((errtype, errfiles)); + + print ("Cur Prediction consuming time %d for %d %d" % (time.time() - cur_start_time, ctfolderid, batchid)) + + sess.close() + +# +# read index files for each output prediction file +# +def read_file_list(cur_cif, cur_chr, cur_strand, sp_options): + cur_list = [] + with open(cur_cif, 'r') as mr: + line = mr.readline(); + while line: # get where is prediction file from each line + line = line.strip(); + if len(line)>0: + lsp = line.split(); + if line[0]=='#': + if lsp[1][0] not in ['/', '\\']: + lsp[1] = lsp[1] + '/' + if lsp[0]=='#base_folder_fast5': sp_options['base_folder_fast5'] = lsp[1]; + elif lsp[0]=='#base_folder_output': sp_options['base_folder_output'] = lsp[1]; + else: + if lsp[1] == cur_strand: + cur_list.append(lsp) + if not lsp[0]==cur_chr: + print ('Warning!!! The chr should be %s but % is found.' % (cur_chr, lsp[0])) + line = mr.readline(); + sp_options['handlingList'] = cur_list + +# +# get prediction detail +# +def read_pred_detail(moptions, sp_options, f5info): + f5pred_file = sp_options['base_folder_output'] + '/' + f5info[5] + f5_pred_key = ('/pred/%s/predetail' % f5info[3]) + # get prediction detail from saved prediction file + # each file contains predictions for multiple fast5 + with h5py.File(f5pred_file, 'r') as mr: + m_pred = mr[f5_pred_key].value; + mapped_chrom = mr['/pred/%s' % f5info[3]].attrs['mapped_chr'] #.decode(encoding="utf-8") + mapped_strand= mr['/pred/%s' % f5info[3]].attrs['mapped_strand'] #.decode(encoding="utf-8") + m_pred = np.array(m_pred, dtype=[('refbase', 'U1'), ('readbase', 'U1'), ('refbasei', np.uint64), ('readbasei', np.uint64), ('mod_pred', np.int)]) + return (m_pred, mapped_chrom, mapped_strand) + +# +# summarize modification for each genome position of interest +# +def sum_handler(moptions, chr_strand_Q): + while not chr_strand_Q.empty(): + try: + # get setting for summarization of predicted modifications + cur_cif, cur_chr, cur_strand = chr_strand_Q.get(block=False) + except: + break; + + sp_options = {} + # get prediction files + read_file_list(cur_cif, cur_chr, cur_strand, sp_options) + sp_options['4NA'] = {moptions['Base']:defaultdict()} + sp_options['4NAfile'] = {} + for nak in sp_options['4NA']: + if not moptions['mod_cluster']: + sp_options['4NAfile'][nak] = ('%s/mod_pos.%s%s.%s.bed' % (moptions['outFolder'], cur_chr, cur_strand, nak)) + else: + sp_options['4NAfile'][nak] = ('%s/cluster_mod_pos.%s%s.%s.bed' % (moptions['outFolder'], cur_chr, cur_strand, nak)) + + cur_start_time = time.time(); hlnum = 0; + for hl in sp_options['handlingList']: + # read prediction detail for each fast5 + m_pred, mapped_chrom, mapped_strand = read_pred_detail(moptions, sp_options, hl) + if not (mapped_chrom==cur_chr and mapped_strand==cur_strand): + print("ERRoR not the same chr (real=%s vs expect=%s) and strand (real=%s VS expect=%s)" % (mapped_chrom, cur_chr, mapped_strand, cur_strand)) + ##################################################### + if moptions['mod_cluster']: # revised; should not used now + from numpy.lib.recfunctions import append_fields + m_pred = append_fields(m_pred, 'mod_pred2', m_pred['mod_pred']+0, usemask=False) + for mi in range(len(m_pred)): + if not m_pred['mod_pred2'][mi] == m_pred['mod_pred'][mi]: + print('mod_pred_Error %d %d %d' % (mi, m_pred['mod_pred2'][mi], m_pred['mod_pred'][mi])) + if m_pred['mod_pred2'][mi]==1 or m_pred['refbase'][mi] not in ['C']: continue; + + m_3 = [] + m_5 = [] + mj = mi-1 + while mj>-1 and len(m_3)<12: + if m_pred['refbase'][mj] in ['N', 'n']: break; + if m_pred['refbase'][mj] not in ['-']: + m_3.append((m_pred['refbase'][mj], m_pred['mod_pred2'][mj])) + mj -= 1; + if len(m_3)>0: m_3 = m_3[::-1] + mj = mi + 1; + while mj < len(m_pred) and len(m_5)<12: + if m_pred['refbase'][mj] in ['N', 'n']: break; + if m_pred['refbase'][mj] not in ['-']: + m_5.append((m_pred['refbase'][mj], m_pred['mod_pred2'][mj])) + mj += 1 + cpgnum = 0; meth_cpgnum = 0; + m_3and5 = [m_3, m_5] + for m_53 in m_3and5: + for mj in range(len(m_53)-1): + if m_53[mj][0]=='C' and m_53[mj+1][0]=='G': + cpgnum +=1 + if -0.1 < m_53[mj][1]-1 < 0.1: + meth_cpgnum += 1 + + if cpgnum>0 and (meth_cpgnum>0 and meth_cpgnum/float(cpgnum)>0.5): + m_pred['mod_pred'][mi] = 1 + ##################################################################################### + for mi in range(len(m_pred)): + # get prediction for each base type + if m_pred['refbase'][mi] not in sp_options['4NA']: continue; + if m_pred['refbase'][mi] in ['-', 'N', 'n']: continue; + if (cur_chr, cur_strand, m_pred['refbasei'][mi]) not in sp_options['4NA'][m_pred['refbase'][mi]]: + sp_options['4NA'][m_pred['refbase'][mi]][(cur_chr, cur_strand, int(m_pred['refbasei'][mi]) )] = [0, 0, m_pred['refbase'][mi]] + if not (m_pred['refbase'][mi] == sp_options['4NA'][m_pred['refbase'][mi]][(cur_chr, cur_strand, int(m_pred['refbasei'][mi]) )][2]): + print ('Error !!!! NA not equal %s == %s' % (m_pred['refbase'][mi], sp_options['4NA'][m_pred['refbase'][mi]][(cur_chr, cur_strand, int(m_pred['refbasei'][mi]) )][2])) + if not m_pred['readbase'][mi]=='-': + sp_options['4NA'][m_pred['refbase'][mi]][(cur_chr, cur_strand, int(m_pred['refbasei'][mi]) )][0] += 1 + if -0.1 < m_pred['mod_pred'][mi]-1 < 0.1: + sp_options['4NA'][m_pred['refbase'][mi]][(cur_chr, cur_strand, int(m_pred['refbasei'][mi]) )][1] += 1 + hlnum += 1 + if hlnum % 1000==0: + print ("\tCurrent time consuming %d for %d" % (time.time() - cur_start_time, hlnum)) + cur_start_time = time.time() + + print ('====sum done! To save') + for nak in sp_options['4NA']: + print ('\tSave %s' % sp_options['4NAfile'][nak]) + if len(sp_options['4NA'][nak])>0: + with open(sp_options['4NAfile'][nak], 'w') as mw: + # save the summary for each base type + pos_keys = sp_options['4NA'][nak].keys(); + pos_keys = sorted(pos_keys); + for pk in pos_keys: + neighna = [sp_options['4NA'][nak][pk][2]] + mw.write(' '.join([ pk[0], str(pk[2]), str(pk[2]+1), ''.join(neighna), \ + str(1000 if sp_options['4NA'][nak][pk][0]>1000 else sp_options['4NA'][nak][pk][0]), \ + pk[1], str(pk[2]), str(pk[2]+1), '0,0,0', str(sp_options['4NA'][nak][pk][0]), \ + ('%d' % (100*sp_options['4NA'][nak][pk][1]/(sp_options['4NA'][nak][pk][0] if sp_options['4NA'][nak][pk][0]>0 else 1))), \ + str(sp_options['4NA'][nak][pk][1]), '\n' ])) +# +# prediction manager of a multiprocess process +# +def mDetect_manager(moptions): + pmanager = multiprocessing.Manager(); + # get input folder + while (not moptions['wrkBase']==None) and len(moptions['wrkBase'])>0 and moptions['wrkBase'][-1] in ['/', '\\']: + moptions['wrkBase'] = moptions['wrkBase'][:-1] + + # need to make prediction of modification + if moptions['predDet']==1: + + # get well-trained model + if moptions['modfile'].rfind('/')==-1: + moptions['modfile'] = [moptions['modfile'], './'] + else: + moptions['modfile'] = [moptions['modfile'], moptions['modfile'][:moptions['modfile'].rfind('/')+1]] + + start_time = time.time(); + + # get fast5 files in a recurisive way + f5files = glob.glob(os.path.join(moptions['wrkBase'],"*.fast5" )) + if moptions['recursive']==1: + f5files.extend(glob.glob(os.path.join(moptions['wrkBase'],"*/*.fast5" ))) + f5files.extend(glob.glob(os.path.join(moptions['wrkBase'],"*/*/*.fast5" ))) + f5files.extend(glob.glob(os.path.join(moptions['wrkBase'],"*/*/*/*.fast5" ))) + + print('Total files=%d' % len(f5files)) + + # output folder + if not os.path.isdir(moptions['outFolder']+moptions['FileID']): + os.system('mkdir -p '+moptions['outFolder']+moptions['FileID']) + + # prepare multiprocessing + h5files_Q = pmanager.Queue(); + file_map_info_q = pmanager.Queue(); + failed_Q = pmanager.Queue() + + # spliting fast5 files into different lists + h5_batch = []; h5batchind = 0; + sub_folder_size = 100; sub_folder_id = 0; + for f5f in f5files: + h5_batch.append(f5f); + if len(h5_batch)==moptions['files_per_thread']: + # each batch + h5files_Q.put((h5_batch, sub_folder_id, h5batchind)) + h5_batch = []; h5batchind += 1 + if h5batchind % sub_folder_size ==0: + sub_folder_id += 1 + if len(h5_batch)>0: + h5files_Q.put((h5_batch, sub_folder_id, h5batchind)) + h5_batch = []; h5batchind += 1 + + # start multiprocessing + share_var = (moptions, h5files_Q, failed_Q, file_map_info_q) + handlers = [] + for hid in range(moptions['threads']): + p = multiprocessing.Process(target=detect_handler, args=share_var); + p.start(); + handlers.append(p); + + # check potential errors + failed_files = defaultdict(list); + while any(p.is_alive() for p in handlers): + try: + errk, fns = failed_Q.get(block=False); + failed_files[errk].extend(fns) + + except: + time.sleep(1); + continue; + + # prepare modificatoin summary for reference positions of interest + moptions['predpath'] = moptions['outFolder'] + '/'+moptions['FileID'] + pred_ind_pref = moptions['outFolder'] + '/'+moptions['FileID']+'/'+pre_base_str + pred_chr_files = glob.glob(os.path.join(moptions['outFolder']+moptions['FileID'], '*/*.'+pre_base_str+'.*')) + chr_dict = defaultdict(list); + for pcf in pred_chr_files: + chr_dict[ pcf.split('/')[-1].split('.'+pre_base_str)[0] ].append(pcf) + chr_keys = chr_dict.keys(); + for ck in chr_keys: + cur_ind_pred_f = pred_ind_pref + '.' + ck; + cur_list = [ ['#base_folder_fast5', moptions['wrkBase']], ['#base_folder_output', os.path.abspath(moptions['outFolder']+moptions['FileID'])] ] + for sub_c_f in chr_dict[ck]: + with open(sub_c_f, 'r') as mr: + line = mr.readline() + while line: + line = line.strip() + if len(line)>0: + lsp = line.split(); + lsp[2] = int(lsp[2]) + cur_list.append(lsp) + line = mr.readline() + cur_list = sorted(cur_list) + with open(cur_ind_pred_f, 'w') as indf_writer: + for mfi in cur_list: + cur_m_f = [] + for mfidetail in mfi: + cur_m_f.append(str(mfidetail)) + cur_m_f.append('\n') + indf_writer.write(' '.join(cur_m_f)) + # error info + if len(failed_files)>0: + print ('Error information for different fast5 files:') + for errtype, errfiles in failed_files.items(): + print ('\t'+errtype, len(errfiles)) + + moptions['outFolder'] = moptions['outFolder']+moptions['FileID'] + end_time = time.time(); + print ("Per-read Prediction consuming time %d" % (end_time-start_time)) + + ### for summarizing modificatoin prediction + start_time = time.time(); + # get all index files of prediction + all_chr_ind_files = glob.glob(os.path.join(moptions['predpath'], pre_base_str+'.*')) + print('Find: %s %d %s' % (moptions['predpath'], len(all_chr_ind_files), pre_base_str)) + print (all_chr_ind_files) + + # for each chromosome, a thread will be initialized for multiprocessing summarization of modifications + chr_strand_Q = pmanager.Queue(); jobnum = 0; + for cur_cif in all_chr_ind_files: + chr_strand_Q.put((cur_cif, cur_cif.split(pre_base_str)[-1][1:], '+')) + chr_strand_Q.put((cur_cif, cur_cif.split(pre_base_str)[-1][1:], '-')) + jobnum +=2 + + # star to summarize modificaiton prediction of reference genomes of interest + share_var = (moptions, chr_strand_Q) + handlers = [] + for hid in range(moptions['threads'] if moptions['threads']<jobnum else jobnum): + p = multiprocessing.Process(target=sum_handler, args=share_var); + p.start(); + handlers.append(p); + while any(p.is_alive() for p in handlers): + try: + time.sleep(1); + except: + time.sleep(1); + continue; + + end_time = time.time(); + print ("Genomic-position Detection consuming time %d" % (end_time-start_time)) + + os.system('touch '+moptions['outFolder']+'.done') + +# for independent testing of code +if __name__=='__main__': +# if len(sys.argv)>4: + moptions = {} + moptions['basecall_1d'] = 'Basecall_1D_000' + moptions['basecall_1d'] = ['Basecall_1D_000'] + moptions['basecall_2strand'] = 'BaseCalled_template' + + moptions['outLevel'] = myCom.OUTPUT_WARNING + moptions['outLevel'] = myCom.OUTPUT_INFO + + moptions['modfile'] = '../../mod_output/train1/2/mod_train' + + moptions['fnum'] = 3; + moptions['hidden'] = 100; + moptions['windowsize'] = 21; + + moptions['threads'] = 8 + moptions['threads'] = 1 + moptions['files_per_thread'] = 500 + + mDetect_manager(moptions)