Switch to side-by-side view

--- a
+++ b/DeepMod_tools/cal_EcoliDetPerf.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python
+
+import os, sys, time
+from collections import defaultdict
+import glob
+import copy
+import numpy as np
+
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from itertools import cycle
+
+from sklearn.metrics import roc_curve, auc
+from sklearn.metrics import precision_recall_curve, average_precision_score
+from sklearn.metrics import matthews_corrcoef
+
+import rpy2.robjects as robjects
+from rpy2.robjects.packages import importr
+from pkg_resources import resource_string
+
+from scipy.stats import binom
+import copy
+
+ggplot = importr('ggplot2')
+importr('gridExtra')
+importr('plyr')
+
+na4com = {'A':'T', 'C':'G', 'T':'A', 'G':'C'}
+
+def readFA(mfa, mpat='Cg', mposinpat=0, t_chr=None, t_start=None, t_end=None):
+   pos_dict = defaultdict(int)
+  
+   pat3 = copy.deepcopy(mpat.upper())
+   comp_pat3 = ''.join([na4com[curna] for curna in pat3][::-1])
+   comp_mposinpat = len(comp_pat3)-1-mposinpat
+ 
+   fadict = defaultdict();
+   with open(mfa, 'r') as mr:
+      cur_chr = None;
+
+      line = mr.readline();
+      while line:
+         line = line.strip();
+         if len(line)>0:
+            if line[0]=='>': 
+               if not cur_chr==None:
+                  fadict[cur_chr] = ''.join(fadict[cur_chr])
+               cur_chr = line[1:].split()[0]
+               if t_chr in [None, cur_chr]:
+                  fadict[cur_chr] = []
+            else:
+               if t_chr in [None, cur_chr]: 
+                  fadict[cur_chr].append(line)
+         line = mr.readline();
+      if not cur_chr==None:
+         fadict[cur_chr] = ''.join(fadict[cur_chr]) 
+   fakeys = fadict.keys();
+   cpgdict = defaultdict(int); cpgnum = [0, 0]
+   for fak in fakeys:
+       cpgdict[fak] = defaultdict()
+       for i in range(len(fadict[fak])):
+          if (t_start==None or i>=t_start) and (t_end==None or i<=t_end):
+             if i-mposinpat>=0 and i+len(comp_pat3)-1-mposinpat<len(fadict[fak]) and ''.join(fadict[fak][i-mposinpat:(i+len(comp_pat3)-1-mposinpat+1)])==pat3:
+                cpgdict[fak][('+', i)] = [1, fadict[fak][i]]; cpgnum[0] += 1
+                cpgdict[fak][('-', i)] = [0, fadict[fak][i]]
+             elif i-comp_mposinpat>=0 and i+len(comp_pat3)-1-comp_mposinpat<len(fadict[fak]) and ''.join(fadict[fak][i-comp_mposinpat:(i+len(comp_pat3)-1-comp_mposinpat+1)])==comp_pat3:
+                cpgdict[fak][('+', i)] = [0, fadict[fak][i]]
+                cpgdict[fak][('-', i)] = [1, fadict[fak][i]]; cpgnum[1] += 1
+             else:
+                cpgdict[fak][('+', i)] = [0, fadict[fak][i]]
+                cpgdict[fak][('-', i)] = [0, fadict[fak][i]]
+   print('%s%d site: %d(+) %d(-)' % (pat3, mposinpat, cpgnum[0], cpgnum[1]))
+   return cpgdict
+
+def readmodf_dict(cpgdict, modf, pred_dict, mna, t_start=None, t_end=None):
+   with open(modf, 'r') as mr:
+      while True:
+          line = mr.readline();
+          if not line: break;
+          line = line.strip();
+          if len(line)>0:
+             lsp = line.split();
+             cur_chr = lsp[0];
+             cur_pos = int(lsp[1]);
+             cur_strand = lsp[5];
+
+             cur_cov = int(lsp[9]);
+             cur_m_p = int(lsp[10]);
+             cur_m_c = int(lsp[11]);
+
+             if not ((t_start==None or cur_pos>=t_start) and (t_end==None or cur_pos<=t_end)):
+                line = mr.readline();
+                continue;
+
+             if not (mna==lsp[3] and lsp[3]==(cpgdict[cur_chr][(cur_strand, cur_pos)][1] if cur_strand=='+' else na4com[cpgdict[cur_chr][(cur_strand, cur_pos)][1]])):
+                print ('Error !! NA not equal %s == %s == %s %s' % (mna, lsp[3], cpgdict[cur_chr][(cur_strand, cur_pos)][1], modf))
+
+             if (cur_chr, cur_pos, cur_strand) not in pred_dict: 
+                pred_dict[(cur_chr, cur_pos, cur_strand)] = [cur_cov, cur_m_p, cur_m_c, lsp[3]]
+             else:
+                pred_dict[(cur_chr, cur_pos, cur_strand)][0] += cur_cov
+                pred_dict[(cur_chr, cur_pos, cur_strand)][2] += cur_m_c
+                pred_dict[(cur_chr, cur_pos, cur_strand)][1] = int(pred_dict[(cur_chr, cur_pos, cur_strand)][2]*100/pred_dict[(cur_chr, cur_pos, cur_strand)][0]) if pred_dict[(cur_chr, cur_pos, cur_strand)][0]>0 else 0
+
+def add_from_dict(cpgdict, pred_dict, label, pred_list, mna, tp_fp_tn_fn, mpat='Cg', mposinpat=0):
+   for posk in pred_dict:
+             cur_chr, cur_pos, cur_strand = posk
+             cur_cov, cur_m_p, cur_m_c, lsp3 = pred_dict[posk]
+
+             iscpg = False;
+             if cpgdict[cur_chr][(cur_strand, cur_pos)][0]==1:
+                 iscpg = True;
+                 pred_list.append((label, cur_cov, cur_m_p, cur_m_c, mpat, np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+                 if (lsp3==mpat[mposinpat]): pass
+                 else: print ('Error not methylated pos %s %s %s' % (mna, cur_strand))
+             if not iscpg:
+                isclosec = False;
+                for i in range(-3, 4):
+                   if (cur_strand, cur_pos+i) in cpgdict[cur_chr] and cpgdict[cur_chr][(cur_strand, cur_pos+i)][0]==1:
+                      isclosec = True; break;
+                if lsp3==mpat[mposinpat]:
+                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_n'+str(abs(i))+mpat[mposinpat] if isclosec else 'Other'+mpat[mposinpat], np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+                else:
+                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_nb' if isclosec else 'Other', np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+             if pred_list[-1][0]==0:
+                tp_fp_tn_fn[2] += cur_cov - cur_m_c
+                tp_fp_tn_fn[1] += cur_m_c
+             else:
+                tp_fp_tn_fn[0] += cur_m_c
+                tp_fp_tn_fn[3] += cur_cov - cur_m_c
+
+
+def readmodf(cpgdict, modf, label, pred_list, mna, tp_fp_tn_fn, mpat='Cg', mposinpat=0, t_start=None, t_end=None):
+   with open(modf, 'r') as mr:
+      line = mr.readline();
+      while line:
+          line = line.strip();
+          if len(line)>0:
+             lsp = line.split();
+             cur_chr = lsp[0];
+             cur_pos = int(lsp[1]);
+             cur_strand = lsp[5];
+             
+             cur_cov = int(lsp[9]);
+             cur_m_p = int(lsp[10]);
+             cur_m_c = int(lsp[11]); 
+
+             if not ((t_start==None or cur_pos>=t_start) and (t_end==None or cur_pos<=t_end)): 
+                line = mr.readline();
+                continue;
+
+             if not (mna==lsp[3] and lsp[3]==(cpgdict[cur_chr][(cur_strand, cur_pos)][1] if cur_strand=='+' else na4com[cpgdict[cur_chr][(cur_strand, cur_pos)][1]])):
+                print ('Error !! NA not equal %s == %s == %s %s' % (mna, lsp[3], cpgdict[cur_chr][(cur_strand, cur_pos)][1], modf))
+             iscpg = False;
+             if cpgdict[cur_chr][(cur_strand, cur_pos)][0]==1:
+                 iscpg = True;
+                 pred_list.append((label, cur_cov, cur_m_p, cur_m_c, mpat, np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+                 if (lsp[3]==mpat[mposinpat]): pass
+                 else: print ('Error not methylated pos %s %s %s' % (mna, cur_strand, modf))
+             if not iscpg:
+                isclosec = False;
+                for i in range(-3, 4):
+                   if (cur_strand, cur_pos+i) in cpgdict[cur_chr] and cpgdict[cur_chr][(cur_strand, cur_pos+i)][0]==1:
+                      isclosec = True; break;
+                if lsp[3]==mpat[mposinpat]:
+                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_n'+str(abs(i))+mpat[mposinpat] if isclosec else 'Other'+mpat[mposinpat], np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+                else:
+                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_nb' if isclosec else 'Other', np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
+             if pred_list[-1][0]==0:
+                tp_fp_tn_fn[2] += cur_cov - cur_m_c
+                tp_fp_tn_fn[1] += cur_m_c
+             else:
+                tp_fp_tn_fn[0] += cur_m_c
+                tp_fp_tn_fn[3] += cur_cov - cur_m_c 
+          line = mr.readline();   
+
+
+sssfolder = sys.argv[1];   # 
+mreffile = sys.argv[2];    # 
+mpat=sys.argv[3];          # Cg
+mposinpat=int(sys.argv[4]);# 0 
+
+chrofinterest = sys.argv[5];
+if chrofinterest=='': chrofinterest = None;
+stposofinterest = int(sys.argv[6]);
+if stposofinterest<0: stposofinterest = None;
+edposofinterest = int(sys.argv[7]);
+if edposofinterest<0: edposofinterest = None;
+
+basefig = sys.argv[8]
+hastwoclass = 1;
+
+sssfiles = {mpat[mposinpat]:glob.glob(os.path.join(sssfolder, 'mod_pos.*.'+mpat[mposinpat]+'.bed'))}
+sssfiles[mpat[mposinpat]].extend(glob.glob(os.path.join(sssfolder, '*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
+sssfiles[mpat[mposinpat]].extend(glob.glob(os.path.join(sssfolder, '*/*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
+print(str(len(sssfiles[mpat[mposinpat]])) + " " + str(sssfolder))
+
+## for negative;
+umrfiles = []
+for cur_umr_f in sys.argv[9].split(','):
+   if not os.path.isdir(cur_umr_f):
+       print("No prediction folder {}".format(cur_umr_f))
+       sys.exit(1);
+   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, '*/*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
+   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, '*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
+   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, 'mod_pos.*.'+mpat[mposinpat]+'.bed')))
+print(str(len(umrfiles)) + "  " + str(sys.argv[9].split(',')))
+sys.stdout.flush()
+
+for sa in sssfiles:
+   print (sa)
+   for nf in sssfiles[sa]:
+      print ('\t'+nf)
+
+cpgdict = readFA(mreffile, mpat, mposinpat, chrofinterest, stposofinterest, edposofinterest)
+
+pred_dict = defaultdict();
+for modf in umrfiles:
+   readmodf_dict(cpgdict, modf, pred_dict, mpat[mposinpat], stposofinterest, edposofinterest)
+
+baseinfo = [mpat, mpat+'_n1'+mpat[mposinpat], mpat+'_n2'+mpat[mposinpat], mpat+'_n3'+mpat[mposinpat], 'Other'+mpat[mposinpat], mpat+'_nb', 'Other']
+
+classify_m = ['Methylation_Percentage']
+classify_types = [baseinfo, [mpat]]
+filename = [['all_mp','motif_mp'] ]
+cov_thr = [1, 5]
+mlinestyle = {1:'bo-', 3:'gx--', 5:'r*-.', 7:'cs-', 10:'md--', 15:'k+-.'}
+
+pred_list = []; tp_fp_tn_fn = [0, 0, 0, 0]
+
+add_from_dict(cpgdict, pred_dict, 0, pred_list, mpat[mposinpat], tp_fp_tn_fn, mpat, mposinpat)
+
+if True:
+   for na4 in sssfiles:
+      for cur_f in sssfiles[na4]:
+         print('%s %s' % (na4, cur_f)); sys.stdout.flush();
+         readmodf(cpgdict, cur_f, hastwoclass, pred_list, na4, tp_fp_tn_fn, mpat, mposinpat, stposofinterest, edposofinterest);
+   pred_list = np.array(pred_list, dtype=[('Methylation', np.uint), ('Coverage', np.uint64), ('Methylation_Percentage', np.uint64), ('Methylation_Coverage', np.uint64), ('BaseInfo', 'U20'), ('logp', np.float64)])
+   
+   if hastwoclass==1:
+      cov_plot_thr = [1, 5]
+      for ct_ind in range(len(classify_types)):
+         ct = classify_types[ct_ind]
+         cur_ct_data = pred_list[np.isin(pred_list['BaseInfo'], ct)]
+         for cm_ind in range(len(classify_m)):
+             print('basetype={} classify_measure={}'.format(ct, classify_m[cm_ind]))
+             cm = classify_m[cm_ind]
+           
+             # 1 for roc, 2: pr;
+             roc_or_pr = 2; roc_or_pr=0
+             for roc_or_pr in range(1,3):
+              if roc_or_pr>0: 
+                mfig= plt.figure()
+              if roc_or_pr==2:
+                cur_fn = basefig+'/ap_plot_met_pr_'+filename[cm_ind][ct_ind]+'.png'
+                xylab = ['Recall', 'Precision'];  leg_mpos = "lower left"
+                for covt in cov_plot_thr:
+                   precision, recall, thresholds = precision_recall_curve(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
+                   ap_pr = average_precision_score(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
+                   plt.plot(recall, precision, mlinestyle[covt], lw=2, label='Coverage>=%d (AP=%0.3f)' % (covt, ap_pr))
+                   print('\t\t %s %d ap=%.5f' % (cur_fn, covt, ap_pr))
+              elif roc_or_pr==1:
+                xylab = ['False Positive Rate', 'True Positive Rate']; leg_mpos = "lower right"
+                cur_fn = basefig+'/roc_plot_met_roc_'+filename[cm_ind][ct_ind]+'.png'
+                prev = 0; prev_ind = -1
+                for covt in cov_plot_thr:
+                   fpr, tpr, mthr = roc_curve(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
+                   #print(','.join([str(np.round(t1, 5)) for t1 in mthr]))
+                   roc_auc = auc(fpr, tpr)
+                   if (not np.isnan(roc_auc)) and (abs(roc_auc - prev)>=0.02 or (covt>10 and abs(roc_auc - prev)>=0.005) or (cov_plot_thr.index(covt)-prev_ind>1 and abs(roc_auc - prev)>=0.005)):
+                      plt.plot(fpr, tpr, mlinestyle[covt], lw=2, label='Coverage>=%d (AUC=%0.3f)' % (covt, roc_auc)) 
+                      prev = roc_auc; prev_ind = cov_plot_thr.index(covt)
+                   if not np.isnan(roc_auc):
+                      print ('\t\t %s %d %.7f' % (cur_fn, covt, roc_auc))
+                plt.plot([0, 1], [0, 1])
+              if roc_or_pr>0:
+                plt.xlim([0.0, 1.0]);              plt.ylim([0.0, 1.0])
+                plt.xlabel(xylab[0]);              plt.ylabel(xylab[1])
+                plt.legend(loc=leg_mpos)
+                mfig.savefig(cur_fn, dpi=300);              plt.close(mfig)
+
+