a b/DeepMod_tools/hm_cluster_predict.py
1
#!/usr/bin/env python
2
3
import os, sys, time
4
from collections import defaultdict
5
import glob
6
7
import numpy as np
8
9
from scipy import stats
10
11
import locale
12
locale.setlocale(locale.LC_ALL, 'en_US')
13
14
import tensorflow as tf
15
16
batch_size = 4096
17
18
cov_thrd = 5
19
20
def readBed(bedfile, t_chr=None, t_start=None, t_end=None):
21
   print('read {}'.format(bedfile)); sys.stdout.flush()
22
   beddict = defaultdict()
23
   with open(bedfile, 'r') as bedreader:
24
      start_time = time.time();
25
      line = bedreader.readline();
26
      while True:
27
         line = bedreader.readline();
28
         if not line: break;
29
30
         line = line.strip();
31
         if len(line)>20:
32
            mchr, start_pos, end_pos, _, _, m_strand, _, _, _, true_cov, meth_perc = line.split()
33
            start_pos = int(start_pos)
34
            true_cov = int(true_cov)
35
            if true_cov < cov_thrd: continue;
36
            meth_perc = round(int(meth_perc)/100.0, 3)
37
            if (t_chr not in [None, mchr]) or (not ((t_start==None or start_pos>=t_start) and (t_end==None or start_pos<=t_end))):
38
                 continue;
39
            if true_cov==0: continue
40
            beddict[(mchr, m_strand, start_pos)] = meth_perc
41
   return beddict
42
43
def readpredmod(predmodf, preddict, t_chr=None, t_start=None, t_end=None, cgCposdict=None):
44
   print('read {}'.format(predmodf)); sys.stdout.flush()
45
   with open(predmodf, 'r') as mr:
46
      while True:
47
          line = mr.readline();
48
          if not line: break;
49
          line = line.strip();
50
          if len(line)>0:
51
             lsp = line.split();
52
             cur_chr = lsp[0];
53
             cur_pos = int(lsp[1]);
54
             cur_strand = lsp[5];
55
56
             if not (cgCposdict==None or (cur_chr, cur_strand, cur_pos) in cgCposdict): continue
57
58
             cur_cov = int(lsp[9]);
59
             cur_m_p = int(lsp[10]);
60
             cur_m_c = int(lsp[11]);
61
             if (t_chr not in [None, cur_chr]) or (not ((t_start==None or cur_pos>=t_start) and (t_end==None or cur_pos<=t_end))):
62
                continue;
63
             if cur_cov==0: continue;
64
 
65
             if (cur_chr, cur_strand, cur_pos) not in preddict:
66
                 preddict[(cur_chr, cur_strand, cur_pos)] = [cur_cov, round(cur_m_p/100.0, 3), cur_m_c, line]
67
             else:
68
                 print("Warning_duplicate {}".format(predmodf))
69
                 preddict[(cur_chr, cur_strand, cur_pos)][0] += cur_cov
70
                 preddict[(cur_chr, cur_strand, cur_pos)][2] += cur_m_c
71
                 if preddict[(cur_chr, cur_strand, cur_pos)][0]>0:
72
                    preddict[(cur_chr, cur_strand, cur_pos)][1] = round(preddict[(cur_chr, cur_strand, cur_pos)][2]/float(preddict[(cur_chr, cur_strand, cur_pos)][0]), 3)
73
74
75
76
pred_file = sys.argv[1]+'.%s.C.bed'
77
save_file = sys.argv[1]+'_clusterCpG.%s.C.bed'
78
gmotfolder = sys.argv[2]
79
80
mpat = 'Cg'; mposinpat=0
81
stposofinterest = None; edposofinterest = None;
82
83
nbsize = 25;
84
train_mod = 'DeepMod/train_mod/na12878_cluster_train_mod-keep_prob0.7-nb25-chr1/{}.cov{}.nb{}'.format(mpat, cov_thrd, nbsize)
85
86
chrkeys = []
87
for i in range(1, 23):
88
   chrkeys.append("chr%d" % i)
89
chrkeys.append("chrX")
90
chrkeys.append("chrY")
91
chrkeys.append("chrM")
92
93
94
new_saver = tf.train.import_meta_graph(train_mod+'.meta')
95
print(new_saver); sys.stdout.flush()
96
with tf.Session() as sess:
97
   print("restore model: {} {}".format(train_mod+'.meta', train_mod[:train_mod.rindex('/')+1]))
98
   print(new_saver.restore(sess,tf.train.latest_checkpoint(train_mod[:train_mod.rindex('/')+1]))); sys.stdout.flush()
99
100
   mgraph = tf.get_default_graph()
101
   output = mgraph.get_tensor_by_name('output:0')
102
   X = mgraph.get_tensor_by_name('X:0')
103
   keep_prob = mgraph.get_tensor_by_name('keep_prob:0')
104
 
105
   for chrofinterest in chrkeys:
106
      #read pred
107
      preddict = defaultdict()
108
109
      cur_cg_pos = '%s/motif_%s_C.bed' % (gmotfolder, chrofinterest)
110
      if not os.path.isfile(cur_cg_pos): 
111
         print("Warning_motif!!! no file {}".format(cur_cg_pos))
112
         continue;
113
      if not os.path.isfile(pred_file % chrofinterest):
114
         print("Warning_pred!!! no file {}".format(pred_file % chrofinterest))
115
         continue;
116
 
117
      cgposdict = defaultdict();
118
      with open(cur_cg_pos, 'r') as mr:
119
         while True:
120
            line = mr.readline();
121
            if not line: break;
122
            lsp = line.split();
123
            cgposdict[ (lsp[0], lsp[2], int(lsp[1]) ) ] = True
124
      print("{}: read {} done! ".format(len(cgposdict), cur_cg_pos)); sys.stdout.flush()
125
      readpredmod(pred_file % chrofinterest, preddict, chrofinterest, cgCposdict=cgposdict)
126
      print("size={} vs ".format(len(preddict), len(cgposdict) )); sys.stdout.flush()
127
128
      train_data = []
129
      pdkeys = sorted(list( preddict.keys() ))
130
      for cspk in pdkeys: # preddict:
131
         if cspk not in cgposdict: 
132
            print("not in cpg warning!!! {} {}".format(chrofinterest, cspk))
133
134
         partner_pos = (cspk[0], '-' if cspk[1]=='+' else '+', cspk[2]+1 if cspk[1]=='+' else cspk[2]-1)
135
         cur_x = [preddict[cspk][1], preddict[partner_pos][1] if partner_pos in preddict else 0]
136
         for pdis in range(11):
137
            cur_x.append(0)
138
         cur_x.append(0)
139
         if len(train_data)<10: print("test")
140
         for rpos in range(cspk[2]-nbsize, cspk[2]+nbsize+1):
141
            if rpos in [cspk[2], partner_pos[2]]: continue;
142
            
143
            if (cspk[0], '+', rpos) in cgposdict and (cspk[0], '+', rpos) in preddict:
144
                cur_x[int(preddict[(cspk[0], '+', rpos)][1]/0.1+0.5) + 3] += 1
145
                cur_x[2] += 1
146
                if len(train_data)<10: print("\t\t{}: {}".format((cspk[0], '+', rpos), preddict[(cspk[0], '+', rpos)]))
147
            elif (cspk[0], '-', rpos) in cgposdict and (cspk[0], '-', rpos) in preddict:
148
                cur_x[int(preddict[(cspk[0], '-', rpos)][1]/0.1+0.5) + 3] += 1
149
                cur_x[2] += 1
150
                if len(train_data)<10: print("\t\t{}: {}".format((cspk[0], '-', rpos), preddict[(cspk[0], '-', rpos)]))
151
         for i in range(3, len(cur_x)):
152
            if cur_x[2]>0: cur_x[i] = round(cur_x[i]/float(cur_x[2]), 3)
153
         if len(train_data)<10: print('\t{}'.format(cur_x)); sys.stdout.flush()
154
         train_data.append(cur_x)
155
156
      print("format data: data={}; {}".format(len(train_data), len(train_data[0]))); sys.stdout.flush()
157
      
158
      batch_data = np.array_split(train_data, int(len(train_data)/batch_size) if len(train_data)>batch_size else 2)
159
      m_pred_new_per = []
160
      for i in range(len(batch_data)):
161
          moutp = sess.run([output], feed_dict={X:batch_data[i], keep_prob:1})
162
          for mpind in moutp:
163
              for curpd in mpind:
164
                 m_pred_new_per.append(curpd)
165
      print("new per: {}, {}  {} {}".format(len(pdkeys), len(train_data), len(m_pred_new_per), curpd ))
166
      for wind in range(10):
167
         print("'{}' <{}> {}".format(preddict[pdkeys[wind]][-1], m_pred_new_per[wind], train_data[wind]))
168
      with open(save_file % chrofinterest, 'w') as mwriter:
169
         for wind in range(len(pdkeys)):
170
            mwriter.write("{} {}\n".format(preddict[pdkeys[wind]][-1], int(m_pred_new_per[wind]*100)))
171
 
172