|
a |
|
b/graph_algorithm/RandomWalk.java |
|
|
1 |
package GraphAlgorithm; |
|
|
2 |
|
|
|
3 |
import java.io.BufferedWriter; |
|
|
4 |
import java.io.File; |
|
|
5 |
import java.io.FileWriter; |
|
|
6 |
import java.io.IOException; |
|
|
7 |
import java.util.ArrayList; |
|
|
8 |
import java.util.Collections; |
|
|
9 |
import java.util.Comparator; |
|
|
10 |
import java.util.HashMap; |
|
|
11 |
import java.util.LinkedHashMap; |
|
|
12 |
import java.util.LinkedList; |
|
|
13 |
import java.util.List; |
|
|
14 |
import java.util.Map; |
|
|
15 |
import java.util.Set; |
|
|
16 |
import network.CommGeneticsLoader; |
|
|
17 |
import network.DisGraph; |
|
|
18 |
import network.SparseMatrix; |
|
|
19 |
import network.SparseVector; |
|
|
20 |
import util.DCNOMIMUMLSIDmap; |
|
|
21 |
import util.FileToList; |
|
|
22 |
|
|
|
23 |
/** |
|
|
24 |
* Implement random walk with restart algorithm using sparse matrix data structure |
|
|
25 |
* @author zhengc |
|
|
26 |
* |
|
|
27 |
*/ |
|
|
28 |
public class RandomWalk { |
|
|
29 |
|
|
|
30 |
// static final double P = 0.15; |
|
|
31 |
// static final double P = 0.15; |
|
|
32 |
static final double err = 0.000001; |
|
|
33 |
private Map<String, Integer> generank; |
|
|
34 |
|
|
|
35 |
/** |
|
|
36 |
* Constructs a new RandomWalk object from a graph and seed vector |
|
|
37 |
* @param cgg a CommGeneticsGraph |
|
|
38 |
* @param E the seed vector |
|
|
39 |
* @param p restart rate |
|
|
40 |
*/ |
|
|
41 |
public RandomWalk(DisGraph cgg, SparseVector E, double p) { |
|
|
42 |
Map<String, Double> genescore = new HashMap<>(); |
|
|
43 |
int size = cgg.getNodes(); |
|
|
44 |
SparseVector x = new SparseVector(size); |
|
|
45 |
SparseMatrix A = cgg.getNet(); |
|
|
46 |
A = normalizeMatrix(A); |
|
|
47 |
int iter = 0; |
|
|
48 |
SparseVector last_x = E; |
|
|
49 |
while((last_x.plus(x.scale(-1)).norm()) > err) { |
|
|
50 |
last_x = x; |
|
|
51 |
x = A.times(x).scale(1-p); |
|
|
52 |
x = x.plus(E.scale(p)); |
|
|
53 |
// System.out.print(x); |
|
|
54 |
iter++; |
|
|
55 |
// System.out.println("Finish iteration " + iter + ":"); |
|
|
56 |
} |
|
|
57 |
System.out.println("Total iteration " + iter); |
|
|
58 |
|
|
|
59 |
genescore = svToMap(x, E); |
|
|
60 |
generank = scoreToRank(genescore); |
|
|
61 |
|
|
|
62 |
} |
|
|
63 |
|
|
|
64 |
/** |
|
|
65 |
* Get the gene-rank pairs from random walk with restart |
|
|
66 |
* @return the gene-rank pairs from random walk with restart |
|
|
67 |
*/ |
|
|
68 |
public Map<String, Integer> getRWRank() { |
|
|
69 |
return generank; |
|
|
70 |
} |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
/** |
|
|
74 |
* Generate column/left stochastic matrix by normalization |
|
|
75 |
* @param comMatrix |
|
|
76 |
* @return normalized SparseMatrix |
|
|
77 |
*/ |
|
|
78 |
private static SparseMatrix normalizeMatrix(SparseMatrix comMatrix) { |
|
|
79 |
|
|
|
80 |
int n = comMatrix.size(); |
|
|
81 |
double[] total_w = new double[n]; //Total weight of a column |
|
|
82 |
SparseMatrix A = new SparseMatrix(n); |
|
|
83 |
|
|
|
84 |
for (int j = 0; j < n; j++){ |
|
|
85 |
total_w[j] = 0; |
|
|
86 |
for (int i = 0; i < n; i++){ |
|
|
87 |
total_w[j] += comMatrix.get(i, j); |
|
|
88 |
} |
|
|
89 |
} |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
//normalize the adjacency matrix |
|
|
93 |
for (int i = 0; i < n; i++){ |
|
|
94 |
|
|
|
95 |
for (int j = 0; j < n; j++){ |
|
|
96 |
double norm_w = comMatrix.get(i, j); |
|
|
97 |
norm_w = norm_w * 1.0 / total_w[j]; |
|
|
98 |
A.put(i, j, norm_w); |
|
|
99 |
} |
|
|
100 |
} |
|
|
101 |
|
|
|
102 |
return A; |
|
|
103 |
} |
|
|
104 |
|
|
|
105 |
/** |
|
|
106 |
* Create a seed from a list |
|
|
107 |
* @param nodesids a list of seed nodes |
|
|
108 |
* @return the seed vector |
|
|
109 |
* @throws IOException |
|
|
110 |
*/ |
|
|
111 |
public static SparseVector createSeedVector(List<String> nodesids) throws IOException{ |
|
|
112 |
|
|
|
113 |
int n = CommGeneticsLoader.entry_index.size(); |
|
|
114 |
SparseVector E = new SparseVector(n); |
|
|
115 |
|
|
|
116 |
for (String id : nodesids) { |
|
|
117 |
if (CommGeneticsLoader.entry_index.containsKey(id)) { |
|
|
118 |
int index = CommGeneticsLoader.entry_index.get(id); |
|
|
119 |
// E.put(index, 1); |
|
|
120 |
E.put(index, 1.0 / nodesids.size()); |
|
|
121 |
} else { |
|
|
122 |
System.out.println(id + " is NOT included in our network!"); |
|
|
123 |
} |
|
|
124 |
} |
|
|
125 |
return E; |
|
|
126 |
} |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
public static SparseVector createSeedVectorFromIdx(List<Integer> nodeidxs) { |
|
|
130 |
int n = CommGeneticsLoader.entry_index.size(); |
|
|
131 |
SparseVector E = new SparseVector(n); |
|
|
132 |
|
|
|
133 |
for (int idx : nodeidxs) { |
|
|
134 |
if (idx > n - 1) { |
|
|
135 |
System.out.println(idx + " is out of network!"); |
|
|
136 |
} else { |
|
|
137 |
E.put(idx, 1.0 / nodeidxs.size());; |
|
|
138 |
} |
|
|
139 |
} |
|
|
140 |
return E; |
|
|
141 |
} |
|
|
142 |
|
|
|
143 |
/** |
|
|
144 |
* Convert a SparseVector from random walk into a HashMap reversely ordered by value |
|
|
145 |
* @param sv the SparseVector representing gene rank score result |
|
|
146 |
* @param seeds_vector the seed vector |
|
|
147 |
* @return a HashMap reversely ordered by value |
|
|
148 |
*/ |
|
|
149 |
private static Map<String, Double> svToMap(SparseVector sv, SparseVector seeds_vector) { |
|
|
150 |
Map<String, Double> genemap = new HashMap<String,Double>(); |
|
|
151 |
// System.out.println(CommGeneticsLoader.num_disease); |
|
|
152 |
// System.out.println(CommGeneticsLoader.entry_list.size()); |
|
|
153 |
for (int i = CommGeneticsLoader.num_disease; i < CommGeneticsLoader.entry_list.size(); i++){ |
|
|
154 |
// double val_seed = seeds_vector.get(i); |
|
|
155 |
Set<Integer> seeds_key = seeds_vector.getKeys(); |
|
|
156 |
if (!seeds_key.contains(i)) { // Test if the node is one of seeds |
|
|
157 |
double val = sv.get(i); |
|
|
158 |
// System.out.println(index_entry.get(num_disease)); |
|
|
159 |
|
|
|
160 |
if (val != 0){ |
|
|
161 |
String gene = CommGeneticsLoader.index_entry.get(i); |
|
|
162 |
genemap.put(gene, val); |
|
|
163 |
} |
|
|
164 |
} |
|
|
165 |
} |
|
|
166 |
genemap = util.CollectionsEx.sortByValueR(genemap); |
|
|
167 |
return genemap; |
|
|
168 |
} |
|
|
169 |
|
|
|
170 |
/** |
|
|
171 |
* Convert a gene rank score into a gene rank HashMap reversely ordered by value |
|
|
172 |
* @param sortedgenemap a reversely sorted gene score map |
|
|
173 |
* @return a gene rank HashMap |
|
|
174 |
*/ |
|
|
175 |
private static Map<String, Integer> scoreToRank(Map<String, Double> sortedgenemap) { |
|
|
176 |
Map<String, Integer> generank = new HashMap<String,Integer>(); |
|
|
177 |
int rank = 0; |
|
|
178 |
for (String s : sortedgenemap.keySet()) { |
|
|
179 |
// System.out.println(sortedgenemap.get(s)); |
|
|
180 |
rank++; |
|
|
181 |
generank.put(s, rank); |
|
|
182 |
generank = util.CollectionsEx.sortByValue(generank); |
|
|
183 |
} |
|
|
184 |
|
|
|
185 |
return generank; |
|
|
186 |
} |
|
|
187 |
|
|
|
188 |
// System.out.println(genemap.size()); |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
|
|
|
192 |
/** |
|
|
193 |
* Write gene rank scores into CSV file, file format is "Rank, Gene, Score" |
|
|
194 |
* @param sortedgenemap HashMap sorted by descend values |
|
|
195 |
* @param rankfile file to be written |
|
|
196 |
* @throws IOException |
|
|
197 |
*/ |
|
|
198 |
public static void saveScorefile(Map<String, Double> sortedgenemap, String rankfile) throws IOException { |
|
|
199 |
|
|
|
200 |
BufferedWriter bw = new BufferedWriter(new FileWriter(new File(rankfile))); |
|
|
201 |
int rank = 0; |
|
|
202 |
bw.write("Rank" + "," + "Gene" + "," + "Score" + "," + "\n"); |
|
|
203 |
for (String s : sortedgenemap.keySet()) { |
|
|
204 |
// System.out.println(sortedgenemap.get(s)); |
|
|
205 |
Double score = sortedgenemap.get(s); |
|
|
206 |
rank++; |
|
|
207 |
bw.write(rank + "," + s + "," + score + "," + "\n"); |
|
|
208 |
} |
|
|
209 |
bw.close(); |
|
|
210 |
} |
|
|
211 |
|
|
|
212 |
/** |
|
|
213 |
* Write gene rank file into a CSV file, file format is "Rank, Gene" |
|
|
214 |
* @param generank gene rank HashMap |
|
|
215 |
* @param rankfile the file to be written |
|
|
216 |
* @throws IOException |
|
|
217 |
*/ |
|
|
218 |
public static void saveRankfile(Map<String, Integer> generank, String rankfile) throws IOException { |
|
|
219 |
|
|
|
220 |
BufferedWriter bw = new BufferedWriter(new FileWriter(new File(rankfile))); |
|
|
221 |
bw.write("Rank" + "," + "Gene" + "," + "\n"); |
|
|
222 |
for (String s : generank.keySet()) { |
|
|
223 |
// System.out.println(sortedgenemap.get(s)); |
|
|
224 |
int rank = generank.get(s); |
|
|
225 |
bw.write(rank + "," + s + "," + "\n"); |
|
|
226 |
} |
|
|
227 |
bw.close(); |
|
|
228 |
} |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
public static void main(String[] args) throws IOException { |
|
|
232 |
|
|
|
233 |
/* |
|
|
234 |
* Create additional DCN_OMIM map for give disease |
|
|
235 |
*/ |
|
|
236 |
String dcnmapfile = "./data/term_umls_id_diso"; |
|
|
237 |
String omimmapfile = "./data/OMIM_umls_id_diso"; |
|
|
238 |
util.DCNOMIMUMLSIDmap.createDCNIdNameMap(dcnmapfile); |
|
|
239 |
util.DCNOMIMUMLSIDmap.createOMIMIdNameMap(omimmapfile); |
|
|
240 |
|
|
|
241 |
String DCN_dis = "dementia"; |
|
|
242 |
String OMIM_pat = "alzheimer"; |
|
|
243 |
|
|
|
244 |
DCNOMIMUMLSIDmap domim = new DCNOMIMUMLSIDmap(DCN_dis, OMIM_pat); |
|
|
245 |
Map<String, List<String>> dcn_omim = domim.getDCNOMIMUMLSIDmap(); |
|
|
246 |
|
|
|
247 |
/* |
|
|
248 |
* Build the bipartite network |
|
|
249 |
*/ |
|
|
250 |
|
|
|
251 |
String commnetfile = "./results/fares_comm_net_lift_final_abbr.txt"; |
|
|
252 |
String ppifile = "./data/gene_gene_string_cut.txt"; |
|
|
253 |
String disgenefile = "./data/OMIM_disease_gene_umls_id_diso"; |
|
|
254 |
DisGraph cgg = null; |
|
|
255 |
try { |
|
|
256 |
cgg = CommGeneticsLoader.createCommGeneticsGraph(commnetfile, ppifile, disgenefile, dcn_omim); |
|
|
257 |
} catch (IOException e) { |
|
|
258 |
// TODO Auto-generated catch block |
|
|
259 |
e.printStackTrace(); |
|
|
260 |
} |
|
|
261 |
|
|
|
262 |
/* |
|
|
263 |
* Create seed vector |
|
|
264 |
*/ |
|
|
265 |
int n = cgg.getNodes(); |
|
|
266 |
|
|
|
267 |
List<String> seeds = new ArrayList<>(); |
|
|
268 |
String seed_dis1 = "dementia"; |
|
|
269 |
String seed_dis1_umls = null; |
|
|
270 |
if (util.DCNOMIMUMLSIDmap.dcnnameidmap.containsKey(seed_dis1)) { |
|
|
271 |
seed_dis1_umls = util.DCNOMIMUMLSIDmap.dcnnameidmap.get(seed_dis1); |
|
|
272 |
System.out.println(seed_dis1 + ": " + seed_dis1_umls); |
|
|
273 |
seeds.add(seed_dis1_umls); |
|
|
274 |
} else { |
|
|
275 |
System.out.println(seed_dis1 + " is NOT included in our network!"); |
|
|
276 |
} |
|
|
277 |
|
|
|
278 |
String seed_genes_file = "./data/AD_omim_genes.txt"; |
|
|
279 |
List<String> seed_genes = FileToList.createListFromfile(seed_genes_file); |
|
|
280 |
|
|
|
281 |
seeds.addAll(seed_genes); |
|
|
282 |
// diseases.add(disease2); |
|
|
283 |
SparseVector seeds_vec = new SparseVector(n); |
|
|
284 |
try { |
|
|
285 |
seeds_vec = createSeedVector(seeds); |
|
|
286 |
} catch (IOException e) { |
|
|
287 |
// TODO Auto-generated catch block |
|
|
288 |
e.printStackTrace(); |
|
|
289 |
} |
|
|
290 |
System.out.println("Seed vector: " + seeds_vec); |
|
|
291 |
|
|
|
292 |
/* |
|
|
293 |
* Random walk |
|
|
294 |
*/ |
|
|
295 |
|
|
|
296 |
RandomWalk rwrank = new RandomWalk(cgg, seeds_vec, 0.5); |
|
|
297 |
Map<String, Integer> result = rwrank.getRWRank(); |
|
|
298 |
|
|
|
299 |
/* |
|
|
300 |
* Process result and write to file |
|
|
301 |
*/ |
|
|
302 |
|
|
|
303 |
String rankfile = "./results/AD_novel_genes.csv"; |
|
|
304 |
try { |
|
|
305 |
saveRankfile(result, rankfile); |
|
|
306 |
} catch (IOException e) { |
|
|
307 |
// TODO Auto-generated catch block |
|
|
308 |
e.printStackTrace(); |
|
|
309 |
} |
|
|
310 |
} |
|
|
311 |
} |