a b/DeepMod_tools/cal_EcoliDetPerf.py
1
#!/usr/bin/env python
2
3
import os, sys, time
4
from collections import defaultdict
5
import glob
6
import copy
7
import numpy as np
8
9
import matplotlib
10
matplotlib.use('Agg')
11
import matplotlib.pyplot as plt
12
from itertools import cycle
13
14
from sklearn.metrics import roc_curve, auc
15
from sklearn.metrics import precision_recall_curve, average_precision_score
16
from sklearn.metrics import matthews_corrcoef
17
18
import rpy2.robjects as robjects
19
from rpy2.robjects.packages import importr
20
from pkg_resources import resource_string
21
22
from scipy.stats import binom
23
import copy
24
25
ggplot = importr('ggplot2')
26
importr('gridExtra')
27
importr('plyr')
28
29
na4com = {'A':'T', 'C':'G', 'T':'A', 'G':'C'}
30
31
def readFA(mfa, mpat='Cg', mposinpat=0, t_chr=None, t_start=None, t_end=None):
32
   pos_dict = defaultdict(int)
33
  
34
   pat3 = copy.deepcopy(mpat.upper())
35
   comp_pat3 = ''.join([na4com[curna] for curna in pat3][::-1])
36
   comp_mposinpat = len(comp_pat3)-1-mposinpat
37
 
38
   fadict = defaultdict();
39
   with open(mfa, 'r') as mr:
40
      cur_chr = None;
41
42
      line = mr.readline();
43
      while line:
44
         line = line.strip();
45
         if len(line)>0:
46
            if line[0]=='>': 
47
               if not cur_chr==None:
48
                  fadict[cur_chr] = ''.join(fadict[cur_chr])
49
               cur_chr = line[1:].split()[0]
50
               if t_chr in [None, cur_chr]:
51
                  fadict[cur_chr] = []
52
            else:
53
               if t_chr in [None, cur_chr]: 
54
                  fadict[cur_chr].append(line)
55
         line = mr.readline();
56
      if not cur_chr==None:
57
         fadict[cur_chr] = ''.join(fadict[cur_chr]) 
58
   fakeys = fadict.keys();
59
   cpgdict = defaultdict(int); cpgnum = [0, 0]
60
   for fak in fakeys:
61
       cpgdict[fak] = defaultdict()
62
       for i in range(len(fadict[fak])):
63
          if (t_start==None or i>=t_start) and (t_end==None or i<=t_end):
64
             if i-mposinpat>=0 and i+len(comp_pat3)-1-mposinpat<len(fadict[fak]) and ''.join(fadict[fak][i-mposinpat:(i+len(comp_pat3)-1-mposinpat+1)])==pat3:
65
                cpgdict[fak][('+', i)] = [1, fadict[fak][i]]; cpgnum[0] += 1
66
                cpgdict[fak][('-', i)] = [0, fadict[fak][i]]
67
             elif i-comp_mposinpat>=0 and i+len(comp_pat3)-1-comp_mposinpat<len(fadict[fak]) and ''.join(fadict[fak][i-comp_mposinpat:(i+len(comp_pat3)-1-comp_mposinpat+1)])==comp_pat3:
68
                cpgdict[fak][('+', i)] = [0, fadict[fak][i]]
69
                cpgdict[fak][('-', i)] = [1, fadict[fak][i]]; cpgnum[1] += 1
70
             else:
71
                cpgdict[fak][('+', i)] = [0, fadict[fak][i]]
72
                cpgdict[fak][('-', i)] = [0, fadict[fak][i]]
73
   print('%s%d site: %d(+) %d(-)' % (pat3, mposinpat, cpgnum[0], cpgnum[1]))
74
   return cpgdict
75
76
def readmodf_dict(cpgdict, modf, pred_dict, mna, t_start=None, t_end=None):
77
   with open(modf, 'r') as mr:
78
      while True:
79
          line = mr.readline();
80
          if not line: break;
81
          line = line.strip();
82
          if len(line)>0:
83
             lsp = line.split();
84
             cur_chr = lsp[0];
85
             cur_pos = int(lsp[1]);
86
             cur_strand = lsp[5];
87
88
             cur_cov = int(lsp[9]);
89
             cur_m_p = int(lsp[10]);
90
             cur_m_c = int(lsp[11]);
91
92
             if not ((t_start==None or cur_pos>=t_start) and (t_end==None or cur_pos<=t_end)):
93
                line = mr.readline();
94
                continue;
95
96
             if not (mna==lsp[3] and lsp[3]==(cpgdict[cur_chr][(cur_strand, cur_pos)][1] if cur_strand=='+' else na4com[cpgdict[cur_chr][(cur_strand, cur_pos)][1]])):
97
                print ('Error !! NA not equal %s == %s == %s %s' % (mna, lsp[3], cpgdict[cur_chr][(cur_strand, cur_pos)][1], modf))
98
99
             if (cur_chr, cur_pos, cur_strand) not in pred_dict: 
100
                pred_dict[(cur_chr, cur_pos, cur_strand)] = [cur_cov, cur_m_p, cur_m_c, lsp[3]]
101
             else:
102
                pred_dict[(cur_chr, cur_pos, cur_strand)][0] += cur_cov
103
                pred_dict[(cur_chr, cur_pos, cur_strand)][2] += cur_m_c
104
                pred_dict[(cur_chr, cur_pos, cur_strand)][1] = int(pred_dict[(cur_chr, cur_pos, cur_strand)][2]*100/pred_dict[(cur_chr, cur_pos, cur_strand)][0]) if pred_dict[(cur_chr, cur_pos, cur_strand)][0]>0 else 0
105
106
def add_from_dict(cpgdict, pred_dict, label, pred_list, mna, tp_fp_tn_fn, mpat='Cg', mposinpat=0):
107
   for posk in pred_dict:
108
             cur_chr, cur_pos, cur_strand = posk
109
             cur_cov, cur_m_p, cur_m_c, lsp3 = pred_dict[posk]
110
111
             iscpg = False;
112
             if cpgdict[cur_chr][(cur_strand, cur_pos)][0]==1:
113
                 iscpg = True;
114
                 pred_list.append((label, cur_cov, cur_m_p, cur_m_c, mpat, np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
115
                 if (lsp3==mpat[mposinpat]): pass
116
                 else: print ('Error not methylated pos %s %s %s' % (mna, cur_strand))
117
             if not iscpg:
118
                isclosec = False;
119
                for i in range(-3, 4):
120
                   if (cur_strand, cur_pos+i) in cpgdict[cur_chr] and cpgdict[cur_chr][(cur_strand, cur_pos+i)][0]==1:
121
                      isclosec = True; break;
122
                if lsp3==mpat[mposinpat]:
123
                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_n'+str(abs(i))+mpat[mposinpat] if isclosec else 'Other'+mpat[mposinpat], np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
124
                else:
125
                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_nb' if isclosec else 'Other', np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
126
             if pred_list[-1][0]==0:
127
                tp_fp_tn_fn[2] += cur_cov - cur_m_c
128
                tp_fp_tn_fn[1] += cur_m_c
129
             else:
130
                tp_fp_tn_fn[0] += cur_m_c
131
                tp_fp_tn_fn[3] += cur_cov - cur_m_c
132
133
134
def readmodf(cpgdict, modf, label, pred_list, mna, tp_fp_tn_fn, mpat='Cg', mposinpat=0, t_start=None, t_end=None):
135
   with open(modf, 'r') as mr:
136
      line = mr.readline();
137
      while line:
138
          line = line.strip();
139
          if len(line)>0:
140
             lsp = line.split();
141
             cur_chr = lsp[0];
142
             cur_pos = int(lsp[1]);
143
             cur_strand = lsp[5];
144
             
145
             cur_cov = int(lsp[9]);
146
             cur_m_p = int(lsp[10]);
147
             cur_m_c = int(lsp[11]); 
148
149
             if not ((t_start==None or cur_pos>=t_start) and (t_end==None or cur_pos<=t_end)): 
150
                line = mr.readline();
151
                continue;
152
153
             if not (mna==lsp[3] and lsp[3]==(cpgdict[cur_chr][(cur_strand, cur_pos)][1] if cur_strand=='+' else na4com[cpgdict[cur_chr][(cur_strand, cur_pos)][1]])):
154
                print ('Error !! NA not equal %s == %s == %s %s' % (mna, lsp[3], cpgdict[cur_chr][(cur_strand, cur_pos)][1], modf))
155
             iscpg = False;
156
             if cpgdict[cur_chr][(cur_strand, cur_pos)][0]==1:
157
                 iscpg = True;
158
                 pred_list.append((label, cur_cov, cur_m_p, cur_m_c, mpat, np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
159
                 if (lsp[3]==mpat[mposinpat]): pass
160
                 else: print ('Error not methylated pos %s %s %s' % (mna, cur_strand, modf))
161
             if not iscpg:
162
                isclosec = False;
163
                for i in range(-3, 4):
164
                   if (cur_strand, cur_pos+i) in cpgdict[cur_chr] and cpgdict[cur_chr][(cur_strand, cur_pos+i)][0]==1:
165
                      isclosec = True; break;
166
                if lsp[3]==mpat[mposinpat]:
167
                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_n'+str(abs(i))+mpat[mposinpat] if isclosec else 'Other'+mpat[mposinpat], np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
168
                else:
169
                   pred_list.append((0, cur_cov, cur_m_p, cur_m_c, mpat+'_nb' if isclosec else 'Other', np.log(binom.pmf(cur_m_c, cur_cov, 0.05)) ))
170
             if pred_list[-1][0]==0:
171
                tp_fp_tn_fn[2] += cur_cov - cur_m_c
172
                tp_fp_tn_fn[1] += cur_m_c
173
             else:
174
                tp_fp_tn_fn[0] += cur_m_c
175
                tp_fp_tn_fn[3] += cur_cov - cur_m_c 
176
          line = mr.readline();   
177
178
179
sssfolder = sys.argv[1];   # 
180
mreffile = sys.argv[2];    # 
181
mpat=sys.argv[3];          # Cg
182
mposinpat=int(sys.argv[4]);# 0 
183
184
chrofinterest = sys.argv[5];
185
if chrofinterest=='': chrofinterest = None;
186
stposofinterest = int(sys.argv[6]);
187
if stposofinterest<0: stposofinterest = None;
188
edposofinterest = int(sys.argv[7]);
189
if edposofinterest<0: edposofinterest = None;
190
191
basefig = sys.argv[8]
192
hastwoclass = 1;
193
194
sssfiles = {mpat[mposinpat]:glob.glob(os.path.join(sssfolder, 'mod_pos.*.'+mpat[mposinpat]+'.bed'))}
195
sssfiles[mpat[mposinpat]].extend(glob.glob(os.path.join(sssfolder, '*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
196
sssfiles[mpat[mposinpat]].extend(glob.glob(os.path.join(sssfolder, '*/*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
197
print(str(len(sssfiles[mpat[mposinpat]])) + " " + str(sssfolder))
198
199
## for negative;
200
umrfiles = []
201
for cur_umr_f in sys.argv[9].split(','):
202
   if not os.path.isdir(cur_umr_f):
203
       print("No prediction folder {}".format(cur_umr_f))
204
       sys.exit(1);
205
   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, '*/*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
206
   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, '*/mod_pos.*.'+mpat[mposinpat]+'.bed')))
207
   umrfiles.extend(glob.glob(os.path.join(cur_umr_f, 'mod_pos.*.'+mpat[mposinpat]+'.bed')))
208
print(str(len(umrfiles)) + "  " + str(sys.argv[9].split(',')))
209
sys.stdout.flush()
210
211
for sa in sssfiles:
212
   print (sa)
213
   for nf in sssfiles[sa]:
214
      print ('\t'+nf)
215
216
cpgdict = readFA(mreffile, mpat, mposinpat, chrofinterest, stposofinterest, edposofinterest)
217
218
pred_dict = defaultdict();
219
for modf in umrfiles:
220
   readmodf_dict(cpgdict, modf, pred_dict, mpat[mposinpat], stposofinterest, edposofinterest)
221
222
baseinfo = [mpat, mpat+'_n1'+mpat[mposinpat], mpat+'_n2'+mpat[mposinpat], mpat+'_n3'+mpat[mposinpat], 'Other'+mpat[mposinpat], mpat+'_nb', 'Other']
223
224
classify_m = ['Methylation_Percentage']
225
classify_types = [baseinfo, [mpat]]
226
filename = [['all_mp','motif_mp'] ]
227
cov_thr = [1, 5]
228
mlinestyle = {1:'bo-', 3:'gx--', 5:'r*-.', 7:'cs-', 10:'md--', 15:'k+-.'}
229
230
pred_list = []; tp_fp_tn_fn = [0, 0, 0, 0]
231
232
add_from_dict(cpgdict, pred_dict, 0, pred_list, mpat[mposinpat], tp_fp_tn_fn, mpat, mposinpat)
233
234
if True:
235
   for na4 in sssfiles:
236
      for cur_f in sssfiles[na4]:
237
         print('%s %s' % (na4, cur_f)); sys.stdout.flush();
238
         readmodf(cpgdict, cur_f, hastwoclass, pred_list, na4, tp_fp_tn_fn, mpat, mposinpat, stposofinterest, edposofinterest);
239
   pred_list = np.array(pred_list, dtype=[('Methylation', np.uint), ('Coverage', np.uint64), ('Methylation_Percentage', np.uint64), ('Methylation_Coverage', np.uint64), ('BaseInfo', 'U20'), ('logp', np.float64)])
240
   
241
   if hastwoclass==1:
242
      cov_plot_thr = [1, 5]
243
      for ct_ind in range(len(classify_types)):
244
         ct = classify_types[ct_ind]
245
         cur_ct_data = pred_list[np.isin(pred_list['BaseInfo'], ct)]
246
         for cm_ind in range(len(classify_m)):
247
             print('basetype={} classify_measure={}'.format(ct, classify_m[cm_ind]))
248
             cm = classify_m[cm_ind]
249
           
250
             # 1 for roc, 2: pr;
251
             roc_or_pr = 2; roc_or_pr=0
252
             for roc_or_pr in range(1,3):
253
              if roc_or_pr>0: 
254
                mfig= plt.figure()
255
              if roc_or_pr==2:
256
                cur_fn = basefig+'/ap_plot_met_pr_'+filename[cm_ind][ct_ind]+'.png'
257
                xylab = ['Recall', 'Precision'];  leg_mpos = "lower left"
258
                for covt in cov_plot_thr:
259
                   precision, recall, thresholds = precision_recall_curve(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
260
                   ap_pr = average_precision_score(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
261
                   plt.plot(recall, precision, mlinestyle[covt], lw=2, label='Coverage>=%d (AP=%0.3f)' % (covt, ap_pr))
262
                   print('\t\t %s %d ap=%.5f' % (cur_fn, covt, ap_pr))
263
              elif roc_or_pr==1:
264
                xylab = ['False Positive Rate', 'True Positive Rate']; leg_mpos = "lower right"
265
                cur_fn = basefig+'/roc_plot_met_roc_'+filename[cm_ind][ct_ind]+'.png'
266
                prev = 0; prev_ind = -1
267
                for covt in cov_plot_thr:
268
                   fpr, tpr, mthr = roc_curve(cur_ct_data['Methylation'][cur_ct_data['Coverage']>=covt], cur_ct_data[cm][cur_ct_data['Coverage']>=covt])
269
                   #print(','.join([str(np.round(t1, 5)) for t1 in mthr]))
270
                   roc_auc = auc(fpr, tpr)
271
                   if (not np.isnan(roc_auc)) and (abs(roc_auc - prev)>=0.02 or (covt>10 and abs(roc_auc - prev)>=0.005) or (cov_plot_thr.index(covt)-prev_ind>1 and abs(roc_auc - prev)>=0.005)):
272
                      plt.plot(fpr, tpr, mlinestyle[covt], lw=2, label='Coverage>=%d (AUC=%0.3f)' % (covt, roc_auc)) 
273
                      prev = roc_auc; prev_ind = cov_plot_thr.index(covt)
274
                   if not np.isnan(roc_auc):
275
                      print ('\t\t %s %d %.7f' % (cur_fn, covt, roc_auc))
276
                plt.plot([0, 1], [0, 1])
277
              if roc_or_pr>0:
278
                plt.xlim([0.0, 1.0]);              plt.ylim([0.0, 1.0])
279
                plt.xlabel(xylab[0]);              plt.ylabel(xylab[1])
280
                plt.legend(loc=leg_mpos)
281
                mfig.savefig(cur_fn, dpi=300);              plt.close(mfig)
282
283